diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 52911d3b34d6..4d150e93655b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,8 @@ * @ritchie46 -/.github/ @ritchie46 @stinodego +/.github/ @ritchie46 @stinodego /crates/ @ritchie46 @orlp /crates/polars-sql/ @ritchie46 @orlp @universalmind303 /crates/polars-time/ @ritchie46 @orlp @MarcoGorelli /py-polars/ @ritchie46 @stinodego @alexander-beedie +/docs/ @ritchie46 @c-peters @braaannigan diff --git a/.github/ISSUE_TEMPLATE/bug_report_python.yml b/.github/ISSUE_TEMPLATE/bug_report_python.yml index a90f239ca3f9..005a245e6de0 100644 --- a/.github/ISSUE_TEMPLATE/bug_report_python.yml +++ b/.github/ISSUE_TEMPLATE/bug_report_python.yml @@ -1,5 +1,5 @@ name: '🐞 Bug report - Python' -description: An issue with Python Polars +description: Report an issue with Python Polars. labels: [bug, python] body: @@ -30,6 +30,15 @@ body: validations: required: true + - type: textarea + id: logs + attributes: + label: Log output + description: > + Set the environment variable ``POLARS_VERBOSE=1`` before running the query. + Paste the output of ``stderr`` here. + render: shell + - type: textarea id: problem attributes: @@ -64,3 +73,4 @@ body: validations: required: true + diff --git a/.github/ISSUE_TEMPLATE/bug_report_rust.yml b/.github/ISSUE_TEMPLATE/bug_report_rust.yml index 2f32e7ee3a71..7d8ce6367272 100644 --- a/.github/ISSUE_TEMPLATE/bug_report_rust.yml +++ b/.github/ISSUE_TEMPLATE/bug_report_rust.yml @@ -1,5 +1,5 @@ name: '🐞 Bug report - Rust' -description: An issue with Rust Polars +description: Report an issue with Rust Polars. labels: [bug, rust] body: @@ -30,6 +30,15 @@ body: validations: required: true + - type: textarea + id: logs + attributes: + label: Log output + description: > + Set the environment variable ``POLARS_VERBOSE=1`` before running the query. + Paste the output of ``stderr`` here. + render: shell + - type: textarea id: problem attributes: diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 000000000000..3594bdb6a40e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,23 @@ +name: '📖 Documentation improvement' +description: Report an issue with the documentation. +labels: [documentation] + +body: + - type: textarea + id: description + attributes: + label: Description + description: > + Describe the issue with the documentation and how it can be fixed or improved. + validations: + required: true + + - type: input + id: link + attributes: + label: Link + description: > + Provide a link to the existing documentation, if applicable. + placeholder: ex. https://pola-rs.github.io/polars/docs/python/dev/... + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 6d70797a52b0..eed3105bf95f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,14 +1,14 @@ name: '✨ Feature request' -description: Suggest a new feature or enhancement for Polars +description: Suggest a new feature or enhancement for Polars. labels: [enhancement] body: - type: textarea id: description attributes: - label: Problem description + label: Description description: > - Please describe the feature or enhancement and explain why it should be implemented. + Describe the feature or enhancement and explain why it should be implemented. Include a code example if applicable. validations: required: true diff --git a/.github/deploy_manylinux.sh b/.github/deploy_manylinux.sh deleted file mode 100644 index 4c7ae774b1c9..000000000000 --- a/.github/deploy_manylinux.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -# easier debugging -set -e -pwd -ls -la - -rm py-polars/README.md -cp README.md py-polars/README.md -cd py-polars -rustup override set nightly-2023-08-26 -export RUSTFLAGS='-C target-feature=+fxsr,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma' - -# first the default release -maturin publish \ - --skip-existing \ - --username ritchie46 - -# now compile polars with bigidx feature -sed -i 's/name = "polars"/name = "polars-u64-idx"/' pyproject.toml -# a brittle hack to insert the 'bigidx' feature -sed -i 's/"dynamic_group_by",/"dynamic_group_by",\n"bigidx",/' Cargo.toml - -maturin publish \ - --skip-existing \ - --username ritchie46 - -# https://github.com/actions/checkout/issues/760 -git config --global --add safe.directory /github/workspace -# Clean up after bigidx changes -git checkout . diff --git a/.github/release-drafter-python.yml b/.github/release-drafter-python.yml index d2e17c11a905..a81ed56bd60c 100644 --- a/.github/release-drafter-python.yml +++ b/.github/release-drafter-python.yml @@ -13,3 +13,24 @@ version-resolver: - breaking - breaking python default: patch + +categories: + - title: 🏆 Highlights + labels: highlight + - title: 💥 Breaking changes + labels: + - breaking + - breaking python + - title: ⚠️ Deprecations + labels: deprecation + - title: 🚀 Performance improvements + labels: performance + - title: ✨ Enhancements + labels: enhancement + - title: 🐞 Bug fixes + labels: fix + - title: 🛠️ Other improvements + labels: + - build + - documentation + - internal diff --git a/.github/release-drafter-rust.yml b/.github/release-drafter-rust.yml index 10c3b7ddf759..2d333e2a3c41 100644 --- a/.github/release-drafter-rust.yml +++ b/.github/release-drafter-rust.yml @@ -13,3 +13,23 @@ version-resolver: - breaking - breaking rust default: patch + +categories: + - title: 🏆 Highlights + labels: highlight + - title: 💥 Breaking changes + labels: + - breaking + - breaking rust + - title: 🚀 Performance improvements + labels: performance + - title: ✨ Enhancements + labels: enhancement + - title: 🐞 Bug fixes + labels: fix + - title: 🛠️ Other improvements + labels: + - build + - deprecation + - documentation + - internal diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index 15a62a8bd827..8216254ab6bc 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -1,22 +1,3 @@ -categories: - - title: 🏆 Highlights - labels: highlight - - title: 💥 Breaking changes - labels: breaking - - title: ⚠️ Deprecations - labels: deprecation - - title: 🚀 Performance improvements - labels: performance - - title: ✨ Enhancements - labels: enhancement - - title: 🐞 Bug fixes - labels: fix - - title: 🛠️ Other improvements - labels: - - build - - documentation - - internal - exclude-labels: - skip changelog - release diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index de2bfc326ae1..254da13172e3 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -28,7 +28,7 @@ jobs: main: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 diff --git a/.github/workflows/clear-caches.yml b/.github/workflows/clear-caches.yml index f6a001c35419..fc75374b21fb 100644 --- a/.github/workflows/clear-caches.yml +++ b/.github/workflows/clear-caches.yml @@ -11,7 +11,7 @@ jobs: clear-caches: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Clear all caches run: gh cache delete --all diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml new file mode 100644 index 000000000000..6e8f12bcae5e --- /dev/null +++ b/.github/workflows/docs-global.yml @@ -0,0 +1,87 @@ +name: Build documentation + +on: + pull_request: + paths: + - docs/** + - mkdocs.yml + - .github/workflows/docs-global.yml + push: + tags: + - py-** + +jobs: + markdown-link-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: gaurav-nelson/github-action-markdown-link-check@v1 + with: + folder-path: docs + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + src: docs/src/python + version: "23.9.1" + + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Create virtual environment + run: | + python -m venv .venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: | + pip install -r py-polars/requirements-dev.txt + pip install -r docs/requirements.txt + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + workspaces: py-polars + save-if: ${{ github.ref_name == 'main' }} + + - name: Install Polars + working-directory: py-polars + run: | + source activate + maturin develop + + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v1 + + - name: Build documentation + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: mkdocs build + + - name: Add .nojekyll + if: ${{ github.ref_type == 'tag' }} + working-directory: site + run: touch .nojekyll + + - name: Deploy docs + if: ${{ github.ref_type == 'tag' }} + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: site + clean-exclude: | + docs/ + py-polars/ + single-commit: true diff --git a/.github/workflows/docs-python.yml b/.github/workflows/docs-python.yml index 2b58f9494f69..3cc0e96c36a7 100644 --- a/.github/workflows/docs-python.yml +++ b/.github/workflows/docs-python.yml @@ -23,7 +23,7 @@ jobs: build-python-docs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 diff --git a/.github/workflows/docs-rust.yml b/.github/workflows/docs-rust.yml index 26d5e94b1e9b..cd02b16ef53d 100644 --- a/.github/workflows/docs-rust.yml +++ b/.github/workflows/docs-rust.yml @@ -19,7 +19,7 @@ jobs: build-rust-docs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup component add rust-docs diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml index 85b86ed05f6e..2ebcc0dca3b0 100644 --- a/.github/workflows/lint-global.yml +++ b/.github/workflows/lint-global.yml @@ -11,7 +11,7 @@ jobs: main: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Lint Markdown and TOML uses: dprint/check@v2.2 - name: Spell Check with Typos diff --git a/.github/workflows/lint-py-polars.yml b/.github/workflows/lint-py-polars.yml index 2af80ca0b9e6..a8d75b835b70 100644 --- a/.github/workflows/lint-py-polars.yml +++ b/.github/workflows/lint-py-polars.yml @@ -30,7 +30,7 @@ jobs: working-directory: py-polars steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup component add rustfmt clippy @@ -46,3 +46,6 @@ jobs: - name: Run clippy run: cargo clippy --locked -- -D warnings + + - name: Compile without default features + run: cargo check --no-default-features diff --git a/.github/workflows/lint-python.yml b/.github/workflows/lint-python.yml index 325cb1e833a5..6568d52681b4 100644 --- a/.github/workflows/lint-python.yml +++ b/.github/workflows/lint-python.yml @@ -18,7 +18,7 @@ jobs: working-directory: py-polars steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 @@ -45,7 +45,7 @@ jobs: working-directory: py-polars steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 diff --git a/.github/workflows/lint-rust.yml b/.github/workflows/lint-rust.yml index de78f4e2d505..ec1bf314cdbb 100644 --- a/.github/workflows/lint-rust.yml +++ b/.github/workflows/lint-rust.yml @@ -27,7 +27,7 @@ jobs: clippy-nightly: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup component add clippy @@ -38,13 +38,13 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Run cargo clippy with all features enabled - run: cargo clippy --workspace --all-targets --all-features -- -D warnings + run: cargo clippy -p polars --all-features -- -D warnings # Default feature set should compile on the stable toolchain clippy-stable: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup override set stable && rustup update @@ -58,13 +58,13 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Run cargo clippy - run: cargo clippy --workspace --all-targets -- -D warnings + run: cargo clippy -p polars -- -D warnings rustfmt: if: github.ref_name != 'main' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup component add rustfmt @@ -76,7 +76,7 @@ jobs: if: github.ref_name != 'main' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup component add miri @@ -90,7 +90,6 @@ jobs: POLARS_ALLOW_EXTENSION: '1' run: > cargo miri test - --no-default-features --features object -p polars-core -p polars-arrow diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml index 0f071abb50f4..47c0f76d3a50 100644 --- a/.github/workflows/release-drafter.yml +++ b/.github/workflows/release-drafter.yml @@ -5,6 +5,11 @@ on: branches: - main workflow_dispatch: + inputs: + # Latest commit to include with the release. If omitted, use the latest commit on the main branch. + sha: + description: Commit SHA + type: string permissions: contents: write @@ -18,6 +23,7 @@ jobs: uses: release-drafter/release-drafter@v5 with: config-name: release-drafter-rust.yml + commitish: ${{ inputs.sha }} disable-autolabeler: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -26,6 +32,7 @@ jobs: uses: release-drafter/release-drafter@v5 with: config-name: release-drafter-python.yml + commitish: ${{ inputs.sha }} disable-autolabeler: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index b212d41e8750..ec7e724a89bf 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -1,177 +1,250 @@ name: Release Python on: - push: - tags: - - py-* + workflow_dispatch: + inputs: + # Latest commit to include with the release. If omitted, use the latest commit on the main branch. + sha: + description: Commit SHA + type: string + # Create the sdist and build the wheels, but do not publish to PyPI / GitHub. + dry-run: + description: Dry run + type: boolean + default: false + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true env: - RUST_TOOLCHAIN: nightly-2023-08-26 PYTHON_VERSION: '3.8' - MATURIN_VERSION: '1.2.1' - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + CARGO_INCREMENTAL: 0 + CARGO_NET_RETRY: 10 + RUSTUP_MAX_RETRIES: 10 defaults: run: shell: bash jobs: - manylinux-x64_64: + create-sdist: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + package: [polars, polars-lts-cpu, polars-u64-idx] + steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 with: - python-version: ${{ env.PYTHON_VERSION }} + ref: ${{ inputs.sha }} - - name: Fix README symlink - run: | - rm py-polars/README.md - cp README.md py-polars/README.md - - - name: Publish wheel - uses: PyO3/maturin-action@v1 - env: - RUSTFLAGS: -C target-feature=+fxsr,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma + # Avoid potential out-of-memory errors + - name: Set swap space for Linux + uses: pierotofy/set-swap-space@master with: - command: publish - args: -m py-polars/Cargo.toml --skip-existing -o wheels -u ritchie46 - maturin-version: ${{ env.MATURIN_VERSION }} - rust-toolchain: ${{ env.RUST_TOOLCHAIN }} + swap-size-gb: 10 - # Needed for Docker on Apple M1 - manylinux-aarch64: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - name: Set up Python + uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - # Needed to avoid out-of-memory error - - name: Set Swap Space - uses: pierotofy/set-swap-space@master + - name: Fix README symlink + run: rm py-polars/README.md && cp README.md py-polars/README.md + + - name: Install yq + if: matrix.package != 'polars' + run: pip install yq + - name: Update package name + if: matrix.package != 'polars' + run: tomlq -i -t ".project.name = \"${{ matrix.package }}\"" py-polars/pyproject.toml + - name: Add bigidx feature + if: matrix.package == 'polars-u64-idx' + run: tomlq -i -t '.dependencies.polars.features += ["bigidx"]' py-polars/Cargo.toml + + - name: Create source distribution + uses: PyO3/maturin-action@v1 with: - swap-size-gb: 10 + command: sdist + args: > + --manifest-path py-polars/Cargo.toml + --out dist - - name: Fix README symlink + - name: Test sdist run: | - rm py-polars/README.md - cp README.md py-polars/README.md + TOOLCHAIN=$(grep -oP 'channel = "\K[^"]+' rust-toolchain.toml) + rustup default $TOOLCHAIN + pip install --force-reinstall --verbose dist/*.tar.gz + python -c 'import polars' - - name: Publish wheel - uses: PyO3/maturin-action@v1 - env: - JEMALLOC_SYS_WITH_LG_PAGE: 16 + - name: Upload sdist + uses: actions/upload-artifact@v3 with: - command: publish - args: -m py-polars/Cargo.toml --skip-existing --no-sdist -o wheels -i python -u ritchie46 - target: aarch64-unknown-linux-gnu - maturin-version: ${{ env.MATURIN_VERSION }} - rust-toolchain: ${{ env.RUST_TOOLCHAIN }} + name: sdist + path: dist/*.tar.gz + + build-wheels: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + package: [polars, polars-lts-cpu, polars-u64-idx] + os: [ubuntu-latest, macos-latest, windows-32gb-ram] + architecture: [x86-64, aarch64] + exclude: + - os: windows-32gb-ram + architecture: aarch64 - manylinux-bigidx: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Fix README symlink - run: | - rm py-polars/README.md - cp README.md py-polars/README.md - - - name: Prepare bigidx - run: | - sed -i 's/name = "polars"/name = "polars-u64-idx"/' py-polars/pyproject.toml - # A brittle hack to insert the 'bigidx' feature - sed -i 's/"dynamic_group_by",/"dynamic_group_by",\n"bigidx",/' py-polars/Cargo.toml + ref: ${{ inputs.sha }} - - name: Publish wheel - uses: PyO3/maturin-action@v1 - env: - RUSTFLAGS: -C target-feature=+fxsr,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma + # Avoid potential out-of-memory errors + - name: Set swap space for Linux + if: matrix.os == 'ubuntu-latest' + uses: pierotofy/set-swap-space@master with: - command: publish - args: -m py-polars/Cargo.toml --skip-existing -o wheels -u ritchie46 - maturin-version: ${{ env.MATURIN_VERSION }} - rust-toolchain: ${{ env.RUST_TOOLCHAIN }} + swap-size-gb: 10 - manylinux-x64_64-lts-cpu: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - name: Set up Python + uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - name: Fix README symlink + run: rm py-polars/README.md && cp README.md py-polars/README.md + + - name: Install yq + if: matrix.package != 'polars' + run: pip install yq + - name: Update package name + if: matrix.package != 'polars' + run: tomlq -i -t ".project.name = \"${{ matrix.package }}\"" py-polars/pyproject.toml + - name: Add bigidx feature + if: matrix.package == 'polars-u64-idx' + run: tomlq -i -t '.dependencies.polars.features += ["bigidx"]' py-polars/Cargo.toml + + - name: Set RUSTFLAGS for x86-64 + if: matrix.architecture == 'x86-64' && matrix.package != 'polars-lts-cpu' && matrix.os != 'macos-latest' + run: echo "RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt" >> $GITHUB_ENV + - name: Set RUSTFLAGS for x86-64 MacOS + if: matrix.architecture == 'x86-64' && matrix.package != 'polars-lts-cpu' && matrix.os == 'macos-latest' + run: echo "RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma" >> $GITHUB_ENV + - name: Set RUSTFLAGS for x86-64 LTS CPU + if: matrix.architecture == 'x86-64' && matrix.package == 'polars-lts-cpu' + run: echo "RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt --cfg use_mimalloc" >> $GITHUB_ENV + + - name: Set Rust target for aarch64 + if: matrix.architecture == 'aarch64' + id: target run: | - rm py-polars/README.md - cp README.md py-polars/README.md + TARGET=${{ matrix.os == 'macos-latest' && 'aarch64-apple-darwin' || 'aarch64-unknown-linux-gnu'}} + echo "target=$TARGET" >> $GITHUB_OUTPUT - - name: Prepare lts-cpu - run: sed -i 's/name = "polars"/name = "polars-lts-cpu"/' py-polars/pyproject.toml + - name: Set jemalloc for aarch64 Linux + if: matrix.architecture == 'aarch64' && matrix.os == 'ubuntu-latest' + run: | + echo "JEMALLOC_SYS_WITH_LG_PAGE=16" >> $GITHUB_ENV - - name: Publish wheel + - name: Build wheel uses: PyO3/maturin-action@v1 - env: - RUSTFLAGS: -C target-feature=+fxsr,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt --cfg use_mimalloc with: - command: publish - args: -m py-polars/Cargo.toml --skip-existing -o wheels -u ritchie46 - maturin-version: ${{ env.MATURIN_VERSION }} - rust-toolchain: ${{ env.RUST_TOOLCHAIN }} - - win-macos: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [macos-latest, windows-latest] + command: build + target: ${{ steps.target.outputs.target }} + args: > + --release + --manifest-path py-polars/Cargo.toml + --out dist + manylinux: auto + + - name: Upload wheel + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist/*.whl + + publish-to-pypi: + needs: [create-sdist, build-wheels] + environment: + name: release-python + url: https://pypi.org/project/polars + runs-on: ubuntu-latest + permissions: + id-token: write steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - name: Download sdist + uses: actions/download-artifact@v3 with: - python-version: ${{ env.PYTHON_VERSION }} + name: sdist + path: dist - - name: Fix README symlink - run: | - rm py-polars/README.md - cp README.md py-polars/README.md + - name: Download wheels + uses: actions/download-artifact@v3 + with: + name: wheels + path: dist - - name: Publish wheel - uses: PyO3/maturin-action@v1 - env: - RUSTFLAGS: -C target-feature=+fxsr,+sse,+sse2,+sse3,+sse4.1,+sse4.2 + - name: Publish to PyPI + if: inputs.dry-run == false + uses: pypa/gh-action-pypi-publish@release/v1 with: - command: publish - args: -m py-polars/Cargo.toml --no-sdist --skip-existing -o wheels -i python -u ritchie46 - maturin-version: ${{ env.MATURIN_VERSION }} - rust-toolchain: ${{ env.RUST_TOOLCHAIN }} + verbose: true - macos-aarch64: - runs-on: macos-latest + publish-to-github: + needs: publish-to-pypi + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 with: - python-version: ${{ env.PYTHON_VERSION }} + ref: ${{ inputs.sha }} - - name: Fix README symlink + - name: Download sdist + uses: actions/download-artifact@v3 + with: + name: sdist + path: dist + + - name: Get version from Cargo.toml + id: version + working-directory: py-polars run: | - rm py-polars/README.md - cp README.md py-polars/README.md + VERSION=$(grep -m 1 -oP 'version = "\K[^"]+' Cargo.toml) + if [[ "$VERSION" == *"-"* ]]; then + IS_PRERELEASE=true + else + IS_PRERELEASE=false + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "is_prerelease=$IS_PRERELEASE" >> $GITHUB_OUTPUT + + - name: Create GitHub release + id: github-release + uses: release-drafter/release-drafter@v5 + with: + config-name: release-drafter-python.yml + name: Python Polars ${{ steps.version.outputs.version }} + tag: py-${{ steps.version.outputs.version }} + version: ${{ steps.version.outputs.version }} + prerelease: ${{ steps.version.outputs.is_prerelease }} + commitish: ${{ inputs.sha }} + disable-autolabeler: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Set up Rust targets - run: rustup target add aarch64-apple-darwin + - name: Upload sdist to GitHub release + run: gh release upload $TAG $FILES --clobber + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG: ${{ steps.github-release.outputs.tag_name }} + FILES: dist/polars-*.tar.gz - - name: Publish wheel - uses: PyO3/maturin-action@v1 - with: - command: publish - args: -m py-polars/Cargo.toml --target aarch64-apple-darwin --no-sdist -o wheels -i python -u ritchie46 - maturin-version: ${{ env.MATURIN_VERSION }} + - name: Publish GitHub release + if: inputs.dry-run == false + run: gh release edit $TAG --draft=false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG: ${{ steps.github-release.outputs.tag_name }} diff --git a/.github/workflows/release-rust.yml b/.github/workflows/release-rust.yml index 9f0bd891e024..ad7be2155053 100644 --- a/.github/workflows/release-rust.yml +++ b/.github/workflows/release-rust.yml @@ -11,4 +11,4 @@ jobs: if: false runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 diff --git a/.github/workflows/test-bytecode-parser.yml b/.github/workflows/test-bytecode-parser.yml index dc8e6c09ce9e..27245919efb2 100644 --- a/.github/workflows/test-bytecode-parser.yml +++ b/.github/workflows/test-bytecode-parser.yml @@ -19,7 +19,7 @@ jobs: python-version: ['3.8', '3.9', '3.10', '3.11'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 1d9517d0ed5c..25c65475cc6f 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -4,6 +4,7 @@ on: pull_request: paths: - py-polars/** + - docs/src/python/** - crates/** - .github/workflows/test-python.yml push: @@ -11,6 +12,7 @@ on: - main paths: - crates/** + - docs/src/python/** - py-polars/** - .github/workflows/test-python.yml @@ -34,13 +36,16 @@ jobs: python-version: ['3.8', '3.11'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v1 + - name: Create virtual environment run: | python -m venv .venv @@ -65,11 +70,13 @@ jobs: - name: Run tests and report coverage if: github.ref_name != 'main' - run: pytest --cov -n auto --dist loadgroup -m "not benchmark" + run: pytest --cov -n auto --dist loadgroup -m "not benchmark and not docs" - name: Run doctests if: github.ref_name != 'main' - run: python tests/docs/run_doctest.py + run: | + python tests/docs/run_doctest.py + pytest tests/docs/test_user_guide.py -m docs - name: Check import without optional dependencies if: github.ref_name != 'main' @@ -80,6 +87,7 @@ jobs: "matplotlib" "backports.zoneinfo" "connectorx" + "pyiceberg" "deltalake" "xlsx2csv" ) @@ -98,7 +106,7 @@ jobs: python-version: ['3.11'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 @@ -125,7 +133,7 @@ jobs: - name: Run tests if: github.ref_name != 'main' - run: pytest -n auto --dist loadgroup -m "not benchmark" + run: pytest -n auto --dist loadgroup -m "not benchmark and not docs" - name: Check import without optional dependencies if: github.ref_name != 'main' diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 9e2ba685baf8..aacfe061026e 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -32,7 +32,7 @@ jobs: os: [ubuntu-latest, windows-latest] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup show @@ -77,7 +77,7 @@ jobs: os: [ubuntu-latest, windows-latest] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup show @@ -97,7 +97,7 @@ jobs: check-features: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: rustup show @@ -118,7 +118,7 @@ jobs: check-wasm: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Rust run: | diff --git a/.gitignore b/.gitignore index 1dd5ecb4236f..5eb602ae7f52 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,37 @@ *.iml *.so *.ipynb -.DS_Store .ENV -.coverage .env -.hypothesis/ -.idea/ .ipynb_checkpoints/ -.mypy_cache/ -.pytest_cache/ .python-version .yarn/ -.vscode/ -__pycache__/ -AUTO_CHANGELOG.md -Cargo.lock coverage.lcov coverage.xml data/ -node_modules/ polars/vendor -target/ -venv*/ -.venv*/ + +# OS +.DS_Store + +# IDE +.idea/ +.vscode/ .vim + +# Python +.hypothesis/ +.mypy_cache/ +.pytest_cache/ +.venv/ +__pycache__/ +.coverage + +# Rust +target/ +Cargo.lock + +# Project +/docs/data/ +/docs/images/ +/docs/people.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 315ac4c8acd8..44321d2f35bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,7 +48,6 @@ You may use the issue to discuss possible solutions. ### Setting up your local environment Polars development flow relies on both Rust and Python, which means setting up your local development environment is not trivial. -For contributing to Node.js Polars, please check out the [Node.js Polars](https://github.com/pola-rs/nodejs-polars) repository. If you run into problems, please contact us on [Discord](https://discord.gg/4UfP5cfBE7). _Note that if you are a Windows user, the steps below might not work as expected; try developing using [WSL](https://learn.microsoft.com/en-us/windows/wsl/install)._ @@ -56,7 +55,7 @@ _Note that if you are a Windows user, the steps below might not work as expected Start by [forking](https://docs.github.com/en/get-started/quickstart/fork-a-repo) the Polars repository, then clone your forked repository using `git`: ```bash -git clone git@github.com:/polars.git +git clone https://github.com//polars.git cd polars ``` @@ -89,7 +88,7 @@ This will do a number of things: - Use Python to create a virtual environment in the `.venv` folder. - Use [pip](https://pip.pypa.io/) to install all Python dependencies for development, linting, and building documentation. -- Use Rust to compile and install Polars in your virtual environment. +- Use Rust to compile and install Polars in your virtual environment. _At least 8GB of RAM is recommended for this step to run smoothly._ - Use [pytest](https://docs.pytest.org/) to run the Python unittests in your virtual environment Check if linting also works correctly by running: @@ -148,12 +147,69 @@ If you are stuck or unsure about your solution, feel free to open a draft pull r ## Contributing to documentation -The most important components of Polars documentation are the [user guide](https://pola-rs.github.io/polars-book/user-guide/), the API references, and the database of questions on [StackOverflow](https://stackoverflow.com/). +The most important components of Polars documentation are the [user guide](https://pola-rs.github.io/polars/user-guide/), the API references, and the database of questions on [StackOverflow](https://stackoverflow.com/). ### User guide -The user guide is maintained in the [polars-book](https://github.com/pola-rs/polars-book) repository. -For contributing to the user guide, please refer to the [contributing guide](https://github.com/pola-rs/polars-book/blob/master/CONTRIBUTING.md) in that repository. +The user guide is maintained in the `docs/user-guide` folder. Before creating a PR first raise an issue to discuss what you feel is missing or could be improved. + +#### Building and serving the user guide + +The user guide is built using [MkDocs](https://www.mkdocs.org/). You install the dependencies for building the user guide by running `make requirements` in the root of the repo. + +Run `mkdocs serve` to build and serve the user guide so you can view it locally and see updates as you make changes. + +#### Creating a new user guide page + +Each user guide page is based on a `.md` markdown file. This file must be listed in `mkdocs.yml`. + +#### Adding a shell code block + +To add a code block with code to be run in a shell with tabs for Python and Rust, use the following format: + +```` +=== ":fontawesome-brands-python: Python" + + ```shell + $ pip install fsspec + ``` + +=== ":fontawesome-brands-rust: Rust" + + ```shell + $ cargo add aws_sdk_s3 + ``` +```` + +#### Adding a code block + +The snippets for Python and Rust code blocks are in the `docs/src/python/` and `docs/src/rust/` directories, respectively. To add a code snippet with Python or Rust code to a `.md` page, use the following format: + +``` +{{code_block('user-guide/io/cloud-storage','read_parquet',[read_parquet,read_csv])}} +``` + +- The first argument is a path to either or both files called `docs/src/python/user-guide/io/cloud-storage.py` and `docs/src/rust/user-guide/io/cloud-storage.rs`. +- The second argument is the name given at the start and end of each snippet in the `.py` or `.rs` file +- The third argument is a list of links to functions in the API docs. For each element of the list there must be a corresponding entry in `docs/_build/API_REFERENCE_LINKS.yml` + +If the corresponding `.py` and `.rs` snippet files both exist then each snippet named in the second argument to `code_block` above must exist or the build will fail. An empty snippet should be added to the `.py` or `.rs` file if the snippet is not needed. + +Each snippet is formatted as follows: + +```python +# --8<-- [start:read_parquet] +import polars as pl + +df = pl.read_parquet("file.parquet") +# --8<-- [end:read_parquet] +``` + +The snippet is delimited by `--8<-- [start:]` and `--8<-- [end:]`. The snippet name must match the name given in the second argument to `code_block` above. + +#### Linting + +Before committing, install `dprint` (see above) and run `dprint fmt` from the `docs` directory to lint the markdown files. ### API reference @@ -181,10 +237,6 @@ The resulting HTML files will be in `py-polars/docs/build/html`. New additions to the API should be added manually to the API reference by adding an entry to the correct `.rst` file in the `py-polars/docs/source/reference` directory. -#### Node.js - -For contributions to Node.js Polars, please refer to the official [Node.js Polars repository](https://github.com/pola-rs/nodejs-polars). - ### StackOverflow We use StackOverflow to create a database of high quality questions and answers that is searchable and remains up-to-date. @@ -192,7 +244,6 @@ There is a separate tag for each language: - [Python Polars](https://stackoverflow.com/questions/tagged/python-polars) - [Rust Polars](https://stackoverflow.com/questions/tagged/rust-polars) -- [Node.js Polars](https://stackoverflow.com/questions/tagged/nodejs-polars) Contributions in the form of well-formulated questions or answers are always welcome! If you add a new question, please notify us by adding a [matching issue](https://github.com/pola-rs/polars/issues/new?&labels=question&template=question.yml) to our GitHub issue tracker. @@ -225,21 +276,14 @@ Start by bumping the version number in the source code: Directly after merging your pull request, release the new version: -8. Go back to the [releases page](https://github.com/pola-rs/polars/releases) and click _Edit_ on the appropriate draft release. -9. On the draft release page, click _Publish release_. This will create a new release and a new tag, which will trigger the GitHub Actions release workflow ([Python](https://github.com/pola-rs/polars/actions/workflows/release-python.yml) / [Rust](https://github.com/pola-rs/polars/actions/workflows/release-rust.yml)). -10. Wait for all release jobs to finish, then check [crates.io](https://crates.io/crates/polars)/[PyPI](https://pypi.org/project/polars/) to verify that the new Polars release is now available. +8. Go to the release workflow ([Python](https://github.com/pola-rs/polars/actions/workflows/release-python.yml)/[Rust](https://github.com/pola-rs/polars/actions/workflows/release-rust.yml)), click _Run workflow_ in the top right, and click the green button. This will trigger the workflow, which will build all release artifacts and publish them. +9. Wait for the workflow to finish, then check [crates.io](https://crates.io/crates/polars)/[PyPI](https://pypi.org/project/polars/)/[GitHub](https://github.com/pola-rs/polars/releases) to verify that the new Polars release is now available. ### Troubleshooting It may happen that one or multiple release jobs fail. If so, you should first try to simply re-run the failed jobs from the GitHub Actions UI. -If that doesn't help, you will have to figure out what's wrong and commit a fix. Once your fix has made it to the `main` branch, re-trigger the release workflow by updating the git tag associated with the release. Note the commit hash of your fix, and run the following command: - -```shell -git tag -f && git push -f origin -``` - -This will update the tag to point to the commit of your fix. The release workflows will re-trigger and hopefully succeed this time! +If that doesn't help, you will have to figure out what's wrong and commit a fix. Once your fix has made it to the `main` branch, simply re-trigger the release workflow. ## License diff --git a/Cargo.toml b/Cargo.toml index 70e2b03da866..69dbed1471d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ exclude = [ ] [workspace.package] -version = "0.32.0" +version = "0.33.2" authors = ["Ritchie Vink "] edition = "2021" homepage = "https://www.pola.rs/" @@ -24,39 +24,73 @@ license = "MIT" ahash = "0.8" atoi = "2" bitflags = "2" -chrono = { version = "0.4", default-features = false, features = ["std"] } +bytemuck = { version = "1", features = ["derive", "extern_crate_alloc"] } +chrono = { version = "0.4.31", default-features = false, features = ["std"] } chrono-tz = "0.8.1" ciborium = "0.2" -either = "1.8" +crossbeam-channel = "0.5.1" +either = "1.9" +ethnum = "1.3.2" futures = "0.3.25" hashbrown = { version = "0.14", features = ["rayon", "ahash"] } indexmap = { version = "2", features = ["std"] } -memchr = "2.6.1" +memchr = "2.6" multiversion = "0.7" num-traits = "0.2" -object_store = { version = "0.6", default-features = false } +object_store = { version = "0.7", default-features = false } once_cell = "1" -pyo3 = "0.19" +pyo3 = "0.20" rand = "0.8" -rayon = "1.6" -regex = "1.7.1" -serde = "1.0.160" +rayon = "1.8" +regex = "1.9" +serde = "1.0.188" serde_json = "1" -simd-json = { version = "0.10", features = ["allow-non-simd", "known-key"] } +simd-json = { version = "0.11", features = ["allow-non-simd", "known-key"] } smartstring = "1" -sqlparser = "0.36" +sqlparser = "0.38" strum_macros = "0.25" thiserror = "1" -url = "2.3.1" +tokio = "1.26" +tokio-util = "0.7.8" +url = "2.4" version_check = "0.9.4" +simdutf8 = "0.1.4" +hex = "0.4.3" +base64 = "0.21.2" +fallible-streaming-iterator = "0.1.9" +streaming-iterator = "0.1.9" +itoa = "1.0.6" +ryu = "1.0.13" +lexical-core = "0.8.5" +percent-encoding = "2.3" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } +polars-core = { version = "0.33.2", path = "crates/polars-core", default-features = false } +polars-arrow = { version = "0.33.2", path = "crates/polars-arrow", default-features = false } +polars-plan = { version = "0.33.2", path = "crates/polars-plan", default-features = false } +polars-lazy = { version = "0.33.2", path = "crates/polars-lazy", default-features = false } +polars-pipe = { version = "0.33.2", path = "crates/polars-pipe", default-features = false } +polars-row = { version = "0.33.2", path = "crates/polars-row", default-features = false } +polars-ffi = { version = "0.33.2", path = "crates/polars-ffi", default-features = false } +polars-ops = { version = "0.33.2", path = "crates/polars-ops", default-features = false } +polars-sql = { version = "0.33.2", path = "crates/polars-sql", default-features = false } +polars-algo = { version = "0.33.2", path = "crates/polars-algo", default-features = false } +polars-time = { version = "0.33.2", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.33.2", path = "crates/polars-utils", default-features = false } +polars-io = { version = "0.33.2", path = "crates/polars-io", default-features = false } +polars-error = { version = "0.33.2", path = "crates/polars-error", default-features = false } +polars-json = { version = "0.33.2", path = "crates/polars-json", default-features = false } +polars = { version = "0.33.2", path = "crates/polars", default-features = false } +rand_distr = "0.4" +reqwest = { version = "0.11", default-features = false } +arrow-array = { version = ">=41", default-features = false } +arrow-buffer = { version = ">=41", default-features = false } +arrow-data = { version = ">=41", default-features = false } +arrow-schema = { version = ">=41", default-features = false } [workspace.dependencies.arrow] -package = "arrow2" -git = "https://github.com/jorgecarleitao/arrow2" -rev = "ba6a882bc1542b0b899774b696ebea77482b5c31" -# branch = "" -# version = "0.17.4" +package = "nano-arrow" +version = "0.1.0" +path = "crates/nano-arrow" default-features = false features = [ "compute_aggregate", diff --git a/Makefile b/Makefile index 532342913f97..67e9044143b1 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,11 @@ requirements: .venv ## Install/refresh Python project requirements $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-dev.txt $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-lint.txt $(VENV_BIN)/pip install --upgrade -r py-polars/docs/requirements-docs.txt + $(VENV_BIN)/pip install --upgrade -r docs/requirements.txt + +.PHONY: build-python +build-python: .venv ## Compile and install Python Polars for development + @$(MAKE) -s -C py-polars build .PHONY: clean clean: ## Clean up caches and build artifacts @@ -32,4 +37,4 @@ clean: ## Clean up caches and build artifacts .PHONY: help help: ## Display this help screen @echo -e "\033[1mAvailable commands:\033[0m" - @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort diff --git a/README.md b/README.md index 1b8f03a3d373..b381350fce96 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ - R | - User Guide + User Guide | Discord

@@ -58,7 +58,7 @@ Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Ru - Hybrid Streaming (larger than RAM datasets) - Rust | Python | NodeJS | R | ... -To learn more, read the [User Guide](https://pola-rs.github.io/polars-book/). +To learn more, read the [User Guide](https://pola-rs.github.io/polars/). ## Python @@ -206,7 +206,9 @@ You can also install the dependencies directly. | fsspec | Support for reading from remote file systems | | connectorx | Support for reading from SQL databases | | xlsx2csv | Support for reading from Excel files | +| openpyxl | Support for reading from Excel files with native types | | deltalake | Support for reading from Delta Lake Tables | +| pyiceberg | Support for reading from Apache Iceberg tables | | timezone | Timezone support, only needed if are on Python<3.9 or you are on Windows | Releases happen quite often (weekly / every few days) at the moment, so updating polars regularly to get the latest bugfixes / features might not be a bad idea. @@ -220,7 +222,7 @@ point to the `main` branch of this repo. polars = { git = "https://github.com/pola-rs/polars", rev = "" } ``` -Required Rust version `>=1.65`. +Required Rust version `>=1.71`. ## Contributing @@ -262,15 +264,11 @@ Don't use this unless you hit the row boundary as the default polars is faster a ## Legacy -Do you want polars to run on an old CPU (e.g. dating from before 2011)? Install `pip install polars-lts-cpu`. This polars project is -compiled without [avx](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) target features. - -## Acknowledgements - -Development of Polars is proudly powered by - -[![Xomnia](https://raw.githubusercontent.com/pola-rs/polars-static/master/sponsors/xomnia.png)](https://www.xomnia.com/) +Do you want polars to run on an old CPU (e.g. dating from before 2011), or on an `x86-64` build +of Python on Apple Silicon under Rosetta? Install `pip install polars-lts-cpu`. This version of +polars is compiled without [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) target +features. ## Sponsors -[](https://www.xomnia.com/)   [](https://www.jetbrains.com) +[](https://www.jetbrains.com) diff --git a/_typos.toml b/_typos.toml index 12406b2f4ea8..4d9ec510b278 100644 --- a/_typos.toml +++ b/_typos.toml @@ -7,6 +7,7 @@ extend-ignore-identifiers-re = [ ba = "ba" Fo = "Fo" nd = "nd" +ND = "ND" opt_nd = "opt_nd" ser = "ser" strat = "strat" diff --git a/crates/Makefile b/crates/Makefile index ad74b606bb9a..594016dac57b 100644 --- a/crates/Makefile +++ b/crates/Makefile @@ -10,22 +10,22 @@ fmt: ## Run rustfmt and dprint .PHONY: check check: ## Run cargo check with all features - cargo check --workspace --all-targets --all-features + cargo check --workspace --all-targets --exclude nano-arrow --all-features .PHONY: clippy clippy: ## Run clippy with all features - cargo clippy --workspace --all-targets --all-features + cargo clippy -p polars --all-features .PHONY: clippy-default clippy-default: ## Run clippy with default features - cargo clippy --workspace --all-targets + cargo clippy -p polars .PHONY: pre-commit pre-commit: fmt clippy clippy-default ## Run autoformatting and linting .PHONY: check-features check-features: ## Run cargo check for feature flag combinations (warning: slow) - cargo hack check --each-feature --no-dev-deps + cargo hack check -p polars --each-feature --no-dev-deps .PHONY: miri miri: ## Run miri @@ -35,7 +35,6 @@ miri: ## Run miri MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-ignore-leaks -Zmiri-disable-stacked-borrows" \ POLARS_ALLOW_EXTENSION=1 \ cargo miri test \ - --no-default-features \ --features object \ -p polars-core \ -p polars-arrow @@ -51,10 +50,25 @@ test: ## Run tests -p polars-utils \ -p polars-row \ -p polars-sql \ + -p polars-ops \ -p polars-plan \ -- \ --test-threads=2 +.PHONY: nextest +nextest: ## Run tests with nextest + cargo nextest run --all-features \ + -p polars-lazy \ + -p polars-io \ + -p polars-core \ + -p polars-arrow \ + -p polars-time \ + -p polars-utils \ + -p polars-row \ + -p polars-sql \ + -p polars-ops \ + -p polars-plan \ + .PHONY: integration-tests integration-tests: ## Run integration tests cargo test --all-features --test it -p polars @@ -96,6 +110,7 @@ publish: ## Publish Polars crates cargo publish --allow-dirty -p polars-arrow cargo publish --allow-dirty -p polars-json cargo publish --allow-dirty -p polars-core + cargo publish --allow-dirty -p polars-ffi cargo publish --allow-dirty -p polars-ops cargo publish --allow-dirty -p polars-time cargo publish --allow-dirty -p polars-io @@ -118,6 +133,9 @@ check-wasm: ## Check wasm build without supported features --exclude-features async \ --exclude-features aws \ --exclude-features azure \ + --exclude-features cloud \ + --exclude-features cloud_write \ + --exclude-features decompress \ --exclude-features decompress-fast \ --exclude-features default \ --exclude-features docs-selection \ diff --git a/crates/nano-arrow/Cargo.toml b/crates/nano-arrow/Cargo.toml new file mode 100644 index 000000000000..641f569a86c5 --- /dev/null +++ b/crates/nano-arrow/Cargo.toml @@ -0,0 +1,200 @@ +[package] +name = "nano-arrow" +version = "0.1.0" +authors = [ + "Jorge C. Leitao ", + "Apache Arrow ", + "Ritchie Vink ", +] +edition = { workspace = true } +homepage = { workspace = true } +license = "Apache 2.0 AND MIT" +repository = { workspace = true } +description = "Minimal implementation of the Arrow specification forked from arrow2." + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytemuck = { workspace = true } +chrono = { workspace = true } +# for timezone support +chrono-tz = { workspace = true, optional = true } +dyn-clone = { version = "1" } +either = { workspace = true } +foreign_vec = { version = "0.1" } +hashbrown = { workspace = true } +num-traits = { workspace = true } +simdutf8 = { workspace = true } + +# for decimal i256 +ethnum = { workspace = true } + +# To efficiently cast numbers to strings +lexical-core = { workspace = true, optional = true } + +fallible-streaming-iterator = { workspace = true, optional = true } +regex = { workspace = true, optional = true } +regex-syntax = { version = "0.7", optional = true } +streaming-iterator = { workspace = true } + +indexmap = { workspace = true, optional = true } + +arrow-format = { version = "0.8", optional = true, features = ["ipc"] } + +hex = { workspace = true, optional = true } + +# for IPC compression +lz4 = { version = "1.24", optional = true } +zstd = { version = "0.12", optional = true } + +base64 = { workspace = true, optional = true } + +# to write to parquet as a stream +futures = { workspace = true, optional = true } + +# to read IPC as a stream +async-stream = { version = "0.3.2", optional = true } + +# avro support +avro-schema = { version = "0.3", optional = true } + +# for division/remainder optimization at runtime +strength_reduce = { version = "0.2", optional = true } + +# For instruction multiversioning +multiversion = { workspace = true, optional = true } + +# Faster hashing +ahash = { workspace = true } + +# Support conversion to/from arrow-rs +arrow-array = { workspace = true, optional = true } +arrow-buffer = { workspace = true, optional = true } +arrow-data = { workspace = true, optional = true } +arrow-schema = { workspace = true, optional = true } + +# parquet support +[dependencies.parquet2] +version = "0.17" +optional = true +default_features = false +features = ["async"] + +[dev-dependencies] +avro-rs = { version = "0.13", features = ["snappy"] } +criterion = "0.5" +crossbeam-channel = { workspace = true } +doc-comment = "0.3" +flate2 = "1" +# used to run formal property testing +proptest = { version = "1", default_features = false, features = ["std"] } +# use for flaky testing +rand = { workspace = true } +# use for generating and testing random data samples +sample-arrow2 = "0.17" +sample-std = "0.1" +sample-test = "0.1" +# used to test async readers +tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { workspace = true, features = ["compat"] } + +[build-dependencies] +rustc_version = "0.4.0" + +[target.wasm32-unknown-unknown.dependencies] +getrandom = { version = "0.2", features = ["js"] } + +[features] +default = [] +full = [ + "arrow_rs", + "io_ipc", + "io_flight", + "io_ipc_write_async", + "io_ipc_read_async", + "io_ipc_compression", + "io_parquet", + "io_parquet_compression", + "io_avro", + "io_avro_compression", + "io_avro_async", + "regex-syntax", + "compute", + # parses timezones used in timestamp conversions + "chrono-tz", +] +arrow_rs = ["arrow-buffer", "arrow-schema", "arrow-data", "arrow-array"] +io_ipc = ["arrow-format"] +io_ipc_write_async = ["io_ipc", "futures"] +io_ipc_read_async = ["io_ipc", "futures", "async-stream"] +io_ipc_compression = ["lz4", "zstd"] +io_flight = ["io_ipc", "arrow-format/flight-data"] + +# base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. +io_parquet = ["parquet2", "io_ipc", "base64", "futures", "fallible-streaming-iterator"] + +io_parquet_compression = [ + "io_parquet_zstd", + "io_parquet_gzip", + "io_parquet_snappy", + "io_parquet_lz4", + "io_parquet_brotli", +] + +# sample testing of generated arrow data +io_parquet_sample_test = ["io_parquet"] + +# compression backends +io_parquet_zstd = ["parquet2/zstd"] +io_parquet_snappy = ["parquet2/snappy"] +io_parquet_gzip = ["parquet2/gzip"] +io_parquet_lz4 = ["parquet2/lz4"] +io_parquet_brotli = ["parquet2/brotli"] + +# parquet bloom filter functions +io_parquet_bloom_filter = ["parquet2/bloom_filter"] + +io_avro = ["avro-schema"] +io_avro_compression = [ + "avro-schema/compression", +] +io_avro_async = ["avro-schema/async"] + +# the compute kernels. Disabling this significantly reduces compile time. +compute_aggregate = ["multiversion"] +compute_arithmetics_decimal = ["strength_reduce"] +compute_arithmetics = ["strength_reduce", "compute_arithmetics_decimal"] +compute_bitwise = [] +compute_boolean = [] +compute_boolean_kleene = [] +compute_cast = ["lexical-core", "compute_take"] +compute_comparison = ["compute_take", "compute_boolean"] +compute_concatenate = [] +compute_filter = [] +compute_hash = ["multiversion"] +compute_if_then_else = [] +compute_take = [] +compute_temporal = [] +compute = [ + "compute_aggregate", + "compute_arithmetics", + "compute_bitwise", + "compute_boolean", + "compute_boolean_kleene", + "compute_cast", + "compute_comparison", + "compute_concatenate", + "compute_filter", + "compute_hash", + "compute_if_then_else", + "compute_take", + "compute_temporal", +] +simd = [] + +[package.metadata.docs.rs] +features = ["full"] +rustdoc-args = ["--cfg", "docsrs"] + +[package.metadata.cargo-all-features] +allowlist = ["compute", "compute_sort", "compute_hash", "compute_nullif"] diff --git a/crates/nano-arrow/src/README.md b/crates/nano-arrow/src/README.md new file mode 100644 index 000000000000..d6371ebc8741 --- /dev/null +++ b/crates/nano-arrow/src/README.md @@ -0,0 +1,32 @@ +# Crate's design + +This document describes the design of this module, and thus the overall crate. +Each module MAY have its own design document, that concerns specifics of that module, and if yes, +it MUST be on each module's `README.md`. + +## Equality + +Array equality is not defined in the Arrow specification. This crate follows the intent of the specification, but there is no guarantee that this no verification that this equals e.g. C++'s definition. + +There is a single source of truth about whether two arrays are equal, and that is via their +equality operators, defined on the module [`array/equal`](array/equal/mod.rs). + +Implementation MUST use these operators for asserting equality, so that all testing follows the same definition of array equality. + +## Error handling + +- Errors from an external dependency MUST be encapsulated on `External`. +- Errors from IO MUST be encapsulated on `Io`. +- This crate MAY return `NotYetImplemented` when the functionality does not exist, or it MAY panic with `unimplemented!`. + +## Logical and physical types + +There is a strict separation between physical and logical types: + +- physical types MUST be implemented via generics +- logical types MUST be implemented via variables (whose value is e.g. an `enum`) +- logical types MUST be declared and implemented on the `datatypes` module + +## Source of undefined behavior + +There is one, and only one, acceptable source of undefined behavior: FFI. It is impossible to prove that data passed via pointers are safe for consumption (only a promise from the specification). diff --git a/crates/nano-arrow/src/array/README.md b/crates/nano-arrow/src/array/README.md new file mode 100644 index 000000000000..af21f91e02ef --- /dev/null +++ b/crates/nano-arrow/src/array/README.md @@ -0,0 +1,73 @@ +# Array module + +This document describes the overall design of this module. + +## Notation: + +- "array" in this module denotes any struct that implements the trait `Array`. +- "mutable array" in this module denotes any struct that implements the trait `MutableArray`. +- words in `code` denote existing terms on this implementation. + +## Arrays: + +- Every arrow array with a different physical representation MUST be implemented as a struct or generic struct. + +- An array MAY have its own module. E.g. `primitive/mod.rs` + +- An array with a null bitmap MUST implement it as `Option` + +- An array MUST be `#[derive(Clone)]` + +- The trait `Array` MUST only be implemented by structs in this module. + +- Every child array on the struct MUST be `Box`. + +- An array MUST implement `try_new(...) -> Self`. This method MUST error iff + the data does not follow the arrow specification, including any sentinel types such as utf8. + +- An array MAY implement `unsafe try_new_unchecked` that skips validation steps that are `O(N)`. + +- An array MUST implement either `new_empty()` or `new_empty(DataType)` that returns a zero-len of `Self`. + +- An array MUST implement either `new_null(length: usize)` or `new_null(DataType, length: usize)` that returns a valid array of length `length` whose all elements are null. + +- An array MAY implement `value(i: usize)` that returns the value at slot `i` ignoring the validity bitmap. + +- functions to create new arrays from native Rust SHOULD be named as follows: + - `from`: from a slice of optional values (e.g. `AsRef<[Option]` for `BooleanArray`) + - `from_slice`: from a slice of values (e.g. `AsRef<[bool]>` for `BooleanArray`) + - `from_trusted_len_iter` from an iterator of trusted len of optional values + - `from_trusted_len_values_iter` from an iterator of trusted len of values + - `try_from_trusted_len_iter` from an fallible iterator of trusted len of optional values + +### Slot offsets + +- An array MUST have a `offset: usize` measuring the number of slots that the array is currently offsetted by if the specification requires. + +- An array MUST implement `fn slice(&self, offset: usize, length: usize) -> Self` that returns an offsetted and/or truncated clone of the array. This function MUST increase the array's offset if it exists. + +- Conversely, `offset` MUST only be changed by `slice`. + +The rational of the above is that it enable us to be fully interoperable with the offset logic supported by the C data interface, while at the same time easily perform array slices +within Rust's type safety mechanism. + +### Mutable Arrays + +- An array MAY have a mutable counterpart. E.g. `MutablePrimitiveArray` is the mutable counterpart of `PrimitiveArray`. + +- Arrays with mutable counterparts MUST have its own module, and have the mutable counterpart declared in `{module}/mutable.rs`. + +- The trait `MutableArray` MUST only be implemented by mutable arrays in this module. + +- A mutable array MUST be `#[derive(Debug)]` + +- A mutable array with a null bitmap MUST implement it as `Option` + +- Converting a `MutableArray` to its immutable counterpart MUST be `O(1)`. Specifically: + - it must not allocate + - it must not cause `O(N)` data transformations + + This is achieved by converting mutable versions to immutable counterparts (e.g. `MutableBitmap -> Bitmap`). + + The rational is that `MutableArray`s can be used to perform in-place operations under + the arrow spec. diff --git a/crates/nano-arrow/src/array/binary/data.rs b/crates/nano-arrow/src/array/binary/data.rs new file mode 100644 index 000000000000..56835dec0c42 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/data.rs @@ -0,0 +1,43 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, BinaryArray}; +use crate::bitmap::Bitmap; +use crate::offset::{Offset, OffsetsBuffer}; + +impl Arrow2Arrow for BinaryArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.offsets().len_proxy()) + .buffers(vec![ + self.offsets.clone().into_inner().into(), + self.values.clone().into(), + ]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let buffers = data.buffers(); + + // Safety: ArrayData is valid + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type, + offsets, + values: buffers[1].clone().into(), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/binary/ffi.rs b/crates/nano-arrow/src/array/binary/ffi.rs new file mode 100644 index 000000000000..3ba66cc130da --- /dev/null +++ b/crates/nano-arrow/src/array/binary/ffi.rs @@ -0,0 +1,63 @@ +use super::BinaryArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for BinaryArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for BinaryArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let values = unsafe { array.buffer::(2) }?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new(data_type, offsets, values, validity)) + } +} diff --git a/crates/nano-arrow/src/array/binary/fmt.rs b/crates/nano-arrow/src/array/binary/fmt.rs new file mode 100644 index 000000000000..d2a6788ce4d8 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/fmt.rs @@ -0,0 +1,26 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BinaryArray; +use crate::offset::Offset; + +pub fn write_value(array: &BinaryArray, index: usize, f: &mut W) -> Result { + let bytes = array.value(index); + let writer = |f: &mut W, index| write!(f, "{}", bytes[index]); + + write_vec(f, writer, None, bytes.len(), "None", false) +} + +impl Debug for BinaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + let head = if O::IS_LARGE { + "LargeBinaryArray" + } else { + "BinaryArray" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/binary/from.rs b/crates/nano-arrow/src/array/binary/from.rs new file mode 100644 index 000000000000..73df03531594 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/from.rs @@ -0,0 +1,11 @@ +use std::iter::FromIterator; + +use super::{BinaryArray, MutableBinaryArray}; +use crate::offset::Offset; + +impl> FromIterator> for BinaryArray { + #[inline] + fn from_iter>>(iter: I) -> Self { + MutableBinaryArray::::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/binary/iterator.rs b/crates/nano-arrow/src/array/binary/iterator.rs new file mode 100644 index 000000000000..3fccec58eb50 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/iterator.rs @@ -0,0 +1,42 @@ +use super::{BinaryArray, MutableBinaryValuesArray}; +use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for BinaryArray { + type Item = &'a [u8]; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`BinaryArray`]. +pub type BinaryValueIter<'a, O> = ArrayValuesIter<'a, BinaryArray>; + +impl<'a, O: Offset> IntoIterator for &'a BinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], BinaryValueIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// Iterator of values of an [`MutableBinaryValuesArray`]. +pub type MutableBinaryValuesIter<'a, O> = ArrayValuesIter<'a, MutableBinaryValuesArray>; + +impl<'a, O: Offset> IntoIterator for &'a MutableBinaryValuesArray { + type Item = &'a [u8]; + type IntoIter = MutableBinaryValuesIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/nano-arrow/src/array/binary/mod.rs b/crates/nano-arrow/src/array/binary/mod.rs new file mode 100644 index 000000000000..94cebf85ca4a --- /dev/null +++ b/crates/nano-arrow/src/array/binary/mod.rs @@ -0,0 +1,424 @@ +use either::Either; + +use super::specification::try_check_offsets_bounds; +use super::{Array, GenericBinaryArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; + +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod from; +mod mutable_values; +pub use mutable_values::*; +mod mutable; +pub use mutable::*; + +#[cfg(feature = "arrow_rs")] +mod data; + +/// A [`BinaryArray`] is Arrow's semantically equivalent of an immutable `Vec>>`. +/// It implements [`Array`]. +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use arrow2::array::BinaryArray; +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// +/// let array = BinaryArray::::from([Some([1, 2].as_ref()), None, Some([3].as_ref())]); +/// assert_eq!(array.value(0), &[1, 2]); +/// assert_eq!(array.iter().collect::>(), vec![Some([1, 2].as_ref()), None, Some([3].as_ref())]); +/// assert_eq!(array.values_iter().collect::>(), vec![[1, 2].as_ref(), &[], &[3]]); +/// // the underlying representation: +/// assert_eq!(array.values(), &Buffer::from(vec![1, 2, 3])); +/// assert_eq!(array.offsets().buffer(), &Buffer::from(vec![0, 2, 2, 3])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// ``` +/// +/// # Generic parameter +/// The generic parameter [`Offset`] can only be `i32` or `i64` and tradeoffs maximum array length with +/// memory usage: +/// * the sum of lengths of all elements cannot exceed `Offset::MAX` +/// * the total size of the underlying data is `array.len() * size_of::() + sum of lengths of all elements` +/// +/// # Safety +/// The following invariants hold: +/// * Two consecutives `offsets` casted (`as`) to `usize` are valid slices of `values`. +/// * `len` is equal to `validity.len()`, when defined. +#[derive(Clone)] +pub struct BinaryArray { + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, +} + +impl BinaryArray { + /// Returns a [`BinaryArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "BinaryArray can only be initialized with DataType::Binary or DataType::LargeBinary", + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Creates a new [`BinaryArray`] from slices of `&[u8]`. + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter()) + } + + /// Creates a new [`BinaryArray`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + MutableBinaryArray::::from(slice).into() + } + + /// Returns an iterator of `Option<&[u8]>` over every element of this array. + pub fn iter(&self) -> ZipValidity<&[u8], BinaryValueIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity.as_ref()) + } + + /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity + pub fn values_iter(&self) -> BinaryValueIter { + BinaryValueIter::new(self) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the element at index `i` + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + + // soundness: the invariant of the struct + self.values.get_unchecked(start..end) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&[u8]> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns the [`DataType`] of this array. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the values of this [`BinaryArray`]. + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the offsets of this [`BinaryArray`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Slices this [`BinaryArray`]. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`BinaryArray`]. + /// # Implementation + /// This function is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, OffsetsBuffer, Buffer, Option) { + let Self { + data_type, + offsets, + values, + validity, + } = self; + (data_type, offsets, values, validity) + } + + /// Try to convert this `BinaryArray` to a `MutableBinaryArray` + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + // Safety: invariants are preserved + Left(bitmap) => Left(BinaryArray::new( + self.data_type, + self.offsets, + self.values, + Some(bitmap), + )), + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets, + values, + Some(mutable_bitmap.into()), + )), + (Left(values), Right(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets.into(), + values, + Some(mutable_bitmap.into()), + )), + (Right(values), Left(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets, + values.into(), + Some(mutable_bitmap.into()), + )), + (Right(values), Right(offsets)) => Right( + MutableBinaryArray::try_new( + self.data_type, + offsets, + values, + Some(mutable_bitmap), + ) + .unwrap(), + ), + }, + } + } else { + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(BinaryArray::new(self.data_type, offsets, values, None)) + }, + (Left(values), Right(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets.into(), + values, + None, + )), + (Right(values), Left(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets, + values.into(), + None, + )), + (Right(values), Right(offsets)) => Right( + MutableBinaryArray::try_new(self.data_type, offsets, values, None).unwrap(), + ), + } + } + } + + /// Creates an empty [`BinaryArray`], i.e. whose `.len` is zero. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, OffsetsBuffer::new(), Buffer::new(), None) + } + + /// Creates an null [`BinaryArray`], i.e. whose `.null_count() == .len()`. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new( + data_type, + Offsets::new_zeroed(length).into(), + Buffer::new(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns the default [`DataType`], `DataType::Binary` or `DataType::LargeBinary` + pub fn default_data_type() -> DataType { + if O::IS_LARGE { + DataType::LargeBinary + } else { + DataType::Binary + } + } + + /// Alias for unwrapping [`Self::try_new`] + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, values, validity).unwrap() + } + + /// Returns a [`BinaryArray`] from an iterator of trusted length. + /// + /// The [`BinaryArray`] is guaranteed to not have a validity + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableBinaryArray::::from_trusted_len_values_iter(iterator).into() + } + + /// Returns a new [`BinaryArray`] from a [`Iterator`] of `&[u8]`. + /// + /// The [`BinaryArray`] is guaranteed to not have a validity + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableBinaryArray::::from_iter_values(iterator).into() + } + + /// Creates a [`BinaryArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + MutableBinaryArray::::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`BinaryArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`BinaryArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + MutableBinaryArray::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) + } + + /// Creates a [`BinaryArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } +} + +impl Array for BinaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +unsafe impl GenericBinaryArray for BinaryArray { + #[inline] + fn values(&self) -> &[u8] { + self.values() + } + + #[inline] + fn offsets(&self) -> &[O] { + self.offsets().buffer() + } +} diff --git a/crates/nano-arrow/src/array/binary/mutable.rs b/crates/nano-arrow/src/array/binary/mutable.rs new file mode 100644 index 000000000000..92521b400323 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/mutable.rs @@ -0,0 +1,469 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{BinaryArray, MutableBinaryValuesArray, MutableBinaryValuesIter}; +use crate::array::physical_binary::*; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// The Arrow's equivalent to `Vec>>`. +/// Converting a [`MutableBinaryArray`] into a [`BinaryArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableBinaryArray { + values: MutableBinaryValuesArray, + validity: Option, +} + +impl From> for BinaryArray { + fn from(other: MutableBinaryArray) -> Self { + let validity = other.validity.and_then(|x| { + let validity: Option = x.into(); + validity + }); + let array: BinaryArray = other.values.into(); + array.with_validity(validity) + } +} + +impl Default for MutableBinaryArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBinaryArray { + /// Creates a new empty [`MutableBinaryArray`]. + /// # Implementation + /// This allocates a [`Vec`] of one element + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Returns a [`MutableBinaryArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + data_type: DataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Result { + let values = MutableBinaryValuesArray::try_new(data_type, offsets, values)?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity's length must be equal to the number of values", + )); + } + + Ok(Self { values, validity }) + } + + /// Creates a new [`MutableBinaryArray`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } + + fn default_data_type() -> DataType { + BinaryArray::::default_data_type() + } + + /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots and values. + /// # Implementation + /// This does not allocate the validity. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + values: MutableBinaryValuesArray::with_capacities(capacity, values), + validity: None, + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.values.reserve(additional, additional_values); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Pushes a new element to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + pub fn push>(&mut self, value: Option) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableBinaryArray`]. + /// This function returns `None` iff this array is empty + pub fn pop(&mut self) -> Option> { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + fn try_from_iter, I: IntoIterator>>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut primitive = Self::with_capacity(lower); + for item in iterator { + primitive.try_push(item.as_ref())? + } + Ok(primitive) + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity); + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: BinaryArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableBinaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + impl_mutable_array_mut_validity!(); +} + +impl MutableBinaryArray { + /// returns its values. + pub fn values(&self) -> &Vec { + self.values.values() + } + + /// returns its offsets. + pub fn offsets(&self) -> &Offsets { + self.values.offsets() + } + + /// Returns an iterator of `Option<&[u8]>` + pub fn iter(&self) -> ZipValidity<&[u8], MutableBinaryValuesIter, BitmapIter> { + ZipValidity::new(self.values_iter(), self.validity.as_ref().map(|x| x.iter())) + } + + /// Returns an iterator over the values of this array + pub fn values_iter(&self) -> MutableBinaryValuesIter { + self.values.iter() + } +} + +impl MutableArray for MutableBinaryArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: BinaryArray = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: BinaryArray = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + self.values.data_type() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&[u8]>(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator> for MutableBinaryArray { + fn from_iter>>(iter: I) -> Self { + Self::try_from_iter(iter).unwrap() + } +} + +impl MutableBinaryArray { + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + let (validity, offsets, values) = trusted_len_unzip(iterator); + + Self::try_new(Self::default_data_type(), offsets, values, validity).unwrap() + } + + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked, I: Iterator>( + iterator: I, + ) -> Self { + let (offsets, values) = trusted_len_values_iter(iterator); + Self::try_new(Self::default_data_type(), offsets, values, None).unwrap() + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_values_iter_unchecked(iterator) } + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + // soundness: assumed trusted len + let (mut validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + if validity.as_mut().unwrap().unset_bits() == 0 { + validity = None; + } + + Ok(Self::try_new(Self::default_data_type(), offsets, values, validity).unwrap()) + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of trusted length. + /// This differs from `extend_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen, + { + // Safety: The iterator is `TrustedLen` + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of values. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + let length = self.values.len(); + self.values.extend(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableBinaryArray`] from an `iterator` of values of trusted length. + /// This differs from `extend_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// # Safety + /// The `iterator` must be [`TrustedLen`] + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + let length = self.values.len(); + self.values.extend_trusted_len_unchecked(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // Safety: The iterator is `TrustedLen` + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + /// # Safety + /// The `iterator` must be [`TrustedLen`] + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + self.values + .extend_from_trusted_len_iter(self.validity.as_mut().unwrap(), iterator); + } + + /// Creates a new [`MutableBinaryArray`] from a [`Iterator`] of `&[u8]`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + let (offsets, values) = values_iter(iterator); + Self::try_new(Self::default_data_type(), offsets, values, None).unwrap() + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator, E>>, + T: AsRef<[u8]>, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend> for MutableBinaryArray { + fn extend>>(&mut self, iter: I) { + self.try_extend(iter).unwrap(); + } +} + +impl> TryExtend> for MutableBinaryArray { + fn try_extend>>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush> for MutableBinaryArray { + fn try_push(&mut self, value: Option) -> Result<()> { + match value { + Some(value) => { + self.values.try_push(value.as_ref())?; + + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(""); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } +} + +impl PartialEq for MutableBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableBinaryArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/nano-arrow/src/array/binary/mutable_values.rs b/crates/nano-arrow/src/array/binary/mutable_values.rs new file mode 100644 index 000000000000..08be36d6f38d --- /dev/null +++ b/crates/nano-arrow/src/array/binary/mutable_values.rs @@ -0,0 +1,374 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{BinaryArray, MutableBinaryArray}; +use crate::array::physical_binary::*; +use crate::array::specification::try_check_offsets_bounds; +use crate::array::{ + Array, ArrayAccessor, ArrayValuesIter, MutableArray, TryExtend, TryExtendFromSelf, TryPush, +}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`BinaryArray`]. It differs +/// from [`MutableBinaryArray`] in that it builds non-null [`BinaryArray`]. +#[derive(Debug, Clone)] +pub struct MutableBinaryValuesArray { + data_type: DataType, + offsets: Offsets, + values: Vec, +} + +impl From> for BinaryArray { + fn from(other: MutableBinaryValuesArray) -> Self { + BinaryArray::::new( + other.data_type, + other.offsets.into(), + other.values.into(), + None, + ) + } +} + +impl From> for MutableBinaryArray { + fn from(other: MutableBinaryValuesArray) -> Self { + MutableBinaryArray::::try_new(other.data_type, other.offsets, other.values, None) + .expect("MutableBinaryValuesArray is consistent with MutableBinaryArray") + } +} + +impl Default for MutableBinaryValuesArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBinaryValuesArray { + /// Returns an empty [`MutableBinaryValuesArray`]. + pub fn new() -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::new(), + values: Vec::::new(), + } + } + + /// Returns a [`MutableBinaryValuesArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new(data_type: DataType, offsets: Offsets, values: Vec) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "MutableBinaryValuesArray can only be initialized with DataType::Binary or DataType::LargeBinary", + )); + } + + Ok(Self { + data_type, + offsets, + values, + }) + } + + /// Returns the default [`DataType`] of this container: [`DataType::Utf8`] or [`DataType::LargeUtf8`] + /// depending on the generic [`Offset`]. + pub fn default_data_type() -> DataType { + BinaryArray::::default_data_type() + } + + /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::::with_capacity(capacity), + values: Vec::::with_capacity(values), + } + } + + /// returns its values. + #[inline] + pub fn values(&self) -> &Vec { + &self.values + } + + /// returns its offsets. + #[inline] + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// Reserves `additional` elements and `additional_values` on the values. + #[inline] + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.offsets.reserve(additional); + self.values.reserve(additional_values); + } + + /// Returns the capacity in number of items + pub fn capacity(&self) -> usize { + self.offsets.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Pushes a new item to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: T) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableBinaryValuesArray`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option> { + if self.len() == 0 { + return None; + } + self.offsets.pop()?; + let start = self.offsets.last().to_usize(); + let value = self.values.split_off(start); + Some(value.to_vec()) + } + + /// Returns the value of the element at index `i`. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end(i); + + // soundness: the invariant of the struct + self.values.get_unchecked(start..end) + } + + /// Returns an iterator of `&[u8]` + pub fn iter(&self) -> ArrayValuesIter { + ArrayValuesIter::new(self) + } + + /// Shrinks the capacity of the [`MutableBinaryValuesArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + } + + /// Extract the low-end APIs from the [`MutableBinaryValuesArray`]. + pub fn into_inner(self) -> (DataType, Offsets, Vec) { + (self.data_type, self.offsets, self.values) + } +} + +impl MutableArray for MutableBinaryValuesArray { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let (data_type, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(data_type, offsets.into(), values.into(), None).boxed() + } + + fn as_arc(&mut self) -> Arc { + let (data_type, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(data_type, offsets.into(), values.into(), None).arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&[u8]>(b"") + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator

for MutableBinaryValuesArray { + fn from_iter>(iter: I) -> Self { + let (offsets, values) = values_iter(iter.into_iter()); + Self::try_new(Self::default_data_type(), offsets, values).unwrap() + } +} + +impl MutableBinaryValuesArray { + pub(crate) unsafe fn extend_from_trusted_len_iter( + &mut self, + validity: &mut MutableBitmap, + iterator: I, + ) where + P: AsRef<[u8]>, + I: Iterator>, + { + extend_from_trusted_len_iter(&mut self.offsets, &mut self.values, validity, iterator); + } + + /// Extends the [`MutableBinaryValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableBinaryValuesArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + } + + /// Creates a [`MutableBinaryValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Returns a new [`MutableBinaryValuesArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator, + { + let (offsets, values) = trusted_len_values_iter(iterator); + Self::try_new(Self::default_data_type(), offsets, values).unwrap() + } + + /// Returns a new [`MutableBinaryValuesArray`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + pub fn try_from_iter, I: IntoIterator>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator>, + T: AsRef<[u8]>, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend for MutableBinaryValuesArray { + fn extend>(&mut self, iter: I) { + extend_from_values_iter(&mut self.offsets, &mut self.values, iter.into_iter()); + } +} + +impl> TryExtend for MutableBinaryValuesArray { + fn try_extend>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush for MutableBinaryValuesArray { + #[inline] + fn try_push(&mut self, value: T) -> Result<()> { + let bytes = value.as_ref(); + self.values.extend_from_slice(bytes); + self.offsets.try_push(bytes.len()) + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableBinaryValuesArray { + type Item = &'a [u8]; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +impl TryExtendFromSelf for MutableBinaryValuesArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + self.values.extend_from_slice(&other.values); + self.offsets.try_extend_from_self(&other.offsets) + } +} diff --git a/crates/nano-arrow/src/array/boolean/data.rs b/crates/nano-arrow/src/array/boolean/data.rs new file mode 100644 index 000000000000..e93aeb3b8d2b --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/data.rs @@ -0,0 +1,36 @@ +use arrow_buffer::{BooleanBuffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, BooleanArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::DataType; + +impl Arrow2Arrow for BooleanArray { + fn to_data(&self) -> ArrayData { + let buffer = NullBuffer::from(self.values.clone()); + + let builder = ArrayDataBuilder::new(arrow_schema::DataType::Boolean) + .len(buffer.len()) + .offset(buffer.offset()) + .buffers(vec![buffer.into_inner().into_inner()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + assert_eq!(data.data_type(), &arrow_schema::DataType::Boolean); + + let buffers = data.buffers(); + let buffer = BooleanBuffer::new(buffers[0].clone(), data.offset(), data.len()); + // Use NullBuffer to compute set count + let values = Bitmap::from_null_buffer(NullBuffer::new(buffer)); + + Self { + data_type: DataType::Boolean, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/boolean/ffi.rs b/crates/nano-arrow/src/array/boolean/ffi.rs new file mode 100644 index 000000000000..64f22de81d5d --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/ffi.rs @@ -0,0 +1,54 @@ +use super::BooleanArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for BooleanArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for BooleanArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.bitmap(1) }?; + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/boolean/fmt.rs b/crates/nano-arrow/src/array/boolean/fmt.rs new file mode 100644 index 000000000000..229a01cd3e03 --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/fmt.rs @@ -0,0 +1,17 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BooleanArray; + +pub fn write_value(array: &BooleanArray, index: usize, f: &mut W) -> Result { + write!(f, "{}", array.value(index)) +} + +impl Debug for BooleanArray { + fn fmt(&self, f: &mut Formatter) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + write!(f, "BooleanArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/boolean/from.rs b/crates/nano-arrow/src/array/boolean/from.rs new file mode 100644 index 000000000000..81a5395ccc06 --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/from.rs @@ -0,0 +1,15 @@ +use std::iter::FromIterator; + +use super::{BooleanArray, MutableBooleanArray}; + +impl]>> From

for BooleanArray { + fn from(slice: P) -> Self { + MutableBooleanArray::from(slice).into() + } +} + +impl>> FromIterator for BooleanArray { + fn from_iter>(iter: I) -> Self { + MutableBooleanArray::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/boolean/iterator.rs b/crates/nano-arrow/src/array/boolean/iterator.rs new file mode 100644 index 000000000000..8e914c98faab --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/iterator.rs @@ -0,0 +1,55 @@ +use super::super::MutableArray; +use super::{BooleanArray, MutableBooleanArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::IntoIter; + +impl<'a> IntoIterator for &'a BooleanArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl IntoIterator for BooleanArray { + type Item = Option; + type IntoIter = ZipValidity; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + let (_, values, validity) = self.into_inner(); + let values = values.into_iter(); + let validity = + validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.into_iter())); + ZipValidity::new(values, validity) + } +} + +impl<'a> IntoIterator for &'a MutableBooleanArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MutableBooleanArray { + /// Returns an iterator over the optional values of this [`MutableBooleanArray`]. + #[inline] + pub fn iter(&'a self) -> ZipValidity, BitmapIter<'a>> { + ZipValidity::new( + self.values().iter(), + self.validity().as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator over the values of this [`MutableBooleanArray`] + #[inline] + pub fn values_iter(&'a self) -> BitmapIter<'a> { + self.values().iter() + } +} diff --git a/crates/nano-arrow/src/array/boolean/mod.rs b/crates/nano-arrow/src/array/boolean/mod.rs new file mode 100644 index 000000000000..c1dd4785231e --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/mod.rs @@ -0,0 +1,384 @@ +use either::Either; + +use super::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod from; +mod iterator; +mod mutable; + +pub use iterator::*; +pub use mutable::*; + +/// A [`BooleanArray`] is Arrow's semantically equivalent of an immutable `Vec>`. +/// It implements [`Array`]. +/// +/// One way to think about a [`BooleanArray`] is `(DataType, Arc>, Option>>)` +/// where: +/// * the first item is the array's logical type +/// * the second is the immutable values +/// * the third is the immutable validity (whether a value is null or not as a bitmap). +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// +/// let array = BooleanArray::from([Some(true), None, Some(false)]); +/// assert_eq!(array.value(0), true); +/// assert_eq!(array.iter().collect::>(), vec![Some(true), None, Some(false)]); +/// assert_eq!(array.values_iter().collect::>(), vec![true, false, false]); +/// // the underlying representation +/// assert_eq!(array.values(), &Bitmap::from([true, false, false])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// +/// ``` +#[derive(Clone)] +pub struct BooleanArray { + data_type: DataType, + values: Bitmap, + validity: Option, +} + +impl BooleanArray { + /// The canonical method to create a [`BooleanArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + pub fn try_new( + data_type: DataType, + values: Bitmap, + validity: Option, + ) -> Result { + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != PhysicalType::Boolean { + return Err(Error::oos( + "BooleanArray can only be initialized with a DataType whose physical type is Boolean", + )); + } + + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Alias to `Self::try_new().unwrap()` + pub fn new(data_type: DataType, values: Bitmap, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Returns an iterator over the optional values of this [`BooleanArray`]. + #[inline] + pub fn iter(&self) -> ZipValidity { + ZipValidity::new_with_validity(self.values().iter(), self.validity()) + } + + /// Returns an iterator over the values of this [`BooleanArray`]. + #[inline] + pub fn values_iter(&self) -> BitmapIter { + self.values().iter() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// The values [`Bitmap`]. + /// Values on null slots are undetermined (they can be anything). + #[inline] + pub fn values(&self) -> &Bitmap { + &self.values + } + + /// Returns the optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the arrays' [`DataType`]. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the value at index `i` + /// # Panic + /// This function panics iff `i >= self.len()`. + #[inline] + pub fn value(&self, i: usize) -> bool { + self.values.get_bit(i) + } + + /// Returns the element at index `i` as bool + /// # Safety + /// Caller must be sure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + self.values.get_bit_unchecked(i) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Slices this [`BooleanArray`]. + /// # Implementation + /// This operation is `O(1)` as it amounts to increase up to two ref counts. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`BooleanArray`]. + /// # Implementation + /// This operation is `O(1)` as it amounts to increase two ref counts. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values.slice_unchecked(offset, length); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns a clone of this [`BooleanArray`] with new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(&self, values: Bitmap) -> Self { + let mut out = self.clone(); + out.set_values(values); + out + } + + /// Sets the values of this [`BooleanArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Bitmap) { + assert_eq!( + values.len(), + self.len(), + "values length must be equal to this arrays length" + ); + self.values = values; + } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics if the function modifies the length of the [`MutableBitmap`]. + pub fn apply_values_mut(&mut self, f: F) { + let values = std::mem::take(&mut self.values); + let mut values = values.make_mut(); + f(&mut values); + if let Some(validity) = &self.validity { + assert_eq!(validity.len(), values.len()); + } + self.values = values.into(); + } + + /// Try to convert this [`BooleanArray`] to a [`MutableBooleanArray`] + pub fn into_mut(self) -> Either { + use Either::*; + + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + Left(bitmap) => Left(BooleanArray::new(self.data_type, self.values, Some(bitmap))), + Right(mutable_bitmap) => match self.values.into_mut() { + Left(immutable) => Left(BooleanArray::new( + self.data_type, + immutable, + Some(mutable_bitmap.into()), + )), + Right(mutable) => Right( + MutableBooleanArray::try_new(self.data_type, mutable, Some(mutable_bitmap)) + .unwrap(), + ), + }, + } + } else { + match self.values.into_mut() { + Left(immutable) => Left(BooleanArray::new(self.data_type, immutable, None)), + Right(mutable) => { + Right(MutableBooleanArray::try_new(self.data_type, mutable, None).unwrap()) + }, + } + } + } + + /// Returns a new empty [`BooleanArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, Bitmap::new(), None) + } + + /// Returns a new [`BooleanArray`] whose all slots are null / `None`. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let bitmap = Bitmap::new_zeroed(length); + Self::new(data_type, bitmap.clone(), Some(bitmap)) + } + + /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. + #[inline] + pub fn from_trusted_len_values_iter>(iterator: I) -> Self { + MutableBooleanArray::from_trusted_len_values_iter(iterator).into() + } + + /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked>( + iterator: I, + ) -> Self { + MutableBooleanArray::from_trusted_len_values_iter_unchecked(iterator).into() + } + + /// Creates a new [`BooleanArray`] from a slice of `bool`. + #[inline] + pub fn from_slice>(slice: P) -> Self { + MutableBooleanArray::from_slice(slice).into() + } + + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + MutableBooleanArray::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + MutableBooleanArray::from_trusted_len_iter(iterator).into() + } + + /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: std::borrow::Borrow, + I: Iterator, E>>, + { + Ok(MutableBooleanArray::try_from_trusted_len_iter_unchecked(iterator)?.into()) + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + Ok(MutableBooleanArray::try_from_trusted_len_iter(iterator)?.into()) + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, Bitmap, Option) { + let Self { + data_type, + values, + validity, + } = self; + (data_type, values, validity) + } + + /// Creates a `[BooleanArray]` from its internal representation. + /// This is the inverted from `[BooleanArray::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + data_type: DataType, + values: Bitmap, + validity: Option, + ) -> Self { + Self { + data_type, + values, + validity, + } + } +} + +impl Array for BooleanArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/boolean/mutable.rs b/crates/nano-arrow/src/array/boolean/mutable.rs new file mode 100644 index 000000000000..9961cadcb2fd --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/mutable.rs @@ -0,0 +1,564 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::BooleanArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +/// The Arrow's equivalent to `Vec>`, but with `1/16` of its size. +/// Converting a [`MutableBooleanArray`] into a [`BooleanArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableBooleanArray { + data_type: DataType, + values: MutableBitmap, + validity: Option, +} + +impl From for BooleanArray { + fn from(other: MutableBooleanArray) -> Self { + BooleanArray::new( + other.data_type, + other.values.into(), + other.validity.map(|x| x.into()), + ) + } +} + +impl]>> From

for MutableBooleanArray { + /// Creates a new [`MutableBooleanArray`] out of a slice of Optional `bool`. + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } +} + +impl Default for MutableBooleanArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBooleanArray { + /// Creates an new empty [`MutableBooleanArray`]. + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// The canonical method to create a [`MutableBooleanArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + pub fn try_new( + data_type: DataType, + values: MutableBitmap, + validity: Option, + ) -> Result { + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != PhysicalType::Boolean { + return Err(Error::oos( + "MutableBooleanArray can only be initialized with a DataType whose physical type is Boolean", + )); + } + + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Creates an new [`MutableBooleanArray`] with a capacity of values. + pub fn with_capacity(capacity: usize) -> Self { + Self { + data_type: DataType::Boolean, + values: MutableBitmap::with_capacity(capacity), + validity: None, + } + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Pushes a new entry to [`MutableBooleanArray`]. + pub fn push(&mut self, value: Option) { + match value { + Some(value) => { + self.values.push(value); + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(false); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + } + + /// Pop an entry from [`MutableBooleanArray`]. + /// Note If the values is empty, this method will return None. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| value)) + .unwrap_or_else(|| Some(value)) + } + + /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len` which accepts in iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + { + // Safety: `I` is `TrustedLen` + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len_unchecked`, which accepts in iterator of optional values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + I: Iterator, + { + let (_, upper) = iterator.size_hint(); + let additional = + upper.expect("extend_trusted_len_values_unchecked requires an upper limit"); + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + + self.values.extend_from_trusted_len_iter_unchecked(iterator) + } + + /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + // Safety: `I` is `TrustedLen` + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: Iterator>, + { + if let Some(validity) = self.validity.as_mut() { + extend_trusted_len_unzip(iterator, validity, &mut self.values); + } else { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + + extend_trusted_len_unzip(iterator, &mut validity, &mut self.values); + + if validity.unset_bits() > 0 { + self.validity = Some(validity); + } + } + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: BooleanArray = self.into(); + Arc::new(a) + } +} + +/// Getters +impl MutableBooleanArray { + /// Returns its values. + pub fn values(&self) -> &MutableBitmap { + &self.values + } +} + +/// Setters +impl MutableBooleanArray { + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Panic + /// Panics iff index is larger than `self.len()`. + pub fn set(&mut self, index: usize, value: Option) { + self.values.set(index, value.unwrap_or_default()); + + if value.is_none() && self.validity.is_none() { + // When the validity is None, all elements so far are valid. When one of the elements is set of null, + // the validity must be initialized. + self.validity = Some(MutableBitmap::from_trusted_len_iter( + std::iter::repeat(true).take(self.len()), + )); + } + if let Some(x) = self.validity.as_mut() { + x.set(index, value.is_some()) + } + } +} + +/// From implementations +impl MutableBooleanArray { + /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. + #[inline] + pub fn from_trusted_len_values_iter>(iterator: I) -> Self { + Self::try_new( + DataType::Boolean, + MutableBitmap::from_trusted_len_iter(iterator), + None, + ) + .unwrap() + } + + /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked>( + iterator: I, + ) -> Self { + let mut mutable = MutableBitmap::new(); + mutable.extend_from_trusted_len_iter_unchecked(iterator); + MutableBooleanArray::try_new(DataType::Boolean, mutable, None).unwrap() + } + + /// Creates a new [`MutableBooleanArray`] from a slice of `bool`. + #[inline] + pub fn from_slice>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter().copied()) + } + + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + let (validity, values) = trusted_len_unzip(iterator); + + Self::try_new(DataType::Boolean, values, validity).unwrap() + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + // Safety: `I` is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: std::borrow::Borrow, + I: Iterator, E>>, + { + let (validity, values) = try_trusted_len_unzip(iterator)?; + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + Ok(Self::try_new(DataType::Boolean, values, validity).unwrap()) + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + // Safety: `I` is `TrustedLen` + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Shrinks the capacity of the [`MutableBooleanArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +/// Creates a Bitmap and an optional [`MutableBitmap`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, MutableBitmap) +where + P: std::borrow::Borrow, + I: Iterator>, +{ + let mut validity = MutableBitmap::new(); + let mut values = MutableBitmap::new(); + + extend_trusted_len_unzip(iterator, &mut validity, &mut values); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + (validity, values) +} + +/// Extends validity [`MutableBitmap`] and values [`MutableBitmap`] from an iterator of `Option`. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn extend_trusted_len_unzip( + iterator: I, + validity: &mut MutableBitmap, + values: &mut MutableBitmap, +) where + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_trusted_len_unzip requires an upper limit"); + + // Length of the array before new values are pushed, + // variable created for assertion post operation + let pre_length = values.len(); + + validity.reserve(additional); + values.reserve(additional); + + for item in iterator { + let item = if let Some(item) = item { + validity.push_unchecked(true); + *item.borrow() + } else { + validity.push_unchecked(false); + bool::default() + }; + values.push_unchecked(item); + } + + debug_assert_eq!( + values.len(), + pre_length + additional, + "Trusted iterator length was not accurately reported" + ); +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(MutableBitmap, MutableBitmap), E> +where + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut values = MutableBitmap::with_capacity(len); + + for item in iterator { + let item = if let Some(item) = item? { + null.push(true); + *item.borrow() + } else { + null.push(false); + false + }; + values.push(item); + } + assert_eq!( + values.len(), + len, + "Trusted iterator length was not accurately reported" + ); + values.set_len(len); + null.set_len(len); + + Ok((null, values)) +} + +impl>> FromIterator for MutableBooleanArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: MutableBitmap = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + false + } + }) + .collect(); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + MutableBooleanArray::try_new(DataType::Boolean, values, validity).unwrap() + } +} + +impl MutableArray for MutableBooleanArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: BooleanArray = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: BooleanArray = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl Extend> for MutableBooleanArray { + fn extend>>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + iter.for_each(|x| self.push(x)) + } +} + +impl TryExtend> for MutableBooleanArray { + /// This is infalible and is implemented for consistency with all other types + fn try_extend>>(&mut self, iter: I) -> Result<(), Error> { + self.extend(iter); + Ok(()) + } +} + +impl TryPush> for MutableBooleanArray { + /// This is infalible and is implemented for consistency with all other types + fn try_push(&mut self, item: Option) -> Result<(), Error> { + self.push(item); + Ok(()) + } +} + +impl PartialEq for MutableBooleanArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableBooleanArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { + self.values + .extend_from_slice_unchecked(slice, 0, other.values.len()); + } + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/data.rs b/crates/nano-arrow/src/array/dictionary/data.rs new file mode 100644 index 000000000000..ecc763c350b3 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/data.rs @@ -0,0 +1,49 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{ + from_data, to_data, Arrow2Arrow, DictionaryArray, DictionaryKey, PrimitiveArray, +}; +use crate::datatypes::{DataType, PhysicalType}; + +impl Arrow2Arrow for DictionaryArray { + fn to_data(&self) -> ArrayData { + let keys = self.keys.to_data(); + let builder = keys + .into_builder() + .data_type(self.data_type.clone().into()) + .child_data(vec![to_data(self.values.as_ref())]); + + // Safety: Dictionary is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let key = match data.data_type() { + arrow_schema::DataType::Dictionary(k, _) => k.as_ref(), + d => panic!("unsupported dictionary type {d}"), + }; + + let data_type = DataType::from(data.data_type().clone()); + assert_eq!( + data_type.to_physical_type(), + PhysicalType::Dictionary(K::KEY_TYPE) + ); + + let key_builder = ArrayDataBuilder::new(key.clone()) + .buffers(vec![data.buffers()[0].clone()]) + .offset(data.offset()) + .len(data.len()) + .nulls(data.nulls().cloned()); + + // Safety: Dictionary is valid + let key_data = unsafe { key_builder.build_unchecked() }; + let keys = PrimitiveArray::from_data(&key_data); + let values = from_data(&data.child_data()[0]); + + Self { + data_type, + keys, + values, + } + } +} diff --git a/crates/nano-arrow/src/array/dictionary/ffi.rs b/crates/nano-arrow/src/array/dictionary/ffi.rs new file mode 100644 index 000000000000..946c850c48b1 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/ffi.rs @@ -0,0 +1,41 @@ +use super::{DictionaryArray, DictionaryKey}; +use crate::array::{FromFfi, PrimitiveArray, ToFfi}; +use crate::error::Error; +use crate::ffi; + +unsafe impl ToFfi for DictionaryArray { + fn buffers(&self) -> Vec> { + self.keys.buffers() + } + + fn offset(&self) -> Option { + self.keys.offset() + } + + fn to_ffi_aligned(&self) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.to_ffi_aligned(), + values: self.values.clone(), + } + } +} + +impl FromFfi for DictionaryArray { + unsafe fn try_from_ffi(array: A) -> Result { + // keys: similar to PrimitiveArray, but the datatype is the inner one + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + let data_type = array.data_type().clone(); + + let keys = PrimitiveArray::::try_new(K::PRIMITIVE.into(), values, validity)?; + let values = array + .dictionary()? + .ok_or_else(|| Error::oos("Dictionary Array must contain a dictionary in ffi"))?; + let values = ffi::try_from(values)?; + + // the assumption of this trait + DictionaryArray::::try_new_unchecked(data_type, keys, values) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/fmt.rs b/crates/nano-arrow/src/array/dictionary/fmt.rs new file mode 100644 index 000000000000..b3ce55515902 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/fmt.rs @@ -0,0 +1,31 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::{DictionaryArray, DictionaryKey}; +use crate::array::Array; + +pub fn write_value( + array: &DictionaryArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let keys = array.keys(); + let values = array.values(); + + if keys.is_valid(index) { + let key = array.key_value(index); + get_display(values.as_ref(), null)(f, key) + } else { + write!(f, "{null}") + } +} + +impl Debug for DictionaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "DictionaryArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/iterator.rs b/crates/nano-arrow/src/array/dictionary/iterator.rs new file mode 100644 index 000000000000..68e95ca86fed --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/iterator.rs @@ -0,0 +1,67 @@ +use super::{DictionaryArray, DictionaryKey}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::scalar::Scalar; +use crate::trusted_len::TrustedLen; + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIter<'a, K: DictionaryKey> { + array: &'a DictionaryArray, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey> DictionaryValuesIter<'a, K> { + #[inline] + pub fn new(array: &'a DictionaryArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(self.array.value(old)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a, K: DictionaryKey> TrustedLen for DictionaryValuesIter<'a, K> {} + +impl<'a, K: DictionaryKey> DoubleEndedIterator for DictionaryValuesIter<'a, K> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(self.array.value(self.end)) + } + } +} + +type ValuesIter<'a, K> = DictionaryValuesIter<'a, K>; +type ZipIter<'a, K> = ZipValidity, ValuesIter<'a, K>, BitmapIter<'a>>; + +impl<'a, K: DictionaryKey> IntoIterator for &'a DictionaryArray { + type Item = Option>; + type IntoIter = ZipIter<'a, K>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/nano-arrow/src/array/dictionary/mod.rs b/crates/nano-arrow/src/array/dictionary/mod.rs new file mode 100644 index 000000000000..48d2334509e0 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/mod.rs @@ -0,0 +1,413 @@ +use std::hash::Hash; +use std::hint::unreachable_unchecked; + +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, IntegerType}; +use crate::error::Error; +use crate::scalar::{new_scalar, Scalar}; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +use crate::array::specification::check_indexes_unchecked; +mod typed_iterator; +mod value_map; + +pub use iterator::*; +pub use mutable::*; + +use super::primitive::PrimitiveArray; +use super::specification::check_indexes; +use super::{new_empty_array, new_null_array, Array}; +use crate::array::dictionary::typed_iterator::{DictValue, DictionaryValuesIterTyped}; + +/// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. +/// # Safety +/// +/// Any implementation of this trait must ensure that `always_fits_usize` only +/// returns `true` if all values succeeds on `value::try_into::().unwrap()`. +pub unsafe trait DictionaryKey: NativeType + TryInto + TryFrom + Hash { + /// The corresponding [`IntegerType`] of this key + const KEY_TYPE: IntegerType; + + /// Represents this key as a `usize`. + /// # Safety + /// The caller _must_ have checked that the value can be casted to `usize`. + #[inline] + unsafe fn as_usize(self) -> usize { + match self.try_into() { + Ok(v) => v, + Err(_) => unreachable_unchecked(), + } + } + + /// If the key type always can be converted to `usize`. + fn always_fits_usize() -> bool { + false + } +} + +unsafe impl DictionaryKey for i8 { + const KEY_TYPE: IntegerType = IntegerType::Int8; +} +unsafe impl DictionaryKey for i16 { + const KEY_TYPE: IntegerType = IntegerType::Int16; +} +unsafe impl DictionaryKey for i32 { + const KEY_TYPE: IntegerType = IntegerType::Int32; +} +unsafe impl DictionaryKey for i64 { + const KEY_TYPE: IntegerType = IntegerType::Int64; +} +unsafe impl DictionaryKey for u8 { + const KEY_TYPE: IntegerType = IntegerType::UInt8; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u16 { + const KEY_TYPE: IntegerType = IntegerType::UInt16; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u32 { + const KEY_TYPE: IntegerType = IntegerType::UInt32; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u64 { + const KEY_TYPE: IntegerType = IntegerType::UInt64; + + #[cfg(target_pointer_width = "64")] + fn always_fits_usize() -> bool { + true + } +} + +/// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of +/// values is low compared to the length of the [`Array`]. +/// +/// # Safety +/// This struct guarantees that each item of [`DictionaryArray::keys`] is castable to `usize` and +/// its value is smaller than [`DictionaryArray::values`]`.len()`. In other words, you can safely +/// use `unchecked` calls to retrieve the values +#[derive(Clone)] +pub struct DictionaryArray { + data_type: DataType, + keys: PrimitiveArray, + values: Box, +} + +fn check_data_type( + key_type: IntegerType, + data_type: &DataType, + values_data_type: &DataType, +) -> Result<(), Error> { + if let DataType::Dictionary(key, value, _) = data_type.to_logical_type() { + if *key != key_type { + return Err(Error::oos( + "DictionaryArray must be initialized with a DataType::Dictionary whose integer is compatible to its keys", + )); + } + if value.as_ref().to_logical_type() != values_data_type.to_logical_type() { + return Err(Error::oos( + "DictionaryArray must be initialized with a DataType::Dictionary whose value is equal to its values", + )); + } + } else { + return Err(Error::oos( + "DictionaryArray must be initialized with logical DataType::Dictionary", + )); + } + Ok(()) +} + +impl DictionaryArray { + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * the `data_type`'s logical type is not a `DictionaryArray` + /// * the `data_type`'s keys is not compatible with `keys` + /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_new( + data_type: DataType, + keys: PrimitiveArray, + values: Box, + ) -> Result { + check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + + if keys.null_count() != keys.len() { + if K::always_fits_usize() { + // safety: we just checked that conversion to `usize` always + // succeeds + unsafe { check_indexes_unchecked(keys.values(), values.len()) }?; + } else { + check_indexes(keys.values(), values.len())?; + } + } + + Ok(Self { + data_type, + keys, + values, + }) + } + + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_from_keys(keys: PrimitiveArray, values: Box) -> Result { + let data_type = Self::default_data_type(values.data_type().clone()); + Self::try_new(data_type, keys, values) + } + + /// Returns a new [`DictionaryArray`]. + /// # Errors + /// This function errors iff + /// * the `data_type`'s logical type is not a `DictionaryArray` + /// * the `data_type`'s keys is not compatible with `keys` + /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// # Safety + /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` + pub unsafe fn try_new_unchecked( + data_type: DataType, + keys: PrimitiveArray, + values: Box, + ) -> Result { + check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + + Ok(Self { + data_type, + keys, + values, + }) + } + + /// Returns a new empty [`DictionaryArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let values = Self::try_get_child(&data_type).unwrap(); + let values = new_empty_array(values.clone()); + Self::try_new( + data_type, + PrimitiveArray::::new_empty(K::PRIMITIVE.into()), + values, + ) + .unwrap() + } + + /// Returns an [`DictionaryArray`] whose all elements are null + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + let values = Self::try_get_child(&data_type).unwrap(); + let values = new_null_array(values.clone(), 1); + Self::try_new( + data_type, + PrimitiveArray::::new_null(K::PRIMITIVE.into(), length), + values, + ) + .unwrap() + } + + /// Returns an iterator of [`Option>`]. + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn iter(&self) -> ZipValidity, DictionaryValuesIter, BitmapIter> { + ZipValidity::new_with_validity(DictionaryValuesIter::new(self), self.keys.validity()) + } + + /// Returns an iterator of [`Box`] + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn values_iter(&self) -> DictionaryValuesIter { + DictionaryValuesIter::new(self) + } + + /// Returns an iterator over the the values [`V::IterValue`]. + /// + /// # Panics + /// + /// Panics if the keys of this [`DictionaryArray`] have any null types. + /// If they do [`DictionaryArray::iter_typed`] should be called + pub fn values_iter_typed( + &self, + ) -> Result, Error> { + let keys = &self.keys; + assert_eq!(keys.null_count(), 0); + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + Ok(unsafe { DictionaryValuesIterTyped::new(keys, values) }) + } + + /// Returns an iterator over the the optional values of [`Option`]. + /// + /// # Panics + /// + /// This function panics if the `values` array + pub fn iter_typed( + &self, + ) -> Result, DictionaryValuesIterTyped, BitmapIter>, Error> + { + let keys = &self.keys; + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + let values_iter = unsafe { DictionaryValuesIterTyped::new(keys, values) }; + Ok(ZipValidity::new_with_validity(values_iter, self.validity())) + } + + /// Returns the [`DataType`] of this [`DictionaryArray`] + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns whether the values of this [`DictionaryArray`] are ordered + #[inline] + pub fn is_ordered(&self) -> bool { + match self.data_type.to_logical_type() { + DataType::Dictionary(_, _, is_ordered) => *is_ordered, + _ => unreachable!(), + } + } + + pub(crate) fn default_data_type(values_datatype: DataType) -> DataType { + DataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false) + } + + /// Slices this [`DictionaryArray`]. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + self.keys.slice(offset, length); + } + + /// Slices this [`DictionaryArray`]. + /// # Safety + /// Safe iff `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.keys.slice_unchecked(offset, length); + } + + impl_sliced!(); + + /// Returns this [`DictionaryArray`] with a new validity. + /// # Panic + /// This function panics iff `validity.len() != self.len()`. + #[must_use] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of the keys of this [`DictionaryArray`]. + /// # Panics + /// This function panics iff `validity.len() != self.len()`. + pub fn set_validity(&mut self, validity: Option) { + self.keys.set_validity(validity); + } + + impl_into_array!(); + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.keys.len() + } + + /// The optional validity. Equivalent to `self.keys().validity()`. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.keys.validity() + } + + /// Returns the keys of the [`DictionaryArray`]. These keys can be used to fetch values + /// from `values`. + #[inline] + pub fn keys(&self) -> &PrimitiveArray { + &self.keys + } + + /// Returns an iterator of the keys' values of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_values_iter(&self) -> impl TrustedLen + Clone + '_ { + // safety - invariant of the struct + self.keys.values_iter().map(|x| unsafe { x.as_usize() }) + } + + /// Returns an iterator of the keys' of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_iter(&self) -> impl TrustedLen> + Clone + '_ { + // safety - invariant of the struct + self.keys.iter().map(|x| x.map(|x| unsafe { x.as_usize() })) + } + + /// Returns the keys' value of the [`DictionaryArray`] as `usize` + /// # Panics + /// This function panics iff `index >= self.len()` + #[inline] + pub fn key_value(&self, index: usize) -> usize { + // safety - invariant of the struct + unsafe { self.keys.values()[index].as_usize() } + } + + /// Returns the values of the [`DictionaryArray`]. + #[inline] + pub fn values(&self) -> &Box { + &self.values + } + + /// Returns the value of the [`DictionaryArray`] at position `i`. + /// # Implementation + /// This function will allocate a new [`Scalar`] and is usually not performant. + /// Consider calling `keys` and `values`, downcasting `values`, and iterating over that. + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn value(&self, index: usize) -> Box { + // safety - invariant of this struct + let index = unsafe { self.keys.value(index).as_usize() }; + new_scalar(self.values.as_ref(), index) + } + + pub(crate) fn try_get_child(data_type: &DataType) -> Result<&DataType, Error> { + Ok(match data_type.to_logical_type() { + DataType::Dictionary(_, values, _) => values.as_ref(), + _ => { + return Err(Error::oos( + "Dictionaries must be initialized with DataType::Dictionary", + )) + }, + }) + } +} + +impl Array for DictionaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.keys.validity() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/mutable.rs b/crates/nano-arrow/src/array/dictionary/mutable.rs new file mode 100644 index 000000000000..dedd6ead0eaa --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/mutable.rs @@ -0,0 +1,241 @@ +use std::hash::Hash; +use std::sync::Arc; + +use super::value_map::ValueMap; +use super::{DictionaryArray, DictionaryKey}; +use crate::array::indexable::{AsIndexed, Indexable}; +use crate::array::primitive::MutablePrimitiveArray; +use crate::array::{Array, MutableArray, TryExtend, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +/// A mutable, strong-typed version of [`DictionaryArray`]. +/// +/// # Example +/// Building a UTF8 dictionary with `i32` keys. +/// ``` +/// # use arrow2::array::{MutableDictionaryArray, MutableUtf8Array, TryPush}; +/// # fn main() -> Result<(), Box> { +/// let mut array: MutableDictionaryArray> = MutableDictionaryArray::new(); +/// array.try_push(Some("A"))?; +/// array.try_push(Some("B"))?; +/// array.push_null(); +/// array.try_push(Some("C"))?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct MutableDictionaryArray { + data_type: DataType, + map: ValueMap, + // invariant: `max(keys) < map.values().len()` + keys: MutablePrimitiveArray, +} + +impl From> for DictionaryArray { + fn from(other: MutableDictionaryArray) -> Self { + // Safety - the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new_unchecked( + other.data_type, + other.keys.into(), + other.map.into_values().as_box(), + ) + .unwrap() + } + } +} + +impl MutableDictionaryArray { + /// Creates an empty [`MutableDictionaryArray`]. + pub fn new() -> Self { + Self::try_empty(M::default()).unwrap() + } +} + +impl Default for MutableDictionaryArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableDictionaryArray { + /// Creates an empty [`MutableDictionaryArray`] from a given empty values array. + /// # Errors + /// Errors if the array is non-empty. + pub fn try_empty(values: M) -> Result { + Ok(Self::from_value_map(ValueMap::::try_empty(values)?)) + } + + /// Creates an empty [`MutableDictionaryArray`] preloaded with a given dictionary of values. + /// Indices associated with those values are automatically assigned based on the order of + /// the values. + /// # Errors + /// Errors if there's more values than the maximum value of `K` or if values are not unique. + pub fn from_values(values: M) -> Result + where + M: Indexable, + M::Type: Eq + Hash, + { + Ok(Self::from_value_map(ValueMap::::from_values(values)?)) + } + + fn from_value_map(value_map: ValueMap) -> Self { + let keys = MutablePrimitiveArray::::new(); + let data_type = + DataType::Dictionary(K::KEY_TYPE, Box::new(value_map.data_type().clone()), false); + Self { + data_type, + map: value_map, + keys, + } + } + + /// Creates an empty [`MutableDictionaryArray`] retaining the same dictionary as the current + /// mutable dictionary array, but with no data. This may come useful when serializing the + /// array into multiple chunks, where there's a requirement that the dictionary is the same. + /// No copying is performed, the value map is moved over to the new array. + pub fn into_empty(self) -> Self { + Self::from_value_map(self.map) + } + + /// Same as `into_empty` but clones the inner value map instead of taking full ownership. + pub fn to_empty(&self) -> Self + where + M: Clone, + { + Self::from_value_map(self.map.clone()) + } + + /// pushes a null value + pub fn push_null(&mut self) { + self.keys.push(None) + } + + /// returns a reference to the inner values. + pub fn values(&self) -> &M { + self.map.values() + } + + /// converts itself into [`Arc`] + pub fn into_arc(self) -> Arc { + let a: DictionaryArray = self.into(); + Arc::new(a) + } + + /// converts itself into [`Box`] + pub fn into_box(self) -> Box { + let a: DictionaryArray = self.into(); + Box::new(a) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.keys.reserve(additional); + } + + /// Shrinks the capacity of the [`MutableDictionaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.map.shrink_to_fit(); + self.keys.shrink_to_fit(); + } + + /// Returns the dictionary keys + pub fn keys(&self) -> &MutablePrimitiveArray { + &self.keys + } + + fn take_into(&mut self) -> DictionaryArray { + DictionaryArray::::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + self.map.take_into(), + ) + .unwrap() + } +} + +impl MutableArray for MutableDictionaryArray { + fn len(&self) -> usize { + self.keys.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.keys.validity() + } + + fn as_box(&mut self) -> Box { + Box::new(self.take_into()) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.take_into()) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.keys.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl TryExtend> for MutableDictionaryArray +where + K: DictionaryKey, + M: MutableArray + Indexable + TryExtend>, + T: AsIndexed, + M::Type: Eq + Hash, +{ + fn try_extend>>(&mut self, iter: II) -> Result<()> { + for value in iter { + if let Some(value) = value { + let key = self + .map + .try_push_valid(value, |arr, v| arr.try_extend(std::iter::once(Some(v))))?; + self.keys.try_push(Some(key))?; + } else { + self.push_null(); + } + } + Ok(()) + } +} + +impl TryPush> for MutableDictionaryArray +where + K: DictionaryKey, + M: MutableArray + Indexable + TryPush>, + T: AsIndexed, + M::Type: Eq + Hash, +{ + fn try_push(&mut self, item: Option) -> Result<()> { + if let Some(value) = item { + let key = self + .map + .try_push_valid(value, |arr, v| arr.try_push(Some(v)))?; + self.keys.try_push(Some(key))?; + } else { + self.push_null(); + } + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/typed_iterator.rs b/crates/nano-arrow/src/array/dictionary/typed_iterator.rs new file mode 100644 index 000000000000..5c528beb251b --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/typed_iterator.rs @@ -0,0 +1,110 @@ +use super::DictionaryKey; +use crate::array::{Array, PrimitiveArray, Utf8Array}; +use crate::error::{Error, Result}; +use crate::trusted_len::TrustedLen; +use crate::types::Offset; + +pub trait DictValue { + type IterValue<'this> + where + Self: 'this; + + /// # Safety + /// Will not do any bound checks but must check validity. + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_>; + + /// Take a [`dyn Array`] an try to downcast it to the type of `DictValue`. + fn downcast_values(array: &dyn Array) -> Result<&Self> + where + Self: Sized; +} + +impl DictValue for Utf8Array { + type IterValue<'a> = &'a str; + + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> { + self.value_unchecked(item) + } + + fn downcast_values(array: &dyn Array) -> Result<&Self> + where + Self: Sized, + { + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::InvalidArgumentError("could not convert array to dictionary value".into()) + }) + .map(|arr| { + assert_eq!( + arr.null_count(), + 0, + "null values in values not supported in iteration" + ); + arr + }) + } +} + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> { + keys: &'a PrimitiveArray, + values: &'a V, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey, V: DictValue> DictionaryValuesIterTyped<'a, K, V> { + pub(super) unsafe fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + Self { + keys, + values, + index: 0, + end: keys.len(), + } + } +} + +impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped<'a, K, V> { + type Item = V::IterValue<'a>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + unsafe { + let key = self.keys.value_unchecked(old); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryValuesIterTyped<'a, K, V> {} + +impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator + for DictionaryValuesIterTyped<'a, K, V> +{ + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + unsafe { + let key = self.keys.value_unchecked(self.end); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + } +} diff --git a/crates/nano-arrow/src/array/dictionary/value_map.rs b/crates/nano-arrow/src/array/dictionary/value_map.rs new file mode 100644 index 000000000000..f9d22edfffe5 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/value_map.rs @@ -0,0 +1,169 @@ +use std::borrow::Borrow; +use std::fmt::{self, Debug}; +use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; + +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; + +use super::DictionaryKey; +use crate::array::indexable::{AsIndexed, Indexable}; +use crate::array::{Array, MutableArray}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Hasher for pre-hashed values; similar to `hash_hasher` but with native endianness. +/// +/// We know that we'll only use it for `u64` values, so we can avoid endian conversion. +/// +/// Invariant: hash of a u64 value is always equal to itself. +#[derive(Copy, Clone, Default)] +pub struct PassthroughHasher(u64); + +impl Hasher for PassthroughHasher { + #[inline] + fn write_u64(&mut self, value: u64) { + self.0 = value; + } + + fn write(&mut self, _: &[u8]) { + unreachable!(); + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +#[derive(Clone)] +pub struct Hashed { + hash: u64, + key: K, +} + +#[inline] +fn ahash_hash(value: &T) -> u64 { + BuildHasherDefault::::default().hash_one(value) +} + +impl Hash for Hashed { + #[inline] + fn hash(&self, state: &mut H) { + self.hash.hash(state) + } +} + +#[derive(Clone)] +pub struct ValueMap { + pub values: M, + pub map: HashMap, (), BuildHasherDefault>, // NB: *only* use insert_hashed_nocheck() and no other hashmap API +} + +impl ValueMap { + pub fn try_empty(values: M) -> Result { + if !values.is_empty() { + return Err(Error::InvalidArgumentError( + "initializing value map with non-empty values array".into(), + )); + } + Ok(Self { + values, + map: HashMap::default(), + }) + } + + pub fn from_values(values: M) -> Result + where + M: Indexable, + M::Type: Eq + Hash, + { + let mut map = HashMap::, _, _>::with_capacity_and_hasher( + values.len(), + BuildHasherDefault::::default(), + ); + for index in 0..values.len() { + let key = K::try_from(index).map_err(|_| Error::Overflow)?; + // safety: we only iterate within bounds + let value = unsafe { values.value_unchecked_at(index) }; + let hash = ahash_hash(value.borrow()); + match map.raw_entry_mut().from_hash(hash, |item| { + // safety: invariant of the struct, it's always in bounds since we maintain it + let stored_value = unsafe { values.value_unchecked_at(item.key.as_usize()) }; + stored_value.borrow() == value.borrow() + }) { + RawEntryMut::Occupied(_) => { + return Err(Error::InvalidArgumentError( + "duplicate value in dictionary values array".into(), + )) + }, + RawEntryMut::Vacant(entry) => { + // NB: don't use .insert() here! + entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); + }, + } + } + Ok(Self { values, map }) + } + + pub fn data_type(&self) -> &DataType { + self.values.data_type() + } + + pub fn into_values(self) -> M { + self.values + } + + pub fn take_into(&mut self) -> Box { + let arr = self.values.as_box(); + self.map.clear(); + arr + } + + #[inline] + pub fn values(&self) -> &M { + &self.values + } + + /// Try to insert a value and return its index (it may or may not get inserted). + pub fn try_push_valid( + &mut self, + value: V, + mut push: impl FnMut(&mut M, V) -> Result<()>, + ) -> Result + where + M: Indexable, + V: AsIndexed, + M::Type: Eq + Hash, + { + let hash = ahash_hash(value.as_indexed()); + Ok( + match self.map.raw_entry_mut().from_hash(hash, |item| { + // safety: we've already checked (the inverse) when we pushed it, so it should be ok? + let index = unsafe { item.key.as_usize() }; + // safety: invariant of the struct, it's always in bounds since we maintain it + let stored_value = unsafe { self.values.value_unchecked_at(index) }; + stored_value.borrow() == value.as_indexed() + }) { + RawEntryMut::Occupied(entry) => entry.key().key, + RawEntryMut::Vacant(entry) => { + let index = self.values.len(); + let key = K::try_from(index).map_err(|_| Error::Overflow)?; + entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); // NB: don't use .insert() here! + push(&mut self.values, value)?; + debug_assert_eq!(self.values.len(), index + 1); + key + }, + }, + ) + } + + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + } +} + +impl Debug for ValueMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.values.fmt(f) + } +} diff --git a/crates/nano-arrow/src/array/equal/binary.rs b/crates/nano-arrow/src/array/equal/binary.rs new file mode 100644 index 000000000000..bed8588efb59 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/binary.rs @@ -0,0 +1,6 @@ +use crate::array::BinaryArray; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &BinaryArray, rhs: &BinaryArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/boolean.rs b/crates/nano-arrow/src/array/equal/boolean.rs new file mode 100644 index 000000000000..d9c6af9b0276 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/boolean.rs @@ -0,0 +1,5 @@ +use crate::array::BooleanArray; + +pub(super) fn equal(lhs: &BooleanArray, rhs: &BooleanArray) -> bool { + lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/dictionary.rs b/crates/nano-arrow/src/array/equal/dictionary.rs new file mode 100644 index 000000000000..d65634095fb3 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/dictionary.rs @@ -0,0 +1,14 @@ +use crate::array::{DictionaryArray, DictionaryKey}; + +pub(super) fn equal(lhs: &DictionaryArray, rhs: &DictionaryArray) -> bool { + if !(lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()) { + return false; + }; + + // if x is not valid and y is but its child is not, the slots are equal. + lhs.iter().zip(rhs.iter()).all(|(x, y)| match (&x, &y) { + (None, Some(y)) => !y.is_valid(), + (Some(x), None) => !x.is_valid(), + _ => x == y, + }) +} diff --git a/crates/nano-arrow/src/array/equal/fixed_size_binary.rs b/crates/nano-arrow/src/array/equal/fixed_size_binary.rs new file mode 100644 index 000000000000..883d5739778b --- /dev/null +++ b/crates/nano-arrow/src/array/equal/fixed_size_binary.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, FixedSizeBinaryArray}; + +pub(super) fn equal(lhs: &FixedSizeBinaryArray, rhs: &FixedSizeBinaryArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/fixed_size_list.rs b/crates/nano-arrow/src/array/equal/fixed_size_list.rs new file mode 100644 index 000000000000..aaf77910013f --- /dev/null +++ b/crates/nano-arrow/src/array/equal/fixed_size_list.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, FixedSizeListArray}; + +pub(super) fn equal(lhs: &FixedSizeListArray, rhs: &FixedSizeListArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/list.rs b/crates/nano-arrow/src/array/equal/list.rs new file mode 100644 index 000000000000..26faa1598faf --- /dev/null +++ b/crates/nano-arrow/src/array/equal/list.rs @@ -0,0 +1,6 @@ +use crate::array::{Array, ListArray}; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &ListArray, rhs: &ListArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/map.rs b/crates/nano-arrow/src/array/equal/map.rs new file mode 100644 index 000000000000..e150fb4a4b41 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/map.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, MapArray}; + +pub(super) fn equal(lhs: &MapArray, rhs: &MapArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/mod.rs b/crates/nano-arrow/src/array/equal/mod.rs new file mode 100644 index 000000000000..91fd0c2f464f --- /dev/null +++ b/crates/nano-arrow/src/array/equal/mod.rs @@ -0,0 +1,287 @@ +use super::*; +use crate::offset::Offset; +use crate::types::NativeType; + +mod binary; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod union; +mod utf8; + +impl PartialEq for dyn Array + '_ { + fn eq(&self, that: &dyn Array) -> bool { + equal(self, that) + } +} + +impl PartialEq for std::sync::Arc { + fn eq(&self, that: &dyn Array) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for Box { + fn eq(&self, that: &dyn Array) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for NullArray { + fn eq(&self, other: &Self) -> bool { + null::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for NullArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq<&dyn Array> for PrimitiveArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &PrimitiveArray) -> bool { + equal(*self, other) + } +} + +impl PartialEq> for PrimitiveArray { + fn eq(&self, other: &Self) -> bool { + primitive::equal::(self, other) + } +} + +impl PartialEq for BooleanArray { + fn eq(&self, other: &Self) -> bool { + equal(self, other) + } +} + +impl PartialEq<&dyn Array> for BooleanArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for Utf8Array { + fn eq(&self, other: &Self) -> bool { + utf8::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for Utf8Array { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &Utf8Array) -> bool { + equal(*self, other) + } +} + +impl PartialEq> for BinaryArray { + fn eq(&self, other: &Self) -> bool { + binary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for BinaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &BinaryArray) -> bool { + equal(*self, other) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + fixed_size_binary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for FixedSizeBinaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for ListArray { + fn eq(&self, other: &Self) -> bool { + list::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for ListArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for FixedSizeListArray { + fn eq(&self, other: &Self) -> bool { + fixed_size_list::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for FixedSizeListArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for StructArray { + fn eq(&self, other: &Self) -> bool { + struct_::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for StructArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for DictionaryArray { + fn eq(&self, other: &Self) -> bool { + dictionary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for DictionaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for UnionArray { + fn eq(&self, other: &Self) -> bool { + union::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for UnionArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for MapArray { + fn eq(&self, other: &Self) -> bool { + map::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for MapArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +/// Logically compares two [`Array`]s. +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * each of their items are equal +pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { + if lhs.data_type() != rhs.data_type() { + return false; + } + + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { + Null => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + null::equal(lhs, rhs) + }, + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + boolean::equal(lhs, rhs) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::equal::<$T>(lhs, rhs) + }), + Utf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::equal::(lhs, rhs) + }, + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::equal::(lhs, rhs) + }, + Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::equal::(lhs, rhs) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::equal::(lhs, rhs) + }, + List => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + list::equal::(lhs, rhs) + }, + LargeList => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + list::equal::(lhs, rhs) + }, + Struct => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + struct_::equal(lhs, rhs) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + dictionary::equal::<$T>(lhs, rhs) + }) + }, + FixedSizeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + fixed_size_binary::equal(lhs, rhs) + }, + FixedSizeList => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + fixed_size_list::equal(lhs, rhs) + }, + Union => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + union::equal(lhs, rhs) + }, + Map => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + map::equal(lhs, rhs) + }, + } +} diff --git a/crates/nano-arrow/src/array/equal/null.rs b/crates/nano-arrow/src/array/equal/null.rs new file mode 100644 index 000000000000..11ad6cc133bb --- /dev/null +++ b/crates/nano-arrow/src/array/equal/null.rs @@ -0,0 +1,6 @@ +use crate::array::{Array, NullArray}; + +#[inline] +pub(super) fn equal(lhs: &NullArray, rhs: &NullArray) -> bool { + lhs.len() == rhs.len() +} diff --git a/crates/nano-arrow/src/array/equal/primitive.rs b/crates/nano-arrow/src/array/equal/primitive.rs new file mode 100644 index 000000000000..dc90bb15da5e --- /dev/null +++ b/crates/nano-arrow/src/array/equal/primitive.rs @@ -0,0 +1,6 @@ +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +pub(super) fn equal(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/struct_.rs b/crates/nano-arrow/src/array/equal/struct_.rs new file mode 100644 index 000000000000..a1741e36368c --- /dev/null +++ b/crates/nano-arrow/src/array/equal/struct_.rs @@ -0,0 +1,54 @@ +use crate::array::{Array, StructArray}; + +pub(super) fn equal(lhs: &StructArray, rhs: &StructArray) -> bool { + lhs.data_type() == rhs.data_type() + && lhs.len() == rhs.len() + && match (lhs.validity(), rhs.validity()) { + (None, None) => lhs.values().iter().eq(rhs.values().iter()), + (Some(l_validity), Some(r_validity)) => lhs + .values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + l_validity.iter().zip(r_validity.iter()).enumerate().all( + |(i, (lhs_is_valid, rhs_is_valid))| { + if lhs_is_valid && rhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + lhs_is_valid == rhs_is_valid + } + }, + ) + }), + (Some(l_validity), None) => { + lhs.values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + l_validity.iter().enumerate().all(|(i, lhs_is_valid)| { + if lhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + // rhs is always valid => different + false + } + }) + }) + }, + (None, Some(r_validity)) => { + lhs.values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + r_validity.iter().enumerate().all(|(i, rhs_is_valid)| { + if rhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + // lhs is always valid => different + false + } + }) + }) + }, + } +} diff --git a/crates/nano-arrow/src/array/equal/union.rs b/crates/nano-arrow/src/array/equal/union.rs new file mode 100644 index 000000000000..51b9d960feac --- /dev/null +++ b/crates/nano-arrow/src/array/equal/union.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, UnionArray}; + +pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/utf8.rs b/crates/nano-arrow/src/array/equal/utf8.rs new file mode 100644 index 000000000000..1327221ca331 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/utf8.rs @@ -0,0 +1,6 @@ +use crate::array::Utf8Array; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &Utf8Array, rhs: &Utf8Array) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/ffi.rs b/crates/nano-arrow/src/array/ffi.rs new file mode 100644 index 000000000000..0e9629d4fdf0 --- /dev/null +++ b/crates/nano-arrow/src/array/ffi.rs @@ -0,0 +1,86 @@ +use crate::array::*; +use crate::datatypes::PhysicalType; +use crate::error::Result; +use crate::ffi; + +/// Trait describing how a struct presents itself to the +/// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +/// # Safety +/// Implementing this trait incorrect will lead to UB +pub(crate) unsafe trait ToFfi { + /// The pointers to the buffers. + fn buffers(&self) -> Vec>; + + /// The children + fn children(&self) -> Vec> { + vec![] + } + + /// The offset + fn offset(&self) -> Option; + + /// return a partial clone of self with an offset. + fn to_ffi_aligned(&self) -> Self; +} + +/// Trait describing how a struct imports into itself from the +/// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub(crate) trait FromFfi: Sized { + /// Convert itself from FFI. + /// # Safety + /// This function is intrinsically `unsafe` as it requires the FFI to be made according + /// to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) + unsafe fn try_from_ffi(array: T) -> Result; +} + +macro_rules! ffi_dyn { + ($array:expr, $ty:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + ( + array.offset().unwrap(), + array.buffers(), + array.children(), + None, + ) + }}; +} + +type BuffersChildren = ( + usize, + Vec>, + Vec>, + Option>, +); + +pub fn offset_buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => ffi_dyn!(array, NullArray), + Boolean => ffi_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + ffi_dyn!(array, PrimitiveArray<$T>) + }), + Binary => ffi_dyn!(array, BinaryArray), + LargeBinary => ffi_dyn!(array, BinaryArray), + FixedSizeBinary => ffi_dyn!(array, FixedSizeBinaryArray), + Utf8 => ffi_dyn!(array, Utf8Array::), + LargeUtf8 => ffi_dyn!(array, Utf8Array::), + List => ffi_dyn!(array, ListArray::), + LargeList => ffi_dyn!(array, ListArray::), + FixedSizeList => ffi_dyn!(array, FixedSizeListArray), + Struct => ffi_dyn!(array, StructArray), + Union => ffi_dyn!(array, UnionArray), + Map => ffi_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + ( + array.offset().unwrap(), + array.buffers(), + array.children(), + Some(array.values().clone()), + ) + }) + }, + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/data.rs b/crates/nano-arrow/src/array/fixed_size_binary/data.rs new file mode 100644 index 000000000000..6eb025d91623 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/data.rs @@ -0,0 +1,37 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, FixedSizeBinaryArray}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; + +impl Arrow2Arrow for FixedSizeBinaryArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.values.clone().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type: DataType = data.data_type().clone().into(); + let size = match data_type { + DataType::FixedSizeBinary(size) => size, + _ => unreachable!("must be FixedSizeBinary"), + }; + + let mut values: Buffer = data.buffers()[0].clone().into(); + values.slice(data.offset() * size, data.len() * size); + + Self { + size, + data_type, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/ffi.rs b/crates/nano-arrow/src/array/fixed_size_binary/ffi.rs new file mode 100644 index 000000000000..ee6e6a030df0 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/ffi.rs @@ -0,0 +1,56 @@ +use super::FixedSizeBinaryArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for FixedSizeBinaryArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset() / self.size; + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset() / self.size; + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + size: self.size, + data_type: self.data_type.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for FixedSizeBinaryArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/fmt.rs b/crates/nano-arrow/src/array/fixed_size_binary/fmt.rs new file mode 100644 index 000000000000..c5f9e2dd3293 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/fmt.rs @@ -0,0 +1,20 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::FixedSizeBinaryArray; + +pub fn write_value(array: &FixedSizeBinaryArray, index: usize, f: &mut W) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| write!(f, "{}", values[index]); + + write_vec(f, writer, None, values.len(), "None", false) +} + +impl Debug for FixedSizeBinaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + write!(f, "{:?}", self.data_type)?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/iterator.rs b/crates/nano-arrow/src/array/fixed_size_binary/iterator.rs new file mode 100644 index 000000000000..4c885c591943 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/iterator.rs @@ -0,0 +1,49 @@ +use super::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; +use crate::array::MutableArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +impl<'a> IntoIterator for &'a FixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> FixedSizeBinaryArray { + /// constructs a new iterator + pub fn iter( + &'a self, + ) -> ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>> { + ZipValidity::new_with_validity(self.values_iter(), self.validity()) + } + + /// Returns iterator over the values of [`FixedSizeBinaryArray`] + pub fn values_iter(&'a self) -> std::slice::ChunksExact<'a, u8> { + self.values().chunks_exact(self.size) + } +} + +impl<'a> IntoIterator for &'a MutableFixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MutableFixedSizeBinaryArray { + /// constructs a new iterator + pub fn iter( + &'a self, + ) -> ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>> { + ZipValidity::new(self.iter_values(), self.validity().map(|x| x.iter())) + } + + /// Returns iterator over the values of [`MutableFixedSizeBinaryArray`] + pub fn iter_values(&'a self) -> std::slice::ChunksExact<'a, u8> { + self.values().chunks_exact(self.size()) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/mod.rs b/crates/nano-arrow/src/array/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..14bd2aa1e512 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/mod.rs @@ -0,0 +1,287 @@ +use super::Array; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::Error; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +pub use mutable::*; + +/// The Arrow's equivalent to an immutable `Vec>`. +/// Cloning and slicing this struct is `O(1)`. +#[derive(Clone)] +pub struct FixedSizeBinaryArray { + size: usize, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. + data_type: DataType, + values: Buffer, + validity: Option, +} + +impl FixedSizeBinaryArray { + /// Creates a new [`FixedSizeBinaryArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Result { + let size = Self::maybe_get_size(&data_type)?; + + if values.len() % size != 0 { + return Err(Error::oos(format!( + "values (of len {}) must be a multiple of size ({}) in FixedSizeBinaryArray.", + values.len(), + size + ))); + } + let len = values.len() / size; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "validity mask length must be equal to the number of values divided by size", + )); + } + + Ok(Self { + size, + data_type, + values, + validity, + }) + } + + /// Creates a new [`FixedSizeBinaryArray`]. + /// # Panics + /// This function panics iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn new(data_type: DataType, values: Buffer, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Returns a new empty [`FixedSizeBinaryArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, Buffer::new(), None) + } + + /// Returns a new null [`FixedSizeBinaryArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let size = Self::maybe_get_size(&data_type).unwrap(); + Self::new( + data_type, + vec![0u8; length * size].into(), + Some(Bitmap::new_zeroed(length)), + ) + } +} + +// must use +impl FixedSizeBinaryArray { + /// Slices this [`FixedSizeBinaryArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`FixedSizeBinaryArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values + .slice_unchecked(offset * self.size, length * self.size); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// accessors +impl FixedSizeBinaryArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the values allocated on this [`FixedSizeBinaryArray`]. + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns value at position `i`. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: invariant of the function. + self.values + .get_unchecked(i * self.size..(i + 1) * self.size) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&[u8]> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns a new [`FixedSizeBinaryArray`] with a different logical type. + /// This is `O(1)`. + /// # Panics + /// Panics iff the data_type is not supported for the physical type. + #[inline] + pub fn to(self, data_type: DataType) -> Self { + match ( + data_type.to_logical_type(), + self.data_type().to_logical_type(), + ) { + (DataType::FixedSizeBinary(size_a), DataType::FixedSizeBinary(size_b)) + if size_a == size_b => {}, + _ => panic!("Wrong DataType"), + } + + Self { + size: self.size, + data_type, + values: self.values, + validity: self.validity, + } + } + + /// Returns the size + pub fn size(&self) -> usize { + self.size + } +} + +impl FixedSizeBinaryArray { + pub(crate) fn maybe_get_size(data_type: &DataType) -> Result { + match data_type.to_logical_type() { + DataType::FixedSizeBinary(size) => { + if *size == 0 { + return Err(Error::oos("FixedSizeBinaryArray expects a positive size")); + } + Ok(*size) + }, + _ => Err(Error::oos( + "FixedSizeBinaryArray expects DataType::FixedSizeBinary", + )), + } + } + + pub(crate) fn get_size(data_type: &DataType) -> usize { + Self::maybe_get_size(data_type).unwrap() + } +} + +impl Array for FixedSizeBinaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl FixedSizeBinaryArray { + /// Creates a [`FixedSizeBinaryArray`] from an fallible iterator of optional `[u8]`. + pub fn try_from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Result { + MutableFixedSizeBinaryArray::try_from_iter(iter, size).map(|x| x.into()) + } + + /// Creates a [`FixedSizeBinaryArray`] from an iterator of optional `[u8]`. + pub fn from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Self { + MutableFixedSizeBinaryArray::try_from_iter(iter, size) + .unwrap() + .into() + } + + /// Creates a [`FixedSizeBinaryArray`] from a slice of arrays of bytes + pub fn from_slice>(a: P) -> Self { + let values = a.as_ref().iter().flatten().copied().collect::>(); + Self::new(DataType::FixedSizeBinary(N), values.into(), None) + } + + /// Creates a new [`FixedSizeBinaryArray`] from a slice of optional `[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from]>>(slice: P) -> Self { + MutableFixedSizeBinaryArray::from(slice).into() + } +} + +pub trait FixedSizeBinaryValues { + fn values(&self) -> &[u8]; + fn size(&self) -> usize; +} + +impl FixedSizeBinaryValues for FixedSizeBinaryArray { + #[inline] + fn values(&self) -> &[u8] { + &self.values + } + + #[inline] + fn size(&self) -> usize { + self.size + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/mutable.rs b/crates/nano-arrow/src/array/fixed_size_binary/mutable.rs new file mode 100644 index 000000000000..f5a68facf681 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/mutable.rs @@ -0,0 +1,321 @@ +use std::sync::Arc; + +use super::{FixedSizeBinaryArray, FixedSizeBinaryValues}; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtendFromSelf}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Error; + +/// The Arrow's equivalent to a mutable `Vec>`. +/// Converting a [`MutableFixedSizeBinaryArray`] into a [`FixedSizeBinaryArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableFixedSizeBinaryArray { + data_type: DataType, + size: usize, + values: Vec, + validity: Option, +} + +impl From for FixedSizeBinaryArray { + fn from(other: MutableFixedSizeBinaryArray) -> Self { + FixedSizeBinaryArray::new( + other.data_type, + other.values.into(), + other.validity.map(|x| x.into()), + ) + } +} + +impl MutableFixedSizeBinaryArray { + /// Creates a new [`MutableFixedSizeBinaryArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + data_type: DataType, + values: Vec, + validity: Option, + ) -> Result { + let size = FixedSizeBinaryArray::maybe_get_size(&data_type)?; + + if values.len() % size != 0 { + return Err(Error::oos(format!( + "values (of len {}) must be a multiple of size ({}) in FixedSizeBinaryArray.", + values.len(), + size + ))); + } + let len = values.len() / size; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "validity mask length must be equal to the number of values divided by size", + )); + } + + Ok(Self { + size, + data_type, + values, + validity, + }) + } + + /// Creates a new empty [`MutableFixedSizeBinaryArray`]. + pub fn new(size: usize) -> Self { + Self::with_capacity(size, 0) + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] with capacity for `capacity` entries. + pub fn with_capacity(size: usize, capacity: usize) -> Self { + Self::try_new( + DataType::FixedSizeBinary(size), + Vec::::with_capacity(capacity * size), + None, + ) + .unwrap() + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] from a slice of optional `[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from]>>(slice: P) -> Self { + let values = slice + .as_ref() + .iter() + .copied() + .flat_map(|x| x.unwrap_or([0; N])) + .collect::>(); + let validity = slice + .as_ref() + .iter() + .map(|x| x.is_some()) + .collect::(); + Self::try_new(DataType::FixedSizeBinary(N), values, validity.into()).unwrap() + } + + /// tries to push a new entry to [`MutableFixedSizeBinaryArray`]. + /// # Error + /// Errors iff the size of `value` is not equal to its own size. + #[inline] + pub fn try_push>(&mut self, value: Option

) -> Result<(), Error> { + match value { + Some(bytes) => { + let bytes = bytes.as_ref(); + if self.size != bytes.len() { + return Err(Error::InvalidArgumentError( + "FixedSizeBinaryArray requires every item to be of its length".to_string(), + )); + } + self.values.extend_from_slice(bytes); + + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.resize(self.values.len() + self.size, 0); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } + + /// pushes a new entry to [`MutableFixedSizeBinaryArray`]. + /// # Panics + /// Panics iff the size of `value` is not equal to its own size. + #[inline] + pub fn push>(&mut self, value: Option

) { + self.try_push(value).unwrap() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// Pop the last entry from [`MutableFixedSizeBinaryArray`]. + /// This function returns `None` iff this array is empty + pub fn pop(&mut self) -> Option> { + if self.values.len() < self.size { + return None; + } + let value_start = self.values.len() - self.size; + let value = self.values.split_off(value_start); + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] from an iterator of values. + /// # Errors + /// Errors iff the size of any of the `value` is not equal to its own size. + pub fn try_from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut primitive = Self::with_capacity(size, lower); + for item in iterator { + primitive.try_push(item)? + } + Ok(primitive) + } + + /// returns the (fixed) size of the [`MutableFixedSizeBinaryArray`]. + #[inline] + pub fn size(&self) -> usize { + self.size + } + + /// Returns the capacity of this array + pub fn capacity(&self) -> usize { + self.values.capacity() / self.size + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Returns the element at index `i` as `&[u8]` + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + &self.values[i * self.size..(i + 1) * self.size] + } + + /// Returns the element at index `i` as `&[u8]` + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + std::slice::from_raw_parts(self.values.as_ptr().add(i * self.size), self.size) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional * self.size); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableFixedSizeBinaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +/// Accessors +impl MutableFixedSizeBinaryArray { + /// Returns its values. + pub fn values(&self) -> &Vec { + &self.values + } + + /// Returns a mutable slice of values. + pub fn values_mut_slice(&mut self) -> &mut [u8] { + self.values.as_mut_slice() + } +} + +impl MutableArray for MutableFixedSizeBinaryArray { + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(self.size), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(self.size), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push::<&[u8]>(None); + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl FixedSizeBinaryValues for MutableFixedSizeBinaryArray { + #[inline] + fn values(&self) -> &[u8] { + &self.values + } + + #[inline] + fn size(&self) -> usize { + self.size + } +} + +impl PartialEq for MutableFixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableFixedSizeBinaryArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + self.values.extend_from_slice(slice); + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/data.rs b/crates/nano-arrow/src/array/fixed_size_list/data.rs new file mode 100644 index 000000000000..966504bf3b6c --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/data.rs @@ -0,0 +1,36 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, FixedSizeListArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::DataType; + +impl Arrow2Arrow for FixedSizeListArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(vec![to_data(self.values.as_ref())]); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type: DataType = data.data_type().clone().into(); + let size = match data_type { + DataType::FixedSizeList(_, size) => size, + _ => unreachable!("must be FixedSizeList type"), + }; + + let mut values = from_data(&data.child_data()[0]); + values.slice(data.offset() * size, data.len() * size); + + Self { + size, + data_type, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/ffi.rs b/crates/nano-arrow/src/array/fixed_size_list/ffi.rs new file mode 100644 index 000000000000..237001809598 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/ffi.rs @@ -0,0 +1,39 @@ +use super::FixedSizeListArray; +use crate::array::ffi::{FromFfi, ToFfi}; +use crate::array::Array; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for FixedSizeListArray { + fn buffers(&self) -> Vec> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + vec![self.values.clone()] + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for FixedSizeListArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let child = unsafe { array.child(0)? }; + let values = ffi::try_from(child)?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/fmt.rs b/crates/nano-arrow/src/array/fixed_size_list/fmt.rs new file mode 100644 index 000000000000..ee7d86115a14 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::FixedSizeListArray; + +pub fn write_value( + array: &FixedSizeListArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for FixedSizeListArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "FixedSizeListArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/iterator.rs b/crates/nano-arrow/src/array/fixed_size_list/iterator.rs new file mode 100644 index 000000000000..123658005adc --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/iterator.rs @@ -0,0 +1,43 @@ +use super::FixedSizeListArray; +use crate::array::{Array, ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +unsafe impl<'a> ArrayAccessor<'a> for FixedSizeListArray { + type Item = Box; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of a [`FixedSizeListArray`]. +pub type FixedSizeListValuesIter<'a> = ArrayValuesIter<'a, FixedSizeListArray>; + +type ZipIter<'a> = ZipValidity, FixedSizeListValuesIter<'a>, BitmapIter<'a>>; + +impl<'a> IntoIterator for &'a FixedSizeListArray { + type Item = Option>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> FixedSizeListArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + ZipValidity::new_with_validity(FixedSizeListValuesIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> FixedSizeListValuesIter<'a> { + FixedSizeListValuesIter::new(self) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/mod.rs b/crates/nano-arrow/src/array/fixed_size_list/mod.rs new file mode 100644 index 000000000000..40eb5016b9b7 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/mod.rs @@ -0,0 +1,221 @@ +use super::{new_empty_array, new_null_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; + +/// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. +/// Cloning and slicing this struct is `O(1)`. +#[derive(Clone)] +pub struct FixedSizeListArray { + size: usize, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. + data_type: DataType, + values: Box, + validity: Option, +} + +impl FixedSizeListArray { + /// Creates a new [`FixedSizeListArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeList`] + /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + data_type: DataType, + values: Box, + validity: Option, + ) -> Result { + let (child, size) = Self::try_child_and_size(&data_type)?; + + let child_data_type = &child.data_type; + let values_data_type = values.data_type(); + if child_data_type != values_data_type { + return Err(Error::oos( + format!("FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_data_type:?} while it got {values_data_type:?}."), + )); + } + + if values.len() % size != 0 { + return Err(Error::oos(format!( + "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", + values.len(), + size + ))); + } + let len = values.len() / size; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "validity mask length must be equal to the number of values divided by size", + )); + } + + Ok(Self { + size, + data_type, + values, + validity, + }) + } + + /// Alias to `Self::try_new(...).unwrap()` + pub fn new(data_type: DataType, values: Box, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. + pub const fn size(&self) -> usize { + self.size + } + + /// Returns a new empty [`FixedSizeListArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let values = new_empty_array(Self::get_child_and_size(&data_type).0.data_type().clone()); + Self::new(data_type, values, None) + } + + /// Returns a new null [`FixedSizeListArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let (field, size) = Self::get_child_and_size(&data_type); + + let values = new_null_array(field.data_type().clone(), length * size); + Self::new(data_type, values, Some(Bitmap::new_zeroed(length))) + } +} + +// must use +impl FixedSizeListArray { + /// Slices this [`FixedSizeListArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`FixedSizeListArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values + .slice_unchecked(offset * self.size, length * self.size); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// accessors +impl FixedSizeListArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the inner array. + pub fn values(&self) -> &Box { + &self.values + } + + /// Returns the `Vec` at position `i`. + /// # Panic: + /// panics iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> Box { + self.values.sliced(i * self.size, self.size) + } + + /// Returns the `Vec` at position `i`. + /// # Safety + /// Caller must ensure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + self.values.sliced_unchecked(i * self.size, self.size) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } +} + +impl FixedSizeListArray { + pub(crate) fn try_child_and_size(data_type: &DataType) -> Result<(&Field, usize), Error> { + match data_type.to_logical_type() { + DataType::FixedSizeList(child, size) => { + if *size == 0 { + return Err(Error::oos("FixedSizeBinaryArray expects a positive size")); + } + Ok((child.as_ref(), *size)) + }, + _ => Err(Error::oos( + "FixedSizeListArray expects DataType::FixedSizeList", + )), + } + } + + pub(crate) fn get_child_and_size(data_type: &DataType) -> (&Field, usize) { + Self::try_child_and_size(data_type).unwrap() + } + + /// Returns a [`DataType`] consistent with [`FixedSizeListArray`]. + pub fn default_datatype(data_type: DataType, size: usize) -> DataType { + let field = Box::new(Field::new("item", data_type, true)); + DataType::FixedSizeList(field, size) + } +} + +impl Array for FixedSizeListArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/mutable.rs b/crates/nano-arrow/src/array/fixed_size_list/mutable.rs new file mode 100644 index 000000000000..bef25a1cbf1f --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/mutable.rs @@ -0,0 +1,256 @@ +use std::sync::Arc; + +use super::FixedSizeListArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, PushUnchecked, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; + +/// The mutable version of [`FixedSizeListArray`]. +#[derive(Debug, Clone)] +pub struct MutableFixedSizeListArray { + data_type: DataType, + size: usize, + values: M, + validity: Option, +} + +impl From> for FixedSizeListArray { + fn from(mut other: MutableFixedSizeListArray) -> Self { + FixedSizeListArray::new( + other.data_type, + other.values.as_box(), + other.validity.map(|x| x.into()), + ) + } +} + +impl MutableFixedSizeListArray { + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. + pub fn new(values: M, size: usize) -> Self { + let data_type = FixedSizeListArray::default_datatype(values.data_type().clone(), size); + Self::new_from(values, data_type, size) + } + + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. + pub fn new_with_field(values: M, name: &str, nullable: bool, size: usize) -> Self { + let data_type = DataType::FixedSizeList( + Box::new(Field::new(name, values.data_type().clone(), nullable)), + size, + ); + Self::new_from(values, data_type, size) + } + + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`], [`DataType`] and size. + pub fn new_from(values: M, data_type: DataType, size: usize) -> Self { + assert_eq!(values.len(), 0); + match data_type { + DataType::FixedSizeList(..) => (), + _ => panic!("data type must be FixedSizeList (got {data_type:?})"), + }; + Self { + size, + data_type, + values, + validity: None, + } + } + + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. + pub const fn size(&self) -> usize { + self.size + } + + /// The length of this array + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The inner values + pub fn values(&self) -> &M { + &self.values + } + + /// The values as a mutable reference + pub fn mut_values(&mut self) -> &mut M { + &mut self.values + } + + fn init_validity(&mut self) { + let len = self.values.len() / self.size; + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn try_push_valid(&mut self) -> Result<()> { + if self.values.len() % self.size != 0 { + return Err(Error::Overflow); + }; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn push_valid(&mut self) { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + #[inline] + fn push_null(&mut self) { + (0..self.size).for_each(|_| self.values.push_null()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableFixedSizeListArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableFixedSizeListArray { + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + FixedSizeListArray::new( + self.data_type.clone(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + FixedSizeListArray::new( + self.data_type.clone(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + (0..self.size).for_each(|_| { + self.values.push_null(); + }); + if let Some(validity) = &mut self.validity { + validity.push(false) + } else { + self.init_validity() + } + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl TryExtend> for MutableFixedSizeListArray +where + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_extend>>(&mut self, iter: II) -> Result<()> { + for items in iter { + self.try_push(items)?; + } + Ok(()) + } +} + +impl TryPush> for MutableFixedSizeListArray +where + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_push(&mut self, item: Option) -> Result<()> { + if let Some(items) = item { + self.values.try_extend(items)?; + self.try_push_valid()?; + } else { + self.push_null(); + } + Ok(()) + } +} + +impl PushUnchecked> for MutableFixedSizeListArray +where + M: MutableArray + Extend>, + I: IntoIterator>, +{ + /// # Safety + /// The caller must ensure that the `I` iterates exactly over `size` + /// items, where `size` is the fixed size width. + #[inline] + unsafe fn push_unchecked(&mut self, item: Option) { + if let Some(items) = item { + self.values.extend(items); + self.push_valid(); + } else { + self.push_null(); + } + } +} + +impl TryExtendFromSelf for MutableFixedSizeListArray +where + M: MutableArray + TryExtendFromSelf, +{ + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/nano-arrow/src/array/fmt.rs b/crates/nano-arrow/src/array/fmt.rs new file mode 100644 index 000000000000..ebc6937714cc --- /dev/null +++ b/crates/nano-arrow/src/array/fmt.rs @@ -0,0 +1,181 @@ +use std::fmt::{Result, Write}; + +use super::Array; +use crate::bitmap::Bitmap; + +/// Returns a function that writes the value of the element of `array` +/// at position `index` to a [`Write`], +/// writing `null` in the null slots. +pub fn get_value_display<'a, F: Write + 'a>( + array: &'a dyn Array, + null: &'static str, +) -> Box Result + 'a> { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => Box::new(move |f, _| write!(f, "{null}")), + Boolean => Box::new(|f, index| { + super::boolean::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, f) + }), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let writer = super::primitive::fmt::get_write_value::<$T, _>( + array.as_any().downcast_ref().unwrap(), + ); + Box::new(move |f, index| writer(f, index)) + }), + Binary => Box::new(|f, index| { + super::binary::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + FixedSizeBinary => Box::new(|f, index| { + super::fixed_size_binary::fmt::write_value( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + LargeBinary => Box::new(|f, index| { + super::binary::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + Utf8 => Box::new(|f, index| { + super::utf8::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + LargeUtf8 => Box::new(|f, index| { + super::utf8::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + List => Box::new(move |f, index| { + super::list::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + FixedSizeList => Box::new(move |f, index| { + super::fixed_size_list::fmt::write_value( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + LargeList => Box::new(move |f, index| { + super::list::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + Struct => Box::new(move |f, index| { + super::struct_::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Union => Box::new(move |f, index| { + super::union::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Map => Box::new(move |f, index| { + super::map::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + Box::new(move |f, index| { + super::dictionary::fmt::write_value::<$T,_>(array.as_any().downcast_ref().unwrap(), index, null, f) + }) + }), + } +} + +/// Returns a function that writes the element of `array` +/// at position `index` to a [`Write`], writing `null` to the null slots. +pub fn get_display<'a, F: Write + 'a>( + array: &'a dyn Array, + null: &'static str, +) -> Box Result + 'a> { + let value_display = get_value_display(array, null); + Box::new(move |f, row| { + if array.is_null(row) { + f.write_str(null) + } else { + value_display(f, row) + } + }) +} + +pub fn write_vec( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + f.write_char('[')?; + write_list(f, d, validity, len, null, new_lines)?; + f.write_char(']')?; + Ok(()) +} + +fn write_list( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + for index in 0..len { + if index != 0 { + f.write_char(',')?; + f.write_char(if new_lines { '\n' } else { ' ' })?; + } + if let Some(val) = validity { + if val.get_bit(index) { + d(f, index) + } else { + write!(f, "{null}") + } + } else { + d(f, index) + }?; + } + Ok(()) +} + +pub fn write_map( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + f.write_char('{')?; + write_list(f, d, validity, len, null, new_lines)?; + f.write_char('}')?; + Ok(()) +} diff --git a/crates/nano-arrow/src/array/growable/binary.rs b/crates/nano-arrow/src/array/growable/binary.rs new file mode 100644 index 000000000000..ca095f351446 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/binary.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, extend_offset_values, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, BinaryArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::offset::{Offset, Offsets}; + +/// Concrete [`Growable`] for the [`BinaryArray`]. +pub struct GrowableBinary<'a, O: Offset> { + arrays: Vec<&'a BinaryArray>, + data_type: DataType, + validity: MutableBitmap, + values: Vec, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a, O: Offset> GrowableBinary<'a, O> { + /// Creates a new [`GrowableBinary`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a BinaryArray>, mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + Self { + arrays, + data_type, + values: Vec::with_capacity(0), + offsets: Offsets::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> BinaryArray { + let data_type = self.data_type.clone(); + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = std::mem::take(&mut self.values); + + BinaryArray::::new(data_type, offsets.into(), values.into(), validity.into()) + } +} + +impl<'a, O: Offset> Growable<'a> for GrowableBinary<'a, O> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let offsets = array.offsets(); + let values = array.values(); + + self.offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + // values + extend_offset_values::(&mut self.values, offsets.buffer(), values, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + self.to().arced() + } + + fn as_box(&mut self) -> Box { + self.to().boxed() + } +} + +impl<'a, O: Offset> From> for BinaryArray { + fn from(val: GrowableBinary<'a, O>) -> Self { + BinaryArray::::new( + val.data_type, + val.offsets.into(), + val.values.into(), + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/boolean.rs b/crates/nano-arrow/src/array/growable/boolean.rs new file mode 100644 index 000000000000..f69d66f1d696 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/boolean.rs @@ -0,0 +1,91 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, BooleanArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`BooleanArray`]. +pub struct GrowableBoolean<'a> { + arrays: Vec<&'a BooleanArray>, + data_type: DataType, + validity: MutableBitmap, + values: MutableBitmap, + extend_null_bits: Vec>, +} + +impl<'a> GrowableBoolean<'a> { + /// Creates a new [`GrowableBoolean`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a BooleanArray>, mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + Self { + arrays, + data_type, + values: MutableBitmap::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> BooleanArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + + BooleanArray::new(self.data_type.clone(), values.into(), validity.into()) + } +} + +impl<'a> Growable<'a> for GrowableBoolean<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let values = array.values(); + + let (slice, offset, _) = values.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { + self.values + .extend_from_slice_unchecked(slice, start + offset, len); + } + } + + fn extend_validity(&mut self, additional: usize) { + self.values.extend_constant(additional, false); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for BooleanArray { + fn from(val: GrowableBoolean<'a>) -> Self { + BooleanArray::new(val.data_type, val.values.into(), val.validity.into()) + } +} diff --git a/crates/nano-arrow/src/array/growable/dictionary.rs b/crates/nano-arrow/src/array/growable/dictionary.rs new file mode 100644 index 000000000000..fa85cdad6f8e --- /dev/null +++ b/crates/nano-arrow/src/array/growable/dictionary.rs @@ -0,0 +1,157 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`DictionaryArray`]. +/// # Implementation +/// This growable does not perform collision checks and instead concatenates +/// the values of each [`DictionaryArray`] one after the other. +pub struct GrowableDictionary<'a, K: DictionaryKey> { + data_type: DataType, + keys_values: Vec<&'a [K]>, + key_values: Vec, + key_validity: MutableBitmap, + offsets: Vec, + values: Box, + extend_null_bits: Vec>, +} + +fn concatenate_values( + arrays_keys: &[&PrimitiveArray], + arrays_values: &[&dyn Array], + capacity: usize, +) -> (Box, Vec) { + let mut mutable = make_growable(arrays_values, false, capacity); + let mut offsets = Vec::with_capacity(arrays_keys.len() + 1); + offsets.push(0); + for (i, values) in arrays_values.iter().enumerate() { + mutable.extend(i, 0, values.len()); + offsets.push(offsets[i] + values.len()); + } + (mutable.as_box(), offsets) +} + +impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { + /// Creates a new [`GrowableDictionary`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: &[&'a DictionaryArray], mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let arrays_keys = arrays.iter().map(|array| array.keys()).collect::>(); + let keys_values = arrays_keys + .iter() + .map(|array| array.values().as_slice()) + .collect::>(); + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(array.keys(), use_validity)) + .collect(); + + let arrays_values = arrays + .iter() + .map(|array| array.values().as_ref()) + .collect::>(); + + let (values, offsets) = concatenate_values(&arrays_keys, &arrays_values, capacity); + + Self { + data_type, + offsets, + values, + keys_values, + key_values: Vec::with_capacity(capacity), + key_validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + #[inline] + fn to(&mut self) -> DictionaryArray { + let validity = std::mem::take(&mut self.key_validity); + let key_values = std::mem::take(&mut self.key_values); + + #[cfg(debug_assertions)] + { + crate::array::specification::check_indexes(&key_values, self.values.len()).unwrap(); + } + let keys = + PrimitiveArray::::new(T::PRIMITIVE.into(), key_values.into(), validity.into()); + + // Safety - the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new_unchecked( + self.data_type.clone(), + keys, + self.values.clone(), + ) + .unwrap() + } + } +} + +impl<'a, T: DictionaryKey> Growable<'a> for GrowableDictionary<'a, T> { + #[inline] + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.key_validity, start, len); + + let values = &self.keys_values[index][start..start + len]; + let offset = self.offsets[index]; + self.key_values.extend( + values + .iter() + // `.unwrap_or(0)` because this operation does not check for null values, which may contain any key. + .map(|x| { + let x: usize = offset + (*x).try_into().unwrap_or(0); + let x: T = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }), + ); + } + + #[inline] + fn len(&self) -> usize { + self.key_values.len() + } + + #[inline] + fn extend_validity(&mut self, additional: usize) { + self.key_values + .resize(self.key_values.len() + additional, T::default()); + self.key_validity.extend_constant(additional, false); + } + + #[inline] + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + #[inline] + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, T: DictionaryKey> From> for DictionaryArray { + #[inline] + fn from(mut val: GrowableDictionary<'a, T>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/fixed_binary.rs b/crates/nano-arrow/src/array/growable/fixed_binary.rs new file mode 100644 index 000000000000..bc6b307f97f9 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/fixed_binary.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, FixedSizeBinaryArray}; +use crate::bitmap::MutableBitmap; + +/// Concrete [`Growable`] for the [`FixedSizeBinaryArray`]. +pub struct GrowableFixedSizeBinary<'a> { + arrays: Vec<&'a FixedSizeBinaryArray>, + validity: MutableBitmap, + values: Vec, + extend_null_bits: Vec>, + size: usize, // just a cache +} + +impl<'a> GrowableFixedSizeBinary<'a> { + /// Creates a new [`GrowableFixedSizeBinary`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new( + arrays: Vec<&'a FixedSizeBinaryArray>, + mut use_validity: bool, + capacity: usize, + ) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let size = FixedSizeBinaryArray::get_size(arrays[0].data_type()); + Self { + arrays, + values: Vec::with_capacity(0), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + size, + } + } + + fn to(&mut self) -> FixedSizeBinaryArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + + FixedSizeBinaryArray::new( + self.arrays[0].data_type().clone(), + values.into(), + validity.into(), + ) + } +} + +impl<'a> Growable<'a> for GrowableFixedSizeBinary<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let values = array.values(); + + self.values + .extend_from_slice(&values[start * self.size..start * self.size + len * self.size]); + } + + fn extend_validity(&mut self, additional: usize) { + self.values + .extend_from_slice(&vec![0; self.size * additional]); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for FixedSizeBinaryArray { + fn from(val: GrowableFixedSizeBinary<'a>) -> Self { + FixedSizeBinaryArray::new( + val.arrays[0].data_type().clone(), + val.values.into(), + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/fixed_size_list.rs b/crates/nano-arrow/src/array/growable/fixed_size_list.rs new file mode 100644 index 000000000000..cacad36bb4a7 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/fixed_size_list.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, FixedSizeListArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`FixedSizeListArray`]. +pub struct GrowableFixedSizeList<'a> { + arrays: Vec<&'a FixedSizeListArray>, + validity: MutableBitmap, + values: Box + 'a>, + extend_null_bits: Vec>, + size: usize, +} + +impl<'a> GrowableFixedSizeList<'a> { + /// Creates a new [`GrowableFixedSizeList`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new( + arrays: Vec<&'a FixedSizeListArray>, + mut use_validity: bool, + capacity: usize, + ) -> Self { + assert!(!arrays.is_empty()); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let size = + if let DataType::FixedSizeList(_, size) = &arrays[0].data_type().to_logical_type() { + *size + } else { + unreachable!("`GrowableFixedSizeList` expects `DataType::FixedSizeList`") + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let inner = arrays + .iter() + .map(|array| array.values().as_ref()) + .collect::>(); + let values = make_growable(&inner, use_validity, 0); + + Self { + arrays, + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + size, + } + } + + fn to(&mut self) -> FixedSizeListArray { + let validity = std::mem::take(&mut self.validity); + let values = self.values.as_box(); + + FixedSizeListArray::new(self.arrays[0].data_type().clone(), values, validity.into()) + } +} + +impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + self.values + .extend(index, start * self.size, len * self.size); + } + + fn extend_validity(&mut self, additional: usize) { + self.values.extend_validity(additional * self.size); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for FixedSizeListArray { + fn from(val: GrowableFixedSizeList<'a>) -> Self { + let mut values = val.values; + let values = values.as_box(); + + Self::new( + val.arrays[0].data_type().clone(), + values, + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/list.rs b/crates/nano-arrow/src/array/growable/list.rs new file mode 100644 index 000000000000..9fdf9eb047bf --- /dev/null +++ b/crates/nano-arrow/src/array/growable/list.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, ListArray}; +use crate::bitmap::MutableBitmap; +use crate::offset::{Offset, Offsets}; + +fn extend_offset_values( + growable: &mut GrowableList<'_, O>, + index: usize, + start: usize, + len: usize, +) { + let array = growable.arrays[index]; + let offsets = array.offsets(); + + growable + .offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + let end = offsets.buffer()[start + len].to_usize(); + let start = offsets.buffer()[start].to_usize(); + let len = end - start; + growable.values.extend(index, start, len); +} + +/// Concrete [`Growable`] for the [`ListArray`]. +pub struct GrowableList<'a, O: Offset> { + arrays: Vec<&'a ListArray>, + validity: MutableBitmap, + values: Box + 'a>, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a, O: Offset> GrowableList<'a, O> { + /// Creates a new [`GrowableList`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a ListArray>, mut use_validity: bool, capacity: usize) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let inner = arrays + .iter() + .map(|array| array.values().as_ref()) + .collect::>(); + let values = make_growable(&inner, use_validity, 0); + + Self { + arrays, + offsets: Offsets::with_capacity(capacity), + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> ListArray { + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = self.values.as_box(); + + ListArray::::new( + self.arrays[0].data_type().clone(), + offsets.into(), + values, + validity.into(), + ) + } +} + +impl<'a, O: Offset> Growable<'a> for GrowableList<'a, O> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + extend_offset_values::(self, index, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, O: Offset> From> for ListArray { + fn from(mut val: GrowableList<'a, O>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/map.rs b/crates/nano-arrow/src/array/growable/map.rs new file mode 100644 index 000000000000..62f9d4c5c53a --- /dev/null +++ b/crates/nano-arrow/src/array/growable/map.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, MapArray}; +use crate::bitmap::MutableBitmap; +use crate::offset::Offsets; + +fn extend_offset_values(growable: &mut GrowableMap<'_>, index: usize, start: usize, len: usize) { + let array = growable.arrays[index]; + let offsets = array.offsets(); + + growable + .offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + let end = offsets.buffer()[start + len] as usize; + let start = offsets.buffer()[start] as usize; + let len = end - start; + growable.values.extend(index, start, len); +} + +/// Concrete [`Growable`] for the [`MapArray`]. +pub struct GrowableMap<'a> { + arrays: Vec<&'a MapArray>, + validity: MutableBitmap, + values: Box + 'a>, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a> GrowableMap<'a> { + /// Creates a new [`GrowableMap`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a MapArray>, mut use_validity: bool, capacity: usize) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let inner = arrays + .iter() + .map(|array| array.field().as_ref()) + .collect::>(); + let values = make_growable(&inner, use_validity, 0); + + Self { + arrays, + offsets: Offsets::with_capacity(capacity), + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> MapArray { + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = self.values.as_box(); + + MapArray::new( + self.arrays[0].data_type().clone(), + offsets.into(), + values, + validity.into(), + ) + } +} + +impl<'a> Growable<'a> for GrowableMap<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + extend_offset_values(self, index, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for MapArray { + fn from(mut val: GrowableMap<'a>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/mod.rs b/crates/nano-arrow/src/array/growable/mod.rs new file mode 100644 index 000000000000..a3fe4b739451 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/mod.rs @@ -0,0 +1,149 @@ +//! Contains the trait [`Growable`] and corresponding concreate implementations, one per concrete array, +//! that offer the ability to create a new [`Array`] out of slices of existing [`Array`]s. + +use std::sync::Arc; + +use crate::array::*; +use crate::datatypes::*; + +mod binary; +pub use binary::GrowableBinary; +mod union; +pub use union::GrowableUnion; +mod boolean; +pub use boolean::GrowableBoolean; +mod fixed_binary; +pub use fixed_binary::GrowableFixedSizeBinary; +mod null; +pub use null::GrowableNull; +mod primitive; +pub use primitive::GrowablePrimitive; +mod list; +pub use list::GrowableList; +mod map; +pub use map::GrowableMap; +mod structure; +pub use structure::GrowableStruct; +mod fixed_size_list; +pub use fixed_size_list::GrowableFixedSizeList; +mod utf8; +pub use utf8::GrowableUtf8; +mod dictionary; +pub use dictionary::GrowableDictionary; + +mod utils; + +/// Describes a struct that can be extended from slices of other pre-existing [`Array`]s. +/// This is used in operations where a new array is built out of other arrays, such +/// as filter and concatenation. +pub trait Growable<'a> { + /// Extends this [`Growable`] with elements from the bounded [`Array`] at index `index` from + /// a slice starting at `start` and length `len`. + /// # Panic + /// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. + fn extend(&mut self, index: usize, start: usize, len: usize); + + /// Extends this [`Growable`] with null elements, disregarding the bound arrays + fn extend_validity(&mut self, additional: usize); + + /// The current length of the [`Growable`]. + fn len(&self) -> usize; + + /// Converts this [`Growable`] to an [`Arc`], thereby finishing the mutation. + /// Self will be empty after such operation. + fn as_arc(&mut self) -> Arc { + self.as_box().into() + } + + /// Converts this [`Growable`] to an [`Box`], thereby finishing the mutation. + /// Self will be empty after such operation + fn as_box(&mut self) -> Box; +} + +macro_rules! dyn_growable { + ($ty:ty, $arrays:expr, $use_validity:expr, $capacity:expr) => {{ + let arrays = $arrays + .iter() + .map(|array| array.as_any().downcast_ref().unwrap()) + .collect::>(); + Box::new(<$ty>::new(arrays, $use_validity, $capacity)) + }}; +} + +/// Creates a new [`Growable`] from an arbitrary number of [`Array`]s. +/// # Panics +/// This function panics iff +/// * the arrays do not have the same [`DataType`]. +/// * `arrays.is_empty()`. +pub fn make_growable<'a>( + arrays: &[&'a dyn Array], + use_validity: bool, + capacity: usize, +) -> Box + 'a> { + assert!(!arrays.is_empty()); + let data_type = arrays[0].data_type(); + + use PhysicalType::*; + match data_type.to_physical_type() { + Null => Box::new(null::GrowableNull::new(data_type.clone())), + Boolean => dyn_growable!(boolean::GrowableBoolean, arrays, use_validity, capacity), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + dyn_growable!(primitive::GrowablePrimitive::<$T>, arrays, use_validity, capacity) + }), + Utf8 => dyn_growable!(utf8::GrowableUtf8::, arrays, use_validity, capacity), + LargeUtf8 => dyn_growable!(utf8::GrowableUtf8::, arrays, use_validity, capacity), + Binary => dyn_growable!( + binary::GrowableBinary::, + arrays, + use_validity, + capacity + ), + LargeBinary => dyn_growable!( + binary::GrowableBinary::, + arrays, + use_validity, + capacity + ), + FixedSizeBinary => dyn_growable!( + fixed_binary::GrowableFixedSizeBinary, + arrays, + use_validity, + capacity + ), + List => dyn_growable!(list::GrowableList::, arrays, use_validity, capacity), + LargeList => dyn_growable!(list::GrowableList::, arrays, use_validity, capacity), + Struct => dyn_growable!(structure::GrowableStruct, arrays, use_validity, capacity), + FixedSizeList => dyn_growable!( + fixed_size_list::GrowableFixedSizeList, + arrays, + use_validity, + capacity + ), + Union => { + let arrays = arrays + .iter() + .map(|array| array.as_any().downcast_ref().unwrap()) + .collect::>(); + Box::new(union::GrowableUnion::new(arrays, capacity)) + }, + Map => dyn_growable!(map::GrowableMap, arrays, use_validity, capacity), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let arrays = arrays + .iter() + .map(|array| { + array + .as_any() + .downcast_ref::>() + .unwrap() + }) + .collect::>(); + Box::new(dictionary::GrowableDictionary::<$T>::new( + &arrays, + use_validity, + capacity, + )) + }) + }, + } +} diff --git a/crates/nano-arrow/src/array/growable/null.rs b/crates/nano-arrow/src/array/growable/null.rs new file mode 100644 index 000000000000..44e1c2488b0f --- /dev/null +++ b/crates/nano-arrow/src/array/growable/null.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use super::Growable; +use crate::array::{Array, NullArray}; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`NullArray`]. +pub struct GrowableNull { + data_type: DataType, + length: usize, +} + +impl Default for GrowableNull { + fn default() -> Self { + Self::new(DataType::Null) + } +} + +impl GrowableNull { + /// Creates a new [`GrowableNull`]. + pub fn new(data_type: DataType) -> Self { + Self { + data_type, + length: 0, + } + } +} + +impl<'a> Growable<'a> for GrowableNull { + fn extend(&mut self, _: usize, _: usize, len: usize) { + self.length += len; + } + + fn extend_validity(&mut self, additional: usize) { + self.length += additional; + } + + #[inline] + fn len(&self) -> usize { + self.length + } + + fn as_arc(&mut self) -> Arc { + Arc::new(NullArray::new(self.data_type.clone(), self.length)) + } + + fn as_box(&mut self) -> Box { + Box::new(NullArray::new(self.data_type.clone(), self.length)) + } +} + +impl From for NullArray { + fn from(val: GrowableNull) -> Self { + NullArray::new(val.data_type, val.length) + } +} diff --git a/crates/nano-arrow/src/array/growable/primitive.rs b/crates/nano-arrow/src/array/growable/primitive.rs new file mode 100644 index 000000000000..cade744a5936 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/primitive.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::types::NativeType; + +/// Concrete [`Growable`] for the [`PrimitiveArray`]. +pub struct GrowablePrimitive<'a, T: NativeType> { + data_type: DataType, + arrays: Vec<&'a [T]>, + validity: MutableBitmap, + values: Vec, + extend_null_bits: Vec>, +} + +impl<'a, T: NativeType> GrowablePrimitive<'a, T> { + /// Creates a new [`GrowablePrimitive`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new( + arrays: Vec<&'a PrimitiveArray>, + mut use_validity: bool, + capacity: usize, + ) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let data_type = arrays[0].data_type().clone(); + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let arrays = arrays + .iter() + .map(|array| array.values().as_slice()) + .collect::>(); + + Self { + data_type, + arrays, + values: Vec::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + #[inline] + fn to(&mut self) -> PrimitiveArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + + PrimitiveArray::::new(self.data_type.clone(), values.into(), validity.into()) + } +} + +impl<'a, T: NativeType> Growable<'a> for GrowablePrimitive<'a, T> { + #[inline] + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let values = self.arrays[index]; + self.values.extend_from_slice(&values[start..start + len]); + } + + #[inline] + fn extend_validity(&mut self, additional: usize) { + self.values + .resize(self.values.len() + additional, T::default()); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + #[inline] + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, T: NativeType> From> for PrimitiveArray { + #[inline] + fn from(val: GrowablePrimitive<'a, T>) -> Self { + PrimitiveArray::::new(val.data_type, val.values.into(), val.validity.into()) + } +} diff --git a/crates/nano-arrow/src/array/growable/structure.rs b/crates/nano-arrow/src/array/growable/structure.rs new file mode 100644 index 000000000000..10afd20e7f06 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/structure.rs @@ -0,0 +1,132 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, StructArray}; +use crate::bitmap::MutableBitmap; + +/// Concrete [`Growable`] for the [`StructArray`]. +pub struct GrowableStruct<'a> { + arrays: Vec<&'a StructArray>, + validity: MutableBitmap, + values: Vec + 'a>>, + extend_null_bits: Vec>, +} + +impl<'a> GrowableStruct<'a> { + /// Creates a new [`GrowableStruct`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a StructArray>, mut use_validity: bool, capacity: usize) -> Self { + assert!(!arrays.is_empty()); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let arrays = arrays + .iter() + .map(|array| array.as_any().downcast_ref::().unwrap()) + .collect::>(); + + // ([field1, field2], [field3, field4]) -> ([field1, field3], [field2, field3]) + let values = (0..arrays[0].values().len()) + .map(|i| { + make_growable( + &arrays + .iter() + .map(|x| x.values()[i].as_ref()) + .collect::>(), + use_validity, + capacity, + ) + }) + .collect::>>(); + + Self { + arrays, + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> StructArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + let values = values.into_iter().map(|mut x| x.as_box()).collect(); + + StructArray::new(self.arrays[0].data_type().clone(), values, validity.into()) + } +} + +impl<'a> Growable<'a> for GrowableStruct<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + if array.null_count() == 0 { + self.values + .iter_mut() + .for_each(|child| child.extend(index, start, len)) + } else { + (start..start + len).for_each(|i| { + if array.is_valid(i) { + self.values + .iter_mut() + .for_each(|child| child.extend(index, i, 1)) + } else { + self.values + .iter_mut() + .for_each(|child| child.extend_validity(1)) + } + }) + } + } + + fn extend_validity(&mut self, additional: usize) { + self.values + .iter_mut() + .for_each(|child| child.extend_validity(additional)); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + // All children should have the same indexing, so just use the first + // one. If we don't have children, we might still have a validity + // array, so use that. + if let Some(child) = self.values.get(0) { + child.len() + } else { + self.validity.len() + } + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for StructArray { + fn from(val: GrowableStruct<'a>) -> Self { + let values = val.values.into_iter().map(|mut x| x.as_box()).collect(); + + StructArray::new( + val.arrays[0].data_type().clone(), + values, + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/union.rs b/crates/nano-arrow/src/array/growable/union.rs new file mode 100644 index 000000000000..4ef39f16fbb3 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/union.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; + +use super::{make_growable, Growable}; +use crate::array::{Array, UnionArray}; + +/// Concrete [`Growable`] for the [`UnionArray`]. +pub struct GrowableUnion<'a> { + arrays: Vec<&'a UnionArray>, + types: Vec, + offsets: Option>, + fields: Vec + 'a>>, +} + +impl<'a> GrowableUnion<'a> { + /// Creates a new [`GrowableUnion`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// Panics iff + /// * `arrays` is empty. + /// * any of the arrays has a different + pub fn new(arrays: Vec<&'a UnionArray>, capacity: usize) -> Self { + let first = arrays[0].data_type(); + assert!(arrays.iter().all(|x| x.data_type() == first)); + + let has_offsets = arrays[0].offsets().is_some(); + + let fields = (0..arrays[0].fields().len()) + .map(|i| { + make_growable( + &arrays + .iter() + .map(|x| x.fields()[i].as_ref()) + .collect::>(), + false, + capacity, + ) + }) + .collect::>>(); + + Self { + arrays, + fields, + offsets: if has_offsets { + Some(Vec::with_capacity(capacity)) + } else { + None + }, + types: Vec::with_capacity(capacity), + } + } + + fn to(&mut self) -> UnionArray { + let types = std::mem::take(&mut self.types); + let fields = std::mem::take(&mut self.fields); + let offsets = std::mem::take(&mut self.offsets); + let fields = fields.into_iter().map(|mut x| x.as_box()).collect(); + + UnionArray::new( + self.arrays[0].data_type().clone(), + types.into(), + fields, + offsets.map(|x| x.into()), + ) + } +} + +impl<'a> Growable<'a> for GrowableUnion<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + let array = self.arrays[index]; + + let types = &array.types()[start..start + len]; + self.types.extend(types); + if let Some(x) = self.offsets.as_mut() { + let offsets = &array.offsets().unwrap()[start..start + len]; + + // in a dense union, each slot has its own offset. We extend the fields accordingly. + for (&type_, &offset) in types.iter().zip(offsets.iter()) { + let field = &mut self.fields[type_ as usize]; + // The offset for the element that is about to be extended is the current length + // of the child field of the corresponding type. Note that this may be very + // different than the original offset from the array we are extending from as + // it is a function of the previous extensions to this child. + x.push(field.len() as i32); + field.extend(index, offset as usize, 1); + } + } else { + // in a sparse union, every field has the same length => extend all fields equally + self.fields + .iter_mut() + .for_each(|field| field.extend(index, start, len)) + } + } + + fn extend_validity(&mut self, _additional: usize) {} + + #[inline] + fn len(&self) -> usize { + self.types.len() + } + + fn as_arc(&mut self) -> Arc { + self.to().arced() + } + + fn as_box(&mut self) -> Box { + self.to().boxed() + } +} + +impl<'a> From> for UnionArray { + fn from(val: GrowableUnion<'a>) -> Self { + let fields = val.fields.into_iter().map(|mut x| x.as_box()).collect(); + + UnionArray::new( + val.arrays[0].data_type().clone(), + val.types.into(), + fields, + val.offsets.map(|x| x.into()), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/utf8.rs b/crates/nano-arrow/src/array/growable/utf8.rs new file mode 100644 index 000000000000..1ea01ffd040a --- /dev/null +++ b/crates/nano-arrow/src/array/growable/utf8.rs @@ -0,0 +1,104 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, extend_offset_values, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, Utf8Array}; +use crate::bitmap::MutableBitmap; +use crate::offset::{Offset, Offsets}; + +/// Concrete [`Growable`] for the [`Utf8Array`]. +pub struct GrowableUtf8<'a, O: Offset> { + arrays: Vec<&'a Utf8Array>, + validity: MutableBitmap, + values: Vec, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a, O: Offset> GrowableUtf8<'a, O> { + /// Creates a new [`GrowableUtf8`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a Utf8Array>, mut use_validity: bool, capacity: usize) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + Self { + arrays: arrays.to_vec(), + values: Vec::with_capacity(0), + offsets: Offsets::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> Utf8Array { + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = std::mem::take(&mut self.values); + + #[cfg(debug_assertions)] + { + crate::array::specification::try_check_utf8(&offsets, &values).unwrap(); + } + + unsafe { + Utf8Array::::try_new_unchecked( + self.arrays[0].data_type().clone(), + offsets.into(), + values.into(), + validity.into(), + ) + .unwrap() + } + } +} + +impl<'a, O: Offset> Growable<'a> for GrowableUtf8<'a, O> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let offsets = array.offsets(); + let values = array.values(); + + self.offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + // values + extend_offset_values::(&mut self.values, offsets.as_slice(), values, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, O: Offset> From> for Utf8Array { + fn from(mut val: GrowableUtf8<'a, O>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/utils.rs b/crates/nano-arrow/src/array/growable/utils.rs new file mode 100644 index 000000000000..ecdfb522249f --- /dev/null +++ b/crates/nano-arrow/src/array/growable/utils.rs @@ -0,0 +1,40 @@ +use crate::array::Array; +use crate::bitmap::MutableBitmap; +use crate::offset::Offset; + +// function used to extend nulls from arrays. This function's lifetime is bound to the array +// because it reads nulls from it. +pub(super) type ExtendNullBits<'a> = Box; + +pub(super) fn build_extend_null_bits(array: &dyn Array, use_validity: bool) -> ExtendNullBits { + if let Some(bitmap) = array.validity() { + Box::new(move |validity, start, len| { + debug_assert!(start + len <= bitmap.len()); + let (slice, offset, _) = bitmap.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { + validity.extend_from_slice_unchecked(slice, start + offset, len); + } + }) + } else if use_validity { + Box::new(|validity, _, len| { + validity.extend_constant(len, true); + }) + } else { + Box::new(|_, _, _| {}) + } +} + +#[inline] +pub(super) fn extend_offset_values( + buffer: &mut Vec, + offsets: &[O], + values: &[u8], + start: usize, + len: usize, +) { + let start_values = offsets[start].to_usize(); + let end_values = offsets[start + len].to_usize(); + let new_values = &values[start_values..end_values]; + buffer.extend_from_slice(new_values); +} diff --git a/crates/nano-arrow/src/array/indexable.rs b/crates/nano-arrow/src/array/indexable.rs new file mode 100644 index 000000000000..d3f466722aa6 --- /dev/null +++ b/crates/nano-arrow/src/array/indexable.rs @@ -0,0 +1,194 @@ +use std::borrow::Borrow; + +use crate::array::{ + MutableArray, MutableBinaryArray, MutableBinaryValuesArray, MutableBooleanArray, + MutableFixedSizeBinaryArray, MutablePrimitiveArray, MutableUtf8Array, MutableUtf8ValuesArray, +}; +use crate::offset::Offset; +use crate::types::NativeType; + +/// Trait for arrays that can be indexed directly to extract a value. +pub trait Indexable { + /// The type of the element at index `i`; may be a reference type or a value type. + type Value<'a>: Borrow + where + Self: 'a; + + type Type: ?Sized; + + /// Returns the element at index `i`. + /// # Panic + /// May panic if `i >= self.len()`. + fn value_at(&self, index: usize) -> Self::Value<'_>; + + /// Returns the element at index `i`. + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + unsafe fn value_unchecked_at(&self, index: usize) -> Self::Value<'_> { + self.value_at(index) + } +} + +pub trait AsIndexed { + fn as_indexed(&self) -> &M::Type; +} + +impl Indexable for MutableBooleanArray { + type Value<'a> = bool; + type Type = bool; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.values().get(i) + } +} + +impl AsIndexed for bool { + #[inline] + fn as_indexed(&self) -> &bool { + self + } +} + +impl Indexable for MutableBinaryArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + // TODO: add .value() / .value_unchecked() to MutableBinaryArray? + assert!(i < self.len()); + unsafe { self.value_unchecked_at(i) } + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + // TODO: add .value() / .value_unchecked() to MutableBinaryArray? + // soundness: the invariant of the function + let (start, end) = self.offsets().start_end_unchecked(i); + // soundness: the invariant of the struct + self.values().get_unchecked(start..end) + } +} + +impl AsIndexed> for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableBinaryValuesArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl AsIndexed> for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableFixedSizeBinaryArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + // soundness: the invariant of the struct + self.value_unchecked(i) + } +} + +impl AsIndexed for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +// TODO: should NativeType derive from Hash? +impl Indexable for MutablePrimitiveArray { + type Value<'a> = T; + type Type = T; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + assert!(i < self.len()); + // TODO: add Length trait? (for both Array and MutableArray) + unsafe { self.value_unchecked_at(i) } + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + *self.values().get_unchecked(i) + } +} + +impl AsIndexed> for T { + #[inline] + fn as_indexed(&self) -> &T { + self + } +} + +impl Indexable for MutableUtf8Array { + type Value<'a> = &'a str; + type Type = str; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl> AsIndexed> for V { + #[inline] + fn as_indexed(&self) -> &str { + self.as_ref() + } +} + +impl Indexable for MutableUtf8ValuesArray { + type Value<'a> = &'a str; + type Type = str; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl> AsIndexed> for V { + #[inline] + fn as_indexed(&self) -> &str { + self.as_ref() + } +} diff --git a/crates/nano-arrow/src/array/iterator.rs b/crates/nano-arrow/src/array/iterator.rs new file mode 100644 index 000000000000..5e8ed44d861e --- /dev/null +++ b/crates/nano-arrow/src/array/iterator.rs @@ -0,0 +1,83 @@ +use crate::trusted_len::TrustedLen; + +mod private { + pub trait Sealed {} + + impl<'a, T: super::ArrayAccessor<'a>> Sealed for T {} +} + +/// Sealed trait representing assess to a value of an array. +/// # Safety +/// Implementers of this trait guarantee that +/// `value_unchecked` is safe when called up to `len` +pub unsafe trait ArrayAccessor<'a>: private::Sealed { + type Item: 'a; + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item; + fn len(&self) -> usize; +} + +/// Iterator of values of an [`ArrayAccessor`]. +#[derive(Debug, Clone)] +pub struct ArrayValuesIter<'a, A: ArrayAccessor<'a>> { + array: &'a A, + index: usize, + end: usize, +} + +impl<'a, A: ArrayAccessor<'a>> ArrayValuesIter<'a, A> { + /// Creates a new [`ArrayValuesIter`] + #[inline] + pub fn new(array: &'a A) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a, A: ArrayAccessor<'a>> Iterator for ArrayValuesIter<'a, A> { + type Item = A::Item; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(unsafe { self.array.value_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl<'a, A: ArrayAccessor<'a>> DoubleEndedIterator for ArrayValuesIter<'a, A> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(unsafe { self.array.value_unchecked(self.end) }) + } + } +} + +unsafe impl<'a, A: ArrayAccessor<'a>> TrustedLen for ArrayValuesIter<'a, A> {} +impl<'a, A: ArrayAccessor<'a>> ExactSizeIterator for ArrayValuesIter<'a, A> {} diff --git a/crates/nano-arrow/src/array/list/data.rs b/crates/nano-arrow/src/array/list/data.rs new file mode 100644 index 000000000000..6f3424c96ce6 --- /dev/null +++ b/crates/nano-arrow/src/array/list/data.rs @@ -0,0 +1,38 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, ListArray}; +use crate::bitmap::Bitmap; +use crate::offset::{Offset, OffsetsBuffer}; + +impl Arrow2Arrow for ListArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.offsets.clone().into_inner().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(vec![to_data(self.values.as_ref())]); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type, + offsets, + values: from_data(&data.child_data()[0]), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/list/ffi.rs b/crates/nano-arrow/src/array/list/ffi.rs new file mode 100644 index 000000000000..487b4ad40128 --- /dev/null +++ b/crates/nano-arrow/src/array/list/ffi.rs @@ -0,0 +1,68 @@ +use super::super::ffi::ToFfi; +use super::super::Array; +use super::ListArray; +use crate::array::FromFfi; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for ListArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + ] + } + + fn children(&self) -> Vec> { + vec![self.values.clone()] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for ListArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let child = unsafe { array.child(0)? }; + let values = ffi::try_from(child)?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new(data_type, offsets, values, validity)) + } +} diff --git a/crates/nano-arrow/src/array/list/fmt.rs b/crates/nano-arrow/src/array/list/fmt.rs new file mode 100644 index 000000000000..67dcd6b78786 --- /dev/null +++ b/crates/nano-arrow/src/array/list/fmt.rs @@ -0,0 +1,30 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::ListArray; +use crate::offset::Offset; + +pub fn write_value( + array: &ListArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for ListArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + let head = if O::IS_LARGE { + "LargeListArray" + } else { + "ListArray" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/list/iterator.rs b/crates/nano-arrow/src/array/list/iterator.rs new file mode 100644 index 000000000000..28552bf4bb65 --- /dev/null +++ b/crates/nano-arrow/src/array/list/iterator.rs @@ -0,0 +1,68 @@ +use super::ListArray; +use crate::array::{Array, ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for ListArray { + type Item = Box; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of a [`ListArray`]. +pub type ListValuesIter<'a, O> = ArrayValuesIter<'a, ListArray>; + +type ZipIter<'a, O> = ZipValidity, ListValuesIter<'a, O>, BitmapIter<'a>>; + +impl<'a, O: Offset> IntoIterator for &'a ListArray { + type Item = Option>; + type IntoIter = ZipIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, O: Offset> ListArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a, O> { + ZipValidity::new_with_validity(ListValuesIter::new(self), self.validity.as_ref()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ListValuesIter<'a, O> { + ListValuesIter::new(self) + } +} + +struct Iter>> { + current: i32, + offsets: std::vec::IntoIter, + values: I, +} + +impl> + Clone> Iterator for Iter { + type Item = Option>>; + + fn next(&mut self) -> Option { + let next = self.offsets.next(); + next.map(|next| { + let length = next - self.current; + let iter = self + .values + .clone() + .skip(self.current as usize) + .take(length as usize); + self.current = next; + Some(iter) + }) + } +} diff --git a/crates/nano-arrow/src/array/list/mod.rs b/crates/nano-arrow/src/array/list/mod.rs new file mode 100644 index 000000000000..f021deb4d7da --- /dev/null +++ b/crates/nano-arrow/src/array/list/mod.rs @@ -0,0 +1,241 @@ +use super::specification::try_check_offsets_bounds; +use super::{new_empty_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; + +/// An [`Array`] semantically equivalent to `Vec>>>` with Arrow's in-memory. +#[derive(Clone)] +pub struct ListArray { + data_type: DataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, +} + +impl ListArray { + /// Creates a new [`ListArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + let child_data_type = Self::try_get_child(&data_type)?.data_type(); + let values_data_type = values.data_type(); + if child_data_type != values_data_type { + return Err(Error::oos( + format!("ListArray's child's DataType must match. However, the expected DataType is {child_data_type:?} while it got {values_data_type:?}."), + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Creates a new [`ListArray`]. + /// + /// # Panics + /// This function panics iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// # Implementation + /// This function is `O(1)` + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, values, validity).unwrap() + } + + /// Returns a new empty [`ListArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let values = new_empty_array(Self::get_child_type(&data_type).clone()); + Self::new(data_type, OffsetsBuffer::default(), values, None) + } + + /// Returns a new null [`ListArray`]. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + let child = Self::get_child_type(&data_type).clone(); + Self::new( + data_type, + Offsets::new_zeroed(length).into(), + new_empty_array(child), + Some(Bitmap::new_zeroed(length)), + ) + } +} + +impl ListArray { + /// Slices this [`ListArray`]. + /// # Panics + /// panics iff `offset + length >= self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`ListArray`]. + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// Accessors +impl ListArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the element at index `i` + /// # Panic + /// Panics iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> Box { + assert!(i < self.len()); + // Safety: invariant of this function + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + // safety: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + let length = end - start; + + // safety: the invariant of the struct + self.values.sliced_unchecked(start, length) + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// The offsets [`Buffer`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The values. + #[inline] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl ListArray { + /// Returns a default [`DataType`]: inner field is named "item" and is nullable + pub fn default_datatype(data_type: DataType) -> DataType { + let field = Box::new(Field::new("item", data_type, true)); + if O::IS_LARGE { + DataType::LargeList(field) + } else { + DataType::List(field) + } + } + + /// Returns a the inner [`Field`] + /// # Panics + /// Panics iff the logical type is not consistent with this struct. + pub fn get_child_field(data_type: &DataType) -> &Field { + Self::try_get_child(data_type).unwrap() + } + + /// Returns a the inner [`Field`] + /// # Errors + /// Panics iff the logical type is not consistent with this struct. + pub fn try_get_child(data_type: &DataType) -> Result<&Field, Error> { + if O::IS_LARGE { + match data_type.to_logical_type() { + DataType::LargeList(child) => Ok(child.as_ref()), + _ => Err(Error::oos("ListArray expects DataType::LargeList")), + } + } else { + match data_type.to_logical_type() { + DataType::List(child) => Ok(child.as_ref()), + _ => Err(Error::oos("ListArray expects DataType::List")), + } + } + } + + /// Returns a the inner [`DataType`] + /// # Panics + /// Panics iff the logical type is not consistent with this struct. + pub fn get_child_type(data_type: &DataType) -> &DataType { + Self::get_child_field(data_type).data_type() + } +} + +impl Array for ListArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/list/mutable.rs b/crates/nano-arrow/src/array/list/mutable.rs new file mode 100644 index 000000000000..91c36ff42d21 --- /dev/null +++ b/crates/nano-arrow/src/array/list/mutable.rs @@ -0,0 +1,315 @@ +use std::sync::Arc; + +use super::ListArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// The mutable version of [`ListArray`]. +#[derive(Debug, Clone)] +pub struct MutableListArray { + data_type: DataType, + offsets: Offsets, + values: M, + validity: Option, +} + +impl MutableListArray { + /// Creates a new empty [`MutableListArray`]. + pub fn new() -> Self { + let values = M::default(); + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, 0) + } + + /// Creates a new [`MutableListArray`] with a capacity. + pub fn with_capacity(capacity: usize) -> Self { + let values = M::default(); + let data_type = ListArray::::default_datatype(values.data_type().clone()); + + let offsets = Offsets::::with_capacity(capacity); + Self { + data_type, + offsets, + values, + validity: None, + } + } +} + +impl Default for MutableListArray { + fn default() -> Self { + Self::new() + } +} + +impl From> for ListArray { + fn from(mut other: MutableListArray) -> Self { + ListArray::new( + other.data_type, + other.offsets.into(), + other.values.as_box(), + other.validity.map(|x| x.into()), + ) + } +} + +impl TryExtend> for MutableListArray +where + O: Offset, + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + fn try_extend>>(&mut self, iter: II) -> Result<()> { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for items in iter { + self.try_push(items)?; + } + Ok(()) + } +} + +impl TryPush> for MutableListArray +where + O: Offset, + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_push(&mut self, item: Option) -> Result<()> { + if let Some(items) = item { + let values = self.mut_values(); + values.try_extend(items)?; + self.try_push_valid()?; + } else { + self.push_null(); + } + Ok(()) + } +} + +impl TryExtendFromSelf for MutableListArray +where + O: Offset, + M: MutableArray + TryExtendFromSelf, +{ + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values)?; + self.offsets.try_extend_from_self(&other.offsets) + } +} + +impl MutableListArray { + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_from(values: M, data_type: DataType, capacity: usize) -> Self { + let offsets = Offsets::::with_capacity(capacity); + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets, + values, + validity: None, + } + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`]. + pub fn new_with_field(values: M, name: &str, nullable: bool) -> Self { + let field = Box::new(Field::new(name, values.data_type().clone(), nullable)); + let data_type = if O::IS_LARGE { + DataType::LargeList(field) + } else { + DataType::List(field) + }; + Self::new_from(values, data_type, 0) + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_with_capacity(values: M, capacity: usize) -> Self { + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, capacity) + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`], [`Offsets`] and + /// [`MutableBitmap`]. + pub fn new_from_mutable( + values: M, + offsets: Offsets, + validity: Option, + ) -> Self { + assert_eq!(values.len(), offsets.last().to_usize()); + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self { + data_type, + offsets, + values, + validity, + } + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn try_push_valid(&mut self) -> Result<()> { + let total_length = self.values.len(); + let offset = self.offsets.last().to_usize(); + let length = total_length.checked_sub(offset).ok_or(Error::Overflow)?; + + self.offsets.try_push(length)?; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.extend_constant(1); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + /// Expand this array, using elements from the underlying backing array. + /// Assumes the expansion begins at the highest previous offset, or zero if + /// this [`MutableListArray`] is currently empty. + /// + /// Panics if: + /// - the new offsets are not in monotonic increasing order. + /// - any new offset is not in bounds of the backing array. + /// - the passed iterator has no upper bound. + pub fn try_extend_from_lengths(&mut self, iterator: II) -> Result<()> + where + II: TrustedLen> + Clone, + { + self.offsets + .try_extend_from_lengths(iterator.clone().map(|x| x.unwrap_or_default()))?; + if let Some(validity) = &mut self.validity { + validity.extend_from_trusted_len_iter(iterator.map(|x| x.is_some())) + } + assert_eq!(self.offsets.last().to_usize(), self.values.len()); + Ok(()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// The values + pub fn mut_values(&mut self) -> &mut M { + &mut self.values + } + + /// The offsets + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// The values + pub fn values(&self) -> &M { + &self.values + } + + fn init_validity(&mut self) { + let len = self.offsets.len_proxy(); + + let mut validity = MutableBitmap::with_capacity(self.offsets.capacity()); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: ListArray = self.into(); + Arc::new(a) + } + + /// converts itself into [`Box`] + pub fn into_box(self) -> Box { + let a: ListArray = self.into(); + Box::new(a) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableListArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableListArray { + fn len(&self) -> usize { + MutableListArray::len(self) + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit(); + } +} diff --git a/crates/nano-arrow/src/array/map/data.rs b/crates/nano-arrow/src/array/map/data.rs new file mode 100644 index 000000000000..cb8862a4df3d --- /dev/null +++ b/crates/nano-arrow/src/array/map/data.rs @@ -0,0 +1,38 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, MapArray}; +use crate::bitmap::Bitmap; +use crate::offset::OffsetsBuffer; + +impl Arrow2Arrow for MapArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.offsets.clone().into_inner().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(vec![to_data(self.field.as_ref())]); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type: data.data_type().clone().into(), + offsets, + field: from_data(&data.child_data()[0]), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/map/ffi.rs b/crates/nano-arrow/src/array/map/ffi.rs new file mode 100644 index 000000000000..9193e7253753 --- /dev/null +++ b/crates/nano-arrow/src/array/map/ffi.rs @@ -0,0 +1,68 @@ +use super::super::ffi::ToFfi; +use super::super::Array; +use super::MapArray; +use crate::array::FromFfi; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::OffsetsBuffer; + +unsafe impl ToFfi for MapArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + ] + } + + fn children(&self) -> Vec> { + vec![self.field.clone()] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + field: self.field.clone(), + } + } +} + +impl FromFfi for MapArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let child = array.child(0)?; + let values = ffi::try_from(child)?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Self::try_new(data_type, offsets, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/map/fmt.rs b/crates/nano-arrow/src/array/map/fmt.rs new file mode 100644 index 000000000000..60abf56e18c5 --- /dev/null +++ b/crates/nano-arrow/src/array/map/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::MapArray; + +pub fn write_value( + array: &MapArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for MapArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "MapArray")?; + write_vec(f, writer, self.validity.as_ref(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/map/iterator.rs b/crates/nano-arrow/src/array/map/iterator.rs new file mode 100644 index 000000000000..f424e91b8043 --- /dev/null +++ b/crates/nano-arrow/src/array/map/iterator.rs @@ -0,0 +1,81 @@ +use super::MapArray; +use crate::array::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::trusted_len::TrustedLen; + +/// Iterator of values of an [`ListArray`]. +#[derive(Clone, Debug)] +pub struct MapValuesIter<'a> { + array: &'a MapArray, + index: usize, + end: usize, +} + +impl<'a> MapValuesIter<'a> { + #[inline] + pub fn new(array: &'a MapArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a> Iterator for MapValuesIter<'a> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + // Safety: + // self.end is maximized by the length of the array + Some(unsafe { self.array.value_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a> TrustedLen for MapValuesIter<'a> {} + +impl<'a> DoubleEndedIterator for MapValuesIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + // Safety: + // self.end is maximized by the length of the array + Some(unsafe { self.array.value_unchecked(self.end) }) + } + } +} + +impl<'a> IntoIterator for &'a MapArray { + type Item = Option>; + type IntoIter = ZipValidity, MapValuesIter<'a>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MapArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipValidity, MapValuesIter<'a>, BitmapIter<'a>> { + ZipValidity::new_with_validity(MapValuesIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> MapValuesIter<'a> { + MapValuesIter::new(self) + } +} diff --git a/crates/nano-arrow/src/array/map/mod.rs b/crates/nano-arrow/src/array/map/mod.rs new file mode 100644 index 000000000000..abc7993fd7d4 --- /dev/null +++ b/crates/nano-arrow/src/array/map/mod.rs @@ -0,0 +1,205 @@ +use super::specification::try_check_offsets_bounds; +use super::{new_empty_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; +use crate::offset::OffsetsBuffer; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; + +/// An array representing a (key, value), both of arbitrary logical types. +#[derive(Clone)] +pub struct MapArray { + data_type: DataType, + // invariant: field.len() == offsets.len() + offsets: OffsetsBuffer, + field: Box, + // invariant: offsets.len() - 1 == Bitmap::len() + validity: Option, +} + +impl MapArray { + /// Returns a new [`MapArray`]. + /// # Errors + /// This function errors iff: + /// * The last offset is not equal to the field' length + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`] + /// * The fields' `data_type` is not equal to the inner field of `data_type` + /// * The validity is not `None` and its length is different from `offsets.len() - 1`. + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + field: Box, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, field.len())?; + + let inner_field = Self::try_get_field(&data_type)?; + if let DataType::Struct(inner) = inner_field.data_type() { + if inner.len() != 2 { + return Err(Error::InvalidArgumentError( + "MapArray's inner `Struct` must have 2 fields (keys and maps)".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "MapArray expects `DataType::Struct` as its inner logical type".to_string(), + )); + } + if field.data_type() != inner_field.data_type() { + return Err(Error::InvalidArgumentError( + "MapArray expects `field.data_type` to match its inner DataType".to_string(), + )); + } + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + Ok(Self { + data_type, + field, + offsets, + validity, + }) + } + + /// Creates a new [`MapArray`]. + /// # Panics + /// * The last offset is not equal to the field' length. + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`], + /// * The validity is not `None` and its length is different from `offsets.len() - 1`. + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + field: Box, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, field, validity).unwrap() + } + + /// Returns a new null [`MapArray`] of `length`. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let field = new_empty_array(Self::get_field(&data_type).data_type().clone()); + Self::new( + data_type, + vec![0i32; 1 + length].try_into().unwrap(), + field, + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns a new empty [`MapArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let field = new_empty_array(Self::get_field(&data_type).data_type().clone()); + Self::new(data_type, OffsetsBuffer::default(), field, None) + } +} + +impl MapArray { + /// Returns a slice of this [`MapArray`]. + /// # Panics + /// panics iff `offset + length >= self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a slice of this [`MapArray`]. + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + pub(crate) fn try_get_field(data_type: &DataType) -> Result<&Field, Error> { + if let DataType::Map(field, _) = data_type.to_logical_type() { + Ok(field.as_ref()) + } else { + Err(Error::oos( + "The data_type's logical type must be DataType::Map", + )) + } + } + + pub(crate) fn get_field(data_type: &DataType) -> &Field { + Self::try_get_field(data_type).unwrap() + } +} + +// Accessors +impl MapArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// returns the offsets + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// Returns the field (guaranteed to be a `Struct`) + #[inline] + pub fn field(&self) -> &Box { + &self.field + } + + /// Returns the element at index `i`. + #[inline] + pub fn value(&self, i: usize) -> Box { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i`. + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + let length = end - start; + + // soundness: the invariant of the struct + self.field.sliced_unchecked(start, length) + } +} + +impl Array for MapArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/mod.rs b/crates/nano-arrow/src/array/mod.rs new file mode 100644 index 000000000000..4e8e9a2177f2 --- /dev/null +++ b/crates/nano-arrow/src/array/mod.rs @@ -0,0 +1,787 @@ +//! Contains the [`Array`] and [`MutableArray`] trait objects declaring arrays, +//! as well as concrete arrays (such as [`Utf8Array`] and [`MutableUtf8Array`]). +//! +//! Fixed-length containers with optional values +//! that are laid in memory according to the Arrow specification. +//! Each array type has its own `struct`. The following are the main array types: +//! * [`PrimitiveArray`] and [`MutablePrimitiveArray`], an array of values with a fixed length such as integers, floats, etc. +//! * [`BooleanArray`] and [`MutableBooleanArray`], an array of boolean values (stored as a bitmap) +//! * [`Utf8Array`] and [`MutableUtf8Array`], an array of variable length utf8 values +//! * [`BinaryArray`] and [`MutableBinaryArray`], an array of opaque variable length values +//! * [`ListArray`] and [`MutableListArray`], an array of arrays (e.g. `[[1, 2], None, [], [None]]`) +//! * [`StructArray`] and [`MutableStructArray`], an array of arrays identified by a string (e.g. `{"a": [1, 2], "b": [true, false]}`) +//! All immutable arrays implement the trait object [`Array`] and that can be downcasted +//! to a concrete struct based on [`PhysicalType`](crate::datatypes::PhysicalType) available from [`Array::data_type`]. +//! All immutable arrays are backed by [`Buffer`](crate::buffer::Buffer) and thus cloning and slicing them is `O(1)`. +//! +//! Most arrays contain a [`MutableArray`] counterpart that is neither clonable nor sliceable, but +//! can be operated in-place. +use std::any::Any; +use std::sync::Arc; + +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::Result; + +pub mod physical_binary; + +/// A trait representing an immutable Arrow array. Arrow arrays are trait objects +/// that are infallibly downcasted to concrete types according to the [`Array::data_type`]. +pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { + /// Converts itself to a reference of [`Any`], which enables downcasting to concrete types. + fn as_any(&self) -> &dyn Any; + + /// Converts itself to a mutable reference of [`Any`], which enables mutable downcasting to concrete types. + fn as_any_mut(&mut self) -> &mut dyn Any; + + /// The length of the [`Array`]. Every array has a length corresponding to the number of + /// elements (slots). + fn len(&self) -> usize; + + /// whether the array is empty + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The [`DataType`] of the [`Array`]. In combination with [`Array::as_any`], this can be + /// used to downcast trait objects (`dyn Array`) to concrete arrays. + fn data_type(&self) -> &DataType; + + /// The validity of the [`Array`]: every array has an optional [`Bitmap`] that, when available + /// specifies whether the array slot is valid or not (null). + /// When the validity is [`None`], all slots are valid. + fn validity(&self) -> Option<&Bitmap>; + + /// The number of null slots on this [`Array`]. + /// # Implementation + /// This is `O(1)` since the number of null elements is pre-computed. + #[inline] + fn null_count(&self) -> usize { + if self.data_type() == &DataType::Null { + return self.len(); + }; + self.validity() + .as_ref() + .map(|x| x.unset_bits()) + .unwrap_or(0) + } + + /// Returns whether slot `i` is null. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + fn is_null(&self, i: usize) -> bool { + assert!(i < self.len()); + unsafe { self.is_null_unchecked(i) } + } + + /// Returns whether slot `i` is null. + /// # Safety + /// The caller must ensure `i < self.len()` + #[inline] + unsafe fn is_null_unchecked(&self, i: usize) -> bool { + self.validity() + .as_ref() + .map(|x| !x.get_bit_unchecked(i)) + .unwrap_or(false) + } + + /// Returns whether slot `i` is valid. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + fn is_valid(&self, i: usize) -> bool { + !self.is_null(i) + } + + /// Slices this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + fn slice(&mut self, offset: usize, length: usize); + + /// Slices the [`Array`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize); + + /// Returns a slice of this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[must_use] + fn sliced(&self, offset: usize, length: usize) -> Box { + let mut new = self.to_boxed(); + new.slice(offset, length); + new + } + + /// Returns a slice of this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`, as it amounts to increase two ref counts + /// and moving the struct to the heap. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + #[must_use] + unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Box { + let mut new = self.to_boxed(); + new.slice_unchecked(offset, length); + new + } + + /// Clones this [`Array`] with a new new assigned bitmap. + /// # Panic + /// This function panics iff `validity.len() != self.len()`. + fn with_validity(&self, validity: Option) -> Box; + + /// Clone a `&dyn Array` to an owned `Box`. + fn to_boxed(&self) -> Box; +} + +dyn_clone::clone_trait_object!(Array); + +/// A trait describing an array with a backing store that can be preallocated to +/// a given size. +pub(crate) trait Container { + /// Create this array with a given capacity. + fn with_capacity(capacity: usize) -> Self + where + Self: Sized; +} + +/// A trait describing a mutable array; i.e. an array whose values can be changed. +/// Mutable arrays cannot be cloned but can be mutated in place, +/// thereby making them useful to perform numeric operations without allocations. +/// As in [`Array`], concrete arrays (such as [`MutablePrimitiveArray`]) implement how they are mutated. +pub trait MutableArray: std::fmt::Debug + Send + Sync { + /// The [`DataType`] of the array. + fn data_type(&self) -> &DataType; + + /// The length of the array. + fn len(&self) -> usize; + + /// Whether the array is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The optional validity of the array. + fn validity(&self) -> Option<&MutableBitmap>; + + /// Convert itself to an (immutable) [`Array`]. + fn as_box(&mut self) -> Box; + + /// Convert itself to an (immutable) atomically reference counted [`Array`]. + // This provided implementation has an extra allocation as it first + // boxes `self`, then converts the box into an `Arc`. Implementors may wish + // to avoid an allocation by skipping the box completely. + fn as_arc(&mut self) -> std::sync::Arc { + self.as_box().into() + } + + /// Convert to `Any`, to enable dynamic casting. + fn as_any(&self) -> &dyn Any; + + /// Convert to mutable `Any`, to enable dynamic casting. + fn as_mut_any(&mut self) -> &mut dyn Any; + + /// Adds a new null element to the array. + fn push_null(&mut self); + + /// Whether `index` is valid / set. + /// # Panic + /// Panics if `index >= self.len()`. + #[inline] + fn is_valid(&self, index: usize) -> bool { + self.validity() + .as_ref() + .map(|x| x.get(index)) + .unwrap_or(true) + } + + /// Reserves additional slots to its capacity. + fn reserve(&mut self, additional: usize); + + /// Shrink the array to fit its length. + fn shrink_to_fit(&mut self); +} + +impl MutableArray for Box { + fn len(&self) -> usize { + self.as_ref().len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.as_ref().validity() + } + + fn as_box(&mut self) -> Box { + self.as_mut().as_box() + } + + fn as_arc(&mut self) -> Arc { + self.as_mut().as_arc() + } + + fn data_type(&self) -> &DataType { + self.as_ref().data_type() + } + + fn as_any(&self) -> &dyn std::any::Any { + self.as_ref().as_any() + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self.as_mut().as_mut_any() + } + + #[inline] + fn push_null(&mut self) { + self.as_mut().push_null() + } + + fn shrink_to_fit(&mut self) { + self.as_mut().shrink_to_fit(); + } + + fn reserve(&mut self, additional: usize) { + self.as_mut().reserve(additional); + } +} + +macro_rules! general_dyn { + ($array:expr, $ty:ty, $f:expr) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + ($f)(array) + }}; +} + +macro_rules! fmt_dyn { + ($array:expr, $ty:ty, $f:expr) => {{ + let mut f = |x: &$ty| x.fmt($f); + general_dyn!($array, $ty, f) + }}; +} + +macro_rules! match_integer_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::IntegerType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + } +})} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns, f16, i256}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float16 => __with_ty__! { f16 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +impl std::fmt::Debug for dyn Array + '_ { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::datatypes::PhysicalType::*; + match self.data_type().to_physical_type() { + Null => fmt_dyn!(self, NullArray, f), + Boolean => fmt_dyn!(self, BooleanArray, f), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + fmt_dyn!(self, PrimitiveArray<$T>, f) + }), + Binary => fmt_dyn!(self, BinaryArray, f), + LargeBinary => fmt_dyn!(self, BinaryArray, f), + FixedSizeBinary => fmt_dyn!(self, FixedSizeBinaryArray, f), + Utf8 => fmt_dyn!(self, Utf8Array::, f), + LargeUtf8 => fmt_dyn!(self, Utf8Array::, f), + List => fmt_dyn!(self, ListArray::, f), + LargeList => fmt_dyn!(self, ListArray::, f), + FixedSizeList => fmt_dyn!(self, FixedSizeListArray, f), + Struct => fmt_dyn!(self, StructArray, f), + Union => fmt_dyn!(self, UnionArray, f), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + fmt_dyn!(self, DictionaryArray::<$T>, f) + }) + }, + Map => fmt_dyn!(self, MapArray, f), + } + } +} + +/// Creates a new [`Array`] with a [`Array::len`] of 0. +pub fn new_empty_array(data_type: DataType) -> Box { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null => Box::new(NullArray::new_empty(data_type)), + Boolean => Box::new(BooleanArray::new_empty(data_type)), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::new_empty(data_type)) + }), + Binary => Box::new(BinaryArray::::new_empty(data_type)), + LargeBinary => Box::new(BinaryArray::::new_empty(data_type)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_empty(data_type)), + Utf8 => Box::new(Utf8Array::::new_empty(data_type)), + LargeUtf8 => Box::new(Utf8Array::::new_empty(data_type)), + List => Box::new(ListArray::::new_empty(data_type)), + LargeList => Box::new(ListArray::::new_empty(data_type)), + FixedSizeList => Box::new(FixedSizeListArray::new_empty(data_type)), + Struct => Box::new(StructArray::new_empty(data_type)), + Union => Box::new(UnionArray::new_empty(data_type)), + Map => Box::new(MapArray::new_empty(data_type)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::new_empty(data_type)) + }) + }, + } +} + +/// Creates a new [`Array`] of [`DataType`] `data_type` and `length`. +/// The array is guaranteed to have [`Array::null_count`] equal to [`Array::len`] +/// for all types except Union, which does not have a validity. +pub fn new_null_array(data_type: DataType, length: usize) -> Box { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null => Box::new(NullArray::new_null(data_type, length)), + Boolean => Box::new(BooleanArray::new_null(data_type, length)), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::new_null(data_type, length)) + }), + Binary => Box::new(BinaryArray::::new_null(data_type, length)), + LargeBinary => Box::new(BinaryArray::::new_null(data_type, length)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_null(data_type, length)), + Utf8 => Box::new(Utf8Array::::new_null(data_type, length)), + LargeUtf8 => Box::new(Utf8Array::::new_null(data_type, length)), + List => Box::new(ListArray::::new_null(data_type, length)), + LargeList => Box::new(ListArray::::new_null(data_type, length)), + FixedSizeList => Box::new(FixedSizeListArray::new_null(data_type, length)), + Struct => Box::new(StructArray::new_null(data_type, length)), + Union => Box::new(UnionArray::new_null(data_type, length)), + Map => Box::new(MapArray::new_null(data_type, length)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::new_null(data_type, length)) + }) + }, + } +} + +/// Trait providing bi-directional conversion between arrow2 [`Array`] and arrow-rs [`ArrayData`] +/// +/// [`ArrayData`]: arrow_data::ArrayData +#[cfg(feature = "arrow_rs")] +pub trait Arrow2Arrow: Array { + /// Convert this [`Array`] into [`ArrayData`] + fn to_data(&self) -> arrow_data::ArrayData; + + /// Create this [`Array`] from [`ArrayData`] + fn from_data(data: &arrow_data::ArrayData) -> Self; +} + +#[cfg(feature = "arrow_rs")] +macro_rules! to_data_dyn { + ($array:expr, $ty:ty) => {{ + let f = |x: &$ty| x.to_data(); + general_dyn!($array, $ty, f) + }}; +} + +#[cfg(feature = "arrow_rs")] +impl From> for arrow_array::ArrayRef { + fn from(value: Box) -> Self { + value.as_ref().into() + } +} + +#[cfg(feature = "arrow_rs")] +impl From<&dyn Array> for arrow_array::ArrayRef { + fn from(value: &dyn Array) -> Self { + arrow_array::make_array(to_data(value)) + } +} + +#[cfg(feature = "arrow_rs")] +impl From for Box { + fn from(value: arrow_array::ArrayRef) -> Self { + value.as_ref().into() + } +} + +#[cfg(feature = "arrow_rs")] +impl From<&dyn arrow_array::Array> for Box { + fn from(value: &dyn arrow_array::Array) -> Self { + from_data(&value.to_data()) + } +} + +/// Convert an arrow2 [`Array`] to [`arrow_data::ArrayData`] +#[cfg(feature = "arrow_rs")] +pub fn to_data(array: &dyn Array) -> arrow_data::ArrayData { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => to_data_dyn!(array, NullArray), + Boolean => to_data_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + to_data_dyn!(array, PrimitiveArray<$T>) + }), + Binary => to_data_dyn!(array, BinaryArray), + LargeBinary => to_data_dyn!(array, BinaryArray), + FixedSizeBinary => to_data_dyn!(array, FixedSizeBinaryArray), + Utf8 => to_data_dyn!(array, Utf8Array::), + LargeUtf8 => to_data_dyn!(array, Utf8Array::), + List => to_data_dyn!(array, ListArray::), + LargeList => to_data_dyn!(array, ListArray::), + FixedSizeList => to_data_dyn!(array, FixedSizeListArray), + Struct => to_data_dyn!(array, StructArray), + Union => to_data_dyn!(array, UnionArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + to_data_dyn!(array, DictionaryArray::<$T>) + }) + }, + Map => to_data_dyn!(array, MapArray), + } +} + +/// Convert an [`arrow_data::ArrayData`] to arrow2 [`Array`] +#[cfg(feature = "arrow_rs")] +pub fn from_data(data: &arrow_data::ArrayData) -> Box { + use crate::datatypes::PhysicalType::*; + let data_type: DataType = data.data_type().clone().into(); + match data_type.to_physical_type() { + Null => Box::new(NullArray::from_data(data)), + Boolean => Box::new(BooleanArray::from_data(data)), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::from_data(data)) + }), + Binary => Box::new(BinaryArray::::from_data(data)), + LargeBinary => Box::new(BinaryArray::::from_data(data)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::from_data(data)), + Utf8 => Box::new(Utf8Array::::from_data(data)), + LargeUtf8 => Box::new(Utf8Array::::from_data(data)), + List => Box::new(ListArray::::from_data(data)), + LargeList => Box::new(ListArray::::from_data(data)), + FixedSizeList => Box::new(FixedSizeListArray::from_data(data)), + Struct => Box::new(StructArray::from_data(data)), + Union => Box::new(UnionArray::from_data(data)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::from_data(data)) + }) + }, + Map => Box::new(MapArray::from_data(data)), + } +} + +macro_rules! clone_dyn { + ($array:expr, $ty:ty) => {{ + let f = |x: &$ty| Box::new(x.clone()); + general_dyn!($array, $ty, f) + }}; +} + +// macro implementing `sliced` and `sliced_unchecked` +macro_rules! impl_sliced { + () => { + /// Returns this array sliced. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + #[inline] + #[must_use] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Returns this array sliced. + /// # Implementation + /// This function is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + }; +} + +// macro implementing `with_validity` and `set_validity` +macro_rules! impl_mut_validity { + () => { + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } + } +} + +// macro implementing `with_validity`, `set_validity` and `apply_validity` for mutable arrays +macro_rules! impl_mutable_array_mut_validity { + () => { + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + #[inline] + pub fn apply_validity MutableBitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + } +} + +// macro implementing `boxed` and `arced` +macro_rules! impl_into_array { + () => { + /// Boxes this array into a [`Box`]. + pub fn boxed(self) -> Box { + Box::new(self) + } + + /// Arcs this array into a [`std::sync::Arc`]. + pub fn arced(self) -> std::sync::Arc { + std::sync::Arc::new(self) + } + }; +} + +// macro implementing common methods of trait `Array` +macro_rules! impl_common_array { + () => { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + #[inline] + fn slice(&mut self, offset: usize, length: usize) { + self.slice(offset, length); + } + + #[inline] + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.slice_unchecked(offset, length); + } + + #[inline] + fn to_boxed(&self) -> Box { + Box::new(self.clone()) + } + }; +} + +/// Clones a dynamic [`Array`]. +/// # Implementation +/// This operation is `O(1)` over `len`, as it amounts to increase two ref counts +/// and moving the concrete struct under a `Box`. +pub fn clone(array: &dyn Array) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => clone_dyn!(array, NullArray), + Boolean => clone_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + clone_dyn!(array, PrimitiveArray<$T>) + }), + Binary => clone_dyn!(array, BinaryArray), + LargeBinary => clone_dyn!(array, BinaryArray), + FixedSizeBinary => clone_dyn!(array, FixedSizeBinaryArray), + Utf8 => clone_dyn!(array, Utf8Array::), + LargeUtf8 => clone_dyn!(array, Utf8Array::), + List => clone_dyn!(array, ListArray::), + LargeList => clone_dyn!(array, ListArray::), + FixedSizeList => clone_dyn!(array, FixedSizeListArray), + Struct => clone_dyn!(array, StructArray), + Union => clone_dyn!(array, UnionArray), + Map => clone_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + clone_dyn!(array, DictionaryArray::<$T>) + }) + }, + } +} + +// see https://users.rust-lang.org/t/generic-for-dyn-a-or-box-dyn-a-or-arc-dyn-a/69430/3 +// for details +impl<'a> AsRef<(dyn Array + 'a)> for dyn Array { + fn as_ref(&self) -> &(dyn Array + 'a) { + self + } +} + +mod binary; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod specification; +mod struct_; +mod union; +mod utf8; + +mod equal; +mod ffi; +mod fmt; +#[doc(hidden)] +pub mod indexable; +mod iterator; + +pub mod growable; +pub mod ord; + +pub use binary::{BinaryArray, BinaryValueIter, MutableBinaryArray, MutableBinaryValuesArray}; +pub use boolean::{BooleanArray, MutableBooleanArray}; +pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; +pub use equal::equal; +pub use fixed_size_binary::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; +pub use fixed_size_list::{FixedSizeListArray, MutableFixedSizeListArray}; +pub use fmt::{get_display, get_value_display}; +pub(crate) use iterator::ArrayAccessor; +pub use iterator::ArrayValuesIter; +pub use list::{ListArray, ListValuesIter, MutableListArray}; +pub use map::MapArray; +pub use null::{MutableNullArray, NullArray}; +pub use primitive::*; +pub use struct_::{MutableStructArray, StructArray}; +pub use union::UnionArray; +pub use utf8::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array, Utf8ValuesIter}; + +pub(crate) use self::ffi::{offset_buffers_children_dictionary, FromFfi, ToFfi}; + +/// A trait describing the ability of a struct to create itself from a iterator. +/// This is similar to [`Extend`], but accepted the creation to error. +pub trait TryExtend { + /// Fallible version of [`Extend::extend`]. + fn try_extend>(&mut self, iter: I) -> Result<()>; +} + +/// A trait describing the ability of a struct to receive new items. +pub trait TryPush { + /// Tries to push a new element. + fn try_push(&mut self, item: A) -> Result<()>; +} + +/// A trait describing the ability of a struct to receive new items. +pub trait PushUnchecked { + /// Push a new element that holds the invariants of the struct. + /// # Safety + /// The items must uphold the invariants of the struct + /// Read the specific implementation of the trait to understand what these are. + unsafe fn push_unchecked(&mut self, item: A); +} + +/// A trait describing the ability of a struct to extend from a reference of itself. +/// Specialization of [`TryExtend`]. +pub trait TryExtendFromSelf { + /// Tries to extend itself with elements from `other`, failing only on overflow. + fn try_extend_from_self(&mut self, other: &Self) -> Result<()>; +} + +/// Trait that [`BinaryArray`] and [`Utf8Array`] implement for the purposes of DRY. +/// # Safety +/// The implementer must ensure that +/// 1. `offsets.len() > 0` +/// 2. `offsets[i] >= offsets[i-1] for all i` +/// 3. `offsets[i] < values.len() for all i` +pub unsafe trait GenericBinaryArray: Array { + /// The values of the array + fn values(&self) -> &[u8]; + /// The offsets of the array + fn offsets(&self) -> &[O]; +} diff --git a/crates/nano-arrow/src/array/null.rs b/crates/nano-arrow/src/array/null.rs new file mode 100644 index 000000000000..4bbd11e8805d --- /dev/null +++ b/crates/nano-arrow/src/array/null.rs @@ -0,0 +1,200 @@ +use std::any::Any; + +use crate::array::{Array, FromFfi, MutableArray, ToFfi}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::ffi; + +/// The concrete [`Array`] of [`DataType::Null`]. +#[derive(Clone)] +pub struct NullArray { + data_type: DataType, + length: usize, +} + +impl NullArray { + /// Returns a new [`NullArray`]. + /// # Errors + /// This function errors iff: + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn try_new(data_type: DataType, length: usize) -> Result { + if data_type.to_physical_type() != PhysicalType::Null { + return Err(Error::oos( + "NullArray can only be initialized with a DataType whose physical type is Boolean", + )); + } + + Ok(Self { data_type, length }) + } + + /// Returns a new [`NullArray`]. + /// # Panics + /// This function errors iff: + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(data_type: DataType, length: usize) -> Self { + Self::try_new(data_type, length).unwrap() + } + + /// Returns a new empty [`NullArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, 0) + } + + /// Returns a new [`NullArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new(data_type, length) + } + + impl_sliced!(); + impl_into_array!(); +} + +impl NullArray { + /// Returns a slice of the [`NullArray`]. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the arrays' length" + ); + unsafe { self.slice_unchecked(offset, length) }; + } + + /// Returns a slice of the [`NullArray`]. + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + pub unsafe fn slice_unchecked(&mut self, _offset: usize, length: usize) { + self.length = length; + } + + #[inline] + fn len(&self) -> usize { + self.length + } +} + +impl Array for NullArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + None + } + + fn with_validity(&self, _: Option) -> Box { + panic!("cannot set validity of a null array") + } +} + +#[derive(Debug)] +/// A distinct type to disambiguate +/// clashing methods +pub struct MutableNullArray { + inner: NullArray, +} + +impl MutableNullArray { + /// Returns a new [`MutableNullArray`]. + /// # Panics + /// This function errors iff: + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(data_type: DataType, length: usize) -> Self { + let inner = NullArray::try_new(data_type, length).unwrap(); + Self { inner } + } +} + +impl From for NullArray { + fn from(value: MutableNullArray) -> Self { + value.inner + } +} + +impl MutableArray for MutableNullArray { + fn data_type(&self) -> &DataType { + &DataType::Null + } + + fn len(&self) -> usize { + self.inner.length + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + self.inner.clone().boxed() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn push_null(&mut self) { + self.inner.length += 1; + } + + fn reserve(&mut self, _additional: usize) { + // no-op + } + + fn shrink_to_fit(&mut self) { + // no-op + } +} + +impl std::fmt::Debug for NullArray { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NullArray({})", self.len()) + } +} + +unsafe impl ToFfi for NullArray { + fn buffers(&self) -> Vec> { + // `None` is technically not required by the specification, but older C++ implementations require it, so leaving + // it here for backward compatibility + vec![None] + } + + fn offset(&self) -> Option { + Some(0) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for NullArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + Self::try_new(data_type, array.array().len()) + } +} + +#[cfg(feature = "arrow_rs")] +mod arrow { + use arrow_data::{ArrayData, ArrayDataBuilder}; + + use super::*; + impl NullArray { + /// Convert this array into [`arrow_data::ArrayData`] + pub fn to_data(&self) -> ArrayData { + let builder = ArrayDataBuilder::new(arrow_schema::DataType::Null).len(self.len()); + + // Safety: safe by construction + unsafe { builder.build_unchecked() } + } + + /// Create this array from [`ArrayData`] + pub fn from_data(data: &ArrayData) -> Self { + Self::new(DataType::Null, data.len()) + } + } +} diff --git a/crates/nano-arrow/src/array/ord.rs b/crates/nano-arrow/src/array/ord.rs new file mode 100644 index 000000000000..3454acebdaca --- /dev/null +++ b/crates/nano-arrow/src/array/ord.rs @@ -0,0 +1,182 @@ +//! Contains functions and function factories to order values within arrays. +use std::cmp::Ordering; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::offset::Offset; +use crate::types::NativeType; +use crate::util::total_ord::TotalOrd; + +/// Compare the values at two arbitrary indices in two arrays. +pub type DynComparator = Box Ordering + Send + Sync>; + +fn compare_primitives( + left: &dyn Array, + right: &dyn Array, +) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).tot_cmp(&right.value(j))) +} + +fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(&right.value(j))) +} + +fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + +fn compare_binary(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + +fn compare_dict(left: &DictionaryArray, right: &DictionaryArray) -> Result +where + K: DictionaryKey, +{ + let left_keys = left.keys().values().clone(); + let right_keys = right.keys().values().clone(); + + let comparator = build_compare(left.values().as_ref(), right.values().as_ref())?; + + Ok(Box::new(move |i: usize, j: usize| { + // safety: all dictionaries keys are guaranteed to be castable to usize + let key_left = unsafe { left_keys[i].as_usize() }; + let key_right = unsafe { right_keys[j].as_usize() }; + (comparator)(key_left, key_right) + })) +} + +macro_rules! dyn_dict { + ($key:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs.as_any().downcast_ref().unwrap(); + let rhs = $rhs.as_any().downcast_ref().unwrap(); + compare_dict::<$key>(lhs, rhs)? + }}; +} + +/// returns a comparison function that compares values at two different slots +/// between two [`Array`]. +/// # Example +/// ``` +/// use arrow2::array::{ord::build_compare, PrimitiveArray}; +/// +/// # fn main() -> arrow2::error::Result<()> { +/// let array1 = PrimitiveArray::from_slice([1, 2]); +/// let array2 = PrimitiveArray::from_slice([3, 4]); +/// +/// let cmp = build_compare(&array1, &array2)?; +/// +/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) +/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); +/// # Ok(()) +/// # } +/// ``` +/// # Error +/// The arrays' [`DataType`] must be equal and the types must have a natural order. +// This is a factory of comparisons. +pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + Ok(match (left.data_type(), right.data_type()) { + (a, b) if a != b => { + return Err(Error::InvalidArgumentError( + "Can't compare arrays of different types".to_string(), + )); + }, + (Boolean, Boolean) => compare_boolean(left, right), + (UInt8, UInt8) => compare_primitives::(left, right), + (UInt16, UInt16) => compare_primitives::(left, right), + (UInt32, UInt32) => compare_primitives::(left, right), + (UInt64, UInt64) => compare_primitives::(left, right), + (Int8, Int8) => compare_primitives::(left, right), + (Int16, Int16) => compare_primitives::(left, right), + (Int32, Int32) + | (Date32, Date32) + | (Time32(Second), Time32(Second)) + | (Time32(Millisecond), Time32(Millisecond)) + | (Interval(YearMonth), Interval(YearMonth)) => compare_primitives::(left, right), + (Int64, Int64) + | (Date64, Date64) + | (Time64(Microsecond), Time64(Microsecond)) + | (Time64(Nanosecond), Time64(Nanosecond)) + | (Timestamp(Second, None), Timestamp(Second, None)) + | (Timestamp(Millisecond, None), Timestamp(Millisecond, None)) + | (Timestamp(Microsecond, None), Timestamp(Microsecond, None)) + | (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None)) + | (Duration(Second), Duration(Second)) + | (Duration(Millisecond), Duration(Millisecond)) + | (Duration(Microsecond), Duration(Microsecond)) + | (Duration(Nanosecond), Duration(Nanosecond)) => compare_primitives::(left, right), + (Float32, Float32) => compare_primitives::(left, right), + (Float64, Float64) => compare_primitives::(left, right), + (Decimal(_, _), Decimal(_, _)) => compare_primitives::(left, right), + (Utf8, Utf8) => compare_string::(left, right), + (LargeUtf8, LargeUtf8) => compare_string::(left, right), + (Binary, Binary) => compare_binary::(left, right), + (LargeBinary, LargeBinary) => compare_binary::(left, right), + (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => { + match (key_type_lhs, key_type_rhs) { + (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), + (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), + (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right), + (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right), + (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right), + (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right), + (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right), + (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right), + (lhs, _) => { + return Err(Error::InvalidArgumentError(format!( + "Dictionaries do not support keys of type {lhs:?}" + ))) + }, + } + }, + (lhs, _) => { + return Err(Error::InvalidArgumentError(format!( + "The data type type {lhs:?} has no natural order" + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/array/physical_binary.rs b/crates/nano-arrow/src/array/physical_binary.rs new file mode 100644 index 000000000000..36e4ecf52d35 --- /dev/null +++ b/crates/nano-arrow/src/array/physical_binary.rs @@ -0,0 +1,230 @@ +use crate::bitmap::MutableBitmap; +use crate::offset::{Offset, Offsets}; + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +#[allow(clippy::type_complexity)] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, Offsets, Vec), E> +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut offsets = Vec::::with_capacity(len + 1); + let mut values = Vec::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + if let Some(item) = item? { + null.push_unchecked(true); + let s = item.as_ref(); + length += O::from_as_usize(s.len()); + values.extend_from_slice(s); + } else { + null.push_unchecked(false); + }; + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + + Ok((null.into(), Offsets::new_unchecked(offsets), values)) +} + +/// Creates [`MutableBitmap`] and two [`Vec`]s from an iterator of `Option`. +/// The first buffer corresponds to a offset buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip( + iterator: I, +) -> (Option, Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = Offsets::::with_capacity(len); + let mut values = Vec::::new(); + let mut validity = MutableBitmap::new(); + + extend_from_trusted_len_iter(&mut offsets, &mut values, &mut validity, iterator); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + (validity, offsets, values) +} + +/// Creates two [`Buffer`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is [`TrustedLen`]. +#[inline] +pub(crate) unsafe fn trusted_len_values_iter(iterator: I) -> (Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = Offsets::::with_capacity(len); + let mut values = Vec::::new(); + + extend_from_trusted_len_values_iter(&mut offsets, &mut values, iterator); + + (offsets, values) +} + +// Populates `offsets` and `values` [`Vec`]s with information extracted +// from the incoming `iterator`. +// # Safety +// The caller must ensure the `iterator` is [`TrustedLen`] +#[inline] +pub(crate) unsafe fn extend_from_trusted_len_values_iter( + offsets: &mut Offsets, + values: &mut Vec, + iterator: I, +) where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let lengths = iterator.map(|item| { + let s = item.as_ref(); + // Push new entries for both `values` and `offsets` buffer + values.extend_from_slice(s); + s.len() + }); + offsets.try_extend_from_lengths(lengths).unwrap(); +} + +// Populates `offsets` and `values` [`Vec`]s with information extracted +// from the incoming `iterator`. +// the return value indicates how many items were added. +#[inline] +pub(crate) fn extend_from_values_iter( + offsets: &mut Offsets, + values: &mut Vec, + iterator: I, +) -> usize +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (size_hint, _) = iterator.size_hint(); + + offsets.reserve(size_hint); + + let start_index = offsets.len_proxy(); + + for item in iterator { + let bytes = item.as_ref(); + values.extend_from_slice(bytes); + offsets.try_push(bytes.len()).unwrap(); + } + offsets.len_proxy() - start_index +} + +// Populates `offsets`, `values`, and `validity` [`Vec`]s with +// information extracted from the incoming `iterator`. +// +// # Safety +// The caller must ensure that `iterator` is [`TrustedLen`] +#[inline] +pub(crate) unsafe fn extend_from_trusted_len_iter( + offsets: &mut Offsets, + values: &mut Vec, + validity: &mut MutableBitmap, + iterator: I, +) where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_from_trusted_len_iter requires an upper limit"); + + offsets.reserve(additional); + validity.reserve(additional); + + let lengths = iterator.map(|item| { + if let Some(item) = item { + let bytes = item.as_ref(); + values.extend_from_slice(bytes); + validity.push_unchecked(true); + bytes.len() + } else { + validity.push_unchecked(false); + 0 + } + }); + offsets.try_extend_from_lengths(lengths).unwrap(); +} + +/// Creates two [`Vec`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +#[inline] +pub(crate) fn values_iter(iterator: I) -> (Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (lower, _) = iterator.size_hint(); + + let mut offsets = Offsets::::with_capacity(lower); + let mut values = Vec::::new(); + + for item in iterator { + let s = item.as_ref(); + values.extend_from_slice(s); + offsets.try_push(s.len()).unwrap(); + } + (offsets, values) +} + +/// Extends `validity` with all items from `other` +pub(crate) fn extend_validity( + length: usize, + validity: &mut Option, + other: &Option, +) { + if let Some(other) = other { + if let Some(validity) = validity { + let slice = other.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { validity.extend_from_slice_unchecked(slice, 0, other.len()) } + } else { + let mut new_validity = MutableBitmap::from_len_set(length); + new_validity.extend_from_slice(other.as_slice(), 0, other.len()); + *validity = Some(new_validity); + } + } +} diff --git a/crates/nano-arrow/src/array/primitive/data.rs b/crates/nano-arrow/src/array/primitive/data.rs new file mode 100644 index 000000000000..d4879f796812 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/data.rs @@ -0,0 +1,33 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, PrimitiveArray}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::types::NativeType; + +impl Arrow2Arrow for PrimitiveArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.values.clone().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + + let mut values: Buffer = data.buffers()[0].clone().into(); + values.slice(data.offset(), data.len()); + + Self { + data_type, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/primitive/ffi.rs b/crates/nano-arrow/src/array/primitive/ffi.rs new file mode 100644 index 000000000000..c74c157f750f --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/ffi.rs @@ -0,0 +1,56 @@ +use super::PrimitiveArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::types::NativeType; + +unsafe impl ToFfi for PrimitiveArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for PrimitiveArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/primitive/fmt.rs b/crates/nano-arrow/src/array/primitive/fmt.rs new file mode 100644 index 000000000000..3743a16a188e --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/fmt.rs @@ -0,0 +1,149 @@ +#![allow(clippy::redundant_closure_call)] +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::PrimitiveArray; +use crate::array::fmt::write_vec; +use crate::array::Array; +use crate::datatypes::{IntervalUnit, TimeUnit}; +use crate::temporal_conversions; +use crate::types::{days_ms, i256, months_days_ns, NativeType}; + +macro_rules! dyn_primitive { + ($array:expr, $ty:ty, $expr:expr) => {{ + let array = ($array as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(move |f, index| write!(f, "{}", $expr(array.value(index)))) + }}; +} + +pub fn get_write_value<'a, T: NativeType, F: Write>( + array: &'a PrimitiveArray, +) -> Box Result + 'a> { + use crate::datatypes::DataType::*; + match array.data_type().to_logical_type() { + Int8 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int16 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt8 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt16 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Float16 => unreachable!(), + Float32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Float64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Date32 => { + dyn_primitive!(array, i32, temporal_conversions::date32_to_date) + }, + Date64 => { + dyn_primitive!(array, i64, temporal_conversions::date64_to_date) + }, + Time32(TimeUnit::Second) => { + dyn_primitive!(array, i32, temporal_conversions::time32s_to_time) + }, + Time32(TimeUnit::Millisecond) => { + dyn_primitive!(array, i32, temporal_conversions::time32ms_to_time) + }, + Time32(_) => unreachable!(), // remaining are not valid + Time64(TimeUnit::Microsecond) => { + dyn_primitive!(array, i64, temporal_conversions::time64us_to_time) + }, + Time64(TimeUnit::Nanosecond) => { + dyn_primitive!(array, i64, temporal_conversions::time64ns_to_time) + }, + Time64(_) => unreachable!(), // remaining are not valid + Timestamp(time_unit, tz) => { + if let Some(tz) = tz { + let timezone = temporal_conversions::parse_offset(tz); + match timezone { + Ok(timezone) => { + dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_datetime(time, *time_unit, &timezone) + }) + }, + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(tz); + match timezone { + Ok(timezone) => dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_datetime( + time, *time_unit, &timezone, + ) + }), + Err(_) => { + let tz = tz.clone(); + Box::new(move |f, index| { + write!(f, "{} ({})", array.value(index), tz) + }) + }, + } + }, + #[cfg(not(feature = "chrono-tz"))] + _ => { + let tz = tz.clone(); + Box::new(move |f, index| write!(f, "{} ({})", array.value(index), tz)) + }, + } + } else { + dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_naive_datetime(time, *time_unit) + }) + } + }, + Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(array, i32, |x| format!("{x}m")) + }, + Interval(IntervalUnit::DayTime) => { + dyn_primitive!(array, days_ms, |x: days_ms| format!( + "{}d{}ms", + x.days(), + x.milliseconds() + )) + }, + Interval(IntervalUnit::MonthDayNano) => { + dyn_primitive!(array, months_days_ns, |x: months_days_ns| format!( + "{}m{}d{}ns", + x.months(), + x.days(), + x.ns() + )) + }, + Duration(TimeUnit::Second) => dyn_primitive!(array, i64, |x| format!("{x}s")), + Duration(TimeUnit::Millisecond) => dyn_primitive!(array, i64, |x| format!("{x}ms")), + Duration(TimeUnit::Microsecond) => dyn_primitive!(array, i64, |x| format!("{x}us")), + Duration(TimeUnit::Nanosecond) => dyn_primitive!(array, i64, |x| format!("{x}ns")), + Decimal(_, scale) => { + // The number 999.99 has a precision of 5 and scale of 2 + let scale = *scale as u32; + let factor = 10i128.pow(scale); + let display = move |x: i128| { + let base = x / factor; + let decimals = (x - base * factor).abs(); + format!("{base}.{decimals}") + }; + dyn_primitive!(array, i128, display) + }, + Decimal256(_, scale) => { + let scale = *scale as u32; + let factor = (ethnum::I256::ONE * 10).pow(scale); + let display = move |x: i256| { + let base = x.0 / factor; + let decimals = (x.0 - base * factor).abs(); + format!("{base}.{decimals}") + }; + dyn_primitive!(array, i256, display) + }, + _ => unreachable!(), + } +} + +impl Debug for PrimitiveArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = get_write_value(self); + + write!(f, "{:?}", self.data_type())?; + write_vec(f, &*writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/primitive/from_natural.rs b/crates/nano-arrow/src/array/primitive/from_natural.rs new file mode 100644 index 000000000000..0530c748af7e --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/from_natural.rs @@ -0,0 +1,16 @@ +use std::iter::FromIterator; + +use super::{MutablePrimitiveArray, PrimitiveArray}; +use crate::types::NativeType; + +impl]>> From

for PrimitiveArray { + fn from(slice: P) -> Self { + MutablePrimitiveArray::::from(slice).into() + } +} + +impl>> FromIterator for PrimitiveArray { + fn from_iter>(iter: I) -> Self { + MutablePrimitiveArray::::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/primitive/iterator.rs b/crates/nano-arrow/src/array/primitive/iterator.rs new file mode 100644 index 000000000000..9433979dad84 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/iterator.rs @@ -0,0 +1,47 @@ +use super::{MutablePrimitiveArray, PrimitiveArray}; +use crate::array::MutableArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::IntoIter as BitmapIntoIter; +use crate::buffer::IntoIter; +use crate::types::NativeType; + +impl IntoIterator for PrimitiveArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIntoIter>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + let (_, values, validity) = self.into_inner(); + let values = values.into_iter(); + let validity = + validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.into_iter())); + ZipValidity::new(values, validity) + } +} + +impl<'a, T: NativeType> IntoIterator for &'a PrimitiveArray { + type Item = Option<&'a T>; + type IntoIter = ZipValidity<&'a T, std::slice::Iter<'a, T>, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T: NativeType> MutablePrimitiveArray { + /// Returns an iterator over `Option` + #[inline] + pub fn iter(&'a self) -> ZipValidity<&'a T, std::slice::Iter<'a, T>, BitmapIter<'a>> { + ZipValidity::new( + self.values().iter(), + self.validity().as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator of `T` + #[inline] + pub fn values_iter(&'a self) -> std::slice::Iter<'a, T> { + self.values().iter() + } +} diff --git a/crates/nano-arrow/src/array/primitive/mod.rs b/crates/nano-arrow/src/array/primitive/mod.rs new file mode 100644 index 000000000000..b3d649a670be --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/mod.rs @@ -0,0 +1,511 @@ +use either::Either; + +use super::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::*; +use crate::error::Error; +use crate::trusted_len::TrustedLen; +use crate::types::{days_ms, f16, i256, months_days_ns, NativeType}; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod from_natural; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; + +/// A [`PrimitiveArray`] is Arrow's semantically equivalent of an immutable `Vec>` where +/// T is [`NativeType`] (e.g. [`i32`]). It implements [`Array`]. +/// +/// One way to think about a [`PrimitiveArray`] is `(DataType, Arc>, Option>>)` +/// where: +/// * the first item is the array's logical type +/// * the second is the immutable values +/// * the third is the immutable validity (whether a value is null or not as a bitmap). +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use arrow2::array::PrimitiveArray; +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// +/// let array = PrimitiveArray::from([Some(1i32), None, Some(10)]); +/// assert_eq!(array.value(0), 1); +/// assert_eq!(array.iter().collect::>(), vec![Some(&1i32), None, Some(&10)]); +/// assert_eq!(array.values_iter().copied().collect::>(), vec![1, 0, 10]); +/// // the underlying representation +/// assert_eq!(array.values(), &Buffer::from(vec![1i32, 0, 10])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// +/// ``` +#[derive(Clone)] +pub struct PrimitiveArray { + data_type: DataType, + values: Buffer, + validity: Option, +} + +pub(super) fn check( + data_type: &DataType, + values: &[T], + validity_len: Option, +) -> Result<(), Error> { + if validity_len.map_or(false, |len| len != values.len()) { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != PhysicalType::Primitive(T::PRIMITIVE) { + return Err(Error::oos( + "PrimitiveArray can only be initialized with a DataType whose physical type is Primitive", + )); + } + Ok(()) +} + +impl PrimitiveArray { + /// The canonical method to create a [`PrimitiveArray`] out of its internal components. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + pub fn try_new( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|v| v.len()))?; + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Returns a new [`PrimitiveArray`] with a different logical type. + /// + /// This function is useful to assign a different [`DataType`] to the array. + /// Used to change the arrays' logical type (see example). + /// # Example + /// ``` + /// use arrow2::array::Int32Array; + /// use arrow2::datatypes::DataType; + /// + /// let array = Int32Array::from(&[Some(1), None, Some(2)]).to(DataType::Date32); + /// assert_eq!( + /// format!("{:?}", array), + /// "Date32[1970-01-02, None, 1970-01-03]" + /// ); + /// ``` + /// # Panics + /// Panics iff the `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + #[inline] + #[must_use] + pub fn to(self, data_type: DataType) -> Self { + check( + &data_type, + &self.values, + self.validity.as_ref().map(|v| v.len()), + ) + .unwrap(); + Self { + data_type, + values: self.values, + validity: self.validity, + } + } + + /// Creates a (non-null) [`PrimitiveArray`] from a vector of values. + /// This function is `O(1)`. + /// # Examples + /// ``` + /// use arrow2::array::PrimitiveArray; + /// + /// let array = PrimitiveArray::from_vec(vec![1, 2, 3]); + /// assert_eq!(format!("{:?}", array), "Int32[1, 2, 3]"); + /// ``` + pub fn from_vec(values: Vec) -> Self { + Self::new(T::PRIMITIVE.into(), values.into(), None) + } + + /// Returns an iterator over the values and validity, `Option<&T>`. + #[inline] + pub fn iter(&self) -> ZipValidity<&T, std::slice::Iter, BitmapIter> { + ZipValidity::new_with_validity(self.values().iter(), self.validity()) + } + + /// Returns an iterator of the values, `&T`, ignoring the arrays' validity. + #[inline] + pub fn values_iter(&self) -> std::slice::Iter { + self.values().iter() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// The values [`Buffer`]. + /// Values on null slots are undetermined (they can be anything). + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the arrays' [`DataType`]. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the value at slot `i`. + /// + /// Equivalent to `self.values()[i]`. The value of a null slot is undetermined (it can be anything). + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> T { + self.values[i] + } + + /// Returns the value at index `i`. + /// The value on null slots is undetermined (it can be anything). + /// # Safety + /// Caller must be sure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> T { + *self.values.get_unchecked(i) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Slices this [`PrimitiveArray`] by an offset and length. + /// # Implementation + /// This operation is `O(1)`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "offset + length may not exceed length of array" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`PrimitiveArray`] by an offset and length. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values.slice_unchecked(offset, length); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns this [`PrimitiveArray`] with new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(mut self, values: Buffer) -> Self { + self.set_values(values); + self + } + + /// Update the values of this [`PrimitiveArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Buffer) { + assert_eq!( + values.len(), + self.len(), + "values' length must be equal to this arrays' length" + ); + self.values = values; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity Bitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + /// Returns an option of a mutable reference to the values of this [`PrimitiveArray`]. + pub fn get_mut_values(&mut self) -> Option<&mut [T]> { + self.values.get_mut_slice() + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, Buffer, Option) { + let Self { + data_type, + values, + validity, + } = self; + (data_type, values, validity) + } + + /// Creates a `[PrimitiveArray]` from its internal representation. + /// This is the inverted from `[PrimitiveArray::into_inner]` + pub fn from_inner( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|v| v.len()))?; + Ok(unsafe { Self::from_inner_unchecked(data_type, values, validity) }) + } + + /// Creates a `[PrimitiveArray]` from its internal representation. + /// This is the inverted from `[PrimitiveArray::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Self { + Self { + data_type, + values, + validity, + } + } + + /// Try to convert this [`PrimitiveArray`] to a [`MutablePrimitiveArray`] via copy-on-write semantics. + /// + /// A [`PrimitiveArray`] is backed by a [`Buffer`] and [`Bitmap`] which are essentially `Arc>`. + /// This function returns a [`MutablePrimitiveArray`] (via [`std::sync::Arc::get_mut`]) iff both values + /// and validity have not been cloned / are unique references to their underlying vectors. + /// + /// This function is primarily used to re-use memory regions. + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + Left(bitmap) => Left(PrimitiveArray::new( + self.data_type, + self.values, + Some(bitmap), + )), + Right(mutable_bitmap) => match self.values.into_mut() { + Right(values) => Right( + MutablePrimitiveArray::try_new( + self.data_type, + values, + Some(mutable_bitmap), + ) + .unwrap(), + ), + Left(values) => Left(PrimitiveArray::new( + self.data_type, + values, + Some(mutable_bitmap.into()), + )), + }, + } + } else { + match self.values.into_mut() { + Right(values) => { + Right(MutablePrimitiveArray::try_new(self.data_type, values, None).unwrap()) + }, + Left(values) => Left(PrimitiveArray::new(self.data_type, values, None)), + } + } + } + + /// Returns a new empty (zero-length) [`PrimitiveArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, Buffer::new(), None) + } + + /// Returns a new [`PrimitiveArray`] where all slots are null / `None`. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new( + data_type, + vec![T::default(); length].into(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Creates a (non-null) [`PrimitiveArray`] from an iterator of values. + /// # Implementation + /// This does not assume that the iterator has a known length. + pub fn from_values>(iter: I) -> Self { + Self::new(T::PRIMITIVE.into(), Vec::::from_iter(iter).into(), None) + } + + /// Creates a (non-null) [`PrimitiveArray`] from a slice of values. + /// # Implementation + /// This is essentially a memcopy and is thus `O(N)` + pub fn from_slice>(slice: P) -> Self { + Self::new( + T::PRIMITIVE.into(), + Vec::::from(slice.as_ref()).into(), + None, + ) + } + + /// Creates a (non-null) [`PrimitiveArray`] from a [`TrustedLen`] of values. + /// # Implementation + /// This does not assume that the iterator has a known length. + pub fn from_trusted_len_values_iter>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_values_iter(iter).into() + } + + /// Creates a new [`PrimitiveArray`] from an iterator over values + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_values_iter_unchecked(iter).into() + } + + /// Creates a [`PrimitiveArray`] from a [`TrustedLen`] of optional values. + pub fn from_trusted_len_iter>>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_iter(iter).into() + } + + /// Creates a [`PrimitiveArray`] from an iterator of optional values. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_iter_unchecked>>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_iter_unchecked(iter).into() + } + + /// Alias for `Self::try_new(..).unwrap()`. + /// # Panics + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive`]. + pub fn new(data_type: DataType, values: Buffer, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } +} + +impl Array for PrimitiveArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +/// A type definition [`PrimitiveArray`] for `i8` +pub type Int8Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i16` +pub type Int16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i32` +pub type Int32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i64` +pub type Int64Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i128` +pub type Int128Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i256` +pub type Int256Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for [`days_ms`] +pub type DaysMsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for [`months_days_ns`] +pub type MonthsDaysNsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f16` +pub type Float16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f32` +pub type Float32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f64` +pub type Float64Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u8` +pub type UInt8Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u16` +pub type UInt16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u32` +pub type UInt32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u64` +pub type UInt64Array = PrimitiveArray; + +/// A type definition [`MutablePrimitiveArray`] for `i8` +pub type Int8Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i16` +pub type Int16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i32` +pub type Int32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i64` +pub type Int64Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i128` +pub type Int128Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i256` +pub type Int256Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for [`days_ms`] +pub type DaysMsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for [`months_days_ns`] +pub type MonthsDaysNsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f16` +pub type Float16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f32` +pub type Float32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f64` +pub type Float64Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u8` +pub type UInt8Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u16` +pub type UInt16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u32` +pub type UInt32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u64` +pub type UInt64Vec = MutablePrimitiveArray; + +impl Default for PrimitiveArray { + fn default() -> Self { + PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None) + } +} diff --git a/crates/nano-arrow/src/array/primitive/mutable.rs b/crates/nano-arrow/src/array/primitive/mutable.rs new file mode 100644 index 000000000000..fc61b2e74884 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/mutable.rs @@ -0,0 +1,665 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{check, PrimitiveArray}; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +/// The Arrow's equivalent to `Vec>` where `T` is byte-size (e.g. `i32`). +/// Converting a [`MutablePrimitiveArray`] into a [`PrimitiveArray`] is `O(1)`. +#[derive(Debug, Clone)] +pub struct MutablePrimitiveArray { + data_type: DataType, + values: Vec, + validity: Option, +} + +impl From> for PrimitiveArray { + fn from(other: MutablePrimitiveArray) -> Self { + let validity = other.validity.and_then(|x| { + let bitmap: Bitmap = x.into(); + if bitmap.unset_bits() == 0 { + None + } else { + Some(bitmap) + } + }); + + PrimitiveArray::::new(other.data_type, other.values.into(), validity) + } +} + +impl]>> From

for MutablePrimitiveArray { + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } +} + +impl MutablePrimitiveArray { + /// Creates a new empty [`MutablePrimitiveArray`]. + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Creates a new [`MutablePrimitiveArray`] with a capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacity_from(capacity, T::PRIMITIVE.into()) + } + + /// The canonical method to create a [`MutablePrimitiveArray`] out of its internal components. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Primitive(T::PRIMITIVE)`] + pub fn try_new( + data_type: DataType, + values: Vec, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|x| x.len()))?; + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Extract the low-end APIs from the [`MutablePrimitiveArray`]. + pub fn into_inner(self) -> (DataType, Vec, Option) { + (self.data_type, self.values, self.validity) + } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics iff `f` panics + pub fn apply_values(&mut self, f: F) { + f(&mut self.values); + } +} + +impl Default for MutablePrimitiveArray { + fn default() -> Self { + Self::new() + } +} + +impl From for MutablePrimitiveArray { + fn from(data_type: DataType) -> Self { + assert!(data_type.to_physical_type().eq_primitive(T::PRIMITIVE)); + Self { + data_type, + values: Vec::::new(), + validity: None, + } + } +} + +impl MutablePrimitiveArray { + /// Creates a new [`MutablePrimitiveArray`] from a capacity and [`DataType`]. + pub fn with_capacity_from(capacity: usize, data_type: DataType) -> Self { + assert!(data_type.to_physical_type().eq_primitive(T::PRIMITIVE)); + Self { + data_type, + values: Vec::::with_capacity(capacity), + validity: None, + } + } + + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Adds a new value to the array. + #[inline] + pub fn push(&mut self, value: Option) { + match value { + Some(value) => { + self.values.push(value); + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(T::default()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => { + self.init_validity(); + }, + } + }, + } + } + + /// Pop a value from the array. + /// Note if the values is empty, this method will return None. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| value)) + .unwrap_or_else(|| Some(value)) + } + + /// Extends the [`MutablePrimitiveArray`] with a constant + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: Option) { + if let Some(value) = value { + self.values.resize(self.values.len() + additional, value); + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, true) + } + } else { + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, false) + } else { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.extend_constant(additional, false); + self.validity = Some(validity) + } + self.values + .resize(self.values.len() + additional, T::default()); + } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: Iterator>, + { + if let Some(validity) = self.validity.as_mut() { + extend_trusted_len_unzip(iterator, validity, &mut self.values) + } else { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + extend_trusted_len_unzip(iterator, &mut validity, &mut self.values); + self.validity = Some(validity); + } + } + /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len` which accepts in iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + { + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len_unchecked` which accepts in iterator of optional values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + I: Iterator, + { + self.values.extend(iterator); + self.update_all_valid(); + } + + #[inline] + /// Extends the [`MutablePrimitiveArray`] from a slice + pub fn extend_from_slice(&mut self, items: &[T]) { + self.values.extend_from_slice(items); + self.update_all_valid(); + } + + fn update_all_valid(&mut self) { + // get len before mutable borrow + let len = self.len(); + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(len - validity.len(), true); + } + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Changes the arrays' [`DataType`], returning a new [`MutablePrimitiveArray`]. + /// Use to change the logical type without changing the corresponding physical Type. + /// # Implementation + /// This operation is `O(1)`. + #[inline] + pub fn to(self, data_type: DataType) -> Self { + Self::try_new(data_type, self.values, self.validity).unwrap() + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: PrimitiveArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutablePrimitiveArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + /// Returns the capacity of this [`MutablePrimitiveArray`]. + pub fn capacity(&self) -> usize { + self.values.capacity() + } +} + +/// Accessors +impl MutablePrimitiveArray { + /// Returns its values. + pub fn values(&self) -> &Vec { + &self.values + } + + /// Returns a mutable slice of values. + pub fn values_mut_slice(&mut self) -> &mut [T] { + self.values.as_mut_slice() + } +} + +/// Setters +impl MutablePrimitiveArray { + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Panic + /// Panics iff index is larger than `self.len()`. + pub fn set(&mut self, index: usize, value: Option) { + assert!(index < self.len()); + // Safety: + // we just checked bounds + unsafe { self.set_unchecked(index, value) } + } + + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Safety + /// Caller must ensure `index < self.len()` + pub unsafe fn set_unchecked(&mut self, index: usize, value: Option) { + *self.values.get_unchecked_mut(index) = value.unwrap_or_default(); + + if value.is_none() && self.validity.is_none() { + // When the validity is None, all elements so far are valid. When one of the elements is set of null, + // the validity must be initialized. + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + if let Some(x) = self.validity.as_mut() { + x.set_unchecked(index, value.is_some()) + } + } + + /// Sets the validity. + /// # Panic + /// Panics iff the validity's len is not equal to the existing values' length. + pub fn set_validity(&mut self, validity: Option) { + if let Some(validity) = &validity { + assert_eq!(self.values.len(), validity.len()) + } + self.validity = validity; + } + + /// Sets values. + /// # Panic + /// Panics iff the values' length is not equal to the existing validity's len. + pub fn set_values(&mut self, values: Vec) { + assert_eq!(values.len(), self.values.len()); + self.values = values; + } +} + +impl Extend> for MutablePrimitiveArray { + fn extend>>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + iter.for_each(|x| self.push(x)) + } +} + +impl TryExtend> for MutablePrimitiveArray { + /// This is infalible and is implemented for consistency with all other types + fn try_extend>>(&mut self, iter: I) -> Result<(), Error> { + self.extend(iter); + Ok(()) + } +} + +impl TryPush> for MutablePrimitiveArray { + /// This is infalible and is implemented for consistency with all other types + fn try_push(&mut self, item: Option) -> Result<(), Error> { + self.push(item); + Ok(()) + } +} + +impl MutableArray for MutablePrimitiveArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + PrimitiveArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + PrimitiveArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl MutablePrimitiveArray { + /// Creates a [`MutablePrimitiveArray`] from a slice of values. + pub fn from_slice>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter().copied()) + } + + /// Creates a [`MutablePrimitiveArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + let (validity, values) = trusted_len_unzip(iterator); + + Self { + data_type: T::PRIMITIVE.into(), + values, + validity, + } + } + + /// Creates a [`MutablePrimitiveArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iter: I, + ) -> std::result::Result + where + P: std::borrow::Borrow, + I: IntoIterator, E>>, + { + let iterator = iter.into_iter(); + + let (validity, values) = try_trusted_len_unzip(iterator)?; + + Ok(Self { + data_type: T::PRIMITIVE.into(), + values, + validity, + }) + } + + /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutablePrimitiveArray`] out an iterator over values + pub fn from_trusted_len_values_iter>(iter: I) -> Self { + Self { + data_type: T::PRIMITIVE.into(), + values: iter.collect(), + validity: None, + } + } + + /// Creates a (non-null) [`MutablePrimitiveArray`] from a vector of values. + /// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. + pub fn from_vec(values: Vec) -> Self { + Self::try_new(T::PRIMITIVE.into(), values, None).unwrap() + } + + /// Creates a new [`MutablePrimitiveArray`] from an iterator over values + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { + Self { + data_type: T::PRIMITIVE.into(), + values: iter.collect(), + validity: None, + } + } +} + +impl>> FromIterator + for MutablePrimitiveArray +{ + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: Vec = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + T::default() + } + }) + .collect(); + + let validity = Some(validity); + + Self { + data_type: T::PRIMITIVE.into(), + values, + validity, + } + } +} + +/// Extends a [`MutableBitmap`] and a [`Vec`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn extend_trusted_len_unzip( + iterator: I, + validity: &mut MutableBitmap, + buffer: &mut Vec, +) where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("trusted_len_unzip requires an upper limit"); + + validity.reserve(additional); + let values = iterator.map(|item| { + if let Some(item) = item { + validity.push_unchecked(true); + *item.borrow() + } else { + validity.push_unchecked(false); + T::default() + } + }); + buffer.extend(values); +} + +/// Creates a [`MutableBitmap`] and a [`Vec`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, Vec) +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let mut validity = MutableBitmap::new(); + let mut buffer = Vec::::new(); + + extend_trusted_len_unzip(iterator, &mut validity, &mut buffer); + + let validity = Some(validity); + + (validity, buffer) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, Vec), E> +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut buffer = Vec::::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let item = if let Some(item) = item? { + null.push(true); + *item.borrow() + } else { + null.push(false); + T::default() + }; + std::ptr::write(dst, item); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + null.set_len(len); + + let validity = Some(null); + + Ok((validity, buffer)) +} + +impl PartialEq for MutablePrimitiveArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutablePrimitiveArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + self.values.extend_from_slice(slice); + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/specification.rs b/crates/nano-arrow/src/array/specification.rs new file mode 100644 index 000000000000..efa8fe1be4a4 --- /dev/null +++ b/crates/nano-arrow/src/array/specification.rs @@ -0,0 +1,178 @@ +use crate::array::DictionaryKey; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +/// Helper trait to support `Offset` and `OffsetBuffer` +pub(crate) trait OffsetsContainer { + fn last(&self) -> usize; + fn as_slice(&self) -> &[O]; +} + +impl OffsetsContainer for OffsetsBuffer { + #[inline] + fn last(&self) -> usize { + self.last().to_usize() + } + + #[inline] + fn as_slice(&self) -> &[O] { + self.buffer() + } +} + +impl OffsetsContainer for Offsets { + #[inline] + fn last(&self) -> usize { + self.last().to_usize() + } + + #[inline] + fn as_slice(&self) -> &[O] { + self.as_slice() + } +} + +pub(crate) fn try_check_offsets_bounds>( + offsets: &C, + values_len: usize, +) -> Result<()> { + if offsets.last() > values_len { + Err(Error::oos("offsets must not exceed the values length")) + } else { + Ok(()) + } +} + +/// # Error +/// * any offset is larger or equal to `values_len`. +/// * any slice of `values` between two consecutive pairs from `offsets` is invalid `utf8`, or +pub(crate) fn try_check_utf8>( + offsets: &C, + values: &[u8], +) -> Result<()> { + if offsets.as_slice().len() == 1 { + return Ok(()); + } + + try_check_offsets_bounds(offsets, values.len())?; + + if values.is_ascii() { + Ok(()) + } else { + simdutf8::basic::from_utf8(values)?; + + // offsets can be == values.len() + // find first offset from the end that is smaller + // Example: + // values.len() = 10 + // offsets = [0, 5, 10, 10] + let offsets = offsets.as_slice(); + let last = offsets + .iter() + .enumerate() + .skip(1) + .rev() + .find_map(|(i, offset)| (offset.to_usize() < values.len()).then(|| i)); + + let last = if let Some(last) = last { + // following the example: last = 1 (offset = 5) + last + } else { + // given `l = values.len()`, this branch is hit iff either: + // * `offsets = [0, l, l, ...]`, which was covered by `from_utf8(values)` above + // * `offsets = [0]`, which never happens because offsets.as_slice().len() == 1 is short-circuited above + return Ok(()); + }; + + // truncate to relevant offsets. Note: `=last` because last was computed skipping the first item + // following the example: starts = [0, 5] + let starts = unsafe { offsets.get_unchecked(..=last) }; + + let mut any_invalid = false; + for start in starts { + let start = start.to_usize(); + + // Safety: `try_check_offsets_bounds` just checked for bounds + let b = *unsafe { values.get_unchecked(start) }; + + // A valid code-point iff it does not start with 0b10xxxxxx + // Bit-magic taken from `std::str::is_char_boundary` + if (b as i8) < -0x40 { + any_invalid = true + } + } + if any_invalid { + return Err(Error::oos("Non-valid char boundary detected")); + } + Ok(()) + } +} + +/// Check dictionary indexes without checking usize conversion. +/// # Safety +/// The caller must ensure that `K::as_usize` always succeeds. +pub(crate) unsafe fn check_indexes_unchecked( + keys: &[K], + len: usize, +) -> Result<()> { + let mut invalid = false; + + // this loop is auto-vectorized + keys.iter().for_each(|k| { + if k.as_usize() > len { + invalid = true; + } + }); + + if invalid { + let key = keys.iter().map(|k| k.as_usize()).max().unwrap(); + Err(Error::oos(format!("One of the dictionary keys is {key} but it must be < than the length of the dictionary values, which is {len}"))) + } else { + Ok(()) + } +} + +pub fn check_indexes(keys: &[K], len: usize) -> Result<()> +where + K: std::fmt::Debug + Copy + TryInto, +{ + keys.iter().try_for_each(|key| { + let key: usize = (*key) + .try_into() + .map_err(|_| Error::oos(format!("The dictionary key must fit in a `usize`, but {key:?} does not")))?; + if key >= len { + Err(Error::oos(format!("One of the dictionary keys is {key} but it must be < than the length of the dictionary values, which is {len}"))) + } else { + Ok(()) + } + }) +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + pub(crate) fn binary_strategy() -> impl Strategy> { + prop::collection::vec(any::(), 1..100) + } + + proptest! { + // a bit expensive, feel free to run it when changing the code above + // #![proptest_config(ProptestConfig::with_cases(100000))] + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well + fn check_utf8_validation(values in binary_strategy()) { + + for offset in 0..values.len() - 1 { + let offsets = vec![0, offset as i32, values.len() as i32].try_into().unwrap(); + + let mut is_valid = std::str::from_utf8(&values[..offset]).is_ok(); + is_valid &= std::str::from_utf8(&values[offset..]).is_ok(); + + assert_eq!(try_check_utf8::>(&offsets, &values).is_ok(), is_valid) + } + } + } +} diff --git a/crates/nano-arrow/src/array/struct_/data.rs b/crates/nano-arrow/src/array/struct_/data.rs new file mode 100644 index 000000000000..b96dc4ffe28b --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/data.rs @@ -0,0 +1,28 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, StructArray}; +use crate::bitmap::Bitmap; + +impl Arrow2Arrow for StructArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(self.values.iter().map(|x| to_data(x.as_ref())).collect()); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + + Self { + data_type, + values: data.child_data().iter().map(from_data).collect(), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/struct_/ffi.rs b/crates/nano-arrow/src/array/struct_/ffi.rs new file mode 100644 index 000000000000..95abe00694b2 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/ffi.rs @@ -0,0 +1,72 @@ +use super::super::ffi::ToFfi; +use super::super::{Array, FromFfi}; +use super::StructArray; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for StructArray { + fn buffers(&self) -> Vec> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + self.values.clone() + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for StructArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let fields = Self::get_fields(&data_type); + + let arrow_array = array.array(); + let validity = unsafe { array.validity() }?; + let len = arrow_array.len(); + let offset = arrow_array.offset(); + let values = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + ffi::try_from(child).map(|arr| { + // there is a discrepancy with how arrow2 exports sliced + // struct array and how pyarrow does it. + // # Pyarrow + // ## struct array len 3 + // * slice 1 by with len 2 + // offset on struct array: 1 + // length on struct array: 2 + // offset on value array: 0 + // length on value array: 3 + // # Arrow2 + // ## struct array len 3 + // * slice 1 by with len 2 + // offset on struct array: 0 + // length on struct array: 3 + // offset on value array: 1 + // length on value array: 2 + // + // this branch will ensure both can round trip + if arr.len() >= (len + offset) { + arr.sliced(offset, len) + } else { + arr + } + }) + }) + .collect::>>>()?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/struct_/fmt.rs b/crates/nano-arrow/src/array/struct_/fmt.rs new file mode 100644 index 000000000000..999cd8b67e08 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/fmt.rs @@ -0,0 +1,34 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_map, write_vec}; +use super::StructArray; + +pub fn write_value( + array: &StructArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let writer = |f: &mut W, _index| { + for (i, (field, column)) in array.fields().iter().zip(array.values()).enumerate() { + if i != 0 { + write!(f, ", ")?; + } + let writer = get_display(column.as_ref(), null); + write!(f, "{}: ", field.name)?; + writer(f, index)?; + } + Ok(()) + }; + + write_map(f, writer, None, 1, null, false) +} + +impl Debug for StructArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "StructArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/struct_/iterator.rs b/crates/nano-arrow/src/array/struct_/iterator.rs new file mode 100644 index 000000000000..cb8e6aafbb09 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/iterator.rs @@ -0,0 +1,96 @@ +use super::StructArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::scalar::{new_scalar, Scalar}; +use crate::trusted_len::TrustedLen; + +pub struct StructValueIter<'a> { + array: &'a StructArray, + index: usize, + end: usize, +} + +impl<'a> StructValueIter<'a> { + #[inline] + pub fn new(array: &'a StructArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a> Iterator for StructValueIter<'a> { + type Item = Vec>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + + // Safety: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), old)) + .collect(), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a> TrustedLen for StructValueIter<'a> {} + +impl<'a> DoubleEndedIterator for StructValueIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + + // Safety: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), self.end)) + .collect(), + ) + } + } +} + +type ValuesIter<'a> = StructValueIter<'a>; +type ZipIter<'a> = ZipValidity>, ValuesIter<'a>, BitmapIter<'a>>; + +impl<'a> IntoIterator for &'a StructArray { + type Item = Option>>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> StructArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + ZipValidity::new_with_validity(StructValueIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ValuesIter<'a> { + StructValueIter::new(self) + } +} diff --git a/crates/nano-arrow/src/array/struct_/mod.rs b/crates/nano-arrow/src/array/struct_/mod.rs new file mode 100644 index 000000000000..8107c885e4ef --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/mod.rs @@ -0,0 +1,255 @@ +use super::{new_empty_array, new_null_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +pub use mutable::*; + +/// A [`StructArray`] is a nested [`Array`] with an optional validity representing +/// multiple [`Array`] with the same number of rows. +/// # Example +/// ``` +/// use arrow2::array::*; +/// use arrow2::datatypes::*; +/// let boolean = BooleanArray::from_slice(&[false, false, true, true]).boxed(); +/// let int = Int32Array::from_slice(&[42, 28, 19, 31]).boxed(); +/// +/// let fields = vec![ +/// Field::new("b", DataType::Boolean, false), +/// Field::new("c", DataType::Int32, false), +/// ]; +/// +/// let array = StructArray::new(DataType::Struct(fields), vec![boolean, int], None); +/// ``` +#[derive(Clone)] +pub struct StructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +impl StructArray { + /// Returns a new [`StructArray`]. + /// # Errors + /// This function errors iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `data_type` are empty + /// * the values's len is different from children's length + /// * any of the values's data type is different from its corresponding children' data type + /// * any element of values has a different length than the first element + /// * the validity's length is not equal to the length of the first element + pub fn try_new( + data_type: DataType, + values: Vec>, + validity: Option, + ) -> Result { + let fields = Self::try_get_fields(&data_type)?; + if fields.is_empty() { + return Err(Error::oos("A StructArray must contain at least one field")); + } + if fields.len() != values.len() { + return Err(Error::oos( + "A StructArray must have a number of fields in its DataType equal to the number of child values", + )); + } + + fields + .iter().map(|a| &a.data_type) + .zip(values.iter().map(|a| a.data_type())) + .enumerate() + .try_for_each(|(index, (data_type, child))| { + if data_type != child { + Err(Error::oos(format!( + "The children DataTypes of a StructArray must equal the children data types. + However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + ))) + } else { + Ok(()) + } + })?; + + let len = values[0].len(); + values + .iter() + .map(|a| a.len()) + .enumerate() + .try_for_each(|(index, a_len)| { + if a_len != len { + Err(Error::oos(format!( + "The children must have an equal number of values. + However, the values at index {index} have a length of {a_len}, which is different from values at index 0, {len}." + ))) + } else { + Ok(()) + } + })?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "The validity length of a StructArray must match its number of elements", + )); + } + + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Returns a new [`StructArray`] + /// # Panics + /// This function panics iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `data_type` are empty + /// * the values's len is different from children's length + /// * any of the values's data type is different from its corresponding children' data type + /// * any element of values has a different length than the first element + /// * the validity's length is not equal to the length of the first element + pub fn new(data_type: DataType, values: Vec>, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Creates an empty [`StructArray`]. + pub fn new_empty(data_type: DataType) -> Self { + if let DataType::Struct(fields) = &data_type.to_logical_type() { + let values = fields + .iter() + .map(|field| new_empty_array(field.data_type().clone())) + .collect(); + Self::new(data_type, values, None) + } else { + panic!("StructArray must be initialized with DataType::Struct"); + } + } + + /// Creates a null [`StructArray`] of length `length`. + pub fn new_null(data_type: DataType, length: usize) -> Self { + if let DataType::Struct(fields) = &data_type { + let values = fields + .iter() + .map(|field| new_null_array(field.data_type().clone(), length)) + .collect(); + Self::new(data_type, values, Some(Bitmap::new_zeroed(length))) + } else { + panic!("StructArray must be initialized with DataType::Struct"); + } + } +} + +// must use +impl StructArray { + /// Deconstructs the [`StructArray`] into its individual components. + #[must_use] + pub fn into_data(self) -> (Vec, Vec>, Option) { + let Self { + data_type, + values, + validity, + } = self; + let fields = if let DataType::Struct(fields) = data_type { + fields + } else { + unreachable!() + }; + (fields, values, validity) + } + + /// Slices this [`StructArray`]. + /// # Panics + /// * `offset + length` must be smaller than `self.len()`. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "offset + length may not exceed length of array" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`StructArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values + .iter_mut() + .for_each(|x| x.slice_unchecked(offset, length)); + } + + impl_sliced!(); + + impl_mut_validity!(); + + impl_into_array!(); +} + +// Accessors +impl StructArray { + #[inline] + fn len(&self) -> usize { + self.values[0].len() + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the values of this [`StructArray`]. + pub fn values(&self) -> &[Box] { + &self.values + } + + /// Returns the fields of this [`StructArray`]. + pub fn fields(&self) -> &[Field] { + Self::get_fields(&self.data_type) + } +} + +impl StructArray { + /// Returns the fields the `DataType::Struct`. + pub(crate) fn try_get_fields(data_type: &DataType) -> Result<&[Field], Error> { + match data_type.to_logical_type() { + DataType::Struct(fields) => Ok(fields), + _ => Err(Error::oos( + "Struct array must be created with a DataType whose physical type is Struct", + )), + } + } + + /// Returns the fields the `DataType::Struct`. + pub fn get_fields(data_type: &DataType) -> &[Field] { + Self::try_get_fields(data_type).unwrap() + } +} + +impl Array for StructArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/struct_/mutable.rs b/crates/nano-arrow/src/array/struct_/mutable.rs new file mode 100644 index 000000000000..8060a698fb63 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/mutable.rs @@ -0,0 +1,245 @@ +use std::sync::Arc; + +use super::StructArray; +use crate::array::{Array, MutableArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Error; + +/// Converting a [`MutableStructArray`] into a [`StructArray`] is `O(1)`. +#[derive(Debug)] +pub struct MutableStructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +fn check( + data_type: &DataType, + values: &[Box], + validity: Option, +) -> Result<(), Error> { + let fields = StructArray::try_get_fields(data_type)?; + if fields.is_empty() { + return Err(Error::oos("A StructArray must contain at least one field")); + } + if fields.len() != values.len() { + return Err(Error::oos( + "A StructArray must have a number of fields in its DataType equal to the number of child values", + )); + } + + fields + .iter().map(|a| &a.data_type) + .zip(values.iter().map(|a| a.data_type())) + .enumerate() + .try_for_each(|(index, (data_type, child))| { + if data_type != child { + Err(Error::oos(format!( + "The children DataTypes of a StructArray must equal the children data types. + However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + ))) + } else { + Ok(()) + } + })?; + + let len = values[0].len(); + values + .iter() + .map(|a| a.len()) + .enumerate() + .try_for_each(|(index, a_len)| { + if a_len != len { + Err(Error::oos(format!( + "The children must have an equal number of values. + However, the values at index {index} have a length of {a_len}, which is different from values at index 0, {len}." + ))) + } else { + Ok(()) + } + })?; + + if validity.map_or(false, |validity| validity != len) { + return Err(Error::oos( + "The validity length of a StructArray must match its number of elements", + )); + } + Ok(()) +} + +impl From for StructArray { + fn from(other: MutableStructArray) -> Self { + let validity = if other.validity.as_ref().map(|x| x.unset_bits()).unwrap_or(0) > 0 { + other.validity.map(|x| x.into()) + } else { + None + }; + + StructArray::new( + other.data_type, + other.values.into_iter().map(|mut v| v.as_box()).collect(), + validity, + ) + } +} + +impl MutableStructArray { + /// Creates a new [`MutableStructArray`]. + pub fn new(data_type: DataType, values: Vec>) -> Self { + Self::try_new(data_type, values, None).unwrap() + } + + /// Create a [`MutableStructArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * `data_type` is not [`DataType::Struct`] + /// * The inner types of `data_type` are not equal to those of `values` + /// * `validity` is not `None` and its length is different from the `values`'s length + pub fn try_new( + data_type: DataType, + values: Vec>, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|x| x.len()))?; + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Extract the low-end APIs from the [`MutableStructArray`]. + pub fn into_inner(self) -> (DataType, Vec>, Option) { + (self.data_type, self.values, self.validity) + } + + /// The mutable values + pub fn mut_values(&mut self) -> &mut Vec> { + &mut self.values + } + + /// The values + pub fn values(&self) -> &Vec> { + &self.values + } + + /// Return the `i`th child array. + pub fn value(&mut self, i: usize) -> Option<&mut A> { + self.values[i].as_mut_any().downcast_mut::() + } +} + +impl MutableStructArray { + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + for v in &mut self.values { + v.reserve(additional); + } + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Call this once for each "row" of children you push. + pub fn push(&mut self, valid: bool) { + match &mut self.validity { + Some(validity) => validity.push(valid), + None => match valid { + true => (), + false => self.init_validity(), + }, + }; + } + + fn push_null(&mut self) { + for v in &mut self.values { + v.push_null(); + } + self.push(false); + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + let len = self.len(); + if len > 0 { + validity.extend_constant(len, true); + validity.set(len - 1, false); + } + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: StructArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableStructArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + for v in &mut self.values { + v.shrink_to_fit(); + } + if let Some(validity) = self.validity.as_mut() { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableStructArray { + fn len(&self) -> usize { + self.values.first().map(|v| v.len()).unwrap_or(0) + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + StructArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values) + .into_iter() + .map(|mut v| v.as_box()) + .collect(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + StructArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values) + .into_iter() + .map(|mut v| v.as_box()) + .collect(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push_null() + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } +} diff --git a/crates/nano-arrow/src/array/union/data.rs b/crates/nano-arrow/src/array/union/data.rs new file mode 100644 index 000000000000..6de6c0074231 --- /dev/null +++ b/crates/nano-arrow/src/array/union/data.rs @@ -0,0 +1,70 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, UnionArray}; +use crate::buffer::Buffer; +use crate::datatypes::DataType; + +impl Arrow2Arrow for UnionArray { + fn to_data(&self) -> ArrayData { + let data_type = arrow_schema::DataType::from(self.data_type.clone()); + let len = self.len(); + + let builder = match self.offsets.clone() { + Some(offsets) => ArrayDataBuilder::new(data_type) + .len(len) + .buffers(vec![self.types.clone().into(), offsets.into()]) + .child_data(self.fields.iter().map(|x| to_data(x.as_ref())).collect()), + None => ArrayDataBuilder::new(data_type) + .len(len) + .buffers(vec![self.types.clone().into()]) + .child_data( + self.fields + .iter() + .map(|x| to_data(x.as_ref()).slice(self.offset, len)) + .collect(), + ), + }; + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type: DataType = data.data_type().clone().into(); + + let fields = data.child_data().iter().map(from_data).collect(); + let buffers = data.buffers(); + let mut types: Buffer = buffers[0].clone().into(); + types.slice(data.offset(), data.len()); + let offsets = match buffers.len() == 2 { + true => { + let mut offsets: Buffer = buffers[1].clone().into(); + offsets.slice(data.offset(), data.len()); + Some(offsets) + }, + false => None, + }; + + // Map from type id to array index + let map = match &data_type { + DataType::Union(_, Some(ids), _) => { + let mut map = [0; 127]; + for (pos, &id) in ids.iter().enumerate() { + map[id as usize] = pos; + } + Some(map) + }, + DataType::Union(_, None, _) => None, + _ => unreachable!("must be Union type"), + }; + + Self { + types, + map, + fields, + offsets, + data_type, + offset: data.offset(), + } + } +} diff --git a/crates/nano-arrow/src/array/union/ffi.rs b/crates/nano-arrow/src/array/union/ffi.rs new file mode 100644 index 000000000000..590afec0c6c5 --- /dev/null +++ b/crates/nano-arrow/src/array/union/ffi.rs @@ -0,0 +1,60 @@ +use super::super::ffi::ToFfi; +use super::super::Array; +use super::UnionArray; +use crate::array::FromFfi; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for UnionArray { + fn buffers(&self) -> Vec> { + if let Some(offsets) = &self.offsets { + vec![ + Some(self.types.as_ptr().cast::()), + Some(offsets.as_ptr().cast::()), + ] + } else { + vec![Some(self.types.as_ptr().cast::())] + } + } + + fn children(&self) -> Vec> { + self.fields.clone() + } + + fn offset(&self) -> Option { + Some(self.types.offset()) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for UnionArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let fields = Self::get_fields(&data_type); + + let mut types = unsafe { array.buffer::(0) }?; + let offsets = if Self::is_sparse(&data_type) { + None + } else { + Some(unsafe { array.buffer::(1) }?) + }; + + let length = array.array().len(); + let offset = array.array().offset(); + let fields = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + ffi::try_from(child) + }) + .collect::>>>()?; + + if offset > 0 { + types.slice(offset, length); + }; + + Self::try_new(data_type, types, fields, offsets) + } +} diff --git a/crates/nano-arrow/src/array/union/fmt.rs b/crates/nano-arrow/src/array/union/fmt.rs new file mode 100644 index 000000000000..521201fffd6d --- /dev/null +++ b/crates/nano-arrow/src/array/union/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::UnionArray; + +pub fn write_value( + array: &UnionArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let (field, index) = array.index(index); + + get_display(array.fields()[field].as_ref(), null)(f, index) +} + +impl Debug for UnionArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "UnionArray")?; + write_vec(f, writer, None, self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/union/iterator.rs b/crates/nano-arrow/src/array/union/iterator.rs new file mode 100644 index 000000000000..bdcf5825af6c --- /dev/null +++ b/crates/nano-arrow/src/array/union/iterator.rs @@ -0,0 +1,59 @@ +use super::UnionArray; +use crate::scalar::Scalar; +use crate::trusted_len::TrustedLen; + +#[derive(Debug, Clone)] +pub struct UnionIter<'a> { + array: &'a UnionArray, + current: usize, +} + +impl<'a> UnionIter<'a> { + #[inline] + pub fn new(array: &'a UnionArray) -> Self { + Self { array, current: 0 } + } +} + +impl<'a> Iterator for UnionIter<'a> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.current == self.array.len() { + None + } else { + let old = self.current; + self.current += 1; + Some(unsafe { self.array.value_unchecked(old) }) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.array.len() - self.current; + (len, Some(len)) + } +} + +impl<'a> IntoIterator for &'a UnionArray { + type Item = Box; + type IntoIter = UnionIter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> UnionArray { + /// constructs a new iterator + #[inline] + pub fn iter(&'a self) -> UnionIter<'a> { + UnionIter::new(self) + } +} + +impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {} + +unsafe impl<'a> TrustedLen for UnionIter<'a> {} diff --git a/crates/nano-arrow/src/array/union/mod.rs b/crates/nano-arrow/src/array/union/mod.rs new file mode 100644 index 000000000000..9150920ea021 --- /dev/null +++ b/crates/nano-arrow/src/array/union/mod.rs @@ -0,0 +1,377 @@ +use super::{new_empty_array, new_null_array, Array}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::{DataType, Field, UnionMode}; +use crate::error::Error; +use crate::scalar::{new_scalar, Scalar}; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; + +type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode); + +/// [`UnionArray`] represents an array whose each slot can contain different values. +/// +// How to read a value at slot i: +// ``` +// let index = self.types()[i] as usize; +// let field = self.fields()[index]; +// let offset = self.offsets().map(|x| x[index]).unwrap_or(i); +// let field = field.as_any().downcast to correct type; +// let value = field.value(offset); +// ``` +#[derive(Clone)] +pub struct UnionArray { + // Invariant: every item in `types` is `> 0 && < fields.len()` + types: Buffer, + // Invariant: `map.len() == fields.len()` + // Invariant: every item in `map` is `> 0 && < fields.len()` + map: Option<[usize; 127]>, + fields: Vec>, + // Invariant: when set, `offsets.len() == types.len()` + offsets: Option>, + data_type: DataType, + offset: usize, +} + +impl UnionArray { + /// Returns a new [`UnionArray`]. + /// # Errors + /// This function errors iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `data_type`'s children's length + /// * The number of `fields` is larger than `i8::MAX` + /// * any of the values's data type is different from its corresponding children' data type + pub fn try_new( + data_type: DataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Result { + let (f, ids, mode) = Self::try_get_all(&data_type)?; + + if f.len() != fields.len() { + return Err(Error::oos( + "The number of `fields` must equal the number of children fields in DataType::Union", + )); + }; + let number_of_fields: i8 = fields + .len() + .try_into() + .map_err(|_| Error::oos("The number of `fields` cannot be larger than i8::MAX"))?; + + f + .iter().map(|a| a.data_type()) + .zip(fields.iter().map(|a| a.data_type())) + .enumerate() + .try_for_each(|(index, (data_type, child))| { + if data_type != child { + Err(Error::oos(format!( + "The children DataTypes of a UnionArray must equal the children data types. + However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + ))) + } else { + Ok(()) + } + })?; + + if let Some(offsets) = &offsets { + if offsets.len() != types.len() { + return Err(Error::oos( + "In a UnionArray, the offsets' length must be equal to the number of types", + )); + } + } + if offsets.is_none() != mode.is_sparse() { + return Err(Error::oos( + "In a sparse UnionArray, the offsets must be set (and vice-versa)", + )); + } + + // build hash + let map = if let Some(&ids) = ids.as_ref() { + if ids.len() != fields.len() { + return Err(Error::oos( + "In a union, when the ids are set, their length must be equal to the number of fields", + )); + } + + // example: + // * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5] + // * ids = [5, 7] + // => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...] + let mut hash = [0; 127]; + + for (pos, &id) in ids.iter().enumerate() { + if !(0..=127).contains(&id) { + return Err(Error::oos( + "In a union, when the ids are set, every id must belong to [0, 128[", + )); + } + hash[id as usize] = pos; + } + + types.iter().try_for_each(|&type_| { + if type_ < 0 { + return Err(Error::oos("In a union, when the ids are set, every type must be >= 0")); + } + let id = hash[type_ as usize]; + if id >= fields.len() { + Err(Error::oos("In a union, when the ids are set, each id must be smaller than the number of fields.")) + } else { + Ok(()) + } + })?; + + Some(hash) + } else { + // Safety: every type in types is smaller than number of fields + let mut is_valid = true; + for &type_ in types.iter() { + if type_ < 0 || type_ >= number_of_fields { + is_valid = false + } + } + if !is_valid { + return Err(Error::oos( + "Every type in `types` must be larger than 0 and smaller than the number of fields.", + )); + } + + None + }; + + Ok(Self { + data_type, + map, + fields, + offsets, + types, + offset: 0, + }) + } + + /// Returns a new [`UnionArray`]. + /// # Panics + /// This function panics iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `data_type`'s children's length + /// * any of the values's data type is different from its corresponding children' data type + pub fn new( + data_type: DataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Self { + Self::try_new(data_type, types, fields, offsets).unwrap() + } + + /// Creates a new null [`UnionArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + if let DataType::Union(f, _, mode) = &data_type { + let fields = f + .iter() + .map(|x| new_null_array(x.data_type().clone(), length)) + .collect(); + + let offsets = if mode.is_sparse() { + None + } else { + Some((0..length as i32).collect::>().into()) + }; + + // all from the same field + let types = vec![0i8; length].into(); + + Self::new(data_type, types, fields, offsets) + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } + + /// Creates a new empty [`UnionArray`]. + pub fn new_empty(data_type: DataType) -> Self { + if let DataType::Union(f, _, mode) = data_type.to_logical_type() { + let fields = f + .iter() + .map(|x| new_empty_array(x.data_type().clone())) + .collect(); + + let offsets = if mode.is_sparse() { + None + } else { + Some(Buffer::default()) + }; + + Self { + data_type, + map: None, + fields, + offsets, + types: Buffer::new(), + offset: 0, + } + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } +} + +impl UnionArray { + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Panic + /// This function panics iff `offset + length >= self.len()`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); + + self.types.slice_unchecked(offset, length); + if let Some(offsets) = self.offsets.as_mut() { + offsets.slice_unchecked(offset, length) + } + self.offset += offset; + } + + impl_sliced!(); + impl_into_array!(); +} + +impl UnionArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.types.len() + } + + /// The optional offsets. + pub fn offsets(&self) -> Option<&Buffer> { + self.offsets.as_ref() + } + + /// The fields. + pub fn fields(&self) -> &Vec> { + &self.fields + } + + /// The types. + pub fn types(&self) -> &Buffer { + &self.types + } + + #[inline] + unsafe fn field_slot_unchecked(&self, index: usize) -> usize { + self.offsets() + .as_ref() + .map(|x| *x.get_unchecked(index) as usize) + .unwrap_or(index + self.offset) + } + + /// Returns the index and slot of the field to select from `self.fields`. + #[inline] + pub fn index(&self, index: usize) -> (usize, usize) { + assert!(index < self.len()); + unsafe { self.index_unchecked(index) } + } + + /// Returns the index and slot of the field to select from `self.fields`. + /// The first value is guaranteed to be `< self.fields().len()` + /// # Safety + /// This function is safe iff `index < self.len`. + #[inline] + pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) { + debug_assert!(index < self.len()); + // Safety: assumption of the function + let type_ = unsafe { *self.types.get_unchecked(index) }; + // Safety: assumption of the struct + let type_ = self + .map + .as_ref() + .map(|map| unsafe { *map.get_unchecked(type_ as usize) }) + .unwrap_or(type_ as usize); + // Safety: assumption of the function + let index = self.field_slot_unchecked(index); + (type_, index) + } + + /// Returns the slot `index` as a [`Scalar`]. + /// # Panics + /// iff `index >= self.len()` + pub fn value(&self, index: usize) -> Box { + assert!(index < self.len()); + unsafe { self.value_unchecked(index) } + } + + /// Returns the slot `index` as a [`Scalar`]. + /// # Safety + /// This function is safe iff `i < self.len`. + pub unsafe fn value_unchecked(&self, index: usize) -> Box { + debug_assert!(index < self.len()); + let (type_, index) = self.index_unchecked(index); + // Safety: assumption of the struct + debug_assert!(type_ < self.fields.len()); + let field = self.fields.get_unchecked(type_).as_ref(); + new_scalar(field, index) + } +} + +impl Array for UnionArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + None + } + + fn with_validity(&self, _: Option) -> Box { + panic!("cannot set validity of a union array") + } +} + +impl UnionArray { + fn try_get_all(data_type: &DataType) -> Result { + match data_type.to_logical_type() { + DataType::Union(fields, ids, mode) => { + Ok((fields, ids.as_ref().map(|x| x.as_ref()), *mode)) + }, + _ => Err(Error::oos( + "The UnionArray requires a logical type of DataType::Union", + )), + } + } + + fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, UnionMode) { + Self::try_get_all(data_type).unwrap() + } + + /// Returns all fields from [`DataType::Union`]. + /// # Panic + /// Panics iff `data_type`'s logical type is not [`DataType::Union`]. + pub fn get_fields(data_type: &DataType) -> &[Field] { + Self::get_all(data_type).0 + } + + /// Returns whether the [`DataType::Union`] is sparse or not. + /// # Panic + /// Panics iff `data_type`'s logical type is not [`DataType::Union`]. + pub fn is_sparse(data_type: &DataType) -> bool { + Self::get_all(data_type).2.is_sparse() + } +} diff --git a/crates/nano-arrow/src/array/utf8/data.rs b/crates/nano-arrow/src/array/utf8/data.rs new file mode 100644 index 000000000000..16674c969372 --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/data.rs @@ -0,0 +1,42 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, Utf8Array}; +use crate::bitmap::Bitmap; +use crate::offset::{Offset, OffsetsBuffer}; + +impl Arrow2Arrow for Utf8Array { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type().clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.offsets().len_proxy()) + .buffers(vec![ + self.offsets.clone().into_inner().into(), + self.values.clone().into(), + ]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let buffers = data.buffers(); + + // Safety: ArrayData is valid + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type, + offsets, + values: buffers[1].clone().into(), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/utf8/ffi.rs b/crates/nano-arrow/src/array/utf8/ffi.rs new file mode 100644 index 000000000000..2129a85a6f8f --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/ffi.rs @@ -0,0 +1,62 @@ +use super::Utf8Array; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for Utf8Array { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for Utf8Array { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let values = unsafe { array.buffer::(2)? }; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new_unchecked(data_type, offsets, values, validity)) + } +} diff --git a/crates/nano-arrow/src/array/utf8/fmt.rs b/crates/nano-arrow/src/array/utf8/fmt.rs new file mode 100644 index 000000000000..4466444ffe3b --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/fmt.rs @@ -0,0 +1,23 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::Utf8Array; +use crate::offset::Offset; + +pub fn write_value(array: &Utf8Array, index: usize, f: &mut W) -> Result { + write!(f, "{}", array.value(index)) +} + +impl Debug for Utf8Array { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + let head = if O::IS_LARGE { + "LargeUtf8Array" + } else { + "Utf8Array" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/utf8/from.rs b/crates/nano-arrow/src/array/utf8/from.rs new file mode 100644 index 000000000000..c1dcaf09b10d --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/from.rs @@ -0,0 +1,11 @@ +use std::iter::FromIterator; + +use super::{MutableUtf8Array, Utf8Array}; +use crate::offset::Offset; + +impl> FromIterator> for Utf8Array { + #[inline] + fn from_iter>>(iter: I) -> Self { + MutableUtf8Array::::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/utf8/iterator.rs b/crates/nano-arrow/src/array/utf8/iterator.rs new file mode 100644 index 000000000000..262b98c10d79 --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/iterator.rs @@ -0,0 +1,79 @@ +use super::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array}; +use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for Utf8Array { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`Utf8Array`]. +pub type Utf8ValuesIter<'a, O> = ArrayValuesIter<'a, Utf8Array>; + +impl<'a, O: Offset> IntoIterator for &'a Utf8Array { + type Item = Option<&'a str>; + type IntoIter = ZipValidity<&'a str, Utf8ValuesIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableUtf8Array { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`MutableUtf8ValuesArray`]. +pub type MutableUtf8ValuesIter<'a, O> = ArrayValuesIter<'a, MutableUtf8ValuesArray>; + +impl<'a, O: Offset> IntoIterator for &'a MutableUtf8Array { + type Item = Option<&'a str>; + type IntoIter = ZipValidity<&'a str, MutableUtf8ValuesIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableUtf8ValuesArray { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +impl<'a, O: Offset> IntoIterator for &'a MutableUtf8ValuesArray { + type Item = &'a str; + type IntoIter = ArrayValuesIter<'a, MutableUtf8ValuesArray>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/nano-arrow/src/array/utf8/mod.rs b/crates/nano-arrow/src/array/utf8/mod.rs new file mode 100644 index 000000000000..bae0169224e9 --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/mod.rs @@ -0,0 +1,546 @@ +use either::Either; + +use super::specification::{try_check_offsets_bounds, try_check_utf8}; +use super::{Array, GenericBinaryArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; + +#[cfg(feature = "arrow_rs")] +mod data; +mod ffi; +pub(super) mod fmt; +mod from; +mod iterator; +mod mutable; +mod mutable_values; +pub use iterator::*; +pub use mutable::*; +pub use mutable_values::MutableUtf8ValuesArray; + +// Auxiliary struct to allow presenting &str as [u8] to a generic function +pub(super) struct StrAsBytes

(P); +impl> AsRef<[u8]> for StrAsBytes { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.0.as_ref().as_bytes() + } +} + +/// A [`Utf8Array`] is arrow's semantic equivalent of an immutable `Vec>`. +/// Cloning and slicing this struct is `O(1)`. +/// # Example +/// ``` +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// use arrow2::array::Utf8Array; +/// # fn main() { +/// let array = Utf8Array::::from([Some("hi"), None, Some("there")]); +/// assert_eq!(array.value(0), "hi"); +/// assert_eq!(array.iter().collect::>(), vec![Some("hi"), None, Some("there")]); +/// assert_eq!(array.values_iter().collect::>(), vec!["hi", "", "there"]); +/// // the underlying representation +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// assert_eq!(array.values(), &Buffer::from(b"hithere".to_vec())); +/// assert_eq!(array.offsets().buffer(), &Buffer::from(vec![0, 2, 2, 2 + 5])); +/// # } +/// ``` +/// +/// # Generic parameter +/// The generic parameter [`Offset`] can only be `i32` or `i64` and tradeoffs maximum array length with +/// memory usage: +/// * the sum of lengths of all elements cannot exceed `Offset::MAX` +/// * the total size of the underlying data is `array.len() * size_of::() + sum of lengths of all elements` +/// +/// # Safety +/// The following invariants hold: +/// * Two consecutives `offsets` casted (`as`) to `usize` are valid slices of `values`. +/// * A slice of `values` taken from two consecutives `offsets` is valid `utf8`. +/// * `len` is equal to `validity.len()`, when defined. +#[derive(Clone)] +pub struct Utf8Array { + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, +} + +// constructors +impl Utf8Array { + /// Returns a [`Utf8Array`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Result { + try_check_utf8(&offsets, &values)?; + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "Utf8Array can only be initialized with DataType::Utf8 or DataType::LargeUtf8", + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Returns a [`Utf8Array`] from a slice of `&str`. + /// + /// A convenience method that uses [`Self::from_trusted_len_values_iter`]. + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter()) + } + + /// Returns a new [`Utf8Array`] from a slice of `&str`. + /// + /// A convenience method that uses [`Self::from_trusted_len_iter`]. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + MutableUtf8Array::::from(slice).into() + } + + /// Returns an iterator of `Option<&str>` + pub fn iter(&self) -> ZipValidity<&str, Utf8ValuesIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity()) + } + + /// Returns an iterator of `&str` + pub fn values_iter(&self) -> Utf8ValuesIter { + Utf8ValuesIter::new(self) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + + // soundness: the invariant of the struct + let slice = self.values.get_unchecked(start..end); + + // soundness: the invariant of the struct + std::str::from_utf8_unchecked(slice) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&str> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns the [`DataType`] of this array. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the values of this [`Utf8Array`]. + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the offsets of this [`Utf8Array`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Slices this [`Utf8Array`]. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the arrays' length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`Utf8Array`]. + /// # Implementation + /// This function is `O(1)` + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, OffsetsBuffer, Buffer, Option) { + let Self { + data_type, + offsets, + values, + validity, + } = self; + (data_type, offsets, values, validity) + } + + /// Try to convert this `Utf8Array` to a `MutableUtf8Array` + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + // Safety: invariants are preserved + Left(bitmap) => Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + self.offsets, + self.values, + Some(bitmap), + ) + }), + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + // Safety: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + offsets, + values, + Some(mutable_bitmap.into()), + ) + }) + }, + (Left(values), Right(offsets)) => { + // Safety: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + offsets.into(), + values, + Some(mutable_bitmap.into()), + ) + }) + }, + (Right(values), Left(offsets)) => { + // Safety: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + offsets, + values.into(), + Some(mutable_bitmap.into()), + ) + }) + }, + (Right(values), Right(offsets)) => Right(unsafe { + MutableUtf8Array::new_unchecked( + self.data_type, + offsets, + values, + Some(mutable_bitmap), + ) + }), + }, + } + } else { + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(unsafe { Utf8Array::new_unchecked(self.data_type, offsets, values, None) }) + }, + (Left(values), Right(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.data_type, offsets.into(), values, None) + }), + (Right(values), Left(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.data_type, offsets, values.into(), None) + }), + (Right(values), Right(offsets)) => Right(unsafe { + MutableUtf8Array::new_unchecked(self.data_type, offsets, values, None) + }), + } + } + } + + /// Returns a new empty [`Utf8Array`]. + /// + /// The array is guaranteed to have no elements nor validity. + #[inline] + pub fn new_empty(data_type: DataType) -> Self { + unsafe { Self::new_unchecked(data_type, OffsetsBuffer::new(), Buffer::new(), None) } + } + + /// Returns a new [`Utf8Array`] whose all slots are null / `None`. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new( + data_type, + Offsets::new_zeroed(length).into(), + Buffer::new(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns a default [`DataType`] of this array, which depends on the generic parameter `O`: `DataType::Utf8` or `DataType::LargeUtf8` + pub fn default_data_type() -> DataType { + if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + } + } + + /// Creates a new [`Utf8Array`] without checking for offsets monotinicity nor utf8-validity + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// # Safety + /// This function is unsound iff: + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn try_new_unchecked( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "BinaryArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8", + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Creates a new [`Utf8Array`]. + /// # Panics + /// This function panics iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, values, validity).unwrap() + } + + /// Creates a new [`Utf8Array`] without checking for offsets monotinicity. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// # Safety + /// This function is unsound iff: + /// * the offsets are not monotonically increasing + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn new_unchecked( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new_unchecked(data_type, offsets, values, validity).unwrap() + } + + /// Returns a (non-null) [`Utf8Array`] created from a [`TrustedLen`] of `&str`. + /// # Implementation + /// This function is `O(N)` + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableUtf8Array::::from_trusted_len_values_iter(iterator).into() + } + + /// Creates a new [`Utf8Array`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableUtf8Array::::from_iter_values(iterator).into() + } + + /// Creates a [`Utf8Array`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator>, + { + MutableUtf8Array::::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`Utf8Array`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen>, + { + MutableUtf8Array::::from_trusted_len_iter(iterator).into() + } + + /// Creates a [`Utf8Array`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef, + I: IntoIterator, E>>, + { + MutableUtf8Array::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) + } + + /// Creates a [`Utf8Array`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> std::result::Result + where + P: AsRef, + I: TrustedLen, E>>, + { + MutableUtf8Array::::try_from_trusted_len_iter(iter).map(|x| x.into()) + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity Bitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } +} + +impl Array for Utf8Array { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +unsafe impl GenericBinaryArray for Utf8Array { + #[inline] + fn values(&self) -> &[u8] { + self.values() + } + + #[inline] + fn offsets(&self) -> &[O] { + self.offsets().buffer() + } +} + +impl Default for Utf8Array { + fn default() -> Self { + let data_type = if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + Utf8Array::new(data_type, Default::default(), Default::default(), None) + } +} diff --git a/crates/nano-arrow/src/array/utf8/mutable.rs b/crates/nano-arrow/src/array/utf8/mutable.rs new file mode 100644 index 000000000000..3fc47b3eae1d --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/mutable.rs @@ -0,0 +1,549 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{MutableUtf8ValuesArray, MutableUtf8ValuesIter, StrAsBytes, Utf8Array}; +use crate::array::physical_binary::*; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`Utf8Array`]. It differs +/// from [`MutableUtf8ValuesArray`] in that it can build nullable [`Utf8Array`]s. +#[derive(Debug, Clone)] +pub struct MutableUtf8Array { + values: MutableUtf8ValuesArray, + validity: Option, +} + +impl From> for Utf8Array { + fn from(other: MutableUtf8Array) -> Self { + let validity = other.validity.and_then(|x| { + let validity: Option = x.into(); + validity + }); + let array: Utf8Array = other.values.into(); + array.with_validity(validity) + } +} + +impl Default for MutableUtf8Array { + fn default() -> Self { + Self::new() + } +} + +impl MutableUtf8Array { + /// Initializes a new empty [`MutableUtf8Array`]. + pub fn new() -> Self { + Self { + values: Default::default(), + validity: None, + } + } + + /// Returns a [`MutableUtf8Array`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + data_type: DataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Result { + let values = MutableUtf8ValuesArray::try_new(data_type, offsets, values)?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity's length must be equal to the number of values", + )); + } + + Ok(Self { values, validity }) + } + + /// Create a [`MutableUtf8Array`] out of low-end APIs. + /// # Safety + /// The caller must ensure that every value between offsets is a valid utf8. + /// # Panics + /// This function panics iff: + /// * The `offsets` and `values` are inconsistent + /// * The validity is not `None` and its length is different from `offsets`'s length minus one. + pub unsafe fn new_unchecked( + data_type: DataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Self { + let values = MutableUtf8ValuesArray::new_unchecked(data_type, offsets, values); + if let Some(ref validity) = validity { + assert_eq!(values.len(), validity.len()); + } + Self { values, validity } + } + + /// Creates a new [`MutableUtf8Array`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } + + fn default_data_type() -> DataType { + Utf8Array::::default_data_type() + } + + /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + values: MutableUtf8ValuesArray::with_capacities(capacity, values), + validity: None, + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.values.reserve(additional, additional_values); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn capacity(&self) -> usize { + self.values.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// Pushes a new element to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: Option) { + self.try_push(value).unwrap() + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + self.values.value(i) + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + self.values.value_unchecked(i) + } + + /// Pop the last entry from [`MutableUtf8Array`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity); + } + + /// Returns an iterator of `Option<&str>` + pub fn iter(&self) -> ZipValidity<&str, MutableUtf8ValuesIter, BitmapIter> { + ZipValidity::new(self.values_iter(), self.validity.as_ref().map(|x| x.iter())) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: Utf8Array = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableUtf8Array`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + /// Extract the low-end APIs from the [`MutableUtf8Array`]. + pub fn into_data(self) -> (DataType, Offsets, Vec, Option) { + let (data_type, offsets, values) = self.values.into_inner(); + (data_type, offsets, values, self.validity) + } + + /// Returns an iterator of `&str` + pub fn values_iter(&self) -> MutableUtf8ValuesIter { + self.values.iter() + } + + /// Sets the validity. + /// # Panic + /// Panics iff the validity's len is not equal to the existing values' length. + pub fn set_validity(&mut self, validity: Option) { + if let Some(validity) = &validity { + assert_eq!(self.values.len(), validity.len()) + } + self.validity = validity; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity MutableBitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } +} + +impl MutableUtf8Array { + /// returns its values. + pub fn values(&self) -> &Vec { + self.values.values() + } + + /// returns its offsets. + pub fn offsets(&self) -> &Offsets { + self.values.offsets() + } +} + +impl MutableArray for MutableUtf8Array { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: Utf8Array = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: Utf8Array = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + if O::IS_LARGE { + &DataType::LargeUtf8 + } else { + &DataType::Utf8 + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&str>(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator> for MutableUtf8Array { + fn from_iter>>(iter: I) -> Self { + Self::try_from_iter(iter).unwrap() + } +} + +impl MutableUtf8Array { + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let length = self.values.len(); + self.values.extend(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let length = self.values.len(); + self.values.extend_trusted_len_unchecked(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8Array`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + self.values + .extend_from_trusted_len_iter(self.validity.as_mut().unwrap(), iterator); + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator>, + { + let iterator = iterator.map(|x| x.map(StrAsBytes)); + let (validity, offsets, values) = trusted_len_unzip(iterator); + + // soundness: P is `str` + Self::new_unchecked(Self::default_data_type(), offsets, values, validity) + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length of `&str`. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked, I: Iterator>( + iterator: I, + ) -> Self { + MutableUtf8ValuesArray::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a new [`MutableUtf8Array`] from a [`TrustedLen`] of `&str`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_values_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableUtf8Array`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + fn try_from_iter, I: IntoIterator>>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + let iterator = iterator.map(|x| x.map(|x| x.map(StrAsBytes))); + let (validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + // soundness: P is `str` + Ok(Self::new_unchecked( + Self::default_data_type(), + offsets, + values, + validity, + )) + } + + /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableUtf8Array`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableUtf8ValuesArray::from_iter(iterator).into() + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator, E>>, + T: AsRef, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend> for MutableUtf8Array { + fn extend>>(&mut self, iter: I) { + self.try_extend(iter).unwrap(); + } +} + +impl> TryExtend> for MutableUtf8Array { + fn try_extend>>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush> for MutableUtf8Array { + #[inline] + fn try_push(&mut self, value: Option) -> Result<()> { + match value { + Some(value) => { + self.values.try_push(value.as_ref())?; + + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(""); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } +} + +impl PartialEq for MutableUtf8Array { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableUtf8Array { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/nano-arrow/src/array/utf8/mutable_values.rs b/crates/nano-arrow/src/array/utf8/mutable_values.rs new file mode 100644 index 000000000000..f500bb79877f --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/mutable_values.rs @@ -0,0 +1,407 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{MutableUtf8Array, StrAsBytes, Utf8Array}; +use crate::array::physical_binary::*; +use crate::array::specification::{try_check_offsets_bounds, try_check_utf8}; +use crate::array::{Array, ArrayValuesIter, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`Utf8Array`]. It differs +/// from [`MutableUtf8Array`] in that it builds non-null [`Utf8Array`]. +#[derive(Debug, Clone)] +pub struct MutableUtf8ValuesArray { + data_type: DataType, + offsets: Offsets, + values: Vec, +} + +impl From> for Utf8Array { + fn from(other: MutableUtf8ValuesArray) -> Self { + // Safety: + // `MutableUtf8ValuesArray` has the same invariants as `Utf8Array` and thus + // `Utf8Array` can be safely created from `MutableUtf8ValuesArray` without checks. + unsafe { + Utf8Array::::new_unchecked( + other.data_type, + other.offsets.into(), + other.values.into(), + None, + ) + } + } +} + +impl From> for MutableUtf8Array { + fn from(other: MutableUtf8ValuesArray) -> Self { + // Safety: + // `MutableUtf8ValuesArray` has the same invariants as `MutableUtf8Array` + unsafe { + MutableUtf8Array::::new_unchecked(other.data_type, other.offsets, other.values, None) + } + } +} + +impl Default for MutableUtf8ValuesArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableUtf8ValuesArray { + /// Returns an empty [`MutableUtf8ValuesArray`]. + pub fn new() -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::new(), + values: Vec::::new(), + } + } + + /// Returns a [`MutableUtf8ValuesArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new(data_type: DataType, offsets: Offsets, values: Vec) -> Result { + try_check_utf8(&offsets, &values)?; + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8", + )); + } + + Ok(Self { + data_type, + offsets, + values, + }) + } + + /// Returns a [`MutableUtf8ValuesArray`] created from its internal representation. + /// + /// # Panic + /// This function does not panic iff: + /// * The last offset is equal to the values' length. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. + /// # Safety + /// This function is safe iff: + /// * the offsets are monotonically increasing + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn new_unchecked(data_type: DataType, offsets: Offsets, values: Vec) -> Self { + try_check_offsets_bounds(&offsets, values.len()) + .expect("The length of the values must be equal to the last offset value"); + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + panic!("MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8") + } + + Self { + data_type, + offsets, + values, + } + } + + /// Returns the default [`DataType`] of this container: [`DataType::Utf8`] or [`DataType::LargeUtf8`] + /// depending on the generic [`Offset`]. + pub fn default_data_type() -> DataType { + Utf8Array::::default_data_type() + } + + /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::::with_capacity(capacity), + values: Vec::::with_capacity(values), + } + } + + /// returns its values. + #[inline] + pub fn values(&self) -> &Vec { + &self.values + } + + /// returns its offsets. + #[inline] + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// Reserves `additional` elements and `additional_values` on the values. + #[inline] + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.offsets.reserve(additional + 1); + self.values.reserve(additional_values); + } + + /// Returns the capacity in number of items + pub fn capacity(&self) -> usize { + self.offsets.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Pushes a new item to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: T) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableUtf8ValuesArray`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option { + if self.len() == 0 { + return None; + } + self.offsets.pop()?; + let start = self.offsets.last().to_usize(); + let value = self.values.split_off(start); + // Safety: utf8 is validated on initialization + Some(unsafe { String::from_utf8_unchecked(value) }) + } + + /// Returns the value of the element at index `i`. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end(i); + + // soundness: the invariant of the struct + let slice = self.values.get_unchecked(start..end); + + // soundness: the invariant of the struct + std::str::from_utf8_unchecked(slice) + } + + /// Returns an iterator of `&str` + pub fn iter(&self) -> ArrayValuesIter { + ArrayValuesIter::new(self) + } + + /// Shrinks the capacity of the [`MutableUtf8ValuesArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + } + + /// Extract the low-end APIs from the [`MutableUtf8ValuesArray`]. + pub fn into_inner(self) -> (DataType, Offsets, Vec) { + (self.data_type, self.offsets, self.values) + } +} + +impl MutableArray for MutableUtf8ValuesArray { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let array: Utf8Array = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: Utf8Array = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&str>("") + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator

for MutableUtf8ValuesArray { + fn from_iter>(iter: I) -> Self { + let (offsets, values) = values_iter(iter.into_iter().map(StrAsBytes)); + // soundness: T: AsRef and offsets are monotonically increasing + unsafe { Self::new_unchecked(Self::default_data_type(), offsets, values) } + } +} + +impl MutableUtf8ValuesArray { + pub(crate) unsafe fn extend_from_trusted_len_iter( + &mut self, + validity: &mut MutableBitmap, + iterator: I, + ) where + P: AsRef, + I: Iterator>, + { + let iterator = iterator.map(|x| x.map(StrAsBytes)); + extend_from_trusted_len_iter(&mut self.offsets, &mut self.values, validity, iterator); + } + + /// Extends the [`MutableUtf8ValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8ValuesArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let iterator = iterator.map(StrAsBytes); + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + } + + /// Creates a [`MutableUtf8ValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Returns a new [`MutableUtf8ValuesArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator, + { + let iterator = iterator.map(StrAsBytes); + let (offsets, values) = trusted_len_values_iter(iterator); + + // soundness: P is `str` and offsets are monotonically increasing + Self::new_unchecked(Self::default_data_type(), offsets, values) + } + + /// Returns a new [`MutableUtf8ValuesArray`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + pub fn try_from_iter, I: IntoIterator>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator>, + T: AsRef, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend for MutableUtf8ValuesArray { + fn extend>(&mut self, iter: I) { + extend_from_values_iter( + &mut self.offsets, + &mut self.values, + iter.into_iter().map(StrAsBytes), + ); + } +} + +impl> TryExtend for MutableUtf8ValuesArray { + fn try_extend>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush for MutableUtf8ValuesArray { + #[inline] + fn try_push(&mut self, value: T) -> Result<()> { + let bytes = value.as_ref().as_bytes(); + self.values.extend_from_slice(bytes); + self.offsets.try_push(bytes.len()) + } +} + +impl TryExtendFromSelf for MutableUtf8ValuesArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + self.values.extend_from_slice(&other.values); + self.offsets.try_extend_from_self(&other.offsets) + } +} diff --git a/crates/nano-arrow/src/bitmap/assign_ops.rs b/crates/nano-arrow/src/bitmap/assign_ops.rs new file mode 100644 index 000000000000..b4d3702c69eb --- /dev/null +++ b/crates/nano-arrow/src/bitmap/assign_ops.rs @@ -0,0 +1,190 @@ +use super::utils::{BitChunk, BitChunkIterExact, BitChunksExact}; +use crate::bitmap::{Bitmap, MutableBitmap}; + +/// Applies a function to every bit of this [`MutableBitmap`] in chunks +/// +/// This function can be for operations like `!` to a [`MutableBitmap`]. +pub fn unary_assign T>(bitmap: &mut MutableBitmap, op: F) { + let mut chunks = bitmap.bitchunks_exact_mut::(); + + chunks.by_ref().for_each(|chunk| { + let new_chunk: T = match (chunk as &[u8]).try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + let new_chunk = op(new_chunk); + chunk.copy_from_slice(new_chunk.to_ne_bytes().as_ref()); + }); + + if chunks.remainder().is_empty() { + return; + } + let mut new_remainder = T::zero().to_ne_bytes(); + chunks + .remainder() + .iter() + .enumerate() + .for_each(|(index, b)| new_remainder[index] = *b); + new_remainder = op(T::from_ne_bytes(new_remainder)).to_ne_bytes(); + + let len = chunks.remainder().len(); + chunks + .remainder() + .copy_from_slice(&new_remainder.as_ref()[..len]); +} + +impl std::ops::Not for MutableBitmap { + type Output = Self; + + #[inline] + fn not(mut self) -> Self { + unary_assign(&mut self, |a: u64| !a); + self + } +} + +fn binary_assign_impl(lhs: &mut MutableBitmap, mut rhs: I, op: F) +where + I: BitChunkIterExact, + T: BitChunk, + F: Fn(T, T) -> T, +{ + let mut lhs_chunks = lhs.bitchunks_exact_mut::(); + + lhs_chunks + .by_ref() + .zip(rhs.by_ref()) + .for_each(|(lhs, rhs)| { + let new_chunk: T = match (lhs as &[u8]).try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + let new_chunk = op(new_chunk, rhs); + lhs.copy_from_slice(new_chunk.to_ne_bytes().as_ref()); + }); + + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs.remainder(); + if rem_lhs.is_empty() { + return; + } + let mut new_remainder = T::zero().to_ne_bytes(); + lhs_chunks + .remainder() + .iter() + .enumerate() + .for_each(|(index, b)| new_remainder[index] = *b); + new_remainder = op(T::from_ne_bytes(new_remainder), rem_rhs).to_ne_bytes(); + + let len = lhs_chunks.remainder().len(); + lhs_chunks + .remainder() + .copy_from_slice(&new_remainder.as_ref()[..len]); +} + +/// Apply a bitwise binary operation to a [`MutableBitmap`]. +/// +/// This function can be used for operations like `&=` to a [`MutableBitmap`]. +/// # Panics +/// This function panics iff `lhs.len() != `rhs.len()` +pub fn binary_assign(lhs: &mut MutableBitmap, rhs: &Bitmap, op: F) +where + F: Fn(T, T) -> T, +{ + assert_eq!(lhs.len(), rhs.len()); + + let (slice, offset, length) = rhs.as_slice(); + if offset == 0 { + let iter = BitChunksExact::::new(slice, length); + binary_assign_impl(lhs, iter, op) + } else { + let rhs_chunks = rhs.chunks::(); + binary_assign_impl(lhs, rhs_chunks, op) + } +} + +#[inline] +/// Compute bitwise OR operation in-place +fn or_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + if rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), true); + } else if rhs.unset_bits() == rhs.len() { + // bitmap remains + } else { + binary_assign(lhs, rhs, |x: T, y| x | y) + } +} + +impl<'a> std::ops::BitOrAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitor_assign(&mut self, rhs: &'a Bitmap) { + or_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitOr<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitor(mut self, rhs: &'a Bitmap) -> Self { + or_assign::(&mut self, rhs); + self + } +} + +#[inline] +/// Compute bitwise `&` between `lhs` and `rhs`, assigning it to `lhs` +fn and_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + if rhs.unset_bits() == 0 { + // bitmap remains + } + if rhs.unset_bits() == rhs.len() { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), false); + } else { + binary_assign(lhs, rhs, |x: T, y| x & y) + } +} + +impl<'a> std::ops::BitAndAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitand_assign(&mut self, rhs: &'a Bitmap) { + and_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitAnd<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitand(mut self, rhs: &'a Bitmap) -> Self { + and_assign::(&mut self, rhs); + self + } +} + +#[inline] +/// Compute bitwise XOR operation +fn xor_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + binary_assign(lhs, rhs, |x: T, y| x ^ y) +} + +impl<'a> std::ops::BitXorAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitxor_assign(&mut self, rhs: &'a Bitmap) { + xor_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitXor<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitxor(mut self, rhs: &'a Bitmap) -> Self { + xor_assign::(&mut self, rhs); + self + } +} diff --git a/crates/nano-arrow/src/bitmap/bitmap_ops.rs b/crates/nano-arrow/src/bitmap/bitmap_ops.rs new file mode 100644 index 000000000000..c83e63255093 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/bitmap_ops.rs @@ -0,0 +1,268 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use super::utils::{BitChunk, BitChunkIterExact, BitChunksExact}; +use super::Bitmap; +use crate::bitmap::MutableBitmap; +use crate::trusted_len::TrustedLen; + +/// Creates a [Vec] from an [`Iterator`] of [`BitChunk`]. +/// # Safety +/// The iterator must be [`TrustedLen`]. +pub unsafe fn from_chunk_iter_unchecked>( + iterator: I, +) -> Vec { + let (_, upper) = iterator.size_hint(); + let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); + let len = upper * std::mem::size_of::(); + + let mut buffer = Vec::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let bytes = item.to_ne_bytes(); + for i in 0..std::mem::size_of::() { + std::ptr::write(dst, bytes[i]); + dst = dst.add(1); + } + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + buffer +} + +/// Creates a [`Vec`] from a [`TrustedLen`] of [`BitChunk`]. +pub fn chunk_iter_to_vec>(iter: I) -> Vec { + unsafe { from_chunk_iter_unchecked(iter) } +} + +/// Apply a bitwise operation `op` to four inputs and return the result as a [`Bitmap`]. +pub fn quaternary(a1: &Bitmap, a2: &Bitmap, a3: &Bitmap, a4: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64, u64, u64) -> u64, +{ + assert_eq!(a1.len(), a2.len()); + assert_eq!(a1.len(), a3.len()); + assert_eq!(a1.len(), a4.len()); + let a1_chunks = a1.chunks(); + let a2_chunks = a2.chunks(); + let a3_chunks = a3.chunks(); + let a4_chunks = a4.chunks(); + + let rem_a1 = a1_chunks.remainder(); + let rem_a2 = a2_chunks.remainder(); + let rem_a3 = a3_chunks.remainder(); + let rem_a4 = a4_chunks.remainder(); + + let chunks = a1_chunks + .zip(a2_chunks) + .zip(a3_chunks) + .zip(a4_chunks) + .map(|(((a1, a2), a3), a4)| op(a1, a2, a3, a4)); + let buffer = + chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3, rem_a4)))); + + let length = a1.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to three inputs and return the result as a [`Bitmap`]. +pub fn ternary(a1: &Bitmap, a2: &Bitmap, a3: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64, u64) -> u64, +{ + assert_eq!(a1.len(), a2.len()); + assert_eq!(a1.len(), a3.len()); + let a1_chunks = a1.chunks(); + let a2_chunks = a2.chunks(); + let a3_chunks = a3.chunks(); + + let rem_a1 = a1_chunks.remainder(); + let rem_a2 = a2_chunks.remainder(); + let rem_a3 = a3_chunks.remainder(); + + let chunks = a1_chunks + .zip(a2_chunks) + .zip(a3_chunks) + .map(|((a1, a2), a3)| op(a1, a2, a3)); + + let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3)))); + + let length = a1.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to two inputs and return the result as a [`Bitmap`]. +pub fn binary(lhs: &Bitmap, rhs: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64) -> u64, +{ + assert_eq!(lhs.len(), rhs.len()); + let lhs_chunks = lhs.chunks(); + let rhs_chunks = rhs.chunks(); + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs_chunks.remainder(); + + let chunks = lhs_chunks + .zip(rhs_chunks) + .map(|(left, right)| op(left, right)); + + let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_lhs, rem_rhs)))); + + let length = lhs.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +fn unary_impl(iter: I, op: F, length: usize) -> Bitmap +where + I: BitChunkIterExact, + F: Fn(u64) -> u64, +{ + let rem = op(iter.remainder()); + + let iterator = iter.map(op).chain(std::iter::once(rem)); + + let buffer = chunk_iter_to_vec(iterator); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to one input and return the result as a [`Bitmap`]. +pub fn unary(lhs: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64) -> u64, +{ + let (slice, offset, length) = lhs.as_slice(); + if offset == 0 { + let iter = BitChunksExact::::new(slice, length); + unary_impl(iter, op, lhs.len()) + } else { + let iter = lhs.chunks::(); + unary_impl(iter, op, lhs.len()) + } +} + +// create a new [`Bitmap`] semantically equal to ``bitmap`` but with an offset equal to ``offset`` +pub(crate) fn align(bitmap: &Bitmap, new_offset: usize) -> Bitmap { + let length = bitmap.len(); + + let bitmap: Bitmap = std::iter::repeat(false) + .take(new_offset) + .chain(bitmap.iter()) + .collect(); + + bitmap.sliced(new_offset, length) +} + +#[inline] +/// Compute bitwise AND operation +pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + if lhs.unset_bits() == lhs.len() || rhs.unset_bits() == rhs.len() { + assert_eq!(lhs.len(), rhs.len()); + Bitmap::new_zeroed(lhs.len()) + } else { + binary(lhs, rhs, |x, y| x & y) + } +} + +#[inline] +/// Compute bitwise OR operation +pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + if lhs.unset_bits() == 0 || rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| x | y) + } +} + +#[inline] +/// Compute bitwise XOR operation +pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + let lhs_nulls = lhs.unset_bits(); + let rhs_nulls = rhs.unset_bits(); + + // all false or all true + if lhs_nulls == rhs_nulls && rhs_nulls == rhs.len() || lhs_nulls == 0 && rhs_nulls == 0 { + assert_eq!(lhs.len(), rhs.len()); + Bitmap::new_zeroed(rhs.len()) + } + // all false and all true or vice versa + else if (lhs_nulls == 0 && rhs_nulls == rhs.len()) + || (lhs_nulls == lhs.len() && rhs_nulls == 0) + { + assert_eq!(lhs.len(), rhs.len()); + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| x ^ y) + } +} + +fn eq(lhs: &Bitmap, rhs: &Bitmap) -> bool { + if lhs.len() != rhs.len() { + return false; + } + + let mut lhs_chunks = lhs.chunks::(); + let mut rhs_chunks = rhs.chunks::(); + + let equal_chunks = lhs_chunks + .by_ref() + .zip(rhs_chunks.by_ref()) + .all(|(left, right)| left == right); + + if !equal_chunks { + return false; + } + let lhs_remainder = lhs_chunks.remainder_iter(); + let rhs_remainder = rhs_chunks.remainder_iter(); + lhs_remainder.zip(rhs_remainder).all(|(x, y)| x == y) +} + +impl PartialEq for Bitmap { + fn eq(&self, other: &Self) -> bool { + eq(self, other) + } +} + +impl<'a, 'b> BitOr<&'b Bitmap> for &'a Bitmap { + type Output = Bitmap; + + fn bitor(self, rhs: &'b Bitmap) -> Bitmap { + or(self, rhs) + } +} + +impl<'a, 'b> BitAnd<&'b Bitmap> for &'a Bitmap { + type Output = Bitmap; + + fn bitand(self, rhs: &'b Bitmap) -> Bitmap { + and(self, rhs) + } +} + +impl<'a, 'b> BitXor<&'b Bitmap> for &'a Bitmap { + type Output = Bitmap; + + fn bitxor(self, rhs: &'b Bitmap) -> Bitmap { + xor(self, rhs) + } +} + +impl Not for &Bitmap { + type Output = Bitmap; + + fn not(self) -> Bitmap { + unary(self, |a| !a) + } +} diff --git a/crates/nano-arrow/src/bitmap/bitmask.rs b/crates/nano-arrow/src/bitmap/bitmask.rs new file mode 100644 index 000000000000..45d4b960c92a --- /dev/null +++ b/crates/nano-arrow/src/bitmap/bitmask.rs @@ -0,0 +1,310 @@ +#[cfg(feature = "simd")] +use std::simd::ToBitMask; + +#[cfg(feature = "simd")] +use num_traits::AsPrimitive; + +use crate::bitmap::Bitmap; + +/// Returns the nth set bit in w, if n+1 bits are set. The indexing is +/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w. +fn nth_set_bit_u32(w: u32, n: u32) -> Option { + // If we have BMI2's PDEP available, we use it. It takes the lower order + // bits of the first argument and spreads it along its second argument + // where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h. + // We use this by setting the first argument to 1 << n, which means the + // first n-1 zero bits of it will spread to the first n-1 one bits of w, + // after which the one bit will exactly get copied to the nth one bit of w. + #[cfg(target_feature = "bmi2")] + { + if n >= 32 { + return None; + } + + let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) }; + if nth_set_bit == 0 { + return None; + } + + Some(nth_set_bit.trailing_zeros()) + } + + #[cfg(not(target_feature = "bmi2"))] + { + // Each block of 2/4/8/16 bits contains how many set bits there are in that block. + let set_per_2 = w - ((w >> 1) & 0x55555555); + let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333); + let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f; + let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff; + let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff; + if n >= set_per_32 { + return None; + } + + let mut idx = 0; + let mut n = n; + let next16 = set_per_16 & 0xff; + if n >= next16 { + n -= next16; + idx += 16; + } + let next8 = (set_per_8 >> idx) & 0xff; + if n >= next8 { + n -= next8; + idx += 8; + } + let next4 = (set_per_4 >> idx) & 0b1111; + if n >= next4 { + n -= next4; + idx += 4; + } + let next2 = (set_per_2 >> idx) & 0b11; + if n >= next2 { + n -= next2; + idx += 2; + } + let next1 = (w >> idx) & 0b1; + if n >= next1 { + idx += 1; + } + Some(idx) + } +} + +// Loads a u64 from the given byteslice, as if it were padded with zeros. +fn load_padded_le_u64(bytes: &[u8]) -> u64 { + let len = bytes.len(); + if len >= 8 { + return u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + } + + if len >= 4 { + let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); + let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap()); + return (lo as u64) | ((hi as u64) << (8 * (len - 4))); + } + + if len == 0 { + return 0; + } + + let lo = bytes[0] as u64; + let mid = (bytes[len / 2] as u64) << (8 * (len / 2)); + let hi = (bytes[len - 1] as u64) << (8 * (len - 1)); + lo | mid | hi +} + +#[derive(Default, Clone)] +pub struct BitMask<'a> { + bytes: &'a [u8], + offset: usize, + len: usize, +} + +impl<'a> BitMask<'a> { + pub fn from_bitmap(bitmap: &'a Bitmap) -> Self { + let (bytes, offset, len) = bitmap.as_slice(); + // Check length so we can use unsafe access in our get. + assert!(bytes.len() * 8 >= len + offset); + Self { bytes, offset, len } + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.len + } + + #[inline] + pub fn split_at(&self, idx: usize) -> (Self, Self) { + assert!(idx <= self.len); + unsafe { self.split_at_unchecked(idx) } + } + + /// # Safety + /// The index must be in-bounds. + #[inline] + pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) { + debug_assert!(idx <= self.len); + let left = Self { len: idx, ..*self }; + let right = Self { + len: self.len - idx, + offset: self.offset + idx, + ..*self + }; + (left, right) + } + + #[cfg(feature = "simd")] + #[inline] + pub fn get_simd(&self, idx: usize) -> T + where + T: ToBitMask, + ::BitMask: Copy + 'static, + u64: AsPrimitive<::BitMask>, + { + // We don't support 64-lane masks because then we couldn't load our + // bitwise mask as a u64 and then do the byteshift on it. + + let lanes = std::mem::size_of::() * 8; + assert!(lanes < 64); + + let start_byte_idx = (self.offset + idx) / 8; + let byte_shift = (self.offset + idx) % 8; + if idx + lanes <= self.len { + // SAFETY: fast path, we know this is completely in-bounds. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + T::from_bitmask((mask >> byte_shift).as_()) + } else if idx < self.len { + // SAFETY: we know that at least the first byte is in-bounds. + // This is partially out of bounds, we have to do extra masking. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + let num_out_of_bounds = idx + lanes - self.len; + let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift); + T::from_bitmask(shifted.as_()) + } else { + T::from_bitmask((0u64).as_()) + } + } + + #[inline] + pub fn get_u32(&self, idx: usize) -> u32 { + let start_byte_idx = (self.offset + idx) / 8; + let byte_shift = (self.offset + idx) % 8; + if idx + 32 <= self.len { + // SAFETY: fast path, we know this is completely in-bounds. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + (mask >> byte_shift) as u32 + } else if idx < self.len { + // SAFETY: we know that at least the first byte is in-bounds. + // This is partially out of bounds, we have to do extra masking. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1; + ((mask >> byte_shift) as u32) & out_of_bounds_mask + } else { + 0 + } + } + + /// Computes the index of the nth set bit after start. + /// + /// Both are zero-indexed, so nth_set_bit_idx(0, 0) finds the index of the + /// first bit set (which can be 0 as well). The returned index is absolute, + /// not relative to start. + pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option { + while start < self.len { + let next_u32_mask = self.get_u32(start); + if next_u32_mask == u32::MAX { + // Happy fast path for dense non-null section. + if n < 32 { + return Some(start + n); + } + n -= 32; + } else { + let ones = next_u32_mask.count_ones() as usize; + if n < ones { + let idx = unsafe { + // SAFETY: we know the nth bit is in the mask. + nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize + }; + return Some(start + idx); + } + n -= ones; + } + + start += 32; + } + + None + } + + /// Computes the index of the nth set bit before end, counting backwards. + /// + /// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of + /// the last bit set (which can be 0 as well). The returned index is + /// absolute (and starts at the beginning), not relative to end. + pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option { + while end > 0 { + // We want to find bits *before* end, so if end < 32 we must mask + // out the bits after the endth. + let (u32_mask_start, u32_mask_mask) = if end >= 32 { + (end - 32, u32::MAX) + } else { + (0, (1 << end) - 1) + }; + let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask; + if next_u32_mask == u32::MAX { + // Happy fast path for dense non-null section. + if n < 32 { + return Some(end - 1 - n); + } + n -= 32; + } else { + let ones = next_u32_mask.count_ones() as usize; + if n < ones { + let rev_n = ones - 1 - n; + let idx = unsafe { + // SAFETY: we know the rev_nth bit is in the mask. + nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize + }; + return Some(u32_mask_start + idx); + } + n -= ones; + } + + end = u32_mask_start; + } + + None + } + + #[inline] + pub fn get(&self, idx: usize) -> bool { + let byte_idx = (self.offset + idx) / 8; + let byte_shift = (self.offset + idx) % 8; + + if idx < self.len { + // SAFETY: we know this is in-bounds. + let byte = unsafe { *self.bytes.get_unchecked(byte_idx) }; + (byte >> byte_shift) & 1 == 1 + } else { + false + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option { + for i in 0..32 { + if w & (1 << i) != 0 { + if n == 0 { + return Some(i); + } + n -= 1; + w ^= 1 << i; + } + } + None + } + + #[test] + fn test_nth_set_bit_u32() { + for n in 0..256 { + assert_eq!(nth_set_bit_u32(0, n), None); + } + + for i in 0..32 { + assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i)); + assert_eq!(nth_set_bit_u32(1 << i, 1), None); + } + + for i in 0..10000 { + let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32; + for i in 0..=32 { + assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i)); + } + } + } +} diff --git a/crates/nano-arrow/src/bitmap/immutable.rs b/crates/nano-arrow/src/bitmap/immutable.rs new file mode 100644 index 000000000000..c29ac7a41314 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/immutable.rs @@ -0,0 +1,471 @@ +use std::iter::FromIterator; +use std::ops::Deref; +use std::sync::Arc; + +use either::Either; + +use super::utils::{count_zeros, fmt, get_bit, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; +use super::{chunk_iter_to_vec, IntoIter, MutableBitmap}; +use crate::buffer::Bytes; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +/// An immutable container semantically equivalent to `Arc>` but represented as `Arc>` where +/// each boolean is represented as a single bit. +/// +/// # Examples +/// ``` +/// use arrow2::bitmap::{Bitmap, MutableBitmap}; +/// +/// let bitmap = Bitmap::from([true, false, true]); +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true]); +/// +/// // creation directly from bytes +/// let bitmap = Bitmap::try_new(vec![0b00001101], 5).unwrap(); +/// // note: the first bit is the left-most of the first byte +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true, true, false]); +/// // we can also get the slice: +/// assert_eq!(bitmap.as_slice(), ([0b00001101u8].as_ref(), 0, 5)); +/// // debug helps :) +/// assert_eq!(format!("{:?}", bitmap), "[0b___01101]".to_string()); +/// +/// // it supports copy-on-write semantics (to a `MutableBitmap`) +/// let bitmap: MutableBitmap = bitmap.into_mut().right().unwrap(); +/// assert_eq!(bitmap, MutableBitmap::from([true, false, true, true, false])); +/// +/// // slicing is 'O(1)' (data is shared) +/// let bitmap = Bitmap::try_new(vec![0b00001101], 5).unwrap(); +/// let mut sliced = bitmap.clone(); +/// sliced.slice(1, 4); +/// assert_eq!(sliced.as_slice(), ([0b00001101u8].as_ref(), 1, 4)); // 1 here is the offset: +/// assert_eq!(format!("{:?}", sliced), "[0b___0110_]".to_string()); +/// // when sliced (or cloned), it is no longer possible to `into_mut`. +/// let same: Bitmap = sliced.into_mut().left().unwrap(); +/// ``` +#[derive(Clone)] +pub struct Bitmap { + bytes: Arc>, + // both are measured in bits. They are used to bound the bitmap to a region of Bytes. + offset: usize, + length: usize, + // this is a cache: it is computed on initialization + unset_bits: usize, +} + +impl std::fmt::Debug for Bitmap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (bytes, offset, len) = self.as_slice(); + fmt(bytes, offset, len, f) + } +} + +impl Default for Bitmap { + fn default() -> Self { + MutableBitmap::new().into() + } +} + +pub(super) fn check(bytes: &[u8], offset: usize, length: usize) -> Result<(), Error> { + if offset + length > bytes.len().saturating_mul(8) { + return Err(Error::InvalidArgumentError(format!( + "The offset + length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", + offset + length, + bytes.len().saturating_mul(8) + ))); + } + Ok(()) +} + +impl Bitmap { + /// Initializes an empty [`Bitmap`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Initializes a new [`Bitmap`] from vector of bytes and a length. + /// # Errors + /// This function errors iff `length > bytes.len() * 8` + #[inline] + pub fn try_new(bytes: Vec, length: usize) -> Result { + check(&bytes, 0, length)?; + let unset_bits = count_zeros(&bytes, 0, length); + Ok(Self { + length, + offset: 0, + bytes: Arc::new(bytes.into()), + unset_bits, + }) + } + + /// Returns the length of the [`Bitmap`]. + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether [`Bitmap`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns a new iterator of `bool` over this bitmap + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(&self.bytes, self.offset, self.length) + } + + /// Returns an iterator over bits in bit chunks [`BitChunk`]. + /// + /// This iterator is useful to operate over multiple bits via e.g. bitwise. + pub fn chunks(&self) -> BitChunks { + BitChunks::new(&self.bytes, self.offset, self.length) + } + + /// Returns the byte slice of this [`Bitmap`]. + /// + /// The returned tuple contains: + /// * `.1`: The byte slice, truncated to the start of the first bit. So the start of the slice + /// is within the first 8 bits. + /// * `.2`: The start offset in bits on a range `0 <= offsets < 8`. + /// * `.3`: The length in number of bits. + #[inline] + pub fn as_slice(&self) -> (&[u8], usize, usize) { + let start = self.offset / 8; + let len = (self.offset % 8 + self.length).saturating_add(7) / 8; + ( + &self.bytes[start..start + len], + self.offset % 8, + self.length, + ) + } + + /// Returns the number of unset bits on this [`Bitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// # Implementation + /// This function is `O(1)` - the number of unset bits is computed when the bitmap is + /// created + pub const fn unset_bits(&self) -> usize { + self.unset_bits + } + + /// Returns the number of unset bits on this [`Bitmap`]. + #[inline] + #[deprecated(since = "0.13.0", note = "use `unset_bits` instead")] + pub fn null_count(&self) -> usize { + self.unset_bits + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Panic + /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` + /// exceeds the allocated capacity of `self`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!(offset + length <= self.length); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Safety + /// The caller must ensure that `self.offset + offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + // first guard a no-op slice so that we don't do a bitcount + // if there isn't any data sliced + if !(offset == 0 && length == self.length) { + // count the smallest chunk + if length < self.length / 2 { + // count the null values in the slice + self.unset_bits = count_zeros(&self.bytes, self.offset + offset, length); + } else { + // subtract the null count of the chunks we slice off + let start_end = self.offset + offset + length; + let head_count = count_zeros(&self.bytes, self.offset, offset); + let tail_count = count_zeros(&self.bytes, start_end, self.length - length - offset); + self.unset_bits -= head_count + tail_count; + } + self.offset += offset; + self.length = length; + } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Panic + /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` + /// exceeds the allocated capacity of `self`. + #[inline] + #[must_use] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!(offset + length <= self.length); + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Safety + /// The caller must ensure that `self.offset + offset + length <= self.len()` + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + + /// Returns whether the bit at position `i` is set. + /// # Panics + /// Panics iff `i >= self.len()`. + #[inline] + pub fn get_bit(&self, i: usize) -> bool { + get_bit(&self.bytes, self.offset + i) + } + + /// Unsafely returns whether the bit at position `i` is set. + /// # Safety + /// Unsound iff `i >= self.len()`. + #[inline] + pub unsafe fn get_bit_unchecked(&self, i: usize) -> bool { + get_bit_unchecked(&self.bytes, self.offset + i) + } + + /// Returns a pointer to the start of this [`Bitmap`] (ignores `offsets`) + /// This pointer is allocated iff `self.len() > 0`. + pub(crate) fn as_ptr(&self) -> *const u8 { + self.bytes.deref().as_ptr() + } + + /// Returns a pointer to the start of this [`Bitmap`] (ignores `offsets`) + /// This pointer is allocated iff `self.len() > 0`. + pub(crate) fn offset(&self) -> usize { + self.offset + } + + /// Converts this [`Bitmap`] to [`MutableBitmap`], returning itself if the conversion + /// is not possible + /// + /// This operation returns a [`MutableBitmap`] iff: + /// * this [`Bitmap`] is not an offsetted slice of another [`Bitmap`] + /// * this [`Bitmap`] has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * this [`Bitmap`] was not imported from the c data interface (FFI) + pub fn into_mut(mut self) -> Either { + match ( + self.offset, + Arc::get_mut(&mut self.bytes).and_then(|b| b.get_vec()), + ) { + (0, Some(v)) => { + let data = std::mem::take(v); + Either::Right(MutableBitmap::from_vec(data, self.length)) + }, + _ => Either::Left(self), + } + } + + /// Converts this [`Bitmap`] into a [`MutableBitmap`], cloning its internal + /// buffer if required (clone-on-write). + pub fn make_mut(self) -> MutableBitmap { + match self.into_mut() { + Either::Left(data) => { + if data.offset > 0 { + // re-align the bits (remove the offset) + let chunks = data.chunks::(); + let remainder = chunks.remainder(); + let vec = chunk_iter_to_vec(chunks.chain(std::iter::once(remainder))); + MutableBitmap::from_vec(vec, data.length) + } else { + MutableBitmap::from_vec(data.bytes.as_ref().to_vec(), data.length) + } + }, + Either::Right(data) => data, + } + } + + /// Initializes an new [`Bitmap`] filled with unset values. + #[inline] + pub fn new_zeroed(length: usize) -> Self { + // don't use `MutableBitmap::from_len_zeroed().into()` + // it triggers a bitcount + let bytes = vec![0; length.saturating_add(7) / 8]; + unsafe { Bitmap::from_inner_unchecked(Arc::new(bytes.into()), 0, length, length) } + } + + /// Counts the nulls (unset bits) starting from `offset` bits and for `length` bits. + #[inline] + pub fn null_count_range(&self, offset: usize, length: usize) -> usize { + count_zeros(&self.bytes, self.offset + offset, length) + } + + /// Creates a new [`Bitmap`] from a slice and length. + /// # Panic + /// Panics iff `length <= bytes.len() * 8` + #[inline] + pub fn from_u8_slice>(slice: T, length: usize) -> Self { + Bitmap::try_new(slice.as_ref().to_vec(), length).unwrap() + } + + /// Alias for `Bitmap::try_new().unwrap()` + /// This function is `O(1)` + /// # Panic + /// This function panics iff `length <= bytes.len() * 8` + #[inline] + pub fn from_u8_vec(vec: Vec, length: usize) -> Self { + Bitmap::try_new(vec, length).unwrap() + } + + /// Returns whether the bit at position `i` is set. + #[inline] + pub fn get(&self, i: usize) -> Option { + if i < self.len() { + Some(unsafe { self.get_bit_unchecked(i) }) + } else { + None + } + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (Arc>, usize, usize, usize) { + let Self { + bytes, + offset, + length, + unset_bits, + } = self; + (bytes, offset, length, unset_bits) + } + + /// Creates a `[Bitmap]` from its internal representation. + /// This is the inverted from `[Bitmap::into_inner]` + /// + /// # Safety + /// The invariants of this struct must be upheld + pub unsafe fn from_inner( + bytes: Arc>, + offset: usize, + length: usize, + unset_bits: usize, + ) -> Result { + check(&bytes, offset, length)?; + Ok(Self { + bytes, + offset, + length, + unset_bits, + }) + } + + /// Creates a `[Bitmap]` from its internal representation. + /// This is the inverted from `[Bitmap::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + bytes: Arc>, + offset: usize, + length: usize, + unset_bits: usize, + ) -> Self { + Self { + bytes, + offset, + length, + unset_bits, + } + } +} + +impl> From

for Bitmap { + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().copied()) + } +} + +impl FromIterator for Bitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + MutableBitmap::from_iter(iter).into() + } +} + +impl Bitmap { + /// Creates a new [`Bitmap`] from an iterator of booleans. + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked>(iterator: I) -> Self { + MutableBitmap::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a new [`Bitmap`] from an iterator of booleans. + #[inline] + pub fn from_trusted_len_iter>(iterator: I) -> Self { + MutableBitmap::from_trusted_len_iter(iterator).into() + } + + /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + #[inline] + pub fn try_from_trusted_len_iter>>( + iterator: I, + ) -> std::result::Result { + Ok(MutableBitmap::try_from_trusted_len_iter(iterator)?.into()) + } + + /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked< + E, + I: Iterator>, + >( + iterator: I, + ) -> std::result::Result { + Ok(MutableBitmap::try_from_trusted_len_iter_unchecked(iterator)?.into()) + } + + /// Create a new [`Bitmap`] from an arrow [`NullBuffer`] + /// + /// [`NullBuffer`]: arrow_buffer::buffer::NullBuffer + #[cfg(feature = "arrow_rs")] + pub fn from_null_buffer(value: arrow_buffer::buffer::NullBuffer) -> Self { + let offset = value.offset(); + let length = value.len(); + let unset_bits = value.null_count(); + Self { + offset, + length, + unset_bits, + bytes: Arc::new(crate::buffer::to_bytes(value.buffer().clone())), + } + } +} + +impl<'a> IntoIterator for &'a Bitmap { + type Item = bool; + type IntoIter = BitmapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitmapIter::<'a>::new(&self.bytes, self.offset, self.length) + } +} + +impl IntoIterator for Bitmap { + type Item = bool; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +#[cfg(feature = "arrow_rs")] +impl From for arrow_buffer::buffer::NullBuffer { + fn from(value: Bitmap) -> Self { + let null_count = value.unset_bits; + let buffer = crate::buffer::to_buffer(value.bytes); + let buffer = arrow_buffer::buffer::BooleanBuffer::new(buffer, value.offset, value.length); + // Safety: null count is accurate + unsafe { arrow_buffer::buffer::NullBuffer::new_unchecked(buffer, null_count) } + } +} diff --git a/crates/nano-arrow/src/bitmap/iterator.rs b/crates/nano-arrow/src/bitmap/iterator.rs new file mode 100644 index 000000000000..93ca7fb8576a --- /dev/null +++ b/crates/nano-arrow/src/bitmap/iterator.rs @@ -0,0 +1,68 @@ +use super::Bitmap; +use crate::trusted_len::TrustedLen; + +/// This crates' equivalent of [`std::vec::IntoIter`] for [`Bitmap`]. +#[derive(Debug, Clone)] +pub struct IntoIter { + values: Bitmap, + index: usize, + end: usize, +} + +impl IntoIter { + /// Creates a new [`IntoIter`] from a [`Bitmap`] + #[inline] + pub fn new(values: Bitmap) -> Self { + let end = values.len(); + Self { + values, + index: 0, + end, + } + } +} + +impl Iterator for IntoIter { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(unsafe { self.values.get_bit_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(unsafe { self.values.get_bit_unchecked(self.end) }) + } + } +} + +unsafe impl TrustedLen for IntoIter {} diff --git a/crates/nano-arrow/src/bitmap/mod.rs b/crates/nano-arrow/src/bitmap/mod.rs new file mode 100644 index 000000000000..18662764dea9 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/mod.rs @@ -0,0 +1,19 @@ +//! contains [`Bitmap`] and [`MutableBitmap`], containers of `bool`. +mod immutable; +pub use immutable::*; + +mod iterator; +pub use iterator::IntoIter; + +mod mutable; +pub use mutable::MutableBitmap; + +mod bitmap_ops; +pub use bitmap_ops::*; + +mod assign_ops; +pub use assign_ops::*; + +pub mod utils; + +pub mod bitmask; diff --git a/crates/nano-arrow/src/bitmap/mutable.rs b/crates/nano-arrow/src/bitmap/mutable.rs new file mode 100644 index 000000000000..e52e39ba3200 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/mutable.rs @@ -0,0 +1,755 @@ +use std::hint::unreachable_unchecked; +use std::iter::FromIterator; +use std::sync::Arc; + +use super::utils::{ + count_zeros, fmt, get_bit, set, set_bit, BitChunk, BitChunksExactMut, BitmapIter, +}; +use super::Bitmap; +use crate::bitmap::utils::{merge_reversed, set_bit_unchecked}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +/// A container of booleans. [`MutableBitmap`] is semantically equivalent +/// to [`Vec`]. +/// +/// The two main differences against [`Vec`] is that each element stored as a single bit, +/// thereby: +/// * it uses 8x less memory +/// * it cannot be represented as `&[bool]` (i.e. no pointer arithmetics). +/// +/// A [`MutableBitmap`] can be converted to a [`Bitmap`] at `O(1)`. +/// # Examples +/// ``` +/// use arrow2::bitmap::MutableBitmap; +/// +/// let bitmap = MutableBitmap::from([true, false, true]); +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true]); +/// +/// // creation directly from bytes +/// let mut bitmap = MutableBitmap::try_new(vec![0b00001101], 5).unwrap(); +/// // note: the first bit is the left-most of the first byte +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true, true, false]); +/// // we can also get the slice: +/// assert_eq!(bitmap.as_slice(), [0b00001101u8].as_ref()); +/// // debug helps :) +/// assert_eq!(format!("{:?}", bitmap), "[0b___01101]".to_string()); +/// +/// // It supports mutation in place +/// bitmap.set(0, false); +/// assert_eq!(format!("{:?}", bitmap), "[0b___01100]".to_string()); +/// // and `O(1)` random access +/// assert_eq!(bitmap.get(0), false); +/// ``` +/// # Implementation +/// This container is internally a [`Vec`]. +#[derive(Clone)] +pub struct MutableBitmap { + buffer: Vec, + // invariant: length.saturating_add(7) / 8 == buffer.len(); + length: usize, +} + +impl std::fmt::Debug for MutableBitmap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt(&self.buffer, 0, self.len(), f) + } +} + +impl PartialEq for MutableBitmap { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl MutableBitmap { + /// Initializes an empty [`MutableBitmap`]. + #[inline] + pub fn new() -> Self { + Self { + buffer: Vec::new(), + length: 0, + } + } + + /// Initializes a new [`MutableBitmap`] from a [`Vec`] and a length. + /// # Errors + /// This function errors iff `length > bytes.len() * 8` + #[inline] + pub fn try_new(bytes: Vec, length: usize) -> Result { + if length > bytes.len().saturating_mul(8) { + return Err(Error::InvalidArgumentError(format!( + "The length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", + length, + bytes.len().saturating_mul(8) + ))); + } + Ok(Self { + length, + buffer: bytes, + }) + } + + /// Initializes a [`MutableBitmap`] from a [`Vec`] and a length. + /// This function is `O(1)`. + /// # Panic + /// Panics iff the length is larger than the length of the buffer times 8. + #[inline] + pub fn from_vec(buffer: Vec, length: usize) -> Self { + Self::try_new(buffer, length).unwrap() + } + + /// Initializes a pre-allocated [`MutableBitmap`] with capacity for `capacity` bits. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buffer: Vec::with_capacity(capacity.saturating_add(7) / 8), + length: 0, + } + } + + /// Pushes a new bit to the [`MutableBitmap`], re-sizing it if necessary. + #[inline] + pub fn push(&mut self, value: bool) { + if self.length % 8 == 0 { + self.buffer.push(0); + } + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + *byte = set(*byte, self.length % 8, value); + self.length += 1; + } + + /// Pop the last bit from the [`MutableBitmap`]. + /// Note if the [`MutableBitmap`] is empty, this method will return None. + #[inline] + pub fn pop(&mut self) -> Option { + if self.is_empty() { + return None; + } + + self.length -= 1; + let value = self.get(self.length); + if self.length % 8 == 0 { + self.buffer.pop(); + } + Some(value) + } + + /// Returns whether the position `index` is set. + /// # Panics + /// Panics iff `index >= self.len()`. + #[inline] + pub fn get(&self, index: usize) -> bool { + get_bit(&self.buffer, index) + } + + /// Sets the position `index` to `value` + /// # Panics + /// Panics iff `index >= self.len()`. + #[inline] + pub fn set(&mut self, index: usize, value: bool) { + set_bit(self.buffer.as_mut_slice(), index, value) + } + + /// constructs a new iterator over the bits of [`MutableBitmap`]. + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(&self.buffer, 0, self.length) + } + + /// Empties the [`MutableBitmap`]. + #[inline] + pub fn clear(&mut self) { + self.length = 0; + self.buffer.clear(); + } + + /// Extends [`MutableBitmap`] by `additional` values of constant `value`. + /// # Implementation + /// This function is an order of magnitude faster than pushing element by element. + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: bool) { + if additional == 0 { + return; + } + + if value { + self.extend_set(additional) + } else { + self.extend_unset(additional) + } + } + + /// Initializes a zeroed [`MutableBitmap`]. + #[inline] + pub fn from_len_zeroed(length: usize) -> Self { + Self { + buffer: vec![0; length.saturating_add(7) / 8], + length, + } + } + + /// Initializes a [`MutableBitmap`] with all values set to valid/ true. + #[inline] + pub fn from_len_set(length: usize) -> Self { + Self { + buffer: vec![u8::MAX; length.saturating_add(7) / 8], + length, + } + } + + /// Reserves `additional` bits in the [`MutableBitmap`], potentially re-allocating its buffer. + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + self.buffer + .reserve((self.length + additional).saturating_add(7) / 8 - self.buffer.len()) + } + + /// Returns the capacity of [`MutableBitmap`] in number of bits. + #[inline] + pub fn capacity(&self) -> usize { + self.buffer.capacity() * 8 + } + + /// Pushes a new bit to the [`MutableBitmap`] + /// # Safety + /// The caller must ensure that the [`MutableBitmap`] has sufficient capacity. + #[inline] + pub unsafe fn push_unchecked(&mut self, value: bool) { + if self.length % 8 == 0 { + self.buffer.push(0); + } + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + *byte = set(*byte, self.length % 8, value); + self.length += 1; + } + + /// Returns the number of unset bits on this [`MutableBitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// # Implementation + /// This function is `O(N)` + pub fn unset_bits(&self) -> usize { + count_zeros(&self.buffer, 0, self.length) + } + + /// Returns the number of unset bits on this [`MutableBitmap`]. + #[deprecated(since = "0.13.0", note = "use `unset_bits` instead")] + pub fn null_count(&self) -> usize { + self.unset_bits() + } + + /// Returns the length of the [`MutableBitmap`]. + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether [`MutableBitmap`] is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// # Safety + /// The caller must ensure that the [`MutableBitmap`] was properly initialized up to `len`. + #[inline] + pub(crate) unsafe fn set_len(&mut self, len: usize) { + self.buffer.set_len(len.saturating_add(7) / 8); + self.length = len; + } + + fn extend_set(&mut self, mut additional: usize) { + let offset = self.length % 8; + let added = if offset != 0 { + // offset != 0 => at least one byte in the buffer + let last_index = self.buffer.len() - 1; + let last = &mut self.buffer[last_index]; + + let remaining = 0b11111111u8; + let remaining = remaining >> 8usize.saturating_sub(additional); + let remaining = remaining << offset; + *last |= remaining; + std::cmp::min(additional, 8 - offset) + } else { + 0 + }; + self.length += added; + additional = additional.saturating_sub(added); + if additional > 0 { + debug_assert_eq!(self.length % 8, 0); + let existing = self.length.saturating_add(7) / 8; + let required = (self.length + additional).saturating_add(7) / 8; + // add remaining as full bytes + self.buffer + .extend(std::iter::repeat(0b11111111u8).take(required - existing)); + self.length += additional; + } + } + + fn extend_unset(&mut self, mut additional: usize) { + let offset = self.length % 8; + let added = if offset != 0 { + // offset != 0 => at least one byte in the buffer + let last_index = self.buffer.len() - 1; + let last = &mut self.buffer[last_index]; + *last &= 0b11111111u8 >> (8 - offset); // unset them + std::cmp::min(additional, 8 - offset) + } else { + 0 + }; + self.length += added; + additional = additional.saturating_sub(added); + if additional > 0 { + debug_assert_eq!(self.length % 8, 0); + self.buffer + .resize((self.length + additional).saturating_add(7) / 8, 0); + self.length += additional; + } + } + + /// Sets the position `index` to `value` + /// # Safety + /// Caller must ensure that `index < self.len()` + #[inline] + pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { + set_bit_unchecked(self.buffer.as_mut_slice(), index, value) + } + + /// Shrinks the capacity of the [`MutableBitmap`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + + /// Returns an iterator over mutable slices, [`BitChunksExactMut`] + pub(crate) fn bitchunks_exact_mut(&mut self) -> BitChunksExactMut { + BitChunksExactMut::new(&mut self.buffer, self.length) + } +} + +impl From for Bitmap { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + Bitmap::try_new(buffer.buffer, buffer.length).unwrap() + } +} + +impl From for Option { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + let unset_bits = buffer.unset_bits(); + if unset_bits > 0 { + // safety: + // invariants of the `MutableBitmap` equal that of `Bitmap` + let bitmap = unsafe { + Bitmap::from_inner_unchecked( + Arc::new(buffer.buffer.into()), + 0, + buffer.length, + unset_bits, + ) + }; + Some(bitmap) + } else { + None + } + } +} + +impl> From

for MutableBitmap { + #[inline] + fn from(slice: P) -> Self { + MutableBitmap::from_trusted_len_iter(slice.as_ref().iter().copied()) + } +} + +impl FromIterator for MutableBitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut iterator = iter.into_iter(); + let mut buffer = { + let byte_capacity: usize = iterator.size_hint().0.saturating_add(7) / 8; + Vec::with_capacity(byte_capacity) + }; + + let mut length = 0; + + loop { + let mut exhausted = false; + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + + //collect (up to) 8 bits into a byte + while mask != 0 { + if let Some(value) = iterator.next() { + length += 1; + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } else { + exhausted = true; + break; + } + } + + // break if the iterator was exhausted before it provided a bool for this byte + if exhausted && mask == 1 { + break; + } + + //ensure we have capacity to write the byte + if buffer.len() == buffer.capacity() { + //no capacity for new byte, allocate 1 byte more (plus however many more the iterator advertises) + let additional_byte_capacity = 1usize.saturating_add( + iterator.size_hint().0.saturating_add(7) / 8, //convert bit count to byte count, rounding up + ); + buffer.reserve(additional_byte_capacity) + } + + // Soundness: capacity was allocated above + buffer.push(byte_accum); + if exhausted { + break; + } + } + Self { buffer, length } + } +} + +// [7, 6, 5, 4, 3, 2, 1, 0], [15, 14, 13, 12, 11, 10, 9, 8] +// [00000001_00000000_00000000_00000000_...] // u64 +/// # Safety +/// The iterator must be trustedLen and its len must be least `len`. +#[inline] +unsafe fn get_chunk_unchecked(iterator: &mut impl Iterator) -> u64 { + let mut byte = 0u64; + let mut mask; + for i in 0..8 { + mask = 1u64 << (8 * i); + for _ in 0..8 { + let value = match iterator.next() { + Some(value) => value, + None => unsafe { unreachable_unchecked() }, + }; + + byte |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } + } + byte +} + +/// # Safety +/// The iterator must be trustedLen and its len must be least `len`. +#[inline] +unsafe fn get_byte_unchecked(len: usize, iterator: &mut impl Iterator) -> u8 { + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + for _ in 0..len { + let value = match iterator.next() { + Some(value) => value, + None => unsafe { unreachable_unchecked() }, + }; + + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } + byte_accum +} + +/// Extends the [`Vec`] from `iterator` +/// # Safety +/// The iterator MUST be [`TrustedLen`]. +#[inline] +unsafe fn extend_aligned_trusted_iter_unchecked( + buffer: &mut Vec, + mut iterator: impl Iterator, +) -> usize { + let additional_bits = iterator.size_hint().1.unwrap(); + let chunks = additional_bits / 64; + let remainder = additional_bits % 64; + + let additional = (additional_bits + 7) / 8; + assert_eq!( + additional, + // a hint of how the following calculation will be done + chunks * 8 + remainder / 8 + (remainder % 8 > 0) as usize + ); + buffer.reserve(additional); + + // chunks of 64 bits + for _ in 0..chunks { + let chunk = get_chunk_unchecked(&mut iterator); + buffer.extend_from_slice(&chunk.to_le_bytes()); + } + + // remaining complete bytes + for _ in 0..(remainder / 8) { + let byte = unsafe { get_byte_unchecked(8, &mut iterator) }; + buffer.push(byte) + } + + // remaining bits + let remainder = remainder % 8; + if remainder > 0 { + let byte = unsafe { get_byte_unchecked(remainder, &mut iterator) }; + buffer.push(byte) + } + additional_bits +} + +impl MutableBitmap { + /// Extends `self` from a [`TrustedLen`] iterator. + #[inline] + pub fn extend_from_trusted_len_iter>(&mut self, iterator: I) { + // safety: I: TrustedLen + unsafe { self.extend_from_trusted_len_iter_unchecked(iterator) } + } + + /// Extends `self` from an iterator of trusted len. + /// # Safety + /// The caller must guarantee that the iterator has a trusted len. + #[inline] + pub unsafe fn extend_from_trusted_len_iter_unchecked>( + &mut self, + mut iterator: I, + ) { + // the length of the iterator throughout this function. + let mut length = iterator.size_hint().1.unwrap(); + + let bit_offset = self.length % 8; + + if length < 8 - bit_offset { + if bit_offset == 0 { + self.buffer.push(0); + } + // the iterator will not fill the last byte + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + let mut i = bit_offset; + for value in iterator { + *byte = set(*byte, i, value); + i += 1; + } + self.length += length; + return; + } + + // at this point we know that length will hit a byte boundary and thus + // increase the buffer. + + if bit_offset != 0 { + // we are in the middle of a byte; lets finish it + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + (bit_offset..8).for_each(|i| { + *byte = set(*byte, i, iterator.next().unwrap()); + }); + self.length += 8 - bit_offset; + length -= 8 - bit_offset; + } + + // everything is aligned; proceed with the bulk operation + debug_assert_eq!(self.length % 8, 0); + + unsafe { extend_aligned_trusted_iter_unchecked(&mut self.buffer, iterator) }; + self.length += length; + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + I: Iterator, + { + let mut buffer = Vec::::new(); + + let length = extend_aligned_trusted_iter_unchecked(&mut buffer, iterator); + + Self { buffer, length } + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + I: TrustedLen, + { + // Safety: Iterator is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + I: TrustedLen>, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBitmap`] from an falible iterator of booleans. + /// # Safety + /// The caller must guarantee that the iterator is `TrustedLen`. + pub unsafe fn try_from_trusted_len_iter_unchecked( + mut iterator: I, + ) -> std::result::Result + where + I: Iterator>, + { + let length = iterator.size_hint().1.unwrap(); + + let mut buffer = vec![0u8; (length + 7) / 8]; + + let chunks = length / 8; + let reminder = length % 8; + + let data = buffer.as_mut_slice(); + data[..chunks].iter_mut().try_for_each(|byte| { + (0..8).try_for_each(|i| { + *byte = set(*byte, i, iterator.next().unwrap()?); + Ok(()) + }) + })?; + + if reminder != 0 { + let last = &mut data[chunks]; + iterator.enumerate().try_for_each(|(i, value)| { + *last = set(*last, i, value?); + Ok(()) + })?; + } + + Ok(Self { buffer, length }) + } + + fn extend_unaligned(&mut self, slice: &[u8], offset: usize, length: usize) { + // e.g. + // [a, b, --101010] <- to be extended + // [00111111, 11010101] <- to extend + // [a, b, 11101010, --001111] expected result + + let aligned_offset = offset / 8; + let own_offset = self.length % 8; + debug_assert_eq!(offset % 8, 0); // assumed invariant + debug_assert!(own_offset != 0); // assumed invariant + + let bytes_len = length.saturating_add(7) / 8; + let items = &slice[aligned_offset..aligned_offset + bytes_len]; + // self has some offset => we need to shift all `items`, and merge the first + let buffer = self.buffer.as_mut_slice(); + let last = &mut buffer[buffer.len() - 1]; + + // --101010 | 00111111 << 6 = 11101010 + // erase previous + *last &= 0b11111111u8 >> (8 - own_offset); // unset before setting + *last |= items[0] << own_offset; + + if length + own_offset <= 8 { + // no new bytes needed + self.length += length; + return; + } + let additional = length - (8 - own_offset); + + let remaining = [items[items.len() - 1], 0]; + let bytes = items + .windows(2) + .chain(std::iter::once(remaining.as_ref())) + .map(|w| merge_reversed(w[0], w[1], 8 - own_offset)) + .take(additional.saturating_add(7) / 8); + self.buffer.extend(bytes); + + self.length += length; + } + + fn extend_aligned(&mut self, slice: &[u8], offset: usize, length: usize) { + let aligned_offset = offset / 8; + let bytes_len = length.saturating_add(7) / 8; + let items = &slice[aligned_offset..aligned_offset + bytes_len]; + self.buffer.extend_from_slice(items); + self.length += length; + } + + /// Extends the [`MutableBitmap`] from a slice of bytes with optional offset. + /// This is the fastest way to extend a [`MutableBitmap`]. + /// # Implementation + /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, + /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + /// # Safety + /// Caller must ensure `offset + length <= slice.len() * 8` + #[inline] + pub unsafe fn extend_from_slice_unchecked( + &mut self, + slice: &[u8], + offset: usize, + length: usize, + ) { + if length == 0 { + return; + }; + let is_aligned = self.length % 8 == 0; + let other_is_aligned = offset % 8 == 0; + match (is_aligned, other_is_aligned) { + (true, true) => self.extend_aligned(slice, offset, length), + (false, true) => self.extend_unaligned(slice, offset, length), + // todo: further optimize the other branches. + _ => self.extend_from_trusted_len_iter(BitmapIter::new(slice, offset, length)), + } + // internal invariant: + debug_assert_eq!(self.length.saturating_add(7) / 8, self.buffer.len()); + } + + /// Extends the [`MutableBitmap`] from a slice of bytes with optional offset. + /// This is the fastest way to extend a [`MutableBitmap`]. + /// # Implementation + /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, + /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + #[inline] + pub fn extend_from_slice(&mut self, slice: &[u8], offset: usize, length: usize) { + assert!(offset + length <= slice.len() * 8); + // safety: invariant is asserted + unsafe { self.extend_from_slice_unchecked(slice, offset, length) } + } + + /// Extends the [`MutableBitmap`] from a [`Bitmap`]. + #[inline] + pub fn extend_from_bitmap(&mut self, bitmap: &Bitmap) { + let (slice, offset, length) = bitmap.as_slice(); + // safety: bitmap.as_slice adheres to the invariant + unsafe { + self.extend_from_slice_unchecked(slice, offset, length); + } + } + + /// Returns the slice of bytes of this [`MutableBitmap`]. + /// Note that the last byte may not be fully used. + #[inline] + pub fn as_slice(&self) -> &[u8] { + let len = (self.length).saturating_add(7) / 8; + &self.buffer[..len] + } +} + +impl Default for MutableBitmap { + fn default() -> Self { + Self::new() + } +} + +impl<'a> IntoIterator for &'a MutableBitmap { + type Item = bool; + type IntoIter = BitmapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitmapIter::<'a>::new(&self.buffer, 0, self.length) + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs new file mode 100644 index 000000000000..4ab9d300ba02 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -0,0 +1,101 @@ +use std::convert::TryInto; +use std::slice::ChunksExact; + +use super::{BitChunk, BitChunkIterExact}; +use crate::trusted_len::TrustedLen; + +/// An iterator over a slice of bytes in [`BitChunk`]s. +#[derive(Debug)] +pub struct BitChunksExact<'a, T: BitChunk> { + iter: ChunksExact<'a, u8>, + remainder: &'a [u8], + remainder_len: usize, + phantom: std::marker::PhantomData, +} + +impl<'a, T: BitChunk> BitChunksExact<'a, T> { + /// Creates a new [`BitChunksExact`]. + #[inline] + pub fn new(bitmap: &'a [u8], length: usize) -> Self { + assert!(length <= bitmap.len() * 8); + let size_of = std::mem::size_of::(); + + let bitmap = &bitmap[..length.saturating_add(7) / 8]; + + let split = (length / 8 / size_of) * size_of; + let (chunks, remainder) = bitmap.split_at(split); + let remainder_len = length - chunks.len() * 8; + let iter = chunks.chunks_exact(size_of); + + Self { + iter, + remainder, + remainder_len, + phantom: std::marker::PhantomData, + } + } + + /// Returns the number of chunks of this iterator + #[inline] + pub fn len(&self) -> usize { + self.iter.len() + } + + /// Returns whether there are still elements in this iterator + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the remaining [`BitChunk`]. It is zero iff `len / 8 == 0`. + #[inline] + pub fn remainder(&self) -> T { + let remainder_bytes = self.remainder; + if remainder_bytes.is_empty() { + return T::zero(); + } + let remainder = match remainder_bytes.try_into() { + Ok(a) => a, + Err(_) => { + let mut remainder = T::zero().to_ne_bytes(); + remainder_bytes + .iter() + .enumerate() + .for_each(|(index, b)| remainder[index] = *b); + remainder + }, + }; + T::from_ne_bytes(remainder) + } +} + +impl Iterator for BitChunksExact<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|x| match x.try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +unsafe impl TrustedLen for BitChunksExact<'_, T> {} + +impl BitChunkIterExact for BitChunksExact<'_, T> { + #[inline] + fn remainder(&self) -> T { + self.remainder() + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/chunk_iterator/merge.rs b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/merge.rs new file mode 100644 index 000000000000..81e08df0059e --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/merge.rs @@ -0,0 +1,61 @@ +use super::BitChunk; + +/// Merges 2 [`BitChunk`]s into a single [`BitChunk`] so that the new items represents +/// the bitmap where bits from `next` are placed in `current` according to `offset`. +/// # Panic +/// The caller must ensure that `0 < offset < size_of::() * 8` +/// # Example +/// ```rust,ignore +/// let current = 0b01011001; +/// let next = 0b01011011; +/// let result = merge_reversed(current, next, 1); +/// assert_eq!(result, 0b10101100); +/// ``` +#[inline] +pub fn merge_reversed(mut current: T, mut next: T, offset: usize) -> T +where + T: BitChunk, +{ + // 8 _bits_: + // current = [c0, c1, c2, c3, c4, c5, c6, c7] + // next = [n0, n1, n2, n3, n4, n5, n6, n7] + // offset = 3 + // expected = [n5, n6, n7, c0, c1, c2, c3, c4] + + // 1. unset most significants of `next` up to `offset` + let inverse_offset = std::mem::size_of::() * 8 - offset; + next <<= inverse_offset; + // next = [n5, n6, n7, 0 , 0 , 0 , 0 , 0 ] + + // 2. unset least significants of `current` up to `offset` + current >>= offset; + // current = [0 , 0 , 0 , c0, c1, c2, c3, c4] + + current | next +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge_reversed() { + let current = 0b00000000; + let next = 0b00000001; + let result = merge_reversed::(current, next, 1); + assert_eq!(result, 0b10000000); + + let current = 0b01011001; + let next = 0b01011011; + let result = merge_reversed::(current, next, 1); + assert_eq!(result, 0b10101100); + } + + #[test] + fn test_merge_reversed_offset2() { + let current = 0b00000000; + let next = 0b00000001; + let result = merge_reversed::(current, next, 3); + assert_eq!(result, 0b00100000); + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/mod.rs new file mode 100644 index 000000000000..71f56a284274 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -0,0 +1,206 @@ +use std::convert::TryInto; + +mod chunks_exact; +mod merge; + +pub use chunks_exact::BitChunksExact; +pub(crate) use merge::merge_reversed; + +use crate::trusted_len::TrustedLen; +pub use crate::types::BitChunk; +use crate::types::BitChunkIter; + +/// Trait representing an exact iterator over bytes in [`BitChunk`]. +pub trait BitChunkIterExact: TrustedLen { + /// The remainder of the iterator. + fn remainder(&self) -> B; + + /// The number of items in the remainder + fn remainder_len(&self) -> usize; + + /// An iterator over individual items of the remainder + #[inline] + fn remainder_iter(&self) -> BitChunkIter { + BitChunkIter::new(self.remainder(), self.remainder_len()) + } +} + +/// This struct is used to efficiently iterate over bit masks by loading bytes on +/// the stack with alignments of `uX`. This allows efficient iteration over bitmaps. +#[derive(Debug)] +pub struct BitChunks<'a, T: BitChunk> { + chunk_iterator: std::slice::ChunksExact<'a, u8>, + current: T, + remainder_bytes: &'a [u8], + last_chunk: T, + remaining: usize, + /// offset inside a byte + bit_offset: usize, + len: usize, + phantom: std::marker::PhantomData, +} + +/// writes `bytes` into `dst`. +#[inline] +fn copy_with_merge(dst: &mut T::Bytes, bytes: &[u8], bit_offset: usize) { + bytes + .windows(2) + .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref())) + .take(std::mem::size_of::()) + .enumerate() + .for_each(|(i, w)| { + let val = merge_reversed(w[0], w[1], bit_offset); + dst[i] = val; + }); +} + +impl<'a, T: BitChunk> BitChunks<'a, T> { + /// Creates a [`BitChunks`]. + pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self { + assert!(offset + len <= slice.len() * 8); + + let slice = &slice[offset / 8..]; + let bit_offset = offset % 8; + let size_of = std::mem::size_of::(); + + let bytes_len = len / 8; + let bytes_upper_len = (len + bit_offset + 7) / 8; + let mut chunks = slice[..bytes_len].chunks_exact(size_of); + + let remainder = &slice[bytes_len - chunks.remainder().len()..bytes_upper_len]; + + let remainder_bytes = if chunks.len() == 0 { slice } else { remainder }; + + let last_chunk = remainder_bytes + .first() + .map(|first| { + let mut last = T::zero().to_ne_bytes(); + last[0] = *first; + T::from_ne_bytes(last) + }) + .unwrap_or_else(T::zero); + + let remaining = chunks.size_hint().0; + + let current = chunks + .next() + .map(|x| match x.try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }) + .unwrap_or_else(T::zero); + + Self { + chunk_iterator: chunks, + len, + current, + remaining, + remainder_bytes, + last_chunk, + bit_offset, + phantom: std::marker::PhantomData, + } + } + + #[inline] + fn load_next(&mut self) { + self.current = match self.chunk_iterator.next().unwrap().try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + } + + /// Returns the remainder [`BitChunk`]. + pub fn remainder(&self) -> T { + // remaining bytes may not fit in `size_of::()`. We complement + // them to fit by allocating T and writing to it byte by byte + let mut remainder = T::zero().to_ne_bytes(); + + let remainder = match (self.remainder_bytes.is_empty(), self.bit_offset == 0) { + (true, _) => remainder, + (false, true) => { + // all remaining bytes + self.remainder_bytes + .iter() + .take(std::mem::size_of::()) + .enumerate() + .for_each(|(i, val)| remainder[i] = *val); + + remainder + }, + (false, false) => { + // all remaining bytes + copy_with_merge::(&mut remainder, self.remainder_bytes, self.bit_offset); + remainder + }, + }; + T::from_ne_bytes(remainder) + } + + /// Returns the remainder bits in [`BitChunks::remainder`]. + pub fn remainder_len(&self) -> usize { + self.len - (std::mem::size_of::() * ((self.len / 8) / std::mem::size_of::()) * 8) + } +} + +impl Iterator for BitChunks<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + + let current = self.current; + let combined = if self.bit_offset == 0 { + // fast case where there is no offset. In this case, there is bit-alignment + // at byte boundary and thus the bytes correspond exactly. + if self.remaining >= 2 { + self.load_next(); + } + current + } else { + let next = if self.remaining >= 2 { + // case where `next` is complete and thus we can take it all + self.load_next(); + self.current + } else { + // case where the `next` is incomplete and thus we take the remaining + self.last_chunk + }; + merge_reversed(current, next, self.bit_offset) + }; + + self.remaining -= 1; + Some(combined) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // it contains always one more than the chunk_iterator, which is the last + // one where the remainder is merged into current. + (self.remaining, Some(self.remaining)) + } +} + +impl BitChunkIterExact for BitChunks<'_, T> { + #[inline] + fn remainder(&self) -> T { + self.remainder() + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len() + } +} + +impl ExactSizeIterator for BitChunks<'_, T> { + #[inline] + fn len(&self) -> usize { + self.chunk_iterator.len() + } +} + +unsafe impl TrustedLen for BitChunks<'_, T> {} diff --git a/crates/nano-arrow/src/bitmap/utils/chunks_exact_mut.rs b/crates/nano-arrow/src/bitmap/utils/chunks_exact_mut.rs new file mode 100644 index 000000000000..7a5a91a12805 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunks_exact_mut.rs @@ -0,0 +1,63 @@ +use super::BitChunk; + +/// An iterator over mutable slices of bytes of exact size. +/// +/// # Safety +/// The slices returned by this iterator are guaranteed to have length equal to +/// `std::mem::size_of::()`. +#[derive(Debug)] +pub struct BitChunksExactMut<'a, T: BitChunk> { + chunks: std::slice::ChunksExactMut<'a, u8>, + remainder: &'a mut [u8], + remainder_len: usize, + marker: std::marker::PhantomData, +} + +impl<'a, T: BitChunk> BitChunksExactMut<'a, T> { + /// Returns a new [`BitChunksExactMut`] + #[inline] + pub fn new(bitmap: &'a mut [u8], length: usize) -> Self { + assert!(length <= bitmap.len() * 8); + let size_of = std::mem::size_of::(); + + let bitmap = &mut bitmap[..length.saturating_add(7) / 8]; + + let split = (length / 8 / size_of) * size_of; + let (chunks, remainder) = bitmap.split_at_mut(split); + let remainder_len = length - chunks.len() * 8; + + let chunks = chunks.chunks_exact_mut(size_of); + Self { + chunks, + remainder, + remainder_len, + marker: std::marker::PhantomData, + } + } + + /// The remainder slice + #[inline] + pub fn remainder(&mut self) -> &mut [u8] { + self.remainder + } + + /// The length of the remainder slice in bits. + #[inline] + pub fn remainder_len(&mut self) -> usize { + self.remainder_len + } +} + +impl<'a, T: BitChunk> Iterator for BitChunksExactMut<'a, T> { + type Item = &'a mut [u8]; + + #[inline] + fn next(&mut self) -> Option { + self.chunks.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.chunks.size_hint() + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/fmt.rs b/crates/nano-arrow/src/bitmap/utils/fmt.rs new file mode 100644 index 000000000000..45fe9ec9ced3 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/fmt.rs @@ -0,0 +1,72 @@ +use std::fmt::Write; + +use super::is_set; + +/// Formats `bytes` taking into account an offset and length of the form +pub fn fmt( + bytes: &[u8], + offset: usize, + length: usize, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + assert!(offset < 8); + + f.write_char('[')?; + let mut remaining = length; + if remaining == 0 { + f.write_char(']')?; + return Ok(()); + } + + let first = bytes[0]; + let bytes = &bytes[1..]; + let empty_before = 8usize.saturating_sub(remaining + offset); + f.write_str("0b")?; + for _ in 0..empty_before { + f.write_char('_')?; + } + let until = std::cmp::min(8, offset + remaining); + for i in offset..until { + if is_set(first, offset + until - 1 - i) { + f.write_char('1')?; + } else { + f.write_char('0')?; + } + } + for _ in 0..offset { + f.write_char('_')?; + } + remaining -= until - offset; + + if remaining == 0 { + f.write_char(']')?; + return Ok(()); + } + + let number_of_bytes = remaining / 8; + for byte in &bytes[..number_of_bytes] { + f.write_str(", ")?; + f.write_fmt(format_args!("{byte:#010b}"))?; + } + remaining -= number_of_bytes * 8; + if remaining == 0 { + f.write_char(']')?; + return Ok(()); + } + + let last = bytes[std::cmp::min((length + offset + 7) / 8, bytes.len() - 1)]; + let remaining = (length + offset) % 8; + f.write_str(", ")?; + f.write_str("0b")?; + for _ in 0..(8 - remaining) { + f.write_char('_')?; + } + for i in 0..remaining { + if is_set(last, remaining - 1 - i) { + f.write_char('1')?; + } else { + f.write_char('0')?; + } + } + f.write_char(']') +} diff --git a/crates/nano-arrow/src/bitmap/utils/iterator.rs b/crates/nano-arrow/src/bitmap/utils/iterator.rs new file mode 100644 index 000000000000..1a35ad56b562 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/iterator.rs @@ -0,0 +1,82 @@ +use super::get_bit_unchecked; +use crate::trusted_len::TrustedLen; + +/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit), +/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`. +#[derive(Debug, Clone)] +pub struct BitmapIter<'a> { + bytes: &'a [u8], + index: usize, + end: usize, +} + +impl<'a> BitmapIter<'a> { + /// Creates a new [`BitmapIter`]. + pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self { + // example: + // slice.len() = 4 + // offset = 9 + // len = 23 + // result: + let bytes = &slice[offset / 8..]; + // bytes.len() = 3 + let index = offset % 8; + // index = 9 % 8 = 1 + let end = len + index; + // end = 23 + 1 = 24 + assert!(end <= bytes.len() * 8); + // maximum read before UB in bits: bytes.len() * 8 = 24 + // the first read from the end is `end - 1`, thus, end = 24 is ok + + Self { bytes, index, end } + } +} + +impl<'a> Iterator for BitmapIter<'a> { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + // See comment in `new` + Some(unsafe { get_bit_unchecked(self.bytes, old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let exact = self.end - self.index; + (exact, Some(exact)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl<'a> DoubleEndedIterator for BitmapIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + // See comment in `new`; end was first decreased + Some(unsafe { get_bit_unchecked(self.bytes, self.end) }) + } + } +} + +unsafe impl TrustedLen for BitmapIter<'_> {} +impl ExactSizeIterator for BitmapIter<'_> {} diff --git a/crates/nano-arrow/src/bitmap/utils/mod.rs b/crates/nano-arrow/src/bitmap/utils/mod.rs new file mode 100644 index 000000000000..b064ffd8bed7 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/mod.rs @@ -0,0 +1,143 @@ +//! General utilities for bitmaps representing items where LSB is the first item. +mod chunk_iterator; +mod chunks_exact_mut; +mod fmt; +mod iterator; +mod slice_iterator; +mod zip_validity; + +use std::convert::TryInto; + +pub(crate) use chunk_iterator::merge_reversed; +pub use chunk_iterator::{BitChunk, BitChunkIterExact, BitChunks, BitChunksExact}; +pub use chunks_exact_mut::BitChunksExactMut; +pub use fmt::fmt; +pub use iterator::BitmapIter; +pub use slice_iterator::SlicesIterator; +pub use zip_validity::{ZipValidity, ZipValidityIter}; + +const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; +const UNSET_BIT_MASK: [u8; 8] = [ + 255 - 1, + 255 - 2, + 255 - 4, + 255 - 8, + 255 - 16, + 255 - 32, + 255 - 64, + 255 - 128, +]; + +/// Returns whether bit at position `i` in `byte` is set or not +#[inline] +pub fn is_set(byte: u8, i: usize) -> bool { + (byte & BIT_MASK[i]) != 0 +} + +/// Sets bit at position `i` in `byte` +#[inline] +pub fn set(byte: u8, i: usize, value: bool) -> u8 { + if value { + byte | BIT_MASK[i] + } else { + byte & UNSET_BIT_MASK[i] + } +} + +/// Sets bit at position `i` in `data` +/// # Panics +/// panics if `i >= data.len() / 8` +#[inline] +pub fn set_bit(data: &mut [u8], i: usize, value: bool) { + data[i / 8] = set(data[i / 8], i % 8, value); +} + +/// Sets bit at position `i` in `data` without doing bound checks +/// # Safety +/// caller must ensure that `i < data.len() / 8` +#[inline] +pub unsafe fn set_bit_unchecked(data: &mut [u8], i: usize, value: bool) { + let byte = data.get_unchecked_mut(i / 8); + *byte = set(*byte, i % 8, value); +} + +/// Returns whether bit at position `i` in `data` is set +/// # Panic +/// This function panics iff `i / 8 >= bytes.len()` +#[inline] +pub fn get_bit(bytes: &[u8], i: usize) -> bool { + is_set(bytes[i / 8], i % 8) +} + +/// Returns whether bit at position `i` in `data` is set or not. +/// +/// # Safety +/// `i >= data.len() * 8` results in undefined behavior +#[inline] +pub unsafe fn get_bit_unchecked(data: &[u8], i: usize) -> bool { + (*data.as_ptr().add(i >> 3) & BIT_MASK[i & 7]) != 0 +} + +/// Returns the number of bytes required to hold `bits` bits. +#[inline] +pub fn bytes_for(bits: usize) -> usize { + bits.saturating_add(7) / 8 +} + +/// Returns the number of zero bits in the slice offsetted by `offset` and a length of `length`. +/// # Panics +/// This function panics iff `(offset + len).saturating_add(7) / 8 >= slice.len()` +/// because it corresponds to the situation where `len` is beyond bounds. +pub fn count_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + }; + + let mut slice = &slice[offset / 8..(offset + len).saturating_add(7) / 8]; + let offset = offset % 8; + + if (offset + len) / 8 == 0 { + // all within a single byte + let byte = (slice[0] >> offset) << (8 - len); + return len - byte.count_ones() as usize; + } + + // slice: [a1,a2,a3,a4], [a5,a6,a7,a8] + // offset: 3 + // len: 4 + // [__,__,__,a4], [a5,a6,a7,__] + let mut set_count = 0; + if offset != 0 { + // count all ignoring the first `offset` bits + // i.e. [__,__,__,a4] + set_count += (slice[0] >> offset).count_ones() as usize; + slice = &slice[1..]; + } + if (offset + len) % 8 != 0 { + let end_offset = (offset + len) % 8; // i.e. 3 + 4 = 7 + let last_index = slice.len() - 1; + // count all ignoring the last `offset` bits + // i.e. [a5,a6,a7,__] + set_count += (slice[last_index] << (8 - end_offset)).count_ones() as usize; + slice = &slice[..last_index]; + } + + // finally, count any and all bytes in the middle in groups of 8 + let mut chunks = slice.chunks_exact(8); + set_count += chunks + .by_ref() + .map(|chunk| { + let a = u64::from_ne_bytes(chunk.try_into().unwrap()); + a.count_ones() as usize + }) + .sum::(); + + // and any bytes that do not fit in the group + set_count += chunks + .remainder() + .iter() + .map(|byte| byte.count_ones() as usize) + .sum::(); + + len - set_count +} diff --git a/crates/nano-arrow/src/bitmap/utils/slice_iterator.rs b/crates/nano-arrow/src/bitmap/utils/slice_iterator.rs new file mode 100644 index 000000000000..dc388f1d41b5 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/slice_iterator.rs @@ -0,0 +1,145 @@ +use crate::bitmap::Bitmap; + +/// Internal state of [`SlicesIterator`] +#[derive(Debug, Clone, PartialEq)] +enum State { + // normal iteration + Nominal, + // nothing more to iterate. + Finished, +} + +/// Iterator over a bitmap that returns slices of set regions +/// This is the most efficient method to extract slices of values from arrays +/// with a validity bitmap. +/// For example, the bitmap `00101111` returns `[(0,4), (6,1)]` +#[derive(Debug, Clone)] +pub struct SlicesIterator<'a> { + values: std::slice::Iter<'a, u8>, + count: usize, + mask: u8, + max_len: usize, + current_byte: &'a u8, + state: State, + len: usize, + start: usize, + on_region: bool, +} + +impl<'a> SlicesIterator<'a> { + /// Creates a new [`SlicesIterator`] + pub fn new(values: &'a Bitmap) -> Self { + let (buffer, offset, _) = values.as_slice(); + let mut iter = buffer.iter(); + + let (current_byte, state) = match iter.next() { + Some(b) => (b, State::Nominal), + None => (&0, State::Finished), + }; + + Self { + state, + count: values.len() - values.unset_bits(), + max_len: values.len(), + values: iter, + mask: 1u8.rotate_left(offset as u32), + current_byte, + len: 0, + start: 0, + on_region: false, + } + } + + #[inline] + fn finish(&mut self) -> Option<(usize, usize)> { + self.state = State::Finished; + if self.on_region { + Some((self.start, self.len)) + } else { + None + } + } + + #[inline] + fn current_len(&self) -> usize { + self.start + self.len + } + + /// Returns the total number of slots. + /// It corresponds to the sum of all lengths of all slices. + #[inline] + pub fn slots(&self) -> usize { + self.count + } +} + +impl<'a> Iterator for SlicesIterator<'a> { + type Item = (usize, usize); + + #[inline] + fn next(&mut self) -> Option { + loop { + if self.state == State::Finished { + return None; + } + if self.current_len() == self.max_len { + return self.finish(); + } + + if self.mask == 1 { + // at the beginning of a byte => try to skip it all together + match (self.on_region, self.current_byte) { + (true, &255u8) => { + self.len = std::cmp::min(self.max_len - self.start, self.len + 8); + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + continue; + }, + (false, &0) => { + self.len = std::cmp::min(self.max_len - self.start, self.len + 8); + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + continue; + }, + _ => (), // we need to run over all bits of this byte + } + }; + + let value = (self.current_byte & self.mask) != 0; + self.mask = self.mask.rotate_left(1); + + match (self.on_region, value) { + (true, true) => self.len += 1, + (false, false) => self.len += 1, + (true, false) => { + self.on_region = false; + let result = (self.start, self.len); + self.start += self.len; + self.len = 1; + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + } + return Some(result); + }, + (false, true) => { + self.start += self.len; + self.len = 1; + self.on_region = true; + }, + } + + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + match self.values.next() { + Some(v) => self.current_byte = v, + None => return self.finish(), + }; + } + } + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/zip_validity.rs b/crates/nano-arrow/src/bitmap/utils/zip_validity.rs new file mode 100644 index 000000000000..40965bab4113 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/zip_validity.rs @@ -0,0 +1,216 @@ +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::Bitmap; +use crate::trusted_len::TrustedLen; + +/// An [`Iterator`] over validity and values. +#[derive(Debug, Clone)] +pub struct ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + values: I, + validity: V, +} + +impl ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + /// Creates a new [`ZipValidityIter`]. + /// # Panics + /// This function panics if the size_hints of the iterators are different + pub fn new(values: I, validity: V) -> Self { + assert_eq!(values.size_hint(), validity.size_hint()); + Self { values, validity } + } +} + +impl Iterator for ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + let value = self.values.next(); + let is_valid = self.validity.next(); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.values.size_hint() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let value = self.values.nth(n); + let is_valid = self.validity.nth(n); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } +} + +impl DoubleEndedIterator for ZipValidityIter +where + I: DoubleEndedIterator, + V: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + let value = self.values.next_back(); + let is_valid = self.validity.next_back(); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } +} + +unsafe impl TrustedLen for ZipValidityIter +where + I: TrustedLen, + V: TrustedLen, +{ +} + +impl ExactSizeIterator for ZipValidityIter +where + I: ExactSizeIterator, + V: ExactSizeIterator, +{ +} + +/// An [`Iterator`] over [`Option`] +/// This enum can be used in two distinct ways: +/// * as an iterator, via `Iterator::next` +/// * as an enum of two iterators, via `match self` +/// The latter allows specializalizing to when there are no nulls +#[derive(Debug, Clone)] +pub enum ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// There are no null values + Required(I), + /// There are null values + Optional(ZipValidityIter), +} + +impl ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// Returns a new [`ZipValidity`] + pub fn new(values: I, validity: Option) -> Self { + match validity { + Some(validity) => Self::Optional(ZipValidityIter::new(values, validity)), + _ => Self::Required(values), + } + } +} + +impl<'a, T, I> ZipValidity> +where + I: Iterator, +{ + /// Returns a new [`ZipValidity`] and drops the `validity` if all values + /// are valid. + pub fn new_with_validity(values: I, validity: Option<&'a Bitmap>) -> Self { + // only if the validity has nulls we take the optional branch. + match validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.iter())) { + Some(validity) => Self::Optional(ZipValidityIter::new(values, validity)), + _ => Self::Required(values), + } + } +} + +impl Iterator for ZipValidity +where + I: Iterator, + V: Iterator, +{ + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + match self { + Self::Required(values) => values.next().map(Some), + Self::Optional(zipped) => zipped.next(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + Self::Required(values) => values.size_hint(), + Self::Optional(zipped) => zipped.size_hint(), + } + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + match self { + Self::Required(values) => values.nth(n).map(Some), + Self::Optional(zipped) => zipped.nth(n), + } + } +} + +impl DoubleEndedIterator for ZipValidity +where + I: DoubleEndedIterator, + V: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + match self { + Self::Required(values) => values.next_back().map(Some), + Self::Optional(zipped) => zipped.next_back(), + } + } +} + +impl ExactSizeIterator for ZipValidity +where + I: ExactSizeIterator, + V: ExactSizeIterator, +{ +} + +unsafe impl TrustedLen for ZipValidity +where + I: TrustedLen, + V: TrustedLen, +{ +} + +impl ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// Unwrap into an iterator that has no null values. + pub fn unwrap_required(self) -> I { + match self { + ZipValidity::Required(i) => i, + _ => panic!("Could not 'unwrap_required'. 'ZipValidity' iterator has nulls."), + } + } + + /// Unwrap into an iterator that has null values. + pub fn unwrap_optional(self) -> ZipValidityIter { + match self { + ZipValidity::Optional(i) => i, + _ => panic!("Could not 'unwrap_optional'. 'ZipValidity' iterator has no nulls."), + } + } +} diff --git a/crates/nano-arrow/src/buffer/immutable.rs b/crates/nano-arrow/src/buffer/immutable.rs new file mode 100644 index 000000000000..4093734a1114 --- /dev/null +++ b/crates/nano-arrow/src/buffer/immutable.rs @@ -0,0 +1,328 @@ +use std::iter::FromIterator; +use std::ops::Deref; +use std::sync::Arc; +use std::usize; + +use either::Either; + +use super::{Bytes, IntoIter}; + +/// [`Buffer`] is a contiguous memory region that can be shared across +/// thread boundaries. +/// +/// The easiest way to think about [`Buffer`] is being equivalent to +/// a `Arc>`, with the following differences: +/// * slicing and cloning is `O(1)`. +/// * it supports external allocated memory +/// +/// The easiest way to create one is to use its implementation of `From>`. +/// +/// # Examples +/// ``` +/// use arrow2::buffer::Buffer; +/// +/// let mut buffer: Buffer = vec![1, 2, 3].into(); +/// assert_eq!(buffer.as_ref(), [1, 2, 3].as_ref()); +/// +/// // it supports copy-on-write semantics (i.e. back to a `Vec`) +/// let vec: Vec = buffer.into_mut().right().unwrap(); +/// assert_eq!(vec, vec![1, 2, 3]); +/// +/// // cloning and slicing is `O(1)` (data is shared) +/// let mut buffer: Buffer = vec![1, 2, 3].into(); +/// let mut sliced = buffer.clone(); +/// sliced.slice(1, 1); +/// assert_eq!(sliced.as_ref(), [2].as_ref()); +/// // but cloning forbids getting mut since `slice` and `buffer` now share data +/// assert_eq!(buffer.get_mut_slice(), None); +/// ``` +#[derive(Clone)] +pub struct Buffer { + /// the internal byte buffer. + data: Arc>, + + /// The offset into the buffer. + offset: usize, + + // the length of the buffer. Given a region `data` of N bytes, [offset..offset+length] is visible + // to this buffer. + length: usize, +} + +impl PartialEq for Buffer { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&**self, f) + } +} + +impl Default for Buffer { + #[inline] + fn default() -> Self { + Vec::new().into() + } +} + +impl Buffer { + /// Creates an empty [`Buffer`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Auxiliary method to create a new Buffer + pub(crate) fn from_bytes(bytes: Bytes) -> Self { + let length = bytes.len(); + Buffer { + data: Arc::new(bytes), + offset: 0, + length, + } + } + + /// Returns the number of bytes in the buffer + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether the buffer is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns whether underlying data is sliced. + /// If sliced the [`Buffer`] is backed by + /// more data than the length of `Self`. + pub fn is_sliced(&self) -> bool { + self.data.len() != self.length + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[T] { + // Safety: + // invariant of this struct `offset + length <= data.len()` + debug_assert!(self.offset + self.length <= self.data.len()); + unsafe { + self.data + .get_unchecked(self.offset..self.offset + self.length) + } + } + + /// Returns the byte slice stored in this buffer + /// # Safety + /// `index` must be smaller than `len` + #[inline] + pub(super) unsafe fn get_unchecked(&self, index: usize) -> &T { + // Safety: + // invariant of this function + debug_assert!(index < self.length); + unsafe { self.data.get_unchecked(self.offset + index) } + } + + /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// # Panics + /// Panics iff `offset + length` is larger than `len`. + #[inline] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + // Safety: we just checked bounds + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Slices this buffer starting at `offset`. + /// # Panics + /// Panics iff `offset` is larger than `len`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + // Safety: we just checked bounds + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + + /// Slices this buffer starting at `offset`. + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.offset += offset; + self.length = length; + } + + /// Returns a pointer to the start of this buffer. + #[inline] + pub(crate) fn as_ptr(&self) -> *const T { + self.data.deref().as_ptr() + } + + /// Returns the offset of this buffer. + #[inline] + pub fn offset(&self) -> usize { + self.offset + } + + /// # Safety + /// The caller must ensure that the buffer was properly initialized up to `len`. + #[inline] + pub unsafe fn set_len(&mut self, len: usize) { + self.length = len; + } + + /// Returns a mutable reference to its underlying [`Vec`], if possible. + /// + /// This operation returns [`Either::Right`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + #[inline] + pub fn into_mut(mut self) -> Either> { + match Arc::get_mut(&mut self.data) + .and_then(|b| b.get_vec()) + .map(std::mem::take) + { + Some(inner) => Either::Right(inner), + None => Either::Left(self), + } + } + + /// Returns a mutable reference to its underlying `Vec`, if possible. + /// Note that only `[self.offset(), self.offset() + self.len()[` in this vector is visible + /// by this buffer. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + /// # Safety + /// The caller must ensure that the vector in the mutable reference keeps a length of at least `self.offset() + self.len() - 1`. + #[inline] + pub unsafe fn get_mut(&mut self) -> Option<&mut Vec> { + Arc::get_mut(&mut self.data).and_then(|b| b.get_vec()) + } + + /// Returns a mutable reference to its slice, if possible. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + #[inline] + pub fn get_mut_slice(&mut self) -> Option<&mut [T]> { + Arc::get_mut(&mut self.data) + .and_then(|b| b.get_vec()) + // Safety: the invariant of this struct + .map(|x| unsafe { x.get_unchecked_mut(self.offset..self.offset + self.length) }) + } + + /// Get the strong count of underlying `Arc` data buffer. + pub fn shared_count_strong(&self) -> usize { + Arc::strong_count(&self.data) + } + + /// Get the weak count of underlying `Arc` data buffer. + pub fn shared_count_weak(&self) -> usize { + Arc::weak_count(&self.data) + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (Arc>, usize, usize) { + let Self { + data, + offset, + length, + } = self; + (data, offset, length) + } + + /// Creates a `[Bitmap]` from its internal representation. + /// This is the inverted from `[Bitmap::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked(data: Arc>, offset: usize, length: usize) -> Self { + Self { + data, + offset, + length, + } + } +} + +impl From> for Buffer { + #[inline] + fn from(p: Vec) -> Self { + let bytes: Bytes = p.into(); + Self { + offset: 0, + length: bytes.len(), + data: Arc::new(bytes), + } + } +} + +impl std::ops::Deref for Buffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl FromIterator for Buffer { + #[inline] + fn from_iter>(iter: I) -> Self { + Vec::from_iter(iter).into() + } +} + +impl IntoIterator for Buffer { + type Item = T; + + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +#[cfg(feature = "arrow_rs")] +impl From for Buffer { + fn from(value: arrow_buffer::Buffer) -> Self { + Self::from_bytes(crate::buffer::to_bytes(value)) + } +} + +#[cfg(feature = "arrow_rs")] +impl From> for arrow_buffer::Buffer { + fn from(value: Buffer) -> Self { + crate::buffer::to_buffer(value.data).slice_with_length( + value.offset * std::mem::size_of::(), + value.length * std::mem::size_of::(), + ) + } +} diff --git a/crates/nano-arrow/src/buffer/iterator.rs b/crates/nano-arrow/src/buffer/iterator.rs new file mode 100644 index 000000000000..93511c480284 --- /dev/null +++ b/crates/nano-arrow/src/buffer/iterator.rs @@ -0,0 +1,68 @@ +use super::Buffer; +use crate::trusted_len::TrustedLen; + +/// This crates' equivalent of [`std::vec::IntoIter`] for [`Buffer`]. +#[derive(Debug, Clone)] +pub struct IntoIter { + values: Buffer, + index: usize, + end: usize, +} + +impl IntoIter { + /// Creates a new [`Buffer`] + #[inline] + pub fn new(values: Buffer) -> Self { + let end = values.len(); + Self { + values, + index: 0, + end, + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(*unsafe { self.values.get_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(*unsafe { self.values.get_unchecked(self.end) }) + } + } +} + +unsafe impl TrustedLen for IntoIter {} diff --git a/crates/nano-arrow/src/buffer/mod.rs b/crates/nano-arrow/src/buffer/mod.rs new file mode 100644 index 000000000000..ef78d5a26e6c --- /dev/null +++ b/crates/nano-arrow/src/buffer/mod.rs @@ -0,0 +1,96 @@ +//! Contains [`Buffer`], an immutable container for all Arrow physical types (e.g. i32, f64). + +mod immutable; +mod iterator; + +use std::ops::Deref; + +use crate::ffi::InternalArrowArray; + +pub(crate) enum BytesAllocator { + InternalArrowArray(InternalArrowArray), + + #[cfg(feature = "arrow_rs")] + Arrow(arrow_buffer::Buffer), +} +pub(crate) type BytesInner = foreign_vec::ForeignVec; + +/// Bytes representation. +#[repr(transparent)] +pub struct Bytes(BytesInner); + +impl Bytes { + /// Takes ownership of an allocated memory region. + /// # Panics + /// This function panics if and only if pointer is not null + /// # Safety + /// This function is safe if and only if `ptr` is valid for `length` + /// # Implementation + /// This function leaks if and only if `owner` does not deallocate + /// the region `[ptr, ptr+length[` when dropped. + #[inline] + pub(crate) unsafe fn from_foreign(ptr: *const T, length: usize, owner: BytesAllocator) -> Self { + Self(BytesInner::from_foreign(ptr, length, owner)) + } + + /// Returns a `Some` mutable reference of [`Vec`] iff this was initialized + /// from a [`Vec`] and `None` otherwise. + #[inline] + pub(crate) fn get_vec(&mut self) -> Option<&mut Vec> { + self.0.get_vec() + } +} + +impl Deref for Bytes { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From> for Bytes { + #[inline] + fn from(data: Vec) -> Self { + let inner: BytesInner = data.into(); + Bytes(inner) + } +} + +impl From> for Bytes { + #[inline] + fn from(value: BytesInner) -> Self { + Self(value) + } +} + +#[cfg(feature = "arrow_rs")] +pub(crate) fn to_buffer( + value: std::sync::Arc>, +) -> arrow_buffer::Buffer { + // This should never panic as ForeignVec pointer must be non-null + let ptr = std::ptr::NonNull::new(value.as_ptr() as _).unwrap(); + let len = value.len() * std::mem::size_of::(); + // Safety: allocation is guaranteed to be valid for `len` bytes + unsafe { arrow_buffer::Buffer::from_custom_allocation(ptr, len, value) } +} + +#[cfg(feature = "arrow_rs")] +pub(crate) fn to_bytes(value: arrow_buffer::Buffer) -> Bytes { + let ptr = value.as_ptr(); + let align = ptr.align_offset(std::mem::align_of::()); + assert_eq!(align, 0, "not aligned"); + let len = value.len() / std::mem::size_of::(); + + // Valid as `NativeType: Pod` and checked alignment above + let ptr = value.as_ptr() as *const T; + + let owner = crate::buffer::BytesAllocator::Arrow(value); + + // Safety: slice is valid for len elements of T + unsafe { Bytes::from_foreign(ptr, len, owner) } +} + +pub use immutable::Buffer; +pub(super) use iterator::IntoIter; diff --git a/crates/nano-arrow/src/chunk.rs b/crates/nano-arrow/src/chunk.rs new file mode 100644 index 000000000000..ffc857bcc134 --- /dev/null +++ b/crates/nano-arrow/src/chunk.rs @@ -0,0 +1,84 @@ +//! Contains [`Chunk`], a container of [`Array`] where every array has the +//! same length. + +use crate::array::Array; +use crate::error::{Error, Result}; + +/// A vector of trait objects of [`Array`] where every item has +/// the same length, [`Chunk::len`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Chunk> { + arrays: Vec, +} + +impl> Chunk { + /// Creates a new [`Chunk`]. + /// # Panic + /// Iff the arrays do not have the same length + pub fn new(arrays: Vec) -> Self { + Self::try_new(arrays).unwrap() + } + + /// Creates a new [`Chunk`]. + /// # Error + /// Iff the arrays do not have the same length + pub fn try_new(arrays: Vec) -> Result { + if !arrays.is_empty() { + let len = arrays.first().unwrap().as_ref().len(); + if arrays + .iter() + .map(|array| array.as_ref()) + .any(|array| array.len() != len) + { + return Err(Error::InvalidArgumentError( + "Chunk require all its arrays to have an equal number of rows".to_string(), + )); + } + } + Ok(Self { arrays }) + } + + /// returns the [`Array`]s in [`Chunk`] + pub fn arrays(&self) -> &[A] { + &self.arrays + } + + /// returns the [`Array`]s in [`Chunk`] + pub fn columns(&self) -> &[A] { + &self.arrays + } + + /// returns the number of rows of every array + pub fn len(&self) -> usize { + self.arrays + .first() + .map(|x| x.as_ref().len()) + .unwrap_or_default() + } + + /// returns whether the columns have any rows + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Consumes [`Chunk`] into its underlying arrays. + /// The arrays are guaranteed to have the same length + pub fn into_arrays(self) -> Vec { + self.arrays + } +} + +impl> From> for Vec { + fn from(c: Chunk) -> Self { + c.into_arrays() + } +} + +impl> std::ops::Deref for Chunk { + type Target = [A]; + + #[inline] + fn deref(&self) -> &[A] { + self.arrays() + } +} diff --git a/crates/nano-arrow/src/compute/README.md b/crates/nano-arrow/src/compute/README.md new file mode 100644 index 000000000000..6b5bec7e703e --- /dev/null +++ b/crates/nano-arrow/src/compute/README.md @@ -0,0 +1,32 @@ +# Design + +This document outlines the design guide lines of this module. + +This module is composed by independent operations common in analytics. Below are some design of its principles: + +- APIs MUST return an error when either: + - The arguments are incorrect + - The execution results in a predictable error (e.g. divide by zero) + +- APIs MAY error when an operation overflows (e.g. `i32 + i32`) + +- kernels MUST NOT have side-effects + +- kernels MUST NOT take ownership of any of its arguments (i.e. everything must be a reference). + +- APIs SHOULD error when an operation on variable sized containers can overflow the maximum size of `usize`. + +- Kernels SHOULD use the arrays' logical type to decide whether kernels + can be applied on an array. For example, `Date32 + Date32` is meaningless and SHOULD NOT be implemented. + +- Kernels SHOULD be implemented via `clone`, `slice` or the `iterator` API provided by `Buffer`, `Bitmap`, `Vec` or `MutableBitmap`. + +- Kernels MUST NOT use any API to read bits other than the ones provided by `Bitmap`. + +- Implementations SHOULD aim for auto-vectorization, which is usually accomplished via `from_trusted_len_iter`. + +- Implementations MUST feature-gate any implementation that requires external dependencies + +- When a kernel accepts dynamically-typed arrays, it MUST expect them as `&dyn Array`. + +- When an API returns `&dyn Array`, it MUST return `Box`. The rational is that a `Box` is mutable, while an `Arc` is not. As such, `Box` offers the most flexible API to consumers and the compiler. Users can cast a `Box` into `Arc` via `.into()`. diff --git a/crates/nano-arrow/src/compute/aggregate/memory.rs b/crates/nano-arrow/src/compute/aggregate/memory.rs new file mode 100644 index 000000000000..3af974a79b14 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/memory.rs @@ -0,0 +1,118 @@ +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; + +fn validity_size(validity: Option<&Bitmap>) -> usize { + validity.as_ref().map(|b| b.as_slice().0.len()).unwrap_or(0) +} + +macro_rules! dyn_binary { + ($array:expr, $ty:ty, $o:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + let offsets = array.offsets().buffer(); + + // in case of Binary/Utf8/List the offsets are sliced, + // not the values buffer + let values_start = offsets[0] as usize; + let values_end = offsets[offsets.len() - 1] as usize; + + values_end - values_start + + offsets.len() * std::mem::size_of::<$o>() + + validity_size(array.validity()) + }}; +} + +/// Returns the total (heap) allocated size of the array in bytes. +/// # Implementation +/// This estimation is the sum of the size of its buffers, validity, including nested arrays. +/// Multiple arrays may share buffers and bitmaps. Therefore, the size of 2 arrays is not the +/// sum of the sizes computed from this function. In particular, [`StructArray`]'s size is an upper bound. +/// +/// When an array is sliced, its allocated size remains constant because the buffer unchanged. +/// However, this function will yield a smaller number. This is because this function returns +/// the visible size of the buffer, not its total capacity. +/// +/// FFI buffers are included in this estimation. +pub fn estimated_bytes_size(array: &dyn Array) -> usize { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => 0, + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + array.values().as_slice().0.len() + validity_size(array.validity()) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + array.values().len() * std::mem::size_of::<$T>() + validity_size(array.validity()) + }), + Binary => dyn_binary!(array, BinaryArray, i32), + FixedSizeBinary => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + array.values().len() + validity_size(array.validity()) + }, + LargeBinary => dyn_binary!(array, BinaryArray, i64), + Utf8 => dyn_binary!(array, Utf8Array, i32), + LargeUtf8 => dyn_binary!(array, Utf8Array, i64), + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + estimated_bytes_size(array.values().as_ref()) + + array.offsets().len_proxy() * std::mem::size_of::() + + validity_size(array.validity()) + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + estimated_bytes_size(array.values().as_ref()) + validity_size(array.validity()) + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + estimated_bytes_size(array.values().as_ref()) + + array.offsets().len_proxy() * std::mem::size_of::() + + validity_size(array.validity()) + }, + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + array + .values() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::() + + validity_size(array.validity()) + }, + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + let types = array.types().len() * std::mem::size_of::(); + let offsets = array + .offsets() + .as_ref() + .map(|x| x.len() * std::mem::size_of::()) + .unwrap_or_default(); + let fields = array + .fields() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::(); + types + offsets + fields + }, + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + estimated_bytes_size(array.keys()) + estimated_bytes_size(array.values().as_ref()) + }), + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let offsets = array.offsets().len_proxy() * std::mem::size_of::(); + offsets + estimated_bytes_size(array.field().as_ref()) + validity_size(array.validity()) + }, + } +} diff --git a/crates/nano-arrow/src/compute/aggregate/min_max.rs b/crates/nano-arrow/src/compute/aggregate/min_max.rs new file mode 100644 index 000000000000..e733c6657ccd --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/min_max.rs @@ -0,0 +1,416 @@ +#![allow(clippy::redundant_closure_call)] +use multiversion::multiversion; + +use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; +use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; +use crate::error::{Error, Result}; +use crate::offset::Offset; +use crate::scalar::*; +use crate::types::simd::*; +use crate::types::NativeType; + +/// Trait describing a type describing multiple lanes with an order relationship +/// consistent with the same order of `T`. +pub trait SimdOrd { + /// The minimum value + const MIN: T; + /// The maximum value + const MAX: T; + /// reduce itself to the minimum + fn max_element(self) -> T; + /// reduce itself to the maximum + fn min_element(self) -> T; + /// lane-wise maximum between two instances + fn max_lane(self, x: Self) -> Self; + /// lane-wise minimum between two instances + fn min_lane(self, x: Self) -> Self; + /// returns a new instance with all lanes equal to `MIN` + fn new_min() -> Self; + /// returns a new instance with all lanes equal to `MAX` + fn new_max() -> Self; +} + +#[multiversion(targets = "simd")] +fn nonnull_min_primitive(values: &[T]) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let chunks = values.chunks_exact(T::Simd::LANES); + let remainder = chunks.remainder(); + + let chunk_reduced = chunks.fold(T::Simd::new_min(), |acc, chunk| { + let chunk = T::Simd::from_chunk(chunk); + acc.min_lane(chunk) + }); + + let remainder = T::Simd::from_incomplete_chunk(remainder, T::Simd::MAX); + let reduced = chunk_reduced.min_lane(remainder); + + reduced.min_element() +} + +#[multiversion(targets = "simd")] +fn null_min_primitive_impl(values: &[T], mut validity_masks: I) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, + I: BitChunkIterExact<<::Simd as NativeSimd>::Chunk>, +{ + let mut chunks = values.chunks_exact(T::Simd::LANES); + + let chunk_reduced = chunks.by_ref().zip(validity_masks.by_ref()).fold( + T::Simd::new_min(), + |acc, (chunk, validity_chunk)| { + let chunk = T::Simd::from_chunk(chunk); + let mask = ::Mask::from_chunk(validity_chunk); + let chunk = chunk.select(mask, T::Simd::new_min()); + acc.min_lane(chunk) + }, + ); + + let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::Simd::MAX); + let mask = ::Mask::from_chunk(validity_masks.remainder()); + let remainder = remainder.select(mask, T::Simd::new_min()); + let reduced = chunk_reduced.min_lane(remainder); + + reduced.min_element() +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +fn null_min_primitive(values: &[T], bitmap: &Bitmap) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let (slice, offset, length) = bitmap.as_slice(); + if offset == 0 { + let validity_masks = BitChunksExact::<::Chunk>::new(slice, length); + null_min_primitive_impl(values, validity_masks) + } else { + let validity_masks = bitmap.chunks::<::Chunk>(); + null_min_primitive_impl(values, validity_masks) + } +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +fn null_max_primitive(values: &[T], bitmap: &Bitmap) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let (slice, offset, length) = bitmap.as_slice(); + if offset == 0 { + let validity_masks = BitChunksExact::<::Chunk>::new(slice, length); + null_max_primitive_impl(values, validity_masks) + } else { + let validity_masks = bitmap.chunks::<::Chunk>(); + null_max_primitive_impl(values, validity_masks) + } +} + +#[multiversion(targets = "simd")] +fn nonnull_max_primitive(values: &[T]) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let chunks = values.chunks_exact(T::Simd::LANES); + let remainder = chunks.remainder(); + + let chunk_reduced = chunks.fold(T::Simd::new_max(), |acc, chunk| { + let chunk = T::Simd::from_chunk(chunk); + acc.max_lane(chunk) + }); + + let remainder = T::Simd::from_incomplete_chunk(remainder, T::Simd::MIN); + let reduced = chunk_reduced.max_lane(remainder); + + reduced.max_element() +} + +#[multiversion(targets = "simd")] +fn null_max_primitive_impl(values: &[T], mut validity_masks: I) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, + I: BitChunkIterExact<<::Simd as NativeSimd>::Chunk>, +{ + let mut chunks = values.chunks_exact(T::Simd::LANES); + + let chunk_reduced = chunks.by_ref().zip(validity_masks.by_ref()).fold( + T::Simd::new_max(), + |acc, (chunk, validity_chunk)| { + let chunk = T::Simd::from_chunk(chunk); + let mask = ::Mask::from_chunk(validity_chunk); + let chunk = chunk.select(mask, T::Simd::new_max()); + acc.max_lane(chunk) + }, + ); + + let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::Simd::MIN); + let mask = ::Mask::from_chunk(validity_masks.remainder()); + let remainder = remainder.select(mask, T::Simd::new_max()); + let reduced = chunk_reduced.max_lane(remainder); + + reduced.max_element() +} + +/// Returns the minimum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn min_primitive(array: &PrimitiveArray) -> Option +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let null_count = array.null_count(); + + // Includes case array.len() == 0 + if null_count == array.len() { + return None; + } + let values = array.values(); + + Some(if let Some(validity) = array.validity() { + null_min_primitive(values, validity) + } else { + nonnull_min_primitive(values) + }) +} + +/// Returns the maximum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn max_primitive(array: &PrimitiveArray) -> Option +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let null_count = array.null_count(); + + // Includes case array.len() == 0 + if null_count == array.len() { + return None; + } + let values = array.values(); + + Some(if let Some(validity) = array.validity() { + null_max_primitive(values, validity) + } else { + nonnull_max_primitive(values) + }) +} + +/// Helper to compute min/max of [`BinaryArray`] and [`Utf8Array`] +macro_rules! min_max_binary_utf8 { + ($array: expr, $cmp: expr) => { + if $array.null_count() == $array.len() { + None + } else if $array.validity().is_some() { + $array + .iter() + .reduce(|v1, v2| match (v1, v2) { + (None, v2) => v2, + (v1, None) => v1, + (Some(v1), Some(v2)) => { + if $cmp(v1, v2) { + Some(v2) + } else { + Some(v1) + } + }, + }) + .unwrap_or(None) + } else { + $array + .values_iter() + .reduce(|v1, v2| if $cmp(v1, v2) { v2 } else { v1 }) + } + }; +} + +/// Returns the maximum value in the binary array, according to the natural order. +pub fn max_binary(array: &BinaryArray) -> Option<&[u8]> { + min_max_binary_utf8!(array, |a, b| a < b) +} + +/// Returns the minimum value in the binary array, according to the natural order. +pub fn min_binary(array: &BinaryArray) -> Option<&[u8]> { + min_max_binary_utf8!(array, |a, b| a > b) +} + +/// Returns the maximum value in the string array, according to the natural order. +pub fn max_string(array: &Utf8Array) -> Option<&str> { + min_max_binary_utf8!(array, |a, b| a < b) +} + +/// Returns the minimum value in the string array, according to the natural order. +pub fn min_string(array: &Utf8Array) -> Option<&str> { + min_max_binary_utf8!(array, |a, b| a > b) +} + +/// Returns the minimum value in the boolean array. +/// +/// ``` +/// use arrow2::{ +/// array::BooleanArray, +/// compute::aggregate::min_boolean, +/// }; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(min_boolean(&a), Some(false)) +/// ``` +pub fn min_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + Some(array.values().unset_bits() == 0) + } else { + // Note the min bool is false (0), so short circuit as soon as we see it + array + .iter() + .find(|&b| b == Some(false)) + .flatten() + .or(Some(true)) + } +} + +/// Returns the maximum value in the boolean array +/// +/// ``` +/// use arrow2::{ +/// array::BooleanArray, +/// compute::aggregate::max_boolean, +/// }; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(max_boolean(&a), Some(true)) +/// ``` +pub fn max_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + Some(array.values().unset_bits() < array.len()) + } else { + // Note the max bool is true (1), so short circuit as soon as we see it + array + .iter() + .find(|&b| b == Some(true)) + .flatten() + .or(Some(false)) + } +} + +macro_rules! dyn_generic { + ($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{ + let array = $array.as_any().downcast_ref::<$array_ty>().unwrap(); + Box::new(<$scalar_ty>::new($f(array))) + }}; +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => return Err(Error::InvalidArgumentError(format!( + "`min` and `max` operator do not support primitive `{:?}`", + $key_type, + ))), + } +})} + +/// Returns the maximum of [`Array`]. The scalar is null when all elements are null. +/// # Error +/// Errors iff the type does not support this operation. +pub fn max(array: &dyn Array) -> Result> { + Ok(match array.data_type().to_physical_type() { + PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean), + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::<$T>::new(data_type, max_primitive::<$T>(array))) + }), + PhysicalType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + PhysicalType::Binary => { + dyn_generic!(BinaryArray, BinaryScalar, array, max_binary) + }, + PhysicalType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + }, + _ => { + return Err(Error::InvalidArgumentError(format!( + "The `max` operator does not support type `{:?}`", + array.data_type(), + ))) + }, + }) +} + +/// Returns the minimum of [`Array`]. The scalar is null when all elements are null. +/// # Error +/// Errors iff the type does not support this operation. +pub fn min(array: &dyn Array) -> Result> { + Ok(match array.data_type().to_physical_type() { + PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean), + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::<$T>::new(data_type, min_primitive::<$T>(array))) + }), + PhysicalType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + PhysicalType::Binary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + }, + PhysicalType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + }, + _ => { + return Err(Error::InvalidArgumentError(format!( + "The `max` operator does not support type `{:?}`", + array.data_type(), + ))) + }, + }) +} + +/// Whether [`min`] supports `data_type` +pub fn can_min(data_type: &DataType) -> bool { + let physical = data_type.to_physical_type(); + if let PhysicalType::Primitive(primitive) = physical { + use PrimitiveType::*; + matches!( + primitive, + Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 + ) + } else { + use PhysicalType::*; + matches!(physical, Boolean | Utf8 | LargeUtf8 | Binary | LargeBinary) + } +} + +/// Whether [`max`] supports `data_type` +pub fn can_max(data_type: &DataType) -> bool { + can_min(data_type) +} diff --git a/crates/nano-arrow/src/compute/aggregate/mod.rs b/crates/nano-arrow/src/compute/aggregate/mod.rs new file mode 100644 index 000000000000..b513238f9fd9 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/mod.rs @@ -0,0 +1,15 @@ +//! Contains different aggregation functions +#[cfg(feature = "compute_aggregate")] +mod sum; +#[cfg(feature = "compute_aggregate")] +pub use sum::*; + +#[cfg(feature = "compute_aggregate")] +mod min_max; +#[cfg(feature = "compute_aggregate")] +pub use min_max::*; + +mod memory; +pub use memory::*; +#[cfg(feature = "compute_aggregate")] +mod simd; diff --git a/crates/nano-arrow/src/compute/aggregate/simd/mod.rs b/crates/nano-arrow/src/compute/aggregate/simd/mod.rs new file mode 100644 index 000000000000..25558e9a9e19 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/simd/mod.rs @@ -0,0 +1,109 @@ +use std::ops::Add; + +use super::{SimdOrd, Sum}; +use crate::types::simd::{i128x8, NativeSimd}; + +macro_rules! simd_add { + ($simd:tt, $type:ty, $lanes:expr, $add:tt) => { + impl std::ops::AddAssign for $simd { + #[inline] + fn add_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self[i] = <$type>::$add(self[i], rhs[i]); + } + } + } + + impl std::ops::Add for $simd { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut result = Self::default(); + for i in 0..$lanes { + result[i] = <$type>::$add(self[i], rhs[i]); + } + result + } + } + + impl Sum<$type> for $simd { + #[inline] + fn simd_sum(self) -> $type { + let mut reduced = <$type>::default(); + (0..<$simd>::LANES).for_each(|i| { + reduced += self[i]; + }); + reduced + } + } + }; +} + +macro_rules! simd_ord_int { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::MIN; + const MAX: $type = <$type>::MAX; + + #[inline] + fn max_element(self) -> $type { + self.0.iter().copied().fold(Self::MIN, <$type>::max) + } + + #[inline] + fn min_element(self) -> $type { + self.0.iter().copied().fold(Self::MAX, <$type>::min) + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).max(*c)); + result + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).min(*c)); + result + } + + #[inline] + fn new_min() -> Self { + Self([Self::MAX; <$simd>::LANES]) + } + + #[inline] + fn new_max() -> Self { + Self([Self::MIN; <$simd>::LANES]) + } + } + }; +} + +pub(super) use {simd_add, simd_ord_int}; + +simd_add!(i128x8, i128, 8, add); +simd_ord_int!(i128x8, i128); + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +#[cfg_attr(docsrs, doc(cfg(feature = "simd")))] +pub use packed::*; diff --git a/crates/nano-arrow/src/compute/aggregate/simd/native.rs b/crates/nano-arrow/src/compute/aggregate/simd/native.rs new file mode 100644 index 000000000000..d6a0275f35e9 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/simd/native.rs @@ -0,0 +1,81 @@ +use std::ops::Add; + +use super::super::min_max::SimdOrd; +use super::super::sum::Sum; +use super::{simd_add, simd_ord_int}; +use crate::types::simd::*; + +simd_add!(u8x64, u8, 64, wrapping_add); +simd_add!(u16x32, u16, 32, wrapping_add); +simd_add!(u32x16, u32, 16, wrapping_add); +simd_add!(u64x8, u64, 8, wrapping_add); +simd_add!(i8x64, i8, 64, wrapping_add); +simd_add!(i16x32, i16, 32, wrapping_add); +simd_add!(i32x16, i32, 16, wrapping_add); +simd_add!(i64x8, i64, 8, wrapping_add); +simd_add!(f32x16, f32, 16, add); +simd_add!(f64x8, f64, 8, add); + +macro_rules! simd_ord_float { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::NAN; + const MAX: $type = <$type>::NAN; + + #[inline] + fn max_element(self) -> $type { + self.0.iter().copied().fold(Self::MIN, <$type>::max) + } + + #[inline] + fn min_element(self) -> $type { + self.0.iter().copied().fold(Self::MAX, <$type>::min) + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).max(*c)); + result + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).min(*c)); + result + } + + #[inline] + fn new_min() -> Self { + Self([Self::MAX; <$simd>::LANES]) + } + + #[inline] + fn new_max() -> Self { + Self([Self::MIN; <$simd>::LANES]) + } + } + }; +} + +simd_ord_int!(u8x64, u8); +simd_ord_int!(u16x32, u16); +simd_ord_int!(u32x16, u32); +simd_ord_int!(u64x8, u64); +simd_ord_int!(i8x64, i8); +simd_ord_int!(i16x32, i16); +simd_ord_int!(i32x16, i32); +simd_ord_int!(i64x8, i64); +simd_ord_float!(f32x16, f32); +simd_ord_float!(f64x8, f64); diff --git a/crates/nano-arrow/src/compute/aggregate/simd/packed.rs b/crates/nano-arrow/src/compute/aggregate/simd/packed.rs new file mode 100644 index 000000000000..40094d31e239 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/simd/packed.rs @@ -0,0 +1,116 @@ +use std::simd::{SimdFloat as _, SimdInt as _, SimdOrd as _, SimdUint as _}; + +use super::super::min_max::SimdOrd; +use super::super::sum::Sum; +use crate::types::simd::*; + +macro_rules! simd_sum { + ($simd:tt, $type:ty, $sum:tt) => { + impl Sum<$type> for $simd { + #[inline] + fn simd_sum(self) -> $type { + self.$sum() + } + } + }; +} + +simd_sum!(f32x16, f32, reduce_sum); +simd_sum!(f64x8, f64, reduce_sum); +simd_sum!(u8x64, u8, reduce_sum); +simd_sum!(u16x32, u16, reduce_sum); +simd_sum!(u32x16, u32, reduce_sum); +simd_sum!(u64x8, u64, reduce_sum); +simd_sum!(i8x64, i8, reduce_sum); +simd_sum!(i16x32, i16, reduce_sum); +simd_sum!(i32x16, i32, reduce_sum); +simd_sum!(i64x8, i64, reduce_sum); + +macro_rules! simd_ord_int { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::MIN; + const MAX: $type = <$type>::MAX; + + #[inline] + fn max_element(self) -> $type { + self.reduce_max() + } + + #[inline] + fn min_element(self) -> $type { + self.reduce_min() + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + self.simd_max(x) + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + self.simd_min(x) + } + + #[inline] + fn new_min() -> Self { + Self::splat(Self::MAX) + } + + #[inline] + fn new_max() -> Self { + Self::splat(Self::MIN) + } + } + }; +} + +macro_rules! simd_ord_float { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::NAN; + const MAX: $type = <$type>::NAN; + + #[inline] + fn max_element(self) -> $type { + self.reduce_max() + } + + #[inline] + fn min_element(self) -> $type { + self.reduce_min() + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + self.simd_max(x) + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + self.simd_min(x) + } + + #[inline] + fn new_min() -> Self { + Self::splat(<$type>::NAN) + } + + #[inline] + fn new_max() -> Self { + Self::splat(<$type>::NAN) + } + } + }; +} + +simd_ord_int!(u8x64, u8); +simd_ord_int!(u16x32, u16); +simd_ord_int!(u32x16, u32); +simd_ord_int!(u64x8, u64); +simd_ord_int!(i8x64, i8); +simd_ord_int!(i16x32, i16); +simd_ord_int!(i32x16, i32); +simd_ord_int!(i64x8, i64); +simd_ord_float!(f32x16, f32); +simd_ord_float!(f64x8, f64); diff --git a/crates/nano-arrow/src/compute/aggregate/sum.rs b/crates/nano-arrow/src/compute/aggregate/sum.rs new file mode 100644 index 000000000000..738440c9f0d2 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/sum.rs @@ -0,0 +1,159 @@ +use std::ops::Add; + +use multiversion::multiversion; + +use crate::array::{Array, PrimitiveArray}; +use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; +use crate::error::{Error, Result}; +use crate::scalar::*; +use crate::types::simd::*; +use crate::types::NativeType; + +/// Object that can reduce itself to a number. This is used in the context of SIMD to reduce +/// a MD (e.g. `[f32; 16]`) into a single number (`f32`). +pub trait Sum { + /// Reduces this element to a single value. + fn simd_sum(self) -> T; +} + +#[multiversion(targets = "simd")] +/// Compute the sum of a slice +pub fn sum_slice(values: &[T]) -> T +where + T: NativeType + Simd + Add + std::iter::Sum, + T::Simd: Sum + Add, +{ + let (head, simd_vals, tail) = T::Simd::align(values); + + let mut reduced = T::Simd::from_incomplete_chunk(&[], T::default()); + for chunk in simd_vals { + reduced = reduced + *chunk; + } + + reduced.simd_sum() + head.iter().copied().sum() + tail.iter().copied().sum() +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +#[multiversion(targets = "simd")] +fn null_sum_impl(values: &[T], mut validity_masks: I) -> T +where + T: NativeType + Simd, + T::Simd: Add + Sum, + I: BitChunkIterExact<<::Simd as NativeSimd>::Chunk>, +{ + let mut chunks = values.chunks_exact(T::Simd::LANES); + + let sum = chunks.by_ref().zip(validity_masks.by_ref()).fold( + T::Simd::default(), + |acc, (chunk, validity_chunk)| { + let chunk = T::Simd::from_chunk(chunk); + let mask = ::Mask::from_chunk(validity_chunk); + let selected = chunk.select(mask, T::Simd::default()); + acc + selected + }, + ); + + let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::default()); + let mask = ::Mask::from_chunk(validity_masks.remainder()); + let remainder = remainder.select(mask, T::Simd::default()); + let reduced = sum + remainder; + + reduced.simd_sum() +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +fn null_sum(values: &[T], bitmap: &Bitmap) -> T +where + T: NativeType + Simd, + T::Simd: Add + Sum, +{ + let (slice, offset, length) = bitmap.as_slice(); + if offset == 0 { + let validity_masks = BitChunksExact::<::Chunk>::new(slice, length); + null_sum_impl(values, validity_masks) + } else { + let validity_masks = bitmap.chunks::<::Chunk>(); + null_sum_impl(values, validity_masks) + } +} + +/// Returns the sum of values in the array. +/// +/// Returns `None` if the array is empty or only contains null values. +pub fn sum_primitive(array: &PrimitiveArray) -> Option +where + T: NativeType + Simd + Add + std::iter::Sum, + T::Simd: Add + Sum, +{ + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + match array.validity() { + None => Some(sum_slice(array.values())), + Some(bitmap) => Some(null_sum(array.values(), bitmap)), + } +} + +/// Whether [`sum`] supports `data_type` +pub fn can_sum(data_type: &DataType) -> bool { + if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() { + use PrimitiveType::*; + matches!( + primitive, + Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 + ) + } else { + false + } +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => return Err(Error::InvalidArgumentError(format!( + "`sum` operator do not support primitive `{:?}`", + $key_type, + ))), + } +})} + +/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical +/// and logical types as `array`. +/// # Error +/// Errors iff the operation is not supported. +pub fn sum(array: &dyn Array) -> Result> { + Ok(match array.data_type().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array))) + }), + _ => { + return Err(Error::InvalidArgumentError(format!( + "The `sum` operator does not support type `{:?}`", + array.data_type(), + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/add.rs b/crates/nano-arrow/src/compute/arithmetics/basic/add.rs new file mode 100644 index 000000000000..5919b65fdbd5 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/add.rs @@ -0,0 +1,337 @@ +//! Definition of basic add operations with primitive arrays +use std::ops::Add; + +use num_traits::ops::overflowing::OverflowingAdd; +use num_traits::{CheckedAdd, SaturatingAdd, WrappingAdd}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, ArrayWrappingAdd, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; + +/// Adds two primitive arrays with the same type. +/// Panics if the sum of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); +/// let b = PrimitiveArray::from([Some(5), None, None, Some(6)]); +/// let result = add(&a, &b); +/// let expected = PrimitiveArray::from([None, None, None, Some(12)]); +/// assert_eq!(result, expected) +/// ``` +pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Add, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b) +} + +/// Wrapping addition of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(-100i8), Some(100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = wrapping_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(-100i8), Some(-56i8), Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + let op = move |a: T, b: T| a.wrapping_add(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked addition of two primitive arrays. If the result from the sum +/// overflows, the validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8), Some(100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = checked_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(100i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + let op = move |a: T, b: T| a.checked_add(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating addition of two primitive arrays. If the result from the sum is +/// larger than the possible number for this type, the result for the operation +/// will be the saturated value. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8)]); +/// let b = PrimitiveArray::from([Some(100i8)]); +/// let result = saturating_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(127)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + let op = move |a: T, b: T| a.saturating_add(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing addition of two primitive arrays. If the result from the sum is +/// larger than the possible number for this type, the result for the operation +/// will be an array with overflowed values and a validity array indicating +/// the overflowing elements from the array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(2i8), Some(-56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingAdd, +{ + let op = move |a: T, b: T| a.overflowing_add(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays +impl ArrayAdd> for PrimitiveArray +where + T: NativeArithmetics + Add, +{ + fn add(&self, rhs: &PrimitiveArray) -> Self { + add(self, rhs) + } +} + +impl ArrayWrappingAdd> for PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + fn wrapping_add(&self, rhs: &PrimitiveArray) -> Self { + wrapping_add(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays +impl ArrayCheckedAdd> for PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + fn checked_add(&self, rhs: &PrimitiveArray) -> Self { + checked_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArraySaturatingAdd> for PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { + saturating_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArrayOverflowingAdd> for PrimitiveArray +where + T: NativeArithmetics + OverflowingAdd, +{ + fn overflowing_add(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_add(self, rhs) + } +} + +/// Adds a scalar T to a primitive array of type T. +/// Panics if the sum of the values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::add_scalar; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); +/// let result = add_scalar(&a, &1i32); +/// let expected = PrimitiveArray::from([None, Some(7), None, Some(7)]); +/// assert_eq!(result, expected) +/// ``` +pub fn add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Add, +{ + let rhs = *rhs; + unary(lhs, |a| a + rhs, lhs.data_type().clone()) +} + +/// Wrapping addition of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_add_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100)]); +/// let result = wrapping_add_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, Some(-56)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + unary(lhs, |a| a.wrapping_add(rhs), lhs.data_type().clone()) +} + +/// Checked addition of a scalar T to a primitive array of type T. If the +/// result from the sum overflows then the validity index for that value is +/// changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_add_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); +/// let result = checked_add_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_add(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated addition of a scalar T to a primitive array of type T. If the +/// result from the sum is larger than the possible number for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_add_scalar; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8)]); +/// let result = saturating_add_scalar(&a, &100i8); +/// let expected = PrimitiveArray::from([Some(127)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_add(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing addition of a scalar T to a primitive array of type T. If the +/// result from the sum is larger than the possible number for this type, then +/// the result will be an array with overflowed values and a validity array +/// indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_add_scalar; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_add_scalar(&a, &100i8); +/// let expected = PrimitiveArray::from([Some(101i8), Some(-56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_add(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays with a scalar +impl ArrayAdd for PrimitiveArray +where + T: NativeArithmetics + Add, +{ + fn add(&self, rhs: &T) -> Self { + add_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays with a scalar +impl ArrayCheckedAdd for PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + fn checked_add(&self, rhs: &T) -> Self { + checked_add_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar +impl ArraySaturatingAdd for PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + fn saturating_add(&self, rhs: &T) -> Self { + saturating_add_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar +impl ArrayOverflowingAdd for PrimitiveArray +where + T: NativeArithmetics + OverflowingAdd, +{ + fn overflowing_add(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_add_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/div.rs b/crates/nano-arrow/src/compute/arithmetics/basic/div.rs new file mode 100644 index 000000000000..a5137ca9a0cc --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/div.rs @@ -0,0 +1,203 @@ +//! Definition of basic div operations with primitive arrays +use std::ops::Div; + +use num_traits::{CheckedDiv, NumCast}; +use strength_reduce::{ + StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, +}; + +use super::NativeArithmetics; +use crate::array::{Array, PrimitiveArray}; +use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; +use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; +use crate::compute::utils::check_same_len; +use crate::datatypes::PrimitiveType; + +/// Divides two primitive arrays with the same type. +/// Panics if the divisor is zero of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::div; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[Some(10), Some(1), Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, Some(6)]); +/// let result = div(&a, &b); +/// let expected = Int32Array::from(&[Some(2), None, Some(1)]); +/// assert_eq!(result, expected) +/// ``` +pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Div, +{ + if rhs.null_count() == 0 { + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b) + } else { + check_same_len(lhs, rhs).unwrap(); + let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) { + (Some(l), Some(r)) => Some(*l / *r), + _ => None, + }); + + PrimitiveArray::from_trusted_len_iter(values).to(lhs.data_type().clone()) + } +} + +/// Checked division of two primitive arrays. If the result from the division +/// overflows, the result for the operation will change the validity array +/// making this operation None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_div; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); +/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); +/// let result = checked_div(&a, &b); +/// let expected = Int8Array::from(&[Some(-1i8), None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + let op = move |a: T, b: T| a.checked_div(&b); + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays +impl ArrayDiv> for PrimitiveArray +where + T: NativeArithmetics + Div, +{ + fn div(&self, rhs: &PrimitiveArray) -> Self { + div(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays +impl ArrayCheckedDiv> for PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + fn checked_div(&self, rhs: &PrimitiveArray) -> Self { + checked_div(self, rhs) + } +} + +/// Divide a primitive array of type T by a scalar T. +/// Panics if the divisor is zero. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::div_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = div_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(3), None, Some(3)]); +/// assert_eq!(result, expected) +/// ``` +pub fn div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Div + NumCast, +{ + let rhs = *rhs; + match T::PRIMITIVE { + PrimitiveType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u64().unwrap(); + + let reduced_div = StrengthReducedU64::new(rhs); + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u32().unwrap(); + + let reduced_div = StrengthReducedU32::new(rhs); + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u16().unwrap(); + + let reduced_div = StrengthReducedU16::new(rhs); + + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u8().unwrap(); + + let reduced_div = StrengthReducedU8::new(rhs); + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + _ => unary(lhs, |a| a / rhs, lhs.data_type().clone()), + } +} + +/// Checked division of a primitive array of type T by a scalar T. If the +/// divisor is zero then the validity array is changed to None. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_div_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = checked_div_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-1i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_div(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays with a scalar +impl ArrayDiv for PrimitiveArray +where + T: NativeArithmetics + Div + NumCast, +{ + fn div(&self, rhs: &T) -> Self { + div_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays with a scalar +impl ArrayCheckedDiv for PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + fn checked_div(&self, rhs: &T) -> Self { + checked_div_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/mod.rs b/crates/nano-arrow/src/compute/arithmetics/basic/mod.rs new file mode 100644 index 000000000000..898a69f59536 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/mod.rs @@ -0,0 +1,100 @@ +//! Contains arithmetic functions for [`PrimitiveArray`]s. +//! +//! Each operation has four variants, like the rest of Rust's ecosystem: +//! * usual, that [`panic!`]s on overflow +//! * `checked_*` that turns overflowings to `None` +//! * `overflowing_*` returning a [`Bitmap`](crate::bitmap::Bitmap) with items that overflow. +//! * `saturating_*` that saturates the result. +mod add; +pub use add::*; +mod div; +pub use div::*; +mod mul; +pub use mul::*; +mod pow; +pub use pow::*; +mod rem; +pub use rem::*; +mod sub; +use std::ops::Neg; + +use num_traits::{CheckedNeg, WrappingNeg}; +pub use sub::*; + +use super::super::arity::{unary, unary_checked}; +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +/// Trait describing a [`NativeType`] whose semantics of arithmetic in Arrow equals +/// the semantics in Rust. +/// A counter example is `i128`, that in arrow represents a decimal while in rust represents +/// a signed integer. +pub trait NativeArithmetics: NativeType {} +impl NativeArithmetics for u8 {} +impl NativeArithmetics for u16 {} +impl NativeArithmetics for u32 {} +impl NativeArithmetics for u64 {} +impl NativeArithmetics for i8 {} +impl NativeArithmetics for i16 {} +impl NativeArithmetics for i32 {} +impl NativeArithmetics for i64 {} +impl NativeArithmetics for f32 {} +impl NativeArithmetics for f64 {} + +/// Negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::negate; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(7)]); +/// let result = negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); +/// assert_eq!(result, expected) +/// ``` +pub fn negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Neg, +{ + unary(array, |a| -a, array.data_type().clone()) +} + +/// Checked negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_negate; +/// use arrow2::array::{Array, PrimitiveArray}; +/// +/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); +/// let result = checked_negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); +/// assert_eq!(result, expected); +/// assert!(!result.is_valid(2)) +/// ``` +pub fn checked_negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + CheckedNeg, +{ + unary_checked(array, |a| a.checked_neg(), array.data_type().clone()) +} + +/// Wrapping negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_negate; +/// use arrow2::array::{Array, PrimitiveArray}; +/// +/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); +/// let result = wrapping_negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), Some(i8::MIN), Some(-7)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + WrappingNeg, +{ + unary(array, |a| a.wrapping_neg(), array.data_type().clone()) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/mul.rs b/crates/nano-arrow/src/compute/arithmetics/basic/mul.rs new file mode 100644 index 000000000000..e006abe186e5 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/mul.rs @@ -0,0 +1,338 @@ +//! Definition of basic mul operations with primitive arrays +use std::ops::Mul; + +use num_traits::ops::overflowing::OverflowingMul; +use num_traits::{CheckedMul, SaturatingMul, WrappingMul}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayCheckedMul, ArrayMul, ArrayOverflowingMul, ArraySaturatingMul, ArrayWrappingMul, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; + +/// Multiplies two primitive arrays with the same type. +/// Panics if the multiplication of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::mul; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); +/// let result = mul(&a, &b); +/// let expected = Int32Array::from(&[None, None, None, Some(36)]); +/// assert_eq!(result, expected) +/// ``` +pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b) +} + +/// Wrapping multiplication of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_mul; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8), Some(0x10i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(0x10i8), Some(0i8)]); +/// let result = wrapping_mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(0), Some(0), Some(0)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + let op = move |a: T, b: T| a.wrapping_mul(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked multiplication of two primitive arrays. If the result from the +/// multiplications overflows, the validity for that index is changed +/// returned. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_mul; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(100i8), Some(100i8), Some(100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(1i8)]); +/// let result = checked_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(100i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + let op = move |a: T, b: T| a.checked_mul(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating multiplication of two primitive arrays. If the result from the +/// multiplication overflows, the result for the +/// operation will be the saturated value. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_mul; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let b = Int8Array::from(&[Some(100i8)]); +/// let result = saturating_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(-128)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + let op = move |a: T, b: T| a.saturating_mul(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing multiplication of two primitive arrays. If the result from the +/// mul overflows, the result for the operation will be an array with +/// overflowed values and a validity array indicating the overflowing elements +/// from the array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_mul; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(1i8), Some(-16i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingMul, +{ + let op = move |a: T, b: T| a.overflowing_mul(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayMul trait for PrimitiveArrays +impl ArrayMul> for PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + fn mul(&self, rhs: &PrimitiveArray) -> Self { + mul(self, rhs) + } +} + +impl ArrayWrappingMul> for PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + fn wrapping_mul(&self, rhs: &PrimitiveArray) -> Self { + wrapping_mul(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays +impl ArrayCheckedMul> for PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { + checked_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArraySaturatingMul> for PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { + saturating_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArrayOverflowingMul> for PrimitiveArray +where + T: NativeArithmetics + OverflowingMul, +{ + fn overflowing_mul(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_mul(self, rhs) + } +} + +/// Multiply a scalar T to a primitive array of type T. +/// Panics if the multiplication of the values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::mul_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = mul_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(12), None, Some(12)]); +/// assert_eq!(result, expected) +/// ``` +pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + let rhs = *rhs; + unary(lhs, |a| a * rhs, lhs.data_type().clone()) +} + +/// Wrapping multiplication of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(0x10)]); +/// let result = wrapping_mul_scalar(&a, &0x10); +/// let expected = Int8Array::from(&[None, Some(0)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + unary(lhs, |a| a.wrapping_mul(rhs), lhs.data_type().clone()) +} + +/// Checked multiplication of a scalar T to a primitive array of type T. If the +/// result from the multiplication overflows, then the validity for that index is +/// changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); +/// let result = checked_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_mul(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated multiplication of a scalar T to a primitive array of type T. If the +/// result from the mul overflows for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = saturating_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-128i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_mul(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing multiplication of a scalar T to a primitive array of type T. If +/// the result from the mul overflows for this type, +/// then the result will be an array with overflowed values and a validity +/// array indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(100i8), Some(16i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingMul, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_mul(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayMul trait for PrimitiveArrays with a scalar +impl ArrayMul for PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + fn mul(&self, rhs: &T) -> Self { + mul_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays with a scalar +impl ArrayCheckedMul for PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + fn checked_mul(&self, rhs: &T) -> Self { + checked_mul_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar +impl ArraySaturatingMul for PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + fn saturating_mul(&self, rhs: &T) -> Self { + saturating_mul_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar +impl ArrayOverflowingMul for PrimitiveArray +where + T: NativeArithmetics + OverflowingMul, +{ + fn overflowing_mul(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_mul_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/pow.rs b/crates/nano-arrow/src/compute/arithmetics/basic/pow.rs new file mode 100644 index 000000000000..ea8908db6a51 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/pow.rs @@ -0,0 +1,49 @@ +//! Definition of basic pow operations with primitive arrays +use num_traits::{checked_pow, CheckedMul, One, Pow}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::compute::arity::{unary, unary_checked}; + +/// Raises an array of primitives to the power of exponent. Panics if one of +/// the values values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::powf_scalar; +/// use arrow2::array::Float32Array; +/// +/// let a = Float32Array::from(&[Some(2f32), None]); +/// let actual = powf_scalar(&a, 2.0); +/// let expected = Float32Array::from(&[Some(4f32), None]); +/// assert_eq!(expected, actual); +/// ``` +pub fn powf_scalar(array: &PrimitiveArray, exponent: T) -> PrimitiveArray +where + T: NativeArithmetics + Pow, +{ + unary(array, |x| x.pow(exponent), array.data_type().clone()) +} + +/// Checked operation of raising an array of primitives to the power of +/// exponent. If the result from the multiplications overflows, the validity +/// for that index is changed returned. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_powf_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), None, Some(7i8)]); +/// let actual = checked_powf_scalar(&a, 8usize); +/// let expected = Int8Array::from(&[Some(1i8), None, None]); +/// assert_eq!(expected, actual); +/// ``` +pub fn checked_powf_scalar(array: &PrimitiveArray, exponent: usize) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul + One, +{ + let op = move |a: T| checked_pow(a, exponent); + + unary_checked(array, op, array.data_type().clone()) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/rem.rs b/crates/nano-arrow/src/compute/arithmetics/basic/rem.rs new file mode 100644 index 000000000000..6c400fce2b07 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/rem.rs @@ -0,0 +1,196 @@ +use std::ops::Rem; + +use num_traits::{CheckedRem, NumCast}; +use strength_reduce::{ + StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, +}; + +use super::NativeArithmetics; +use crate::array::{Array, PrimitiveArray}; +use crate::compute::arithmetics::{ArrayCheckedRem, ArrayRem}; +use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; +use crate::datatypes::PrimitiveType; + +/// Remainder of two primitive arrays with the same type. +/// Panics if the divisor is zero of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::rem; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[Some(10), Some(7)]); +/// let b = Int32Array::from(&[Some(5), Some(6)]); +/// let result = rem(&a, &b); +/// let expected = Int32Array::from(&[Some(0), Some(1)]); +/// assert_eq!(result, expected) +/// ``` +pub fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Rem, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b) +} + +/// Checked remainder of two primitive arrays. If the result from the remainder +/// overflows, the result for the operation will change the validity array +/// making this operation None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_rem; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); +/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); +/// let result = checked_rem(&a, &b); +/// let expected = Int8Array::from(&[Some(-0i8), None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + let op = move |a: T, b: T| a.checked_rem(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +impl ArrayRem> for PrimitiveArray +where + T: NativeArithmetics + Rem, +{ + fn rem(&self, rhs: &PrimitiveArray) -> Self { + rem(self, rhs) + } +} + +impl ArrayCheckedRem> for PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + fn checked_rem(&self, rhs: &PrimitiveArray) -> Self { + checked_rem(self, rhs) + } +} + +/// Remainder a primitive array of type T by a scalar T. +/// Panics if the divisor is zero. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::rem_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(7)]); +/// let result = rem_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(0), None, Some(1)]); +/// assert_eq!(result, expected) +/// ``` +pub fn rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Rem + NumCast, +{ + let rhs = *rhs; + + match T::PRIMITIVE { + PrimitiveType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u64().unwrap(); + + let reduced_rem = StrengthReducedU64::new(rhs); + + // small hack to avoid a transmute of `PrimitiveArray` to `PrimitiveArray` + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u32().unwrap(); + + let reduced_rem = StrengthReducedU32::new(rhs); + + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u16().unwrap(); + + let reduced_rem = StrengthReducedU16::new(rhs); + + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u8().unwrap(); + + let reduced_rem = StrengthReducedU8::new(rhs); + + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + _ => unary(lhs, |a| a % rhs, lhs.data_type().clone()), + } +} + +/// Checked remainder of a primitive array of type T by a scalar T. If the +/// divisor is zero then the validity array is changed to None. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_rem_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = checked_rem_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(0i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_rem(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +impl ArrayRem for PrimitiveArray +where + T: NativeArithmetics + Rem + NumCast, +{ + fn rem(&self, rhs: &T) -> Self { + rem_scalar(self, rhs) + } +} + +impl ArrayCheckedRem for PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + fn checked_rem(&self, rhs: &T) -> Self { + checked_rem_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/sub.rs b/crates/nano-arrow/src/compute/arithmetics/basic/sub.rs new file mode 100644 index 000000000000..5b2dcd36cb25 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/sub.rs @@ -0,0 +1,337 @@ +//! Definition of basic sub operations with primitive arrays +use std::ops::Sub; + +use num_traits::ops::overflowing::OverflowingSub; +use num_traits::{CheckedSub, SaturatingSub, WrappingSub}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayCheckedSub, ArrayOverflowingSub, ArraySaturatingSub, ArraySub, ArrayWrappingSub, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; + +/// Subtracts two primitive arrays with the same type. +/// Panics if the subtraction of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::sub; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); +/// let result = sub(&a, &b); +/// let expected = Int32Array::from(&[None, None, None, Some(0)]); +/// assert_eq!(result, expected) +/// ``` +pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b) +} + +/// Wrapping subtraction of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_sub; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(-100i8), Some(-100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = wrapping_sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(-100i8), Some(56i8), Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + let op = move |a: T, b: T| a.wrapping_sub(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked subtraction of two primitive arrays. If the result from the +/// subtraction overflow, the validity for that index is changed +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_sub; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(100i8), Some(-100i8), Some(100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(0i8)]); +/// let result = checked_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(99i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + let op = move |a: T, b: T| a.checked_sub(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating subtraction of two primitive arrays. If the result from the sub +/// is smaller than the possible number for this type, the result for the +/// operation will be the saturated value. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_sub; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let b = Int8Array::from(&[Some(100i8)]); +/// let result = saturating_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(-128)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + let op = move |a: T, b: T| a.saturating_sub(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing subtraction of two primitive arrays. If the result from the sub +/// is smaller than the possible number for this type, the result for the +/// operation will be an array with overflowed values and a validity array +/// indicating the overflowing elements from the array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_sub; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(0i8), Some(56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingSub, +{ + let op = move |a: T, b: T| a.overflowing_sub(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArraySub trait for PrimitiveArrays +impl ArraySub> for PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + fn sub(&self, rhs: &PrimitiveArray) -> Self { + sub(self, rhs) + } +} + +impl ArrayWrappingSub> for PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + fn wrapping_sub(&self, rhs: &PrimitiveArray) -> Self { + wrapping_sub(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays +impl ArrayCheckedSub> for PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { + checked_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArraySaturatingSub> for PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { + saturating_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArrayOverflowingSub> for PrimitiveArray +where + T: NativeArithmetics + OverflowingSub, +{ + fn overflowing_sub(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_sub(self, rhs) + } +} + +/// Subtract a scalar T to a primitive array of type T. +/// Panics if the subtraction of the values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::sub_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = sub_scalar(&a, &1i32); +/// let expected = Int32Array::from(&[None, Some(5), None, Some(5)]); +/// assert_eq!(result, expected) +/// ``` +pub fn sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + let rhs = *rhs; + unary(lhs, |a| a - rhs, lhs.data_type().clone()) +} + +/// Wrapping subtraction of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(-100)]); +/// let result = wrapping_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, Some(56)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + unary(lhs, |a| a.wrapping_sub(rhs), lhs.data_type().clone()) +} + +/// Checked subtraction of a scalar T to a primitive array of type T. If the +/// result from the subtraction overflows, then the validity for that index +/// is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(-100), None, Some(-100)]); +/// let result = checked_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_sub(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated subtraction of a scalar T to a primitive array of type T. If the +/// result from the sub is smaller than the possible number for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = saturating_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-128i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_sub(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing subtraction of a scalar T to a primitive array of type T. If +/// the result from the sub is smaller than the possible number for this type, +/// then the result will be an array with overflowed values and a validity +/// array indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let (result, overflow) = overflowing_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-99i8), Some(56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingSub, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_sub(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArraySub trait for PrimitiveArrays with a scalar +impl ArraySub for PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + fn sub(&self, rhs: &T) -> Self { + sub_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays with a scalar +impl ArrayCheckedSub for PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + fn checked_sub(&self, rhs: &T) -> Self { + checked_sub_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar +impl ArraySaturatingSub for PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + fn saturating_sub(&self, rhs: &T) -> Self { + saturating_sub_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar +impl ArrayOverflowingSub for PrimitiveArray +where + T: NativeArithmetics + OverflowingSub, +{ + fn overflowing_sub(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_sub_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/add.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/add.rs new file mode 100644 index 000000000000..dccdb6b144c1 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/add.rs @@ -0,0 +1,236 @@ +//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; +use crate::compute::arity::{binary, binary_checked}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Adds two decimal [`PrimitiveArray`] with the same precision and scale. +/// # Error +/// Errors if the precision and scale are different. +/// # Panic +/// This function panics iff the added numbers result in a number larger than +/// the possible number for the precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = add(&a, &b); +/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; + + assert!( + res.abs() <= max, + "Overflow in addition presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturated addition of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is larger than +/// the possible number with the selected precision then the resulted number in +/// the arrow array is the maximum number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; + + if res.abs() > max { + if res > 0 { + max + } else { + -max + } + } else { + res + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked addition of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_add(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let result: i128 = a + b; + + if result.abs() > max { + None + } else { + Some(result) + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays +impl ArrayAdd> for PrimitiveArray { + fn add(&self, rhs: &PrimitiveArray) -> Self { + add(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays +impl ArrayCheckedAdd> for PrimitiveArray { + fn checked_add(&self, rhs: &PrimitiveArray) -> Self { + checked_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArraySaturatingAdd> for PrimitiveArray { + fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { + saturating_add(self, rhs) + } +} + +/// Adaptive addition of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// addition one of the results is larger than the max possible value, the +/// result precision is changed to the precision of the max value +/// +/// ```nocode +/// 11111.11 -> 7, 2 +/// 11111.111 -> 8, 3 +/// ------------------ +/// 22222.221 -> 8, 3 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); +/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); +/// let result = adaptive_add(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + l + r * shift + } else { + l * shift + r + }; + + // The precision of the resulting array will change if one of the + // sums during the iteration produces a value bigger than the + // possible value for the initial precision + + // 99.9999 -> 6, 4 + // 00.0001 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/div.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/div.rs new file mode 100644 index 000000000000..1576fc061947 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/div.rs @@ -0,0 +1,302 @@ +//! Defines the division arithmetic kernels for Decimal +//! `PrimitiveArrays`. + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; +use crate::compute::arity::{binary, binary_checked, unary}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::scalar::{PrimitiveScalar, Scalar}; + +/// Divide two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the dividend is divided by 0 or None. +/// This function also panics if the division produces a number larger +/// than the possible number for the array precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = div(&a, &b); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + let op = move |a: i128, b: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; + + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(b).expect("Found division by zero"); + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +pub fn div_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; + + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(rhs).expect("Found division by zero"); + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Saturated division of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the division is +/// larger than the possible number with the selected precision then the +/// resulted number in the arrow array is the maximum number for the selected +/// precision. The function panics if divided by zero. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(000_01i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_div(&a, &b); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + }, + None => 0, + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked division of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the divisor is zero, then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_div(&a, &b); +/// let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => None, + _ => Some(res), + }, + None => None, + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays +impl ArrayDiv> for PrimitiveArray { + fn div(&self, rhs: &PrimitiveArray) -> Self { + div(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays +impl ArrayCheckedDiv> for PrimitiveArray { + fn checked_div(&self, rhs: &PrimitiveArray) -> Self { + checked_div(self, rhs) + } +} + +/// Adaptive division of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// division one of the results is larger than the max possible value, the +/// result precision is changed to the precision of the max value. The function +/// panics when divided by zero. +/// +/// ```nocode +/// 1000.00 -> 7, 2 +/// 10.0000 -> 6, 4 +/// ----------------- +/// 100.0000 -> 9, 4 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); +/// let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); +/// let result = adaptive_div(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(9, 4)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + let numeral: i128 = l * shift_1; + + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + numeral.checked_div(r * shift) + } else { + (numeral * shift).checked_div(*r) + } + .expect("Found division by zero"); + + // The precision of the resulting array will change if one of the + // multiplications during the iteration produces a value bigger + // than the possible value for the initial precision + + // 10.0000 -> 6, 4 + // 00.1000 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/mod.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/mod.rs new file mode 100644 index 000000000000..4b412ef13c6e --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/mod.rs @@ -0,0 +1,119 @@ +//! Defines the arithmetic kernels for Decimal `PrimitiveArrays`. The +//! [`Decimal`](crate::datatypes::DataType::Decimal) type specifies the +//! precision and scale parameters. These affect the arithmetic operations and +//! need to be considered while doing operations with Decimal numbers. + +mod add; +pub use add::*; +mod div; +pub use div::*; +mod mul; +pub use mul::*; +mod sub; +pub use sub::*; + +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Maximum value that can exist with a selected precision +#[inline] +fn max_value(precision: usize) -> i128 { + 10i128.pow(precision as u32) - 1 +} + +// Calculates the number of digits in a i128 number +fn number_digits(num: i128) -> usize { + let mut num = num.abs(); + let mut digit: i128 = 0; + let base = 10i128; + + while num != 0 { + num /= base; + digit += 1; + } + + digit as usize +} + +fn get_parameters(lhs: &DataType, rhs: &DataType) -> Result<(usize, usize)> { + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.to_logical_type(), rhs.to_logical_type()) + { + if lhs_p == rhs_p && lhs_s == rhs_s { + Ok((*lhs_p, *lhs_s)) + } else { + Err(Error::InvalidArgumentError( + "Arrays must have the same precision and scale".to_string(), + )) + } + } else { + unreachable!() + } +} + +/// Returns the adjusted precision and scale for the lhs and rhs precision and +/// scale +fn adjusted_precision_scale( + lhs_p: usize, + lhs_s: usize, + rhs_p: usize, + rhs_s: usize, +) -> (usize, usize, usize) { + // The initial new precision and scale is based on the number of digits + // that lhs and rhs number has before and after the point. The max + // number of digits before and after the point will make the last + // precision and scale of the result + + // Digits before/after point + // before after + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + let lhs_digits_before = lhs_p - lhs_s; + let rhs_digits_before = rhs_p - rhs_s; + + let res_digits_before = std::cmp::max(lhs_digits_before, rhs_digits_before); + + let (res_s, diff) = if lhs_s > rhs_s { + (lhs_s, lhs_s - rhs_s) + } else { + (rhs_s, rhs_s - lhs_s) + }; + + let res_p = res_digits_before + res_s; + + (res_p, res_s, diff) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_value() { + assert_eq!(999, max_value(3)); + assert_eq!(99999, max_value(5)); + assert_eq!(999999, max_value(6)); + } + + #[test] + fn test_number_digits() { + assert_eq!(2, number_digits(12i128)); + assert_eq!(3, number_digits(123i128)); + assert_eq!(4, number_digits(1234i128)); + assert_eq!(6, number_digits(123456i128)); + assert_eq!(7, number_digits(1234567i128)); + assert_eq!(7, number_digits(-1234567i128)); + assert_eq!(3, number_digits(-123i128)); + } + + #[test] + fn test_adjusted_precision_scale() { + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + assert_eq!((9, 4, 2), adjusted_precision_scale(5, 4, 7, 2)) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/mul.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/mul.rs new file mode 100644 index 000000000000..a944279a133e --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/mul.rs @@ -0,0 +1,314 @@ +//! Defines the multiplication arithmetic kernels for Decimal +//! `PrimitiveArrays`. + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul}; +use crate::compute::arity::{binary, binary_checked, unary}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::scalar::{PrimitiveScalar, Scalar}; + +/// Multiply two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a + .checked_mul(rhs) + .expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Saturated multiplication of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the multiplication is +/// larger than the possible number with the selected precision then the +/// resulted number in the arrow array is the maximum number for the selected +/// precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + } + }, + None => max, + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked multiplication of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the mul is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_mul(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => None, + _ => Some(res), + } + }, + None => None, + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayMul trait for PrimitiveArrays +impl ArrayMul> for PrimitiveArray { + fn mul(&self, rhs: &PrimitiveArray) -> Self { + mul(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays +impl ArrayCheckedMul> for PrimitiveArray { + fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { + checked_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArraySaturatingMul> for PrimitiveArray { + fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { + saturating_mul(self, rhs) + } +} + +/// Adaptive multiplication of two decimal primitive arrays with different +/// precision and scale. If the precision and scale is different, then the +/// smallest scale and precision is adjusted to the largest precision and +/// scale. If during the multiplication one of the results is larger than the +/// max possible value, the result precision is changed to the precision of the +/// max value +/// +/// ```nocode +/// 11111.0 -> 6, 1 +/// 10.002 -> 5, 3 +/// ----------------- +/// 111132.222 -> 9, 3 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(11111_0i128), Some(1_0i128)]).to(DataType::Decimal(6, 1)); +/// let b = PrimitiveArray::from([Some(10_002i128), Some(2_000i128)]).to(DataType::Decimal(5, 3)); +/// let result = adaptive_mul(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(111132_222i128), Some(2_000i128)]).to(DataType::Decimal(9, 3)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + l.checked_mul(r * shift) + } else { + (l * shift).checked_mul(*r) + } + .expect("Mayor overflow for multiplication"); + + let res = res / shift_1; + + // The precision of the resulting array will change if one of the + // multiplications during the iteration produces a value bigger + // than the possible value for the initial precision + + // 10.0000 -> 6, 4 + // 10.0000 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/sub.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/sub.rs new file mode 100644 index 000000000000..2a0f7a72da17 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/sub.rs @@ -0,0 +1,238 @@ +//! Defines the subtract arithmetic kernels for Decimal `PrimitiveArrays`. + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedSub, ArraySaturatingSub, ArraySub}; +use crate::compute::arity::{binary, binary_checked}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Subtract two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the subtracted numbers result in a number +/// smaller than the possible number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(0i128), Some(-1i128), None, Some(0i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + assert!( + res.abs() <= max, + "Overflow in subtract presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturated subtraction of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is smaller +/// than the possible number with the selected precision then the resulted +/// number in the arrow array is the minimum number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(-99999i128), Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArraySub trait for PrimitiveArrays +impl ArraySub> for PrimitiveArray { + fn sub(&self, rhs: &PrimitiveArray) -> Self { + sub(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays +impl ArrayCheckedSub> for PrimitiveArray { + fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { + checked_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArraySaturatingSub> for PrimitiveArray { + fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { + saturating_sub(self, rhs) + } +} +/// Checked subtract of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sub is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_sub(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => None, + _ => Some(res), + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Adaptive subtract of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// addition one of the results is smaller than the min possible value, the +/// result precision is changed to the precision of the min value +/// +/// ```nocode +/// 99.9999 -> 6, 4 +/// -00.0001 -> 6, 4 +/// ----------------- +/// 100.0000 -> 7, 4 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); +/// let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(6, 4)); +/// let result = adaptive_sub(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res: i128 = if lhs_s > rhs_s { + l - r * shift + } else { + l * shift - r + }; + + // The precision of the resulting array will change if one of the + // subtraction during the iteration produces a value bigger than the + // possible value for the initial precision + + // -99.9999 -> 6, 4 + // 00.0001 -> 6, 4 + // ----------------- + // -100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/mod.rs b/crates/nano-arrow/src/compute/arithmetics/mod.rs new file mode 100644 index 000000000000..1d520e9ad644 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/mod.rs @@ -0,0 +1,581 @@ +//! Defines basic arithmetic kernels for [`PrimitiveArray`](crate::array::PrimitiveArray)s. +//! +//! The Arithmetics module is composed by basic arithmetics operations that can +//! be performed on [`PrimitiveArray`](crate::array::PrimitiveArray). +//! +//! Whenever possible, each operation declares variations +//! of the basic operation that offers different guarantees: +//! * plain: panics on overflowing and underflowing. +//! * checked: turns an overflowing to a null. +//! * saturating: turns the overflowing to the MAX or MIN value respectively. +//! * overflowing: returns an extra [`Bitmap`] denoting whether the operation overflowed. +//! * adaptive: for [`Decimal`](crate::datatypes::DataType::Decimal) only, +//! adjusts the precision and scale to make the resulting value fit. +#[forbid(unsafe_code)] +pub mod basic; +#[cfg(feature = "compute_arithmetics_decimal")] +pub mod decimal; +pub mod time; + +use crate::array::{Array, DictionaryArray, PrimitiveArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; +use crate::scalar::{PrimitiveScalar, Scalar}; +use crate::types::NativeType; + +fn binary_dyn, &PrimitiveArray) -> PrimitiveArray>( + lhs: &dyn Array, + rhs: &dyn Array, + op: F, +) -> Box { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + op(lhs, rhs).boxed() +} + +// Macro to create a `match` statement with dynamic dispatch to functions based on +// the array's logical types +macro_rules! arith { + ($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{ + let lhs = $lhs; + let rhs = $rhs; + use DataType::*; + match (lhs.data_type(), rhs.data_type()) { + (Int8, Int8) => binary_dyn::(lhs, rhs, basic::$op), + (Int16, Int16) => binary_dyn::(lhs, rhs, basic::$op), + (Int32, Int32) => binary_dyn::(lhs, rhs, basic::$op), + (Int64, Int64) | (Duration(_), Duration(_)) => { + binary_dyn::(lhs, rhs, basic::$op) + } + (UInt8, UInt8) => binary_dyn::(lhs, rhs, basic::$op), + (UInt16, UInt16) => binary_dyn::(lhs, rhs, basic::$op), + (UInt32, UInt32) => binary_dyn::(lhs, rhs, basic::$op), + (UInt64, UInt64) => binary_dyn::(lhs, rhs, basic::$op), + (Float32, Float32) => binary_dyn::(lhs, rhs, basic::$op), + (Float64, Float64) => binary_dyn::(lhs, rhs, basic::$op), + $ ( + (Decimal(_, _), Decimal(_, _)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(decimal::$op_decimal(lhs, rhs)) as Box + } + )? + $ ( + (Time32(TimeUnit::Second), Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Date32, Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Date64, Duration(_)) + | (Timestamp(_, _), Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + )? + $ ( + (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + $ ( + (Timestamp(_, None), Timestamp(_, None)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + _ => todo!( + "Addition of {:?} with {:?} is not supported", + lhs.data_type(), + rhs.data_type() + ), + } + }}; +} + +fn binary_scalar, &T) -> PrimitiveArray>( + lhs: &PrimitiveArray, + rhs: &PrimitiveScalar, + op: F, +) -> PrimitiveArray { + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + op(lhs, &rhs) +} + +fn binary_scalar_dyn, &T) -> PrimitiveArray>( + lhs: &dyn Array, + rhs: &dyn Scalar, + op: F, +) -> Box { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary_scalar(lhs, rhs, op).boxed() +} + +// Macro to create a `match` statement with dynamic dispatch to functions based on +// the array's logical types +macro_rules! arith_scalar { + ($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{ + let lhs = $lhs; + let rhs = $rhs; + use DataType::*; + match (lhs.data_type(), rhs.data_type()) { + (Int8, Int8) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int16, Int16) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int32, Int32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int64, Int64) | (Duration(_), Duration(_)) => { + binary_scalar_dyn::(lhs, rhs, basic::$op) + } + (UInt8, UInt8) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt16, UInt16) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt32, UInt32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt64, UInt64) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Float32, Float32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Float64, Float64) => binary_scalar_dyn::(lhs, rhs, basic::$op), + $ ( + (Decimal(_, _), Decimal(_, _)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + decimal::$op_decimal(lhs, rhs).boxed() + } + )? + $ ( + (Time32(TimeUnit::Second), Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Date32, Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_duration::(lhs, rhs).boxed() + } + (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Date64, Duration(_)) + | (Timestamp(_, _), Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_duration::(lhs, rhs).boxed() + } + )? + $ ( + (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_interval(lhs, rhs).unwrap().boxed() + } + )? + $ ( + (Timestamp(_, None), Timestamp(_, None)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_timestamp(lhs, rhs).unwrap().boxed() + } + )? + _ => todo!( + "Addition of {:?} with {:?} is not supported", + lhs.data_type(), + rhs.data_type() + ), + } + }}; +} + +/// Adds two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_add`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!( + lhs, + rhs, + add, + duration = add_duration, + interval = add_interval + ) +} + +/// Adds an [`Array`] and a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_add`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!( + lhs, + rhs, + add_scalar, + duration = add_duration_scalar, + interval = add_interval_scalar + ) +} + +/// Returns whether two [`DataType`]s can be added by [`add`]. +pub fn can_add(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Duration(_), Duration(_)) + | (Decimal(_, _), Decimal(_, _)) + | (Date32, Duration(_)) + | (Date64, Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Time32(TimeUnit::Second), Duration(_)) + | (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Timestamp(_, _), Duration(_)) + | (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) + ) +} + +/// Subtracts two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_sub`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!( + lhs, + rhs, + sub, + decimal = sub, + duration = subtract_duration, + timestamp = subtract_timestamps + ) +} + +/// Adds an [`Array`] and a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_sub`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn sub_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!( + lhs, + rhs, + sub_scalar, + duration = sub_duration_scalar, + timestamp = sub_timestamps_scalar + ) +} + +/// Returns whether two [`DataType`]s can be subtracted by [`sub`]. +pub fn can_sub(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Duration(_), Duration(_)) + | (Decimal(_, _), Decimal(_, _)) + | (Date32, Duration(_)) + | (Date64, Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Time32(TimeUnit::Second), Duration(_)) + | (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Timestamp(_, _), Duration(_)) + | (Timestamp(_, None), Timestamp(_, None)) + ) +} + +/// Multiply two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_mul`] to check) +/// * the arrays have a different length +pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, mul, decimal = mul) +} + +/// Multiply an [`Array`] with a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_mul`] to check) +pub fn mul_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!(lhs, rhs, mul_scalar, decimal = mul_scalar) +} + +/// Returns whether two [`DataType`]s can be multiplied by [`mul`]. +pub fn can_mul(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Decimal(_, _), Decimal(_, _)) + ) +} + +/// Divide of two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_div`] to check) +/// * the arrays have a different length +pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, div, decimal = div) +} + +/// Divide an [`Array`] with a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_div`] to check) +pub fn div_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!(lhs, rhs, div_scalar, decimal = div_scalar) +} + +/// Returns whether two [`DataType`]s can be divided by [`div`]. +pub fn can_div(lhs: &DataType, rhs: &DataType) -> bool { + can_mul(lhs, rhs) +} + +/// Remainder of two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_rem`] to check) +/// * the arrays have a different length +pub fn rem(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, rem) +} + +/// Returns whether two [`DataType`]s "can be remainder" by [`rem`]. +pub fn can_rem(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + ) +} + +macro_rules! with_match_negatable {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns, i256}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 | UInt16 | UInt32 | UInt64 | Float16 => todo!(), + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +/// Negates an [`Array`]. +/// # Panic +/// This function panics iff either +/// * the operation is not supported for the logical type (use [`can_neg`] to check) +/// * the operation overflows +pub fn neg(array: &dyn Array) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_negatable!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + + let result = basic::negate::<$T>(array); + Box::new(result) as Box + }), + Dictionary(key) => match_integer_type!(key, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + + let values = neg(array.values().as_ref()); + + // safety - this operation only applies to values and thus preserves the dictionary's invariant + unsafe{ + DictionaryArray::<$T>::try_new_unchecked(array.data_type().clone(), array.keys().clone(), values).unwrap().boxed() + } + }), + _ => todo!(), + } +} + +/// Whether [`neg`] is supported for a given [`DataType`] +pub fn can_neg(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { + return can_neg(values.as_ref()); + } + + use crate::datatypes::PhysicalType::*; + use crate::datatypes::PrimitiveType::*; + matches!( + data_type.to_physical_type(), + Primitive(Int8) + | Primitive(Int16) + | Primitive(Int32) + | Primitive(Int64) + | Primitive(Float64) + | Primitive(Float32) + | Primitive(DaysMs) + | Primitive(MonthDayNano) + ) +} + +/// Defines basic addition operation for primitive arrays +pub trait ArrayAdd: Sized { + /// Adds itself to `rhs` + fn add(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping addition operation for primitive arrays +pub trait ArrayWrappingAdd: Sized { + /// Adds itself to `rhs` using wrapping addition + fn wrapping_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked addition operation for primitive arrays +pub trait ArrayCheckedAdd: Sized { + /// Checked add + fn checked_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating addition operation for primitive arrays +pub trait ArraySaturatingAdd: Sized { + /// Saturating add + fn saturating_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing addition operation for primitive arrays +pub trait ArrayOverflowingAdd: Sized { + /// Overflowing add + fn overflowing_add(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic subtraction operation for primitive arrays +pub trait ArraySub: Sized { + /// subtraction + fn sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping subtraction operation for primitive arrays +pub trait ArrayWrappingSub: Sized { + /// wrapping subtraction + fn wrapping_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked subtraction operation for primitive arrays +pub trait ArrayCheckedSub: Sized { + /// checked subtraction + fn checked_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating subtraction operation for primitive arrays +pub trait ArraySaturatingSub: Sized { + /// saturarting subtraction + fn saturating_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing subtraction operation for primitive arrays +pub trait ArrayOverflowingSub: Sized { + /// overflowing subtraction + fn overflowing_sub(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic multiplication operation for primitive arrays +pub trait ArrayMul: Sized { + /// multiplication + fn mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping multiplication operation for primitive arrays +pub trait ArrayWrappingMul: Sized { + /// wrapping multiplication + fn wrapping_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked multiplication operation for primitive arrays +pub trait ArrayCheckedMul: Sized { + /// checked multiplication + fn checked_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating multiplication operation for primitive arrays +pub trait ArraySaturatingMul: Sized { + /// saturating multiplication + fn saturating_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing multiplication operation for primitive arrays +pub trait ArrayOverflowingMul: Sized { + /// overflowing multiplication + fn overflowing_mul(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic division operation for primitive arrays +pub trait ArrayDiv: Sized { + /// division + fn div(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked division operation for primitive arrays +pub trait ArrayCheckedDiv: Sized { + /// checked division + fn checked_div(&self, rhs: &Rhs) -> Self; +} + +/// Defines basic reminder operation for primitive arrays +pub trait ArrayRem: Sized { + /// remainder + fn rem(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked reminder operation for primitive arrays +pub trait ArrayCheckedRem: Sized { + /// checked remainder + fn checked_rem(&self, rhs: &Rhs) -> Self; +} diff --git a/crates/nano-arrow/src/compute/arithmetics/time.rs b/crates/nano-arrow/src/compute/arithmetics/time.rs new file mode 100644 index 000000000000..aa2e25e3ab0f --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/time.rs @@ -0,0 +1,432 @@ +//! Defines the arithmetic kernels for adding a Duration to a Timestamp, +//! Time32, Time64, Date32 and Date64. +//! +//! For the purposes of Arrow Implementations, adding this value to a Timestamp +//! ("t1") naively (i.e. simply summing the two number) is acceptable even +//! though in some cases the resulting Timestamp (t2) would not account for +//! leap-seconds during the elapsed time between "t1" and "t2". Similarly, +//! representing the difference between two Unix timestamp is acceptable, but +//! would yield a value that is possibly a few seconds off from the true +//! elapsed time. + +use std::ops::{Add, Sub}; + +use num_traits::AsPrimitive; + +use crate::array::PrimitiveArray; +use crate::compute::arity::{binary, unary}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::error::{Error, Result}; +use crate::scalar::{PrimitiveScalar, Scalar}; +use crate::temporal_conversions; +use crate::types::{months_days_ns, NativeType}; + +/// Creates the scale required to add or subtract a Duration to a time array +/// (Timestamp, Time, or Date). The resulting scale always multiplies the rhs +/// number (Duration) so it can be added to the lhs number (time array). +fn create_scale(lhs: &DataType, rhs: &DataType) -> Result { + // Matching on both data types from both numbers to calculate the correct + // scale for the operation. The timestamp, Time and duration have a + // Timeunit enum in its data type. This enum is used to describe the + // addition of the duration. The Date32 and Date64 have different rules for + // the scaling. + let scale = match (lhs, rhs) { + (DataType::Timestamp(timeunit_a, _), DataType::Duration(timeunit_b)) + | (DataType::Time32(timeunit_a), DataType::Duration(timeunit_b)) + | (DataType::Time64(timeunit_a), DataType::Duration(timeunit_b)) => { + // The scale is based on the TimeUnit that each of the numbers have. + temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b) + }, + (DataType::Date32, DataType::Duration(timeunit)) => { + // Date32 represents the time elapsed time since UNIX epoch + // (1970-01-01) in days (32 bits). The duration value has to be + // scaled to days to be able to add the value to the Date. + temporal_conversions::timeunit_scale(TimeUnit::Second, *timeunit) + / temporal_conversions::SECONDS_IN_DAY as f64 + }, + (DataType::Date64, DataType::Duration(timeunit)) => { + // Date64 represents the time elapsed time since UNIX epoch + // (1970-01-01) in milliseconds (64 bits). The duration value has + // to be scaled to milliseconds to be able to add the value to the + // Date. + temporal_conversions::timeunit_scale(TimeUnit::Millisecond, *timeunit) + }, + _ => { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the arguments".to_string(), + )); + }, + }; + + Ok(scale) +} + +/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::time::add_duration; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// +/// let timestamp = PrimitiveArray::from([ +/// Some(100000i64), +/// Some(200000i64), +/// None, +/// Some(300000i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = add_duration(×tamp, &duration); +/// let expected = PrimitiveArray::from([ +/// Some(100010i64), +/// Some(200020i64), +/// None, +/// Some(300030i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn add_duration( + time: &PrimitiveArray, + duration: &PrimitiveArray, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Add, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T, b: i64| a + (b as f64 * scale).as_(); + + binary(time, duration, time.data_type().clone(), op) +} + +/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +pub fn add_duration_scalar( + time: &PrimitiveArray, + duration: &PrimitiveScalar, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Add, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + let duration = if let Some(duration) = *duration.value() { + duration + } else { + return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); + }; + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T| a + (duration as f64 * scale).as_(); + + unary(time, op, time.data_type().clone()) +} + +/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::time::subtract_duration; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// +/// let timestamp = PrimitiveArray::from([ +/// Some(100000i64), +/// Some(200000i64), +/// None, +/// Some(300000i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = subtract_duration(×tamp, &duration); +/// let expected = PrimitiveArray::from([ +/// Some(99990i64), +/// Some(199980i64), +/// None, +/// Some(299970i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// assert_eq!(result, expected); +/// +/// ``` +pub fn subtract_duration( + time: &PrimitiveArray, + duration: &PrimitiveArray, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Sub, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T, b: i64| a - (b as f64 * scale).as_(); + + binary(time, duration, time.data_type().clone(), op) +} + +/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +pub fn sub_duration_scalar( + time: &PrimitiveArray, + duration: &PrimitiveScalar, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Sub, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + let duration = if let Some(duration) = *duration.value() { + duration + } else { + return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); + }; + + let op = move |a: T| a - (duration as f64 * scale).as_(); + + unary(time, op, time.data_type().clone()) +} + +/// Calculates the difference between two timestamps returning an array of type +/// Duration. The timeunit enum is used to scale correctly both arrays; +/// subtracting seconds with seconds, or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::time::subtract_timestamps; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// let timestamp_a = PrimitiveArray::from([ +/// Some(100_010i64), +/// Some(200_020i64), +/// None, +/// Some(300_030i64), +/// ]) +/// .to(DataType::Timestamp(TimeUnit::Second, None)); +/// +/// let timestamp_b = PrimitiveArray::from([ +/// Some(100_000i64), +/// Some(200_000i64), +/// None, +/// Some(300_000i64), +/// ]) +/// .to(DataType::Timestamp(TimeUnit::Second, None)); +/// +/// let expected = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = subtract_timestamps(×tamp_a, &×tamp_b).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn subtract_timestamps( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + // Matching on both data types from both arrays. + // Both timestamps have a Timeunit enum in its data type. + // This enum is used to adjust the scale between the timestamps. + match (lhs.data_type(), rhs.data_type()) { + // Naive timestamp comparison. It doesn't take into account timezones + // from the Timestamp timeunit. + (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) => { + // Closure for the binary operation. The closure contains the scale + // required to calculate the difference between the timestamps. + let scale = temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b); + let op = move |a, b| a - (b as f64 * scale) as i64; + + Ok(binary(lhs, rhs, DataType::Duration(*timeunit_a), op)) + }, + _ => Err(Error::InvalidArgumentError( + "Incorrect data type for the arguments".to_string(), + )), + } +} + +/// Calculates the difference between two timestamps as [`DataType::Duration`] with the same time scale. +pub fn sub_timestamps_scalar( + lhs: &PrimitiveArray, + rhs: &PrimitiveScalar, +) -> Result> { + let (scale, timeunit_a) = + if let (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) = + (lhs.data_type(), rhs.data_type()) + { + ( + temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b), + timeunit_a, + ) + } else { + return Err(Error::InvalidArgumentError( + "sub_timestamps_scalar requires both arguments to be timestamps without timezone" + .to_string(), + )); + }; + + let rhs = if let Some(value) = *rhs.value() { + value + } else { + return Ok(PrimitiveArray::::new_null( + lhs.data_type().clone(), + lhs.len(), + )); + }; + + let op = move |a| a - (rhs as f64 * scale) as i64; + + Ok(unary(lhs, op, DataType::Duration(*timeunit_a))) +} + +/// Adds an interval to a [`DataType::Timestamp`]. +pub fn add_interval( + timestamp: &PrimitiveArray, + interval: &PrimitiveArray, +) -> Result> { + match timestamp.data_type().to_logical_type() { + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let time_unit = *time_unit; + let timezone = temporal_conversions::parse_offset(timezone_str); + match timezone { + Ok(timezone) => Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + )), + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; + Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + )) + }, + #[cfg(not(feature = "chrono-tz"))] + _ => Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))), + } + }, + DataType::Timestamp(time_unit, None) => { + let time_unit = *time_unit; + Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_naive_interval(timestamp, time_unit, interval) + }, + )) + }, + _ => Err(Error::InvalidArgumentError( + "Adding an interval is only supported for `DataType::Timestamp`".to_string(), + )), + } +} + +/// Adds an interval to a [`DataType::Timestamp`]. +pub fn add_interval_scalar( + timestamp: &PrimitiveArray, + interval: &PrimitiveScalar, +) -> Result> { + let interval = if let Some(interval) = *interval.value() { + interval + } else { + return Ok(PrimitiveArray::::new_null( + timestamp.data_type().clone(), + timestamp.len(), + )); + }; + + match timestamp.data_type().to_logical_type() { + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let time_unit = *time_unit; + let timezone = temporal_conversions::parse_offset(timezone_str); + match timezone { + Ok(timezone) => Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + timestamp.data_type().clone(), + )), + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; + Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + timestamp.data_type().clone(), + )) + }, + #[cfg(not(feature = "chrono-tz"))] + _ => Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))), + } + }, + DataType::Timestamp(time_unit, None) => { + let time_unit = *time_unit; + Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_naive_interval(timestamp, time_unit, interval) + }, + timestamp.data_type().clone(), + )) + }, + _ => Err(Error::InvalidArgumentError( + "Adding an interval is only supported for `DataType::Timestamp`".to_string(), + )), + } +} diff --git a/crates/nano-arrow/src/compute/arity.rs b/crates/nano-arrow/src/compute/arity.rs new file mode 100644 index 000000000000..935970ccdf75 --- /dev/null +++ b/crates/nano-arrow/src/compute/arity.rs @@ -0,0 +1,279 @@ +//! Defines kernels suitable to perform operations to primitive arrays. + +use super::utils::{check_same_len, combine_validities}; +use crate::array::PrimitiveArray; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +/// Applies an unary and infallible function to a [`PrimitiveArray`]. This is the +/// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits +/// of a vectorized operation outweighs the cost of branching nulls and +/// non-nulls. +/// +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infallible for any value of the +/// corresponding type or this function may panic. +#[inline] +pub fn unary(array: &PrimitiveArray, op: F, data_type: DataType) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> O, +{ + let values = array.values().iter().map(|v| op(*v)).collect::>(); + + PrimitiveArray::::new(data_type, values.into(), array.validity().cloned()) +} + +/// Version of unary that checks for errors in the closure used to create the +/// buffer +pub fn try_unary( + array: &PrimitiveArray, + op: F, + data_type: DataType, +) -> Result> +where + I: NativeType, + O: NativeType, + F: Fn(I) -> Result, +{ + let values = array + .values() + .iter() + .map(|v| op(*v)) + .collect::>>()? + .into(); + + Ok(PrimitiveArray::::new( + data_type, + values, + array.validity().cloned(), + )) +} + +/// Version of unary that returns an array and bitmap. Used when working with +/// overflowing operations +pub fn unary_with_bitmap( + array: &PrimitiveArray, + op: F, + data_type: DataType, +) -> (PrimitiveArray, Bitmap) +where + I: NativeType, + O: NativeType, + F: Fn(I) -> (O, bool), +{ + let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); + + let values = array + .values() + .iter() + .map(|v| { + let (res, over) = op(*v); + mut_bitmap.push(over); + res + }) + .collect::>() + .into(); + + ( + PrimitiveArray::::new(data_type, values, array.validity().cloned()), + mut_bitmap.into(), + ) +} + +/// Version of unary that creates a mutable bitmap that is used to keep track +/// of checked operations. The resulting bitmap is compared with the array +/// bitmap to create the final validity array. +pub fn unary_checked( + array: &PrimitiveArray, + op: F, + data_type: DataType, +) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> Option, +{ + let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); + + let values = array + .values() + .iter() + .map(|v| match op(*v) { + Some(val) => { + mut_bitmap.push(true); + val + }, + None => { + mut_bitmap.push(false); + O::default() + }, + }) + .collect::>() + .into(); + + // The validity has to be checked against the bitmap created during the + // creation of the values with the iterator. If an error was found during + // the iteration, then the validity is changed to None to mark the value + // as Null + let bitmap: Bitmap = mut_bitmap.into(); + let validity = combine_validities(array.validity(), Some(&bitmap)); + + PrimitiveArray::::new(data_type, values, validity) +} + +/// Applies a binary operations to two primitive arrays. This is the fastest +/// way to perform an operation on two primitive array when the benefits of a +/// vectorized operation outweighs the cost of branching nulls and non-nulls. +/// # Errors +/// This function errors iff the arrays have a different length. +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infallible for any value of the +/// corresponding type. +/// The types of the arrays are not checked with this operation. The closure +/// "op" needs to handle the different types in the arrays. The datatype for the +/// resulting array has to be selected by the implementer of the function as +/// an argument for the function. +#[inline] +pub fn binary( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> PrimitiveArray +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> T, +{ + check_same_len(lhs, rhs).unwrap(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>() + .into(); + + PrimitiveArray::::new(data_type, values, validity) +} + +/// Version of binary that checks for errors in the closure used to create the +/// buffer +pub fn try_binary( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> Result> +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> Result, +{ + check_same_len(lhs, rhs)?; + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>>()? + .into(); + + Ok(PrimitiveArray::::new(data_type, values, validity)) +} + +/// Version of binary that returns an array and bitmap. Used when working with +/// overflowing operations +pub fn binary_with_bitmap( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> (PrimitiveArray, Bitmap) +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> (T, bool), +{ + check_same_len(lhs, rhs).unwrap(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + let (res, over) = op(*l, *r); + mut_bitmap.push(over); + res + }) + .collect::>() + .into(); + + ( + PrimitiveArray::::new(data_type, values, validity), + mut_bitmap.into(), + ) +} + +/// Version of binary that creates a mutable bitmap that is used to keep track +/// of checked operations. The resulting bitmap is compared with the array +/// bitmap to create the final validity array. +pub fn binary_checked( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> PrimitiveArray +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> Option, +{ + check_same_len(lhs, rhs).unwrap(); + + let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| match op(*l, *r) { + Some(val) => { + mut_bitmap.push(true); + val + }, + None => { + mut_bitmap.push(false); + T::default() + }, + }) + .collect::>() + .into(); + + let bitmap: Bitmap = mut_bitmap.into(); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + // The validity has to be checked against the bitmap created during the + // creation of the values with the iterator. If an error was found during + // the iteration, then the validity is changed to None to mark the value + // as Null + let validity = combine_validities(validity.as_ref(), Some(&bitmap)); + + PrimitiveArray::::new(data_type, values, validity) +} diff --git a/crates/nano-arrow/src/compute/arity_assign.rs b/crates/nano-arrow/src/compute/arity_assign.rs new file mode 100644 index 000000000000..e1b358d8aebb --- /dev/null +++ b/crates/nano-arrow/src/compute/arity_assign.rs @@ -0,0 +1,96 @@ +//! Defines generics suitable to perform operations to [`PrimitiveArray`] in-place. + +use either::Either; + +use super::utils::check_same_len; +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +/// Applies an unary function to a [`PrimitiveArray`], optionally in-place. +/// +/// # Implementation +/// This function tries to apply the function directly to the values of the array. +/// If that region is shared, this function creates a new region and writes to it. +/// +/// # Panics +/// This function panics iff +/// * the arrays have a different length. +/// * the function itself panics. +#[inline] +pub fn unary(array: &mut PrimitiveArray, op: F) +where + I: NativeType, + F: Fn(I) -> I, +{ + if let Some(values) = array.get_mut_values() { + // mutate in place + values.iter_mut().for_each(|l| *l = op(*l)); + } else { + // alloc and write to new region + let values = array.values().iter().map(|l| op(*l)).collect::>(); + array.set_values(values.into()); + } +} + +/// Applies a binary function to two [`PrimitiveArray`]s, optionally in-place, returning +/// a new [`PrimitiveArray`]. +/// +/// # Implementation +/// This function tries to apply the function directly to the values of the array. +/// If that region is shared, this function creates a new region and writes to it. +/// # Panics +/// This function panics iff +/// * the arrays have a different length. +/// * the function itself panics. +#[inline] +pub fn binary(lhs: &mut PrimitiveArray, rhs: &PrimitiveArray, op: F) +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> T, +{ + check_same_len(lhs, rhs).unwrap(); + + // both for the validity and for the values + // we branch to check if we can mutate in place + // if we can, great that is fastest. + // if we cannot, we allocate a new buffer and assign values to that + // new buffer, that is benchmarked to be ~2x faster than first memcpy and assign in place + // for the validity bits it can be much faster as we might need to iterate all bits if the + // bitmap has an offset. + if let Some(rhs) = rhs.validity() { + if lhs.validity().is_none() { + lhs.set_validity(Some(rhs.clone())); + } else { + lhs.apply_validity(|bitmap| { + match bitmap.into_mut() { + Either::Left(immutable) => { + // alloc new region + &immutable & rhs + }, + Either::Right(mutable) => { + // mutate in place + (mutable & rhs).into() + }, + } + }); + } + }; + + if let Some(values) = lhs.get_mut_values() { + // mutate values in place + values + .iter_mut() + .zip(rhs.values().iter()) + .for_each(|(l, r)| *l = op(*l, *r)); + } else { + // alloc new region + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>(); + lhs.set_values(values.into()); + } +} diff --git a/crates/nano-arrow/src/compute/bitwise.rs b/crates/nano-arrow/src/compute/bitwise.rs new file mode 100644 index 000000000000..37c26542b848 --- /dev/null +++ b/crates/nano-arrow/src/compute/bitwise.rs @@ -0,0 +1,75 @@ +//! Contains bitwise operators: [`or`], [`and`], [`xor`] and [`not`]. +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use crate::array::PrimitiveArray; +use crate::compute::arity::{binary, unary}; +use crate::types::NativeType; + +/// Performs `OR` operation on two [`PrimitiveArray`]s. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn or(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitOr, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a | b) +} + +/// Performs `XOR` operation between two [`PrimitiveArray`]s. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn xor(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitXor, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a ^ b) +} + +/// Performs `AND` operation on two [`PrimitiveArray`]s. +/// # Panic +/// This function panics when the arrays have different lengths. +pub fn and(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitAnd, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a & b) +} + +/// Returns a new [`PrimitiveArray`] with the bitwise `not`. +pub fn not(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Not, +{ + let op = move |a: T| !a; + unary(array, op, array.data_type().clone()) +} + +/// Performs `OR` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn or_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitOr, +{ + unary(lhs, |a| a | *rhs, lhs.data_type().clone()) +} + +/// Performs `XOR` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn xor_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitXor, +{ + unary(lhs, |a| a ^ *rhs, lhs.data_type().clone()) +} + +/// Performs `AND` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function panics when the arrays have different lengths. +pub fn and_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitAnd, +{ + unary(lhs, |a| a & *rhs, lhs.data_type().clone()) +} diff --git a/crates/nano-arrow/src/compute/boolean.rs b/crates/nano-arrow/src/compute/boolean.rs new file mode 100644 index 000000000000..daf6853c3c29 --- /dev/null +++ b/crates/nano-arrow/src/compute/boolean.rs @@ -0,0 +1,288 @@ +//! null-preserving operators such as [`and`], [`or`] and [`not`]. +use super::utils::combine_validities; +use crate::array::{Array, BooleanArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::scalar::BooleanScalar; + +fn assert_lengths(lhs: &BooleanArray, rhs: &BooleanArray) { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); +} + +/// Helper function to implement binary kernels +pub(crate) fn binary_boolean_kernel( + lhs: &BooleanArray, + rhs: &BooleanArray, + op: F, +) -> BooleanArray +where + F: Fn(&Bitmap, &Bitmap) -> Bitmap, +{ + assert_lengths(lhs, rhs); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + let values = op(left_buffer, right_buffer); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Performs `&&` operation on two [`BooleanArray`], combining the validities. +/// # Panics +/// This function panics iff the arrays have different lengths. +/// # Examples +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::and; +/// +/// let a = BooleanArray::from(&[Some(false), Some(true), None]); +/// let b = BooleanArray::from(&[Some(true), Some(true), Some(false)]); +/// let and_ab = and(&a, &b); +/// assert_eq!(and_ab, BooleanArray::from(&[Some(false), Some(true), None])); +/// ``` +pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on both sides + (0, 0) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `false` on left side + (l, _) if l == lhs.len() => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `false` on right side + (_, r) if r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // ignore the rest + _ => {}, + } + } + + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs & rhs) +} + +/// Performs `||` operation on two [`BooleanArray`], combining the validities. +/// # Panics +/// This function panics iff the arrays have different lengths. +/// # Examples +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::or; +/// +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); +/// let or_ab = or(&a, &b); +/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None])); +/// ``` +pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on left side + (0, _) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `true` on right side + (_, 0) => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // all values on lhs and rhs are `false` + (l, r) if l == lhs.len() && r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // ignore the rest + _ => {}, + } + } + + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs | rhs) +} + +/// Performs unary `NOT` operation on an arrays. If value is null then the result is also +/// null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::not; +/// +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let not_a = not(&a); +/// assert_eq!(not_a, BooleanArray::from(vec![Some(true), Some(false), None])); +/// ``` +pub fn not(array: &BooleanArray) -> BooleanArray { + let values = !array.values(); + let validity = array.validity().cloned(); + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Returns a non-null [`BooleanArray`] with whether each value of the array is null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::is_null; +/// # fn main() { +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let a_is_null = is_null(&a); +/// assert_eq!(a_is_null, BooleanArray::from_slice(vec![false, false, true])); +/// # } +/// ``` +pub fn is_null(input: &dyn Array) -> BooleanArray { + let len = input.len(); + + let values = match input.validity() { + None => MutableBitmap::from_len_zeroed(len).into(), + Some(buffer) => !buffer, + }; + + BooleanArray::new(DataType::Boolean, values, None) +} + +/// Returns a non-null [`BooleanArray`] with whether each value of the array is not null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::is_not_null; +/// +/// let a = BooleanArray::from(&vec![Some(false), Some(true), None]); +/// let a_is_not_null = is_not_null(&a); +/// assert_eq!(a_is_not_null, BooleanArray::from_slice(&vec![true, true, false])); +/// ``` +pub fn is_not_null(input: &dyn Array) -> BooleanArray { + let values = match input.validity() { + None => { + let mut mutable = MutableBitmap::new(); + mutable.extend_constant(input.len(), true); + mutable.into() + }, + Some(buffer) => buffer.clone(), + }; + BooleanArray::new(DataType::Boolean, values, None) +} + +/// Performs `AND` operation on an array and a scalar value. If either left or right value +/// is null then the result is also null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::and_scalar; +/// use arrow2::scalar::BooleanScalar; +/// +/// let array = BooleanArray::from_slice(&[false, false, true, true]); +/// let scalar = BooleanScalar::new(Some(true)); +/// let result = and_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from_slice(&[false, false, true, true])); +/// +/// ``` +pub fn and_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => array.clone(), + Some(false) => { + let values = Bitmap::new_zeroed(array.len()); + BooleanArray::new(DataType::Boolean, values, array.validity().cloned()) + }, + None => BooleanArray::new_null(DataType::Boolean, array.len()), + } +} + +/// Performs `OR` operation on an array and a scalar value. If either left or right value +/// is null then the result is also null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::or_scalar; +/// use arrow2::scalar::BooleanScalar; +/// # fn main() { +/// let array = BooleanArray::from_slice(&[false, false, true, true]); +/// let scalar = BooleanScalar::new(Some(true)); +/// let result = or_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from_slice(&[true, true, true, true])); +/// # } +/// ``` +pub fn or_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => { + let mut values = MutableBitmap::new(); + values.extend_constant(array.len(), true); + BooleanArray::new(DataType::Boolean, values.into(), array.validity().cloned()) + }, + Some(false) => array.clone(), + None => BooleanArray::new_null(DataType::Boolean, array.len()), + } +} + +/// Returns whether any of the values in the array are `true`. +/// +/// Null values are ignored. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::any; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false)]); +/// let b = BooleanArray::from(&[Some(false), Some(false)]); +/// let c = BooleanArray::from(&[None, Some(false)]); +/// +/// assert_eq!(any(&a), true); +/// assert_eq!(any(&b), false); +/// assert_eq!(any(&c), false); +/// ``` +pub fn any(array: &BooleanArray) -> bool { + if array.is_empty() { + false + } else if array.null_count() > 0 { + array.into_iter().any(|v| v == Some(true)) + } else { + let vals = array.values(); + vals.unset_bits() != vals.len() + } +} + +/// Returns whether all values in the array are `true`. +/// +/// Null values are ignored. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::all; +/// +/// let a = BooleanArray::from(&[Some(true), Some(true)]); +/// let b = BooleanArray::from(&[Some(false), Some(true)]); +/// let c = BooleanArray::from(&[None, Some(true)]); +/// +/// assert_eq!(all(&a), true); +/// assert_eq!(all(&b), false); +/// assert_eq!(all(&c), true); +/// ``` +pub fn all(array: &BooleanArray) -> bool { + if array.is_empty() { + true + } else if array.null_count() > 0 { + !array.into_iter().any(|v| v == Some(false)) + } else { + let vals = array.values(); + vals.unset_bits() == 0 + } +} diff --git a/crates/nano-arrow/src/compute/boolean_kleene.rs b/crates/nano-arrow/src/compute/boolean_kleene.rs new file mode 100644 index 000000000000..2983c2e31ded --- /dev/null +++ b/crates/nano-arrow/src/compute/boolean_kleene.rs @@ -0,0 +1,301 @@ +//! Boolean operators of [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics). +use crate::array::{Array, BooleanArray}; +use crate::bitmap::{binary, quaternary, ternary, unary, Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::scalar::BooleanScalar; + +/// Logical 'or' operation on two arrays with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Panics +/// This function panics iff the arrays have a different length +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::or; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false), None]); +/// let b = BooleanArray::from(&[None, None, None]); +/// let or_ab = or(&a, &b); +/// assert_eq!(or_ab, BooleanArray::from(&[Some(true), None, None])); +/// ``` +pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); + + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_validity = lhs.validity(); + let rhs_validity = rhs.validity(); + + let validity = match (lhs_validity, rhs_validity) { + (Some(lhs_validity), Some(rhs_validity)) => { + Some(quaternary( + lhs_values, + rhs_values, + lhs_validity, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v, rhs_v| { + // A = T + (lhs & lhs_v) | + // B = T + (rhs & rhs_v) | + // A = F & B = F + (!lhs & lhs_v) & (!rhs & rhs_v) + }, + )) + }, + (Some(lhs_validity), None) => { + // B != U + Some(ternary( + lhs_values, + rhs_values, + lhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v| { + // A = T + (lhs & lhs_v) | + // B = T + rhs | + // A = F & B = F + (!lhs & lhs_v) & !rhs + }, + )) + }, + (None, Some(rhs_validity)) => { + Some(ternary( + lhs_values, + rhs_values, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, rhs_v| { + // A = T + lhs | + // B = T + (rhs & rhs_v) | + // A = F & B = F + !lhs & (!rhs & rhs_v) + }, + )) + }, + (None, None) => None, + }; + BooleanArray::new(DataType::Boolean, lhs_values | rhs_values, validity) +} + +/// Logical 'and' operation on two arrays with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Panics +/// This function panics iff the arrays have a different length +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::and; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false), None]); +/// let b = BooleanArray::from(&[None, None, None]); +/// let and_ab = and(&a, &b); +/// assert_eq!(and_ab, BooleanArray::from(&[None, Some(false), None])); +/// ``` +pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); + + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_validity = lhs.validity(); + let rhs_validity = rhs.validity(); + + let validity = match (lhs_validity, rhs_validity) { + (Some(lhs_validity), Some(rhs_validity)) => { + Some(quaternary( + lhs_values, + rhs_values, + lhs_validity, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v, rhs_v| { + // B = F + (!rhs & rhs_v) | + // A = F + (!lhs & lhs_v) | + // A = T & B = T + (lhs & lhs_v) & (rhs & rhs_v) + }, + )) + }, + (Some(lhs_validity), None) => { + Some(ternary( + lhs_values, + rhs_values, + lhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v| { + // B = F + !rhs | + // A = F + (!lhs & lhs_v) | + // A = T & B = T + (lhs & lhs_v) & rhs + }, + )) + }, + (None, Some(rhs_validity)) => { + Some(ternary( + lhs_values, + rhs_values, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, rhs_v| { + // B = F + (!rhs & rhs_v) | + // A = F + !lhs | + // A = T & B = T + lhs & (rhs & rhs_v) + }, + )) + }, + (None, None) => None, + }; + BooleanArray::new(DataType::Boolean, lhs_values & rhs_values, validity) +} + +/// Logical 'or' operation on an array and a scalar value with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::scalar::BooleanScalar; +/// use arrow2::compute::boolean_kleene::or_scalar; +/// +/// let array = BooleanArray::from(&[Some(true), Some(false), None]); +/// let scalar = BooleanScalar::new(Some(false)); +/// let result = or_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from(&[Some(true), Some(false), None])); +/// ``` +pub fn or_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => { + let mut values = MutableBitmap::new(); + values.extend_constant(array.len(), true); + BooleanArray::new(DataType::Boolean, values.into(), None) + }, + Some(false) => array.clone(), + None => { + let values = array.values(); + let validity = match array.validity() { + Some(validity) => binary(values, validity, |value, validity| validity & value), + None => unary(values, |value| value), + }; + BooleanArray::new(DataType::Boolean, values.clone(), Some(validity)) + }, + } +} + +/// Logical 'and' operation on an array and a scalar value with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::scalar::BooleanScalar; +/// use arrow2::compute::boolean_kleene::and_scalar; +/// +/// let array = BooleanArray::from(&[Some(true), Some(false), None]); +/// let scalar = BooleanScalar::new(None); +/// let result = and_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from(&[None, Some(false), None])); +/// ``` +pub fn and_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => array.clone(), + Some(false) => { + let values = Bitmap::new_zeroed(array.len()); + BooleanArray::new(DataType::Boolean, values, None) + }, + None => { + let values = array.values(); + let validity = match array.validity() { + Some(validity) => binary(values, validity, |value, validity| validity & !value), + None => unary(values, |value| !value), + }; + BooleanArray::new(DataType::Boolean, array.values().clone(), Some(validity)) + }, + } +} + +/// Returns whether any of the values in the array are `true`. +/// +/// The output is unknown (`None`) if the array contains any null values and +/// no `true` values. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::any; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false)]); +/// let b = BooleanArray::from(&[Some(false), Some(false)]); +/// let c = BooleanArray::from(&[None, Some(false)]); +/// +/// assert_eq!(any(&a), Some(true)); +/// assert_eq!(any(&b), Some(false)); +/// assert_eq!(any(&c), None); +/// ``` +pub fn any(array: &BooleanArray) -> Option { + if array.is_empty() { + Some(false) + } else if array.null_count() > 0 { + if array.into_iter().any(|v| v == Some(true)) { + Some(true) + } else { + None + } + } else { + let vals = array.values(); + Some(vals.unset_bits() != vals.len()) + } +} + +/// Returns whether all values in the array are `true`. +/// +/// The output is unknown (`None`) if the array contains any null values and +/// no `false` values. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::all; +/// +/// let a = BooleanArray::from(&[Some(true), Some(true)]); +/// let b = BooleanArray::from(&[Some(false), Some(true)]); +/// let c = BooleanArray::from(&[None, Some(true)]); +/// +/// assert_eq!(all(&a), Some(true)); +/// assert_eq!(all(&b), Some(false)); +/// assert_eq!(all(&c), None); +/// ``` +pub fn all(array: &BooleanArray) -> Option { + if array.is_empty() { + Some(true) + } else if array.null_count() > 0 { + if array.into_iter().any(|v| v == Some(false)) { + Some(false) + } else { + None + } + } else { + let vals = array.values(); + Some(vals.unset_bits() == 0) + } +} diff --git a/crates/nano-arrow/src/compute/cast/binary_to.rs b/crates/nano-arrow/src/compute/cast/binary_to.rs new file mode 100644 index 000000000000..4f5a1fb2b610 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/binary_to.rs @@ -0,0 +1,159 @@ +use super::CastOptions; +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::offset::{Offset, Offsets}; +use crate::types::NativeType; + +/// Conversion of binary +pub fn binary_to_large_binary(from: &BinaryArray, to_data_type: DataType) -> BinaryArray { + let values = from.values().clone(); + BinaryArray::::new( + to_data_type, + from.offsets().into(), + values, + from.validity().cloned(), + ) +} + +/// Conversion of binary +pub fn binary_large_to_binary( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + let values = from.values().clone(); + let offsets = from.offsets().try_into()?; + Ok(BinaryArray::::new( + to_data_type, + offsets, + values, + from.validity().cloned(), + )) +} + +/// Conversion to utf8 +pub fn binary_to_utf8( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + Utf8Array::::try_new( + to_data_type, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Conversion to utf8 +/// # Errors +/// This function errors if the values are not valid utf8 +pub fn binary_to_large_utf8( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + let values = from.values().clone(); + let offsets = from.offsets().into(); + + Utf8Array::::try_new(to_data_type, offsets, values, from.validity().cloned()) +} + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero. +pub fn partial_binary_to_primitive( + from: &BinaryArray, + to: &DataType, +) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse_partial(x).ok().map(|x| x.0))); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub fn binary_to_primitive(from: &BinaryArray, to: &DataType) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse(x).ok())); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn binary_to_primitive_dyn( + from: &dyn Array, + to: &DataType, + options: CastOptions, +) -> Result> +where + T: NativeType + lexical_core::FromLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + if options.partial { + Ok(Box::new(partial_binary_to_primitive::(from, to))) + } else { + Ok(Box::new(binary_to_primitive::(from, to))) + } +} + +/// Cast [`BinaryArray`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn binary_to_dictionary( + from: &BinaryArray, +) -> Result> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn binary_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let values = from.as_any().downcast_ref().unwrap(); + binary_to_dictionary::(values).map(|x| Box::new(x) as Box) +} + +fn fixed_size_to_offsets(values_len: usize, fixed_size: usize) -> Offsets { + let offsets = (0..(values_len + 1)) + .step_by(fixed_size) + .map(|v| O::from_as_usize(v)) + .collect(); + // Safety + // * every element is `>= 0` + // * element at position `i` is >= than element at position `i-1`. + unsafe { Offsets::new_unchecked(offsets) } +} + +/// Conversion of `FixedSizeBinary` to `Binary`. +pub fn fixed_size_binary_binary( + from: &FixedSizeBinaryArray, + to_data_type: DataType, +) -> BinaryArray { + let values = from.values().clone(); + let offsets = fixed_size_to_offsets(values.len(), from.size()); + BinaryArray::::new( + to_data_type, + offsets.into(), + values, + from.validity().cloned(), + ) +} + +/// Conversion of binary +pub fn binary_to_list(from: &BinaryArray, to_data_type: DataType) -> ListArray { + let values = from.values().clone(); + let values = PrimitiveArray::new(DataType::UInt8, values, None); + ListArray::::new( + to_data_type, + from.offsets().clone(), + values.boxed(), + from.validity().cloned(), + ) +} diff --git a/crates/nano-arrow/src/compute/cast/boolean_to.rs b/crates/nano-arrow/src/compute/cast/boolean_to.rs new file mode 100644 index 000000000000..8a8cf7089d8f --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/boolean_to.rs @@ -0,0 +1,48 @@ +use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; +use crate::error::Result; +use crate::offset::Offset; +use crate::types::NativeType; + +pub(super) fn boolean_to_primitive_dyn(array: &dyn Array) -> Result> +where + T: NativeType + num_traits::One, +{ + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_primitive::(array))) +} + +/// Casts the [`BooleanArray`] to a [`PrimitiveArray`]. +pub fn boolean_to_primitive(from: &BooleanArray) -> PrimitiveArray +where + T: NativeType + num_traits::One, +{ + let values = from + .values() + .iter() + .map(|x| if x { T::one() } else { T::default() }) + .collect::>(); + + PrimitiveArray::::new(T::PRIMITIVE.into(), values.into(), from.validity().cloned()) +} + +/// Casts the [`BooleanArray`] to a [`Utf8Array`], casting trues to `"1"` and falses to `"0"` +pub fn boolean_to_utf8(from: &BooleanArray) -> Utf8Array { + let iter = from.values().iter().map(|x| if x { "1" } else { "0" }); + Utf8Array::from_trusted_len_values_iter(iter) +} + +pub(super) fn boolean_to_utf8_dyn(array: &dyn Array) -> Result> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_utf8::(array))) +} + +/// Casts the [`BooleanArray`] to a [`BinaryArray`], casting trues to `"1"` and falses to `"0"` +pub fn boolean_to_binary(from: &BooleanArray) -> BinaryArray { + let iter = from.values().iter().map(|x| if x { b"1" } else { b"0" }); + BinaryArray::from_trusted_len_values_iter(iter) +} + +pub(super) fn boolean_to_binary_dyn(array: &dyn Array) -> Result> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_binary::(array))) +} diff --git a/crates/nano-arrow/src/compute/cast/decimal_to.rs b/crates/nano-arrow/src/compute/cast/decimal_to.rs new file mode 100644 index 000000000000..ba9995c86c12 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/decimal_to.rs @@ -0,0 +1,137 @@ +use num_traits::{AsPrimitive, Float, NumCast}; + +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +#[inline] +fn decimal_to_decimal_impl Option>( + from: &PrimitiveArray, + op: F, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + op(*x).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let (from_precision, from_scale) = + if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + if to_scale == from_scale && to_precision >= from_precision { + // fast path + return from.clone().to(DataType::Decimal(to_precision, to_scale)); + } + // todo: other fast paths include increasing scale and precision by so that + // a number will never overflow (validity is preserved) + + if from_scale > to_scale { + let factor = 10_i128.pow((from_scale - to_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_div(factor), + to_precision, + to_scale, + ) + } else { + let factor = 10_i128.pow((to_scale - from_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_mul(factor), + to_precision, + to_scale, + ) + } +} + +pub(super) fn decimal_to_decimal_dyn( + from: &dyn Array, + to_precision: usize, + to_scale: usize, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_decimal(from, to_precision, to_scale))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_float(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let div = 10_f64.powi(from_scale as i32); + let values = from + .values() + .iter() + .map(|x| (*x as f64 / div).as_()) + .collect(); + + PrimitiveArray::::new(T::PRIMITIVE.into(), values, from.validity().cloned()) +} + +pub(super) fn decimal_to_float_dyn(from: &dyn Array) -> Result> +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_float::(from))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_integer(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + NumCast, +{ + let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let factor = 10_i128.pow(from_scale as u32); + let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor))); + + PrimitiveArray::from_trusted_len_iter(values) +} + +pub(super) fn decimal_to_integer_dyn(from: &dyn Array) -> Result> +where + T: NativeType + NumCast, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_integer::(from))) +} diff --git a/crates/nano-arrow/src/compute/cast/dictionary_to.rs b/crates/nano-arrow/src/compute/cast/dictionary_to.rs new file mode 100644 index 000000000000..4126e4a3d589 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/dictionary_to.rs @@ -0,0 +1,183 @@ +use super::{primitive_as_primitive, primitive_to_primitive, CastOptions}; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::compute::cast::cast; +use crate::compute::take::take; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +macro_rules! key_cast { + ($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty, $to_datatype:expr) => {{ + let cast_keys = primitive_to_primitive::<_, $to_type>($keys, $to_keys_type); + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > $keys.null_count() { + return Err(Error::Overflow); + } + // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { + DictionaryArray::try_new_unchecked($to_datatype, cast_keys, $values.clone()) + } + .map(|x| x.boxed()) + }}; +} + +/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] by keeping the +/// keys and casting the values to `values_type`. +/// # Errors +/// This function errors if the values are not castable to `values_type` +pub fn dictionary_to_dictionary_values( + from: &DictionaryArray, + values_type: &DataType, +) -> Result> { + let keys = from.keys(); + let values = from.values(); + let length = values.len(); + + let values = cast(values.as_ref(), values_type, CastOptions::default())?; + + assert_eq!(values.len(), length); // this is guaranteed by `cast` + unsafe { + DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + } +} + +/// Similar to dictionary_to_dictionary_values, but overflowing cast is wrapped +pub fn wrapping_dictionary_to_dictionary_values( + from: &DictionaryArray, + values_type: &DataType, +) -> Result> { + let keys = from.keys(); + let values = from.values(); + let length = values.len(); + + let values = cast( + values.as_ref(), + values_type, + CastOptions { + wrapped: true, + partial: false, + }, + )?; + assert_eq!(values.len(), length); // this is guaranteed by `cast` + unsafe { + DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + } +} + +/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] backed by a +/// different physical type of the keys, while keeping the values equal. +/// # Errors +/// Errors if any of the old keys' values is larger than the maximum value +/// supported by the new physical type. +pub fn dictionary_to_dictionary_keys( + from: &DictionaryArray, +) -> Result> +where + K1: DictionaryKey + num_traits::NumCast, + K2: DictionaryKey + num_traits::NumCast, +{ + let keys = from.keys(); + let values = from.values(); + let is_ordered = from.is_ordered(); + + let casted_keys = primitive_to_primitive::(keys, &K2::PRIMITIVE.into()); + + if casted_keys.null_count() > keys.null_count() { + Err(Error::Overflow) + } else { + let data_type = DataType::Dictionary( + K2::KEY_TYPE, + Box::new(values.data_type().clone()), + is_ordered, + ); + // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { DictionaryArray::try_new_unchecked(data_type, casted_keys, values.clone()) } + } +} + +/// Similar to dictionary_to_dictionary_keys, but overflowing cast is wrapped +pub fn wrapping_dictionary_to_dictionary_keys( + from: &DictionaryArray, +) -> Result> +where + K1: DictionaryKey + num_traits::AsPrimitive, + K2: DictionaryKey, +{ + let keys = from.keys(); + let values = from.values(); + let is_ordered = from.is_ordered(); + + let casted_keys = primitive_as_primitive::(keys, &K2::PRIMITIVE.into()); + + if casted_keys.null_count() > keys.null_count() { + Err(Error::Overflow) + } else { + let data_type = DataType::Dictionary( + K2::KEY_TYPE, + Box::new(values.data_type().clone()), + is_ordered, + ); + // some of the values may not fit in `usize` and thus this needs to be checked + DictionaryArray::try_new(data_type, casted_keys, values.clone()) + } +} + +pub(super) fn dictionary_cast_dyn( + array: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> { + let array = array.as_any().downcast_ref::>().unwrap(); + let keys = array.keys(); + let values = array.values(); + + match to_type { + DataType::Dictionary(to_keys_type, to_values_type, _) => { + let values = cast(values.as_ref(), to_values_type, options)?; + + // create the appropriate array type + let to_key_type = (*to_keys_type).into(); + + // Safety: + // we return an error on overflow so the integers remain within bounds + match_integer_type!(to_keys_type, |$T| { + key_cast!(keys, values, array, &to_key_type, $T, to_type.clone()) + }) + }, + _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), + } +} + +// Unpack the dictionary +fn unpack_dictionary( + keys: &PrimitiveArray, + values: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> +where + K: DictionaryKey + num_traits::NumCast, +{ + // attempt to cast the dict values to the target type + // use the take kernel to expand out the dictionary + let values = cast(values, to_type, options)?; + + // take requires first casting i32 + let indices = primitive_to_primitive::<_, i32>(keys, &DataType::Int32); + + take(values.as_ref(), &indices) +} + +/// Casts a [`DictionaryArray`] to its values' [`DataType`], also known as unpacking. +/// The resulting array has the same length. +pub fn dictionary_to_values(from: &DictionaryArray) -> Box +where + K: DictionaryKey + num_traits::NumCast, +{ + // take requires first casting i64 + let indices = primitive_to_primitive::<_, i64>(from.keys(), &DataType::Int64); + + // unwrap: The dictionary guarantees that the keys are not out-of-bounds. + take(from.values().as_ref(), &indices).unwrap() +} diff --git a/crates/nano-arrow/src/compute/cast/mod.rs b/crates/nano-arrow/src/compute/cast/mod.rs new file mode 100644 index 000000000000..f13a638a9c0d --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/mod.rs @@ -0,0 +1,989 @@ +//! Defines different casting operators such as [`cast`] or [`primitive_to_binary`]. + +mod binary_to; +mod boolean_to; +mod decimal_to; +mod dictionary_to; +mod primitive_to; +mod utf8_to; + +pub use binary_to::*; +pub use boolean_to::*; +pub use decimal_to::*; +pub use dictionary_to::*; +pub use primitive_to::*; +pub use utf8_to::*; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; + +/// options defining how Cast kernels behave +#[derive(Clone, Copy, Debug, Default)] +pub struct CastOptions { + /// default to false + /// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized). + /// Settings this to `true` is 5-6x faster for numeric types. + pub wrapped: bool, + /// default to false + /// whether to cast to an integer at the best-effort + pub partial: bool, +} + +impl CastOptions { + fn with_wrapped(&self, v: bool) -> Self { + let mut option = *self; + option.wrapped = v; + option + } +} + +/// Returns true if this type is numeric: (UInt*, Unit*, or Float*). +fn is_numeric(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 + ) +} + +macro_rules! primitive_dyn { + ($from:expr, $expr:tt) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from))) + }}; + ($from:expr, $expr:tt, $to:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $to))) + }}; + ($from:expr, $expr:tt, $from_t:expr, $to:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $from_t, $to))) + }}; + ($from:expr, $expr:tt, $arg1:expr, $arg2:expr, $arg3:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $arg1, $arg2, $arg3))) + }}; +} + +/// Return true if a value of type `from_type` can be cast into a +/// value of `to_type`. Note that such as cast may be lossy. +/// +/// If this function returns true to stay consistent with the `cast` kernel below. +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Null, _) | (_, Null) => true, + (Struct(_), _) => false, + (_, Struct(_)) => false, + (FixedSizeList(list_from, _), List(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (FixedSizeList(list_from, _), LargeList(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (List(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (LargeList(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (List(list_from), List(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (LargeList(list_from), LargeList(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (List(list_from), LargeList(list_to)) if list_from == list_to => true, + (LargeList(list_from), List(list_to)) if list_from == list_to => true, + (_, List(list_to)) => can_cast_types(from_type, &list_to.data_type), + (_, LargeList(list_to)) if from_type != &LargeBinary => { + can_cast_types(from_type, &list_to.data_type) + }, + (Dictionary(_, from_value_type, _), Dictionary(_, to_value_type, _)) => { + can_cast_types(from_value_type, to_value_type) + }, + (Dictionary(_, value_type, _), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type, _)) => can_cast_types(from_type, value_type), + + (_, Boolean) => is_numeric(from_type), + (Boolean, _) => { + is_numeric(to_type) + || to_type == &Utf8 + || to_type == &LargeUtf8 + || to_type == &Binary + || to_type == &LargeBinary + }, + + (Utf8, to_type) => { + is_numeric(to_type) + || matches!( + to_type, + LargeUtf8 | Binary | Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, _) + ) + }, + (LargeUtf8, to_type) => { + is_numeric(to_type) + || matches!( + to_type, + Utf8 | LargeBinary | Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, _) + ) + }, + + (Binary, to_type) => { + is_numeric(to_type) || matches!(to_type, LargeBinary | Utf8 | LargeUtf8) + }, + (LargeBinary, to_type) => { + is_numeric(to_type) + || match to_type { + Binary | LargeUtf8 => true, + LargeList(field) => matches!(field.data_type, UInt8), + _ => false, + } + }, + (FixedSizeBinary(_), to_type) => matches!(to_type, Binary | LargeBinary), + (Timestamp(_, _), Utf8) => true, + (Timestamp(_, _), LargeUtf8) => true, + (_, Utf8) => is_numeric(from_type) || from_type == &Binary, + (_, LargeUtf8) => is_numeric(from_type) || from_type == &LargeBinary, + + (_, Binary) => is_numeric(from_type), + (_, LargeBinary) => is_numeric(from_type), + + // start numeric casts + (UInt8, UInt16) => true, + (UInt8, UInt32) => true, + (UInt8, UInt64) => true, + (UInt8, Int8) => true, + (UInt8, Int16) => true, + (UInt8, Int32) => true, + (UInt8, Int64) => true, + (UInt8, Float32) => true, + (UInt8, Float64) => true, + (UInt8, Decimal(_, _)) => true, + + (UInt16, UInt8) => true, + (UInt16, UInt32) => true, + (UInt16, UInt64) => true, + (UInt16, Int8) => true, + (UInt16, Int16) => true, + (UInt16, Int32) => true, + (UInt16, Int64) => true, + (UInt16, Float32) => true, + (UInt16, Float64) => true, + (UInt16, Decimal(_, _)) => true, + + (UInt32, UInt8) => true, + (UInt32, UInt16) => true, + (UInt32, UInt64) => true, + (UInt32, Int8) => true, + (UInt32, Int16) => true, + (UInt32, Int32) => true, + (UInt32, Int64) => true, + (UInt32, Float32) => true, + (UInt32, Float64) => true, + (UInt32, Decimal(_, _)) => true, + + (UInt64, UInt8) => true, + (UInt64, UInt16) => true, + (UInt64, UInt32) => true, + (UInt64, Int8) => true, + (UInt64, Int16) => true, + (UInt64, Int32) => true, + (UInt64, Int64) => true, + (UInt64, Float32) => true, + (UInt64, Float64) => true, + (UInt64, Decimal(_, _)) => true, + + (Int8, UInt8) => true, + (Int8, UInt16) => true, + (Int8, UInt32) => true, + (Int8, UInt64) => true, + (Int8, Int16) => true, + (Int8, Int32) => true, + (Int8, Int64) => true, + (Int8, Float32) => true, + (Int8, Float64) => true, + (Int8, Decimal(_, _)) => true, + + (Int16, UInt8) => true, + (Int16, UInt16) => true, + (Int16, UInt32) => true, + (Int16, UInt64) => true, + (Int16, Int8) => true, + (Int16, Int32) => true, + (Int16, Int64) => true, + (Int16, Float32) => true, + (Int16, Float64) => true, + (Int16, Decimal(_, _)) => true, + + (Int32, UInt8) => true, + (Int32, UInt16) => true, + (Int32, UInt32) => true, + (Int32, UInt64) => true, + (Int32, Int8) => true, + (Int32, Int16) => true, + (Int32, Int64) => true, + (Int32, Float32) => true, + (Int32, Float64) => true, + (Int32, Decimal(_, _)) => true, + + (Int64, UInt8) => true, + (Int64, UInt16) => true, + (Int64, UInt32) => true, + (Int64, UInt64) => true, + (Int64, Int8) => true, + (Int64, Int16) => true, + (Int64, Int32) => true, + (Int64, Float32) => true, + (Int64, Float64) => true, + (Int64, Decimal(_, _)) => true, + + (Float16, Float32) => true, + + (Float32, UInt8) => true, + (Float32, UInt16) => true, + (Float32, UInt32) => true, + (Float32, UInt64) => true, + (Float32, Int8) => true, + (Float32, Int16) => true, + (Float32, Int32) => true, + (Float32, Int64) => true, + (Float32, Float64) => true, + (Float32, Decimal(_, _)) => true, + + (Float64, UInt8) => true, + (Float64, UInt16) => true, + (Float64, UInt32) => true, + (Float64, UInt64) => true, + (Float64, Int8) => true, + (Float64, Int16) => true, + (Float64, Int32) => true, + (Float64, Int64) => true, + (Float64, Float32) => true, + (Float64, Decimal(_, _)) => true, + + ( + Decimal(_, _), + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal(_, _), + ) => true, + // end numeric casts + + // temporal casts + (Int32, Date32) => true, + (Int32, Time32(_)) => true, + (Date32, Int32) => true, + (Date32, Int64) => true, + (Time32(_), Int32) => true, + (Int64, Date64) => true, + (Int64, Time64(_)) => true, + (Date64, Int32) => true, + (Date64, Int64) => true, + (Time64(_), Int64) => true, + (Date32, Date64) => true, + (Date64, Date32) => true, + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true, + (Time32(_), Time64(_)) => true, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true, + (Time64(_), Time32(to_unit)) => { + matches!(to_unit, TimeUnit::Second | TimeUnit::Millisecond) + }, + (Timestamp(_, _), Int64) => true, + (Int64, Timestamp(_, _)) => true, + (Timestamp(_, _), Timestamp(_, _)) => true, + (Timestamp(_, _), Date32) => true, + (Timestamp(_, _), Date64) => true, + (Int64, Duration(_)) => true, + (Duration(_), Int64) => true, + (Interval(_), Interval(IntervalUnit::MonthDayNano)) => true, + (_, _) => false, + } +} + +fn cast_list( + array: &ListArray, + to_type: &DataType, + options: CastOptions, +) -> Result> { + let values = array.values(); + let new_values = cast( + values.as_ref(), + ListArray::::get_child_type(to_type), + options, + )?; + + Ok(ListArray::::new( + to_type.clone(), + array.offsets().clone(), + new_values, + array.validity().cloned(), + )) +} + +fn cast_list_to_large_list(array: &ListArray, to_type: &DataType) -> ListArray { + let offsets = array.offsets().into(); + + ListArray::::new( + to_type.clone(), + offsets, + array.values().clone(), + array.validity().cloned(), + ) +} + +fn cast_large_to_list(array: &ListArray, to_type: &DataType) -> ListArray { + let offsets = array.offsets().try_into().expect("Convertme to error"); + + ListArray::::new( + to_type.clone(), + offsets, + array.values().clone(), + array.validity().cloned(), + ) +} + +fn cast_fixed_size_list_to_list( + fixed: &FixedSizeListArray, + to_type: &DataType, + options: CastOptions, +) -> Result> { + let new_values = cast( + fixed.values().as_ref(), + ListArray::::get_child_type(to_type), + options, + )?; + + let offsets = (0..=fixed.len()) + .map(|ix| O::from_as_usize(ix * fixed.size())) + .collect::>(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + Ok(ListArray::::new( + to_type.clone(), + offsets.into(), + new_values, + fixed.validity().cloned(), + )) +} + +fn cast_list_to_fixed_size_list( + list: &ListArray, + inner: &Field, + size: usize, + options: CastOptions, +) -> Result { + let offsets = list.offsets().buffer().iter(); + let expected = (0..list.len()).map(|ix| O::from_as_usize(ix * size)); + + match offsets + .zip(expected) + .find(|(actual, expected)| *actual != expected) + { + Some(_) => Err(Error::InvalidArgumentError( + "incompatible offsets in source list".to_string(), + )), + None => { + let sliced_values = list.values().sliced( + list.offsets().first().to_usize(), + list.offsets().range().to_usize(), + ); + let new_values = cast(sliced_values.as_ref(), inner.data_type(), options)?; + Ok(FixedSizeListArray::new( + DataType::FixedSizeList(Box::new(inner.clone()), size), + new_values, + list.validity().cloned(), + )) + }, + } +} + +/// Cast `array` to the provided data type and return a new [`Array`] with +/// type `to_type`, if possible. +/// +/// Behavior: +/// * PrimitiveArray to PrimitiveArray: overflowing cast will be None +/// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings +/// in integer casts return null +/// * Numeric to boolean: 0 returns `false`, any other value returns `true` +/// * List to List: the underlying data type is cast +/// * Fixed Size List to List: the underlying data type is cast +/// * List to Fixed Size List: the offsets are checked for valid order, then the +/// underlying type is cast. +/// * PrimitiveArray to List: a list array with 1 value per slot is created +/// * Date32 and Date64: precision lost when going to higher interval +/// * Time32 and Time64: precision lost when going to higher interval +/// * Timestamp and Date{32|64}: precision lost when going to higher interval +/// * Temporal to/from backing primitive: zero-copy with data type change +/// Unsupported Casts +/// * To or from `StructArray` +/// * List to primitive +/// * Utf8 to boolean +/// * Interval and duration +pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Result> { + use DataType::*; + let from_type = array.data_type(); + + // clone array if types are the same + if from_type == to_type { + return Ok(clone(array)); + } + + let as_options = options.with_wrapped(true); + match (from_type, to_type) { + (Null, _) | (_, Null) => Ok(new_null_array(to_type.clone(), array.len())), + (Struct(_), _) => Err(Error::NotYetImplemented( + "Cannot cast from struct to other types".to_string(), + )), + (_, Struct(_)) => Err(Error::NotYetImplemented( + "Cannot cast to struct from other types".to_string(), + )), + (List(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( + array.as_any().downcast_ref().unwrap(), + inner.as_ref(), + *size, + options, + ) + .map(|x| x.boxed()), + (LargeList(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( + array.as_any().downcast_ref().unwrap(), + inner.as_ref(), + *size, + options, + ) + .map(|x| x.boxed()), + (FixedSizeList(_, _), List(_)) => cast_fixed_size_list_to_list::( + array.as_any().downcast_ref().unwrap(), + to_type, + options, + ) + .map(|x| x.boxed()), + (FixedSizeList(_, _), LargeList(_)) => cast_fixed_size_list_to_list::( + array.as_any().downcast_ref().unwrap(), + to_type, + options, + ) + .map(|x| x.boxed()), + (List(_), List(_)) => { + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) + .map(|x| x.boxed()) + }, + (LargeList(_), LargeList(_)) => { + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) + .map(|x| x.boxed()) + }, + (List(lhs), LargeList(rhs)) if lhs == rhs => { + Ok(cast_list_to_large_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + }, + (LargeList(lhs), List(rhs)) if lhs == rhs => { + Ok(cast_large_to_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + }, + + (_, List(to)) => { + // cast primitive to list's primitive + let values = cast(array, &to.data_type, options)?; + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets = (0..=array.len() as i32).collect::>(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); + + Ok(Box::new(list_array)) + }, + + (_, LargeList(to)) if from_type != &LargeBinary => { + // cast primitive to list's primitive + let values = cast(array, &to.data_type, options)?; + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets = (0..=array.len() as i64).collect::>(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); + + Ok(Box::new(list_array)) + }, + + (Dictionary(index_type, ..), _) => match_integer_type!(index_type, |$T| { + dictionary_cast_dyn::<$T>(array, to_type, options) + }), + (_, Dictionary(index_type, value_type, _)) => match_integer_type!(index_type, |$T| { + cast_to_dictionary::<$T>(array, value_type, options) + }), + (_, Boolean) => match from_type { + UInt8 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt16 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt32 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int8 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int16 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int32 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Float32 => primitive_to_boolean_dyn::(array, to_type.clone()), + Float64 => primitive_to_boolean_dyn::(array, to_type.clone()), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Boolean, _) => match to_type { + UInt8 => boolean_to_primitive_dyn::(array), + UInt16 => boolean_to_primitive_dyn::(array), + UInt32 => boolean_to_primitive_dyn::(array), + UInt64 => boolean_to_primitive_dyn::(array), + Int8 => boolean_to_primitive_dyn::(array), + Int16 => boolean_to_primitive_dyn::(array), + Int32 => boolean_to_primitive_dyn::(array), + Int64 => boolean_to_primitive_dyn::(array), + Float32 => boolean_to_primitive_dyn::(array), + Float64 => boolean_to_primitive_dyn::(array), + LargeUtf8 => boolean_to_utf8_dyn::(array), + LargeBinary => boolean_to_binary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (Utf8, _) => match to_type { + UInt8 => utf8_to_primitive_dyn::(array, to_type, options), + UInt16 => utf8_to_primitive_dyn::(array, to_type, options), + UInt32 => utf8_to_primitive_dyn::(array, to_type, options), + UInt64 => utf8_to_primitive_dyn::(array, to_type, options), + Int8 => utf8_to_primitive_dyn::(array, to_type, options), + Int16 => utf8_to_primitive_dyn::(array, to_type, options), + Int32 => utf8_to_primitive_dyn::(array, to_type, options), + Int64 => utf8_to_primitive_dyn::(array, to_type, options), + Float32 => utf8_to_primitive_dyn::(array, to_type, options), + Float64 => utf8_to_primitive_dyn::(array, to_type, options), + Date32 => utf8_to_date32_dyn::(array), + Date64 => utf8_to_date64_dyn::(array), + LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( + array.as_any().downcast_ref().unwrap(), + ))), + Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, Some(tz)) => { + utf8_to_timestamp_ns_dyn::(array, tz.clone()) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (LargeUtf8, _) => match to_type { + UInt8 => utf8_to_primitive_dyn::(array, to_type, options), + UInt16 => utf8_to_primitive_dyn::(array, to_type, options), + UInt32 => utf8_to_primitive_dyn::(array, to_type, options), + UInt64 => utf8_to_primitive_dyn::(array, to_type, options), + Int8 => utf8_to_primitive_dyn::(array, to_type, options), + Int16 => utf8_to_primitive_dyn::(array, to_type, options), + Int32 => utf8_to_primitive_dyn::(array, to_type, options), + Int64 => utf8_to_primitive_dyn::(array, to_type, options), + Float32 => utf8_to_primitive_dyn::(array, to_type, options), + Float64 => utf8_to_primitive_dyn::(array, to_type, options), + Date32 => utf8_to_date32_dyn::(array), + Date64 => utf8_to_date64_dyn::(array), + Utf8 => utf8_large_to_utf8(array.as_any().downcast_ref().unwrap()).map(|x| x.boxed()), + LargeBinary => Ok(utf8_to_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, Some(tz)) => { + utf8_to_timestamp_ns_dyn::(array, tz.clone()) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, Utf8) => match from_type { + UInt8 => primitive_to_utf8_dyn::(array), + UInt16 => primitive_to_utf8_dyn::(array), + UInt32 => primitive_to_utf8_dyn::(array), + UInt64 => primitive_to_utf8_dyn::(array), + Int8 => primitive_to_utf8_dyn::(array), + Int16 => primitive_to_utf8_dyn::(array), + Int32 => primitive_to_utf8_dyn::(array), + Int64 => primitive_to_utf8_dyn::(array), + Float32 => primitive_to_utf8_dyn::(array), + Float64 => primitive_to_utf8_dyn::(array), + Timestamp(from_unit, Some(tz)) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(timestamp_to_utf8::(from, *from_unit, tz)?)) + }, + Timestamp(from_unit, None) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(naive_timestamp_to_utf8::(from, *from_unit))) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, LargeUtf8) => match from_type { + UInt8 => primitive_to_utf8_dyn::(array), + UInt16 => primitive_to_utf8_dyn::(array), + UInt32 => primitive_to_utf8_dyn::(array), + UInt64 => primitive_to_utf8_dyn::(array), + Int8 => primitive_to_utf8_dyn::(array), + Int16 => primitive_to_utf8_dyn::(array), + Int32 => primitive_to_utf8_dyn::(array), + Int64 => primitive_to_utf8_dyn::(array), + Float32 => primitive_to_utf8_dyn::(array), + Float64 => primitive_to_utf8_dyn::(array), + LargeBinary => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + Timestamp(from_unit, Some(tz)) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(timestamp_to_utf8::(from, *from_unit, tz)?)) + }, + Timestamp(from_unit, None) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(naive_timestamp_to_utf8::(from, *from_unit))) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (Binary, _) => match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type, options), + UInt16 => binary_to_primitive_dyn::(array, to_type, options), + UInt32 => binary_to_primitive_dyn::(array, to_type, options), + UInt64 => binary_to_primitive_dyn::(array, to_type, options), + Int8 => binary_to_primitive_dyn::(array, to_type, options), + Int16 => binary_to_primitive_dyn::(array, to_type, options), + Int32 => binary_to_primitive_dyn::(array, to_type, options), + Int64 => binary_to_primitive_dyn::(array, to_type, options), + Float32 => binary_to_primitive_dyn::(array, to_type, options), + Float64 => binary_to_primitive_dyn::(array, to_type, options), + LargeBinary => Ok(Box::new(binary_to_large_binary( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ))), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (LargeBinary, _) => { + match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type, options), + UInt16 => binary_to_primitive_dyn::(array, to_type, options), + UInt32 => binary_to_primitive_dyn::(array, to_type, options), + UInt64 => binary_to_primitive_dyn::(array, to_type, options), + Int8 => binary_to_primitive_dyn::(array, to_type, options), + Int16 => binary_to_primitive_dyn::(array, to_type, options), + Int32 => binary_to_primitive_dyn::(array, to_type, options), + Int64 => binary_to_primitive_dyn::(array, to_type, options), + Float32 => binary_to_primitive_dyn::(array, to_type, options), + Float64 => binary_to_primitive_dyn::(array, to_type, options), + Binary => { + binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + LargeUtf8 => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + LargeList(inner) if matches!(inner.data_type, DataType::UInt8) => Ok( + binary_to_list::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .boxed(), + ), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + } + }, + (FixedSizeBinary(_), _) => match to_type { + Binary => Ok(fixed_size_binary_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + LargeBinary => Ok(fixed_size_binary_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, Binary) => match from_type { + UInt8 => primitive_to_binary_dyn::(array), + UInt16 => primitive_to_binary_dyn::(array), + UInt32 => primitive_to_binary_dyn::(array), + UInt64 => primitive_to_binary_dyn::(array), + Int8 => primitive_to_binary_dyn::(array), + Int16 => primitive_to_binary_dyn::(array), + Int32 => primitive_to_binary_dyn::(array), + Int64 => primitive_to_binary_dyn::(array), + Float32 => primitive_to_binary_dyn::(array), + Float64 => primitive_to_binary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, LargeBinary) => match from_type { + UInt8 => primitive_to_binary_dyn::(array), + UInt16 => primitive_to_binary_dyn::(array), + UInt32 => primitive_to_binary_dyn::(array), + UInt64 => primitive_to_binary_dyn::(array), + Int8 => primitive_to_binary_dyn::(array), + Int16 => primitive_to_binary_dyn::(array), + Int32 => primitive_to_binary_dyn::(array), + Int64 => primitive_to_binary_dyn::(array), + Float32 => primitive_to_binary_dyn::(array), + Float64 => primitive_to_binary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + // start numeric casts + (UInt8, UInt16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, Int16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Float16, Float32) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(f16_to_f32(from).boxed()) + }, + + (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Float32, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Decimal(_, _), UInt8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Float32) => decimal_to_float_dyn::(array), + (Decimal(_, _), Float64) => decimal_to_float_dyn::(array), + (Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s), + // end numeric casts + + // temporal casts + (Int32, Date32) => primitive_to_same_primitive_dyn::(array, to_type), + (Int32, Time32(TimeUnit::Second)) => primitive_to_same_primitive_dyn::(array, to_type), + (Int32, Time32(TimeUnit::Millisecond)) => { + primitive_to_same_primitive_dyn::(array, to_type) + }, + // No support for microsecond/nanosecond with i32 + (Date32, Int32) => primitive_to_same_primitive_dyn::(array, to_type), + (Date32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Time32(_), Int32) => primitive_to_same_primitive_dyn::(array, to_type), + (Int64, Date64) => primitive_to_same_primitive_dyn::(array, to_type), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => { + primitive_to_same_primitive_dyn::(array, to_type) + }, + (Int64, Time64(TimeUnit::Nanosecond)) => { + primitive_to_same_primitive_dyn::(array, to_type) + }, + + (Date64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Date64, Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Time64(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Date32, Date64) => primitive_dyn!(array, date32_to_date64), + (Date64, Date32) => primitive_dyn!(array, date64_to_date32), + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => { + primitive_dyn!(array, time32s_to_time32ms) + }, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => { + primitive_dyn!(array, time32ms_to_time32s) + }, + (Time32(from_unit), Time64(to_unit)) => { + primitive_dyn!(array, time32_to_time64, *from_unit, *to_unit) + }, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => { + primitive_dyn!(array, time64us_to_time64ns) + }, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => { + primitive_dyn!(array, time64ns_to_time64us) + }, + (Time64(from_unit), Time32(to_unit)) => { + primitive_dyn!(array, time64_to_time32, *from_unit, *to_unit) + }, + (Timestamp(_, _), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Int64, Timestamp(_, _)) => primitive_to_same_primitive_dyn::(array, to_type), + (Timestamp(from_unit, _), Timestamp(to_unit, tz)) => { + primitive_dyn!(array, timestamp_to_timestamp, *from_unit, *to_unit, tz) + }, + (Timestamp(from_unit, _), Date32) => primitive_dyn!(array, timestamp_to_date32, *from_unit), + (Timestamp(from_unit, _), Date64) => primitive_dyn!(array, timestamp_to_date64, *from_unit), + + (Int64, Duration(_)) => primitive_to_same_primitive_dyn::(array, to_type), + (Duration(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::MonthDayNano)) => { + primitive_dyn!(array, days_ms_to_months_days_ns) + }, + (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::MonthDayNano)) => { + primitive_dyn!(array, months_to_months_days_ns) + }, + + (_, _) => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + } +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + options: CastOptions, +) -> Result> { + let array = cast(array, dict_value_type, options)?; + let array = array.as_ref(); + match *dict_value_type { + DataType::Int8 => primitive_to_dictionary_dyn::(array), + DataType::Int16 => primitive_to_dictionary_dyn::(array), + DataType::Int32 => primitive_to_dictionary_dyn::(array), + DataType::Int64 => primitive_to_dictionary_dyn::(array), + DataType::UInt8 => primitive_to_dictionary_dyn::(array), + DataType::UInt16 => primitive_to_dictionary_dyn::(array), + DataType::UInt32 => primitive_to_dictionary_dyn::(array), + DataType::UInt64 => primitive_to_dictionary_dyn::(array), + DataType::Utf8 => utf8_to_dictionary_dyn::(array), + DataType::LargeUtf8 => utf8_to_dictionary_dyn::(array), + DataType::Binary => binary_to_dictionary_dyn::(array), + DataType::LargeBinary => binary_to_dictionary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Unsupported output type for dictionary packing: {dict_value_type:?}" + ))), + } +} diff --git a/crates/nano-arrow/src/compute/cast/primitive_to.rs b/crates/nano-arrow/src/compute/cast/primitive_to.rs new file mode 100644 index 000000000000..661e40bc6343 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/primitive_to.rs @@ -0,0 +1,584 @@ +use std::hash::Hash; + +use num_traits::{AsPrimitive, Float, ToPrimitive}; + +use super::CastOptions; +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::compute::arity::unary; +use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; +use crate::error::Result; +use crate::offset::{Offset, Offsets}; +use crate::temporal_conversions::*; +use crate::types::{days_ms, f16, months_days_ns, NativeType}; + +/// Returns a [`BinaryArray`] where every element is the binary representation of the number. +pub fn primitive_to_binary( + from: &PrimitiveArray, +) -> BinaryArray { + let mut values: Vec = Vec::with_capacity(from.len()); + let mut offsets: Vec = Vec::with_capacity(from.len() + 1); + offsets.push(O::default()); + + let mut offset: usize = 0; + + unsafe { + for x in from.values().iter() { + values.reserve(offset + T::FORMATTED_SIZE_DECIMAL); + + let bytes = std::slice::from_raw_parts_mut( + values.as_mut_ptr().add(offset), + values.capacity() - offset, + ); + let len = lexical_core::write_unchecked(*x, bytes).len(); + + offset += len; + offsets.push(O::from_as_usize(offset)); + } + values.set_len(offset); + values.shrink_to_fit(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + BinaryArray::::new( + BinaryArray::::default_data_type(), + offsets.into(), + values.into(), + from.validity().cloned(), + ) + } +} + +pub(super) fn primitive_to_binary_dyn(from: &dyn Array) -> Result> +where + O: Offset, + T: NativeType + lexical_core::ToLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_binary::(from))) +} + +/// Returns a [`BooleanArray`] where every element is different from zero. +/// Validity is preserved. +pub fn primitive_to_boolean( + from: &PrimitiveArray, + to_type: DataType, +) -> BooleanArray { + let iter = from.values().iter().map(|v| *v != T::default()); + let values = Bitmap::from_trusted_len_iter(iter); + + BooleanArray::new(to_type, values, from.validity().cloned()) +} + +pub(super) fn primitive_to_boolean_dyn( + from: &dyn Array, + to_type: DataType, +) -> Result> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_boolean::(from, to_type))) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number. +pub fn primitive_to_utf8( + from: &PrimitiveArray, +) -> Utf8Array { + let mut values: Vec = Vec::with_capacity(from.len()); + let mut offsets: Vec = Vec::with_capacity(from.len() + 1); + offsets.push(O::default()); + + let mut offset: usize = 0; + + unsafe { + for x in from.values().iter() { + values.reserve(offset + T::FORMATTED_SIZE_DECIMAL); + + let bytes = std::slice::from_raw_parts_mut( + values.as_mut_ptr().add(offset), + values.capacity() - offset, + ); + let len = lexical_core::write_unchecked(*x, bytes).len(); + + offset += len; + offsets.push(O::from_as_usize(offset)); + } + values.set_len(offset); + values.shrink_to_fit(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + Utf8Array::::new_unchecked( + Utf8Array::::default_data_type(), + offsets.into(), + values.into(), + from.validity().cloned(), + ) + } +} + +pub(super) fn primitive_to_utf8_dyn(from: &dyn Array) -> Result> +where + O: Offset, + T: NativeType + lexical_core::ToLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_utf8::(from))) +} + +pub(super) fn primitive_to_primitive_dyn( + from: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> +where + I: NativeType + num_traits::NumCast + num_traits::AsPrimitive, + O: NativeType + num_traits::NumCast, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + if options.wrapped { + Ok(Box::new(primitive_as_primitive::(from, to_type))) + } else { + Ok(Box::new(primitive_to_primitive::(from, to_type))) + } +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of another physical type via numeric conversion. +pub fn primitive_to_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + I: NativeType + num_traits::NumCast, + O: NativeType + num_traits::NumCast, +{ + let iter = from + .iter() + .map(|v| v.and_then(|x| num_traits::cast::cast::(*x))); + PrimitiveArray::::from_trusted_len_iter(iter).to(to_type.clone()) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn integer_to_decimal>( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let multiplier = 10_i128.pow(to_scale as u32); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + x.as_().checked_mul(multiplier).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn integer_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> Result> +where + T: NativeType + AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(integer_to_decimal::(from, precision, scale))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn float_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray +where + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, +{ + // 1.2 => 12 + let multiplier: T = (10_f64).powi(to_scale as i32).as_(); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + let x = (*x * multiplier).to_i128().unwrap(); + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn float_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> Result> +where + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(float_to_decimal::(from, precision, scale))) +} + +/// Cast [`PrimitiveArray`] as a [`PrimitiveArray`] +/// Same as `number as to_number_type` in rust +pub fn primitive_as_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + I: NativeType + num_traits::AsPrimitive, + O: NativeType, +{ + unary(from, num_traits::AsPrimitive::::as_, to_type.clone()) +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. +/// This is O(1). +pub fn primitive_to_same_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + T: NativeType, +{ + PrimitiveArray::::new( + to_type.clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. +/// This is O(1). +pub(super) fn primitive_to_same_primitive_dyn( + from: &dyn Array, + to_type: &DataType, +) -> Result> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_same_primitive::(from, to_type))) +} + +pub(super) fn primitive_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + primitive_to_dictionary::(from).map(|x| Box::new(x) as Box) +} + +/// Cast [`PrimitiveArray`] to [`DictionaryArray`]. Also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn primitive_to_dictionary( + from: &PrimitiveArray, +) -> Result> { + let iter = from.iter().map(|x| x.copied()); + let mut array = MutableDictionaryArray::::try_empty(MutablePrimitiveArray::::from( + from.data_type().clone(), + ))?; + array.try_extend(iter)?; + + Ok(array.into()) +} + +/// Get the time unit as a multiple of a second +const fn time_unit_multiple(unit: TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +/// Conversion of dates +pub fn date32_to_date64(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x as i64 * MILLISECONDS_IN_DAY, DataType::Date64) +} + +/// Conversion of dates +pub fn date64_to_date32(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| (x / MILLISECONDS_IN_DAY) as i32, DataType::Date32) +} + +/// Conversion of times +pub fn time32s_to_time32ms(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x * 1000, DataType::Time32(TimeUnit::Millisecond)) +} + +/// Conversion of times +pub fn time32ms_to_time32s(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x / 1000, DataType::Time32(TimeUnit::Second)) +} + +/// Conversion of times +pub fn time64us_to_time64ns(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x * 1000, DataType::Time64(TimeUnit::Nanosecond)) +} + +/// Conversion of times +pub fn time64ns_to_time64us(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x / 1000, DataType::Time64(TimeUnit::Microsecond)) +} + +/// Conversion of timestamp +pub fn timestamp_to_date64(from: &PrimitiveArray, from_unit: TimeUnit) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = MILLISECONDS; + let to_type = DataType::Date64; + + // Scale time_array by (to_size / from_size) using a + // single integer operation, but need to avoid integer + // math rounding down to zero + + match to_size.cmp(&from_size) { + std::cmp::Ordering::Less => unary(from, |x| (x / (from_size / to_size)), to_type), + std::cmp::Ordering::Equal => primitive_to_same_primitive(from, &to_type), + std::cmp::Ordering::Greater => unary(from, |x| (x * (to_size / from_size)), to_type), + } +} + +/// Conversion of timestamp +pub fn timestamp_to_date32(from: &PrimitiveArray, from_unit: TimeUnit) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; + unary(from, |x| (x / from_size) as i32, DataType::Date32) +} + +/// Conversion of time +pub fn time32_to_time64( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let divisor = to_size / from_size; + unary(from, |x| (x as i64 * divisor), DataType::Time64(to_unit)) +} + +/// Conversion of time +pub fn time64_to_time32( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let divisor = from_size / to_size; + unary(from, |x| (x / divisor) as i32, DataType::Time32(to_unit)) +} + +/// Conversion of timestamp +pub fn timestamp_to_timestamp( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, + tz: &Option, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let to_type = DataType::Timestamp(to_unit, tz.clone()); + // we either divide or multiply, depending on size of each unit + if from_size >= to_size { + unary(from, |x| (x / (from_size / to_size)), to_type) + } else { + unary(from, |x| (x * (to_size / from_size)), to_type) + } +} + +fn timestamp_to_utf8_impl( + from: &PrimitiveArray, + time_unit: TimeUnit, + timezone: T, +) -> Utf8Array +where + T::Offset: std::fmt::Display, +{ + match time_unit { + TimeUnit::Nanosecond => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_ns_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Microsecond => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_us_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Millisecond => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_ms_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Second => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_s_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + } +} + +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +fn chrono_tz_timestamp_to_utf8( + from: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, +) -> Result> { + let timezone = parse_offset_tz(timezone_str)?; + Ok(timestamp_to_utf8_impl::( + from, time_unit, timezone, + )) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz_timestamp_to_utf8( + _: &PrimitiveArray, + _: TimeUnit, + timezone_str: &str, +) -> Result> { + use crate::error::Error; + Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. +pub fn timestamp_to_utf8( + from: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, +) -> Result> { + let timezone = parse_offset(timezone_str); + + if let Ok(timezone) = timezone { + Ok(timestamp_to_utf8_impl::( + from, time_unit, timezone, + )) + } else { + chrono_tz_timestamp_to_utf8(from, time_unit, timezone_str) + } +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. +pub fn naive_timestamp_to_utf8( + from: &PrimitiveArray, + time_unit: TimeUnit, +) -> Utf8Array { + match time_unit { + TimeUnit::Nanosecond => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_ns_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Microsecond => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_us_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Millisecond => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_ms_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Second => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_s_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + } +} + +#[inline] +fn days_ms_to_months_days_ns_scalar(from: days_ms) -> months_days_ns { + months_days_ns::new(0, from.days(), from.milliseconds() as i64 * 1000) +} + +/// Casts [`days_ms`]s to [`months_days_ns`]. This operation is infalible and lossless. +pub fn days_ms_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + days_ms_to_months_days_ns_scalar, + DataType::Interval(IntervalUnit::MonthDayNano), + ) +} + +#[inline] +fn months_to_months_days_ns_scalar(from: i32) -> months_days_ns { + months_days_ns::new(from, 0, 0) +} + +/// Casts months represented as [`i32`]s to [`months_days_ns`]. This operation is infalible and lossless. +pub fn months_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + months_to_months_days_ns_scalar, + DataType::Interval(IntervalUnit::MonthDayNano), + ) +} + +/// Casts f16 into f32 +pub fn f16_to_f32(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x.to_f32(), DataType::Float32) +} diff --git a/crates/nano-arrow/src/compute/cast/utf8_to.rs b/crates/nano-arrow/src/compute/cast/utf8_to.rs new file mode 100644 index 000000000000..9c86ff85da54 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/utf8_to.rs @@ -0,0 +1,176 @@ +use chrono::Datelike; + +use super::CastOptions; +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::offset::Offset; +use crate::temporal_conversions::{ + utf8_to_naive_timestamp_ns as utf8_to_naive_timestamp_ns_, + utf8_to_timestamp_ns as utf8_to_timestamp_ns_, EPOCH_DAYS_FROM_CE, +}; +use crate::types::NativeType; + +const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; + +/// Casts a [`Utf8Array`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub fn utf8_to_primitive(from: &Utf8Array, to: &DataType) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse(x.as_bytes()).ok())); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +/// Casts a [`Utf8Array`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero. +pub fn partial_utf8_to_primitive( + from: &Utf8Array, + to: &DataType, +) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from.iter().map(|x| { + x.and_then::(|x| lexical_core::parse_partial(x.as_bytes()).ok().map(|x| x.0)) + }); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn utf8_to_primitive_dyn( + from: &dyn Array, + to: &DataType, + options: CastOptions, +) -> Result> +where + T: NativeType + lexical_core::FromLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + if options.partial { + Ok(Box::new(partial_utf8_to_primitive::(from, to))) + } else { + Ok(Box::new(utf8_to_primitive::(from, to))) + } +} + +/// Casts a [`Utf8Array`] to a Date32 primitive, making any uncastable value a Null. +pub fn utf8_to_date32(from: &Utf8Array) -> PrimitiveArray { + let iter = from.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + }) + }); + PrimitiveArray::::from_trusted_len_iter(iter).to(DataType::Date32) +} + +pub(super) fn utf8_to_date32_dyn(from: &dyn Array) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8_to_date32::(from))) +} + +/// Casts a [`Utf8Array`] to a Date64 primitive, making any uncastable value a Null. +pub fn utf8_to_date64(from: &Utf8Array) -> PrimitiveArray { + let iter = from.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| (x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) as i64 * 86400000) + }) + }); + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Date64) +} + +pub(super) fn utf8_to_date64_dyn(from: &dyn Array) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8_to_date64::(from))) +} + +pub(super) fn utf8_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let values = from.as_any().downcast_ref().unwrap(); + utf8_to_dictionary::(values).map(|x| Box::new(x) as Box) +} + +/// Cast [`Utf8Array`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn utf8_to_dictionary( + from: &Utf8Array, +) -> Result> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn utf8_to_naive_timestamp_ns_dyn( + from: &dyn Array, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8_to_naive_timestamp_ns::(from))) +} + +/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting +pub fn utf8_to_naive_timestamp_ns(from: &Utf8Array) -> PrimitiveArray { + utf8_to_naive_timestamp_ns_(from, RFC3339) +} + +pub(super) fn utf8_to_timestamp_ns_dyn( + from: &dyn Array, + timezone: String, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + utf8_to_timestamp_ns::(from, timezone) + .map(Box::new) + .map(|x| x as Box) +} + +/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting +pub fn utf8_to_timestamp_ns( + from: &Utf8Array, + timezone: String, +) -> Result> { + utf8_to_timestamp_ns_(from, RFC3339, timezone) +} + +/// Conversion of utf8 +pub fn utf8_to_large_utf8(from: &Utf8Array) -> Utf8Array { + let data_type = Utf8Array::::default_data_type(); + let validity = from.validity().cloned(); + let values = from.values().clone(); + + let offsets = from.offsets().into(); + // Safety: sound because `values` fulfills the same invariants as `from.values()` + unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } +} + +/// Conversion of utf8 +pub fn utf8_large_to_utf8(from: &Utf8Array) -> Result> { + let data_type = Utf8Array::::default_data_type(); + let validity = from.validity().cloned(); + let values = from.values().clone(); + let offsets = from.offsets().try_into()?; + + // Safety: sound because `values` fulfills the same invariants as `from.values()` + Ok(unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) }) +} + +/// Conversion to binary +pub fn utf8_to_binary(from: &Utf8Array, to_data_type: DataType) -> BinaryArray { + // Safety: erasure of an invariant is always safe + unsafe { + BinaryArray::::new( + to_data_type, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/binary.rs b/crates/nano-arrow/src/compute/comparison/binary.rs new file mode 100644 index 000000000000..af87362a7841 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/binary.rs @@ -0,0 +1,238 @@ +//! Comparison functions for [`BinaryArray`] +use super::super::utils::combine_validities; +use crate::array::{BinaryArray, BooleanArray}; +use crate::bitmap::Bitmap; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &BinaryArray, rhs: &BinaryArray, op: F) -> BooleanArray +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + assert_eq!(lhs.len(), rhs.len()); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values_iter() + .zip(rhs.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`] and scalar using +/// a specified comparison function. +fn compare_op_scalar(lhs: &BinaryArray, rhs: &[u8], op: F) -> BooleanArray +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + let validity = lhs.validity().cloned(); + + let values = lhs.values_iter().map(|lhs| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`]. +/// # Panic +/// iff the arrays do not have the same length. +pub fn eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and include validities in comparison. +/// # Panic +/// iff the arrays do not have the same length. +pub fn eq_and_validity(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a == b); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar. +pub fn eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar and include validities in comparison. +pub fn eq_scalar_and_validity(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a == b); + + finish_eq_validities(out, validity, None) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`]. +/// # Panic +/// iff the arrays do not have the same length. +pub fn neq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`]. +/// # Panic +/// iff the arrays do not have the same length and include validities in comparison. +pub fn neq_and_validity(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + + let out = compare_op(&lhs, &rhs, |a, b| a != b); + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar. +pub fn neq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar and include validities in comparison. +pub fn neq_scalar_and_validity(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a != b); + + finish_neq_validities(out, validity, None) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`]. +pub fn lt(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`] and a scalar. +pub fn lt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`]. +pub fn lt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`] and a scalar. +pub fn lt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`]. +pub fn gt(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`] and a scalar. +pub fn gt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`]. +pub fn gt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a >= b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`] and a scalar. +pub fn gt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a >= b) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_generic, &BinaryArray) -> BooleanArray>( + lhs: Vec<&[u8]>, + rhs: Vec<&[u8]>, + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let rhs = BinaryArray::::from_slice(rhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, &rhs), expected); + } + + fn test_generic_scalar, &[u8]) -> BooleanArray>( + lhs: Vec<&[u8]>, + rhs: &[u8], + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, rhs), expected); + } + + #[test] + fn test_gt_eq() { + test_generic::( + vec![b"arrow", b"datafusion", b"flight", b"parquet"], + vec![b"flight", b"flight", b"flight", b"flight"], + gt_eq, + vec![false, false, true, true], + ) + } + + #[test] + fn test_gt_eq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"datafusion", b"flight", b"parquet"], + b"flight", + gt_eq_scalar, + vec![false, false, true, true], + ) + } + + #[test] + fn test_eq() { + test_generic::( + vec![b"arrow", b"arrow", b"arrow", b"arrow"], + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + eq, + vec![true, false, false, false], + ) + } + + #[test] + fn test_eq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + b"arrow", + eq_scalar, + vec![true, false, false, false], + ) + } + + #[test] + fn test_neq() { + test_generic::( + vec![b"arrow", b"arrow", b"arrow", b"arrow"], + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + neq, + vec![false, true, true, true], + ) + } + + #[test] + fn test_neq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + b"arrow", + neq_scalar, + vec![false, true, true, true], + ) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/boolean.rs b/crates/nano-arrow/src/compute/comparison/boolean.rs new file mode 100644 index 000000000000..6b62f7fc6b00 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/boolean.rs @@ -0,0 +1,172 @@ +//! Comparison functions for [`BooleanArray`] +use super::super::utils::combine_validities; +use crate::array::BooleanArray; +use crate::bitmap::{binary, unary, Bitmap}; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; + +/// Evaluate `op(lhs, rhs)` for [`BooleanArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray +where + F: Fn(u64, u64) -> u64, +{ + assert_eq!(lhs.len(), rhs.len()); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = binary(lhs.values(), rhs.values(), op); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Evaluate `op(left, right)` for [`BooleanArray`] and scalar using +/// a specified comparison function. +pub fn compare_op_scalar(lhs: &BooleanArray, rhs: bool, op: F) -> BooleanArray +where + F: Fn(u64, u64) -> u64, +{ + let rhs = if rhs { !0 } else { 0 }; + + let values = unary(lhs.values(), |x| op(x, rhs)); + BooleanArray::new(DataType::Boolean, values, lhs.validity().cloned()) +} + +/// Perform `lhs == rhs` operation on two [`BooleanArray`]s. +pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| !(a ^ b)) +} + +/// Perform `lhs == rhs` operation on two [`BooleanArray`]s and include validities in comparison. +pub fn eq_and_validity(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| !(a ^ b)); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value. +pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + lhs.clone() + } else { + compare_op_scalar(lhs, rhs, |a, _| !a) + } +} + +/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value and include validities in comparison. +pub fn eq_scalar_and_validity(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + if rhs { + finish_eq_validities(lhs, validity, None) + } else { + let lhs = lhs.with_validity(None); + + let out = compare_op_scalar(&lhs, rhs, |a, _| !a); + + finish_eq_validities(out, validity, None) + } +} + +/// `lhs != rhs` for [`BooleanArray`] +pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a ^ b) +} + +/// `lhs != rhs` for [`BooleanArray`] and include validities in comparison. +pub fn neq_and_validity(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a ^ b); + + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + eq_scalar(lhs, !rhs) +} + +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar_and_validity(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = eq_scalar(&lhs, !rhs); + finish_neq_validities(out, validity, None) +} + +/// Perform `left < right` operation on two arrays. +pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| !a & b) +} + +/// Perform `left < right` operation on an array and a scalar value. +pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + compare_op_scalar(lhs, rhs, |a, _| !a) + } else { + BooleanArray::new( + DataType::Boolean, + Bitmap::new_zeroed(lhs.len()), + lhs.validity().cloned(), + ) + } +} + +/// Perform `left <= right` operation on two arrays. +pub fn lt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| !a | b) +} + +/// Perform `left <= right` operation on an array and a scalar value. +/// Null values are less than non-null values. +pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + let all_ones = !0; + compare_op_scalar(lhs, rhs, |_, _| all_ones) + } else { + compare_op_scalar(lhs, rhs, |a, _| !a) + } +} + +/// Perform `left > right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a & !b) +} + +/// Perform `left > right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + BooleanArray::new( + DataType::Boolean, + Bitmap::new_zeroed(lhs.len()), + lhs.validity().cloned(), + ) + } else { + lhs.clone() + } +} + +/// Perform `left >= right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a | !b) +} + +/// Perform `left >= right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + lhs.clone() + } else { + let all_ones = !0; + compare_op_scalar(lhs, rhs, |_, _| all_ones) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/mod.rs b/crates/nano-arrow/src/compute/comparison/mod.rs new file mode 100644 index 000000000000..96627ef2a5e1 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/mod.rs @@ -0,0 +1,613 @@ +//! Contains comparison operators +//! +//! The module contains functions that compare either an [`Array`] and a [`Scalar`] +//! or two [`Array`]s (of the same [`DataType`]). The scalar-oriented functions are +//! suffixed with `_scalar`. +//! +//! The functions are organized in two variants: +//! * statically typed +//! * dynamically typed +//! The statically typed are available under each module of this module (e.g. [`primitive::eq`], [`primitive::lt_scalar`]) +//! The dynamically typed are available in this module (e.g. [`eq`] or [`lt_scalar`]). +//! +//! # Examples +//! +//! Compare two [`PrimitiveArray`]s: +//! ``` +//! use arrow2::array::{BooleanArray, PrimitiveArray}; +//! use arrow2::compute::comparison::primitive::gt; +//! +//! let array1 = PrimitiveArray::::from([Some(1), None, Some(2)]); +//! let array2 = PrimitiveArray::::from([Some(1), Some(3), Some(1)]); +//! let result = gt(&array1, &array2); +//! assert_eq!(result, BooleanArray::from([Some(false), None, Some(true)])); +//! ``` +//! +//! Compare two dynamically-typed [`Array`]s (trait objects): +//! ``` +//! use arrow2::array::{Array, BooleanArray, PrimitiveArray}; +//! use arrow2::compute::comparison::eq; +//! +//! let array1: &dyn Array = &PrimitiveArray::::from(&[Some(10.0), None, Some(20.0)]); +//! let array2: &dyn Array = &PrimitiveArray::::from(&[Some(10.0), None, Some(10.0)]); +//! let result = eq(array1, array2); +//! assert_eq!(result, BooleanArray::from([Some(true), None, Some(false)])); +//! ``` +//! +//! Compare (not equal) a [`Utf8Array`] to a word: +//! ``` +//! use arrow2::array::{BooleanArray, Utf8Array}; +//! use arrow2::compute::comparison::utf8::neq_scalar; +//! +//! let array = Utf8Array::::from([Some("compute"), None, Some("compare")]); +//! let result = neq_scalar(&array, "compare"); +//! assert_eq!(result, BooleanArray::from([Some(true), None, Some(false)])); +//! ``` + +use crate::array::*; +use crate::datatypes::{DataType, IntervalUnit}; +use crate::scalar::*; + +pub mod binary; +pub mod boolean; +pub mod primitive; +pub mod utf8; + +mod simd; +pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; + +use super::take::take_boolean; +use crate::bitmap::{binary, Bitmap}; +use crate::compute; + +macro_rules! match_eq_ord {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::i256; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => todo!(), + MonthDayNano => todo!(), + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float16 => todo!(), + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +macro_rules! match_eq {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns, f16, i256}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float16 => __with_ty__! { f16 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +macro_rules! compare { + ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ + let lhs = $lhs; + let rhs = $rhs; + assert_eq!( + lhs.data_type().to_logical_type(), + rhs.data_type().to_logical_type() + ); + + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + boolean::$op(lhs, rhs) + }, + Primitive(primitive) => $p!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::<$T>(lhs, rhs) + }), + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::$op::(lhs, rhs) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::$op::(lhs, rhs) + }, + _ => todo!( + "Comparison between {:?} are not yet supported", + lhs.data_type() + ), + } + }}; +} + +/// `==` between two [`Array`]s. +/// Use [`can_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, eq, match_eq) +} + +/// `==` between two [`Array`]s and includes validities in comparison. +/// Use [`can_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn eq_and_validity(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, eq_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`eq`]. +pub fn can_eq(data_type: &DataType) -> bool { + can_partial_eq(data_type) +} + +/// `!=` between two [`Array`]s. +/// Use [`can_neq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, neq, match_eq) +} + +/// `!=` between two [`Array`]s and includes validities in comparison. +/// Use [`can_neq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn neq_and_validity(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, neq_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`neq`]. +pub fn can_neq(data_type: &DataType) -> bool { + can_partial_eq(data_type) +} + +/// `<` between two [`Array`]s. +/// Use [`can_lt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, lt, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`lt`]. +pub fn can_lt(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +/// `<=` between two [`Array`]s. +/// Use [`can_lt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, lt_eq, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`lt`]. +pub fn can_lt_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +/// `>` between two [`Array`]s. +/// Use [`can_gt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, gt, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`gt`]. +pub fn can_gt(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +/// `>=` between two [`Array`]s. +/// Use [`can_gt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, gt_eq, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`gt_eq`]. +pub fn can_gt_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +macro_rules! compare_scalar { + ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ + let lhs = $lhs; + let rhs = $rhs; + assert_eq!( + lhs.data_type().to_logical_type(), + rhs.data_type().to_logical_type() + ); + if !rhs.is_valid() { + return BooleanArray::new_null(DataType::Boolean, lhs.len()); + } + + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + // validity checked above + boolean::$op(lhs, rhs.value().unwrap()) + }, + Primitive(primitive) => $p!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::<$T>(lhs, rhs.value().unwrap()) + }), + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + utf8::$op::(lhs, rhs.value().unwrap()) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + binary::$op::(lhs, rhs.value().unwrap()) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let values = $op(lhs.values().as_ref(), rhs); + + take_boolean(&values, lhs.keys()) + }) + }, + _ => todo!("Comparisons of {:?} are not yet supported", lhs.data_type()), + } + }}; +} + +/// `==` between an [`Array`] and a [`Scalar`]. +/// Use [`can_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, eq_scalar, match_eq) +} + +/// `==` between an [`Array`] and a [`Scalar`] and includes validities in comparison. +/// Use [`can_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn eq_scalar_and_validity(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, eq_scalar_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is supported by [`eq_scalar`]. +pub fn can_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_scalar(data_type) +} + +/// `!=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_neq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, neq_scalar, match_eq) +} + +/// `!=` between an [`Array`] and a [`Scalar`] and includes validities in comparison. +/// Use [`can_neq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn neq_scalar_and_validity(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, neq_scalar_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is supported by [`neq_scalar`]. +pub fn can_neq_scalar(data_type: &DataType) -> bool { + can_partial_eq_scalar(data_type) +} + +/// `<` between an [`Array`] and a [`Scalar`]. +/// Use [`can_lt_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, lt_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`lt_scalar`]. +pub fn can_lt_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +/// `<=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_lt_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, lt_eq_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`lt_eq_scalar`]. +pub fn can_lt_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +/// `>` between an [`Array`] and a [`Scalar`]. +/// Use [`can_gt_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, gt_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`gt_scalar`]. +pub fn can_gt_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +/// `>=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_gt_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, gt_eq_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`gt_eq_scalar`]. +pub fn can_gt_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +// The list of operations currently supported. +fn can_partial_eq_and_ord_scalar(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { + return can_partial_eq_and_ord_scalar(values.as_ref()); + } + can_partial_eq_and_ord(data_type) +} + +// The list of operations currently supported. +fn can_partial_eq_and_ord(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) + | DataType::Int64 + | DataType::Timestamp(_, _) + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Decimal(_, _) + | DataType::Binary + | DataType::LargeBinary + ) +} + +// The list of operations currently supported. +fn can_partial_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) + || matches!( + data_type.to_logical_type(), + DataType::Float16 + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) + ) +} + +// The list of operations currently supported. +fn can_partial_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) + || matches!( + data_type.to_logical_type(), + DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) + ) +} + +/// Utility for low level end users that implement their own comparison functions +/// A comparison on the data column can be applied on masked out values +/// This function will correct equality for the validities. +pub fn finish_eq_validities( + output_without_validities: BooleanArray, + validity_lhs: Option, + validity_rhs: Option, +) -> BooleanArray { + match (validity_lhs, validity_rhs) { + (None, None) => output_without_validities, + (Some(lhs), None) => compute::boolean::and( + &BooleanArray::new(DataType::Boolean, lhs, None), + &output_without_validities, + ), + (None, Some(rhs)) => compute::boolean::and( + &output_without_validities, + &BooleanArray::new(DataType::Boolean, rhs, None), + ), + (Some(lhs), Some(rhs)) => { + let lhs_validity_unset_bits = lhs.unset_bits(); + let rhs_validity_unset_bits = rhs.unset_bits(); + + // this branch is a bit more complicated as both arrays can have masked out values + // these masked out values might differ and lead to a `eq == false` that has to + // be corrected as both should be `null == null = true` + + let lhs = BooleanArray::new(DataType::Boolean, lhs, None); + let rhs = BooleanArray::new(DataType::Boolean, rhs, None); + let eq_validities = compute::comparison::boolean::eq(&lhs, &rhs); + + // validity_bits are equal AND values are equal + let equal = compute::boolean::and(&output_without_validities, &eq_validities); + + match (lhs_validity_unset_bits, rhs_validity_unset_bits) { + // there is at least one side with all values valid + // so we don't have to correct. + (0, _) | (_, 0) => equal, + _ => { + // we use the binary kernel here to save allocations + // and apply `!(lhs | rhs)` in one step + let both_sides_invalid = + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| !(lhs | rhs)) + }); + // this still might include incorrect masked out values + // under the validity bits, so we must correct for that + + // if not all true, e.g. at least one is set. + // then we propagate that null as `true` in equality + if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { + compute::boolean::or(&equal, &both_sides_invalid) + } else { + equal + } + }, + } + }, + } +} + +/// Utility for low level end users that implement their own comparison functions +/// A comparison on the data column can be applied on masked out values +/// This function will correct non-equality for the validities. +pub fn finish_neq_validities( + output_without_validities: BooleanArray, + validity_lhs: Option, + validity_rhs: Option, +) -> BooleanArray { + match (validity_lhs, validity_rhs) { + (None, None) => output_without_validities, + (Some(lhs), None) => { + let lhs_negated = + compute::boolean::not(&BooleanArray::new(DataType::Boolean, lhs, None)); + compute::boolean::or(&lhs_negated, &output_without_validities) + }, + (None, Some(rhs)) => { + let rhs_negated = + compute::boolean::not(&BooleanArray::new(DataType::Boolean, rhs, None)); + compute::boolean::or(&output_without_validities, &rhs_negated) + }, + (Some(lhs), Some(rhs)) => { + let lhs_validity_unset_bits = lhs.unset_bits(); + let rhs_validity_unset_bits = rhs.unset_bits(); + + // this branch is a bit more complicated as both arrays can have masked out values + // these masked out values might differ and lead to a `neq == true` that has to + // be corrected as both should be `null != null = false` + let lhs = BooleanArray::new(DataType::Boolean, lhs, None); + let rhs = BooleanArray::new(DataType::Boolean, rhs, None); + let neq_validities = compute::comparison::boolean::neq(&lhs, &rhs); + + // validity_bits are not equal OR values not equal + let or = compute::boolean::or(&output_without_validities, &neq_validities); + + match (lhs_validity_unset_bits, rhs_validity_unset_bits) { + // there is at least one side with all values valid + // so we don't have to correct. + (0, _) | (_, 0) => or, + _ => { + // we use the binary kernel here to save allocations + // and apply `!(lhs | rhs)` in one step + let both_sides_invalid = + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| !(lhs | rhs)) + }); + // this still might include incorrect masked out values + // under the validity bits, so we must correct for that + + // if not all true, e.g. at least one is set. + // then we propagate that null as `false` as the nulls are equal + if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { + // we use the `binary` kernel directly to save allocations + // and apply `lhs & !rhs)` in one shot. + + compute::boolean::binary_boolean_kernel( + &or, + &both_sides_invalid, + |lhs, rhs| binary(lhs, rhs, |lhs, rhs| (lhs & !rhs)), + ) + } else { + or + } + }, + } + }, + } +} diff --git a/crates/nano-arrow/src/compute/comparison/primitive.rs b/crates/nano-arrow/src/compute/comparison/primitive.rs new file mode 100644 index 000000000000..5ecda063cd22 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/primitive.rs @@ -0,0 +1,590 @@ +//! Comparison functions for [`PrimitiveArray`] +use super::super::utils::combine_validities; +use super::simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use crate::array::{BooleanArray, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; +use crate::types::NativeType; + +pub(crate) fn compare_values_op(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + assert_eq!(lhs.len(), rhs.len()); + + let lhs_chunks_iter = lhs.chunks_exact(8); + let lhs_remainder = lhs_chunks_iter.remainder(); + let rhs_chunks_iter = rhs.chunks_exact(8); + let rhs_remainder = rhs_chunks_iter.remainder(); + + let mut values = Vec::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.zip(rhs_chunks_iter).map(|(lhs, rhs)| { + let lhs = T::Simd::from_chunk(lhs); + let rhs = T::Simd::from_chunk(rhs); + op(lhs, rhs) + }); + values.extend(iterator); + + if !lhs_remainder.is_empty() { + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + let rhs = T::Simd::from_incomplete_chunk(rhs_remainder, T::default()); + values.push(op(lhs, rhs)) + }; + MutableBitmap::from_vec(values, lhs.len()) +} + +pub(crate) fn compare_values_op_scalar(lhs: &[T], rhs: T, op: F) -> MutableBitmap +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + let rhs = T::Simd::from_chunk(&[rhs; 8]); + + let lhs_chunks_iter = lhs.chunks_exact(8); + let lhs_remainder = lhs_chunks_iter.remainder(); + + let mut values = Vec::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.map(|lhs| { + let lhs = T::Simd::from_chunk(lhs); + op(lhs, rhs) + }); + values.extend(iterator); + + if !lhs_remainder.is_empty() { + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + values.push(op(lhs, rhs)) + }; + + MutableBitmap::from_vec(values, lhs.len()) +} + +/// Evaluate `op(lhs, rhs)` for [`PrimitiveArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: F) -> BooleanArray +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = compare_values_op(lhs.values(), rhs.values(), op); + + BooleanArray::new(DataType::Boolean, values.into(), validity) +} + +/// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using +/// a specified comparison function. +pub fn compare_op_scalar(lhs: &PrimitiveArray, rhs: T, op: F) -> BooleanArray +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + let validity = lhs.validity().cloned(); + + let values = compare_values_op_scalar(lhs.values(), rhs, op); + + BooleanArray::new(DataType::Boolean, values.into(), validity) +} + +/// Perform `lhs == rhs` operation on two arrays. +pub fn eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op(lhs, rhs, |a, b| a.eq(b)) +} + +/// Perform `lhs == rhs` operation on two arrays and include validities in comparison. +pub fn eq_and_validity(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a.eq(b)); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `left == right` operation on an array and a scalar value. +pub fn eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op_scalar(lhs, rhs, |a, b| a.eq(b)) +} + +/// Perform `left == right` operation on an array and a scalar value and include validities in comparison. +pub fn eq_scalar_and_validity(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a.eq(b)); + + finish_eq_validities(out, validity, None) +} + +/// Perform `left != right` operation on two arrays. +pub fn neq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op(lhs, rhs, |a, b| a.neq(b)) +} + +/// Perform `left != right` operation on two arrays and include validities in comparison. +pub fn neq_and_validity(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a.neq(b)); + + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op_scalar(lhs, rhs, |a, b| a.neq(b)) +} + +/// Perform `left != right` operation on an array and a scalar value and include validities in comparison. +pub fn neq_scalar_and_validity(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a.neq(b)); + + finish_neq_validities(out, validity, None) +} + +/// Perform `left < right` operation on two arrays. +pub fn lt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.lt(b)) +} + +/// Perform `left < right` operation on an array and a scalar value. +pub fn lt_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.lt(b)) +} + +/// Perform `left <= right` operation on two arrays. +pub fn lt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.lt_eq(b)) +} + +/// Perform `left <= right` operation on an array and a scalar value. +/// Null values are less than non-null values. +pub fn lt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b)) +} + +/// Perform `left > right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.gt(b)) +} + +/// Perform `left > right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.gt(b)) +} + +/// Perform `left >= right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.gt_eq(b)) +} + +/// Perform `left >= right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b)) +} + +// disable wrapping inside literal vectors used for test data and assertions +#[rustfmt::skip::macros(vec)] +#[cfg(test)] +mod tests { + use super::*; + use crate::array::{Int64Array, Int8Array}; + + /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. + /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>`. + /// `EXPECTED` can be either `Vec` or `Vec>`. + /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. + macro_rules! cmp_i64 { + ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + let a = Int64Array::from_slice($A_VEC); + let b = Int64Array::from_slice($B_VEC); + let c = $KERNEL(&a, &b); + assert_eq!(BooleanArray::from_slice($EXPECTED), c); + }; + } + + macro_rules! cmp_i64_options { + ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + let a = Int64Array::from($A_VEC); + let b = Int64Array::from($B_VEC); + let c = $KERNEL(&a, &b); + assert_eq!(BooleanArray::from($EXPECTED), c); + }; + } + + /// Evaluate `KERNEL` with one vectors and one scalar as inputs and assert against the expected output. + /// `A_VEC` can be of type `Vec` or `Vec>`. + /// `EXPECTED` can be either `Vec` or `Vec>`. + /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. + macro_rules! cmp_i64_scalar_options { + ($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => { + let a = Int64Array::from($A_VEC); + let c = $KERNEL(&a, $B); + assert_eq!(BooleanArray::from($EXPECTED), c); + }; + } + + macro_rules! cmp_i64_scalar { + ($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => { + let a = Int64Array::from_slice($A_VEC); + let c = $KERNEL(&a, $B); + assert_eq!(BooleanArray::from_slice($EXPECTED), c); + }; + } + + #[test] + fn test_primitive_array_eq() { + cmp_i64!( + eq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); + } + + #[test] + fn test_primitive_array_eq_scalar() { + cmp_i64_scalar!( + eq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![false, false, true, false, false, false, false, true, false, false] + ); + } + + #[test] + fn test_primitive_array_eq_with_slice() { + let a = Int64Array::from_slice([6, 7, 8, 8, 10]); + let mut b = Int64Array::from_slice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + b.slice(5, 5); + let d = eq(&b, &a); + assert_eq!(d, BooleanArray::from_slice([true, true, true, false, true])); + } + + #[test] + fn test_primitive_array_neq() { + cmp_i64!( + neq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, false, true, true, true, true, false, true, true] + ); + } + + #[test] + fn test_primitive_array_neq_scalar() { + cmp_i64_scalar!( + neq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![true, true, false, true, true, true, true, false, true, true] + ); + } + + #[test] + fn test_primitive_array_lt() { + cmp_i64!( + lt, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, false, true, true, false, false, false, true, true] + ); + } + + #[test] + fn test_primitive_array_lt_scalar() { + cmp_i64_scalar!( + lt_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![true, true, false, false, false, true, true, false, false, false] + ); + } + + #[test] + fn test_primitive_array_lt_nulls() { + cmp_i64_options!( + lt, + &[None, None, Some(1), Some(1), None, None, Some(2), Some(2),], + &[None, Some(1), None, Some(1), None, Some(3), None, Some(3),], + vec![None, None, None, Some(false), None, None, None, Some(true)] + ); + } + + #[test] + fn test_primitive_array_lt_scalar_nulls() { + cmp_i64_scalar_options!( + lt_scalar, + &[None, Some(1), Some(2), Some(3), None, Some(1), Some(2), Some(3), Some(2), None], + 2, + vec![None, Some(true), Some(false), Some(false), None, Some(true), Some(false), Some(false), Some(false), None] + ); + } + + #[test] + fn test_primitive_array_lt_eq() { + cmp_i64!( + lt_eq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, true, true, false, false, true, true, true] + ); + } + + #[test] + fn test_primitive_array_lt_eq_scalar() { + cmp_i64_scalar!( + lt_eq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![true, true, true, false, false, true, true, true, false, false] + ); + } + + #[test] + fn test_primitive_array_lt_eq_nulls() { + cmp_i64_options!( + lt_eq, + &[ + None, + None, + Some(1), + None, + None, + Some(1), + None, + None, + Some(1) + ], + &[ + None, + Some(1), + Some(0), + None, + Some(1), + Some(2), + None, + None, + Some(3) + ], + vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)] + ); + } + + #[test] + fn test_primitive_array_lt_eq_scalar_nulls() { + cmp_i64_scalar_options!( + lt_eq_scalar, + &[None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)], + 1, + vec![None, Some(true), Some(false), None, Some(true), Some(false), None, Some(true), Some(false)] + ); + } + + #[test] + fn test_primitive_array_gt() { + cmp_i64!( + gt, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, false, false, false, true, true, false, false, false] + ); + } + + #[test] + fn test_primitive_array_gt_scalar() { + cmp_i64_scalar!( + gt_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![false, false, false, true, true, false, false, false, true, true] + ); + } + + #[test] + fn test_primitive_array_gt_nulls() { + cmp_i64_options!( + gt, + &[ + None, + None, + Some(1), + None, + None, + Some(2), + None, + None, + Some(3) + ], + &[ + None, + Some(1), + Some(1), + None, + Some(1), + Some(1), + None, + Some(1), + Some(1) + ], + vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)] + ); + } + + #[test] + fn test_primitive_array_gt_scalar_nulls() { + cmp_i64_scalar_options!( + gt_scalar, + &[None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)], + 1, + vec![None, Some(false), Some(true), None, Some(false), Some(true), None, Some(false), Some(true)] + ); + } + + #[test] + fn test_primitive_array_gt_eq() { + cmp_i64!( + gt_eq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, true, false, false, true, true, true, false, false] + ); + } + + #[test] + fn test_primitive_array_gt_eq_scalar() { + cmp_i64_scalar!( + gt_eq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![false, false, true, true, true, false, false, true, true, true] + ); + } + + #[test] + fn test_primitive_array_gt_eq_nulls() { + cmp_i64_options!( + gt_eq, + vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)], + vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)], + vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)] + ); + } + + #[test] + fn test_primitive_array_gt_eq_scalar_nulls() { + cmp_i64_scalar_options!( + gt_eq_scalar, + vec![None, Some(1), Some(2), None, Some(2), Some(3), None, Some(3), Some(4)], + 2, + vec![None, Some(false), Some(true), None, Some(true), Some(true), None, Some(true), Some(true)] + ); + } + + #[test] + fn test_primitive_array_compare_slice() { + let mut a = (0..100).map(Some).collect::>(); + a.slice(50, 50); + let mut b = (100..200).map(Some).collect::>(); + b.slice(50, 50); + let actual = lt(&a, &b); + let expected: BooleanArray = (0..50).map(|_| Some(true)).collect(); + assert_eq!(expected, actual); + } + + #[test] + fn test_primitive_array_compare_scalar_slice() { + let mut a = (0..100).map(Some).collect::>(); + a.slice(50, 50); + let actual = lt_scalar(&a, 200); + let expected: BooleanArray = (0..50).map(|_| Some(true)).collect(); + assert_eq!(expected, actual); + } + + #[test] + fn test_length_of_result_buffer() { + // `item_count` is chosen to not be a multiple of 64. + const ITEM_COUNT: usize = 130; + + let array_a = Int8Array::from_slice([1; ITEM_COUNT]); + let array_b = Int8Array::from_slice([2; ITEM_COUNT]); + let expected = BooleanArray::from_slice([false; ITEM_COUNT]); + let result = gt_eq(&array_a, &array_b); + + assert_eq!(result, expected) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/simd/mod.rs b/crates/nano-arrow/src/compute/comparison/simd/mod.rs new file mode 100644 index 000000000000..30d9773cd4c9 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/simd/mod.rs @@ -0,0 +1,133 @@ +use crate::types::NativeType; + +/// [`NativeType`] that supports a representation of 8 lanes +pub trait Simd8: NativeType { + /// The 8 lane representation of `Self` + type Simd: Simd8Lanes; +} + +/// Trait declaring an 8-lane multi-data. +pub trait Simd8Lanes: Copy { + /// loads a complete chunk + fn from_chunk(v: &[T]) -> Self; + /// loads an incomplete chunk, filling the remaining items with `remaining`. + fn from_incomplete_chunk(v: &[T], remaining: T) -> Self; +} + +/// Trait implemented by implementors of [`Simd8Lanes`] whose [`Simd8`] implements [PartialEq]. +pub trait Simd8PartialEq: Copy { + /// Equal + fn eq(self, other: Self) -> u8; + /// Not equal + fn neq(self, other: Self) -> u8; +} + +/// Trait implemented by implementors of [`Simd8Lanes`] whose [`Simd8`] implements [PartialOrd]. +pub trait Simd8PartialOrd: Copy { + /// Less than or equal to + fn lt_eq(self, other: Self) -> u8; + /// Less than + fn lt(self, other: Self) -> u8; + /// Greater than + fn gt(self, other: Self) -> u8; + /// Greater than or equal to + fn gt_eq(self, other: Self) -> u8; +} + +#[inline] +pub(super) fn set bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 { + let mut byte = 0u8; + lhs.iter() + .zip(rhs.iter()) + .enumerate() + .for_each(|(i, (lhs, rhs))| { + byte |= if op(*lhs, *rhs) { 1 << i } else { 0 }; + }); + byte +} + +/// Types that implement Simd8 +macro_rules! simd8_native { + ($type:ty) => { + impl Simd8 for $type { + type Simd = [$type; 8]; + } + + impl Simd8Lanes<$type> for [$type; 8] { + #[inline] + fn from_chunk(v: &[$type]) -> Self { + v.try_into().unwrap() + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; 8]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + a + } + } + }; +} + +/// Types that implement PartialEq +macro_rules! simd8_native_partial_eq { + ($type:ty) => { + impl Simd8PartialEq for [$type; 8] { + #[inline] + fn eq(self, other: Self) -> u8 { + set(self, other, |x, y| x == y) + } + + #[inline] + fn neq(self, other: Self) -> u8 { + #[allow(clippy::float_cmp)] + set(self, other, |x, y| x != y) + } + } + }; +} + +/// Types that implement PartialOrd +macro_rules! simd8_native_partial_ord { + ($type:ty) => { + impl Simd8PartialOrd for [$type; 8] { + #[inline] + fn lt_eq(self, other: Self) -> u8 { + set(self, other, |x, y| x <= y) + } + + #[inline] + fn lt(self, other: Self) -> u8 { + set(self, other, |x, y| x < y) + } + + #[inline] + fn gt_eq(self, other: Self) -> u8 { + set(self, other, |x, y| x >= y) + } + + #[inline] + fn gt(self, other: Self) -> u8 { + set(self, other, |x, y| x > y) + } + } + }; +} + +/// Types that implement simd8, PartialEq and PartialOrd +macro_rules! simd8_native_all { + ($type:ty) => { + simd8_native! {$type} + simd8_native_partial_eq! {$type} + simd8_native_partial_ord! {$type} + }; +} + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +pub use packed::*; diff --git a/crates/nano-arrow/src/compute/comparison/simd/native.rs b/crates/nano-arrow/src/compute/comparison/simd/native.rs new file mode 100644 index 000000000000..b8bbf9b17d66 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/simd/native.rs @@ -0,0 +1,23 @@ +use std::convert::TryInto; + +use super::{set, Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use crate::types::{days_ms, f16, i256, months_days_ns}; + +simd8_native_all!(u8); +simd8_native_all!(u16); +simd8_native_all!(u32); +simd8_native_all!(u64); +simd8_native_all!(i8); +simd8_native_all!(i16); +simd8_native_all!(i32); +simd8_native_all!(i128); +simd8_native_all!(i256); +simd8_native_all!(i64); +simd8_native!(f16); +simd8_native_partial_eq!(f16); +simd8_native_all!(f32); +simd8_native_all!(f64); +simd8_native!(days_ms); +simd8_native_partial_eq!(days_ms); +simd8_native!(months_days_ns); +simd8_native_partial_eq!(months_days_ns); diff --git a/crates/nano-arrow/src/compute/comparison/simd/packed.rs b/crates/nano-arrow/src/compute/comparison/simd/packed.rs new file mode 100644 index 000000000000..707d875deef0 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/simd/packed.rs @@ -0,0 +1,81 @@ +use std::convert::TryInto; +use std::simd::{SimdPartialEq, SimdPartialOrd, ToBitMask}; + +use super::*; +use crate::types::simd::*; +use crate::types::{days_ms, f16, i256, months_days_ns}; + +macro_rules! simd8 { + ($type:ty, $md:ty) => { + impl Simd8 for $type { + type Simd = $md; + } + + impl Simd8Lanes<$type> for $md { + #[inline] + fn from_chunk(v: &[$type]) -> Self { + <$md>::from_slice(v) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; 8]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + Self::from_array(a) + } + } + + impl Simd8PartialEq for $md { + #[inline] + fn eq(self, other: Self) -> u8 { + self.simd_eq(other).to_bitmask() + } + + #[inline] + fn neq(self, other: Self) -> u8 { + self.simd_ne(other).to_bitmask() + } + } + + impl Simd8PartialOrd for $md { + #[inline] + fn lt_eq(self, other: Self) -> u8 { + self.simd_le(other).to_bitmask() + } + + #[inline] + fn lt(self, other: Self) -> u8 { + self.simd_lt(other).to_bitmask() + } + + #[inline] + fn gt_eq(self, other: Self) -> u8 { + self.simd_ge(other).to_bitmask() + } + + #[inline] + fn gt(self, other: Self) -> u8 { + self.simd_gt(other).to_bitmask() + } + } + }; +} + +simd8!(u8, u8x8); +simd8!(u16, u16x8); +simd8!(u32, u32x8); +simd8!(u64, u64x8); +simd8!(i8, i8x8); +simd8!(i16, i16x8); +simd8!(i32, i32x8); +simd8!(i64, i64x8); +simd8_native_all!(i128); +simd8_native_all!(i256); +simd8_native!(f16); +simd8_native_partial_eq!(f16); +simd8!(f32, f32x8); +simd8!(f64, f64x8); +simd8_native!(days_ms); +simd8_native_partial_eq!(days_ms); +simd8_native!(months_days_ns); +simd8_native_partial_eq!(months_days_ns); diff --git a/crates/nano-arrow/src/compute/comparison/utf8.rs b/crates/nano-arrow/src/compute/comparison/utf8.rs new file mode 100644 index 000000000000..cba683c7b869 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/utf8.rs @@ -0,0 +1,291 @@ +//! Comparison functions for [`Utf8Array`] +use super::super::utils::combine_validities; +use crate::array::{BooleanArray, Utf8Array}; +use crate::bitmap::Bitmap; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// Evaluate `op(lhs, rhs)` for [`Utf8Array`]s using a specified +/// comparison function. +fn compare_op(lhs: &Utf8Array, rhs: &Utf8Array, op: F) -> BooleanArray +where + O: Offset, + F: Fn(&str, &str) -> bool, +{ + assert_eq!(lhs.len(), rhs.len()); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values_iter() + .zip(rhs.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Evaluate `op(lhs, rhs)` for [`Utf8Array`] and scalar using +/// a specified comparison function. +fn compare_op_scalar(lhs: &Utf8Array, rhs: &str, op: F) -> BooleanArray +where + O: Offset, + F: Fn(&str, &str) -> bool, +{ + let validity = lhs.validity().cloned(); + + let values = lhs.values_iter().map(|lhs| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`]. +pub fn eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`] and include validities in comparison. +pub fn eq_and_validity(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a == b); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`] and include validities in comparison. +pub fn neq_and_validity(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a != b); + + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`] and a scalar. +pub fn eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`] and a scalar. Also includes null values in comparison. +pub fn eq_scalar_and_validity(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a == b); + + finish_eq_validities(out, validity, None) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`] and a scalar. Also includes null values in comparison. +pub fn neq_scalar_and_validity(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a != b); + + finish_neq_validities(out, validity, None) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`]. +pub fn neq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`] and a scalar. +pub fn neq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs < rhs` operation on [`Utf8Array`]. +pub fn lt(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs < rhs` operation on [`Utf8Array`] and a scalar. +pub fn lt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs <= rhs` operation on [`Utf8Array`]. +pub fn lt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs <= rhs` operation on [`Utf8Array`] and a scalar. +pub fn lt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs > rhs` operation on [`Utf8Array`]. +pub fn gt(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs > rhs` operation on [`Utf8Array`] and a scalar. +pub fn gt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs >= rhs` operation on [`Utf8Array`]. +pub fn gt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a >= b) +} + +/// Perform `lhs >= rhs` operation on [`Utf8Array`] and a scalar. +pub fn gt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a >= b) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_generic, &Utf8Array) -> BooleanArray>( + lhs: Vec<&str>, + rhs: Vec<&str>, + op: F, + expected: Vec, + ) { + let lhs = Utf8Array::::from_slice(lhs); + let rhs = Utf8Array::::from_slice(rhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, &rhs), expected); + } + + fn test_generic_scalar, &str) -> BooleanArray>( + lhs: Vec<&str>, + rhs: &str, + op: F, + expected: Vec, + ) { + let lhs = Utf8Array::::from_slice(lhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, rhs), expected); + } + + #[test] + fn test_gt_eq() { + test_generic::( + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + gt_eq, + vec![false, false, true, true], + ) + } + + #[test] + fn test_gt_eq_scalar() { + test_generic_scalar::( + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_eq_scalar, + vec![false, false, true, true], + ) + } + + #[test] + fn test_eq() { + test_generic::( + vec!["arrow", "arrow", "arrow", "arrow"], + vec!["arrow", "parquet", "datafusion", "flight"], + eq, + vec![true, false, false, false], + ) + } + + #[test] + fn test_eq_scalar() { + test_generic_scalar::( + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + eq_scalar, + vec![true, false, false, false], + ) + } + + #[test] + fn test_neq() { + test_generic::( + vec!["arrow", "arrow", "arrow", "arrow"], + vec!["arrow", "parquet", "datafusion", "flight"], + neq, + vec![false, true, true, true], + ) + } + + #[test] + fn test_neq_scalar() { + test_generic_scalar::( + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + neq_scalar, + vec![false, true, true, true], + ) + } + + /* + test_utf8!( + test_utf8_array_lt, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + lt_utf8, + vec![true, true, false, false] + ); + test_utf8_scalar!( + test_utf8_array_lt_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + lt_utf8_scalar, + vec![true, true, false, false] + ); + + test_utf8!( + test_utf8_array_lt_eq, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + lt_eq_utf8, + vec![true, true, true, false] + ); + test_utf8_scalar!( + test_utf8_array_lt_eq_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + lt_eq_utf8_scalar, + vec![true, true, true, false] + ); + + test_utf8!( + test_utf8_array_gt, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + gt_utf8, + vec![false, false, false, true] + ); + test_utf8_scalar!( + test_utf8_array_gt_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_utf8_scalar, + vec![false, false, false, true] + ); + + test_utf8!( + test_utf8_array_gt_eq, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + gt_eq_utf8, + vec![false, false, true, true] + ); + test_utf8_scalar!( + test_utf8_array_gt_eq_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_eq_utf8_scalar, + vec![false, false, true, true] + ); + */ +} diff --git a/crates/nano-arrow/src/compute/concatenate.rs b/crates/nano-arrow/src/compute/concatenate.rs new file mode 100644 index 000000000000..48db6c141e27 --- /dev/null +++ b/crates/nano-arrow/src/compute/concatenate.rs @@ -0,0 +1,69 @@ +//! Contains the concatenate kernel +//! +//! Example: +//! +//! ``` +//! use arrow2::array::Utf8Array; +//! use arrow2::compute::concatenate::concatenate; +//! +//! let arr = concatenate(&[ +//! &Utf8Array::::from_slice(["hello", "world"]), +//! &Utf8Array::::from_slice(["!"]), +//! ]).unwrap(); +//! assert_eq!(arr.len(), 3); +//! ``` + +use crate::array::growable::make_growable; +use crate::array::Array; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::error::{Error, Result}; + +/// Concatenate multiple [Array] of the same type into a single [`Array`]. +pub fn concatenate(arrays: &[&dyn Array]) -> Result> { + if arrays.is_empty() { + return Err(Error::InvalidArgumentError( + "concat requires input of at least one array".to_string(), + )); + } + + if arrays + .iter() + .any(|array| array.data_type() != arrays[0].data_type()) + { + return Err(Error::InvalidArgumentError( + "It is not possible to concatenate arrays of different data types.".to_string(), + )); + } + + let lengths = arrays.iter().map(|array| array.len()).collect::>(); + let capacity = lengths.iter().sum(); + + let mut mutable = make_growable(arrays, false, capacity); + + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } + + Ok(mutable.as_box()) +} + +/// Concatenate the validities of multiple [Array]s into a single Bitmap. +pub fn concatenate_validities(arrays: &[&dyn Array]) -> Option { + let null_count: usize = arrays.iter().map(|a| a.null_count()).sum(); + if null_count == 0 { + return None; + } + + let total_size: usize = arrays.iter().map(|a| a.len()).sum(); + let mut bitmap = MutableBitmap::with_capacity(total_size); + for arr in arrays { + if arr.null_count() == arr.len() { + bitmap.extend_constant(arr.len(), false); + } else if arr.null_count() == 0 { + bitmap.extend_constant(arr.len(), true); + } else { + bitmap.extend_from_bitmap(arr.validity().unwrap()); + } + } + Some(bitmap.into()) +} diff --git a/crates/nano-arrow/src/compute/filter.rs b/crates/nano-arrow/src/compute/filter.rs new file mode 100644 index 000000000000..90ddf4b4d158 --- /dev/null +++ b/crates/nano-arrow/src/compute/filter.rs @@ -0,0 +1,321 @@ +//! Contains operators to filter arrays such as [`filter`]. +use crate::array::growable::{make_growable, Growable}; +use crate::array::*; +use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact, SlicesIterator}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::chunk::Chunk; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::simd::Simd; +use crate::types::{BitChunkOnes, NativeType}; + +/// Function that can filter arbitrary arrays +pub type Filter<'a> = Box Box + 'a + Send + Sync>; + +#[inline] +fn get_leading_ones(chunk: u64) -> u32 { + if cfg!(target_endian = "little") { + chunk.trailing_ones() + } else { + chunk.leading_ones() + } +} + +/// # Safety +/// This assumes that the `mask_chunks` contains a number of set/true items equal +/// to `filter_count` +unsafe fn nonnull_filter_impl(values: &[T], mut mask_chunks: I, filter_count: usize) -> Vec +where + T: NativeType + Simd, + I: BitChunkIterExact, +{ + let mut chunks = values.chunks_exact(64); + let mut new = Vec::::with_capacity(filter_count); + let mut dst = new.as_mut_ptr(); + + chunks + .by_ref() + .zip(mask_chunks.by_ref()) + .for_each(|(chunk, mask_chunk)| { + let ones = mask_chunk.count_ones(); + let leading_ones = get_leading_ones(mask_chunk); + + if ones == leading_ones { + let size = leading_ones as usize; + unsafe { + std::ptr::copy(chunk.as_ptr(), dst, size); + dst = dst.add(size); + } + return; + } + + let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); + for pos in ones_iter { + dst.write(*chunk.get_unchecked(pos)); + dst = dst.add(1); + } + }); + + chunks + .remainder() + .iter() + .zip(mask_chunks.remainder_iter()) + .for_each(|(value, b)| { + if b { + unsafe { + dst.write(*value); + dst = dst.add(1); + }; + } + }); + + unsafe { new.set_len(filter_count) }; + new +} + +/// # Safety +/// This assumes that the `mask_chunks` contains a number of set/true items equal +/// to `filter_count` +unsafe fn null_filter_impl( + values: &[T], + validity: &Bitmap, + mut mask_chunks: I, + filter_count: usize, +) -> (Vec, MutableBitmap) +where + T: NativeType + Simd, + I: BitChunkIterExact, +{ + let mut chunks = values.chunks_exact(64); + + let mut validity_chunks = validity.chunks::(); + + let mut new = Vec::::with_capacity(filter_count); + let mut dst = new.as_mut_ptr(); + let mut new_validity = MutableBitmap::with_capacity(filter_count); + + chunks + .by_ref() + .zip(validity_chunks.by_ref()) + .zip(mask_chunks.by_ref()) + .for_each(|((chunk, validity_chunk), mask_chunk)| { + let ones = mask_chunk.count_ones(); + let leading_ones = get_leading_ones(mask_chunk); + + if ones == leading_ones { + let size = leading_ones as usize; + unsafe { + std::ptr::copy(chunk.as_ptr(), dst, size); + dst = dst.add(size); + + // safety: invariant offset + length <= slice.len() + new_validity.extend_from_slice_unchecked( + validity_chunk.to_ne_bytes().as_ref(), + 0, + size, + ); + } + return; + } + + // this triggers a bitcount + let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); + for pos in ones_iter { + dst.write(*chunk.get_unchecked(pos)); + dst = dst.add(1); + new_validity.push_unchecked(validity_chunk & (1 << pos) > 0); + } + }); + + chunks + .remainder() + .iter() + .zip(validity_chunks.remainder_iter()) + .zip(mask_chunks.remainder_iter()) + .for_each(|((value, is_valid), is_selected)| { + if is_selected { + unsafe { + dst.write(*value); + dst = dst.add(1); + new_validity.push_unchecked(is_valid); + }; + } + }); + + unsafe { new.set_len(filter_count) }; + (new, new_validity) +} + +fn null_filter_simd( + values: &[T], + validity: &Bitmap, + mask: &Bitmap, +) -> (Vec, MutableBitmap) { + assert_eq!(values.len(), mask.len()); + let filter_count = mask.len() - mask.unset_bits(); + + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } + } else { + let mask_chunks = mask.chunks::(); + unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } + } +} + +fn nonnull_filter_simd(values: &[T], mask: &Bitmap) -> Vec { + assert_eq!(values.len(), mask.len()); + let filter_count = mask.len() - mask.unset_bits(); + + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } + } else { + let mask_chunks = mask.chunks::(); + unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } + } +} + +fn filter_nonnull_primitive( + array: &PrimitiveArray, + mask: &Bitmap, +) -> PrimitiveArray { + assert_eq!(array.len(), mask.len()); + + if let Some(validity) = array.validity() { + let (values, validity) = null_filter_simd(array.values(), validity, mask); + PrimitiveArray::::new(array.data_type().clone(), values.into(), validity.into()) + } else { + let values = nonnull_filter_simd(array.values(), mask); + PrimitiveArray::::new(array.data_type().clone(), values.into(), None) + } +} + +fn filter_primitive( + array: &PrimitiveArray, + mask: &BooleanArray, +) -> PrimitiveArray { + // todo: branch on mask.validity() + filter_nonnull_primitive(array, mask.values()) +} + +fn filter_growable<'a>(growable: &mut impl Growable<'a>, chunks: &[(usize, usize)]) { + chunks + .iter() + .for_each(|(start, len)| growable.extend(0, *start, *len)); +} + +/// Returns a prepared function optimized to filter multiple arrays. +/// Creating this function requires time, but using it is faster than [filter] when the +/// same filter needs to be applied to multiple arrays (e.g. a multiple columns). +pub fn build_filter(filter: &BooleanArray) -> Result { + let iter = SlicesIterator::new(filter.values()); + let filter_count = iter.slots(); + let chunks = iter.collect::>(); + + use crate::datatypes::PhysicalType::*; + Ok(Box::new(move |array: &dyn Array| { + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + let mut growable = + growable::GrowablePrimitive::<$T>::new(vec![array], false, filter_count); + filter_growable(&mut growable, &chunks); + let array: PrimitiveArray<$T> = growable.into(); + Box::new(array) + }), + LargeUtf8 => { + let array = array.as_any().downcast_ref::>().unwrap(); + let mut growable = growable::GrowableUtf8::new(vec![array], false, filter_count); + filter_growable(&mut growable, &chunks); + let array: Utf8Array = growable.into(); + Box::new(array) + }, + _ => { + let mut mutable = make_growable(&[array], false, filter_count); + chunks + .iter() + .for_each(|(start, len)| mutable.extend(0, *start, *len)); + mutable.as_box() + }, + } + })) +} + +/// Filters an [Array], returning elements matching the filter (i.e. where the values are true). +/// +/// Note that the nulls of `filter` are interpreted as `false` will lead to these elements being +/// masked out. +/// +/// # Example +/// ```rust +/// # use arrow2::array::{Int32Array, PrimitiveArray, BooleanArray}; +/// # use arrow2::error::Result; +/// # use arrow2::compute::filter::filter; +/// # fn main() -> Result<()> { +/// let array = PrimitiveArray::from_slice([5, 6, 7, 8, 9]); +/// let filter_array = BooleanArray::from_slice(&vec![true, false, false, true, false]); +/// let c = filter(&array, &filter_array)?; +/// let c = c.as_any().downcast_ref::().unwrap(); +/// assert_eq!(c, &PrimitiveArray::from_slice(vec![5, 8])); +/// # Ok(()) +/// # } +/// ``` +pub fn filter(array: &dyn Array, filter: &BooleanArray) -> Result> { + // The validities may be masking out `true` bits, making the filter operation + // based on the values incorrect + if let Some(validities) = filter.validity() { + let values = filter.values(); + let new_values = values & validities; + let filter = BooleanArray::new(DataType::Boolean, new_values, None); + return crate::compute::filter::filter(array, &filter); + } + + let false_count = filter.values().unset_bits(); + if false_count == filter.len() { + assert_eq!(array.len(), filter.len()); + return Ok(new_empty_array(array.data_type().clone())); + } + if false_count == 0 { + assert_eq!(array.len(), filter.len()); + return Ok(array.to_boxed()); + } + + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(filter_primitive::<$T>(array, filter))) + }), + _ => { + let iter = SlicesIterator::new(filter.values()); + let mut mutable = make_growable(&[array], false, iter.slots()); + iter.for_each(|(start, len)| mutable.extend(0, start, len)); + Ok(mutable.as_box()) + }, + } +} + +/// Returns a new [Chunk] with arrays containing only values matching the filter. +/// This is a convenience function: filter multiple columns is embarrassingly parallel. +pub fn filter_chunk>( + columns: &Chunk, + filter_values: &BooleanArray, +) -> Result>> { + let arrays = columns.arrays(); + + let num_columns = arrays.len(); + + let filtered_arrays = match num_columns { + 1 => { + vec![filter(columns.arrays()[0].as_ref(), filter_values)?] + }, + _ => { + let filter = build_filter(filter_values)?; + arrays.iter().map(|a| filter(a.as_ref())).collect() + }, + }; + Chunk::try_new(filtered_arrays) +} diff --git a/crates/nano-arrow/src/compute/if_then_else.rs b/crates/nano-arrow/src/compute/if_then_else.rs new file mode 100644 index 000000000000..86c46b29d040 --- /dev/null +++ b/crates/nano-arrow/src/compute/if_then_else.rs @@ -0,0 +1,75 @@ +//! Contains the operator [`if_then_else`]. +use crate::array::{growable, Array, BooleanArray}; +use crate::bitmap::utils::SlicesIterator; +use crate::error::{Error, Result}; + +/// Returns the values from `lhs` if the predicate is `true` or from the `rhs` if the predicate is false +/// Returns `None` if the predicate is `None`. +/// # Example +/// ```rust +/// # use arrow2::error::Result; +/// use arrow2::compute::if_then_else::if_then_else; +/// use arrow2::array::{Int32Array, BooleanArray}; +/// +/// # fn main() -> Result<()> { +/// let lhs = Int32Array::from_slice(&[1, 2, 3]); +/// let rhs = Int32Array::from_slice(&[4, 5, 6]); +/// let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); +/// let result = if_then_else(&predicate, &lhs, &rhs)?; +/// +/// let expected = Int32Array::from(&[Some(1), None, Some(6)]); +/// +/// assert_eq!(expected, result.as_ref()); +/// # Ok(()) +/// # } +/// ``` +pub fn if_then_else( + predicate: &BooleanArray, + lhs: &dyn Array, + rhs: &dyn Array, +) -> Result> { + if lhs.data_type() != rhs.data_type() { + return Err(Error::InvalidArgumentError(format!( + "If then else requires the arguments to have the same datatypes ({:?} != {:?})", + lhs.data_type(), + rhs.data_type() + ))); + } + if (lhs.len() != rhs.len()) | (lhs.len() != predicate.len()) { + return Err(Error::InvalidArgumentError(format!( + "If then else requires all arguments to have the same length (predicate = {}, lhs = {}, rhs = {})", + predicate.len(), + lhs.len(), + rhs.len() + ))); + } + + let result = if predicate.null_count() > 0 { + let mut growable = growable::make_growable(&[lhs, rhs], true, lhs.len()); + for (i, v) in predicate.iter().enumerate() { + match v { + Some(v) => growable.extend(!v as usize, i, 1), + None => growable.extend_validity(1), + } + } + growable.as_box() + } else { + let mut growable = growable::make_growable(&[lhs, rhs], false, lhs.len()); + let mut start_falsy = 0; + let mut total_len = 0; + for (start, len) in SlicesIterator::new(predicate.values()) { + if start != start_falsy { + growable.extend(1, start_falsy, start - start_falsy); + total_len += start - start_falsy; + }; + growable.extend(0, start, len); + total_len += len; + start_falsy = start + len; + } + if total_len != lhs.len() { + growable.extend(1, total_len, lhs.len() - total_len); + } + growable.as_box() + }; + Ok(result) +} diff --git a/crates/nano-arrow/src/compute/mod.rs b/crates/nano-arrow/src/compute/mod.rs new file mode 100644 index 000000000000..a40e4dcbb558 --- /dev/null +++ b/crates/nano-arrow/src/compute/mod.rs @@ -0,0 +1,52 @@ +//! contains a wide range of compute operations (e.g. +//! [`arithmetics`], [`aggregate`], +//! [`filter`], [`comparison`], and [`sort`]) +//! +//! This module's general design is +//! that each operator has two interfaces, a statically-typed version and a dynamically-typed +//! version. +//! The statically-typed version expects concrete arrays (such as [`PrimitiveArray`](crate::array::PrimitiveArray)); +//! the dynamically-typed version expects `&dyn Array` and errors if the the type is not +//! supported. +//! Some dynamically-typed operators have an auxiliary function, `can_*`, that returns +//! true if the operator can be applied to the particular `DataType`. + +#[cfg(any(feature = "compute_aggregate", feature = "io_parquet"))] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_aggregate")))] +pub mod aggregate; +#[cfg(feature = "compute_arithmetics")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_arithmetics")))] +pub mod arithmetics; +pub mod arity; +pub mod arity_assign; +#[cfg(feature = "compute_bitwise")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_bitwise")))] +pub mod bitwise; +#[cfg(feature = "compute_boolean")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_boolean")))] +pub mod boolean; +#[cfg(feature = "compute_boolean_kleene")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_boolean_kleene")))] +pub mod boolean_kleene; +#[cfg(feature = "compute_cast")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_cast")))] +pub mod cast; +#[cfg(feature = "compute_comparison")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_comparison")))] +pub mod comparison; +#[cfg(feature = "compute_concatenate")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_concatenate")))] +pub mod concatenate; +#[cfg(feature = "compute_filter")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_filter")))] +pub mod filter; +#[cfg(feature = "compute_if_then_else")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_if_then_else")))] +pub mod if_then_else; +#[cfg(feature = "compute_take")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_take")))] +pub mod take; +#[cfg(feature = "compute_temporal")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_temporal")))] +pub mod temporal; +mod utils; diff --git a/crates/nano-arrow/src/compute/take/binary.rs b/crates/nano-arrow/src/compute/take/binary.rs new file mode 100644 index 000000000000..0e6460206f0e --- /dev/null +++ b/crates/nano-arrow/src/compute/take/binary.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::generic_binary::*; +use super::Index; +use crate::array::{Array, BinaryArray, PrimitiveArray}; +use crate::offset::Offset; + +/// `take` implementation for utf8 arrays +pub fn take( + values: &BinaryArray, + indices: &PrimitiveArray, +) -> BinaryArray { + let data_type = values.data_type().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => { + take_no_validity::(values.offsets(), values.values(), indices.values()) + }, + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.offsets(), values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + BinaryArray::::new(data_type, offsets, values, validity) +} diff --git a/crates/nano-arrow/src/compute/take/boolean.rs b/crates/nano-arrow/src/compute/take/boolean.rs new file mode 100644 index 000000000000..62be88e46226 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/boolean.rs @@ -0,0 +1,138 @@ +use super::Index; +use crate::array::{Array, BooleanArray, PrimitiveArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; + +// take implementation when neither values nor indices contain nulls +fn take_no_validity(values: &Bitmap, indices: &[I]) -> (Bitmap, Option) { + let values = indices.iter().map(|index| values.get_bit(index.to_usize())); + let buffer = Bitmap::from_trusted_len_iter(values); + + (buffer, None) +} + +// take implementation when only values contain nulls +fn take_values_validity( + values: &BooleanArray, + indices: &[I], +) -> (Bitmap, Option) { + let validity_values = values.validity().unwrap(); + let validity = indices + .iter() + .map(|index| validity_values.get_bit(index.to_usize())); + let validity = Bitmap::from_trusted_len_iter(validity); + + let values_values = values.values(); + let values = indices + .iter() + .map(|index| values_values.get_bit(index.to_usize())); + let buffer = Bitmap::from_trusted_len_iter(values); + + (buffer, validity.into()) +} + +// take implementation when only indices contain nulls +fn take_indices_validity( + values: &Bitmap, + indices: &PrimitiveArray, +) -> (Bitmap, Option) { + let validity = indices.validity().unwrap(); + + let values = indices.values().iter().enumerate().map(|(i, index)| { + let index = index.to_usize(); + match values.get(index) { + Some(value) => value, + None => { + if !validity.get_bit(i) { + false + } else { + panic!("Out-of-bounds index {index}") + } + }, + } + }); + + let buffer = Bitmap::from_trusted_len_iter(values); + + (buffer, indices.validity().cloned()) +} + +// take implementation when both values and indices contain nulls +fn take_values_indices_validity( + values: &BooleanArray, + indices: &PrimitiveArray, +) -> (Bitmap, Option) { + let mut validity = MutableBitmap::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + + let values_values = values.values(); + let values = indices.iter().map(|index| match index { + Some(index) => { + let index = index.to_usize(); + validity.push(values_validity.get_bit(index)); + values_values.get_bit(index) + }, + None => { + validity.push(false); + false + }, + }); + let values = Bitmap::from_trusted_len_iter(values); + (values, validity.into()) +} + +/// `take` implementation for boolean arrays +pub fn take(values: &BooleanArray, indices: &PrimitiveArray) -> BooleanArray { + let data_type = values.data_type().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => take_no_validity(values.values(), indices.values()), + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + + BooleanArray::new(data_type, values, validity) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Int32Array; + + fn _all_cases() -> Vec<(Int32Array, BooleanArray, BooleanArray)> { + vec![ + ( + Int32Array::from(&[Some(1), Some(0)]), + BooleanArray::from(vec![Some(true), Some(false)]), + BooleanArray::from(vec![Some(false), Some(true)]), + ), + ( + Int32Array::from(&[Some(1), None]), + BooleanArray::from(vec![Some(true), Some(false)]), + BooleanArray::from(vec![Some(false), None]), + ), + ( + Int32Array::from(&[Some(1), Some(0)]), + BooleanArray::from(vec![None, Some(false)]), + BooleanArray::from(vec![Some(false), None]), + ), + ( + Int32Array::from(&[Some(1), None, Some(0)]), + BooleanArray::from(vec![None, Some(false)]), + BooleanArray::from(vec![Some(false), None, None]), + ), + ] + } + + #[test] + fn all_cases() { + let cases = _all_cases(); + for (indices, input, expected) in cases { + let output = take(&input, &indices); + assert_eq!(expected, output); + } + } +} diff --git a/crates/nano-arrow/src/compute/take/dict.rs b/crates/nano-arrow/src/compute/take/dict.rs new file mode 100644 index 000000000000..bb60c09193f7 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/dict.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::primitive::take as take_primitive; +use super::Index; +use crate::array::{DictionaryArray, DictionaryKey, PrimitiveArray}; + +/// `take` implementation for dictionary arrays +/// +/// applies `take` to the keys of the dictionary array and returns a new dictionary array +/// with the same dictionary values and reordered keys +pub fn take(values: &DictionaryArray, indices: &PrimitiveArray) -> DictionaryArray +where + K: DictionaryKey, + I: Index, +{ + let keys = take_primitive::(values.keys(), indices); + // safety - this operation takes a subset of keys and thus preserves the dictionary's invariant + unsafe { + DictionaryArray::::try_new_unchecked( + values.data_type().clone(), + keys, + values.values().clone(), + ) + .unwrap() + } +} diff --git a/crates/nano-arrow/src/compute/take/fixed_size_list.rs b/crates/nano-arrow/src/compute/take/fixed_size_list.rs new file mode 100644 index 000000000000..6e7e74b91720 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/fixed_size_list.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Index; +use crate::array::growable::{Growable, GrowableFixedSizeList}; +use crate::array::{FixedSizeListArray, PrimitiveArray}; + +/// `take` implementation for FixedSizeListArrays +pub fn take( + values: &FixedSizeListArray, + indices: &PrimitiveArray, +) -> FixedSizeListArray { + let mut capacity = 0; + let arrays = indices + .values() + .iter() + .map(|index| { + let index = index.to_usize(); + let slice = values.clone().sliced(index, 1); + capacity += slice.len(); + slice + }) + .collect::>(); + + let arrays = arrays.iter().collect(); + + if let Some(validity) = indices.validity() { + let mut growable: GrowableFixedSizeList = + GrowableFixedSizeList::new(arrays, true, capacity); + + for index in 0..indices.len() { + if validity.get_bit(index) { + growable.extend(index, 0, 1); + } else { + growable.extend_validity(1) + } + } + + growable.into() + } else { + let mut growable: GrowableFixedSizeList = + GrowableFixedSizeList::new(arrays, false, capacity); + for index in 0..indices.len() { + growable.extend(index, 0, 1); + } + + growable.into() + } +} diff --git a/crates/nano-arrow/src/compute/take/generic_binary.rs b/crates/nano-arrow/src/compute/take/generic_binary.rs new file mode 100644 index 000000000000..9f6658c7d5a0 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/generic_binary.rs @@ -0,0 +1,155 @@ +use super::Index; +use crate::array::{GenericBinaryArray, PrimitiveArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::buffer::Buffer; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +pub fn take_values( + length: O, + starts: &[O], + offsets: &OffsetsBuffer, + values: &[u8], +) -> Buffer { + let new_len = length.to_usize(); + let mut buffer = Vec::with_capacity(new_len); + starts + .iter() + .map(|start| start.to_usize()) + .zip(offsets.lengths()) + .for_each(|(start, length)| { + let end = start + length; + buffer.extend_from_slice(&values[start..end]); + }); + buffer.into() +} + +// take implementation when neither values nor indices contain nulls +pub fn take_no_validity( + offsets: &OffsetsBuffer, + values: &[u8], + indices: &[I], +) -> (OffsetsBuffer, Buffer, Option) { + let mut buffer = Vec::::new(); + let lengths = indices.iter().map(|index| index.to_usize()).map(|index| { + let (start, end) = offsets.start_end(index); + // todo: remove this bound check + buffer.extend_from_slice(&values[start..end]); + end - start + }); + let offsets = Offsets::try_from_lengths(lengths).expect(""); + + (offsets.into(), buffer.into(), None) +} + +// take implementation when only values contain nulls +pub fn take_values_validity>( + values: &A, + indices: &[I], +) -> (OffsetsBuffer, Buffer, Option) { + let validity_values = values.validity().unwrap(); + let validity = indices + .iter() + .map(|index| validity_values.get_bit(index.to_usize())); + let validity = Bitmap::from_trusted_len_iter(validity); + + let mut length = O::default(); + + let offsets = values.offsets(); + let values_values = values.values(); + + let mut starts = Vec::::with_capacity(indices.len()); + let offsets = indices.iter().map(|index| { + let index = index.to_usize(); + let start = offsets[index]; + length += offsets[index + 1] - start; + starts.push(start); + length + }); + let offsets = std::iter::once(O::default()) + .chain(offsets) + .collect::>(); + // Safety: by construction offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + + let buffer = take_values(length, starts.as_slice(), &offsets, values_values); + + (offsets, buffer, validity.into()) +} + +// take implementation when only indices contain nulls +pub fn take_indices_validity( + offsets: &OffsetsBuffer, + values: &[u8], + indices: &PrimitiveArray, +) -> (OffsetsBuffer, Buffer, Option) { + let mut length = O::default(); + + let offsets = offsets.buffer(); + + let mut starts = Vec::::with_capacity(indices.len()); + let offsets = indices.values().iter().map(|index| { + let index = index.to_usize(); + match offsets.get(index + 1) { + Some(&next) => { + let start = offsets[index]; + length += next - start; + starts.push(start); + }, + None => starts.push(O::default()), + }; + length + }); + let offsets = std::iter::once(O::default()) + .chain(offsets) + .collect::>(); + // Safety: by construction offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + + let buffer = take_values(length, &starts, &offsets, values); + + (offsets, buffer, indices.validity().cloned()) +} + +// take implementation when both indices and values contain nulls +pub fn take_values_indices_validity>( + values: &A, + indices: &PrimitiveArray, +) -> (OffsetsBuffer, Buffer, Option) { + let mut length = O::default(); + let mut validity = MutableBitmap::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + let offsets = values.offsets(); + let values_values = values.values(); + + let mut starts = Vec::::with_capacity(indices.len()); + let offsets = indices.iter().map(|index| { + match index { + Some(index) => { + let index = index.to_usize(); + if values_validity.get_bit(index) { + validity.push(true); + length += offsets[index + 1] - offsets[index]; + starts.push(offsets[index]); + } else { + validity.push(false); + starts.push(O::default()); + } + }, + None => { + validity.push(false); + starts.push(O::default()); + }, + }; + length + }); + let offsets = std::iter::once(O::default()) + .chain(offsets) + .collect::>(); + // Safety: by construction offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + + let buffer = take_values(length, &starts, &offsets, values_values); + + (offsets, buffer, validity.into()) +} diff --git a/crates/nano-arrow/src/compute/take/list.rs b/crates/nano-arrow/src/compute/take/list.rs new file mode 100644 index 000000000000..58fb9d6fd788 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/list.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Index; +use crate::array::growable::{Growable, GrowableList}; +use crate::array::{ListArray, PrimitiveArray}; +use crate::offset::Offset; + +/// `take` implementation for ListArrays +pub fn take( + values: &ListArray, + indices: &PrimitiveArray, +) -> ListArray { + let mut capacity = 0; + let arrays = indices + .values() + .iter() + .map(|index| { + let index = index.to_usize(); + let slice = values.clone().sliced(index, 1); + capacity += slice.len(); + slice + }) + .collect::>>(); + + let arrays = arrays.iter().collect(); + + if let Some(validity) = indices.validity() { + let mut growable: GrowableList = GrowableList::new(arrays, true, capacity); + + for index in 0..indices.len() { + if validity.get_bit(index) { + growable.extend(index, 0, 1); + } else { + growable.extend_validity(1) + } + } + + growable.into() + } else { + let mut growable: GrowableList = GrowableList::new(arrays, false, capacity); + for index in 0..indices.len() { + growable.extend(index, 0, 1); + } + + growable.into() + } +} diff --git a/crates/nano-arrow/src/compute/take/mod.rs b/crates/nano-arrow/src/compute/take/mod.rs new file mode 100644 index 000000000000..d526713a4327 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/mod.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines take kernel for [`Array`] + +use crate::array::{new_empty_array, Array, NullArray, PrimitiveArray}; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::Index; + +mod binary; +mod boolean; +mod dict; +mod fixed_size_list; +mod generic_binary; +mod list; +mod primitive; +mod structure; +mod utf8; + +pub(crate) use boolean::take as take_boolean; + +/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls. +/// The returned array has a length equal to `indices.len()`. +pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result> { + if indices.len() == 0 { + return Ok(new_empty_array(values.data_type().clone())); + } + + use crate::datatypes::PhysicalType::*; + match values.data_type().to_physical_type() { + Null => Ok(Box::new(NullArray::new( + values.data_type().clone(), + indices.len(), + ))), + Boolean => { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean::take::(values, indices))) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive::take::<$T, _>(&values, indices))) + }), + LargeUtf8 => { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8::take::(values, indices))) + }, + LargeBinary => { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(binary::take::(values, indices))) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(dict::take::<$T, _>(&values, indices))) + }) + }, + Struct => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(structure::take::<_>(array, indices)?)) + }, + LargeList => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(list::take::(array, indices))) + }, + FixedSizeList => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(fixed_size_list::take::(array, indices))) + }, + t => unimplemented!("Take not supported for data type {:?}", t), + } +} + +/// Checks if an array of type `datatype` can perform take operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::take::can_take; +/// use arrow2::datatypes::{DataType}; +/// +/// let data_type = DataType::Int8; +/// assert_eq!(can_take(&data_type), true); +/// ``` +pub fn can_take(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(_) + | DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Dictionary(..) + ) +} diff --git a/crates/nano-arrow/src/compute/take/primitive.rs b/crates/nano-arrow/src/compute/take/primitive.rs new file mode 100644 index 000000000000..5ce53ba7cc20 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/primitive.rs @@ -0,0 +1,112 @@ +use super::Index; +use crate::array::{Array, PrimitiveArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::buffer::Buffer; +use crate::types::NativeType; + +// take implementation when neither values nor indices contain nulls +fn take_no_validity( + values: &[T], + indices: &[I], +) -> (Buffer, Option) { + let values = indices + .iter() + .map(|index| values[index.to_usize()]) + .collect::>(); + + (values.into(), None) +} + +// take implementation when only values contain nulls +fn take_values_validity( + values: &PrimitiveArray, + indices: &[I], +) -> (Buffer, Option) { + let values_validity = values.validity().unwrap(); + + let validity = indices + .iter() + .map(|index| values_validity.get_bit(index.to_usize())); + let validity = MutableBitmap::from_trusted_len_iter(validity); + + let values_values = values.values(); + + let values = indices + .iter() + .map(|index| values_values[index.to_usize()]) + .collect::>(); + + (values.into(), validity.into()) +} + +// take implementation when only indices contain nulls +fn take_indices_validity( + values: &[T], + indices: &PrimitiveArray, +) -> (Buffer, Option) { + let validity = indices.validity().unwrap(); + let values = indices + .values() + .iter() + .enumerate() + .map(|(i, index)| { + let index = index.to_usize(); + match values.get(index) { + Some(value) => *value, + None => { + if !validity.get_bit(i) { + T::default() + } else { + panic!("Out-of-bounds index {index}") + } + }, + } + }) + .collect::>(); + + (values.into(), indices.validity().cloned()) +} + +// take implementation when both values and indices contain nulls +fn take_values_indices_validity( + values: &PrimitiveArray, + indices: &PrimitiveArray, +) -> (Buffer, Option) { + let mut bitmap = MutableBitmap::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + + let values_values = values.values(); + let values = indices + .iter() + .map(|index| match index { + Some(index) => { + let index = index.to_usize(); + bitmap.push(values_validity.get_bit(index)); + values_values[index] + }, + None => { + bitmap.push(false); + T::default() + }, + }) + .collect::>(); + (values.into(), bitmap.into()) +} + +/// `take` implementation for primitive arrays +pub fn take( + values: &PrimitiveArray, + indices: &PrimitiveArray, +) -> PrimitiveArray { + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + let (buffer, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => take_no_validity::(values.values(), indices.values()), + (true, false) => take_values_validity::(values, indices.values()), + (false, true) => take_indices_validity::(values.values(), indices), + (true, true) => take_values_indices_validity::(values, indices), + }; + + PrimitiveArray::::new(values.data_type().clone(), buffer, validity) +} diff --git a/crates/nano-arrow/src/compute/take/structure.rs b/crates/nano-arrow/src/compute/take/structure.rs new file mode 100644 index 000000000000..e0a2717f5746 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/structure.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Index; +use crate::array::{Array, PrimitiveArray, StructArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::error::Result; + +#[inline] +fn take_validity( + validity: Option<&Bitmap>, + indices: &PrimitiveArray, +) -> Result> { + let indices_validity = indices.validity(); + match (validity, indices_validity) { + (None, _) => Ok(indices_validity.cloned()), + (Some(validity), None) => { + let iter = indices.values().iter().map(|index| { + let index = index.to_usize(); + validity.get_bit(index) + }); + Ok(MutableBitmap::from_trusted_len_iter(iter).into()) + }, + (Some(validity), _) => { + let iter = indices.iter().map(|x| match x { + Some(index) => { + let index = index.to_usize(); + validity.get_bit(index) + }, + None => false, + }); + Ok(MutableBitmap::from_trusted_len_iter(iter).into()) + }, + } +} + +pub fn take(array: &StructArray, indices: &PrimitiveArray) -> Result { + let values: Vec> = array + .values() + .iter() + .map(|a| super::take(a.as_ref(), indices)) + .collect::>()?; + let validity = take_validity(array.validity(), indices)?; + Ok(StructArray::new( + array.data_type().clone(), + values, + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/take/utf8.rs b/crates/nano-arrow/src/compute/take/utf8.rs new file mode 100644 index 000000000000..3f5f5877c12f --- /dev/null +++ b/crates/nano-arrow/src/compute/take/utf8.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::generic_binary::*; +use super::Index; +use crate::array::{Array, PrimitiveArray, Utf8Array}; +use crate::offset::Offset; + +/// `take` implementation for utf8 arrays +pub fn take( + values: &Utf8Array, + indices: &PrimitiveArray, +) -> Utf8Array { + let data_type = values.data_type().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => { + take_no_validity::(values.offsets(), values.values(), indices.values()) + }, + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.offsets(), values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Int32Array; + + fn _all_cases() -> Vec<(Int32Array, Utf8Array, Utf8Array)> { + vec![ + ( + Int32Array::from(&[Some(1), Some(0)]), + Utf8Array::::from(vec![Some("one"), Some("two")]), + Utf8Array::::from(vec![Some("two"), Some("one")]), + ), + ( + Int32Array::from(&[Some(1), None]), + Utf8Array::::from(vec![Some("one"), Some("two")]), + Utf8Array::::from(vec![Some("two"), None]), + ), + ( + Int32Array::from(&[Some(1), Some(0)]), + Utf8Array::::from(vec![None, Some("two")]), + Utf8Array::::from(vec![Some("two"), None]), + ), + ( + Int32Array::from(&[Some(1), None, Some(0)]), + Utf8Array::::from(vec![None, Some("two")]), + Utf8Array::::from(vec![Some("two"), None, None]), + ), + ] + } + + #[test] + fn all_cases() { + let cases = _all_cases::(); + for (indices, input, expected) in cases { + let output = take(&input, &indices); + assert_eq!(expected, output); + } + let cases = _all_cases::(); + for (indices, input, expected) in cases { + let output = take(&input, &indices); + assert_eq!(expected, output); + } + } +} diff --git a/crates/nano-arrow/src/compute/temporal.rs b/crates/nano-arrow/src/compute/temporal.rs new file mode 100644 index 000000000000..132492f58b6e --- /dev/null +++ b/crates/nano-arrow/src/compute/temporal.rs @@ -0,0 +1,410 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines temporal kernels for time and date related functions. + +use chrono::{Datelike, Timelike}; + +use super::arity::unary; +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::temporal_conversions::*; +use crate::types::NativeType; + +// Create and implement a trait that converts chrono's `Weekday` +// type into `u32` +trait U32Weekday: Datelike { + fn u32_weekday(&self) -> u32 { + self.weekday().number_from_monday() + } +} + +impl U32Weekday for chrono::NaiveDateTime {} +impl U32Weekday for chrono::DateTime {} + +// Create and implement a trait that converts chrono's `IsoWeek` +// type into `u32` +trait U32IsoWeek: Datelike { + fn u32_iso_week(&self) -> u32 { + self.iso_week().week() + } +} + +impl U32IsoWeek for chrono::NaiveDateTime {} +impl U32IsoWeek for chrono::DateTime {} + +// Macro to avoid repetition in functions, that apply +// `chrono::Datelike` methods on Arrays +macro_rules! date_like { + ($extract:ident, $array:ident, $data_type:path) => { + match $array.data_type().to_logical_type() { + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, None) => { + date_variants($array, $data_type, |x| x.$extract()) + }, + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let array = $array.as_any().downcast_ref().unwrap(); + + if let Ok(timezone) = parse_offset(timezone_str) { + Ok(extract_impl(array, *time_unit, timezone, |x| x.$extract())) + } else { + chrono_tz(array, *time_unit, timezone_str, |x| x.$extract()) + } + }, + dt => Err(Error::NotYetImplemented(format!( + "\"{}\" does not support type {:?}", + stringify!($extract), + dt + ))), + } + }; +} + +/// Extracts the years of a temporal array as [`PrimitiveArray`]. +/// Use [`can_year`] to check if this operation is supported for the target [`DataType`]. +pub fn year(array: &dyn Array) -> Result> { + date_like!(year, array, DataType::Int32) +} + +/// Extracts the months of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 1 to 12. +/// Use [`can_month`] to check if this operation is supported for the target [`DataType`]. +pub fn month(array: &dyn Array) -> Result> { + date_like!(month, array, DataType::UInt32) +} + +/// Extracts the days of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 1 to 32 (Last day depends on month). +/// Use [`can_day`] to check if this operation is supported for the target [`DataType`]. +pub fn day(array: &dyn Array) -> Result> { + date_like!(day, array, DataType::UInt32) +} + +/// Extracts weekday of a temporal array as [`PrimitiveArray`]. +/// Monday is 1, Tuesday is 2, ..., Sunday is 7. +/// Use [`can_weekday`] to check if this operation is supported for the target [`DataType`] +pub fn weekday(array: &dyn Array) -> Result> { + date_like!(u32_weekday, array, DataType::UInt32) +} + +/// Extracts ISO week of a temporal array as [`PrimitiveArray`] +/// Value ranges from 1 to 53 (Last week depends on the year). +/// Use [`can_iso_week`] to check if this operation is supported for the target [`DataType`] +pub fn iso_week(array: &dyn Array) -> Result> { + date_like!(u32_iso_week, array, DataType::UInt32) +} + +// Macro to avoid repetition in functions, that apply +// `chrono::Timelike` methods on Arrays +macro_rules! time_like { + ($extract:ident, $array:ident, $data_type:path) => { + match $array.data_type().to_logical_type() { + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, None) => { + date_variants($array, $data_type, |x| x.$extract()) + }, + DataType::Time32(_) | DataType::Time64(_) => { + time_variants($array, DataType::UInt32, |x| x.$extract()) + }, + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let array = $array.as_any().downcast_ref().unwrap(); + + if let Ok(timezone) = parse_offset(timezone_str) { + Ok(extract_impl(array, *time_unit, timezone, |x| x.$extract())) + } else { + chrono_tz(array, *time_unit, timezone_str, |x| x.$extract()) + } + }, + dt => Err(Error::NotYetImplemented(format!( + "\"{}\" does not support type {:?}", + stringify!($extract), + dt + ))), + } + }; +} + +/// Extracts the hours of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 23. +/// Use [`can_hour`] to check if this operation is supported for the target [`DataType`]. +pub fn hour(array: &dyn Array) -> Result> { + time_like!(hour, array, DataType::UInt32) +} + +/// Extracts the minutes of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 59. +/// Use [`can_minute`] to check if this operation is supported for the target [`DataType`]. +pub fn minute(array: &dyn Array) -> Result> { + time_like!(minute, array, DataType::UInt32) +} + +/// Extracts the seconds of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 59. +/// Use [`can_second`] to check if this operation is supported for the target [`DataType`]. +pub fn second(array: &dyn Array) -> Result> { + time_like!(second, array, DataType::UInt32) +} + +/// Extracts the nanoseconds of a temporal array as [`PrimitiveArray`]. +/// Use [`can_nanosecond`] to check if this operation is supported for the target [`DataType`]. +pub fn nanosecond(array: &dyn Array) -> Result> { + time_like!(nanosecond, array, DataType::UInt32) +} + +fn date_variants(array: &dyn Array, data_type: DataType, op: F) -> Result> +where + O: NativeType, + F: Fn(chrono::NaiveDateTime) -> O, +{ + match array.data_type().to_logical_type() { + DataType::Date32 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(date32_to_datetime(x)), data_type)) + }, + DataType::Date64 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(date64_to_datetime(x)), data_type)) + }, + DataType::Timestamp(time_unit, None) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let func = match time_unit { + TimeUnit::Second => timestamp_s_to_datetime, + TimeUnit::Millisecond => timestamp_ms_to_datetime, + TimeUnit::Microsecond => timestamp_us_to_datetime, + TimeUnit::Nanosecond => timestamp_ns_to_datetime, + }; + Ok(unary(array, |x| op(func(x)), data_type)) + }, + _ => unreachable!(), + } +} + +fn time_variants(array: &dyn Array, data_type: DataType, op: F) -> Result> +where + O: NativeType, + F: Fn(chrono::NaiveTime) -> O, +{ + match array.data_type().to_logical_type() { + DataType::Time32(TimeUnit::Second) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time32s_to_time(x)), data_type)) + }, + DataType::Time32(TimeUnit::Millisecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time32ms_to_time(x)), data_type)) + }, + DataType::Time64(TimeUnit::Microsecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time64us_to_time(x)), data_type)) + }, + DataType::Time64(TimeUnit::Nanosecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time64ns_to_time(x)), data_type)) + }, + _ => unreachable!(), + } +} + +#[cfg(feature = "chrono-tz")] +fn chrono_tz( + array: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, + op: F, +) -> Result> +where + O: NativeType, + F: Fn(chrono::DateTime) -> O, +{ + let timezone = parse_offset_tz(timezone_str)?; + Ok(extract_impl(array, time_unit, timezone, op)) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz( + _: &PrimitiveArray, + _: TimeUnit, + timezone_str: &str, + _: F, +) -> Result> +where + O: NativeType, + F: Fn(chrono::DateTime) -> O, +{ + Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))) +} + +fn extract_impl( + array: &PrimitiveArray, + time_unit: TimeUnit, + timezone: T, + extract: F, +) -> PrimitiveArray +where + T: chrono::TimeZone, + A: NativeType, + F: Fn(chrono::DateTime) -> A, +{ + match time_unit { + TimeUnit::Second => { + let op = |x| { + let datetime = timestamp_s_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Millisecond => { + let op = |x| { + let datetime = timestamp_ms_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Microsecond => { + let op = |x| { + let datetime = timestamp_us_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Nanosecond => { + let op = |x| { + let datetime = timestamp_ns_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + } +} + +/// Checks if an array of type `datatype` can perform year operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::temporal::can_year; +/// use arrow2::datatypes::{DataType}; +/// +/// assert_eq!(can_year(&DataType::Date32), true); +/// assert_eq!(can_year(&DataType::Int8), false); +/// ``` +pub fn can_year(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `datatype` can perform month operation +pub fn can_month(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `datatype` can perform day operation +pub fn can_day(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `data_type` can perform weekday operation +pub fn can_weekday(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `data_type` can perform ISO week operation +pub fn can_iso_week(data_type: &DataType) -> bool { + can_date(data_type) +} + +fn can_date(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) + ) +} + +/// Checks if an array of type `datatype` can perform hour operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::temporal::can_hour; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// +/// assert_eq!(can_hour(&DataType::Time32(TimeUnit::Second)), true); +/// assert_eq!(can_hour(&DataType::Int8), false); +/// ``` +pub fn can_hour(data_type: &DataType) -> bool { + can_time(data_type) +} + +/// Checks if an array of type `datatype` can perform minute operation +pub fn can_minute(data_type: &DataType) -> bool { + can_time(data_type) +} + +/// Checks if an array of type `datatype` can perform second operation +pub fn can_second(data_type: &DataType) -> bool { + can_time(data_type) +} + +/// Checks if an array of type `datatype` can perform nanosecond operation +pub fn can_nanosecond(data_type: &DataType) -> bool { + can_time(data_type) +} + +fn can_time(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + ) +} diff --git a/crates/nano-arrow/src/compute/utils.rs b/crates/nano-arrow/src/compute/utils.rs new file mode 100644 index 000000000000..e06acdcd470c --- /dev/null +++ b/crates/nano-arrow/src/compute/utils.rs @@ -0,0 +1,23 @@ +use crate::array::Array; +use crate::bitmap::Bitmap; +use crate::error::{Error, Result}; + +pub fn combine_validities(lhs: Option<&Bitmap>, rhs: Option<&Bitmap>) -> Option { + match (lhs, rhs) { + (Some(lhs), None) => Some(lhs.clone()), + (None, Some(rhs)) => Some(rhs.clone()), + (None, None) => None, + (Some(lhs), Some(rhs)) => Some(lhs & rhs), + } +} + +// Errors iff the two arrays have a different length. +#[inline] +pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> Result<()> { + if lhs.len() != rhs.len() { + return Err(Error::InvalidArgumentError( + "Arrays must have the same length".to_string(), + )); + } + Ok(()) +} diff --git a/crates/nano-arrow/src/datatypes/field.rs b/crates/nano-arrow/src/datatypes/field.rs new file mode 100644 index 000000000000..a32396780cdf --- /dev/null +++ b/crates/nano-arrow/src/datatypes/field.rs @@ -0,0 +1,96 @@ +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +use super::{DataType, Metadata}; + +/// Represents Arrow's metadata of a "column". +/// +/// A [`Field`] is the closest representation of the traditional "column": a logical type +/// ([`DataType`]) with a name and nullability. +/// A Field has optional [`Metadata`] that can be used to annotate the field with custom metadata. +/// +/// Almost all IO in this crate uses [`Field`] to represent logical information about the data +/// to be serialized. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub struct Field { + /// Its name + pub name: String, + /// Its logical [`DataType`] + pub data_type: DataType, + /// Its nullability + pub is_nullable: bool, + /// Additional custom (opaque) metadata. + pub metadata: Metadata, +} + +impl Field { + /// Creates a new [`Field`]. + pub fn new>(name: T, data_type: DataType, is_nullable: bool) -> Self { + Field { + name: name.into(), + data_type, + is_nullable, + metadata: Default::default(), + } + } + + /// Creates a new [`Field`] with metadata. + #[inline] + pub fn with_metadata(self, metadata: Metadata) -> Self { + Self { + name: self.name, + data_type: self.data_type, + is_nullable: self.is_nullable, + metadata, + } + } + + /// Returns the [`Field`]'s [`DataType`]. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[cfg(feature = "arrow_rs")] +impl From for arrow_schema::Field { + fn from(value: Field) -> Self { + Self::new(value.name, value.data_type.into(), value.is_nullable) + .with_metadata(value.metadata.into_iter().collect()) + } +} + +#[cfg(feature = "arrow_rs")] +impl From for Field { + fn from(value: arrow_schema::Field) -> Self { + (&value).into() + } +} + +#[cfg(feature = "arrow_rs")] +impl From<&arrow_schema::Field> for Field { + fn from(value: &arrow_schema::Field) -> Self { + let data_type = value.data_type().clone().into(); + let metadata = value + .metadata() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + Self::new(value.name(), data_type, value.is_nullable()).with_metadata(metadata) + } +} + +#[cfg(feature = "arrow_rs")] +impl From for Field { + fn from(value: arrow_schema::FieldRef) -> Self { + value.as_ref().into() + } +} + +#[cfg(feature = "arrow_rs")] +impl From<&arrow_schema::FieldRef> for Field { + fn from(value: &arrow_schema::FieldRef) -> Self { + value.as_ref().into() + } +} diff --git a/crates/nano-arrow/src/datatypes/mod.rs b/crates/nano-arrow/src/datatypes/mod.rs new file mode 100644 index 000000000000..95ba5e69bff8 --- /dev/null +++ b/crates/nano-arrow/src/datatypes/mod.rs @@ -0,0 +1,513 @@ +#![forbid(unsafe_code)] +//! Contains all metadata, such as [`PhysicalType`], [`DataType`], [`Field`] and [`Schema`]. + +mod field; +mod physical_type; +mod schema; + +use std::collections::BTreeMap; +use std::sync::Arc; + +pub use field::Field; +pub use physical_type::*; +pub use schema::Schema; +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +/// typedef for [BTreeMap] denoting [`Field`]'s and [`Schema`]'s metadata. +pub type Metadata = BTreeMap; +/// typedef for [Option<(String, Option)>] descr +pub(crate) type Extension = Option<(String, Option)>; + +/// The set of supported logical types in this crate. +/// +/// Each variant uniquely identifies a logical type, which define specific semantics to the data +/// (e.g. how it should be represented). +/// Each variant has a corresponding [`PhysicalType`], obtained via [`DataType::to_physical_type`], +/// which declares the in-memory representation of data. +/// The [`DataType::Extension`] is special in that it augments a [`DataType`] with metadata to support custom types. +/// Use `to_logical_type` to desugar such type and return its corresponding logical type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum DataType { + /// Null type + Null, + /// `true` and `false`. + Boolean, + /// An [`i8`] + Int8, + /// An [`i16`] + Int16, + /// An [`i32`] + Int32, + /// An [`i64`] + Int64, + /// An [`u8`] + UInt8, + /// An [`u16`] + UInt16, + /// An [`u32`] + UInt32, + /// An [`u64`] + UInt64, + /// An 16-bit float + Float16, + /// A [`f32`] + Float32, + /// A [`f64`] + Float64, + /// A [`i64`] representing a timestamp measured in [`TimeUnit`] with an optional timezone. + /// + /// Time is measured as a Unix epoch, counting the seconds from + /// 00:00:00.000 on 1 January 1970, excluding leap seconds, + /// as a 64-bit signed integer. + /// + /// The time zone is a string indicating the name of a time zone, one of: + /// + /// * As used in the Olson time zone database (the "tz database" or + /// "tzdata"), such as "America/New_York" + /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + /// When the timezone is not specified, the timestamp is considered to have no timezone + /// and is represented _as is_ + Timestamp(TimeUnit, Option), + /// An [`i32`] representing the elapsed time since UNIX epoch (1970-01-01) + /// in days. + Date32, + /// An [`i64`] representing the elapsed time since UNIX epoch (1970-01-01) + /// in milliseconds. Values are evenly divisible by 86400000. + Date64, + /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + /// Only [`TimeUnit::Second`] and [`TimeUnit::Millisecond`] are supported on this variant. + Time32(TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + /// Only [`TimeUnit::Microsecond`] and [`TimeUnit::Nanosecond`] are supported on this variant. + Time64(TimeUnit), + /// Measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.) + Duration(TimeUnit), + /// A "calendar" interval modeling elapsed time that takes into account calendar shifts. + /// For example an interval of 1 day may represent more than 24 hours. + Interval(IntervalUnit), + /// Opaque binary data of variable length whose offsets are represented as [`i32`]. + Binary, + /// Opaque binary data of fixed size. + /// Enum parameter specifies the number of bytes per value. + FixedSizeBinary(usize), + /// Opaque binary data of variable length whose offsets are represented as [`i64`]. + LargeBinary, + /// A variable-length UTF-8 encoded string whose offsets are represented as [`i32`]. + Utf8, + /// A variable-length UTF-8 encoded string whose offsets are represented as [`i64`]. + LargeUtf8, + /// A list of some logical data type whose offsets are represented as [`i32`]. + List(Box), + /// A list of some logical data type with a fixed number of elements. + FixedSizeList(Box, usize), + /// A list of some logical data type whose offsets are represented as [`i64`]. + LargeList(Box), + /// A nested [`DataType`] with a given number of [`Field`]s. + Struct(Vec), + /// A nested datatype that can represent slots of differing types. + /// Third argument represents mode + Union(Vec, Option>, UnionMode), + /// A nested type that is represented as + /// + /// List> + /// + /// In this layout, the keys and values are each respectively contiguous. We do + /// not constrain the key and value types, so the application is responsible + /// for ensuring that the keys are hashable and unique. Whether the keys are sorted + /// may be set in the metadata for this field. + /// + /// In a field with Map type, the field has a child Struct field, which then + /// has two children: key type and the second the value type. The names of the + /// child fields may be respectively "entries", "key", and "value", but this is + /// not enforced. + /// + /// Map + /// ```text + /// - child[0] entries: Struct + /// - child[0] key: K + /// - child[1] value: V + /// ``` + /// Neither the "entries" field nor the "key" field may be nullable. + /// + /// The metadata is structured so that Arrow systems without special handling + /// for Map can make Map an alias for List. The "layout" attribute for the Map + /// field must have the same contents as a List. + Map(Box, bool), + /// A dictionary encoded array (`key_type`, `value_type`), where + /// each array element is an index of `key_type` into an + /// associated dictionary of `value_type`. + /// + /// Dictionary arrays are used to store columns of `value_type` + /// that contain many repeated values using less memory, but with + /// a higher CPU overhead for some operations. + /// + /// This type mostly used to represent low cardinality string + /// arrays or a limited set of primitive types as integers. + /// + /// The `bool` value indicates the `Dictionary` is sorted if set to `true`. + Dictionary(IntegerType, Box, bool), + /// Decimal value with precision and scale + /// precision is the number of digits in the number and + /// scale is the number of decimal places. + /// The number 999.99 has a precision of 5 and scale of 2. + Decimal(usize, usize), + /// Decimal backed by 256 bits + Decimal256(usize, usize), + /// Extension type. + Extension(String, Box, Option), +} + +#[cfg(feature = "arrow_rs")] +impl From for arrow_schema::DataType { + fn from(value: DataType) -> Self { + use arrow_schema::{Field as ArrowField, UnionFields}; + + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::Int8, + DataType::Int16 => Self::Int16, + DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::UInt8 => Self::UInt8, + DataType::UInt16 => Self::UInt16, + DataType::UInt32 => Self::UInt32, + DataType::UInt64 => Self::UInt64, + DataType::Float16 => Self::Float16, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Timestamp(unit, tz) => Self::Timestamp(unit.into(), tz.map(Into::into)), + DataType::Date32 => Self::Date32, + DataType::Date64 => Self::Date64, + DataType::Time32(unit) => Self::Time32(unit.into()), + DataType::Time64(unit) => Self::Time64(unit.into()), + DataType::Duration(unit) => Self::Duration(unit.into()), + DataType::Interval(unit) => Self::Interval(unit.into()), + DataType::Binary => Self::Binary, + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), + DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Utf8, + DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(f) => Self::List(Arc::new((*f).into())), + DataType::FixedSizeList(f, size) => { + Self::FixedSizeList(Arc::new((*f).into()), size as _) + }, + DataType::LargeList(f) => Self::LargeList(Arc::new((*f).into())), + DataType::Struct(f) => Self::Struct(f.into_iter().map(ArrowField::from).collect()), + DataType::Union(fields, Some(ids), mode) => { + let ids = ids.into_iter().map(|x| x as _); + let fields = fields.into_iter().map(ArrowField::from); + Self::Union(UnionFields::new(ids, fields), mode.into()) + }, + DataType::Union(fields, None, mode) => { + let ids = 0..fields.len() as i8; + let fields = fields.into_iter().map(ArrowField::from); + Self::Union(UnionFields::new(ids, fields), mode.into()) + }, + DataType::Map(f, ordered) => Self::Map(Arc::new((*f).into()), ordered), + DataType::Dictionary(key, value, _) => Self::Dictionary( + Box::new(DataType::from(key).into()), + Box::new((*value).into()), + ), + DataType::Decimal(precision, scale) => Self::Decimal128(precision as _, scale as _), + DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), + DataType::Extension(_, d, _) => (*d).into(), + } + } +} + +#[cfg(feature = "arrow_rs")] +impl From for DataType { + fn from(value: arrow_schema::DataType) -> Self { + use arrow_schema::DataType; + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::Int8, + DataType::Int16 => Self::Int16, + DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::UInt8 => Self::UInt8, + DataType::UInt16 => Self::UInt16, + DataType::UInt32 => Self::UInt32, + DataType::UInt64 => Self::UInt64, + DataType::Float16 => Self::Float16, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Timestamp(unit, tz) => { + Self::Timestamp(unit.into(), tz.map(|x| x.to_string())) + }, + DataType::Date32 => Self::Date32, + DataType::Date64 => Self::Date64, + DataType::Time32(unit) => Self::Time32(unit.into()), + DataType::Time64(unit) => Self::Time64(unit.into()), + DataType::Duration(unit) => Self::Duration(unit.into()), + DataType::Interval(unit) => Self::Interval(unit.into()), + DataType::Binary => Self::Binary, + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), + DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Utf8, + DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(f) => Self::List(Box::new(f.into())), + DataType::FixedSizeList(f, size) => Self::FixedSizeList(Box::new(f.into()), size as _), + DataType::LargeList(f) => Self::LargeList(Box::new(f.into())), + DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()), + DataType::Union(fields, mode) => { + let ids = fields.iter().map(|(x, _)| x as _).collect(); + let fields = fields.iter().map(|(_, f)| f.into()).collect(); + Self::Union(fields, Some(ids), mode.into()) + }, + DataType::Map(f, ordered) => Self::Map(Box::new(f.into()), ordered), + DataType::Dictionary(key, value) => { + let key = match *key { + DataType::Int8 => IntegerType::Int8, + DataType::Int16 => IntegerType::Int16, + DataType::Int32 => IntegerType::Int32, + DataType::Int64 => IntegerType::Int64, + DataType::UInt8 => IntegerType::UInt8, + DataType::UInt16 => IntegerType::UInt16, + DataType::UInt32 => IntegerType::UInt32, + DataType::UInt64 => IntegerType::UInt64, + d => panic!("illegal dictionary key type: {d}"), + }; + Self::Dictionary(key, Box::new((*value).into()), false) + }, + DataType::Decimal128(precision, scale) => Self::Decimal(precision as _, scale as _), + DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), + DataType::RunEndEncoded(_, _) => panic!("Run-end encoding not supported by arrow2"), + } + } +} + +/// Mode of [`DataType::Union`] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum UnionMode { + /// Dense union + Dense, + /// Sparse union + Sparse, +} + +#[cfg(feature = "arrow_rs")] +impl From for arrow_schema::UnionMode { + fn from(value: UnionMode) -> Self { + match value { + UnionMode::Dense => Self::Dense, + UnionMode::Sparse => Self::Sparse, + } + } +} + +#[cfg(feature = "arrow_rs")] +impl From for UnionMode { + fn from(value: arrow_schema::UnionMode) -> Self { + match value { + arrow_schema::UnionMode::Dense => Self::Dense, + arrow_schema::UnionMode::Sparse => Self::Sparse, + } + } +} + +impl UnionMode { + /// Constructs a [`UnionMode::Sparse`] if the input bool is true, + /// or otherwise constructs a [`UnionMode::Dense`] + pub fn sparse(is_sparse: bool) -> Self { + if is_sparse { + Self::Sparse + } else { + Self::Dense + } + } + + /// Returns whether the mode is sparse + pub fn is_sparse(&self) -> bool { + matches!(self, Self::Sparse) + } + + /// Returns whether the mode is dense + pub fn is_dense(&self) -> bool { + matches!(self, Self::Dense) + } +} + +/// The time units defined in Arrow. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum TimeUnit { + /// Time in seconds. + Second, + /// Time in milliseconds. + Millisecond, + /// Time in microseconds. + Microsecond, + /// Time in nanoseconds. + Nanosecond, +} + +#[cfg(feature = "arrow_rs")] +impl From for arrow_schema::TimeUnit { + fn from(value: TimeUnit) -> Self { + match value { + TimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Second => Self::Second, + } + } +} + +#[cfg(feature = "arrow_rs")] +impl From for TimeUnit { + fn from(value: arrow_schema::TimeUnit) -> Self { + match value { + arrow_schema::TimeUnit::Nanosecond => Self::Nanosecond, + arrow_schema::TimeUnit::Millisecond => Self::Millisecond, + arrow_schema::TimeUnit::Microsecond => Self::Microsecond, + arrow_schema::TimeUnit::Second => Self::Second, + } + } +} + +/// Interval units defined in Arrow +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum IntervalUnit { + /// The number of elapsed whole months. + YearMonth, + /// The number of elapsed days and milliseconds, + /// stored as 2 contiguous `i32` + DayTime, + /// The number of elapsed months (i32), days (i32) and nanoseconds (i64). + MonthDayNano, +} + +#[cfg(feature = "arrow_rs")] +impl From for arrow_schema::IntervalUnit { + fn from(value: IntervalUnit) -> Self { + match value { + IntervalUnit::YearMonth => Self::YearMonth, + IntervalUnit::DayTime => Self::DayTime, + IntervalUnit::MonthDayNano => Self::MonthDayNano, + } + } +} + +#[cfg(feature = "arrow_rs")] +impl From for IntervalUnit { + fn from(value: arrow_schema::IntervalUnit) -> Self { + match value { + arrow_schema::IntervalUnit::YearMonth => Self::YearMonth, + arrow_schema::IntervalUnit::DayTime => Self::DayTime, + arrow_schema::IntervalUnit::MonthDayNano => Self::MonthDayNano, + } + } +} + +impl DataType { + /// the [`PhysicalType`] of this [`DataType`]. + pub fn to_physical_type(&self) -> PhysicalType { + use DataType::*; + match self { + Null => PhysicalType::Null, + Boolean => PhysicalType::Boolean, + Int8 => PhysicalType::Primitive(PrimitiveType::Int8), + Int16 => PhysicalType::Primitive(PrimitiveType::Int16), + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + PhysicalType::Primitive(PrimitiveType::Int32) + }, + Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => { + PhysicalType::Primitive(PrimitiveType::Int64) + }, + Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128), + Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256), + UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8), + UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16), + UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32), + UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64), + Float16 => PhysicalType::Primitive(PrimitiveType::Float16), + Float32 => PhysicalType::Primitive(PrimitiveType::Float32), + Float64 => PhysicalType::Primitive(PrimitiveType::Float64), + Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs), + Interval(IntervalUnit::MonthDayNano) => { + PhysicalType::Primitive(PrimitiveType::MonthDayNano) + }, + Binary => PhysicalType::Binary, + FixedSizeBinary(_) => PhysicalType::FixedSizeBinary, + LargeBinary => PhysicalType::LargeBinary, + Utf8 => PhysicalType::Utf8, + LargeUtf8 => PhysicalType::LargeUtf8, + List(_) => PhysicalType::List, + FixedSizeList(_, _) => PhysicalType::FixedSizeList, + LargeList(_) => PhysicalType::LargeList, + Struct(_) => PhysicalType::Struct, + Union(_, _, _) => PhysicalType::Union, + Map(_, _) => PhysicalType::Map, + Dictionary(key, _, _) => PhysicalType::Dictionary(*key), + Extension(_, key, _) => key.to_physical_type(), + } + } + + /// Returns `&self` for all but [`DataType::Extension`]. For [`DataType::Extension`], + /// (recursively) returns the inner [`DataType`]. + /// Never returns the variant [`DataType::Extension`]. + pub fn to_logical_type(&self) -> &DataType { + use DataType::*; + match self { + Extension(_, key, _) => key.to_logical_type(), + _ => self, + } + } +} + +impl From for DataType { + fn from(item: IntegerType) -> Self { + match item { + IntegerType::Int8 => DataType::Int8, + IntegerType::Int16 => DataType::Int16, + IntegerType::Int32 => DataType::Int32, + IntegerType::Int64 => DataType::Int64, + IntegerType::UInt8 => DataType::UInt8, + IntegerType::UInt16 => DataType::UInt16, + IntegerType::UInt32 => DataType::UInt32, + IntegerType::UInt64 => DataType::UInt64, + } + } +} + +impl From for DataType { + fn from(item: PrimitiveType) -> Self { + match item { + PrimitiveType::Int8 => DataType::Int8, + PrimitiveType::Int16 => DataType::Int16, + PrimitiveType::Int32 => DataType::Int32, + PrimitiveType::Int64 => DataType::Int64, + PrimitiveType::UInt8 => DataType::UInt8, + PrimitiveType::UInt16 => DataType::UInt16, + PrimitiveType::UInt32 => DataType::UInt32, + PrimitiveType::UInt64 => DataType::UInt64, + PrimitiveType::Int128 => DataType::Decimal(32, 32), + PrimitiveType::Int256 => DataType::Decimal256(32, 32), + PrimitiveType::Float16 => DataType::Float16, + PrimitiveType::Float32 => DataType::Float32, + PrimitiveType::Float64 => DataType::Float64, + PrimitiveType::DaysMs => DataType::Interval(IntervalUnit::DayTime), + PrimitiveType::MonthDayNano => DataType::Interval(IntervalUnit::MonthDayNano), + } + } +} + +/// typedef for [`Arc`]. +pub type SchemaRef = Arc; + +/// support get extension for metadata +pub fn get_extension(metadata: &Metadata) -> Extension { + if let Some(name) = metadata.get("ARROW:extension:name") { + let metadata = metadata.get("ARROW:extension:metadata").cloned(); + Some((name.clone(), metadata)) + } else { + None + } +} diff --git a/crates/nano-arrow/src/datatypes/physical_type.rs b/crates/nano-arrow/src/datatypes/physical_type.rs new file mode 100644 index 000000000000..1e57fcf936bc --- /dev/null +++ b/crates/nano-arrow/src/datatypes/physical_type.rs @@ -0,0 +1,76 @@ +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +pub use crate::types::PrimitiveType; + +/// The set of physical types: unique in-memory representations of an Arrow array. +/// A physical type has a one-to-many relationship with a [`crate::datatypes::DataType`] and +/// a one-to-one mapping to each struct in this crate that implements [`crate::array::Array`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum PhysicalType { + /// A Null with no allocation. + Null, + /// A boolean represented as a single bit. + Boolean, + /// An array where each slot has a known compile-time size. + Primitive(PrimitiveType), + /// Opaque binary data of variable length. + Binary, + /// Opaque binary data of fixed size. + FixedSizeBinary, + /// Opaque binary data of variable length and 64-bit offsets. + LargeBinary, + /// A variable-length string in Unicode with UTF-8 encoding. + Utf8, + /// A variable-length string in Unicode with UFT-8 encoding and 64-bit offsets. + LargeUtf8, + /// A list of some data type with variable length. + List, + /// A list of some data type with fixed length. + FixedSizeList, + /// A list of some data type with variable length and 64-bit offsets. + LargeList, + /// A nested type that contains an arbitrary number of fields. + Struct, + /// A nested type that represents slots of differing types. + Union, + /// A nested type. + Map, + /// A dictionary encoded array by `IntegerType`. + Dictionary(IntegerType), +} + +impl PhysicalType { + /// Whether this physical type equals [`PhysicalType::Primitive`] of type `primitive`. + pub fn eq_primitive(&self, primitive: PrimitiveType) -> bool { + if let Self::Primitive(o) = self { + o == &primitive + } else { + false + } + } +} + +/// the set of valid indices types of a dictionary-encoded Array. +/// Each type corresponds to a variant of [`crate::array::DictionaryArray`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum IntegerType { + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, +} diff --git a/crates/nano-arrow/src/datatypes/schema.rs b/crates/nano-arrow/src/datatypes/schema.rs new file mode 100644 index 000000000000..d01f1937d2ed --- /dev/null +++ b/crates/nano-arrow/src/datatypes/schema.rs @@ -0,0 +1,60 @@ +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +use super::{Field, Metadata}; + +/// An ordered sequence of [`Field`]s with associated [`Metadata`]. +/// +/// [`Schema`] is an abstraction used to read from, and write to, Arrow IPC format, +/// Apache Parquet, and Apache Avro. All these formats have a concept of a schema +/// with fields and metadata. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub struct Schema { + /// The fields composing this schema. + pub fields: Vec, + /// Optional metadata. + pub metadata: Metadata, +} + +impl Schema { + /// Attaches a [`Metadata`] to [`Schema`] + #[inline] + pub fn with_metadata(self, metadata: Metadata) -> Self { + Self { + fields: self.fields, + metadata, + } + } + + /// Returns a new [`Schema`] with a subset of all fields whose `predicate` + /// evaluates to true. + pub fn filter bool>(self, predicate: F) -> Self { + let fields = self + .fields + .into_iter() + .enumerate() + .filter_map(|(index, f)| { + if (predicate)(index, &f) { + Some(f) + } else { + None + } + }) + .collect(); + + Schema { + fields, + metadata: self.metadata, + } + } +} + +impl From> for Schema { + fn from(fields: Vec) -> Self { + Self { + fields, + ..Default::default() + } + } +} diff --git a/crates/nano-arrow/src/doc/lib.md b/crates/nano-arrow/src/doc/lib.md new file mode 100644 index 000000000000..a1b57945c020 --- /dev/null +++ b/crates/nano-arrow/src/doc/lib.md @@ -0,0 +1,87 @@ +Welcome to arrow2's documentation. Thanks for checking it out! + +This is a library for efficient in-memory data operations with +[Arrow in-memory format](https://arrow.apache.org/docs/format/Columnar.html). +It is a re-write from the bottom up of the official `arrow` crate with soundness +and type safety in mind. + +Check out [the guide](https://jorgecarleitao.github.io/arrow2/main/guide/) for an introduction. +Below is an example of some of the things you can do with it: + +```rust +use std::sync::Arc; + +use arrow2::array::*; +use arrow2::datatypes::{Field, DataType, Schema}; +use arrow2::compute::arithmetics; +use arrow2::error::Result; +use arrow2::io::parquet::write::*; +use arrow2::chunk::Chunk; + +fn main() -> Result<()> { + // declare arrays + let a = Int32Array::from(&[Some(1), None, Some(3)]); + let b = Int32Array::from(&[Some(2), None, Some(6)]); + + // compute (probably the fastest implementation of a nullable op you can find out there) + let c = arithmetics::basic::mul_scalar(&a, &2); + assert_eq!(c, b); + + // declare a schema with fields + let schema = Schema::from(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]); + + // declare chunk + let chunk = Chunk::new(vec![a.arced(), b.arced()]); + + // write to parquet (probably the fastest implementation of writing to parquet out there) + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Snappy, + version: Version::V1, + data_pagesize_limit: None, + }; + + let row_groups = RowGroupIterator::try_new( + vec![Ok(chunk)].into_iter(), + &schema, + options, + vec![vec![Encoding::Plain], vec![Encoding::Plain]], + )?; + + // anything implementing `std::io::Write` works + let mut file = vec![]; + + let mut writer = FileWriter::try_new(file, schema, options)?; + + // Write the file. + for group in row_groups { + writer.write(group?)?; + } + let _ = writer.end(None)?; + Ok(()) +} +``` + +## Cargo features + +This crate has a significant number of cargo features to reduce compilation +time and number of dependencies. The feature `"full"` activates most +functionality, such as: + +- `io_ipc`: to interact with the Arrow IPC format +- `io_ipc_compression`: to read and write compressed Arrow IPC (v2) +- `io_csv` to read and write CSV +- `io_json` to read and write JSON +- `io_flight` to read and write to Arrow's Flight protocol +- `io_parquet` to read and write parquet +- `io_parquet_compression` to read and write compressed parquet +- `io_print` to write batches to formatted ASCII tables +- `compute` to operate on arrays (addition, sum, sort, etc.) + +The feature `simd` (not part of `full`) produces more explicit SIMD instructions +via [`std::simd`](https://doc.rust-lang.org/nightly/std/simd/index.html), but requires the +nightly channel. diff --git a/crates/nano-arrow/src/error.rs b/crates/nano-arrow/src/error.rs new file mode 100644 index 000000000000..e6455d6f055d --- /dev/null +++ b/crates/nano-arrow/src/error.rs @@ -0,0 +1,100 @@ +//! Defines [`Error`], representing all errors returned by this crate. +use std::fmt::{Debug, Display, Formatter}; + +/// Enum with all errors in this crate. +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + /// Returned when functionality is not yet available. + NotYetImplemented(String), + /// Wrapper for an error triggered by a dependency + External(String, Box), + /// Wrapper for IO errors + Io(std::io::Error), + /// When an invalid argument is passed to a function. + InvalidArgumentError(String), + /// Error during import or export to/from a format + ExternalFormat(String), + /// Whenever pushing to a container fails because it does not support more entries. + /// The solution is usually to use a higher-capacity container-backing type. + Overflow, + /// Whenever incoming data from the C data interface, IPC or Flight does not fulfil the Arrow specification. + OutOfSpec(String), +} + +impl Error { + /// Wraps an external error in an `Error`. + pub fn from_external_error(error: impl std::error::Error + Send + Sync + 'static) -> Self { + Self::External("".to_string(), Box::new(error)) + } + + pub(crate) fn oos>(msg: A) -> Self { + Self::OutOfSpec(msg.into()) + } + + #[allow(dead_code)] + pub(crate) fn nyi>(msg: A) -> Self { + Self::NotYetImplemented(msg.into()) + } +} + +impl From<::std::io::Error> for Error { + fn from(error: std::io::Error) -> Self { + Error::Io(error) + } +} + +impl From for Error { + fn from(error: std::str::Utf8Error) -> Self { + Error::External("".to_string(), Box::new(error)) + } +} + +impl From for Error { + fn from(error: std::string::FromUtf8Error) -> Self { + Error::External("".to_string(), Box::new(error)) + } +} + +impl From for Error { + fn from(error: simdutf8::basic::Utf8Error) -> Self { + Error::External("".to_string(), Box::new(error)) + } +} + +impl From for Error { + fn from(_: std::collections::TryReserveError) -> Error { + Error::Overflow + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Error::NotYetImplemented(source) => { + write!(f, "Not yet implemented: {}", &source) + }, + Error::External(message, source) => { + write!(f, "External error{}: {}", message, &source) + }, + Error::Io(desc) => write!(f, "Io error: {desc}"), + Error::InvalidArgumentError(desc) => { + write!(f, "Invalid argument error: {desc}") + }, + Error::ExternalFormat(desc) => { + write!(f, "External format error: {desc}") + }, + Error::Overflow => { + write!(f, "Operation overflew the backing container.") + }, + Error::OutOfSpec(message) => { + write!(f, "{message}") + }, + } + } +} + +impl std::error::Error for Error {} + +/// Typedef for a [`std::result::Result`] of an [`Error`]. +pub type Result = std::result::Result; diff --git a/crates/nano-arrow/src/ffi/array.rs b/crates/nano-arrow/src/ffi/array.rs new file mode 100644 index 000000000000..f87f7e66a10c --- /dev/null +++ b/crates/nano-arrow/src/ffi/array.rs @@ -0,0 +1,568 @@ +//! Contains functionality to load an ArrayData from the C Data Interface +use std::sync::Arc; + +use super::ArrowArray; +use crate::array::*; +use crate::bitmap::utils::{bytes_for, count_zeros}; +use crate::bitmap::Bitmap; +use crate::buffer::{Buffer, Bytes, BytesAllocator}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::{Error, Result}; +use crate::ffi::schema::get_child; +use crate::types::NativeType; + +/// Reads a valid `ffi` interface into a `Box` +/// # Errors +/// If and only if: +/// * the interface is not valid (e.g. a null pointer) +pub unsafe fn try_from(array: A) -> Result> { + use PhysicalType::*; + Ok(match array.data_type().to_physical_type() { + Null => Box::new(NullArray::try_from_ffi(array)?), + Boolean => Box::new(BooleanArray::try_from_ffi(array)?), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::try_from_ffi(array)?) + }), + Utf8 => Box::new(Utf8Array::::try_from_ffi(array)?), + LargeUtf8 => Box::new(Utf8Array::::try_from_ffi(array)?), + Binary => Box::new(BinaryArray::::try_from_ffi(array)?), + LargeBinary => Box::new(BinaryArray::::try_from_ffi(array)?), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::try_from_ffi(array)?), + List => Box::new(ListArray::::try_from_ffi(array)?), + LargeList => Box::new(ListArray::::try_from_ffi(array)?), + FixedSizeList => Box::new(FixedSizeListArray::try_from_ffi(array)?), + Struct => Box::new(StructArray::try_from_ffi(array)?), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::try_from_ffi(array)?) + }) + }, + Union => Box::new(UnionArray::try_from_ffi(array)?), + Map => Box::new(MapArray::try_from_ffi(array)?), + }) +} + +// Sound because the arrow specification does not allow multiple implementations +// to change this struct +// This is intrinsically impossible to prove because the implementations agree +// on this as part of the Arrow specification +unsafe impl Send for ArrowArray {} +unsafe impl Sync for ArrowArray {} + +impl Drop for ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +// callback used to drop [ArrowArray] when it is exported +unsafe extern "C" fn c_release_array(array: *mut ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it + let private = Box::from_raw(array.private_data as *mut PrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + + array.release = None; +} + +#[allow(dead_code)] +struct PrivateData { + array: Box, + buffers_ptr: Box<[*const std::os::raw::c_void]>, + children_ptr: Box<[*mut ArrowArray]>, + dictionary_ptr: Option<*mut ArrowArray>, +} + +impl ArrowArray { + /// creates a new `ArrowArray` from existing data. + /// # Safety + /// This method releases `buffers`. Consumers of this struct *must* call `release` before + /// releasing this struct, or contents in `buffers` leak. + pub(crate) fn new(array: Box) -> Self { + let (offset, buffers, children, dictionary) = + offset_buffers_children_dictionary(array.as_ref()); + + let buffers_ptr = buffers + .iter() + .map(|maybe_buffer| match maybe_buffer { + Some(b) => *b as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let n_buffers = buffers.len() as i64; + + let children_ptr = children + .into_iter() + .map(|child| Box::into_raw(Box::new(ArrowArray::new(child)))) + .collect::>(); + let n_children = children_ptr.len() as i64; + + let dictionary_ptr = + dictionary.map(|array| Box::into_raw(Box::new(ArrowArray::new(array)))); + + let length = array.len() as i64; + let null_count = array.null_count() as i64; + + let mut private_data = Box::new(PrivateData { + array, + buffers_ptr, + children_ptr, + dictionary_ptr, + }); + + Self { + length, + null_count, + offset: offset as i64, + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children_ptr.as_mut_ptr(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), + release: Some(c_release_array), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } + } + + /// creates an empty [`ArrowArray`], which can be used to import data into + pub fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// the length of the array + pub(crate) fn len(&self) -> usize { + self.length as usize + } + + /// the offset of the array + pub(crate) fn offset(&self) -> usize { + self.offset as usize + } + + /// the null count of the array + pub(crate) fn null_count(&self) -> usize { + self.null_count as usize + } +} + +/// # Safety +/// The caller must ensure that the buffer at index `i` is not mutably shared. +unsafe fn get_buffer_ptr( + array: &ArrowArray, + data_type: &DataType, + index: usize, +) -> Result<*mut T> { + if array.buffers.is_null() { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} must have non-null buffers" + ))); + } + + if array + .buffers + .align_offset(std::mem::align_of::<*mut *const u8>()) + != 0 + { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} + must have buffer {index} aligned to type {}", + std::any::type_name::<*mut *const u8>() + ))); + } + let buffers = array.buffers as *mut *const u8; + + if index >= array.n_buffers as usize { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} + must have buffer {index}." + ))); + } + + let ptr = *buffers.add(index); + if ptr.is_null() { + return Err(Error::oos(format!( + "An array of type {data_type:?} + must have a non-null buffer {index}" + ))); + } + + // note: we can't prove that this pointer is not mutably shared - part of the safety invariant + Ok(ptr as *mut T) +} + +/// returns the buffer `i` of `array` interpreted as a [`Buffer`]. +/// # Safety +/// This function is safe iff: +/// * the buffers up to position `index` are valid for the declared length +/// * the buffers' pointers are not mutably shared for the lifetime of `owner` +unsafe fn create_buffer( + array: &ArrowArray, + data_type: &DataType, + owner: InternalArrowArray, + index: usize, +) -> Result> { + let len = buffer_len(array, data_type, index)?; + + if len == 0 { + return Ok(Buffer::new()); + } + + let offset = buffer_offset(array, data_type, index); + let ptr: *mut T = get_buffer_ptr(array, data_type, index)?; + + // We have to check alignment. + // This is the zero-copy path. + if ptr.align_offset(std::mem::align_of::()) == 0 { + let bytes = Bytes::from_foreign(ptr, len, BytesAllocator::InternalArrowArray(owner)); + Ok(Buffer::from_bytes(bytes).sliced(offset, len - offset)) + } + // This is the path where alignment isn't correct. + // We copy the data to a new vec + else { + let buf = std::slice::from_raw_parts(ptr, len - offset).to_vec(); + Ok(Buffer::from(buf)) + } +} + +/// returns the buffer `i` of `array` interpreted as a [`Bitmap`]. +/// # Safety +/// This function is safe iff: +/// * the buffer at position `index` is valid for the declared length +/// * the buffers' pointer is not mutable for the lifetime of `owner` +unsafe fn create_bitmap( + array: &ArrowArray, + data_type: &DataType, + owner: InternalArrowArray, + index: usize, + // if this is the validity bitmap + // we can use the null count directly + is_validity: bool, +) -> Result { + let len: usize = array.length.try_into().expect("length to fit in `usize`"); + if len == 0 { + return Ok(Bitmap::new()); + } + let ptr = get_buffer_ptr(array, data_type, index)?; + + // Pointer of u8 has alignment 1, so we don't have to check alignment. + + let offset: usize = array.offset.try_into().expect("offset to fit in `usize`"); + let bytes_len = bytes_for(offset + len); + let bytes = Bytes::from_foreign(ptr, bytes_len, BytesAllocator::InternalArrowArray(owner)); + + let null_count: usize = if is_validity { + array.null_count() + } else { + count_zeros(bytes.as_ref(), offset, len) + }; + Bitmap::from_inner(Arc::new(bytes), offset, len, null_count) +} + +fn buffer_offset(array: &ArrowArray, data_type: &DataType, i: usize) -> usize { + use PhysicalType::*; + match (data_type.to_physical_type(), i) { + (LargeUtf8, 2) | (LargeBinary, 2) | (Utf8, 2) | (Binary, 2) => 0, + (FixedSizeBinary, 1) => { + if let DataType::FixedSizeBinary(size) = data_type.to_logical_type() { + let offset: usize = array.offset.try_into().expect("Offset to fit in `usize`"); + offset * *size + } else { + unreachable!() + } + }, + _ => array.offset.try_into().expect("Offset to fit in `usize`"), + } +} + +/// Returns the length, in slots, of the buffer `i` (indexed according to the C data interface) +unsafe fn buffer_len(array: &ArrowArray, data_type: &DataType, i: usize) -> Result { + Ok(match (data_type.to_physical_type(), i) { + (PhysicalType::FixedSizeBinary, 1) => { + if let DataType::FixedSizeBinary(size) = data_type.to_logical_type() { + *size * (array.offset as usize + array.length as usize) + } else { + unreachable!() + } + }, + (PhysicalType::FixedSizeList, 1) => { + if let DataType::FixedSizeList(_, size) = data_type.to_logical_type() { + *size * (array.offset as usize + array.length as usize) + } else { + unreachable!() + } + }, + (PhysicalType::Utf8, 1) + | (PhysicalType::LargeUtf8, 1) + | (PhysicalType::Binary, 1) + | (PhysicalType::LargeBinary, 1) + | (PhysicalType::List, 1) + | (PhysicalType::LargeList, 1) + | (PhysicalType::Map, 1) => { + // the len of the offset buffer (buffer 1) equals length + 1 + array.offset as usize + array.length as usize + 1 + }, + (PhysicalType::Utf8, 2) | (PhysicalType::Binary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = buffer_len(array, data_type, 1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; + // interpret as i32 + let offset_buffer = offset_buffer as *const i32; + // get last offset + + (unsafe { *offset_buffer.add(len - 1) }) as usize + }, + (PhysicalType::LargeUtf8, 2) | (PhysicalType::LargeBinary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = buffer_len(array, data_type, 1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; + // interpret as i64 + let offset_buffer = offset_buffer as *const i64; + // get last offset + (unsafe { *offset_buffer.add(len - 1) }) as usize + }, + // buffer len of primitive types + _ => array.offset as usize + array.length as usize, + }) +} + +/// Safety +/// This function is safe iff: +/// * `array.children` at `index` is valid +/// * `array.children` is not mutably shared for the lifetime of `parent` +/// * the pointer of `array.children` at `index` is valid +/// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` +unsafe fn create_child( + array: &ArrowArray, + data_type: &DataType, + parent: InternalArrowArray, + index: usize, +) -> Result> { + let data_type = get_child(data_type, index)?; + + // catch what we can + if array.children.is_null() { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} must have non-null children" + ))); + } + + if index >= array.n_children as usize { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} + must have child {index}." + ))); + } + + // Safety - part of the invariant + let arr_ptr = unsafe { *array.children.add(index) }; + + // catch what we can + if arr_ptr.is_null() { + return Err(Error::oos(format!( + "An array of type {data_type:?} + must have a non-null child {index}" + ))); + } + + // Safety - invariant of this function + let arr_ptr = unsafe { &*arr_ptr }; + Ok(ArrowArrayChild::new(arr_ptr, data_type, parent)) +} + +/// Safety +/// This function is safe iff: +/// * `array.dictionary` is valid +/// * `array.dictionary` is not mutably shared for the lifetime of `parent` +unsafe fn create_dictionary( + array: &ArrowArray, + data_type: &DataType, + parent: InternalArrowArray, +) -> Result>> { + if let DataType::Dictionary(_, values, _) = data_type { + let data_type = values.as_ref().clone(); + // catch what we can + if array.dictionary.is_null() { + return Err(Error::oos(format!( + "An array of type {data_type:?} + must have a non-null dictionary" + ))); + } + + // safety: part of the invariant + let array = unsafe { &*array.dictionary }; + Ok(Some(ArrowArrayChild::new(array, data_type, parent))) + } else { + Ok(None) + } +} + +pub trait ArrowArrayRef: std::fmt::Debug { + fn owner(&self) -> InternalArrowArray { + (*self.parent()).clone() + } + + /// returns the null bit buffer. + /// Rust implementation uses a buffer that is not part of the array of buffers. + /// The C Data interface's null buffer is part of the array of buffers. + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a bitmap. + /// This function assumes that the bitmap created from FFI is valid; this is impossible to prove. + unsafe fn validity(&self) -> Result> { + if self.array().null_count() == 0 { + Ok(None) + } else { + create_bitmap(self.array(), self.data_type(), self.owner(), 0, true).map(Some) + } + } + + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a buffer. + /// This function assumes that the buffer created from FFI is valid; this is impossible to prove. + unsafe fn buffer(&self, index: usize) -> Result> { + create_buffer::(self.array(), self.data_type(), self.owner(), index) + } + + /// # Safety + /// This function is safe iff: + /// * the buffer at position `index` is valid for the declared length + /// * the buffers' pointer is not mutable for the lifetime of `owner` + unsafe fn bitmap(&self, index: usize) -> Result { + create_bitmap(self.array(), self.data_type(), self.owner(), index, false) + } + + /// # Safety + /// * `array.children` at `index` is valid + /// * `array.children` is not mutably shared for the lifetime of `parent` + /// * the pointer of `array.children` at `index` is valid + /// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` + unsafe fn child(&self, index: usize) -> Result { + create_child(self.array(), self.data_type(), self.parent().clone(), index) + } + + unsafe fn dictionary(&self) -> Result> { + create_dictionary(self.array(), self.data_type(), self.parent().clone()) + } + + fn n_buffers(&self) -> usize; + + fn parent(&self) -> &InternalArrowArray; + fn array(&self) -> &ArrowArray; + fn data_type(&self) -> &DataType; +} + +/// Struct used to move an Array from and to the C Data Interface. +/// Its main responsibility is to expose functionality that requires +/// both [ArrowArray] and [ArrowSchema]. +/// +/// This struct has two main paths: +/// +/// ## Import from the C Data Interface +/// * [InternalArrowArray::empty] to allocate memory to be filled by an external call +/// * [InternalArrowArray::try_from_raw] to consume two non-null allocated pointers +/// ## Export to the C Data Interface +/// * [InternalArrowArray::try_new] to create a new [InternalArrowArray] from Rust-specific information +/// * [InternalArrowArray::into_raw] to expose two pointers for [ArrowArray] and [ArrowSchema]. +/// +/// # Safety +/// Whoever creates this struct is responsible for releasing their resources. Specifically, +/// consumers *must* call [InternalArrowArray::into_raw] and take ownership of the individual pointers, +/// calling [ArrowArray::release] and [ArrowSchema::release] accordingly. +/// +/// Furthermore, this struct assumes that the incoming data agrees with the C data interface. +#[derive(Debug, Clone)] +pub struct InternalArrowArray { + // Arc is used for sharability since this is immutable + array: Arc, + // Arced to reduce cost of cloning + data_type: Arc, +} + +impl InternalArrowArray { + pub fn new(array: ArrowArray, data_type: DataType) -> Self { + Self { + array: Arc::new(array), + data_type: Arc::new(data_type), + } + } +} + +impl ArrowArrayRef for InternalArrowArray { + /// the data_type as declared in the schema + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn parent(&self) -> &InternalArrowArray { + self + } + + fn array(&self) -> &ArrowArray { + self.array.as_ref() + } + + fn n_buffers(&self) -> usize { + self.array.n_buffers as usize + } +} + +#[derive(Debug)] +pub struct ArrowArrayChild<'a> { + array: &'a ArrowArray, + data_type: DataType, + parent: InternalArrowArray, +} + +impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { + /// the data_type as declared in the schema + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn parent(&self) -> &InternalArrowArray { + &self.parent + } + + fn array(&self) -> &ArrowArray { + self.array + } + + fn n_buffers(&self) -> usize { + self.array.n_buffers as usize + } +} + +impl<'a> ArrowArrayChild<'a> { + fn new(array: &'a ArrowArray, data_type: DataType, parent: InternalArrowArray) -> Self { + Self { + array, + data_type, + parent, + } + } +} diff --git a/crates/nano-arrow/src/ffi/bridge.rs b/crates/nano-arrow/src/ffi/bridge.rs new file mode 100644 index 000000000000..7a7b9a86ca3a --- /dev/null +++ b/crates/nano-arrow/src/ffi/bridge.rs @@ -0,0 +1,39 @@ +use crate::array::*; + +macro_rules! ffi_dyn { + ($array:expr, $ty:ty) => {{ + let a = $array.as_any().downcast_ref::<$ty>().unwrap(); + if a.offset().is_some() { + $array + } else { + Box::new(a.to_ffi_aligned()) + } + }}; +} + +pub fn align_to_c_data_interface(array: Box) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => ffi_dyn!(array, NullArray), + Boolean => ffi_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + ffi_dyn!(array, PrimitiveArray<$T>) + }), + Binary => ffi_dyn!(array, BinaryArray), + LargeBinary => ffi_dyn!(array, BinaryArray), + FixedSizeBinary => ffi_dyn!(array, FixedSizeBinaryArray), + Utf8 => ffi_dyn!(array, Utf8Array::), + LargeUtf8 => ffi_dyn!(array, Utf8Array::), + List => ffi_dyn!(array, ListArray::), + LargeList => ffi_dyn!(array, ListArray::), + FixedSizeList => ffi_dyn!(array, FixedSizeListArray), + Struct => ffi_dyn!(array, StructArray), + Union => ffi_dyn!(array, UnionArray), + Map => ffi_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + ffi_dyn!(array, DictionaryArray<$T>) + }) + }, + } +} diff --git a/crates/nano-arrow/src/ffi/generated.rs b/crates/nano-arrow/src/ffi/generated.rs new file mode 100644 index 000000000000..cd4953b7198a --- /dev/null +++ b/crates/nano-arrow/src/ffi/generated.rs @@ -0,0 +1,55 @@ +/* automatically generated by rust-bindgen 0.59.2 */ + +/// ABI-compatible struct for [`ArrowSchema`](https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions) +#[repr(C)] +#[derive(Debug)] +pub struct ArrowSchema { + pub(super) format: *const ::std::os::raw::c_char, + pub(super) name: *const ::std::os::raw::c_char, + pub(super) metadata: *const ::std::os::raw::c_char, + pub(super) flags: i64, + pub(super) n_children: i64, + pub(super) children: *mut *mut ArrowSchema, + pub(super) dictionary: *mut ArrowSchema, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} + +/// ABI-compatible struct for [`ArrowArray`](https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions) +#[repr(C)] +#[derive(Debug)] +pub struct ArrowArray { + pub(super) length: i64, + pub(super) null_count: i64, + pub(super) offset: i64, + pub(super) n_buffers: i64, + pub(super) n_children: i64, + pub(super) buffers: *mut *const ::std::os::raw::c_void, + pub(super) children: *mut *mut ArrowArray, + pub(super) dictionary: *mut ArrowArray, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} + +/// ABI-compatible struct for [`ArrowArrayStream`](https://arrow.apache.org/docs/format/CStreamInterface.html). +#[repr(C)] +#[derive(Debug)] +pub struct ArrowArrayStream { + pub(super) get_schema: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut ArrowArrayStream, + out: *mut ArrowSchema, + ) -> ::std::os::raw::c_int, + >, + pub(super) get_next: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut ArrowArrayStream, + out: *mut ArrowArray, + ) -> ::std::os::raw::c_int, + >, + pub(super) get_last_error: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char, + >, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} diff --git a/crates/nano-arrow/src/ffi/mmap.rs b/crates/nano-arrow/src/ffi/mmap.rs new file mode 100644 index 000000000000..03c1ac9aa30a --- /dev/null +++ b/crates/nano-arrow/src/ffi/mmap.rs @@ -0,0 +1,164 @@ +//! Functionality to mmap in-memory data regions. +use std::sync::Arc; + +use super::{ArrowArray, InternalArrowArray}; +use crate::array::{BooleanArray, FromFfi, PrimitiveArray}; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::types::NativeType; + +#[allow(dead_code)] +struct PrivateData { + // the owner of the pointers' regions + data: T, + buffers_ptr: Box<[*const std::os::raw::c_void]>, + children_ptr: Box<[*mut ArrowArray]>, + dictionary_ptr: Option<*mut ArrowArray>, +} + +pub(crate) unsafe fn create_array< + T: AsRef<[u8]>, + I: Iterator>, + II: Iterator, +>( + data: Arc, + num_rows: usize, + null_count: usize, + buffers: I, + children: II, + dictionary: Option, + offset: Option, +) -> ArrowArray { + let buffers_ptr = buffers + .map(|maybe_buffer| match maybe_buffer { + Some(b) => b as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let n_buffers = buffers_ptr.len() as i64; + + let children_ptr = children + .map(|child| Box::into_raw(Box::new(child))) + .collect::>(); + let n_children = children_ptr.len() as i64; + + let dictionary_ptr = dictionary.map(|array| Box::into_raw(Box::new(array))); + + let mut private_data = Box::new(PrivateData::> { + data, + buffers_ptr, + children_ptr, + dictionary_ptr, + }); + + ArrowArray { + length: num_rows as i64, + null_count: null_count as i64, + offset: offset.unwrap_or(0) as i64, // Unwrap: IPC files are by definition not offset + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children_ptr.as_mut_ptr(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), + release: Some(release::>), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} + +/// callback used to drop [`ArrowArray`] when it is exported specified for [`PrivateData`]. +unsafe extern "C" fn release(array: *mut ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it + let private = Box::from_raw(array.private_data as *mut PrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + + array.release = None; +} + +/// Creates a (non-null) [`PrimitiveArray`] from a slice of values. +/// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without incurring +/// a memcopy cost. +/// +/// # Safety +/// +/// Using this function is not unsafe, but the returned PrimitiveArray's lifetime is bound to the lifetime +/// of the slice. The returned [`PrimitiveArray`] _must not_ outlive the passed slice. +pub unsafe fn slice(slice: &[T]) -> PrimitiveArray { + let num_rows = slice.len(); + let null_count = 0; + let validity = None; + + let data: &[u8] = bytemuck::cast_slice(slice); + let ptr = data.as_ptr(); + let data = Arc::new(data); + + // safety: the underlying assumption of this function: the array will not be used + // beyond the + let array = create_array( + data, + num_rows, + null_count, + [validity, Some(ptr)].into_iter(), + [].into_iter(), + None, + None, + ); + let array = InternalArrowArray::new(array, T::PRIMITIVE.into()); + + // safety: we just created a valid array + unsafe { PrimitiveArray::::try_from_ffi(array) }.unwrap() +} + +/// Creates a (non-null) [`BooleanArray`] from a slice of bits. +/// This does not have memcopy and is the fastest way to create a [`BooleanArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without incurring +/// a memcopy cost. +/// +/// The `offset` indicates where the first bit starts in the first byte. +/// +/// # Safety +/// +/// Using this function is not unsafe, but the returned BooleanArrays's lifetime is bound to the lifetime +/// of the slice. The returned [`BooleanArray`] _must not_ outlive the passed slice. +pub unsafe fn bitmap(data: &[u8], offset: usize, length: usize) -> Result { + if offset >= 8 { + return Err(Error::InvalidArgumentError("offset should be < 8".into())); + }; + if length > data.len() * 8 - offset { + return Err(Error::InvalidArgumentError("given length is oob".into())); + } + let null_count = 0; + let validity = None; + + let ptr = data.as_ptr(); + let data = Arc::new(data); + + // safety: the underlying assumption of this function: the array will not be used + // beyond the + let array = create_array( + data, + length, + null_count, + [validity, Some(ptr)].into_iter(), + [].into_iter(), + None, + Some(offset), + ); + let array = InternalArrowArray::new(array, DataType::Boolean); + + // safety: we just created a valid array + Ok(unsafe { BooleanArray::try_from_ffi(array) }.unwrap()) +} diff --git a/crates/nano-arrow/src/ffi/mod.rs b/crates/nano-arrow/src/ffi/mod.rs new file mode 100644 index 000000000000..b1a1ac3c1210 --- /dev/null +++ b/crates/nano-arrow/src/ffi/mod.rs @@ -0,0 +1,46 @@ +//! contains FFI bindings to import and export [`Array`](crate::array::Array) via +//! Arrow's [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +mod array; +mod bridge; +mod generated; +pub mod mmap; +mod schema; +mod stream; + +pub(crate) use array::{try_from, ArrowArrayRef, InternalArrowArray}; +pub use generated::{ArrowArray, ArrowArrayStream, ArrowSchema}; +pub use stream::{export_iterator, ArrowArrayStreamReader}; + +use self::schema::to_field; +use crate::array::Array; +use crate::datatypes::{DataType, Field}; +use crate::error::Result; + +/// Exports an [`Box`] to the C data interface. +pub fn export_array_to_c(array: Box) -> ArrowArray { + ArrowArray::new(bridge::align_to_c_data_interface(array)) +} + +/// Exports a [`Field`] to the C data interface. +pub fn export_field_to_c(field: &Field) -> ArrowSchema { + ArrowSchema::new(field) +} + +/// Imports a [`Field`] from the C data interface. +/// # Safety +/// This function is intrinsically `unsafe` and relies on a [`ArrowSchema`] +/// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub unsafe fn import_field_from_c(field: &ArrowSchema) -> Result { + to_field(field) +} + +/// Imports an [`Array`] from the C data interface. +/// # Safety +/// This function is intrinsically `unsafe` and relies on a [`ArrowArray`] +/// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub unsafe fn import_array_from_c( + array: ArrowArray, + data_type: DataType, +) -> Result> { + try_from(InternalArrowArray::new(array, data_type)) +} diff --git a/crates/nano-arrow/src/ffi/schema.rs b/crates/nano-arrow/src/ffi/schema.rs new file mode 100644 index 000000000000..332410b0b6c5 --- /dev/null +++ b/crates/nano-arrow/src/ffi/schema.rs @@ -0,0 +1,633 @@ +use std::collections::BTreeMap; +use std::convert::TryInto; +use std::ffi::{CStr, CString}; +use std::ptr; + +use super::ArrowSchema; +use crate::datatypes::{ + DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, UnionMode, +}; +use crate::error::{Error, Result}; + +#[allow(dead_code)] +struct SchemaPrivateData { + name: CString, + format: CString, + metadata: Option>, + children_ptr: Box<[*mut ArrowSchema]>, + dictionary: Option<*mut ArrowSchema>, +} + +// callback used to drop [ArrowSchema] when it is exported. +unsafe extern "C" fn c_release_schema(schema: *mut ArrowSchema) { + if schema.is_null() { + return; + } + let schema = &mut *schema; + + let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary { + let _ = Box::from_raw(ptr); + } + + schema.release = None; +} + +/// allocate (and hold) the children +fn schema_children(data_type: &DataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> { + match data_type { + DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { + Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))]) + }, + DataType::Map(field, is_sorted) => { + *flags += (*is_sorted as i64) * 4; + Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))]) + }, + DataType::Struct(fields) | DataType::Union(fields, _, _) => fields + .iter() + .map(|field| Box::into_raw(Box::new(ArrowSchema::new(field)))) + .collect::>(), + DataType::Extension(_, inner, _) => schema_children(inner, flags), + _ => Box::new([]), + } +} + +impl ArrowSchema { + /// creates a new [ArrowSchema] + pub(crate) fn new(field: &Field) -> Self { + let format = to_format(field.data_type()); + let name = field.name.clone(); + + let mut flags = field.is_nullable as i64 * 2; + + // note: this cannot be done along with the above because the above is fallible and this op leaks. + let children_ptr = schema_children(field.data_type(), &mut flags); + let n_children = children_ptr.len() as i64; + + let dictionary = if let DataType::Dictionary(_, values, is_ordered) = field.data_type() { + flags += *is_ordered as i64; + // we do not store field info in the dict values, so can't recover it all :( + let field = Field::new("", values.as_ref().clone(), true); + Some(Box::new(ArrowSchema::new(&field))) + } else { + None + }; + + let metadata = &field.metadata; + + let metadata = if let DataType::Extension(name, _, extension_metadata) = field.data_type() { + // append extension information. + let mut metadata = metadata.clone(); + + // metadata + if let Some(extension_metadata) = extension_metadata { + metadata.insert( + "ARROW:extension:metadata".to_string(), + extension_metadata.clone(), + ); + } + + metadata.insert("ARROW:extension:name".to_string(), name.clone()); + + Some(metadata_to_bytes(&metadata)) + } else if !metadata.is_empty() { + Some(metadata_to_bytes(metadata)) + } else { + None + }; + + let name = CString::new(name).unwrap(); + let format = CString::new(format).unwrap(); + + let mut private = Box::new(SchemaPrivateData { + name, + format, + metadata, + children_ptr, + dictionary: dictionary.map(Box::into_raw), + }); + + // + Self { + format: private.format.as_ptr(), + name: private.name.as_ptr(), + metadata: private + .metadata + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) as *const ::std::os::raw::c_char, + flags, + n_children, + children: private.children_ptr.as_mut_ptr(), + dictionary: private.dictionary.unwrap_or(std::ptr::null_mut()), + release: Some(c_release_schema), + private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void, + } + } + + /// create an empty [ArrowSchema] + pub fn empty() -> Self { + Self { + format: std::ptr::null_mut(), + name: std::ptr::null_mut(), + metadata: std::ptr::null_mut(), + flags: 0, + n_children: 0, + children: ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// returns the format of this schema. + pub(crate) fn format(&self) -> &str { + assert!(!self.format.is_null()); + // safe because the lifetime of `self.format` equals `self` + unsafe { CStr::from_ptr(self.format) } + .to_str() + .expect("The external API has a non-utf8 as format") + } + + /// returns the name of this schema. + /// + /// Since this field is optional, `""` is returned if it is not set (as per the spec). + pub(crate) fn name(&self) -> &str { + if self.name.is_null() { + return ""; + } + // safe because the lifetime of `self.name` equals `self` + unsafe { CStr::from_ptr(self.name) }.to_str().unwrap() + } + + pub(crate) fn child(&self, index: usize) -> &'static Self { + assert!(index < self.n_children as usize); + unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } + } + + pub(crate) fn dictionary(&self) -> Option<&'static Self> { + if self.dictionary.is_null() { + return None; + }; + Some(unsafe { self.dictionary.as_ref().unwrap() }) + } + + pub(crate) fn nullable(&self) -> bool { + (self.flags / 2) & 1 == 1 + } +} + +impl Drop for ArrowSchema { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> Result { + let dictionary = schema.dictionary(); + let data_type = if let Some(dictionary) = dictionary { + let indices = to_integer_type(schema.format())?; + let values = to_field(dictionary)?; + let is_ordered = schema.flags & 1 == 1; + DataType::Dictionary(indices, Box::new(values.data_type().clone()), is_ordered) + } else { + to_data_type(schema)? + }; + let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) }; + + let data_type = if let Some((name, extension_metadata)) = extension { + DataType::Extension(name, Box::new(data_type), extension_metadata) + } else { + data_type + }; + + Ok(Field::new(schema.name(), data_type, schema.nullable()).with_metadata(metadata)) +} + +fn to_integer_type(format: &str) -> Result { + use IntegerType::*; + Ok(match format { + "c" => Int8, + "C" => UInt8, + "s" => Int16, + "S" => UInt16, + "i" => Int32, + "I" => UInt32, + "l" => Int64, + "L" => UInt64, + _ => { + return Err(Error::OutOfSpec( + "Dictionary indices can only be integers".to_string(), + )) + }, + }) +} + +unsafe fn to_data_type(schema: &ArrowSchema) -> Result { + Ok(match schema.format() { + "n" => DataType::Null, + "b" => DataType::Boolean, + "c" => DataType::Int8, + "C" => DataType::UInt8, + "s" => DataType::Int16, + "S" => DataType::UInt16, + "i" => DataType::Int32, + "I" => DataType::UInt32, + "l" => DataType::Int64, + "L" => DataType::UInt64, + "e" => DataType::Float16, + "f" => DataType::Float32, + "g" => DataType::Float64, + "z" => DataType::Binary, + "Z" => DataType::LargeBinary, + "u" => DataType::Utf8, + "U" => DataType::LargeUtf8, + "tdD" => DataType::Date32, + "tdm" => DataType::Date64, + "tts" => DataType::Time32(TimeUnit::Second), + "ttm" => DataType::Time32(TimeUnit::Millisecond), + "ttu" => DataType::Time64(TimeUnit::Microsecond), + "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "tDs" => DataType::Duration(TimeUnit::Second), + "tDm" => DataType::Duration(TimeUnit::Millisecond), + "tDu" => DataType::Duration(TimeUnit::Microsecond), + "tDn" => DataType::Duration(TimeUnit::Nanosecond), + "tiM" => DataType::Interval(IntervalUnit::YearMonth), + "tiD" => DataType::Interval(IntervalUnit::DayTime), + "+l" => { + let child = schema.child(0); + DataType::List(Box::new(to_field(child)?)) + }, + "+L" => { + let child = schema.child(0); + DataType::LargeList(Box::new(to_field(child)?)) + }, + "+m" => { + let child = schema.child(0); + + let is_sorted = (schema.flags & 4) != 0; + DataType::Map(Box::new(to_field(child)?), is_sorted) + }, + "+s" => { + let children = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + DataType::Struct(children) + }, + other => { + match other.splitn(2, ':').collect::>()[..] { + // Timestamps with no timezone + ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), + + // Timestamps with timezone + ["tss", tz] => DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())), + ["tsm", tz] => DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())), + ["tsu", tz] => DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())), + ["tsn", tz] => DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())), + + ["w", size_raw] => { + // Example: "w:42" fixed-width binary [42 bytes] + let size = size_raw + .parse::() + .map_err(|_| Error::OutOfSpec("size is not a valid integer".to_string()))?; + DataType::FixedSizeBinary(size) + }, + ["+w", size_raw] => { + // Example: "+w:123" fixed-sized list [123 items] + let size = size_raw + .parse::() + .map_err(|_| Error::OutOfSpec("size is not a valid integer".to_string()))?; + let child = to_field(schema.child(0))?; + DataType::FixedSizeList(Box::new(child), size) + }, + ["d", raw] => { + // Decimal + let (precision, scale) = match raw.split(',').collect::>()[..] { + [precision_raw, scale_raw] => { + // Example: "d:19,10" decimal128 [precision 19, scale 10] + (precision_raw, scale_raw) + }, + [precision_raw, scale_raw, width_raw] => { + // Example: "d:19,10,NNN" decimal bitwidth = NNN [precision 19, scale 10] + // Only bitwdth of 128 currently supported + let bit_width = width_raw.parse::().map_err(|_| { + Error::OutOfSpec( + "Decimal bit width is not a valid integer".to_string(), + ) + })?; + if bit_width == 256 { + return Ok(DataType::Decimal256( + precision_raw.parse::().map_err(|_| { + Error::OutOfSpec( + "Decimal precision is not a valid integer".to_string(), + ) + })?, + scale_raw.parse::().map_err(|_| { + Error::OutOfSpec( + "Decimal scale is not a valid integer".to_string(), + ) + })?, + )); + } + (precision_raw, scale_raw) + }, + _ => { + return Err(Error::OutOfSpec( + "Decimal must contain 2 or 3 comma-separated values".to_string(), + )); + }, + }; + + DataType::Decimal( + precision.parse::().map_err(|_| { + Error::OutOfSpec("Decimal precision is not a valid integer".to_string()) + })?, + scale.parse::().map_err(|_| { + Error::OutOfSpec("Decimal scale is not a valid integer".to_string()) + })?, + ) + }, + [union_type @ "+us", union_parts] | [union_type @ "+ud", union_parts] => { + // union, sparse + // Example "+us:I,J,..." sparse union with type ids I,J... + // Example: "+ud:I,J,..." dense union with type ids I,J... + let mode = UnionMode::sparse(union_type == "+us"); + let type_ids = union_parts + .split(',') + .map(|x| { + x.parse::().map_err(|_| { + Error::OutOfSpec("Union type id is not a valid integer".to_string()) + }) + }) + .collect::>>()?; + let fields = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + DataType::Union(fields, Some(type_ids), mode) + }, + _ => { + return Err(Error::OutOfSpec(format!( + "The datatype \"{other}\" is still not supported in Rust implementation", + ))); + }, + } + }, + }) +} + +/// the inverse of [to_field] +fn to_format(data_type: &DataType) -> String { + match data_type { + DataType::Null => "n".to_string(), + DataType::Boolean => "b".to_string(), + DataType::Int8 => "c".to_string(), + DataType::UInt8 => "C".to_string(), + DataType::Int16 => "s".to_string(), + DataType::UInt16 => "S".to_string(), + DataType::Int32 => "i".to_string(), + DataType::UInt32 => "I".to_string(), + DataType::Int64 => "l".to_string(), + DataType::UInt64 => "L".to_string(), + DataType::Float16 => "e".to_string(), + DataType::Float32 => "f".to_string(), + DataType::Float64 => "g".to_string(), + DataType::Binary => "z".to_string(), + DataType::LargeBinary => "Z".to_string(), + DataType::Utf8 => "u".to_string(), + DataType::LargeUtf8 => "U".to_string(), + DataType::Date32 => "tdD".to_string(), + DataType::Date64 => "tdm".to_string(), + DataType::Time32(TimeUnit::Second) => "tts".to_string(), + DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(), + DataType::Time32(_) => { + unreachable!("Time32 is only supported for seconds and milliseconds") + }, + DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(), + DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(), + DataType::Time64(_) => { + unreachable!("Time64 is only supported for micro and nanoseconds") + }, + DataType::Duration(TimeUnit::Second) => "tDs".to_string(), + DataType::Duration(TimeUnit::Millisecond) => "tDm".to_string(), + DataType::Duration(TimeUnit::Microsecond) => "tDu".to_string(), + DataType::Duration(TimeUnit::Nanosecond) => "tDn".to_string(), + DataType::Interval(IntervalUnit::YearMonth) => "tiM".to_string(), + DataType::Interval(IntervalUnit::DayTime) => "tiD".to_string(), + DataType::Interval(IntervalUnit::MonthDayNano) => { + todo!("Spec for FFI for MonthDayNano still not defined.") + }, + DataType::Timestamp(unit, tz) => { + let unit = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "m", + TimeUnit::Microsecond => "u", + TimeUnit::Nanosecond => "n", + }; + format!( + "ts{}:{}", + unit, + tz.as_ref().map(|x| x.as_ref()).unwrap_or("") + ) + }, + DataType::Decimal(precision, scale) => format!("d:{precision},{scale}"), + DataType::Decimal256(precision, scale) => format!("d:{precision},{scale},256"), + DataType::List(_) => "+l".to_string(), + DataType::LargeList(_) => "+L".to_string(), + DataType::Struct(_) => "+s".to_string(), + DataType::FixedSizeBinary(size) => format!("w:{size}"), + DataType::FixedSizeList(_, size) => format!("+w:{size}"), + DataType::Union(f, ids, mode) => { + let sparsness = if mode.is_sparse() { 's' } else { 'd' }; + let mut r = format!("+u{sparsness}:"); + let ids = if let Some(ids) = ids { + ids.iter() + .fold(String::new(), |a, b| a + &b.to_string() + ",") + } else { + (0..f.len()).fold(String::new(), |a, b| a + &b.to_string() + ",") + }; + let ids = &ids[..ids.len() - 1]; // take away last "," + r.push_str(ids); + r + }, + DataType::Map(_, _) => "+m".to_string(), + DataType::Dictionary(index, _, _) => to_format(&(*index).into()), + DataType::Extension(_, inner, _) => to_format(inner.as_ref()), + } +} + +pub(super) fn get_child(data_type: &DataType, index: usize) -> Result { + match (index, data_type) { + (0, DataType::List(field)) => Ok(field.data_type().clone()), + (0, DataType::FixedSizeList(field, _)) => Ok(field.data_type().clone()), + (0, DataType::LargeList(field)) => Ok(field.data_type().clone()), + (0, DataType::Map(field, _)) => Ok(field.data_type().clone()), + (index, DataType::Struct(fields)) => Ok(fields[index].data_type().clone()), + (index, DataType::Union(fields, _, _)) => Ok(fields[index].data_type().clone()), + (index, DataType::Extension(_, subtype, _)) => get_child(subtype, index), + (child, data_type) => Err(Error::OutOfSpec(format!( + "Requested child {child} to type {data_type:?} that has no such child", + ))), + } +} + +fn metadata_to_bytes(metadata: &BTreeMap) -> Vec { + let a = (metadata.len() as i32).to_ne_bytes().to_vec(); + metadata.iter().fold(a, |mut acc, (key, value)| { + acc.extend((key.len() as i32).to_ne_bytes()); + acc.extend(key.as_bytes()); + acc.extend((value.len() as i32).to_ne_bytes()); + acc.extend(value.as_bytes()); + acc + }) +} + +unsafe fn read_ne_i32(ptr: *const u8) -> i32 { + let slice = std::slice::from_raw_parts(ptr, 4); + i32::from_ne_bytes(slice.try_into().unwrap()) +} + +unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str { + let slice = std::slice::from_raw_parts(ptr, len); + simdutf8::basic::from_utf8(slice).unwrap() +} + +unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) { + let mut data = data as *const u8; // u8 = i8 + if data.is_null() { + return (Metadata::default(), None); + }; + let len = read_ne_i32(data); + data = data.add(4); + + let mut result = BTreeMap::new(); + let mut extension_name = None; + let mut extension_metadata = None; + for _ in 0..len { + let key_len = read_ne_i32(data) as usize; + data = data.add(4); + let key = read_bytes(data, key_len); + data = data.add(key_len); + let value_len = read_ne_i32(data) as usize; + data = data.add(4); + let value = read_bytes(data, value_len); + data = data.add(value_len); + match key { + "ARROW:extension:name" => { + extension_name = Some(value.to_string()); + }, + "ARROW:extension:metadata" => { + extension_metadata = Some(value.to_string()); + }, + _ => { + result.insert(key.to_string(), value.to_string()); + }, + }; + } + let extension = extension_name.map(|name| (name, extension_metadata)); + (result, extension) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all() { + let mut dts = vec![ + DataType::Null, + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Decimal(5, 5), + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + DataType::FixedSizeBinary(2), + DataType::List(Box::new(Field::new("example", DataType::Boolean, false))), + DataType::FixedSizeList(Box::new(Field::new("example", DataType::Boolean, false)), 2), + DataType::LargeList(Box::new(Field::new("example", DataType::Boolean, false))), + DataType::Struct(vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + ]), + DataType::Map(Box::new(Field::new("a", DataType::Int64, true)), true), + DataType::Union( + vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + ], + Some(vec![1, 2]), + UnionMode::Dense, + ), + DataType::Union( + vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + ], + Some(vec![0, 1]), + UnionMode::Sparse, + ), + ]; + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + dts.push(DataType::Timestamp(time_unit, None)); + dts.push(DataType::Timestamp(time_unit, Some("00:00".to_string()))); + dts.push(DataType::Duration(time_unit)); + } + for interval_type in [ + IntervalUnit::DayTime, + IntervalUnit::YearMonth, + //IntervalUnit::MonthDayNano, // not yet defined on the C data interface + ] { + dts.push(DataType::Interval(interval_type)); + } + + for expected in dts { + let field = Field::new("a", expected.clone(), true); + let schema = ArrowSchema::new(&field); + let result = unsafe { super::to_data_type(&schema).unwrap() }; + assert_eq!(result, expected); + } + } +} diff --git a/crates/nano-arrow/src/ffi/stream.rs b/crates/nano-arrow/src/ffi/stream.rs new file mode 100644 index 000000000000..4776014bca54 --- /dev/null +++ b/crates/nano-arrow/src/ffi/stream.rs @@ -0,0 +1,226 @@ +use std::ffi::{CStr, CString}; +use std::ops::DerefMut; + +use super::{ + export_array_to_c, export_field_to_c, import_array_from_c, import_field_from_c, ArrowArray, + ArrowArrayStream, ArrowSchema, +}; +use crate::array::Array; +use crate::datatypes::Field; +use crate::error::Error; + +impl Drop for ArrowArrayStream { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +impl ArrowArrayStream { + /// Creates an empty [`ArrowArrayStream`] used to import from a producer. + pub fn empty() -> Self { + Self { + get_schema: None, + get_next: None, + get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + } + } +} + +unsafe fn handle_error(iter: &mut ArrowArrayStream) -> Error { + let error = unsafe { (iter.get_last_error.unwrap())(&mut *iter) }; + + if error.is_null() { + return Error::External( + "C stream".to_string(), + Box::new(Error::ExternalFormat("an unspecified error".to_string())), + ); + } + + let error = unsafe { CStr::from_ptr(error) }; + Error::External( + "C stream".to_string(), + Box::new(Error::ExternalFormat(error.to_str().unwrap().to_string())), + ) +} + +/// Implements an iterator of [`Array`] consumed from the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html). +pub struct ArrowArrayStreamReader> { + iter: Iter, + field: Field, +} + +impl> ArrowArrayStreamReader { + /// Returns a new [`ArrowArrayStreamReader`] + /// # Error + /// Errors iff the [`ArrowArrayStream`] is out of specification, + /// or was already released prior to calling this function. + /// # Safety + /// This method is intrinsically `unsafe` since it assumes that the `ArrowArrayStream` + /// contains a valid Arrow C stream interface. + /// In particular: + /// * The `ArrowArrayStream` fulfills the invariants of the C stream interface + /// * The schema `get_schema` produces fulfills the C data interface + pub unsafe fn try_new(mut iter: Iter) -> Result { + if iter.release.is_none() { + return Err(Error::InvalidArgumentError( + "The C stream was already released".to_string(), + )); + }; + + if iter.get_next.is_none() { + return Err(Error::OutOfSpec( + "The C stream MUST contain a non-null get_next".to_string(), + )); + }; + + if iter.get_last_error.is_none() { + return Err(Error::OutOfSpec( + "The C stream MUST contain a non-null get_last_error".to_string(), + )); + }; + + let mut field = ArrowSchema::empty(); + let status = if let Some(f) = iter.get_schema { + unsafe { (f)(&mut *iter, &mut field) } + } else { + return Err(Error::OutOfSpec( + "The C stream MUST contain a non-null get_schema".to_string(), + )); + }; + + if status != 0 { + return Err(unsafe { handle_error(&mut iter) }); + } + + let field = unsafe { import_field_from_c(&field)? }; + + Ok(Self { iter, field }) + } + + /// Returns the field provided by the stream + pub fn field(&self) -> &Field { + &self.field + } + + /// Advances this iterator by one array + /// # Error + /// Errors iff: + /// * The C stream interface returns an error + /// * The C stream interface returns an invalid array (that we can identify, see Safety below) + /// # Safety + /// Calling this iterator's `next` assumes that the [`ArrowArrayStream`] produces arrow arrays + /// that fulfill the C data interface + pub unsafe fn next(&mut self) -> Option, Error>> { + let mut array = ArrowArray::empty(); + let status = unsafe { (self.iter.get_next.unwrap())(&mut *self.iter, &mut array) }; + + if status != 0 { + return Some(Err(unsafe { handle_error(&mut self.iter) })); + } + + // last paragraph of https://arrow.apache.org/docs/format/CStreamInterface.html#c.ArrowArrayStream.get_next + array.release?; + + // Safety: assumed from the C stream interface + unsafe { import_array_from_c(array, self.field.data_type.clone()) } + .map(Some) + .transpose() + } +} + +struct PrivateData { + iter: Box, Error>>>, + field: Field, + error: Option, +} + +unsafe extern "C" fn get_next(iter: *mut ArrowArrayStream, array: *mut ArrowArray) -> i32 { + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + match private.iter.next() { + Some(Ok(item)) => { + // check that the array has the same data_type as field + let item_dt = item.data_type(); + let expected_dt = private.field.data_type(); + if item_dt != expected_dt { + private.error = Some(CString::new(format!("The iterator produced an item of data type {item_dt:?} but the producer expects data type {expected_dt:?}").as_bytes().to_vec()).unwrap()); + return 2001; // custom application specific error (since this is never a result of this interface) + } + + std::ptr::write(array, export_array_to_c(item)); + + private.error = None; + 0 + }, + Some(Err(err)) => { + private.error = Some(CString::new(err.to_string().as_bytes().to_vec()).unwrap()); + 2001 // custom application specific error (since this is never a result of this interface) + }, + None => { + let a = ArrowArray::empty(); + std::ptr::write_unaligned(array, a); + private.error = None; + 0 + }, + } +} + +unsafe extern "C" fn get_schema(iter: *mut ArrowArrayStream, schema: *mut ArrowSchema) -> i32 { + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + std::ptr::write(schema, export_field_to_c(&private.field)); + 0 +} + +unsafe extern "C" fn get_last_error(iter: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char { + if iter.is_null() { + return std::ptr::null(); + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + private + .error + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) +} + +unsafe extern "C" fn release(iter: *mut ArrowArrayStream) { + if iter.is_null() { + return; + } + let _ = Box::from_raw((*iter).private_data as *mut PrivateData); + (*iter).release = None; + // private drops automatically +} + +/// Exports an iterator to the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html) +pub fn export_iterator( + iter: Box, Error>>>, + field: Field, +) -> ArrowArrayStream { + let private_data = Box::new(PrivateData { + iter, + field, + error: None, + }); + + ArrowArrayStream { + get_schema: Some(get_schema), + get_next: Some(get_next), + get_last_error: Some(get_last_error), + release: Some(release), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} diff --git a/crates/nano-arrow/src/io/README.md b/crates/nano-arrow/src/io/README.md new file mode 100644 index 000000000000..a3c7599b8bdf --- /dev/null +++ b/crates/nano-arrow/src/io/README.md @@ -0,0 +1,24 @@ +# IO module + +This document describes the overall design of this module. + +## Rules: + +- Each directory in this module corresponds to a specific format such as `csv` and `json`. +- directories that depend on external dependencies MUST be feature gated, with a feature named with a prefix `io_`. +- modules MUST re-export any API of external dependencies they require as part of their public API. + E.g. + - if a module as an API `write(writer: &mut csv:Writer, ...)`, it MUST contain `pub use csv::Writer;`. + + The rational is that adding this crate to `cargo.toml` must be sufficient to use it. +- Each directory SHOULD contain two directories, `read` and `write`, corresponding + to functionality about reading from the format and writing to the format respectively. +- The base module SHOULD contain `use pub read;` and `use pub write;`. +- Implementations SHOULD separate reading of "data" from reading of "metadata". Examples: + - schema read or inference SHOULD be a separate function + - functions that read "data" SHOULD consume a schema typically pre-read. +- Implementations SHOULD separate IO-bounded operations from CPU-bounded operations. + I.e. implementations SHOULD: + - contain functions that consume a `Read` implementor and output a "raw" struct, i.e. a struct that is e.g. compressed and serialized + - contain functions that consume a "raw" struct and convert it into Arrow. + - offer each of these functions as independent public APIs, so that consumers can decide how to balance CPU-bounds and IO-bounds. diff --git a/crates/nano-arrow/src/io/avro/mod.rs b/crates/nano-arrow/src/io/avro/mod.rs new file mode 100644 index 000000000000..bf7bda85f197 --- /dev/null +++ b/crates/nano-arrow/src/io/avro/mod.rs @@ -0,0 +1,42 @@ +//! Read and write from and to Apache Avro + +pub use avro_schema; + +impl From for crate::error::Error { + fn from(error: avro_schema::error::Error) -> Self { + Self::ExternalFormat(error.to_string()) + } +} + +pub mod read; +pub mod write; + +// macros that can operate in sync and async code. +macro_rules! avro_decode { + ($reader:ident $($_await:tt)*) => { + { + let mut i = 0u64; + let mut buf = [0u8; 1]; + let mut j = 0; + loop { + if j > 9 { + // if j * 7 > 64 + return Err(Error::ExternalFormat( + "zigzag decoding failed - corrupt avro file".to_string(), + )); + } + $reader.read_exact(&mut buf[..])$($_await)*?; + i |= (u64::from(buf[0] & 0x7F)) << (j * 7); + if (buf[0] >> 7) == 0 { + break; + } else { + j += 1; + } + } + + Ok(i) + } + } +} + +pub(crate) use avro_decode; diff --git a/crates/nano-arrow/src/io/avro/read/deserialize.rs b/crates/nano-arrow/src/io/avro/read/deserialize.rs new file mode 100644 index 000000000000..6cafd9d8c4c1 --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/deserialize.rs @@ -0,0 +1,526 @@ +use std::convert::TryInto; + +use avro_schema::file::Block; +use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema}; + +use super::nested::*; +use super::util; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::types::months_days_ns; + +fn make_mutable( + data_type: &DataType, + avro_field: Option<&AvroSchema>, + capacity: usize, +) -> Result> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => { + Box::new(MutableBooleanArray::with_capacity(capacity)) as Box + }, + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }), + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::Utf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::Dictionary(_) => { + if let Some(AvroSchema::Enum(Enum { symbols, .. })) = avro_field { + let values = Utf8Array::::from_slice(symbols); + Box::new(FixedItemsUtf8Dictionary::with_capacity(values, capacity)) + as Box + } else { + unreachable!() + } + }, + _ => match data_type { + DataType::List(inner) => { + let values = make_mutable(inner.data_type(), None, 0)?; + Box::new(DynMutableListArray::::new_from( + values, + data_type.clone(), + capacity, + )) as Box + }, + DataType::FixedSizeBinary(size) => { + Box::new(MutableFixedSizeBinaryArray::with_capacity(*size, capacity)) + as Box + }, + DataType::Struct(fields) => { + let values = fields + .iter() + .map(|field| make_mutable(field.data_type(), None, capacity)) + .collect::>>()?; + Box::new(DynMutableStructArray::new(values, data_type.clone())) + as Box + }, + other => { + return Err(Error::NotYetImplemented(format!( + "Deserializing type {other:#?} is still not implemented" + ))) + }, + }, + }) +} + +fn is_union_null_first(avro_field: &AvroSchema) -> bool { + if let AvroSchema::Union(schemas) = avro_field { + schemas[0] == AvroSchema::Null + } else { + unreachable!() + } +} + +fn deserialize_item<'a>( + array: &mut dyn MutableArray, + is_nullable: bool, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> Result<&'a [u8]> { + if is_nullable { + let variant = util::zigzag_i64(&mut block)?; + let is_null_first = is_union_null_first(avro_field); + if is_null_first && variant == 0 || !is_null_first && variant != 0 { + array.push_null(); + return Ok(block); + } + } + deserialize_value(array, avro_field, block) +} + +fn deserialize_value<'a>( + array: &mut dyn MutableArray, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> Result<&'a [u8]> { + let data_type = array.data_type(); + match data_type { + DataType::List(inner) => { + let is_nullable = inner.is_nullable; + let avro_inner = match avro_field { + AvroSchema::Array(inner) => inner.as_ref(), + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => { + inner.as_ref() + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + // Arrays are encoded as a series of blocks. + loop { + // Each block consists of a long count value, followed by that many array items. + let len = util::zigzag_i64(&mut block)?; + let len = if len < 0 { + // Avro spec: If a block's count is negative, its absolute value is used, + // and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields. + let _ = util::zigzag_i64(&mut block)?; + + -len + } else { + len + }; + + // A block with count zero indicates the end of the array. + if len == 0 { + break; + } + + // Each item is encoded per the array’s item schema. + let values = array.mut_values(); + for _ in 0..len { + block = deserialize_item(values, is_nullable, avro_inner, block)?; + } + } + array.try_push_valid()?; + }, + DataType::Struct(inner_fields) => { + let fields = match avro_field { + AvroSchema::Record(Record { fields, .. }) => fields, + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Record(Record { fields, .. }), _] + | &[_, AvroSchema::Record(Record { fields, .. })] => fields, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let is_nullable = inner_fields + .iter() + .map(|x| x.is_nullable) + .collect::>(); + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + + for (index, (field, is_nullable)) in fields.iter().zip(is_nullable.iter()).enumerate() { + let values = array.mut_values(index); + block = deserialize_item(values, *is_nullable, &field.schema, block)?; + } + array.try_push_valid()?; + }, + _ => match data_type.to_physical_type() { + PhysicalType::Boolean => { + let is_valid = block[0] == 1; + block = &block[1..]; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push(Some(is_valid)) + }, + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int32 => { + let value = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Int64 => { + let value = util::zigzag_i64(&mut block)?; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Float32 => { + let value = + f32::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); + block = &block[std::mem::size_of::()..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Float64 => { + let value = + f64::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); + block = &block[std::mem::size_of::()..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::MonthDayNano => { + // https://avro.apache.org/docs/current/spec.html#Duration + // 12 bytes, months, days, millis in LE + let data = &block[..12]; + block = &block[12..]; + + let value = months_days_ns::new( + i32::from_le_bytes([data[0], data[1], data[2], data[3]]), + i32::from_le_bytes([data[4], data[5], data[6], data[7]]), + i32::from_le_bytes([data[8], data[9], data[10], data[11]]) as i64 + * 1_000_000, + ); + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Int128 => { + let avro_inner = match avro_field { + AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field, + AvroSchema::Union(u) => match &u.as_slice() { + &[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let len = match avro_inner { + AvroSchema::Bytes(_) => { + util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })? + }, + AvroSchema::Fixed(b) => b.size, + _ => unreachable!(), + }; + if len > 16 { + return Err(Error::ExternalFormat( + "Avro decimal bytes return more than 16 bytes".to_string(), + )); + } + let mut bytes = [0u8; 16]; + bytes[..len].copy_from_slice(&block[..len]); + block = &block[len..]; + let data = i128::from_be_bytes(bytes) >> (8 * (16 - len)); + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + }, + _ => unreachable!(), + }, + PhysicalType::Utf8 => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + let data = simdutf8::basic::from_utf8(&block[..len])?; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + }, + PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + let data = &block[..len]; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)); + }, + PhysicalType::FixedSizeBinary => { + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + let len = array.size(); + let data = &block[..len]; + block = &block[len..]; + array.push(Some(data)); + }, + PhysicalType::Dictionary(_) => { + let index = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push_valid(index); + }, + _ => todo!(), + }, + }; + Ok(block) +} + +fn skip_item<'a>(field: &Field, avro_field: &AvroSchema, mut block: &'a [u8]) -> Result<&'a [u8]> { + if field.is_nullable { + let variant = util::zigzag_i64(&mut block)?; + let is_null_first = is_union_null_first(avro_field); + if is_null_first && variant == 0 || !is_null_first && variant != 0 { + return Ok(block); + } + } + match &field.data_type { + DataType::List(inner) => { + let avro_inner = match avro_field { + AvroSchema::Array(inner) => inner.as_ref(), + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => { + inner.as_ref() + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + loop { + let len = util::zigzag_i64(&mut block)?; + let (len, bytes) = if len < 0 { + // Avro spec: If a block's count is negative, its absolute value is used, + // and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields. + let bytes = util::zigzag_i64(&mut block)?; + + (-len, Some(bytes)) + } else { + (len, None) + }; + + let bytes: Option = bytes + .map(|bytes| { + bytes + .try_into() + .map_err(|_| Error::oos("Avro block size negative or too large")) + }) + .transpose()?; + + if len == 0 { + break; + } + + if let Some(bytes) = bytes { + block = &block[bytes..]; + } else { + for _ in 0..len { + block = skip_item(inner, avro_inner, block)?; + } + } + } + }, + DataType::Struct(inner_fields) => { + let fields = match avro_field { + AvroSchema::Record(Record { fields, .. }) => fields, + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Record(Record { fields, .. }), _] + | &[_, AvroSchema::Record(Record { fields, .. })] => fields, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + for (field, avro_field) in inner_fields.iter().zip(fields.iter()) { + block = skip_item(field, &avro_field.schema, block)?; + } + }, + _ => match field.data_type.to_physical_type() { + PhysicalType::Boolean => { + let _ = block[0] == 1; + block = &block[1..]; + }, + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int32 => { + let _ = util::zigzag_i64(&mut block)?; + }, + PrimitiveType::Int64 => { + let _ = util::zigzag_i64(&mut block)?; + }, + PrimitiveType::Float32 => { + block = &block[std::mem::size_of::()..]; + }, + PrimitiveType::Float64 => { + block = &block[std::mem::size_of::()..]; + }, + PrimitiveType::MonthDayNano => { + block = &block[12..]; + }, + PrimitiveType::Int128 => { + let avro_inner = match avro_field { + AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field, + AvroSchema::Union(u) => match &u.as_slice() { + &[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let len = match avro_inner { + AvroSchema::Bytes(_) => { + util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })? + }, + AvroSchema::Fixed(b) => b.size, + _ => unreachable!(), + }; + block = &block[len..]; + }, + _ => unreachable!(), + }, + PhysicalType::Utf8 | PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + block = &block[len..]; + }, + PhysicalType::FixedSizeBinary => { + let len = if let DataType::FixedSizeBinary(len) = &field.data_type { + *len + } else { + unreachable!() + }; + + block = &block[len..]; + }, + PhysicalType::Dictionary(_) => { + let _ = util::zigzag_i64(&mut block)? as i32; + }, + _ => todo!(), + }, + } + Ok(block) +} + +/// Deserializes a [`Block`] assumed to be encoded according to [`AvroField`] into [`Chunk`], +/// using `projection` to ignore `avro_fields`. +/// # Panics +/// `fields`, `avro_fields` and `projection` must have the same length. +pub fn deserialize( + block: &Block, + fields: &[Field], + avro_fields: &[AvroField], + projection: &[bool], +) -> Result>> { + assert_eq!(fields.len(), avro_fields.len()); + assert_eq!(fields.len(), projection.len()); + + let rows = block.number_of_rows; + let mut block = block.data.as_ref(); + + // create mutables, one per field + let mut arrays: Vec> = fields + .iter() + .zip(avro_fields.iter()) + .zip(projection.iter()) + .map(|((field, avro_field), projection)| { + if *projection { + make_mutable(&field.data_type, Some(&avro_field.schema), rows) + } else { + // just something; we are not going to use it + make_mutable(&DataType::Int32, None, 0) + } + }) + .collect::>()?; + + // this is _the_ expensive transpose (rows -> columns) + for _ in 0..rows { + let iter = arrays + .iter_mut() + .zip(fields.iter()) + .zip(avro_fields.iter()) + .zip(projection.iter()); + + for (((array, field), avro_field), projection) in iter { + block = if *projection { + deserialize_item(array.as_mut(), field.is_nullable, &avro_field.schema, block) + } else { + skip_item(field, &avro_field.schema, block) + }? + } + } + Chunk::try_new( + arrays + .iter_mut() + .zip(projection.iter()) + .filter_map(|x| x.1.then(|| x.0)) + .map(|array| array.as_box()) + .collect(), + ) +} diff --git a/crates/nano-arrow/src/io/avro/read/mod.rs b/crates/nano-arrow/src/io/avro/read/mod.rs new file mode 100644 index 000000000000..5014499c12a6 --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/mod.rs @@ -0,0 +1,67 @@ +//! APIs to read from Avro format to arrow. +use std::io::Read; + +use avro_schema::file::FileMetadata; +use avro_schema::read::fallible_streaming_iterator::FallibleStreamingIterator; +use avro_schema::read::{block_iterator, BlockStreamingIterator}; +use avro_schema::schema::Field as AvroField; + +mod deserialize; +pub use deserialize::deserialize; +mod nested; +mod schema; +mod util; + +pub use schema::infer_schema; + +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Field; +use crate::error::Result; + +/// Single threaded, blocking reader of Avro; [`Iterator`] of [`Chunk`]. +pub struct Reader { + iter: BlockStreamingIterator, + avro_fields: Vec, + fields: Vec, + projection: Vec, +} + +impl Reader { + /// Creates a new [`Reader`]. + pub fn new( + reader: R, + metadata: FileMetadata, + fields: Vec, + projection: Option>, + ) -> Self { + let projection = projection.unwrap_or_else(|| fields.iter().map(|_| true).collect()); + + Self { + iter: block_iterator(reader, metadata.compression, metadata.marker), + avro_fields: metadata.record.fields, + fields, + projection, + } + } + + /// Deconstructs itself into its internal reader + pub fn into_inner(self) -> R { + self.iter.into_inner() + } +} + +impl Iterator for Reader { + type Item = Result>>; + + fn next(&mut self) -> Option { + let fields = &self.fields[..]; + let avro_fields = &self.avro_fields; + let projection = &self.projection; + + self.iter + .next() + .transpose() + .map(|maybe_block| deserialize(maybe_block?, fields, avro_fields, projection)) + } +} diff --git a/crates/nano-arrow/src/io/avro/read/nested.rs b/crates/nano-arrow/src/io/avro/read/nested.rs new file mode 100644 index 000000000000..fd5bb6b7dbbd --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/nested.rs @@ -0,0 +1,309 @@ +use crate::array::*; +use crate::bitmap::*; +use crate::datatypes::*; +use crate::error::*; +use crate::offset::{Offset, Offsets}; + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + offsets: Offsets, + values: Box, + validity: Option, +} + +impl DynMutableListArray { + pub fn new_from(values: Box, data_type: DataType, capacity: usize) -> Self { + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets: Offsets::::with_capacity(capacity), + values, + validity: None, + } + } + + /// The values + pub fn mut_values(&mut self) -> &mut dyn MutableArray { + self.values.as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> Result<()> { + let total_length = self.values.len(); + let offset = self.offsets.last().to_usize(); + let length = total_length.checked_sub(offset).ok_or(Error::Overflow)?; + + self.offsets.try_push(length)?; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.extend_constant(1); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + fn init_validity(&mut self) { + let len = self.offsets.len_proxy(); + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableListArray { + fn len(&self) -> usize { + self.offsets.len_proxy() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> std::sync::Arc { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} + +#[derive(Debug)] +pub struct FixedItemsUtf8Dictionary { + data_type: DataType, + keys: MutablePrimitiveArray, + values: Utf8Array, +} + +impl FixedItemsUtf8Dictionary { + pub fn with_capacity(values: Utf8Array, capacity: usize) -> Self { + Self { + data_type: DataType::Dictionary( + IntegerType::Int32, + Box::new(values.data_type().clone()), + false, + ), + keys: MutablePrimitiveArray::::with_capacity(capacity), + values, + } + } + + pub fn push_valid(&mut self, key: i32) { + self.keys.push(Some(key)) + } + + /// pushes a null value + pub fn push_null(&mut self) { + self.keys.push(None) + } +} + +impl MutableArray for FixedItemsUtf8Dictionary { + fn len(&self) -> usize { + self.keys.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.keys.validity() + } + + fn as_box(&mut self) -> Box { + Box::new( + DictionaryArray::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) + } + + fn as_arc(&mut self) -> std::sync::Arc { + std::sync::Arc::new( + DictionaryArray::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableStructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +impl DynMutableStructArray { + pub fn new(values: Vec>, data_type: DataType) -> Self { + Self { + data_type, + values, + validity: None, + } + } + + /// The values + pub fn mut_values(&mut self, field: usize) -> &mut dyn MutableArray { + self.values[field].as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> Result<()> { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.values.iter_mut().for_each(|x| x.push_null()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + fn init_validity(&mut self) { + let len = self.len(); + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableStructArray { + fn len(&self) -> usize { + self.values[0].len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let values = self.values.iter_mut().map(|x| x.as_box()).collect(); + + Box::new(StructArray::new( + self.data_type.clone(), + values, + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn as_arc(&mut self) -> std::sync::Arc { + let values = self.values.iter_mut().map(|x| x.as_box()).collect(); + + std::sync::Arc::new(StructArray::new( + self.data_type.clone(), + values, + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} diff --git a/crates/nano-arrow/src/io/avro/read/schema.rs b/crates/nano-arrow/src/io/avro/read/schema.rs new file mode 100644 index 000000000000..ca50c59ca9fa --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/schema.rs @@ -0,0 +1,145 @@ +use avro_schema::schema::{Enum, Fixed, Record, Schema as AvroSchema}; + +use crate::datatypes::*; +use crate::error::{Error, Result}; + +fn external_props(schema: &AvroSchema) -> Metadata { + let mut props = Metadata::new(); + match &schema { + AvroSchema::Record(Record { + doc: Some(ref doc), .. + }) + | AvroSchema::Enum(Enum { + doc: Some(ref doc), .. + }) => { + props.insert("avro::doc".to_string(), doc.clone()); + }, + _ => {}, + } + props +} + +/// Infers an [`Schema`] from the root [`Record`]. +/// This +pub fn infer_schema(record: &Record) -> Result { + Ok(record + .fields + .iter() + .map(|field| { + schema_to_field( + &field.schema, + Some(&field.name), + external_props(&field.schema), + ) + }) + .collect::>>()? + .into()) +} + +fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) -> Result { + let mut nullable = false; + let data_type = match schema { + AvroSchema::Null => DataType::Null, + AvroSchema::Boolean => DataType::Boolean, + AvroSchema::Int(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::IntLogical::Date => DataType::Date32, + avro_schema::schema::IntLogical::Time => DataType::Time32(TimeUnit::Millisecond), + }, + None => DataType::Int32, + }, + AvroSchema::Long(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::LongLogical::Time => DataType::Time64(TimeUnit::Microsecond), + avro_schema::schema::LongLogical::TimestampMillis => { + DataType::Timestamp(TimeUnit::Millisecond, Some("00:00".to_string())) + }, + avro_schema::schema::LongLogical::TimestampMicros => { + DataType::Timestamp(TimeUnit::Microsecond, Some("00:00".to_string())) + }, + avro_schema::schema::LongLogical::LocalTimestampMillis => { + DataType::Timestamp(TimeUnit::Millisecond, None) + }, + avro_schema::schema::LongLogical::LocalTimestampMicros => { + DataType::Timestamp(TimeUnit::Microsecond, None) + }, + }, + None => DataType::Int64, + }, + AvroSchema::Float => DataType::Float32, + AvroSchema::Double => DataType::Float64, + AvroSchema::Bytes(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::BytesLogical::Decimal(precision, scale) => { + DataType::Decimal(*precision, *scale) + }, + }, + None => DataType::Binary, + }, + AvroSchema::String(_) => DataType::Utf8, + AvroSchema::Array(item_schema) => DataType::List(Box::new(schema_to_field( + item_schema, + Some("item"), // default name for list items + Metadata::default(), + )?)), + AvroSchema::Map(_) => todo!("Avro maps are mapped to MapArrays"), + AvroSchema::Union(schemas) => { + // If there are only two variants and one of them is null, set the other type as the field data type + let has_nullable = schemas.iter().any(|x| x == &AvroSchema::Null); + if has_nullable && schemas.len() == 2 { + nullable = true; + if let Some(schema) = schemas + .iter() + .find(|&schema| !matches!(schema, AvroSchema::Null)) + { + schema_to_field(schema, None, Metadata::default())?.data_type + } else { + return Err(Error::NotYetImplemented(format!( + "Can't read avro union {schema:?}" + ))); + } + } else { + let fields = schemas + .iter() + .map(|s| schema_to_field(s, None, Metadata::default())) + .collect::>>()?; + DataType::Union(fields, None, UnionMode::Dense) + } + }, + AvroSchema::Record(Record { fields, .. }) => { + let fields = fields + .iter() + .map(|field| { + let mut props = Metadata::new(); + if let Some(doc) = &field.doc { + props.insert("avro::doc".to_string(), doc.clone()); + } + schema_to_field(&field.schema, Some(&field.name), props) + }) + .collect::>()?; + DataType::Struct(fields) + }, + AvroSchema::Enum { .. } => { + return Ok(Field::new( + name.unwrap_or_default(), + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), false), + false, + )) + }, + AvroSchema::Fixed(Fixed { size, logical, .. }) => match logical { + Some(logical) => match logical { + avro_schema::schema::FixedLogical::Decimal(precision, scale) => { + DataType::Decimal(*precision, *scale) + }, + avro_schema::schema::FixedLogical::Duration => { + DataType::Interval(IntervalUnit::MonthDayNano) + }, + }, + None => DataType::FixedSizeBinary(*size), + }, + }; + + let name = name.unwrap_or_default(); + + Ok(Field::new(name, data_type, nullable).with_metadata(props)) +} diff --git a/crates/nano-arrow/src/io/avro/read/util.rs b/crates/nano-arrow/src/io/avro/read/util.rs new file mode 100644 index 000000000000..a26ee0e005ee --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/util.rs @@ -0,0 +1,17 @@ +use std::io::Read; + +use super::super::avro_decode; +use crate::error::{Error, Result}; + +pub fn zigzag_i64(reader: &mut R) -> Result { + let z = decode_variable(reader)?; + Ok(if z & 0x1 == 0 { + (z >> 1) as i64 + } else { + !(z >> 1) as i64 + }) +} + +fn decode_variable(reader: &mut R) -> Result { + avro_decode!(reader) +} diff --git a/crates/nano-arrow/src/io/avro/write/mod.rs b/crates/nano-arrow/src/io/avro/write/mod.rs new file mode 100644 index 000000000000..6448782bb44e --- /dev/null +++ b/crates/nano-arrow/src/io/avro/write/mod.rs @@ -0,0 +1,28 @@ +//! APIs to write to Avro format. +use avro_schema::file::Block; + +mod schema; +pub use schema::to_record; +mod serialize; +pub use serialize::{can_serialize, new_serializer, BoxSerializer}; + +/// consumes a set of [`BoxSerializer`] into an [`Block`]. +/// # Panics +/// Panics iff the number of items in any of the serializers is not equal to the number of rows +/// declared in the `block`. +pub fn serialize(serializers: &mut [BoxSerializer], block: &mut Block) { + let Block { + data, + number_of_rows, + } = block; + + data.clear(); // restart it + + // _the_ transpose (columns -> rows) + for _ in 0..*number_of_rows { + for serializer in &mut *serializers { + let item_data = serializer.next().unwrap(); + data.extend(item_data); + } + } +} diff --git a/crates/nano-arrow/src/io/avro/write/schema.rs b/crates/nano-arrow/src/io/avro/write/schema.rs new file mode 100644 index 000000000000..b81cdc77ce3a --- /dev/null +++ b/crates/nano-arrow/src/io/avro/write/schema.rs @@ -0,0 +1,91 @@ +use avro_schema::schema::{ + BytesLogical, Field as AvroField, Fixed, FixedLogical, IntLogical, LongLogical, Record, + Schema as AvroSchema, +}; + +use crate::datatypes::*; +use crate::error::{Error, Result}; + +/// Converts a [`Schema`] to an Avro [`Record`]. +pub fn to_record(schema: &Schema) -> Result { + let mut name_counter: i32 = 0; + let fields = schema + .fields + .iter() + .map(|f| field_to_field(f, &mut name_counter)) + .collect::>()?; + Ok(Record { + name: "".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields, + }) +} + +fn field_to_field(field: &Field, name_counter: &mut i32) -> Result { + let schema = type_to_schema(field.data_type(), field.is_nullable, name_counter)?; + Ok(AvroField::new(&field.name, schema)) +} + +fn type_to_schema( + data_type: &DataType, + is_nullable: bool, + name_counter: &mut i32, +) -> Result { + Ok(if is_nullable { + AvroSchema::Union(vec![ + AvroSchema::Null, + _type_to_schema(data_type, name_counter)?, + ]) + } else { + _type_to_schema(data_type, name_counter)? + }) +} + +fn _get_field_name(name_counter: &mut i32) -> String { + *name_counter += 1; + format!("r{name_counter}") +} + +fn _type_to_schema(data_type: &DataType, name_counter: &mut i32) -> Result { + Ok(match data_type.to_logical_type() { + DataType::Null => AvroSchema::Null, + DataType::Boolean => AvroSchema::Boolean, + DataType::Int32 => AvroSchema::Int(None), + DataType::Int64 => AvroSchema::Long(None), + DataType::Float32 => AvroSchema::Float, + DataType::Float64 => AvroSchema::Double, + DataType::Binary => AvroSchema::Bytes(None), + DataType::LargeBinary => AvroSchema::Bytes(None), + DataType::Utf8 => AvroSchema::String(None), + DataType::LargeUtf8 => AvroSchema::String(None), + DataType::LargeList(inner) | DataType::List(inner) => AvroSchema::Array(Box::new( + type_to_schema(&inner.data_type, inner.is_nullable, name_counter)?, + )), + DataType::Struct(fields) => AvroSchema::Record(Record::new( + _get_field_name(name_counter), + fields + .iter() + .map(|f| field_to_field(f, name_counter)) + .collect::>>()?, + )), + DataType::Date32 => AvroSchema::Int(Some(IntLogical::Date)), + DataType::Time32(TimeUnit::Millisecond) => AvroSchema::Int(Some(IntLogical::Time)), + DataType::Time64(TimeUnit::Microsecond) => AvroSchema::Long(Some(LongLogical::Time)), + DataType::Timestamp(TimeUnit::Millisecond, None) => { + AvroSchema::Long(Some(LongLogical::LocalTimestampMillis)) + }, + DataType::Timestamp(TimeUnit::Microsecond, None) => { + AvroSchema::Long(Some(LongLogical::LocalTimestampMicros)) + }, + DataType::Interval(IntervalUnit::MonthDayNano) => { + let mut fixed = Fixed::new("", 12); + fixed.logical = Some(FixedLogical::Duration); + AvroSchema::Fixed(fixed) + }, + DataType::FixedSizeBinary(size) => AvroSchema::Fixed(Fixed::new("", *size)), + DataType::Decimal(p, s) => AvroSchema::Bytes(Some(BytesLogical::Decimal(*p, *s))), + other => return Err(Error::NotYetImplemented(format!("write {other:?} to avro"))), + }) +} diff --git a/crates/nano-arrow/src/io/avro/write/serialize.rs b/crates/nano-arrow/src/io/avro/write/serialize.rs new file mode 100644 index 000000000000..888861db376a --- /dev/null +++ b/crates/nano-arrow/src/io/avro/write/serialize.rs @@ -0,0 +1,535 @@ +use avro_schema::schema::{Record, Schema as AvroSchema}; +use avro_schema::write::encode; + +use super::super::super::iterator::*; +use crate::array::*; +use crate::bitmap::utils::ZipValidity; +use crate::datatypes::{DataType, IntervalUnit, PhysicalType, PrimitiveType}; +use crate::offset::Offset; +use crate::types::months_days_ns; + +// Zigzag representation of false and true respectively. +const IS_NULL: u8 = 0; +const IS_VALID: u8 = 2; + +/// A type alias for a boxed [`StreamingIterator`], used to write arrays into avro rows +/// (i.e. a column -> row transposition of types known at run-time) +pub type BoxSerializer<'a> = Box + 'a + Send + Sync>; + +fn utf8_required(array: &Utf8Array) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x.as_bytes()); + }, + vec![], + )) +} + +fn utf8_optional(array: &Utf8Array) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x.as_bytes()); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn binary_required(array: &BinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x); + }, + vec![], + )) +} + +fn binary_optional(array: &BinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn fixed_size_binary_required(array: &FixedSizeBinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + buf.extend_from_slice(x); + }, + vec![], + )) +} + +fn fixed_size_binary_optional(array: &FixedSizeBinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend_from_slice(x); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn list_required<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> BoxSerializer<'a> { + let mut inner = new_serializer(array.values().as_ref(), schema); + let lengths = array + .offsets() + .buffer() + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64); + + Box::new(BufStreamingIterator::new( + lengths, + move |length, buf| { + encode::zigzag_encode(length, buf).unwrap(); + let mut rows = 0; + while let Some(item) = inner.next() { + buf.extend_from_slice(item); + rows += 1; + if rows == length { + encode::zigzag_encode(0, buf).unwrap(); + break; + } + } + }, + vec![], + )) +} + +fn list_optional<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> BoxSerializer<'a> { + let mut inner = new_serializer(array.values().as_ref(), schema); + let lengths = array + .offsets() + .buffer() + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64); + let lengths = ZipValidity::new_with_validity(lengths, array.validity()); + + Box::new(BufStreamingIterator::new( + lengths, + move |length, buf| { + if let Some(length) = length { + buf.push(IS_VALID); + encode::zigzag_encode(length, buf).unwrap(); + let mut rows = 0; + while let Some(item) = inner.next() { + buf.extend_from_slice(item); + rows += 1; + if rows == length { + encode::zigzag_encode(0, buf).unwrap(); + break; + } + } + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn struct_required<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + Box::new(BufStreamingIterator::new( + 0..array.len(), + move |_, buf| { + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + }, + vec![], + )) +} + +fn struct_optional<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + let iterator = ZipValidity::new_with_validity(0..array.len(), array.validity()); + + Box::new(BufStreamingIterator::new( + iterator, + move |maybe, buf| { + if maybe.is_some() { + buf.push(IS_VALID); + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + } else { + buf.push(IS_NULL); + // skip the item + inner.iter_mut().for_each(|item| { + let _ = item.next().unwrap(); + }); + } + }, + vec![], + )) +} + +/// Creates a [`StreamingIterator`] trait object that presents items from `array` +/// encoded according to `schema`. +/// # Panic +/// This function panics iff the `data_type` is not supported (use [`can_serialize`] to check) +/// # Implementation +/// This function performs minimal CPU work: it dynamically dispatches based on the schema +/// and arrow type. +pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSerializer<'a> { + let data_type = array.data_type().to_physical_type(); + + match (data_type, schema) { + (PhysicalType::Boolean, AvroSchema::Boolean) => { + let values = array.as_any().downcast_ref::().unwrap(); + Box::new(BufStreamingIterator::new( + values.values_iter(), + |x, buf| { + buf.push(x as u8); + }, + vec![], + )) + }, + (PhysicalType::Boolean, AvroSchema::Union(_)) => { + let values = array.as_any().downcast_ref::().unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.extend_from_slice(&[IS_VALID, x as u8]); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Utf8, AvroSchema::Union(_)) => { + utf8_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeUtf8, AvroSchema::Union(_)) => { + utf8_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Utf8, AvroSchema::String(_)) => { + utf8_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeUtf8, AvroSchema::String(_)) => { + utf8_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Binary, AvroSchema::Union(_)) => { + binary_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeBinary, AvroSchema::Union(_)) => { + binary_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::FixedSizeBinary, AvroSchema::Union(_)) => { + fixed_size_binary_optional(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Binary, AvroSchema::Bytes(_)) => { + binary_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeBinary, AvroSchema::Bytes(_)) => { + binary_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::FixedSizeBinary, AvroSchema::Fixed(_)) => { + fixed_size_binary_required(array.as_any().downcast_ref().unwrap()) + }, + + (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(*x as i64, buf).unwrap(); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Int(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + encode::zigzag_encode(*x as i64, buf).unwrap(); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(*x, buf).unwrap(); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Long(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + encode::zigzag_encode(*x, buf).unwrap(); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend(x.to_le_bytes()) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Float) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + buf.extend_from_slice(&x.to_le_bytes()); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend(x.to_le_bytes()) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Double) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + buf.extend_from_slice(&x.to_le_bytes()); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Bytes(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + let len = ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize; + encode::zigzag_encode((16 - len) as i64, buf).unwrap(); + buf.extend_from_slice(&x.to_be_bytes()[len..]); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + let len = + ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize; + encode::zigzag_encode((16 - len) as i64, buf).unwrap(); + buf.extend_from_slice(&x.to_be_bytes()[len..]); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Fixed(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + interval_write, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + interval_write(x, buf) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + + (PhysicalType::List, AvroSchema::Array(schema)) => { + list_required::(array.as_any().downcast_ref().unwrap(), schema.as_ref()) + }, + (PhysicalType::LargeList, AvroSchema::Array(schema)) => { + list_required::(array.as_any().downcast_ref().unwrap(), schema.as_ref()) + }, + (PhysicalType::List, AvroSchema::Union(inner)) => { + let schema = if let AvroSchema::Array(schema) = &inner[1] { + schema.as_ref() + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + list_optional::(array.as_any().downcast_ref().unwrap(), schema) + }, + (PhysicalType::LargeList, AvroSchema::Union(inner)) => { + let schema = if let AvroSchema::Array(schema) = &inner[1] { + schema.as_ref() + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + list_optional::(array.as_any().downcast_ref().unwrap(), schema) + }, + (PhysicalType::Struct, AvroSchema::Record(inner)) => { + struct_required(array.as_any().downcast_ref().unwrap(), inner) + }, + (PhysicalType::Struct, AvroSchema::Union(inner)) => { + let inner = if let AvroSchema::Record(inner) = &inner[1] { + inner + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + struct_optional(array.as_any().downcast_ref().unwrap(), inner) + }, + (a, b) => todo!("{:?} -> {:?} not supported", a, b), + } +} + +/// Whether [`new_serializer`] supports `data_type`. +pub fn can_serialize(data_type: &DataType) -> bool { + use DataType::*; + match data_type.to_logical_type() { + List(inner) => return can_serialize(&inner.data_type), + LargeList(inner) => return can_serialize(&inner.data_type), + Struct(inner) => return inner.iter().all(|inner| can_serialize(&inner.data_type)), + _ => {}, + }; + + matches!( + data_type, + Boolean + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal(_, _) + | Utf8 + | Binary + | FixedSizeBinary(_) + | LargeUtf8 + | LargeBinary + | Interval(IntervalUnit::MonthDayNano) + ) +} + +#[inline] +fn interval_write(x: &months_days_ns, buf: &mut Vec) { + // https://avro.apache.org/docs/current/spec.html#Duration + // 12 bytes, months, days, millis in LE + buf.reserve(12); + buf.extend(x.months().to_le_bytes()); + buf.extend(x.days().to_le_bytes()); + buf.extend(((x.ns() / 1_000_000) as i32).to_le_bytes()); +} diff --git a/crates/nano-arrow/src/io/flight/mod.rs b/crates/nano-arrow/src/io/flight/mod.rs new file mode 100644 index 000000000000..0cce1774568f --- /dev/null +++ b/crates/nano-arrow/src/io/flight/mod.rs @@ -0,0 +1,243 @@ +//! Serialization and deserialization to Arrow's flight protocol + +use arrow_format::flight::data::{FlightData, SchemaResult}; +use arrow_format::ipc; +use arrow_format::ipc::planus::ReadAsRoot; + +use super::ipc::read::Dictionaries; +pub use super::ipc::write::default_ipc_fields; +use super::ipc::{IpcField, IpcSchema}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +pub use crate::io::ipc::write::common::WriteOptions; +use crate::io::ipc::write::common::{encode_chunk, DictionaryTracker, EncodedData}; +use crate::io::ipc::{read, write}; + +/// Serializes [`Chunk`] to a vector of [`FlightData`] representing the serialized dictionaries +/// and a [`FlightData`] representing the batch. +/// # Errors +/// This function errors iff `fields` is not consistent with `columns` +pub fn serialize_batch( + chunk: &Chunk>, + fields: &[IpcField], + options: &WriteOptions, +) -> Result<(Vec, FlightData)> { + if fields.len() != chunk.arrays().len() { + return Err(Error::InvalidArgumentError("The argument `fields` must be consistent with the columns' schema. Use e.g. &arrow2::io::flight::default_ipc_fields(&schema.fields)".to_string())); + } + + let mut dictionary_tracker = DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }; + + let (encoded_dictionaries, encoded_batch) = + encode_chunk(chunk, fields, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + Ok((flight_dictionaries, flight_batch)) +} + +impl From for FlightData { + fn from(data: EncodedData) -> Self { + FlightData { + data_header: data.ipc_message, + data_body: data.arrow_data, + ..Default::default() + } + } +} + +/// Serializes a [`Schema`] to [`SchemaResult`]. +pub fn serialize_schema_to_result( + schema: &Schema, + ipc_fields: Option<&[IpcField]>, +) -> SchemaResult { + SchemaResult { + schema: _serialize_schema(schema, ipc_fields), + } +} + +/// Serializes a [`Schema`] to [`FlightData`]. +pub fn serialize_schema(schema: &Schema, ipc_fields: Option<&[IpcField]>) -> FlightData { + FlightData { + data_header: _serialize_schema(schema, ipc_fields), + ..Default::default() + } +} + +/// Convert a [`Schema`] to bytes in the format expected in [`arrow_format::flight::data::FlightInfo`]. +pub fn serialize_schema_to_info( + schema: &Schema, + ipc_fields: Option<&[IpcField]>, +) -> Result> { + let encoded_data = if let Some(ipc_fields) = ipc_fields { + schema_as_encoded_data(schema, ipc_fields) + } else { + let ipc_fields = default_ipc_fields(&schema.fields); + schema_as_encoded_data(schema, &ipc_fields) + }; + + let mut schema = vec![]; + write::common_sync::write_message(&mut schema, &encoded_data)?; + Ok(schema) +} + +fn _serialize_schema(schema: &Schema, ipc_fields: Option<&[IpcField]>) -> Vec { + if let Some(ipc_fields) = ipc_fields { + write::schema_to_bytes(schema, ipc_fields) + } else { + let ipc_fields = default_ipc_fields(&schema.fields); + write::schema_to_bytes(schema, &ipc_fields) + } +} + +fn schema_as_encoded_data(schema: &Schema, ipc_fields: &[IpcField]) -> EncodedData { + EncodedData { + ipc_message: write::schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], + } +} + +/// Deserialize an IPC message into [`Schema`], [`IpcSchema`]. +/// Use to deserialize [`FlightData::data_header`] and [`SchemaResult::schema`]. +pub fn deserialize_schemas(bytes: &[u8]) -> Result<(Schema, IpcSchema)> { + read::deserialize_schema(bytes) +} + +/// Deserializes [`FlightData`] representing a record batch message to [`Chunk`]. +pub fn deserialize_batch( + data: &FlightData, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &read::Dictionaries, +) -> Result>> { + // check that the data_header is a record batch message + let message = arrow_format::ipc::MessageRef::read_as_root(&data.data_header) + .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {err:?}")))?; + + let length = data.data_body.len(); + let mut reader = std::io::Cursor::new(&data.data_body); + + match message.header()?.ok_or_else(|| { + Error::oos("Unable to convert flight data header to a record batch".to_string()) + })? { + ipc::MessageHeaderRef::RecordBatch(batch) => read::read_record_batch( + batch, + fields, + ipc_schema, + None, + None, + dictionaries, + message.version()?, + &mut reader, + 0, + length as u64, + &mut Default::default(), + ), + _ => Err(Error::nyi( + "flight currently only supports reading RecordBatch messages", + )), + } +} + +/// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`. +pub fn deserialize_dictionary( + data: &FlightData, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut read::Dictionaries, +) -> Result<()> { + let message = ipc::MessageRef::read_as_root(&data.data_header)?; + + let chunk = if let ipc::MessageHeaderRef::DictionaryBatch(chunk) = message + .header()? + .ok_or_else(|| Error::oos("Header is required"))? + { + chunk + } else { + return Ok(()); + }; + + let length = data.data_body.len(); + let mut reader = std::io::Cursor::new(&data.data_body); + read::read_dictionary( + chunk, + fields, + ipc_schema, + dictionaries, + &mut reader, + 0, + length as u64, + &mut Default::default(), + )?; + + Ok(()) +} + +/// Deserializes [`FlightData`] into either a [`Chunk`] (when the message is a record batch) +/// or by upserting into `dictionaries` (when the message is a dictionary) +pub fn deserialize_message( + data: &FlightData, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, +) -> Result>>> { + let FlightData { + data_header, + data_body, + .. + } = data; + + let message = arrow_format::ipc::MessageRef::read_as_root(data_header)?; + let header = message + .header()? + .ok_or_else(|| Error::oos("IPC Message must contain a header"))?; + + match header { + ipc::MessageHeaderRef::RecordBatch(batch) => { + let length = data_body.len(); + let mut reader = std::io::Cursor::new(data_body); + + let chunk = read::read_record_batch( + batch, + fields, + ipc_schema, + None, + None, + dictionaries, + arrow_format::ipc::MetadataVersion::V5, + &mut reader, + 0, + length as u64, + &mut Default::default(), + )?; + + Ok(chunk.into()) + }, + ipc::MessageHeaderRef::DictionaryBatch(dict_batch) => { + let length = data_body.len(); + let mut reader = std::io::Cursor::new(data_body); + + read::read_dictionary( + dict_batch, + fields, + ipc_schema, + dictionaries, + &mut reader, + 0, + length as u64, + &mut Default::default(), + )?; + Ok(None) + }, + t => Err(Error::nyi(format!( + "Reading types other than record batches not yet supported, unable to read {t:?}" + ))), + } +} diff --git a/crates/nano-arrow/src/io/ipc/append/mod.rs b/crates/nano-arrow/src/io/ipc/append/mod.rs new file mode 100644 index 000000000000..1acb39a931ef --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/append/mod.rs @@ -0,0 +1,72 @@ +//! A struct adapter of Read+Seek+Write to append to IPC files +// read header and convert to writer information +// seek to first byte of header - 1 +// write new batch +// write new footer +use std::io::{Read, Seek, SeekFrom, Write}; + +use super::endianness::is_native_little_endian; +use super::read::{self, FileMetadata}; +use super::write::common::DictionaryTracker; +use super::write::writer::*; +use super::write::*; +use crate::error::{Error, Result}; + +impl FileWriter { + /// Creates a new [`FileWriter`] from an existing file, seeking to the last message + /// and appending new messages afterwards. Users call `finish` to write the footer (with both) + /// the existing and appended messages on it. + /// # Error + /// This function errors iff: + /// * the file's endianness is not the native endianness (not yet supported) + /// * the file is not a valid Arrow IPC file + pub fn try_from_file( + mut writer: R, + metadata: FileMetadata, + options: WriteOptions, + ) -> Result> { + if metadata.ipc_schema.is_little_endian != is_native_little_endian() { + return Err(Error::nyi( + "Appending to a file of a non-native endianness is still not supported", + )); + } + + let dictionaries = + read::read_file_dictionaries(&mut writer, &metadata, &mut Default::default())?; + + let last_block = metadata.blocks.last().ok_or_else(|| { + Error::oos("An Arrow IPC file must have at least 1 message (the schema message)") + })?; + let offset: u64 = last_block + .offset + .try_into() + .map_err(|_| Error::oos("The block's offset must be a positive number"))?; + let meta_data_length: u64 = last_block + .meta_data_length + .try_into() + .map_err(|_| Error::oos("The block's meta length must be a positive number"))?; + let body_length: u64 = last_block + .body_length + .try_into() + .map_err(|_| Error::oos("The block's body length must be a positive number"))?; + let offset: u64 = offset + meta_data_length + body_length; + + writer.seek(SeekFrom::Start(offset))?; + + Ok(FileWriter { + writer, + options, + schema: metadata.schema, + ipc_fields: metadata.ipc_schema.fields, + block_offsets: offset as usize, + dictionary_blocks: metadata.dictionaries.unwrap_or_default(), + record_blocks: metadata.blocks, + state: State::Started, // file already exists, so we are ready + dictionary_tracker: DictionaryTracker { + dictionaries, + cannot_replace: true, + }, + encoded_message: Default::default(), + }) + } +} diff --git a/crates/nano-arrow/src/io/ipc/compression.rs b/crates/nano-arrow/src/io/ipc/compression.rs new file mode 100644 index 000000000000..9a69deb8248a --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/compression.rs @@ -0,0 +1,91 @@ +use crate::error::Result; + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn decompress_lz4(input_buf: &[u8], output_buf: &mut [u8]) -> Result<()> { + use std::io::Read; + let mut decoder = lz4::Decoder::new(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn decompress_zstd(input_buf: &[u8], output_buf: &mut [u8]) -> Result<()> { + use std::io::Read; + let mut decoder = zstd::Decoder::new(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn decompress_lz4(_input_buf: &[u8], _output_buf: &mut [u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to read compressed IPC.".to_string())) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn decompress_zstd(_input_buf: &[u8], _output_buf: &mut [u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to read compressed IPC.".to_string())) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn compress_lz4(input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { + use std::io::Write; + + use crate::error::Error; + let mut encoder = lz4::EncoderBuilder::new() + .build(output_buf) + .map_err(Error::from)?; + encoder.write_all(input_buf)?; + encoder.finish().1.map_err(|e| e.into()) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn compress_zstd(input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { + zstd::stream::copy_encode(input_buf, output_buf, 0).map_err(|e| e.into()) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn compress_lz4(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string())) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn compress_zstd(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "io_ipc_compression")] + #[test] + #[cfg_attr(miri, ignore)] // ZSTD uses foreign calls that miri does not support + fn round_trip_zstd() { + let data: Vec = (0..200u8).map(|x| x % 10).collect(); + let mut buffer = vec![]; + compress_zstd(&data, &mut buffer).unwrap(); + + let mut result = vec![0; 200]; + decompress_zstd(&buffer, &mut result).unwrap(); + assert_eq!(data, result); + } + + #[cfg(feature = "io_ipc_compression")] + #[test] + #[cfg_attr(miri, ignore)] // LZ4 uses foreign calls that miri does not support + fn round_trip_lz4() { + let data: Vec = (0..200u8).map(|x| x % 10).collect(); + let mut buffer = vec![]; + compress_lz4(&data, &mut buffer).unwrap(); + + let mut result = vec![0; 200]; + decompress_lz4(&buffer, &mut result).unwrap(); + assert_eq!(data, result); + } +} diff --git a/crates/nano-arrow/src/io/ipc/endianness.rs b/crates/nano-arrow/src/io/ipc/endianness.rs new file mode 100644 index 000000000000..61b3f9b7c51c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/endianness.rs @@ -0,0 +1,11 @@ +#[cfg(target_endian = "little")] +#[inline] +pub fn is_native_little_endian() -> bool { + true +} + +#[cfg(target_endian = "big")] +#[inline] +pub fn is_native_little_endian() -> bool { + false +} diff --git a/crates/nano-arrow/src/io/ipc/mod.rs b/crates/nano-arrow/src/io/ipc/mod.rs new file mode 100644 index 000000000000..7da03e5c0abb --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/mod.rs @@ -0,0 +1,104 @@ +//! APIs to read from and write to Arrow's IPC format. +//! +//! Inter-process communication is a method through which different processes +//! share and pass data between them. Its use-cases include parallel +//! processing of chunks of data across different CPU cores, transferring +//! data between different Apache Arrow implementations in other languages and +//! more. Under the hood Apache Arrow uses [FlatBuffers](https://google.github.io/flatbuffers/) +//! as its binary protocol, so every Arrow-centered streaming or serialiation +//! problem that could be solved using FlatBuffers could probably be solved +//! using the more integrated approach that is exposed in this module. +//! +//! [Arrow's IPC protocol](https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc) +//! allows only batch or dictionary columns to be passed +//! around due to its reliance on a pre-defined data scheme. This constraint +//! provides a large performance gain because serialized data will always have a +//! known structutre, i.e. the same fields and datatypes, with the only variance +//! being the number of rows and the actual data inside the Batch. This dramatically +//! increases the deserialization rate, as the bytes in the file or stream are already +//! structured "correctly". +//! +//! Reading and writing IPC messages is done using one of two variants - either +//! [`FileReader`](read::FileReader) <-> [`FileWriter`](struct@write::FileWriter) or +//! [`StreamReader`](read::StreamReader) <-> [`StreamWriter`](struct@write::StreamWriter). +//! These two variants wrap a type `T` that implements [`Read`](std::io::Read), and in +//! the case of the `File` variant it also implements [`Seek`](std::io::Seek). In +//! practice it means that `File`s can be arbitrarily accessed while `Stream`s are only +//! read in certain order - the one they were written in (first in, first out). +//! +//! # Examples +//! Read and write to a file: +//! ``` +//! use arrow2::io::ipc::{{read::{FileReader, read_file_metadata}}, {write::{FileWriter, WriteOptions}}}; +//! # use std::fs::File; +//! # use arrow2::datatypes::{Field, Schema, DataType}; +//! # use arrow2::array::{Int32Array, Array}; +//! # use arrow2::chunk::Chunk; +//! # use arrow2::error::Error; +//! // Setup the writer +//! let path = "example.arrow".to_string(); +//! let mut file = File::create(&path)?; +//! let x_coord = Field::new("x", DataType::Int32, false); +//! let y_coord = Field::new("y", DataType::Int32, false); +//! let schema = Schema::from(vec![x_coord, y_coord]); +//! let options = WriteOptions {compression: None}; +//! let mut writer = FileWriter::try_new(file, schema, None, options)?; +//! +//! // Setup the data +//! let x_data = Int32Array::from_slice([-1i32, 1]); +//! let y_data = Int32Array::from_slice([1i32, -1]); +//! let chunk = Chunk::try_new(vec![x_data.boxed(), y_data.boxed()])?; +//! +//! // Write the messages and finalize the stream +//! for _ in 0..5 { +//! writer.write(&chunk, None); +//! } +//! writer.finish(); +//! +//! // Fetch some of the data and get the reader back +//! let mut reader = File::open(&path)?; +//! let metadata = read_file_metadata(&mut reader)?; +//! let mut reader = FileReader::new(reader, metadata, None, None); +//! let row1 = reader.next().unwrap(); // [[-1, 1], [1, -1]] +//! let row2 = reader.next().unwrap(); // [[-1, 1], [1, -1]] +//! let mut reader = reader.into_inner(); +//! // Do more stuff with the reader, like seeking ahead. +//! # Ok::<(), Error>(()) +//! ``` +//! +//! For further information and examples please consult the +//! [user guide](https://jorgecarleitao.github.io/arrow2/io/index.html). +//! For even more examples check the `examples` folder in the main repository +//! ([1](https://github.com/jorgecarleitao/arrow2/blob/main/examples/ipc_file_read.rs), +//! [2](https://github.com/jorgecarleitao/arrow2/blob/main/examples/ipc_file_write.rs), +//! [3](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow)). + +mod compression; +mod endianness; + +pub mod append; +pub mod read; +pub mod write; + +const ARROW_MAGIC_V1: [u8; 4] = [b'F', b'E', b'A', b'1']; +const ARROW_MAGIC_V2: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +pub(crate) const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Struct containing `dictionary_id` and nested `IpcField`, allowing users +/// to specify the dictionary ids of the IPC fields when writing to IPC. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct IpcField { + /// optional children + pub fields: Vec, + /// dictionary id + pub dictionary_id: Option, +} + +/// Struct containing fields and whether the file is written in little or big endian. +#[derive(Debug, Clone, PartialEq)] +pub struct IpcSchema { + /// The fields in the schema + pub fields: Vec, + /// Endianness of the file + pub is_little_endian: bool, +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/binary.rs b/crates/nano-arrow/src/io/ipc/read/array/binary.rs new file mode 100644 index 000000000000..52a5c4b7b7b0 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/binary.rs @@ -0,0 +1,91 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::BinaryArray; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_binary( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets: Buffer = read_buffer( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + let values = read_buffer( + buffers, + last_offset, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + BinaryArray::::try_new(data_type, offsets.try_into()?, values, validity) +} + +pub fn skip_binary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for binary. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/boolean.rs b/crates/nano-arrow/src/io/ipc/read/array/boolean.rs new file mode 100644 index 000000000000..6d78c184b168 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/boolean.rs @@ -0,0 +1,72 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::BooleanArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_boolean( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let values = read_bitmap( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + BooleanArray::try_new(data_type, values, validity) +} + +pub fn skip_boolean( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for boolean. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/dictionary.rs b/crates/nano-arrow/src/io/ipc/read/array/dictionary.rs new file mode 100644 index 000000000000..554e6d32dcbf --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/dictionary.rs @@ -0,0 +1,65 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek}; + +use ahash::HashSet; + +use super::super::{Compression, Dictionaries, IpcBuffer, Node}; +use super::{read_primitive, skip_primitive}; +use crate::array::{DictionaryArray, DictionaryKey}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_dictionary( + field_nodes: &mut VecDeque, + data_type: DataType, + id: Option, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + compression: Option, + limit: Option, + is_little_endian: bool, + scratch: &mut Vec, +) -> Result> +where + Vec: TryInto, +{ + let id = if let Some(id) = id { + id + } else { + return Err(Error::OutOfSpec("Dictionary has no id.".to_string())); + }; + let values = dictionaries + .get(&id) + .ok_or_else(|| { + let valid_ids = dictionaries.keys().collect::>(); + Error::OutOfSpec(format!( + "Dictionary id {id} not found. Valid ids: {valid_ids:?}" + )) + })? + .clone(); + + let keys = read_primitive( + field_nodes, + T::PRIMITIVE.into(), + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + DictionaryArray::::try_new(data_type, keys, values) +} + +pub fn skip_dictionary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + skip_primitive(field_nodes, buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/fixed_size_binary.rs b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_binary.rs new file mode 100644 index 000000000000..ed0d0049ffb2 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_binary.rs @@ -0,0 +1,76 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::FixedSizeBinaryArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_fixed_size_binary( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let length = length.saturating_mul(FixedSizeBinaryArray::maybe_get_size(&data_type)?); + let values = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + FixedSizeBinaryArray::try_new(data_type, values, validity) +} + +pub fn skip_fixed_size_binary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos( + "IPC: unable to fetch the field for fixed-size binary. The file or stream is corrupted.", + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_list.rs new file mode 100644 index 000000000000..5553c1f478ff --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -0,0 +1,83 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::FixedSizeListArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_fixed_size_list( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let (field, size) = FixedSizeListArray::get_child_and_size(&data_type); + + let limit = limit.map(|x| x.saturating_mul(size)); + + let values = read( + field_nodes, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + )?; + FixedSizeListArray::try_new(data_type, values, validity) +} + +pub fn skip_fixed_size_list( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos( + "IPC: unable to fetch the field for fixed-size list. The file or stream is corrupted.", + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + + let (field, _) = FixedSizeListArray::get_child_and_size(data_type); + + skip(field_nodes, field.data_type(), buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/list.rs b/crates/nano-arrow/src/io/ipc/read/array/list.rs new file mode 100644 index 000000000000..83809cf995c1 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/list.rs @@ -0,0 +1,108 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, OutOfSpecKind, Version}; +use crate::array::ListArray; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_list( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result> +where + Vec: TryInto, +{ + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets = read_buffer::( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + + let field = ListArray::::get_child_field(&data_type); + + let values = read( + field_nodes, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + Some(last_offset), + version, + scratch, + )?; + ListArray::try_new(data_type, offsets.try_into()?, values, validity) +} + +pub fn skip_list( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for list. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + + let data_type = ListArray::::get_child_type(data_type); + + skip(field_nodes, data_type, buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/map.rs b/crates/nano-arrow/src/io/ipc/read/array/map.rs new file mode 100644 index 000000000000..cf383407a8c0 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/map.rs @@ -0,0 +1,103 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, OutOfSpecKind, Version}; +use crate::array::MapArray; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_map( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets = read_buffer::( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![0i32])))?; + + let field = MapArray::get_field(&data_type); + + let last_offset: usize = offsets.last().copied().unwrap() as usize; + + let field = read( + field_nodes, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + Some(last_offset), + version, + scratch, + )?; + MapArray::try_new(data_type, offsets.try_into()?, field, validity) +} + +pub fn skip_map( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for map. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + + let data_type = MapArray::get_field(data_type).data_type(); + + skip(field_nodes, data_type, buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/mod.rs b/crates/nano-arrow/src/io/ipc/read/array/mod.rs new file mode 100644 index 000000000000..249e5e05e165 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/mod.rs @@ -0,0 +1,24 @@ +mod primitive; +pub use primitive::*; +mod boolean; +pub use boolean::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod fixed_size_binary; +pub use fixed_size_binary::*; +mod list; +pub use list::*; +mod fixed_size_list; +pub use fixed_size_list::*; +mod struct_; +pub use struct_::*; +mod null; +pub use null::*; +mod dictionary; +pub use dictionary::*; +mod union; +pub use union::*; +mod map; +pub use map::*; diff --git a/crates/nano-arrow/src/io/ipc/read/array/null.rs b/crates/nano-arrow/src/io/ipc/read/array/null.rs new file mode 100644 index 000000000000..e56f1886112d --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/null.rs @@ -0,0 +1,28 @@ +use std::collections::VecDeque; + +use super::super::{Node, OutOfSpecKind}; +use crate::array::NullArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +pub fn read_null(field_nodes: &mut VecDeque, data_type: DataType) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + NullArray::try_new(data_type, length) +} + +pub fn skip_null(field_nodes: &mut VecDeque) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for null. The file or stream is corrupted.") + })?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/primitive.rs b/crates/nano-arrow/src/io/ipc/read/array/primitive.rs new file mode 100644 index 000000000000..d6ccb581ffe5 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/primitive.rs @@ -0,0 +1,77 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::PrimitiveArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::types::NativeType; + +#[allow(clippy::too_many_arguments)] +pub fn read_primitive( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> +where + Vec: TryInto, +{ + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let values = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + PrimitiveArray::::try_new(data_type, values, validity) +} + +pub fn skip_primitive( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for primitive. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/struct_.rs b/crates/nano-arrow/src/io/ipc/read/array/struct_.rs new file mode 100644 index 000000000000..9a5084a8783f --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/struct_.rs @@ -0,0 +1,88 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::StructArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_struct( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let fields = StructArray::get_fields(&data_type); + + let values = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { + read( + field_nodes, + field, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + }) + .collect::>>()?; + + StructArray::try_new(data_type, values, validity) +} + +pub fn skip_struct( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for struct. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + + let fields = StructArray::get_fields(data_type); + + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.data_type(), buffers)) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/union.rs b/crates/nano-arrow/src/io/ipc/read/array/union.rs new file mode 100644 index 000000000000..ac1eb9b02527 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/union.rs @@ -0,0 +1,125 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, OutOfSpecKind, Version}; +use crate::array::UnionArray; +use crate::datatypes::DataType; +use crate::datatypes::UnionMode::Dense; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_union( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + if version != Version::V5 { + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + }; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let types = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + let offsets = if let DataType::Union(_, _, mode) = data_type { + if !mode.is_sparse() { + Some(read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?) + } else { + None + } + } else { + unreachable!() + }; + + let fields = UnionArray::get_fields(&data_type); + + let fields = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { + read( + field_nodes, + field, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + None, + version, + scratch, + ) + }) + .collect::>>()?; + + UnionArray::try_new(data_type, types, fields, offsets) +} + +pub fn skip_union( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for struct. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + if let DataType::Union(_, _, Dense) = data_type { + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + } else { + unreachable!() + }; + + let fields = UnionArray::get_fields(data_type); + + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.data_type(), buffers)) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/utf8.rs b/crates/nano-arrow/src/io/ipc/read/array/utf8.rs new file mode 100644 index 000000000000..21e54480e48e --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/utf8.rs @@ -0,0 +1,92 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::Utf8Array; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_utf8( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets: Buffer = read_buffer( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + let values = read_buffer( + buffers, + last_offset, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + Utf8Array::::try_new(data_type, offsets.try_into()?, values, validity) +} + +pub fn skip_utf8( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for utf8. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/common.rs b/crates/nano-arrow/src/io/ipc/read/common.rs new file mode 100644 index 000000000000..f890562ed41c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/common.rs @@ -0,0 +1,363 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use ahash::AHashMap; +use arrow_format; + +use super::deserialize::{read, skip}; +use super::Dictionaries; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; +use crate::io::ipc::read::OutOfSpecKind; +use crate::io::ipc::{IpcField, IpcSchema}; + +#[derive(Debug, Eq, PartialEq, Hash)] +enum ProjectionResult { + Selected(A), + NotSelected(A), +} + +/// An iterator adapter that will return `Some(x)` or `None` +/// # Panics +/// The iterator panics iff the `projection` is not strictly increasing. +struct ProjectionIter<'a, A, I: Iterator> { + projection: &'a [usize], + iter: I, + current_count: usize, + current_projection: usize, +} + +impl<'a, A, I: Iterator> ProjectionIter<'a, A, I> { + /// # Panics + /// iff `projection` is empty + pub fn new(projection: &'a [usize], iter: I) -> Self { + Self { + projection: &projection[1..], + iter, + current_count: 0, + current_projection: projection[0], + } + } +} + +impl<'a, A, I: Iterator> Iterator for ProjectionIter<'a, A, I> { + type Item = ProjectionResult; + + fn next(&mut self) -> Option { + if let Some(item) = self.iter.next() { + let result = if self.current_count == self.current_projection { + if !self.projection.is_empty() { + assert!(self.projection[0] > self.current_projection); + self.current_projection = self.projection[0]; + self.projection = &self.projection[1..]; + } else { + self.current_projection = 0 // a value that most likely already passed + }; + Some(ProjectionResult::Selected(item)) + } else { + Some(ProjectionResult::NotSelected(item)) + }; + self.current_count += 1; + result + } else { + None + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// Returns a [`Chunk`] from a reader. +/// # Panic +/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) +#[allow(clippy::too_many_arguments)] +pub fn read_record_batch( + batch: arrow_format::ipc::RecordBatchRef, + fields: &[Field], + ipc_schema: &IpcSchema, + projection: Option<&[usize]>, + limit: Option, + dictionaries: &Dictionaries, + version: arrow_format::ipc::MetadataVersion, + reader: &mut R, + block_offset: u64, + file_size: u64, + scratch: &mut Vec, +) -> Result>> { + assert_eq!(fields.len(), ipc_schema.fields.len()); + let buffers = batch + .buffers() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageBuffers))?; + let mut buffers: VecDeque = buffers.iter().collect(); + + // check that the sum of the sizes of all buffers is <= than the size of the file + let buffers_size = buffers + .iter() + .map(|buffer| { + let buffer_size: u64 = buffer + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + Ok(buffer_size) + }) + .sum::>()?; + if buffers_size > file_size { + return Err(Error::from(OutOfSpecKind::InvalidBuffersLength { + buffers_size, + file_size, + })); + } + + let field_nodes = batch + .nodes() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageNodes))?; + let mut field_nodes = field_nodes.iter().collect::>(); + + let columns = if let Some(projection) = projection { + let projection = + ProjectionIter::new(projection, fields.iter().zip(ipc_schema.fields.iter())); + + projection + .map(|maybe_field| match maybe_field { + ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read( + &mut field_nodes, + field, + ipc_field, + &mut buffers, + reader, + dictionaries, + block_offset, + ipc_schema.is_little_endian, + batch.compression().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, + limit, + version, + scratch, + )?)), + ProjectionResult::NotSelected((field, _)) => { + skip(&mut field_nodes, &field.data_type, &mut buffers)?; + Ok(None) + }, + }) + .filter_map(|x| x.transpose()) + .collect::>>()? + } else { + fields + .iter() + .zip(ipc_schema.fields.iter()) + .map(|(field, ipc_field)| { + read( + &mut field_nodes, + field, + ipc_field, + &mut buffers, + reader, + dictionaries, + block_offset, + ipc_schema.is_little_endian, + batch.compression().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, + limit, + version, + scratch, + ) + }) + .collect::>>()? + }; + Chunk::try_new(columns) +} + +fn find_first_dict_field_d<'a>( + id: i64, + data_type: &'a DataType, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + use DataType::*; + match data_type { + Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field), + List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => { + find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0]) + }, + Union(fields, ..) | Struct(fields) => { + for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) { + if let Some(f) = find_first_dict_field(id, field, ipc_field) { + return Some(f); + } + } + None + }, + _ => None, + } +} + +fn find_first_dict_field<'a>( + id: i64, + field: &'a Field, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + if let Some(field_id) = ipc_field.dictionary_id { + if id == field_id { + return Some((field, ipc_field)); + } + } + find_first_dict_field_d(id, &field.data_type, ipc_field) +} + +pub(crate) fn first_dict_field<'a>( + id: i64, + fields: &'a [Field], + ipc_fields: &'a [IpcField], +) -> Result<(&'a Field, &'a IpcField)> { + assert_eq!(fields.len(), ipc_fields.len()); + for (field, ipc_field) in fields.iter().zip(ipc_fields.iter()) { + if let Some(field) = find_first_dict_field(id, field, ipc_field) { + return Ok(field); + } + } + Err(Error::from(OutOfSpecKind::InvalidId { requested_id: id })) +} + +/// Reads a dictionary from the reader, +/// updating `dictionaries` with the resulting dictionary +#[allow(clippy::too_many_arguments)] +pub fn read_dictionary( + batch: arrow_format::ipc::DictionaryBatchRef, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, + reader: &mut R, + block_offset: u64, + file_size: u64, + scratch: &mut Vec, +) -> Result<()> { + if batch + .is_delta() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferIsDelta(err)))? + { + return Err(Error::NotYetImplemented( + "delta dictionary batches not supported".to_string(), + )); + } + + let id = batch + .id() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferId(err)))?; + let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?; + + let batch = batch + .data() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingData))?; + + let value_type = + if let DataType::Dictionary(_, value_type, _) = first_field.data_type.to_logical_type() { + value_type.as_ref() + } else { + return Err(Error::from(OutOfSpecKind::InvalidIdDataType { + requested_id: id, + })); + }; + + // Make a fake schema for the dictionary batch. + let fields = vec![Field::new("", value_type.clone(), false)]; + let ipc_schema = IpcSchema { + fields: vec![first_ipc_field.clone()], + is_little_endian: ipc_schema.is_little_endian, + }; + let chunk = read_record_batch( + batch, + &fields, + &ipc_schema, + None, + None, // we must read the whole dictionary + dictionaries, + arrow_format::ipc::MetadataVersion::V5, + reader, + block_offset, + file_size, + scratch, + )?; + + dictionaries.insert(id, chunk.into_arrays().pop().unwrap()); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn project_iter() { + let iter = 1..6; + let iter = ProjectionIter::new(&[0, 2, 4], iter); + let result: Vec<_> = iter.collect(); + use ProjectionResult::*; + assert_eq!( + result, + vec![ + Selected(1), + NotSelected(2), + Selected(3), + NotSelected(4), + Selected(5) + ] + ) + } +} + +pub fn prepare_projection( + fields: &[Field], + mut projection: Vec, +) -> (Vec, AHashMap, Vec) { + let fields = projection.iter().map(|x| fields[*x].clone()).collect(); + + // todo: find way to do this more efficiently + let mut indices = (0..projection.len()).collect::>(); + indices.sort_unstable_by_key(|&i| &projection[i]); + let map = indices.iter().copied().enumerate().fold( + AHashMap::default(), + |mut acc, (index, new_index)| { + acc.insert(index, new_index); + acc + }, + ); + projection.sort_unstable(); + + // check unique + if !projection.is_empty() { + let mut previous = projection[0]; + + for &i in &projection[1..] { + assert!( + previous < i, + "The projection on IPC must not contain duplicates" + ); + previous = i; + } + } + + (projection, map, fields) +} + +pub fn apply_projection( + chunk: Chunk>, + map: &AHashMap, +) -> Chunk> { + // re-order according to projection + let arrays = chunk.into_arrays(); + let mut new_arrays = arrays.clone(); + + map.iter() + .for_each(|(old, new)| new_arrays[*new] = arrays[*old].clone()); + + Chunk::new(new_arrays) +} diff --git a/crates/nano-arrow/src/io/ipc/read/deserialize.rs b/crates/nano-arrow/src/io/ipc/read/deserialize.rs new file mode 100644 index 000000000000..28f8b9e68191 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/deserialize.rs @@ -0,0 +1,251 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use arrow_format::ipc::{BodyCompressionRef, MetadataVersion}; + +use super::array::*; +use super::{Dictionaries, IpcBuffer, Node}; +use crate::array::*; +use crate::datatypes::{DataType, Field, PhysicalType}; +use crate::error::Result; +use crate::io::ipc::IpcField; + +#[allow(clippy::too_many_arguments)] +pub fn read( + field_nodes: &mut VecDeque, + field: &Field, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: MetadataVersion, + scratch: &mut Vec, +) -> Result> { + use PhysicalType::*; + let data_type = field.data_type.clone(); + + match data_type.to_physical_type() { + Null => read_null(field_nodes, data_type).map(|x| x.boxed()), + Boolean => read_boolean( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + read_primitive::<$T, _>( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()) + }), + Binary => read_binary::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + LargeBinary => read_binary::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + FixedSizeBinary => read_fixed_size_binary( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + Utf8 => read_utf8::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + LargeUtf8 => read_utf8::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + List => read_list::( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + LargeList => read_list::( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + FixedSizeList => read_fixed_size_list( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Struct => read_struct( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + read_dictionary::<$T, _>( + field_nodes, + data_type, + ipc_field.dictionary_id, + buffers, + reader, + dictionaries, + block_offset, + compression, + limit, + is_little_endian, + scratch, + ) + .map(|x| x.boxed()) + }) + }, + Union => read_union( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Map => read_map( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + } +} + +pub fn skip( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + use PhysicalType::*; + match data_type.to_physical_type() { + Null => skip_null(field_nodes), + Boolean => skip_boolean(field_nodes, buffers), + Primitive(_) => skip_primitive(field_nodes, buffers), + LargeBinary | Binary => skip_binary(field_nodes, buffers), + LargeUtf8 | Utf8 => skip_utf8(field_nodes, buffers), + FixedSizeBinary => skip_fixed_size_binary(field_nodes, buffers), + List => skip_list::(field_nodes, data_type, buffers), + LargeList => skip_list::(field_nodes, data_type, buffers), + FixedSizeList => skip_fixed_size_list(field_nodes, data_type, buffers), + Struct => skip_struct(field_nodes, data_type, buffers), + Dictionary(_) => skip_dictionary(field_nodes, buffers), + Union => skip_union(field_nodes, data_type, buffers), + Map => skip_map(field_nodes, data_type, buffers), + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/error.rs b/crates/nano-arrow/src/io/ipc/read/error.rs new file mode 100644 index 000000000000..cbac69aef2e3 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/error.rs @@ -0,0 +1,112 @@ +use crate::error::Error; + +/// The different types of errors that reading from IPC can cause +#[derive(Debug)] +#[non_exhaustive] +pub enum OutOfSpecKind { + /// The IPC file does not start with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidHeader, + /// The IPC file does not end with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidFooter, + /// The first 4 bytes of the last 10 bytes is < 0 + NegativeFooterLength, + /// The footer is an invalid flatbuffer + InvalidFlatbufferFooter(arrow_format::ipc::planus::Error), + /// The file's footer does not contain record batches + MissingRecordBatches, + /// The footer's record batches is an invalid flatbuffer + InvalidFlatbufferRecordBatches(arrow_format::ipc::planus::Error), + /// The file's footer does not contain a schema + MissingSchema, + /// The footer's schema is an invalid flatbuffer + InvalidFlatbufferSchema(arrow_format::ipc::planus::Error), + /// The file's schema does not contain fields + MissingFields, + /// The footer's dictionaries is an invalid flatbuffer + InvalidFlatbufferDictionaries(arrow_format::ipc::planus::Error), + /// The block is an invalid flatbuffer + InvalidFlatbufferBlock(arrow_format::ipc::planus::Error), + /// The dictionary message is an invalid flatbuffer + InvalidFlatbufferMessage(arrow_format::ipc::planus::Error), + /// The message does not contain a header + MissingMessageHeader, + /// The message's header is an invalid flatbuffer + InvalidFlatbufferHeader(arrow_format::ipc::planus::Error), + /// Relative positions in the file is < 0 + UnexpectedNegativeInteger, + /// dictionaries can only contain dictionary messages; record batches can only contain records + UnexpectedMessageType, + /// RecordBatch messages do not contain buffers + MissingMessageBuffers, + /// The message's buffers is an invalid flatbuffer + InvalidFlatbufferBuffers(arrow_format::ipc::planus::Error), + /// RecordBatch messages does not contain nodes + MissingMessageNodes, + /// The message's nodes is an invalid flatbuffer + InvalidFlatbufferNodes(arrow_format::ipc::planus::Error), + /// The message's body length is an invalid flatbuffer + InvalidFlatbufferBodyLength(arrow_format::ipc::planus::Error), + /// The message does not contain data + MissingData, + /// The message's data is an invalid flatbuffer + InvalidFlatbufferData(arrow_format::ipc::planus::Error), + /// The version is an invalid flatbuffer + InvalidFlatbufferVersion(arrow_format::ipc::planus::Error), + /// The compression is an invalid flatbuffer + InvalidFlatbufferCompression(arrow_format::ipc::planus::Error), + /// The record contains a number of buffers that does not match the required number by the data type + ExpectedBuffer, + /// A buffer's size is smaller than the required for the number of elements + InvalidBuffer { + /// Declared number of elements in the buffer + length: usize, + /// The name of the `NativeType` + type_name: &'static str, + /// Bytes required for the `length` and `type` + required_number_of_bytes: usize, + /// The size of the IPC buffer + buffer_length: usize, + }, + /// A buffer's size is larger than the file size + InvalidBuffersLength { + /// number of bytes of all buffers in the record + buffers_size: u64, + /// the size of the file + file_size: u64, + }, + /// A bitmap's size is smaller than the required for the number of elements + InvalidBitmap { + /// Declared length of the bitmap + length: usize, + /// Number of bits on the IPC buffer + number_of_bits: usize, + }, + /// The dictionary is_delta is an invalid flatbuffer + InvalidFlatbufferIsDelta(arrow_format::ipc::planus::Error), + /// The dictionary id is an invalid flatbuffer + InvalidFlatbufferId(arrow_format::ipc::planus::Error), + /// Invalid dictionary id + InvalidId { + /// The requested dictionary id + requested_id: i64, + }, + /// Field id is not a dictionary + InvalidIdDataType { + /// The requested dictionary id + requested_id: i64, + }, + /// FixedSizeBinaryArray has invalid datatype. + InvalidDataType, +} + +impl From for Error { + fn from(kind: OutOfSpecKind) -> Self { + Error::OutOfSpec(format!("{kind:?}")) + } +} + +impl From for Error { + fn from(error: arrow_format::ipc::planus::Error) -> Self { + Error::OutOfSpec(error.to_string()) + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/file.rs b/crates/nano-arrow/src/io/ipc/read/file.rs new file mode 100644 index 000000000000..ec0084a08614 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/file.rs @@ -0,0 +1,321 @@ +use std::convert::TryInto; +use std::io::{Read, Seek, SeekFrom}; + +use ahash::AHashMap; +use arrow_format::ipc::planus::ReadAsRoot; + +use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER}; +use super::common::*; +use super::schema::fb_to_schema; +use super::{Dictionaries, OutOfSpecKind}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; +use crate::io::ipc::IpcSchema; + +/// Metadata of an Arrow IPC file, written in the footer of the file. +#[derive(Debug, Clone)] +pub struct FileMetadata { + /// The schema that is read from the file footer + pub schema: Schema, + + /// The files' [`IpcSchema`] + pub ipc_schema: IpcSchema, + + /// The blocks in the file + /// + /// A block indicates the regions in the file to read to get data + pub blocks: Vec, + + /// Dictionaries associated to each dict_id + pub(crate) dictionaries: Option>, + + /// The total size of the file in bytes + pub size: u64, +} + +fn read_dictionary_message( + reader: &mut R, + offset: u64, + data: &mut Vec, +) -> Result<()> { + let mut message_size: [u8; 4] = [0; 4]; + reader.seek(SeekFrom::Start(offset))?; + reader.read_exact(&mut message_size)?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size)?; + }; + let message_length = i32::from_le_bytes(message_size); + + let message_length: usize = message_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + data.clear(); + data.try_reserve(message_length)?; + reader + .by_ref() + .take(message_length as u64) + .read_to_end(data)?; + + Ok(()) +} + +pub(crate) fn get_dictionary_batch<'a>( + message: &'a arrow_format::ipc::MessageRef, +) -> Result> { + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + match header { + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => Ok(batch), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +fn read_dictionary_block( + reader: &mut R, + metadata: &FileMetadata, + block: &arrow_format::ipc::Block, + dictionaries: &mut Dictionaries, + message_scratch: &mut Vec, + dictionary_scratch: &mut Vec, +) -> Result<()> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + read_dictionary_message(reader, offset, message_scratch)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_scratch.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let batch = get_dictionary_batch(&message)?; + + read_dictionary( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + dictionaries, + reader, + offset + length, + metadata.size, + dictionary_scratch, + ) +} + +/// Reads all file's dictionaries, if any +/// This function is IO-bounded +pub fn read_file_dictionaries( + reader: &mut R, + metadata: &FileMetadata, + scratch: &mut Vec, +) -> Result { + let mut dictionaries = Default::default(); + + let blocks = if let Some(blocks) = &metadata.dictionaries { + blocks + } else { + return Ok(AHashMap::new()); + }; + // use a temporary smaller scratch for the messages + let mut message_scratch = Default::default(); + + for block in blocks { + read_dictionary_block( + reader, + metadata, + block, + &mut dictionaries, + &mut message_scratch, + scratch, + )?; + } + Ok(dictionaries) +} + +/// Reads the footer's length and magic number in footer +fn read_footer_len(reader: &mut R) -> Result<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10))? + 10; + + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer)?; + let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); + + if footer[4..] != ARROW_MAGIC_V2 { + return Err(Error::from(OutOfSpecKind::InvalidFooter)); + } + let footer_len = footer_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok((end, footer_len)) +} + +pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> Result { + let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + + let blocks = footer + .record_batches() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingRecordBatches))?; + + let blocks = blocks + .iter() + .map(|block| { + block + .try_into() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err))) + }) + .collect::>>()?; + + let ipc_schema = footer + .schema() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferSchema(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingSchema))?; + let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; + + let dictionaries = footer + .dictionaries() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferDictionaries(err)))? + .map(|dictionaries| { + dictionaries + .into_iter() + .map(|block| { + block.try_into().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) + }) + .collect::>>() + }) + .transpose()?; + + Ok(FileMetadata { + schema, + ipc_schema, + blocks, + dictionaries, + size, + }) +} + +/// Read the Arrow IPC file's metadata +pub fn read_file_metadata(reader: &mut R) -> Result { + // check if header contain the correct magic bytes + let mut magic_buffer: [u8; 6] = [0; 6]; + let start = reader.stream_position()?; + reader.read_exact(&mut magic_buffer)?; + if magic_buffer != ARROW_MAGIC_V2 { + if magic_buffer[..4] == ARROW_MAGIC_V1 { + return Err(Error::NotYetImplemented("feather v1 not supported".into())); + } + return Err(Error::from(OutOfSpecKind::InvalidHeader)); + } + + let (end, footer_len) = read_footer_len(reader)?; + + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64))?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + reader + .by_ref() + .take(footer_len as u64) + .read_to_end(&mut serialized_footer)?; + + deserialize_footer(&serialized_footer, end - start) +} + +pub(crate) fn get_record_batch( + message: arrow_format::ipc::MessageRef, +) -> Result { + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => Ok(batch), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +/// Reads the record batch at position `index` from the reader. +/// +/// This function is useful for random access to the file. For example, if +/// you have indexed the file somewhere else, this allows pruning +/// certain parts of the file. +/// # Panics +/// This function panics iff `index >= metadata.blocks.len()` +#[allow(clippy::too_many_arguments)] +pub fn read_batch( + reader: &mut R, + dictionaries: &Dictionaries, + metadata: &FileMetadata, + projection: Option<&[usize]>, + limit: Option, + index: usize, + message_scratch: &mut Vec, + data_scratch: &mut Vec, +) -> Result>> { + let block = metadata.blocks[index]; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + // read length + reader.seek(SeekFrom::Start(offset))?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf)?; + if meta_buf == CONTINUATION_MARKER { + // continuation marker encountered, read message next + reader.read_exact(&mut meta_buf)?; + } + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + message_scratch.clear(); + message_scratch.try_reserve(meta_len)?; + reader + .by_ref() + .take(meta_len as u64) + .read_to_end(message_scratch)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_scratch.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let batch = get_record_batch(message)?; + + read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection, + limit, + dictionaries, + message + .version() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferVersion(err)))?, + reader, + offset + length, + metadata.size, + data_scratch, + ) +} diff --git a/crates/nano-arrow/src/io/ipc/read/file_async.rs b/crates/nano-arrow/src/io/ipc/read/file_async.rs new file mode 100644 index 000000000000..df1895021282 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/file_async.rs @@ -0,0 +1,349 @@ +//! Async reader for Arrow IPC files +use std::io::SeekFrom; + +use ahash::AHashMap; +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, MessageHeaderRef}; +use futures::stream::BoxStream; +use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt}; + +use super::common::{apply_projection, prepare_projection, read_dictionary, read_record_batch}; +use super::file::{deserialize_footer, get_record_batch}; +use super::{Dictionaries, FileMetadata, OutOfSpecKind}; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::{Field, Schema}; +use crate::error::{Error, Result}; +use crate::io::ipc::{IpcSchema, ARROW_MAGIC_V2, CONTINUATION_MARKER}; + +/// Async reader for Arrow IPC files +pub struct FileStream<'a> { + stream: BoxStream<'a, Result>>>, + schema: Option, + metadata: FileMetadata, +} + +impl<'a> FileStream<'a> { + /// Create a new IPC file reader. + /// + /// # Examples + /// See [`FileSink`](crate::io::ipc::write::file_async::FileSink). + pub fn new( + reader: R, + metadata: FileMetadata, + projection: Option>, + limit: Option, + ) -> Self + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + let (projection, schema) = if let Some(projection) = projection { + let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); + let schema = Schema { + fields, + metadata: metadata.schema.metadata.clone(), + }; + (Some((p, h)), Some(schema)) + } else { + (None, None) + }; + + let stream = Self::stream(reader, None, metadata.clone(), projection, limit); + Self { + stream, + metadata, + schema, + } + } + + /// Get the metadata from the IPC file. + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Get the projected schema from the IPC file. + pub fn schema(&self) -> &Schema { + self.schema.as_ref().unwrap_or(&self.metadata.schema) + } + + fn stream( + mut reader: R, + mut dictionaries: Option, + metadata: FileMetadata, + projection: Option<(Vec, AHashMap)>, + limit: Option, + ) -> BoxStream<'a, Result>>> + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + async_stream::try_stream! { + // read dictionaries + cached_read_dictionaries(&mut reader, &metadata, &mut dictionaries).await?; + + let mut meta_buffer = Default::default(); + let mut block_buffer = Default::default(); + let mut scratch = Default::default(); + let mut remaining = limit.unwrap_or(usize::MAX); + for block in 0..metadata.blocks.len() { + let chunk = read_batch( + &mut reader, + dictionaries.as_mut().unwrap(), + &metadata, + projection.as_ref().map(|x| x.0.as_ref()), + Some(remaining), + block, + &mut meta_buffer, + &mut block_buffer, + &mut scratch + ).await?; + remaining -= chunk.len(); + + let chunk = if let Some((_, map)) = &projection { + // re-order according to projection + apply_projection(chunk, map) + } else { + chunk + }; + + yield chunk; + } + } + .boxed() + } +} + +impl<'a> Stream for FileStream<'a> { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().stream.poll_next_unpin(cx) + } +} + +/// Reads the footer's length and magic number in footer +async fn read_footer_len(reader: &mut R) -> Result { + // read footer length and magic number in footer + reader.seek(SeekFrom::End(-10)).await?; + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer).await?; + let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); + + if footer[4..] != ARROW_MAGIC_V2 { + return Err(Error::from(OutOfSpecKind::InvalidFooter)); + } + footer_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength)) +} + +/// Read the metadata from an IPC file. +pub async fn read_file_metadata_async(reader: &mut R) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let footer_size = read_footer_len(reader).await?; + // Read footer + reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; + + let mut footer = vec![]; + footer.try_reserve(footer_size)?; + reader + .take(footer_size as u64) + .read_to_end(&mut footer) + .await?; + + deserialize_footer(&footer, u64::MAX) +} + +#[allow(clippy::too_many_arguments)] +async fn read_batch( + mut reader: R, + dictionaries: &mut Dictionaries, + metadata: &FileMetadata, + projection: Option<&[usize]>, + limit: Option, + block: usize, + meta_buffer: &mut Vec, + block_buffer: &mut Vec, + scratch: &mut Vec, +) -> Result>> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let block = metadata.blocks[block]; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(offset)).await?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf).await?; + if meta_buf == CONTINUATION_MARKER { + reader.read_exact(&mut meta_buf).await?; + } + + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + meta_buffer.clear(); + meta_buffer.try_reserve(meta_len)?; + (&mut reader) + .take(meta_len as u64) + .read_to_end(meta_buffer) + .await?; + + let message = arrow_format::ipc::MessageRef::read_as_root(meta_buffer) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let batch = get_record_batch(message)?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + block_buffer.clear(); + block_buffer.try_reserve(block_length)?; + reader + .take(block_length as u64) + .read_to_end(block_buffer) + .await?; + + let mut cursor = std::io::Cursor::new(&block_buffer); + + read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection, + limit, + dictionaries, + message + .version() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferVersion(err)))?, + &mut cursor, + 0, + metadata.size, + scratch, + ) +} + +async fn read_dictionaries( + mut reader: R, + fields: &[Field], + ipc_schema: &IpcSchema, + blocks: &[Block], + scratch: &mut Vec, +) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut dictionaries = Default::default(); + let mut data: Vec = vec![]; + let mut buffer: Vec = vec![]; + + for block in blocks { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: usize = block + .body_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + read_dictionary_message(&mut reader, offset, &mut data).await?; + + let message = arrow_format::ipc::MessageRef::read_as_root(data.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + match header { + MessageHeaderRef::DictionaryBatch(batch) => { + buffer.clear(); + buffer.try_reserve(length)?; + (&mut reader) + .take(length as u64) + .read_to_end(&mut buffer) + .await?; + let mut cursor = std::io::Cursor::new(&buffer); + read_dictionary( + batch, + fields, + ipc_schema, + &mut dictionaries, + &mut cursor, + 0, + u64::MAX, + scratch, + )?; + }, + _ => return Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } + } + Ok(dictionaries) +} + +async fn read_dictionary_message(mut reader: R, offset: u64, data: &mut Vec) -> Result<()> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut message_size = [0; 4]; + reader.seek(SeekFrom::Start(offset)).await?; + reader.read_exact(&mut message_size).await?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + } + let footer_size = i32::from_le_bytes(message_size); + + let footer_size: usize = footer_size + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + data.clear(); + data.try_reserve(footer_size)?; + (&mut reader) + .take(footer_size as u64) + .read_to_end(data) + .await?; + + Ok(()) +} + +async fn cached_read_dictionaries( + reader: &mut R, + metadata: &FileMetadata, + dictionaries: &mut Option, +) -> Result<()> { + match (&dictionaries, metadata.dictionaries.as_deref()) { + (None, Some(blocks)) => { + let new_dictionaries = read_dictionaries( + reader, + &metadata.schema.fields, + &metadata.ipc_schema, + blocks, + &mut Default::default(), + ) + .await?; + *dictionaries = Some(new_dictionaries); + }, + (None, None) => { + *dictionaries = Some(Default::default()); + }, + _ => {}, + }; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/mod.rs b/crates/nano-arrow/src/io/ipc/read/mod.rs new file mode 100644 index 000000000000..887cf7b36258 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/mod.rs @@ -0,0 +1,45 @@ +//! APIs to read Arrow's IPC format. +//! +//! The two important structs here are the [`FileReader`](reader::FileReader), +//! which provides arbitrary access to any of its messages, and the +//! [`StreamReader`](stream::StreamReader), which only supports reading +//! data in the order it was written in. +use ahash::AHashMap; + +use crate::array::Array; + +mod array; +mod common; +mod deserialize; +mod error; +pub(crate) mod file; +mod read_basic; +mod reader; +mod schema; +mod stream; + +pub use error::OutOfSpecKind; + +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod stream_async; + +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod file_async; + +pub(crate) use common::first_dict_field; +#[cfg(feature = "io_flight")] +pub(crate) use common::{read_dictionary, read_record_batch}; +pub use file::{read_batch, read_file_dictionaries, read_file_metadata, FileMetadata}; +pub use reader::FileReader; +pub use schema::deserialize_schema; +pub use stream::{read_stream_metadata, StreamMetadata, StreamReader, StreamState}; + +/// how dictionaries are tracked in this crate +pub type Dictionaries = AHashMap>; + +pub(crate) type Node<'a> = arrow_format::ipc::FieldNodeRef<'a>; +pub(crate) type IpcBuffer<'a> = arrow_format::ipc::BufferRef<'a>; +pub(crate) type Compression<'a> = arrow_format::ipc::BodyCompressionRef<'a>; +pub(crate) type Version = arrow_format::ipc::MetadataVersion; diff --git a/crates/nano-arrow/src/io/ipc/read/read_basic.rs b/crates/nano-arrow/src/io/ipc/read/read_basic.rs new file mode 100644 index 000000000000..a56ebc81b3c4 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/read_basic.rs @@ -0,0 +1,291 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek, SeekFrom}; + +use super::super::compression; +use super::super::endianness::is_native_little_endian; +use super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::error::{Error, Result}; +use crate::types::NativeType; + +fn read_swapped( + reader: &mut R, + length: usize, + buffer: &mut Vec, + is_little_endian: bool, +) -> Result<()> { + // slow case where we must reverse bits + let mut slice = vec![0u8; length * std::mem::size_of::()]; + reader.read_exact(&mut slice)?; + + let chunks = slice.chunks_exact(std::mem::size_of::()); + if !is_little_endian { + // machine is little endian, file is big endian + buffer + .as_mut_slice() + .iter_mut() + .zip(chunks) + .try_for_each(|(slot, chunk)| { + let a: T::Bytes = match chunk.try_into() { + Ok(a) => a, + Err(_) => unreachable!(), + }; + *slot = T::from_be_bytes(a); + Result::Ok(()) + })?; + } else { + // machine is big endian, file is little endian + return Err(Error::NotYetImplemented( + "Reading little endian files from big endian machines".to_string(), + )); + } + Ok(()) +} + +fn read_uncompressed_buffer( + reader: &mut R, + buffer_length: usize, + length: usize, + is_little_endian: bool, +) -> Result> { + let required_number_of_bytes = length.saturating_mul(std::mem::size_of::()); + if required_number_of_bytes > buffer_length { + return Err(Error::from(OutOfSpecKind::InvalidBuffer { + length, + type_name: std::any::type_name::(), + required_number_of_bytes, + buffer_length, + })); + // todo: move this to the error's Display + /* + return Err(Error::OutOfSpec( + format!("The slots of the array times the physical size must \ + be smaller or equal to the length of the IPC buffer. \ + However, this array reports {} slots, which, for physical type \"{}\", corresponds to {} bytes, \ + which is larger than the buffer length {}", + length, + std::any::type_name::(), + bytes, + buffer_length, + ), + )); + */ + } + + // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + + if is_native_little_endian() == is_little_endian { + // fast case where we can just copy the contents + let slice = bytemuck::cast_slice_mut(&mut buffer); + reader.read_exact(slice)?; + } else { + read_swapped(reader, length, &mut buffer, is_little_endian)?; + } + Ok(buffer) +} + +fn read_compressed_buffer( + reader: &mut R, + buffer_length: usize, + length: usize, + is_little_endian: bool, + compression: Compression, + scratch: &mut Vec, +) -> Result> { + if is_little_endian != is_native_little_endian() { + return Err(Error::NotYetImplemented( + "Reading compressed and big endian IPC".to_string(), + )); + } + + // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + + // decompress first + scratch.clear(); + scratch.try_reserve(buffer_length)?; + reader + .by_ref() + .take(buffer_length as u64) + .read_to_end(scratch)?; + + let out_slice = bytemuck::cast_slice_mut(&mut buffer); + + let compression = compression + .codec() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { + arrow_format::ipc::CompressionType::Lz4Frame => { + compression::decompress_lz4(&scratch[8..], out_slice)?; + }, + arrow_format::ipc::CompressionType::Zstd => { + compression::decompress_zstd(&scratch[8..], out_slice)?; + }, + } + Ok(buffer) +} + +pub fn read_buffer( + buf: &mut VecDeque, + length: usize, // in slots + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + scratch: &mut Vec, +) -> Result> { + let buf = buf + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let buffer_length: usize = buf + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + if let Some(compression) = compression { + Ok(read_compressed_buffer( + reader, + buffer_length, + length, + is_little_endian, + compression, + scratch, + )? + .into()) + } else { + Ok(read_uncompressed_buffer(reader, buffer_length, length, is_little_endian)?.into()) + } +} + +fn read_uncompressed_bitmap( + length: usize, + bytes: usize, + reader: &mut R, +) -> Result> { + if length > bytes * 8 { + return Err(Error::from(OutOfSpecKind::InvalidBitmap { + length, + number_of_bits: bytes * 8, + })); + } + + let mut buffer = vec![]; + buffer.try_reserve(bytes)?; + reader + .by_ref() + .take(bytes as u64) + .read_to_end(&mut buffer)?; + + Ok(buffer) +} + +fn read_compressed_bitmap( + length: usize, + bytes: usize, + compression: Compression, + reader: &mut R, + scratch: &mut Vec, +) -> Result> { + let mut buffer = vec![0; (length + 7) / 8]; + + scratch.clear(); + scratch.try_reserve(bytes)?; + reader.by_ref().take(bytes as u64).read_to_end(scratch)?; + + let compression = compression + .codec() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { + arrow_format::ipc::CompressionType::Lz4Frame => { + compression::decompress_lz4(&scratch[8..], &mut buffer)?; + }, + arrow_format::ipc::CompressionType::Zstd => { + compression::decompress_zstd(&scratch[8..], &mut buffer)?; + }, + } + Ok(buffer) +} + +pub fn read_bitmap( + buf: &mut VecDeque, + length: usize, + reader: &mut R, + block_offset: u64, + _: bool, + compression: Option, + scratch: &mut Vec, +) -> Result { + let buf = buf + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let bytes: usize = buf + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + let buffer = if let Some(compression) = compression { + read_compressed_bitmap(length, bytes, compression, reader, scratch) + } else { + read_uncompressed_bitmap(length, bytes, reader) + }?; + + Bitmap::try_new(buffer, length) +} + +#[allow(clippy::too_many_arguments)] +pub fn read_validity( + buffers: &mut VecDeque, + field_node: Node, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> { + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + Ok(if field_node.null_count() > 0 { + Some(read_bitmap( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?) + } else { + let _ = buffers + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + None + }) +} diff --git a/crates/nano-arrow/src/io/ipc/read/reader.rs b/crates/nano-arrow/src/io/ipc/read/reader.rs new file mode 100644 index 000000000000..80c900fd9a76 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/reader.rs @@ -0,0 +1,137 @@ +use std::io::{Read, Seek}; + +use ahash::AHashMap; + +use super::common::*; +use super::{read_batch, read_file_dictionaries, Dictionaries, FileMetadata}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::Result; + +/// An iterator of [`Chunk`]s from an Arrow IPC file. +pub struct FileReader { + reader: R, + metadata: FileMetadata, + // the dictionaries are going to be read + dictionaries: Option, + current_block: usize, + projection: Option<(Vec, AHashMap, Schema)>, + remaining: usize, + data_scratch: Vec, + message_scratch: Vec, +} + +impl FileReader { + /// Creates a new [`FileReader`]. Use `projection` to only take certain columns. + /// # Panic + /// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) + pub fn new( + reader: R, + metadata: FileMetadata, + projection: Option>, + limit: Option, + ) -> Self { + let projection = projection.map(|projection| { + let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); + let schema = Schema { + fields, + metadata: metadata.schema.metadata.clone(), + }; + (p, h, schema) + }); + Self { + reader, + metadata, + dictionaries: Default::default(), + projection, + remaining: limit.unwrap_or(usize::MAX), + current_block: 0, + data_scratch: Default::default(), + message_scratch: Default::default(), + } + } + + /// Return the schema of the file + pub fn schema(&self) -> &Schema { + self.projection + .as_ref() + .map(|x| &x.2) + .unwrap_or(&self.metadata.schema) + } + + /// Returns the [`FileMetadata`] + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Consumes this FileReader, returning the underlying reader + pub fn into_inner(self) -> R { + self.reader + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn get_scratches(&mut self) -> (Vec, Vec) { + ( + std::mem::take(&mut self.data_scratch), + std::mem::take(&mut self.message_scratch), + ) + } + + /// Set the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn set_scratches(&mut self, scratches: (Vec, Vec)) { + (self.data_scratch, self.message_scratch) = scratches; + } + + fn read_dictionaries(&mut self) -> Result<()> { + if self.dictionaries.is_none() { + self.dictionaries = Some(read_file_dictionaries( + &mut self.reader, + &self.metadata, + &mut self.data_scratch, + )?); + }; + Ok(()) + } +} + +impl Iterator for FileReader { + type Item = Result>>; + + fn next(&mut self) -> Option { + // get current block + if self.current_block == self.metadata.blocks.len() { + return None; + } + + match self.read_dictionaries() { + Ok(_) => {}, + Err(e) => return Some(Err(e)), + }; + + let block = self.current_block; + self.current_block += 1; + + let chunk = read_batch( + &mut self.reader, + self.dictionaries.as_ref().unwrap(), + &self.metadata, + self.projection.as_ref().map(|x| x.0.as_ref()), + Some(self.remaining), + block, + &mut self.message_scratch, + &mut self.data_scratch, + ); + self.remaining -= chunk.as_ref().map(|x| x.len()).unwrap_or_default(); + + let chunk = if let Some((_, map, _)) = &self.projection { + // re-order according to projection + chunk.map(|chunk| apply_projection(chunk, map)) + } else { + chunk + }; + Some(chunk) + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/schema.rs b/crates/nano-arrow/src/io/ipc/read/schema.rs new file mode 100644 index 000000000000..1b6687f30c95 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/schema.rs @@ -0,0 +1,429 @@ +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef}; + +use super::super::{IpcField, IpcSchema}; +use super::{OutOfSpecKind, StreamMetadata}; +use crate::datatypes::{ + get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema, + TimeUnit, UnionMode, +}; +use crate::error::{Error, Result}; + +fn try_unzip_vec>>(iter: I) -> Result<(Vec, Vec)> { + let mut a = vec![]; + let mut b = vec![]; + for maybe_item in iter { + let (a_i, b_i) = maybe_item?; + a.push(a_i); + b.push(b_i); + } + + Ok((a, b)) +} + +fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> Result<(Field, IpcField)> { + let metadata = read_metadata(&ipc_field)?; + + let extension = get_extension(&metadata); + + let (data_type, ipc_field_) = get_data_type(ipc_field, extension, true)?; + + let field = Field { + name: ipc_field + .name()? + .ok_or_else(|| Error::oos("Every field in IPC must have a name"))? + .to_string(), + data_type, + is_nullable: ipc_field.nullable()?, + metadata, + }; + + Ok((field, ipc_field_)) +} + +fn read_metadata(field: &arrow_format::ipc::FieldRef) -> Result { + Ok(if let Some(list) = field.custom_metadata()? { + let mut metadata_map = Metadata::new(); + for kv in list { + let kv = kv?; + if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) { + metadata_map.insert(k.to_string(), v.to_string()); + } + } + metadata_map + } else { + Metadata::default() + }) +} + +fn deserialize_integer(int: arrow_format::ipc::IntRef) -> Result { + Ok(match (int.bit_width()?, int.is_signed()?) { + (8, true) => IntegerType::Int8, + (8, false) => IntegerType::UInt8, + (16, true) => IntegerType::Int16, + (16, false) => IntegerType::UInt16, + (32, true) => IntegerType::Int32, + (32, false) => IntegerType::UInt32, + (64, true) => IntegerType::Int64, + (64, false) => IntegerType::UInt64, + _ => return Err(Error::oos("IPC: indexType can only be 8, 16, 32 or 64.")), + }) +} + +fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> Result { + use arrow_format::ipc::TimeUnit::*; + Ok(match time_unit { + Second => TimeUnit::Second, + Millisecond => TimeUnit::Millisecond, + Microsecond => TimeUnit::Microsecond, + Nanosecond => TimeUnit::Nanosecond, + }) +} + +fn deserialize_time(time: TimeRef) -> Result<(DataType, IpcField)> { + let unit = deserialize_timeunit(time.unit()?)?; + + let data_type = match (time.bit_width()?, unit) { + (32, TimeUnit::Second) => DataType::Time32(TimeUnit::Second), + (32, TimeUnit::Millisecond) => DataType::Time32(TimeUnit::Millisecond), + (64, TimeUnit::Microsecond) => DataType::Time64(TimeUnit::Microsecond), + (64, TimeUnit::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond), + (bits, precision) => { + return Err(Error::nyi(format!( + "Time type with bit width of {bits} and unit of {precision:?}" + ))) + }, + }; + Ok((data_type, IpcField::default())) +} + +fn deserialize_timestamp(timestamp: TimestampRef) -> Result<(DataType, IpcField)> { + let timezone = timestamp.timezone()?.map(|tz| tz.to_string()); + let time_unit = deserialize_timeunit(timestamp.unit()?)?; + Ok(( + DataType::Timestamp(time_unit, timezone), + IpcField::default(), + )) +} + +fn deserialize_union(union_: UnionRef, field: FieldRef) -> Result<(DataType, IpcField)> { + let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse); + let ids = union_.type_ids()?.map(|x| x.iter().collect()); + + let fields = field + .children()? + .ok_or_else(|| Error::oos("IPC: Union must contain children"))?; + if fields.is_empty() { + return Err(Error::oos("IPC: Union must contain at least one child")); + } + + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + Ok((DataType::Union(fields, ids, mode), ipc_field)) +} + +fn deserialize_map(map: MapRef, field: FieldRef) -> Result<(DataType, IpcField)> { + let is_sorted = map.keys_sorted()?; + + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: Map must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: Map must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + let data_type = DataType::Map(Box::new(field), is_sorted); + Ok(( + data_type, + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_struct(field: FieldRef) -> Result<(DataType, IpcField)> { + let fields = field + .children()? + .ok_or_else(|| Error::oos("IPC: Struct must contain children"))?; + if fields.is_empty() { + return Err(Error::oos("IPC: Struct must contain at least one child")); + } + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + Ok((DataType::Struct(fields), ipc_field)) +} + +fn deserialize_list(field: FieldRef) -> Result<(DataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: List must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: List must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + Ok(( + DataType::List(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_large_list(field: FieldRef) -> Result<(DataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: List must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: List must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + Ok(( + DataType::LargeList(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_fixed_size_list( + list: FixedSizeListRef, + field: FieldRef, +) -> Result<(DataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: FixedSizeList must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: FixedSizeList must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + let size = list + .list_size()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok(( + DataType::FixedSizeList(Box::new(field), size), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +/// Get the Arrow data type from the flatbuffer Field table +fn get_data_type( + field: arrow_format::ipc::FieldRef, + extension: Extension, + may_be_dictionary: bool, +) -> Result<(DataType, IpcField)> { + if let Some(dictionary) = field.dictionary()? { + if may_be_dictionary { + let int = dictionary + .index_type()? + .ok_or_else(|| Error::oos("indexType is mandatory in Dictionary."))?; + let index_type = deserialize_integer(int)?; + let (inner, mut ipc_field) = get_data_type(field, extension, false)?; + ipc_field.dictionary_id = Some(dictionary.id()?); + return Ok(( + DataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?), + ipc_field, + )); + } + } + + if let Some(extension) = extension { + let (name, metadata) = extension; + let (data_type, fields) = get_data_type(field, None, false)?; + return Ok(( + DataType::Extension(name, Box::new(data_type), metadata), + fields, + )); + } + + let type_ = field + .type_()? + .ok_or_else(|| Error::oos("IPC: field type is mandatory"))?; + + use arrow_format::ipc::TypeRef::*; + Ok(match type_ { + Null(_) => (DataType::Null, IpcField::default()), + Bool(_) => (DataType::Boolean, IpcField::default()), + Int(int) => { + let data_type = deserialize_integer(int)?.into(); + (data_type, IpcField::default()) + }, + Binary(_) => (DataType::Binary, IpcField::default()), + LargeBinary(_) => (DataType::LargeBinary, IpcField::default()), + Utf8(_) => (DataType::Utf8, IpcField::default()), + LargeUtf8(_) => (DataType::LargeUtf8, IpcField::default()), + FixedSizeBinary(fixed) => ( + DataType::FixedSizeBinary( + fixed + .byte_width()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?, + ), + IpcField::default(), + ), + FloatingPoint(float) => { + let data_type = match float.precision()? { + arrow_format::ipc::Precision::Half => DataType::Float16, + arrow_format::ipc::Precision::Single => DataType::Float32, + arrow_format::ipc::Precision::Double => DataType::Float64, + }; + (data_type, IpcField::default()) + }, + Date(date) => { + let data_type = match date.unit()? { + arrow_format::ipc::DateUnit::Day => DataType::Date32, + arrow_format::ipc::DateUnit::Millisecond => DataType::Date64, + }; + (data_type, IpcField::default()) + }, + Time(time) => deserialize_time(time)?, + Timestamp(timestamp) => deserialize_timestamp(timestamp)?, + Interval(interval) => { + let data_type = match interval.unit()? { + arrow_format::ipc::IntervalUnit::YearMonth => { + DataType::Interval(IntervalUnit::YearMonth) + }, + arrow_format::ipc::IntervalUnit::DayTime => { + DataType::Interval(IntervalUnit::DayTime) + }, + arrow_format::ipc::IntervalUnit::MonthDayNano => { + DataType::Interval(IntervalUnit::MonthDayNano) + }, + }; + (data_type, IpcField::default()) + }, + Duration(duration) => { + let time_unit = deserialize_timeunit(duration.unit()?)?; + (DataType::Duration(time_unit), IpcField::default()) + }, + Decimal(decimal) => { + let bit_width: usize = decimal + .bit_width()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let precision: usize = decimal + .precision()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let scale: usize = decimal + .scale()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_type = match bit_width { + 128 => DataType::Decimal(precision, scale), + 256 => DataType::Decimal256(precision, scale), + _ => return Err(Error::from(OutOfSpecKind::NegativeFooterLength)), + }; + + (data_type, IpcField::default()) + }, + List(_) => deserialize_list(field)?, + LargeList(_) => deserialize_large_list(field)?, + FixedSizeList(list) => deserialize_fixed_size_list(list, field)?, + Struct(_) => deserialize_struct(field)?, + Union(union_) => deserialize_union(union_, field)?, + Map(map) => deserialize_map(map, field)?, + }) +} + +/// Deserialize an flatbuffers-encoded Schema message into [`Schema`] and [`IpcSchema`]. +pub fn deserialize_schema(message: &[u8]) -> Result<(Schema, IpcSchema)> { + let message = arrow_format::ipc::MessageRef::read_as_root(message) + .map_err(|err| Error::oos(format!("Unable deserialize message: {err:?}")))?; + + let schema = match message + .header()? + .ok_or_else(|| Error::oos("Unable to convert header to a schema".to_string()))? + { + arrow_format::ipc::MessageHeaderRef::Schema(schema) => Ok(schema), + _ => Err(Error::nyi("The message is expected to be a Schema message")), + }?; + + fb_to_schema(schema) +} + +/// Deserialize the raw Schema table from IPC format to Schema data type +pub(super) fn fb_to_schema(schema: arrow_format::ipc::SchemaRef) -> Result<(Schema, IpcSchema)> { + let fields = schema + .fields()? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingFields))?; + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + + let is_little_endian = match schema.endianness()? { + arrow_format::ipc::Endianness::Little => true, + arrow_format::ipc::Endianness::Big => false, + }; + + let mut metadata = Metadata::default(); + if let Some(md_fields) = schema.custom_metadata()? { + for kv in md_fields { + let kv = kv?; + let k_str = kv.key()?; + let v_str = kv.value()?; + if let Some(k) = k_str { + if let Some(v) = v_str { + metadata.insert(k.to_string(), v.to_string()); + } + } + } + } + + Ok(( + Schema { fields, metadata }, + IpcSchema { + fields: ipc_fields, + is_little_endian, + }, + )) +} + +pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> Result { + let message = arrow_format::ipc::MessageRef::read_as_root(meta) + .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {err:?}")))?; + let version = message.version()?; + // message header is a Schema, so read it + let header = message + .header()? + .ok_or_else(|| Error::oos("Unable to read the first IPC message"))?; + let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header { + schema + } else { + return Err(Error::oos( + "The first IPC message of the stream must be a schema", + )); + }; + let (schema, ipc_schema) = fb_to_schema(schema)?; + + Ok(StreamMetadata { + schema, + version, + ipc_schema, + }) +} diff --git a/crates/nano-arrow/src/io/ipc/read/stream.rs b/crates/nano-arrow/src/io/ipc/read/stream.rs new file mode 100644 index 000000000000..848bf5acb938 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/stream.rs @@ -0,0 +1,318 @@ +use std::io::Read; + +use ahash::AHashMap; +use arrow_format; +use arrow_format::ipc::planus::ReadAsRoot; + +use super::super::CONTINUATION_MARKER; +use super::common::*; +use super::schema::deserialize_stream_metadata; +use super::{Dictionaries, OutOfSpecKind}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; +use crate::io::ipc::IpcSchema; + +/// Metadata of an Arrow IPC stream, written at the start of the stream +#[derive(Debug, Clone)] +pub struct StreamMetadata { + /// The schema that is read from the stream's first message + pub schema: Schema, + + /// The IPC version of the stream + pub version: arrow_format::ipc::MetadataVersion, + + /// The IPC fields tracking dictionaries + pub ipc_schema: IpcSchema, +} + +/// Reads the metadata of the stream +pub fn read_stream_metadata(reader: &mut R) -> Result { + // determine metadata length + let mut meta_size: [u8; 4] = [0; 4]; + reader.read_exact(&mut meta_size)?; + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_size == CONTINUATION_MARKER { + reader.read_exact(&mut meta_size)?; + } + i32::from_le_bytes(meta_size) + }; + + let length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let mut buffer = vec![]; + buffer.try_reserve(length)?; + reader + .by_ref() + .take(length as u64) + .read_to_end(&mut buffer)?; + + deserialize_stream_metadata(&buffer) +} + +/// Encodes the stream's status after each read. +/// +/// A stream is an iterator, and an iterator returns `Option`. The `Item` +/// type in the [`StreamReader`] case is `StreamState`, which means that an Arrow +/// stream may yield one of three values: (1) `None`, which signals that the stream +/// is done; (2) [`StreamState::Some`], which signals that there was +/// data waiting in the stream and we read it; and finally (3) +/// [`Some(StreamState::Waiting)`], which means that the stream is still "live", it +/// just doesn't hold any data right now. +pub enum StreamState { + /// A live stream without data + Waiting, + /// Next item in the stream + Some(Chunk>), +} + +impl StreamState { + /// Return the data inside this wrapper. + /// + /// # Panics + /// + /// If the `StreamState` was `Waiting`. + pub fn unwrap(self) -> Chunk> { + if let StreamState::Some(batch) = self { + batch + } else { + panic!("The batch is not available") + } + } +} + +/// Reads the next item, yielding `None` if the stream is done, +/// and a [`StreamState`] otherwise. +fn read_next( + reader: &mut R, + metadata: &StreamMetadata, + dictionaries: &mut Dictionaries, + message_buffer: &mut Vec, + data_buffer: &mut Vec, + projection: &Option<(Vec, AHashMap, Schema)>, + scratch: &mut Vec, +) -> Result> { + // determine metadata length + let mut meta_length: [u8; 4] = [0; 4]; + + match reader.read_exact(&mut meta_length) { + Ok(()) => (), + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(Some(StreamState::Waiting)) + } else { + Err(Error::from(e)) + }; + }, + } + + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_length == CONTINUATION_MARKER { + reader.read_exact(&mut meta_length)?; + } + i32::from_le_bytes(meta_length) + }; + + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + if meta_length == 0 { + // the stream has ended, mark the reader as finished + return Ok(None); + } + + message_buffer.clear(); + message_buffer.try_reserve(meta_length)?; + reader + .by_ref() + .take(meta_length as u64) + .read_to_end(message_buffer)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { + data_buffer.clear(); + data_buffer.try_reserve(block_length)?; + reader + .by_ref() + .take(block_length as u64) + .read_to_end(data_buffer)?; + + let file_size = data_buffer.len() as u64; + + let mut reader = std::io::Cursor::new(data_buffer); + + let chunk = read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection.as_ref().map(|x| x.0.as_ref()), + None, + dictionaries, + metadata.version, + &mut reader, + 0, + file_size, + scratch, + ); + + if let Some((_, map, _)) = projection { + // re-order according to projection + chunk + .map(|chunk| apply_projection(chunk, map)) + .map(|x| Some(StreamState::Some(x))) + } else { + chunk.map(|x| Some(StreamState::Some(x))) + } + }, + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { + data_buffer.clear(); + data_buffer.try_reserve(block_length)?; + reader + .by_ref() + .take(block_length as u64) + .read_to_end(data_buffer)?; + + let file_size = data_buffer.len() as u64; + let mut dict_reader = std::io::Cursor::new(&data_buffer); + + read_dictionary( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + dictionaries, + &mut dict_reader, + 0, + file_size, + scratch, + )?; + + // read the next message until we encounter a RecordBatch message + read_next( + reader, + metadata, + dictionaries, + message_buffer, + data_buffer, + projection, + scratch, + ) + }, + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +/// Arrow Stream reader. +/// +/// An [`Iterator`] over an Arrow stream that yields a result of [`StreamState`]s. +/// This is the recommended way to read an arrow stream (by iterating over its data). +/// +/// For a more thorough walkthrough consult [this example](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow). +pub struct StreamReader { + reader: R, + metadata: StreamMetadata, + dictionaries: Dictionaries, + finished: bool, + data_buffer: Vec, + message_buffer: Vec, + projection: Option<(Vec, AHashMap, Schema)>, + scratch: Vec, +} + +impl StreamReader { + /// Try to create a new stream reader + /// + /// The first message in the stream is the schema, the reader will fail if it does not + /// encounter a schema. + /// To check if the reader is done, use `is_finished(self)` + pub fn new(reader: R, metadata: StreamMetadata, projection: Option>) -> Self { + let projection = projection.map(|projection| { + let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); + let schema = Schema { + fields, + metadata: metadata.schema.metadata.clone(), + }; + (p, h, schema) + }); + + Self { + reader, + metadata, + dictionaries: Default::default(), + finished: false, + data_buffer: Default::default(), + message_buffer: Default::default(), + projection, + scratch: Default::default(), + } + } + + /// Return the schema of the stream + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata + } + + /// Return the schema of the file + pub fn schema(&self) -> &Schema { + self.projection + .as_ref() + .map(|x| &x.2) + .unwrap_or(&self.metadata.schema) + } + + /// Check if the stream is finished + pub fn is_finished(&self) -> bool { + self.finished + } + + fn maybe_next(&mut self) -> Result> { + if self.finished { + return Ok(None); + } + let batch = read_next( + &mut self.reader, + &self.metadata, + &mut self.dictionaries, + &mut self.message_buffer, + &mut self.data_buffer, + &self.projection, + &mut self.scratch, + )?; + if batch.is_none() { + self.finished = true; + } + Ok(batch) + } +} + +impl Iterator for StreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.maybe_next().transpose() + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/stream_async.rs b/crates/nano-arrow/src/io/ipc/read/stream_async.rs new file mode 100644 index 000000000000..f87f84a8d317 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/stream_async.rs @@ -0,0 +1,237 @@ +//! APIs to read Arrow streams asynchronously + +use arrow_format::ipc::planus::ReadAsRoot; +use futures::future::BoxFuture; +use futures::{AsyncRead, AsyncReadExt, FutureExt, Stream}; + +use super::super::CONTINUATION_MARKER; +use super::common::{read_dictionary, read_record_batch}; +use super::schema::deserialize_stream_metadata; +use super::{Dictionaries, OutOfSpecKind, StreamMetadata}; +use crate::array::*; +use crate::chunk::Chunk; +use crate::error::{Error, Result}; + +/// A (private) state of stream messages +struct ReadState { + pub reader: R, + pub metadata: StreamMetadata, + pub dictionaries: Dictionaries, + /// The internal buffer to read data inside the messages (records and dictionaries) to + pub data_buffer: Vec, + /// The internal buffer to read messages to + pub message_buffer: Vec, +} + +/// The state of an Arrow stream +enum StreamState { + /// The stream does not contain new chunks (and it has not been closed) + Waiting(ReadState), + /// The stream contain a new chunk + Some((ReadState, Chunk>)), +} + +/// Reads the [`StreamMetadata`] of the Arrow stream asynchronously +pub async fn read_stream_metadata_async( + reader: &mut R, +) -> Result { + // determine metadata length + let mut meta_size: [u8; 4] = [0; 4]; + reader.read_exact(&mut meta_size).await?; + let meta_len = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_size == CONTINUATION_MARKER { + reader.read_exact(&mut meta_size).await?; + } + i32::from_le_bytes(meta_size) + }; + + let meta_len: usize = meta_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let mut meta_buffer = vec![]; + meta_buffer.try_reserve(meta_len)?; + reader + .take(meta_len as u64) + .read_to_end(&mut meta_buffer) + .await?; + + deserialize_stream_metadata(&meta_buffer) +} + +/// Reads the next item, yielding `None` if the stream has been closed, +/// or a [`StreamState`] otherwise. +async fn maybe_next( + mut state: ReadState, +) -> Result>> { + let mut scratch = Default::default(); + // determine metadata length + let mut meta_length: [u8; 4] = [0; 4]; + + match state.reader.read_exact(&mut meta_length).await { + Ok(()) => (), + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(Some(StreamState::Waiting(state))) + } else { + Err(Error::from(e)) + }; + }, + } + + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_length == CONTINUATION_MARKER { + state.reader.read_exact(&mut meta_length).await?; + } + i32::from_le_bytes(meta_length) + }; + + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + if meta_length == 0 { + // the stream has ended, mark the reader as finished + return Ok(None); + } + + state.message_buffer.clear(); + state.message_buffer.try_reserve(meta_length)?; + (&mut state.reader) + .take(meta_length as u64) + .read_to_end(&mut state.message_buffer) + .await?; + + let message = arrow_format::ipc::MessageRef::read_as_root(state.message_buffer.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { + state.data_buffer.clear(); + state.data_buffer.try_reserve(block_length)?; + (&mut state.reader) + .take(block_length as u64) + .read_to_end(&mut state.data_buffer) + .await?; + + read_record_batch( + batch, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + None, + None, + &state.dictionaries, + state.metadata.version, + &mut std::io::Cursor::new(&state.data_buffer), + 0, + state.data_buffer.len() as u64, + &mut scratch, + ) + .map(|chunk| Some(StreamState::Some((state, chunk)))) + }, + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { + state.data_buffer.clear(); + state.data_buffer.try_reserve(block_length)?; + (&mut state.reader) + .take(block_length as u64) + .read_to_end(&mut state.data_buffer) + .await?; + + let file_size = state.data_buffer.len() as u64; + + let mut dict_reader = std::io::Cursor::new(&state.data_buffer); + + read_dictionary( + batch, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + &mut state.dictionaries, + &mut dict_reader, + 0, + file_size, + &mut scratch, + )?; + + // read the next message until we encounter a Chunk> message + Ok(Some(StreamState::Waiting(state))) + }, + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`Chunk`]s. +pub struct AsyncStreamReader<'a, R: AsyncRead + Unpin + Send + 'a> { + metadata: StreamMetadata, + future: Option>>>>, +} + +impl<'a, R: AsyncRead + Unpin + Send + 'a> AsyncStreamReader<'a, R> { + /// Creates a new [`AsyncStreamReader`] + pub fn new(reader: R, metadata: StreamMetadata) -> Self { + let state = ReadState { + reader, + metadata: metadata.clone(), + dictionaries: Default::default(), + data_buffer: Default::default(), + message_buffer: Default::default(), + }; + let future = Some(maybe_next(state).boxed()); + Self { metadata, future } + } + + /// Return the schema of the stream + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata + } +} + +impl<'a, R: AsyncRead + Unpin + Send> Stream for AsyncStreamReader<'a, R> { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::pin::Pin; + use std::task::Poll; + let me = Pin::into_inner(self); + + match &mut me.future { + Some(fut) => match fut.as_mut().poll(cx) { + Poll::Ready(Ok(None)) => { + me.future = None; + Poll::Ready(None) + }, + Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => { + me.future = Some(Box::pin(maybe_next(state))); + Poll::Ready(Some(Ok(batch))) + }, + Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending, + Poll::Ready(Err(err)) => { + me.future = None; + Poll::Ready(Some(Err(err))) + }, + Poll::Pending => Poll::Pending, + }, + None => Poll::Ready(None), + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/common.rs b/crates/nano-arrow/src/io/ipc/write/common.rs new file mode 100644 index 000000000000..4684bd7f658d --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/common.rs @@ -0,0 +1,448 @@ +use std::borrow::{Borrow, Cow}; + +use arrow_format::ipc::planus::Builder; + +use super::super::IpcField; +use super::{write, write_dictionary}; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::io::ipc::endianness::is_native_little_endian; +use crate::io::ipc::read::Dictionaries; + +/// Compression codec +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Compression { + /// LZ4 (framed) + LZ4, + /// ZSTD + ZSTD, +} + +/// Options declaring the behaviour of writing to IPC +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct WriteOptions { + /// Whether the buffers should be compressed and which codec to use. + /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`. + pub compression: Option, +} + +fn encode_dictionary( + field: &IpcField, + array: &dyn Array, + options: &WriteOptions, + dictionary_tracker: &mut DictionaryTracker, + encoded_dictionaries: &mut Vec, +) -> Result<()> { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null + | FixedSizeBinary => Ok(()), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let dict_id = field.dictionary_id + .ok_or_else(|| Error::InvalidArgumentError("Dictionaries must have an associated id".to_string()))?; + + let emit = dictionary_tracker.insert(dict_id, array)?; + + let array = array.as_any().downcast_ref::>().unwrap(); + let values = array.values(); + encode_dictionary(field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries + )?; + + if emit { + encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>( + dict_id, + array, + options, + is_native_little_endian(), + )); + }; + Ok(()) + }), + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + let fields = field.fields.as_slice(); + if array.fields().len() != fields.len() { + return Err(Error::InvalidArgumentError( + "The number of fields in a struct must equal the number of children in IpcField".to_string(), + )); + } + fields + .iter() + .zip(array.values().iter()) + .try_for_each(|(field, values)| { + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }) + }, + List => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + LargeList => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + FixedSizeList => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + Union => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .fields(); + let fields = &field.fields[..]; // todo: error instead + if values.len() != fields.len() { + return Err(Error::InvalidArgumentError( + "The number of fields in a union must equal the number of children in IpcField" + .to_string(), + )); + } + fields + .iter() + .zip(values.iter()) + .try_for_each(|(field, values)| { + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }) + }, + Map => { + let values = array.as_any().downcast_ref::().unwrap().field(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + } +} + +pub fn encode_chunk( + chunk: &Chunk>, + fields: &[IpcField], + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, +) -> Result<(Vec, EncodedData)> { + let mut encoded_message = EncodedData::default(); + let encoded_dictionaries = encode_chunk_amortized( + chunk, + fields, + dictionary_tracker, + options, + &mut encoded_message, + )?; + Ok((encoded_dictionaries, encoded_message)) +} + +// Amortizes `EncodedData` allocation. +pub fn encode_chunk_amortized( + chunk: &Chunk>, + fields: &[IpcField], + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, + encoded_message: &mut EncodedData, +) -> Result> { + let mut encoded_dictionaries = vec![]; + + for (field, array) in fields.iter().zip(chunk.as_ref()) { + encode_dictionary( + field, + array.as_ref(), + options, + dictionary_tracker, + &mut encoded_dictionaries, + )?; + } + + chunk_to_bytes_amortized(chunk, options, encoded_message); + + Ok(encoded_dictionaries) +} + +fn serialize_compression( + compression: Option, +) -> Option> { + if let Some(compression) = compression { + let codec = match compression { + Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame, + Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd, + }; + Some(Box::new(arrow_format::ipc::BodyCompression { + codec, + method: arrow_format::ipc::BodyCompressionMethod::Buffer, + })) + } else { + None + } +} + +/// Write [`Chunk`] into two sets of bytes, one for the header (ipc::Schema::Message) and the +/// other for the batch's data +fn chunk_to_bytes_amortized( + chunk: &Chunk>, + options: &WriteOptions, + encoded_message: &mut EncodedData, +) { + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data); + arrow_data.clear(); + + let mut offset = 0; + for array in chunk.arrays() { + write( + array.as_ref(), + &mut buffers, + &mut arrow_data, + &mut nodes, + &mut offset, + is_native_little_endian(), + options.compression, + ) + } + + let compression = serialize_compression(options.compression); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new( + arrow_format::ipc::RecordBatch { + length: chunk.len() as i64, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + }, + ))), + body_length: arrow_data.len() as i64, + custom_metadata: None, + }; + + let mut builder = Builder::new(); + let ipc_message = builder.finish(&message, None); + encoded_message.ipc_message = ipc_message.to_vec(); + encoded_message.arrow_data = arrow_data +} + +/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the +/// other for the data +fn dictionary_batch_to_bytes( + dict_id: i64, + array: &DictionaryArray, + options: &WriteOptions, + is_little_endian: bool, +) -> EncodedData { + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + + let length = write_dictionary( + array, + &mut buffers, + &mut arrow_data, + &mut nodes, + &mut 0, + is_little_endian, + options.compression, + false, + ); + + let compression = serialize_compression(options.compression); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new( + arrow_format::ipc::DictionaryBatch { + id: dict_id, + data: Some(Box::new(arrow_format::ipc::RecordBatch { + length: length as i64, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + })), + is_delta: false, + }, + ))), + body_length: arrow_data.len() as i64, + custom_metadata: None, + }; + + let mut builder = Builder::new(); + let ipc_message = builder.finish(&message, None); + + EncodedData { + ipc_message: ipc_message.to_vec(), + arrow_data, + } +} + +/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary +/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which +/// isn't allowed in the `FileWriter`. +pub struct DictionaryTracker { + pub dictionaries: Dictionaries, + pub cannot_replace: bool, +} + +impl DictionaryTracker { + /// Keep track of the dictionary with the given ID and values. Behavior: + /// + /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate + /// that the dictionary was not actually inserted (because it's already been seen). + /// * If this ID has been written already but with different data, and this tracker is + /// configured to return an error, return an error. + /// * If the tracker has not been configured to error on replacement or this dictionary + /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just + /// inserted. + pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> Result { + let values = match array.data_type() { + DataType::Dictionary(key_type, _, _) => { + match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + array.values() + }) + }, + _ => unreachable!(), + }; + + // If a dictionary with this id was already emitted, check if it was the same. + if let Some(last) = self.dictionaries.get(&dict_id) { + if last.as_ref() == values.as_ref() { + // Same dictionary values => no need to emit it again + return Ok(false); + } else if self.cannot_replace { + return Err(Error::InvalidArgumentError( + "Dictionary replacement detected when writing IPC file format. \ + Arrow IPC files only support a single dictionary for a given field \ + across all batches." + .to_string(), + )); + } + }; + + self.dictionaries.insert(dict_id, values.clone()); + Ok(true) + } +} + +/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data +#[derive(Debug, Default)] +pub struct EncodedData { + /// An encoded ipc::Schema::Message + pub ipc_message: Vec, + /// Arrow buffers to be written, should be an empty vec for schema messages + pub arrow_data: Vec, +} + +/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes +#[inline] +pub(crate) fn pad_to_64(len: usize) -> usize { + ((len + 63) & !63) - len +} + +/// An array [`Chunk`] with optional accompanying IPC fields. +#[derive(Debug, Clone, PartialEq)] +pub struct Record<'a> { + columns: Cow<'a, Chunk>>, + fields: Option>, +} + +impl<'a> Record<'a> { + /// Get the IPC fields for this record. + pub fn fields(&self) -> Option<&[IpcField]> { + self.fields.as_deref() + } + + /// Get the Arrow columns in this record. + pub fn columns(&self) -> &Chunk> { + self.columns.borrow() + } +} + +impl From>> for Record<'static> { + fn from(columns: Chunk>) -> Self { + Self { + columns: Cow::Owned(columns), + fields: None, + } + } +} + +impl<'a, F> From<(Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (Chunk>, Option)) -> Self { + Self { + columns: Cow::Owned(columns), + fields: fields.map(|f| f.into()), + } + } +} + +impl<'a, F> From<(&'a Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (&'a Chunk>, Option)) -> Self { + Self { + columns: Cow::Borrowed(columns), + fields: fields.map(|f| f.into()), + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/common_async.rs b/crates/nano-arrow/src/io/ipc/write/common_async.rs new file mode 100644 index 000000000000..397391cd24ee --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/common_async.rs @@ -0,0 +1,66 @@ +use futures::{AsyncWrite, AsyncWriteExt}; + +use super::super::CONTINUATION_MARKER; +use super::common::{pad_to_64, EncodedData}; +use crate::error::Result; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub async fn write_message( + mut writer: W, + encoded: EncodedData, +) -> Result<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + + let a = 64 - 1; + let buffer = encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; // the message length + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(&mut writer, (aligned_size - prefix_size) as i32).await?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(&buffer).await?; + } + // write padding + writer.write_all(&vec![0; padding_bytes]).await?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data).await? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub async fn write_continuation( + mut writer: W, + total_len: i32, +) -> Result { + writer.write_all(&CONTINUATION_MARKER).await?; + writer.write_all(&total_len.to_le_bytes()[..]).await?; + Ok(8) +} + +async fn write_body_buffers( + mut writer: W, + data: &[u8], +) -> Result { + let len = data.len(); + let pad_len = pad_to_64(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data).await?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..]).await?; + } + + Ok(total_len) +} diff --git a/crates/nano-arrow/src/io/ipc/write/common_sync.rs b/crates/nano-arrow/src/io/ipc/write/common_sync.rs new file mode 100644 index 000000000000..b20196419b2c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/common_sync.rs @@ -0,0 +1,59 @@ +use std::io::Write; + +use super::super::CONTINUATION_MARKER; +use super::common::{pad_to_64, EncodedData}; +use crate::error::Result; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub fn write_message(writer: &mut W, encoded: &EncodedData) -> Result<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + + let a = 8 - 1; + let buffer = &encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(writer, (aligned_size - prefix_size) as i32)?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(buffer)?; + } + // write padding + // aligned to a 8 byte boundary, so maximum is [u8;8] + const PADDING_MAX: [u8; 8] = [0u8; 8]; + writer.write_all(&PADDING_MAX[..padding_bytes])?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data)? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { + let len = data.len(); + let pad_len = pad_to_64(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data)?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..])?; + } + + Ok(total_len) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub fn write_continuation(writer: &mut W, total_len: i32) -> Result { + writer.write_all(&CONTINUATION_MARKER)?; + writer.write_all(&total_len.to_le_bytes()[..])?; + Ok(8) +} diff --git a/crates/nano-arrow/src/io/ipc/write/file_async.rs b/crates/nano-arrow/src/io/ipc/write/file_async.rs new file mode 100644 index 000000000000..93a1715282e2 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/file_async.rs @@ -0,0 +1,252 @@ +//! Async writer for IPC files. + +use std::task::Poll; + +use arrow_format::ipc::planus::Builder; +use arrow_format::ipc::{Block, Footer, MetadataVersion}; +use futures::future::BoxFuture; +use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink}; + +use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_async::{write_continuation, write_message}; +use super::schema::serialize_schema; +use super::{default_ipc_fields, schema_to_bytes, Record}; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::io::ipc::{IpcField, ARROW_MAGIC_V2}; + +type WriteOutput = (usize, Option, Vec, Option); + +/// Sink that writes array [`chunks`](crate::chunk::Chunk) as an IPC file. +/// +/// The file header is automatically written before writing the first chunk, and the file footer is +/// automatically written when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use futures::{SinkExt, TryStreamExt, io::Cursor}; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::ipc::write::file_async::FileSink; +/// use arrow2::io::ipc::read::file_async::{read_file_metadata_async, FileStream}; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = Cursor::new(vec![]); +/// let mut sink = FileSink::new( +/// &mut buffer, +/// schema, +/// None, +/// Default::default(), +/// ); +/// +/// // Write chunks to file +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// drop(sink); +/// +/// // Read chunks from file +/// buffer.set_position(0); +/// let metadata = read_file_metadata_async(&mut buffer).await?; +/// let mut stream = FileStream::new(buffer, metadata, None, None); +/// let chunks = stream.try_collect::>().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, + dictionary_tracker: DictionaryTracker, + offset: usize, + fields: Vec, + record_blocks: Vec, + dictionary_blocks: Vec, + schema: Schema, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new file writer. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let encoded = EncodedData { + ipc_message: schema_to_bytes(&schema, &fields), + arrow_data: vec![], + }; + let task = Some(Self::start(writer, encoded).boxed()); + Self { + writer: None, + task, + options, + fields, + offset: 0, + schema, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: true, + }, + record_blocks: vec![], + dictionary_blocks: vec![], + } + } + + async fn start(mut writer: W, encoded: EncodedData) -> Result> { + writer.write_all(&ARROW_MAGIC_V2[..]).await?; + writer.write_all(&[0, 0]).await?; + let (meta, data) = write_message(&mut writer, encoded).await?; + + Ok((meta + data + 8, None, vec![], Some(writer))) + } + + async fn write( + mut writer: W, + mut offset: usize, + record: EncodedData, + dictionaries: Vec, + ) -> Result> { + let mut dict_blocks = vec![]; + for dict in dictionaries { + let (meta, data) = write_message(&mut writer, dict).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + dict_blocks.push(block); + offset += meta + data; + } + let (meta, data) = write_message(&mut writer, record).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + offset += meta + data; + Ok((offset, Some(block), dict_blocks, Some(writer))) + } + + async fn finish(mut writer: W, footer: Footer) -> Result> { + write_continuation(&mut writer, 0).await?; + let footer = { + let mut builder = Builder::new(); + builder.finish(&footer, None).to_owned() + }; + writer.write_all(&footer[..]).await?; + writer + .write_all(&(footer.len() as i32).to_le_bytes()) + .await?; + writer.write_all(&ARROW_MAGIC_V2).await?; + writer.close().await?; + + Ok((0, None, vec![], None)) + } + + fn poll_write(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok((offset, record, mut dictionaries, writer)) => { + self.task = None; + self.writer = writer; + self.offset = offset; + if let Some(block) = record { + self.record_blocks.push(block); + } + self.dictionary_blocks.append(&mut dictionaries); + Poll::Ready(Ok(())) + }, + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + }, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink> for FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Record<'_>) -> Result<()> { + let this = self.get_mut(); + + if let Some(writer) = this.writer.take() { + let fields = item.fields().unwrap_or_else(|| &this.fields[..]); + + let (dictionaries, record) = encode_chunk( + item.columns(), + fields, + &mut this.dictionary_tracker, + &this.options, + )?; + + this.task = Some(Self::write(writer, this.offset, record, dictionaries).boxed()); + Ok(()) + } else { + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer is closed", + ))) + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_write(cx)) { + Ok(()) => { + if let Some(writer) = this.writer.take() { + let schema = serialize_schema(&this.schema, &this.fields); + let footer = Footer { + version: MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut this.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut this.record_blocks)), + custom_metadata: None, + }; + this.task = Some(Self::finish(writer, footer).boxed()); + this.poll_write(cx) + } else { + Poll::Ready(Ok(())) + } + }, + Err(error) => Poll::Ready(Err(error)), + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/mod.rs b/crates/nano-arrow/src/io/ipc/write/mod.rs new file mode 100644 index 000000000000..55672a85da3c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/mod.rs @@ -0,0 +1,70 @@ +//! APIs to write to Arrow's IPC format. +pub(crate) mod common; +mod schema; +mod serialize; +mod stream; +pub(crate) mod writer; + +pub use common::{Compression, Record, WriteOptions}; +pub use schema::schema_to_bytes; +pub use serialize::write; +use serialize::write_dictionary; +pub use stream::StreamWriter; +pub use writer::FileWriter; + +pub(crate) mod common_sync; + +#[cfg(feature = "io_ipc_write_async")] +mod common_async; +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod stream_async; + +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod file_async; + +use super::IpcField; +use crate::datatypes::{DataType, Field}; + +fn default_ipc_field(data_type: &DataType, current_id: &mut i64) -> IpcField { + use crate::datatypes::DataType::*; + match data_type.to_logical_type() { + // single child => recurse + Map(inner, ..) | FixedSizeList(inner, _) | LargeList(inner) | List(inner) => IpcField { + fields: vec![default_ipc_field(inner.data_type(), current_id)], + dictionary_id: None, + }, + // multiple children => recurse + Union(fields, ..) | Struct(fields) => IpcField { + fields: fields + .iter() + .map(|f| default_ipc_field(f.data_type(), current_id)) + .collect(), + dictionary_id: None, + }, + // dictionary => current_id + Dictionary(_, data_type, _) => { + let dictionary_id = Some(*current_id); + *current_id += 1; + IpcField { + fields: vec![default_ipc_field(data_type, current_id)], + dictionary_id, + } + }, + // no children => do nothing + _ => IpcField { + fields: vec![], + dictionary_id: None, + }, + } +} + +/// Assigns every dictionary field a unique ID +pub fn default_ipc_fields(fields: &[Field]) -> Vec { + let mut dictionary_id = 0i64; + fields + .iter() + .map(|field| default_ipc_field(field.data_type().to_logical_type(), &mut dictionary_id)) + .collect() +} diff --git a/crates/nano-arrow/src/io/ipc/write/schema.rs b/crates/nano-arrow/src/io/ipc/write/schema.rs new file mode 100644 index 000000000000..dd6f44bbd33a --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/schema.rs @@ -0,0 +1,333 @@ +use arrow_format::ipc::planus::Builder; + +use super::super::IpcField; +use crate::datatypes::{ + DataType, Field, IntegerType, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode, +}; +use crate::io::ipc::endianness::is_native_little_endian; + +/// Converts a [Schema] and [IpcField]s to a flatbuffers-encoded [arrow_format::ipc::Message]. +pub fn schema_to_bytes(schema: &Schema, ipc_fields: &[IpcField]) -> Vec { + let schema = serialize_schema(schema, ipc_fields); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(schema))), + body_length: 0, + custom_metadata: None, // todo: allow writing custom metadata + }; + let mut builder = Builder::new(); + let footer_data = builder.finish(&message, None); + footer_data.to_vec() +} + +pub fn serialize_schema(schema: &Schema, ipc_fields: &[IpcField]) -> arrow_format::ipc::Schema { + let endianness = if is_native_little_endian() { + arrow_format::ipc::Endianness::Little + } else { + arrow_format::ipc::Endianness::Big + }; + + let fields = schema + .fields + .iter() + .zip(ipc_fields.iter()) + .map(|(field, ipc_field)| serialize_field(field, ipc_field)) + .collect::>(); + + let mut custom_metadata = vec![]; + for (key, value) in &schema.metadata { + custom_metadata.push(arrow_format::ipc::KeyValue { + key: Some(key.clone()), + value: Some(value.clone()), + }); + } + let custom_metadata = if custom_metadata.is_empty() { + None + } else { + Some(custom_metadata) + }; + + arrow_format::ipc::Schema { + endianness, + fields: Some(fields), + custom_metadata, + features: None, // todo add this one + } +} + +fn write_metadata(metadata: &Metadata, kv_vec: &mut Vec) { + for (k, v) in metadata { + if k != "ARROW:extension:name" && k != "ARROW:extension:metadata" { + let entry = arrow_format::ipc::KeyValue { + key: Some(k.clone()), + value: Some(v.clone()), + }; + kv_vec.push(entry); + } + } +} + +fn write_extension( + name: &str, + metadata: &Option, + kv_vec: &mut Vec, +) { + // metadata + if let Some(metadata) = metadata { + let entry = arrow_format::ipc::KeyValue { + key: Some("ARROW:extension:metadata".to_string()), + value: Some(metadata.clone()), + }; + kv_vec.push(entry); + } + + // name + let entry = arrow_format::ipc::KeyValue { + key: Some("ARROW:extension:name".to_string()), + value: Some(name.to_string()), + }; + kv_vec.push(entry); +} + +/// Create an IPC Field from an Arrow Field +pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_format::ipc::Field { + // custom metadata. + let mut kv_vec = vec![]; + if let DataType::Extension(name, _, metadata) = field.data_type() { + write_extension(name, metadata, &mut kv_vec); + } + + let type_ = serialize_type(field.data_type()); + let children = serialize_children(field.data_type(), ipc_field); + + let dictionary = if let DataType::Dictionary(index_type, inner, is_ordered) = field.data_type() + { + if let DataType::Extension(name, _, metadata) = inner.as_ref() { + write_extension(name, metadata, &mut kv_vec); + } + Some(serialize_dictionary( + index_type, + ipc_field + .dictionary_id + .expect("All Dictionary types have `dict_id`"), + *is_ordered, + )) + } else { + None + }; + + write_metadata(&field.metadata, &mut kv_vec); + + let custom_metadata = if !kv_vec.is_empty() { + Some(kv_vec) + } else { + None + }; + + arrow_format::ipc::Field { + name: Some(field.name.clone()), + nullable: field.is_nullable, + type_: Some(type_), + dictionary: dictionary.map(Box::new), + children: Some(children), + custom_metadata, + } +} + +fn serialize_time_unit(unit: &TimeUnit) -> arrow_format::ipc::TimeUnit { + match unit { + TimeUnit::Second => arrow_format::ipc::TimeUnit::Second, + TimeUnit::Millisecond => arrow_format::ipc::TimeUnit::Millisecond, + TimeUnit::Microsecond => arrow_format::ipc::TimeUnit::Microsecond, + TimeUnit::Nanosecond => arrow_format::ipc::TimeUnit::Nanosecond, + } +} + +fn serialize_type(data_type: &DataType) -> arrow_format::ipc::Type { + use arrow_format::ipc; + use DataType::*; + match data_type { + Null => ipc::Type::Null(Box::new(ipc::Null {})), + Boolean => ipc::Type::Bool(Box::new(ipc::Bool {})), + UInt8 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 8, + is_signed: false, + })), + UInt16 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 16, + is_signed: false, + })), + UInt32 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 32, + is_signed: false, + })), + UInt64 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 64, + is_signed: false, + })), + Int8 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 8, + is_signed: true, + })), + Int16 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 16, + is_signed: true, + })), + Int32 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 32, + is_signed: true, + })), + Int64 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 64, + is_signed: true, + })), + Float16 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Half, + })), + Float32 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Single, + })), + Float64 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Double, + })), + Decimal(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width: 128, + })), + Decimal256(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width: 256, + })), + Binary => ipc::Type::Binary(Box::new(ipc::Binary {})), + LargeBinary => ipc::Type::LargeBinary(Box::new(ipc::LargeBinary {})), + Utf8 => ipc::Type::Utf8(Box::new(ipc::Utf8 {})), + LargeUtf8 => ipc::Type::LargeUtf8(Box::new(ipc::LargeUtf8 {})), + FixedSizeBinary(size) => ipc::Type::FixedSizeBinary(Box::new(ipc::FixedSizeBinary { + byte_width: *size as i32, + })), + Date32 => ipc::Type::Date(Box::new(ipc::Date { + unit: ipc::DateUnit::Day, + })), + Date64 => ipc::Type::Date(Box::new(ipc::Date { + unit: ipc::DateUnit::Millisecond, + })), + Duration(unit) => ipc::Type::Duration(Box::new(ipc::Duration { + unit: serialize_time_unit(unit), + })), + Time32(unit) => ipc::Type::Time(Box::new(ipc::Time { + unit: serialize_time_unit(unit), + bit_width: 32, + })), + Time64(unit) => ipc::Type::Time(Box::new(ipc::Time { + unit: serialize_time_unit(unit), + bit_width: 64, + })), + Timestamp(unit, tz) => ipc::Type::Timestamp(Box::new(ipc::Timestamp { + unit: serialize_time_unit(unit), + timezone: tz.as_ref().cloned(), + })), + Interval(unit) => ipc::Type::Interval(Box::new(ipc::Interval { + unit: match unit { + IntervalUnit::YearMonth => ipc::IntervalUnit::YearMonth, + IntervalUnit::DayTime => ipc::IntervalUnit::DayTime, + IntervalUnit::MonthDayNano => ipc::IntervalUnit::MonthDayNano, + }, + })), + List(_) => ipc::Type::List(Box::new(ipc::List {})), + LargeList(_) => ipc::Type::LargeList(Box::new(ipc::LargeList {})), + FixedSizeList(_, size) => ipc::Type::FixedSizeList(Box::new(ipc::FixedSizeList { + list_size: *size as i32, + })), + Union(_, type_ids, mode) => ipc::Type::Union(Box::new(ipc::Union { + mode: match mode { + UnionMode::Dense => ipc::UnionMode::Dense, + UnionMode::Sparse => ipc::UnionMode::Sparse, + }, + type_ids: type_ids.clone(), + })), + Map(_, keys_sorted) => ipc::Type::Map(Box::new(ipc::Map { + keys_sorted: *keys_sorted, + })), + Struct(_) => ipc::Type::Struct(Box::new(ipc::Struct {})), + Dictionary(_, v, _) => serialize_type(v), + Extension(_, v, _) => serialize_type(v), + } +} + +fn serialize_children(data_type: &DataType, ipc_field: &IpcField) -> Vec { + use DataType::*; + match data_type { + Null + | Boolean + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Timestamp(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Duration(_) + | Interval(_) + | Binary + | FixedSizeBinary(_) + | LargeBinary + | Utf8 + | LargeUtf8 + | Decimal(_, _) + | Decimal256(_, _) => vec![], + FixedSizeList(inner, _) | LargeList(inner) | List(inner) | Map(inner, _) => { + vec![serialize_field(inner, &ipc_field.fields[0])] + }, + Union(fields, _, _) | Struct(fields) => fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc)| serialize_field(field, ipc)) + .collect(), + Dictionary(_, inner, _) => serialize_children(inner, ipc_field), + Extension(_, inner, _) => serialize_children(inner, ipc_field), + } +} + +/// Create an IPC dictionary encoding +pub(crate) fn serialize_dictionary( + index_type: &IntegerType, + dict_id: i64, + dict_is_ordered: bool, +) -> arrow_format::ipc::DictionaryEncoding { + use IntegerType::*; + let is_signed = match index_type { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => false, + }; + + let bit_width = match index_type { + Int8 | UInt8 => 8, + Int16 | UInt16 => 16, + Int32 | UInt32 => 32, + Int64 | UInt64 => 64, + }; + + let index_type = arrow_format::ipc::Int { + bit_width, + is_signed, + }; + + arrow_format::ipc::DictionaryEncoding { + id: dict_id, + index_type: Some(Box::new(index_type)), + is_ordered: dict_is_ordered, + dictionary_kind: arrow_format::ipc::DictionaryKind::DenseArray, + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/serialize.rs b/crates/nano-arrow/src/io/ipc/write/serialize.rs new file mode 100644 index 000000000000..f5bad22d6fe4 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/serialize.rs @@ -0,0 +1,763 @@ +#![allow(clippy::ptr_arg)] // false positive in clippy, see https://github.com/rust-lang/rust-clippy/issues/8463 +use arrow_format::ipc; + +use super::super::compression; +use super::super::endianness::is_native_little_endian; +use super::common::{pad_to_64, Compression}; +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; +use crate::offset::{Offset, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +fn write_primitive( + array: &PrimitiveArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + + write_buffer( + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ) +} + +fn write_boolean( + array: &BooleanArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + _: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write_bitmap( + Some(&array.values().clone()), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); +} + +#[allow(clippy::too_many_arguments)] +fn write_generic_binary( + validity: Option<&Bitmap>, + offsets: &OffsetsBuffer, + values: &[u8], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = offsets.buffer(); + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == O::default() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write_bytes( + &values[first.to_usize()..last.to_usize()], + buffers, + arrow_data, + offset, + compression, + ); +} + +fn write_binary( + array: &BinaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_generic_binary( + array.validity(), + array.offsets(), + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); +} + +fn write_utf8( + array: &Utf8Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_generic_binary( + array.validity(), + array.offsets(), + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); +} + +fn write_fixed_size_binary( + array: &FixedSizeBinaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + _is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write_bytes(array.values(), buffers, arrow_data, offset, compression); +} + +fn write_list( + array: &ListArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = array.offsets().buffer(); + let validity = array.validity(); + + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == O::zero() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write( + array + .values() + .sliced(first.to_usize(), last.to_usize() - first.to_usize()) + .as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} + +pub fn write_struct( + array: &StructArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + array.values().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }); +} + +pub fn write_union( + array: &UnionArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_buffer( + array.types(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + + if let Some(offsets) = array.offsets() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + array.fields().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ) + }); +} + +fn write_map( + array: &MapArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = array.offsets().buffer(); + let validity = array.validity(); + + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == 0 { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write( + array + .field() + .sliced(first as usize, last as usize - first as usize) + .as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} + +fn write_fixed_size_list( + array: &FixedSizeListArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write( + array.values().as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} + +// use `write_keys` to either write keys or values +#[allow(clippy::too_many_arguments)] +pub(super) fn write_dictionary( + array: &DictionaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, + write_keys: bool, +) -> usize { + if write_keys { + write_primitive( + array.keys(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + array.keys().len() + } else { + write( + array.values().as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + array.values().len() + } +} + +/// Writes an [`Array`] to `arrow_data` +pub fn write( + array: &dyn Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + nodes.push(ipc::FieldNode { + length: array.len() as i64, + null_count: array.null_count() as i64, + }); + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => (), + Boolean => write_boolean( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + write_primitive::<$T>(array, buffers, arrow_data, offset, is_little_endian, compression) + }), + Binary => write_binary::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + LargeBinary => write_binary::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + FixedSizeBinary => write_fixed_size_binary( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + Utf8 => write_utf8::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + LargeUtf8 => write_utf8::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + List => write_list::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + LargeList => write_list::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + FixedSizeList => write_fixed_size_list( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + Struct => write_struct( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + write_dictionary::<$T>( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + true, + ); + }), + Union => { + write_union( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }, + Map => { + write_map( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }, + } +} + +#[inline] +fn pad_buffer_to_64(buffer: &mut Vec, length: usize) { + let pad_len = pad_to_64(length); + buffer.extend_from_slice(&vec![0u8; pad_len]); +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +fn write_bytes( + bytes: &[u8], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + compression: Option, +) { + let start = arrow_data.len(); + if let Some(compression) = compression { + arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(bytes, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(bytes, arrow_data).unwrap(); + }, + } + } else { + arrow_data.extend_from_slice(bytes); + }; + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +fn write_bitmap( + bitmap: Option<&Bitmap>, + length: usize, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + compression: Option, +) { + match bitmap { + Some(bitmap) => { + assert_eq!(bitmap.len(), length); + let (slice, slice_offset, _) = bitmap.as_slice(); + if slice_offset != 0 { + // case where we can't slice the bitmap as the offsets are not multiple of 8 + let bytes = Bitmap::from_trusted_len_iter(bitmap.iter()); + let (slice, _, _) = bytes.as_slice(); + write_bytes(slice, buffers, arrow_data, offset, compression) + } else { + write_bytes(slice, buffers, arrow_data, offset, compression) + } + }, + None => { + buffers.push(ipc::Buffer { + offset: *offset, + length: 0, + }); + }, + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +fn write_buffer( + buffer: &[T], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let start = arrow_data.len(); + if let Some(compression) = compression { + _write_compressed_buffer(buffer, arrow_data, is_little_endian, compression); + } else { + _write_buffer(buffer, arrow_data, is_little_endian); + }; + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +#[inline] +fn _write_buffer_from_iter>( + buffer: I, + arrow_data: &mut Vec, + is_little_endian: bool, +) { + let len = buffer.size_hint().0; + arrow_data.reserve(len * std::mem::size_of::()); + if is_little_endian { + buffer + .map(|x| T::to_le_bytes(&x)) + .for_each(|x| arrow_data.extend_from_slice(x.as_ref())) + } else { + buffer + .map(|x| T::to_be_bytes(&x)) + .for_each(|x| arrow_data.extend_from_slice(x.as_ref())) + } +} + +#[inline] +fn _write_compressed_buffer_from_iter>( + buffer: I, + arrow_data: &mut Vec, + is_little_endian: bool, + compression: Compression, +) { + let len = buffer.size_hint().0; + let mut swapped = Vec::with_capacity(len * std::mem::size_of::()); + if is_little_endian { + buffer + .map(|x| T::to_le_bytes(&x)) + .for_each(|x| swapped.extend_from_slice(x.as_ref())); + } else { + buffer + .map(|x| T::to_be_bytes(&x)) + .for_each(|x| swapped.extend_from_slice(x.as_ref())) + }; + arrow_data.extend_from_slice(&(swapped.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(&swapped, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(&swapped, arrow_data).unwrap(); + }, + } +} + +fn _write_buffer(buffer: &[T], arrow_data: &mut Vec, is_little_endian: bool) { + if is_little_endian == is_native_little_endian() { + // in native endianness we can use the bytes directly. + let buffer = bytemuck::cast_slice(buffer); + arrow_data.extend_from_slice(buffer); + } else { + _write_buffer_from_iter(buffer.iter().copied(), arrow_data, is_little_endian) + } +} + +fn _write_compressed_buffer( + buffer: &[T], + arrow_data: &mut Vec, + is_little_endian: bool, + compression: Compression, +) { + if is_little_endian == is_native_little_endian() { + let bytes = bytemuck::cast_slice(buffer); + arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(bytes, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(bytes, arrow_data).unwrap(); + }, + } + } else { + todo!() + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +#[inline] +fn write_buffer_from_iter>( + buffer: I, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let start = arrow_data.len(); + + if let Some(compression) = compression { + _write_compressed_buffer_from_iter(buffer, arrow_data, is_little_endian, compression); + } else { + _write_buffer_from_iter(buffer, arrow_data, is_little_endian); + } + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +fn finish_buffer(arrow_data: &mut Vec, start: usize, offset: &mut i64) -> ipc::Buffer { + let buffer_len = (arrow_data.len() - start) as i64; + + pad_buffer_to_64(arrow_data, arrow_data.len() - start); + let total_len = (arrow_data.len() - start) as i64; + + let buffer = ipc::Buffer { + offset: *offset, + length: buffer_len, + }; + *offset += total_len; + buffer +} diff --git a/crates/nano-arrow/src/io/ipc/write/stream.rs b/crates/nano-arrow/src/io/ipc/write/stream.rs new file mode 100644 index 000000000000..3fe7e143e02d --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/stream.rs @@ -0,0 +1,113 @@ +//! Arrow IPC File and Stream Writers +//! +//! The `FileWriter` and `StreamWriter` have similar interfaces, +//! however the `FileWriter` expects a reader that supports `Seek`ing + +use std::io::Write; + +use super::super::IpcField; +use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_sync::{write_continuation, write_message}; +use super::{default_ipc_fields, schema_to_bytes}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; + +/// Arrow stream writer +/// +/// The data written by this writer must be read in order. To signal that no more +/// data is arriving through the stream call [`self.finish()`](StreamWriter::finish); +/// +/// For a usage walkthrough consult [this example](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow). +pub struct StreamWriter { + /// The object to write to + writer: W, + /// IPC write options + write_options: WriteOptions, + /// Whether the stream has been finished + finished: bool, + /// Keeps track of dictionaries that have been written + dictionary_tracker: DictionaryTracker, + + ipc_fields: Option>, +} + +impl StreamWriter { + /// Creates a new [`StreamWriter`] + pub fn new(writer: W, write_options: WriteOptions) -> Self { + Self { + writer, + write_options, + finished: false, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }, + ipc_fields: None, + } + } + + /// Starts the stream by writing a Schema message to it. + /// Use `ipc_fields` to declare dictionary ids in the schema, for dictionary-reuse + pub fn start(&mut self, schema: &Schema, ipc_fields: Option>) -> Result<()> { + self.ipc_fields = Some(if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(&schema.fields) + }); + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(schema, self.ipc_fields.as_ref().unwrap()), + arrow_data: vec![], + }; + write_message(&mut self.writer, &encoded_message)?; + Ok(()) + } + + /// Writes [`Chunk`] to the stream + pub fn write( + &mut self, + columns: &Chunk>, + ipc_fields: Option<&[IpcField]>, + ) -> Result<()> { + if self.finished { + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Cannot write to a finished stream".to_string(), + ))); + } + + // we can't make it a closure because it borrows (and it can't borrow mut and non-mut below) + #[allow(clippy::or_fun_call)] + let fields = ipc_fields.unwrap_or(self.ipc_fields.as_ref().unwrap()); + + let (encoded_dictionaries, encoded_message) = encode_chunk( + columns, + fields, + &mut self.dictionary_tracker, + &self.write_options, + )?; + + for encoded_dictionary in encoded_dictionaries { + write_message(&mut self.writer, &encoded_dictionary)?; + } + + write_message(&mut self.writer, &encoded_message)?; + Ok(()) + } + + /// Write continuation bytes, and mark the stream as done + pub fn finish(&mut self) -> Result<()> { + write_continuation(&mut self.writer, 0)?; + + self.finished = true; + + Ok(()) + } + + /// Consumes itself, returning the inner writer. + pub fn into_inner(self) -> W { + self.writer + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/stream_async.rs b/crates/nano-arrow/src/io/ipc/write/stream_async.rs new file mode 100644 index 000000000000..7af62682935a --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/stream_async.rs @@ -0,0 +1,188 @@ +//! `async` writing of arrow streams + +use std::pin::Pin; +use std::task::Poll; + +use futures::future::BoxFuture; +use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink}; + +use super::super::IpcField; +pub use super::common::WriteOptions; +use super::common::{encode_chunk, DictionaryTracker, EncodedData}; +use super::common_async::{write_continuation, write_message}; +use super::{default_ipc_fields, schema_to_bytes, Record}; +use crate::datatypes::*; +use crate::error::{Error, Result}; + +/// A sink that writes array [`chunks`](crate::chunk::Chunk) as an IPC stream. +/// +/// The stream header is automatically written before writing the first chunk. +/// +/// # Examples +/// +/// ``` +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// # use arrow2::io::ipc::write::stream_async::StreamSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = vec![]; +/// let mut sink = StreamSink::new( +/// &mut buffer, +/// &schema, +/// None, +/// Default::default(), +/// ); +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, + dictionary_tracker: DictionaryTracker, + fields: Vec, +} + +impl<'a, W> StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new [`StreamSink`]. + pub fn new( + writer: W, + schema: &Schema, + ipc_fields: Option>, + write_options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let task = Some(Self::start(writer, schema, &fields[..])); + Self { + writer: None, + task, + fields, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }, + options: write_options, + } + } + + fn start( + mut writer: W, + schema: &Schema, + ipc_fields: &[IpcField], + ) -> BoxFuture<'a, Result>> { + let message = EncodedData { + ipc_message: schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], + }; + async move { + write_message(&mut writer, message).await?; + Ok(Some(writer)) + } + .boxed() + } + + fn write(&mut self, record: Record<'_>) -> Result<()> { + let fields = record.fields().unwrap_or(&self.fields[..]); + let (dictionaries, message) = encode_chunk( + record.columns(), + fields, + &mut self.dictionary_tracker, + &self.options, + )?; + + if let Some(mut writer) = self.writer.take() { + self.task = Some( + async move { + for d in dictionaries { + write_message(&mut writer, d).await?; + } + write_message(&mut writer, message).await?; + Ok(Some(writer)) + } + .boxed(), + ); + Ok(()) + } else { + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer closed".to_string(), + ))) + } + } + + fn poll_complete(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(writer) => { + self.writer = writer; + self.task = None; + Poll::Ready(Ok(())) + }, + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + }, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink> for StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Record<'_>) -> Result<()> { + self.get_mut().write(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.poll_complete(cx) { + Poll::Ready(Ok(())) => { + if let Some(mut writer) = this.writer.take() { + this.task = Some( + async move { + write_continuation(&mut writer, 0).await?; + writer.flush().await?; + writer.close().await?; + Ok(None) + } + .boxed(), + ); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + }, + res => res, + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/writer.rs b/crates/nano-arrow/src/io/ipc/write/writer.rs new file mode 100644 index 000000000000..8fcdd2a8bd66 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/writer.rs @@ -0,0 +1,210 @@ +use std::io::Write; + +use arrow_format::ipc::planus::Builder; + +use super::super::{IpcField, ARROW_MAGIC_V2}; +use super::common::{DictionaryTracker, EncodedData, WriteOptions}; +use super::common_sync::{write_continuation, write_message}; +use super::{default_ipc_fields, schema, schema_to_bytes}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::io::ipc::write::common::encode_chunk_amortized; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) enum State { + None, + Started, + Finished, +} + +/// Arrow file writer +pub struct FileWriter { + /// The object to write to + pub(crate) writer: W, + /// IPC write options + pub(crate) options: WriteOptions, + /// A reference to the schema, used in validating record batches + pub(crate) schema: Schema, + pub(crate) ipc_fields: Vec, + /// The number of bytes between each block of bytes, as an offset for random access + pub(crate) block_offsets: usize, + /// Dictionary blocks that will be written as part of the IPC footer + pub(crate) dictionary_blocks: Vec, + /// Record blocks that will be written as part of the IPC footer + pub(crate) record_blocks: Vec, + /// Whether the writer footer has been written, and the writer is finished + pub(crate) state: State, + /// Keeps track of dictionaries that have been written + pub(crate) dictionary_tracker: DictionaryTracker, + /// Buffer/scratch that is reused between writes + pub(crate) encoded_message: EncodedData, +} + +impl FileWriter { + /// Creates a new [`FileWriter`] and writes the header to `writer` + pub fn try_new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Result { + let mut slf = Self::new(writer, schema, ipc_fields, options); + slf.start()?; + + Ok(slf) + } + + /// Creates a new [`FileWriter`]. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(&schema.fields) + }; + + Self { + writer, + options, + schema, + ipc_fields, + block_offsets: 0, + dictionary_blocks: vec![], + record_blocks: vec![], + state: State::None, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: true, + }, + encoded_message: Default::default(), + } + } + + /// Consumes itself into the inner writer + pub fn into_inner(self) -> W { + self.writer + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn get_scratches(&mut self) -> EncodedData { + std::mem::take(&mut self.encoded_message) + } + /// Set the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn set_scratches(&mut self, scratches: EncodedData) { + self.encoded_message = scratches; + } + + /// Writes the header and first (schema) message to the file. + /// # Errors + /// Errors if the file has been started or has finished. + pub fn start(&mut self) -> Result<()> { + if self.state != State::None { + return Err(Error::oos("The IPC file can only be started once")); + } + // write magic to header + self.writer.write_all(&ARROW_MAGIC_V2[..])?; + // create an 8-byte boundary after the header + self.writer.write_all(&[0, 0])?; + // write the schema, set the written bytes to the schema + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(&self.schema, &self.ipc_fields), + arrow_data: vec![], + }; + + let (meta, data) = write_message(&mut self.writer, &encoded_message)?; + self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment + self.state = State::Started; + Ok(()) + } + + /// Writes [`Chunk`] to the file + pub fn write( + &mut self, + chunk: &Chunk>, + ipc_fields: Option<&[IpcField]>, + ) -> Result<()> { + if self.state != State::Started { + return Err(Error::oos( + "The IPC file must be started before it can be written to. Call `start` before `write`", + )); + } + + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + self.ipc_fields.as_ref() + }; + let encoded_dictionaries = encode_chunk_amortized( + chunk, + ipc_fields, + &mut self.dictionary_tracker, + &self.options, + &mut self.encoded_message, + )?; + + // add all dictionaries + for encoded_dictionary in encoded_dictionaries { + let (meta, data) = write_message(&mut self.writer, &encoded_dictionary)?; + + let block = arrow_format::ipc::Block { + offset: self.block_offsets as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + self.dictionary_blocks.push(block); + self.block_offsets += meta + data; + } + + let (meta, data) = write_message(&mut self.writer, &self.encoded_message)?; + // add a record block for the footer + let block = arrow_format::ipc::Block { + offset: self.block_offsets as i64, + meta_data_length: meta as i32, // TODO: is this still applicable? + body_length: data as i64, + }; + self.record_blocks.push(block); + self.block_offsets += meta + data; + Ok(()) + } + + /// Write footer and closing tag, then mark the writer as done + pub fn finish(&mut self) -> Result<()> { + if self.state != State::Started { + return Err(Error::oos( + "The IPC file must be started before it can be finished. Call `start` before `finish`", + )); + } + + // write EOS + write_continuation(&mut self.writer, 0)?; + + let schema = schema::serialize_schema(&self.schema, &self.ipc_fields); + + let root = arrow_format::ipc::Footer { + version: arrow_format::ipc::MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut self.record_blocks)), + custom_metadata: None, + }; + let mut builder = Builder::new(); + let footer_data = builder.finish(&root, None); + self.writer.write_all(footer_data)?; + self.writer + .write_all(&(footer_data.len() as i32).to_le_bytes())?; + self.writer.write_all(&ARROW_MAGIC_V2)?; + self.writer.flush()?; + self.state = State::Finished; + + Ok(()) + } +} diff --git a/crates/nano-arrow/src/io/iterator.rs b/crates/nano-arrow/src/io/iterator.rs new file mode 100644 index 000000000000..91ec86fc2e04 --- /dev/null +++ b/crates/nano-arrow/src/io/iterator.rs @@ -0,0 +1,65 @@ +pub use streaming_iterator::StreamingIterator; + +/// A [`StreamingIterator`] with an internal buffer of [`Vec`] used to efficiently +/// present items of type `T` as `&[u8]`. +/// It is generic over the type `T` and the transformation `F: T -> &[u8]`. +pub struct BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + iterator: I, + f: F, + buffer: Vec, + is_valid: bool, +} + +impl BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + #[inline] + pub fn new(iterator: I, f: F, buffer: Vec) -> Self { + Self { + iterator, + f, + buffer, + is_valid: false, + } + } +} + +impl StreamingIterator for BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + type Item = [u8]; + + #[inline] + fn advance(&mut self) { + let a = self.iterator.next(); + if let Some(a) = a { + self.is_valid = true; + self.buffer.clear(); + (self.f)(a, &mut self.buffer); + } else { + self.is_valid = false; + } + } + + #[inline] + fn get(&self) -> Option<&Self::Item> { + if self.is_valid { + Some(&self.buffer) + } else { + None + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iterator.size_hint() + } +} diff --git a/crates/nano-arrow/src/io/mod.rs b/crates/nano-arrow/src/io/mod.rs new file mode 100644 index 000000000000..72bf37ba9ea5 --- /dev/null +++ b/crates/nano-arrow/src/io/mod.rs @@ -0,0 +1,21 @@ +#![forbid(unsafe_code)] +//! Contains modules to interface with other formats such as [`csv`], +//! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`]. + +#[cfg(feature = "io_ipc")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] +pub mod ipc; + +#[cfg(feature = "io_flight")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_flight")))] +pub mod flight; + +#[cfg(feature = "io_parquet")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_parquet")))] +pub mod parquet; + +#[cfg(feature = "io_avro")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_avro")))] +pub mod avro; + +pub mod iterator; diff --git a/crates/nano-arrow/src/io/parquet/mod.rs b/crates/nano-arrow/src/io/parquet/mod.rs new file mode 100644 index 000000000000..04e5693fcfe6 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/mod.rs @@ -0,0 +1,31 @@ +//! APIs to read from and write to Parquet format. +use crate::error::Error; + +pub mod read; +pub mod write; + +#[cfg(feature = "io_parquet_bloom_filter")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_parquet_bloom_filter")))] +pub use parquet2::bloom_filter; + +const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; + +impl From for Error { + fn from(error: parquet2::error::Error) -> Self { + match error { + parquet2::error::Error::FeatureNotActive(_, _) => { + let message = "Failed to read a compressed parquet file. \ + Use the cargo feature \"io_parquet_compression\" to read compressed parquet files." + .to_string(); + Error::ExternalFormat(message) + }, + _ => Error::ExternalFormat(error.to_string()), + } + } +} + +impl From for parquet2::error::Error { + fn from(error: Error) -> Self { + parquet2::error::Error::OutOfSpec(error.to_string()) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/README.md b/crates/nano-arrow/src/io/parquet/read/README.md new file mode 100644 index 000000000000..c36aaafaf79a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/README.md @@ -0,0 +1,36 @@ +## Observations + +### LSB equivalence between definition levels and bitmaps + +When the maximum repetition level is 0 and the maximum definition level is 1, +the RLE-encoded definition levels correspond exactly to Arrow's bitmap and can be +memcopied without further transformations. + +## Nested parquet groups are deserialized recursively + +Reading a parquet nested field is done by reading each primitive +column sequentially, and build the nested struct recursively. + +Rows of nested parquet groups are encoded in the repetition and definition levels. +In arrow, they correspond to: + +- list's offsets and validity +- struct's validity + +The implementation in this module leverages this observation: + +Nested parquet fields are initially recursed over to gather +whether the type is a Struct or List, and whether it is required or optional, which we store +in `nested_info: Vec>`. `Nested` is a trait object that receives definition +and repetition levels depending on the type and nullability of the nested item. +We process the definition and repetition levels into `nested_info`. + +When we finish a field, we recursively pop from `nested_info` as we build +the `StructArray` or `ListArray`. + +With this approach, the only difference vs flat is: + +1. we do not leverage the bitmap optimization, and instead need to deserialize the repetition + and definition levels to `i32`. +2. we deserialize definition levels twice, once to extend the values/nullability and + one to extend `nested_info`. diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/README.md b/crates/nano-arrow/src/io/parquet/read/deserialize/README.md new file mode 100644 index 000000000000..5b985bac8e9b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/README.md @@ -0,0 +1,71 @@ +# Design + +## Non-nested types + +Let's start with the design used for non-nested arrays. The (private) entry point of this +module for non-nested arrays is `simple::page_iter_to_arrays`. + +This function expects + +- a (fallible) streaming iterator of decompressed and encoded pages, `Pages` +- the source (parquet) column type, including its logical information +- the target (arrow) `DataType` +- the chunk size + +and returns an iterator of `Array`, `ArrayIter`. + +This design is shared among _all_ `(parquet, arrow)` implemented tuples. Their main +difference is how they are deserialized, which depends on the source and target types. + +When the array iterator is pulled the first time, the following happens: + +- a page from `Pages` is pulled +- a `PageState<'a>` is built from the page +- the `PageState` is consumed into a mutable array: + - if `chunk_size` is larger than the number of rows in the page, the mutable array state is preserved and a new page is pulled and the process repeated until we fill a chunk. + - if `chunk_size` is smaller than the number of rows in the page, the mutable array state + is returned and the remaining of the page is consumed into multiple mutable arrays of length `chunk_size` into a FIFO queue. + +Subsequent pulls of arrays will first try to pull from the FIFO queue. Once the queue is empty, the +a new page is pulled. + +### `PageState` + +As mentioned above, the iterator leverages the idea that we attach a state to a page. Recall +that a page is essentially `[header][data]`. The `data` part contains encoded +`[rep levels][def levels][non-null values]`. Some pages have an associated dictionary page, +in which case the `non-null values` represent the indices. + +Irrespectively of the physical type, the main idea is to split the page in two iterators: + +- An iterator over `def levels` +- An iterator over `non-null values` + +and progress the iterators as needed. In particular, for non-nested types, `def levels` is +a bitmap with the same representation as Arrow, in which case the validity is extended directly. + +The `non-null values` are "expanded" by filling null values with the default value of each physical +type. + +## Nested types + +For nested type with N+1 levels (1 is the primitive), we need to build the nest information of each +N levels + the non-nested Arrow array. + +This is done by first transversing the parquet types and using it to initialize, per chunk, the N levels. + +The per-chunk execution is then similar but `chunk_size` only drives the number of retrieved +rows from the outermost parquet group (the field). Each of these pulls knows how many items need +to be pulled from the inner groups, all the way to the primitive type. This works because +in parquet a row cannot be split between two pages and thus each page is guaranteed +to contain a full row. + +The `PageState` of nested types is composed by 4 iterators: + +- A (zipped) iterator over `rep levels` and `def levels` +- An iterator over `def levels` +- An iterator over `non-null values` + +The idea is that an iterator of `rep, def` contain all the information to decode the +nesting structure of an arrow array. The other two iterators are equivalent to the non-nested +types with the exception that `def levels` are no equivalent to arrow bitmaps. diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/basic.rs new file mode 100644 index 000000000000..6008dd9de005 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/basic.rs @@ -0,0 +1,516 @@ +use std::collections::VecDeque; +use std::default::Default; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::{delta_length_byte_array, hybrid_rle, Encoding}; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{ + extend_from_decoder, get_selected_rows, next, DecodedState, FilteredOptionalPageValidity, + MaybeNext, OptionalPageValidity, +}; +use super::super::{utils, Pages}; +use super::utils::*; +use crate::array::{Array, BinaryArray, Utf8Array}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[derive(Debug)] +pub(super) struct Required<'a> { + pub values: SizedBinaryIter<'a>, +} + +impl<'a> Required<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + let values = SizedBinaryIter::new(values, page.num_values()); + + Ok(Self { values }) + } + + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct Delta<'a> { + pub lengths: std::vec::IntoIter, + pub values: &'a [u8], +} + +impl<'a> Delta<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + + let mut lengths_iter = delta_length_byte_array::Decoder::try_new(values)?; + + #[allow(clippy::needless_collect)] // we need to consume it to get the values + let lengths = lengths_iter + .by_ref() + .map(|x| x.map(|x| x as usize).map_err(Error::from)) + .collect::>>()?; + + let values = lengths_iter.into_values(); + Ok(Self { + lengths: lengths.into_iter(), + values, + }) + } + + pub fn len(&self) -> usize { + self.lengths.size_hint().0 + } +} + +impl<'a> Iterator for Delta<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + let length = self.lengths.next()?; + let (item, remaining) = self.values.split_at(length); + self.values = remaining; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + self.lengths.size_hint() + } +} + +#[derive(Debug)] +pub(super) struct FilteredRequired<'a> { + pub values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + pub fn new(page: &'a DataPage) -> Self { + let values = SizedBinaryIter::new(page.buffer(), page.num_values()); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Self { values } + } + + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct FilteredDelta<'a> { + pub values: SliceFilteredIter>, +} + +impl<'a> FilteredDelta<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let values = Delta::try_new(page)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } + + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +pub(super) type Dict = Vec>; + +#[derive(Debug)] +pub(super) struct RequiredDictionary<'a> { + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Dict, +} + +impl<'a> RequiredDictionary<'a> { + pub fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = utils::dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct FilteredRequiredDictionary<'a> { + pub values: SliceFilteredIter>, + pub dict: &'a Dict, +} + +impl<'a> FilteredRequiredDictionary<'a> { + pub fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = utils::dict_indices_decoder(page)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values, dict }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct ValuesDictionary<'a> { + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Dict, +} + +impl<'a> ValuesDictionary<'a> { + pub fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = utils::dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +enum State<'a> { + Optional(OptionalPageValidity<'a>, BinaryIter<'a>), + Required(Required<'a>), + RequiredDictionary(RequiredDictionary<'a>), + OptionalDictionary(OptionalPageValidity<'a>, ValuesDictionary<'a>), + Delta(Delta<'a>), + OptionalDelta(OptionalPageValidity<'a>, Delta<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredDelta(FilteredDelta<'a>), + FilteredOptionalDelta(FilteredOptionalPageValidity<'a>, Delta<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, BinaryIter<'a>), + FilteredRequiredDictionary(FilteredRequiredDictionary<'a>), + FilteredOptionalDictionary(FilteredOptionalPageValidity<'a>, ValuesDictionary<'a>), +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(validity, _) => validity.len(), + State::Required(state) => state.len(), + State::Delta(state) => state.len(), + State::OptionalDelta(state, _) => state.len(), + State::RequiredDictionary(values) => values.len(), + State::OptionalDictionary(optional, _) => optional.len(), + State::FilteredRequired(state) => state.len(), + State::FilteredOptional(validity, _) => validity.len(), + State::FilteredDelta(state) => state.len(), + State::FilteredOptionalDelta(state, _) => state.len(), + State::FilteredRequiredDictionary(values) => values.len(), + State::FilteredOptionalDictionary(optional, _) => optional.len(), + } + } +} + +impl DecodedState for (Binary, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +#[derive(Debug, Default)] +struct BinaryDecoder { + phantom_o: std::marker::PhantomData, +} + +impl<'a, O: Offset> utils::Decoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dict = Dict; + type DecodedState = (Binary, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => Ok( + State::RequiredDictionary(RequiredDictionary::try_new(page, dict)?), + ), + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + Ok(State::OptionalDictionary( + OptionalPageValidity::try_new(page)?, + ValuesDictionary::try_new(page, dict)?, + )) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, true) => { + FilteredRequiredDictionary::try_new(page, dict) + .map(State::FilteredRequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, true) => { + Ok(State::FilteredOptionalDictionary( + FilteredOptionalPageValidity::try_new(page)?, + ValuesDictionary::try_new(page, dict)?, + )) + }, + (Encoding::Plain, _, true, false) => { + let (_, _, values) = split_buffer(page)?; + + let values = BinaryIter::new(values); + + Ok(State::Optional( + OptionalPageValidity::try_new(page)?, + values, + )) + }, + (Encoding::Plain, _, false, false) => Ok(State::Required(Required::try_new(page)?)), + (Encoding::Plain, _, false, true) => { + Ok(State::FilteredRequired(FilteredRequired::new(page))) + }, + (Encoding::Plain, _, true, true) => { + let (_, _, values) = split_buffer(page)?; + + Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + BinaryIter::new(values), + )) + }, + (Encoding::DeltaLengthByteArray, _, false, false) => { + Delta::try_new(page).map(State::Delta) + }, + (Encoding::DeltaLengthByteArray, _, true, false) => Ok(State::OptionalDelta( + OptionalPageValidity::try_new(page)?, + Delta::try_new(page)?, + )), + (Encoding::DeltaLengthByteArray, _, false, true) => { + FilteredDelta::try_new(page).map(State::FilteredDelta) + }, + (Encoding::DeltaLengthByteArray, _, true, true) => Ok(State::FilteredOptionalDelta( + FilteredOptionalPageValidity::try_new(page)?, + Delta::try_new(page)?, + )), + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Binary::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + additional: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page_validity, page_values) => extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + page_values, + ), + State::Required(page) => { + for x in page.values.by_ref().take(additional) { + values.push(x) + } + }, + State::Delta(page) => { + values.extend_lengths(page.lengths.by_ref().take(additional), &mut page.values); + }, + State::OptionalDelta(page_validity, page_values) => { + let Binary { + offsets, + values: values_, + } = values; + + let last_offset = *offsets.last(); + extend_from_decoder( + validity, + page_validity, + Some(additional), + offsets, + page_values.lengths.by_ref(), + ); + + let length = *offsets.last() - last_offset; + + let (consumed, remaining) = page_values.values.split_at(length.to_usize()); + page_values.values = remaining; + values_.extend_from_slice(consumed); + }, + State::FilteredRequired(page) => { + for x in page.values.by_ref().take(additional) { + values.push(x) + } + }, + State::FilteredDelta(page) => { + for x in page.values.by_ref().take(additional) { + values.push(x) + } + }, + State::OptionalDictionary(page_validity, page_values) => { + let page_dict = &page_values.dict; + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + &mut page_values + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()), + ) + }, + State::RequiredDictionary(page) => { + let page_dict = &page.dict; + + for x in page + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()) + .take(additional) + { + values.push(x) + } + }, + State::FilteredOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + page_values.by_ref(), + ); + }, + State::FilteredOptionalDelta(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + page_values.by_ref(), + ); + }, + State::FilteredRequiredDictionary(page) => { + let page_dict = &page.dict; + for x in page + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()) + .take(additional) + { + values.push(x) + } + }, + State::FilteredOptionalDictionary(page_validity, page_values) => { + let page_dict = &page_values.dict; + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + &mut page_values + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()), + ) + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + deserialize_plain(&page.buffer, page.num_values) + } +} + +pub(super) fn finish( + data_type: &DataType, + mut values: Binary, + mut validity: MutableBitmap, +) -> Result> { + values.offsets.shrink_to_fit(); + values.values.shrink_to_fit(); + validity.shrink_to_fit(); + + match data_type.to_physical_type() { + PhysicalType::Binary | PhysicalType::LargeBinary => BinaryArray::::try_new( + data_type.clone(), + values.offsets.into(), + values.values.into(), + validity.into(), + ) + .map(|x| x.boxed()), + PhysicalType::Utf8 | PhysicalType::LargeUtf8 => Utf8Array::::try_new( + data_type.clone(), + values.offsets.into(), + values.values.into(), + validity.into(), + ) + .map(|x| x.boxed()), + _ => unreachable!(), + } +} + +pub struct Iter { + iter: I, + data_type: DataType, + items: VecDeque<(Binary, MutableBitmap)>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl Iter { + pub fn new(iter: I, data_type: DataType, chunk_size: Option, num_rows: usize) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for Iter { + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &BinaryDecoder::::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((values, validity))) => { + Some(finish(&self.data_type, values, validity)) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +pub(super) fn deserialize_plain(values: &[u8], num_values: usize) -> Dict { + SizedBinaryIter::new(values, num_values) + .map(|x| x.to_vec()) + .collect() +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/dictionary.rs new file mode 100644 index 000000000000..0fb3615de050 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/dictionary.rs @@ -0,0 +1,174 @@ +use std::collections::VecDeque; + +use parquet2::page::DictPage; + +use super::super::dictionary::*; +use super::super::utils::MaybeNext; +use super::super::Pages; +use super::utils::{Binary, SizedBinaryIter}; +use crate::array::{Array, BinaryArray, DictionaryArray, DictionaryKey, Utf8Array}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Result; +use crate::io::parquet::read::deserialize::nested_utils::{InitNested, NestedState}; +use crate::offset::Offset; + +/// An iterator adapter over [`Pages`] assumed to be encoded as parquet's dictionary-encoded binary representation +#[derive(Debug)] +pub struct DictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + iter: I, + data_type: DataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + phantom: std::marker::PhantomData, +} + +impl DictIter +where + K: DictionaryKey, + O: Offset, + I: Pages, +{ + pub fn new(iter: I, data_type: DataType, num_rows: usize, chunk_size: Option) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + phantom: std::marker::PhantomData, + } + } +} + +fn read_dict(data_type: DataType, dict: &DictPage) -> Box { + let data_type = match data_type { + DataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + + let values = SizedBinaryIter::new(&dict.buffer, dict.num_values); + + let mut data = Binary::::with_capacity(dict.num_values); + data.values = Vec::with_capacity(dict.buffer.len() - 4 * dict.num_values); + for item in values { + data.push(item) + } + + match data_type.to_physical_type() { + PhysicalType::Utf8 | PhysicalType::LargeUtf8 => { + Utf8Array::::new(data_type, data.offsets.into(), data.values.into(), None).boxed() + }, + PhysicalType::Binary | PhysicalType::LargeBinary => { + BinaryArray::::new(data_type, data.offsets.into(), data.values.into(), None).boxed() + }, + _ => unreachable!(), + } +} + +impl Iterator for DictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`] +#[derive(Debug)] +pub struct NestedDictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + iter: I, + init: Vec, + data_type: DataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, + phantom: std::marker::PhantomData, +} + +impl NestedDictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + phantom: Default::default(), + } + } +} + +impl Iterator for NestedDictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + type Item = Result<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/mod.rs new file mode 100644 index 000000000000..c48bfe276bcc --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/mod.rs @@ -0,0 +1,8 @@ +mod basic; +mod dictionary; +mod nested; +mod utils; + +pub use basic::Iter; +pub use dictionary::{DictIter, NestedDictIter}; +pub use nested::NestedIter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/nested.rs new file mode 100644 index 000000000000..43ea90161cd7 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/nested.rs @@ -0,0 +1,191 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::nested_utils::*; +use super::super::utils; +use super::super::utils::MaybeNext; +use super::basic::{deserialize_plain, finish, Dict, ValuesDictionary}; +use super::utils::*; +use crate::array::Array; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::Pages; +use crate::offset::Offset; + +#[derive(Debug)] +enum State<'a> { + Optional(BinaryIter<'a>), + Required(BinaryIter<'a>), + RequiredDictionary(ValuesDictionary<'a>), + OptionalDictionary(ValuesDictionary<'a>), +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(validity) => validity.size_hint().0, + State::Required(state) => state.size_hint().0, + State::RequiredDictionary(required) => required.len(), + State::OptionalDictionary(optional) => optional.len(), + } + } +} + +#[derive(Debug, Default)] +struct BinaryDecoder { + phantom_o: std::marker::PhantomData, +} + +impl<'a, O: Offset> NestedDecoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dictionary = Dict; + type DecodedState = (Binary, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + ValuesDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + ValuesDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + (Encoding::Plain, _, true, false) => { + let (_, _, values) = split_buffer(page)?; + + let values = BinaryIter::new(values); + + Ok(State::Optional(values)) + }, + (Encoding::Plain, _, false, false) => { + let (_, _, values) = split_buffer(page)?; + + let values = BinaryIter::new(values); + + Ok(State::Required(values)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Binary::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page) => { + let value = page.next().unwrap_or_default(); + values.push(value); + validity.push(true); + }, + State::Required(page) => { + let value = page.next().unwrap_or_default(); + values.push(value); + }, + State::RequiredDictionary(page) => { + let dict_values = &page.dict; + let item = page + .values + .next() + .map(|index| dict_values[index.unwrap() as usize].as_ref()) + .unwrap_or_default(); + values.push(item); + }, + State::OptionalDictionary(page) => { + let dict_values = &page.dict; + let item = page + .values + .next() + .map(|index| dict_values[index.unwrap() as usize].as_ref()) + .unwrap_or_default(); + values.push(item); + validity.push(true); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(&[]); + validity.push(false); + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + deserialize_plain(&page.buffer, page.num_values) + } +} + +pub struct NestedIter { + iter: I, + data_type: DataType, + init: Vec, + items: VecDeque<(NestedState, (Binary, MutableBitmap))>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl NestedIter { + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + data_type, + init, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for NestedIter { + type Item = Result<(NestedState, Box)>; + + fn next(&mut self) -> Option { + loop { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + &self.init, + self.chunk_size, + &BinaryDecoder::::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((nested, decoded))) => { + return Some( + finish(&self.data_type, decoded.0, decoded.1).map(|array| (nested, array)), + ) + }, + MaybeNext::Some(Err(e)) => return Some(Err(e)), + MaybeNext::None => return None, + MaybeNext::More => continue, // Using continue in a loop instead of calling next helps prevent stack overflow. + } + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/utils.rs new file mode 100644 index 000000000000..961268db2beb --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/utils.rs @@ -0,0 +1,169 @@ +use super::super::utils::Pushable; +use crate::offset::{Offset, Offsets}; + +/// [`Pushable`] for variable length binary data. +#[derive(Debug)] +pub struct Binary { + pub offsets: Offsets, + pub values: Vec, +} + +impl Pushable for Offsets { + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + #[inline] + fn len(&self) -> usize { + self.len_proxy() + } + + #[inline] + fn push(&mut self, value: usize) { + self.try_push(value).unwrap() + } + + #[inline] + fn push_null(&mut self) { + self.extend_constant(1); + } + + #[inline] + fn extend_constant(&mut self, additional: usize, _: usize) { + self.extend_constant(additional) + } +} + +impl Binary { + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + offsets: Offsets::with_capacity(capacity), + values: Vec::with_capacity(capacity.min(100) * 24), + } + } + + #[inline] + pub fn push(&mut self, v: &[u8]) { + if self.offsets.len_proxy() == 100 && self.offsets.capacity() > 100 { + let bytes_per_row = self.values.len() / 100 + 1; + let bytes_estimate = bytes_per_row * self.offsets.capacity(); + if bytes_estimate > self.values.capacity() { + self.values.reserve(bytes_estimate - self.values.capacity()); + } + } + + self.values.extend(v); + self.offsets.try_push(v.len()).unwrap() + } + + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + } + + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + #[inline] + pub fn extend_lengths>(&mut self, lengths: I, values: &mut &[u8]) { + let current_offset = *self.offsets.last(); + self.offsets.try_extend_from_lengths(lengths).unwrap(); + let new_offset = *self.offsets.last(); + let length = new_offset.to_usize() - current_offset.to_usize(); + let (consumed, remaining) = values.split_at(length); + *values = remaining; + self.values.extend_from_slice(consumed); + } +} + +impl<'a, O: Offset> Pushable<&'a [u8]> for Binary { + #[inline] + fn reserve(&mut self, additional: usize) { + let avg_len = self.values.len() / std::cmp::max(self.offsets.last().to_usize(), 1); + self.values.reserve(additional * avg_len); + self.offsets.reserve(additional); + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push_null(&mut self) { + self.push(&[]) + } + + #[inline] + fn push(&mut self, value: &[u8]) { + self.push(value) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: &[u8]) { + assert_eq!(value.len(), 0); + self.extend_constant(additional) + } +} + +#[derive(Debug)] +pub struct BinaryIter<'a> { + values: &'a [u8], +} + +impl<'a> BinaryIter<'a> { + pub fn new(values: &'a [u8]) -> Self { + Self { values } + } +} + +impl<'a> Iterator for BinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + if self.values.is_empty() { + return None; + } + let (length, remaining) = self.values.split_at(4); + let length = u32::from_le_bytes(length.try_into().unwrap()) as usize; + let (result, remaining) = remaining.split_at(length); + self.values = remaining; + Some(result) + } +} + +#[derive(Debug)] +pub struct SizedBinaryIter<'a> { + iter: BinaryIter<'a>, + remaining: usize, +} + +impl<'a> SizedBinaryIter<'a> { + pub fn new(values: &'a [u8], size: usize) -> Self { + let iter = BinaryIter::new(values); + Self { + iter, + remaining: size, + } + } +} + +impl<'a> Iterator for SizedBinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } else { + self.remaining -= 1 + }; + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/basic.rs new file mode 100644 index 000000000000..dd3ac9eb52c5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/basic.rs @@ -0,0 +1,229 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{ + extend_from_decoder, get_selected_rows, next, DecodedState, Decoder, + FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, +}; +use super::super::{utils, Pages}; +use crate::array::BooleanArray; +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +#[derive(Debug)] +struct Values<'a>(BitmapIter<'a>); + +impl<'a> Values<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + + Ok(Self(BitmapIter::new(values, 0, values.len() * 8))) + } +} + +// The state of a required DataPage with a boolean physical type +#[derive(Debug)] +struct Required<'a> { + values: &'a [u8], + // invariant: offset <= length; + offset: usize, + length: usize, +} + +impl<'a> Required<'a> { + pub fn new(page: &'a DataPage) -> Self { + Self { + values: page.buffer(), + offset: 0, + length: page.num_values(), + } + } +} + +#[derive(Debug)] +struct FilteredRequired<'a> { + values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + // todo: replace this by an iterator over slices, for faster deserialization + let values = BitmapIter::new(values, 0, page.num_values()); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +// The state of a `DataPage` of `Boolean` parquet boolean type +#[derive(Debug)] +enum State<'a> { + Optional(OptionalPageValidity<'a>, Values<'a>), + Required(Required<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), +} + +impl<'a> State<'a> { + pub fn len(&self) -> usize { + match self { + State::Optional(validity, _) => validity.len(), + State::Required(page) => page.length - page.offset, + State::FilteredRequired(page) => page.len(), + State::FilteredOptional(optional, _) => optional.len(), + } + } +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + self.len() + } +} + +impl DecodedState for (MutableBitmap, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +#[derive(Default)] +struct BooleanDecoder {} + +impl<'a> Decoder<'a> for BooleanDecoder { + type State = State<'a>; + type Dict = (); + type DecodedState = (MutableBitmap, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, _: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::Plain, true, false) => Ok(State::Optional( + OptionalPageValidity::try_new(page)?, + Values::try_new(page)?, + )), + (Encoding::Plain, false, false) => Ok(State::Required(Required::new(page))), + (Encoding::Plain, true, true) => Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + Values::try_new(page)?, + )), + (Encoding::Plain, false, true) => { + Ok(State::FilteredRequired(FilteredRequired::try_new(page)?)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + MutableBitmap::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page_validity, page_values) => extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.0, + ), + State::Required(page) => { + let remaining = remaining.min(page.length - page.offset); + values.extend_from_slice(page.values, page.offset, remaining); + page.offset += remaining; + }, + State::FilteredRequired(page) => { + values.reserve(remaining); + for item in page.values.by_ref().take(remaining) { + values.push(item) + } + }, + State::FilteredOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.0.by_ref(), + ); + }, + } + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dict {} +} + +fn finish(data_type: &DataType, values: MutableBitmap, validity: MutableBitmap) -> BooleanArray { + BooleanArray::new(data_type.clone(), values.into(), validity.into()) +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct Iter { + iter: I, + data_type: DataType, + items: VecDeque<(MutableBitmap, MutableBitmap)>, + chunk_size: Option, + remaining: usize, +} + +impl Iter { + pub fn new(iter: I, data_type: DataType, chunk_size: Option, num_rows: usize) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for Iter { + type Item = Result; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut None, + &mut self.remaining, + self.chunk_size, + &BooleanDecoder::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/mod.rs new file mode 100644 index 000000000000..dc00cc2a4249 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use nested::NestedIter; + +pub use self::basic::Iter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/nested.rs new file mode 100644 index 000000000000..f3e684ab9fe3 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/nested.rs @@ -0,0 +1,153 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::nested_utils::*; +use super::super::utils::MaybeNext; +use super::super::{utils, Pages}; +use crate::array::BooleanArray; +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +// The state of a `DataPage` of `Boolean` parquet boolean type +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +enum State<'a> { + Optional(BitmapIter<'a>), + Required(BitmapIter<'a>), +} + +impl<'a> State<'a> { + pub fn len(&self) -> usize { + match self { + State::Optional(iter) => iter.size_hint().0, + State::Required(iter) => iter.size_hint().0, + } + } +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + self.len() + } +} + +#[derive(Default)] +struct BooleanDecoder {} + +impl<'a> NestedDecoder<'a> for BooleanDecoder { + type State = State<'a>; + type Dictionary = (); + type DecodedState = (MutableBitmap, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + _: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::Plain, true, false) => { + let (_, _, values) = split_buffer(page)?; + let values = BitmapIter::new(values, 0, values.len() * 8); + + Ok(State::Optional(values)) + }, + (Encoding::Plain, false, false) => { + let (_, _, values) = split_buffer(page)?; + let values = BitmapIter::new(values, 0, values.len() * 8); + + Ok(State::Required(values)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + MutableBitmap::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page_values) => { + let value = page_values.next().unwrap_or_default(); + values.push(value); + validity.push(true); + }, + State::Required(page_values) => { + let value = page_values.next().unwrap_or_default(); + values.push(value); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(false); + validity.push(false); + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dictionary {} +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct NestedIter { + iter: I, + init: Vec, + items: VecDeque<(NestedState, (MutableBitmap, MutableBitmap))>, + remaining: usize, + chunk_size: Option, +} + +impl NestedIter { + pub fn new(iter: I, init: Vec, num_rows: usize, chunk_size: Option) -> Self { + Self { + iter, + init, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + } + } +} + +fn finish(data_type: &DataType, values: MutableBitmap, validity: MutableBitmap) -> BooleanArray { + BooleanArray::new(data_type.clone(), values.into(), validity.into()) +} + +impl Iterator for NestedIter { + type Item = Result<(NestedState, BooleanArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut None, + &mut self.remaining, + &self.init, + self.chunk_size, + &BooleanDecoder::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((nested, (values, validity)))) => { + Some(Ok((nested, finish(&DataType::Boolean, values, validity)))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/mod.rs new file mode 100644 index 000000000000..7826f5856c0e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/mod.rs @@ -0,0 +1,314 @@ +mod nested; + +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::hybrid_rle::HybridRleDecoder; +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage, Page}; +use parquet2::schema::Repetition; + +use super::utils::{ + self, dict_indices_decoder, extend_from_decoder, get_selected_rows, DecodedState, Decoder, + FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, +}; +use super::Pages; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +// The state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +pub enum State<'a> { + Optional(Optional<'a>), + Required(Required<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, HybridRleDecoder<'a>), +} + +#[derive(Debug)] +pub struct Required<'a> { + values: HybridRleDecoder<'a>, +} + +impl<'a> Required<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + Ok(Self { values }) + } +} + +#[derive(Debug)] +pub struct FilteredRequired<'a> { + values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } +} + +#[derive(Debug)] +pub struct Optional<'a> { + values: HybridRleDecoder<'a>, + validity: OptionalPageValidity<'a>, +} + +impl<'a> Optional<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + + Ok(Self { + values, + validity: OptionalPageValidity::try_new(page)?, + }) + } +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(optional) => optional.validity.len(), + State::Required(required) => required.values.size_hint().0, + State::FilteredRequired(required) => required.values.size_hint().0, + State::FilteredOptional(validity, _) => validity.len(), + } + } +} + +#[derive(Debug)] +pub struct PrimitiveDecoder +where + K: DictionaryKey, +{ + phantom_k: std::marker::PhantomData, +} + +impl Default for PrimitiveDecoder +where + K: DictionaryKey, +{ + #[inline] + fn default() -> Self { + Self { + phantom_k: std::marker::PhantomData, + } + } +} + +impl<'a, K> utils::Decoder<'a> for PrimitiveDecoder +where + K: DictionaryKey, +{ + type State = State<'a>; + type Dict = (); + type DecodedState = (Vec, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, _: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, false, false) => { + Required::try_new(page).map(State::Required) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, true, false) => { + Optional::try_new(page).map(State::Optional) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, false, true) => { + FilteredRequired::try_new(page).map(State::FilteredRequired) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, true, true) => { + Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + dict_indices_decoder(page)?, + )) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page) => extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + &mut page.values.by_ref().map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => panic!("The maximum key is too small"), + } + }), + ), + State::Required(page) => { + values.extend( + page.values + .by_ref() + .map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }) + .take(remaining), + ); + }, + State::FilteredOptional(page_validity, page_values) => extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.by_ref().map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }), + ), + State::FilteredRequired(page) => { + values.extend( + page.values + .by_ref() + .map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }) + .take(remaining), + ); + }, + } + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dict {} +} + +fn finish_key(values: Vec, validity: MutableBitmap) -> PrimitiveArray { + PrimitiveArray::new(K::PRIMITIVE.into(), values.into(), validity.into()) +} + +#[inline] +pub(super) fn next_dict Box>( + iter: &mut I, + items: &mut VecDeque<(Vec, MutableBitmap)>, + dict: &mut Option>, + data_type: DataType, + remaining: &mut usize, + chunk_size: Option, + read_dict: F, +) -> MaybeNext>> { + if items.len() > 1 { + let (values, validity) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + return MaybeNext::Some(DictionaryArray::try_new( + data_type, + keys, + dict.clone().unwrap(), + )); + } + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(Some(page)) => { + let (page, dict) = match (&dict, page) { + (None, Page::Data(_)) => { + return MaybeNext::Some(Err(Error::nyi( + "dictionary arrays from non-dict-encoded pages", + ))); + }, + (_, Page::Dict(dict_page)) => { + *dict = Some(read_dict(dict_page)); + return next_dict( + iter, items, dict, data_type, remaining, chunk_size, read_dict, + ); + }, + (Some(dict), Page::Data(page)) => (page, dict), + }; + + // there is a new page => consume the page from the start + let maybe_page = PrimitiveDecoder::::default().build_state(page, None); + let page = match maybe_page { + Ok(page) => page, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + utils::extend_from_new_page( + page, + chunk_size, + items, + remaining, + &PrimitiveDecoder::::default(), + ); + + if items.front().unwrap().len() < chunk_size.unwrap_or(usize::MAX) { + MaybeNext::More + } else { + let (values, validity) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + MaybeNext::Some(DictionaryArray::try_new(data_type, keys, dict.clone())) + } + }, + Ok(None) => { + if let Some((values, validity)) = items.pop_front() { + // we have a populated item and no more pages + // the only case where an item's length may be smaller than chunk_size + debug_assert!(values.len() <= chunk_size.unwrap_or(usize::MAX)); + + let keys = finish_key(values, validity); + MaybeNext::Some(DictionaryArray::try_new( + data_type, + keys, + dict.clone().unwrap(), + )) + } else { + MaybeNext::None + } + }, + } +} + +pub use nested::next_dict as nested_next_dict; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/nested.rs new file mode 100644 index 000000000000..1fb1919d1504 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/nested.rs @@ -0,0 +1,213 @@ +use std::collections::VecDeque; + +use parquet2::encoding::hybrid_rle::HybridRleDecoder; +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage, Page}; +use parquet2::schema::Repetition; + +use super::super::super::Pages; +use super::super::nested_utils::*; +use super::super::utils::{dict_indices_decoder, not_implemented, MaybeNext, PageState}; +use super::finish_key; +use crate::array::{Array, DictionaryArray, DictionaryKey}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +// The state of a required DataPage with a boolean physical type +#[derive(Debug)] +pub struct Required<'a> { + values: HybridRleDecoder<'a>, + length: usize, +} + +impl<'a> Required<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + let length = page.num_values(); + Ok(Self { values, length }) + } +} + +// The state of a `DataPage` of a `Dictionary` type +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum State<'a> { + Optional(HybridRleDecoder<'a>), + Required(Required<'a>), +} + +impl<'a> State<'a> { + pub fn len(&self) -> usize { + match self { + State::Optional(page) => page.len(), + State::Required(page) => page.length, + } + } +} + +impl<'a> PageState<'a> for State<'a> { + fn len(&self) -> usize { + self.len() + } +} + +#[derive(Debug)] +pub struct DictionaryDecoder +where + K: DictionaryKey, +{ + phantom_k: std::marker::PhantomData, +} + +impl Default for DictionaryDecoder +where + K: DictionaryKey, +{ + #[inline] + fn default() -> Self { + Self { + phantom_k: std::marker::PhantomData, + } + } +} + +impl<'a, K: DictionaryKey> NestedDecoder<'a> for DictionaryDecoder { + type State = State<'a>; + type Dictionary = (); + type DecodedState = (Vec, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + _: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::RleDictionary | Encoding::PlainDictionary, true, false) => { + dict_indices_decoder(page).map(State::Optional) + }, + (Encoding::RleDictionary | Encoding::PlainDictionary, false, false) => { + Required::try_new(page).map(State::Required) + }, + _ => Err(not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page_values) => { + let key = page_values.next().transpose()?; + // todo: convert unwrap to error + let key = match K::try_from(key.unwrap_or_default() as usize) { + Ok(key) => key, + Err(_) => todo!(), + }; + values.push(key); + validity.push(true); + }, + State::Required(page_values) => { + let key = page_values.values.next().transpose()?; + let key = match K::try_from(key.unwrap_or_default() as usize) { + Ok(key) => key, + Err(_) => todo!(), + }; + values.push(key); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(K::default()); + validity.push(false) + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dictionary {} +} + +#[allow(clippy::too_many_arguments)] +pub fn next_dict Box>( + iter: &mut I, + items: &mut VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: &mut usize, + init: &[InitNested], + dict: &mut Option>, + data_type: DataType, + chunk_size: Option, + read_dict: F, +) -> MaybeNext)>> { + if items.len() > 1 { + let (nested, (values, validity)) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + let dict = DictionaryArray::try_new(data_type, keys, dict.clone().unwrap()); + return MaybeNext::Some(dict.map(|dict| (nested, dict))); + } + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(Some(page)) => { + let (page, dict) = match (&dict, page) { + (None, Page::Data(_)) => { + return MaybeNext::Some(Err(Error::nyi( + "dictionary arrays from non-dict-encoded pages", + ))); + }, + (_, Page::Dict(dict_page)) => { + *dict = Some(read_dict(dict_page)); + return next_dict( + iter, items, remaining, init, dict, data_type, chunk_size, read_dict, + ); + }, + (Some(dict), Page::Data(page)) => (page, dict), + }; + + let error = extend( + page, + init, + items, + None, + remaining, + &DictionaryDecoder::::default(), + chunk_size, + ); + match error { + Ok(_) => {}, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + if items.front().unwrap().0.len() < chunk_size.unwrap_or(usize::MAX) { + MaybeNext::More + } else { + let (nested, (values, validity)) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + let dict = DictionaryArray::try_new(data_type, keys, dict.clone()); + MaybeNext::Some(dict.map(|dict| (nested, dict))) + } + }, + Ok(None) => { + if let Some((nested, (values, validity))) = items.pop_front() { + // we have a populated item and no more pages + // the only case where an item's length may be smaller than chunk_size + debug_assert!(values.len() <= chunk_size.unwrap_or(usize::MAX)); + + let keys = finish_key(values, validity); + let dict = DictionaryArray::try_new(data_type, keys, dict.clone().unwrap()); + MaybeNext::Some(dict.map(|dict| (nested, dict))) + } else { + MaybeNext::None + } + }, + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/basic.rs new file mode 100644 index 000000000000..aee3116ed64e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/basic.rs @@ -0,0 +1,322 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::{hybrid_rle, Encoding}; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{ + dict_indices_decoder, extend_from_decoder, get_selected_rows, next, not_implemented, + DecodedState, Decoder, FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, + PageState, Pushable, +}; +use super::super::Pages; +use super::utils::FixedSizeBinary; +use crate::array::FixedSizeBinaryArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +pub(super) type Dict = Vec; + +#[derive(Debug)] +pub(super) struct Optional<'a> { + pub(super) values: std::slice::ChunksExact<'a, u8>, + pub(super) validity: OptionalPageValidity<'a>, +} + +impl<'a> Optional<'a> { + pub(super) fn try_new(page: &'a DataPage, size: usize) -> Result { + let (_, _, values) = split_buffer(page)?; + + let values = values.chunks_exact(size); + + Ok(Self { + values, + validity: OptionalPageValidity::try_new(page)?, + }) + } +} + +#[derive(Debug)] +pub(super) struct Required<'a> { + pub values: std::slice::ChunksExact<'a, u8>, +} + +impl<'a> Required<'a> { + pub(super) fn new(page: &'a DataPage, size: usize) -> Self { + let values = page.buffer(); + assert_eq!(values.len() % size, 0); + let values = values.chunks_exact(size); + Self { values } + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct FilteredRequired<'a> { + pub values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + fn new(page: &'a DataPage, size: usize) -> Self { + let values = page.buffer(); + assert_eq!(values.len() % size, 0); + let values = values.chunks_exact(size); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Self { values } + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct RequiredDictionary<'a> { + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Dict, +} + +impl<'a> RequiredDictionary<'a> { + pub(super) fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct OptionalDictionary<'a> { + pub(super) values: hybrid_rle::HybridRleDecoder<'a>, + pub(super) validity: OptionalPageValidity<'a>, + pub(super) dict: &'a Dict, +} + +impl<'a> OptionalDictionary<'a> { + pub(super) fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = dict_indices_decoder(page)?; + + Ok(Self { + values, + validity: OptionalPageValidity::try_new(page)?, + dict, + }) + } +} + +#[derive(Debug)] +enum State<'a> { + Optional(Optional<'a>), + Required(Required<'a>), + RequiredDictionary(RequiredDictionary<'a>), + OptionalDictionary(OptionalDictionary<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredOptional( + FilteredOptionalPageValidity<'a>, + std::slice::ChunksExact<'a, u8>, + ), +} + +impl<'a> PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(state) => state.validity.len(), + State::Required(state) => state.len(), + State::RequiredDictionary(state) => state.len(), + State::OptionalDictionary(state) => state.validity.len(), + State::FilteredRequired(state) => state.len(), + State::FilteredOptional(state, _) => state.len(), + } + } +} + +struct BinaryDecoder { + size: usize, +} + +impl DecodedState for (FixedSizeBinary, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +impl<'a> Decoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dict = Dict; + type DecodedState = (FixedSizeBinary, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::Plain, _, true, false) => { + Ok(State::Optional(Optional::try_new(page, self.size)?)) + }, + (Encoding::Plain, _, false, false) => { + Ok(State::Required(Required::new(page, self.size))) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + RequiredDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + OptionalDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + (Encoding::Plain, None, false, true) => Ok(State::FilteredRequired( + FilteredRequired::new(page, self.size), + )), + (Encoding::Plain, _, true, true) => { + let (_, _, values) = split_buffer(page)?; + + Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + values.chunks_exact(self.size), + )) + }, + _ => Err(not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + FixedSizeBinary::with_capacity(capacity, self.size), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page) => extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + &mut page.values, + ), + State::Required(page) => { + for x in page.values.by_ref().take(remaining) { + values.push(x) + } + }, + State::FilteredRequired(page) => { + for x in page.values.by_ref().take(remaining) { + values.push(x) + } + }, + State::OptionalDictionary(page) => extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + page.values.by_ref().map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }), + ), + State::RequiredDictionary(page) => { + for x in page + .values + .by_ref() + .map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }) + .take(remaining) + { + values.push(x) + } + }, + State::FilteredOptional(page_validity, page_values) => { + extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.by_ref(), + ); + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + page.buffer.clone() + } +} + +pub fn finish( + data_type: &DataType, + values: FixedSizeBinary, + validity: MutableBitmap, +) -> FixedSizeBinaryArray { + FixedSizeBinaryArray::new(data_type.clone(), values.values.into(), validity.into()) +} + +pub struct Iter { + iter: I, + data_type: DataType, + size: usize, + items: VecDeque<(FixedSizeBinary, MutableBitmap)>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl Iter { + pub fn new(iter: I, data_type: DataType, num_rows: usize, chunk_size: Option) -> Self { + let size = FixedSizeBinaryArray::get_size(&data_type); + Self { + iter, + data_type, + size, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for Iter { + type Item = Result; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &BinaryDecoder { size: self.size }, + ); + match maybe_state { + MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs new file mode 100644 index 000000000000..3f5455b0bdb8 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs @@ -0,0 +1,150 @@ +use std::collections::VecDeque; + +use parquet2::page::DictPage; + +use super::super::dictionary::*; +use super::super::utils::MaybeNext; +use super::super::Pages; +use crate::array::{Array, DictionaryArray, DictionaryKey, FixedSizeBinaryArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::deserialize::nested_utils::{InitNested, NestedState}; + +/// An iterator adapter over [`Pages`] assumed to be encoded as parquet's dictionary-encoded binary representation +#[derive(Debug)] +pub struct DictIter +where + I: Pages, + K: DictionaryKey, +{ + iter: I, + data_type: DataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, +} + +impl DictIter +where + K: DictionaryKey, + I: Pages, +{ + pub fn new(iter: I, data_type: DataType, num_rows: usize, chunk_size: Option) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + } + } +} + +fn read_dict(data_type: DataType, dict: &DictPage) -> Box { + let data_type = match data_type { + DataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + + let values = dict.buffer.clone(); + + FixedSizeBinaryArray::try_new(data_type, values.into(), None) + .unwrap() + .boxed() +} + +impl Iterator for DictIter +where + I: Pages, + K: DictionaryKey, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`]. +#[derive(Debug)] +pub struct NestedDictIter +where + I: Pages, + K: DictionaryKey, +{ + iter: I, + init: Vec, + data_type: DataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, +} + +impl NestedDictIter +where + I: Pages, + K: DictionaryKey, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + remaining: num_rows, + items: VecDeque::new(), + chunk_size, + } + } +} + +impl Iterator for NestedDictIter +where + I: Pages, + K: DictionaryKey, +{ + type Item = Result<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..c48bfe276bcc --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/mod.rs @@ -0,0 +1,8 @@ +mod basic; +mod dictionary; +mod nested; +mod utils; + +pub use basic::Iter; +pub use dictionary::{DictIter, NestedDictIter}; +pub use nested::NestedIter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs new file mode 100644 index 000000000000..f2b65380baad --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs @@ -0,0 +1,189 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{not_implemented, MaybeNext, PageState}; +use super::utils::FixedSizeBinary; +use crate::array::FixedSizeBinaryArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::deserialize::fixed_size_binary::basic::{ + finish, Dict, Optional, OptionalDictionary, Required, RequiredDictionary, +}; +use crate::io::parquet::read::deserialize::nested_utils::{next, NestedDecoder}; +use crate::io::parquet::read::deserialize::utils::Pushable; +use crate::io::parquet::read::{InitNested, NestedState, Pages}; + +#[derive(Debug)] +enum State<'a> { + Optional(Optional<'a>), + Required(Required<'a>), + RequiredDictionary(RequiredDictionary<'a>), + OptionalDictionary(OptionalDictionary<'a>), +} + +impl<'a> PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(state) => state.validity.len(), + State::Required(state) => state.len(), + State::RequiredDictionary(state) => state.len(), + State::OptionalDictionary(state) => state.validity.len(), + } + } +} + +#[derive(Debug, Default)] +struct BinaryDecoder { + size: usize, +} + +impl<'a> NestedDecoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dictionary = Dict; + type DecodedState = (FixedSizeBinary, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::Plain, _, true, false) => { + Ok(State::Optional(Optional::try_new(page, self.size)?)) + }, + (Encoding::Plain, _, false, false) => { + Ok(State::Required(Required::new(page, self.size))) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + RequiredDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + OptionalDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + _ => Err(not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + FixedSizeBinary::with_capacity(capacity, self.size), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page) => { + let value = page.values.by_ref().next().unwrap_or_default(); + values.push(value); + validity.push(true); + }, + State::Required(page) => { + let value = page.values.by_ref().next().unwrap_or_default(); + values.push(value); + }, + State::RequiredDictionary(page) => { + let item = page + .values + .by_ref() + .next() + .map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }) + .unwrap_or_default(); + values.push(item); + }, + State::OptionalDictionary(page) => { + let item = page + .values + .by_ref() + .next() + .map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }) + .unwrap_or_default(); + values.push(item); + validity.push(true); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push_null(); + validity.push(false); + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + page.buffer.clone() + } +} + +pub struct NestedIter { + iter: I, + data_type: DataType, + size: usize, + init: Vec, + items: VecDeque<(NestedState, (FixedSizeBinary, MutableBitmap))>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl NestedIter { + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + let size = FixedSizeBinaryArray::get_size(&data_type); + Self { + iter, + data_type, + size, + init, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for NestedIter { + type Item = Result<(NestedState, FixedSizeBinaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + &self.init, + self.chunk_size, + &BinaryDecoder { size: self.size }, + ); + match maybe_state { + MaybeNext::Some(Ok((nested, decoded))) => { + Some(Ok((nested, finish(&self.data_type, decoded.0, decoded.1)))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/utils.rs new file mode 100644 index 000000000000..f718ce1bdc2b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/utils.rs @@ -0,0 +1,58 @@ +use super::super::utils::Pushable; + +/// A [`Pushable`] for fixed sized binary data +#[derive(Debug)] +pub struct FixedSizeBinary { + pub values: Vec, + pub size: usize, +} + +impl FixedSizeBinary { + #[inline] + pub fn with_capacity(capacity: usize, size: usize) -> Self { + Self { + values: Vec::with_capacity(capacity * size), + size, + } + } + + #[inline] + pub fn push(&mut self, value: &[u8]) { + debug_assert_eq!(value.len(), self.size); + self.values.extend(value); + } + + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + self.values + .resize(self.values.len() + additional * self.size, 0); + } +} + +impl<'a> Pushable<&'a [u8]> for FixedSizeBinary { + #[inline] + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional * self.size); + } + #[inline] + fn push(&mut self, value: &[u8]) { + debug_assert_eq!(value.len(), self.size); + self.push(value); + } + + #[inline] + fn push_null(&mut self) { + self.values.extend(std::iter::repeat(0).take(self.size)) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: &[u8]) { + assert_eq!(value.len(), 0); + self.extend_constant(additional) + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/mod.rs new file mode 100644 index 000000000000..098430b3d154 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/mod.rs @@ -0,0 +1,212 @@ +//! APIs to read from Parquet format. +mod binary; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod nested; +mod nested_utils; +mod null; +mod primitive; +mod simple; +mod struct_; +mod utils; + +use parquet2::read::get_page_iterator as _get_page_iterator; +use parquet2::schema::types::PrimitiveType; +use simple::page_iter_to_arrays; + +pub use self::nested_utils::{init_nested, InitNested, NestedArrayIter, NestedState}; +pub use self::struct_::StructIterator; +use super::*; +use crate::array::{Array, DictionaryKey, FixedSizeListArray, ListArray, MapArray}; +use crate::datatypes::{DataType, Field, IntervalUnit}; +use crate::error::Result; +use crate::offset::Offsets; + +/// Creates a new iterator of compressed pages. +pub fn get_page_iterator( + column_metadata: &ColumnChunkMetaData, + reader: R, + pages_filter: Option, + buffer: Vec, + max_header_size: usize, +) -> Result> { + Ok(_get_page_iterator( + column_metadata, + reader, + pages_filter, + buffer, + max_header_size, + )?) +} + +/// Creates a new [`ListArray`] or [`FixedSizeListArray`]. +pub fn create_list( + data_type: DataType, + nested: &mut NestedState, + values: Box, +) -> Box { + let (mut offsets, validity) = nested.nested.pop().unwrap().inner(); + match data_type.to_logical_type() { + DataType::List(_) => { + offsets.push(values.len() as i64); + + let offsets = offsets.iter().map(|x| *x as i32).collect::>(); + + let offsets: Offsets = offsets + .try_into() + .expect("i64 offsets do not fit in i32 offsets"); + + Box::new(ListArray::::new( + data_type, + offsets.into(), + values, + validity.and_then(|x| x.into()), + )) + }, + DataType::LargeList(_) => { + offsets.push(values.len() as i64); + + Box::new(ListArray::::new( + data_type, + offsets.try_into().expect("List too large"), + values, + validity.and_then(|x| x.into()), + )) + }, + DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new( + data_type, + values, + validity.and_then(|x| x.into()), + )), + _ => unreachable!(), + } +} + +/// Creates a new [`MapArray`]. +pub fn create_map( + data_type: DataType, + nested: &mut NestedState, + values: Box, +) -> Box { + let (mut offsets, validity) = nested.nested.pop().unwrap().inner(); + match data_type.to_logical_type() { + DataType::Map(_, _) => { + offsets.push(values.len() as i64); + let offsets = offsets.iter().map(|x| *x as i32).collect::>(); + + let offsets: Offsets = offsets + .try_into() + .expect("i64 offsets do not fit in i32 offsets"); + + Box::new(MapArray::new( + data_type, + offsets.into(), + values, + validity.and_then(|x| x.into()), + )) + }, + _ => unreachable!(), + } +} + +fn is_primitive(data_type: &DataType) -> bool { + matches!( + data_type.to_physical_type(), + crate::datatypes::PhysicalType::Primitive(_) + | crate::datatypes::PhysicalType::Null + | crate::datatypes::PhysicalType::Boolean + | crate::datatypes::PhysicalType::Utf8 + | crate::datatypes::PhysicalType::LargeUtf8 + | crate::datatypes::PhysicalType::Binary + | crate::datatypes::PhysicalType::LargeBinary + | crate::datatypes::PhysicalType::FixedSizeBinary + | crate::datatypes::PhysicalType::Dictionary(_) + ) +} + +fn columns_to_iter_recursive<'a, I: 'a>( + mut columns: Vec, + mut types: Vec<&PrimitiveType>, + field: Field, + init: Vec, + num_rows: usize, + chunk_size: Option, +) -> Result> +where + I: Pages, +{ + if init.is_empty() && is_primitive(&field.data_type) { + return Ok(Box::new( + page_iter_to_arrays( + columns.pop().unwrap(), + types.pop().unwrap(), + field.data_type, + chunk_size, + num_rows, + )? + .map(|x| Ok((NestedState::new(vec![]), x?))), + )); + } + + nested::columns_to_iter_recursive(columns, types, field, init, num_rows, chunk_size) +} + +/// Returns the number of (parquet) columns that a [`DataType`] contains. +pub fn n_columns(data_type: &DataType) -> usize { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 => 1, + List | FixedSizeList | LargeList => { + let a = data_type.to_logical_type(); + if let DataType::List(inner) = a { + n_columns(&inner.data_type) + } else if let DataType::LargeList(inner) = a { + n_columns(&inner.data_type) + } else if let DataType::FixedSizeList(inner, _) = a { + n_columns(&inner.data_type) + } else { + unreachable!() + } + }, + Map => { + let a = data_type.to_logical_type(); + if let DataType::Map(inner, _) = a { + n_columns(&inner.data_type) + } else { + unreachable!() + } + }, + Struct => { + if let DataType::Struct(fields) = data_type.to_logical_type() { + fields.iter().map(|inner| n_columns(&inner.data_type)).sum() + } else { + unreachable!() + } + }, + _ => todo!(), + } +} + +/// An iterator adapter that maps multiple iterators of [`Pages`] into an iterator of [`Array`]s. +/// +/// For a non-nested datatypes such as [`DataType::Int32`], this function requires a single element in `columns` and `types`. +/// For nested types, `columns` must be composed by all parquet columns with associated types `types`. +/// +/// The arrays are guaranteed to be at most of size `chunk_size` and data type `field.data_type`. +pub fn column_iter_to_arrays<'a, I: 'a>( + columns: Vec, + types: Vec<&PrimitiveType>, + field: Field, + chunk_size: Option, + num_rows: usize, +) -> Result> +where + I: Pages, +{ + Ok(Box::new( + columns_to_iter_recursive(columns, types, field, vec![], num_rows, chunk_size)? + .map(|x| x.map(|x| x.1)), + )) +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/nested.rs new file mode 100644 index 000000000000..14f75fa8d672 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/nested.rs @@ -0,0 +1,590 @@ +use ethnum::I256; +use parquet2::schema::types::PrimitiveType; + +use super::nested_utils::{InitNested, NestedArrayIter}; +use super::*; +use crate::array::PrimitiveArray; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; + +/// Converts an iterator of arrays to a trait object returning trait objects +#[inline] +fn remove_nested<'a, I>(iter: I) -> NestedArrayIter<'a> +where + I: Iterator)>> + Send + Sync + 'a, +{ + Box::new(iter.map(|x| { + x.map(|(mut nested, array)| { + let _ = nested.nested.pop().unwrap(); // the primitive + (nested, array) + }) + })) +} + +/// Converts an iterator of arrays to a trait object returning trait objects +#[inline] +fn primitive<'a, A, I>(iter: I) -> NestedArrayIter<'a> +where + A: Array, + I: Iterator> + Send + Sync + 'a, +{ + Box::new(iter.map(|x| { + x.map(|(mut nested, array)| { + let _ = nested.nested.pop().unwrap(); // the primitive + (nested, Box::new(array) as _) + }) + })) +} + +pub fn columns_to_iter_recursive<'a, I: 'a>( + mut columns: Vec, + mut types: Vec<&PrimitiveType>, + field: Field, + mut init: Vec, + num_rows: usize, + chunk_size: Option, +) -> Result> +where + I: Pages, +{ + use crate::datatypes::PhysicalType::*; + use crate::datatypes::PrimitiveType::*; + + Ok(match field.data_type().to_physical_type() { + Null => { + // physical type is i32 + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(null::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + )) + }, + Boolean => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(boolean::NestedIter::new( + columns.pop().unwrap(), + init, + num_rows, + chunk_size, + )) + }, + Primitive(Int8) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as i8, + )) + }, + Primitive(Int16) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as i16, + )) + }, + Primitive(Int32) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x, + )) + }, + Primitive(Int64) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i64| x, + )) + }, + Primitive(UInt8) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as u8, + )) + }, + Primitive(UInt16) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as u16, + )) + }, + Primitive(UInt32) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + match type_.physical_type { + PhysicalType::Int32 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as u32, + )), + // some implementations of parquet write arrow's u32 into i64. + PhysicalType::Int64 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i64| x as u32, + )), + other => { + return Err(Error::nyi(format!( + "Deserializing UInt32 from {other:?}'s parquet" + ))) + }, + } + }, + Primitive(UInt64) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i64| x as u64, + )) + }, + Primitive(Float32) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: f32| x, + )) + }, + Primitive(Float64) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: f64| x, + )) + }, + Binary | Utf8 => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + remove_nested(binary::NestedIter::::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + )) + }, + LargeBinary | LargeUtf8 => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + remove_nested(binary::NestedIter::::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + )) + }, + _ => match field.data_type().to_logical_type() { + DataType::Dictionary(key_type, _, _) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + let iter = columns.pop().unwrap(); + let data_type = field.data_type().clone(); + match_integer_type!(key_type, |$K| { + dict_read::<$K, _>(iter, init, type_, data_type, num_rows, chunk_size) + })? + }, + DataType::List(inner) + | DataType::LargeList(inner) + | DataType::FixedSizeList(inner, _) => { + init.push(InitNested::List(field.is_nullable)); + let iter = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + num_rows, + chunk_size, + )?; + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let array = create_list(field.data_type().clone(), &mut nested, array); + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + DataType::Decimal(_, _) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + match type_.physical_type { + PhysicalType::Int32 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i32| x as i128, + )), + PhysicalType::Int64 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i64| x as i128, + )), + PhysicalType::FixedLenByteArray(n) if n > 16 => { + return Err(Error::InvalidArgumentError(format!( + "Can't decode Decimal128 type from `FixedLenByteArray` of len {n}" + ))) + }, + PhysicalType::FixedLenByteArray(n) => { + let iter = fixed_size_binary::NestedIter::new( + columns.pop().unwrap(), + init, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + // Convert the fixed length byte array to Decimal. + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| super::super::convert_i128(value, n)) + .collect::>(); + let validity = array.validity().cloned(); + + let array: Box = Box::new(PrimitiveArray::::try_new( + field.data_type.clone(), + values.into(), + validity, + )?); + + let _ = nested.nested.pop().unwrap(); // the primitive + + Ok((nested, array)) + }); + Box::new(iter) + }, + _ => { + return Err(Error::nyi(format!( + "Deserializing type for Decimal {:?} from parquet", + type_.physical_type + ))) + }, + } + }, + DataType::Decimal256(_, _) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + match type_.physical_type { + PhysicalType::Int32 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i32| i256(I256::new(x as i128)), + )), + PhysicalType::Int64 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i64| i256(I256::new(x as i128)), + )), + PhysicalType::FixedLenByteArray(n) if n <= 16 => { + let iter = fixed_size_binary::NestedIter::new( + columns.pop().unwrap(), + init, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + // Convert the fixed length byte array to Decimal. + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let values = array + .values() + .chunks_exact(n) + .map(|value| i256(I256::new(super::super::convert_i128(value, n)))) + .collect::>(); + let validity = array.validity().cloned(); + + let array: Box = Box::new(PrimitiveArray::::try_new( + field.data_type.clone(), + values.into(), + validity, + )?); + + let _ = nested.nested.pop().unwrap(); // the primitive + + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + + PhysicalType::FixedLenByteArray(n) if n <= 32 => { + let iter = fixed_size_binary::NestedIter::new( + columns.pop().unwrap(), + init, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + // Convert the fixed length byte array to Decimal. + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_i256) + .collect::>(); + let validity = array.validity().cloned(); + + let array: Box = Box::new(PrimitiveArray::::try_new( + field.data_type.clone(), + values.into(), + validity, + )?); + + let _ = nested.nested.pop().unwrap(); // the primitive + + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + PhysicalType::FixedLenByteArray(n) => { + return Err(Error::InvalidArgumentError(format!( + "Can't decode Decimal256 type from from `FixedLenByteArray` of len {n}" + ))) + }, + _ => { + return Err(Error::nyi(format!( + "Deserializing type for Decimal {:?} from parquet", + type_.physical_type + ))) + }, + } + }, + DataType::Struct(fields) => { + let columns = fields + .iter() + .rev() + .map(|f| { + let mut init = init.clone(); + init.push(InitNested::Struct(field.is_nullable)); + let n = n_columns(&f.data_type); + let columns = columns.drain(columns.len() - n..).collect(); + let types = types.drain(types.len() - n..).collect(); + columns_to_iter_recursive( + columns, + types, + f.clone(), + init, + num_rows, + chunk_size, + ) + }) + .collect::>>()?; + let columns = columns.into_iter().rev().collect(); + Box::new(struct_::StructIterator::new(columns, fields.clone())) + }, + DataType::Map(inner, _) => { + init.push(InitNested::List(field.is_nullable)); + let iter = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + num_rows, + chunk_size, + )?; + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let array = create_map(field.data_type().clone(), &mut nested, array); + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + other => { + return Err(Error::nyi(format!( + "Deserializing type {other:?} from parquet" + ))) + }, + }, + }) +} + +fn dict_read<'a, K: DictionaryKey, I: 'a + Pages>( + iter: I, + init: Vec, + _type_: &PrimitiveType, + data_type: DataType, + num_rows: usize, + chunk_size: Option, +) -> Result> { + use DataType::*; + let values_data_type = if let Dictionary(_, v, _) = &data_type { + v.as_ref() + } else { + panic!() + }; + + Ok(match values_data_type.to_logical_type() { + UInt8 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as u8, + )), + UInt16 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as u16, + )), + UInt32 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as u32, + )), + Int8 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as i8, + )), + Int16 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as i16, + )), + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x, + )) + }, + Int64 | Date64 | Time64(_) | Duration(_) => { + primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i64| x as i32, + )) + }, + Float32 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: f32| x, + )), + Float64 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: f64| x, + )), + Utf8 | Binary => primitive(binary::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), + LargeUtf8 | LargeBinary => primitive(binary::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), + FixedSizeBinary(_) => primitive(fixed_size_binary::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), + /* + + Timestamp(time_unit, _) => { + let time_unit = *time_unit; + return timestamp_dict::( + iter, + physical_type, + logical_type, + data_type, + chunk_size, + time_unit, + ); + } + */ + other => { + return Err(Error::nyi(format!( + "Reading nested dictionaries of type {other:?}" + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/nested_utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/nested_utils.rs new file mode 100644 index 000000000000..595f161bb73e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/nested_utils.rs @@ -0,0 +1,553 @@ +use std::collections::VecDeque; + +use parquet2::encoding::hybrid_rle::HybridRleDecoder; +use parquet2::page::{split_buffer, DataPage, DictPage, Page}; +use parquet2::read::levels::get_bit_width; + +use super::super::Pages; +pub use super::utils::Zip; +use super::utils::{DecodedState, MaybeNext, PageState}; +use crate::array::Array; +use crate::bitmap::MutableBitmap; +use crate::error::Result; + +/// trait describing deserialized repetition and definition levels +pub trait Nested: std::fmt::Debug + Send + Sync { + fn inner(&mut self) -> (Vec, Option); + + fn push(&mut self, length: i64, is_valid: bool); + + fn is_nullable(&self) -> bool; + + fn is_repeated(&self) -> bool { + false + } + + // Whether the Arrow container requires all items to be filled. + fn is_required(&self) -> bool; + + /// number of rows + fn len(&self) -> usize; + + /// number of values associated to the primitive type this nested tracks + fn num_values(&self) -> usize; +} + +#[derive(Debug, Default)] +pub struct NestedPrimitive { + is_nullable: bool, + length: usize, +} + +impl NestedPrimitive { + pub fn new(is_nullable: bool) -> Self { + Self { + is_nullable, + length: 0, + } + } +} + +impl Nested for NestedPrimitive { + fn inner(&mut self) -> (Vec, Option) { + (Default::default(), Default::default()) + } + + fn is_nullable(&self) -> bool { + self.is_nullable + } + + fn is_required(&self) -> bool { + false + } + + fn push(&mut self, _value: i64, _is_valid: bool) { + self.length += 1 + } + + fn len(&self) -> usize { + self.length + } + + fn num_values(&self) -> usize { + self.length + } +} + +#[derive(Debug, Default)] +pub struct NestedOptional { + pub validity: MutableBitmap, + pub offsets: Vec, +} + +impl Nested for NestedOptional { + fn inner(&mut self) -> (Vec, Option) { + let offsets = std::mem::take(&mut self.offsets); + let validity = std::mem::take(&mut self.validity); + (offsets, Some(validity)) + } + + fn is_nullable(&self) -> bool { + true + } + + fn is_repeated(&self) -> bool { + true + } + + fn is_required(&self) -> bool { + // it may be for FixedSizeList + false + } + + fn push(&mut self, value: i64, is_valid: bool) { + self.offsets.push(value); + self.validity.push(is_valid); + } + + fn len(&self) -> usize { + self.offsets.len() + } + + fn num_values(&self) -> usize { + self.offsets.last().copied().unwrap_or(0) as usize + } +} + +impl NestedOptional { + pub fn with_capacity(capacity: usize) -> Self { + let offsets = Vec::::with_capacity(capacity + 1); + let validity = MutableBitmap::with_capacity(capacity); + Self { validity, offsets } + } +} + +#[derive(Debug, Default)] +pub struct NestedValid { + pub offsets: Vec, +} + +impl Nested for NestedValid { + fn inner(&mut self) -> (Vec, Option) { + let offsets = std::mem::take(&mut self.offsets); + (offsets, None) + } + + fn is_nullable(&self) -> bool { + false + } + + fn is_repeated(&self) -> bool { + true + } + + fn is_required(&self) -> bool { + // it may be for FixedSizeList + false + } + + fn push(&mut self, value: i64, _is_valid: bool) { + self.offsets.push(value); + } + + fn len(&self) -> usize { + self.offsets.len() + } + + fn num_values(&self) -> usize { + self.offsets.last().copied().unwrap_or(0) as usize + } +} + +impl NestedValid { + pub fn with_capacity(capacity: usize) -> Self { + let offsets = Vec::::with_capacity(capacity + 1); + Self { offsets } + } +} + +#[derive(Debug, Default)] +pub struct NestedStructValid { + length: usize, +} + +impl NestedStructValid { + pub fn new() -> Self { + Self { length: 0 } + } +} + +impl Nested for NestedStructValid { + fn inner(&mut self) -> (Vec, Option) { + (Default::default(), None) + } + + fn is_nullable(&self) -> bool { + false + } + + fn is_required(&self) -> bool { + true + } + + fn push(&mut self, _value: i64, _is_valid: bool) { + self.length += 1; + } + + fn len(&self) -> usize { + self.length + } + + fn num_values(&self) -> usize { + self.length + } +} + +#[derive(Debug, Default)] +pub struct NestedStruct { + validity: MutableBitmap, +} + +impl NestedStruct { + pub fn with_capacity(capacity: usize) -> Self { + Self { + validity: MutableBitmap::with_capacity(capacity), + } + } +} + +impl Nested for NestedStruct { + fn inner(&mut self) -> (Vec, Option) { + (Default::default(), Some(std::mem::take(&mut self.validity))) + } + + fn is_nullable(&self) -> bool { + true + } + + fn is_required(&self) -> bool { + true + } + + fn push(&mut self, _value: i64, is_valid: bool) { + self.validity.push(is_valid) + } + + fn len(&self) -> usize { + self.validity.len() + } + + fn num_values(&self) -> usize { + self.validity.len() + } +} + +/// A decoder that knows how to map `State` -> Array +pub(super) trait NestedDecoder<'a> { + type State: PageState<'a>; + type Dictionary; + type DecodedState: DecodedState; + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result; + + /// Initializes a new state + fn with_capacity(&self, capacity: usize) -> Self::DecodedState; + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()>; + fn push_null(&self, decoded: &mut Self::DecodedState); + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary; +} + +/// The initial info of nested data types. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InitNested { + /// Primitive data types + Primitive(bool), + /// List data types + List(bool), + /// Struct data types + Struct(bool), +} + +/// Initialize [`NestedState`] from `&[InitNested]`. +pub fn init_nested(init: &[InitNested], capacity: usize) -> NestedState { + let container = init + .iter() + .map(|init| match init { + InitNested::Primitive(is_nullable) => { + Box::new(NestedPrimitive::new(*is_nullable)) as Box + }, + InitNested::List(is_nullable) => { + if *is_nullable { + Box::new(NestedOptional::with_capacity(capacity)) as Box + } else { + Box::new(NestedValid::with_capacity(capacity)) as Box + } + }, + InitNested::Struct(is_nullable) => { + if *is_nullable { + Box::new(NestedStruct::with_capacity(capacity)) as Box + } else { + Box::new(NestedStructValid::new()) as Box + } + }, + }) + .collect(); + NestedState::new(container) +} + +pub struct NestedPage<'a> { + iter: std::iter::Peekable, HybridRleDecoder<'a>>>, +} + +impl<'a> NestedPage<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (rep_levels, def_levels, _) = split_buffer(page)?; + + let max_rep_level = page.descriptor.max_rep_level; + let max_def_level = page.descriptor.max_def_level; + + let reps = + HybridRleDecoder::try_new(rep_levels, get_bit_width(max_rep_level), page.num_values())?; + let defs = + HybridRleDecoder::try_new(def_levels, get_bit_width(max_def_level), page.num_values())?; + + let iter = reps.zip(defs).peekable(); + + Ok(Self { iter }) + } + + // number of values (!= number of rows) + pub fn len(&self) -> usize { + self.iter.size_hint().0 + } +} + +/// The state of nested data types. +#[derive(Debug)] +pub struct NestedState { + /// The nesteds composing `NestedState`. + pub nested: Vec>, +} + +impl NestedState { + /// Creates a new [`NestedState`]. + pub fn new(nested: Vec>) -> Self { + Self { nested } + } + + /// The number of rows in this state + pub fn len(&self) -> usize { + // outermost is the number of rows + self.nested[0].len() + } +} + +/// Extends `items` by consuming `page`, first trying to complete the last `item` +/// and extending it if more are needed +pub(super) fn extend<'a, D: NestedDecoder<'a>>( + page: &'a DataPage, + init: &[InitNested], + items: &mut VecDeque<(NestedState, D::DecodedState)>, + dict: Option<&'a D::Dictionary>, + remaining: &mut usize, + decoder: &D, + chunk_size: Option, +) -> Result<()> { + let mut values_page = decoder.build_state(page, dict)?; + let mut page = NestedPage::try_new(page)?; + + let capacity = chunk_size.unwrap_or(0); + // chunk_size = None, remaining = 44 => chunk_size = 44 + let chunk_size = chunk_size.unwrap_or(usize::MAX); + + let (mut nested, mut decoded) = if let Some((nested, decoded)) = items.pop_back() { + (nested, decoded) + } else { + // there is no state => initialize it + (init_nested(init, capacity), decoder.with_capacity(0)) + }; + let existing = nested.len(); + + let additional = (chunk_size - existing).min(*remaining); + + // extend the current state + extend_offsets2( + &mut page, + &mut values_page, + &mut nested.nested, + &mut decoded, + decoder, + additional, + )?; + *remaining -= nested.len() - existing; + items.push_back((nested, decoded)); + + while page.len() > 0 && *remaining > 0 { + let additional = chunk_size.min(*remaining); + + let mut nested = init_nested(init, additional); + let mut decoded = decoder.with_capacity(0); + extend_offsets2( + &mut page, + &mut values_page, + &mut nested.nested, + &mut decoded, + decoder, + additional, + )?; + *remaining -= nested.len(); + items.push_back((nested, decoded)); + } + Ok(()) +} + +fn extend_offsets2<'a, D: NestedDecoder<'a>>( + page: &mut NestedPage<'a>, + values_state: &mut D::State, + nested: &mut [Box], + decoded: &mut D::DecodedState, + decoder: &D, + additional: usize, +) -> Result<()> { + let max_depth = nested.len(); + + let mut cum_sum = vec![0u32; max_depth + 1]; + for (i, nest) in nested.iter().enumerate() { + let delta = nest.is_nullable() as u32 + nest.is_repeated() as u32; + cum_sum[i + 1] = cum_sum[i] + delta; + } + + let mut cum_rep = vec![0u32; max_depth + 1]; + for (i, nest) in nested.iter().enumerate() { + let delta = nest.is_repeated() as u32; + cum_rep[i + 1] = cum_rep[i] + delta; + } + + let mut rows = 0; + while let Some((rep, def)) = page.iter.next() { + let rep = rep?; + let def = def?; + if rep == 0 { + rows += 1; + } + + let mut is_required = false; + for depth in 0..max_depth { + let right_level = rep <= cum_rep[depth] && def >= cum_sum[depth]; + if is_required || right_level { + let length = nested + .get(depth + 1) + .map(|x| x.len() as i64) + // the last depth is the leaf, which is always increased by 1 + .unwrap_or(1); + + let nest = &mut nested[depth]; + + let is_valid = nest.is_nullable() && def > cum_sum[depth]; + nest.push(length, is_valid); + is_required = nest.is_required() && !is_valid; + + if depth == max_depth - 1 { + // the leaf / primitive + let is_valid = (def != cum_sum[depth]) || !nest.is_nullable(); + if right_level && is_valid { + decoder.push_valid(values_state, decoded)?; + } else { + decoder.push_null(decoded); + } + } + } + } + + let next_rep = *page + .iter + .peek() + .map(|x| x.0.as_ref()) + .transpose() + .unwrap() // todo: fix this + .unwrap_or(&0); + + if next_rep == 0 && rows == additional { + break; + } + } + Ok(()) +} + +#[inline] +pub(super) fn next<'a, I, D>( + iter: &'a mut I, + items: &mut VecDeque<(NestedState, D::DecodedState)>, + dict: &'a mut Option, + remaining: &mut usize, + init: &[InitNested], + chunk_size: Option, + decoder: &D, +) -> MaybeNext> +where + I: Pages, + D: NestedDecoder<'a>, +{ + // front[a1, a2, a3, ...]back + if items.len() > 1 { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if *remaining == 0 { + return match items.pop_front() { + Some(decoded) => MaybeNext::Some(Ok(decoded)), + None => MaybeNext::None, + }; + } + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(None) => { + if let Some(decoded) = items.pop_front() { + MaybeNext::Some(Ok(decoded)) + } else { + MaybeNext::None + } + }, + Ok(Some(page)) => { + let page = match page { + Page::Data(page) => page, + Page::Dict(dict_page) => { + *dict = Some(decoder.deserialize_dict(dict_page)); + return MaybeNext::More; + }, + }; + + // there is a new page => consume the page from the start + let error = extend( + page, + init, + items, + dict.as_ref(), + remaining, + decoder, + chunk_size, + ); + match error { + Ok(_) => {}, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + if (items.len() == 1) + && items.front().unwrap().0.len() > chunk_size.unwrap_or(usize::MAX) + { + MaybeNext::Some(Ok(items.pop_front().unwrap())) + } else { + MaybeNext::More + } + }, + } +} + +/// Type def for a sharable, boxed dyn [`Iterator`] of NestedStates and arrays +pub type NestedArrayIter<'a> = + Box)>> + Send + Sync + 'a>; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/null/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/null/mod.rs new file mode 100644 index 000000000000..576db09d364b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/null/mod.rs @@ -0,0 +1,104 @@ +mod nested; + +pub(super) use nested::NestedIter; +use parquet2::page::Page; + +use super::super::{ArrayIter, Pages}; +use crate::array::NullArray; +use crate::datatypes::DataType; + +/// Converts [`Pages`] to an [`ArrayIter`] +pub fn iter_to_arrays<'a, I>( + mut iter: I, + data_type: DataType, + chunk_size: Option, + num_rows: usize, +) -> ArrayIter<'a> +where + I: 'a + Pages, +{ + let mut len = 0usize; + + while let Ok(Some(page)) = iter.next() { + match page { + Page::Dict(_) => continue, + Page::Data(page) => { + let rows = page.num_values(); + len = (len + rows).min(num_rows); + if len == num_rows { + break; + } + }, + } + } + + if len == 0 { + return Box::new(std::iter::empty()); + } + + let chunk_size = chunk_size.unwrap_or(len); + + let complete_chunks = len / chunk_size; + + let remainder = len - (complete_chunks * chunk_size); + let i_data_type = data_type.clone(); + let complete = (0..complete_chunks) + .map(move |_| Ok(NullArray::new(i_data_type.clone(), chunk_size).boxed())); + if len % chunk_size == 0 { + Box::new(complete) + } else { + let array = NullArray::new(data_type, remainder); + Box::new(complete.chain(std::iter::once(Ok(array.boxed())))) + } +} + +#[cfg(test)] +mod tests { + use parquet2::encoding::Encoding; + use parquet2::error::Error as ParquetError; + use parquet2::metadata::Descriptor; + use parquet2::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; + use parquet2::schema::types::{PhysicalType, PrimitiveType}; + + use super::iter_to_arrays; + use crate::array::NullArray; + use crate::datatypes::DataType; + use crate::error::Error; + + #[test] + fn limit() { + let new_page = |values: i32| { + Page::Data(DataPage::new( + DataPageHeader::V1(DataPageHeaderV1 { + num_values: values, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Plain.into(), + repetition_level_encoding: Encoding::Plain.into(), + statistics: None, + }), + vec![], + Descriptor { + primitive_type: PrimitiveType::from_physical( + "a".to_string(), + PhysicalType::Int32, + ), + max_def_level: 0, + max_rep_level: 0, + }, + None, + )) + }; + + let p1 = new_page(100); + let p2 = new_page(100); + let pages = vec![Result::<_, ParquetError>::Ok(&p1), Ok(&p2)]; + let pages = fallible_streaming_iterator::convert(pages.into_iter()); + let arrays = iter_to_arrays(pages, DataType::Null, Some(10), 101); + + let arrays = arrays.collect::, Error>>().unwrap(); + let expected = std::iter::repeat(NullArray::new(DataType::Null, 10).boxed()) + .take(10) + .chain(std::iter::once(NullArray::new(DataType::Null, 1).boxed())); + assert_eq!(arrays, expected.collect::>()) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/null/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/null/nested.rs new file mode 100644 index 000000000000..9528720e73be --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/null/nested.rs @@ -0,0 +1,126 @@ +use std::collections::VecDeque; + +use parquet2::page::{DataPage, DictPage}; + +use super::super::nested_utils::*; +use super::super::{utils, Pages}; +use crate::array::NullArray; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::deserialize::utils::DecodedState; + +impl<'a> utils::PageState<'a> for usize { + fn len(&self) -> usize { + *self + } +} + +#[derive(Debug)] +struct NullDecoder {} + +impl DecodedState for usize { + fn len(&self) -> usize { + *self + } +} + +impl<'a> NestedDecoder<'a> for NullDecoder { + type State = usize; + type Dictionary = usize; + type DecodedState = usize; + + fn build_state( + &self, + _page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + if let Some(n) = dict { + return Ok(*n); + } + Ok(1) + } + + /// Initializes a new state + fn with_capacity(&self, _capacity: usize) -> Self::DecodedState { + 0 + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + *decoded += *state; + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let length = decoded; + *length += 1; + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + page.num_values + } +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as null arrays +#[derive(Debug)] +pub struct NestedIter +where + I: Pages, +{ + iter: I, + init: Vec, + data_type: DataType, + items: VecDeque<(NestedState, usize)>, + remaining: usize, + chunk_size: Option, + decoder: NullDecoder, +} + +impl NestedIter +where + I: Pages, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + items: VecDeque::new(), + chunk_size, + remaining: num_rows, + decoder: NullDecoder {}, + } + } +} + +impl Iterator for NestedIter +where + I: Pages, +{ + type Item = Result<(NestedState, NullArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut None, + &mut self.remaining, + &self.init, + self.chunk_size, + &self.decoder, + ); + match maybe_state { + utils::MaybeNext::Some(Ok((nested, state))) => { + Some(Ok((nested, NullArray::new(self.data_type.clone(), state)))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/basic.rs new file mode 100644 index 000000000000..200c9a517dd0 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/basic.rs @@ -0,0 +1,370 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::{hybrid_rle, Encoding}; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; +use parquet2::types::{decode, NativeType as ParquetNativeType}; + +use super::super::utils::{get_selected_rows, FilteredOptionalPageValidity, OptionalPageValidity}; +use super::super::{utils, Pages}; +use crate::array::MutablePrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +#[derive(Debug)] +pub(super) struct FilteredRequiredValues<'a> { + values: SliceFilteredIter>, +} + +impl<'a> FilteredRequiredValues<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + assert_eq!(values.len() % std::mem::size_of::

(), 0); + + let values = values.chunks_exact(std::mem::size_of::

()); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct Values<'a> { + pub values: std::slice::ChunksExact<'a, u8>, +} + +impl<'a> Values<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + assert_eq!(values.len() % std::mem::size_of::

(), 0); + Ok(Self { + values: values.chunks_exact(std::mem::size_of::

()), + }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct ValuesDictionary<'a, T> +where + T: NativeType, +{ + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Vec, +} + +impl<'a, T> ValuesDictionary<'a, T> +where + T: NativeType, +{ + pub fn try_new(page: &'a DataPage, dict: &'a Vec) -> Result { + let values = utils::dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +// The state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +pub(super) enum State<'a, T> +where + T: NativeType, +{ + Optional(OptionalPageValidity<'a>, Values<'a>), + Required(Values<'a>), + RequiredDictionary(ValuesDictionary<'a, T>), + OptionalDictionary(OptionalPageValidity<'a>, ValuesDictionary<'a, T>), + FilteredRequired(FilteredRequiredValues<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), +} + +impl<'a, T> utils::PageState<'a> for State<'a, T> +where + T: NativeType, +{ + fn len(&self) -> usize { + match self { + State::Optional(optional, _) => optional.len(), + State::Required(values) => values.len(), + State::RequiredDictionary(values) => values.len(), + State::OptionalDictionary(optional, _) => optional.len(), + State::FilteredRequired(values) => values.len(), + State::FilteredOptional(optional, _) => optional.len(), + } + } +} + +#[derive(Debug)] +pub(super) struct PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData

, + pub op: F, +} + +impl PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + #[inline] + pub(super) fn new(op: F) -> Self { + Self { + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData, + op, + } + } +} + +impl utils::DecodedState for (Vec, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +impl<'a, T, P, F> utils::Decoder<'a> for PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type State = State<'a, T>; + type Dict = Vec; + type DecodedState = (Vec, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + ValuesDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + Ok(State::OptionalDictionary( + OptionalPageValidity::try_new(page)?, + ValuesDictionary::try_new(page, dict)?, + )) + }, + (Encoding::Plain, _, true, false) => { + let validity = OptionalPageValidity::try_new(page)?; + let values = Values::try_new::

(page)?; + + Ok(State::Optional(validity, values)) + }, + (Encoding::Plain, _, false, false) => Ok(State::Required(Values::try_new::

(page)?)), + (Encoding::Plain, _, false, true) => { + FilteredRequiredValues::try_new::

(page).map(State::FilteredRequired) + }, + (Encoding::Plain, _, true, true) => Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + Values::try_new::

(page)?, + )), + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page_validity, page_values) => utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.values.by_ref().map(decode).map(self.op), + ), + State::Required(page) => { + values.extend( + page.values + .by_ref() + .map(decode) + .map(self.op) + .take(remaining), + ); + }, + State::OptionalDictionary(page_validity, page_values) => { + let op1 = |index: u32| page_values.dict[index as usize]; + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.values.by_ref().map(|x| x.unwrap()).map(op1), + ) + }, + State::RequiredDictionary(page) => { + let op1 = |index: u32| page.dict[index as usize]; + values.extend( + page.values + .by_ref() + .map(|x| x.unwrap()) + .map(op1) + .take(remaining), + ); + }, + State::FilteredRequired(page) => { + values.extend( + page.values + .by_ref() + .map(decode) + .map(self.op) + .take(remaining), + ); + }, + State::FilteredOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.values.by_ref().map(decode).map(self.op), + ); + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + deserialize_plain(&page.buffer, self.op) + } +} + +pub(super) fn finish( + data_type: &DataType, + values: Vec, + validity: MutableBitmap, +) -> MutablePrimitiveArray { + let validity = if validity.is_empty() { + None + } else { + Some(validity) + }; + MutablePrimitiveArray::try_new(data_type.clone(), values, validity).unwrap() +} + +/// An [`Iterator`] adapter over [`Pages`] assumed to be encoded as primitive arrays +#[derive(Debug)] +pub struct Iter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + data_type: DataType, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + dict: Option>, + op: F, + phantom: std::marker::PhantomData

, +} + +impl Iter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + dict: None, + remaining: num_rows, + chunk_size, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for Iter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = utils::next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &PrimitiveDecoder::new(self.op), + ); + match maybe_state { + utils::MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} + +pub(super) fn deserialize_plain(values: &[u8], op: F) -> Vec +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + values + .chunks_exact(std::mem::size_of::

()) + .map(decode) + .map(op) + .collect::>() +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/dictionary.rs new file mode 100644 index 000000000000..35293d582d10 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/dictionary.rs @@ -0,0 +1,190 @@ +use std::collections::VecDeque; + +use parquet2::page::DictPage; +use parquet2::types::NativeType as ParquetNativeType; + +use super::super::dictionary::{nested_next_dict, *}; +use super::super::nested_utils::{InitNested, NestedState}; +use super::super::utils::MaybeNext; +use super::super::Pages; +use super::basic::deserialize_plain; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +fn read_dict(data_type: DataType, op: F, dict: &DictPage) -> Box +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + let data_type = match data_type { + DataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + let values = deserialize_plain(&dict.buffer, op); + + Box::new(PrimitiveArray::new(data_type, values.into(), None)) +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct DictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + data_type: DataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + op: F, + phantom: std::marker::PhantomData

, +} + +impl DictIter +where + K: DictionaryKey, + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + chunk_size, + remaining: num_rows, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for DictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), self.op, dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`] +#[derive(Debug)] +pub struct NestedDictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + init: Vec, + data_type: DataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, + op: F, + phantom: std::marker::PhantomData

, +} + +impl NestedDictIter +where + K: DictionaryKey, + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for NestedDictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), self.op, dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/integer.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/integer.rs new file mode 100644 index 000000000000..ac6c0bac0c1f --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/integer.rs @@ -0,0 +1,262 @@ +use std::collections::VecDeque; + +use num_traits::AsPrimitive; +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::delta_bitpacked::Decoder; +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; +use parquet2::types::NativeType as ParquetNativeType; + +use super::super::{utils, Pages}; +use super::basic::{finish, PrimitiveDecoder, State as PrimitiveState}; +use crate::array::MutablePrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::io::parquet::read::deserialize::utils::{ + get_selected_rows, FilteredOptionalPageValidity, OptionalPageValidity, +}; +use crate::types::NativeType; + +/// The state of a [`DataPage`] of an integer parquet type (i32 or i64) +#[derive(Debug)] +enum State<'a, T> +where + T: NativeType, +{ + Common(PrimitiveState<'a, T>), + DeltaBinaryPackedRequired(Decoder<'a>), + DeltaBinaryPackedOptional(OptionalPageValidity<'a>, Decoder<'a>), + FilteredDeltaBinaryPackedRequired(SliceFilteredIter>), + FilteredDeltaBinaryPackedOptional(FilteredOptionalPageValidity<'a>, Decoder<'a>), +} + +impl<'a, T> utils::PageState<'a> for State<'a, T> +where + T: NativeType, +{ + fn len(&self) -> usize { + match self { + State::Common(state) => state.len(), + State::DeltaBinaryPackedRequired(state) => state.size_hint().0, + State::DeltaBinaryPackedOptional(state, _) => state.len(), + State::FilteredDeltaBinaryPackedRequired(state) => state.size_hint().0, + State::FilteredDeltaBinaryPackedOptional(state, _) => state.len(), + } + } +} + +/// Decoder of integer parquet type +#[derive(Debug)] +struct IntDecoder(PrimitiveDecoder) +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Fn(P) -> T; + +impl IntDecoder +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Fn(P) -> T, +{ + #[inline] + fn new(op: F) -> Self { + Self(PrimitiveDecoder::new(op)) + } +} + +impl<'a, T, P, F> utils::Decoder<'a> for IntDecoder +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Copy + Fn(P) -> T, +{ + type State = State<'a, T>; + type Dict = Vec; + type DecodedState = (Vec, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::DeltaBinaryPacked, _, false, false) => { + let (_, _, values) = split_buffer(page)?; + Decoder::try_new(values) + .map(State::DeltaBinaryPackedRequired) + .map_err(Error::from) + }, + (Encoding::DeltaBinaryPacked, _, true, false) => { + let (_, _, values) = split_buffer(page)?; + Ok(State::DeltaBinaryPackedOptional( + OptionalPageValidity::try_new(page)?, + Decoder::try_new(values)?, + )) + }, + (Encoding::DeltaBinaryPacked, _, false, true) => { + let (_, _, values) = split_buffer(page)?; + let values = Decoder::try_new(values)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(State::FilteredDeltaBinaryPackedRequired(values)) + }, + (Encoding::DeltaBinaryPacked, _, true, true) => { + let (_, _, values) = split_buffer(page)?; + let values = Decoder::try_new(values)?; + + Ok(State::FilteredDeltaBinaryPackedOptional( + FilteredOptionalPageValidity::try_new(page)?, + values, + )) + }, + _ => self.0.build_state(page, dict).map(State::Common), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + self.0.with_capacity(capacity) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Common(state) => self.0.extend_from_state(state, decoded, remaining), + State::DeltaBinaryPackedRequired(state) => { + values.extend( + state + .by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op) + .take(remaining), + ); + }, + State::DeltaBinaryPackedOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values + .by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op), + ) + }, + State::FilteredDeltaBinaryPackedRequired(page) => { + values.extend( + page.by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op) + .take(remaining), + ); + }, + State::FilteredDeltaBinaryPackedOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values + .by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op), + ); + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + self.0.deserialize_dict(page) + } +} + +/// An [`Iterator`] adapter over [`Pages`] assumed to be encoded as primitive arrays +/// encoded as parquet integer types +#[derive(Debug)] +pub struct IntegerIter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + data_type: DataType, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + dict: Option>, + op: F, + phantom: std::marker::PhantomData

, +} + +impl IntegerIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + dict: None, + remaining: num_rows, + chunk_size, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for IntegerIter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Copy + Fn(P) -> T, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = utils::next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &IntDecoder::new(self.op), + ); + match maybe_state { + utils::MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/mod.rs new file mode 100644 index 000000000000..27d9c27c3186 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/mod.rs @@ -0,0 +1,9 @@ +mod basic; +mod dictionary; +mod integer; +mod nested; + +pub use basic::Iter; +pub use dictionary::{DictIter, NestedDictIter}; +pub use integer::IntegerIter; +pub use nested::NestedIter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/nested.rs new file mode 100644 index 000000000000..405e2d9a7c09 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/nested.rs @@ -0,0 +1,244 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage}; +use parquet2::schema::Repetition; +use parquet2::types::{decode, NativeType as ParquetNativeType}; + +use super::super::nested_utils::*; +use super::super::{utils, Pages}; +use super::basic::{deserialize_plain, Values, ValuesDictionary}; +use crate::array::PrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +// The state of a `DataPage` of `Primitive` parquet primitive type +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +enum State<'a, T> +where + T: NativeType, +{ + Optional(Values<'a>), + Required(Values<'a>), + RequiredDictionary(ValuesDictionary<'a, T>), + OptionalDictionary(ValuesDictionary<'a, T>), +} + +impl<'a, T> utils::PageState<'a> for State<'a, T> +where + T: NativeType, +{ + fn len(&self) -> usize { + match self { + State::Optional(values) => values.len(), + State::Required(values) => values.len(), + State::RequiredDictionary(values) => values.len(), + State::OptionalDictionary(values) => values.len(), + } + } +} + +#[derive(Debug)] +struct PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData

, + op: F, +} + +impl PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + #[inline] + fn new(op: F) -> Self { + Self { + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData, + op, + } + } +} + +impl<'a, T, P, F> NestedDecoder<'a> for PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type State = State<'a, T>; + type Dictionary = Vec; + type DecodedState = (Vec, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + ValuesDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + ValuesDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + (Encoding::Plain, _, true, false) => Values::try_new::

(page).map(State::Optional), + (Encoding::Plain, _, false, false) => Values::try_new::

(page).map(State::Required), + _ => Err(utils::not_implemented(page)), + } + } + + /// Initializes a new state + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page_values) => { + let value = page_values.values.by_ref().next().map(decode).map(self.op); + // convert unwrap to error + values.push(value.unwrap_or_default()); + validity.push(true); + }, + State::Required(page_values) => { + let value = page_values.values.by_ref().next().map(decode).map(self.op); + // convert unwrap to error + values.push(value.unwrap_or_default()); + }, + State::RequiredDictionary(page) => { + let value = page + .values + .next() + .map(|index| page.dict[index.unwrap() as usize]); + + values.push(value.unwrap_or_default()); + }, + State::OptionalDictionary(page) => { + let value = page + .values + .next() + .map(|index| page.dict[index.unwrap() as usize]); + + values.push(value.unwrap_or_default()); + validity.push(true); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(T::default()); + validity.push(false) + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + deserialize_plain(&page.buffer, self.op) + } +} + +fn finish( + data_type: &DataType, + values: Vec, + validity: MutableBitmap, +) -> PrimitiveArray { + PrimitiveArray::new(data_type.clone(), values.into(), validity.into()) +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct NestedIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + iter: I, + init: Vec, + data_type: DataType, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + dict: Option>, + remaining: usize, + chunk_size: Option, + decoder: PrimitiveDecoder, +} + +impl NestedIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + init, + data_type, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + decoder: PrimitiveDecoder::new(op), + } + } +} + +impl Iterator for NestedIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result<(NestedState, PrimitiveArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + &self.init, + self.chunk_size, + &self.decoder, + ); + match maybe_state { + utils::MaybeNext::Some(Ok((nested, state))) => { + Some(Ok((nested, finish(&self.data_type, state.0, state.1)))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/simple.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/simple.rs new file mode 100644 index 000000000000..83d9d8fbae8a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/simple.rs @@ -0,0 +1,651 @@ +use ethnum::I256; +use parquet2::schema::types::{ + PhysicalType, PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, +}; +use parquet2::types::int96_to_i64_ns; + +use super::super::{ArrayIter, Pages}; +use super::{binary, boolean, fixed_size_binary, null, primitive}; +use crate::array::{Array, DictionaryKey, MutablePrimitiveArray, PrimitiveArray}; +use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; +use crate::error::{Error, Result}; +use crate::types::{days_ms, i256, NativeType}; + +/// Converts an iterator of arrays to a trait object returning trait objects +#[inline] +fn dyn_iter<'a, A, I>(iter: I) -> ArrayIter<'a> +where + A: Array, + I: Iterator> + Send + Sync + 'a, +{ + Box::new(iter.map(|x| x.map(|x| Box::new(x) as Box))) +} + +/// Converts an iterator of [MutablePrimitiveArray] into an iterator of [PrimitiveArray] +#[inline] +fn iden(iter: I) -> impl Iterator>> +where + T: NativeType, + I: Iterator>>, +{ + iter.map(|x| x.map(|x| x.into())) +} + +#[inline] +fn op(iter: I, op: F) -> impl Iterator>> +where + T: NativeType, + I: Iterator>>, + F: Fn(T) -> T + Copy, +{ + iter.map(move |x| { + x.map(move |mut x| { + x.values_mut_slice().iter_mut().for_each(|x| *x = op(*x)); + x.into() + }) + }) +} + +/// An iterator adapter that maps an iterator of Pages into an iterator of Arrays +/// of [`DataType`] `data_type` and length `chunk_size`. +pub fn page_iter_to_arrays<'a, I: Pages + 'a>( + pages: I, + type_: &PrimitiveType, + data_type: DataType, + chunk_size: Option, + num_rows: usize, +) -> Result> { + use DataType::*; + + let physical_type = &type_.physical_type; + let logical_type = &type_.logical_type; + + Ok(match (physical_type, data_type.to_logical_type()) { + (_, Null) => null::iter_to_arrays(pages, data_type, chunk_size, num_rows), + (PhysicalType::Boolean, Boolean) => { + dyn_iter(boolean::Iter::new(pages, data_type, chunk_size, num_rows)) + }, + (PhysicalType::Int32, UInt8) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as u8, + ))), + (PhysicalType::Int32, UInt16) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as u16, + ))), + (PhysicalType::Int32, UInt32) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as u32, + ))), + (PhysicalType::Int64, UInt32) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x as u32, + ))), + (PhysicalType::Int32, Int8) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i8, + ))), + (PhysicalType::Int32, Int16) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i16, + ))), + (PhysicalType::Int32, Int32 | Date32 | Time32(_)) => dyn_iter(iden( + primitive::IntegerIter::new(pages, data_type, num_rows, chunk_size, |x: i32| x), + )), + (PhysicalType::Int64 | PhysicalType::Int96, Timestamp(time_unit, _)) => { + let time_unit = *time_unit; + return timestamp( + pages, + physical_type, + logical_type, + data_type, + num_rows, + chunk_size, + time_unit, + ); + }, + (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => dyn_iter( + fixed_size_binary::Iter::new(pages, data_type, num_rows, chunk_size), + ), + (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::YearMonth)) => { + let n = 12; + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| i32::from_le_bytes(value[..4].try_into().unwrap())) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::DayTime)) => { + let n = 12; + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_days_ms) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::Int32, Decimal(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i128, + ))), + (PhysicalType::Int64, Decimal(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x as i128, + ))), + (PhysicalType::FixedLenByteArray(n), Decimal(_, _)) if *n > 16 => { + return Err(Error::NotYetImplemented(format!( + "Can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}" + ))) + }, + (PhysicalType::FixedLenByteArray(n), Decimal(_, _)) => { + let n = *n; + + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| super::super::convert_i128(value, n)) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::Int32, Decimal256(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| i256(I256::new(x as i128)), + ))), + (PhysicalType::Int64, Decimal256(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| i256(I256::new(x as i128)), + ))), + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n <= 16 => { + let n = *n; + + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| i256(I256::new(super::super::convert_i128(value, n)))) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n <= 32 => { + let n = *n; + + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_i256) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n > 32 => { + return Err(Error::NotYetImplemented(format!( + "Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}" + ))) + }, + (PhysicalType::Int32, Date64) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i64 * 86400000, + ))), + (PhysicalType::Int64, Date64) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x, + ))), + (PhysicalType::Int64, Int64 | Time64(_) | Duration(_)) => dyn_iter(iden( + primitive::IntegerIter::new(pages, data_type, num_rows, chunk_size, |x: i64| x), + )), + (PhysicalType::Int64, UInt64) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x as u64, + ))), + (PhysicalType::Float, Float32) => dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: f32| x, + ))), + (PhysicalType::Double, Float64) => dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: f64| x, + ))), + + (PhysicalType::ByteArray, Utf8 | Binary) => Box::new(binary::Iter::::new( + pages, data_type, chunk_size, num_rows, + )), + (PhysicalType::ByteArray, LargeBinary | LargeUtf8) => Box::new( + binary::Iter::::new(pages, data_type, chunk_size, num_rows), + ), + + (_, Dictionary(key_type, _, _)) => { + return match_integer_type!(key_type, |$K| { + dict_read::<$K, _>(pages, physical_type, logical_type, data_type, num_rows, chunk_size) + }) + }, + (from, to) => { + return Err(Error::NotYetImplemented(format!( + "Reading parquet type {from:?} to {to:?} still not implemented" + ))) + }, + }) +} + +/// Unify the timestamp unit from parquet TimeUnit into arrow's TimeUnit +/// Returns (a int64 factor, is_multiplier) +fn unify_timestamp_unit( + logical_type: &Option, + time_unit: TimeUnit, +) -> (i64, bool) { + if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + match (*unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) + | (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => (1, true), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) + | (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => (1000, false), + + (ParquetTimeUnit::Microseconds, TimeUnit::Second) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => (1_000_000, false), + + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => (1_000_000_000, false), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) + | (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => (1_000, true), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => (1_000_000, true), + } + } else { + (1, true) + } +} + +#[inline] +pub fn int96_to_i64_us(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const MICROS_PER_SECOND: i64 = 1_000_000; + + let day = value[2] as i64; + let microseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * MICROS_PER_SECOND + microseconds +} + +#[inline] +pub fn int96_to_i64_ms(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const MILLIS_PER_SECOND: i64 = 1_000; + + let day = value[2] as i64; + let milliseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * MILLIS_PER_SECOND + milliseconds +} + +#[inline] +pub fn int96_to_i64_s(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + + let day = value[2] as i64; + let seconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000_000; + let day_seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + day_seconds + seconds +} + +fn timestamp<'a, I: Pages + 'a>( + pages: I, + physical_type: &PhysicalType, + logical_type: &Option, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + time_unit: TimeUnit, +) -> Result> { + if physical_type == &PhysicalType::Int96 { + return match time_unit { + TimeUnit::Nanosecond => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_ns, + )))), + TimeUnit::Microsecond => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_us, + )))), + TimeUnit::Millisecond => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_ms, + )))), + TimeUnit::Second => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_s, + )))), + }; + }; + + if physical_type != &PhysicalType::Int64 { + return Err(Error::nyi( + "Can't decode a timestamp from a non-int64 parquet type", + )); + } + + let iter = primitive::IntegerIter::new(pages, data_type, num_rows, chunk_size, |x: i64| x); + let (factor, is_multiplier) = unify_timestamp_unit(logical_type, time_unit); + match (factor, is_multiplier) { + (1, _) => Ok(dyn_iter(iden(iter))), + (a, true) => Ok(dyn_iter(op(iter, move |x| x * a))), + (a, false) => Ok(dyn_iter(op(iter, move |x| x / a))), + } +} + +fn timestamp_dict<'a, K: DictionaryKey, I: Pages + 'a>( + pages: I, + physical_type: &PhysicalType, + logical_type: &Option, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + time_unit: TimeUnit, +) -> Result> { + if physical_type == &PhysicalType::Int96 { + let logical_type = PrimitiveLogicalType::Timestamp { + unit: ParquetTimeUnit::Nanoseconds, + is_adjusted_to_utc: false, + }; + let (factor, is_multiplier) = unify_timestamp_unit(&Some(logical_type), time_unit); + return match (factor, is_multiplier) { + (a, true) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + DataType::Timestamp(TimeUnit::Nanosecond, None), + num_rows, + chunk_size, + move |x| int96_to_i64_ns(x) * a, + ))), + (a, false) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + DataType::Timestamp(TimeUnit::Nanosecond, None), + num_rows, + chunk_size, + move |x| int96_to_i64_ns(x) / a, + ))), + }; + }; + + let (factor, is_multiplier) = unify_timestamp_unit(logical_type, time_unit); + match (factor, is_multiplier) { + (a, true) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + data_type, + num_rows, + chunk_size, + move |x: i64| x * a, + ))), + (a, false) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + data_type, + num_rows, + chunk_size, + move |x: i64| x / a, + ))), + } +} + +fn dict_read<'a, K: DictionaryKey, I: Pages + 'a>( + iter: I, + physical_type: &PhysicalType, + logical_type: &Option, + data_type: DataType, + num_rows: usize, + chunk_size: Option, +) -> Result> { + use DataType::*; + let values_data_type = if let Dictionary(_, v, _) = &data_type { + v.as_ref() + } else { + panic!() + }; + + Ok(match (physical_type, values_data_type.to_logical_type()) { + (PhysicalType::Int32, UInt8) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as u8, + )), + (PhysicalType::Int32, UInt16) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as u16, + )), + (PhysicalType::Int32, UInt32) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as u32, + )), + (PhysicalType::Int64, UInt64) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i64| x as u64, + )), + (PhysicalType::Int32, Int8) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as i8, + )), + (PhysicalType::Int32, Int16) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as i16, + )), + (PhysicalType::Int32, Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth)) => { + dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x, + )) + }, + + (PhysicalType::Int64, Timestamp(time_unit, _)) => { + let time_unit = *time_unit; + return timestamp_dict::( + iter, + physical_type, + logical_type, + data_type, + num_rows, + chunk_size, + time_unit, + ); + }, + + (PhysicalType::Int64, Int64 | Date64 | Time64(_) | Duration(_)) => { + dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i64| x, + )) + }, + (PhysicalType::Float, Float32) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: f32| x, + )), + (PhysicalType::Double, Float64) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: f64| x, + )), + + (PhysicalType::ByteArray, Utf8 | Binary) => dyn_iter(binary::DictIter::::new( + iter, data_type, num_rows, chunk_size, + )), + (PhysicalType::ByteArray, LargeUtf8 | LargeBinary) => dyn_iter( + binary::DictIter::::new(iter, data_type, num_rows, chunk_size), + ), + (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => dyn_iter( + fixed_size_binary::DictIter::::new(iter, data_type, num_rows, chunk_size), + ), + other => { + return Err(Error::nyi(format!( + "Reading dictionaries of type {other:?}" + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/struct_.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/struct_.rs new file mode 100644 index 000000000000..947e7f1141e5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/struct_.rs @@ -0,0 +1,58 @@ +use super::nested_utils::{NestedArrayIter, NestedState}; +use crate::array::{Array, StructArray}; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; + +/// An iterator adapter over [`NestedArrayIter`] assumed to be encoded as Struct arrays +pub struct StructIterator<'a> { + iters: Vec>, + fields: Vec, +} + +impl<'a> StructIterator<'a> { + /// Creates a new [`StructIterator`] with `iters` and `fields`. + pub fn new(iters: Vec>, fields: Vec) -> Self { + assert_eq!(iters.len(), fields.len()); + Self { iters, fields } + } +} + +impl<'a> Iterator for StructIterator<'a> { + type Item = Result<(NestedState, Box), Error>; + + fn next(&mut self) -> Option { + let values = self + .iters + .iter_mut() + .map(|iter| iter.next()) + .collect::>(); + + if values.iter().any(|x| x.is_none()) { + return None; + } + + // todo: unzip of Result not yet supported in stable Rust + let mut nested = vec![]; + let mut new_values = vec![]; + for x in values { + match x.unwrap() { + Ok((nest, values)) => { + new_values.push(values); + nested.push(nest); + }, + Err(e) => return Some(Err(e)), + } + } + let mut nested = nested.pop().unwrap(); + let (_, validity) = nested.nested.pop().unwrap().inner(); + + Some(Ok(( + nested, + Box::new(StructArray::new( + DataType::Struct(self.fields.clone()), + new_values, + validity.and_then(|x| x.into()), + )), + ))) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/utils.rs new file mode 100644 index 000000000000..a39a7506d8e1 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/utils.rs @@ -0,0 +1,524 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::{ + FilteredHybridEncoded, FilteredHybridRleDecoderIter, HybridDecoderBitmapIter, HybridEncoded, +}; +use parquet2::encoding::hybrid_rle; +use parquet2::indexes::Interval; +use parquet2::page::{split_buffer, DataPage, DictPage, Page}; +use parquet2::schema::Repetition; + +use super::super::Pages; +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::MutableBitmap; +use crate::error::Error; + +pub fn not_implemented(page: &DataPage) -> Error { + let is_optional = page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + let required = if is_optional { "optional" } else { "required" }; + let is_filtered = if is_filtered { ", index-filtered" } else { "" }; + Error::NotYetImplemented(format!( + "Decoding {:?} \"{:?}\"-encoded {} {} parquet pages", + page.descriptor.primitive_type.physical_type, + page.encoding(), + required, + is_filtered, + )) +} + +/// A private trait representing structs that can receive elements. +pub(super) trait Pushable: Sized { + fn reserve(&mut self, additional: usize); + fn push(&mut self, value: T); + fn len(&self) -> usize; + fn push_null(&mut self); + fn extend_constant(&mut self, additional: usize, value: T); +} + +impl Pushable for MutableBitmap { + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBitmap::reserve(self, additional) + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push(&mut self, value: bool) { + self.push(value) + } + + #[inline] + fn push_null(&mut self) { + self.push(false) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: bool) { + self.extend_constant(additional, value) + } +} + +impl Pushable for Vec { + #[inline] + fn reserve(&mut self, additional: usize) { + Vec::reserve(self, additional) + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push_null(&mut self) { + self.push(A::default()) + } + + #[inline] + fn push(&mut self, value: A) { + self.push(value) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: A) { + self.resize(self.len() + additional, value); + } +} + +/// The state of a partially deserialized page +pub(super) trait PageValidity<'a> { + fn next_limited(&mut self, limit: usize) -> Option>; +} + +#[derive(Debug, Clone)] +pub struct FilteredOptionalPageValidity<'a> { + iter: FilteredHybridRleDecoderIter<'a>, + current: Option<(FilteredHybridEncoded<'a>, usize)>, +} + +impl<'a> FilteredOptionalPageValidity<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, validity, _) = split_buffer(page)?; + + let iter = hybrid_rle::Decoder::new(validity, 1); + let iter = HybridDecoderBitmapIter::new(iter, page.num_values()); + let selected_rows = get_selected_rows(page); + let iter = FilteredHybridRleDecoderIter::new(iter, selected_rows); + + Ok(Self { + iter, + current: None, + }) + } + + pub fn len(&self) -> usize { + self.iter.len() + } +} + +pub fn get_selected_rows(page: &DataPage) -> VecDeque { + page.selected_rows() + .unwrap_or(&[Interval::new(0, page.num_values())]) + .iter() + .copied() + .collect() +} + +impl<'a> PageValidity<'a> for FilteredOptionalPageValidity<'a> { + fn next_limited(&mut self, limit: usize) -> Option> { + let (run, own_offset) = if let Some((run, offset)) = self.current { + (run, offset) + } else { + // a new run + let run = self.iter.next()?.unwrap(); // no run -> None + self.current = Some((run, 0)); + return self.next_limited(limit); + }; + + match run { + FilteredHybridEncoded::Bitmap { + values, + offset, + length, + } => { + let run_length = length - own_offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, own_offset + length)); + } + + Some(FilteredHybridEncoded::Bitmap { + values, + offset, + length, + }) + }, + FilteredHybridEncoded::Repeated { is_set, length } => { + let run_length = length - own_offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, own_offset + length)); + } + + Some(FilteredHybridEncoded::Repeated { is_set, length }) + }, + FilteredHybridEncoded::Skipped(set) => { + self.current = None; + Some(FilteredHybridEncoded::Skipped(set)) + }, + } + } +} + +pub struct Zip { + validity: V, + values: I, +} + +impl Zip { + pub fn new(validity: V, values: I) -> Self { + Self { validity, values } + } +} + +impl, I: Iterator> Iterator for Zip { + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + self.validity + .next() + .map(|x| if x { self.values.next() } else { None }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.validity.size_hint() + } +} + +#[derive(Debug, Clone)] +pub struct OptionalPageValidity<'a> { + iter: HybridDecoderBitmapIter<'a>, + current: Option<(HybridEncoded<'a>, usize)>, +} + +impl<'a> OptionalPageValidity<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, validity, _) = split_buffer(page)?; + + let iter = hybrid_rle::Decoder::new(validity, 1); + let iter = HybridDecoderBitmapIter::new(iter, page.num_values()); + Ok(Self { + iter, + current: None, + }) + } + + /// Number of items remaining + pub fn len(&self) -> usize { + self.iter.len() + + self + .current + .as_ref() + .map(|(run, offset)| run.len() - offset) + .unwrap_or_default() + } + + fn next_limited(&mut self, limit: usize) -> Option> { + let (run, offset) = if let Some((run, offset)) = self.current { + (run, offset) + } else { + // a new run + let run = self.iter.next()?.unwrap(); // no run -> None + self.current = Some((run, 0)); + return self.next_limited(limit); + }; + + match run { + HybridEncoded::Bitmap(values, length) => { + let run_length = length - offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, offset + length)); + } + + Some(FilteredHybridEncoded::Bitmap { + values, + offset, + length, + }) + }, + HybridEncoded::Repeated(is_set, run_length) => { + let run_length = run_length - offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, offset + length)); + } + + Some(FilteredHybridEncoded::Repeated { is_set, length }) + }, + } + } +} + +impl<'a> PageValidity<'a> for OptionalPageValidity<'a> { + fn next_limited(&mut self, limit: usize) -> Option> { + self.next_limited(limit) + } +} + +/// Extends a [`Pushable`] from an iterator of non-null values and an hybrid-rle decoder +pub(super) fn extend_from_decoder, I: Iterator>( + validity: &mut MutableBitmap, + page_validity: &mut dyn PageValidity, + limit: Option, + pushable: &mut P, + mut values_iter: I, +) { + let limit = limit.unwrap_or(usize::MAX); + + let mut runs = vec![]; + let mut remaining = limit; + let mut reserve_pushable = 0; + + // first do a scan so that we know how much to reserve up front + while remaining > 0 { + let run = page_validity.next_limited(remaining); + let run = if let Some(run) = run { run } else { break }; + + match run { + FilteredHybridEncoded::Bitmap { length, .. } => { + reserve_pushable += length; + remaining -= length; + }, + FilteredHybridEncoded::Repeated { length, .. } => { + reserve_pushable += length; + remaining -= length; + }, + _ => {}, + }; + runs.push(run) + } + pushable.reserve(reserve_pushable); + validity.reserve(reserve_pushable); + + // then a second loop to really fill the buffers + for run in runs { + match run { + FilteredHybridEncoded::Bitmap { + values, + offset, + length, + } => { + // consume `length` items + let iter = BitmapIter::new(values, offset, length); + let iter = Zip::new(iter, &mut values_iter); + + for item in iter { + if let Some(item) = item { + pushable.push(item) + } else { + pushable.push_null() + } + } + validity.extend_from_slice(values, offset, length); + }, + FilteredHybridEncoded::Repeated { is_set, length } => { + validity.extend_constant(length, is_set); + if is_set { + for v in (&mut values_iter).take(length) { + pushable.push(v) + } + } else { + pushable.extend_constant(length, T::default()); + } + }, + FilteredHybridEncoded::Skipped(valids) => for _ in values_iter.by_ref().take(valids) {}, + }; + } +} + +/// The state of a partially deserialized page +pub(super) trait PageState<'a>: std::fmt::Debug { + fn len(&self) -> usize; +} + +/// The state of a partially deserialized page +pub(super) trait DecodedState: std::fmt::Debug { + // the number of values that the state already has + fn len(&self) -> usize; +} + +/// A decoder that knows how to map `State` -> Array +pub(super) trait Decoder<'a> { + /// The state that this decoder derives from a [`DataPage`]. This is bound to the page. + type State: PageState<'a>; + /// The dictionary representation that the decoder uses + type Dict; + /// The target state that this Decoder decodes into. + type DecodedState: DecodedState; + + /// Creates a new `Self::State` + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dict>, + ) -> Result; + + /// Initializes a new [`Self::DecodedState`]. + fn with_capacity(&self, capacity: usize) -> Self::DecodedState; + + /// extends [`Self::DecodedState`] by deserializing items in [`Self::State`]. + /// It guarantees that the length of `decoded` is at most `decoded.len() + remaining`. + fn extend_from_state( + &self, + page: &mut Self::State, + decoded: &mut Self::DecodedState, + additional: usize, + ); + + /// Deserializes a [`DictPage`] into [`Self::Dict`]. + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict; +} + +pub(super) fn extend_from_new_page<'a, T: Decoder<'a>>( + mut page: T::State, + chunk_size: Option, + items: &mut VecDeque, + remaining: &mut usize, + decoder: &T, +) { + let capacity = chunk_size.unwrap_or(0); + let chunk_size = chunk_size.unwrap_or(usize::MAX); + + let mut decoded = if let Some(decoded) = items.pop_back() { + decoded + } else { + // there is no state => initialize it + decoder.with_capacity(capacity) + }; + let existing = decoded.len(); + + let additional = (chunk_size - existing).min(*remaining); + + decoder.extend_from_state(&mut page, &mut decoded, additional); + *remaining -= decoded.len() - existing; + items.push_back(decoded); + + while page.len() > 0 && *remaining > 0 { + let additional = chunk_size.min(*remaining); + + let mut decoded = decoder.with_capacity(additional); + decoder.extend_from_state(&mut page, &mut decoded, additional); + *remaining -= decoded.len(); + items.push_back(decoded) + } +} + +/// Represents what happened when a new page was consumed +#[derive(Debug)] +pub enum MaybeNext

{ + /// Whether the page was sufficient to fill `chunk_size` + Some(P), + /// whether there are no more pages or intermediary decoded states + None, + /// Whether the page was insufficient to fill `chunk_size` and a new page is required + More, +} + +#[inline] +pub(super) fn next<'a, I: Pages, D: Decoder<'a>>( + iter: &'a mut I, + items: &'a mut VecDeque, + dict: &'a mut Option, + remaining: &'a mut usize, + chunk_size: Option, + decoder: &'a D, +) -> MaybeNext> { + // front[a1, a2, a3, ...]back + if items.len() > 1 { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if (items.len() == 1) && items.front().unwrap().len() == chunk_size.unwrap_or(usize::MAX) { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if *remaining == 0 { + return match items.pop_front() { + Some(decoded) => MaybeNext::Some(Ok(decoded)), + None => MaybeNext::None, + }; + } + + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(Some(page)) => { + let page = match page { + Page::Data(page) => page, + Page::Dict(dict_page) => { + *dict = Some(decoder.deserialize_dict(dict_page)); + return MaybeNext::More; + }, + }; + + // there is a new page => consume the page from the start + let maybe_page = decoder.build_state(page, dict.as_ref()); + let page = match maybe_page { + Ok(page) => page, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + extend_from_new_page(page, chunk_size, items, remaining, decoder); + + if (items.len() == 1) && items.front().unwrap().len() < chunk_size.unwrap_or(usize::MAX) + { + MaybeNext::More + } else { + let decoded = items.pop_front().unwrap(); + MaybeNext::Some(Ok(decoded)) + } + }, + Ok(None) => { + if let Some(decoded) = items.pop_front() { + // we have a populated item and no more pages + // the only case where an item's length may be smaller than chunk_size + debug_assert!(decoded.len() <= chunk_size.unwrap_or(usize::MAX)); + MaybeNext::Some(Ok(decoded)) + } else { + MaybeNext::None + } + }, + } +} + +#[inline] +pub(super) fn dict_indices_decoder(page: &DataPage) -> Result { + let (_, _, indices_buffer) = split_buffer(page)?; + + // SPEC: Data page format: the bit width used to encode the entry ids stored as 1 byte (max bit width = 32), + // SPEC: followed by the values encoded using RLE/Bit packed described above (with the given bit width). + let bit_width = indices_buffer[0]; + let indices_buffer = &indices_buffer[1..]; + + hybrid_rle::HybridRleDecoder::try_new(indices_buffer, bit_width as u32, page.num_values()) + .map_err(Error::from) +} diff --git a/crates/nano-arrow/src/io/parquet/read/file.rs b/crates/nano-arrow/src/io/parquet/read/file.rs new file mode 100644 index 000000000000..750340c60ef7 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/file.rs @@ -0,0 +1,205 @@ +use std::io::{Read, Seek}; + +use parquet2::indexes::FilteredPage; + +use super::{RowGroupDeserializer, RowGroupMetaData}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::Result; +use crate::io::parquet::read::read_columns_many; + +/// An iterator of [`Chunk`]s coming from row groups of a parquet file. +/// +/// This can be thought of a flatten chain of [`Iterator`] - each row group is sequentially +/// mapped to an [`Iterator`] and each iterator is iterated upon until either the limit +/// or the last iterator ends. +/// # Implementation +/// This iterator is single threaded on both IO-bounded and CPU-bounded tasks, and mixes them. +pub struct FileReader { + row_groups: RowGroupReader, + remaining_rows: usize, + current_row_group: Option, +} + +impl FileReader { + /// Returns a new [`FileReader`]. + pub fn new( + reader: R, + row_groups: Vec, + schema: Schema, + chunk_size: Option, + limit: Option, + page_indexes: Option>>>>, + ) -> Self { + let row_groups = + RowGroupReader::new(reader, schema, row_groups, chunk_size, limit, page_indexes); + + Self { + row_groups, + remaining_rows: limit.unwrap_or(usize::MAX), + current_row_group: None, + } + } + + fn next_row_group(&mut self) -> Result> { + let result = self.row_groups.next().transpose()?; + + // If current_row_group is None, then there will be no elements to remove. + if self.current_row_group.is_some() { + self.remaining_rows = self.remaining_rows.saturating_sub( + result + .as_ref() + .map(|x| x.num_rows()) + .unwrap_or(self.remaining_rows), + ); + } + Ok(result) + } + + /// Returns the [`Schema`] associated to this file. + pub fn schema(&self) -> &Schema { + &self.row_groups.schema + } +} + +impl Iterator for FileReader { + type Item = Result>>; + + fn next(&mut self) -> Option { + if self.remaining_rows == 0 { + // reached the limit + return None; + } + + if let Some(row_group) = &mut self.current_row_group { + match row_group.next() { + // no more chunks in the current row group => try a new one + None => match self.next_row_group() { + Ok(Some(row_group)) => { + self.current_row_group = Some(row_group); + // new found => pull again + self.next() + }, + Ok(None) => { + self.current_row_group = None; + None + }, + Err(e) => Some(Err(e)), + }, + other => other, + } + } else { + match self.next_row_group() { + Ok(Some(row_group)) => { + self.current_row_group = Some(row_group); + self.next() + }, + Ok(None) => { + self.current_row_group = None; + None + }, + Err(e) => Some(Err(e)), + } + } + } +} + +/// An [`Iterator`] from row groups of a parquet file. +/// +/// # Implementation +/// Advancing this iterator is IO-bounded - each iteration reads all the column chunks from the file +/// to memory and attaches [`RowGroupDeserializer`] to them so that they can be iterated in chunks. +pub struct RowGroupReader { + reader: R, + schema: Schema, + row_groups: std::vec::IntoIter, + chunk_size: Option, + remaining_rows: usize, + page_indexes: Option>>>>, +} + +impl RowGroupReader { + /// Returns a new [`RowGroupReader`] + pub fn new( + reader: R, + schema: Schema, + row_groups: Vec, + chunk_size: Option, + limit: Option, + page_indexes: Option>>>>, + ) -> Self { + if let Some(pages) = &page_indexes { + assert_eq!(pages.len(), row_groups.len()) + } + Self { + reader, + schema, + row_groups: row_groups.into_iter(), + chunk_size, + remaining_rows: limit.unwrap_or(usize::MAX), + page_indexes: page_indexes.map(|pages| pages.into_iter()), + } + } + + #[inline] + fn _next(&mut self) -> Result> { + if self.schema.fields.is_empty() { + return Ok(None); + } + if self.remaining_rows == 0 { + // reached the limit + return Ok(None); + } + + let row_group = if let Some(row_group) = self.row_groups.next() { + row_group + } else { + return Ok(None); + }; + + let pages = self.page_indexes.as_mut().and_then(|iter| iter.next()); + + // the number of rows depends on whether indexes are selected or not. + let num_rows = pages + .as_ref() + .map(|x| { + // first field, first column within that field + x[0][0] + .iter() + .map(|page| { + page.selected_rows + .iter() + .map(|interval| interval.length) + .sum::() + }) + .sum() + }) + .unwrap_or_else(|| row_group.num_rows()); + + let column_chunks = read_columns_many( + &mut self.reader, + &row_group, + self.schema.fields.clone(), + self.chunk_size, + Some(self.remaining_rows), + pages, + )?; + + let result = RowGroupDeserializer::new(column_chunks, num_rows, Some(self.remaining_rows)); + self.remaining_rows = self.remaining_rows.saturating_sub(num_rows); + Ok(Some(result)) + } +} + +impl Iterator for RowGroupReader { + type Item = Result; + + fn next(&mut self) -> Option { + self._next().transpose() + } + + fn size_hint(&self) -> (usize, Option) { + self.row_groups.size_hint() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/binary.rs b/crates/nano-arrow/src/io/parquet/read/indexes/binary.rs new file mode 100644 index 000000000000..9a7c7c4ca90b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/binary.rs @@ -0,0 +1,40 @@ +use parquet2::indexes::PageIndex; + +use super::ColumnPageStatistics; +use crate::array::{Array, BinaryArray, PrimitiveArray, Utf8Array}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +pub fn deserialize( + indexes: &[PageIndex>], + data_type: &DataType, +) -> Result { + Ok(ColumnPageStatistics { + min: deserialize_binary_iter(indexes.iter().map(|index| index.min.as_ref()), data_type)?, + max: deserialize_binary_iter(indexes.iter().map(|index| index.max.as_ref()), data_type)?, + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + }) +} + +fn deserialize_binary_iter<'a, I: TrustedLen>>>( + iter: I, + data_type: &DataType, +) -> Result, Error> { + match data_type.to_physical_type() { + PhysicalType::LargeBinary => Ok(Box::new(BinaryArray::::from_iter(iter))), + PhysicalType::Utf8 => { + let iter = iter.map(|x| x.map(|x| std::str::from_utf8(x)).transpose()); + Ok(Box::new(Utf8Array::::try_from_trusted_len_iter(iter)?)) + }, + PhysicalType::LargeUtf8 => { + let iter = iter.map(|x| x.map(|x| std::str::from_utf8(x)).transpose()); + Ok(Box::new(Utf8Array::::try_from_trusted_len_iter(iter)?)) + }, + _ => Ok(Box::new(BinaryArray::::from_iter(iter))), + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/boolean.rs b/crates/nano-arrow/src/io/parquet/read/indexes/boolean.rs new file mode 100644 index 000000000000..70977197d103 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/boolean.rs @@ -0,0 +1,20 @@ +use parquet2::indexes::PageIndex; + +use super::ColumnPageStatistics; +use crate::array::{BooleanArray, PrimitiveArray}; + +pub fn deserialize(indexes: &[PageIndex]) -> ColumnPageStatistics { + ColumnPageStatistics { + min: Box::new(BooleanArray::from_trusted_len_iter( + indexes.iter().map(|index| index.min), + )), + max: Box::new(BooleanArray::from_trusted_len_iter( + indexes.iter().map(|index| index.max), + )), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/fixed_len_binary.rs b/crates/nano-arrow/src/io/parquet/read/indexes/fixed_len_binary.rs new file mode 100644 index 000000000000..26002e5857d5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/fixed_len_binary.rs @@ -0,0 +1,67 @@ +use parquet2::indexes::PageIndex; + +use super::ColumnPageStatistics; +use crate::array::{Array, FixedSizeBinaryArray, MutableFixedSizeBinaryArray, PrimitiveArray}; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; +use crate::trusted_len::TrustedLen; +use crate::types::{i256, NativeType}; + +pub fn deserialize(indexes: &[PageIndex>], data_type: DataType) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_binary_iter( + indexes.iter().map(|index| index.min.as_ref()), + data_type.clone(), + ), + max: deserialize_binary_iter(indexes.iter().map(|index| index.max.as_ref()), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +fn deserialize_binary_iter<'a, I: TrustedLen>>>( + iter: I, + data_type: DataType, +) -> Box { + match data_type.to_physical_type() { + PhysicalType::Primitive(PrimitiveType::Int128) => { + Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| { + v.map(|x| { + // Copy the fixed-size byte value to the start of a 16 byte stack + // allocated buffer, then use an arithmetic right shift to fill in + // MSBs, which accounts for leading 1's in negative (two's complement) + // values. + let n = x.len(); + let mut bytes = [0u8; 16]; + bytes[..n].copy_from_slice(x); + i128::from_be_bytes(bytes) >> (8 * (16 - n)) + }) + }))) + }, + PhysicalType::Primitive(PrimitiveType::Int256) => { + Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| { + v.map(|x| { + let n = x.len(); + let mut bytes = [0u8; 32]; + bytes[..n].copy_from_slice(x); + i256::from_be_bytes(bytes) + }) + }))) + }, + _ => { + let mut a = MutableFixedSizeBinaryArray::try_new( + data_type, + Vec::with_capacity(iter.size_hint().0), + None, + ) + .unwrap(); + for item in iter { + a.push(item); + } + let a: FixedSizeBinaryArray = a.into(); + Box::new(a) + }, + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/mod.rs b/crates/nano-arrow/src/io/parquet/read/indexes/mod.rs new file mode 100644 index 000000000000..b60b717ebfd5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/mod.rs @@ -0,0 +1,381 @@ +//! API to perform page-level filtering (also known as indexes) +use parquet2::error::Error as ParquetError; +use parquet2::indexes::{ + select_pages, BooleanIndex, ByteIndex, FixedLenByteIndex, Index as ParquetIndex, NativeIndex, + PageLocation, +}; +use parquet2::metadata::{ColumnChunkMetaData, RowGroupMetaData}; +use parquet2::read::{read_columns_indexes as _read_columns_indexes, read_pages_locations}; +use parquet2::schema::types::PhysicalType as ParquetPhysicalType; + +mod binary; +mod boolean; +mod fixed_len_binary; +mod primitive; + +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +pub use parquet2::indexes::{FilteredPage, Interval}; + +use super::get_field_pages; +use crate::array::{Array, UInt64Array}; +use crate::datatypes::{DataType, Field, PhysicalType, PrimitiveType}; +use crate::error::Error; + +/// Page statistics of an Arrow field. +#[derive(Debug, PartialEq)] +pub enum FieldPageStatistics { + /// Variant used for fields with a single parquet column (e.g. primitives, dictionaries, list) + Single(ColumnPageStatistics), + /// Variant used for fields with multiple parquet columns (e.g. Struct, Map) + Multiple(Vec), +} + +impl From for FieldPageStatistics { + fn from(column: ColumnPageStatistics) -> Self { + Self::Single(column) + } +} + +/// [`ColumnPageStatistics`] contains the minimum, maximum, and null_count +/// of each page of a parquet column, as an [`Array`]. +/// This struct has the following invariants: +/// * `min`, `max` and `null_count` have the same length (equal to the number of pages in the column) +/// * `min`, `max` and `null_count` are guaranteed to be non-null +/// * `min` and `max` have the same logical type +#[derive(Debug, PartialEq)] +pub struct ColumnPageStatistics { + /// The minimum values in the pages + pub min: Box, + /// The maximum values in the pages + pub max: Box, + /// The number of null values in the pages. + pub null_count: UInt64Array, +} + +/// Given a sequence of [`ParquetIndex`] representing the page indexes of each column in the +/// parquet file, returns the page-level statistics as a [`FieldPageStatistics`]. +/// +/// This function maps timestamps, decimal types, etc. accordingly. +/// # Implementation +/// This function is CPU-bounded `O(P)` where `P` is the total number of pages on all columns. +/// # Error +/// This function errors iff the value is not deserializable to arrow (e.g. invalid utf-8) +fn deserialize( + indexes: &mut VecDeque<&Box>, + data_type: DataType, +) -> Result { + match data_type.to_physical_type() { + PhysicalType::Boolean => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + Ok(boolean::deserialize(&index.indexes).into()) + }, + PhysicalType::Primitive(PrimitiveType::Int128) => { + let index = indexes.pop_front().unwrap(); + match index.physical_type() { + ParquetPhysicalType::Int32 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) + }, + parquet2::schema::types::PhysicalType::Int64 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok( + primitive::deserialize_i64( + &index.indexes, + &index.primitive_type, + data_type, + ) + .into(), + ) + }, + parquet2::schema::types::PhysicalType::FixedLenByteArray(_) => { + let index = index.as_any().downcast_ref::().unwrap(); + Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) + }, + other => Err(Error::nyi(format!( + "Deserialize {other:?} to arrow's int64" + ))), + } + }, + PhysicalType::Primitive(PrimitiveType::Int256) => { + let index = indexes.pop_front().unwrap(); + match index.physical_type() { + ParquetPhysicalType::Int32 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) + }, + parquet2::schema::types::PhysicalType::Int64 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok( + primitive::deserialize_i64( + &index.indexes, + &index.primitive_type, + data_type, + ) + .into(), + ) + }, + parquet2::schema::types::PhysicalType::FixedLenByteArray(_) => { + let index = index.as_any().downcast_ref::().unwrap(); + Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) + }, + other => Err(Error::nyi(format!( + "Deserialize {other:?} to arrow's int64" + ))), + } + }, + PhysicalType::Primitive(PrimitiveType::UInt8) + | PhysicalType::Primitive(PrimitiveType::UInt16) + | PhysicalType::Primitive(PrimitiveType::UInt32) + | PhysicalType::Primitive(PrimitiveType::Int32) => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) + }, + PhysicalType::Primitive(PrimitiveType::UInt64) + | PhysicalType::Primitive(PrimitiveType::Int64) => { + let index = indexes.pop_front().unwrap(); + match index.physical_type() { + ParquetPhysicalType::Int64 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok( + primitive::deserialize_i64( + &index.indexes, + &index.primitive_type, + data_type, + ) + .into(), + ) + }, + parquet2::schema::types::PhysicalType::Int96 => { + let index = index + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_i96(&index.indexes, data_type).into()) + }, + other => Err(Error::nyi(format!( + "Deserialize {other:?} to arrow's int64" + ))), + } + }, + PhysicalType::Primitive(PrimitiveType::Float32) => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_id(&index.indexes, data_type).into()) + }, + PhysicalType::Primitive(PrimitiveType::Float64) => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_id(&index.indexes, data_type).into()) + }, + PhysicalType::Binary + | PhysicalType::LargeBinary + | PhysicalType::Utf8 + | PhysicalType::LargeUtf8 => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + binary::deserialize(&index.indexes, &data_type).map(|x| x.into()) + }, + PhysicalType::FixedSizeBinary => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) + }, + PhysicalType::Dictionary(_) => { + if let DataType::Dictionary(_, inner, _) = data_type.to_logical_type() { + deserialize(indexes, (**inner).clone()) + } else { + unreachable!() + } + }, + PhysicalType::List => { + if let DataType::List(inner) = data_type.to_logical_type() { + deserialize(indexes, inner.data_type.clone()) + } else { + unreachable!() + } + }, + PhysicalType::LargeList => { + if let DataType::LargeList(inner) = data_type.to_logical_type() { + deserialize(indexes, inner.data_type.clone()) + } else { + unreachable!() + } + }, + PhysicalType::Map => { + if let DataType::Map(inner, _) = data_type.to_logical_type() { + deserialize(indexes, inner.data_type.clone()) + } else { + unreachable!() + } + }, + PhysicalType::Struct => { + let children_fields = if let DataType::Struct(children) = data_type.to_logical_type() { + children + } else { + unreachable!() + }; + let children = children_fields + .iter() + .map(|child| deserialize(indexes, child.data_type.clone())) + .collect::, Error>>()?; + + Ok(FieldPageStatistics::Multiple(children)) + }, + + other => Err(Error::nyi(format!( + "Deserialize into arrow's {other:?} page index" + ))), + } +} + +/// Checks whether the row group have page index information (page statistics) +pub fn has_indexes(row_group: &RowGroupMetaData) -> bool { + row_group + .columns() + .iter() + .all(|chunk| chunk.column_chunk().column_index_offset.is_some()) +} + +/// Reads the column indexes from the reader assuming a valid set of derived Arrow fields +/// for all parquet the columns in the file. +/// +/// It returns one [`FieldPageStatistics`] per field in `fields` +/// +/// This function is expected to be used to filter out parquet pages. +/// +/// # Implementation +/// This function is IO-bounded and calls `reader.read_exact` exactly once. +/// # Error +/// Errors iff the indexes can't be read or their deserialization to arrow is incorrect (e.g. invalid utf-8) +pub fn read_columns_indexes( + reader: &mut R, + chunks: &[ColumnChunkMetaData], + fields: &[Field], +) -> Result, Error> { + let indexes = _read_columns_indexes(reader, chunks)?; + + fields + .iter() + .map(|field| { + let indexes = get_field_pages(chunks, &indexes, &field.name); + let mut indexes = indexes.into_iter().collect(); + + deserialize(&mut indexes, field.data_type.clone()) + }) + .collect() +} + +/// Returns the set of (row) intervals of the pages. +pub fn compute_page_row_intervals( + locations: &[PageLocation], + num_rows: usize, +) -> Result, ParquetError> { + if locations.is_empty() { + return Ok(vec![]); + }; + + let last = (|| { + let start: usize = locations.last().unwrap().first_row_index.try_into()?; + let length = num_rows - start; + Result::<_, ParquetError>::Ok(Interval::new(start, length)) + })(); + + let pages_lengths = locations + .windows(2) + .map(|x| { + let start = usize::try_from(x[0].first_row_index)?; + let length = usize::try_from(x[1].first_row_index - x[0].first_row_index)?; + Ok(Interval::new(start, length)) + }) + .chain(std::iter::once(last)); + pages_lengths.collect() +} + +/// Reads all page locations and index locations (IO-bounded) and uses `predicate` to compute +/// the set of [`FilteredPage`] that fulfill the predicate. +/// +/// The non-trivial argument of this function is `predicate`, that controls which pages are selected. +/// Its signature contains 2 arguments: +/// * 0th argument (indexes): contains one [`ColumnPageStatistics`] (page statistics) per field. +/// Use it to evaluate the predicate against +/// * 1th argument (intervals): contains one [`Vec>`] (row positions) per field. +/// For each field, the outermost vector corresponds to each parquet column: +/// a primitive field contains 1 column, a struct field with 2 primitive fields contain 2 columns. +/// The inner `Vec` contains one [`Interval`] per page: its length equals the length of [`ColumnPageStatistics`]. +/// It returns a single [`Vec`] denoting the set of intervals that the predicate selects (over all columns). +/// +/// This returns one item per `field`. For each field, there is one item per column (for non-nested types it returns one column) +/// and finally [`Vec`], that corresponds to the set of selected pages. +pub fn read_filtered_pages< + R: Read + Seek, + F: Fn(&[FieldPageStatistics], &[Vec>]) -> Vec, +>( + reader: &mut R, + row_group: &RowGroupMetaData, + fields: &[Field], + predicate: F, + //is_intersection: bool, +) -> Result>>, Error> { + let num_rows = row_group.num_rows(); + + // one vec per column + let locations = read_pages_locations(reader, row_group.columns())?; + // one Vec> per field (non-nested contain a single entry on the first column) + let locations = fields + .iter() + .map(|field| get_field_pages(row_group.columns(), &locations, &field.name)) + .collect::>(); + + // one ColumnPageStatistics per field + let indexes = read_columns_indexes(reader, row_group.columns(), fields)?; + + let intervals = locations + .iter() + .map(|locations| { + locations + .iter() + .map(|locations| Ok(compute_page_row_intervals(locations, num_rows)?)) + .collect::, Error>>() + }) + .collect::, Error>>()?; + + let intervals = predicate(&indexes, &intervals); + + locations + .into_iter() + .map(|locations| { + locations + .into_iter() + .map(|locations| Ok(select_pages(&intervals, locations, num_rows)?)) + .collect::, Error>>() + }) + .collect() +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/primitive.rs b/crates/nano-arrow/src/io/parquet/read/indexes/primitive.rs new file mode 100644 index 000000000000..90e52e4a4aaf --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/primitive.rs @@ -0,0 +1,222 @@ +use ethnum::I256; +use parquet2::indexes::PageIndex; +use parquet2::schema::types::{PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit}; +use parquet2::types::int96_to_i64_ns; + +use super::ColumnPageStatistics; +use crate::array::{Array, MutablePrimitiveArray, PrimitiveArray}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::trusted_len::TrustedLen; +use crate::types::{i256, NativeType}; + +#[inline] +fn deserialize_int32>>( + iter: I, + data_type: DataType, +) -> Box { + use DataType::*; + match data_type.to_logical_type() { + UInt8 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u8))) + .to(data_type), + ) as _, + UInt16 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u16))) + .to(data_type), + ), + UInt32 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u32))) + .to(data_type), + ), + Decimal(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as i128))) + .to(data_type), + ), + Decimal256(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter( + iter.map(|x| x.map(|x| i256(I256::new(x.into())))), + ) + .to(data_type), + ) as _, + _ => Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)), + } +} + +#[inline] +fn timestamp( + array: &mut MutablePrimitiveArray, + time_unit: TimeUnit, + logical_type: Option, +) { + let unit = if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + unit + } else { + return; + }; + + match (unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + (ParquetTimeUnit::Microseconds, TimeUnit::Second) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000_000), + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000_000_000), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) => {}, + (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000_000), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x *= 1_000), + (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) => {}, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x *= 1_000_000), + (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => {}, + } +} + +#[inline] +fn deserialize_int64>>( + iter: I, + primitive_type: &PrimitiveType, + data_type: DataType, +) -> Box { + use DataType::*; + match data_type.to_logical_type() { + UInt64 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u64))) + .to(data_type), + ) as _, + Decimal(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as i128))) + .to(data_type), + ) as _, + Decimal256(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter( + iter.map(|x| x.map(|x| i256(I256::new(x.into())))), + ) + .to(data_type), + ) as _, + Timestamp(time_unit, _) => { + let mut array = + MutablePrimitiveArray::::from_trusted_len_iter(iter).to(data_type.clone()); + + timestamp(&mut array, *time_unit, primitive_type.logical_type); + + let array: PrimitiveArray = array.into(); + + Box::new(array) + }, + _ => Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)), + } +} + +#[inline] +fn deserialize_int96>>( + iter: I, + data_type: DataType, +) -> Box { + Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(int96_to_i64_ns))) + .to(data_type), + ) +} + +#[inline] +fn deserialize_id_s>>( + iter: I, + data_type: DataType, +) -> Box { + Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)) +} + +pub fn deserialize_i32(indexes: &[PageIndex], data_type: DataType) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_int32(indexes.iter().map(|index| index.min), data_type.clone()), + max: deserialize_int32(indexes.iter().map(|index| index.max), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +pub fn deserialize_i64( + indexes: &[PageIndex], + primitive_type: &PrimitiveType, + data_type: DataType, +) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_int64( + indexes.iter().map(|index| index.min), + primitive_type, + data_type.clone(), + ), + max: deserialize_int64( + indexes.iter().map(|index| index.max), + primitive_type, + data_type, + ), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +pub fn deserialize_i96( + indexes: &[PageIndex<[u32; 3]>], + data_type: DataType, +) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_int96(indexes.iter().map(|index| index.min), data_type.clone()), + max: deserialize_int96(indexes.iter().map(|index| index.max), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +pub fn deserialize_id( + indexes: &[PageIndex], + data_type: DataType, +) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_id_s(indexes.iter().map(|index| index.min), data_type.clone()), + max: deserialize_id_s(indexes.iter().map(|index| index.max), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/mod.rs b/crates/nano-arrow/src/io/parquet/read/mod.rs new file mode 100644 index 000000000000..52a4d07d922e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/mod.rs @@ -0,0 +1,95 @@ +//! APIs to read from Parquet format. +#![allow(clippy::type_complexity)] + +mod deserialize; +mod file; +pub mod indexes; +mod row_group; +pub mod schema; +pub mod statistics; + +use std::io::{Read, Seek}; + +pub use deserialize::{ + column_iter_to_arrays, create_list, create_map, get_page_iterator, init_nested, n_columns, + InitNested, NestedArrayIter, NestedState, StructIterator, +}; +pub use file::{FileReader, RowGroupReader}; +use futures::{AsyncRead, AsyncSeek}; +// re-exports of parquet2's relevant APIs +pub use parquet2::{ + error::Error as ParquetError, + fallible_streaming_iterator, + metadata::{ColumnChunkMetaData, ColumnDescriptor, RowGroupMetaData}, + page::{CompressedDataPage, DataPageHeader, Page}, + read::{ + decompress, get_column_iterator, get_page_stream, + read_columns_indexes as _read_columns_indexes, read_metadata as _read_metadata, + read_metadata_async as _read_metadata_async, read_pages_locations, BasicDecompressor, + Decompressor, MutStreamingIterator, PageFilter, PageReader, ReadColumnIterator, State, + }, + schema::types::{ + GroupLogicalType, ParquetType, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, + TimeUnit as ParquetTimeUnit, + }, + types::int96_to_i64_ns, + FallibleStreamingIterator, +}; +pub use row_group::*; +pub use schema::{infer_schema, FileMetaData}; + +use crate::array::Array; +use crate::error::Result; +use crate::types::{i256, NativeType}; + +/// Trait describing a [`FallibleStreamingIterator`] of [`Page`] +pub trait Pages: + FallibleStreamingIterator + Send + Sync +{ +} + +impl + Send + Sync> Pages for I {} + +/// Type def for a sharable, boxed dyn [`Iterator`] of arrays +pub type ArrayIter<'a> = Box>> + Send + Sync + 'a>; + +/// Reads parquets' metadata synchronously. +pub fn read_metadata(reader: &mut R) -> Result { + Ok(_read_metadata(reader)?) +} + +/// Reads parquets' metadata asynchronously. +pub async fn read_metadata_async( + reader: &mut R, +) -> Result { + Ok(_read_metadata_async(reader).await?) +} + +fn convert_days_ms(value: &[u8]) -> crate::types::days_ms { + crate::types::days_ms( + i32::from_le_bytes(value[4..8].try_into().unwrap()), + i32::from_le_bytes(value[8..12].try_into().unwrap()), + ) +} + +fn convert_i128(value: &[u8], n: usize) -> i128 { + // Copy the fixed-size byte value to the start of a 16 byte stack + // allocated buffer, then use an arithmetic right shift to fill in + // MSBs, which accounts for leading 1's in negative (two's complement) + // values. + let mut bytes = [0u8; 16]; + bytes[..n].copy_from_slice(value); + i128::from_be_bytes(bytes) >> (8 * (16 - n)) +} + +fn convert_i256(value: &[u8]) -> i256 { + if value[0] >= 128 { + let mut neg_bytes = [255u8; 32]; + neg_bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(neg_bytes) + } else { + let mut bytes = [0u8; 32]; + bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(bytes) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/row_group.rs b/crates/nano-arrow/src/io/parquet/read/row_group.rs new file mode 100644 index 000000000000..26480a1a2602 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/row_group.rs @@ -0,0 +1,248 @@ +use std::io::{Read, Seek}; + +use parquet2::indexes::FilteredPage; +use parquet2::metadata::ColumnChunkMetaData; +use parquet2::read::{BasicDecompressor, IndexedPageReader, PageMetaData, PageReader}; + +use super::{ArrayIter, RowGroupMetaData}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Field; +use crate::error::Result; +use crate::io::parquet::read::column_iter_to_arrays; + +/// An [`Iterator`] of [`Chunk`] that (dynamically) adapts a vector of iterators of [`Array`] into +/// an iterator of [`Chunk`]. +/// +/// This struct tracks advances each of the iterators individually and combines the +/// result in a single [`Chunk`]. +/// +/// # Implementation +/// This iterator is single-threaded and advancing it is CPU-bounded. +pub struct RowGroupDeserializer { + num_rows: usize, + remaining_rows: usize, + column_chunks: Vec>, +} + +impl RowGroupDeserializer { + /// Creates a new [`RowGroupDeserializer`]. + /// + /// # Panic + /// This function panics iff any of the `column_chunks` + /// do not return an array with an equal length. + pub fn new( + column_chunks: Vec>, + num_rows: usize, + limit: Option, + ) -> Self { + Self { + num_rows, + remaining_rows: limit.unwrap_or(usize::MAX).min(num_rows), + column_chunks, + } + } + + /// Returns the number of rows on this row group + pub fn num_rows(&self) -> usize { + self.num_rows + } +} + +impl Iterator for RowGroupDeserializer { + type Item = Result>>; + + fn next(&mut self) -> Option { + if self.remaining_rows == 0 { + return None; + } + let chunk = self + .column_chunks + .iter_mut() + .map(|iter| iter.next().unwrap()) + .collect::>>() + .and_then(Chunk::try_new); + self.remaining_rows = self.remaining_rows.saturating_sub( + chunk + .as_ref() + .map(|x| x.len()) + .unwrap_or(self.remaining_rows), + ); + + Some(chunk) + } +} + +/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. +/// For non-nested parquet types, this returns a single column +pub fn get_field_columns<'a>( + columns: &'a [ColumnChunkMetaData], + field_name: &str, +) -> Vec<&'a ColumnChunkMetaData> { + columns + .iter() + .filter(|x| x.descriptor().path_in_schema[0] == field_name) + .collect() +} + +/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. +/// For non-nested parquet types, this returns a single column +pub fn get_field_pages<'a, T>( + columns: &'a [ColumnChunkMetaData], + items: &'a [T], + field_name: &str, +) -> Vec<&'a T> { + columns + .iter() + .zip(items) + .filter(|(metadata, _)| metadata.descriptor().path_in_schema[0] == field_name) + .map(|(_, item)| item) + .collect() +} + +/// Reads all columns that are part of the parquet field `field_name` +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns associated to +/// the field (one for non-nested types) +pub fn read_columns<'a, R: Read + Seek>( + reader: &mut R, + columns: &'a [ColumnChunkMetaData], + field_name: &str, +) -> Result)>> { + get_field_columns(columns, field_name) + .into_iter() + .map(|meta| _read_single_column(reader, meta)) + .collect() +} + +fn _read_single_column<'a, R>( + reader: &mut R, + meta: &'a ColumnChunkMetaData, +) -> Result<(&'a ColumnChunkMetaData, Vec)> +where + R: Read + Seek, +{ + let (start, length) = meta.byte_range(); + reader.seek(std::io::SeekFrom::Start(start))?; + + let mut chunk = vec![]; + chunk.try_reserve(length as usize)?; + reader.by_ref().take(length).read_to_end(&mut chunk)?; + Ok((meta, chunk)) +} + +type Pages = Box< + dyn Iterator> + + Sync + + Send, +>; + +/// Converts a vector of columns associated with the parquet field whose name is [`Field`] +/// to an iterator of [`Array`], [`ArrayIter`] of chunk size `chunk_size`. +pub fn to_deserializer<'a>( + columns: Vec<(&ColumnChunkMetaData, Vec)>, + field: Field, + num_rows: usize, + chunk_size: Option, + pages: Option>>, +) -> Result> { + let chunk_size = chunk_size.map(|c| c.min(num_rows)); + + let (columns, types) = if let Some(pages) = pages { + let (columns, types): (Vec<_>, Vec<_>) = columns + .into_iter() + .zip(pages) + .map(|((column_meta, chunk), mut pages)| { + // de-offset the start, since we read in chunks (and offset is from start of file) + let mut meta: PageMetaData = column_meta.into(); + pages + .iter_mut() + .for_each(|page| page.start -= meta.column_start); + meta.column_start = 0; + let pages = IndexedPageReader::new_with_page_meta( + std::io::Cursor::new(chunk), + meta, + pages, + vec![], + vec![], + ); + let pages = Box::new(pages) as Pages; + ( + BasicDecompressor::new(pages, vec![]), + &column_meta.descriptor().descriptor.primitive_type, + ) + }) + .unzip(); + + (columns, types) + } else { + let (columns, types): (Vec<_>, Vec<_>) = columns + .into_iter() + .map(|(column_meta, chunk)| { + let len = chunk.len(); + let pages = PageReader::new( + std::io::Cursor::new(chunk), + column_meta, + std::sync::Arc::new(|_, _| true), + vec![], + len * 2 + 1024, + ); + let pages = Box::new(pages) as Pages; + ( + BasicDecompressor::new(pages, vec![]), + &column_meta.descriptor().descriptor.primitive_type, + ) + }) + .unzip(); + + (columns, types) + }; + + column_iter_to_arrays(columns, types, field, chunk_size, num_rows) +} + +/// Returns a vector of iterators of [`Array`] ([`ArrayIter`]) corresponding to the top +/// level parquet fields whose name matches `fields`'s names. +/// +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns in the row group - +/// it reads all the columns to memory from the row group associated to the requested fields. +/// +/// This operation is single-threaded. For readers with stronger invariants +/// (e.g. implement [`Clone`]) you can use [`read_columns`] to read multiple columns at once +/// and convert them to [`ArrayIter`] via [`to_deserializer`]. +pub fn read_columns_many<'a, R: Read + Seek>( + reader: &mut R, + row_group: &RowGroupMetaData, + fields: Vec, + chunk_size: Option, + limit: Option, + pages: Option>>>, +) -> Result>> { + let num_rows = row_group.num_rows(); + let num_rows = limit.map(|limit| limit.min(num_rows)).unwrap_or(num_rows); + + // reads all the necessary columns for all fields from the row group + // This operation is IO-bounded `O(C)` where C is the number of columns in the row group + let field_columns = fields + .iter() + .map(|field| read_columns(reader, row_group.columns(), &field.name)) + .collect::>>()?; + + if let Some(pages) = pages { + field_columns + .into_iter() + .zip(fields) + .zip(pages) + .map(|((columns, field), pages)| { + to_deserializer(columns, field, num_rows, chunk_size, Some(pages)) + }) + .collect() + } else { + field_columns + .into_iter() + .zip(fields) + .map(|(columns, field)| to_deserializer(columns, field, num_rows, chunk_size, None)) + .collect() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/schema/convert.rs b/crates/nano-arrow/src/io/parquet/read/schema/convert.rs new file mode 100644 index 000000000000..4f55ffb872d3 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/schema/convert.rs @@ -0,0 +1,1091 @@ +//! This module has entry points, [`parquet_to_arrow_schema`] and the more configurable [`parquet_to_arrow_schema_with_options`]. +use parquet2::schema::types::{ + FieldInfo, GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType, + PrimitiveConvertedType, PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, +}; +use parquet2::schema::Repetition; + +use crate::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use crate::io::parquet::read::schema::SchemaInferenceOptions; + +/// Converts [`ParquetType`]s to a [`Field`], ignoring parquet fields that do not contain +/// any physical column. +pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> Vec { + parquet_to_arrow_schema_with_options(fields, &None) +} + +/// Like [`parquet_to_arrow_schema`] but with configurable options which affect the behavior of schema inference +pub fn parquet_to_arrow_schema_with_options( + fields: &[ParquetType], + options: &Option, +) -> Vec { + fields + .iter() + .filter_map(|f| to_field(f, options.as_ref().unwrap_or(&Default::default()))) + .collect::>() +} + +fn from_int32( + logical_type: Option, + converted_type: Option, +) -> DataType { + use PrimitiveLogicalType::*; + match (logical_type, converted_type) { + // handle logical types first + (Some(Integer(t)), _) => match t { + IntegerType::Int8 => DataType::Int8, + IntegerType::Int16 => DataType::Int16, + IntegerType::Int32 => DataType::Int32, + IntegerType::UInt8 => DataType::UInt8, + IntegerType::UInt16 => DataType::UInt16, + IntegerType::UInt32 => DataType::UInt32, + // The above are the only possible annotations for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => DataType::Int32, + }, + (Some(Decimal(precision, scale)), _) => DataType::Decimal(precision, scale), + (Some(Date), _) => DataType::Date32, + (Some(Time { unit, .. }), _) => match unit { + ParquetTimeUnit::Milliseconds => DataType::Time32(TimeUnit::Millisecond), + // MILLIS is the only possible annotation for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => DataType::Int32, + }, + // handle converted types: + (_, Some(PrimitiveConvertedType::Uint8)) => DataType::UInt8, + (_, Some(PrimitiveConvertedType::Uint16)) => DataType::UInt16, + (_, Some(PrimitiveConvertedType::Uint32)) => DataType::UInt32, + (_, Some(PrimitiveConvertedType::Int8)) => DataType::Int8, + (_, Some(PrimitiveConvertedType::Int16)) => DataType::Int16, + (_, Some(PrimitiveConvertedType::Int32)) => DataType::Int32, + (_, Some(PrimitiveConvertedType::Date)) => DataType::Date32, + (_, Some(PrimitiveConvertedType::TimeMillis)) => DataType::Time32(TimeUnit::Millisecond), + (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + DataType::Decimal(precision, scale) + }, + (_, _) => DataType::Int32, + } +} + +fn from_int64( + logical_type: Option, + converted_type: Option, +) -> DataType { + use PrimitiveLogicalType::*; + match (logical_type, converted_type) { + // handle logical types first + (Some(Integer(integer)), _) => match integer { + IntegerType::UInt64 => DataType::UInt64, + IntegerType::Int64 => DataType::Int64, + _ => DataType::Int64, + }, + ( + Some(Timestamp { + is_adjusted_to_utc, + unit, + }), + _, + ) => { + let timezone = if is_adjusted_to_utc { + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // A TIMESTAMP with isAdjustedToUTC=true is defined as [...] elapsed since the Unix epoch + Some("+00:00".to_string()) + } else { + // PARQUET: + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // A TIMESTAMP with isAdjustedToUTC=false represents [...] such + // timestamps should always be displayed the same way, regardless of the local time zone in effect + // ARROW: + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // If the time zone is null or equal to an empty string, the data is "time + // zone naive" and shall be displayed *as is* to the user, not localized + // to the locale of the user. + None + }; + + match unit { + ParquetTimeUnit::Milliseconds => { + DataType::Timestamp(TimeUnit::Millisecond, timezone) + }, + ParquetTimeUnit::Microseconds => { + DataType::Timestamp(TimeUnit::Microsecond, timezone) + }, + ParquetTimeUnit::Nanoseconds => DataType::Timestamp(TimeUnit::Nanosecond, timezone), + } + }, + (Some(Time { unit, .. }), _) => match unit { + ParquetTimeUnit::Microseconds => DataType::Time64(TimeUnit::Microsecond), + ParquetTimeUnit::Nanoseconds => DataType::Time64(TimeUnit::Nanosecond), + // MILLIS is only possible for int32. Appearing in int64 is a deviation + // to parquet's spec, which we ignore + _ => DataType::Int64, + }, + (Some(Decimal(precision, scale)), _) => DataType::Decimal(precision, scale), + // handle converted types: + (_, Some(PrimitiveConvertedType::TimeMicros)) => DataType::Time64(TimeUnit::Microsecond), + (_, Some(PrimitiveConvertedType::TimestampMillis)) => { + DataType::Timestamp(TimeUnit::Millisecond, None) + }, + (_, Some(PrimitiveConvertedType::TimestampMicros)) => { + DataType::Timestamp(TimeUnit::Microsecond, None) + }, + (_, Some(PrimitiveConvertedType::Int64)) => DataType::Int64, + (_, Some(PrimitiveConvertedType::Uint64)) => DataType::UInt64, + (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + DataType::Decimal(precision, scale) + }, + + (_, _) => DataType::Int64, + } +} + +fn from_byte_array( + logical_type: &Option, + converted_type: &Option, +) -> DataType { + match (logical_type, converted_type) { + (Some(PrimitiveLogicalType::String), _) => DataType::Utf8, + (Some(PrimitiveLogicalType::Json), _) => DataType::Binary, + (Some(PrimitiveLogicalType::Bson), _) => DataType::Binary, + (Some(PrimitiveLogicalType::Enum), _) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Json)) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Bson)) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Enum)) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Utf8)) => DataType::Utf8, + (_, _) => DataType::Binary, + } +} + +fn from_fixed_len_byte_array( + length: usize, + logical_type: Option, + converted_type: Option, +) -> DataType { + match (logical_type, converted_type) { + (Some(PrimitiveLogicalType::Decimal(precision, scale)), _) => { + DataType::Decimal(precision, scale) + }, + (None, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + DataType::Decimal(precision, scale) + }, + (None, Some(PrimitiveConvertedType::Interval)) => { + // There is currently no reliable way of determining which IntervalUnit + // to return. Thus without the original Arrow schema, the results + // would be incorrect if all 12 bytes of the interval are populated + DataType::Interval(IntervalUnit::DayTime) + }, + _ => DataType::FixedSizeBinary(length), + } +} + +/// Maps a [`PhysicalType`] with optional metadata to a [`DataType`] +fn to_primitive_type_inner( + primitive_type: &PrimitiveType, + options: &SchemaInferenceOptions, +) -> DataType { + match primitive_type.physical_type { + PhysicalType::Boolean => DataType::Boolean, + PhysicalType::Int32 => { + from_int32(primitive_type.logical_type, primitive_type.converted_type) + }, + PhysicalType::Int64 => { + from_int64(primitive_type.logical_type, primitive_type.converted_type) + }, + PhysicalType::Int96 => DataType::Timestamp(options.int96_coerce_to_timeunit, None), + PhysicalType::Float => DataType::Float32, + PhysicalType::Double => DataType::Float64, + PhysicalType::ByteArray => { + from_byte_array(&primitive_type.logical_type, &primitive_type.converted_type) + }, + PhysicalType::FixedLenByteArray(length) => from_fixed_len_byte_array( + length, + primitive_type.logical_type, + primitive_type.converted_type, + ), + } +} + +/// Entry point for converting parquet primitive type to arrow type. +/// +/// This function takes care of repetition. +fn to_primitive_type(primitive_type: &PrimitiveType, options: &SchemaInferenceOptions) -> DataType { + let base_type = to_primitive_type_inner(primitive_type, options); + + if primitive_type.field_info.repetition == Repetition::Repeated { + DataType::List(Box::new(Field::new( + &primitive_type.field_info.name, + base_type, + is_nullable(&primitive_type.field_info), + ))) + } else { + base_type + } +} + +fn non_repeated_group( + logical_type: &Option, + converted_type: &Option, + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + debug_assert!(!fields.is_empty()); + match (logical_type, converted_type) { + (Some(GroupLogicalType::List), _) => to_list(fields, parent_name, options), + (None, Some(GroupConvertedType::List)) => to_list(fields, parent_name, options), + (Some(GroupLogicalType::Map), _) => to_list(fields, parent_name, options), + (None, Some(GroupConvertedType::Map) | Some(GroupConvertedType::MapKeyValue)) => { + to_map(fields, options) + }, + _ => to_struct(fields, options), + } +} + +/// Converts a parquet group type to an arrow [`DataType::Struct`]. +/// Returns [`None`] if all its fields are empty +fn to_struct(fields: &[ParquetType], options: &SchemaInferenceOptions) -> Option { + let fields = fields + .iter() + .filter_map(|f| to_field(f, options)) + .collect::>(); + if fields.is_empty() { + None + } else { + Some(DataType::Struct(fields)) + } +} + +/// Converts a parquet group type to an arrow [`DataType::Struct`]. +/// Returns [`None`] if all its fields are empty +fn to_map(fields: &[ParquetType], options: &SchemaInferenceOptions) -> Option { + let inner = to_field(&fields[0], options)?; + Some(DataType::Map(Box::new(inner), false)) +} + +/// Entry point for converting parquet group type. +/// +/// This function takes care of logical type and repetition. +fn to_group_type( + field_info: &FieldInfo, + logical_type: &Option, + converted_type: &Option, + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + debug_assert!(!fields.is_empty()); + if field_info.repetition == Repetition::Repeated { + Some(DataType::List(Box::new(Field::new( + &field_info.name, + to_struct(fields, options)?, + is_nullable(field_info), + )))) + } else { + non_repeated_group(logical_type, converted_type, fields, parent_name, options) + } +} + +/// Checks whether this schema is nullable. +pub(crate) fn is_nullable(field_info: &FieldInfo) -> bool { + match field_info.repetition { + Repetition::Optional => true, + Repetition::Repeated => true, + Repetition::Required => false, + } +} + +/// Converts parquet schema to arrow field. +/// Returns `None` iff the parquet type has no associated primitive types, +/// i.e. if it is a column-less group type. +fn to_field(type_: &ParquetType, options: &SchemaInferenceOptions) -> Option { + Some(Field::new( + &type_.get_field_info().name, + to_data_type(type_, options)?, + is_nullable(type_.get_field_info()), + )) +} + +/// Converts a parquet list to arrow list. +/// +/// To fully understand this algorithm, please refer to +/// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). +fn to_list( + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + let item = fields.first().unwrap(); + + let item_type = match item { + ParquetType::PrimitiveType(primitive) => Some(to_primitive_type_inner(primitive, options)), + ParquetType::GroupType { fields, .. } => { + if fields.len() == 1 + && item.name() != "array" + && item.name() != format!("{parent_name}_tuple") + { + // extract the repetition field + let nested_item = fields.first().unwrap(); + to_data_type(nested_item, options) + } else { + to_struct(fields, options) + } + }, + }?; + + // Check that the name of the list child is "list", in which case we + // get the child nullability and name (normally "element") from the nested + // group type. + // Without this step, the child incorrectly inherits the parent's optionality + let (list_item_name, item_is_optional) = match item { + ParquetType::GroupType { + field_info, fields, .. + } if field_info.name == "list" && fields.len() == 1 => { + let field = fields.first().unwrap(); + ( + &field.get_field_info().name, + field.get_field_info().repetition == Repetition::Optional, + ) + }, + _ => ( + &item.get_field_info().name, + item.get_field_info().repetition == Repetition::Optional, + ), + }; + + Some(DataType::List(Box::new(Field::new( + list_item_name, + item_type, + item_is_optional, + )))) +} + +/// Converts parquet schema to arrow data type. +/// +/// This function discards schema name. +/// +/// If this schema is a primitive type and not included in the leaves, the result is +/// Ok(None). +/// +/// If this schema is a group type and none of its children is reserved in the +/// conversion, the result is Ok(None). +pub(crate) fn to_data_type( + type_: &ParquetType, + options: &SchemaInferenceOptions, +) -> Option { + match type_ { + ParquetType::PrimitiveType(primitive) => Some(to_primitive_type(primitive, options)), + ParquetType::GroupType { + field_info, + logical_type, + converted_type, + fields, + } => { + if fields.is_empty() { + None + } else { + to_group_type( + field_info, + logical_type, + converted_type, + fields, + &field_info.name, + options, + ) + } + }, + } +} + +#[cfg(test)] +mod tests { + use parquet2::metadata::SchemaDescriptor; + + use super::*; + use crate::datatypes::{DataType, Field, TimeUnit}; + use crate::error::Result; + + #[test] + fn test_flat_primitives() -> Result<()> { + let message = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 uint8 (INTEGER(8,false)); + REQUIRED INT32 uint16 (INTEGER(16,false)); + REQUIRED INT32 int32; + REQUIRED INT64 int64 ; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + OPTIONAL BINARY string_2 (STRING); + } + "; + let expected = &[ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("int16", DataType::Int16, false), + Field::new("uint8", DataType::UInt8, false), + Field::new("uint16", DataType::UInt16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new("string_2", DataType::Utf8, true), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_byte_array_fields() -> Result<()> { + let message = " + message test_schema { + REQUIRED BYTE_ARRAY binary; + REQUIRED FIXED_LEN_BYTE_ARRAY (20) fixed_binary; + } + "; + let expected = vec![ + Field::new("binary", DataType::Binary, false), + Field::new("fixed_binary", DataType::FixedSizeBinary(20), false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_duplicate_fields() -> Result<()> { + let message = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + } + "; + let expected = &[ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_parquet_lists() -> Result<()> { + let mut arrow_fields = Vec::new(); + + // LIST encoding example taken from parquet-format/LogicalTypes.md + let message_type = " + message test_schema { + REQUIRED GROUP my_list (LIST) { + REPEATED GROUP list { + OPTIONAL BINARY element (UTF8); + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + OPTIONAL GROUP array_of_arrays (LIST) { + REPEATED GROUP list { + REQUIRED GROUP element (LIST) { + REPEATED GROUP list { + REQUIRED INT32 element; + } + } + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP element { + REQUIRED BINARY str (UTF8); + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED INT32 element; + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP element { + REQUIRED BINARY str (UTF8); + REQUIRED INT32 num; + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP array { + REQUIRED BINARY str (UTF8); + } + + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP my_list_tuple { + REQUIRED BINARY str (UTF8); + } + } + REPEATED INT32 name; + } + "; + + // // List (list non-null, elements nullable) + // required group my_list (LIST) { + // repeated group list { + // optional binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Utf8, true))), + false, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + true, + )); + } + + // Element types can be nested structures. For example, a list of lists: + // + // // List> + // optional group array_of_arrays (LIST) { + // repeated group list { + // required group element (LIST) { + // repeated group list { + // required int32 element; + // } + // } + // } + // } + { + let arrow_inner_list = + DataType::List(Box::new(Field::new("element", DataType::Int32, false))); + arrow_fields.push(Field::new( + "array_of_arrays", + DataType::List(Box::new(Field::new("element", arrow_inner_list, false))), + true, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // }; + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + true, + )); + } + + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Int32, false))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + { + let arrow_struct = DataType::Struct(vec![ + Field::new("str", DataType::Utf8, false), + Field::new("num", DataType::Int32, false), + ]); + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", arrow_struct, false))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // Special case: group is named array + { + let arrow_struct = DataType::Struct(vec![Field::new("str", DataType::Utf8, false)]); + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("array", arrow_struct, false))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // Special case: group named ends in _tuple + { + let arrow_struct = DataType::Struct(vec![Field::new("str", DataType::Utf8, false)]); + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("my_list_tuple", arrow_struct, false))), + true, + )); + } + + // One-level encoding: Only allows required lists with required cells + // repeated value_type name + { + arrow_fields.push(Field::new( + "name", + DataType::List(Box::new(Field::new("name", DataType::Int32, false))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_parquet_list_with_struct() -> Result<()> { + let mut arrow_fields = Vec::new(); + + let message_type = " + message eventlog { + REQUIRED group events (LIST) { + REPEATED group array { + REQUIRED BYTE_ARRAY event_name (STRING); + REQUIRED INT64 event_time (TIMESTAMP(MILLIS,true)); + } + } + } + "; + + { + let struct_fields = vec![ + Field::new("event_name", DataType::Utf8, false), + Field::new( + "event_time", + DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), + false, + ), + ]; + arrow_fields.push(Field::new( + "events", + DataType::List(Box::new(Field::new( + "array", + DataType::Struct(struct_fields), + false, + ))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_parquet_list_nullable() -> Result<()> { + let mut arrow_fields = Vec::new(); + + let message_type = " + message test_schema { + REQUIRED GROUP my_list1 (LIST) { + REPEATED GROUP list { + OPTIONAL BINARY element (UTF8); + } + } + OPTIONAL GROUP my_list2 (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + REQUIRED GROUP my_list3 (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + } + "; + + // // List (list non-null, elements nullable) + // required group my_list1 (LIST) { + // repeated group list { + // optional binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list1", + DataType::List(Box::new(Field::new("element", DataType::Utf8, true))), + false, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list2 (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list2", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + true, + )); + } + + // // List (list non-null, elements non-null) + // repeated group my_list3 (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list3", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_nested_schema() -> Result<()> { + let mut arrow_fields = Vec::new(); + { + let group1_fields = vec![ + Field::new("leaf1", DataType::Boolean, false), + Field::new("leaf2", DataType::Int32, false), + ]; + let group1_struct = Field::new("group1", DataType::Struct(group1_fields), false); + arrow_fields.push(group1_struct); + + let leaf3_field = Field::new("leaf3", DataType::Int64, false); + arrow_fields.push(leaf3_field); + } + + let message_type = " + message test_schema { + REQUIRED GROUP group1 { + REQUIRED BOOLEAN leaf1; + REQUIRED INT32 leaf2; + } + REQUIRED INT64 leaf3; + } + "; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_repeated_nested_schema() -> Result<()> { + let mut arrow_fields = Vec::new(); + { + arrow_fields.push(Field::new("leaf1", DataType::Int32, true)); + + let inner_group_list = Field::new( + "innerGroup", + DataType::List(Box::new(Field::new( + "innerGroup", + DataType::Struct(vec![Field::new("leaf3", DataType::Int32, true)]), + false, + ))), + false, + ); + + let outer_group_list = Field::new( + "outerGroup", + DataType::List(Box::new(Field::new( + "outerGroup", + DataType::Struct(vec![ + Field::new("leaf2", DataType::Int32, true), + inner_group_list, + ]), + false, + ))), + false, + ); + arrow_fields.push(outer_group_list); + } + + let message_type = " + message test_schema { + OPTIONAL INT32 leaf1; + REPEATED GROUP outerGroup { + OPTIONAL INT32 leaf2; + REPEATED GROUP innerGroup { + OPTIONAL INT32 leaf3; + } + } + } + "; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_column_desc_to_field() -> Result<()> { + let message_type = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 uint8 (INTEGER(8,false)); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 uint16 (INTEGER(16,false)); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + REPEATED BOOLEAN bools; + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME_MILLIS); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 time_nano (TIME(NANOS,false)); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP_MICROS); + REQUIRED INT64 ts_nano (TIMESTAMP(NANOS,true)); + } + "; + let arrow_fields = vec![ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("uint8", DataType::UInt8, false), + Field::new("int16", DataType::Int16, false), + Field::new("uint16", DataType::UInt16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new( + "bools", + DataType::List(Box::new(Field::new("bools", DataType::Boolean, false))), + false, + ), + Field::new("date", DataType::Date32, true), + Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micro", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_nano", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new( + "ts_milli", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "ts_nano", + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".to_string())), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_field_to_column_desc() -> Result<()> { + let message_type = " + message arrow_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INTEGER(16,true)); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (STRING); + OPTIONAL GROUP bools (LIST) { + REPEATED GROUP list { + OPTIONAL BOOLEAN element; + } + } + REQUIRED GROUP bools_non_null (LIST) { + REPEATED GROUP list { + REQUIRED BOOLEAN element; + } + } + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME(MILLIS,false)); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP(MICROS,false)); + REQUIRED GROUP struct { + REQUIRED BOOLEAN bools; + REQUIRED INT32 uint32 (INTEGER(32,false)); + REQUIRED GROUP int32 (LIST) { + REPEATED GROUP list { + OPTIONAL INT32 element; + } + } + } + REQUIRED BINARY dictionary_strings (STRING); + } + "; + + let arrow_fields = vec![ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("int16", DataType::Int16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new( + "bools", + DataType::List(Box::new(Field::new("element", DataType::Boolean, true))), + true, + ), + Field::new( + "bools_non_null", + DataType::List(Box::new(Field::new("element", DataType::Boolean, false))), + false, + ), + Field::new("date", DataType::Date32, true), + Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micro", DataType::Time64(TimeUnit::Microsecond), true), + Field::new( + "ts_milli", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "struct", + DataType::Struct(vec![ + Field::new("bools", DataType::Boolean, false), + Field::new("uint32", DataType::UInt32, false), + Field::new( + "int32", + DataType::List(Box::new(Field::new("element", DataType::Int32, true))), + false, + ), + ]), + false, + ), + Field::new("dictionary_strings", DataType::Utf8, false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_int96_options() -> Result<()> { + for tu in [ + TimeUnit::Second, + TimeUnit::Microsecond, + TimeUnit::Millisecond, + TimeUnit::Nanosecond, + ] { + let message_type = " + message arrow_schema { + REQUIRED INT96 int96_field; + OPTIONAL GROUP int96_list (LIST) { + REPEATED GROUP list { + OPTIONAL INT96 element; + } + } + REQUIRED GROUP int96_struct { + REQUIRED INT96 int96_field; + } + } + "; + let coerced_to = DataType::Timestamp(tu, None); + let arrow_fields = vec![ + Field::new("int96_field", coerced_to.clone(), false), + Field::new( + "int96_list", + DataType::List(Box::new(Field::new("element", coerced_to.clone(), true))), + true, + ), + Field::new( + "int96_struct", + DataType::Struct(vec![Field::new("int96_field", coerced_to.clone(), false)]), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema_with_options( + parquet_schema.fields(), + &Some(SchemaInferenceOptions { + int96_coerce_to_timeunit: tu, + }), + ); + assert_eq!(arrow_fields, fields); + } + Ok(()) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/schema/metadata.rs b/crates/nano-arrow/src/io/parquet/read/schema/metadata.rs new file mode 100644 index 000000000000..574ff08d1fd5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/schema/metadata.rs @@ -0,0 +1,55 @@ +use base64::engine::general_purpose; +use base64::Engine as _; +pub use parquet2::metadata::KeyValue; + +use super::super::super::ARROW_SCHEMA_META_KEY; +use crate::datatypes::{Metadata, Schema}; +use crate::error::{Error, Result}; +use crate::io::ipc::read::deserialize_schema; + +/// Reads an arrow schema from Parquet's file metadata. Returns `None` if no schema was found. +/// # Errors +/// Errors iff the schema cannot be correctly parsed. +pub fn read_schema_from_metadata(metadata: &mut Metadata) -> Result> { + metadata + .remove(ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)) + .transpose() +} + +/// Try to convert Arrow schema metadata into a schema +fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Result { + let decoded = general_purpose::STANDARD.decode(encoded_meta); + match decoded { + Ok(bytes) => { + let slice = if bytes[0..4] == [255u8; 4] { + &bytes[8..] + } else { + bytes.as_slice() + }; + deserialize_schema(slice).map(|x| x.0) + }, + Err(err) => { + // The C++ implementation returns an error if the schema can't be parsed. + Err(Error::InvalidArgumentError(format!( + "Unable to decode the encoded schema stored in {ARROW_SCHEMA_META_KEY}, {err:?}" + ))) + }, + } +} + +pub(super) fn parse_key_value_metadata(key_value_metadata: &Option>) -> Metadata { + key_value_metadata + .as_ref() + .map(|key_values| { + key_values + .iter() + .filter_map(|kv| { + kv.value + .as_ref() + .map(|value| (kv.key.clone(), value.clone())) + }) + .collect() + }) + .unwrap_or_default() +} diff --git a/crates/nano-arrow/src/io/parquet/read/schema/mod.rs b/crates/nano-arrow/src/io/parquet/read/schema/mod.rs new file mode 100644 index 000000000000..8b2394684440 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/schema/mod.rs @@ -0,0 +1,58 @@ +//! APIs to handle Parquet <-> Arrow schemas. +use crate::datatypes::{Schema, TimeUnit}; +use crate::error::Result; + +mod convert; +mod metadata; + +pub(crate) use convert::*; +pub use convert::{parquet_to_arrow_schema, parquet_to_arrow_schema_with_options}; +pub use metadata::read_schema_from_metadata; +pub use parquet2::metadata::{FileMetaData, KeyValue, SchemaDescriptor}; +pub use parquet2::schema::types::ParquetType; + +use self::metadata::parse_key_value_metadata; + +/// Options when inferring schemas from Parquet +pub struct SchemaInferenceOptions { + /// When inferring schemas from the Parquet INT96 timestamp type, this is the corresponding TimeUnit + /// in the inferred Arrow Timestamp type. + /// + /// This defaults to `TimeUnit::Nanosecond`, but INT96 timestamps outside of the range of years 1678-2262, + /// will overflow when parsed as `Timestamp(TimeUnit::Nanosecond)`. Setting this to a lower resolution + /// (e.g. TimeUnit::Milliseconds) will result in loss of precision, but support a larger range of dates + /// without overflowing when parsing the data. + pub int96_coerce_to_timeunit: TimeUnit, +} + +impl Default for SchemaInferenceOptions { + fn default() -> Self { + SchemaInferenceOptions { + int96_coerce_to_timeunit: TimeUnit::Nanosecond, + } + } +} + +/// Infers a [`Schema`] from parquet's [`FileMetaData`]. This first looks for the metadata key +/// `"ARROW:schema"`; if it does not exist, it converts the parquet types declared in the +/// file's parquet schema to Arrow's equivalent. +/// # Error +/// This function errors iff the key `"ARROW:schema"` exists but is not correctly encoded, +/// indicating that that the file's arrow metadata was incorrectly written. +pub fn infer_schema(file_metadata: &FileMetaData) -> Result { + infer_schema_with_options(file_metadata, &None) +} + +/// Like [`infer_schema`] but with configurable options which affects the behavior of inference +pub fn infer_schema_with_options( + file_metadata: &FileMetaData, + options: &Option, +) -> Result { + let mut metadata = parse_key_value_metadata(file_metadata.key_value_metadata()); + + let schema = read_schema_from_metadata(&mut metadata)?; + Ok(schema.unwrap_or_else(|| { + let fields = parquet_to_arrow_schema_with_options(file_metadata.schema().fields(), options); + Schema { fields, metadata } + })) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/binary.rs b/crates/nano-arrow/src/io/parquet/read/statistics/binary.rs new file mode 100644 index 000000000000..aeb43a6b3e0b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/binary.rs @@ -0,0 +1,24 @@ +use parquet2::statistics::{BinaryStatistics, Statistics as ParquetStatistics}; + +use crate::array::{MutableArray, MutableBinaryArray}; +use crate::error::Result; +use crate::offset::Offset; + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + min.push(from.and_then(|s| s.min_value.as_ref())); + max.push(from.and_then(|s| s.max_value.as_ref())); + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/boolean.rs b/crates/nano-arrow/src/io/parquet/read/statistics/boolean.rs new file mode 100644 index 000000000000..ebb0ce3dade2 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/boolean.rs @@ -0,0 +1,23 @@ +use parquet2::statistics::{BooleanStatistics, Statistics as ParquetStatistics}; + +use crate::array::{MutableArray, MutableBooleanArray}; +use crate::error::Result; + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + min.push(from.and_then(|s| s.min_value)); + max.push(from.and_then(|s| s.max_value)); + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/statistics/dictionary.rs new file mode 100644 index 000000000000..f6e2fdddcce9 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/dictionary.rs @@ -0,0 +1,69 @@ +use super::make_mutable; +use crate::array::*; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Result; + +#[derive(Debug)] +pub struct DynMutableDictionary { + data_type: DataType, + pub inner: Box, +} + +impl DynMutableDictionary { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inner = if let DataType::Dictionary(_, inner, _) = &data_type { + inner.as_ref() + } else { + unreachable!() + }; + let inner = make_mutable(inner, capacity)?; + + Ok(Self { data_type, inner }) + } +} + +impl MutableArray for DynMutableDictionary { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + self.inner.validity() + } + + fn as_box(&mut self) -> Box { + let inner = self.inner.as_box(); + match self.data_type.to_physical_type() { + PhysicalType::Dictionary(key) => match_integer_type!(key, |$T| { + let keys: Vec<$T> = (0..inner.len() as $T).collect(); + let keys = PrimitiveArray::<$T>::from_vec(keys); + Box::new(DictionaryArray::<$T>::try_new(self.data_type.clone(), keys, inner).unwrap()) + }), + _ => todo!(), + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/fixlen.rs b/crates/nano-arrow/src/io/parquet/read/statistics/fixlen.rs new file mode 100644 index 000000000000..1f9db20d9c9a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/fixlen.rs @@ -0,0 +1,146 @@ +use ethnum::I256; +use parquet2::statistics::{FixedLenStatistics, Statistics as ParquetStatistics}; + +use super::super::{convert_days_ms, convert_i128}; +use crate::array::*; +use crate::error::Result; +use crate::io::parquet::read::convert_i256; +use crate::types::{days_ms, i256}; + +pub(super) fn push_i128( + from: Option<&dyn ParquetStatistics>, + n: usize, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(|x| convert_i128(x, n)))); + max.push(from.and_then(|s| s.max_value.as_deref().map(|x| convert_i128(x, n)))); + + Ok(()) +} + +pub(super) fn push_i256_with_i128( + from: Option<&dyn ParquetStatistics>, + n: usize, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| { + s.min_value + .as_deref() + .map(|x| i256(I256::new(convert_i128(x, n)))) + })); + max.push(from.and_then(|s| { + s.max_value + .as_deref() + .map(|x| i256(I256::new(convert_i128(x, n)))) + })); + + Ok(()) +} + +pub(super) fn push_i256( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_i256))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_i256))); + + Ok(()) +} + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + min.push(from.and_then(|s| s.min_value.as_ref())); + max.push(from.and_then(|s| s.max_value.as_ref())); + Ok(()) +} + +fn convert_year_month(value: &[u8]) -> i32 { + i32::from_le_bytes(value[..4].try_into().unwrap()) +} + +pub(super) fn push_year_month( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_year_month))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_year_month))); + + Ok(()) +} + +pub(super) fn push_days_ms( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_days_ms))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_days_ms))); + + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/list.rs b/crates/nano-arrow/src/io/parquet/read/statistics/list.rs new file mode 100644 index 000000000000..cb22cbf7063a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/list.rs @@ -0,0 +1,85 @@ +use super::make_mutable; +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::offset::Offsets; + +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + pub inner: Box, +} + +impl DynMutableListArray { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inner = match data_type.to_logical_type() { + DataType::List(inner) | DataType::LargeList(inner) => inner.data_type(), + _ => unreachable!(), + }; + let inner = make_mutable(inner, capacity)?; + + Ok(Self { data_type, inner }) + } +} + +impl MutableArray for DynMutableListArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + self.inner.validity() + } + + fn as_box(&mut self) -> Box { + let inner = self.inner.as_box(); + + match self.data_type.to_logical_type() { + DataType::List(_) => { + let offsets = + Offsets::try_from_lengths(std::iter::repeat(1).take(inner.len())).unwrap(); + Box::new(ListArray::::new( + self.data_type.clone(), + offsets.into(), + inner, + None, + )) + }, + DataType::LargeList(_) => { + let offsets = + Offsets::try_from_lengths(std::iter::repeat(1).take(inner.len())).unwrap(); + Box::new(ListArray::::new( + self.data_type.clone(), + offsets.into(), + inner, + None, + )) + }, + _ => unreachable!(), + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/map.rs b/crates/nano-arrow/src/io/parquet/read/statistics/map.rs new file mode 100644 index 000000000000..d6b2a73388f5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/map.rs @@ -0,0 +1,65 @@ +use super::make_mutable; +use crate::array::{Array, MapArray, MutableArray}; +use crate::datatypes::DataType; +use crate::error::Error; + +#[derive(Debug)] +pub struct DynMutableMapArray { + data_type: DataType, + pub inner: Box, +} + +impl DynMutableMapArray { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inner = match data_type.to_logical_type() { + DataType::Map(inner, _) => inner, + _ => unreachable!(), + }; + let inner = make_mutable(inner.data_type(), capacity)?; + + Ok(Self { data_type, inner }) + } +} + +impl MutableArray for DynMutableMapArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + Box::new(MapArray::new( + self.data_type.clone(), + vec![0, self.inner.len() as i32].try_into().unwrap(), + self.inner.as_box(), + None, + )) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/mod.rs b/crates/nano-arrow/src/io/parquet/read/statistics/mod.rs new file mode 100644 index 000000000000..3048952530a6 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/mod.rs @@ -0,0 +1,577 @@ +//! APIs exposing `parquet2`'s statistics as arrow's statistics. +use std::collections::VecDeque; +use std::sync::Arc; + +use ethnum::I256; +use parquet2::metadata::RowGroupMetaData; +use parquet2::schema::types::{ + PhysicalType as ParquetPhysicalType, PrimitiveType as ParquetPrimitiveType, +}; +use parquet2::statistics::{ + BinaryStatistics, BooleanStatistics, FixedLenStatistics, PrimitiveStatistics, + Statistics as ParquetStatistics, +}; +use parquet2::types::int96_to_i64_ns; + +use crate::array::*; +use crate::datatypes::{DataType, Field, IntervalUnit, PhysicalType}; +use crate::error::{Error, Result}; +use crate::types::i256; + +mod binary; +mod boolean; +mod dictionary; +mod fixlen; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod utf8; + +use self::list::DynMutableListArray; +use super::get_field_columns; + +/// Arrow-deserialized parquet Statistics of a file +#[derive(Debug, PartialEq)] +pub struct Statistics { + /// number of nulls. This is a [`UInt64Array`] for non-nested types + pub null_count: Box, + /// number of dictinct values. This is a [`UInt64Array`] for non-nested types + pub distinct_count: Box, + /// Minimum + pub min_value: Box, + /// Maximum + pub max_value: Box, +} + +/// Arrow-deserialized parquet Statistics of a file +#[derive(Debug)] +struct MutableStatistics { + /// number of nulls + pub null_count: Box, + /// number of dictinct values + pub distinct_count: Box, + /// Minimum + pub min_value: Box, + /// Maximum + pub max_value: Box, +} + +impl From for Statistics { + fn from(mut s: MutableStatistics) -> Self { + let null_count = if let PhysicalType::Struct = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::Map = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::List = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::LargeList = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else { + s.null_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + }; + let distinct_count = if let PhysicalType::Struct = + s.distinct_count.data_type().to_physical_type() + { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::Map = s.distinct_count.data_type().to_physical_type() { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::List = s.distinct_count.data_type().to_physical_type() { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::LargeList = s.distinct_count.data_type().to_physical_type() { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + }; + Self { + null_count, + distinct_count, + min_value: s.min_value.as_box(), + max_value: s.max_value.as_box(), + } + } +} + +fn make_mutable(data_type: &DataType, capacity: usize) -> Result> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => { + Box::new(MutableBooleanArray::with_capacity(capacity)) as Box + }, + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }), + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::LargeBinary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::Utf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::LargeUtf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::FixedSizeBinary => { + Box::new(MutableFixedSizeBinaryArray::try_new(data_type.clone(), vec![], None).unwrap()) + as _ + }, + PhysicalType::LargeList | PhysicalType::List => Box::new( + DynMutableListArray::try_with_capacity(data_type.clone(), capacity)?, + ) as Box, + PhysicalType::Dictionary(_) => Box::new( + dictionary::DynMutableDictionary::try_with_capacity(data_type.clone(), capacity)?, + ), + PhysicalType::Struct => Box::new(struct_::DynMutableStructArray::try_with_capacity( + data_type.clone(), + capacity, + )?), + PhysicalType::Map => Box::new(map::DynMutableMapArray::try_with_capacity( + data_type.clone(), + capacity, + )?), + PhysicalType::Null => { + Box::new(MutableNullArray::new(DataType::Null, 0)) as Box + }, + other => { + return Err(Error::NotYetImplemented(format!( + "Deserializing parquet stats from {other:?} is still not implemented" + ))) + }, + }) +} + +fn create_dt(data_type: &DataType) -> DataType { + if let DataType::Struct(fields) = data_type.to_logical_type() { + DataType::Struct( + fields + .iter() + .map(|f| Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)) + .collect(), + ) + } else if let DataType::Map(f, ordered) = data_type.to_logical_type() { + DataType::Map( + Box::new(Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)), + *ordered, + ) + } else if let DataType::List(f) = data_type.to_logical_type() { + DataType::List(Box::new(Field::new( + &f.name, + create_dt(&f.data_type), + f.is_nullable, + ))) + } else if let DataType::LargeList(f) = data_type.to_logical_type() { + DataType::LargeList(Box::new(Field::new( + &f.name, + create_dt(&f.data_type), + f.is_nullable, + ))) + } else { + DataType::UInt64 + } +} + +impl MutableStatistics { + fn try_new(field: &Field) -> Result { + let min_value = make_mutable(&field.data_type, 0)?; + let max_value = make_mutable(&field.data_type, 0)?; + + let dt = create_dt(&field.data_type); + Ok(Self { + null_count: make_mutable(&dt, 0)?, + distinct_count: make_mutable(&dt, 0)?, + min_value, + max_value, + }) + } +} + +fn push_others( + from: Option<&dyn ParquetStatistics>, + distinct_count: &mut UInt64Vec, + null_count: &mut UInt64Vec, +) { + let from = if let Some(from) = from { + from + } else { + distinct_count.push(None); + null_count.push(None); + return; + }; + let (distinct, null_count1) = match from.physical_type() { + ParquetPhysicalType::Boolean => { + let from = from.as_any().downcast_ref::().unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Int32 => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Int64 => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Int96 => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Float => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Double => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::ByteArray => { + let from = from.as_any().downcast_ref::().unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::FixedLenByteArray(_) => { + let from = from.as_any().downcast_ref::().unwrap(); + (from.distinct_count, from.null_count) + }, + }; + + distinct_count.push(distinct.map(|x| x as u64)); + null_count.push(null_count1.map(|x| x as u64)); +} + +fn push( + stats: &mut VecDeque<(Option>, ParquetPrimitiveType)>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, + distinct_count: &mut dyn MutableArray, + null_count: &mut dyn MutableArray, +) -> Result<()> { + match min.data_type().to_logical_type() { + List(_) | LargeList(_) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + return push( + stats, + min.inner.as_mut(), + max.inner.as_mut(), + distinct_count.inner.as_mut(), + null_count.inner.as_mut(), + ); + }, + Dictionary(_, _, _) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + return push( + stats, + min.inner.as_mut(), + max.inner.as_mut(), + distinct_count, + null_count, + ); + }, + Struct(_) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + return min + .inner + .iter_mut() + .zip(max.inner.iter_mut()) + .zip(distinct_count.inner.iter_mut()) + .zip(null_count.inner.iter_mut()) + .try_for_each(|(((min, max), distinct_count), null_count)| { + push( + stats, + min.as_mut(), + max.as_mut(), + distinct_count.as_mut(), + null_count.as_mut(), + ) + }); + }, + Map(_, _) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + return push( + stats, + min.inner.as_mut(), + max.inner.as_mut(), + distinct_count.inner.as_mut(), + null_count.inner.as_mut(), + ); + }, + _ => {}, + } + + let (from, type_) = stats.pop_front().unwrap(); + let from = from.as_deref(); + + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count.as_mut_any().downcast_mut::().unwrap(); + + push_others(from, distinct_count, null_count); + + let physical_type = &type_.physical_type; + + use DataType::*; + match min.data_type().to_logical_type() { + Boolean => boolean::push(from, min, max), + Int8 => primitive::push(from, min, max, |x: i32| Ok(x as i8)), + Int16 => primitive::push(from, min, max, |x: i32| Ok(x as i16)), + Date32 | Time32(_) => primitive::push::(from, min, max, Ok), + Interval(IntervalUnit::YearMonth) => fixlen::push_year_month(from, min, max), + Interval(IntervalUnit::DayTime) => fixlen::push_days_ms(from, min, max), + UInt8 => primitive::push(from, min, max, |x: i32| Ok(x as u8)), + UInt16 => primitive::push(from, min, max, |x: i32| Ok(x as u16)), + UInt32 => match physical_type { + // some implementations of parquet write arrow's u32 into i64. + ParquetPhysicalType::Int64 => primitive::push(from, min, max, |x: i64| Ok(x as u32)), + ParquetPhysicalType::Int32 => primitive::push(from, min, max, |x: i32| Ok(x as u32)), + other => Err(Error::NotYetImplemented(format!( + "Can't decode UInt32 type from parquet type {other:?}" + ))), + }, + Int32 => primitive::push::(from, min, max, Ok), + Date64 => match physical_type { + ParquetPhysicalType::Int64 => primitive::push::(from, min, max, Ok), + // some implementations of parquet write arrow's date64 into i32. + ParquetPhysicalType::Int32 => { + primitive::push(from, min, max, |x: i32| Ok(x as i64 * 86400000)) + }, + other => Err(Error::NotYetImplemented(format!( + "Can't decode Date64 type from parquet type {other:?}" + ))), + }, + Int64 | Time64(_) | Duration(_) => primitive::push::(from, min, max, Ok), + UInt64 => primitive::push(from, min, max, |x: i64| Ok(x as u64)), + Timestamp(time_unit, _) => { + let time_unit = *time_unit; + if physical_type == &ParquetPhysicalType::Int96 { + let from = from.map(|from| { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + PrimitiveStatistics:: { + primitive_type: from.primitive_type.clone(), + null_count: from.null_count, + distinct_count: from.distinct_count, + min_value: from.min_value.map(int96_to_i64_ns), + max_value: from.max_value.map(int96_to_i64_ns), + } + }); + primitive::push( + from.as_ref().map(|x| x as &dyn ParquetStatistics), + min, + max, + |x: i64| { + Ok(primitive::timestamp( + type_.logical_type.as_ref(), + time_unit, + x, + )) + }, + ) + } else { + primitive::push(from, min, max, |x: i64| { + Ok(primitive::timestamp( + type_.logical_type.as_ref(), + time_unit, + x, + )) + }) + } + }, + Float32 => primitive::push::(from, min, max, Ok), + Float64 => primitive::push::(from, min, max, Ok), + Decimal(_, _) => match physical_type { + ParquetPhysicalType::Int32 => primitive::push(from, min, max, |x: i32| Ok(x as i128)), + ParquetPhysicalType::Int64 => primitive::push(from, min, max, |x: i64| Ok(x as i128)), + ParquetPhysicalType::FixedLenByteArray(n) if *n > 16 => Err(Error::NotYetImplemented( + format!("Can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}"), + )), + ParquetPhysicalType::FixedLenByteArray(n) => fixlen::push_i128(from, *n, min, max), + _ => unreachable!(), + }, + Decimal256(_, _) => match physical_type { + ParquetPhysicalType::Int32 => { + primitive::push(from, min, max, |x: i32| Ok(i256(I256::new(x.into())))) + }, + ParquetPhysicalType::Int64 => { + primitive::push(from, min, max, |x: i64| Ok(i256(I256::new(x.into())))) + }, + ParquetPhysicalType::FixedLenByteArray(n) if *n <= 16 => { + fixlen::push_i256_with_i128(from, *n, min, max) + }, + ParquetPhysicalType::FixedLenByteArray(n) if *n > 32 => Err(Error::NotYetImplemented( + format!("Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}"), + )), + ParquetPhysicalType::FixedLenByteArray(_) => fixlen::push_i256(from, min, max), + _ => unreachable!(), + }, + Binary => binary::push::(from, min, max), + LargeBinary => binary::push::(from, min, max), + Utf8 => utf8::push::(from, min, max), + LargeUtf8 => utf8::push::(from, min, max), + FixedSizeBinary(_) => fixlen::push(from, min, max), + Null => null::push(min, max), + other => todo!("{:?}", other), + } +} + +/// Deserializes the statistics in the column chunks from all `row_groups` +/// into [`Statistics`] associated from `field`'s name. +/// +/// # Errors +/// This function errors if the deserialization of the statistics fails (e.g. invalid utf8) +pub fn deserialize(field: &Field, row_groups: &[RowGroupMetaData]) -> Result { + let mut statistics = MutableStatistics::try_new(field)?; + + // transpose + row_groups.iter().try_for_each(|group| { + let columns = get_field_columns(group.columns(), field.name.as_ref()); + let mut stats = columns + .into_iter() + .map(|column| { + Ok(( + column.statistics().transpose()?, + column.descriptor().descriptor.primitive_type.clone(), + )) + }) + .collect::, ParquetPrimitiveType)>>>()?; + push( + &mut stats, + statistics.min_value.as_mut(), + statistics.max_value.as_mut(), + statistics.distinct_count.as_mut(), + statistics.null_count.as_mut(), + ) + })?; + + Ok(statistics.into()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/null.rs b/crates/nano-arrow/src/io/parquet/read/statistics/null.rs new file mode 100644 index 000000000000..9102720ebc5c --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/null.rs @@ -0,0 +1,11 @@ +use crate::array::*; +use crate::error::Result; + +pub(super) fn push(min: &mut dyn MutableArray, max: &mut dyn MutableArray) -> Result<()> { + let min = min.as_mut_any().downcast_mut::().unwrap(); + let max = max.as_mut_any().downcast_mut::().unwrap(); + min.push_null(); + max.push_null(); + + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/primitive.rs b/crates/nano-arrow/src/io/parquet/read/statistics/primitive.rs new file mode 100644 index 000000000000..849363028ad1 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/primitive.rs @@ -0,0 +1,55 @@ +use parquet2::schema::types::{PrimitiveLogicalType, TimeUnit as ParquetTimeUnit}; +use parquet2::statistics::{PrimitiveStatistics, Statistics as ParquetStatistics}; +use parquet2::types::NativeType as ParquetNativeType; + +use crate::array::*; +use crate::datatypes::TimeUnit; +use crate::error::Result; +use crate::types::NativeType; + +pub fn timestamp(logical_type: Option<&PrimitiveLogicalType>, time_unit: TimeUnit, x: i64) -> i64 { + let unit = if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + unit + } else { + return x; + }; + + match (unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) => x / 1_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Second) => x / 1_000_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => x * 1_000_000_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) => x, + (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) => x / 1_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => x / 1_000_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) => x * 1_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) => x, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => x / 1_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => x * 1_000_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => x * 1_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => x, + } +} + +pub(super) fn push Result + Copy>( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, + map: F, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::>().unwrap()); + min.push(from.and_then(|s| s.min_value.map(map)).transpose()?); + max.push(from.and_then(|s| s.max_value.map(map)).transpose()?); + + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/struct_.rs b/crates/nano-arrow/src/io/parquet/read/statistics/struct_.rs new file mode 100644 index 000000000000..6aca0352701e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/struct_.rs @@ -0,0 +1,64 @@ +use super::make_mutable; +use crate::array::{Array, MutableArray, StructArray}; +use crate::datatypes::DataType; +use crate::error::Result; + +#[derive(Debug)] +pub struct DynMutableStructArray { + data_type: DataType, + pub inner: Vec>, +} + +impl DynMutableStructArray { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inners = match data_type.to_logical_type() { + DataType::Struct(inner) => inner, + _ => unreachable!(), + }; + let inner = inners + .iter() + .map(|f| make_mutable(f.data_type(), capacity)) + .collect::>>()?; + + Ok(Self { data_type, inner }) + } +} +impl MutableArray for DynMutableStructArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner[0].len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let inner = self.inner.iter_mut().map(|x| x.as_box()).collect(); + + Box::new(StructArray::new(self.data_type.clone(), inner, None)) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/utf8.rs b/crates/nano-arrow/src/io/parquet/read/statistics/utf8.rs new file mode 100644 index 000000000000..da9fcb6e1119 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/utf8.rs @@ -0,0 +1,31 @@ +use parquet2::statistics::{BinaryStatistics, Statistics as ParquetStatistics}; + +use crate::array::{MutableArray, MutableUtf8Array}; +use crate::error::Result; +use crate::offset::Offset; + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push( + from.and_then(|s| s.min_value.as_deref().map(simdutf8::basic::from_utf8)) + .transpose()?, + ); + max.push( + from.and_then(|s| s.max_value.as_deref().map(simdutf8::basic::from_utf8)) + .transpose()?, + ); + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/write/binary/basic.rs b/crates/nano-arrow/src/io/parquet/write/binary/basic.rs new file mode 100644 index 000000000000..de840e45fa5a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/binary/basic.rs @@ -0,0 +1,168 @@ +use parquet2::encoding::{delta_bitpacked, Encoding}; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, BinaryStatistics, ParquetStatistics, Statistics}; + +use super::super::{utils, WriteOptions}; +use crate::array::{Array, BinaryArray}; +use crate::bitmap::Bitmap; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::offset::Offset; + +pub(crate) fn encode_plain( + array: &BinaryArray, + is_optional: bool, + buffer: &mut Vec, +) { + // append the non-null values + if is_optional { + array.iter().for_each(|x| { + if let Some(x) = x { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x); + } + }) + } else { + array.values_iter().for_each(|x| { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x); + }) + } +} + +pub fn array_to_page( + array: &BinaryArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> Result { + let validity = array.validity(); + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, is_optional, &mut buffer), + Encoding::DeltaLengthByteArray => encode_delta( + array.values(), + array.offsets().buffer(), + array.validity(), + is_optional, + &mut buffer, + ), + _ => { + return Err(Error::InvalidArgumentError(format!( + "Datatype {:?} cannot be encoded by {:?} encoding", + array.data_type(), + encoding + ))) + }, + } + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub(crate) fn build_statistics( + array: &BinaryArray, + primitive_type: PrimitiveType, +) -> ParquetStatistics { + let statistics = &BinaryStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + } as &dyn Statistics; + serialize_statistics(statistics) +} + +pub(crate) fn encode_delta( + values: &[u8], + offsets: &[O], + validity: Option<&Bitmap>, + is_optional: bool, + buffer: &mut Vec, +) { + if is_optional { + if let Some(validity) = validity { + let lengths = offsets + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64) + .zip(validity.iter()) + .flat_map(|(x, is_valid)| if is_valid { Some(x) } else { None }); + let length = offsets.len() - 1 - validity.unset_bits(); + let lengths = utils::ExactSizedIter::new(lengths, length); + + delta_bitpacked::encode(lengths, buffer); + } else { + let lengths = offsets.windows(2).map(|w| (w[1] - w[0]).to_usize() as i64); + delta_bitpacked::encode(lengths, buffer); + } + } else { + let lengths = offsets.windows(2).map(|w| (w[1] - w[0]).to_usize() as i64); + delta_bitpacked::encode(lengths, buffer); + } + + buffer.extend_from_slice( + &values[offsets.first().unwrap().to_usize()..offsets.last().unwrap().to_usize()], + ) +} + +/// Returns the ordering of two binary values. This corresponds to pyarrows' ordering +/// of statistics. +pub(crate) fn ord_binary<'a>(a: &'a [u8], b: &'a [u8]) -> std::cmp::Ordering { + use std::cmp::Ordering::*; + match (a.is_empty(), b.is_empty()) { + (true, true) => return Equal, + (true, false) => return Less, + (false, true) => return Greater, + (false, false) => {}, + } + + for (v1, v2) in a.iter().zip(b.iter()) { + match v1.cmp(v2) { + Equal => continue, + other => return other, + } + } + Equal +} diff --git a/crates/nano-arrow/src/io/parquet/write/binary/mod.rs b/crates/nano-arrow/src/io/parquet/write/binary/mod.rs new file mode 100644 index 000000000000..e942b4b69103 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/binary/mod.rs @@ -0,0 +1,7 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub(crate) use basic::{build_statistics, encode_plain}; +pub(super) use basic::{encode_delta, ord_binary}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/binary/nested.rs b/crates/nano-arrow/src/io/parquet/write/binary/nested.rs new file mode 100644 index 000000000000..11de9d9676a7 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/binary/nested.rs @@ -0,0 +1,48 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, BinaryArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; +use crate::offset::Offset; + +pub fn array_to_page( + array: &BinaryArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result +where + O: Offset, +{ + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, is_optional, &mut buffer); + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/boolean/basic.rs b/crates/nano-arrow/src/io/parquet/write/boolean/basic.rs new file mode 100644 index 000000000000..833bfab09e5a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/boolean/basic.rs @@ -0,0 +1,92 @@ +use parquet2::encoding::hybrid_rle::bitpacked_encode; +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{ + serialize_statistics, BooleanStatistics, ParquetStatistics, Statistics, +}; + +use super::super::{utils, WriteOptions}; +use crate::array::*; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; + +fn encode(iterator: impl Iterator, buffer: &mut Vec) -> Result<()> { + // encode values using bitpacking + let len = buffer.len(); + let mut buffer = std::io::Cursor::new(buffer); + buffer.set_position(len as u64); + Ok(bitpacked_encode(&mut buffer, iterator)?) +} + +pub(super) fn encode_plain( + array: &BooleanArray, + is_optional: bool, + buffer: &mut Vec, +) -> Result<()> { + if is_optional { + let iter = array.iter().flatten().take( + array + .validity() + .as_ref() + .map(|x| x.len() - x.unset_bits()) + .unwrap_or_else(|| array.len()), + ); + encode(iter, buffer) + } else { + let iter = array.values().iter(); + encode(iter, buffer) + } +} + +pub fn array_to_page( + array: &BooleanArray, + options: WriteOptions, + type_: PrimitiveType, +) -> Result { + let is_optional = is_nullable(&type_.field_info); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, is_optional, &mut buffer)?; + + let statistics = if options.write_statistics { + Some(build_statistics(array)) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} + +pub(super) fn build_statistics(array: &BooleanArray) -> ParquetStatistics { + let statistics = &BooleanStatistics { + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array.iter().flatten().max(), + min_value: array.iter().flatten().min(), + } as &dyn Statistics; + serialize_statistics(statistics) +} diff --git a/crates/nano-arrow/src/io/parquet/write/boolean/mod.rs b/crates/nano-arrow/src/io/parquet/write/boolean/mod.rs new file mode 100644 index 000000000000..280e2ff9efb5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/boolean/mod.rs @@ -0,0 +1,5 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/boolean/nested.rs b/crates/nano-arrow/src/io/parquet/write/boolean/nested.rs new file mode 100644 index 000000000000..656019100825 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/boolean/nested.rs @@ -0,0 +1,44 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, BooleanArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; + +pub fn array_to_page( + array: &BooleanArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result { + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, is_optional, &mut buffer)?; + + let statistics = if options.write_statistics { + Some(build_statistics(array)) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/dictionary.rs b/crates/nano-arrow/src/io/parquet/write/dictionary.rs new file mode 100644 index 000000000000..4ee0a5c37eac --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/dictionary.rs @@ -0,0 +1,281 @@ +use parquet2::encoding::hybrid_rle::encode_u32; +use parquet2::encoding::Encoding; +use parquet2::page::{DictPage, Page}; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, ParquetStatistics}; +use parquet2::write::DynIter; + +use super::binary::{ + build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, +}; +use super::fixed_len_bytes::{ + build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain, +}; +use super::primitive::{ + build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain, +}; +use super::utf8::{build_statistics as utf8_build_statistics, encode_plain as utf8_encode_plain}; +use super::{nested, Nested, WriteOptions}; +use crate::array::{Array, DictionaryArray, DictionaryKey}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::{slice_nested_leaf, utils}; + +fn serialize_def_levels_simple( + validity: Option<&Bitmap>, + length: usize, + is_optional: bool, + options: WriteOptions, + buffer: &mut Vec, +) -> Result<()> { + utils::write_def_levels(buffer, is_optional, validity, length, options.version) +} + +fn serialize_keys_values( + array: &DictionaryArray, + validity: Option<&Bitmap>, + buffer: &mut Vec, +) -> Result<()> { + let keys = array.keys_values_iter().map(|x| x as u32); + if let Some(validity) = validity { + // discard indices whose values are null. + let keys = keys + .zip(validity.iter()) + .filter(|&(_key, is_valid)| is_valid) + .map(|(key, _is_valid)| key); + let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); + + let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits()); + + // num_bits as a single byte + buffer.push(num_bits as u8); + + // followed by the encoded indices. + Ok(encode_u32(buffer, keys, num_bits)?) + } else { + let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); + + // num_bits as a single byte + buffer.push(num_bits as u8); + + // followed by the encoded indices. + Ok(encode_u32(buffer, keys, num_bits)?) + } +} + +fn serialize_levels( + validity: Option<&Bitmap>, + length: usize, + type_: &PrimitiveType, + nested: &[Nested], + options: WriteOptions, + buffer: &mut Vec, +) -> Result<(usize, usize)> { + if nested.len() == 1 { + let is_optional = is_nullable(&type_.field_info); + serialize_def_levels_simple(validity, length, is_optional, options, buffer)?; + let definition_levels_byte_length = buffer.len(); + Ok((0, definition_levels_byte_length)) + } else { + nested::write_rep_and_def(options.version, nested, buffer) + } +} + +fn normalized_validity(array: &DictionaryArray) -> Option { + match (array.keys().validity(), array.values().validity()) { + (None, None) => None, + (None, rhs) => rhs.cloned(), + (lhs, None) => lhs.cloned(), + (Some(_), Some(rhs)) => { + let projected_validity = array + .keys_iter() + .map(|x| x.map(|x| rhs.get_bit(x)).unwrap_or(false)); + MutableBitmap::from_trusted_len_iter(projected_validity).into() + }, + } +} + +fn serialize_keys( + array: &DictionaryArray, + type_: PrimitiveType, + nested: &[Nested], + statistics: Option, + options: WriteOptions, +) -> Result { + let mut buffer = vec![]; + + // parquet only accepts a single validity - we "&" the validities into a single one + // and ignore keys whole _value_ is null. + let validity = normalized_validity(array); + let (start, len) = slice_nested_leaf(nested); + + let mut nested = nested.to_vec(); + let array = array.clone().sliced(start, len); + if let Some(Nested::Primitive(_, _, c)) = nested.last_mut() { + *c = len; + } else { + unreachable!("") + } + + let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels( + validity.as_ref(), + array.len(), + &type_, + &nested, + options, + &mut buffer, + )?; + + serialize_keys_values(&array, validity.as_ref(), &mut buffer)?; + + let (num_values, num_rows) = if nested.len() == 1 { + (array.len(), array.len()) + } else { + (nested::num_values(&nested), nested[0].len()) + }; + + utils::build_plain_page( + buffer, + num_values, + num_rows, + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::RleDictionary, + ) + .map(Page::Data) +} + +macro_rules! dyn_prim { + ($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{ + let values = $array.values().as_any().downcast_ref().unwrap(); + + let buffer = primitive_encode_plain::<$from, $to>(values, false, vec![]); + + let stats: Option = if $options.write_statistics { + let mut stats = primitive_build_statistics::<$from, $to>(values, $type_.clone()); + stats.null_count = Some($array.null_count() as i64); + let stats = serialize_statistics(&stats); + Some(stats) + } else { + None + }; + (DictPage::new(buffer, values.len(), false), stats) + }}; +} + +pub fn array_to_pages( + array: &DictionaryArray, + type_: PrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> Result>> { + match encoding { + Encoding::PlainDictionary | Encoding::RleDictionary => { + // write DictPage + let (dict_page, statistics): (_, Option) = + match array.values().data_type().to_logical_type() { + DataType::Int8 => dyn_prim!(i8, i32, array, options, type_), + DataType::Int16 => dyn_prim!(i16, i32, array, options, type_), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + dyn_prim!(i32, i32, array, options, type_) + }, + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_), + DataType::UInt8 => dyn_prim!(u8, i32, array, options, type_), + DataType::UInt16 => dyn_prim!(u16, i32, array, options, type_), + DataType::UInt32 => dyn_prim!(u32, i32, array, options, type_), + DataType::UInt64 => dyn_prim!(u64, i64, array, options, type_), + DataType::Float32 => dyn_prim!(f32, f32, array, options, type_), + DataType::Float64 => dyn_prim!(f64, f64, array, options, type_), + DataType::Utf8 => { + let array = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + utf8_encode_plain::(array, false, &mut buffer); + let stats = if options.write_statistics { + Some(utf8_build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + DataType::LargeUtf8 => { + let array = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + utf8_encode_plain::(array, false, &mut buffer); + let stats = if options.write_statistics { + Some(utf8_build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + DataType::Binary => { + let array = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + binary_encode_plain::(array, false, &mut buffer); + let stats = if options.write_statistics { + Some(binary_build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + DataType::LargeBinary => { + let values = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + binary_encode_plain::(values, false, &mut buffer); + let stats = if options.write_statistics { + let mut stats = binary_build_statistics(values, type_.clone()); + stats.null_count = Some(array.null_count() as i64); + Some(stats) + } else { + None + }; + (DictPage::new(buffer, values.len(), false), stats) + }, + DataType::FixedSizeBinary(_) => { + let mut buffer = vec![]; + let array = array.values().as_any().downcast_ref().unwrap(); + fixed_binary_encode_plain(array, false, &mut buffer); + let stats = if options.write_statistics { + let mut stats = fixed_binary_build_statistics(array, type_.clone()); + stats.null_count = Some(array.null_count() as i64); + Some(serialize_statistics(&stats)) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + other => { + return Err(Error::NotYetImplemented(format!( + "Writing dictionary arrays to parquet only support data type {other:?}" + ))) + }, + }; + let dict_page = Page::Dict(dict_page); + + // write DataPage pointing to DictPage + let data_page = serialize_keys(array, type_, nested, statistics, options)?; + + let iter = std::iter::once(Ok(dict_page)).chain(std::iter::once(Ok(data_page))); + Ok(DynIter::new(Box::new(iter))) + }, + _ => Err(Error::NotYetImplemented( + "Dictionary arrays only support dictionary encoding".to_string(), + )), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/file.rs b/crates/nano-arrow/src/io/parquet/write/file.rs new file mode 100644 index 000000000000..4ec37b941ad9 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/file.rs @@ -0,0 +1,95 @@ +use std::io::Write; + +use parquet2::metadata::{KeyValue, SchemaDescriptor}; +use parquet2::write::{RowGroupIter, WriteOptions as FileWriteOptions}; + +use super::schema::schema_to_metadata_key; +use super::{to_parquet_schema, ThriftFileMetaData, WriteOptions}; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; + +/// Attaches [`Schema`] to `key_value_metadata` +pub fn add_arrow_schema( + schema: &Schema, + key_value_metadata: Option>, +) -> Option> { + key_value_metadata + .map(|mut x| { + x.push(schema_to_metadata_key(schema)); + x + }) + .or_else(|| Some(vec![schema_to_metadata_key(schema)])) +} + +/// An interface to write a parquet to a [`Write`] +pub struct FileWriter { + writer: parquet2::write::FileWriter, + schema: Schema, + options: WriteOptions, +} + +// Accessors +impl FileWriter { + /// The options assigned to the file + pub fn options(&self) -> WriteOptions { + self.options + } + + /// The [`SchemaDescriptor`] assigned to this file + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.writer.schema() + } + + /// The [`Schema`] assigned to this file + pub fn schema(&self) -> &Schema { + &self.schema + } +} + +impl FileWriter { + /// Returns a new [`FileWriter`]. + /// # Error + /// If it is unable to derive a parquet schema from [`Schema`]. + pub fn try_new(writer: W, schema: Schema, options: WriteOptions) -> Result { + let parquet_schema = to_parquet_schema(&schema)?; + + let created_by = Some("Arrow2 - Native Rust implementation of Arrow".to_string()); + + Ok(Self { + writer: parquet2::write::FileWriter::new( + writer, + parquet_schema, + FileWriteOptions { + version: options.version, + write_statistics: options.write_statistics, + }, + created_by, + ), + schema, + options, + }) + } + + /// Writes a row group to the file. + pub fn write(&mut self, row_group: RowGroupIter<'_, Error>) -> Result<()> { + Ok(self.writer.write(row_group)?) + } + + /// Writes the footer of the parquet file. Returns the total size of the file. + pub fn end(&mut self, key_value_metadata: Option>) -> Result { + let key_value_metadata = add_arrow_schema(&self.schema, key_value_metadata); + Ok(self.writer.end(key_value_metadata)?) + } + + /// Consumes this writer and returns the inner writer + pub fn into_inner(self) -> W { + self.writer.into_inner() + } + + /// Returns the underlying writer and [`ThriftFileMetaData`] + /// # Panics + /// This function panics if [`Self::end`] has not yet been called + pub fn into_inner_and_metadata(self) -> (W, ThriftFileMetaData) { + self.writer.into_inner_and_metadata() + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/fixed_len_bytes.rs b/crates/nano-arrow/src/io/parquet/write/fixed_len_bytes.rs new file mode 100644 index 000000000000..86080ef7728f --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/fixed_len_bytes.rs @@ -0,0 +1,147 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, FixedLenStatistics}; + +use super::binary::ord_binary; +use super::{utils, WriteOptions}; +use crate::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::types::i256; + +pub(crate) fn encode_plain(array: &FixedSizeBinaryArray, is_optional: bool, buffer: &mut Vec) { + // append the non-null values + if is_optional { + array.iter().for_each(|x| { + if let Some(x) = x { + buffer.extend_from_slice(x); + } + }) + } else { + buffer.extend_from_slice(array.values()); + } +} + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + statistics: Option, +) -> Result { + let is_optional = is_nullable(&type_.field_info); + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, is_optional, &mut buffer); + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics.map(|x| serialize_statistics(&x)), + type_, + options, + Encoding::Plain, + ) +} + +pub(super) fn build_statistics( + array: &FixedSizeBinaryArray, + primitive_type: PrimitiveType, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + } +} + +pub(super) fn build_statistics_decimal( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max() + .map(|x| x.to_be_bytes()[16 - size..].to_vec()), + min_value: array + .iter() + .flatten() + .min() + .map(|x| x.to_be_bytes()[16 - size..].to_vec()), + } +} + +pub(super) fn build_statistics_decimal256_with_i128( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max() + .map(|x| x.0.low().to_be_bytes()[16 - size..].to_vec()), + min_value: array + .iter() + .flatten() + .min() + .map(|x| x.0.low().to_be_bytes()[16 - size..].to_vec()), + } +} + +pub(super) fn build_statistics_decimal256( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max() + .map(|x| x.0.to_be_bytes()[32 - size..].to_vec()), + min_value: array + .iter() + .flatten() + .min() + .map(|x| x.0.to_be_bytes()[32 - size..].to_vec()), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/mod.rs b/crates/nano-arrow/src/io/parquet/write/mod.rs new file mode 100644 index 000000000000..b74daea04d7e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/mod.rs @@ -0,0 +1,876 @@ +//! APIs to write to Parquet format. +//! +//! # Arrow/Parquet Interoperability +//! As of [parquet-format v2.9](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md) +//! there are Arrow [DataTypes](crate::datatypes::DataType) which do not have a parquet +//! representation. These include but are not limited to: +//! * `DataType::Timestamp(TimeUnit::Second, _)` +//! * `DataType::Int64` +//! * `DataType::Duration` +//! * `DataType::Date64` +//! * `DataType::Time32(TimeUnit::Second)` +//! +//! The use of these arrow types will result in no logical type being stored within a parquet file. + +mod binary; +mod boolean; +mod dictionary; +mod file; +mod fixed_len_bytes; +mod nested; +mod pages; +mod primitive; +mod row_group; +mod schema; +mod sink; +mod utf8; +mod utils; + +pub use nested::{num_values, write_rep_and_def}; +pub use pages::{to_leaves, to_nested, to_parquet_leaves}; +pub use parquet2::compression::{BrotliLevel, CompressionOptions, GzipLevel, ZstdLevel}; +pub use parquet2::encoding::Encoding; +pub use parquet2::metadata::{ + Descriptor, FileMetaData, KeyValue, SchemaDescriptor, ThriftFileMetaData, +}; +pub use parquet2::page::{CompressedDataPage, CompressedPage, Page}; +use parquet2::schema::types::PrimitiveType as ParquetPrimitiveType; +pub use parquet2::schema::types::{FieldInfo, ParquetType, PhysicalType as ParquetPhysicalType}; +pub use parquet2::write::{ + compress, write_metadata_sidecar, Compressor, DynIter, DynStreamingIterator, RowGroupIter, + Version, +}; +pub use parquet2::{fallible_streaming_iterator, FallibleStreamingIterator}; +pub use utils::write_def_levels; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::types::{days_ms, i256, NativeType}; + +/// Currently supported options to write to parquet +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WriteOptions { + /// Whether to write statistics + pub write_statistics: bool, + /// The page and file version to use + pub version: Version, + /// The compression to apply to every page + pub compression: CompressionOptions, + /// The size to flush a page, defaults to 1024 * 1024 if None + pub data_pagesize_limit: Option, +} + +pub use file::FileWriter; +pub use pages::{array_to_columns, Nested}; +pub use row_group::{row_group_iter, RowGroupIterator}; +pub use schema::to_parquet_type; +pub use sink::FileSink; + +use crate::compute::aggregate::estimated_bytes_size; + +/// returns offset and length to slice the leaf values +pub fn slice_nested_leaf(nested: &[Nested]) -> (usize, usize) { + // find the deepest recursive dremel structure as that one determines how many values we must + // take + let mut out = (0, 0); + for nested in nested.iter().rev() { + match nested { + Nested::LargeList(l_nested) => { + let start = *l_nested.offsets.first(); + let end = *l_nested.offsets.last(); + return (start as usize, (end - start) as usize); + }, + Nested::List(l_nested) => { + let start = *l_nested.offsets.first(); + let end = *l_nested.offsets.last(); + return (start as usize, (end - start) as usize); + }, + Nested::Primitive(_, _, len) => out = (0, *len), + _ => {}, + } + } + out +} + +fn decimal_length_from_precision(precision: usize) -> usize { + // digits = floor(log_10(2^(8*n - 1) - 1)) + // ceil(digits) = log10(2^(8*n - 1) - 1) + // 10^ceil(digits) = 2^(8*n - 1) - 1 + // 10^ceil(digits) + 1 = 2^(8*n - 1) + // log2(10^ceil(digits) + 1) = (8*n - 1) + // log2(10^ceil(digits) + 1) + 1 = 8*n + // (log2(10^ceil(a) + 1) + 1) / 8 = n + (((10.0_f64.powi(precision as i32) + 1.0).log2() + 1.0) / 8.0).ceil() as usize +} + +/// Creates a parquet [`SchemaDescriptor`] from a [`Schema`]. +pub fn to_parquet_schema(schema: &Schema) -> Result { + let parquet_types = schema + .fields + .iter() + .map(to_parquet_type) + .collect::>>()?; + Ok(SchemaDescriptor::new("root".to_string(), parquet_types)) +} + +/// Checks whether the `data_type` can be encoded as `encoding`. +/// Note that this is whether this implementation supports it, which is a subset of +/// what the parquet spec allows. +pub fn can_encode(data_type: &DataType, encoding: Encoding) -> bool { + if let (Encoding::DeltaBinaryPacked, DataType::Decimal(p, _)) = + (encoding, data_type.to_logical_type()) + { + return *p <= 18; + }; + + matches!( + (encoding, data_type.to_logical_type()), + (Encoding::Plain, _) + | ( + Encoding::DeltaLengthByteArray, + DataType::Binary | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8, + ) + | (Encoding::RleDictionary, DataType::Dictionary(_, _, _)) + | (Encoding::PlainDictionary, DataType::Dictionary(_, _, _)) + | ( + Encoding::DeltaBinaryPacked, + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + ) + ) +} + +/// Slices the [`Array`] to `Box` and `Vec`. +pub fn slice_parquet_array( + primitive_array: &mut dyn Array, + nested: &mut [Nested], + mut current_offset: usize, + mut current_length: usize, +) { + for nested in nested.iter_mut() { + match nested { + Nested::LargeList(l_nested) => { + l_nested.offsets.slice(current_offset, current_length + 1); + if let Some(validity) = l_nested.validity.as_mut() { + validity.slice(current_offset, current_length) + }; + + current_length = l_nested.offsets.range() as usize; + current_offset = *l_nested.offsets.first() as usize; + }, + Nested::List(l_nested) => { + l_nested.offsets.slice(current_offset, current_length + 1); + if let Some(validity) = l_nested.validity.as_mut() { + validity.slice(current_offset, current_length) + }; + + current_length = l_nested.offsets.range() as usize; + current_offset = *l_nested.offsets.first() as usize; + }, + Nested::Struct(validity, _, length) => { + *length = current_length; + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + }, + Nested::Primitive(validity, _, length) => { + *length = current_length; + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + primitive_array.slice(current_offset, current_length); + }, + } + } +} + +/// Get the length of [`Array`] that should be sliced. +pub fn get_max_length(nested: &[Nested]) -> usize { + let mut length = 0; + for nested in nested.iter() { + match nested { + Nested::LargeList(l_nested) => length += l_nested.offsets.range() as usize, + Nested::List(l_nested) => length += l_nested.offsets.range() as usize, + _ => {}, + } + } + length +} + +/// Returns an iterator of [`Page`]. +pub fn array_to_pages( + primitive_array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> Result>> { + if let DataType::Dictionary(key_type, _, _) = primitive_array.data_type().to_logical_type() { + return match_integer_type!(key_type, |$T| { + dictionary::array_to_pages::<$T>( + primitive_array.as_any().downcast_ref().unwrap(), + type_, + &nested, + options, + encoding, + ) + }); + }; + + let nested = nested.to_vec(); + let primitive_array = primitive_array.to_boxed(); + + let number_of_rows = nested[0].len(); + + // note: this is not correct if the array is sliced - the estimation should happen on the + // primitive after sliced for parquet + let byte_size = estimated_bytes_size(primitive_array.as_ref()); + + const DEFAULT_PAGE_SIZE: usize = 1024 * 1024; + let max_page_size = options.data_pagesize_limit.unwrap_or(DEFAULT_PAGE_SIZE); + let max_page_size = max_page_size.min(2usize.pow(31) - 2usize.pow(25)); // allowed maximum page size + let bytes_per_row = if number_of_rows == 0 { + 0 + } else { + ((byte_size as f64) / (number_of_rows as f64)) as usize + }; + let rows_per_page = (max_page_size / (bytes_per_row + 1)).max(1); + + let pages = (0..number_of_rows) + .step_by(rows_per_page) + .map(move |offset| { + let length = if offset + rows_per_page > number_of_rows { + number_of_rows - offset + } else { + rows_per_page + }; + + let mut right_array = primitive_array.clone(); + let mut right_nested = nested.clone(); + slice_parquet_array(right_array.as_mut(), &mut right_nested, offset, length); + + array_to_page( + right_array.as_ref(), + type_.clone(), + &right_nested, + options, + encoding, + ) + }); + + Ok(DynIter::new(pages)) +} + +/// Converts an [`Array`] to a [`CompressedPage`] based on options, descriptor and `encoding`. +pub fn array_to_page( + array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> Result { + if nested.len() == 1 { + // special case where validity == def levels + return array_to_page_simple(array, type_, options, encoding); + } + array_to_page_nested(array, type_, nested, options, encoding) +} + +/// Converts an [`Array`] to a [`CompressedPage`] based on options, descriptor and `encoding`. +pub fn array_to_page_simple( + array: &dyn Array, + type_: ParquetPrimitiveType, + options: WriteOptions, + encoding: Encoding, +) -> Result { + let data_type = array.data_type(); + if !can_encode(data_type, encoding) { + return Err(Error::InvalidArgumentError(format!( + "The datatype {data_type:?} cannot be encoded by {encoding:?}" + ))); + } + + match data_type.to_logical_type() { + DataType::Boolean => { + boolean::array_to_page(array.as_any().downcast_ref().unwrap(), options, type_) + }, + // casts below MUST match the casts done at the metadata (field -> parquet type). + DataType::UInt8 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::UInt16 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::UInt32 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::UInt64 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Int8 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Int16 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ) + }, + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Float32 => primitive::array_to_page_plain::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + ), + DataType::Float64 => primitive::array_to_page_plain::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + ), + DataType::Utf8 => utf8::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::LargeUtf8 => utf8::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Binary => binary::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::LargeBinary => binary::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Null => { + let array = Int32Array::new_null(DataType::Int32, array.len()); + primitive::array_to_page_plain::(&array, options, type_) + }, + DataType::Interval(IntervalUnit::YearMonth) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let mut values = Vec::::with_capacity(12 * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_le_bytes(); + values.extend_from_slice(bytes); + values.extend_from_slice(&[0; 8]); + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(12), + values.into(), + array.validity().cloned(), + ); + let statistics = if options.write_statistics { + Some(fixed_len_bytes::build_statistics(&array, type_.clone())) + } else { + None + }; + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + }, + DataType::Interval(IntervalUnit::DayTime) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let mut values = Vec::::with_capacity(12 * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_le_bytes(); + values.extend_from_slice(&[0; 4]); // months + values.extend_from_slice(bytes); // days and seconds + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(12), + values.into(), + array.validity().cloned(), + ); + let statistics = if options.write_statistics { + Some(fixed_len_bytes::build_statistics(&array, type_.clone())) + } else { + None + }; + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + }, + DataType::FixedSizeBinary(_) => { + let array = array.as_any().downcast_ref().unwrap(); + let statistics = if options.write_statistics { + Some(fixed_len_bytes::build_statistics(array, type_.clone())) + } else { + None + }; + + fixed_len_bytes::array_to_page(array, options, type_, statistics) + }, + DataType::Decimal256(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i32()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i64()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else if precision <= 38 { + let size = decimal_length_from_precision(precision); + let statistics = if options.write_statistics { + let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + array, + type_.clone(), + size, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.0.low().to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } else { + let size = 32; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal256(array, type_.clone(), size); + Some(stats) + } else { + None + }; + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes(); + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + DataType::Decimal(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else { + let size = decimal_length_from_precision(precision); + + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal(array, type_.clone(), size); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + other => Err(Error::NotYetImplemented(format!( + "Writing parquet pages for data type {other:?}" + ))), + } + .map(Page::Data) +} + +fn array_to_page_nested( + array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + _encoding: Encoding, +) -> Result { + use DataType::*; + match array.data_type().to_logical_type() { + Null => { + let array = Int32Array::new_null(DataType::Int32, array.len()); + primitive::nested_array_to_page::(&array, options, type_, nested) + }, + Boolean => { + let array = array.as_any().downcast_ref().unwrap(); + boolean::nested_array_to_page(array, options, type_, nested) + }, + Utf8 => { + let array = array.as_any().downcast_ref().unwrap(); + utf8::nested_array_to_page::(array, options, type_, nested) + }, + LargeUtf8 => { + let array = array.as_any().downcast_ref().unwrap(); + utf8::nested_array_to_page::(array, options, type_, nested) + }, + Binary => { + let array = array.as_any().downcast_ref().unwrap(); + binary::nested_array_to_page::(array, options, type_, nested) + }, + LargeBinary => { + let array = array.as_any().downcast_ref().unwrap(); + binary::nested_array_to_page::(array, options, type_, nested) + }, + UInt8 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt16 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt32 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt64 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int8 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int16 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int32 | Date32 | Time32(_) => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int64 | Date64 | Time64(_) | Timestamp(_, _) | Duration(_) => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Float32 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Float64 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Decimal(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else { + let size = decimal_length_from_precision(precision); + + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal(array, type_.clone(), size); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + Decimal256(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i32()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i64()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 38 { + let size = decimal_length_from_precision(precision); + let statistics = if options.write_statistics { + let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + array, + type_.clone(), + size, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.0.low().to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } else { + let size = 32; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal256(array, type_.clone(), size); + Some(stats) + } else { + None + }; + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes(); + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + other => Err(Error::NotYetImplemented(format!( + "Writing nested parquet pages for data type {other:?}" + ))), + } + .map(Page::Data) +} + +fn transverse_recursive T + Clone>( + data_type: &DataType, + map: F, + encodings: &mut Vec, +) { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 => encodings.push(map(data_type)), + List | FixedSizeList | LargeList => { + let a = data_type.to_logical_type(); + if let DataType::List(inner) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else if let DataType::LargeList(inner) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else if let DataType::FixedSizeList(inner, _) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else { + unreachable!() + } + }, + Struct => { + if let DataType::Struct(fields) = data_type.to_logical_type() { + for field in fields { + transverse_recursive(&field.data_type, map.clone(), encodings) + } + } else { + unreachable!() + } + }, + Map => { + if let DataType::Map(field, _) = data_type.to_logical_type() { + if let DataType::Struct(fields) = field.data_type.to_logical_type() { + for field in fields { + transverse_recursive(&field.data_type, map.clone(), encodings) + } + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + Union => todo!(), + } +} + +/// Transverses the `data_type` up to its (parquet) columns and returns a vector of +/// items based on `map`. +/// This is used to assign an [`Encoding`] to every parquet column based on the columns' type (see example) +/// # Example +/// ``` +/// use arrow2::io::parquet::write::{transverse, Encoding}; +/// use arrow2::datatypes::{DataType, Field}; +/// +/// let dt = DataType::Struct(vec![ +/// Field::new("a", DataType::Int64, true), +/// Field::new("b", DataType::List(Box::new(Field::new("item", DataType::Int32, true))), true), +/// ]); +/// +/// let encodings = transverse(&dt, |dt| Encoding::Plain); +/// assert_eq!(encodings, vec![Encoding::Plain, Encoding::Plain]); +/// ``` +pub fn transverse T + Clone>(data_type: &DataType, map: F) -> Vec { + let mut encodings = vec![]; + transverse_recursive(data_type, map, &mut encodings); + encodings +} diff --git a/crates/nano-arrow/src/io/parquet/write/nested/def.rs b/crates/nano-arrow/src/io/parquet/write/nested/def.rs new file mode 100644 index 000000000000..02947dd5bef9 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/nested/def.rs @@ -0,0 +1,584 @@ +use super::super::pages::{ListNested, Nested}; +use super::rep::num_values; +use super::to_length; +use crate::bitmap::Bitmap; +use crate::offset::Offset; + +trait DebugIter: Iterator + std::fmt::Debug {} + +impl + std::fmt::Debug> DebugIter for A {} + +fn single_iter<'a>( + validity: &'a Option, + is_optional: bool, + length: usize, +) -> Box { + match (is_optional, validity) { + (false, _) => { + Box::new(std::iter::repeat((0u32, 1usize)).take(length)) as Box + }, + (true, None) => { + Box::new(std::iter::repeat((1u32, 1usize)).take(length)) as Box + }, + (true, Some(validity)) => { + Box::new(validity.iter().map(|v| (v as u32, 1usize)).take(length)) as Box + }, + } +} + +fn single_list_iter<'a, O: Offset>(nested: &'a ListNested) -> Box { + match (nested.is_optional, &nested.validity) { + (false, _) => Box::new( + std::iter::repeat(0u32) + .zip(to_length(&nested.offsets)) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + (true, None) => Box::new( + std::iter::repeat(1u32) + .zip(to_length(&nested.offsets)) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + (true, Some(validity)) => Box::new( + validity + .iter() + .map(|x| (x as u32)) + .zip(to_length(&nested.offsets)) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + } +} + +fn iter<'a>(nested: &'a [Nested]) -> Vec> { + nested + .iter() + .map(|nested| match nested { + Nested::Primitive(validity, is_optional, length) => { + single_iter(validity, *is_optional, *length) + }, + Nested::List(nested) => single_list_iter(nested), + Nested::LargeList(nested) => single_list_iter(nested), + Nested::Struct(validity, is_optional, length) => { + single_iter(validity, *is_optional, *length) + }, + }) + .collect() +} + +/// Iterator adapter of parquet / dremel definition levels +#[derive(Debug)] +pub struct DefLevelsIter<'a> { + // iterators of validities and lengths. E.g. [[[None,b,c], None], None] -> [[(true, 2), (false, 0)], [(true, 3), (false, 0)], [(false, 1), (true, 1), (true, 1)]] + iter: Vec>, + // vector containing the remaining number of values of each iterator. + // e.g. the iters [[2, 2], [3, 4, 1, 2]] after the first iteration will return [2, 3], + // and remaining will be [2, 3]. + // on the second iteration, it will be `[2, 2]` (since iterations consume the last items) + remaining: Vec, /* < remaining.len() == iter.len() */ + validity: Vec, + // cache of the first `remaining` that is non-zero. Examples: + // * `remaining = [2, 2] => current_level = 2` + // * `remaining = [2, 0] => current_level = 1` + // * `remaining = [0, 0] => current_level = 0` + current_level: usize, /* < iter.len() */ + // the total definition level at any given point during the iteration + total: u32, /* < iter.len() */ + // the total number of items that this iterator will return + remaining_values: usize, +} + +impl<'a> DefLevelsIter<'a> { + pub fn new(nested: &'a [Nested]) -> Self { + let remaining_values = num_values(nested); + + let iter = iter(nested); + let remaining = vec![0; iter.len()]; + let validity = vec![0; iter.len()]; + + Self { + iter, + remaining, + validity, + total: 0, + current_level: 0, + remaining_values, + } + } +} + +impl<'a> Iterator for DefLevelsIter<'a> { + type Item = u32; + + fn next(&mut self) -> Option { + if self.remaining_values == 0 { + return None; + } + + if self.remaining.is_empty() { + self.remaining_values -= 1; + return Some(0); + } + + let mut empty_contrib = 0u32; + for ((iter, remaining), validity) in self + .iter + .iter_mut() + .zip(self.remaining.iter_mut()) + .zip(self.validity.iter_mut()) + .skip(self.current_level) + { + let (is_valid, length): (u32, usize) = iter.next()?; + *validity = is_valid; + self.total += is_valid; + + *remaining = length; + if length == 0 { + *validity = 0; + self.total -= is_valid; + empty_contrib = is_valid; + break; + } + self.current_level += 1; + } + + // track + if let Some(x) = self.remaining.get_mut(self.current_level.saturating_sub(1)) { + *x = x.saturating_sub(1) + } + + let r = Some(self.total + empty_contrib); + + for index in (1..self.current_level).rev() { + if self.remaining[index] == 0 { + self.current_level -= 1; + self.remaining[index - 1] -= 1; + self.total -= self.validity[index]; + } + } + if self.remaining[0] == 0 { + self.current_level = self.current_level.saturating_sub(1); + self.total -= self.validity[0]; + } + self.remaining_values -= 1; + r + } + + fn size_hint(&self) -> (usize, Option) { + let length = self.remaining_values; + (length, Some(length)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test(nested: Vec, expected: Vec) { + let mut iter = DefLevelsIter::new(&nested); + assert_eq!(iter.size_hint().0, expected.len()); + let result = iter.by_ref().collect::>(); + assert_eq!(result, expected); + assert_eq!(iter.size_hint().0, 0); + } + + #[test] + fn struct_optional() { + let b = [ + true, false, true, true, false, true, false, false, true, true, + ]; + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(Some(b.into()), true, 10), + ]; + let expected = vec![2, 1, 2, 2, 1, 2, 1, 1, 2, 2]; + + test(nested, expected) + } + + #[test] + fn nested_edge_simple() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 2), + ]; + let expected = vec![3, 3]; + + test(nested, expected) + } + + #[test] + fn struct_optional_1() { + let b = [ + true, false, true, true, false, true, false, false, true, true, + ]; + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(Some(b.into()), true, 10), + ]; + let expected = vec![2, 1, 2, 2, 1, 2, 1, 1, 2, 2]; + + test(nested, expected) + } + + #[test] + fn struct_optional_optional() { + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(None, true, 10), + ]; + let expected = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + test(nested, expected) + } + + #[test] + fn l1_required_required() { + let nested = vec![ + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 12), + ]; + let expected = vec![1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1]; + + test(nested, expected) + } + + #[test] + fn l1_optional_optional() { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + + let v0 = [true, false, true, true, true, true, false, true]; + let v1 = [ + true, true, //[0, 1] + true, false, true, //[2, None, 3] + true, true, true, //[4, 5, 6] + true, true, true, //[7, 8, 9] + true, //[10] + ]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(v0.into()), + }), + Nested::Primitive(Some(v1.into()), true, 12), + ]; + let expected = vec![3u32, 3, 0, 3, 2, 3, 3, 3, 3, 1, 3, 3, 3, 0, 3]; + + test(nested, expected) + } + + #[test] + fn l2_required_required_required() { + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + [ + [8], + [9, 10] + ] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + test(nested, expected) + } + + #[test] + fn l2_optional_required_required() { + let a = [true, false, true, true]; + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + None, + [ + [8], + [], + [9, 10] + ] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 2, 5].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![3, 3, 3, 3, 3, 3, 3, 0, 1, 3, 2, 3, 3]; + + test(nested, expected) + } + + #[test] + fn l2_optional_optional_required() { + let a = [true, false, true]; + let b = [true, true, true, true, false]; + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + None, + [ + [8], + [], + None, + ], + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 7, 8, 8, 8].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(None, false, 8), + ]; + let expected = vec![4, 4, 4, 4, 4, 4, 4, 0, 4, 3, 2]; + + test(nested, expected) + } + + #[test] + fn l2_optional_optional_optional() { + let a = [true, false, true]; + let b = [true, true, true, false]; + let c = [true, true, true, true, false, true, true, true]; + /* + [ + [ + [1,2,3], + [4,None,6,7], + ], + None, + [ + [8], + None, + ], + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 4].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 7, 8, 8].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(Some(c.into()), true, 8), + ]; + let expected = vec![5, 5, 5, 5, 4, 5, 5, 0, 5, 2]; + + test(nested, expected) + } + + /* + [{"a": "a"}, {"a": "b"}], + None, + [{"a": "b"}, None, {"a": "b"}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": "d"}, {"a": "d"}, {"a": "d"}], + None, + [{"a": "e"}], + */ + #[test] + fn nested_list_struct_nullable() { + let a = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let c = [true, false, true, true, true, true, false, true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Struct(Some(b.into()), true, 12), + Nested::Primitive(Some(a.into()), true, 12), + ]; + let expected = vec![4, 4, 0, 4, 2, 4, 3, 3, 3, 1, 4, 4, 4, 0, 4]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_nullable1() { + let c = [true, false]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Struct(None, true, 1), + Nested::Primitive(None, true, 1), + ]; + let expected = vec![4, 0]; + + test(nested, expected) + } + + #[test] + fn nested_struct_list_nullable() { + let a = [true, false, true, true, true, true, false, true]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let nested = vec![ + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Primitive(Some(b.into()), true, 12), + ]; + let expected = vec![4, 4, 1, 4, 3, 4, 4, 4, 4, 2, 4, 4, 4, 1, 4]; + + test(nested, expected) + } + + #[test] + fn nested_struct_list_nullable1() { + let a = [true, true, false]; + let nested = vec![ + Nested::Struct(None, true, 3), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1, 1].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Primitive(None, true, 1), + ]; + let expected = vec![4, 2, 1]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable1() { + /* + [ + [{"a": ["b"]}, None], + ] + */ + + let a = [true]; + let b = [true, false]; + let c = [true, false]; + let d = [true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Struct(Some(b.into()), true, 2), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Primitive(Some(d.into()), true, 1), + ]; + /* + 0 6 + 1 6 + 0 0 + 0 6 + 1 2 + */ + let expected = vec![6, 2]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable() { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let a = [true, false, true, true, true, true, false, true]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let c = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ]; + let d = [true, true, true, true, true, false, true, true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Struct(Some(b.into()), true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] + .try_into() + .unwrap(), + validity: Some(c.into()), + }), + Nested::Primitive(Some(d.into()), true, 8), + ]; + let expected = vec![6, 6, 0, 6, 2, 6, 3, 3, 3, 1, 6, 5, 6, 6, 0, 4]; + + test(nested, expected) + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/nested/mod.rs b/crates/nano-arrow/src/io/parquet/write/nested/mod.rs new file mode 100644 index 000000000000..042d731c57de --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/nested/mod.rs @@ -0,0 +1,118 @@ +mod def; +mod rep; + +use parquet2::encoding::hybrid_rle::encode_u32; +use parquet2::read::levels::get_bit_width; +use parquet2::write::Version; +pub use rep::num_values; + +use super::Nested; +use crate::error::Result; +use crate::offset::Offset; + +fn write_levels_v1) -> Result<()>>( + buffer: &mut Vec, + encode: F, +) -> Result<()> { + buffer.extend_from_slice(&[0; 4]); + let start = buffer.len(); + + encode(buffer)?; + + let end = buffer.len(); + let length = end - start; + + // write the first 4 bytes as length + let length = (length as i32).to_le_bytes(); + (0..4).for_each(|i| buffer[start - 4 + i] = length[i]); + Ok(()) +} + +/// writes the rep levels to a `Vec`. +fn write_rep_levels(buffer: &mut Vec, nested: &[Nested], version: Version) -> Result<()> { + let max_level = max_rep_level(nested) as i16; + if max_level == 0 { + return Ok(()); + } + let num_bits = get_bit_width(max_level); + + let levels = rep::RepLevelsIter::new(nested); + + match version { + Version::V1 => { + write_levels_v1(buffer, |buffer: &mut Vec| { + encode_u32(buffer, levels, num_bits)?; + Ok(()) + })?; + }, + Version::V2 => { + encode_u32(buffer, levels, num_bits)?; + }, + } + + Ok(()) +} + +/// writes the rep levels to a `Vec`. +fn write_def_levels(buffer: &mut Vec, nested: &[Nested], version: Version) -> Result<()> { + let max_level = max_def_level(nested) as i16; + if max_level == 0 { + return Ok(()); + } + let num_bits = get_bit_width(max_level); + + let levels = def::DefLevelsIter::new(nested); + + match version { + Version::V1 => write_levels_v1(buffer, move |buffer: &mut Vec| { + encode_u32(buffer, levels, num_bits)?; + Ok(()) + }), + Version::V2 => Ok(encode_u32(buffer, levels, num_bits)?), + } +} + +fn max_def_level(nested: &[Nested]) -> usize { + nested + .iter() + .map(|nested| match nested { + Nested::Primitive(_, is_optional, _) => *is_optional as usize, + Nested::List(nested) => 1 + (nested.is_optional as usize), + Nested::LargeList(nested) => 1 + (nested.is_optional as usize), + Nested::Struct(_, is_optional, _) => *is_optional as usize, + }) + .sum() +} + +fn max_rep_level(nested: &[Nested]) -> usize { + nested + .iter() + .map(|nested| match nested { + Nested::LargeList(_) | Nested::List(_) => 1, + Nested::Primitive(_, _, _) | Nested::Struct(_, _, _) => 0, + }) + .sum() +} + +fn to_length( + offsets: &[O], +) -> impl Iterator + std::fmt::Debug + Clone + '_ { + offsets + .windows(2) + .map(|w| w[1].to_usize() - w[0].to_usize()) +} + +/// Write `repetition_levels` and `definition_levels` to buffer. +pub fn write_rep_and_def( + page_version: Version, + nested: &[Nested], + buffer: &mut Vec, +) -> Result<(usize, usize)> { + write_rep_levels(buffer, nested, page_version)?; + let repetition_levels_byte_length = buffer.len(); + + write_def_levels(buffer, nested, page_version)?; + let definition_levels_byte_length = buffer.len() - repetition_levels_byte_length; + + Ok((repetition_levels_byte_length, definition_levels_byte_length)) +} diff --git a/crates/nano-arrow/src/io/parquet/write/nested/rep.rs b/crates/nano-arrow/src/io/parquet/write/nested/rep.rs new file mode 100644 index 000000000000..2bfbe1ce24f4 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/nested/rep.rs @@ -0,0 +1,370 @@ +use super::super::pages::Nested; +use super::to_length; + +trait DebugIter: Iterator + std::fmt::Debug {} + +impl + std::fmt::Debug> DebugIter for A {} + +fn iter<'a>(nested: &'a [Nested]) -> Vec> { + nested + .iter() + .filter_map(|nested| match nested { + Nested::Primitive(_, _, _) => None, + Nested::List(nested) => { + Some(Box::new(to_length(&nested.offsets)) as Box) + }, + Nested::LargeList(nested) => { + Some(Box::new(to_length(&nested.offsets)) as Box) + }, + Nested::Struct(_, _, _) => None, + }) + .collect() +} + +/// return number values of the nested +pub fn num_values(nested: &[Nested]) -> usize { + let pr = match nested.last().unwrap() { + Nested::Primitive(_, _, len) => *len, + _ => todo!(), + }; + + iter(nested) + .into_iter() + .enumerate() + .map(|(_, lengths)| { + lengths + .map(|length| if length == 0 { 1 } else { 0 }) + .sum::() + }) + .sum::() + + pr +} + +/// Iterator adapter of parquet / dremel repetition levels +#[derive(Debug)] +pub struct RepLevelsIter<'a> { + // iterators of lengths. E.g. [[[a,b,c], [d,e,f,g]], [[h], [i,j]]] -> [[2, 2], [3, 4, 1, 2]] + iter: Vec>, + // vector containing the remaining number of values of each iterator. + // e.g. the iters [[2, 2], [3, 4, 1, 2]] after the first iteration will return [2, 3], + // and remaining will be [2, 3]. + // on the second iteration, it will be `[2, 2]` (since iterations consume the last items) + remaining: Vec, /* < remaining.len() == iter.len() */ + // cache of the first `remaining` that is non-zero. Examples: + // * `remaining = [2, 2] => current_level = 2` + // * `remaining = [2, 0] => current_level = 1` + // * `remaining = [0, 0] => current_level = 0` + current_level: usize, /* < iter.len() */ + // the number to discount due to being the first element of the iterators. + total: usize, /* < iter.len() */ + + // the total number of items that this iterator will return + remaining_values: usize, +} + +impl<'a> RepLevelsIter<'a> { + pub fn new(nested: &'a [Nested]) -> Self { + let remaining_values = num_values(nested); + + let iter = iter(nested); + let remaining = vec![0; iter.len()]; + + Self { + iter, + remaining, + total: 0, + current_level: 0, + remaining_values, + } + } +} + +impl<'a> Iterator for RepLevelsIter<'a> { + type Item = u32; + + fn next(&mut self) -> Option { + if self.remaining_values == 0 { + return None; + } + if self.remaining.is_empty() { + self.remaining_values -= 1; + return Some(0); + } + + for (iter, remaining) in self + .iter + .iter_mut() + .zip(self.remaining.iter_mut()) + .skip(self.current_level) + { + let length: usize = iter.next()?; + *remaining = length; + if length == 0 { + break; + } + self.current_level += 1; + self.total += 1; + } + + // track + if let Some(x) = self.remaining.get_mut(self.current_level.saturating_sub(1)) { + *x = x.saturating_sub(1) + } + let r = Some((self.current_level - self.total) as u32); + + // update + for index in (1..self.current_level).rev() { + if self.remaining[index] == 0 { + self.current_level -= 1; + self.remaining[index - 1] -= 1; + } + } + if self.remaining[0] == 0 { + self.current_level = self.current_level.saturating_sub(1); + } + self.total = 0; + self.remaining_values -= 1; + + r + } + + fn size_hint(&self) -> (usize, Option) { + let length = self.remaining_values; + (length, Some(length)) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::pages::ListNested; + use super::*; + + fn test(nested: Vec, expected: Vec) { + let mut iter = RepLevelsIter::new(&nested); + assert_eq!(iter.size_hint().0, expected.len()); + assert_eq!(iter.by_ref().collect::>(), expected); + assert_eq!(iter.size_hint().0, 0); + } + + #[test] + fn struct_required() { + let nested = vec![ + Nested::Struct(None, false, 10), + Nested::Primitive(None, true, 10), + ]; + let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + test(nested, expected) + } + + #[test] + fn struct_optional() { + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(None, true, 10), + ]; + let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + test(nested, expected) + } + + #[test] + fn l1() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 12), + ]; + let expected = vec![0u32, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0]; + + test(nested, expected) + } + + #[test] + fn l2() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![0, 2, 2, 1, 2, 2, 2, 0, 0, 1, 2]; + + test(nested, expected) + } + + #[test] + fn list_of_struct() { + /* + [ + [{"a": "b"}],[{"a": "c"}] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 2), + Nested::Primitive(None, true, 2), + ]; + let expected = vec![0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 3].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 3), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 6, 7].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 7), + ]; + let expected = vec![0, 2, 2, 1, 2, 2, 0]; + + test(nested, expected) + } + + #[test] + fn struct_list_optional() { + /* + {"f1": ["a", "b", None, "c"]} + */ + let nested = vec![ + Nested::Struct(None, true, 1), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 4].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 4), + ]; + let expected = vec![0, 1, 1, 1]; + + test(nested, expected) + } + + #[test] + fn l2_other() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 1, 1, 3, 5, 5, 8, 8, 9].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 4, 5, 7, 8, 9, 10, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 12), + ]; + let expected = vec![0, 2, 0, 0, 2, 1, 0, 2, 1, 0, 0, 1, 1, 0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_1() { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + [], + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": []}, {"a": []}, {"a": []}], + [], + [{"a": ["d"]}, {"a": ["a"]}, {"a": ["c", "d"]}], + [], + [{"a": []}], + ] + // reps: [0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 2, 0, 0] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 8), + ]; + let expected = vec![0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 2, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_2() { + /* + [ + [{"a": []}], + ] + // reps: [0] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 0), + ]; + let expected = vec![0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_3() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 0), + ]; + let expected = vec![0, 0]; + // [1, 0], [0] + // pick last + + test(nested, expected) + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/pages.rs b/crates/nano-arrow/src/io/parquet/write/pages.rs new file mode 100644 index 000000000000..ce51bcdcda89 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/pages.rs @@ -0,0 +1,633 @@ +use std::fmt::Debug; + +use parquet2::page::Page; +use parquet2::schema::types::{ParquetType, PrimitiveType as ParquetPrimitiveType}; +use parquet2::write::DynIter; + +use super::{array_to_pages, Encoding, WriteOptions}; +use crate::array::{Array, ListArray, MapArray, StructArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::offset::{Offset, OffsetsBuffer}; + +#[derive(Debug, Clone, PartialEq)] +pub struct ListNested { + pub is_optional: bool, + pub offsets: OffsetsBuffer, + pub validity: Option, +} + +impl ListNested { + pub fn new(offsets: OffsetsBuffer, validity: Option, is_optional: bool) -> Self { + Self { + is_optional, + offsets, + validity, + } + } +} + +/// Descriptor of nested information of a field +#[derive(Debug, Clone, PartialEq)] +pub enum Nested { + /// a primitive (leaf or parquet column) + /// bitmap, _, length + Primitive(Option, bool, usize), + /// a list + List(ListNested), + /// a list + LargeList(ListNested), + /// a struct + Struct(Option, bool, usize), +} + +impl Nested { + /// Returns the length (number of rows) of the element + pub fn len(&self) -> usize { + match self { + Nested::Primitive(_, _, length) => *length, + Nested::List(nested) => nested.offsets.len_proxy(), + Nested::LargeList(nested) => nested.offsets.len_proxy(), + Nested::Struct(_, _, len) => *len, + } + } +} + +/// Constructs the necessary `Vec>` to write the rep and def levels of `array` to parquet +pub fn to_nested(array: &dyn Array, type_: &ParquetType) -> Result>> { + let mut nested = vec![]; + + to_nested_recursive(array, type_, &mut nested, vec![])?; + Ok(nested) +} + +fn to_nested_recursive( + array: &dyn Array, + type_: &ParquetType, + nested: &mut Vec>, + mut parents: Vec, +) -> Result<()> { + let is_optional = is_nullable(type_.get_field_info()); + + use PhysicalType::*; + match array.data_type().to_physical_type() { + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + let fields = if let ParquetType::GroupType { fields, .. } = type_ { + fields + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a struct array".to_string(), + )); + }; + + parents.push(Nested::Struct( + array.validity().cloned(), + is_optional, + array.len(), + )); + + for (type_, array) in fields.iter().zip(array.values()) { + to_nested_recursive(array.as_ref(), type_, nested, parents.clone())?; + } + }, + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + }; + + parents.push(Nested::List(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + }; + + parents.push(Nested::LargeList(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a map array".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a map array".to_string(), + )); + }; + + parents.push(Nested::List(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.field().as_ref(), type_, nested, parents)?; + }, + _ => { + parents.push(Nested::Primitive( + array.validity().cloned(), + is_optional, + array.len(), + )); + nested.push(parents) + }, + } + Ok(()) +} + +/// Convert [`Array`] to `Vec<&dyn Array>` leaves in DFS order. +pub fn to_leaves(array: &dyn Array) -> Vec<&dyn Array> { + let mut leaves = vec![]; + to_leaves_recursive(array, &mut leaves); + leaves +} + +fn to_leaves_recursive<'a>(array: &'a dyn Array, leaves: &mut Vec<&'a dyn Array>) { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + array + .values() + .iter() + .for_each(|a| to_leaves_recursive(a.as_ref(), leaves)); + }, + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + to_leaves_recursive(array.values().as_ref(), leaves); + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + to_leaves_recursive(array.values().as_ref(), leaves); + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + to_leaves_recursive(array.field().as_ref(), leaves); + }, + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | LargeUtf8 | Dictionary(_) => leaves.push(array), + other => todo!("Writing {:?} to parquet not yet implemented", other), + } +} + +/// Convert `ParquetType` to `Vec` leaves in DFS order. +pub fn to_parquet_leaves(type_: ParquetType) -> Vec { + let mut leaves = vec![]; + to_parquet_leaves_recursive(type_, &mut leaves); + leaves +} + +fn to_parquet_leaves_recursive(type_: ParquetType, leaves: &mut Vec) { + match type_ { + ParquetType::PrimitiveType(primitive) => leaves.push(primitive), + ParquetType::GroupType { fields, .. } => { + fields + .into_iter() + .for_each(|type_| to_parquet_leaves_recursive(type_, leaves)); + }, + } +} + +/// Returns a vector of iterators of [`Page`], one per leaf column in the array +pub fn array_to_columns + Send + Sync>( + array: A, + type_: ParquetType, + options: WriteOptions, + encoding: &[Encoding], +) -> Result>>> { + let array = array.as_ref(); + let nested = to_nested(array, &type_)?; + + let types = to_parquet_leaves(type_); + + let values = to_leaves(array); + + assert_eq!(encoding.len(), types.len()); + + values + .iter() + .zip(nested) + .zip(types) + .zip(encoding.iter()) + .map(|(((values, nested), type_), encoding)| { + array_to_pages(*values, type_, &nested, options, *encoding) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use parquet2::schema::types::{GroupLogicalType, PrimitiveConvertedType, PrimitiveLogicalType}; + use parquet2::schema::Repetition; + + use super::super::{FieldInfo, ParquetPhysicalType, ParquetPrimitiveType}; + use super::*; + use crate::array::*; + use crate::bitmap::Bitmap; + use crate::datatypes::*; + + #[test] + fn test_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + vec![ + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_struct_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let fields = vec![ + Field::new("b", array.data_type().clone(), true), + Field::new("c", array.data_type().clone(), true), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![Box::new(array.clone()), Box::new(array)], + None, + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_.clone(), type_], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + // a.b.b + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + // a.b.c + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + // a.c.b + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + // a.c.c + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_list_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let array = ListArray::new( + DataType::List(Box::new(Field::new("l", array.data_type().clone(), true))), + vec![0i32, 2, 4].try_into().unwrap(), + Box::new(array), + None, + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "l".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ParquetType::GroupType { + field_info: FieldInfo { + name: "list".to_string(), + repetition: Repetition::Repeated, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_], + }], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_map() { + let kv_type = DataType::Struct(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ]); + let kv_field = Field::new("kv", kv_type.clone(), false); + let map_type = DataType::Map(Box::new(kv_field), false); + + let key_array = Utf8Array::::from_slice(["k1", "k2", "k3", "k4", "k5", "k6"]).boxed(); + let val_array = Int32Array::from_slice([42, 28, 19, 31, 21, 17]).boxed(); + let kv_array = StructArray::try_new(kv_type, vec![key_array, val_array], None) + .unwrap() + .boxed(); + let offsets = OffsetsBuffer::try_from(vec![0, 2, 3, 4, 6]).unwrap(); + + let array = MapArray::try_new(map_type, offsets, kv_array, None).unwrap(); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "kv".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "k".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: Some(PrimitiveLogicalType::String), + converted_type: Some(PrimitiveConvertedType::Utf8), + physical_type: ParquetPhysicalType::ByteArray, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "v".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "m".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: Some(GroupLogicalType::Map), + converted_type: None, + fields: vec![ParquetType::GroupType { + field_info: FieldInfo { + name: "map".to_string(), + repetition: Repetition::Repeated, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_], + }], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 3, 4, 6].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 6), + Nested::Primitive(None, false, 6), + ], + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 3, 4, 6].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 6), + Nested::Primitive(None, false, 6), + ], + ] + ); + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/primitive/basic.rs b/crates/nano-arrow/src/io/parquet/write/primitive/basic.rs new file mode 100644 index 000000000000..14d5f9077b49 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/primitive/basic.rs @@ -0,0 +1,192 @@ +use parquet2::encoding::delta_bitpacked::encode; +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, PrimitiveStatistics}; +use parquet2::types::NativeType as ParquetNativeType; + +use super::super::{utils, WriteOptions}; +use crate::array::{Array, PrimitiveArray}; +use crate::error::Error; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::utils::ExactSizedIter; +use crate::types::NativeType; + +pub(crate) fn encode_plain( + array: &PrimitiveArray, + is_optional: bool, + mut buffer: Vec, +) -> Vec +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + if is_optional { + buffer.reserve(std::mem::size_of::

() * (array.len() - array.null_count())); + // append the non-null values + array.iter().for_each(|x| { + if let Some(x) = x { + let parquet_native: P = x.as_(); + buffer.extend_from_slice(parquet_native.to_le_bytes().as_ref()) + } + }); + } else { + buffer.reserve(std::mem::size_of::

() * array.len()); + // append all values + array.values().iter().for_each(|x| { + let parquet_native: P = x.as_(); + buffer.extend_from_slice(parquet_native.to_le_bytes().as_ref()) + }); + } + buffer +} + +pub(crate) fn encode_delta( + array: &PrimitiveArray, + is_optional: bool, + mut buffer: Vec, +) -> Vec +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, + P: num_traits::AsPrimitive, +{ + if is_optional { + // append the non-null values + let iterator = array.iter().flatten().map(|x| { + let parquet_native: P = x.as_(); + let integer: i64 = parquet_native.as_(); + integer + }); + let iterator = ExactSizedIter::new(iterator, array.len() - array.null_count()); + encode(iterator, &mut buffer) + } else { + // append all values + let iterator = array.values().iter().map(|x| { + let parquet_native: P = x.as_(); + let integer: i64 = parquet_native.as_(); + integer + }); + encode(iterator, &mut buffer) + } + buffer +} + +pub fn array_to_page_plain( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, +) -> Result +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + array_to_page(array, options, type_, Encoding::Plain, encode_plain) +} + +pub fn array_to_page_integer( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> Result +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, + P: num_traits::AsPrimitive, +{ + match encoding { + Encoding::DeltaBinaryPacked => array_to_page(array, options, type_, encoding, encode_delta), + Encoding::Plain => array_to_page(array, options, type_, encoding, encode_plain), + other => Err(Error::nyi(format!("Encoding integer as {other:?}"))), + } +} + +pub fn array_to_page, bool, Vec) -> Vec>( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, + encode: F, +) -> Result +where + T: NativeType, + P: ParquetNativeType, + // constraint required to build statistics + T: num_traits::AsPrimitive

, +{ + let is_optional = is_nullable(&type_.field_info); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + let buffer = encode(array, is_optional, buffer); + + let statistics = if options.write_statistics { + Some(serialize_statistics(&build_statistics( + array, + type_.clone(), + ))) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub fn build_statistics( + array: &PrimitiveArray, + primitive_type: PrimitiveType, +) -> PrimitiveStatistics

+where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + PrimitiveStatistics::

{ + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .map(|x| { + let x: P = x.as_(); + x + }) + .max_by(|x, y| x.ord(y)), + min_value: array + .iter() + .flatten() + .map(|x| { + let x: P = x.as_(); + x + }) + .min_by(|x, y| x.ord(y)), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/primitive/mod.rs b/crates/nano-arrow/src/io/parquet/write/primitive/mod.rs new file mode 100644 index 000000000000..96318ab0a89b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/primitive/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use basic::{array_to_page_integer, array_to_page_plain}; +pub(crate) use basic::{build_statistics, encode_plain}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/primitive/nested.rs b/crates/nano-arrow/src/io/parquet/write/primitive/nested.rs new file mode 100644 index 000000000000..fe859013c96b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/primitive/nested.rs @@ -0,0 +1,56 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::serialize_statistics; +use parquet2::types::NativeType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, PrimitiveArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; +use crate::types::NativeType as ArrowNativeType; + +pub fn array_to_page( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result +where + T: ArrowNativeType, + R: NativeType, + T: num_traits::AsPrimitive, +{ + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + let buffer = encode_plain(array, is_optional, buffer); + + let statistics = if options.write_statistics { + Some(serialize_statistics(&build_statistics( + array, + type_.clone(), + ))) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/row_group.rs b/crates/nano-arrow/src/io/parquet/write/row_group.rs new file mode 100644 index 000000000000..d281b63cebda --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/row_group.rs @@ -0,0 +1,126 @@ +use parquet2::error::Error as ParquetError; +use parquet2::schema::types::ParquetType; +use parquet2::write::Compressor; +use parquet2::FallibleStreamingIterator; + +use super::{ + array_to_columns, to_parquet_schema, DynIter, DynStreamingIterator, Encoding, RowGroupIter, + SchemaDescriptor, WriteOptions, +}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; + +/// Maps a [`Chunk`] and parquet-specific options to an [`RowGroupIter`] used to +/// write to parquet +/// # Panics +/// Iff +/// * `encodings.len() != fields.len()` or +/// * `encodings.len() != chunk.arrays().len()` +pub fn row_group_iter + 'static + Send + Sync>( + chunk: Chunk, + encodings: Vec>, + fields: Vec, + options: WriteOptions, +) -> RowGroupIter<'static, Error> { + assert_eq!(encodings.len(), fields.len()); + assert_eq!(encodings.len(), chunk.arrays().len()); + DynIter::new( + chunk + .into_arrays() + .into_iter() + .zip(fields) + .zip(encodings) + .flat_map(move |((array, type_), encoding)| { + let encoded_columns = array_to_columns(array, type_, options, &encoding).unwrap(); + encoded_columns + .into_iter() + .map(|encoded_pages| { + let pages = encoded_pages; + + let pages = DynIter::new( + pages + .into_iter() + .map(|x| x.map_err(|e| ParquetError::OutOfSpec(e.to_string()))), + ); + + let compressed_pages = Compressor::new(pages, options.compression, vec![]) + .map_err(Error::from); + Ok(DynStreamingIterator::new(compressed_pages)) + }) + .collect::>() + }), + ) +} + +/// An iterator adapter that converts an iterator over [`Chunk`] into an iterator +/// of row groups. +/// Use it to create an iterator consumable by the parquet's API. +pub struct RowGroupIterator + 'static, I: Iterator>>> { + iter: I, + options: WriteOptions, + parquet_schema: SchemaDescriptor, + encodings: Vec>, +} + +impl + 'static, I: Iterator>>> RowGroupIterator { + /// Creates a new [`RowGroupIterator`] from an iterator over [`Chunk`]. + /// + /// # Errors + /// Iff + /// * the Arrow schema can't be converted to a valid Parquet schema. + /// * the length of the encodings is different from the number of fields in schema + pub fn try_new( + iter: I, + schema: &Schema, + options: WriteOptions, + encodings: Vec>, + ) -> Result { + if encodings.len() != schema.fields.len() { + return Err(Error::InvalidArgumentError( + "The number of encodings must equal the number of fields".to_string(), + )); + } + let parquet_schema = to_parquet_schema(schema)?; + + Ok(Self { + iter, + options, + parquet_schema, + encodings, + }) + } + + /// Returns the [`SchemaDescriptor`] of the [`RowGroupIterator`]. + pub fn parquet_schema(&self) -> &SchemaDescriptor { + &self.parquet_schema + } +} + +impl + 'static + Send + Sync, I: Iterator>>> Iterator + for RowGroupIterator +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let options = self.options; + + self.iter.next().map(|maybe_chunk| { + let chunk = maybe_chunk?; + if self.encodings.len() != chunk.arrays().len() { + return Err(Error::InvalidArgumentError( + "The number of arrays in the chunk must equal the number of fields in the schema" + .to_string(), + )); + }; + let encodings = self.encodings.clone(); + Ok(row_group_iter( + chunk, + encodings, + self.parquet_schema.fields().to_vec(), + options, + )) + }) + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/schema.rs b/crates/nano-arrow/src/io/parquet/write/schema.rs new file mode 100644 index 000000000000..6f3ade5d46b3 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/schema.rs @@ -0,0 +1,379 @@ +use base64::engine::general_purpose; +use base64::Engine as _; +use parquet2::metadata::KeyValue; +use parquet2::schema::types::{ + GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType, + PrimitiveConvertedType, PrimitiveLogicalType, TimeUnit as ParquetTimeUnit, +}; +use parquet2::schema::Repetition; + +use super::super::ARROW_SCHEMA_META_KEY; +use crate::datatypes::{DataType, Field, Schema, TimeUnit}; +use crate::error::{Error, Result}; +use crate::io::ipc::write::{default_ipc_fields, schema_to_bytes}; +use crate::io::parquet::write::decimal_length_from_precision; + +pub fn schema_to_metadata_key(schema: &Schema) -> KeyValue { + let serialized_schema = schema_to_bytes(schema, &default_ipc_fields(&schema.fields)); + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.extend_from_slice(&[255u8, 255, 255, 255]); + len_prefix_schema.extend_from_slice(&(schema_len as u32).to_le_bytes()); + len_prefix_schema.extend_from_slice(&serialized_schema); + + let encoded = general_purpose::STANDARD.encode(&len_prefix_schema); + + KeyValue { + key: ARROW_SCHEMA_META_KEY.to_string(), + value: Some(encoded), + } +} + +/// Creates a [`ParquetType`] from a [`Field`]. +pub fn to_parquet_type(field: &Field) -> Result { + let name = field.name.clone(); + let repetition = if field.is_nullable { + Repetition::Optional + } else { + Repetition::Required + }; + // create type from field + match field.data_type().to_logical_type() { + DataType::Null => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + Some(PrimitiveLogicalType::Unknown), + None, + )?), + DataType::Boolean => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Boolean, + repetition, + None, + None, + None, + )?), + DataType::Int32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + None, + None, + )?), + // DataType::Duration(_) has no parquet representation => do not apply any logical type + DataType::Int64 | DataType::Duration(_) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + DataType::Date64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + DataType::Float32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Float, + repetition, + None, + None, + None, + )?), + DataType::Float64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Double, + repetition, + None, + None, + None, + )?), + DataType::Binary | DataType::LargeBinary => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::ByteArray, + repetition, + None, + None, + None, + )?), + DataType::Utf8 | DataType::LargeUtf8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::ByteArray, + repetition, + Some(PrimitiveConvertedType::Utf8), + Some(PrimitiveLogicalType::String), + None, + )?), + DataType::Date32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Date), + Some(PrimitiveLogicalType::Date), + None, + )?), + DataType::Int8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Int8), + Some(PrimitiveLogicalType::Integer(IntegerType::Int8)), + None, + )?), + DataType::Int16 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Int16), + Some(PrimitiveLogicalType::Integer(IntegerType::Int16)), + None, + )?), + DataType::UInt8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint8), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt8)), + None, + )?), + DataType::UInt16 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint16), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt16)), + None, + )?), + DataType::UInt32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint32), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt32)), + None, + )?), + DataType::UInt64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + Some(PrimitiveConvertedType::Uint64), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt64)), + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + DataType::Timestamp(TimeUnit::Second, _) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + DataType::Timestamp(time_unit, zone) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + Some(PrimitiveLogicalType::Timestamp { + is_adjusted_to_utc: matches!(zone, Some(z) if !z.as_str().is_empty()), + unit: match time_unit { + TimeUnit::Second => unreachable!(), + TimeUnit::Millisecond => ParquetTimeUnit::Milliseconds, + TimeUnit::Microsecond => ParquetTimeUnit::Microseconds, + TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds, + }, + }), + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + DataType::Time32(TimeUnit::Second) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + None, + None, + )?), + DataType::Time32(TimeUnit::Millisecond) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::TimeMillis), + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: ParquetTimeUnit::Milliseconds, + }), + None, + )?), + DataType::Time64(time_unit) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + match time_unit { + TimeUnit::Microsecond => Some(PrimitiveConvertedType::TimeMicros), + TimeUnit::Nanosecond => None, + _ => unreachable!(), + }, + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: match time_unit { + TimeUnit::Microsecond => ParquetTimeUnit::Microseconds, + TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds, + _ => unreachable!(), + }, + }), + None, + )?), + DataType::Struct(fields) => { + if fields.is_empty() { + return Err(Error::InvalidArgumentError( + "Parquet does not support writing empty structs".to_string(), + )); + } + // recursively convert children to types/nodes + let fields = fields + .iter() + .map(to_parquet_type) + .collect::>>()?; + Ok(ParquetType::from_group( + name, repetition, None, None, fields, None, + )) + }, + DataType::Dictionary(_, value, _) => { + let dict_field = Field::new(name.as_str(), value.as_ref().clone(), field.is_nullable); + to_parquet_type(&dict_field) + }, + DataType::FixedSizeBinary(size) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(*size), + repetition, + None, + None, + None, + )?), + DataType::Decimal(precision, scale) => { + let precision = *precision; + let scale = *scale; + let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale)); + + let physical_type = if precision <= 9 { + PhysicalType::Int32 + } else if precision <= 18 { + PhysicalType::Int64 + } else { + let len = decimal_length_from_precision(precision); + PhysicalType::FixedLenByteArray(len) + }; + Ok(ParquetType::try_from_primitive( + name, + physical_type, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + }, + DataType::Decimal256(precision, scale) => { + let precision = *precision; + let scale = *scale; + let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale)); + + if precision <= 9 { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else if precision <= 18 { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else if precision <= 38 { + let len = decimal_length_from_precision(precision); + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(len), + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(32), + repetition, + None, + None, + None, + )?) + } + }, + DataType::Interval(_) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(12), + repetition, + Some(PrimitiveConvertedType::Interval), + None, + None, + )?), + DataType::List(f) | DataType::FixedSizeList(f, _) | DataType::LargeList(f) => { + Ok(ParquetType::from_group( + name, + repetition, + Some(GroupConvertedType::List), + Some(GroupLogicalType::List), + vec![ParquetType::from_group( + "list".to_string(), + Repetition::Repeated, + None, + None, + vec![to_parquet_type(f)?], + None, + )], + None, + )) + }, + DataType::Map(f, _) => Ok(ParquetType::from_group( + name, + repetition, + Some(GroupConvertedType::Map), + Some(GroupLogicalType::Map), + vec![ParquetType::from_group( + "map".to_string(), + Repetition::Repeated, + None, + None, + vec![to_parquet_type(f)?], + None, + )], + None, + )), + other => Err(Error::NotYetImplemented(format!( + "Writing the data type {other:?} is not yet implemented" + ))), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/sink.rs b/crates/nano-arrow/src/io/parquet/write/sink.rs new file mode 100644 index 000000000000..d357d7b89c2d --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/sink.rs @@ -0,0 +1,236 @@ +use std::pin::Pin; +use std::task::Poll; + +use ahash::AHashMap; +use futures::future::BoxFuture; +use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink, TryFutureExt}; +use parquet2::metadata::KeyValue; +use parquet2::write::{FileStreamer, WriteOptions as ParquetWriteOptions}; + +use super::file::add_arrow_schema; +use super::{Encoding, SchemaDescriptor, WriteOptions}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::Error; + +/// Sink that writes array [`chunks`](Chunk) as a Parquet file. +/// +/// Any values in the sink's `metadata` field will be written to the file's footer +/// when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::parquet::write::{Encoding, WriteOptions, CompressionOptions, Version}; +/// # use arrow2::io::parquet::write::FileSink; +/// # futures::executor::block_on(async move { +/// +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// let encoding = vec![vec![Encoding::Plain]]; +/// let options = WriteOptions { +/// write_statistics: true, +/// compression: CompressionOptions::Uncompressed, +/// version: Version::V2, +/// data_pagesize_limit: None, +/// }; +/// +/// let mut buffer = vec![]; +/// let mut sink = FileSink::try_new( +/// &mut buffer, +/// schema, +/// encoding, +/// options, +/// )?; +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk).await?; +/// } +/// sink.metadata.insert(String::from("key"), Some(String::from("value"))); +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { + writer: Option>, + task: Option>, Error>>>, + options: WriteOptions, + encodings: Vec>, + schema: Schema, + parquet_schema: SchemaDescriptor, + /// Key-value metadata that will be written to the file on close. + pub metadata: AHashMap>, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Send + Unpin + 'a, +{ + /// Create a new sink that writes arrays to the provided `writer`. + /// + /// # Error + /// Iff + /// * the Arrow schema can't be converted to a valid Parquet schema. + /// * the length of the encodings is different from the number of fields in schema + pub fn try_new( + writer: W, + schema: Schema, + encodings: Vec>, + options: WriteOptions, + ) -> Result { + if encodings.len() != schema.fields.len() { + return Err(Error::InvalidArgumentError( + "The number of encodings must equal the number of fields".to_string(), + )); + } + + let parquet_schema = crate::io::parquet::write::to_parquet_schema(&schema)?; + let created_by = Some("Arrow2 - Native Rust implementation of Arrow".to_string()); + let writer = FileStreamer::new( + writer, + parquet_schema.clone(), + ParquetWriteOptions { + version: options.version, + write_statistics: options.write_statistics, + }, + created_by, + ); + Ok(Self { + writer: Some(writer), + task: None, + options, + schema, + encodings, + parquet_schema, + metadata: AHashMap::default(), + }) + } + + /// The Arrow [`Schema`] for the file. + pub fn schema(&self) -> &Schema { + &self.schema + } + + /// The Parquet [`SchemaDescriptor`] for the file. + pub fn parquet_schema(&self) -> &SchemaDescriptor { + &self.parquet_schema + } + + /// The write options for the file. + pub fn options(&self) -> &WriteOptions { + &self.options + } + + fn poll_complete( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(writer) => { + self.task = None; + self.writer = writer; + Poll::Ready(Ok(())) + }, + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + }, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink>> for FileSink<'a, W> +where + W: AsyncWrite + Send + Unpin + 'a, +{ + type Error = Error; + + fn start_send(self: Pin<&mut Self>, item: Chunk>) -> Result<(), Self::Error> { + if self.schema.fields.len() != item.arrays().len() { + return Err(Error::InvalidArgumentError( + "The number of arrays in the chunk must equal the number of fields in the schema" + .to_string(), + )); + } + let this = self.get_mut(); + if let Some(mut writer) = this.writer.take() { + let rows = crate::io::parquet::write::row_group_iter( + item, + this.encodings.clone(), + this.parquet_schema.fields().to_vec(), + this.options, + ); + this.task = Some(Box::pin(async move { + writer.write(rows).await?; + Ok(Some(writer)) + })); + Ok(()) + } else { + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer closed".to_string(), + ))) + } + } + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_complete(cx)) { + Ok(()) => { + let writer = this.writer.take(); + if let Some(mut writer) = writer { + let meta = std::mem::take(&mut this.metadata); + let metadata = if meta.is_empty() { + None + } else { + Some( + meta.into_iter() + .map(|(k, v)| KeyValue::new(k, v)) + .collect::>(), + ) + }; + let kv_meta = add_arrow_schema(&this.schema, metadata); + + this.task = Some(Box::pin(async move { + writer.end(kv_meta).map_err(Error::from).await?; + writer.into_inner().close().map_err(Error::from).await?; + Ok(None) + })); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + }, + Err(error) => Poll::Ready(Err(error)), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/utf8/basic.rs b/crates/nano-arrow/src/io/parquet/write/utf8/basic.rs new file mode 100644 index 000000000000..39f9c157c988 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utf8/basic.rs @@ -0,0 +1,117 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, BinaryStatistics, ParquetStatistics, Statistics}; + +use super::super::binary::{encode_delta, ord_binary}; +use super::super::{utils, WriteOptions}; +use crate::array::{Array, Utf8Array}; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::offset::Offset; + +pub(crate) fn encode_plain( + array: &Utf8Array, + is_optional: bool, + buffer: &mut Vec, +) { + if is_optional { + array.iter().for_each(|x| { + if let Some(x) = x { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x.as_bytes()); + } + }) + } else { + array.values_iter().for_each(|x| { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x.as_bytes()); + }) + } +} + +pub fn array_to_page( + array: &Utf8Array, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> Result { + let validity = array.validity(); + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, is_optional, &mut buffer), + Encoding::DeltaLengthByteArray => encode_delta( + array.values(), + array.offsets().buffer(), + array.validity(), + is_optional, + &mut buffer, + ), + _ => { + return Err(Error::InvalidArgumentError(format!( + "Datatype {:?} cannot be encoded by {:?} encoding", + array.data_type(), + encoding + ))) + }, + } + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub(crate) fn build_statistics( + array: &Utf8Array, + primitive_type: PrimitiveType, +) -> ParquetStatistics { + let statistics = &BinaryStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .map(|x| x.as_bytes()) + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .map(|x| x.as_bytes()) + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + } as &dyn Statistics; + serialize_statistics(statistics) +} diff --git a/crates/nano-arrow/src/io/parquet/write/utf8/mod.rs b/crates/nano-arrow/src/io/parquet/write/utf8/mod.rs new file mode 100644 index 000000000000..e4ef46599e2c --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utf8/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub(crate) use basic::{build_statistics, encode_plain}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/utf8/nested.rs b/crates/nano-arrow/src/io/parquet/write/utf8/nested.rs new file mode 100644 index 000000000000..43767246d194 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utf8/nested.rs @@ -0,0 +1,48 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, Utf8Array}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; +use crate::offset::Offset; + +pub fn array_to_page( + array: &Utf8Array, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result +where + O: Offset, +{ + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, is_optional, &mut buffer); + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/utils.rs b/crates/nano-arrow/src/io/parquet/write/utils.rs new file mode 100644 index 000000000000..caaba98a07fe --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utils.rs @@ -0,0 +1,146 @@ +use parquet2::compression::CompressionOptions; +use parquet2::encoding::hybrid_rle::encode_bool; +use parquet2::encoding::Encoding; +use parquet2::metadata::Descriptor; +use parquet2::page::{DataPage, DataPageHeader, DataPageHeaderV1, DataPageHeaderV2}; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::ParquetStatistics; + +use super::{Version, WriteOptions}; +use crate::bitmap::Bitmap; +use crate::error::Result; + +fn encode_iter_v1>(buffer: &mut Vec, iter: I) -> Result<()> { + buffer.extend_from_slice(&[0; 4]); + let start = buffer.len(); + encode_bool(buffer, iter)?; + let end = buffer.len(); + let length = end - start; + + // write the first 4 bytes as length + let length = (length as i32).to_le_bytes(); + (0..4).for_each(|i| buffer[start - 4 + i] = length[i]); + Ok(()) +} + +fn encode_iter_v2>(writer: &mut Vec, iter: I) -> Result<()> { + Ok(encode_bool(writer, iter)?) +} + +fn encode_iter>( + writer: &mut Vec, + iter: I, + version: Version, +) -> Result<()> { + match version { + Version::V1 => encode_iter_v1(writer, iter), + Version::V2 => encode_iter_v2(writer, iter), + } +} + +/// writes the def levels to a `Vec` and returns it. +pub fn write_def_levels( + writer: &mut Vec, + is_optional: bool, + validity: Option<&Bitmap>, + len: usize, + version: Version, +) -> Result<()> { + // encode def levels + match (is_optional, validity) { + (true, Some(validity)) => encode_iter(writer, validity.iter(), version), + (true, None) => encode_iter(writer, std::iter::repeat(true).take(len), version), + _ => Ok(()), // is required => no def levels + } +} + +#[allow(clippy::too_many_arguments)] +pub fn build_plain_page( + buffer: Vec, + num_values: usize, + num_rows: usize, + null_count: usize, + repetition_levels_byte_length: usize, + definition_levels_byte_length: usize, + statistics: Option, + type_: PrimitiveType, + options: WriteOptions, + encoding: Encoding, +) -> Result { + let header = match options.version { + Version::V1 => DataPageHeader::V1(DataPageHeaderV1 { + num_values: num_values as i32, + encoding: encoding.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }), + Version::V2 => DataPageHeader::V2(DataPageHeaderV2 { + num_values: num_values as i32, + encoding: encoding.into(), + num_nulls: null_count as i32, + num_rows: num_rows as i32, + definition_levels_byte_length: definition_levels_byte_length as i32, + repetition_levels_byte_length: repetition_levels_byte_length as i32, + is_compressed: Some(options.compression != CompressionOptions::Uncompressed), + statistics, + }), + }; + Ok(DataPage::new( + header, + buffer, + Descriptor { + primitive_type: type_, + max_def_level: 0, + max_rep_level: 0, + }, + Some(num_rows), + )) +} + +/// Auxiliary iterator adapter to declare the size hint of an iterator. +pub(super) struct ExactSizedIter> { + iter: I, + remaining: usize, +} + +impl + Clone> Clone for ExactSizedIter { + fn clone(&self) -> Self { + Self { + iter: self.iter.clone(), + remaining: self.remaining, + } + } +} + +impl> ExactSizedIter { + pub fn new(iter: I, length: usize) -> Self { + Self { + iter, + remaining: length, + } + } +} + +impl> Iterator for ExactSizedIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|x| { + self.remaining -= 1; + x + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +/// Returns the number of bits needed to bitpack `max` +#[inline] +pub fn get_bit_width(max: u64) -> u32 { + 64 - max.leading_zeros() +} diff --git a/crates/nano-arrow/src/lib.rs b/crates/nano-arrow/src/lib.rs new file mode 100644 index 000000000000..c26b3e1a0b28 --- /dev/null +++ b/crates/nano-arrow/src/lib.rs @@ -0,0 +1,42 @@ +// So that we have more control over what is `unsafe` inside an `unsafe` block +#![allow(unused_unsafe)] +// +#![allow(clippy::len_without_is_empty)] +// this landed on 1.60. Let's not force everyone to bump just yet +#![allow(clippy::unnecessary_lazy_evaluations)] +// Trait objects must be returned as a &Box so that they can be cloned +#![allow(clippy::borrowed_box)] +// Allow type complexity warning to avoid API break. +#![allow(clippy::type_complexity)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "nightly_build", feature(build_hasher_simple_hash_one))] + +#[macro_use] +pub mod array; +pub mod bitmap; +pub mod buffer; +pub mod chunk; +pub mod error; +#[cfg(feature = "io_ipc")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] +pub mod mmap; + +pub mod offset; +pub mod scalar; +pub mod trusted_len; +pub mod types; + +pub mod compute; +pub mod io; +pub mod temporal_conversions; + +pub mod datatypes; + +pub mod ffi; +pub mod util; + +// re-exported because we return `Either` in our public API +// re-exported to construct dictionaries +pub use ahash::AHashMap; +pub use either::Either; diff --git a/crates/nano-arrow/src/mmap/array.rs b/crates/nano-arrow/src/mmap/array.rs new file mode 100644 index 000000000000..8efd6afcd671 --- /dev/null +++ b/crates/nano-arrow/src/mmap/array.rs @@ -0,0 +1,568 @@ +use std::collections::VecDeque; +use std::sync::Arc; + +use crate::array::{Array, DictionaryKey, FixedSizeListArray, ListArray, StructArray}; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::ffi::mmap::create_array; +use crate::ffi::{export_array_to_c, try_from, ArrowArray, InternalArrowArray}; +use crate::io::ipc::read::{Dictionaries, IpcBuffer, Node, OutOfSpecKind}; +use crate::io::ipc::IpcField; +use crate::offset::Offset; +use crate::types::NativeType; + +fn get_buffer_bounds(buffers: &mut VecDeque) -> Result<(usize, usize), Error> { + let buffer = buffers + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: usize = buffer + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: usize = buffer + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok((offset, length)) +} + +fn get_buffer<'a, T: NativeType>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, + num_rows: usize, +) -> Result<&'a [u8], Error> { + let (offset, length) = get_buffer_bounds(buffers)?; + + // verify that they are in-bounds + let values = data + .get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| Error::OutOfSpec("buffer out of bounds".to_string()))?; + + // validate alignment + let v: &[T] = bytemuck::try_cast_slice(values) + .map_err(|_| Error::OutOfSpec("buffer not aligned for mmap".to_string()))?; + + if v.len() < num_rows { + return Err(Error::OutOfSpec( + "buffer's length is too small in mmap".to_string(), + )); + } + + Ok(values) +} + +fn get_validity<'a>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, + null_count: usize, +) -> Result, Error> { + let validity = get_buffer_bounds(buffers)?; + let (offset, length) = validity; + + Ok(if null_count > 0 { + // verify that they are in-bounds and get its pointer + Some( + data.get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| Error::OutOfSpec("buffer out of bounds".to_string()))?, + ) + } else { + None + }) +} + +fn mmap_binary>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let offsets = get_buffer::(data_ref, block_offset, buffers, num_rows + 1)?.as_ptr(); + let values = get_buffer::(data_ref, block_offset, buffers, 0)?.as_ptr(); + + // NOTE: offsets and values invariants are _not_ validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(offsets), Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_fixed_size_binary>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, + data_type: &DataType, +) -> Result { + let bytes_per_row = if let DataType::FixedSizeBinary(bytes_per_row) = data_type { + bytes_per_row + } else { + return Err(Error::from(OutOfSpecKind::InvalidDataType)); + }; + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + let values = + get_buffer::(data_ref, block_offset, buffers, num_rows * bytes_per_row)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_null>( + data: Arc, + node: &Node, + _block_offset: usize, + _buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_boolean>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer_bounds(buffers)?; + let (offset, length) = values; + + // verify that they are in-bounds and get its pointer + let values = data_ref[block_offset + offset..block_offset + offset + length].as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_primitive>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> Result { + let data_ref = data.as_ref().as_ref(); + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer::

(data_ref, block_offset, buffers, num_rows)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_list>( + data: Arc, + node: &Node, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let child = ListArray::::try_get_child(data_type)?.data_type(); + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let offsets = get_buffer::(data_ref, block_offset, buffers, num_rows + 1)?.as_ptr(); + + let values = get_array( + data.clone(), + block_offset, + child, + &ipc_field.fields[0], + dictionaries, + field_nodes, + buffers, + )?; + + // NOTE: offsets and values invariants are _not_ validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(offsets)].into_iter(), + [values].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_fixed_size_list>( + data: Arc, + node: &Node, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let child = FixedSizeListArray::try_child_and_size(data_type)? + .0 + .data_type(); + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_array( + data.clone(), + block_offset, + child, + &ipc_field.fields[0], + dictionaries, + field_nodes, + buffers, + )?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity].into_iter(), + [values].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_struct>( + data: Arc, + node: &Node, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let children = StructArray::try_get_fields(data_type)?; + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = children + .iter() + .map(|f| &f.data_type) + .zip(ipc_field.fields.iter()) + .map(|(child, ipc)| { + get_array( + data.clone(), + block_offset, + child, + ipc, + dictionaries, + field_nodes, + buffers, + ) + }) + .collect::, Error>>()?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity].into_iter(), + values.into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_dict>( + data: Arc, + node: &Node, + block_offset: usize, + _: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + _: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let dictionary = dictionaries + .get(&ipc_field.dictionary_id.unwrap()) + .ok_or_else(|| Error::oos("Missing dictionary"))? + .clone(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer::(data_ref, block_offset, buffers, num_rows)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + Some(export_array_to_c(dictionary)), + None, + ) + }) +} + +fn get_array>( + data: Arc, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + use crate::datatypes::PhysicalType::*; + let node = field_nodes + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + match data_type.to_physical_type() { + Null => mmap_null(data, &node, block_offset, buffers), + Boolean => mmap_boolean(data, &node, block_offset, buffers), + Primitive(p) => with_match_primitive_type!(p, |$T| { + mmap_primitive::<$T, _>(data, &node, block_offset, buffers) + }), + Utf8 | Binary => mmap_binary::(data, &node, block_offset, buffers), + FixedSizeBinary => mmap_fixed_size_binary(data, &node, block_offset, buffers, data_type), + LargeBinary | LargeUtf8 => mmap_binary::(data, &node, block_offset, buffers), + List => mmap_list::( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + LargeList => mmap_list::( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + FixedSizeList => mmap_fixed_size_list( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + Struct => mmap_struct( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + mmap_dict::<$T, _>( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ) + }), + _ => todo!(), + } +} + +/// Maps a memory region to an [`Array`]. +pub(crate) unsafe fn mmap>( + data: Arc, + block_offset: usize, + data_type: DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result, Error> { + let array = get_array( + data, + block_offset, + &data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + )?; + // The unsafety comes from the fact that `array` is not necessarily valid - + // the IPC file may be corrupted (e.g. invalid offsets or non-utf8 data) + unsafe { try_from(InternalArrowArray::new(array, data_type)) } +} diff --git a/crates/nano-arrow/src/mmap/mod.rs b/crates/nano-arrow/src/mmap/mod.rs new file mode 100644 index 000000000000..58265892ea57 --- /dev/null +++ b/crates/nano-arrow/src/mmap/mod.rs @@ -0,0 +1,227 @@ +//! Memory maps regions defined on the IPC format into [`Array`]. +use std::collections::VecDeque; +use std::sync::Arc; + +mod array; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, MessageRef, RecordBatchRef}; + +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; +use crate::io::ipc::read::file::{get_dictionary_batch, get_record_batch}; +use crate::io::ipc::read::{ + first_dict_field, Dictionaries, FileMetadata, IpcBuffer, Node, OutOfSpecKind, +}; +use crate::io::ipc::{IpcField, CONTINUATION_MARKER}; + +fn read_message( + mut bytes: &[u8], + block: arrow_format::ipc::Block, +) -> Result<(MessageRef, usize), Error> { + let offset: usize = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let block_length: usize = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + bytes = &bytes[offset..]; + let mut message_length = bytes[..4].try_into().unwrap(); + bytes = &bytes[4..]; + + if message_length == CONTINUATION_MARKER { + // continuation marker encountered, read message next + message_length = bytes[..4].try_into().unwrap(); + bytes = &bytes[4..]; + }; + + let message_length: usize = i32::from_le_bytes(message_length) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let message = arrow_format::ipc::MessageRef::read_as_root(&bytes[..message_length]) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + Ok((message, offset + block_length)) +} + +fn get_buffers_nodes( + batch: RecordBatchRef, +) -> Result<(VecDeque, VecDeque), Error> { + let compression = batch.compression()?; + if compression.is_some() { + return Err(Error::nyi( + "mmap can only be done on uncompressed IPC files", + )); + } + + let buffers = batch + .buffers() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageBuffers))?; + let buffers = buffers.iter().collect::>(); + + let field_nodes = batch + .nodes() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageNodes))?; + let field_nodes = field_nodes.iter().collect::>(); + + Ok((buffers, field_nodes)) +} + +unsafe fn _mmap_record>( + fields: &[Field], + ipc_fields: &[IpcField], + data: Arc, + batch: RecordBatchRef, + offset: usize, + dictionaries: &Dictionaries, +) -> Result>, Error> { + let (mut buffers, mut field_nodes) = get_buffers_nodes(batch)?; + + fields + .iter() + .map(|f| &f.data_type) + .cloned() + .zip(ipc_fields) + .map(|(data_type, ipc_field)| { + array::mmap( + data.clone(), + offset, + data_type, + ipc_field, + dictionaries, + &mut field_nodes, + &mut buffers, + ) + }) + .collect::>() + .and_then(Chunk::try_new) +} + +unsafe fn _mmap_unchecked>( + fields: &[Field], + ipc_fields: &[IpcField], + data: Arc, + block: Block, + dictionaries: &Dictionaries, +) -> Result>, Error> { + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_record_batch(message)?; + _mmap_record( + fields, + ipc_fields, + data.clone(), + batch, + offset, + dictionaries, + ) +} + +/// Memory maps an record batch from an IPC file into a [`Chunk`]. +/// # Errors +/// This function errors when: +/// * The IPC file is not valid +/// * the buffers on the file are un-aligned with their corresponding data. This can happen when: +/// * the file was written with 8-bit alignment +/// * the file contains type decimal 128 or 256 +/// # Safety +/// The caller must ensure that `data` contains a valid buffers, for example: +/// * Offsets in variable-sized containers must be in-bounds and increasing +/// * Utf8 data is valid +pub unsafe fn mmap_unchecked>( + metadata: &FileMetadata, + dictionaries: &Dictionaries, + data: Arc, + chunk: usize, +) -> Result>, Error> { + let block = metadata.blocks[chunk]; + + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_record_batch(message)?; + _mmap_record( + &metadata.schema.fields, + &metadata.ipc_schema.fields, + data.clone(), + batch, + offset, + dictionaries, + ) +} + +unsafe fn mmap_dictionary>( + metadata: &FileMetadata, + data: Arc, + block: Block, + dictionaries: &mut Dictionaries, +) -> Result<(), Error> { + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_dictionary_batch(&message)?; + + let id = batch + .id() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferId(err)))?; + let (first_field, first_ipc_field) = + first_dict_field(id, &metadata.schema.fields, &metadata.ipc_schema.fields)?; + + let batch = batch + .data() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingData))?; + + let value_type = + if let DataType::Dictionary(_, value_type, _) = first_field.data_type.to_logical_type() { + value_type.as_ref() + } else { + return Err(Error::from(OutOfSpecKind::InvalidIdDataType { + requested_id: id, + })); + }; + + // Make a fake schema for the dictionary batch. + let field = Field::new("", value_type.clone(), false); + + let chunk = _mmap_record( + &[field], + &[first_ipc_field.clone()], + data.clone(), + batch, + offset, + dictionaries, + )?; + + dictionaries.insert(id, chunk.into_arrays().pop().unwrap()); + + Ok(()) +} + +/// Memory maps dictionaries from an IPC file into +/// # Safety +/// The caller must ensure that `data` contains a valid buffers, for example: +/// * Offsets in variable-sized containers must be in-bounds and increasing +/// * Utf8 data is valid +pub unsafe fn mmap_dictionaries_unchecked>( + metadata: &FileMetadata, + data: Arc, +) -> Result { + let blocks = if let Some(blocks) = &metadata.dictionaries { + blocks + } else { + return Ok(Default::default()); + }; + + let mut dictionaries = Default::default(); + + blocks + .iter() + .cloned() + .try_for_each(|block| mmap_dictionary(metadata, data.clone(), block, &mut dictionaries))?; + Ok(dictionaries) +} diff --git a/crates/nano-arrow/src/offset.rs b/crates/nano-arrow/src/offset.rs new file mode 100644 index 000000000000..5bd06aeb7e57 --- /dev/null +++ b/crates/nano-arrow/src/offset.rs @@ -0,0 +1,537 @@ +//! Contains the declaration of [`Offset`] +use std::hint::unreachable_unchecked; + +use crate::buffer::Buffer; +use crate::error::Error; +pub use crate::types::Offset; + +/// A wrapper type of [`Vec`] representing the invariants of Arrow's offsets. +/// It is guaranteed to (sound to assume that): +/// * every element is `>= 0` +/// * element at position `i` is >= than element at position `i-1`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Offsets(Vec); + +impl Default for Offsets { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl TryFrom> for Offsets { + type Error = Error; + + #[inline] + fn try_from(offsets: Vec) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets)) + } +} + +impl TryFrom> for OffsetsBuffer { + type Error = Error; + + #[inline] + fn try_from(offsets: Buffer) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets)) + } +} + +impl TryFrom> for OffsetsBuffer { + type Error = Error; + + #[inline] + fn try_from(offsets: Vec) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets.into())) + } +} + +impl From> for OffsetsBuffer { + #[inline] + fn from(offsets: Offsets) -> Self { + Self(offsets.0.into()) + } +} + +impl Offsets { + /// Returns an empty [`Offsets`] (i.e. with a single element, the zero) + #[inline] + pub fn new() -> Self { + Self(vec![O::zero()]) + } + + /// Returns an [`Offsets`] whose all lengths are zero. + #[inline] + pub fn new_zeroed(length: usize) -> Self { + Self(vec![O::zero(); length + 1]) + } + + /// Creates a new [`Offsets`] from an iterator of lengths + #[inline] + pub fn try_from_iter>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut offsets = Self::with_capacity(lower); + for item in iterator { + offsets.try_push(item)? + } + Ok(offsets) + } + + /// Returns a new [`Offsets`] with a capacity, allocating at least `capacity + 1` entries. + pub fn with_capacity(capacity: usize) -> Self { + let mut offsets = Vec::with_capacity(capacity + 1); + offsets.push(O::zero()); + Self(offsets) + } + + /// Returns the capacity of [`Offsets`]. + pub fn capacity(&self) -> usize { + self.0.capacity() - 1 + } + + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + self.0.reserve(additional); + } + + /// Shrinks the capacity of self to fit. + pub fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit(); + } + + /// Pushes a new element with a given length. + /// # Error + /// This function errors iff the new last item is larger than what `O` supports. + /// # Implementation + /// This function: + /// * checks that this length does not overflow + #[inline] + pub fn try_push(&mut self, length: usize) -> Result<(), Error> { + if O::IS_LARGE { + let length = O::from_as_usize(length); + let old_length = self.last(); + let new_length = *old_length + length; + self.0.push(new_length); + Ok(()) + } else { + let length = O::from_usize(length).ok_or(Error::Overflow)?; + + let old_length = self.last(); + let new_length = old_length.checked_add(&length).ok_or(Error::Overflow)?; + self.0.push(new_length); + Ok(()) + } + } + + /// Returns [`Offsets`] assuming that `offsets` fulfills its invariants + /// # Safety + /// This is safe iff the invariants of this struct are guaranteed in `offsets`. + #[inline] + pub unsafe fn new_unchecked(offsets: Vec) -> Self { + Self(offsets) + } + + /// Returns the last offset of this container. + #[inline] + pub fn last(&self) -> &O { + match self.0.last() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn start_end(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + assert!(index < self.len_proxy()); + unsafe { self.start_end_unchecked(index) } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Safety + /// `index` must be `< self.len()` + #[inline] + pub unsafe fn start_end_unchecked(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + let start = self.0.get_unchecked(index).to_usize(); + let end = self.0.get_unchecked(index + 1).to_usize(); + (start, end) + } + + /// Returns the length an array with these offsets would be. + #[inline] + pub fn len_proxy(&self) -> usize { + self.0.len() - 1 + } + + #[inline] + /// Returns the number of offsets in this container. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[O] { + self.0.as_slice() + } + + /// Pops the last element + #[inline] + pub fn pop(&mut self) -> Option { + if self.len_proxy() == 0 { + None + } else { + self.0.pop() + } + } + + /// Extends itself with `additional` elements equal to the last offset. + /// This is useful to extend offsets with empty values, e.g. for null slots. + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + let offset = *self.last(); + if additional == 1 { + self.0.push(offset) + } else { + self.0.resize(self.len() + additional, offset) + } + } + + /// Try to create a new [`Offsets`] from a sequence of `lengths` + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + #[inline] + pub fn try_from_lengths>(lengths: I) -> Result { + let mut self_ = Self::with_capacity(lengths.size_hint().0); + self_.try_extend_from_lengths(lengths)?; + Ok(self_) + } + + /// Try extend from an iterator of lengths + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + #[inline] + pub fn try_extend_from_lengths>( + &mut self, + lengths: I, + ) -> Result<(), Error> { + let mut total_length = 0; + let mut offset = *self.last(); + let original_offset = offset.to_usize(); + + let lengths = lengths.map(|length| { + total_length += length; + O::from_as_usize(length) + }); + + let offsets = lengths.map(|length| { + offset += length; // this may overflow, checked below + offset + }); + self.0.extend(offsets); + + let last_offset = original_offset + .checked_add(total_length) + .ok_or(Error::Overflow)?; + O::from_usize(last_offset).ok_or(Error::Overflow)?; + Ok(()) + } + + /// Extends itself from another [`Offsets`] + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + pub fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + let mut length = *self.last(); + let other_length = *other.last(); + // check if the operation would overflow + length.checked_add(&other_length).ok_or(Error::Overflow)?; + + let lengths = other.as_slice().windows(2).map(|w| w[1] - w[0]); + let offsets = lengths.map(|new_length| { + length += new_length; + length + }); + self.0.extend(offsets); + Ok(()) + } + + /// Extends itself from another [`Offsets`] sliced by `start, length` + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + pub fn try_extend_from_slice( + &mut self, + other: &OffsetsBuffer, + start: usize, + length: usize, + ) -> Result<(), Error> { + if length == 0 { + return Ok(()); + } + let other = &other.0[start..start + length + 1]; + let other_length = other.last().expect("Length to be non-zero"); + let mut length = *self.last(); + // check if the operation would overflow + length.checked_add(other_length).ok_or(Error::Overflow)?; + + let lengths = other.windows(2).map(|w| w[1] - w[0]); + let offsets = lengths.map(|new_length| { + length += new_length; + length + }); + self.0.extend(offsets); + Ok(()) + } + + /// Returns the inner [`Vec`]. + #[inline] + pub fn into_inner(self) -> Vec { + self.0 + } +} + +/// Checks that `offsets` is monotonically increasing. +fn try_check_offsets(offsets: &[O]) -> Result<(), Error> { + // this code is carefully constructed to auto-vectorize, don't change naively! + match offsets.first() { + None => Err(Error::oos("offsets must have at least one element")), + Some(first) => { + if *first < O::zero() { + return Err(Error::oos("offsets must be larger than 0")); + } + let mut previous = *first; + let mut any_invalid = false; + + // This loop will auto-vectorize because there is not any break, + // an invalid value will be returned once the whole offsets buffer is processed. + for offset in offsets { + if previous > *offset { + any_invalid = true + } + previous = *offset; + } + + if any_invalid { + Err(Error::oos("offsets must be monotonically increasing")) + } else { + Ok(()) + } + }, + } +} + +/// A wrapper type of [`Buffer`] that is guaranteed to: +/// * Always contain an element +/// * Every element is `>= 0` +/// * element at position `i` is >= than element at position `i-1`. +#[derive(Clone, PartialEq, Debug)] +pub struct OffsetsBuffer(Buffer); + +impl Default for OffsetsBuffer { + #[inline] + fn default() -> Self { + Self(vec![O::zero()].into()) + } +} + +impl OffsetsBuffer { + /// # Safety + /// This is safe iff the invariants of this struct are guaranteed in `offsets`. + #[inline] + pub unsafe fn new_unchecked(offsets: Buffer) -> Self { + Self(offsets) + } + + /// Returns an empty [`OffsetsBuffer`] (i.e. with a single element, the zero) + #[inline] + pub fn new() -> Self { + Self(vec![O::zero()].into()) + } + + /// Copy-on-write API to convert [`OffsetsBuffer`] into [`Offsets`]. + #[inline] + pub fn into_mut(self) -> either::Either> { + self.0 + .into_mut() + // Safety: Offsets and OffsetsBuffer share invariants + .map_right(|offsets| unsafe { Offsets::new_unchecked(offsets) }) + .map_left(Self) + } + + /// Returns a reference to its internal [`Buffer`]. + #[inline] + pub fn buffer(&self) -> &Buffer { + &self.0 + } + + /// Returns the length an array with these offsets would be. + #[inline] + pub fn len_proxy(&self) -> usize { + self.0.len() - 1 + } + + /// Returns the number of offsets in this container. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[O] { + self.0.as_slice() + } + + /// Returns the range of the offsets. + #[inline] + pub fn range(&self) -> O { + *self.last() - *self.first() + } + + /// Returns the first offset. + #[inline] + pub fn first(&self) -> &O { + match self.0.first() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns the last offset. + #[inline] + pub fn last(&self) -> &O { + match self.0.last() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn start_end(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + assert!(index < self.len_proxy()); + unsafe { self.start_end_unchecked(index) } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Safety + /// `index` must be `< self.len()` + #[inline] + pub unsafe fn start_end_unchecked(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + let start = self.0.get_unchecked(index).to_usize(); + let end = self.0.get_unchecked(index + 1).to_usize(); + (start, end) + } + + /// Slices this [`OffsetsBuffer`]. + /// # Panics + /// Panics if `offset + length` is larger than `len` + /// or `length == 0`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!(length > 0); + self.0.slice(offset, length); + } + + /// Slices this [`OffsetsBuffer`] starting at `offset`. + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.0.slice_unchecked(offset, length); + } + + /// Returns an iterator with the lengths of the offsets + #[inline] + pub fn lengths(&self) -> impl Iterator + '_ { + self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) + } + + /// Returns the inner [`Buffer`]. + #[inline] + pub fn into_inner(self) -> Buffer { + self.0 + } +} + +impl From<&OffsetsBuffer> for OffsetsBuffer { + fn from(offsets: &OffsetsBuffer) -> Self { + // this conversion is lossless and uphelds all invariants + Self( + offsets + .buffer() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(), + ) + } +} + +impl TryFrom<&OffsetsBuffer> for OffsetsBuffer { + type Error = Error; + + fn try_from(offsets: &OffsetsBuffer) -> Result { + i32::try_from(*offsets.last()).map_err(|_| Error::Overflow)?; + + // this conversion is lossless and uphelds all invariants + Ok(Self( + offsets + .buffer() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(), + )) + } +} + +impl From> for Offsets { + fn from(offsets: Offsets) -> Self { + // this conversion is lossless and uphelds all invariants + Self( + offsets + .as_slice() + .iter() + .map(|x| *x as i64) + .collect::>(), + ) + } +} + +impl TryFrom> for Offsets { + type Error = Error; + + fn try_from(offsets: Offsets) -> Result { + i32::try_from(*offsets.last()).map_err(|_| Error::Overflow)?; + + // this conversion is lossless and uphelds all invariants + Ok(Self( + offsets + .as_slice() + .iter() + .map(|x| *x as i32) + .collect::>(), + )) + } +} + +impl std::ops::Deref for OffsetsBuffer { + type Target = [O]; + + #[inline] + fn deref(&self) -> &[O] { + self.0.as_slice() + } +} diff --git a/crates/nano-arrow/src/scalar/README.md b/crates/nano-arrow/src/scalar/README.md new file mode 100644 index 000000000000..b780081b6131 --- /dev/null +++ b/crates/nano-arrow/src/scalar/README.md @@ -0,0 +1,28 @@ +# Scalar API + +Design choices: + +### `Scalar` is trait object + +There are three reasons: + +- a scalar should have a small memory footprint, which an enum would not ensure given the different physical types available. +- forward-compatibility: a new entry on an `enum` is backward-incompatible +- do not expose implementation details to users (reduce the surface of the public API) + +### `Scalar` MUST contain nullability information + +This is to be aligned with the general notion of arrow's `Array`. + +This API is a companion to the `Array`, and follows the same design as `Array`. +Specifically, a `Scalar` is a trait object that can be downcasted to concrete implementations. + +Like `Array`, `Scalar` implements + +- `data_type`, which is used to perform the correct downcast +- `is_valid`, to tell whether the scalar is null or not + +### There is one implementation per arrows' physical type + +- Reduces the number of `match` that users need to write +- Allows casting of logical types without changing the underlying physical type diff --git a/crates/nano-arrow/src/scalar/binary.rs b/crates/nano-arrow/src/scalar/binary.rs new file mode 100644 index 000000000000..0d33f6f8f7e4 --- /dev/null +++ b/crates/nano-arrow/src/scalar/binary.rs @@ -0,0 +1,55 @@ +use super::Scalar; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// The [`Scalar`] implementation of binary ([`Option>`]). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BinaryScalar { + value: Option>, + phantom: std::marker::PhantomData, +} + +impl BinaryScalar { + /// Returns a new [`BinaryScalar`]. + #[inline] + pub fn new>>(value: Option

) -> Self { + Self { + value: value.map(|x| x.into()), + phantom: std::marker::PhantomData, + } + } + + /// Its value + #[inline] + pub fn value(&self) -> Option<&[u8]> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl>> From> for BinaryScalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for BinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + if O::IS_LARGE { + &DataType::LargeBinary + } else { + &DataType::Binary + } + } +} diff --git a/crates/nano-arrow/src/scalar/boolean.rs b/crates/nano-arrow/src/scalar/boolean.rs new file mode 100644 index 000000000000..aa7d435859af --- /dev/null +++ b/crates/nano-arrow/src/scalar/boolean.rs @@ -0,0 +1,46 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// The [`Scalar`] implementation of a boolean. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BooleanScalar { + value: Option, +} + +impl BooleanScalar { + /// Returns a new [`BooleanScalar`] + #[inline] + pub fn new(value: Option) -> Self { + Self { value } + } + + /// The value + #[inline] + pub fn value(&self) -> Option { + self.value + } +} + +impl Scalar for BooleanScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Boolean + } +} + +impl From> for BooleanScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(v) + } +} diff --git a/crates/nano-arrow/src/scalar/dictionary.rs b/crates/nano-arrow/src/scalar/dictionary.rs new file mode 100644 index 000000000000..97e3e5916f52 --- /dev/null +++ b/crates/nano-arrow/src/scalar/dictionary.rs @@ -0,0 +1,54 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; + +/// The [`DictionaryArray`] equivalent of [`Array`] for [`Scalar`]. +#[derive(Debug, Clone)] +pub struct DictionaryScalar { + value: Option>, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for DictionaryScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) && (self.value.as_ref() == other.value.as_ref()) + } +} + +impl DictionaryScalar { + /// returns a new [`DictionaryScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, value: Option>) -> Self { + Self { + value, + phantom: std::marker::PhantomData, + data_type, + } + } + + /// The values of the [`DictionaryScalar`] + pub fn value(&self) -> Option<&Box> { + self.value.as_ref() + } +} + +impl Scalar for DictionaryScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.value.is_some() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/equal.rs b/crates/nano-arrow/src/scalar/equal.rs new file mode 100644 index 000000000000..34f98d23640d --- /dev/null +++ b/crates/nano-arrow/src/scalar/equal.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use super::*; +use crate::datatypes::PhysicalType; + +impl PartialEq for dyn Scalar + '_ { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(self, that) + } +} + +impl PartialEq for Arc { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for Box { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(&**self, that) + } +} + +macro_rules! dyn_eq { + ($ty:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap(); + let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap(); + lhs == rhs + }}; +} + +fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { + if lhs.data_type() != rhs.data_type() { + return false; + } + + use PhysicalType::*; + match lhs.data_type().to_physical_type() { + Null => dyn_eq!(NullScalar, lhs, rhs), + Boolean => dyn_eq!(BooleanScalar, lhs, rhs), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + dyn_eq!(PrimitiveScalar<$T>, lhs, rhs) + }), + LargeUtf8 => dyn_eq!(Utf8Scalar, lhs, rhs), + LargeBinary => dyn_eq!(BinaryScalar, lhs, rhs), + LargeList => dyn_eq!(ListScalar, lhs, rhs), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + dyn_eq!(DictionaryScalar<$T>, lhs, rhs) + }), + Struct => dyn_eq!(StructScalar, lhs, rhs), + FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs), + FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), + Union => dyn_eq!(UnionScalar, lhs, rhs), + Map => dyn_eq!(MapScalar, lhs, rhs), + _ => unimplemented!(), + } +} diff --git a/crates/nano-arrow/src/scalar/fixed_size_binary.rs b/crates/nano-arrow/src/scalar/fixed_size_binary.rs new file mode 100644 index 000000000000..d8fbb96bac2c --- /dev/null +++ b/crates/nano-arrow/src/scalar/fixed_size_binary.rs @@ -0,0 +1,58 @@ +use super::Scalar; +use crate::datatypes::DataType; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// The [`Scalar`] implementation of fixed size binary ([`Option>`]). +pub struct FixedSizeBinaryScalar { + value: Option>, + data_type: DataType, +} + +impl FixedSizeBinaryScalar { + /// Returns a new [`FixedSizeBinaryScalar`]. + /// # Panics + /// iff + /// * the `data_type` is not `FixedSizeBinary` + /// * the size of child binary is not equal + #[inline] + pub fn new>>(data_type: DataType, value: Option

) -> Self { + assert_eq!( + data_type.to_physical_type(), + crate::datatypes::PhysicalType::FixedSizeBinary + ); + Self { + value: value.map(|x| { + let x: Vec = x.into(); + assert_eq!( + data_type.to_logical_type(), + &DataType::FixedSizeBinary(x.len()) + ); + x.into_boxed_slice() + }), + data_type, + } + } + + /// Its value + #[inline] + pub fn value(&self) -> Option<&[u8]> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl Scalar for FixedSizeBinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/fixed_size_list.rs b/crates/nano-arrow/src/scalar/fixed_size_list.rs new file mode 100644 index 000000000000..b8333c02c347 --- /dev/null +++ b/crates/nano-arrow/src/scalar/fixed_size_list.rs @@ -0,0 +1,60 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; + +/// The scalar equivalent of [`FixedSizeListArray`]. Like [`FixedSizeListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct FixedSizeListScalar { + values: Option>, + data_type: DataType, +} + +impl PartialEq for FixedSizeListScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.values.is_some() == other.values.is_some()) + && ((self.values.is_none()) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl FixedSizeListScalar { + /// returns a new [`FixedSizeListScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `FixedSizeList` + /// * the child of the `data_type` is not equal to the `values` + /// * the size of child array is not equal + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let (field, size) = FixedSizeListArray::get_child_and_size(&data_type); + let inner_data_type = field.data_type(); + let values = values.map(|x| { + assert_eq!(inner_data_type, x.data_type()); + assert_eq!(size, x.len()); + x + }); + Self { values, data_type } + } + + /// The values of the [`FixedSizeListScalar`] + pub fn values(&self) -> Option<&Box> { + self.values.as_ref() + } +} + +impl Scalar for FixedSizeListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.values.is_some() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/list.rs b/crates/nano-arrow/src/scalar/list.rs new file mode 100644 index 000000000000..d82bf02768bf --- /dev/null +++ b/crates/nano-arrow/src/scalar/list.rs @@ -0,0 +1,68 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// The scalar equivalent of [`ListArray`]. Like [`ListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct ListScalar { + values: Box, + is_valid: bool, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for ListScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl ListScalar { + /// returns a new [`ListScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let inner_data_type = ListArray::::get_child_type(&data_type); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_data_type, values.data_type()); + (true, values) + }, + None => (false, new_empty_array(inner_data_type.clone())), + }; + Self { + values, + is_valid, + phantom: std::marker::PhantomData, + data_type, + } + } + + /// The values of the [`ListScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for ListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/map.rs b/crates/nano-arrow/src/scalar/map.rs new file mode 100644 index 000000000000..90145fb6a30f --- /dev/null +++ b/crates/nano-arrow/src/scalar/map.rs @@ -0,0 +1,66 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; + +/// The scalar equivalent of [`MapArray`]. Like [`MapArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct MapScalar { + values: Box, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for MapScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl MapScalar { + /// returns a new [`MapScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `Map` + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let inner_field = MapArray::try_get_field(&data_type).unwrap(); + let inner_data_type = inner_field.data_type(); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_data_type, values.data_type()); + (true, values) + }, + None => (false, new_empty_array(inner_data_type.clone())), + }; + Self { + values, + is_valid, + data_type, + } + } + + /// The values of the [`MapScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for MapScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/mod.rs b/crates/nano-arrow/src/scalar/mod.rs new file mode 100644 index 000000000000..7b78b93a44f2 --- /dev/null +++ b/crates/nano-arrow/src/scalar/mod.rs @@ -0,0 +1,187 @@ +//! contains the [`Scalar`] trait object representing individual items of [`Array`](crate::array::Array)s, +//! as well as concrete implementations such as [`BooleanScalar`]. +use std::any::Any; + +use crate::array::*; +use crate::datatypes::*; + +mod dictionary; +pub use dictionary::*; +mod equal; +mod primitive; +pub use primitive::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod boolean; +pub use boolean::*; +mod list; +pub use list::*; +mod map; +pub use map::*; +mod null; +pub use null::*; +mod struct_; +pub use struct_::*; +mod fixed_size_list; +pub use fixed_size_list::*; +mod fixed_size_binary; +pub use fixed_size_binary::*; +mod union; +pub use union::UnionScalar; + +/// Trait object declaring an optional value with a [`DataType`]. +/// This strait is often used in APIs that accept multiple scalar types. +pub trait Scalar: std::fmt::Debug + Send + Sync + dyn_clone::DynClone + 'static { + /// convert itself to + fn as_any(&self) -> &dyn Any; + + /// whether it is valid + fn is_valid(&self) -> bool; + + /// the logical type. + fn data_type(&self) -> &DataType; +} + +dyn_clone::clone_trait_object!(Scalar); + +macro_rules! dyn_new_utf8 { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(Utf8Scalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_binary { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(BinaryScalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_list { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index).into()) + } else { + None + }; + Box::new(ListScalar::<$type>::new(array.data_type().clone(), value)) + }}; +} + +/// creates a new [`Scalar`] from an [`Array`]. +pub fn new_scalar(array: &dyn Array, index: usize) -> Box { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => Box::new(NullScalar::new()), + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(BooleanScalar::new(value)) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(PrimitiveScalar::new(array.data_type().clone(), value)) + }), + Utf8 => dyn_new_utf8!(array, index, i32), + LargeUtf8 => dyn_new_utf8!(array, index, i64), + Binary => dyn_new_binary!(array, index, i32), + LargeBinary => dyn_new_binary!(array, index, i64), + List => dyn_new_list!(array, index, i32), + LargeList => dyn_new_list!(array, index, i64), + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_valid(index) { + let values = array + .values() + .iter() + .map(|x| new_scalar(x.as_ref(), index)) + .collect(); + Box::new(StructScalar::new(array.data_type().clone(), Some(values))) + } else { + Box::new(StructScalar::new(array.data_type().clone(), None)) + } + }, + FixedSizeBinary => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(FixedSizeBinaryScalar::new(array.data_type().clone(), value)) + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(FixedSizeListScalar::new(array.data_type().clone(), value)) + }, + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + Box::new(UnionScalar::new( + array.data_type().clone(), + array.types()[index], + array.value(index), + )) + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(MapScalar::new(array.data_type().clone(), value)) + }, + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index).into()) + } else { + None + }; + Box::new(DictionaryScalar::<$T>::new( + array.data_type().clone(), + value, + )) + }), + } +} diff --git a/crates/nano-arrow/src/scalar/null.rs b/crates/nano-arrow/src/scalar/null.rs new file mode 100644 index 000000000000..2de7d7cde55b --- /dev/null +++ b/crates/nano-arrow/src/scalar/null.rs @@ -0,0 +1,37 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// The representation of a single entry of a [`crate::array::NullArray`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NullScalar {} + +impl NullScalar { + /// A new [`NullScalar`] + #[inline] + pub fn new() -> Self { + Self {} + } +} + +impl Default for NullScalar { + fn default() -> Self { + Self::new() + } +} + +impl Scalar for NullScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + false + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Null + } +} diff --git a/crates/nano-arrow/src/scalar/primitive.rs b/crates/nano-arrow/src/scalar/primitive.rs new file mode 100644 index 000000000000..3288708f6755 --- /dev/null +++ b/crates/nano-arrow/src/scalar/primitive.rs @@ -0,0 +1,67 @@ +use super::Scalar; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::types::NativeType; + +/// The implementation of [`Scalar`] for primitive, semantically equivalent to [`Option`] +/// with [`DataType`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrimitiveScalar { + value: Option, + data_type: DataType, +} + +impl PrimitiveScalar { + /// Returns a new [`PrimitiveScalar`]. + #[inline] + pub fn new(data_type: DataType, value: Option) -> Self { + if !data_type.to_physical_type().eq_primitive(T::PRIMITIVE) { + panic!( + "{:?}", + Error::InvalidArgumentError(format!( + "Type {} does not support logical type {:?}", + std::any::type_name::(), + data_type + )) + ) + } + Self { value, data_type } + } + + /// Returns the optional value. + #[inline] + pub fn value(&self) -> &Option { + &self.value + } + + /// Returns a new `PrimitiveScalar` with the same value but different [`DataType`] + /// # Panic + /// This function panics if the `data_type` is not valid for self's physical type `T`. + pub fn to(self, data_type: DataType) -> Self { + Self::new(data_type, self.value) + } +} + +impl From> for PrimitiveScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(T::PRIMITIVE.into(), v) + } +} + +impl Scalar for PrimitiveScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/struct_.rs b/crates/nano-arrow/src/scalar/struct_.rs new file mode 100644 index 000000000000..29c2c33ba295 --- /dev/null +++ b/crates/nano-arrow/src/scalar/struct_.rs @@ -0,0 +1,54 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// A single entry of a [`crate::array::StructArray`]. +#[derive(Debug, Clone)] +pub struct StructScalar { + values: Vec>, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for StructScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values == other.values)) + } +} + +impl StructScalar { + /// Returns a new [`StructScalar`] + #[inline] + pub fn new(data_type: DataType, values: Option>>) -> Self { + let is_valid = values.is_some(); + Self { + values: values.unwrap_or_default(), + is_valid, + data_type, + } + } + + /// Returns the values irrespectively of the validity. + #[inline] + pub fn values(&self) -> &[Box] { + &self.values + } +} + +impl Scalar for StructScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/union.rs b/crates/nano-arrow/src/scalar/union.rs new file mode 100644 index 000000000000..987e9f4e6044 --- /dev/null +++ b/crates/nano-arrow/src/scalar/union.rs @@ -0,0 +1,51 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// A single entry of a [`crate::array::UnionArray`]. +#[derive(Debug, Clone, PartialEq)] +pub struct UnionScalar { + value: Box, + type_: i8, + data_type: DataType, +} + +impl UnionScalar { + /// Returns a new [`UnionScalar`] + #[inline] + pub fn new(data_type: DataType, type_: i8, value: Box) -> Self { + Self { + value, + type_, + data_type, + } + } + + /// Returns the inner value + #[inline] + pub fn value(&self) -> &Box { + &self.value + } + + /// Returns the type of the union scalar + #[inline] + pub fn type_(&self) -> i8 { + self.type_ + } +} + +impl Scalar for UnionScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + true + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/utf8.rs b/crates/nano-arrow/src/scalar/utf8.rs new file mode 100644 index 000000000000..ea08d30af578 --- /dev/null +++ b/crates/nano-arrow/src/scalar/utf8.rs @@ -0,0 +1,55 @@ +use super::Scalar; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// The implementation of [`Scalar`] for utf8, semantically equivalent to [`Option`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Utf8Scalar { + value: Option, + phantom: std::marker::PhantomData, +} + +impl Utf8Scalar { + /// Returns a new [`Utf8Scalar`] + #[inline] + pub fn new>(value: Option

) -> Self { + Self { + value: value.map(|x| x.into()), + phantom: std::marker::PhantomData, + } + } + + /// Returns the value irrespectively of the validity. + #[inline] + pub fn value(&self) -> Option<&str> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl> From> for Utf8Scalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for Utf8Scalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + if O::IS_LARGE { + &DataType::LargeUtf8 + } else { + &DataType::Utf8 + } + } +} diff --git a/crates/nano-arrow/src/temporal_conversions.rs b/crates/nano-arrow/src/temporal_conversions.rs new file mode 100644 index 000000000000..5058d1d887bd --- /dev/null +++ b/crates/nano-arrow/src/temporal_conversions.rs @@ -0,0 +1,543 @@ +//! Conversion methods for dates and times. + +use chrono::format::{parse, Parsed, StrftimeItems}; +use chrono::{Datelike, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; + +use crate::array::{PrimitiveArray, Utf8Array}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::error::{Error, Result}; +use crate::offset::Offset; +use crate::types::months_days_ns; + +/// Number of seconds in a day +pub const SECONDS_IN_DAY: i64 = 86_400; +/// Number of milliseconds in a second +pub const MILLISECONDS: i64 = 1_000; +/// Number of microseconds in a second +pub const MICROSECONDS: i64 = 1_000_000; +/// Number of nanoseconds in a second +pub const NANOSECONDS: i64 = 1_000_000_000; +/// Number of milliseconds in a day +pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime(v: i32) -> NaiveDateTime { + date32_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime_opt(v: i32) -> Option { + NaiveDateTime::from_timestamp_opt(v as i64 * SECONDS_IN_DAY, 0) +} + +/// converts a `i32` representing a `date32` to [`NaiveDate`] +#[inline] +pub fn date32_to_date(days: i32) -> NaiveDate { + date32_to_date_opt(days).expect("out-of-range date") +} + +/// converts a `i32` representing a `date32` to [`NaiveDate`] +#[inline] +pub fn date32_to_date_opt(days: i32) -> Option { + NaiveDate::from_num_days_from_ce_opt(EPOCH_DAYS_FROM_CE + days) +} + +/// converts a `i64` representing a `date64` to [`NaiveDateTime`] +#[inline] +pub fn date64_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp_opt( + // extract seconds from milliseconds + v / MILLISECONDS, + // discard extracted seconds and convert milliseconds to nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) + .expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `date64` to [`NaiveDate`] +#[inline] +pub fn date64_to_date(milliseconds: i64) -> NaiveDate { + date64_to_datetime(milliseconds).date() +} + +/// converts a `i32` representing a `time32(s)` to [`NaiveTime`] +#[inline] +pub fn time32s_to_time(v: i32) -> NaiveTime { + NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0).expect("invalid time") +} + +/// converts a `i64` representing a `duration(s)` to [`Duration`] +#[inline] +pub fn duration_s_to_duration(v: i64) -> Duration { + Duration::seconds(v) +} + +/// converts a `i64` representing a `duration(ms)` to [`Duration`] +#[inline] +pub fn duration_ms_to_duration(v: i64) -> Duration { + Duration::milliseconds(v) +} + +/// converts a `i64` representing a `duration(us)` to [`Duration`] +#[inline] +pub fn duration_us_to_duration(v: i64) -> Duration { + Duration::microseconds(v) +} + +/// converts a `i64` representing a `duration(ns)` to [`Duration`] +#[inline] +pub fn duration_ns_to_duration(v: i64) -> Duration { + Duration::nanoseconds(v) +} + +/// converts a `i32` representing a `time32(ms)` to [`NaiveTime`] +#[inline] +pub fn time32ms_to_time(v: i32) -> NaiveTime { + let v = v as i64; + let seconds = v / MILLISECONDS; + + let milli_to_nano = 1_000_000; + let nano = (v - seconds * MILLISECONDS) * milli_to_nano; + NaiveTime::from_num_seconds_from_midnight_opt(seconds as u32, nano as u32) + .expect("invalid time") +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveTime`] +#[inline] +pub fn time64us_to_time(v: i64) -> NaiveTime { + time64us_to_time_opt(v).expect("invalid time") +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveTime`] +#[inline] +pub fn time64us_to_time_opt(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from microseconds + (v / MICROSECONDS) as u32, + // discard extracted seconds and convert microseconds to + // nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveTime`] +#[inline] +pub fn time64ns_to_time(v: i64) -> NaiveTime { + time64ns_to_time_opt(v).expect("invalid time") +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveTime`] +#[inline] +pub fn time64ns_to_time_opt(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from nanoseconds + (v / NANOSECONDS) as u32, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime(seconds: i64) -> NaiveDateTime { + timestamp_s_to_datetime_opt(seconds).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime_opt(seconds: i64) -> Option { + NaiveDateTime::from_timestamp_opt(seconds, 0) +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime(v: i64) -> NaiveDateTime { + timestamp_ms_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime_opt(v: i64) -> Option { + if v >= 0 { + NaiveDateTime::from_timestamp_opt( + // extract seconds from milliseconds + v / MILLISECONDS, + // discard extracted seconds and convert milliseconds to nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) + } else { + let secs_rem = (v / MILLISECONDS, v % MILLISECONDS); + if secs_rem.1 == 0 { + // whole/integer seconds; no adjustment required + NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) + } else { + // negative values with fractional seconds require 'div_floor' rounding behaviour. + // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) + NaiveDateTime::from_timestamp_opt( + secs_rem.0 - 1, + (NANOSECONDS + (v % MILLISECONDS * MICROSECONDS)) as u32, + ) + } + } +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime(v: i64) -> NaiveDateTime { + timestamp_us_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime_opt(v: i64) -> Option { + if v >= 0 { + NaiveDateTime::from_timestamp_opt( + // extract seconds from microseconds + v / MICROSECONDS, + // discard extracted seconds and convert microseconds to nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) + } else { + let secs_rem = (v / MICROSECONDS, v % MICROSECONDS); + if secs_rem.1 == 0 { + // whole/integer seconds; no adjustment required + NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) + } else { + // negative values with fractional seconds require 'div_floor' rounding behaviour. + // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) + NaiveDateTime::from_timestamp_opt( + secs_rem.0 - 1, + (NANOSECONDS + (v % MICROSECONDS * MILLISECONDS)) as u32, + ) + } + } +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime(v: i64) -> NaiveDateTime { + timestamp_ns_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime_opt(v: i64) -> Option { + if v >= 0 { + NaiveDateTime::from_timestamp_opt( + // extract seconds from nanoseconds + v / NANOSECONDS, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) + } else { + let secs_rem = (v / NANOSECONDS, v % NANOSECONDS); + if secs_rem.1 == 0 { + // whole/integer seconds; no adjustment required + NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) + } else { + // negative values with fractional seconds require 'div_floor' rounding behaviour. + // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) + NaiveDateTime::from_timestamp_opt( + secs_rem.0 - 1, + (NANOSECONDS + (v % NANOSECONDS)) as u32, + ) + } + } +} + +/// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. +#[inline] +pub fn timestamp_to_naive_datetime(timestamp: i64, time_unit: TimeUnit) -> chrono::NaiveDateTime { + match time_unit { + TimeUnit::Second => timestamp_s_to_datetime(timestamp), + TimeUnit::Millisecond => timestamp_ms_to_datetime(timestamp), + TimeUnit::Microsecond => timestamp_us_to_datetime(timestamp), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(timestamp), + } +} + +/// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. +#[inline] +pub fn timestamp_to_datetime( + timestamp: i64, + time_unit: TimeUnit, + timezone: &T, +) -> chrono::DateTime { + timezone.from_utc_datetime(×tamp_to_naive_datetime(timestamp, time_unit)) +} + +/// Calculates the scale factor between two TimeUnits. The function returns the +/// scale that should multiply the TimeUnit "b" to have the same time scale as +/// the TimeUnit "a". +pub fn timeunit_scale(a: TimeUnit, b: TimeUnit) -> f64 { + match (a, b) { + (TimeUnit::Second, TimeUnit::Second) => 1.0, + (TimeUnit::Second, TimeUnit::Millisecond) => 0.001, + (TimeUnit::Second, TimeUnit::Microsecond) => 0.000_001, + (TimeUnit::Second, TimeUnit::Nanosecond) => 0.000_000_001, + (TimeUnit::Millisecond, TimeUnit::Second) => 1_000.0, + (TimeUnit::Millisecond, TimeUnit::Millisecond) => 1.0, + (TimeUnit::Millisecond, TimeUnit::Microsecond) => 0.001, + (TimeUnit::Millisecond, TimeUnit::Nanosecond) => 0.000_001, + (TimeUnit::Microsecond, TimeUnit::Second) => 1_000_000.0, + (TimeUnit::Microsecond, TimeUnit::Millisecond) => 1_000.0, + (TimeUnit::Microsecond, TimeUnit::Microsecond) => 1.0, + (TimeUnit::Microsecond, TimeUnit::Nanosecond) => 0.001, + (TimeUnit::Nanosecond, TimeUnit::Second) => 1_000_000_000.0, + (TimeUnit::Nanosecond, TimeUnit::Millisecond) => 1_000_000.0, + (TimeUnit::Nanosecond, TimeUnit::Microsecond) => 1_000.0, + (TimeUnit::Nanosecond, TimeUnit::Nanosecond) => 1.0, + } +} + +/// Parses an offset of the form `"+WX:YZ"` or `"UTC"` into [`FixedOffset`]. +/// # Errors +/// If the offset is not in any of the allowed forms. +pub fn parse_offset(offset: &str) -> Result { + if offset == "UTC" { + return Ok(FixedOffset::east_opt(0).expect("FixedOffset::east out of bounds")); + } + let error = "timezone offset must be of the form [-]00:00"; + + let mut a = offset.split(':'); + let first = a + .next() + .map(Ok) + .unwrap_or_else(|| Err(Error::InvalidArgumentError(error.to_string())))?; + let last = a + .next() + .map(Ok) + .unwrap_or_else(|| Err(Error::InvalidArgumentError(error.to_string())))?; + let hours: i32 = first + .parse() + .map_err(|_| Error::InvalidArgumentError(error.to_string()))?; + let minutes: i32 = last + .parse() + .map_err(|_| Error::InvalidArgumentError(error.to_string()))?; + + Ok(FixedOffset::east_opt(hours * 60 * 60 + minutes * 60) + .expect("FixedOffset::east out of bounds")) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). +#[inline] +pub fn utf8_to_timestamp_ns_scalar( + value: &str, + fmt: &str, + tz: &T, +) -> Option { + utf8_to_timestamp_scalar(value, fmt, tz, &TimeUnit::Nanosecond) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_timestamp_scalar( + value: &str, + fmt: &str, + tz: &T, + tu: &TimeUnit, +) -> Option { + let mut parsed = Parsed::new(); + let fmt = StrftimeItems::new(fmt); + let r = parse(&mut parsed, value, fmt).ok(); + if r.is_some() { + parsed + .to_datetime() + .map(|x| x.naive_utc()) + .map(|x| tz.from_utc_datetime(&x)) + .map(|x| match tu { + TimeUnit::Second => x.timestamp(), + TimeUnit::Millisecond => x.timestamp_millis(), + TimeUnit::Microsecond => x.timestamp_micros(), + TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + }) + .ok() + } else { + None + } +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. +#[inline] +pub fn utf8_to_naive_timestamp_ns_scalar(value: &str, fmt: &str) -> Option { + utf8_to_naive_timestamp_scalar(value, fmt, &TimeUnit::Nanosecond) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> Option { + let fmt = StrftimeItems::new(fmt); + let mut parsed = Parsed::new(); + parse(&mut parsed, value, fmt.clone()).ok(); + parsed + .to_naive_datetime_with_offset(0) + .map(|x| match tu { + TimeUnit::Second => x.timestamp(), + TimeUnit::Millisecond => x.timestamp_millis(), + TimeUnit::Microsecond => x.timestamp_micros(), + TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + }) + .ok() +} + +fn utf8_to_timestamp_ns_impl( + array: &Utf8Array, + fmt: &str, + timezone: String, + tz: T, +) -> PrimitiveArray { + let iter = array + .iter() + .map(|x| x.and_then(|x| utf8_to_timestamp_ns_scalar(x, fmt, &tz))); + + PrimitiveArray::from_trusted_len_iter(iter) + .to(DataType::Timestamp(TimeUnit::Nanosecond, Some(timezone))) +} + +/// Parses `value` to a [`chrono_tz::Tz`] with the Arrow's definition of timestamp with a timezone. +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +pub fn parse_offset_tz(timezone: &str) -> Result { + timezone.parse::().map_err(|_| { + Error::InvalidArgumentError(format!("timezone \"{timezone}\" cannot be parsed")) + }) +} + +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +fn chrono_tz_utf_to_timestamp_ns( + array: &Utf8Array, + fmt: &str, + timezone: String, +) -> Result> { + let tz = parse_offset_tz(&timezone)?; + Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz_utf_to_timestamp_ns( + _: &Utf8Array, + _: &str, + timezone: String, +) -> Result> { + Err(Error::InvalidArgumentError(format!( + "timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)", + ))) +} + +/// Parses a [`Utf8Array`] to a timeozone-aware timestamp, i.e. [`PrimitiveArray`] with type `Timestamp(Nanosecond, Some(timezone))`. +/// # Implementation +/// * parsed values with timezone other than `timezone` are converted to `timezone`. +/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp_ns`] to parse naive timezones. +/// * Null elements remain null; non-parsable elements are null. +/// The feature `"chrono-tz"` enables IANA and zoneinfo formats for `timezone`. +/// # Error +/// This function errors iff `timezone` is not parsable to an offset. +pub fn utf8_to_timestamp_ns( + array: &Utf8Array, + fmt: &str, + timezone: String, +) -> Result> { + let tz = parse_offset(timezone.as_str()); + + if let Ok(tz) = tz { + Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) + } else { + chrono_tz_utf_to_timestamp_ns(array, fmt, timezone) + } +} + +/// Parses a [`Utf8Array`] to naive timestamp, i.e. +/// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. +/// Timezones are ignored. +/// Null elements remain null; non-parsable elements are set to null. +pub fn utf8_to_naive_timestamp_ns( + array: &Utf8Array, + fmt: &str, +) -> PrimitiveArray { + let iter = array + .iter() + .map(|x| x.and_then(|x| utf8_to_naive_timestamp_ns_scalar(x, fmt))); + + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(TimeUnit::Nanosecond, None)) +} + +fn add_month(year: i32, month: u32, months: i32) -> chrono::NaiveDate { + let new_year = (year * 12 + (month - 1) as i32 + months) / 12; + let new_month = (year * 12 + (month - 1) as i32 + months) % 12 + 1; + chrono::NaiveDate::from_ymd_opt(new_year, new_month as u32, 1) + .expect("invalid or out-of-range date") +} + +fn get_days_between_months(year: i32, month: u32, months: i32) -> i64 { + add_month(year, month, months) + .signed_duration_since( + chrono::NaiveDate::from_ymd_opt(year, month, 1).expect("invalid or out-of-range date"), + ) + .num_days() +} + +/// Adds an `interval` to a `timestamp` in `time_unit` units without timezone. +#[inline] +pub fn add_naive_interval(timestamp: i64, time_unit: TimeUnit, interval: months_days_ns) -> i64 { + // convert seconds to a DateTime of a given offset. + let datetime = match time_unit { + TimeUnit::Second => timestamp_s_to_datetime(timestamp), + TimeUnit::Millisecond => timestamp_ms_to_datetime(timestamp), + TimeUnit::Microsecond => timestamp_us_to_datetime(timestamp), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(timestamp), + }; + + // compute the number of days in the interval, which depends on the particular year and month (leap days) + let delta_days = get_days_between_months(datetime.year(), datetime.month(), interval.months()) + + interval.days() as i64; + + // add; no leap hours are considered + let new_datetime_tz = datetime + + chrono::Duration::nanoseconds(delta_days * 24 * 60 * 60 * 1_000_000_000 + interval.ns()); + + // convert back to the target unit + match time_unit { + TimeUnit::Second => new_datetime_tz.timestamp_millis() / 1000, + TimeUnit::Millisecond => new_datetime_tz.timestamp_millis(), + TimeUnit::Microsecond => new_datetime_tz.timestamp_nanos_opt().unwrap() / 1000, + TimeUnit::Nanosecond => new_datetime_tz.timestamp_nanos_opt().unwrap(), + } +} + +/// Adds an `interval` to a `timestamp` in `time_unit` units and timezone `timezone`. +#[inline] +pub fn add_interval( + timestamp: i64, + time_unit: TimeUnit, + interval: months_days_ns, + timezone: &T, +) -> i64 { + // convert seconds to a DateTime of a given offset. + let datetime_tz = timestamp_to_datetime(timestamp, time_unit, timezone); + + // compute the number of days in the interval, which depends on the particular year and month (leap days) + let delta_days = + get_days_between_months(datetime_tz.year(), datetime_tz.month(), interval.months()) + + interval.days() as i64; + + // add; tz will take care of leap hours + let new_datetime_tz = datetime_tz + + chrono::Duration::nanoseconds(delta_days * 24 * 60 * 60 * 1_000_000_000 + interval.ns()); + + // convert back to the target unit + match time_unit { + TimeUnit::Second => new_datetime_tz.timestamp_millis() / 1000, + TimeUnit::Millisecond => new_datetime_tz.timestamp_millis(), + TimeUnit::Microsecond => new_datetime_tz.timestamp_nanos_opt().unwrap() / 1000, + TimeUnit::Nanosecond => new_datetime_tz.timestamp_nanos_opt().unwrap(), + } +} diff --git a/crates/nano-arrow/src/trusted_len.rs b/crates/nano-arrow/src/trusted_len.rs new file mode 100644 index 000000000000..a1c38bd51c71 --- /dev/null +++ b/crates/nano-arrow/src/trusted_len.rs @@ -0,0 +1,57 @@ +//! Declares [`TrustedLen`]. +use std::slice::Iter; + +/// An iterator of known, fixed size. +/// A trait denoting Rusts' unstable [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). +/// This is re-defined here and implemented for some iterators until `std::iter::TrustedLen` +/// is stabilized. +/// +/// # Safety +/// This trait must only be implemented when the contract is upheld. +/// Consumers of this trait must inspect Iterator::size_hint()’s upper bound. +pub unsafe trait TrustedLen: Iterator {} + +unsafe impl TrustedLen for Iter<'_, T> {} + +unsafe impl B> TrustedLen for std::iter::Map {} + +unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Copied +where + I: TrustedLen, + T: Copy, +{ +} +unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Cloned +where + I: TrustedLen, + T: Clone, +{ +} + +unsafe impl TrustedLen for std::iter::Enumerate where I: TrustedLen {} + +unsafe impl TrustedLen for std::iter::Zip +where + A: TrustedLen, + B: TrustedLen, +{ +} + +unsafe impl TrustedLen for std::slice::ChunksExact<'_, T> {} + +unsafe impl TrustedLen for std::slice::Windows<'_, T> {} + +unsafe impl TrustedLen for std::iter::Chain +where + A: TrustedLen, + B: TrustedLen, +{ +} + +unsafe impl TrustedLen for std::iter::Once {} + +unsafe impl TrustedLen for std::vec::IntoIter {} + +unsafe impl TrustedLen for std::iter::Repeat {} +unsafe impl A> TrustedLen for std::iter::RepeatWith {} +unsafe impl TrustedLen for std::iter::Take {} diff --git a/crates/nano-arrow/src/types/bit_chunk.rs b/crates/nano-arrow/src/types/bit_chunk.rs new file mode 100644 index 000000000000..ef4b25fd28a2 --- /dev/null +++ b/crates/nano-arrow/src/types/bit_chunk.rs @@ -0,0 +1,161 @@ +use std::fmt::Binary; +use std::ops::{BitAndAssign, Not, Shl, ShlAssign, ShrAssign}; + +use num_traits::PrimInt; + +use super::NativeType; + +/// A chunk of bits. This is used to create masks of a given length +/// whose width is `1` bit. In `portable_simd` notation, this corresponds to `m1xY`. +/// +/// This (sealed) trait is implemented for [`u8`], [`u16`], [`u32`] and [`u64`]. +pub trait BitChunk: + super::private::Sealed + + PrimInt + + NativeType + + Binary + + ShlAssign + + Not + + ShrAssign + + ShlAssign + + Shl + + BitAndAssign +{ + /// convert itself into bytes. + fn to_ne_bytes(self) -> Self::Bytes; + /// convert itself from bytes. + fn from_ne_bytes(v: Self::Bytes) -> Self; +} + +macro_rules! bit_chunk { + ($ty:ty) => { + impl BitChunk for $ty { + #[inline(always)] + fn to_ne_bytes(self) -> Self::Bytes { + self.to_ne_bytes() + } + + #[inline(always)] + fn from_ne_bytes(v: Self::Bytes) -> Self { + Self::from_ne_bytes(v) + } + } + }; +} + +bit_chunk!(u8); +bit_chunk!(u16); +bit_chunk!(u32); +bit_chunk!(u64); + +/// An [`Iterator`] over a [`BitChunk`]. This iterator is often +/// compiled to SIMD. +/// The [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit) corresponds +/// to the first slot, as defined by the arrow specification. +/// # Example +/// ``` +/// use arrow2::types::BitChunkIter; +/// let a = 0b00010000u8; +/// let iter = BitChunkIter::new(a, 7); +/// let r = iter.collect::>(); +/// assert_eq!(r, vec![false, false, false, false, true, false, false]); +/// ``` +pub struct BitChunkIter { + value: T, + mask: T, + remaining: usize, +} + +impl BitChunkIter { + /// Creates a new [`BitChunkIter`] with `len` bits. + #[inline] + pub fn new(value: T, len: usize) -> Self { + assert!(len <= std::mem::size_of::() * 8); + Self { + value, + remaining: len, + mask: T::one(), + } + } +} + +impl Iterator for BitChunkIter { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + }; + let result = Some(self.value & self.mask != T::zero()); + self.remaining -= 1; + self.mask <<= 1; + result + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +// # Safety +// a mathematical invariant of this iterator +unsafe impl crate::trusted_len::TrustedLen for BitChunkIter {} + +/// An [`Iterator`] over a [`BitChunk`] returning the index of each bit set in the chunk +/// See for details +/// # Example +/// ``` +/// use arrow2::types::BitChunkOnes; +/// let a = 0b00010000u8; +/// let iter = BitChunkOnes::new(a); +/// let r = iter.collect::>(); +/// assert_eq!(r, vec![4]); +/// ``` +pub struct BitChunkOnes { + value: T, + remaining: usize, +} + +impl BitChunkOnes { + /// Creates a new [`BitChunkOnes`] with `len` bits. + #[inline] + pub fn new(value: T) -> Self { + Self { + value, + remaining: value.count_ones() as usize, + } + } + + #[inline] + #[cfg(feature = "compute_filter")] + pub(crate) fn from_known_count(value: T, remaining: usize) -> Self { + Self { value, remaining } + } +} + +impl Iterator for BitChunkOnes { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let v = self.value.trailing_zeros() as usize; + self.value &= self.value - T::one(); + + self.remaining -= 1; + Some(v) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +// # Safety +// a mathematical invariant of this iterator +unsafe impl crate::trusted_len::TrustedLen for BitChunkOnes {} diff --git a/crates/nano-arrow/src/types/index.rs b/crates/nano-arrow/src/types/index.rs new file mode 100644 index 000000000000..0aedea008fa3 --- /dev/null +++ b/crates/nano-arrow/src/types/index.rs @@ -0,0 +1,103 @@ +use std::convert::TryFrom; + +use super::NativeType; +use crate::trusted_len::TrustedLen; + +/// Sealed trait describing the subset of [`NativeType`] (`i32`, `i64`, `u32` and `u64`) +/// that can be used to index a slot of an array. +pub trait Index: + NativeType + + std::ops::AddAssign + + std::ops::Sub + + num_traits::One + + num_traits::Num + + num_traits::CheckedAdd + + PartialOrd + + Ord +{ + /// Convert itself to [`usize`]. + fn to_usize(&self) -> usize; + /// Convert itself from [`usize`]. + fn from_usize(index: usize) -> Option; + + /// Convert itself from [`usize`]. + fn from_as_usize(index: usize) -> Self; + + /// An iterator from (inclusive) `start` to (exclusive) `end`. + fn range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => Some(IndexRange::new(start, end)), + _ => None, + } + } +} + +macro_rules! index { + ($t:ty) => { + impl Index for $t { + #[inline] + fn to_usize(&self) -> usize { + *self as usize + } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } + + #[inline] + fn from_as_usize(value: usize) -> Self { + value as $t + } + } + }; +} + +index!(i8); +index!(i16); +index!(i32); +index!(i64); +index!(u8); +index!(u16); +index!(u32); +index!(u64); + +/// Range of [`Index`], equivalent to `(a..b)`. +/// `Step` is unstable in Rust, which does not allow us to implement (a..b) for [`Index`]. +pub struct IndexRange { + start: I, + end: I, +} + +impl IndexRange { + /// Returns a new [`IndexRange`]. + pub fn new(start: I, end: I) -> Self { + assert!(end >= start); + Self { start, end } + } +} + +impl Iterator for IndexRange { + type Item = I; + + #[inline] + fn next(&mut self) -> Option { + if self.start == self.end { + return None; + } + let old = self.start; + self.start += I::one(); + Some(old) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = (self.end - self.start).to_usize(); + (len, Some(len)) + } +} + +/// Safety: a range is always of known length +unsafe impl TrustedLen for IndexRange {} diff --git a/crates/nano-arrow/src/types/mod.rs b/crates/nano-arrow/src/types/mod.rs new file mode 100644 index 000000000000..2ba57b4d784a --- /dev/null +++ b/crates/nano-arrow/src/types/mod.rs @@ -0,0 +1,89 @@ +//! Sealed traits and implementations to handle all _physical types_ used in this crate. +//! +//! Most physical types used in this crate are native Rust types, such as `i32`. +//! The trait [`NativeType`] describes the interfaces required by this crate to be conformant +//! with Arrow. +//! +//! Every implementation of [`NativeType`] has an associated variant in [`PrimitiveType`], +//! available via [`NativeType::PRIMITIVE`]. +//! Combined, these allow structs generic over [`NativeType`] to be trait objects downcastable +//! to concrete implementations based on the matched [`NativeType::PRIMITIVE`] variant. +//! +//! Another important trait in this module is [`Offset`], the subset of [`NativeType`] that can +//! be used in Arrow offsets (`i32` and `i64`). +//! +//! Another important trait in this module is [`BitChunk`], describing types that can be used to +//! represent chunks of bits (e.g. 8 bits via `u8`, 16 via `u16`), and [`BitChunkIter`], +//! that can be used to iterate over bitmaps in [`BitChunk`]s according to +//! Arrow's definition of bitmaps. +//! +//! Finally, this module contains traits used to compile code based on [`NativeType`] optimized +//! for SIMD, at [`mod@simd`]. + +mod bit_chunk; +pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes}; +mod index; +pub mod simd; +pub use index::*; +mod native; +pub use native::*; +mod offset; +pub use offset::*; +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +/// The set of all implementations of the sealed trait [`NativeType`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum PrimitiveType { + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// A signed 128-bit integer. + Int128, + /// A signed 256-bit integer. + Int256, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, + /// A 16-bit floating point number. + Float16, + /// A 32-bit floating point number. + Float32, + /// A 64-bit floating point number. + Float64, + /// Two i32 representing days and ms + DaysMs, + /// months_days_ns(i32, i32, i64) + MonthDayNano, +} + +mod private { + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for i128 {} + impl Sealed for super::i256 {} + impl Sealed for super::f16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for super::days_ms {} + impl Sealed for super::months_days_ns {} +} diff --git a/crates/nano-arrow/src/types/native.rs b/crates/nano-arrow/src/types/native.rs new file mode 100644 index 000000000000..6e50a1454ead --- /dev/null +++ b/crates/nano-arrow/src/types/native.rs @@ -0,0 +1,639 @@ +use std::convert::TryFrom; +use std::ops::Neg; +use std::panic::RefUnwindSafe; + +use bytemuck::{Pod, Zeroable}; + +use super::PrimitiveType; + +/// Sealed trait implemented by all physical types that can be allocated, +/// serialized and deserialized by this crate. +/// All O(N) allocations in this crate are done for this trait alone. +pub trait NativeType: + super::private::Sealed + + Pod + + Send + + Sync + + Sized + + RefUnwindSafe + + std::fmt::Debug + + std::fmt::Display + + PartialEq + + Default +{ + /// The corresponding variant of [`PrimitiveType`]. + const PRIMITIVE: PrimitiveType; + + /// Type denoting its representation as bytes. + /// This is `[u8; N]` where `N = size_of::`. + type Bytes: AsRef<[u8]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> + + std::fmt::Debug + + Default; + + /// To bytes in little endian + fn to_le_bytes(&self) -> Self::Bytes; + + /// To bytes in big endian + fn to_be_bytes(&self) -> Self::Bytes; + + /// From bytes in little endian + fn from_le_bytes(bytes: Self::Bytes) -> Self; + + /// From bytes in big endian + fn from_be_bytes(bytes: Self::Bytes) -> Self; +} + +macro_rules! native_type { + ($type:ty, $primitive_type:expr) => { + impl NativeType for $type { + const PRIMITIVE: PrimitiveType = $primitive_type; + + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + Self::to_be_bytes(*self) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self::from_le_bytes(bytes) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self::from_be_bytes(bytes) + } + } + }; +} + +native_type!(u8, PrimitiveType::UInt8); +native_type!(u16, PrimitiveType::UInt16); +native_type!(u32, PrimitiveType::UInt32); +native_type!(u64, PrimitiveType::UInt64); +native_type!(i8, PrimitiveType::Int8); +native_type!(i16, PrimitiveType::Int16); +native_type!(i32, PrimitiveType::Int32); +native_type!(i64, PrimitiveType::Int64); +native_type!(f32, PrimitiveType::Float32); +native_type!(f64, PrimitiveType::Float64); +native_type!(i128, PrimitiveType::Int128); + +/// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct days_ms(pub i32, pub i32); + +impl days_ms { + /// A new [`days_ms`]. + #[inline] + pub fn new(days: i32, milliseconds: i32) -> Self { + Self(days, milliseconds) + } + + /// The number of days + #[inline] + pub fn days(&self) -> i32 { + self.0 + } + + /// The number of milliseconds + #[inline] + pub fn milliseconds(&self) -> i32 { + self.1 + } +} + +impl NativeType for days_ms { + const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs; + type Bytes = [u8; 8]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let days = self.0.to_le_bytes(); + let ms = self.1.to_le_bytes(); + let mut result = [0; 8]; + result[0] = days[0]; + result[1] = days[1]; + result[2] = days[2]; + result[3] = days[3]; + result[4] = ms[0]; + result[5] = ms[1]; + result[6] = ms[2]; + result[7] = ms[3]; + result + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let days = self.0.to_be_bytes(); + let ms = self.1.to_be_bytes(); + let mut result = [0; 8]; + result[0] = days[0]; + result[1] = days[1]; + result[2] = days[2]; + result[3] = days[3]; + result[4] = ms[0]; + result[5] = ms[1]; + result[6] = ms[2]; + result[7] = ms[3]; + result + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_le_bytes(days), i32::from_le_bytes(ms)) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_be_bytes(days), i32::from_be_bytes(ms)) + } +} + +/// The in-memory representation of the MonthDayNano variant of the "Interval" logical type. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct months_days_ns(pub i32, pub i32, pub i64); + +impl months_days_ns { + /// A new [`months_days_ns`]. + #[inline] + pub fn new(months: i32, days: i32, nanoseconds: i64) -> Self { + Self(months, days, nanoseconds) + } + + /// The number of months + #[inline] + pub fn months(&self) -> i32 { + self.0 + } + + /// The number of days + #[inline] + pub fn days(&self) -> i32 { + self.1 + } + + /// The number of nanoseconds + #[inline] + pub fn ns(&self) -> i64 { + self.2 + } +} + +impl NativeType for months_days_ns { + const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano; + type Bytes = [u8; 16]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let months = self.months().to_le_bytes(); + let days = self.days().to_le_bytes(); + let ns = self.ns().to_le_bytes(); + let mut result = [0; 16]; + result[0] = months[0]; + result[1] = months[1]; + result[2] = months[2]; + result[3] = months[3]; + result[4] = days[0]; + result[5] = days[1]; + result[6] = days[2]; + result[7] = days[3]; + (0..8).for_each(|i| { + result[8 + i] = ns[i]; + }); + result + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let months = self.months().to_be_bytes(); + let days = self.days().to_be_bytes(); + let ns = self.ns().to_be_bytes(); + let mut result = [0; 16]; + result[0] = months[0]; + result[1] = months[1]; + result[2] = months[2]; + result[3] = months[3]; + result[4] = days[0]; + result[5] = days[1]; + result[6] = days[2]; + result[7] = days[3]; + (0..8).for_each(|i| { + result[8 + i] = ns[i]; + }); + result + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_le_bytes(months), + i32::from_le_bytes(days), + i64::from_le_bytes(ns), + ) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_be_bytes(months), + i32::from_be_bytes(days), + i64::from_be_bytes(ns), + ) + } +} + +impl std::fmt::Display for days_ms { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}d {}ms", self.days(), self.milliseconds()) + } +} + +impl std::fmt::Display for months_days_ns { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}m {}d {}ns", self.months(), self.days(), self.ns()) + } +} + +impl Neg for days_ms { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self::new(-self.days(), -self.milliseconds()) + } +} + +impl Neg for months_days_ns { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self::new(-self.months(), -self.days(), -self.ns()) + } +} + +/// Type representation of the Float16 physical type +#[derive(Copy, Clone, Default, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct f16(pub u16); + +impl PartialEq for f16 { + #[inline] + fn eq(&self, other: &f16) -> bool { + if self.is_nan() || other.is_nan() { + false + } else { + (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) + } + } +} + +// see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs +impl f16 { + /// The difference between 1.0 and the next largest representable number. + pub const EPSILON: f16 = f16(0x1400u16); + + #[inline] + #[must_use] + pub(crate) const fn is_nan(self) -> bool { + self.0 & 0x7FFFu16 > 0x7C00u16 + } + + /// Casts from u16. + #[inline] + pub const fn from_bits(bits: u16) -> f16 { + f16(bits) + } + + /// Casts to u16. + #[inline] + pub const fn to_bits(self) -> u16 { + self.0 + } + + /// Casts this `f16` to `f32` + pub fn to_f32(self) -> f32 { + let i = self.0; + // Check for signed zero + if i & 0x7FFFu16 == 0 { + return f32::from_bits((i as u32) << 16); + } + + let half_sign = (i & 0x8000u16) as u32; + let half_exp = (i & 0x7C00u16) as u32; + let half_man = (i & 0x03FFu16) as u32; + + // Check for an infinity or NaN when all exponent bits set + if half_exp == 0x7C00u32 { + // Check for signed infinity if mantissa is zero + if half_man == 0 { + let number = (half_sign << 16) | 0x7F80_0000u32; + return f32::from_bits(number); + } else { + // NaN, keep current mantissa but also set most significiant mantissa bit + let number = (half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13); + return f32::from_bits(number); + } + } + + // Calculate single-precision components with adjusted exponent + let sign = half_sign << 16; + // Unbias exponent + let unbiased_exp = ((half_exp as i32) >> 10) - 15; + + // Check for subnormals, which will be normalized by adjusting exponent + if half_exp == 0 { + // Calculate how much to adjust the exponent by + let e = (half_man as u16).leading_zeros() - 6; + + // Rebias and adjust exponent + let exp = (127 - 15 - e) << 23; + let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; + return f32::from_bits(sign | exp | man); + } + + // Rebias exponent for a normalized normal + let exp = ((unbiased_exp + 127) as u32) << 23; + let man = (half_man & 0x03FFu32) << 13; + f32::from_bits(sign | exp | man) + } + + /// Casts an `f32` into `f16` + pub fn from_f32(value: f32) -> Self { + let x: u32 = value.to_bits(); + + // Extract IEEE754 components + let sign = x & 0x8000_0000u32; + let exp = x & 0x7F80_0000u32; + let man = x & 0x007F_FFFFu32; + + // Check for all exponent bits being set, which is Infinity or NaN + if exp == 0x7F80_0000u32 { + // Set mantissa MSB for NaN (and also keep shifted mantissa bits) + let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; + return f16(((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16); + } + + // The number is normalized, start assembling half precision version + let half_sign = sign >> 16; + // Unbias the exponent, then bias for half precision + let unbiased_exp = ((exp >> 23) as i32) - 127; + let half_exp = unbiased_exp + 15; + + // Check for exponent overflow, return +infinity + if half_exp >= 0x1F { + return f16((half_sign | 0x7C00u32) as u16); + } + + // Check for underflow + if half_exp <= 0 { + // Check mantissa for what we can do + if 14 - half_exp > 24 { + // No rounding possibility, so this is a full underflow, return signed zero + return f16(half_sign as u16); + } + // Don't forget about hidden leading mantissa bit when assembling mantissa + let man = man | 0x0080_0000u32; + let mut half_man = man >> (14 - half_exp); + // Check for rounding (see comment above functions) + let round_bit = 1 << (13 - half_exp); + if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { + half_man += 1; + } + // No exponent for subnormals + return f16((half_sign | half_man) as u16); + } + + // Rebias the exponent + let half_exp = (half_exp as u32) << 10; + let half_man = man >> 13; + // Check for rounding (see comment above functions) + let round_bit = 0x0000_1000u32; + if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { + // Round it + f16(((half_sign | half_exp | half_man) + 1) as u16) + } else { + f16((half_sign | half_exp | half_man) as u16) + } + } +} + +impl std::fmt::Debug for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.to_f32()) + } +} + +impl std::fmt::Display for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl NativeType for f16 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Float16; + type Bytes = [u8; 2]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + self.0.to_le_bytes() + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + self.0.to_be_bytes() + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_be_bytes(bytes)) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_le_bytes(bytes)) + } +} + +/// Physical representation of a decimal +#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct i256(pub ethnum::I256); + +impl i256 { + /// Returns a new [`i256`] from two `i128`. + pub fn from_words(hi: i128, lo: i128) -> Self { + Self(ethnum::I256::from_words(hi, lo)) + } +} + +impl Neg for i256 { + type Output = Self; + + #[inline] + fn neg(self) -> Self::Output { + let (a, b) = self.0.into_words(); + Self(ethnum::I256::from_words(-a, b)) + } +} + +impl std::fmt::Debug for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl std::fmt::Display for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +unsafe impl Pod for i256 {} +unsafe impl Zeroable for i256 {} + +impl NativeType for i256 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Int256; + + type Bytes = [u8; 32]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let mut bytes = [0u8; 32]; + let (a, b) = self.0.into_words(); + let a = a.to_le_bytes(); + (0..16).for_each(|i| { + bytes[i] = a[i]; + }); + + let b = b.to_le_bytes(); + (0..16).for_each(|i| { + bytes[i + 16] = b[i]; + }); + + bytes + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let mut bytes = [0u8; 32]; + let (a, b) = self.0.into_words(); + + let a = a.to_be_bytes(); + (0..16).for_each(|i| { + bytes[i] = a[i]; + }); + + let b = b.to_be_bytes(); + (0..16).for_each(|i| { + bytes[i + 16] = b[i]; + }); + + bytes + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let (a, b) = bytes.split_at(16); + let a: [u8; 16] = a.try_into().unwrap(); + let b: [u8; 16] = b.try_into().unwrap(); + let a = i128::from_be_bytes(a); + let b = i128::from_be_bytes(b); + Self(ethnum::I256::from_words(a, b)) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let (b, a) = bytes.split_at(16); + let a: [u8; 16] = a.try_into().unwrap(); + let b: [u8; 16] = b.try_into().unwrap(); + let a = i128::from_le_bytes(a); + let b = i128::from_le_bytes(b); + Self(ethnum::I256::from_words(a, b)) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_f16_to_f32() { + let f = f16::from_f32(7.0); + assert_eq!(f.to_f32(), 7.0f32); + + // 7.1 is NOT exactly representable in 16-bit, it's rounded + let f = f16::from_f32(7.1); + let diff = (f.to_f32() - 7.1f32).abs(); + // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 + assert!(diff <= 4.0 * f16::EPSILON.to_f32()); + + assert_eq!(f16(0x0000_0001).to_f32(), 2.0f32.powi(-24)); + assert_eq!(f16(0x0000_0005).to_f32(), 5.0 * 2.0f32.powi(-24)); + + assert_eq!(f16(0x0000_0001), f16::from_f32(2.0f32.powi(-24))); + assert_eq!(f16(0x0000_0005), f16::from_f32(5.0 * 2.0f32.powi(-24))); + + assert_eq!(format!("{}", f16::from_f32(7.0)), "7".to_string()); + assert_eq!(format!("{:?}", f16::from_f32(7.0)), "7.0".to_string()); + } +} diff --git a/crates/nano-arrow/src/types/offset.rs b/crates/nano-arrow/src/types/offset.rs new file mode 100644 index 000000000000..e68bb7ceb6bd --- /dev/null +++ b/crates/nano-arrow/src/types/offset.rs @@ -0,0 +1,16 @@ +use super::Index; + +/// Sealed trait describing the subset (`i32` and `i64`) of [`Index`] that can be used +/// as offsets of variable-length Arrow arrays. +pub trait Offset: super::private::Sealed + Index { + /// Whether it is `i32` (false) or `i64` (true). + const IS_LARGE: bool; +} + +impl Offset for i32 { + const IS_LARGE: bool = false; +} + +impl Offset for i64 { + const IS_LARGE: bool = true; +} diff --git a/crates/nano-arrow/src/types/simd/mod.rs b/crates/nano-arrow/src/types/simd/mod.rs new file mode 100644 index 000000000000..d906c9d25e95 --- /dev/null +++ b/crates/nano-arrow/src/types/simd/mod.rs @@ -0,0 +1,167 @@ +//! Contains traits and implementations of multi-data used in SIMD. +//! The actual representation is driven by the feature flag `"simd"`, which, if set, +//! uses [`std::simd`]. +use super::{days_ms, f16, i256, months_days_ns, BitChunk, BitChunkIter, NativeType}; + +/// Describes the ability to convert itself from a [`BitChunk`]. +pub trait FromMaskChunk { + /// Convert itself from a slice. + fn from_chunk(v: T) -> Self; +} + +/// A struct that lends itself well to be compiled leveraging SIMD +/// # Safety +/// The `NativeType` and the `NativeSimd` must have possible a matching alignment. +/// e.g. slicing `&[NativeType]` by `align_of()` must be properly aligned/safe. +pub unsafe trait NativeSimd: Sized + Default + Copy { + /// Number of lanes + const LANES: usize; + /// The [`NativeType`] of this struct. E.g. `f32` for a `NativeSimd = f32x16`. + type Native: NativeType; + /// The type holding bits for masks. + type Chunk: BitChunk; + /// Type used for masking. + type Mask: FromMaskChunk; + + /// Sets values to `default` based on `mask`. + fn select(self, mask: Self::Mask, default: Self) -> Self; + + /// Convert itself from a slice. + /// # Panics + /// * iff `v.len()` != `T::LANES` + fn from_chunk(v: &[Self::Native]) -> Self; + + /// creates a new Self from `v` by populating items from `v` up to its length. + /// Items from `v` at positions larger than the number of lanes are ignored; + /// remaining items are populated with `remaining`. + fn from_incomplete_chunk(v: &[Self::Native], remaining: Self::Native) -> Self; + + /// Returns a tuple of 3 items whose middle item is itself, and the remaining + /// are the head and tail of the un-aligned parts. + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]); +} + +/// Trait implemented by some [`NativeType`] that have a SIMD representation. +pub trait Simd: NativeType { + /// The SIMD type associated with this trait. + /// This type supports SIMD operations + type Simd: NativeSimd; +} + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +pub use packed::*; + +macro_rules! native_simd { + ($name:tt, $type:ty, $lanes:expr, $mask:ty) => { + /// Multi-Data correspondence of the native type + #[allow(non_camel_case_types)] + #[derive(Copy, Clone)] + pub struct $name(pub [$type; $lanes]); + + unsafe impl NativeSimd for $name { + const LANES: usize = $lanes; + type Native = $type; + type Chunk = $mask; + type Mask = $mask; + + #[inline] + fn select(self, mask: $mask, default: Self) -> Self { + let mut reduced = default; + let iter = BitChunkIter::new(mask, Self::LANES); + for (i, b) in (0..Self::LANES).zip(iter) { + reduced[i] = if b { self[i] } else { reduced[i] }; + } + reduced + } + + #[inline] + fn from_chunk(v: &[$type]) -> Self { + ($name)(v.try_into().unwrap()) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; $lanes]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + Self(a) + } + + #[inline] + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]) { + unsafe { values.align_to::() } + } + } + + impl std::ops::Index for $name { + type Output = $type; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } + } + + impl std::ops::IndexMut for $name { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } + } + + impl Default for $name { + #[inline] + fn default() -> Self { + ($name)([<$type>::default(); $lanes]) + } + } + }; +} + +pub(super) use native_simd; + +// Types do not have specific intrinsics and thus SIMD can't be specialized. +// Therefore, we can declare their MD representation as `[$t; 8]` irrespectively +// of how they are represented in the different channels. +native_simd!(f16x32, f16, 32, u32); +native_simd!(days_msx8, days_ms, 8, u8); +native_simd!(months_days_nsx8, months_days_ns, 8, u8); +native_simd!(i128x8, i128, 8, u8); +native_simd!(i256x8, i256, 8, u8); + +// In the native implementation, a mask is 1 bit wide, as per AVX512. +impl FromMaskChunk for T { + #[inline] + fn from_chunk(v: T) -> Self { + v + } +} + +macro_rules! native { + ($type:ty, $simd:ty) => { + impl Simd for $type { + type Simd = $simd; + } + }; +} + +native!(u8, u8x64); +native!(u16, u16x32); +native!(u32, u32x16); +native!(u64, u64x8); +native!(i8, i8x64); +native!(i16, i16x32); +native!(i32, i32x16); +native!(i64, i64x8); +native!(f16, f16x32); +native!(f32, f32x16); +native!(f64, f64x8); +native!(i128, i128x8); +native!(i256, i256x8); +native!(days_ms, days_msx8); +native!(months_days_ns, months_days_nsx8); diff --git a/crates/nano-arrow/src/types/simd/native.rs b/crates/nano-arrow/src/types/simd/native.rs new file mode 100644 index 000000000000..af31b8b26bc0 --- /dev/null +++ b/crates/nano-arrow/src/types/simd/native.rs @@ -0,0 +1,16 @@ +use std::convert::TryInto; + +use super::*; +use crate::types::BitChunkIter; + +native_simd!(u8x64, u8, 64, u64); +native_simd!(u16x32, u16, 32, u32); +native_simd!(u32x16, u32, 16, u16); +native_simd!(u64x8, u64, 8, u8); +native_simd!(i8x64, i8, 64, u64); +native_simd!(i16x32, i16, 32, u32); +native_simd!(i32x16, i32, 16, u16); +native_simd!(i64x8, i64, 8, u8); +native_simd!(f16x32, f16, 32, u32); +native_simd!(f32x16, f32, 16, u16); +native_simd!(f64x8, f64, 8, u8); diff --git a/crates/nano-arrow/src/types/simd/packed.rs b/crates/nano-arrow/src/types/simd/packed.rs new file mode 100644 index 000000000000..0d95b68882aa --- /dev/null +++ b/crates/nano-arrow/src/types/simd/packed.rs @@ -0,0 +1,197 @@ +pub use std::simd::{ + f32x16, f32x8, f64x8, i16x32, i16x8, i32x16, i32x8, i64x8, i8x64, i8x8, mask32x16 as m32x16, + mask64x8 as m64x8, mask8x64 as m8x64, u16x32, u16x8, u32x16, u32x8, u64x8, u8x64, u8x8, + SimdPartialEq, +}; + +/// Vector of 32 16-bit masks +#[allow(non_camel_case_types)] +pub type m16x32 = std::simd::Mask; + +use super::*; + +macro_rules! simd { + ($name:tt, $type:ty, $lanes:expr, $chunk:ty, $mask:tt) => { + unsafe impl NativeSimd for $name { + const LANES: usize = $lanes; + type Native = $type; + type Chunk = $chunk; + type Mask = $mask; + + #[inline] + fn select(self, mask: $mask, default: Self) -> Self { + mask.select(self, default) + } + + #[inline] + fn from_chunk(v: &[$type]) -> Self { + <$name>::from_slice(v) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; $lanes]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + <$name>::from_chunk(a.as_ref()) + } + + #[inline] + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]) { + unsafe { values.align_to::() } + } + } + }; +} + +simd!(u8x64, u8, 64, u64, m8x64); +simd!(u16x32, u16, 32, u32, m16x32); +simd!(u32x16, u32, 16, u16, m32x16); +simd!(u64x8, u64, 8, u8, m64x8); +simd!(i8x64, i8, 64, u64, m8x64); +simd!(i16x32, i16, 32, u32, m16x32); +simd!(i32x16, i32, 16, u16, m32x16); +simd!(i64x8, i64, 8, u8, m64x8); +simd!(f32x16, f32, 16, u16, m32x16); +simd!(f64x8, f64, 8, u8, m64x8); + +macro_rules! chunk_macro { + ($type:ty, $chunk:ty, $simd:ty, $mask:tt, $m:expr) => { + impl FromMaskChunk<$chunk> for $mask { + #[inline] + fn from_chunk(chunk: $chunk) -> Self { + ($m)(chunk) + } + } + }; +} + +chunk_macro!(u8, u64, u8x64, m8x64, from_chunk_u64); +chunk_macro!(u16, u32, u16x32, m16x32, from_chunk_u32); +chunk_macro!(u32, u16, u32x16, m32x16, from_chunk_u16); +chunk_macro!(u64, u8, u64x8, m64x8, from_chunk_u8); + +#[inline] +fn from_chunk_u8(chunk: u8) -> m64x8 { + let idx = u64x8::from_array([1, 2, 4, 8, 16, 32, 64, 128]); + let vecmask = u64x8::splat(chunk as u64); + + (idx & vecmask).simd_eq(idx) +} + +#[inline] +fn from_chunk_u16(chunk: u16) -> m32x16 { + let idx = u32x16::from_array([ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, + ]); + let vecmask = u32x16::splat(chunk as u32); + + (idx & vecmask).simd_eq(idx) +} + +#[inline] +fn from_chunk_u32(chunk: u32) -> m16x32 { + let idx = u16x32::from_array([ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 1, 2, 4, 8, + 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, + ]); + let left = u16x32::from_chunk(&[ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]); + let right = u16x32::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, + 1024, 2048, 4096, 8192, 16384, 32768, + ]); + + let a = chunk.to_ne_bytes(); + let a1 = u16::from_ne_bytes([a[2], a[3]]); + let a2 = u16::from_ne_bytes([a[0], a[1]]); + + let vecmask1 = u16x32::splat(a1); + let vecmask2 = u16x32::splat(a2); + + (idx & left & vecmask1).simd_eq(idx) | (idx & right & vecmask2).simd_eq(idx) +} + +#[inline] +fn from_chunk_u64(chunk: u64) -> m8x64 { + let idx = u8x64::from_array([ + 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, + 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, + 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, + ]); + let idxs = [ + u8x64::from_chunk(&[ + 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, + 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, + 4, 8, 16, 32, 64, 128, + ]), + ]; + + let a = chunk.to_ne_bytes(); + + let mut result = m8x64::default(); + for i in 0..8 { + result |= (idxs[i] & u8x64::splat(a[i])).simd_eq(idx) + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic1() { + let a = 0b00000001000000010000000100000001u32; + let a = from_chunk_u32(a); + for i in 0..32 { + assert_eq!(a.test(i), i % 8 == 0) + } + } + + #[test] + fn test_basic2() { + let a = 0b0000000100000001000000010000000100000001000000010000000100000001u64; + let a = from_chunk_u64(a); + for i in 0..64 { + assert_eq!(a.test(i), i % 8 == 0) + } + } +} diff --git a/crates/nano-arrow/src/util/bench_util.rs b/crates/nano-arrow/src/util/bench_util.rs new file mode 100644 index 000000000000..59fb88b198fc --- /dev/null +++ b/crates/nano-arrow/src/util/bench_util.rs @@ -0,0 +1,99 @@ +//! Utilities for benchmarking + +use rand::distributions::{Alphanumeric, Distribution, Standard}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use crate::array::*; +use crate::offset::Offset; +use crate::types::NativeType; + +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +/// Creates an random (but fixed-seeded) array of a given size and null density +pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray +where + T: NativeType, + Standard: Distribution, +{ + let mut rng = seedable_rng(); + + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.gen()) + } + }) + .collect::>() +} + +/// Creates a new [`PrimitiveArray`] from random values with a pre-set seed. +pub fn create_primitive_array_with_seed( + size: usize, + null_density: f32, + seed: u64, +) -> PrimitiveArray +where + T: NativeType, + Standard: Distribution, +{ + let mut rng = StdRng::seed_from_u64(seed); + + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.gen()) + } + }) + .collect::>() +} + +/// Creates an random (but fixed-seeded) array of a given size and null density +pub fn create_boolean_array(size: usize, null_density: f32, true_density: f32) -> BooleanArray +where + Standard: Distribution, +{ + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + let value = rng.gen::() < true_density; + Some(value) + } + }) + .collect() +} + +/// Creates an random (but fixed-seeded) [`Utf8Array`] of a given length, number of characters and null density. +pub fn create_string_array( + length: usize, + size: usize, + null_density: f32, + seed: u64, +) -> Utf8Array { + let mut rng = StdRng::seed_from_u64(seed); + + (0..length) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + let value = (&mut rng) + .sample_iter(&Alphanumeric) + .take(size) + .map(char::from) + .collect::(); + Some(value) + } + }) + .collect() +} diff --git a/crates/nano-arrow/src/util/lexical.rs b/crates/nano-arrow/src/util/lexical.rs new file mode 100644 index 000000000000..047986cbbedd --- /dev/null +++ b/crates/nano-arrow/src/util/lexical.rs @@ -0,0 +1,42 @@ +/// Converts numeric type to a `String` +#[inline] +pub fn lexical_to_bytes(n: N) -> Vec { + let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); + lexical_to_bytes_mut(n, &mut buf); + buf +} + +/// Converts numeric type to a `String` +#[inline] +pub fn lexical_to_bytes_mut(n: N, buf: &mut Vec) { + buf.clear(); + buf.reserve(N::FORMATTED_SIZE_DECIMAL); + unsafe { + // JUSTIFICATION + // Benefit + // Allows using the faster serializer lexical core and convert to string + // Soundness + // Length of buf is set as written length afterwards. lexical_core + // creates a valid string, so doesn't need to be checked. + let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); + + // Safety: + // Omits an unneeded bound check as we just ensured that we reserved `N::FORMATTED_SIZE_DECIMAL` + #[cfg(debug_assertions)] + { + let len = lexical_core::write(n, slice).len(); + buf.set_len(len); + } + #[cfg(not(debug_assertions))] + { + let len = lexical_core::write_unchecked(n, slice).len(); + buf.set_len(len); + } + } +} + +/// Converts numeric type to a `String` +#[inline] +pub fn lexical_to_string(n: N) -> String { + unsafe { String::from_utf8_unchecked(lexical_to_bytes(n)) } +} diff --git a/crates/nano-arrow/src/util/mod.rs b/crates/nano-arrow/src/util/mod.rs new file mode 100644 index 000000000000..11e14360e35c --- /dev/null +++ b/crates/nano-arrow/src/util/mod.rs @@ -0,0 +1,26 @@ +//! Misc utilities used in different places in the crate. + +#[cfg(any( + feature = "compute", + feature = "io_csv_write", + feature = "io_csv_read", + feature = "io_json", + feature = "io_json_write", + feature = "compute_cast" +))] +mod lexical; +#[cfg(any( + feature = "compute", + feature = "io_csv_write", + feature = "io_csv_read", + feature = "io_json", + feature = "io_json_write", + feature = "compute_cast" +))] +pub use lexical::*; + +#[cfg(feature = "benchmarks")] +#[cfg_attr(docsrs, doc(cfg(feature = "benchmarks")))] +pub mod bench_util; + +pub mod total_ord; diff --git a/crates/nano-arrow/src/util/total_ord.rs b/crates/nano-arrow/src/util/total_ord.rs new file mode 100644 index 000000000000..f6fe19bfbb5a --- /dev/null +++ b/crates/nano-arrow/src/util/total_ord.rs @@ -0,0 +1,419 @@ +use std::cmp::Ordering; +use std::hash::{Hash, Hasher}; + +use bytemuck::TransparentWrapper; + +use crate::array::Array; + +/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +pub fn canonical_f32(x: f32) -> f32 { + // -0.0 + 0.0 becomes 0.0. + let convert_zero = x + 0.0; + if convert_zero.is_nan() { + f32::from_bits(0x7fc00000) // Canonical quiet NaN. + } else { + x + } +} + +/// Converts an f64 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +pub fn canonical_f64(x: f64) -> f64 { + // -0.0 + 0.0 becomes 0.0. + let convert_zero = x + 0.0; + if convert_zero.is_nan() { + f64::from_bits(0x7ff8000000000000) // Canonical quiet NaN. + } else { + x + } +} + +/// Alternative trait for Eq. By consistently using this we can still be +/// generic w.r.t Eq while getting a total ordering for floats. +pub trait TotalEq { + fn tot_eq(&self, other: &Self) -> bool; + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + !(self.tot_eq(other)) + } +} + +/// Alternative trait for Ord. By consistently using this we can still be +/// generic w.r.t Ord while getting a total ordering for floats. +pub trait TotalOrd: TotalEq { + fn tot_cmp(&self, other: &Self) -> Ordering; + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + self.tot_cmp(other) == Ordering::Less + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + self.tot_cmp(other) == Ordering::Greater + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + self.tot_cmp(other) != Ordering::Greater + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + self.tot_cmp(other) != Ordering::Less + } +} + +/// Alternative trait for Hash. By consistently using this we can still be +/// generic w.r.t Hash while being able to hash floats. +pub trait TotalHash { + fn tot_hash(&self, state: &mut H) + where + H: Hasher; + + fn tot_hash_slice(data: &[Self], state: &mut H) + where + H: Hasher, + Self: Sized, + { + for piece in data { + piece.tot_hash(state) + } + } +} + +#[repr(transparent)] +pub struct TotalOrdWrap(pub T); +unsafe impl TransparentWrapper for TotalOrdWrap {} + +impl PartialOrd for TotalOrdWrap { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + + #[inline(always)] + fn lt(&self, other: &Self) -> bool { + self.0.tot_lt(&other.0) + } + + #[inline(always)] + fn le(&self, other: &Self) -> bool { + self.0.tot_le(&other.0) + } + + #[inline(always)] + fn gt(&self, other: &Self) -> bool { + self.0.tot_gt(&other.0) + } + + #[inline(always)] + fn ge(&self, other: &Self) -> bool { + self.0.tot_ge(&other.0) + } +} + +impl Ord for TotalOrdWrap { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + self.0.tot_cmp(&other.0) + } +} + +impl PartialEq for TotalOrdWrap { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.0.tot_eq(&other.0) + } + + #[inline(always)] + #[allow(clippy::partialeq_ne_impl)] + fn ne(&self, other: &Self) -> bool { + self.0.tot_ne(&other.0) + } +} + +impl Eq for TotalOrdWrap {} + +impl Hash for TotalOrdWrap { + fn hash(&self, state: &mut H) { + self.0.tot_hash(state); + } +} + +macro_rules! impl_trivial_eq { + ($T: ty) => { + impl TotalEq for $T { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + self != other + } + } + }; +} + +macro_rules! impl_trivial_eq_ord_hash { + ($T: ty) => { + impl_trivial_eq!($T); + + impl TotalOrd for $T { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + self < other + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + self > other + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + self <= other + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + self >= other + } + } + + impl TotalHash for $T { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } + } + }; +} + +// We can't do a blanket impl because Rust complains f32 might implement +// Ord / Eq someday. +impl_trivial_eq_ord_hash!(bool); +impl_trivial_eq_ord_hash!(u8); +impl_trivial_eq_ord_hash!(u16); +impl_trivial_eq_ord_hash!(u32); +impl_trivial_eq_ord_hash!(u64); +impl_trivial_eq_ord_hash!(u128); +impl_trivial_eq_ord_hash!(usize); +impl_trivial_eq_ord_hash!(i8); +impl_trivial_eq_ord_hash!(i16); +impl_trivial_eq_ord_hash!(i32); +impl_trivial_eq_ord_hash!(i64); +impl_trivial_eq_ord_hash!(i128); +impl_trivial_eq_ord_hash!(isize); +impl_trivial_eq_ord_hash!(char); +impl_trivial_eq_ord_hash!(&str); +impl_trivial_eq_ord_hash!(&[u8]); +impl_trivial_eq_ord_hash!(String); +impl_trivial_eq!(&dyn Array); +impl_trivial_eq!(Box); + +macro_rules! impl_eq_ord_float { + ($f:ty) => { + impl TotalEq for $f { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + if self.is_nan() { + other.is_nan() + } else { + self == other + } + } + } + + impl TotalOrd for $f { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + if self.tot_lt(other) { + Ordering::Less + } else if self.tot_gt(other) { + Ordering::Greater + } else { + Ordering::Equal + } + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + !self.tot_ge(other) + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + other.tot_lt(self) + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + other.tot_ge(self) + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + // We consider all NaNs equal, and NaN is the largest possible + // value. Thus if self is NaN we always return true. Otherwise + // self >= other is correct. If other is not NaN it is trivially + // correct, and if it is we note that nothing can be greater or + // equal to NaN except NaN itself, which we already handled earlier. + self.is_nan() | (self >= other) + } + } + }; +} + +impl_eq_ord_float!(f32); +impl_eq_ord_float!(f64); + +impl TotalHash for f32 { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f32(*self).to_bits().hash(state) + } +} + +impl TotalHash for f64 { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f64(*self).to_bits().hash(state) + } +} + +// Blanket implementations. +impl TotalEq for Option { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + match (self, other) { + (None, None) => true, + (Some(a), Some(b)) => a.tot_eq(b), + _ => false, + } + } + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + match (self, other) { + (None, None) => false, + (Some(a), Some(b)) => a.tot_ne(b), + _ => true, + } + } +} + +impl TotalOrd for Option { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(a), Some(b)) => a.tot_cmp(b), + } + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + match (self, other) { + (None, Some(_)) => true, + (Some(a), Some(b)) => a.tot_lt(b), + _ => false, + } + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + other.tot_lt(self) + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + match (self, other) { + (Some(_), None) => false, + (Some(a), Some(b)) => a.tot_lt(b), + _ => true, + } + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + other.tot_le(self) + } +} + +impl TotalHash for Option { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.is_some().tot_hash(state); + if let Some(slf) = self { + slf.tot_hash(state) + } + } +} + +impl TotalEq for &T { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + (*self).tot_eq(*other) + } + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + (*self).tot_ne(*other) + } +} + +impl TotalOrd for &T { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + (*self).tot_cmp(*other) + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + (*self).tot_lt(*other) + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + (*self).tot_gt(*other) + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + (*self).tot_le(*other) + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + (*self).tot_ge(*other) + } +} + +impl TotalHash for &T { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + (*self).tot_hash(state) + } +} diff --git a/crates/polars-algo/Cargo.toml b/crates/polars-algo/Cargo.toml index 569cccec807e..ac589731c96f 100644 --- a/crates/polars-algo/Cargo.toml +++ b/crates/polars-algo/Cargo.toml @@ -9,9 +9,9 @@ repository = { workspace = true } description = "Algorithms built upon Polars primitives" [dependencies] -polars-core = { version = "0.32.0", path = "../polars-core", features = ["dtype-categorical", "asof_join"], default-features = false } -polars-lazy = { version = "0.32.0", path = "../polars-lazy", features = ["asof_join", "concat_str", "strings"] } -polars-ops = { version = "0.32.0", path = "../polars-ops", features = ["dtype-categorical", "asof_join"], default-features = false } +polars-core = { workspace = true, features = ["dtype-categorical", "asof_join"] } +polars-lazy = { workspace = true, features = ["asof_join", "concat_str", "strings"], default-features = true } +polars-ops = { workspace = true, features = ["dtype-categorical", "asof_join"] } [package.metadata.docs.rs] all-features = true diff --git a/crates/polars-algo/README.md b/crates/polars-algo/README.md index 7d9d546c74b3..62caa8983752 100644 --- a/crates/polars-algo/README.md +++ b/crates/polars-algo/README.md @@ -1,5 +1,5 @@ # polars-algo -`polars-algo` is a sub-crate of Polars that contains algorithms built upon Polars primitives. +`polars-algo` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, designed to house algorithms built on Polars primitives. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. If you're looking to use Polars, please refer to the main [Polars crate](https://crates.io/crates/polars) instead. diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 6b47d0cecf96..3ca0bfb8a77c 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -9,13 +9,13 @@ repository = { workspace = true } description = "Arrow interfaces for Polars DataFrame library" [dependencies] -polars-error = { version = "0.32.0", path = "../polars-error" } +polars-error = { workspace = true } arrow = { workspace = true } atoi = { workspace = true, optional = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } -ethnum = { version = "1.3.2", optional = true } +ethnum = { workspace = true, optional = true } hashbrown = { workspace = true } multiversion = { workspace = true } num-traits = { workspace = true } @@ -34,6 +34,5 @@ compute = ["arrow/compute_cast"] temporal = ["arrow/compute_temporal"] bigidx = [] performant = [] -like = ["arrow/compute_like"] timezones = ["chrono-tz", "chrono"] simd = [] diff --git a/crates/polars-arrow/README.md b/crates/polars-arrow/README.md index a6d84224681b..e2676f74c1ae 100644 --- a/crates/polars-arrow/README.md +++ b/crates/polars-arrow/README.md @@ -1,5 +1,5 @@ # polars-arrow -`polars-arrow` is a sub-crate that provides Arrow interfaces for the Polars dataframe library. +`polars-arrow` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, offering Arrow interfaces. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-arrow/src/array/get.rs b/crates/polars-arrow/src/array/get.rs deleted file mode 100644 index 331cba3abdf8..000000000000 --- a/crates/polars-arrow/src/array/get.rs +++ /dev/null @@ -1,154 +0,0 @@ -use arrow::array::{ - Array, BinaryArray, BooleanArray, FixedSizeListArray, ListArray, PrimitiveArray, Utf8Array, -}; -use arrow::types::NativeType; - -pub trait ArrowGetItem { - type Item; - - fn get(&self, item: usize) -> Option; - - /// # Safety - /// Get item. It is the callers responsibility that the `item < self.len()` - unsafe fn get_unchecked(&self, item: usize) -> Option; -} - -impl ArrowGetItem for PrimitiveArray { - type Item = T; - - #[inline] - fn get(&self, item: usize) -> Option { - if item >= self.len() { - None - } else { - unsafe { self.get_unchecked(item) } - } - } - - #[inline] - unsafe fn get_unchecked(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } -} - -impl ArrowGetItem for BooleanArray { - type Item = bool; - - #[inline] - fn get(&self, item: usize) -> Option { - if item >= self.len() { - None - } else { - unsafe { self.get_unchecked(item) } - } - } - - #[inline] - unsafe fn get_unchecked(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } -} - -impl<'a> ArrowGetItem for &'a Utf8Array { - type Item = &'a str; - - #[inline] - fn get(&self, item: usize) -> Option { - if item >= self.len() { - None - } else { - unsafe { self.get_unchecked(item) } - } - } - - #[inline] - unsafe fn get_unchecked(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } -} - -impl<'a> ArrowGetItem for &'a BinaryArray { - type Item = &'a [u8]; - - #[inline] - fn get(&self, item: usize) -> Option { - if item >= self.len() { - None - } else { - unsafe { self.get_unchecked(item) } - } - } - - #[inline] - unsafe fn get_unchecked(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } -} - -impl ArrowGetItem for ListArray { - type Item = Box; - - #[inline] - fn get(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if item >= self.len() { - None - } else { - unsafe { self.get_unchecked(item) } - } - } - - #[inline] - unsafe fn get_unchecked(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } -} - -impl ArrowGetItem for FixedSizeListArray { - type Item = Box; - - #[inline] - fn get(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if item >= self.len() { - None - } else { - unsafe { self.get_unchecked(item) } - } - } - - #[inline] - unsafe fn get_unchecked(&self, item: usize) -> Option { - debug_assert!(item < self.len()); - if self.is_null_unchecked(item) { - None - } else { - Some(self.value_unchecked(item)) - } - } -} diff --git a/crates/polars-arrow/src/array/list.rs b/crates/polars-arrow/src/array/list.rs index 8428d38e503b..26e89a28d8a4 100644 --- a/crates/polars-arrow/src/array/list.rs +++ b/crates/polars-arrow/src/array/list.rs @@ -32,7 +32,7 @@ impl<'a> AnonymousBuilder<'a> { } pub fn is_empty(&self) -> bool { - self.arrays.is_empty() + self.offsets.len() == 1 } pub fn offsets(&self) -> &[i64] { diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs index 51f813440185..6529d73f4cc1 100644 --- a/crates/polars-arrow/src/array/mod.rs +++ b/crates/polars-arrow/src/array/mod.rs @@ -12,13 +12,11 @@ use crate::utils::CustomIterTools; pub mod default_arrays; #[cfg(feature = "dtype-array")] pub mod fixed_size_list; -mod get; pub mod list; pub mod null; pub mod slice; pub mod utf8; -pub use get::ArrowGetItem; pub use slice::*; pub trait ValueSize { diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index bd97af2d9bf0..c7b2d84fe8eb 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -46,3 +46,9 @@ impl MutableArray for MutableNullArray { // no-op } } + +impl MutableNullArray { + pub fn extend_nulls(&mut self, null_count: usize) { + self.len += null_count; + } +} diff --git a/crates/polars-arrow/src/bit_util.rs b/crates/polars-arrow/src/bit_util.rs index 13a3c72c3da8..51554533e4eb 100644 --- a/crates/polars-arrow/src/bit_util.rs +++ b/crates/polars-arrow/src/bit_util.rs @@ -2,7 +2,8 @@ /// /// Note that the bound checks are optimized away. /// - +use arrow::bitmap::utils::{BitChunkIterExact, BitChunks, BitChunksExact}; +use arrow::bitmap::Bitmap; const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; /// Returns the nearest number that is `>=` than `num` and is a multiple of 64 @@ -79,3 +80,160 @@ pub fn ceil(value: usize, divisor: usize) -> usize { quot } } + +fn first_set_bit_impl(mut mask_chunks: I) -> usize +where + I: BitChunkIterExact, +{ + let mut total = 0usize; + const SIZE: u32 = 64; + for chunk in &mut mask_chunks { + let pos = chunk.trailing_zeros(); + if pos != SIZE { + return total + pos as usize; + } else { + total += SIZE as usize + } + } + if let Some(pos) = mask_chunks.remainder_iter().position(|v| v) { + total += pos; + return total; + } + // all null, return the first + 0 +} + +pub fn first_set_bit(mask: &Bitmap) -> usize { + if mask.unset_bits() == 0 || mask.unset_bits() == mask.len() { + return 0; + } + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + first_set_bit_impl(mask_chunks) + } else { + let mask_chunks = mask.chunks::(); + first_set_bit_impl(mask_chunks) + } +} + +fn first_unset_bit_impl(mut mask_chunks: I) -> usize +where + I: BitChunkIterExact, +{ + let mut total = 0usize; + const SIZE: u32 = 64; + for chunk in &mut mask_chunks { + let pos = chunk.trailing_ones(); + if pos != SIZE { + return total + pos as usize; + } else { + total += SIZE as usize + } + } + if let Some(pos) = mask_chunks.remainder_iter().position(|v| !v) { + total += pos; + return total; + } + // all null, return the first + 0 +} + +pub fn first_unset_bit(mask: &Bitmap) -> usize { + if mask.unset_bits() == 0 || mask.unset_bits() == mask.len() { + return 0; + } + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + first_unset_bit_impl(mask_chunks) + } else { + let mask_chunks = mask.chunks::(); + first_unset_bit_impl(mask_chunks) + } +} + +pub fn find_first_true_false_null( + mut bit_chunks: BitChunks, + mut validity_chunks: BitChunks, +) -> (Option, Option, Option) { + let (mut true_index, mut false_index, mut null_index) = (None, None, None); + let (mut true_not_found_mask, mut false_not_found_mask, mut null_not_found_mask) = + (!0u64, !0u64, !0u64); // All ones while not found. + let mut offset: usize = 0; + let mut all_found = false; + for (truth_mask, null_mask) in (&mut bit_chunks).zip(&mut validity_chunks) { + let mask = null_mask & truth_mask & true_not_found_mask; + if mask > 0 { + true_index = Some(offset + mask.trailing_zeros() as usize); + true_not_found_mask = 0; + } + let mask = null_mask & !truth_mask & false_not_found_mask; + if mask > 0 { + false_index = Some(offset + mask.trailing_zeros() as usize); + false_not_found_mask = 0; + } + if !null_mask & null_not_found_mask > 0 { + null_index = Some(offset + null_mask.trailing_ones() as usize); + null_not_found_mask = 0; + } + if null_not_found_mask | true_not_found_mask | false_not_found_mask == 0 { + all_found = true; + break; + } + offset += 64; + } + if !all_found { + for (val, not_null) in bit_chunks + .remainder_iter() + .zip(validity_chunks.remainder_iter()) + { + if true_index.is_none() && not_null && val { + true_index = Some(offset); + } else if false_index.is_none() && not_null && !val { + false_index = Some(offset); + } else if null_index.is_none() && !not_null { + null_index = Some(offset); + } + offset += 1; + } + } + (true_index, false_index, null_index) +} + +pub fn find_first_true_false_no_null( + mut bit_chunks: BitChunks, +) -> (Option, Option) { + let (mut true_index, mut false_index) = (None, None); + let (mut true_not_found_mask, mut false_not_found_mask) = (!0u64, !0u64); // All ones while not found. + let mut offset: usize = 0; + let mut all_found = false; + for truth_mask in &mut bit_chunks { + let mask = truth_mask & true_not_found_mask; + if mask > 0 { + true_index = Some(offset + mask.trailing_zeros() as usize); + true_not_found_mask = 0; + } + let mask = !truth_mask & false_not_found_mask; + if mask > 0 { + false_index = Some(offset + mask.trailing_zeros() as usize); + false_not_found_mask = 0; + } + if true_not_found_mask | false_not_found_mask == 0 { + all_found = true; + break; + } + offset += 64; + } + if !all_found { + for val in bit_chunks.remainder_iter() { + if true_index.is_none() && val { + true_index = Some(offset); + } else if false_index.is_none() && !val { + false_index = Some(offset); + } + offset += 1; + } + } + (true_index, false_index) +} diff --git a/crates/polars-arrow/src/compute/decimal.rs b/crates/polars-arrow/src/compute/decimal.rs index 04066f1d9629..4c17422889f8 100644 --- a/crates/polars-arrow/src/compute/decimal.rs +++ b/crates/polars-arrow/src/compute/decimal.rs @@ -28,12 +28,20 @@ pub fn infer_scale(bytes: &[u8]) -> Option { /// Deserializes bytes to a single i128 representing a decimal /// The decimal precision and scale are not checked. #[inline] -pub(super) fn deserialize_decimal(bytes: &[u8], precision: Option, scale: u8) -> Option { +pub(super) fn deserialize_decimal( + mut bytes: &[u8], + precision: Option, + scale: u8, +) -> Option { + let negative = bytes.first() == Some(&b'-'); + if negative { + bytes = &bytes[1..]; + }; let (lhs, rhs) = split_decimal_bytes(bytes); let precision = precision.unwrap_or(u8::MAX); let lhs_b = lhs?; - parse_integer_checked(lhs_b).and_then(|x| { + let abs = parse_integer_checked(lhs_b).and_then(|x| { match rhs { Some(rhs) => { parse_integer_checked(rhs) @@ -77,9 +85,7 @@ pub(super) fn deserialize_decimal(bytes: &[u8], precision: Option, scale: u8 Some((lhs, rhs)) } }) - .map(|(lhs, rhs)| { - lhs * 10i128.pow(scale as u32) + (if lhs < 0 { -rhs } else { rhs }) - }) + .map(|(lhs, rhs)| lhs * 10i128.pow(scale as u32) + rhs) }, None => { if lhs_b.len() > precision as usize || scale != 0 { @@ -88,7 +94,12 @@ pub(super) fn deserialize_decimal(bytes: &[u8], precision: Option, scale: u8 parse_integer_checked(lhs_b) }, } - }) + }); + if negative { + Some(-abs?) + } else { + abs + } } #[cfg(test)] @@ -117,6 +128,12 @@ mod test { Some(14390) ); + let val = "-0.5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(-50) + ); + let val = "-1.5"; assert_eq!( deserialize_decimal(val.as_bytes(), precision, scale), diff --git a/crates/polars-arrow/src/floats/ord.rs b/crates/polars-arrow/src/floats/ord.rs index 52d71bc3382e..c4f38807b34d 100644 --- a/crates/polars-arrow/src/floats/ord.rs +++ b/crates/polars-arrow/src/floats/ord.rs @@ -14,7 +14,7 @@ pub struct OrdFloat(T); impl PartialOrd for OrdFloat { fn partial_cmp(&self, other: &Self) -> Option { - Some(compare_fn_nan_max(&self.0, &other.0)) + Some(self.cmp(other)) } } diff --git a/crates/polars-arrow/src/kernels/string.rs b/crates/polars-arrow/src/kernels/string.rs index 5d4770b2b13e..e348ac1f9548 100644 --- a/crates/polars-arrow/src/kernels/string.rs +++ b/crates/polars-arrow/src/kernels/string.rs @@ -5,7 +5,7 @@ use arrow::datatypes::DataType; use crate::prelude::*; use crate::trusted_len::TrustedLenPush; -pub fn string_lengths(array: &Utf8Array) -> ArrayRef { +pub fn string_len_bytes(array: &Utf8Array) -> ArrayRef { let values = array .offsets() .as_slice() @@ -16,7 +16,7 @@ pub fn string_lengths(array: &Utf8Array) -> ArrayRef { Box::new(array) } -pub fn string_nchars(array: &Utf8Array) -> ArrayRef { +pub fn string_len_chars(array: &Utf8Array) -> ArrayRef { let values = array.values_iter().map(|x| x.chars().count() as u32); let values: Buffer<_> = Vec::from_trusted_len_iter(values).into(); let array = UInt32Array::new(DataType::UInt32, values, array.validity().cloned()); diff --git a/crates/polars-arrow/src/kernels/take_agg/var.rs b/crates/polars-arrow/src/kernels/take_agg/var.rs index 67b86a81bc93..b5000f0eca4d 100644 --- a/crates/polars-arrow/src/kernels/take_agg/var.rs +++ b/crates/polars-arrow/src/kernels/take_agg/var.rs @@ -31,11 +31,12 @@ where mean = new_mean; m2 = new_m2; } - match count { - 0 => None, - 1 => Some(0.0), - _ => Some(m2 / (count as f64 - ddof as f64)), + + if count <= ddof as u64 { + return None; } + + Some(m2 / (count as f64 - ddof as f64)) } /// Take kernel for single chunk and an iterator as index. diff --git a/crates/polars-arrow/src/lib.rs b/crates/polars-arrow/src/lib.rs index 6baca9f93817..0d04de03fc23 100644 --- a/crates/polars-arrow/src/lib.rs +++ b/crates/polars-arrow/src/lib.rs @@ -1,8 +1,5 @@ #![cfg_attr(feature = "simd", feature(portable_simd))] -#![cfg_attr( - feature = "nightly", - allow(clippy::incorrect_partial_ord_impl_on_ord_type) -)] // Remove once stable. +#![cfg_attr(feature = "nightly", allow(clippy::non_canonical_partial_ord_impl))] // Remove once stable. pub mod array; pub mod bit_util; pub mod bitmap; diff --git a/crates/polars-arrow/src/trusted_len/boolean.rs b/crates/polars-arrow/src/trusted_len/boolean.rs index c60eb0949b95..45fb335c45af 100644 --- a/crates/polars-arrow/src/trusted_len/boolean.rs +++ b/crates/polars-arrow/src/trusted_len/boolean.rs @@ -1,10 +1,11 @@ use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; +use arrow::datatypes::DataType; use crate::array::default_arrays::FromData; -use crate::trusted_len::TrustedLen; +use crate::bit_util::{set_bit_raw, unset_bit_raw}; +use crate::trusted_len::{FromIteratorReversed, TrustedLen}; use crate::utils::FromTrustedLenIterator; - impl FromTrustedLenIterator> for BooleanArray { fn from_iter_trusted_length>>(iter: I) -> Self where @@ -30,3 +31,54 @@ impl FromTrustedLenIterator for BooleanArray { } } } + +impl FromIteratorReversed for BooleanArray { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let vals = MutableBitmap::from_len_zeroed(size); + let vals_ptr = vals.as_slice().as_ptr() as *mut u8; + unsafe { + let mut offset = size; + iter.for_each(|item| { + offset -= 1; + if item { + set_bit_raw(vals_ptr, offset); + } + }); + } + BooleanArray::new(DataType::Boolean, vals.into(), None) + } +} + +impl FromIteratorReversed> for BooleanArray { + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let vals = MutableBitmap::from_len_zeroed(size); + let mut validity = MutableBitmap::with_capacity(size); + validity.extend_constant(size, true); + let validity_ptr = validity.as_slice().as_ptr() as *mut u8; + let vals_ptr = vals.as_slice().as_ptr() as *mut u8; + unsafe { + let mut offset = size; + + iter.for_each(|opt_item| { + offset -= 1; + match opt_item { + Some(item) => { + if item { + // Set value (validity bit is already true). + set_bit_raw(vals_ptr, offset); + } + }, + None => { + // Unset validity bit. + unset_bit_raw(validity_ptr, offset) + }, + } + }); + } + BooleanArray::new(DataType::Boolean, vals.into(), Some(validity.into())) + } +} diff --git a/crates/polars-arrow/src/trusted_len/push_unchecked.rs b/crates/polars-arrow/src/trusted_len/push_unchecked.rs index 5d268d070777..f3d830f76fa1 100644 --- a/crates/polars-arrow/src/trusted_len/push_unchecked.rs +++ b/crates/polars-arrow/src/trusted_len/push_unchecked.rs @@ -1,13 +1,13 @@ use super::*; pub trait TrustedLenPush { - /// Will push an item and not check if there is enough capacity + /// Will push an item and not check if there is enough capacity. /// /// # Safety /// Caller must ensure the array has enough capacity to hold `T`. unsafe fn push_unchecked(&mut self, value: T); - /// Extend the array with an iterator who's length can be trusted + /// Extend the array with an iterator who's length can be trusted. fn extend_trusted_len, J: TrustedLen>( &mut self, iter: I, @@ -16,9 +16,16 @@ pub trait TrustedLenPush { } /// # Safety - /// Caller must ensure the iterators reported length is correct + /// Caller must ensure the iterators reported length is correct. unsafe fn extend_trusted_len_unchecked>(&mut self, iter: I); + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn try_extend_trusted_len_unchecked>>( + &mut self, + iter: I, + ) -> Result<(), E>; + fn from_trusted_len_iter, J: TrustedLen>( iter: I, ) -> Self @@ -28,8 +35,28 @@ pub trait TrustedLenPush { unsafe { Self::from_trusted_len_iter_unchecked(iter) } } /// # Safety - /// Caller must ensure the iterators reported length is correct + /// Caller must ensure the iterators reported length is correct. unsafe fn from_trusted_len_iter_unchecked>(iter: I) -> Self; + + fn try_from_trusted_len_iter< + E, + I: IntoIterator, IntoIter = J>, + J: TrustedLen, + >( + iter: I, + ) -> Result + where + Self: Sized, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn try_from_trusted_len_iter_unchecked>>( + iter: I, + ) -> Result + where + Self: Sized; } impl TrustedLenPush for Vec { @@ -55,10 +82,38 @@ impl TrustedLenPush for Vec { self.set_len(self.len() + upper) } + unsafe fn try_extend_trusted_len_unchecked>>( + &mut self, + iter: I, + ) -> Result<(), E> { + let iter = iter.into_iter(); + let upper = iter.size_hint().1.expect("must have an upper bound"); + self.reserve(upper); + + let mut dst = self.as_mut_ptr().add(self.len()); + for value in iter { + std::ptr::write(dst, value?); + dst = dst.add(1) + } + self.set_len(self.len() + upper); + Ok(()) + } + #[inline] unsafe fn from_trusted_len_iter_unchecked>(iter: I) -> Self { let mut v = vec![]; v.extend_trusted_len_unchecked(iter); v } + + unsafe fn try_from_trusted_len_iter_unchecked>>( + iter: I, + ) -> Result + where + Self: Sized, + { + let mut v = vec![]; + v.try_extend_trusted_len_unchecked(iter)?; + Ok(v) + } } diff --git a/crates/polars-arrow/src/utils.rs b/crates/polars-arrow/src/utils.rs index b5201c9847eb..b498046674e8 100644 --- a/crates/polars-arrow/src/utils.rs +++ b/crates/polars-arrow/src/utils.rs @@ -1,9 +1,11 @@ use std::ops::{BitAnd, BitOr}; use arrow::array::PrimitiveArray; -use arrow::bitmap::Bitmap; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::DataType; use arrow::types::NativeType; +use crate::bit_util::unset_bit_raw; use crate::trusted_len::{FromIteratorReversed, TrustedLen, TrustedLenPush}; #[derive(Clone)] @@ -67,15 +69,6 @@ pub fn combine_validities_or(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> unsafe impl arrow::trusted_len::TrustedLen for TrustMyLength where I: Iterator {} pub trait CustomIterTools: Iterator { - fn fold_first_(mut self, f: F) -> Option - where - Self: Sized, - F: FnMut(Self::Item, Self::Item) -> Self::Item, - { - let first = self.next()?; - Some(self.fold(first, f)) - } - /// Turn any iterator in a trusted length iterator /// /// # Safety @@ -170,6 +163,61 @@ impl FromTrustedLenIterator for PrimitiveArray { } } +impl FromIteratorReversed for PrimitiveArray { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let mut vals: Vec = Vec::with_capacity(size); + unsafe { + // Set to end of buffer. + let mut ptr = vals.as_mut_ptr().add(size); + + iter.for_each(|item| { + ptr = ptr.sub(1); + std::ptr::write(ptr, item); + }); + vals.set_len(size) + } + PrimitiveArray::new(DataType::from(T::PRIMITIVE), vals.into(), None) + } +} + +impl FromIteratorReversed> for PrimitiveArray { + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let mut vals: Vec = Vec::with_capacity(size); + let mut validity = MutableBitmap::with_capacity(size); + validity.extend_constant(size, true); + let validity_ptr = validity.as_slice().as_ptr() as *mut u8; + unsafe { + // Set to end of buffer. + let mut ptr = vals.as_mut_ptr().add(size); + let mut offset = size; + + iter.for_each(|opt_item| { + offset -= 1; + ptr = ptr.sub(1); + match opt_item { + Some(item) => { + std::ptr::write(ptr, item); + }, + None => { + std::ptr::write(ptr, T::default()); + unset_bit_raw(validity_ptr, offset) + }, + } + }); + vals.set_len(size) + } + PrimitiveArray::new( + DataType::from(T::PRIMITIVE), + vals.into(), + Some(validity.into()), + ) + } +} + macro_rules! with_match_primitive_type {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* ) => ({ diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index adb530002559..5579749de621 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -9,14 +9,16 @@ repository = { workspace = true } description = "Core of the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", features = ["compute"] } -polars-error = { version = "0.32.0", path = "../polars-error" } -polars-row = { version = "0.32.0", path = "../polars-row" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true, features = ["compute"] } +polars-error = { workspace = true } +polars-row = { workspace = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } +arrow-array = { workspace = true, optional = true } bitflags = { workspace = true } +bytemuck = { workspace = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } comfy-table = { version = "7.0.1", default_features = false, optional = true } @@ -26,10 +28,9 @@ indexmap = { workspace = true } itoap = { version = "1", optional = true, features = ["simd"] } ndarray = { version = "0.15", optional = true, default_features = false } num-traits = { workspace = true } -object_store = { workspace = true, optional = true } once_cell = { workspace = true } rand = { workspace = true, optional = true, features = ["small_rng", "std"] } -rand_distr = { version = "0.4", optional = true } +rand_distr = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true, optional = true } # activate if you want serde support for Series and DataFrames @@ -37,7 +38,6 @@ serde = { workspace = true, features = ["derive"], optional = true } serde_json = { workspace = true, optional = true } smartstring = { workspace = true } thiserror = { workspace = true } -url = { workspace = true, optional = true } xxhash-rust = { workspace = true } [dev-dependencies] @@ -54,7 +54,8 @@ avx512 = [] docs = [] temporal = ["regex", "chrono", "polars-error/regex"] random = ["rand", "rand_distr"] -default = ["docs", "temporal"] +algorithm_group_by = [] +default = ["algorithm_group_by"] lazy = [] # ~40% faster collect, needed until trustedlength iter stabilizes @@ -62,7 +63,7 @@ lazy = [] performant = ["polars-arrow/performant", "reinterpret"] # extra utilities for Utf8Chunked -strings = ["regex", "polars-arrow/strings", "arrow/compute_substring", "polars-error/regex"] +strings = ["regex", "polars-arrow/strings", "polars-error/regex"] # support for ObjectChunked (downcastable Series of any type) object = ["serde_json"] @@ -80,24 +81,19 @@ rows = [] zip_with = [] round_series = [] checked_arithmetic = [] -repeat_by = [] -is_first = [] -is_last = [] +is_first_distinct = [] +is_last_distinct = [] asof_join = [] -cross_join = [] dot_product = [] -concat_str = [] row_hash = [] reinterpret = [] take_opt_iter = [] -mode = [] # allow group_by operation on list type group_by_list = [] # cumsum, cummin, etc. cum_agg = [] # rolling window functions rolling_window = [] -rank = [] diff = [] pct_change = ["diff"] moment = [] @@ -109,11 +105,11 @@ dataframe_arithmetic = [] product = [] unique_counts = [] partition_by = [] -semi_anti_join = [] chunked_ids = [] describe = [] timezones = ["chrono-tz", "arrow/chrono-tz", "polars-arrow/timezones"] dynamic_group_by = ["dtype-datetime", "dtype-date"] +arrow_rs = ["arrow-array", "arrow/arrow_rs"] # opt-in datatypes for Series dtype-date = ["temporal"] @@ -132,7 +128,7 @@ dtype-struct = [] parquet = ["arrow/io_parquet"] # scale to terabytes? -bigidx = ["polars-arrow/bigidx"] +bigidx = ["polars-arrow/bigidx", "polars-utils/bigidx"] python = [] serde = ["dep:serde", "smartstring/serde", "bitflags/serde"] @@ -150,22 +146,17 @@ docs-selection = [ "zip_with", "round_series", "checked_arithmetic", - "repeat_by", - "is_first", - "is_last", + "is_first_distinct", + "is_last_distinct", "asof_join", - "cross_join", "dot_product", - "concat_str", "row_hash", - "mode", "cum_agg", "rolling_window", "diff", "moment", "dtype-categorical", "dtype-decimal", - "rank", "diagonal_concat", "horizontal_concat", "abs", @@ -174,16 +165,10 @@ docs-selection = [ "unique_counts", "describe", "chunked_ids", - "semi_anti_join", "partition_by", + "algorithm_group_by", ] -# Cloud support. -"async" = ["url"] -"aws" = ["async", "object_store/aws"] -"azure" = ["async", "object_store/azure"] -"gcp" = ["async", "object_store/gcp"] - [package.metadata.docs.rs] # not all because arrow 4.3 does not compile with simd # all-features = true diff --git a/crates/polars-core/README.md b/crates/polars-core/README.md index cf77fc95019d..684c5f33832c 100644 --- a/crates/polars-core/README.md +++ b/crates/polars-core/README.md @@ -1,5 +1,5 @@ # polars-core -`polars-core` is a sub-crate that provides core functionality for the Polars dataframe library. +`polars-core` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, providing its core functionalities. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-core/src/chunked_array/arithmetic/numeric.rs b/crates/polars-core/src/chunked_array/arithmetic/numeric.rs index 6efa9a3ffa13..940563f50c23 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/numeric.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/numeric.rs @@ -286,7 +286,7 @@ where let mut out = self .apply_kernel(&|arr| Box::new(::div_scalar(arr, &rhs))); - if rhs < T::Native::zero() { + if rhs.tot_lt(&T::Native::zero()) { out.set_sorted_flag(self.is_sorted_flag().reverse()); } else { out.set_sorted_flag(self.is_sorted_flag()); diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index c28ec0b68b1a..4c2d637c835f 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -5,13 +5,13 @@ use crate::chunked_array::list::iterator::AmortizedListIter; use crate::series::unstable::{ArrayBox, UnstableSeries}; impl ArrayChunked { - /// This is an iterator over a ListChunked that save allocations. + /// This is an iterator over a [`ListChunked`] that save allocations. /// A Series is: /// 1. [`Arc`] /// ChunkedArray is: /// 2. Vec< 3. ArrayRef> /// - /// The ArrayRef we indicated with 3. will be updated during iteration. + /// The [`ArrayRef`] we indicated with 3. will be updated during iteration. /// The Series will be pinned in memory, saving an allocation for /// 1. Arc<..> /// 2. Vec<...> @@ -47,7 +47,7 @@ impl ArrayChunked { // Safety: // inner type passed as physical type let series_container = unsafe { - Box::new(Series::from_chunks_and_dtype_unchecked( + Box::pin(Series::from_chunks_and_dtype_unchecked( name, vec![inner_values.clone()], &iter_dtype, diff --git a/crates/polars-core/src/chunked_array/array/mod.rs b/crates/polars-core/src/chunked_array/array/mod.rs index f2dba184f6ba..0b4e6d5c99ca 100644 --- a/crates/polars-core/src/chunked_array/array/mod.rs +++ b/crates/polars-core/src/chunked_array/array/mod.rs @@ -32,7 +32,7 @@ impl ArrayChunked { let inner_dtype = self.inner_dtype().to_arrow(); let arr = ca.downcast_iter().next().unwrap(); unsafe { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( self.name(), vec![(arr.values()).clone()], &inner_dtype, @@ -41,7 +41,7 @@ impl ArrayChunked { } } - /// Ignore the list indices and apply `func` to the inner type as `Series`. + /// Ignore the list indices and apply `func` to the inner type as [`Series`]. pub fn apply_to_inner( &self, func: &dyn Fn(Series) -> PolarsResult, @@ -52,7 +52,7 @@ impl ArrayChunked { let chunks = ca.downcast_iter().map(|arr| { let elements = unsafe { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( self.name(), vec![(*arr.values()).clone()], &inner_dtype, diff --git a/crates/polars-core/src/chunked_array/builder/list/null.rs b/crates/polars-core/src/chunked_array/builder/list/null.rs index 70346ed32071..cca037cda105 100644 --- a/crates/polars-core/src/chunked_array/builder/list/null.rs +++ b/crates/polars-core/src/chunked_array/builder/list/null.rs @@ -12,12 +12,18 @@ impl ListNullChunkedBuilder { name: name.into(), } } + + pub(crate) fn append(&mut self, s: &Series) { + let value_builder = self.builder.mut_values(); + value_builder.extend_nulls(s.len()); + self.builder.try_push_valid().unwrap(); + } } impl ListBuilderTrait for ListNullChunkedBuilder { #[inline] - fn append_series(&mut self, _s: &Series) -> PolarsResult<()> { - self.builder.push_null(); + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + self.append(s); Ok(()) } diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index c00521f125b8..1db996fe618f 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -241,8 +241,8 @@ mod test { // Test list collect. let out = [&s1, &s2].iter().copied().collect::(); - assert_eq!(out.get(0).unwrap().len(), 6); - assert_eq!(out.get(1).unwrap().len(), 3); + assert_eq!(out.get_as_series(0).unwrap().len(), 6); + assert_eq!(out.get_as_series(1).unwrap().len(), 3); let mut builder = ListPrimitiveChunkedBuilder::::new("a", 10, 5, DataType::Int32); diff --git a/crates/polars-core/src/chunked_array/collect.rs b/crates/polars-core/src/chunked_array/collect.rs new file mode 100644 index 000000000000..739cc0c6f5c8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/collect.rs @@ -0,0 +1,171 @@ +//! Methods for collecting into a ChunkedArray. +//! +//! For types that don't have dtype parameters: +//! iter.(try_)collect_ca(_trusted) (name) +//! +//! For all types: +//! iter.(try_)collect_ca(_trusted)_like (other_df) Copies name/dtype from other_df +//! iter.(try_)collect_ca(_trusted)_with_dtype (name, df) +//! +//! The try variants work on iterators of Results, the trusted variants do not +//! check the length of the iterator. + +use std::sync::Arc; + +use polars_arrow::trusted_len::TrustedLen; + +use crate::chunked_array::ChunkedArray; +use crate::datatypes::{ + ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, DataType, Field, PolarsDataType, +}; + +pub trait ChunkedCollectIterExt: Iterator + Sized { + #[inline] + fn collect_ca_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.collect_arr_with_dtype(dtype); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_like(self, name_dtype_src: &ChunkedArray) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.collect_arr_with_dtype(field.dtype.clone()); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + Self: TrustedLen, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.collect_arr_trusted_with_dtype(dtype); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted_like(self, name_dtype_src: &ChunkedArray) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + Self: TrustedLen, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.collect_arr_trusted_with_dtype(field.dtype.clone()); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn try_collect_ca_with_dtype( + self, + name: &str, + dtype: DataType, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator>, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.try_collect_arr_with_dtype(dtype)?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_like( + self, + name_dtype_src: &ChunkedArray, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator>, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.try_collect_arr_with_dtype(field.dtype.clone())?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted_with_dtype( + self, + name: &str, + dtype: DataType, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.try_collect_arr_trusted_with_dtype(dtype)?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted_like( + self, + name_dtype_src: &ChunkedArray, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.try_collect_arr_trusted_with_dtype(field.dtype.clone())?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } +} + +impl ChunkedCollectIterExt for I {} + +pub trait ChunkedCollectInferIterExt: Iterator + Sized { + #[inline] + fn collect_ca(self, name: &str) -> ChunkedArray + where + T::Array: ArrayFromIter, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.collect_arr(); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted(self, name: &str) -> ChunkedArray + where + T::Array: ArrayFromIter, + Self: TrustedLen, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.collect_arr_trusted(); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn try_collect_ca(self, name: &str) -> Result, E> + where + T::Array: ArrayFromIter, + Self: Iterator>, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.try_collect_arr()?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted(self, name: &str) -> Result, E> + where + T::Array: ArrayFromIter, + Self: Iterator> + TrustedLen, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.try_collect_arr_trusted()?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } +} + +impl ChunkedCollectInferIterExt for I {} diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 31540fbba65a..a0ee4d753b36 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -9,7 +9,6 @@ use arrow::compute::comparison; use arrow::scalar::{BinaryScalar, PrimitiveScalar, Scalar, Utf8Scalar}; use either::Either; use num_traits::{NumCast, ToPrimitive}; -use polars_arrow::kernels::rolling::compare_fn_nan_max; use polars_arrow::prelude::FromData; use crate::prelude::*; @@ -680,48 +679,84 @@ impl ChunkCompare<&str> for Utf8Chunked { } } +#[doc(hidden)] +fn _list_comparison_helper(lhs: &ListChunked, rhs: &ListChunked, op: F) -> BooleanChunked +where + F: Fn(Option<&Series>, Option<&Series>) -> Option, +{ + match (lhs.len(), rhs.len()) { + (_, 1) => { + let right = rhs.get_as_series(0).map(|s| s.with_name("")); + // SAFETY: values within iterator do not outlive the iterator itself + unsafe { + lhs.amortized_iter() + .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) + .collect_trusted() + } + }, + (1, _) => { + let left = lhs.get_as_series(0).map(|s| s.with_name("")); + // SAFETY: values within iterator do not outlive the iterator itself + unsafe { + rhs.amortized_iter() + .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) + .collect_trusted() + } + }, + // SAFETY: values within iterator do not outlive the iterator itself + _ => unsafe { + lhs.amortized_iter() + .zip(rhs.amortized_iter()) + .map(|(left, right)| { + op( + left.as_ref().map(|us| us.as_ref()), + right.as_ref().map(|us| us.as_ref()), + ) + }) + .collect_trusted() + }, + } +} + impl ChunkCompare<&ListChunked> for ListChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ListChunked) -> BooleanChunked { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => Some(l.as_ref().series_equal_missing(r.as_ref())), - _ => None, - }) - .collect_trusted() + let _series_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(l.series_equal(r)), + _ => None, + }; + + _list_comparison_helper(self, rhs, _series_equal) } fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => l.as_ref().series_equal_missing(r.as_ref()), - (None, None) => true, - _ => false, - }) - .collect_trusted() + let _series_equal_missing = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(l.series_equal_missing(r)), + (None, None) => Some(true), + _ => Some(false), + }; + + _list_comparison_helper(self, rhs, _series_equal_missing) } fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => Some(!l.as_ref().series_equal_missing(r.as_ref())), - _ => None, - }) - .collect_trusted() + let _series_not_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(!l.series_equal(r)), + _ => None, + }; + + _list_comparison_helper(self, rhs, _series_not_equal) } fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => !l.as_ref().series_equal_missing(r.as_ref()), - (None, None) => false, - _ => true, - }) - .collect_trusted() + let _series_not_equal_missing = + |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(!l.series_equal_missing(r)), + (None, None) => Some(false), + _ => Some(true), + }; + + _list_comparison_helper(self, rhs, _series_not_equal_missing) } // The following are not implemented because gt, lt comparison of series don't make sense. @@ -1216,6 +1251,17 @@ mod test { assert_eq!(Vec::from(&c), &[Some(true), Some(false), None]) } + #[test] + fn list_broadcasting_lists() { + let s_el = Series::new("", &[1, 2, 3]); + let s_lhs = Series::new("", &[s_el.clone(), s_el.clone()]); + let s_rhs = Series::new("", &[s_el.clone()]); + + let result = s_lhs.list().unwrap().equal(s_rhs.list().unwrap()); + assert_eq!(result.len(), 2); + assert!(result.all()); + } + #[test] fn test_broadcasting_bools() { let a = BooleanChunked::from_slice("", &[true, false, true]); diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs index 22779f05c3a1..0391aa1eea3c 100644 --- a/crates/polars-core/src/chunked_array/comparison/scalar.rs +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -1,5 +1,3 @@ -use std::cmp::Ordering; - use super::*; impl ChunkedArray @@ -17,54 +15,36 @@ where } } -fn binary_search( +/// Splits the ChunkedArray into a lower part, where is_lower returns true, and +/// an upper part where it returns false, and returns a mask where the lower part +/// has value lower_part, and the upper part !lower_part. +/// The ChunkedArray is assumed to be sorted w.r.t. is_lower, that is, is_lower +/// first always returns true, and then always returns false. +fn partition_mask( ca: &ChunkedArray, - // lhs part of mask will be set to boolean - // rhs part of mask will be set to !boolean lower_part: bool, - cmp_fn: F, + is_lower: F, ) -> BooleanChunked where - F: Fn(&T::Native) -> Ordering + Copy, + F: Fn(&T::Native) -> bool, { let chunks = ca.downcast_iter().map(|arr| { let values = arr.values(); - let mask = match values.binary_search_by(cmp_fn) { - Err(mut idx) => { - if idx == 0 || idx == arr.len() { - let mut mask = MutableBitmap::with_capacity(arr.len()); - let fill_value = if idx == 0 { !lower_part } else { lower_part }; - mask.extend_constant(arr.len(), fill_value); - BooleanArray::from_data_default(mask.into(), None) - } else { - let found_ordering = cmp_fn(&values[idx]); - - idx = idx.saturating_sub(1); - loop { - let current_value = unsafe { values.get_unchecked(idx) }; - let current_output = cmp_fn(current_value); - - if current_output != found_ordering || idx == 0 { - break; - } - - idx = idx.saturating_sub(1); - } - idx += 1; - let mut mask = MutableBitmap::with_capacity(arr.len()); - mask.extend_constant(idx, lower_part); - mask.extend_constant(arr.len() - idx, !lower_part); - BooleanArray::from_data_default(mask.into(), None) - } - }, - Ok(_) => { - unreachable!() - }, - }; - mask + let lower_len = values.partition_point(&is_lower); + let mut mask = MutableBitmap::with_capacity(arr.len()); + mask.extend_constant(lower_len, lower_part); + mask.extend_constant(arr.len() - lower_len, !lower_part); + BooleanArray::from_data_default(mask.into(), None) }); - BooleanChunked::from_chunk_iter(ca.name(), chunks) + let output_order = if lower_part { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + let mut ca = BooleanChunked::from_chunk_iter(ca.name(), chunks); + ca.set_sorted_flag(output_order); + ca } impl ChunkCompare for ChunkedArray @@ -91,16 +71,13 @@ where fn gt(&self, rhs: Rhs) -> BooleanChunked { match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) if self.len() > 1 => { + (IsSorted::Ascending, 0) => { let rhs: T::Native = NumCast::from(rhs).unwrap(); - - let cmp_fn = |a: &T::Native| match compare_fn_nan_max(a, &rhs) { - Ordering::Equal | Ordering::Less => Ordering::Less, - _ => Ordering::Greater, - }; - let mut ca = binary_search(self, false, cmp_fn); - ca.set_sorted_flag(IsSorted::Ascending); - ca + partition_mask(self, false, |x| x.tot_le(&rhs)) + }, + (IsSorted::Descending, 0) => { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + partition_mask(self, true, |x| x.tot_gt(&rhs)) }, _ => self.primitive_compare_scalar(rhs, |l, rhs| comparison::gt_scalar(l, rhs)), } @@ -108,16 +85,13 @@ where fn gt_eq(&self, rhs: Rhs) -> BooleanChunked { match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) if self.len() > 1 => { + (IsSorted::Ascending, 0) => { let rhs: T::Native = NumCast::from(rhs).unwrap(); - - let cmp_fn = |a: &T::Native| match compare_fn_nan_max(a, &rhs) { - Ordering::Equal | Ordering::Greater => Ordering::Greater, - Ordering::Less => Ordering::Less, - }; - let mut ca = binary_search(self, false, cmp_fn); - ca.set_sorted_flag(IsSorted::Ascending); - ca + partition_mask(self, false, |x| x.tot_lt(&rhs)) + }, + (IsSorted::Descending, 0) => { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + partition_mask(self, true, |x| x.tot_ge(&rhs)) }, _ => self.primitive_compare_scalar(rhs, |l, rhs| comparison::gt_eq_scalar(l, rhs)), } @@ -127,14 +101,11 @@ where match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { let rhs: T::Native = NumCast::from(rhs).unwrap(); - - let cmp_fn = |a: &T::Native| match compare_fn_nan_max(a, &rhs) { - Ordering::Equal | Ordering::Greater => Ordering::Greater, - Ordering::Less => Ordering::Less, - }; - let mut ca = binary_search(self, true, cmp_fn); - ca.set_sorted_flag(IsSorted::Ascending); - ca + partition_mask(self, true, |x| x.tot_lt(&rhs)) + }, + (IsSorted::Descending, 0) => { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + partition_mask(self, false, |x| x.tot_ge(&rhs)) }, _ => self.primitive_compare_scalar(rhs, |l, rhs| comparison::lt_scalar(l, rhs)), } @@ -144,14 +115,11 @@ where match (self.is_sorted_flag(), self.null_count()) { (IsSorted::Ascending, 0) => { let rhs: T::Native = NumCast::from(rhs).unwrap(); - - let cmp_fn = |a: &T::Native| match compare_fn_nan_max(a, &rhs) { - Ordering::Greater => Ordering::Greater, - Ordering::Equal | Ordering::Less => Ordering::Less, - }; - let mut ca = binary_search(self, true, cmp_fn); - ca.set_sorted_flag(IsSorted::Ascending); - ca + partition_mask(self, true, |x| x.tot_le(&rhs)) + }, + (IsSorted::Descending, 0) => { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + partition_mask(self, false, |x| x.tot_gt(&rhs)) }, _ => self.primitive_compare_scalar(rhs, |l, rhs| comparison::lt_eq_scalar(l, rhs)), } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 7e55c2d04355..b20ea1cde3ca 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -1,6 +1,6 @@ use super::*; -#[allow(clippy::ptr_arg)] +#[allow(clippy::all)] fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataType { // ensure we don't get List let dtype = if let Some(arr) = chunks.get(0) { @@ -19,7 +19,7 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy let list_arr = array.as_any().downcast_ref::>().unwrap(); let values_arr = list_arr.values(); let cat = unsafe { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( "", vec![values_arr.clone()], values_arr.data_type(), @@ -46,7 +46,7 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy let list_arr = array.as_any().downcast_ref::().unwrap(); let values_arr = list_arr.values(); let cat = unsafe { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( "", vec![values_arr.clone()], values_arr.data_type(), @@ -72,8 +72,8 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy impl From for ChunkedArray where - T: PolarsDataType, - A: StaticallyMatchesPolarsType + Array, + T: PolarsDataType, + A: Array, { fn from(arr: A) -> Self { Self::with_chunk("", arr) @@ -86,7 +86,8 @@ where { pub fn with_chunk(name: &str, arr: A) -> Self where - A: StaticallyMatchesPolarsType + Array, + A: Array, + T: PolarsDataType, { unsafe { Self::from_chunks(name, vec![Box::new(arr)]) } } @@ -94,7 +95,8 @@ where pub fn from_chunk_iter(name: &str, iter: I) -> Self where I: IntoIterator, - ::Item: StaticallyMatchesPolarsType + Array, + T: PolarsDataType::Item>, + ::Item: Array, { let chunks = iter .into_iter() @@ -103,10 +105,24 @@ where unsafe { Self::from_chunks(name, chunks) } } + pub fn from_chunk_iter_like(ca: &Self, iter: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + let chunks = iter + .into_iter() + .map(|x| Box::new(x) as Box) + .collect(); + unsafe { Self::from_chunks_and_dtype_unchecked(ca.name(), chunks, ca.dtype().clone()) } + } + pub fn try_from_chunk_iter(name: &str, iter: I) -> Result where I: IntoIterator>, - A: StaticallyMatchesPolarsType + Array, + T: PolarsDataType, + A: Array, { let chunks: Result<_, _> = iter .into_iter() @@ -115,7 +131,36 @@ where unsafe { Ok(Self::from_chunks(name, chunks?)) } } - /// Create a new ChunkedArray from existing chunks. + pub(crate) fn from_chunk_iter_and_field(field: Arc, chunks: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + assert_eq!( + std::mem::discriminant(&T::get_dtype()), + std::mem::discriminant(&field.dtype) + ); + + let mut length = 0; + let chunks = chunks + .into_iter() + .map(|x| { + length += x.len(); + Box::new(x) as Box + }) + .collect(); + + ChunkedArray { + field, + chunks, + phantom: PhantomData, + bit_settings: Default::default(), + length: length.try_into().unwrap(), + } + } + + /// Create a new [`ChunkedArray`] from existing chunks. /// /// # Safety /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. @@ -126,15 +171,13 @@ where dtype @ DataType::Array(_, _) => from_chunks_list_dtype(&mut chunks, dtype), dt => dt, }; - // assertions in debug mode - // that check if the data types in the arrays are as expected - #[cfg(debug_assertions)] - { - if !chunks.is_empty() && dtype.is_primitive() { - assert_eq!(chunks[0].data_type(), &dtype.to_physical().to_arrow()) - } - } - let field = Arc::new(Field::new(name, dtype)); + Self::from_chunks_and_dtype(name, chunks, dtype) + } + + /// # Safety + /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. + pub unsafe fn with_chunks(&self, chunks: Vec) -> Self { + let field = self.field.clone(); let mut out = ChunkedArray { field, chunks, @@ -146,10 +189,24 @@ where out } + /// Create a new [`ChunkedArray`] from existing chunks. + /// /// # Safety /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. - pub unsafe fn with_chunks(&self, chunks: Vec) -> Self { - let field = self.field.clone(); + pub unsafe fn from_chunks_and_dtype( + name: &str, + chunks: Vec, + dtype: DataType, + ) -> Self { + // assertions in debug mode + // that check if the data types in the arrays are as expected + #[cfg(debug_assertions)] + { + if !chunks.is_empty() && dtype.is_primitive() { + assert_eq!(chunks[0].data_type(), &dtype.to_physical().to_arrow()) + } + } + let field = Arc::new(Field::new(name, dtype)); let mut out = ChunkedArray { field, chunks, @@ -188,29 +245,7 @@ where } out } -} - -impl ListChunked { - pub(crate) unsafe fn from_chunks_and_dtype_unchecked( - name: &str, - chunks: Vec, - dtype: DataType, - ) -> Self { - let field = Arc::new(Field::new(name, dtype)); - let mut out = ChunkedArray { - field, - chunks, - phantom: PhantomData, - bit_settings: Default::default(), - length: 0, - }; - out.compute_len(); - out - } -} -#[cfg(feature = "dtype-array")] -impl ArrayChunked { pub(crate) unsafe fn from_chunks_and_dtype_unchecked( name: &str, chunks: Vec, diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index 3e9d9ce45545..ef46c738f888 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -12,15 +12,15 @@ type LargeBinaryArray = BinaryArray; type LargeListArray = ListArray; pub mod par; -/// A `PolarsIterator` is an iterator over a `ChunkedArray` which contains polars types. A `PolarsIterator` -/// must implement `ExactSizeIterator` and `DoubleEndedIterator`. +/// A [`PolarsIterator`] is an iterator over a [`ChunkedArray`] which contains polars types. A [`PolarsIterator`] +/// must implement [`ExactSizeIterator`] and [`DoubleEndedIterator`]. pub trait PolarsIterator: ExactSizeIterator + DoubleEndedIterator + Send + Sync + TrustedLen { } unsafe impl<'a, I> TrustedLen for Box + 'a> {} -/// Implement PolarsIterator for every iterator that implements the needed traits. +/// Implement [`PolarsIterator`] for every iterator that implements the needed traits. impl PolarsIterator for T where T: ExactSizeIterator + DoubleEndedIterator + Send + Sync + TrustedLen { @@ -54,7 +54,7 @@ impl<'a> IntoIterator for &'a BooleanChunked { } } -/// The no null iterator for a BooleanArray +/// The no null iterator for a [`BooleanArray`] pub struct BoolIterNoNull<'a> { array: &'a BooleanArray, current: usize, @@ -619,8 +619,8 @@ impl<'a> Iterator for StructIter<'a> { } } -/// Wrapper struct to convert an iterator of type `T` into one of type `Option`. It is useful to make the -/// `IntoIterator` trait, in which every iterator shall return an `Option`. +/// Wrapper struct to convert an iterator of type `T` into one of type [`Option`]. It is useful to make the +/// [`IntoIterator`] trait, in which every iterator shall return an [`Option`]. pub struct SomeIterator(I) where I: Iterator; @@ -668,14 +668,14 @@ mod test { ) } - /// Generate test for `IntoIterator` trait for chunked arrays with just one chunk and no null values. - /// The expected return value of the iterator generated by `IntoIterator` trait is `Option`, where + /// Generate test for [`IntoIterator`] trait for chunked arrays with just one chunk and no null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where /// `T` is the chunked array type. /// /// # Input /// /// test_name: The name of the test to generate. - /// ca_type: The chunked array to use for this test. Ex: `Utf8Chunked`, `UInt32Chunked` ... + /// ca_type: The chunked array to use for this test. Ex: [`Utf8Chunked`], [`UInt32Chunked`] ... /// first_val: The first value contained in the chunked array. /// second_val: The second value contained in the chunked array. /// third_val: The third value contained in the chunked array. @@ -729,17 +729,17 @@ mod test { impl_test_iter_single_chunk!(utf8_iter_single_chunk, Utf8Chunked, "a", "b", "c"); impl_test_iter_single_chunk!(bool_iter_single_chunk, BooleanChunked, true, true, false); - /// Generate test for `IntoIterator` trait for chunked arrays with just one chunk and null values. - /// The expected return value of the iterator generated by `IntoIterator` trait is `Option`, where + /// Generate test for [`IntoIterator`] trait for chunked arrays with just one chunk and null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where /// `T` is the chunked array type. /// /// # Input /// /// test_name: The name of the test to generate. - /// ca_type: The chunked array to use for this test. Ex: `Utf8Chunked`, `UInt32Chunked` ... - /// first_val: The first value contained in the chunked array. Must be an `Option`. - /// second_val: The second value contained in the chunked array. Must be an `Option`. - /// third_val: The third value contained in the chunked array. Must be an `Option`. + /// ca_type: The chunked array to use for this test. Ex: [`Utf8Chunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. Must be an [`Option`]. + /// second_val: The second value contained in the chunked array. Must be an [`Option`]. + /// third_val: The third value contained in the chunked array. Must be an [`Option`]. macro_rules! impl_test_iter_single_chunk_null_check { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] @@ -808,14 +808,14 @@ mod test { Some(false) ); - /// Generate test for `IntoIterator` trait for chunked arrays with many chunks and no null values. - /// The expected return value of the iterator generated by `IntoIterator` trait is `Option`, where + /// Generate test for [`IntoIterator`] trait for chunked arrays with many chunks and no null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where /// `T` is the chunked array type. /// /// # Input /// /// test_name: The name of the test to generate. - /// ca_type: The chunked array to use for this test. Ex: `Utf8Chunked`, `UInt32Chunked` ... + /// ca_type: The chunked array to use for this test. Ex: [`Utf8Chunked`], [`UInt32Chunked`] ... /// first_val: The first value contained in the chunked array. /// second_val: The second value contained in the chunked array. /// third_val: The third value contained in the chunked array. @@ -871,17 +871,17 @@ mod test { impl_test_iter_many_chunk!(utf8_iter_many_chunk, Utf8Chunked, "a", "b", "c"); impl_test_iter_many_chunk!(bool_iter_many_chunk, BooleanChunked, true, true, false); - /// Generate test for `IntoIterator` trait for chunked arrays with many chunk and null values. - /// The expected return value of the iterator generated by `IntoIterator` trait is `Option`, where + /// Generate test for [`IntoIterator`] trait for chunked arrays with many chunk and null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where /// `T` is the chunked array type. /// /// # Input /// /// test_name: The name of the test to generate. - /// ca_type: The chunked array to use for this test. Ex: `Utf8Chunked`, `UInt32Chunked` ... - /// first_val: The first value contained in the chunked array. Must be an `Option`. - /// second_val: The second value contained in the chunked array. Must be an `Option`. - /// third_val: The third value contained in the chunked array. Must be an `Option`. + /// ca_type: The chunked array to use for this test. Ex: [`Utf8Chunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. Must be an [`Option`]. + /// second_val: The second value contained in the chunked array. Must be an [`Option`]. + /// third_val: The third value contained in the chunked array. Must be an [`Option`]. macro_rules! impl_test_iter_many_chunk_null_check { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] @@ -952,14 +952,14 @@ mod test { Some(false) ); - /// Generate test for `IntoNoNullIterator` trait for chunked arrays with just one chunk and no null values. - /// The expected return value of the iterator generated by `IntoNoNullIterator` trait is `T`, where + /// Generate test for [`IntoNoNullIterator`] trait for chunked arrays with just one chunk and no null values. + /// The expected return value of the iterator generated by [`IntoNoNullIterator`] trait is `T`, where /// `T` is the chunked array type. /// /// # Input /// /// test_name: The name of the test to generate. - /// ca_type: The chunked array to use for this test. Ex: `Utf8Chunked`, `UInt32Chunked` ... + /// ca_type: The chunked array to use for this test. Ex: [`Utf8Chunked`], [`UInt32Chunked`] ... /// first_val: The first value contained in the chunked array. /// second_val: The second value contained in the chunked array. /// third_val: The third value contained in the chunked array. @@ -1025,14 +1025,14 @@ mod test { false ); - /// Generate test for `IntoNoNullIterator` trait for chunked arrays with many chunks and no null values. - /// The expected return value of the iterator generated by `IntoNoNullIterator` trait is `T`, where + /// Generate test for [`IntoNoNullIterator`] trait for chunked arrays with many chunks and no null values. + /// The expected return value of the iterator generated by [`IntoNoNullIterator`] trait is `T`, where /// `T` is the chunked array type. /// /// # Input /// /// test_name: The name of the test to generate. - /// ca_type: The chunked array to use for this test. Ex: `Utf8Chunked`, `UInt32Chunked` ... + /// ca_type: The chunked array to use for this test. Ex: [`Utf8Chunked`], [`UInt32Chunked`] ... /// first_val: The first value contained in the chunked array. /// second_val: The second value contained in the chunked array. /// third_val: The third value contained in the chunked array. @@ -1172,12 +1172,12 @@ mod test { a }); - /// Generates a `Vec` of `bool`, with even indexes are true, and odd indexes are false. + /// Generates a [`Vec`] of [`bool`], with even indexes are true, and odd indexes are false. fn generate_boolean_vec(size: usize) -> Vec { (0..size).map(|n| n % 2 == 0).collect() } - /// Generate a `Vec` of `Option`, where: + /// Generate a [`Vec`] of [`Option`], where: /// - If the index is divisible by 3, then, the value is `None`. /// - If the index is not divisible by 3 and it is even, then, the value is `Some(true)`. /// - Otherwise, the value is `Some(false)`. diff --git a/crates/polars-core/src/chunked_array/kernels/mod.rs b/crates/polars-core/src/chunked_array/kernels/mod.rs deleted file mode 100644 index 66d56923fa2b..000000000000 --- a/crates/polars-core/src/chunked_array/kernels/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub(crate) mod take; diff --git a/crates/polars-core/src/chunked_array/kernels/take.rs b/crates/polars-core/src/chunked_array/kernels/take.rs deleted file mode 100644 index 30c9e189483c..000000000000 --- a/crates/polars-core/src/chunked_array/kernels/take.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::convert::TryFrom; - -use polars_arrow::compute::take::bitmap::take_bitmap_unchecked; -use polars_arrow::compute::take::take_value_indices_from_list; -use polars_arrow::utils::combine_validities_and; - -use crate::prelude::*; - -/// Take kernel for multiple chunks. We directly return a ChunkedArray because that path chooses the fastest collection path. -pub(crate) fn take_primitive_iter_n_chunks>( - ca: &ChunkedArray, - indices: I, -) -> ChunkedArray { - let taker = ca.take_rand(); - indices.into_iter().map(|idx| taker.get(idx)).collect() -} - -/// Take kernel for multiple chunks where an iterator can produce None values. -/// Used in join operations. We directly return a ChunkedArray because that path chooses the fastest collection path. -pub(crate) fn take_primitive_opt_iter_n_chunks< - T: PolarsNumericType, - I: IntoIterator>, ->( - ca: &ChunkedArray, - indices: I, -) -> ChunkedArray { - let taker = ca.take_rand(); - indices - .into_iter() - .map(|opt_idx| opt_idx.and_then(|idx| taker.get(idx))) - .collect() -} - -/// This is faster because it does no bounds checks and allocates directly into aligned memory. -/// -/// # Safety -/// No bounds checks -pub(crate) unsafe fn take_list_unchecked( - values: &ListArray, - indices: &IdxArr, -) -> ListArray { - // Taking the whole list or a contiguous sublist. - let (list_indices, offsets) = take_value_indices_from_list(values, indices); - - // Temporary series so that we can take primitives from it. - let s = Series::try_from(("", values.values().clone() as ArrayRef)).unwrap(); - let taken = s.take_unchecked(&list_indices.into()).unwrap(); - - let taken = taken.array_ref(0).clone(); - let validity = if let Some(validity) = values.validity() { - let validity = take_bitmap_unchecked(validity, indices.values().as_slice()); - combine_validities_and(Some(&validity), indices.validity()) - } else { - indices.validity().cloned() - }; - - let dtype = ListArray::::default_datatype(taken.data_type().clone()); - // SAFETY: offsets are monotonically increasing. - ListArray::new(dtype, offsets.into(), taken, validity) -} diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 253727531269..2dc2c7eb8559 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::pin::Pin; use std::ptr::NonNull; use crate::prelude::*; @@ -7,7 +8,7 @@ use crate::utils::CustomIterTools; pub struct AmortizedListIter<'a, I: Iterator>> { len: usize, - series_container: Box, + series_container: Pin>, inner: NonNull, lifetime: PhantomData<&'a ArrayRef>, iter: I, @@ -19,7 +20,7 @@ pub struct AmortizedListIter<'a, I: Iterator>> { impl<'a, I: Iterator>> AmortizedListIter<'a, I> { pub(crate) fn new( len: usize, - series_container: Box, + series_container: Pin>, inner: NonNull, iter: I, inner_dtype: DataType, @@ -95,7 +96,7 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a unsafe impl<'a, I: Iterator>> TrustedLen for AmortizedListIter<'a, I> {} impl ListChunked { - /// This is an iterator over a ListChunked that save allocations. + /// This is an iterator over a [`ListChunked`] that save allocations. /// A Series is: /// 1. [`Arc`] /// ChunkedArray is: @@ -111,11 +112,20 @@ impl ListChunked { /// this function still needs precautions. The returned should never be cloned or taken longer /// than a single iteration, as every call on `next` of the iterator will change the contents of /// that Series. - pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { + /// + /// # Safety + /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive + /// longer than the iterator is UB. + pub unsafe fn amortized_iter( + &self, + ) -> AmortizedListIter> + '_> { self.amortized_iter_with_name("") } - pub fn amortized_iter_with_name( + /// # Safety + /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive + /// longer than the iterator is UB. + pub unsafe fn amortized_iter_with_name( &self, name: &str, ) -> AmortizedListIter> + '_> { @@ -143,7 +153,7 @@ impl ListChunked { &iter_dtype, ); s.clear_settings(); - Box::new(s) + Box::pin(s) }; let ptr = series_container.array_ref(0) as *const ArrayRef as *mut ArrayRef; @@ -157,6 +167,68 @@ impl ListChunked { ) } + /// Apply a closure `F` elementwise. + #[must_use] + pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray + where + V: PolarsDataType, + F: FnMut(Option>) -> Option + Copy, + V::Array: ArrayFromIter>, + { + // TODO! make an amortized iter that does not flatten + // SAFETY: unstable series never lives longer than the iterator. + unsafe { self.amortized_iter().map(f).collect_ca(self.name()) } + } + + pub fn for_each_amortized<'a, F>(&'a self, f: F) + where + F: FnMut(Option>), + { + // SAFETY: unstable series never lives longer than the iterator. + unsafe { self.amortized_iter().for_each(f) } + } + + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + #[must_use] + pub fn zip_and_apply_amortized<'a, T, I, F>(&'a self, ca: &'a ChunkedArray, mut f: F) -> Self + where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen>>, + F: FnMut(Option>, Option>) -> Option, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = self.null_count() == 0; + // SAFETY: unstable series never lives longer than the iterator. + let mut out: ListChunked = unsafe { + self.amortized_iter() + .zip(ca) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v); + match out { + Some(out) if out.is_empty() => { + fast_explode = false; + Some(out) + }, + None => { + fast_explode = false; + out + }, + _ => out, + } + }) + .collect_trusted() + }; + + out.rename(self.name()); + if fast_explode { + out.set_fast_explode(); + } + out + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self @@ -167,18 +239,20 @@ impl ListChunked { return self.clone(); } let mut fast_explode = self.null_count() == 0; - let mut ca: ListChunked = self - .amortized_iter() - .map(|opt_v| { - opt_v.map(|v| { - let out = f(v); - if out.is_empty() { - fast_explode = false; - } - out + // SAFETY: unstable series never lives longer than the iterator. + let mut ca: ListChunked = unsafe { + self.amortized_iter() + .map(|opt_v| { + opt_v.map(|v| { + let out = f(v); + if out.is_empty() { + fast_explode = false; + } + out + }) }) - }) - .collect_trusted(); + .collect_trusted() + }; ca.rename(self.name()); if fast_explode { @@ -195,22 +269,24 @@ impl ListChunked { return Ok(self.clone()); } let mut fast_explode = self.null_count() == 0; - let mut ca: ListChunked = self - .amortized_iter() - .map(|opt_v| { - opt_v - .map(|v| { - let out = f(v); - if let Ok(out) = &out { - if out.is_empty() { - fast_explode = false - } - }; - out - }) - .transpose() - }) - .collect::>()?; + // SAFETY: unstable series never lives longer than the iterator. + let mut ca: ListChunked = unsafe { + self.amortized_iter() + .map(|opt_v| { + opt_v + .map(|v| { + let out = f(v); + if let Ok(out) = &out { + if out.is_empty() { + fast_explode = false + } + }; + out + }) + .transpose() + }) + .collect::>()? + }; ca.rename(self.name()); if fast_explode { ca.set_fast_explode(); @@ -232,8 +308,11 @@ mod test { builder.append_series(&Series::new("", &[1, 1])).unwrap(); let ca = builder.finish(); - ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| { - assert!(s1.unwrap().as_ref().series_equal(&s2.unwrap())); - }); + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| { + assert!(s1.unwrap().as_ref().series_equal(&s2.unwrap())); + }) + }; } } diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index d679c45bbeb8..ec2a21d14d15 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -29,27 +29,20 @@ impl ListChunked { self.bit_settings.contains(Settings::FAST_EXPLODE_LIST) } - pub(crate) fn is_nested(&self) -> bool { - match self.dtype() { - DataType::List(inner) => matches!(&**inner, DataType::List(_)), - _ => unreachable!(), - } - } - - /// Set the logical type of the ListChunked. + /// Set the logical type of the [`ListChunked`]. pub fn to_logical(&mut self, inner_dtype: DataType) { debug_assert_eq!(inner_dtype.to_physical(), self.inner_dtype()); let fld = Arc::make_mut(&mut self.field); fld.coerce(DataType::List(Box::new(inner_dtype))) } - /// Get the inner values as `Series`, ignoring the list offsets. + /// Get the inner values as [`Series`], ignoring the list offsets. pub fn get_inner(&self) -> Series { let ca = self.rechunk(); let inner_dtype = self.inner_dtype().to_arrow(); let arr = ca.downcast_iter().next().unwrap(); unsafe { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( self.name(), vec![(*arr.values()).clone()], &inner_dtype, @@ -58,7 +51,7 @@ impl ListChunked { } } - /// Ignore the list indices and apply `func` to the inner type as `Series`. + /// Ignore the list indices and apply `func` to the inner type as [`Series`]. pub fn apply_to_inner( &self, func: &dyn Fn(Series) -> PolarsResult, @@ -69,7 +62,7 @@ impl ListChunked { let chunks = ca.downcast_iter().map(|arr| { let elements = unsafe { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( self.name(), vec![(*arr.values()).clone()], &inner_dtype, diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index d8ad21b12f88..4d922b72c9f5 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -6,7 +6,7 @@ use hashbrown::hash_map::{Entry, RawEntryMut}; use polars_arrow::trusted_len::TrustedLenPush; use crate::datatypes::PlHashMap; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; +use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::{using_string_cache, StringCache, POOL}; @@ -92,7 +92,7 @@ impl RevMapping { !self.is_global() } - /// Get the categories in this RevMapping + /// Get the categories in this [`RevMapping`] pub fn get_categories(&self) -> &Utf8Array { match self { Self::Global(_, a, _) => a, @@ -105,7 +105,9 @@ impl RevMapping { self.get_categories().len() } - /// Categorical to str + /// [`Categorical`] to [`str`] + /// + /// [`Categorical`]: crate::datatypes::DataType::Categorical pub fn get(&self, idx: u32) -> &str { match self { Self::Global(map, a, _) => { @@ -126,7 +128,9 @@ impl RevMapping { } } - /// Categorical to str + /// [`Categorical`] to [`str`] + /// + /// [`Categorical`]: crate::datatypes::DataType::Categorical /// /// # Safety /// This doesn't do any bound checking @@ -150,7 +154,10 @@ impl RevMapping { } } - /// str to Categorical + /// [`str`] to [`Categorical`] + /// + /// + /// [`Categorical`]: crate::datatypes::DataType::Categorical pub fn find(&self, value: &str) -> Option { match self { Self::Global(rev_map, a, id) => { @@ -270,7 +277,7 @@ impl<'a> CategoricalChunkedBuilder<'a> { } /// `store_hashes` is not needed by the local builder, only for the global builder under contention - /// The hashes have the same order as the `Utf8Array` values. + /// The hashes have the same order as the [`Utf8Array`] values. fn build_local_map(&mut self, i: I, store_hashes: bool) -> Vec where I: IntoIterator>, @@ -280,8 +287,10 @@ impl<'a> CategoricalChunkedBuilder<'a> { self.hashes = Vec::with_capacity(iter.size_hint().0 / 10) } // It is important that we use the same hash builder as the global `StringCache` does. - self.local_mapping = - PlHashMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, StringCache::get_hash_builder()); + self.local_mapping = PlHashMap::with_capacity_and_hasher( + _HASHMAP_INIT_SIZE, + StringCache::get_hash_builder(), + ); for opt_s in &mut iter { match opt_s { Some(s) => self.push_impl(s, store_hashes), @@ -297,7 +306,7 @@ impl<'a> CategoricalChunkedBuilder<'a> { std::mem::take(&mut self.hashes) } - /// Build a global string cached `CategoricalChunked` from a local `Dictionary`. + /// Build a global string cached [`CategoricalChunked`] from a local [`Dictionary`]. pub(super) fn global_map_from_local(&mut self, keys: &UInt32Array, values: Utf8Array) { // locally we don't need a hashmap because we all categories are 1 integer apart // so the index is local, and the values is global @@ -352,7 +361,7 @@ impl<'a> CategoricalChunkedBuilder<'a> { where I: IntoIterator>, { - // first build the values: `Utf8Array` + // first build the values: [`Utf8Array`] // we can use a local hashmap for that // `hashes.len()` is equal to to the number of unique values. let hashes = self.build_local_map(i, true); @@ -477,7 +486,7 @@ impl CategoricalChunked { pub unsafe fn from_global_indices_unchecked(cats: UInt32Chunked) -> CategoricalChunked { let cache = crate::STRING_CACHE.read_map(); - let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), HASHMAP_INIT_SIZE); + let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), _HASHMAP_INIT_SIZE); let mut rev_map = PlHashMap::with_capacity(cap); let mut str_values = MutableUtf8Array::with_capacities(cap, cap * 24); @@ -503,12 +512,12 @@ impl CategoricalChunked { mod test { use crate::chunked_array::categorical::CategoricalChunkedBuilder; use crate::prelude::*; - use crate::{enable_string_cache, reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; #[test] fn test_categorical_rev() -> PolarsResult<()> { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); + disable_string_cache(); let slice = &[ Some("foo"), None, @@ -523,7 +532,7 @@ mod test { assert_eq!(out.get_rev_map().len(), 2); // test the global branch - enable_string_cache(true); + enable_string_cache(); // empty global cache let out = ca.cast(&DataType::Categorical(None))?; let out = out.categorical().unwrap().clone(); @@ -547,11 +556,13 @@ mod test { #[test] fn test_categorical_builder() { - use crate::{enable_string_cache, reset_string_cache}; + use crate::{disable_string_cache, enable_string_cache}; let _lock = crate::SINGLE_LOCK.lock(); - for b in &[false, true] { - reset_string_cache(); - enable_string_cache(*b); + for use_string_cache in [false, true] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } // Use 2 builders to check if the global string cache // does not interfere with the index mapping diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs index f5e6a66cb1f1..ab7fc5afc51c 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -38,7 +38,7 @@ struct State { #[derive(Default)] pub(crate) struct RevMapMerger { - id: u32, + id: Option, original: Arc, // only initiate state when // we encounter a rev-map from a different source, @@ -48,12 +48,14 @@ pub(crate) struct RevMapMerger { impl RevMapMerger { pub(crate) fn new(rev_map: Arc) -> Self { - let RevMapping::Global(_, _, id) = rev_map.as_ref() else { - panic!("impl error") + let id = if let RevMapping::Global(_, _, id) = rev_map.as_ref() { + Some(*id) + } else { + None }; RevMapMerger { state: None, - id: *id, + id, original: rev_map, } } @@ -74,11 +76,12 @@ impl RevMapMerger { if Arc::ptr_eq(&self.original, rev_map) { return Ok(()); } + let msg = "categoricals don't originate from the same string cache\n\ + try setting a global string cache or increase the scope of the local string cache"; let RevMapping::Global(map, slots, id) = rev_map.as_ref() else { - polars_bail!(ComputeError: "expected global rev-map") + polars_bail!(ComputeError: msg) }; - polars_ensure!(*id == self.id, ComputeError: "categoricals don't originate from the same string cache\n\ - try setting a global string cache or increase the scope of the local string cache"); + polars_ensure!(Some(*id) == self.id, ComputeError: msg); if self.state.is_none() { self.init_state() @@ -103,7 +106,7 @@ impl RevMapMerger { match self.state { None => self.original, Some(state) => { - let new_rev = RevMapping::Global(state.map, state.slots.into(), self.id); + let new_rev = RevMapping::Global(state.map, state.slots.into(), self.id.unwrap()); Arc::new(new_rev) }, } @@ -143,7 +146,7 @@ pub(crate) fn merge_rev_map( } impl CategoricalChunked { - pub(crate) fn merge_categorical_map(&self, other: &Self) -> PolarsResult> { + pub fn _merge_categorical_map(&self, other: &Self) -> PolarsResult> { merge_rev_map(self.get_rev_map(), other.get_rev_map()) } } @@ -153,13 +156,13 @@ impl CategoricalChunked { mod test { use super::*; use crate::chunked_array::categorical::CategoricalChunkedBuilder; - use crate::{enable_string_cache, reset_string_cache, IUseStringCache}; + use crate::{disable_string_cache, enable_string_cache, StringCacheHolder}; #[test] fn test_merge_rev_map() { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); - let _sc = IUseStringCache::hold(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); let mut builder1 = CategoricalChunkedBuilder::new("foo", 10); let mut builder2 = CategoricalChunkedBuilder::new("foo", 10); @@ -167,7 +170,7 @@ mod test { builder2.drain_iter(vec![Some("hello"), None, Some("world"), Some("bar")].into_iter()); let ca1 = builder1.finish(); let ca2 = builder2.finish(); - let rev_map = ca1.merge_categorical_map(&ca2).unwrap(); + let rev_map = ca1._merge_categorical_map(&ca2).unwrap(); let mut ca = UInt32Chunked::new("", &[0, 1, 2, 3]); ca.categorical_map = Some(rev_map); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 6841d6da5ba1..d66e6318b5ef 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -2,12 +2,11 @@ mod builder; mod from; mod merge; mod ops; -pub mod stringcache; +pub mod string_cache; use bitflags::bitflags; pub use builder::*; pub(crate) use merge::*; -pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal}; use polars_utils::sync::SyncPtr; use super::*; @@ -266,12 +265,12 @@ mod test { use std::convert::TryFrom; use super::*; - use crate::{enable_string_cache, reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; #[test] fn test_categorical_round_trip() -> PolarsResult<()> { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); + disable_string_cache(); let slice = &[ Some("foo"), None, @@ -296,8 +295,8 @@ mod test { #[test] fn test_append_categorical() { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); - enable_string_cache(true); + disable_string_cache(); + enable_string_cache(); let mut s1 = Series::new("1", vec!["a", "b", "c"]) .cast(&DataType::Categorical(None)) @@ -330,8 +329,7 @@ mod test { #[test] fn test_categorical_flow() -> PolarsResult<()> { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); - enable_string_cache(false); + disable_string_cache(); // tests several things that may lose the dtype information let s = Series::new("a", vec!["a", "b", "c"]).cast(&DataType::Categorical(None))?; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs index 44587e704983..6385cfba3a1b 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs @@ -20,7 +20,7 @@ impl CategoricalChunked { polars_bail!(string_cache_mismatch); } else { let len = self.len(); - let new_rev_map = self.merge_categorical_map(other)?; + let new_rev_map = self._merge_categorical_map(other)?; unsafe { self.set_rev_map(new_rev_map, false) }; self.logical_mut().length += other.len() as IdxSize; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs index 91f3e293e202..759628b322cb 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs @@ -1,10 +1,8 @@ mod append; mod full; -mod take_random; +#[cfg(feature = "algorithm_group_by")] mod unique; #[cfg(feature = "zip_with")] mod zip; -pub(crate) use take_random::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal}; - use super::*; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/take_random.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/take_random.rs deleted file mode 100644 index b258ffc000e5..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/take_random.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::cmp::Ordering; - -use arrow::array::Utf8Array; - -use crate::prelude::compare_inner::PartialOrdInner; -use crate::prelude::{ - CategoricalChunked, IntoTakeRandom, NumTakeRandomChunked, NumTakeRandomCont, - NumTakeRandomSingleChunk, PlHashMap, RevMapping, TakeRandBranch3, TakeRandom, -}; - -type TakeCats<'a> = TakeRandBranch3< - NumTakeRandomCont<'a, u32>, - NumTakeRandomSingleChunk<'a, u32>, - NumTakeRandomChunked<'a, u32>, ->; - -pub(crate) struct CategoricalTakeRandomLocal<'a> { - rev_map: &'a Utf8Array, - cats: TakeCats<'a>, -} - -impl<'a> CategoricalTakeRandomLocal<'a> { - pub(crate) fn new(ca: &'a CategoricalChunked) -> Self { - // should be rechunked upstream - assert_eq!(ca.logical.chunks.len(), 1, "implementation error"); - if let RevMapping::Local(rev_map) = &**ca.get_rev_map() { - let cats = ca.logical().take_rand(); - Self { rev_map, cats } - } else { - unreachable!() - } - } -} - -impl PartialOrdInner for CategoricalTakeRandomLocal<'_> { - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - let a = self - .cats - .get_unchecked(idx_a) - .map(|cat| self.rev_map.value_unchecked(cat as usize)); - let b = self - .cats - .get_unchecked(idx_b) - .map(|cat| self.rev_map.value_unchecked(cat as usize)); - a.partial_cmp(&b).unwrap() - } -} - -pub(crate) struct CategoricalTakeRandomGlobal<'a> { - rev_map_part_1: &'a PlHashMap, - rev_map_part_2: &'a Utf8Array, - cats: TakeCats<'a>, -} -impl<'a> CategoricalTakeRandomGlobal<'a> { - pub(crate) fn new(ca: &'a CategoricalChunked) -> Self { - // should be rechunked upstream - assert_eq!(ca.logical.chunks.len(), 1, "implementation error"); - if let RevMapping::Global(rev_map_part_1, rev_map_part_2, _) = &**ca.get_rev_map() { - let cats = ca.logical().take_rand(); - Self { - rev_map_part_1, - rev_map_part_2, - cats, - } - } else { - unreachable!() - } - } -} - -impl PartialOrdInner for CategoricalTakeRandomGlobal<'_> { - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - let a = self.cats.get_unchecked(idx_a).map(|cat| { - let idx = self.rev_map_part_1.get(&cat).unwrap(); - self.rev_map_part_2.value_unchecked(*idx as usize) - }); - let b = self.cats.get_unchecked(idx_b).map(|cat| { - let idx = self.rev_map_part_1.get(&cat).unwrap(); - self.rev_map_part_2.value_unchecked(*idx as usize) - }); - a.partial_cmp(&b).unwrap() - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs index 732266ee78d4..7fcad0c73cbe 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs @@ -15,7 +15,7 @@ impl CategoricalChunked { }, _ => self.logical().zip_with(mask, other.logical())?, }; - let new_state = self.merge_categorical_map(other)?; + let new_state = self._merge_categorical_map(other)?; // Safety: // we checked the rev_maps. diff --git a/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs similarity index 60% rename from crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs rename to crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs index 195579e1392b..f39a1523446c 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs @@ -1,6 +1,6 @@ use std::hash::{Hash, Hasher}; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; @@ -8,83 +8,106 @@ use once_cell::sync::Lazy; use smartstring::{LazyCompact, SmartString}; use crate::datatypes::PlIdHashMap; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; +use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::InitHashMaps; -/// We use atomic reference counting -/// to determine how many threads use the string cache -/// if the refcount is zero, we may clear the string cache. -pub(crate) static USE_STRING_CACHE: AtomicU32 = AtomicU32::new(0); +/// We use atomic reference counting to determine how many threads use the +/// string cache. If the refcount is zero, we may clear the string cache. +static STRING_CACHE_REFCOUNT: Mutex = Mutex::new(0); +static STRING_CACHE_ENABLED_GLOBALLY: AtomicBool = AtomicBool::new(false); static STRING_CACHE_UUID_CTR: AtomicU32 = AtomicU32::new(0); -/// RAII for the string cache -/// If an operation creates categoricals and uses them in a join -/// or comparison that operation must hold this cache via -/// `let handle = IUseStringCache::hold()` -/// The cache is valid until `handle` is dropped. +/// Enable the global string cache as long as the object is alive ([RAII]). +/// +/// # Examples +/// +/// Enable the string cache by initializing the object: +/// +/// ``` +/// use polars_core::StringCacheHolder; +/// +/// let _sc = StringCacheHolder::hold(); +/// ``` +/// +/// The string cache is enabled until `handle` is dropped. /// /// # De-allocation +/// /// Multiple threads can hold the string cache at the same time. -/// The contents of the cache will only get dropped when no -/// thread holds it. -pub struct IUseStringCache { +/// The contents of the cache will only get dropped when no thread holds it. +/// +/// [RAII]: https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization +pub struct StringCacheHolder { // only added so that it will never be constructed directly #[allow(dead_code)] private_zst: (), } -impl Default for IUseStringCache { +impl Default for StringCacheHolder { fn default() -> Self { Self::hold() } } -impl IUseStringCache { +impl StringCacheHolder { /// Hold the StringCache - pub fn hold() -> IUseStringCache { - enable_string_cache(true); - IUseStringCache { private_zst: () } + pub fn hold() -> StringCacheHolder { + increment_string_cache_refcount(); + StringCacheHolder { private_zst: () } } } -impl Drop for IUseStringCache { +impl Drop for StringCacheHolder { fn drop(&mut self) { - enable_string_cache(false) + decrement_string_cache_refcount(); } } -pub fn with_string_cache T, T>(func: F) -> T { - enable_string_cache(true); - let out = func(); - enable_string_cache(false); - out +fn increment_string_cache_refcount() { + let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount += 1; +} +fn decrement_string_cache_refcount() { + let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount -= 1; + if *refcount == 0 { + STRING_CACHE.clear() + } } -/// Use a global string cache for the Categorical Types. +/// Enable the global string cache. /// -/// This is used to cache the string categories locally. -/// This allows join operations on categorical types. -pub fn enable_string_cache(toggle: bool) { - if toggle { - USE_STRING_CACHE.fetch_add(1, Ordering::Release); - } else { - let previous = USE_STRING_CACHE.fetch_sub(1, Ordering::Release); - if previous == 0 || previous == 1 { - USE_STRING_CACHE.store(0, Ordering::Release); - STRING_CACHE.clear() - } +/// [`Categorical`] columns created under the same global string cache have the +/// same underlying physical value when string values are equal. This allows the +/// columns to be concatenated or used in a join operation, for example. +/// +/// Note that enabling the global string cache introduces some overhead. +/// The amount of overhead depends on the number of categories in your data. +/// It is advised to enable the global string cache only when strictly necessary. +/// +/// [`Categorical`]: crate::datatypes::DataType::Categorical +pub fn enable_string_cache() { + let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(true, Ordering::AcqRel); + if !was_enabled { + increment_string_cache_refcount(); } } -/// Reset the global string cache used for the Categorical Types. -pub fn reset_string_cache() { - USE_STRING_CACHE.store(0, Ordering::Release); - STRING_CACHE.clear() +/// Disable and clear the global string cache. +/// +/// Note: Consider using [`StringCacheHolder`] for a more reliable way of +/// enabling and disabling the string cache. +pub fn disable_string_cache() { + let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(false, Ordering::AcqRel); + if was_enabled { + decrement_string_cache_refcount(); + } } -/// Check if string cache is set. +/// Check whether the global string cache is enabled. pub fn using_string_cache() -> bool { - USE_STRING_CACHE.load(Ordering::Acquire) > 0 + let refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount > 0 } // This is the hash and the Index offset in the linear buffer @@ -179,9 +202,9 @@ impl SCacheInner { impl Default for SCacheInner { fn default() -> Self { Self { - map: PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE), + map: PlIdHashMap::with_capacity(_HASHMAP_INIT_SIZE), uuid: STRING_CACHE_UUID_CTR.fetch_add(1, Ordering::AcqRel), - payloads: Vec::with_capacity(HASHMAP_INIT_SIZE), + payloads: Vec::with_capacity(_HASHMAP_INIT_SIZE), } } } diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 0151361e13ca..38b5593a92c2 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -43,6 +43,7 @@ impl LogicalType for DateChunked { .into_datetime(*tu, tz.clone()) .into_series()) }, + #[cfg(feature = "dtype-time")] (Date, Time) => Ok(Int64Chunked::full(self.name(), 0i64, self.len()) .into_time() .into_series()), diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index 0c53224f044f..4e3cdcb3a602 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -70,10 +70,10 @@ impl Logical { } pub trait LogicalType { - /// Get data type of ChunkedArray. + /// Get data type of [`ChunkedArray`]. fn dtype(&self) -> &DataType; - /// Gets AnyValue from LogicalType + /// Gets [`AnyValue`] from [`LogicalType`] fn get_any_value(&self, _i: usize) -> PolarsResult> { unimplemented!() } diff --git a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs index 2a523ecf856d..40e9eacbff36 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs @@ -11,7 +11,7 @@ use smartstring::alias::String as SmartString; use super::*; use crate::datatypes::*; -use crate::utils::index_to_chunked_index2; +use crate::utils::index_to_chunked_index; /// This is logical type [`StructChunked`] that /// dispatches most logic to the `fields` implementations @@ -112,7 +112,7 @@ impl StructChunked { } Ok(Self::new_unchecked(name, &new_fields)) } else if fields.is_empty() { - let fields = &[Series::full_null("", 1, &DataType::Null)]; + let fields = &[Series::full_null("", 0, &DataType::Null)]; Ok(Self::new_unchecked(name, fields)) } else { Ok(Self::new_unchecked(name, fields)) @@ -425,7 +425,7 @@ impl LogicalType for StructChunked { } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - let (chunk_idx, idx) = index_to_chunked_index2(&self.chunks, i); + let (chunk_idx, idx) = index_to_chunked_index(self.chunks.iter().map(|c| c.len()), i); if let DataType::Struct(flds) = self.dtype() { // safety: we already have a single chunk and we are // guarded by the type system. diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 19f4600b7510..f93ba54d5c46 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -16,10 +16,10 @@ pub mod ops; pub mod arithmetic; pub mod builder; pub mod cast; +pub mod collect; pub mod comparison; pub mod float; pub mod iterator; -pub mod kernels; #[cfg(feature = "ndarray")] pub(crate) mod ndarray; @@ -62,9 +62,9 @@ pub type ChunkIdIter<'a> = std::iter::Map, fn(&Ar /// # ChunkedArray /// -/// Every Series contains a `ChunkedArray`. Unlike Series, ChunkedArray's are typed. This allows -/// us to apply closures to the data and collect the results to a `ChunkedArray` of the same type `T`. -/// Below we use an apply to use the cosine function to the values of a `ChunkedArray`. +/// Every Series contains a [`ChunkedArray`]. Unlike [`Series`], [`ChunkedArray`]s are typed. This allows +/// us to apply closures to the data and collect the results to a [`ChunkedArray`] of the same type `T`. +/// Below we use an apply to use the cosine function to the values of a [`ChunkedArray`]. /// /// ```rust /// # use polars_core::prelude::*; @@ -74,7 +74,7 @@ pub type ChunkIdIter<'a> = std::iter::Map, fn(&Ar /// ``` /// /// ## Conversion between Series and ChunkedArray's -/// Conversion from a `Series` to a `ChunkedArray` is effortless. +/// Conversion from a [`Series`] to a [`ChunkedArray`] is effortless. /// /// ```rust /// # use polars_core::prelude::*; @@ -89,7 +89,7 @@ pub type ChunkIdIter<'a> = std::iter::Map, fn(&Ar /// /// # Iterators /// -/// `ChunkedArrays` fully support Rust native [Iterator](https://doc.rust-lang.org/std/iter/trait.Iterator.html) +/// [`ChunkedArray`]s fully support Rust native [Iterator](https://doc.rust-lang.org/std/iter/trait.Iterator.html) /// and [DoubleEndedIterator](https://doc.rust-lang.org/std/iter/trait.DoubleEndedIterator.html) traits, thereby /// giving access to all the excellent methods available for [Iterators](https://doc.rust-lang.org/std/iter/trait.Iterator.html). /// @@ -110,28 +110,30 @@ pub type ChunkIdIter<'a> = std::iter::Map, fn(&Ar /// /// # Memory layout /// -/// `ChunkedArray`'s use [Apache Arrow](https://github.com/apache/arrow) as backend for the memory layout. +/// [`ChunkedArray`]s use [Apache Arrow](https://github.com/apache/arrow) as backend for the memory layout. /// Arrows memory is immutable which makes it possible to make multiple zero copy (sub)-views from a single array. /// -/// To be able to append data, Polars uses chunks to append new memory locations, hence the `ChunkedArray` data structure. +/// To be able to append data, Polars uses chunks to append new memory locations, hence the [`ChunkedArray`] data structure. /// Appends are cheap, because it will not lead to a full reallocation of the whole array (as could be the case with a Rust Vec). /// -/// However, multiple chunks in a `ChunkArray` will slow down many operations that need random access because we have an extra indirection +/// However, multiple chunks in a [`ChunkedArray`] will slow down many operations that need random access because we have an extra indirection /// and indexes need to be mapped to the proper chunk. Arithmetic may also be slowed down by this. -/// When multiplying two `ChunkArray'`s with different chunk sizes they cannot utilize [SIMD](https://en.wikipedia.org/wiki/SIMD) for instance. +/// When multiplying two [`ChunkedArray`]s with different chunk sizes they cannot utilize [SIMD](https://en.wikipedia.org/wiki/SIMD) for instance. /// /// If you want to have predictable performance -/// (no unexpected re-allocation of memory), it is advised to call the [ChunkedArray::rechunk] after +/// (no unexpected re-allocation of memory), it is advised to call the [`ChunkedArray::rechunk`] after /// multiple append operations. /// /// See also [`ChunkedArray::extend`] for appends within a chunk. /// /// # Invariants -/// - A `ChunkedArray` should always have at least a single `ArrayRef`. -/// - The [`PolarsDataType`] `T` should always map to the correct [`ArrowDataType`] in the `ArrayRef` +/// - A [`ChunkedArray`] should always have at least a single [`ArrayRef`]. +/// - The [`PolarsDataType`] `T` should always map to the correct [`ArrowDataType`] in the [`ArrayRef`] /// chunks. -/// - Nested datatypes such as `List` and `Array` store the physical types instead of the +/// - Nested datatypes such as [`List`] and [`Array`] store the physical types instead of the /// logical type given by the datatype. +/// +/// [`List`]: crate::datatypes::DataType::List pub struct ChunkedArray { pub(crate) field: Arc, pub(crate) chunks: Vec, @@ -195,7 +197,7 @@ impl ChunkedArray { self.bit_settings } - /// Set flags for the Chunked Array + /// Set flags for the [`ChunkedArray`] pub(crate) fn set_flags(&mut self, flags: Settings) { self.bit_settings = flags; } @@ -209,7 +211,7 @@ impl ChunkedArray { self.bit_settings.set_sorted_flag(sorted) } - /// Get the index of the first non null value in this ChunkedArray. + /// Get the index of the first non null value in this [`ChunkedArray`]. pub fn first_non_null(&self) -> Option { if self.is_empty() { None @@ -218,7 +220,7 @@ impl ChunkedArray { } } - /// Get the index of the last non null value in this ChunkedArray. + /// Get the index of the last non null value in this [`ChunkedArray`]. pub fn last_non_null(&self) -> Option { last_non_null(self.iter_validities(), self.length as usize) } @@ -234,7 +236,7 @@ impl ChunkedArray { } #[inline] - /// Return if any the chunks in this `[ChunkedArray]` have a validity bitmap. + /// Return if any the chunks in this [`ChunkedArray`] have a validity bitmap. /// no bitmap means no null values. pub fn has_validity(&self) -> bool { self.iter_validities().any(|valid| valid.is_some()) @@ -245,7 +247,7 @@ impl ChunkedArray { self.chunks = vec![concatenate_owned_unchecked(self.chunks.as_slice()).unwrap()]; } - /// Unpack a Series to the same physical type. + /// Unpack a [`Series`] to the same physical type. /// /// # Safety /// @@ -300,7 +302,7 @@ impl ChunkedArray { /// A mutable reference to the chunks /// /// # Safety - /// The caller must ensure to not change the `DataType` or `length` of any of the chunks. + /// The caller must ensure to not change the [`DataType`] or `length` of any of the chunks. #[inline] pub unsafe fn chunks_mut(&mut self) -> &mut Vec { &mut self.chunks @@ -317,7 +319,7 @@ impl ChunkedArray { self.chunks.iter().map(|arr| arr.null_count()).sum() } - /// Create a new ChunkedArray from self, where the chunks are replaced. + /// Create a new [`ChunkedArray`] from self, where the chunks are replaced. /// /// # Safety /// The caller must ensure the dtypes of the chunks are correct @@ -336,7 +338,7 @@ impl ChunkedArray { ) } - /// Get data type of ChunkedArray. + /// Get data type of [`ChunkedArray`]. pub fn dtype(&self) -> &DataType { self.field.data_type() } @@ -346,7 +348,7 @@ impl ChunkedArray { self.field = Arc::new(Field::new(self.name(), dtype)) } - /// Name of the ChunkedArray. + /// Name of the [`ChunkedArray`]. pub fn name(&self) -> &str { self.field.name() } @@ -356,19 +358,101 @@ impl ChunkedArray { &self.field } - /// Rename this ChunkedArray. + /// Rename this [`ChunkedArray`]. pub fn rename(&mut self, name: &str) { self.field = Arc::new(Field::new(name, self.field.data_type().clone())) } + + /// Return this [`ChunkedArray`] with a new name. + pub fn with_name(mut self, name: &str) -> Self { + self.rename(name); + self + } +} + +impl ChunkedArray +where + T: PolarsDataType, +{ + #[inline] + pub fn get(&self, idx: usize) -> Option> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + let arr = self.downcast_get(chunk_idx)?; + + // SAFETY: if index_to_chunked_index returns a valid chunk_idx, we know + // that arr_idx < arr.len(). + unsafe { arr.get_unchecked(arr_idx) } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + pub unsafe fn get_unchecked(&self, idx: usize) -> Option> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + + unsafe { + // SAFETY: up to the caller to make sure the index is valid. + self.downcast_get_unchecked(chunk_idx) + .get_unchecked(arr_idx) + } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + pub unsafe fn value_unchecked(&self, idx: usize) -> T::Physical<'_> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + + unsafe { + // SAFETY: up to the caller to make sure the index is valid. + self.downcast_get_unchecked(chunk_idx) + .value_unchecked(arr_idx) + } + } + + #[inline] + pub fn last(&self) -> Option> { + unsafe { + let arr = self.downcast_get_unchecked(self.chunks.len().checked_sub(1)?); + arr.get_unchecked(arr.len().checked_sub(1)?) + } + } +} + +impl ListChunked { + #[inline] + pub fn get_as_series(&self, idx: usize) -> Option { + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + self.name(), + vec![self.get(idx)?], + &self.inner_dtype().to_physical(), + )) + } + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayChunked { + #[inline] + pub fn get_as_series(&self, idx: usize) -> Option { + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + self.name(), + vec![self.get(idx)?], + &self.inner_dtype().to_physical(), + )) + } + } } impl ChunkedArray where T: PolarsDataType, { - /// Should be used to match the chunk_id of another ChunkedArray. + /// Should be used to match the chunk_id of another [`ChunkedArray`]. /// # Panics - /// It is the callers responsibility to ensure that this ChunkedArray has a single chunk. + /// It is the callers responsibility to ensure that this [`ChunkedArray`] has a single chunk. pub(crate) fn match_chunks(&self, chunk_id: I) -> Self where I: Iterator, @@ -436,6 +520,35 @@ impl AsSinglePtr for BinaryChunked {} #[cfg(feature = "object")] impl AsSinglePtr for ObjectChunked {} +pub enum ChunkedArrayLayout<'a, T: PolarsDataType> { + SingleNoNull(&'a T::Array), + Single(&'a T::Array), + MultiNoNull(&'a ChunkedArray), + Multi(&'a ChunkedArray), +} + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn layout(&self) -> ChunkedArrayLayout<'_, T> { + if self.chunks.len() == 1 { + let arr = self.downcast_iter().next().unwrap(); + return if arr.null_count() == 0 { + ChunkedArrayLayout::SingleNoNull(arr) + } else { + ChunkedArrayLayout::Single(arr) + }; + } + + if self.downcast_iter().all(|a| a.null_count() == 0) { + ChunkedArrayLayout::MultiNoNull(self) + } else { + ChunkedArrayLayout::Multi(self) + } + } +} + impl ChunkedArray where T: PolarsNumericType, @@ -626,7 +739,7 @@ pub(crate) mod test { #[test] fn take() { let a = get_chunked_array(); - let new = a.take([0usize, 1].iter().copied().into()).unwrap(); + let new = a.take(&[0 as IdxSize, 1]).unwrap(); assert_eq!(new.len(), 2) } @@ -714,9 +827,9 @@ pub(crate) mod test { #[test] #[cfg(feature = "dtype-categorical")] fn test_iter_categorical() { - use crate::{reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, SINGLE_LOCK}; let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); + disable_string_cache(); let ca = Utf8Chunked::new("", &[Some("foo"), None, Some("bar"), Some("ham")]); let ca = ca.cast(&DataType::Categorical(None)).unwrap(); let ca = ca.categorical().unwrap(); diff --git a/crates/polars-core/src/chunked_array/ndarray.rs b/crates/polars-core/src/chunked_array/ndarray.rs index f59f81b644ad..bc043da81e02 100644 --- a/crates/polars-core/src/chunked_array/ndarray.rs +++ b/crates/polars-core/src/chunked_array/ndarray.rs @@ -19,7 +19,7 @@ where T: PolarsNumericType, { /// If data is aligned in a single chunk and has no Null values a zero copy view is returned - /// as an `ndarray` + /// as an [ndarray] pub fn to_ndarray(&self) -> PolarsResult> { let slice = self.cont_slice()?; Ok(aview1(slice)) @@ -27,7 +27,7 @@ where } impl ListChunked { - /// If all nested `Series` have the same length, a 2 dimensional `ndarray::Array` is returned. + /// If all nested [`Series`] have the same length, a 2 dimensional [`ndarray::Array`] is returned. pub fn to_ndarray(&self) -> PolarsResult> where N: PolarsNumericType, @@ -75,8 +75,8 @@ impl ListChunked { } impl DataFrame { - /// Create a 2D `ndarray::Array` from this `DataFrame`. This requires all columns in the - /// `DataFrame` to be non-null and numeric. They will be casted to the same data type + /// Create a 2D [`ndarray::Array`] from this [`DataFrame`]. This requires all columns in the + /// [`DataFrame`] to be non-null and numeric. They will be casted to the same data type /// (if they aren't already). /// /// For floating point data we implicitly convert `None` to `NaN` without failure. diff --git a/crates/polars-core/src/chunked_array/object/extension/drop.rs b/crates/polars-core/src/chunked_array/object/extension/drop.rs index 1f678f0a946f..d16311120a85 100644 --- a/crates/polars-core/src/chunked_array/object/extension/drop.rs +++ b/crates/polars-core/src/chunked_array/object/extension/drop.rs @@ -1,7 +1,7 @@ use crate::chunked_array::object::extension::PolarsExtension; use crate::prelude::*; -/// This will dereference a raw ptr when dropping the PolarsExtension, make sure that it's valid. +/// This will dereference a raw ptr when dropping the [`PolarsExtension`], make sure that it's valid. pub(crate) unsafe fn drop_list(ca: &ListChunked) { let mut inner = ca.inner_dtype(); let mut nested_count = 0; diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index c5377ee2c144..ad6651ba1971 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -143,6 +143,12 @@ mod test { pub other_heap: String, } + impl TotalEq for Foo { + fn tot_eq(&self, other: &Self) -> bool { + self == other + } + } + impl Display for Foo { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index b3fe806fd31e..f51eb63c3fca 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::fmt::{Debug, Display}; use std::hash::Hash; +use arrow::bitmap::utils::{BitmapIter, ZipValidity}; use arrow::bitmap::Bitmap; use crate::prelude::*; @@ -35,7 +36,7 @@ pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display { /// Values need to implement this so that they can be stored into a Series and DataFrame pub trait PolarsObject: - Any + Debug + Clone + Send + Sync + Default + Display + Hash + PartialEq + Eq + Any + Debug + Clone + Send + Sync + Default + Display + Hash + PartialEq + Eq + TotalEq { /// This should be used as type information. Consider this a part of the type system. fn type_name() -> &'static str; @@ -55,6 +56,8 @@ impl PolarsObjectSafe for T { } } +pub type ObjectValueIter<'a, T> = std::slice::Iter<'a, T>; + impl ObjectArray where T: PolarsObject, @@ -64,6 +67,15 @@ where &self.values } + pub fn values_iter(&self) -> ObjectValueIter<'_, T> { + self.values.iter() + } + + /// Returns an iterator of `Option<&T>` over every element of this array. + pub fn iter(&self) -> ZipValidity<&T, ObjectValueIter<'_, T>, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.null_bitmap.as_ref()) + } + /// Get a value at a certain index location pub fn value(&self, index: usize) -> &T { &self.values[self.offset + index] @@ -117,6 +129,27 @@ where Some(self.value_unchecked(item)) } } + + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.null_bitmap = validity; + } } impl Array for ObjectArray @@ -153,11 +186,11 @@ where fn validity(&self) -> Option<&Bitmap> { self.null_bitmap.as_ref() } + fn with_validity(&self, validity: Option) -> Box { - let mut arr = self.clone(); - arr.null_bitmap = validity; - Box::new(arr) + Box::new(self.clone().with_validity(validity)) } + fn to_boxed(&self) -> Box { Box::new(self.clone()) } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/float_sum.rs b/crates/polars-core/src/chunked_array/ops/aggregate/float_sum.rs index 3cd0107f1c1c..b2073ef1e3c2 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/float_sum.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/float_sum.rs @@ -1,113 +1,13 @@ use std::ops::{Add, IndexMut}; #[cfg(feature = "simd")] -use std::simd::{Mask, Simd, SimdElement, ToBitMask}; +use std::simd::{Mask, Simd, SimdElement}; +use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::Bitmap; -#[cfg(feature = "simd")] -use num_traits::AsPrimitive; const STRIPE: usize = 16; const PAIRWISE_RECURSION_LIMIT: usize = 128; -// Load 8 bytes as little-endian into a u64, padding with zeros if it's too short. -#[cfg(feature = "simd")] -pub fn load_padded_le_u64(bytes: &[u8]) -> u64 { - let len = bytes.len(); - if len >= 8 { - return u64::from_le_bytes(bytes[0..8].try_into().unwrap()); - } - - if len >= 4 { - let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); - let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap()); - return (lo as u64) | ((hi as u64) << (8 * (len - 4))); - } - - if len == 0 { - return 0; - } - - let lo = bytes[0] as u64; - let mid = (bytes[len / 2] as u64) << (8 * (len / 2)); - let hi = (bytes[len - 1] as u64) << (8 * (len - 1)); - lo | mid | hi -} - -struct BitMask<'a> { - bytes: &'a [u8], - offset: usize, - len: usize, -} - -impl<'a> BitMask<'a> { - pub fn new(bitmap: &'a Bitmap) -> Self { - let (bytes, offset, len) = bitmap.as_slice(); - // Check length so we can use unsafe access in our get. - assert!(bytes.len() * 8 >= len + offset); - Self { bytes, offset, len } - } - - fn split_at(&self, idx: usize) -> (Self, Self) { - assert!(idx <= self.len); - unsafe { self.split_at_unchecked(idx) } - } - - unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) { - debug_assert!(idx <= self.len); - let left = Self { len: idx, ..*self }; - let right = Self { - len: self.len - idx, - offset: self.offset + idx, - ..*self - }; - (left, right) - } - - #[cfg(feature = "simd")] - pub fn get_simd(&self, idx: usize) -> T - where - T: ToBitMask, - ::BitMask: Copy + 'static, - u64: AsPrimitive<::BitMask>, - { - // We don't support 64-lane masks because then we couldn't load our - // bitwise mask as a u64 and then do the byteshift on it. - - let lanes = std::mem::size_of::() * 8; - assert!(lanes < 64); - - let start_byte_idx = (self.offset + idx) / 8; - let byte_shift = (self.offset + idx) % 8; - if idx + lanes <= self.len { - // SAFETY: fast path, we know this is completely in-bounds. - let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); - T::from_bitmask((mask >> byte_shift).as_()) - } else if idx < self.len { - // SAFETY: we know that at least the first byte is in-bounds. - // This is partially out of bounds, we have to do extra masking. - let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); - let num_out_of_bounds = idx + lanes - self.len; - let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift); - T::from_bitmask(shifted.as_()) - } else { - T::from_bitmask((0u64).as_()) - } - } - - pub fn get(&self, idx: usize) -> bool { - let byte_idx = (self.offset + idx) / 8; - let byte_shift = (self.offset + idx) % 8; - - if idx < self.len { - // SAFETY: we know this is in-bounds. - let byte = unsafe { *self.bytes.get_unchecked(byte_idx) }; - (byte >> byte_shift) & 1 == 1 - } else { - false - } - } -} - fn vector_horizontal_sum(mut v: V) -> T where V: IndexMut, @@ -222,7 +122,7 @@ macro_rules! def_sum { /// Also, f.len() == mask.len(). unsafe fn pairwise_sum_with_mask(f: &[$T], mask: BitMask<'_>) -> f64 { debug_assert!(f.len() > 0 && f.len() % PAIRWISE_RECURSION_LIMIT == 0); - debug_assert!(f.len() == mask.len); + debug_assert!(f.len() == mask.len()); if let Ok(block) = f.try_into() { return sum_block_vectorized_with_mask(block, mask) as f64; @@ -253,8 +153,8 @@ macro_rules! def_sum { } pub fn sum_with_validity(f: &[$T], validity: &Bitmap) -> f64 { - let mask = BitMask::new(validity); - assert!(f.len() == mask.len); + let mask = BitMask::from_bitmap(validity); + assert!(f.len() == mask.len()); let remainder = f.len() % PAIRWISE_RECURSION_LIMIT; let (rest, main) = f.split_at(remainder); diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index b6be0e22ad77..6c8af6b53332 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -8,7 +8,7 @@ use std::ops::Add; use arrow::compute; use arrow::types::simd::Simd; use arrow::types::NativeType; -use num_traits::{Float, ToPrimitive, Zero}; +use num_traits::{Float, One, ToPrimitive, Zero}; use polars_arrow::kernels::rolling::{compare_fn_nan_max, compare_fn_nan_min}; pub use quantile::*; pub use var::*; @@ -16,26 +16,26 @@ pub use var::*; use crate::chunked_array::ChunkedArray; use crate::datatypes::{BooleanChunked, PolarsNumericType}; use crate::prelude::*; +use crate::series::implementations::SeriesWrap; use crate::series::IsSorted; -use crate::utils::CustomIterTools; mod float_sum; -/// Aggregations that return Series of unit length. Those can be used in broadcasting operations. +/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations. pub trait ChunkAggSeries { - /// Get the sum of the ChunkedArray as a new Series of length 1. + /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1. fn sum_as_series(&self) -> Series { unimplemented!() } - /// Get the max of the ChunkedArray as a new Series of length 1. + /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1. fn max_as_series(&self) -> Series { unimplemented!() } - /// Get the min of the ChunkedArray as a new Series of length 1. + /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1. fn min_as_series(&self) -> Series { unimplemented!() } - /// Get the product of the ChunkedArray as a new Series of length 1. + /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1. fn prod_as_series(&self) -> Series { unimplemented!() } @@ -116,7 +116,7 @@ where IsSorted::Not => self .downcast_iter() .filter_map(compute::aggregate::min_primitive) - .fold_first_(|acc, v| { + .reduce(|acc, v| { if matches!(compare_fn_nan_max(&acc, &v), Ordering::Less) { acc } else { @@ -148,7 +148,7 @@ where IsSorted::Not => self .downcast_iter() .filter_map(compute::aggregate::max_primitive) - .fold_first_(|acc, v| { + .reduce(|acc, v| { if matches!(compare_fn_nan_min(&acc, &v), Ordering::Greater) { acc } else { @@ -275,7 +275,7 @@ impl BooleanChunked { } } -// Needs the same trait bounds as the implementation of ChunkedArray of dyn Series +// Needs the same trait bounds as the implementation of ChunkedArray of dyn Series. impl ChunkAggSeries for ChunkedArray where T: PolarsNumericType, @@ -290,45 +290,38 @@ where ca.rename(self.name()); ca.into_series() } + fn max_as_series(&self) -> Series { - let v = self.max(); + let v = ChunkAgg::max(self); let mut ca: ChunkedArray = [v].iter().copied().collect(); ca.rename(self.name()); ca.into_series() } + fn min_as_series(&self) -> Series { - let v = self.min(); + let v = ChunkAgg::min(self); let mut ca: ChunkedArray = [v].iter().copied().collect(); ca.rename(self.name()); ca.into_series() } fn prod_as_series(&self) -> Series { - let mut prod = None; - for opt_v in self.into_iter() { - match (prod, opt_v) { - (_, None) => return Self::full_null(self.name(), 1).into_series(), - (None, Some(v)) => prod = Some(v), - (Some(p), Some(v)) => prod = Some(p * v), - } + let mut prod = T::Native::one(); + for opt_v in self.into_iter().flatten() { + prod = prod * opt_v; } - Self::from_slice_options(self.name(), &[prod]).into_series() + Self::from_slice_options(self.name(), &[Some(prod)]).into_series() } } -macro_rules! impl_as_series { - ($self:expr, $agg:ident, $ty: ty) => {{ - let v = $self.$agg(); - let mut ca: $ty = [v].iter().copied().collect(); - ca.rename($self.name()); - ca.into_series() - }}; - ($self:expr, $agg:ident, $arg:expr, $ty: ty) => {{ - let v = $self.$agg($arg); - let mut ca: $ty = [v].iter().copied().collect(); - ca.rename($self.name()); - ca.into_series() - }}; +fn as_series(name: &str, v: Option) -> Series +where + T: PolarsNumericType, + SeriesWrap>: SeriesTrait, +{ + let mut ca: ChunkedArray = [v].into_iter().collect(); + ca.rename(name); + ca.into_series() } impl VarAggSeries for ChunkedArray @@ -339,43 +332,34 @@ where + compute::aggregate::SimdOrd, { fn var_as_series(&self, ddof: u8) -> Series { - impl_as_series!(self, var, ddof, Float64Chunked) + as_series::(self.name(), self.var(ddof)) } fn std_as_series(&self, ddof: u8) -> Series { - impl_as_series!(self, std, ddof, Float64Chunked) + as_series::(self.name(), self.std(ddof)) } } impl VarAggSeries for Float32Chunked { fn var_as_series(&self, ddof: u8) -> Series { - impl_as_series!(self, var, ddof, Float32Chunked) + as_series::(self.name(), self.var(ddof).map(|x| x as f32)) } fn std_as_series(&self, ddof: u8) -> Series { - impl_as_series!(self, std, ddof, Float32Chunked) + as_series::(self.name(), self.std(ddof).map(|x| x as f32)) } } impl VarAggSeries for Float64Chunked { fn var_as_series(&self, ddof: u8) -> Series { - impl_as_series!(self, var, ddof, Float64Chunked) + as_series::(self.name(), self.var(ddof)) } fn std_as_series(&self, ddof: u8) -> Series { - impl_as_series!(self, std, ddof, Float64Chunked) + as_series::(self.name(), self.std(ddof)) } } -macro_rules! impl_quantile_as_series { - ($self:expr, $agg:ident, $ty: ty, $qtl:expr, $opt:expr) => {{ - let v = $self.$agg($qtl, $opt)?; - let mut ca: $ty = [v].iter().copied().collect(); - ca.rename($self.name()); - Ok(ca.into_series()) - }}; -} - impl QuantileAggSeries for ChunkedArray where T: PolarsIntegerType, @@ -389,11 +373,14 @@ where quantile: f64, interpol: QuantileInterpolOptions, ) -> PolarsResult { - impl_quantile_as_series!(self, quantile, Float64Chunked, quantile, interpol) + Ok(as_series::( + self.name(), + self.quantile(quantile, interpol)?, + )) } fn median_as_series(&self) -> Series { - impl_as_series!(self, median, Float64Chunked) + as_series::(self.name(), self.median()) } } @@ -403,11 +390,14 @@ impl QuantileAggSeries for Float32Chunked { quantile: f64, interpol: QuantileInterpolOptions, ) -> PolarsResult { - impl_quantile_as_series!(self, quantile, Float32Chunked, quantile, interpol) + Ok(as_series::( + self.name(), + self.quantile(quantile, interpol)?, + )) } fn median_as_series(&self) -> Series { - impl_as_series!(self, median, Float32Chunked) + as_series::(self.name(), self.median()) } } @@ -417,11 +407,14 @@ impl QuantileAggSeries for Float64Chunked { quantile: f64, interpol: QuantileInterpolOptions, ) -> PolarsResult { - impl_quantile_as_series!(self, quantile, Float64Chunked, quantile, interpol) + Ok(as_series::( + self.name(), + self.quantile(quantile, interpol)?, + )) } fn median_as_series(&self) -> Series { - impl_as_series!(self, median, Float64Chunked) + as_series::(self.name(), self.median()) } } @@ -463,7 +456,7 @@ impl Utf8Chunked { IsSorted::Not => self .downcast_iter() .filter_map(compute::aggregate::max_string) - .fold_first_(|acc, v| if acc > v { acc } else { v }), + .reduce(|acc, v| if acc > v { acc } else { v }), } } pub(crate) fn min_str(&self) -> Option<&str> { @@ -488,7 +481,7 @@ impl Utf8Chunked { IsSorted::Not => self .downcast_iter() .filter_map(compute::aggregate::min_string) - .fold_first_(|acc, v| if acc < v { acc } else { v }), + .reduce(|acc, v| if acc < v { acc } else { v }), } } } @@ -513,22 +506,20 @@ impl BinaryChunked { match self.is_sorted_flag() { IsSorted::Ascending => { self.last_non_null().and_then(|idx| { - // Safety: - // last_non_null returns in bound index + // SAFETY: last_non_null returns in bound index. unsafe { self.get_unchecked(idx) } }) }, IsSorted::Descending => { self.first_non_null().and_then(|idx| { - // Safety: - // first_non_null returns in bound index + // SAFETY: first_non_null returns in bound index. unsafe { self.get_unchecked(idx) } }) }, IsSorted::Not => self .downcast_iter() .filter_map(compute::aggregate::max_binary) - .fold_first_(|acc, v| if acc > v { acc } else { v }), + .reduce(|acc, v| if acc > v { acc } else { v }), } } @@ -539,22 +530,20 @@ impl BinaryChunked { match self.is_sorted_flag() { IsSorted::Ascending => { self.first_non_null().and_then(|idx| { - // Safety: - // first_non_null returns in bound index + // SAFETY: first_non_null returns in bound index. unsafe { self.get_unchecked(idx) } }) }, IsSorted::Descending => { self.last_non_null().and_then(|idx| { - // Safety: - // last_non_null returns in bound index + // SAFETY: last_non_null returns in bound index. unsafe { self.get_unchecked(idx) } }) }, IsSorted::Not => self .downcast_iter() .filter_map(compute::aggregate::min_binary) - .fold_first_(|acc, v| if acc < v { acc } else { v }), + .reduce(|acc, v| if acc < v { acc } else { v }), } } } @@ -610,9 +599,9 @@ mod test { #[test] fn test_var() { - // validated with numpy - // Note that numpy as an argument ddof which influences results. The default is ddof=0 - // we chose ddof=1, which is standard in statistics + // Validated with numpy. Note that numpy uses ddof as an argument which + // influences results. The default ddof=0, we chose ddof=1, which is + // standard in statistics. let ca1 = Int32Chunked::new("", &[5, 8, 9, 5, 0]); let ca2 = Int32Chunked::new( "", diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index 30691c51bc6c..202fc2173aa7 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -5,9 +5,9 @@ use polars_utils::slice::Extrema; use super::*; pub trait QuantileAggSeries { - /// Get the median of the ChunkedArray as a new Series of length 1. + /// Get the median of the [`ChunkedArray`] as a new [`Series`] of length 1. fn median_as_series(&self) -> Series; - /// Get the quantile of the ChunkedArray as a new Series of length 1. + /// Get the quantile of the [`ChunkedArray`] as a new [`Series`] of length 1. fn quantile_as_series( &self, _quantile: f64, diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs index d7f4a4828591..cb5479dc7346 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs @@ -1,94 +1,45 @@ use super::*; pub trait VarAggSeries { - /// Get the variance of the ChunkedArray as a new Series of length 1. + /// Get the variance of the [`ChunkedArray`] as a new [`Series`] of length 1. fn var_as_series(&self, ddof: u8) -> Series; - /// Get the standard deviation of the ChunkedArray as a new Series of length 1. + /// Get the standard deviation of the [`ChunkedArray`] as a new [`Series`] of length 1. fn std_as_series(&self, ddof: u8) -> Series; } -impl ChunkVar for ChunkedArray +impl ChunkVar for ChunkedArray where - T: PolarsIntegerType, + T: PolarsNumericType, ::Simd: Add::Simd> + compute::aggregate::Sum + compute::aggregate::SimdOrd, { fn var(&self, ddof: u8) -> Option { let n_values = self.len() - self.null_count(); - - if ddof as usize > n_values { + if n_values <= ddof as usize { return None; } - let n_values = n_values as f64; let mean = self.mean()?; let squared: Float64Chunked = ChunkedArray::apply_values_generic(self, |value| { let tmp = value.to_f64().unwrap() - mean; tmp * tmp }); - // Note, this is similar behavior to numpy if DDOF=1. - // in statistics DDOF often = 1. - // this last step is similar to mean, only now instead of 1/n it is 1/(n-1) - squared.sum().map(|sum| sum / (n_values - ddof as f64)) - } - fn std(&self, ddof: u8) -> Option { - self.var(ddof).map(|var| var.sqrt()) - } -} - -impl ChunkVar for Float32Chunked { - fn var(&self, ddof: u8) -> Option { - if self.len() == 1 { - return Some(0.0); - } - let n_values = self.len() - self.null_count(); - if ddof as usize > n_values { - return None; - } - let n_values = n_values as f32; - - let mean = self.mean()? as f32; - let squared = self.apply_values(|value| { - let tmp = value - mean; - tmp * tmp - }); - squared.sum().map(|sum| sum / (n_values - ddof as f32)) + squared + .sum() + .map(|sum| sum / (n_values as f64 - ddof as f64)) } - fn std(&self, ddof: u8) -> Option { - self.var(ddof).map(|var| var.sqrt()) - } -} - -impl ChunkVar for Float64Chunked { - fn var(&self, ddof: u8) -> Option { - if self.len() == 1 { - return Some(0.0); - } - let n_values = self.len() - self.null_count(); - if ddof as usize > n_values { - return None; - } - let n_values = n_values as f64; - - let mean = self.mean()?; - let squared = self.apply_values(|value| { - let tmp = value - mean; - tmp * tmp - }); - squared.sum().map(|sum| sum / (n_values - ddof as f64)) - } fn std(&self, ddof: u8) -> Option { self.var(ddof).map(|var| var.sqrt()) } } -impl ChunkVar for Utf8Chunked {} -impl ChunkVar for ListChunked {} +impl ChunkVar for Utf8Chunked {} +impl ChunkVar for ListChunked {} #[cfg(feature = "dtype-array")] -impl ChunkVar for ArrayChunked {} +impl ChunkVar for ArrayChunked {} #[cfg(feature = "object")] -impl ChunkVar for ObjectChunked {} -impl ChunkVar for BooleanChunked {} +impl ChunkVar for ObjectChunked {} +impl ChunkVar for BooleanChunked {} diff --git a/crates/polars-core/src/chunked_array/ops/any_value.rs b/crates/polars-core/src/chunked_array/ops/any_value.rs index f08b1b3e530d..7b01f66c8015 100644 --- a/crates/polars-core/src/chunked_array/ops/any_value.rs +++ b/crates/polars-core/src/chunked_array/ops/any_value.rs @@ -4,6 +4,7 @@ use polars_utils::sync::SyncPtr; #[cfg(feature = "object")] use crate::chunked_array::object::extension::polars_extension::PolarsExtension; use crate::prelude::*; +use crate::series::implementations::null::NullChunked; #[inline] #[allow(unused_variables)] @@ -282,3 +283,14 @@ impl ChunkAnyValue for ObjectChunked { } } } + +impl ChunkAnyValue for NullChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, _index: usize) -> AnyValue { + AnyValue::Null + } + + fn get_any_value(&self, _index: usize) -> PolarsResult { + Ok(AnyValue::Null) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 4a1503826743..c14405ed377d 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -2,106 +2,68 @@ use crate::prelude::*; use crate::series::IsSorted; pub(crate) fn new_chunks(chunks: &mut Vec, other: &[ArrayRef], len: usize) { - // replace an empty array + // Replace an empty array. if chunks.len() == 1 && len == 0 { *chunks = other.to_owned(); } else { - chunks.extend_from_slice(other); + for chunk in other { + if chunk.len() > 0 { + chunks.push(chunk.clone()); + } + } } } -pub(super) fn update_sorted_flag_before_append<'a, T>( - ca: &mut ChunkedArray, - other: &'a ChunkedArray, -) where +pub(super) fn update_sorted_flag_before_append(ca: &mut ChunkedArray, other: &ChunkedArray) +where T: PolarsDataType, - &'a ChunkedArray: TakeRandom, - <&'a ChunkedArray as TakeRandom>::Item: PartialOrd, + for<'a> T::Physical<'a>: TotalOrd, { - let get_start_end = || { - let end = { - unsafe { - // reborrow and - // inform bchk that we still have lifetime 'a - // this is safe as we go from &mut borrow to & - // because the trait is only implemented for &ChunkedArray - let borrow = std::mem::transmute::<&ChunkedArray, &'a ChunkedArray>(ca); - // ensure we don't access with `len() - 1` this will have O(n^2) complexity - // if we append many chunks that are sorted - borrow.last() - } - }; - let start = unsafe { other.get_unchecked(0) }; + // If either is empty (or completely null), copy the sorted flag from the other. + if ca.len() == ca.null_count() { + ca.set_sorted_flag(other.is_sorted_flag()); + return; + } + if other.len() == other.null_count() { + return; + } - (start, end) - }; + // Both need to be sorted, in the same order. + let ls = ca.is_sorted_flag(); + let rs = other.is_sorted_flag(); + if ls != rs || ls == IsSorted::Not || rs == IsSorted::Not { + ca.set_sorted_flag(IsSorted::Not); + return; + } - if !ca.is_empty() && !other.is_empty() { - match (ca.is_sorted_flag(), other.is_sorted_flag()) { - (IsSorted::Ascending, IsSorted::Ascending) => { - let (start, end) = get_start_end(); - if end > start { - ca.set_sorted_flag(IsSorted::Not) - } - }, - (IsSorted::Descending, IsSorted::Descending) => { - let (start, end) = get_start_end(); - if end < start { - ca.set_sorted_flag(IsSorted::Not) - } - }, - _ => ca.set_sorted_flag(IsSorted::Not), + // Check the order is maintained. + let still_sorted = { + let left = ca.get(ca.last_non_null().unwrap()).unwrap(); + let right = other.get(other.first_non_null().unwrap()).unwrap(); + if ca.is_sorted_ascending_flag() { + left.tot_le(&right) + } else { + left.tot_ge(&right) } - } else if ca.is_empty() { - ca.set_sorted_flag(other.is_sorted_flag()) + }; + if !still_sorted { + ca.set_sorted_flag(IsSorted::Not); } } impl ChunkedArray where - T: PolarsNumericType, + T: PolarsDataType, + for<'a> T::Physical<'a>: TotalOrd, { /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. /// /// See also [`extend`](Self::extend) for appends to the underlying memory pub fn append(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); - - let len = self.len(); - self.length += other.length; - new_chunks(&mut self.chunks, &other.chunks, len); - } -} - -#[doc(hidden)] -impl BooleanChunked { - pub fn append(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); + update_sorted_flag_before_append::(self, other); let len = self.len(); self.length += other.length; new_chunks(&mut self.chunks, &other.chunks, len); - self.set_sorted_flag(IsSorted::Not); - } -} -#[doc(hidden)] -impl Utf8Chunked { - pub fn append(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); - let len = self.len(); - self.length += other.length; - new_chunks(&mut self.chunks, &other.chunks, len); - self.set_sorted_flag(IsSorted::Not); - } -} - -#[doc(hidden)] -impl BinaryChunked { - pub fn append(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); - let len = self.len(); - self.length += other.length; - new_chunks(&mut self.chunks, &other.chunks, len); - self.set_sorted_flag(IsSorted::Not); } } diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index 18a34d5566da..3ef019176f3a 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -1,13 +1,9 @@ //! Implementations of the ChunkApply Trait. use std::borrow::Cow; use std::convert::TryFrom; -use std::error::Error; use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::utils::{get_bit_unchecked, set_bit_unchecked}; -use arrow::bitmap::Bitmap; -use arrow::trusted_len::TrustedLen; -use arrow::types::NativeType; use polars_arrow::bitmap::unary_mut; use crate::prelude::*; @@ -17,112 +13,134 @@ use crate::utils::CustomIterTools; impl ChunkedArray where T: PolarsDataType, - Self: HasUnderlyingArray, { - pub fn apply_values_generic<'a, U, K, F>(&'a self, op: F) -> ChunkedArray + // Applies a function to all elements , regardless of whether they + // are null or not, after which the null mask is copied from the + // original array. + pub fn apply_values_generic<'a, U, K, F>(&'a self, mut op: F) -> ChunkedArray where U: PolarsDataType, - F: FnMut(<::ArrayT as StaticArray>::ValueT<'a>) -> K + Copy, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, + F: FnMut(T::Physical<'a>) -> K, + U::Array: ArrayFromIter, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.values_iter().map(op); - let array = K::array_from_values_iter(element_iter); - array.with_validity_typed(arr.validity().cloned()) + let out: U::Array = arr.values_iter().map(&mut op).collect_arr(); + out.with_validity_typed(arr.validity().cloned()) }); ChunkedArray::from_chunk_iter(self.name(), iter) } - pub fn try_apply_values_generic<'a, U, K, F, E>(&'a self, op: F) -> Result, E> + /// Applies a function to all elements, regardless of whether they + /// are null or not, after which the null mask is copied from the + /// original array. + pub fn try_apply_values_generic<'a, U, K, F, E>( + &'a self, + mut op: F, + ) -> Result, E> where U: PolarsDataType, - F: FnMut(<::ArrayT as StaticArray>::ValueT<'a>) -> Result - + Copy, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, - E: Error, + F: FnMut(T::Physical<'a>) -> Result, + U::Array: ArrayFromIter, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.values_iter().map(op); - let array = K::try_array_from_values_iter(element_iter)?; + let element_iter = arr.values_iter().map(&mut op); + let array: U::Array = element_iter.try_collect_arr()?; Ok(array.with_validity_typed(arr.validity().cloned())) }); ChunkedArray::try_from_chunk_iter(self.name(), iter) } - pub fn try_apply_generic<'a, U, K, F, E>(&'a self, op: F) -> Result, E> + /// Applies a function only to the non-null elements, propagating nulls. + pub fn apply_nonnull_values_generic<'a, U, K, F>( + &'a self, + dtype: DataType, + mut op: F, + ) -> ChunkedArray where U: PolarsDataType, - F: FnMut( - Option<<::ArrayT as StaticArray>::ValueT<'a>>, - ) -> Result, E> - + Copy, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, - E: Error, + F: FnMut(T::Physical<'a>) -> K, + U::Array: ArrayFromIterDtype + ArrayFromIterDtype>, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.iter().map(op); - let array = K::try_array_from_iter(element_iter)?; - Ok(array.with_validity_typed(arr.validity().cloned())) + if arr.null_count() == 0 { + let out: U::Array = arr + .values_iter() + .map(&mut op) + .collect_arr_with_dtype(dtype.clone()); + out.with_validity_typed(arr.validity().cloned()) + } else { + let out: U::Array = arr + .iter() + .map(|opt| opt.map(&mut op)) + .collect_arr_with_dtype(dtype.clone()); + out.with_validity_typed(arr.validity().cloned()) + } }); - ChunkedArray::try_from_chunk_iter(self.name(), iter) + ChunkedArray::from_chunk_iter(self.name(), iter) } - pub fn apply_generic<'a, U, K, F>(&'a self, op: F) -> ChunkedArray + /// Applies a function only to the non-null elements, propagating nulls. + pub fn try_apply_nonnull_values_generic<'a, U, K, F, E>( + &'a self, + mut op: F, + ) -> Result, E> where U: PolarsDataType, - F: FnMut( - Option<<::ArrayT as StaticArray>::ValueT<'a>>, - ) -> Option - + Copy, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, + F: FnMut(T::Physical<'a>) -> Result, + U::Array: ArrayFromIter + ArrayFromIter>, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.iter().map(op); - K::array_from_iter(element_iter) + let arr = if arr.null_count() == 0 { + let out: U::Array = arr.values_iter().map(&mut op).try_collect_arr()?; + out.with_validity_typed(arr.validity().cloned()) + } else { + let out: U::Array = arr + .iter() + .map(|opt| opt.map(&mut op).transpose()) + .try_collect_arr()?; + out.with_validity_typed(arr.validity().cloned()) + }; + Ok(arr) }); - ChunkedArray::from_chunk_iter(self.name(), iter) + ChunkedArray::try_from_chunk_iter(self.name(), iter) } -} - -fn collect_array>( - iter: I, - validity: Option, -) -> PrimitiveArray { - PrimitiveArray::from_trusted_len_values_iter(iter).with_validity(validity) -} -macro_rules! try_apply { - ($self:expr, $f:expr) => {{ - if !$self.has_validity() { - $self.into_no_null_iter().map($f).collect() + pub fn apply_generic<'a, U, K, F>(&'a self, mut op: F) -> ChunkedArray + where + U: PolarsDataType, + F: FnMut(Option>) -> Option, + U::Array: ArrayFromIter>, + { + if self.null_count() == 0 { + let iter = self + .downcast_iter() + .map(|arr| arr.values_iter().map(|x| op(Some(x))).collect_arr()); + ChunkedArray::from_chunk_iter(self.name(), iter) } else { - $self - .into_iter() - .map(|opt_v| opt_v.map($f).transpose()) - .collect() + let iter = self + .downcast_iter() + .map(|arr| arr.iter().map(&mut op).collect_arr()); + ChunkedArray::from_chunk_iter(self.name(), iter) } - }}; -} + } -macro_rules! apply { - ($self:expr, $f:expr) => {{ - if !$self.has_validity() { - $self.into_no_null_iter().map($f).collect_trusted() - } else { - $self - .into_iter() - .map(|opt_v| opt_v.map($f)) - .collect_trusted() - } - }}; + pub fn try_apply_generic<'a, U, K, F, E>(&'a self, op: F) -> Result, E> + where + U: PolarsDataType, + F: FnMut(Option>) -> Result, E> + Copy, + U::Array: ArrayFromIter>, + { + let iter = self.downcast_iter().map(|arr| { + let array: U::Array = arr.iter().map(op).try_collect_arr()?; + Ok(array.with_validity_typed(arr.validity().cloned())) + }); + + ChunkedArray::try_from_chunk_iter(self.name(), iter) + } } fn apply_in_place_impl(name: &str, chunks: Vec, f: F) -> ChunkedArray @@ -220,7 +238,8 @@ where .data_views() .zip(self.iter_validities()) .map(|(slice, validity)| { - collect_array(slice.iter().copied().map(f), validity.cloned()) + let arr: T::Array = slice.iter().copied().map(f).collect_arr(); + arr.with_validity(validity.cloned()) }); ChunkedArray::from_chunk_iter(self.name(), chunks) } @@ -375,6 +394,21 @@ impl Utf8Chunked { }); Utf8Chunked::from_chunk_iter(self.name(), chunks) } + + /// Utility that reuses an string buffer to amortize allocations. + /// Prefer this over an `apply` that returns an owned `String`. + pub fn apply_to_buffer<'a, F>(&'a self, mut f: F) -> Self + where + F: FnMut(&'a str, &mut String), + { + let mut buf = String::new(); + let outer = |s: &'a str| { + buf.clear(); + f(s, &mut buf); + unsafe { std::mem::transmute::<&str, &'a str>(buf.as_str()) } + }; + self.apply_mut(outer) + } } impl BinaryChunked { @@ -560,7 +594,17 @@ impl<'a> ChunkApply<'a, Series> for ListChunked { } out }; - let mut ca: ListChunked = apply!(self, &mut function); + let mut ca: ListChunked = { + if !self.has_validity() { + self.into_no_null_iter() + .map(&mut function) + .collect_trusted() + } else { + self.into_iter() + .map(|opt_v| opt_v.map(&mut function)) + .collect_trusted() + } + }; if fast_explode { ca.set_fast_explode() } @@ -585,7 +629,15 @@ impl<'a> ChunkApply<'a, Series> for ListChunked { } out }; - let ca: PolarsResult = try_apply!(self, &mut function); + let ca: PolarsResult = { + if !self.has_validity() { + self.into_no_null_iter().map(&mut function).collect() + } else { + self.into_iter() + .map(|opt_v| opt_v.map(&mut function).transpose()) + .collect() + } + }; let mut ca = ca?; if fast_explode { ca.set_fast_explode() diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index 4214c41deccc..8c14a17ec567 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -3,15 +3,32 @@ use std::error::Error; use arrow::array::Array; use polars_arrow::utils::combine_validities_and; -use crate::datatypes::{ - ArrayFromElementIter, HasUnderlyingArray, PolarsNumericType, StaticArray, - StaticallyMatchesPolarsType, -}; +use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray}; use crate::prelude::{ChunkedArray, PolarsDataType}; -use crate::utils::align_chunks_binary; +use crate::utils::{align_chunks_binary, align_chunks_ternary}; + +// We need this helper because for<'a> notation can't yet be applied properly +// on the return type. +pub trait TernaryFnMut: FnMut(A1, A2, A3) -> Self::Ret { + type Ret; +} + +impl R> TernaryFnMut for T { + type Ret = R; +} + +// We need this helper because for<'a> notation can't yet be applied properly +// on the return type. +pub trait BinaryFnMut: FnMut(A1, A2) -> Self::Ret { + type Ret; +} + +impl R> BinaryFnMut for T { + type Ret = R; +} #[inline] -pub fn binary_elementwise( +pub fn binary_elementwise( lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F, @@ -20,14 +37,10 @@ where T: PolarsDataType, U: PolarsDataType, V: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - F: for<'a> FnMut( - Option<< as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>>, - Option<< as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>>, - ) -> Option, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, + F: for<'a> BinaryFnMut>, Option>>, + V::Array: for<'a> ArrayFromIter< + >, Option>>>::Ret, + >, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -38,11 +51,61 @@ where .iter() .zip(rhs_arr.iter()) .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); - K::array_from_iter(element_iter) + element_iter.collect_arr() }); ChunkedArray::from_chunk_iter(lhs.name(), iter) } +#[inline] +pub fn binary_elementwise_for_each<'a, 'b, T, U, F>( + lhs: &'a ChunkedArray, + rhs: &'b ChunkedArray, + mut op: F, +) where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(Option>, Option>), +{ + let mut lhs_arr_iter = lhs.downcast_iter(); + let mut rhs_arr_iter = rhs.downcast_iter(); + + let lhs_arr = lhs_arr_iter.next().unwrap(); + let rhs_arr = rhs_arr_iter.next().unwrap(); + + let mut lhs_remaining = lhs_arr.len(); + let mut rhs_remaining = rhs_arr.len(); + let mut lhs_iter = lhs_arr.iter(); + let mut rhs_iter = rhs_arr.iter(); + + loop { + let range = std::cmp::min(lhs_remaining, rhs_remaining); + + for _ in 0..range { + // SAFETY: we loop until the smaller iter is exhausted. + let lhs_opt_val = unsafe { lhs_iter.next().unwrap_unchecked() }; + let rhs_opt_val = unsafe { rhs_iter.next().unwrap_unchecked() }; + op(lhs_opt_val, rhs_opt_val) + } + lhs_remaining -= range; + rhs_remaining -= range; + + if lhs_remaining == 0 { + let Some(new_arr) = lhs_arr_iter.next() else { + return; + }; + lhs_remaining = new_arr.len(); + lhs_iter = new_arr.iter(); + } + if rhs_remaining == 0 { + let Some(new_arr) = rhs_arr_iter.next() else { + return; + }; + rhs_remaining = new_arr.len(); + rhs_iter = new_arr.iter(); + } + } +} + #[inline] pub fn try_binary_elementwise( lhs: &ChunkedArray, @@ -53,15 +116,8 @@ where T: PolarsDataType, U: PolarsDataType, V: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - F: for<'a> FnMut( - Option<< as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>>, - Option<< as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>>, - ) -> Result, E>, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, - E: Error, + F: for<'a> FnMut(Option>, Option>) -> Result, E>, + V::Array: ArrayFromIter>, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -72,7 +128,7 @@ where .iter() .zip(rhs_arr.iter()) .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); - K::try_array_from_iter(element_iter) + element_iter.try_collect_arr() }); ChunkedArray::try_from_chunk_iter(lhs.name(), iter) } @@ -86,17 +142,12 @@ pub fn binary_elementwise_values( where T: PolarsDataType, U: PolarsDataType, - V: PolarsNumericType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - F: for<'a> FnMut( - < as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>, - < as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>, - ) -> K, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, + V: PolarsDataType, + F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K, + V::Array: ArrayFromIter, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs .downcast_iter() .zip(rhs.downcast_iter()) @@ -108,7 +159,7 @@ where .zip(rhs_arr.values_iter()) .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val)); - let array = K::array_from_values_iter(element_iter); + let array: V::Array = element_iter.collect_arr(); array.with_validity_typed(validity) }); ChunkedArray::from_chunk_iter(lhs.name(), iter) @@ -123,16 +174,9 @@ pub fn try_binary_elementwise_values( where T: PolarsDataType, U: PolarsDataType, - V: PolarsNumericType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - F: for<'a> FnMut( - < as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>, - < as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>, - ) -> Result, - K: ArrayFromElementIter, - K::ArrayType: StaticallyMatchesPolarsType, - E: Error, + V: PolarsDataType, + F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result, + V::Array: ArrayFromIter, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -146,7 +190,7 @@ where .zip(rhs_arr.values_iter()) .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val)); - let array = K::try_array_from_values_iter(element_iter)?; + let array: V::Array = element_iter.try_collect_arr()?; Ok(array.with_validity_typed(validity)) }); ChunkedArray::try_from_chunk_iter(lhs.name(), iter) @@ -163,14 +207,9 @@ pub fn binary_mut_with_options( where T: PolarsDataType, U: PolarsDataType, - V: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - Arr: Array + StaticallyMatchesPolarsType, - F: FnMut( - & as HasUnderlyingArray>::ArrayT, - & as HasUnderlyingArray>::ArrayT, - ) -> Arr, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Arr, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -189,14 +228,9 @@ pub fn binary( where T: PolarsDataType, U: PolarsDataType, - V: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - Arr: Array + StaticallyMatchesPolarsType, - F: FnMut( - & as HasUnderlyingArray>::ArrayT, - & as HasUnderlyingArray>::ArrayT, - ) -> Arr, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Arr, { binary_mut_with_options(lhs, rhs, op, lhs.name()) } @@ -210,14 +244,9 @@ pub fn try_binary( where T: PolarsDataType, U: PolarsDataType, - V: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - Arr: Array + StaticallyMatchesPolarsType, - F: FnMut( - & as HasUnderlyingArray>::ArrayT, - & as HasUnderlyingArray>::ArrayT, - ) -> Result, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Result, E: Error, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); @@ -243,12 +272,7 @@ pub unsafe fn binary_unchecked_same_type( where T: PolarsDataType, U: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - F: FnMut( - & as HasUnderlyingArray>::ArrayT, - & as HasUnderlyingArray>::ArrayT, - ) -> Box, + F: FnMut(&T::Array, &U::Array) -> Box, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let chunks = lhs @@ -274,12 +298,7 @@ pub unsafe fn try_binary_unchecked_same_type( where T: PolarsDataType, U: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - ChunkedArray: HasUnderlyingArray, - F: FnMut( - & as HasUnderlyingArray>::ArrayT, - & as HasUnderlyingArray>::ArrayT, - ) -> Result, E>, + F: FnMut(&T::Array, &U::Array) -> Result, E>, E: Error, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); @@ -290,3 +309,79 @@ where .collect::, E>>()?; Ok(lhs.copy_with_chunks(chunks, keep_sorted, keep_fast_explode)) } + +#[inline] +pub fn try_ternary_elementwise( + ca1: &ChunkedArray, + ca2: &ChunkedArray, + ca3: &ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + G: PolarsDataType, + F: for<'a> FnMut( + Option>, + Option>, + Option>, + ) -> Result, E>, + V::Array: ArrayFromIter>, +{ + let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3); + let iter = ca1 + .downcast_iter() + .zip(ca2.downcast_iter()) + .zip(ca3.downcast_iter()) + .map(|((ca1_arr, ca2_arr), ca3_arr)| { + let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map( + |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| { + op(ca1_opt_val, ca2_opt_val, ca3_opt_val) + }, + ); + element_iter.try_collect_arr() + }); + ChunkedArray::try_from_chunk_iter(ca1.name(), iter) +} + +#[inline] +pub fn ternary_elementwise( + ca1: &ChunkedArray, + ca2: &ChunkedArray, + ca3: &ChunkedArray, + mut op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + G: PolarsDataType, + V: PolarsDataType, + F: for<'a> TernaryFnMut< + Option>, + Option>, + Option>, + >, + V::Array: for<'a> ArrayFromIter< + >, + Option>, + Option>, + >>::Ret, + >, +{ + let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3); + let iter = ca1 + .downcast_iter() + .zip(ca2.downcast_iter()) + .zip(ca3.downcast_iter()) + .map(|((ca1_arr, ca2_arr), ca3_arr)| { + let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map( + |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| { + op(ca1_opt_val, ca2_opt_val, ca3_opt_val) + }, + ); + element_iter.collect_arr() + }); + ChunkedArray::from_chunk_iter(ca1.name(), iter) +} diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs index a8379ce9a566..5dbd3f0c30cf 100644 --- a/crates/polars-core/src/chunked_array/ops/bit_repr.rs +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -2,7 +2,7 @@ use arrow::buffer::Buffer; use crate::prelude::*; -/// Reinterprets the type of a ChunkedArray. T and U must have the same size +/// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size /// and alignment. fn reinterpret_chunked_array( ca: &ChunkedArray, @@ -22,8 +22,9 @@ fn reinterpret_chunked_array( ChunkedArray::from_chunk_iter(ca.name(), chunks) } -/// Reinterprets the type of a ListChunked. T and U must have the same size +/// Reinterprets the type of a [`ListChunked`]. T and U must have the same size /// and alignment. +#[cfg(feature = "reinterpret")] fn reinterpret_list_chunked( ca: &ListChunked, ) -> ListChunked { @@ -245,7 +246,7 @@ impl UInt32Chunked { /// Used to save compilation paths. Use carefully. Although this is safe, /// if misused it can lead to incorrect results. impl Float32Chunked { - pub(crate) fn apply_as_ints(&self, f: F) -> Series + pub fn apply_as_ints(&self, f: F) -> Series where F: Fn(&Series) -> Series, { @@ -256,7 +257,7 @@ impl Float32Chunked { } } impl Float64Chunked { - pub(crate) fn apply_as_ints(&self, f: F) -> Series + pub fn apply_as_ints(&self, f: F) -> Series where F: Fn(&Series) -> Series, { diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index 085b041d8b5b..1aa9d4dafaf3 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -1,6 +1,7 @@ #[cfg(feature = "object")] use arrow::array::Array; use polars_arrow::kernels::concatenate::concatenate_owned_unchecked; +use polars_error::constants::LENGTH_LIMIT_MSG; use super::*; #[cfg(feature = "object")] @@ -72,22 +73,11 @@ impl ChunkedArray { _ => chunks.iter().fold(0, |acc, arr| acc + arr.len()), } } - self.length = inner(&self.chunks) as IdxSize; + self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); if self.length <= 1 { self.set_sorted_flag(IsSorted::Ascending) } - - #[cfg(feature = "python")] - assert!( - self.length < IdxSize::MAX, - "Polars' maximum length reached. Consider installing 'polars-u64-idx'." - ); - #[cfg(not(feature = "python"))] - assert!( - self.length < IdxSize::MAX, - "Polars' maximum length reached. Consider compiling with 'bigidx' feature." - ); } pub fn rechunk(&self) -> Self { @@ -133,7 +123,7 @@ impl ChunkedArray { self.slice(0, num_elements) } - /// Get the head of the ChunkedArray + /// Get the head of the [`ChunkedArray`] #[must_use] pub fn head(&self, length: Option) -> Self where @@ -145,7 +135,7 @@ impl ChunkedArray { } } - /// Get the tail of the ChunkedArray + /// Get the tail of the [`ChunkedArray`] #[must_use] pub fn tail(&self, length: Option) -> Self where diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index a637d423f705..c62d291e8de2 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -1,106 +1,70 @@ -//! -//! Used to speed up PartialEq and PartialOrd of elements within an array -//! +//! Used to speed up TotalEq and TotalOrd of elements within an array. -use std::cmp::{Ordering, PartialEq}; +use std::cmp::Ordering; -use crate::chunked_array::ops::take::take_random::{ - BinaryTakeRandom, BinaryTakeRandomSingleChunk, BoolTakeRandom, BoolTakeRandomSingleChunk, - NumTakeRandomChunked, NumTakeRandomCont, NumTakeRandomSingleChunk, Utf8TakeRandom, - Utf8TakeRandomSingleChunk, -}; -#[cfg(feature = "object")] -use crate::chunked_array::ops::take::take_random::{ObjectTakeRandom, ObjectTakeRandomSingleChunk}; +use crate::chunked_array::ChunkedArrayLayout; use crate::prelude::*; -use crate::utils::Wrap; -pub trait PartialEqInner: Send + Sync { - /// Safety: - /// Does not do any bound checks - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool; -} +#[repr(transparent)] +struct NonNull(T); -pub trait PartialOrdInner: Send + Sync { - /// Safety: - /// Does not do any bound checks - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering; +trait GetInner { + type Item; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item; } -macro_rules! impl_traits { - ($struct:ty) => { - impl PartialEqInner for $struct { - #[inline] - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - self.get(idx_a) == self.get(idx_b) - } - } - impl PartialOrdInner for $struct { - #[inline] - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - let a = self.get(idx_a); - let b = self.get(idx_b); - a.partial_cmp(&b).unwrap_or_else(|| fallback(a)) - } - } - }; - ($struct:ty, $T:tt) => { - impl<$T> PartialEqInner for $struct - where - $T: NumericNative + Sync, - { - #[inline] - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - self.get(idx_a) == self.get(idx_b) - } - } - - impl<$T> PartialOrdInner for $struct - where - $T: NumericNative + Sync, - { - #[inline] - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - // nulls so we can not do unchecked - let a = self.get(idx_a); - let b = self.get(idx_b); - a.partial_cmp(&b).unwrap_or_else(|| fallback(a)) - } - } - }; +impl<'a, T: PolarsDataType> GetInner for &'a ChunkedArray { + type Item = Option>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + ChunkedArray::get_unchecked(self, idx) + } } -impl_traits!(Utf8TakeRandom<'_>); -impl_traits!(Utf8TakeRandomSingleChunk<'_>); -impl_traits!(BinaryTakeRandom<'_>); -impl_traits!(BinaryTakeRandomSingleChunk<'_>); -impl_traits!(BoolTakeRandom<'_>); -impl_traits!(BoolTakeRandomSingleChunk<'_>); -impl_traits!(NumTakeRandomSingleChunk<'_, T>, T); -impl_traits!(NumTakeRandomChunked<'_, T>, T); +impl<'a, T: StaticArray> GetInner for &'a T { + type Item = Option>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + ::get_unchecked(self, idx) + } +} -impl<'a> PartialEqInner for ListTakeRandomSingleChunk<'a> { - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - self.get_unchecked(idx_a).map(Wrap) == self.get_unchecked(idx_b).map(Wrap) +impl<'a, T: PolarsDataType> GetInner for NonNull<&'a ChunkedArray> { + type Item = T::Physical<'a>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + self.0.value_unchecked(idx) } } -impl<'a> PartialEqInner for ListTakeRandom<'a> { - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - self.get_unchecked(idx_a).map(Wrap) == self.get_unchecked(idx_b).map(Wrap) +impl<'a, T: StaticArray> GetInner for NonNull<&'a T> { + type Item = T::ValueT<'a>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + self.0.value_unchecked(idx) } } -impl PartialEqInner for NumTakeRandomCont<'_, T> +pub trait PartialEqInner: Send + Sync { + /// # Safety + /// Does not do any bound checks. + unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool; +} + +pub trait PartialOrdInner: Send + Sync { + /// # Safety + /// Does not do any bound checks. + unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering; +} + +impl PartialEqInner for T where - T: Copy + PartialEq + Sync, + T: GetInner + Send + Sync, + T::Item: TotalEq, { + #[inline] unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - // no nulls so we can do unchecked - self.get_unchecked(idx_a) == self.get_unchecked(idx_b) + self.get_unchecked(idx_a).tot_eq(&self.get_unchecked(idx_b)) } } -/// Create a type that implements PartialEqInner +/// Create a type that implements PartialEqInner. pub(crate) trait IntoPartialEqInner<'a> { /// Create a type that implements `TakeRandom`. fn into_partial_eq_inner(self) -> Box; @@ -109,288 +73,92 @@ pub(crate) trait IntoPartialEqInner<'a> { /// We use a trait object because we want to call this from Series and cannot use a typed enum. impl<'a, T> IntoPartialEqInner<'a> for &'a ChunkedArray where - T: PolarsNumericType, + T: PolarsDataType, + T::Physical<'a>: TotalEq, { fn into_partial_eq_inner(self) -> Box { - let mut chunks = self.downcast_iter(); - - if self.chunks.len() == 1 { - let arr = chunks.next().unwrap(); - - if !self.has_validity() { - let t = NumTakeRandomCont { - slice: arr.values(), - }; - Box::new(t) - } else { - let t = NumTakeRandomSingleChunk::<'_, T::Native>::new(arr); - Box::new(t) - } - } else { - let t = NumTakeRandomChunked::<'_, T::Native> { - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - } - } -} - -impl<'a> IntoPartialEqInner<'a> for &'a ListChunked { - fn into_partial_eq_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = ListTakeRandomSingleChunk { - arr, - name: self.name(), - }; - Box::new(t) - }, - _ => { - let name = self.name(); - let inner_type = self.inner_dtype().to_physical(); - let t = ListTakeRandom { - inner_type, - name, - chunks: self.downcast_iter().collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, - } - } -} - -impl<'a> IntoPartialEqInner<'a> for &'a Utf8Chunked { - fn into_partial_eq_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = Utf8TakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = Utf8TakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, - } - } -} - -impl<'a> IntoPartialEqInner<'a> for &'a BinaryChunked { - fn into_partial_eq_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = BinaryTakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = BinaryTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, - } - } -} - -impl<'a> IntoPartialEqInner<'a> for &'a BooleanChunked { - fn into_partial_eq_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = BoolTakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = BoolTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, + match self.layout() { + ChunkedArrayLayout::SingleNoNull(arr) => Box::new(NonNull(arr)), + ChunkedArrayLayout::Single(arr) => Box::new(arr), + ChunkedArrayLayout::MultiNoNull(ca) => Box::new(NonNull(ca)), + ChunkedArrayLayout::Multi(ca) => Box::new(ca), } } } -// Partial ordering implementations - -fn fallback(a: T) -> Ordering { - // nan != nan - // this is a simple way to check if it is nan - // without convincing the compiler we deal with floats - #[allow(clippy::eq_op)] - if a != a { - Ordering::Less - } else { - Ordering::Greater - } -} - -impl PartialOrdInner for NumTakeRandomCont<'_, T> +impl PartialOrdInner for T where - T: Copy + PartialOrd + Sync, + T: GetInner + Send + Sync, + T::Item: TotalOrd, { + #[inline] unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - // no nulls so we can do unchecked let a = self.get_unchecked(idx_a); let b = self.get_unchecked(idx_b); - a.partial_cmp(&b).unwrap_or_else(|| fallback(a)) + a.tot_cmp(&b) } } -/// Create a type that implements PartialOrdInner + +/// Create a type that implements PartialOrdInner. pub(crate) trait IntoPartialOrdInner<'a> { /// Create a type that implements `TakeRandom`. fn into_partial_ord_inner(self) -> Box; } -/// We use a trait object because we want to call this from Series and cannot use a typed enum. + impl<'a, T> IntoPartialOrdInner<'a> for &'a ChunkedArray where - T: PolarsNumericType, + T: PolarsDataType, + T::Physical<'a>: TotalOrd, { fn into_partial_ord_inner(self) -> Box { - let mut chunks = self.downcast_iter(); - - if self.chunks.len() == 1 { - let arr = chunks.next().unwrap(); - - if !self.has_validity() { - let t = NumTakeRandomCont { - slice: arr.values(), - }; - Box::new(t) - } else { - let t = NumTakeRandomSingleChunk::<'_, T::Native>::new(arr); - Box::new(t) - } - } else { - let t = NumTakeRandomChunked::<'_, T::Native> { - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) + match self.layout() { + ChunkedArrayLayout::SingleNoNull(arr) => Box::new(NonNull(arr)), + ChunkedArrayLayout::Single(arr) => Box::new(arr), + ChunkedArrayLayout::MultiNoNull(ca) => Box::new(NonNull(ca)), + ChunkedArrayLayout::Multi(ca) => Box::new(ca), } } } -impl<'a> IntoPartialOrdInner<'a> for &'a Utf8Chunked { - fn into_partial_ord_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = Utf8TakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = Utf8TakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, - } - } +#[cfg(feature = "dtype-categorical")] +struct LocalCategorical<'a> { + rev_map: &'a Utf8Array, + cats: &'a UInt32Chunked, } -impl<'a> IntoPartialOrdInner<'a> for &'a BinaryChunked { - fn into_partial_ord_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = BinaryTakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = BinaryTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, - } +#[cfg(feature = "dtype-categorical")] +impl<'a> GetInner for LocalCategorical<'a> { + type Item = Option<&'a str>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + let cat = self.cats.get_unchecked(idx)?; + Some(self.rev_map.value_unchecked(cat as usize)) } } -impl<'a> IntoPartialOrdInner<'a> for &'a BooleanChunked { - fn into_partial_ord_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = BoolTakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = BoolTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, - } +#[cfg(feature = "dtype-categorical")] +struct GlobalCategorical<'a> { + p1: &'a PlHashMap, + p2: &'a Utf8Array, + cats: &'a UInt32Chunked, +} + +#[cfg(feature = "dtype-categorical")] +impl<'a> GetInner for GlobalCategorical<'a> { + type Item = Option<&'a str>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + let cat = self.cats.get_unchecked(idx)?; + let idx = self.p1.get(&cat).unwrap(); + Some(self.p2.value_unchecked(*idx as usize)) } } #[cfg(feature = "dtype-categorical")] impl<'a> IntoPartialOrdInner<'a> for &'a CategoricalChunked { fn into_partial_ord_inner(self) -> Box { + let cats = self.logical(); match &**self.get_rev_map() { - RevMapping::Local(_) => Box::new(CategoricalTakeRandomLocal::new(self)), - RevMapping::Global(_, _, _) => Box::new(CategoricalTakeRandomGlobal::new(self)), - } - } -} - -#[cfg(feature = "object")] -impl<'a, T> PartialEqInner for ObjectTakeRandom<'a, T> -where - T: PolarsObject, -{ - #[inline] - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - self.get(idx_a) == self.get(idx_b) - } -} - -#[cfg(feature = "object")] -impl<'a, T> PartialEqInner for ObjectTakeRandomSingleChunk<'a, T> -where - T: PolarsObject, -{ - #[inline] - unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { - self.get(idx_a) == self.get(idx_b) - } -} - -#[cfg(feature = "object")] -impl<'a, T: PolarsObject> IntoPartialEqInner<'a> for &'a ObjectChunked { - fn into_partial_eq_inner(self) -> Box { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = ObjectTakeRandomSingleChunk { arr }; - Box::new(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = ObjectTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) - }, + RevMapping::Global(p1, p2, _) => Box::new(GlobalCategorical { p1, p2, cats }), + RevMapping::Local(rev_map) => Box::new(LocalCategorical { rev_map, cats }), } } } diff --git a/crates/polars-core/src/chunked_array/ops/concat_str.rs b/crates/polars-core/src/chunked_array/ops/concat_str.rs deleted file mode 100644 index a07321e677b0..000000000000 --- a/crates/polars-core/src/chunked_array/ops/concat_str.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::fmt::{Display, Write}; - -use polars_arrow::array::default_arrays::FromDataUtf8; - -use super::StrConcat; -use crate::prelude::*; - -fn fmt_and_write(value: Option, buf: &mut String) { - match value { - None => buf.push_str("null"), - Some(v) => { - write!(buf, "{v}").unwrap(); - }, - } -} - -fn str_concat_impl(mut iter: I, delimiter: &str, name: &str) -> Utf8Chunked -where - I: Iterator>, - T: Display, -{ - let mut buf = String::with_capacity(iter.size_hint().0 * 5); - - if let Some(first) = iter.next() { - fmt_and_write(first, &mut buf); - - for val in iter { - buf.push_str(delimiter); - fmt_and_write(val, &mut buf); - } - } - buf.shrink_to_fit(); - let buf = buf.into_bytes(); - let offsets = vec![0, buf.len() as i64]; - let arr = unsafe { Utf8Array::from_data_unchecked_default(offsets.into(), buf.into(), None) }; - Utf8Chunked::with_chunk(name, arr) -} - -impl StrConcat for ChunkedArray -where - T: PolarsNumericType, - T::Native: Display, -{ - fn str_concat(&self, delimiter: &str) -> Utf8Chunked { - let iter = self.into_iter(); - str_concat_impl(iter, delimiter, self.name()) - } -} - -impl StrConcat for Utf8Chunked { - fn str_concat(&self, delimiter: &str) -> Utf8Chunked { - let iter = self.into_iter(); - str_concat_impl(iter, delimiter, self.name()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_str_concat() { - let ca = Int32Chunked::new("foo", &[Some(1), None, Some(3)]); - let out = ca.str_concat("-"); - - let out = out.get(0); - assert_eq!(out, Some("1-null-3")); - } -} diff --git a/crates/polars-core/src/chunked_array/ops/decimal.rs b/crates/polars-core/src/chunked_array/ops/decimal.rs index 199b53d703af..07171c7d16be 100644 --- a/crates/polars-core/src/chunked_array/ops/decimal.rs +++ b/crates/polars-core/src/chunked_array/ops/decimal.rs @@ -1,7 +1,7 @@ use crate::prelude::*; impl Utf8Chunked { - /// Convert an [`Utf8Chunked`] to a `Series` of [`DataType::Decimal`]. + /// Convert an [`Utf8Chunked`] to a [`Series`] of [`DataType::Decimal`]. /// The parameters needed for the decimal type are inferred. /// /// If the decimal `precision` and `scale` are already known, consider diff --git a/crates/polars-core/src/chunked_array/ops/downcast.rs b/crates/polars-core/src/chunked_array/ops/downcast.rs index 66197f51efb7..d31ee0b545f4 100644 --- a/crates/polars-core/src/chunked_array/ops/downcast.rs +++ b/crates/polars-core/src/chunked_array/ops/downcast.rs @@ -47,18 +47,13 @@ impl<'a, T> Chunks<'a, T> { } #[doc(hidden)] -impl ChunkedArray -where - Self: HasUnderlyingArray, -{ +impl ChunkedArray { #[inline] - pub fn downcast_iter( - &self, - ) -> impl Iterator::ArrayT> + DoubleEndedIterator { + pub fn downcast_iter(&self) -> impl Iterator + DoubleEndedIterator { self.chunks.iter().map(|arr| { - // SAFETY: HasUnderlyingArray guarantees this is correct. + // SAFETY: T::Array guarantees this is correct. let arr = &**arr; - unsafe { &*(arr as *const dyn Array as *const ::ArrayT) } + unsafe { &*(arr as *const dyn Array as *const T::Array) } }) } @@ -69,19 +64,37 @@ where #[inline] pub unsafe fn downcast_iter_mut( &mut self, - ) -> impl Iterator::ArrayT> + DoubleEndedIterator { + ) -> impl Iterator + DoubleEndedIterator { self.chunks.iter_mut().map(|arr| { - // SAFETY: HasUnderlyingArray guarantees this is correct. + // SAFETY: T::Array guarantees this is correct. let arr = &mut **arr; - &mut *(arr as *mut dyn Array as *mut ::ArrayT) + &mut *(arr as *mut dyn Array as *mut T::Array) }) } #[inline] - pub fn downcast_chunks(&self) -> Chunks<'_, ::ArrayT> { + pub fn downcast_chunks(&self) -> Chunks<'_, T::Array> { Chunks::new(&self.chunks) } + #[inline] + pub fn downcast_get(&self, idx: usize) -> Option<&T::Array> { + let arr = self.chunks.get(idx)?; + // SAFETY: T::Array guarantees this is correct. + let arr = &**arr; + unsafe { Some(&*(arr as *const dyn Array as *const T::Array)) } + } + + #[inline] + /// # Safety + /// It is up to the caller to ensure the chunk idx is in-bounds + pub unsafe fn downcast_get_unchecked(&self, idx: usize) -> &T::Array { + let arr = self.chunks.get_unchecked(idx); + // SAFETY: T::Array guarantees this is correct. + let arr = &**arr; + unsafe { &*(arr as *const dyn Array as *const T::Array) } + } + /// Get the index of the chunk and the index of the value in that chunk. #[inline] pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) { diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index 9dd53ca94189..8edd6b19d898 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -176,7 +176,16 @@ impl ExplodeByOffsets for Float64Chunked { impl ExplodeByOffsets for NullChunked { fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - NullChunked::new(self.name.clone(), offsets.len() - 1).into_series() + let mut last_offset = offsets[0]; + + let mut len = 0; + for &offset in &offsets[1..] { + // If offset == last_offset we have an empty list and a new row is inserted, + // therefore we always increase at least 1. + len += std::cmp::max(offset - last_offset, 1) as usize; + last_offset = offset; + } + NullChunked::new(self.name.clone(), len).into_series() } } diff --git a/crates/polars-core/src/chunked_array/ops/extend.rs b/crates/polars-core/src/chunked_array/ops/extend.rs index 0dc9e1fa8651..5a7743d251e3 100644 --- a/crates/polars-core/src/chunked_array/ops/extend.rs +++ b/crates/polars-core/src/chunked_array/ops/extend.rs @@ -37,7 +37,7 @@ where /// when you read in multiple files and when to store them in a single `DataFrame`. /// In the latter case finish the sequence of `append` operations with a [`rechunk`](Self::rechunk). pub fn extend(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); + update_sorted_flag_before_append::(self, other); // all to a single chunk if self.chunks.len() > 1 { self.append(other); @@ -90,7 +90,7 @@ where #[doc(hidden)] impl Utf8Chunked { pub fn extend(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); + update_sorted_flag_before_append::(self, other); if self.chunks.len() > 1 { self.append(other); *self = self.rechunk(); @@ -129,7 +129,7 @@ impl Utf8Chunked { #[doc(hidden)] impl BinaryChunked { pub fn extend(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); + update_sorted_flag_before_append::(self, other); if self.chunks.len() > 1 { self.append(other); *self = self.rechunk(); @@ -167,7 +167,7 @@ impl BinaryChunked { #[doc(hidden)] impl BooleanChunked { pub fn extend(&mut self, other: &Self) { - update_sorted_flag_before_append(self, other); + update_sorted_flag_before_append::(self, other); // make sure that we are a single chunk already if self.chunks.len() > 1 { self.append(other); diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 440fb6591ea8..97bb000a282c 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -278,8 +278,8 @@ where + compute::aggregate::Sum + compute::aggregate::SimdOrd, { - // nothing to fill - if !ca.has_validity() { + // Nothing to fill. + if ca.null_count() == 0 { return Ok(ca.clone()); } let mut out = match strategy { @@ -287,8 +287,12 @@ where FillNullStrategy::Forward(Some(limit)) => fill_forward_limit(ca, limit), FillNullStrategy::Backward(None) => fill_backward(ca), FillNullStrategy::Backward(Some(limit)) => fill_backward_limit(ca, limit), - FillNullStrategy::Min => ca.fill_null_with_values(ca.min().ok_or_else(err_fill_null)?)?, - FillNullStrategy::Max => ca.fill_null_with_values(ca.max().ok_or_else(err_fill_null)?)?, + FillNullStrategy::Min => { + ca.fill_null_with_values(ChunkAgg::min(ca).ok_or_else(err_fill_null)?)? + }, + FillNullStrategy::Max => { + ca.fill_null_with_values(ChunkAgg::max(ca).ok_or_else(err_fill_null)?)? + }, FillNullStrategy::Mean => ca.fill_null_with_values( ca.mean() .map(|v| NumCast::from(v).unwrap()) @@ -304,8 +308,8 @@ where } fn fill_null_bool(ca: &BooleanChunked, strategy: FillNullStrategy) -> PolarsResult { - // nothing to fill - if !ca.has_validity() { + // Nothing to fill. + if ca.null_count() == 0 { return Ok(ca.clone().into_series()); } match strategy { @@ -342,8 +346,8 @@ fn fill_null_bool(ca: &BooleanChunked, strategy: FillNullStrategy) -> PolarsResu } fn fill_null_binary(ca: &BinaryChunked, strategy: FillNullStrategy) -> PolarsResult { - // nothing to fill - if !ca.has_validity() { + // Nothing to fill. + if ca.null_count() == 0 { return Ok(ca.clone()); } match strategy { @@ -374,8 +378,8 @@ fn fill_null_binary(ca: &BinaryChunked, strategy: FillNullStrategy) -> PolarsRes } fn fill_null_list(ca: &ListChunked, strategy: FillNullStrategy) -> PolarsResult { - // nothing to fill - if !ca.has_validity() { + // Nothing to fill. + if ca.null_count() == 0 { return Ok(ca.clone()); } match strategy { diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs index 7543cff66583..8a50021147a1 100644 --- a/crates/polars-core/src/chunked_array/ops/filter.rs +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -104,6 +104,7 @@ impl ChunkFilter for ListChunked { )), }; } + check_filter_len!(self, filter); Ok(unsafe { arity::binary_unchecked_same_type( self, @@ -129,6 +130,7 @@ impl ChunkFilter for ArrayChunked { )), }; } + check_filter_len!(self, filter); Ok(unsafe { arity::binary_unchecked_same_type( self, @@ -157,7 +159,7 @@ where _ => Ok(ObjectChunked::new_empty(self.name())), }; } - polars_ensure!(!self.is_empty(), NoData: "cannot filter empty object array"); + check_filter_len!(self, filter); let chunks = self.downcast_iter().collect::>(); let mut builder = ObjectChunkedBuilder::::new(self.name(), self.len()); for (idx, mask) in filter.into_iter().enumerate() { diff --git a/crates/polars-core/src/chunked_array/ops/for_each.rs b/crates/polars-core/src/chunked_array/ops/for_each.rs new file mode 100644 index 000000000000..42713e0cdff2 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/for_each.rs @@ -0,0 +1,15 @@ +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn for_each<'a, F>(&'a self, mut op: F) + where + F: FnMut(Option>), + { + self.downcast_iter().for_each(|arr| { + arr.iter().for_each(&mut op); + }) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index 63a490199f90..82f92b231a72 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -147,15 +147,22 @@ impl ChunkFullNull for ArrayChunked { impl ListChunked { pub fn full_null_with_dtype(name: &str, length: usize, inner_dtype: &DataType) -> ListChunked { - let arr = ListArray::new_null( + let arr: ListArray = ListArray::new_null( ArrowDataType::LargeList(Box::new(ArrowField::new( "item", - inner_dtype.to_arrow(), + inner_dtype.to_physical().to_arrow(), true, ))), length, ); - ChunkedArray::with_chunk(name, arr) + // SAFETY: physical type matches the logical. + unsafe { + ChunkedArray::from_chunks_and_dtype( + name, + vec![Box::new(arr)], + DataType::List(Box::new(inner_dtype.clone())), + ) + } } } #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs new file mode 100644 index 000000000000..3115b7b3121c --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -0,0 +1,224 @@ +use arrow::array::Array; +use arrow::bitmap::bitmask::BitMask; +use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_utils::index::check_bounds; + +use crate::chunked_array::ops::{ChunkTake, ChunkTakeUnchecked}; +use crate::chunked_array::ChunkedArray; +use crate::datatypes::{IdxCa, PolarsDataType, StaticArray}; +use crate::prelude::*; + +const BINARY_SEARCH_LIMIT: usize = 8; + +pub fn check_bounds_nulls(idx: &PrimitiveArray, len: IdxSize) -> PolarsResult<()> { + let mask = BitMask::from_bitmap(idx.validity().unwrap()); + + // We iterate in chunks to make the inner loop branch-free. + for (block_idx, block) in idx.values().chunks(32).enumerate() { + let mut in_bounds = 0; + for (i, x) in block.iter().enumerate() { + in_bounds |= ((*x < len) as u32) << i; + } + let m = mask.get_u32(32 * block_idx); + polars_ensure!(m == m & in_bounds, ComputeError: "take indices are out of bounds"); + } + Ok(()) +} + +pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> { + let all_valid = indices.downcast_iter().all(|a| { + if a.null_count() == 0 { + check_bounds(a.values(), len).is_ok() + } else { + check_bounds_nulls(a, len).is_ok() + } + }); + polars_ensure!(all_valid, ComputeError: "take indices are out of bounds"); + Ok(()) +} + +impl + ?Sized> ChunkTake for ChunkedArray +where + ChunkedArray: ChunkTakeUnchecked, +{ + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &I) -> PolarsResult { + check_bounds(indices.as_ref(), self.len() as IdxSize)?; + + // SAFETY: we just checked the indices are valid. + Ok(unsafe { self.take_unchecked(indices) }) + } +} + +impl ChunkTake for ChunkedArray +where + ChunkedArray: ChunkTakeUnchecked, +{ + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &IdxCa) -> PolarsResult { + check_bounds_ca(indices, self.len() as IdxSize)?; + + // SAFETY: we just checked the indices are valid. + Ok(unsafe { self.take_unchecked(indices) }) + } +} + +/// Computes cumulative lengths for efficient branchless binary search +/// lookup. The first element is always 0, and the last length of arrs +/// is always ignored (as we already checked that all indices are +/// in-bounds we don't need to check against the last length). +fn cumulative_lengths(arrs: &[&A]) -> [IdxSize; BINARY_SEARCH_LIMIT] { + assert!(arrs.len() <= BINARY_SEARCH_LIMIT); + let mut ret = [IdxSize::MAX; BINARY_SEARCH_LIMIT]; + ret[0] = 0; + for i in 1..arrs.len() { + ret[i] = ret[i - 1] + arrs[i - 1].len() as IdxSize; + } + ret +} + +#[rustfmt::skip] +#[inline] +fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize; BINARY_SEARCH_LIMIT]) -> (usize, usize) { + // Branchless bitwise binary search. + let mut chunk_idx = 0; + chunk_idx += if idx >= cumlens[chunk_idx + 0b100] { 0b0100 } else { 0 }; + chunk_idx += if idx >= cumlens[chunk_idx + 0b010] { 0b0010 } else { 0 }; + chunk_idx += if idx >= cumlens[chunk_idx + 0b001] { 0b0001 } else { 0 }; + (chunk_idx, (idx - cumlens[chunk_idx]) as usize) +} + +#[inline] +unsafe fn target_value_unchecked<'a, A: StaticArray>( + targets: &[&'a A], + cumlens: &[IdxSize; BINARY_SEARCH_LIMIT], + idx: IdxSize, +) -> A::ValueT<'a> { + let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens); + let arr = targets.get_unchecked(chunk_idx); + arr.value_unchecked(arr_idx) +} + +#[inline] +unsafe fn target_get_unchecked<'a, A: StaticArray>( + targets: &[&'a A], + cumlens: &[IdxSize; BINARY_SEARCH_LIMIT], + idx: IdxSize, +) -> Option> { + let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens); + let arr = targets.get_unchecked(chunk_idx); + arr.get_unchecked(arr_idx) +} + +unsafe fn gather_idx_array_unchecked( + dtype: DataType, + targets: &[&A], + has_nulls: bool, + indices: &[IdxSize], +) -> A { + let it = indices.iter().copied(); + if targets.len() == 1 { + let target = targets.first().unwrap(); + if has_nulls { + it.map(|i| target.get_unchecked(i as usize)) + .collect_arr_trusted_with_dtype(dtype) + } else if let Some(sl) = target.as_slice() { + // Avoid the Arc overhead from value_unchecked. + it.map(|i| sl.get_unchecked(i as usize).clone()) + .collect_arr_trusted_with_dtype(dtype) + } else { + it.map(|i| target.value_unchecked(i as usize)) + .collect_arr_trusted_with_dtype(dtype) + } + } else { + let cumlens = cumulative_lengths(targets); + if has_nulls { + it.map(|i| target_get_unchecked(targets, &cumlens, i)) + .collect_arr_trusted_with_dtype(dtype) + } else { + it.map(|i| target_value_unchecked(targets, &cumlens, i)) + .collect_arr_trusted_with_dtype(dtype) + } + } +} + +impl + ?Sized> ChunkTakeUnchecked for ChunkedArray { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let rechunked; + let mut ca = self; + if self.chunks().len() > BINARY_SEARCH_LIMIT { + rechunked = self.rechunk(); + ca = &rechunked; + } + let targets: Vec<_> = ca.downcast_iter().collect(); + let arr = gather_idx_array_unchecked( + ca.dtype().clone(), + &targets, + ca.null_count() > 0, + indices.as_ref(), + ); + ChunkedArray::from_chunk_iter_like(ca, [arr]) + } +} + +impl ChunkTakeUnchecked for ChunkedArray { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let rechunked; + let mut ca = self; + if self.chunks().len() > BINARY_SEARCH_LIMIT { + rechunked = self.rechunk(); + ca = &rechunked; + } + let targets_have_nulls = ca.null_count() > 0; + let targets: Vec<_> = ca.downcast_iter().collect(); + + let chunks = indices.downcast_iter().map(|idx_arr| { + let dtype = ca.dtype().clone(); + if idx_arr.null_count() == 0 { + gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values()) + } else if targets.len() == 1 { + let target = targets.first().unwrap(); + if targets_have_nulls { + idx_arr + .iter() + .map(|i| target.get_unchecked(*i? as usize)) + .collect_arr_trusted_with_dtype(dtype) + } else { + idx_arr + .iter() + .map(|i| Some(target.value_unchecked(*i? as usize))) + .collect_arr_trusted_with_dtype(dtype) + } + } else { + let cumlens = cumulative_lengths(&targets); + if targets_have_nulls { + idx_arr + .iter() + .map(|i| target_get_unchecked(&targets, &cumlens, *i?)) + .collect_arr_trusted_with_dtype(dtype) + } else { + idx_arr + .iter() + .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?))) + .collect_arr_trusted_with_dtype(dtype) + } + } + }); + + let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks); + + use crate::series::IsSorted::*; + let sorted_flag = match (ca.is_sorted_flag(), indices.is_sorted_flag()) { + (_, Not) => Not, + (Not, _) => Not, + (Ascending, Ascending) => Ascending, + (Ascending, Descending) => Descending, + (Descending, Ascending) => Descending, + (Descending, Descending) => Ascending, + }; + out.set_sorted_flag(sorted_flag); + out + } +} diff --git a/crates/polars-core/src/chunked_array/ops/len.rs b/crates/polars-core/src/chunked_array/ops/len.rs deleted file mode 100644 index 8b137891791f..000000000000 --- a/crates/polars-core/src/chunked_array/ops/len.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index d4b716f42928..4d402de9372e 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -2,7 +2,6 @@ use arrow::offset::OffsetsBuffer; use polars_arrow::prelude::QuantileInterpolOptions; -pub use self::take::*; #[cfg(feature = "object")] use crate::datatypes::ObjectType; use crate::prelude::*; @@ -17,8 +16,6 @@ pub mod arity; mod bit_repr; pub(crate) mod chunkops; pub(crate) mod compare_inner; -#[cfg(feature = "concat_str")] -mod concat_str; #[cfg(feature = "cum_agg")] mod cum_agg; #[cfg(feature = "dtype-decimal")] @@ -27,18 +24,16 @@ pub(crate) mod downcast; pub(crate) mod explode; mod explode_and_offsets; mod extend; -mod fill_null; +pub mod fill_null; mod filter; +mod for_each; pub mod full; +pub mod gather; #[cfg(feature = "interpolate")] mod interpolate; -mod len; #[cfg(feature = "zip_with")] pub(crate) mod min_max_binary; mod nulls; -mod peaks; -#[cfg(feature = "repeat_by")] -mod repeat_by; mod reverse; pub(crate) mod rolling_window; mod set; @@ -46,6 +41,7 @@ mod shift; pub mod sort; pub(crate) mod take; mod tile; +#[cfg(feature = "algorithm_group_by")] pub(crate) mod unique; #[cfg(feature = "zip_with")] pub mod zip; @@ -144,83 +140,19 @@ pub trait ChunkRollApply: AsRefDataType { } } -/// Random access -pub trait TakeRandom { - type Item; - - /// Get a nullable value by index. - /// - /// # Panics - /// Panics if `index >= self.len()` - fn get(&self, index: usize) -> Option; - - /// Get a value by index and ignore the null bit. - /// - /// # Safety - /// - /// Does not do bound checks. - unsafe fn get_unchecked(&self, index: usize) -> Option +pub trait ChunkTake: ChunkTakeUnchecked { + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &Idx) -> PolarsResult where - Self: Sized, - { - self.get(index) - } - - /// This is much faster if we have many chunks as we don't have to compute the index - /// # Panics - /// Panics if `index >= self.len()` - fn last(&self) -> Option; -} -// Utility trait because associated type needs a lifetime -pub trait TakeRandomUtf8 { - type Item; - - /// Get a nullable value by index. - /// - /// # Panics - /// Panics if `index >= self.len()` - fn get(self, index: usize) -> Option; - - /// Get a value by index and ignore the null bit. - /// - /// # Safety - /// - /// Does not do bound checks. - unsafe fn get_unchecked(self, index: usize) -> Option - where - Self: Sized, - { - self.get(index) - } - - /// This is much faster if we have many chunks - /// # Panics - /// Panics if `index >= self.len()` - fn last(&self) -> Option; + Self: Sized; } -/// Fast access by index. -pub trait ChunkTake { - /// Take values from ChunkedArray by index. +pub trait ChunkTakeUnchecked { + /// Gather values from ChunkedArray by index. /// /// # Safety - /// - /// Doesn't do any bound checking. - #[must_use] - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: Sized, - I: TakeIterator, - INulls: TakeIteratorNulls; - - /// Take values from ChunkedArray by index. - /// Note that the iterator will be cloned, so prefer an iterator that takes the owned memory - /// by reference. - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: Sized, - I: TakeIterator, - INulls: TakeIteratorNulls; + /// The non-null indices must be valid. + unsafe fn take_unchecked(&self, indices: &Idx) -> Self; } /// Create a `ChunkedArray` with new values by index or by boolean mask. @@ -380,14 +312,14 @@ pub trait ChunkQuantile { } /// Variance and standard deviation aggregation. -pub trait ChunkVar { +pub trait ChunkVar { /// Compute the variance of this ChunkedArray/Series. - fn var(&self, _ddof: u8) -> Option { + fn var(&self, _ddof: u8) -> Option { None } /// Compute the standard deviation of this ChunkedArray/Series. - fn std(&self, _ddof: u8) -> Option { + fn std(&self, _ddof: u8) -> Option { None } } @@ -449,12 +381,6 @@ pub trait ChunkUnique { fn n_unique(&self) -> PolarsResult { self.arg_unique().map(|v| v.len()) } - - /// The most occurring value(s). Can return multiple Values - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult> { - polars_bail!(opq = mode, T::get_dtype()); - } } #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)] @@ -589,10 +515,9 @@ macro_rules! impl_chunk_expand { }}; } -impl ChunkExpandAtIndex for ChunkedArray +impl ChunkExpandAtIndex for ChunkedArray where - ChunkedArray: ChunkFull + TakeRandom, - T: PolarsNumericType, + ChunkedArray: ChunkFull, { fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray { let mut out = impl_chunk_expand!(self, length, index); @@ -627,7 +552,7 @@ impl ChunkExpandAtIndex for BinaryChunked { impl ChunkExpandAtIndex for ListChunked { fn new_from_index(&self, index: usize, length: usize) -> ListChunked { - let opt_val = self.get(index); + let opt_val = self.get_as_series(index); match opt_val { Some(val) => { let mut ca = ListChunked::full(self.name(), &val, length); @@ -642,7 +567,7 @@ impl ChunkExpandAtIndex for ListChunked { #[cfg(feature = "dtype-array")] impl ChunkExpandAtIndex for ArrayChunked { fn new_from_index(&self, index: usize, length: usize) -> ArrayChunked { - let opt_val = self.get(index); + let opt_val = self.get_as_series(index); match opt_val { Some(val) => { let mut ca = ArrayChunked::full(self.name(), &val, length); @@ -699,50 +624,18 @@ pub trait ChunkApplyKernel { S: PolarsDataType; } -/// Find local minima/ maxima -pub trait ChunkPeaks { - /// Get a boolean mask of the local maximum peaks. - fn peak_max(&self) -> BooleanChunked { - unimplemented!() - } - - /// Get a boolean mask of the local minimum peaks. - fn peak_min(&self) -> BooleanChunked { - unimplemented!() - } -} - -/// Repeat the values `n` times. -#[cfg(feature = "repeat_by")] -pub trait RepeatBy { - /// Repeat the values `n` times, where `n` is determined by the values in `by`. - fn repeat_by(&self, _by: &IdxCa) -> PolarsResult { - unimplemented!() - } -} - -#[cfg(feature = "is_first")] +#[cfg(feature = "is_first_distinct")] /// Mask the first unique values as `true` -pub trait IsFirst { - fn is_first(&self) -> PolarsResult { - polars_bail!(opq = is_first, T::get_dtype()); +pub trait IsFirstDistinct { + fn is_first_distinct(&self) -> PolarsResult { + polars_bail!(opq = is_first_distinct, T::get_dtype()); } } -#[cfg(feature = "is_last")] +#[cfg(feature = "is_last_distinct")] /// Mask the last unique values as `true` -pub trait IsLast { - fn is_last(&self) -> PolarsResult { - polars_bail!(opq = is_last, T::get_dtype()); +pub trait IsLastDistinct { + fn is_last_distinct(&self) -> PolarsResult { + polars_bail!(opq = is_last_distinct, T::get_dtype()); } } - -#[cfg(feature = "concat_str")] -/// Concat the values into a string array. -pub trait StrConcat { - /// Concat the values into a string array. - /// # Arguments - /// - /// * `delimiter` - A string that will act as delimiter between values. - fn str_concat(&self, delimiter: &str) -> Utf8Chunked; -} diff --git a/crates/polars-core/src/chunked_array/ops/peaks.rs b/crates/polars-core/src/chunked_array/ops/peaks.rs deleted file mode 100644 index 24a9cae4a08a..000000000000 --- a/crates/polars-core/src/chunked_array/ops/peaks.rs +++ /dev/null @@ -1,20 +0,0 @@ -use num_traits::Zero; - -use crate::prelude::*; - -impl ChunkPeaks for ChunkedArray -where - T: PolarsNumericType, -{ - /// Get a boolean mask of the local maximum peaks. - fn peak_max(&self) -> BooleanChunked { - (self.shift_and_fill(1, Some(Zero::zero())).lt(self)) - & (self.shift_and_fill(-1, Some(Zero::zero())).lt(self)) - } - - /// Get a boolean mask of the local minimum peaks. - fn peak_min(&self) -> BooleanChunked { - (self.shift_and_fill(1, Some(Zero::zero())).gt(self)) - & (self.shift_and_fill(-1, Some(Zero::zero())).gt(self)) - } -} diff --git a/crates/polars-core/src/chunked_array/ops/repeat_by.rs b/crates/polars-core/src/chunked_array/ops/repeat_by.rs deleted file mode 100644 index 69e83a042879..000000000000 --- a/crates/polars-core/src/chunked_array/ops/repeat_by.rs +++ /dev/null @@ -1,148 +0,0 @@ -use arrow::array::ListArray; -use polars_arrow::array::ListFromIter; - -use super::RepeatBy; -use crate::prelude::*; - -type LargeListArray = ListArray; - -fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> { - polars_ensure!( - (length_srs == length_by) | (length_by == 1) | (length_srs == 1), - ComputeError: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}", - length_srs, length_by - ); - Ok(()) -} - -impl RepeatBy for ChunkedArray -where - T: PolarsNumericType, -{ - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - check_lengths(self.len(), by.len())?; - - match (self.len(), by.len()) { - (left_len, right_len) if left_len == right_len => { - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { - LargeListArray::from_iter_primitive_trusted_len( - iter, - T::get_dtype().to_arrow(), - ) - } - })) - }, - (_, 1) => self.repeat_by(&IdxCa::new( - self.name(), - std::iter::repeat(by.get(0).unwrap()) - .take(self.len()) - .collect::>(), - )), - (1, _) => { - let new_array = self.new_from_index(0, by.len()); - new_array.repeat_by(by) - }, - // we have already checked the length - _ => unreachable!(), - } - } -} - -impl RepeatBy for BooleanChunked { - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - check_lengths(self.len(), by.len())?; - - match (self.len(), by.len()) { - (left_len, right_len) if left_len == right_len => { - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { LargeListArray::from_iter_bool_trusted_len(iter) } - })) - }, - (_, 1) => self.repeat_by(&IdxCa::new( - self.name(), - std::iter::repeat(by.get(0).unwrap()) - .take(self.len()) - .collect::>(), - )), - (1, _) => { - let new_array = self.new_from_index(0, by.len()); - new_array.repeat_by(by) - }, - // we have already checked the length - _ => unreachable!(), - } - } -} -impl RepeatBy for Utf8Chunked { - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - // TODO! dispatch via binary. - check_lengths(self.len(), by.len())?; - - match (self.len(), by.len()) { - (left_len, right_len) if left_len == right_len => { - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) } - })) - }, - (_, 1) => self.repeat_by(&IdxCa::new( - self.name(), - std::iter::repeat(by.get(0).unwrap()) - .take(self.len()) - .collect::>(), - )), - (1, _) => { - let new_array = self.new_from_index(0, by.len()); - new_array.repeat_by(by) - }, - // we have already checked the length - _ => unreachable!(), - } - } -} - -impl RepeatBy for BinaryChunked { - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - check_lengths(self.len(), by.len())?; - - match (self.len(), by.len()) { - (left_len, right_len) if left_len == right_len => { - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) } - })) - }, - (_, 1) => self.repeat_by(&IdxCa::new( - self.name(), - std::iter::repeat(by.get(0).unwrap()) - .take(self.len()) - .collect::>(), - )), - (1, _) => { - let new_array = self.new_from_index(0, by.len()); - new_array.repeat_by(by) - }, - // we have already checked the length - _ => unreachable!(), - } - } -} diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs index 4f950a8b78d0..d658991058fa 100644 --- a/crates/polars-core/src/chunked_array/ops/reverse.rs +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -82,8 +82,7 @@ impl ChunkReverse for ArrayChunked { #[cfg(feature = "object")] impl ChunkReverse for ObjectChunked { fn reverse(&self) -> Self { - // Safety - // we we know we don't get out of bounds - unsafe { self.take_unchecked((0..self.len()).rev().into()) } + // SAFETY: we know we don't go out of bounds. + unsafe { self.take_unchecked(&(0..self.len() as IdxSize).rev().collect_ca("")) } } } diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index c87d7fce9130..337dee580b2a 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -120,7 +120,7 @@ impl CategoricalChunked { #[cfg(test)] mod test { use crate::prelude::*; - use crate::{enable_string_cache, reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; fn assert_order(ca: &CategoricalChunked, cmp: &[&str]) { let s = ca.cast(&DataType::Utf8).unwrap(); @@ -133,9 +133,12 @@ mod test { let init = &["c", "b", "a", "d"]; let _lock = SINGLE_LOCK.lock(); - for toggle in [true, false] { - reset_string_cache(); - enable_string_cache(toggle); + for use_string_cache in [true, false] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + let s = Series::new("", init).cast(&DataType::Categorical(None))?; let ca = s.categorical()?; let mut ca_lexical = ca.clone(); @@ -157,13 +160,16 @@ mod test { } #[test] - fn test_cat_lexical_sort_multiple() -> PolarsResult<()> { let init = &["c", "b", "a", "a"]; let _lock = SINGLE_LOCK.lock(); - for enable in [true, false] { - enable_string_cache(enable); + for use_string_cache in [true, false] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + let s = Series::new("", init).cast(&DataType::Categorical(None))?; let ca = s.categorical()?; let mut ca_lexical: CategoricalChunked = ca.clone(); diff --git a/crates/polars-core/src/chunked_array/ops/take/mod.rs b/crates/polars-core/src/chunked_array/ops/take/mod.rs index 202d38236473..ccb11d118ba3 100644 --- a/crates/polars-core/src/chunked_array/ops/take/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/take/mod.rs @@ -1,594 +1,9 @@ //! Traits to provide fast Random access to ChunkedArrays data. //! This prevents downcasting every iteration. -//! IntoTakeRandom provides structs that implement the TakeRandom trait. -//! There are several structs that implement the fastest path for random access. -//! -use std::borrow::Cow; -use polars_arrow::compute::take::*; -pub use take_random::*; -pub use traits::*; - -use crate::chunked_array::kernels::take::*; use crate::prelude::*; use crate::utils::NoNull; mod take_chunked; -mod take_every; -pub(crate) mod take_random; -pub(crate) mod take_single; -mod traits; #[cfg(feature = "chunked_ids")] pub(crate) use take_chunked::*; - -macro_rules! take_iter_n_chunks { - ($ca:expr, $indices:expr) => {{ - let taker = $ca.take_rand(); - $indices.into_iter().map(|idx| taker.get(idx)).collect() - }}; -} - -macro_rules! take_opt_iter_n_chunks { - ($ca:expr, $indices:expr) => {{ - let taker = $ca.take_rand(); - $indices - .into_iter() - .map(|opt_idx| opt_idx.and_then(|idx| taker.get(idx))) - .collect() - }}; -} - -macro_rules! take_iter_n_chunks_unchecked { - ($ca:expr, $indices:expr) => {{ - let taker = $ca.take_rand(); - $indices - .into_iter() - .map(|idx| taker.get_unchecked(idx)) - .collect() - }}; -} - -macro_rules! take_opt_iter_n_chunks_unchecked { - ($ca:expr, $indices:expr) => {{ - let taker = $ca.take_rand(); - $indices - .into_iter() - .map(|opt_idx| opt_idx.and_then(|idx| taker.get_unchecked(idx))) - .collect() - }}; -} - -impl ChunkedArray -where - T: PolarsDataType, -{ - fn finish_from_array(&self, array: Box) -> Self { - let keep_fast_explode = array.null_count() == 0; - unsafe { self.copy_with_chunks(vec![array], false, keep_fast_explode) } - } -} - -impl ChunkTake for ChunkedArray -where - T: PolarsNumericType, -{ - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let mut chunks = self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - let array = match (self.null_count(), self.chunks.len()) { - (0, 1) => { - take_no_null_primitive_unchecked::(chunks.next().unwrap(), array) - as ArrayRef - }, - (_, 1) => take_primitive_unchecked::(chunks.next().unwrap(), array) - as ArrayRef, - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - let mut ca = take_primitive_iter_n_chunks(self, iter); - ca.rename(self.name()); - ca - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let mut ca = take_primitive_opt_iter_n_chunks(self, iter); - ca.rename(self.name()); - ca - } - }, - }; - self.finish_from_array(array) - }, - TakeIdx::Iter(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => take_no_null_primitive_iter_unchecked::( - chunks.next().unwrap(), - iter, - ) as ArrayRef, - (_, 1) => { - take_primitive_iter_unchecked::(chunks.next().unwrap(), iter) - as ArrayRef - }, - _ => { - let mut ca = take_primitive_iter_n_chunks(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - TakeIdx::IterNulls(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => take_no_null_primitive_opt_iter_unchecked::( - chunks.next().unwrap(), - iter, - ) as ArrayRef, - (_, 1) => take_primitive_opt_iter_unchecked::( - chunks.next().unwrap(), - iter, - ) as ArrayRef, - _ => { - let mut ca = take_primitive_opt_iter_n_chunks(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - } - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // Safety: - // just checked bounds - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -impl ChunkTake for BooleanChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let mut chunks = self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - let array = match self.chunks.len() { - 1 => take::take_unchecked(chunks.next().unwrap(), array), - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - let mut ca: BooleanChunked = take_iter_n_chunks!(self, iter); - ca.rename(self.name()); - ca - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let mut ca: BooleanChunked = take_opt_iter_n_chunks!(self, iter); - ca.rename(self.name()); - ca - } - }, - }; - self.finish_from_array(array) - }, - TakeIdx::Iter(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_bool_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - (_, 1) => take_bool_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef, - _ => { - let mut ca: BooleanChunked = take_iter_n_chunks_unchecked!(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - TakeIdx::IterNulls(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_bool_opt_iter_unchecked(chunks.next().unwrap(), iter) - as ArrayRef - }, - (_, 1) => { - take_bool_opt_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - _ => { - let mut ca: BooleanChunked = take_opt_iter_n_chunks_unchecked!(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - } - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // Safety: - // just checked bounds - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -impl ChunkTake for Utf8Chunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - self.as_binary().take_unchecked(indices).to_utf8() - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let out = self.as_binary().take(indices)?; - Ok(unsafe { out.to_utf8() }) - } -} - -impl ChunkTake for BinaryChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let mut chunks = self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - let array = match self.chunks.len() { - 1 => take_binary_unchecked(chunks.next().unwrap(), array) as ArrayRef, - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - let mut ca: BinaryChunked = take_iter_n_chunks_unchecked!(self, iter); - ca.rename(self.name()); - ca - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let mut ca: BinaryChunked = - take_opt_iter_n_chunks_unchecked!(self, iter); - ca.rename(self.name()); - ca - } - }, - }; - self.finish_from_array(array) - }, - TakeIdx::Iter(iter) => { - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_binary_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - (_, 1) => take_binary_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef, - _ => { - let mut ca: BinaryChunked = take_iter_n_chunks_unchecked!(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - TakeIdx::IterNulls(iter) => { - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_binary_opt_iter_unchecked(chunks.next().unwrap(), iter) - as ArrayRef - }, - (_, 1) => { - take_binary_opt_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - _ => { - let mut ca: BinaryChunked = take_opt_iter_n_chunks_unchecked!(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - } - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // Safety: - // just checked bounds - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -impl ChunkTake for ListChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let ca_self = if self.is_nested() { - Cow::Owned(self.rechunk()) - } else { - Cow::Borrowed(self) - }; - let mut chunks = ca_self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null_with_dtype( - self.name(), - array.len(), - &self.inner_dtype(), - ); - } - let array = match ca_self.chunks.len() { - 1 => Box::new(take_list_unchecked(chunks.next().unwrap(), array)) as ArrayRef, - _ => { - if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - let mut ca: ListChunked = - take_iter_n_chunks_unchecked!(ca_self.as_ref(), iter); - ca.chunks.pop().unwrap() - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let mut ca: ListChunked = - take_opt_iter_n_chunks_unchecked!(ca_self.as_ref(), iter); - ca.chunks.pop().unwrap() - } - }, - }; - self.finish_from_array(array) - }, - // todo! fast path for single chunk - TakeIdx::Iter(iter) => { - if ca_self.chunks.len() == 1 { - let idx: NoNull = iter.map(|v| v as IdxSize).collect(); - ca_self.take_unchecked((&idx.into_inner()).into()) - } else { - let mut ca: ListChunked = take_iter_n_chunks_unchecked!(ca_self.as_ref(), iter); - self.finish_from_array(ca.chunks.pop().unwrap()) - } - }, - TakeIdx::IterNulls(iter) => { - if ca_self.chunks.len() == 1 { - let idx: IdxCa = iter.map(|v| v.map(|v| v as IdxSize)).collect(); - ca_self.take_unchecked((&idx).into()) - } else { - let mut ca: ListChunked = - take_opt_iter_n_chunks_unchecked!(ca_self.as_ref(), iter); - self.finish_from_array(ca.chunks.pop().unwrap()) - } - }, - } - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // Safety: - // just checked bounds - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -#[cfg(feature = "dtype-array")] -impl ChunkTake for ArrayChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let ca_self = self.rechunk(); - match indices { - TakeIdx::Array(idx_array) => { - if idx_array.null_count() == idx_array.len() { - return Self::full_null_with_dtype( - self.name(), - idx_array.len(), - &self.inner_dtype(), - ca_self.width(), - ); - } - let arr = self.chunks[0].as_ref(); - let arr = take_unchecked(arr, idx_array); - self.finish_from_array(arr) - }, - TakeIdx::Iter(iter) => { - let idx: NoNull = iter.map(|v| v as IdxSize).collect(); - ca_self.take_unchecked((&idx.into_inner()).into()) - }, - TakeIdx::IterNulls(iter) => { - let idx: IdxCa = iter.map(|v| v.map(|v| v as IdxSize)).collect(); - ca_self.take_unchecked((&idx).into()) - }, - } - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // Safety: - // just checked bounds - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -#[cfg(feature = "object")] -impl ChunkTake for ObjectChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - // current implementation is suboptimal, every iterator is allocated to UInt32Array - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - - match self.chunks.len() { - 1 => { - let values = self.downcast_chunks().get(0).unwrap().values(); - - let mut ca: Self = array - .into_iter() - .map(|opt_idx| { - opt_idx.map(|idx| values.get_unchecked(*idx as usize).clone()) - }) - .collect(); - ca.rename(self.name()); - ca - }, - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - - let taker = self.take_rand(); - let mut ca: ObjectChunked = - iter.map(|idx| taker.get_unchecked(idx).cloned()).collect(); - ca.rename(self.name()); - ca - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let taker = self.take_rand(); - - let mut ca: ObjectChunked = iter - .map(|opt_idx| { - opt_idx.and_then(|idx| taker.get_unchecked(idx).cloned()) - }) - .collect(); - - ca.rename(self.name()); - ca - } - }, - } - }, - TakeIdx::Iter(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - - let taker = self.take_rand(); - let mut ca: ObjectChunked = - iter.map(|idx| taker.get_unchecked(idx).cloned()).collect(); - ca.rename(self.name()); - ca - }, - TakeIdx::IterNulls(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let taker = self.take_rand(); - - let mut ca: ObjectChunked = iter - .map(|opt_idx| opt_idx.and_then(|idx| taker.get(idx).cloned())) - .collect(); - - ca.rename(self.name()); - ca - }, - } - } - - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // Safety: - // just checked bounds - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_take_random() { - let ca = Int32Chunked::from_slice("a", &[1, 2, 3]); - assert_eq!(ca.get(0), Some(1)); - assert_eq!(ca.get(1), Some(2)); - assert_eq!(ca.get(2), Some(3)); - - let ca = Utf8Chunked::from_slice("a", &["a", "b", "c"]); - assert_eq!(ca.get(0), Some("a")); - assert_eq!(ca.get(1), Some("b")); - assert_eq!(ca.get(2), Some("c")); - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/take_every.rs b/crates/polars-core/src/chunked_array/ops/take/take_every.rs deleted file mode 100644 index 6401eb2f133d..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_every.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::prelude::*; - -impl Series { - /// Traverse and collect every nth element in a new array. - pub fn take_every(&self, n: usize) -> Series { - let mut idx = (0..self.len()).step_by(n); - - // safety: we are in bounds - unsafe { self.take_iter_unchecked(&mut idx) } - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/take_random.rs b/crates/polars-core/src/chunked_array/ops/take/take_random.rs deleted file mode 100644 index 43feaca04576..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_random.rs +++ /dev/null @@ -1,686 +0,0 @@ -use arrow::array::{Array, BooleanArray, ListArray, PrimitiveArray, Utf8Array}; -use arrow::bitmap::utils::get_bit_unchecked; -use arrow::bitmap::Bitmap; -use polars_arrow::is_valid::*; - -#[cfg(feature = "object")] -use crate::chunked_array::object::ObjectArray; -use crate::prelude::downcast::Chunks; -use crate::prelude::*; - -macro_rules! take_random_get { - ($self:ident, $index:ident) => {{ - let (chunk_idx, arr_idx) = crate::utils::index_to_chunked_index( - $self.chunk_lens.iter().copied(), - $index as IdxSize, - ); - - // Safety: - // bounds are checked above - let arr = unsafe { $self.chunks.get_unchecked(chunk_idx as usize) }; - - if arr.is_null(arr_idx as usize) { - None - } else { - // SAFETY: - // bounds checked above - unsafe { Some(arr.value_unchecked(arr_idx as usize)) } - } - }}; -} - -macro_rules! take_random_get_unchecked { - ($self:ident, $index:ident) => {{ - let (chunk_idx, arr_idx) = crate::utils::index_to_chunked_index( - $self.chunk_lens.iter().copied(), - $index as IdxSize, - ); - - // Safety: - // bounds are checked above - let arr = $self.chunks.get_unchecked(chunk_idx as usize); - - if arr.is_null_unchecked(arr_idx as usize) { - None - } else { - // SAFETY: - // bounds checked above - Some(arr.value_unchecked(arr_idx as usize)) - } - }}; -} - -macro_rules! take_random_get_single { - ($self:ident, $index:ident) => {{ - if $self.arr.is_null($index) { - None - } else { - // Safety: - // bound checked above - unsafe { Some($self.arr.value_unchecked($index)) } - } - }}; -} - -/// Create a type that implements a faster `TakeRandom`. -pub trait IntoTakeRandom<'a> { - type Item; - type TakeRandom; - /// Create a type that implements `TakeRandom`. - fn take_rand(&self) -> Self::TakeRandom; -} - -pub enum TakeRandBranch3 { - SingleNoNull(N), - Single(S), - Multi(M), -} - -impl TakeRandom for TakeRandBranch3 -where - N: TakeRandom, - S: TakeRandom, - M: TakeRandom, -{ - type Item = I; - - #[inline] - fn get(&self, index: usize) -> Option { - match self { - Self::SingleNoNull(s) => s.get(index), - Self::Single(s) => s.get(index), - Self::Multi(m) => m.get(index), - } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - match self { - Self::SingleNoNull(s) => s.get_unchecked(index), - Self::Single(s) => s.get_unchecked(index), - Self::Multi(m) => m.get_unchecked(index), - } - } - - fn last(&self) -> Option { - match self { - Self::SingleNoNull(s) => s.last(), - Self::Single(s) => s.last(), - Self::Multi(m) => m.last(), - } - } -} - -pub enum TakeRandBranch2 { - Single(S), - Multi(M), -} - -impl TakeRandom for TakeRandBranch2 -where - S: TakeRandom, - M: TakeRandom, -{ - type Item = I; - - fn get(&self, index: usize) -> Option { - match self { - Self::Single(s) => s.get(index), - Self::Multi(m) => m.get(index), - } - } - - unsafe fn get_unchecked(&self, index: usize) -> Option { - match self { - Self::Single(s) => s.get_unchecked(index), - Self::Multi(m) => m.get_unchecked(index), - } - } - fn last(&self) -> Option { - match self { - Self::Single(s) => s.last(), - Self::Multi(m) => m.last(), - } - } -} - -#[allow(clippy::type_complexity)] -impl<'a, T> IntoTakeRandom<'a> for &'a ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - type TakeRandom = TakeRandBranch3< - NumTakeRandomCont<'a, T::Native>, - NumTakeRandomSingleChunk<'a, T::Native>, - NumTakeRandomChunked<'a, T::Native>, - >; - - #[inline] - fn take_rand(&self) -> Self::TakeRandom { - let mut chunks = self.downcast_iter(); - - if self.chunks.len() == 1 { - let arr = chunks.next().unwrap(); - - if !self.has_validity() { - let t = NumTakeRandomCont { - slice: arr.values(), - }; - TakeRandBranch3::SingleNoNull(t) - } else { - let t = NumTakeRandomSingleChunk::new(arr); - TakeRandBranch3::Single(t) - } - } else { - let t = NumTakeRandomChunked { - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch3::Multi(t) - } - } -} - -pub struct Utf8TakeRandom<'a> { - pub(crate) chunks: Chunks<'a, Utf8Array>, - pub(crate) chunk_lens: Vec, -} - -impl<'a> TakeRandom for Utf8TakeRandom<'a> { - type Item = &'a str; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - take_random_get_unchecked!(self, index) - } - fn last(&self) -> Option { - self.chunks - .last() - .and_then(|arr| arr.get(arr.len().saturating_sub(1))) - } -} - -pub struct Utf8TakeRandomSingleChunk<'a> { - pub(crate) arr: &'a Utf8Array, -} - -impl<'a> TakeRandom for Utf8TakeRandomSingleChunk<'a> { - type Item = &'a str; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - if self.arr.is_valid_unchecked(index) { - Some(self.arr.value_unchecked(index)) - } else { - None - } - } - fn last(&self) -> Option { - self.get(self.arr.len().saturating_sub(1)) - } -} - -impl<'a> IntoTakeRandom<'a> for &'a Utf8Chunked { - type Item = &'a str; - type TakeRandom = TakeRandBranch2, Utf8TakeRandom<'a>>; - - fn take_rand(&self) -> Self::TakeRandom { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = Utf8TakeRandomSingleChunk { arr }; - TakeRandBranch2::Single(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = Utf8TakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch2::Multi(t) - }, - } - } -} - -pub struct BinaryTakeRandom<'a> { - pub(crate) chunks: Chunks<'a, BinaryArray>, - pub(crate) chunk_lens: Vec, -} - -impl<'a> TakeRandom for BinaryTakeRandom<'a> { - type Item = &'a [u8]; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - take_random_get_unchecked!(self, index) - } - fn last(&self) -> Option { - self.chunks - .last() - .and_then(|arr| arr.get(arr.len().saturating_sub(1))) - } -} - -pub struct BinaryTakeRandomSingleChunk<'a> { - pub(crate) arr: &'a BinaryArray, -} - -impl<'a> TakeRandom for BinaryTakeRandomSingleChunk<'a> { - type Item = &'a [u8]; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - if self.arr.is_valid_unchecked(index) { - Some(self.arr.value_unchecked(index)) - } else { - None - } - } - fn last(&self) -> Option { - self.get(self.arr.len().saturating_sub(1)) - } -} - -impl<'a> IntoTakeRandom<'a> for &'a BinaryChunked { - type Item = &'a [u8]; - type TakeRandom = TakeRandBranch2, BinaryTakeRandom<'a>>; - - fn take_rand(&self) -> Self::TakeRandom { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = BinaryTakeRandomSingleChunk { arr }; - TakeRandBranch2::Single(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = BinaryTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch2::Multi(t) - }, - } - } -} - -impl<'a> IntoTakeRandom<'a> for &'a BooleanChunked { - type Item = bool; - type TakeRandom = TakeRandBranch2, BoolTakeRandom<'a>>; - - fn take_rand(&self) -> Self::TakeRandom { - match self.chunks.len() { - 1 => { - let arr = self.downcast_iter().next().unwrap(); - let t = BoolTakeRandomSingleChunk { arr }; - TakeRandBranch2::Single(t) - }, - _ => { - let chunks = self.downcast_chunks(); - let t = BoolTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch2::Multi(t) - }, - } - } -} - -impl<'a> IntoTakeRandom<'a> for &'a ListChunked { - type Item = Series; - type TakeRandom = TakeRandBranch2, ListTakeRandom<'a>>; - - fn take_rand(&self) -> Self::TakeRandom { - let mut chunks = self.downcast_iter(); - if self.chunks.len() == 1 { - let t = ListTakeRandomSingleChunk { - arr: chunks.next().unwrap(), - name: self.name(), - }; - TakeRandBranch2::Single(t) - } else { - let name = self.name(); - let inner_type = self.inner_dtype().to_physical(); - let t = ListTakeRandom { - name, - inner_type, - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch2::Multi(t) - } - } -} - -pub struct NumTakeRandomChunked<'a, T> -where - T: NumericNative, -{ - pub(crate) chunks: Vec<&'a PrimitiveArray>, - pub(crate) chunk_lens: Vec, -} - -impl<'a, T> TakeRandom for NumTakeRandomChunked<'a, T> -where - T: NumericNative, -{ - type Item = T; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - take_random_get_unchecked!(self, index) - } - fn last(&self) -> Option { - self.chunks - .last() - .and_then(|arr| arr.get(arr.len().saturating_sub(1))) - } -} - -pub struct NumTakeRandomCont<'a, T> { - pub(crate) slice: &'a [T], -} - -impl<'a, T> TakeRandom for NumTakeRandomCont<'a, T> -where - T: Copy, -{ - type Item = T; - - #[inline] - fn get(&self, index: usize) -> Option { - self.slice.get(index).copied() - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - Some(*self.slice.get_unchecked(index)) - } - fn last(&self) -> Option { - self.slice.last().copied() - } -} - -pub struct TakeRandomBitmap<'a> { - bytes: &'a [u8], - offset: usize, -} - -impl<'a> TakeRandomBitmap<'a> { - pub(crate) fn new(bitmap: &'a Bitmap) -> Self { - let (bytes, offset, _) = bitmap.as_slice(); - Self { bytes, offset } - } - - unsafe fn get_unchecked(&self, index: usize) -> bool { - get_bit_unchecked(self.bytes, self.offset + index) - } -} - -pub struct NumTakeRandomSingleChunk<'a, T> -where - T: NumericNative, -{ - pub(crate) vals: &'a [T], - pub(crate) validity: TakeRandomBitmap<'a>, -} - -impl<'a, T: NumericNative> NumTakeRandomSingleChunk<'a, T> { - pub(crate) fn new(arr: &'a PrimitiveArray) -> Self { - let validity = TakeRandomBitmap::new(arr.validity().unwrap()); - let vals = arr.values(); - NumTakeRandomSingleChunk { vals, validity } - } -} - -impl<'a, T> TakeRandom for NumTakeRandomSingleChunk<'a, T> -where - T: NumericNative, -{ - type Item = T; - - #[inline] - fn get(&self, index: usize) -> Option { - if index < self.vals.len() { - unsafe { self.get_unchecked(index) } - } else { - None - } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - if self.validity.get_unchecked(index) { - Some(*self.vals.get_unchecked(index)) - } else { - None - } - } - fn last(&self) -> Option { - self.get(self.vals.len().saturating_sub(1)) - } -} - -pub struct BoolTakeRandom<'a> { - pub(crate) chunks: Chunks<'a, BooleanArray>, - pub(crate) chunk_lens: Vec, -} - -impl<'a> TakeRandom for BoolTakeRandom<'a> { - type Item = bool; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - take_random_get_unchecked!(self, index) - } - - fn last(&self) -> Option { - self.chunks - .last() - .and_then(|arr| arr.get(arr.len().saturating_sub(1))) - } -} - -pub struct BoolTakeRandomSingleChunk<'a> { - pub(crate) arr: &'a BooleanArray, -} - -impl<'a> TakeRandom for BoolTakeRandomSingleChunk<'a> { - type Item = bool; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - if self.arr.is_valid_unchecked(index) { - Some(self.arr.value_unchecked(index)) - } else { - None - } - } - fn last(&self) -> Option { - self.arr.get(self.arr.len().saturating_sub(1)) - } -} - -pub struct ListTakeRandom<'a> { - pub(crate) inner_type: DataType, - pub(crate) name: &'a str, - pub(crate) chunks: Vec<&'a ListArray>, - pub(crate) chunk_lens: Vec, -} - -impl<'a> TakeRandom for ListTakeRandom<'a> { - type Item = Series; - - #[inline] - fn get(&self, index: usize) -> Option { - let v = take_random_get!(self, index); - v.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked(self.name, vec![arr], &self.inner_type) - }) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - let v = take_random_get_unchecked!(self, index); - v.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked(self.name, vec![arr], &self.inner_type) - }) - } - fn last(&self) -> Option { - self.chunks.last().and_then(|arr| { - let arr = arr.get(arr.len().saturating_sub(1)); - arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name, - vec![arr.to_boxed()], - &self.inner_type, - ) - }) - }) - } -} - -pub struct ListTakeRandomSingleChunk<'a> { - pub(crate) arr: &'a ListArray, - pub(crate) name: &'a str, -} - -impl<'a> TakeRandom for ListTakeRandomSingleChunk<'a> { - type Item = Series; - - #[inline] - fn get(&self, index: usize) -> Option { - let v = take_random_get_single!(self, index); - v.map(|v| { - let s = Series::try_from((self.name, v)); - s.unwrap() - }) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - if self.arr.is_valid_unchecked(index) { - let v = self.arr.value_unchecked(index); - let s = Series::try_from((self.name, v)); - s.ok() - } else { - None - } - } - fn last(&self) -> Option { - self.get(self.arr.len().saturating_sub(1)) - } -} - -#[cfg(feature = "object")] -pub struct ObjectTakeRandom<'a, T: PolarsObject> { - pub(crate) chunks: Chunks<'a, ObjectArray>, - pub(crate) chunk_lens: Vec, -} - -#[cfg(feature = "object")] -impl<'a, T: PolarsObject> TakeRandom for ObjectTakeRandom<'a, T> { - type Item = &'a T; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - take_random_get_unchecked!(self, index) - } - fn last(&self) -> Option { - self.chunks - .last() - .and_then(|arr| arr.get(arr.len().saturating_sub(1))) - } -} - -#[cfg(feature = "object")] -pub struct ObjectTakeRandomSingleChunk<'a, T: PolarsObject> { - pub(crate) arr: &'a ObjectArray, -} - -#[cfg(feature = "object")] -impl<'a, T: PolarsObject> TakeRandom for ObjectTakeRandomSingleChunk<'a, T> { - type Item = &'a T; - - #[inline] - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - if self.arr.is_valid_unchecked(index) { - Some(self.arr.value_unchecked(index)) - } else { - None - } - } - fn last(&self) -> Option { - self.arr.get(self.arr.len().saturating_sub(1)) - } -} - -#[cfg(feature = "object")] -impl<'a, T: PolarsObject> IntoTakeRandom<'a> for &'a ObjectChunked { - type Item = &'a T; - type TakeRandom = TakeRandBranch2, ObjectTakeRandom<'a, T>>; - - fn take_rand(&self) -> Self::TakeRandom { - let chunks = self.downcast_chunks(); - if self.chunks.len() == 1 { - let t = ObjectTakeRandomSingleChunk { - arr: chunks.get(0).unwrap(), - }; - TakeRandBranch2::Single(t) - } else { - let t = ObjectTakeRandom { - chunks, - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch2::Multi(t) - } - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/take_single.rs b/crates/polars-core/src/chunked_array/ops/take/take_single.rs deleted file mode 100644 index b5494d337471..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_single.rs +++ /dev/null @@ -1,338 +0,0 @@ -use arrow::array::*; -use polars_arrow::is_valid::IsValid; - -#[cfg(feature = "object")] -use crate::chunked_array::object::ObjectArray; -use crate::prelude::*; - -macro_rules! impl_take_random_get { - ($self:ident, $index:ident, $array_type:ty) => {{ - assert!($index < $self.len()); - let (chunk_idx, idx) = $self.index_to_chunked_index($index); - // Safety: - // bounds are checked above - let arr = $self.chunks.get_unchecked(chunk_idx); - - // Safety: - // caller should give right array type - let arr = &*(arr as *const ArrayRef as *const Box<$array_type>); - - // Safety: - // index should be in bounds - if arr.is_valid(idx) { - Some(arr.value_unchecked(idx)) - } else { - None - } - }}; -} - -macro_rules! impl_take_random_get_unchecked { - ($self:ident, $index:ident, $array_type:ty) => {{ - let (chunk_idx, idx) = $self.index_to_chunked_index($index); - debug_assert!(chunk_idx < $self.chunks.len()); - // Safety: - // bounds are checked above - let arr = $self.chunks.get_unchecked(chunk_idx); - - // Safety: - // caller should give right array type - let arr = &*(&**arr as *const dyn Array as *const $array_type); - - // Safety: - // index should be in bounds - debug_assert!(idx < arr.len()); - if arr.is_valid_unchecked(idx) { - Some(arr.value_unchecked(idx)) - } else { - None - } - }}; -} - -impl TakeRandom for ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - - #[inline] - fn get(&self, index: usize) -> Option { - unsafe { impl_take_random_get!(self, index, PrimitiveArray) } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, PrimitiveArray) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a, T> TakeRandom for &'a ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - - #[inline] - fn get(&self, index: usize) -> Option { - (*self).get(index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - (*self).get_unchecked(index) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl TakeRandom for BooleanChunked { - type Item = bool; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, BooleanArray) } - } - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, BooleanArray) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a> TakeRandom for &'a BooleanChunked { - type Item = bool; - - #[inline] - fn get(&self, index: usize) -> Option { - (*self).get(index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - (*self).get_unchecked(index) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a> TakeRandom for &'a Utf8Chunked { - type Item = &'a str; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, LargeStringArray) } - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a> TakeRandom for &'a BinaryChunked { - type Item = &'a [u8]; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, LargeBinaryArray) } - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -// extra trait such that it also works without extra reference. -// Autoref will insert the reference and -impl<'a> TakeRandomUtf8 for &'a Utf8Chunked { - type Item = &'a str; - - #[inline] - fn get(self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, LargeStringArray) } - } - - #[inline] - unsafe fn get_unchecked(self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, LargeStringArray) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -#[cfg(feature = "object")] -impl<'a, T: PolarsObject> TakeRandom for &'a ObjectChunked { - type Item = &'a T; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, ObjectArray) } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, ObjectArray) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl TakeRandom for ListChunked { - type Item = Series; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - let opt_arr = unsafe { impl_take_random_get!(self, index, LargeListArray) }; - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - let opt_arr = impl_take_random_get_unchecked!(self, index, LargeListArray); - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1).map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } else { - None - } - } -} - -#[cfg(feature = "dtype-array")] -impl TakeRandom for ArrayChunked { - type Item = Series; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - let opt_arr = unsafe { impl_take_random_get!(self, index, FixedSizeListArray) }; - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - let opt_arr = impl_take_random_get_unchecked!(self, index, FixedSizeListArray); - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1).map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } else { - None - } - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/traits.rs b/crates/polars-core/src/chunked_array/ops/take/traits.rs deleted file mode 100644 index 818681f831e6..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/traits.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! Traits that indicate the allowed arguments in a ChunkedArray::take operation. -use crate::frame::group_by::GroupsProxyIter; -use crate::prelude::*; - -// Utility traits -pub trait TakeIterator: Iterator + TrustedLen { - fn check_bounds(&self, bound: usize) -> PolarsResult<()>; - // a sort of clone - fn boxed_clone(&self) -> Box; -} -pub trait TakeIteratorNulls: Iterator> + TrustedLen { - fn check_bounds(&self, bound: usize) -> PolarsResult<()>; - - fn boxed_clone(&self) -> Box; -} - -unsafe impl TrustedLen for &mut dyn TakeIterator {} -unsafe impl TrustedLen for &mut dyn TakeIteratorNulls {} -unsafe impl TrustedLen for GroupsProxyIter<'_> {} - -// Implement for the ref as well -impl TakeIterator for &mut dyn TakeIterator { - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - (**self).check_bounds(bound) - } - - fn boxed_clone(&self) -> Box { - (**self).boxed_clone() - } -} -impl TakeIteratorNulls for &mut dyn TakeIteratorNulls { - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - (**self).check_bounds(bound) - } - - fn boxed_clone(&self) -> Box { - (**self).boxed_clone() - } -} - -// Clonable iterators may implement the traits above -impl TakeIterator for I -where - I: Iterator + Clone + Sized + TrustedLen, -{ - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - // clone so that the iterator can be used again. - let iter = self.clone(); - let mut inbounds = true; - - for i in iter { - if i >= bound { - // we will not break here as that prevents SIMD - inbounds = false; - } - } - polars_ensure!(inbounds, ComputeError: "take indices are out of bounds"); - Ok(()) - } - - fn boxed_clone(&self) -> Box { - Box::new(self.clone()) - } -} -impl TakeIteratorNulls for I -where - I: Iterator> + Clone + Sized + TrustedLen, -{ - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - // clone so that the iterator can be used again. - let iter = self.clone(); - let mut inbounds = true; - - for i in iter.flatten() { - if i >= bound { - // we will not break here as that prevents SIMD - inbounds = false; - } - } - polars_ensure!(inbounds, ComputeError: "take indices are out of bounds"); - Ok(()) - } - - fn boxed_clone(&self) -> Box { - Box::new(self.clone()) - } -} - -/// One of the three arguments allowed in unchecked_take -pub enum TakeIdx<'a, I, INulls> -where - I: TakeIterator, - INulls: TakeIteratorNulls, -{ - Array(&'a IdxArr), - Iter(I), - // will return a null where None - IterNulls(INulls), -} - -impl<'a, I, INulls> TakeIdx<'a, I, INulls> -where - I: TakeIterator, - INulls: TakeIteratorNulls, -{ - pub(crate) fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - match self { - TakeIdx::Iter(i) => i.check_bounds(bound), - TakeIdx::IterNulls(i) => i.check_bounds(bound), - TakeIdx::Array(arr) => { - let values = arr.values().as_slice(); - let mut inbounds = true; - let len = bound as IdxSize; - if arr.null_count() == 0 { - for &i in values { - // we will not break here as that prevents SIMD - if i >= len { - inbounds = false; - } - } - } else { - for opt_v in *arr { - match opt_v { - Some(&v) if v >= len => { - inbounds = false; - }, - _ => {}, - } - } - } - polars_ensure!(inbounds, ComputeError: "take indices are out of bounds"); - Ok(()) - }, - } - } -} - -/// Dummy type, we need to instantiate all generic types, so we fill one with a dummy. -pub type Dummy = std::iter::Once; - -// Below the conversions from -// * UInt32Chunked -// * Iterator -// * Iterator> -// -// To the checked and unchecked TakeIdx enums - -// Unchecked conversions - -/// Conversion from UInt32Chunked to Unchecked TakeIdx -impl<'a> From<&'a IdxCa> for TakeIdx<'a, Dummy, Dummy>> { - fn from(ca: &'a IdxCa) -> Self { - if ca.chunks.len() == 1 { - TakeIdx::Array(ca.downcast_iter().next().unwrap()) - } else { - panic!("implementation error, should be transformed to an iterator by the caller") - } - } -} - -/// Conversion from Iterator to Unchecked TakeIdx -impl<'a, I> From for TakeIdx<'a, I, Dummy>> -where - I: TakeIterator, -{ - fn from(iter: I) -> Self { - TakeIdx::Iter(iter) - } -} - -/// Conversion from [`Iterator>`] to Unchecked [`TakeIdx`] -impl<'a, I> From for TakeIdx<'a, Dummy, I> -where - I: TakeIteratorNulls, -{ - fn from(iter: I) -> Self { - TakeIdx::IterNulls(iter) - } -} - -#[inline] -fn to_usize(idx: &IdxSize) -> usize { - *idx as usize -} - -/// Conversion from `&[IdxSize]` to Unchecked TakeIdx -impl<'a> From<&'a [IdxSize]> - for TakeIdx< - 'a, - std::iter::Map, fn(&IdxSize) -> usize>, - Dummy>, - > -{ - fn from(slice: &'a [IdxSize]) -> Self { - TakeIdx::Iter(slice.iter().map(to_usize)) - } -} - -/// Conversion from `&[IdxSize]` to Unchecked TakeIdx -impl<'a> From<&'a Vec> - for TakeIdx< - 'a, - std::iter::Map, fn(&IdxSize) -> usize>, - Dummy>, - > -{ - fn from(slice: &'a Vec) -> Self { - (&**slice).into() - } -} diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index 217e7a5494b0..baed88993f5c 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "rank")] -pub(crate) mod rank; - use std::hash::Hash; use arrow::bitmap::MutableBitmap; @@ -8,10 +5,8 @@ use arrow::bitmap::MutableBitmap; #[cfg(feature = "object")] use crate::datatypes::ObjectType; use crate::datatypes::PlHashSet; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; use crate::frame::group_by::GroupsProxy; -#[cfg(feature = "mode")] -use crate::frame::group_by::IntoGroupsProxy; +use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::series::IsSorted; @@ -77,54 +72,6 @@ where unique } -#[cfg(feature = "mode")] -fn mode_indices(groups: GroupsProxy) -> Vec { - match groups { - GroupsProxy::Idx(groups) => { - let mut groups = groups.into_iter().collect_trusted::>(); - groups.sort_unstable_by_key(|k| k.1.len()); - let last = &groups.last().unwrap(); - let max_occur = last.1.len(); - groups - .iter() - .rev() - .take_while(|v| v.1.len() == max_occur) - .map(|v| v.0) - .collect() - }, - GroupsProxy::Slice { groups, .. } => { - let last = groups.last().unwrap(); - let max_occur = last[1]; - - groups - .iter() - .rev() - .take_while(|v| { - let len = v[1]; - len == max_occur - }) - .map(|v| v[0]) - .collect() - }, - } -} - -#[cfg(feature = "mode")] -fn mode(ca: &ChunkedArray) -> ChunkedArray -where - ChunkedArray: IntoGroupsProxy + ChunkTake, -{ - if ca.is_empty() { - return ca.clone(); - } - let groups = ca.group_tuples(true, false).unwrap(); - let idx = mode_indices(groups); - - // Safety: - // group indices are in bounds - unsafe { ca.take_unchecked(idx.as_slice().into()) } -} - macro_rules! arg_unique_ca { ($ca:expr) => {{ match $ca.has_validity() { @@ -222,11 +169,6 @@ where }, } } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(mode(self)) - } } impl ChunkUnique for Utf8Chunked { @@ -242,12 +184,6 @@ impl ChunkUnique for Utf8Chunked { fn n_unique(&self) -> PolarsResult { self.as_binary().n_unique() } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - let out = self.as_binary().mode()?; - Ok(unsafe { out.to_utf8() }) - } } impl ChunkUnique for BinaryChunked { @@ -255,7 +191,7 @@ impl ChunkUnique for BinaryChunked { match self.null_count() { 0 => { let mut set = - PlHashSet::with_capacity(std::cmp::min(HASHMAP_INIT_SIZE, self.len())); + PlHashSet::with_capacity(std::cmp::min(_HASHMAP_INIT_SIZE, self.len())); for arr in self.downcast_iter() { set.extend(arr.values_iter()) } @@ -266,7 +202,7 @@ impl ChunkUnique for BinaryChunked { }, _ => { let mut set = - PlHashSet::with_capacity(std::cmp::min(HASHMAP_INIT_SIZE, self.len())); + PlHashSet::with_capacity(std::cmp::min(_HASHMAP_INIT_SIZE, self.len())); for arr in self.downcast_iter() { set.extend(arr.iter()) } @@ -296,11 +232,6 @@ impl ChunkUnique for BinaryChunked { Ok(set.len()) } } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(mode(self)) - } } impl ChunkUnique for BooleanChunked { @@ -333,12 +264,6 @@ impl ChunkUnique for Float32Chunked { fn arg_unique(&self) -> PolarsResult { self.bit_repr_small().arg_unique() } - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult> { - let s = self.apply_as_ints(|v| v.mode().unwrap()); - let ca = s.f32().unwrap().clone(); - Ok(ca) - } } impl ChunkUnique for Float64Chunked { @@ -351,12 +276,6 @@ impl ChunkUnique for Float64Chunked { fn arg_unique(&self) -> PolarsResult { self.bit_repr_large().arg_unique() } - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult> { - let s = self.apply_as_ints(|v| v.mode().unwrap()); - let ca = s.f64().unwrap().clone(); - Ok(ca) - } } #[cfg(test)] @@ -395,22 +314,4 @@ mod test { vec![Some(0), Some(1), Some(4)] ); } - - #[test] - #[cfg(feature = "mode")] - fn mode() { - let ca = Int32Chunked::from_slice("a", &[0, 1, 2, 3, 4, 4, 5, 6, 5, 0]); - let mut result = Vec::from(&ca.mode().unwrap()); - result.sort_by_key(|a| a.unwrap()); - assert_eq!(&result, &[Some(0), Some(4), Some(5)]); - - let ca2 = Int32Chunked::from_slice("b", &[1, 1]); - let mut result2 = Vec::from(&ca2.mode().unwrap()); - result2.sort_by_key(|a| a.unwrap()); - assert_eq!(&result2, &[Some(1)]); - - let ca3 = Int32Chunked::from_slice("c", &[]); - let result3 = Vec::from(&ca3.mode().unwrap()); - assert_eq!(result3, &[]); - } } diff --git a/crates/polars-core/src/chunked_array/ops/unique/rank.rs b/crates/polars-core/src/chunked_array/ops/unique/rank.rs deleted file mode 100644 index 3e532404506f..000000000000 --- a/crates/polars-core/src/chunked_array/ops/unique/rank.rs +++ /dev/null @@ -1,454 +0,0 @@ -use polars_arrow::prelude::FromData; -#[cfg(feature = "random")] -use rand::prelude::SliceRandom; -use rand::prelude::*; -#[cfg(feature = "random")] -use rand::{rngs::SmallRng, SeedableRng}; - -use crate::prelude::*; - -#[derive(Copy, Clone)] -pub enum RankMethod { - Average, - Min, - Max, - Dense, - Ordinal, - #[cfg(feature = "random")] - Random, -} - -// We might want to add a `nulls_last` or `null_behavior` field. -#[derive(Copy, Clone)] -pub struct RankOptions { - pub method: RankMethod, - pub descending: bool, -} - -impl Default for RankOptions { - fn default() -> Self { - Self { - method: RankMethod::Dense, - descending: false, - } - } -} - -#[cfg(feature = "random")] -fn get_random_seed() -> u64 { - let mut rng = SmallRng::from_entropy(); - - rng.next_u64() -} - -pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> Series { - match s.len() { - 1 => { - return match method { - Average => Series::new(s.name(), &[1.0f64]), - _ => Series::new(s.name(), &[1 as IdxSize]), - }; - }, - 0 => { - return match method { - Average => Float64Chunked::from_slice(s.name(), &[]).into_series(), - _ => IdxCa::from_slice(s.name(), &[]).into_series(), - }; - }, - _ => {}, - } - - if s.null_count() > 0 { - let nulls = s.is_not_null().rechunk(); - let arr = nulls.downcast_iter().next().unwrap(); - let validity = arr.values(); - // Currently, nulls tie with the minimum or maximum bound for a type, depending on descending. - // TODO: Need to expose nulls_last in arg_sort to prevent this. - // Fill using MaxBound/MinBound to give nulls last rank. - // we will replace them later. - let null_strategy = if descending { - FillNullStrategy::MinBound - } else { - FillNullStrategy::MaxBound - }; - let s = s.fill_null(null_strategy).unwrap(); - - let mut out = rank(&s, method, descending, seed); - unsafe { - let arr = &mut out.chunks_mut()[0]; - *arr = arr.with_validity(Some(validity.clone())) - } - return out; - } - - // See: https://github.com/scipy/scipy/blob/v1.7.1/scipy/stats/stats.py#L8631-L8737 - - let len = s.len(); - let null_count = s.null_count(); - let sort_idx_ca = s.arg_sort(SortOptions { - descending, - ..Default::default() - }); - let sort_idx = sort_idx_ca.downcast_iter().next().unwrap().values(); - - let mut inv: Vec = Vec::with_capacity(len); - // Safety: - // Values will be filled next and there is only primitive data - #[allow(clippy::uninit_vec)] - unsafe { - inv.set_len(len) - } - let inv_values = inv.as_mut_slice(); - - #[cfg(feature = "random")] - let mut count = if let RankMethod::Ordinal | RankMethod::Random = method { - 1 as IdxSize - } else { - 0 - }; - - #[cfg(not(feature = "random"))] - let mut count = if let RankMethod::Ordinal = method { - 1 as IdxSize - } else { - 0 - }; - - // Safety: - // we are in bounds - unsafe { - sort_idx.iter().for_each(|&i| { - *inv_values.get_unchecked_mut(i as usize) = count; - count += 1; - }); - } - - use RankMethod::*; - match method { - Ordinal => { - let inv_ca = IdxCa::from_vec(s.name(), inv); - inv_ca.into_series() - }, - #[cfg(feature = "random")] - Random => { - // Safety: - // in bounds - let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() }; - let not_consecutive_same = arr - .slice(1, len - 1) - .not_equal(&arr.slice(0, len - 1)) - .unwrap() - .rechunk(); - let obs = not_consecutive_same.downcast_iter().next().unwrap(); - - // Collect slice indices for sort_idx which point to ties in the original series. - let mut ties_indices = Vec::with_capacity(len + 1); - let mut ties_index: usize = 0; - - ties_indices.push(ties_index); - obs.iter().for_each(|b| { - if let Some(b) = b { - ties_index += 1; - if b { - ties_indices.push(ties_index) - } - } - }); - // Close last slice (if there where nulls in the original series, they will always be in the last slice). - ties_indices.push(len); - - let mut sort_idx = sort_idx.to_vec(); - - let rng = &mut SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed)); - - // Shuffle sort_idx positions which point to ties in the original series. - for i in 0..(ties_indices.len() - 1) { - let ties_index_start = ties_indices[i]; - let ties_index_end = ties_indices[i + 1]; - if ties_index_end - ties_index_start > 1 { - sort_idx[ties_index_start..ties_index_end].shuffle(rng); - } - } - - // Recreate inv_ca (where ties are randomly shuffled compared with Ordinal). - let mut count = 1 as IdxSize; - unsafe { - sort_idx.iter().for_each(|&i| { - *inv_values.get_unchecked_mut(i as usize) = count; - count += 1; - }); - } - - let inv_ca = IdxCa::from_vec(s.name(), inv); - inv_ca.into_series() - }, - _ => { - let inv_ca = IdxCa::from_vec(s.name(), inv); - // SAFETY: in bounds. - let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() }; - let validity = arr.chunks()[0].validity().cloned(); - let not_consecutive_same = arr - .slice(1, len - 1) - .not_equal(&arr.slice(0, len - 1)) - .unwrap() - .rechunk(); - // This obs is shorter than that of scipy stats, because we can just start the cumsum by 1 - // instead of 0. - let obs = not_consecutive_same.downcast_iter().next().unwrap(); - let mut dense = Vec::with_capacity(len); - - // This offset save an offset on the whole column, what scipy does in: - // - // ```python - // if method == 'min': - // return count[dense - 1] + 1 - // ``` - // INVALID LINT REMOVE LATER - #[allow(clippy::bool_to_int_with_if)] - let mut cumsum: IdxSize = if let RankMethod::Min = method { - 0 - } else { - // Nulls will be first, rank, but we will replace them (with null), - // this ensures the second rank will be 1. - if matches!(method, RankMethod::Dense) && s.null_count() > 0 { - 0 - } else { - 1 - } - }; - - dense.push(cumsum); - obs.values_iter().for_each(|b| { - if b { - cumsum += 1; - } - dense.push(cumsum) - }); - let arr = IdxArr::from_data_default(dense.into(), validity); - let dense = IdxCa::with_chunk(s.name(), arr); - - // SAFETY: in bounds. - let dense = unsafe { dense.take_unchecked((&inv_ca).into()) }; - - if let RankMethod::Dense = method { - return if s.null_count() == 0 { - dense.into_series() - } else { - // Null will be the first rank. We restore original nulls and shift all ranks by one. - let validity = s.is_null().rechunk(); - let validity = validity.downcast_iter().next().unwrap(); - let validity = validity.values().clone(); - - let arr = dense.downcast_iter().next().unwrap(); - let arr = arr.with_validity(Some(validity)); - let dtype = arr.data_type().clone(); - - // SAFETY: given dtype is correct. - unsafe { - Series::try_from_arrow_unchecked(s.name(), vec![arr], &dtype).unwrap() - } - }; - } - - let bitmap = obs.values(); - let cap = bitmap.len() - bitmap.unset_bits(); - let mut count = Vec::with_capacity(cap + 1); - let mut cnt: IdxSize = 0; - count.push(cnt); - - if null_count > 0 { - obs.iter().for_each(|b| { - if let Some(b) = b { - cnt += 1; - if b { - count.push(cnt) - } - } - }); - } else { - obs.values_iter().for_each(|b| { - cnt += 1; - if b { - count.push(cnt) - } - }); - } - - count.push((len - null_count) as IdxSize); - let count = IdxCa::from_vec(s.name(), count); - - match method { - Max => { - // SAFETY: in bounds. - unsafe { count.take_unchecked((&dense).into()).into_series() } - }, - Min => { - // SAFETY: in bounds. - unsafe { (count.take_unchecked((&dense).into()) + 1).into_series() } - }, - Average => { - // SAFETY: in bounds. - let a = unsafe { count.take_unchecked((&dense).into()) } - .cast(&DataType::Float64) - .unwrap(); - let b = unsafe { count.take_unchecked((&(dense - 1)).into()) } - .cast(&DataType::Float64) - .unwrap() - + 1.0; - (&a + &b) * 0.5 - }, - #[cfg(feature = "random")] - Dense | Ordinal | Random => unimplemented!(), - #[cfg(not(feature = "random"))] - Dense | Ordinal => unimplemented!(), - } - }, - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_rank() -> PolarsResult<()> { - let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]); - - let out = rank(&s, RankMethod::Ordinal, false, None) - .idx()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]); - - #[cfg(feature = "random")] - { - let out = rank(&s, RankMethod::Random, false, None) - .idx()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out[0], 2); - assert_eq!(out[6], 1); - assert_eq!(out[1] + out[3] + out[4], 12); - assert_eq!(out[2] + out[5], 13); - assert_ne!(out[1], out[3]); - assert_ne!(out[1], out[4]); - assert_ne!(out[3], out[4]); - } - - let out = rank(&s, RankMethod::Dense, false, None) - .idx()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]); - - let out = rank(&s, RankMethod::Max, false, None) - .idx()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]); - - let out = rank(&s, RankMethod::Min, false, None) - .idx()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]); - - let out = rank(&s, RankMethod::Average, false, None) - .f64()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]); - - let s = Series::new( - "a", - &[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)], - ); - - let out = rank(&s, RankMethod::Average, false, None) - .f64()? - .into_iter() - .collect::>(); - - assert_eq!( - out, - &[ - Some(2.0f64), - Some(3.5), - Some(5.0), - Some(3.5), - None, - None, - Some(1.0) - ] - ); - let s = Series::new( - "a", - &[ - Some(5), - Some(6), - Some(4), - None, - Some(78), - Some(4), - Some(2), - Some(8), - ], - ); - let out = rank(&s, RankMethod::Max, false, None) - .idx()? - .into_iter() - .collect::>(); - assert_eq!( - out, - &[ - Some(4), - Some(5), - Some(3), - None, - Some(7), - Some(3), - Some(1), - Some(6) - ] - ); - - Ok(()) - } - - #[test] - fn test_rank_all_null() -> PolarsResult<()> { - let s = UInt32Chunked::new("", &[None, None, None]).into_series(); - let out = rank(&s, RankMethod::Average, false, None) - .f64()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[2.0f64, 2.0, 2.0]); - let out = rank(&s, RankMethod::Dense, false, None) - .idx()? - .into_no_null_iter() - .collect::>(); - assert_eq!(out, &[1, 1, 1]); - Ok(()) - } - - #[test] - fn test_rank_empty() { - let s = UInt32Chunked::from_slice("", &[]).into_series(); - let out = rank(&s, RankMethod::Average, false, None); - assert_eq!(out.dtype(), &DataType::Float64); - let out = rank(&s, RankMethod::Max, false, None); - assert_eq!(out.dtype(), &IDX_DTYPE); - } - - #[test] - fn test_rank_reverse() -> PolarsResult<()> { - let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]); - let out = rank(&s, RankMethod::Dense, true, None) - .idx()? - .into_iter() - .collect::>(); - assert_eq!(out, &[None, Some(2 as IdxSize), Some(2), Some(1), None]); - - Ok(()) - } -} diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 74ece90d2c3e..446f3518e5ee 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -2,6 +2,7 @@ use num_traits::{Float, NumCast}; use polars_error::to_compute_err; use rand::distributions::Bernoulli; use rand::prelude::*; +use rand::seq::index::IndexVec; use rand_distr::{Distribution, Normal, Standard, StandardNormal, Uniform}; use crate::prelude::*; @@ -27,15 +28,21 @@ fn create_rand_index_no_replacement( shuffle: bool, ) -> IdxCa { let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); - let mut buf; + let mut buf: Vec; if n == len { buf = (0..len as IdxSize).collect(); + if shuffle { + buf.shuffle(&mut rng) + } } else { - buf = vec![0; n]; - (0..len as IdxSize).choose_multiple_fill(&mut rng, &mut buf); - } - if shuffle { - buf.shuffle(&mut rng) + // TODO: avoid extra potential copy by vendoring rand::seq::index::sample, + // or genericize take over slices over any unsigned type. The optimizer + // should get rid of the extra copy already if IdxSize matches the IndexVec + // size returned. + buf = match rand::seq::index::sample(&mut rng, len, n) { + IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(), + IndexVec::USize(v) => v.into_iter().map(|x| x as IdxSize).collect(), + }; } IdxCa::new_vec("", buf) } @@ -85,20 +92,20 @@ impl Series { match with_replacement { true => { let idx = create_rand_index_with_replacement(n, len, seed); - // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { self.take_unchecked(&idx) } + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } }, false => { let idx = create_rand_index_no_replacement(n, len, seed, shuffle); - // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { self.take_unchecked(&idx) } + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } }, } } - /// Sample a fraction between 0.0-1.0 of this ChunkedArray. + /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`]. pub fn sample_frac( &self, frac: f64, @@ -114,18 +121,18 @@ impl Series { let len = self.len(); let n = len; let idx = create_rand_index_no_replacement(n, len, seed, true); - // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { self.take_unchecked(&idx).unwrap() } + // SAFETY: we know that we never go out of bounds. + unsafe { self.take_unchecked(&idx) } } } impl ChunkedArray where T: PolarsDataType, - ChunkedArray: ChunkTake, + ChunkedArray: ChunkTake, { - /// Sample n datapoints from this ChunkedArray. + /// Sample n datapoints from this [`ChunkedArray`]. pub fn sample_n( &self, n: usize, @@ -139,20 +146,20 @@ where match with_replacement { true => { let idx = create_rand_index_with_replacement(n, len, seed); - // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { Ok(self.take_unchecked((&idx).into())) } + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } }, false => { let idx = create_rand_index_no_replacement(n, len, seed, shuffle); - // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { Ok(self.take_unchecked((&idx).into())) } + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } }, } } - /// Sample a fraction between 0.0-1.0 of this ChunkedArray. + /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`]. pub fn sample_frac( &self, frac: f64, @@ -166,8 +173,36 @@ where } impl DataFrame { - /// Sample n datapoints from this DataFrame. + /// Sample n datapoints from this [`DataFrame`]. pub fn sample_n( + &self, + n: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + polars_ensure!( + n.len() == 1, + ComputeError: "Sample size must be a single value." + ); + + let n = n.cast(&IDX_DTYPE)?; + let n = n.idx()?; + + match n.get(0) { + Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), + None => { + let new_cols = self + .columns + .iter() + .map(|c| Series::new_empty(c.name(), c.dtype())) + .collect_trusted(); + Ok(DataFrame::new_no_checks(new_cols)) + }, + } + } + + pub fn sample_n_literal( &self, n: usize, with_replacement: bool, @@ -175,17 +210,16 @@ impl DataFrame { seed: Option, ) -> PolarsResult { ensure_shape(n, self.height(), with_replacement)?; - // all columns should used the same indices. So we first create the indices. + // All columns should used the same indices. So we first create the indices. let idx = match with_replacement { true => create_rand_index_with_replacement(n, self.height(), seed), false => create_rand_index_no_replacement(n, self.height(), seed, shuffle), }; - // Safety: - // indices are within bounds + // SAFETY: the indices are within bounds. Ok(unsafe { self.take_unchecked(&idx) }) } - /// Sample a fraction between 0.0-1.0 of this DataFrame. + /// Sample a fraction between 0.0-1.0 of this [`DataFrame`]. pub fn sample_frac( &self, frac: f64, @@ -194,7 +228,7 @@ impl DataFrame { seed: Option, ) -> PolarsResult { let n = (self.height() as f64 * frac) as usize; - self.sample_n(n, with_replacement, shuffle, seed) + self.sample_n_literal(n, with_replacement, shuffle, seed) } } @@ -203,7 +237,7 @@ where T: PolarsNumericType, T::Native: Float, { - /// Create `ChunkedArray` with samples from a Normal distribution. + /// Create [`ChunkedArray`] with samples from a Normal distribution. pub fn rand_normal(name: &str, length: usize, mean: f64, std_dev: f64) -> PolarsResult { let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?; let mut builder = PrimitiveChunkedBuilder::::new(name, length); @@ -216,7 +250,7 @@ where Ok(builder.finish()) } - /// Create `ChunkedArray` with samples from a Standard Normal distribution. + /// Create [`ChunkedArray`] with samples from a Standard Normal distribution. pub fn rand_standard_normal(name: &str, length: usize) -> Self { let mut builder = PrimitiveChunkedBuilder::::new(name, length); let mut rng = rand::thread_rng(); @@ -228,7 +262,7 @@ where builder.finish() } - /// Create `ChunkedArray` with samples from a Uniform distribution. + /// Create [`ChunkedArray`] with samples from a Uniform distribution. pub fn rand_uniform(name: &str, length: usize, low: f64, high: f64) -> Self { let uniform = Uniform::new(low, high); let mut builder = PrimitiveChunkedBuilder::::new(name, length); @@ -243,7 +277,7 @@ where } impl BooleanChunked { - /// Create `ChunkedArray` with samples from a Bernoulli distribution. + /// Create [`ChunkedArray`] with samples from a Bernoulli distribution. pub fn rand_bernoulli(name: &str, length: usize, p: f64) -> PolarsResult { let dist = Bernoulli::new(p).map_err(to_compute_err)?; let mut rng = rand::thread_rng(); @@ -267,17 +301,23 @@ mod test { ] .unwrap(); - // default samples are random and don't require seeds - assert!(df.sample_n(3, false, false, None).is_ok()); + // Default samples are random and don't require seeds. + assert!(df + .sample_n(&Series::new("s", &[3]), false, false, None) + .is_ok()); assert!(df.sample_frac(0.4, false, false, None).is_ok()); - // with seeding - assert!(df.sample_n(3, false, false, Some(0)).is_ok()); + // With seeding. + assert!(df + .sample_n(&Series::new("s", &[3]), false, false, Some(0)) + .is_ok()); assert!(df.sample_frac(0.4, false, false, Some(0)).is_ok()); - // without replacement can not sample more than 100% + // Without replacement can not sample more than 100%. assert!(df.sample_frac(2.0, false, false, Some(0)).is_err()); - assert!(df.sample_n(3, true, false, Some(0)).is_ok()); + assert!(df + .sample_n(&Series::new("s", &[3]), true, false, Some(0)) + .is_ok()); assert!(df.sample_frac(0.4, true, false, Some(0)).is_ok()); - // with replacement can sample more than 100% + // With replacement can sample more than 100%. assert!(df.sample_frac(2.0, true, false, Some(0)).is_ok()); } } diff --git a/crates/polars-core/src/chunked_array/temporal/conversion.rs b/crates/polars-core/src/chunked_array/temporal/conversion.rs index 1657e6a754f3..34baa7c7533e 100644 --- a/crates/polars-core/src/chunked_array/temporal/conversion.rs +++ b/crates/polars-core/src/chunked_array/temporal/conversion.rs @@ -36,7 +36,7 @@ impl From<&AnyValue<'_>> for NaiveTime { // Used by lazy for literal conversion pub fn datetime_to_timestamp_ns(v: NaiveDateTime) -> i64 { - v.timestamp_nanos() + v.timestamp_nanos_opt().unwrap() } // Used by lazy for literal conversion diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs index cca5597e4cf7..8f7dde50e0fa 100644 --- a/crates/polars-core/src/chunked_array/temporal/datetime.rs +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -4,6 +4,7 @@ use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, }; use chrono::format::{DelayedFormat, StrftimeItems}; +use chrono::NaiveDate; #[cfg(feature = "timezones")] use chrono::TimeZone as TimeZoneTrait; #[cfg(feature = "timezones")] diff --git a/crates/polars-core/src/chunked_array/to_vec.rs b/crates/polars-core/src/chunked_array/to_vec.rs index 8367432b542f..1a8ed2798e65 100644 --- a/crates/polars-core/src/chunked_array/to_vec.rs +++ b/crates/polars-core/src/chunked_array/to_vec.rs @@ -3,7 +3,7 @@ use either::Either; use crate::prelude::*; impl ChunkedArray { - /// Convert to a `Vec` of `Option`. + /// Convert to a [`Vec`] of [`Option`]. pub fn to_vec(&self) -> Vec> { let mut buf = Vec::with_capacity(self.len()); for arr in self.downcast_iter() { @@ -12,7 +12,7 @@ impl ChunkedArray { buf } - /// Convert to a `Vec` but don't return `Option` if there are no null values + /// Convert to a [`Vec`] but don't return [`Option`] if there are no null values pub fn to_vec_null_aware(&self) -> Either, Vec>> { if self.null_count() == 0 { let mut buf = Vec::with_capacity(self.len()); diff --git a/crates/polars-core/src/chunked_array/trusted_len.rs b/crates/polars-core/src/chunked_array/trusted_len.rs index 4cca38c2b6ff..cce5785c278c 100644 --- a/crates/polars-core/src/chunked_array/trusted_len.rs +++ b/crates/polars-core/src/chunked_array/trusted_len.rs @@ -1,7 +1,5 @@ use std::borrow::Borrow; -use arrow::bitmap::MutableBitmap; -use polars_arrow::bit_util::{set_bit_raw, unset_bit_raw}; use polars_arrow::trusted_len::{FromIteratorReversed, TrustedLenPush}; use crate::chunked_array::upstream_traits::PolarsAsRef; @@ -49,70 +47,7 @@ where T: PolarsNumericType, { fn from_trusted_len_iter_rev>>(iter: I) -> Self { - let size = iter.size_hint().1.unwrap(); - - let mut vals: Vec = Vec::with_capacity(size); - let mut validity = MutableBitmap::with_capacity(size); - validity.extend_constant(size, true); - let validity_ptr = validity.as_slice().as_ptr() as *mut u8; - unsafe { - // Set to end of buffer. - let mut ptr = vals.as_mut_ptr().add(size); - let mut offset = size; - - iter.for_each(|opt_item| { - offset -= 1; - ptr = ptr.sub(1); - match opt_item { - Some(item) => { - std::ptr::write(ptr, item); - }, - None => { - std::ptr::write(ptr, T::Native::default()); - unset_bit_raw(validity_ptr, offset) - }, - } - }); - vals.set_len(size) - } - let arr = PrimitiveArray::new( - T::get_dtype().to_arrow(), - vals.into(), - Some(validity.into()), - ); - arr.into() - } -} - -impl FromIteratorReversed> for BooleanChunked { - fn from_trusted_len_iter_rev>>(iter: I) -> Self { - let size = iter.size_hint().1.unwrap(); - - let vals = MutableBitmap::from_len_zeroed(size); - let mut validity = MutableBitmap::with_capacity(size); - validity.extend_constant(size, true); - let validity_ptr = validity.as_slice().as_ptr() as *mut u8; - let vals_ptr = vals.as_slice().as_ptr() as *mut u8; - unsafe { - let mut offset = size; - - iter.for_each(|opt_item| { - offset -= 1; - match opt_item { - Some(item) => { - if item { - // Set value (validity bit is already true). - set_bit_raw(vals_ptr, offset); - } - }, - None => { - // Unset validity bit. - unset_bit_raw(validity_ptr, offset) - }, - } - }); - } - let arr = BooleanArray::new(ArrowDataType::Boolean, vals.into(), Some(validity.into())); + let arr: PrimitiveArray = iter.collect_reversed(); arr.into() } } @@ -122,20 +57,21 @@ where T: PolarsNumericType, { fn from_trusted_len_iter_rev>(iter: I) -> Self { - let size = iter.size_hint().1.unwrap(); + let arr: PrimitiveArray = iter.collect_reversed(); + NoNull::new(arr.into()) + } +} - let mut vals: Vec = Vec::with_capacity(size); - unsafe { - // Set to end of buffer. - let mut ptr = vals.as_mut_ptr().add(size); +impl FromIteratorReversed> for BooleanChunked { + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let arr: BooleanArray = iter.collect_reversed(); + arr.into() + } +} - iter.for_each(|item| { - ptr = ptr.sub(1); - std::ptr::write(ptr, item); - }); - vals.set_len(size) - } - let arr = PrimitiveArray::new(T::get_dtype().to_arrow(), vals.into(), None); +impl FromIteratorReversed for NoNull { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let arr: BooleanArray = iter.collect_reversed(); NoNull::new(arr.into()) } } diff --git a/crates/polars-core/src/chunked_array/upstream_traits.rs b/crates/polars-core/src/chunked_array/upstream_traits.rs index 5aeeb9117b7d..af24444fdf14 100644 --- a/crates/polars-core/src/chunked_array/upstream_traits.rs +++ b/crates/polars-core/src/chunked_array/upstream_traits.rs @@ -35,37 +35,13 @@ impl Default for ChunkedArray { } /// FromIterator trait - impl FromIterator> for ChunkedArray where T: PolarsNumericType, { fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - - let arr: PrimitiveArray = match iter.size_hint() { - (a, Some(b)) if a == b => { - // 2021-02-07: ~40% faster than builder. - // It is unsafe because we cannot be certain that the iterators length can be trusted. - // For most iterators that report the same upper bound as lower bound it is, but still - // somebody can create an iterator that incorrectly gives those bounds. - // This will not lead to UB, but will panic. - #[cfg(feature = "performant")] - unsafe { - let arr = PrimitiveArray::from_trusted_len_iter_unchecked(iter) - .to(T::get_dtype().to_arrow()); - assert_eq!(arr.len(), a); - arr - } - #[cfg(not(feature = "performant"))] - iter.collect::>() - .to(T::get_dtype().to_arrow()) - }, - _ => iter - .collect::>() - .to(T::get_dtype().to_arrow()), - }; - arr.into() + // TODO: eliminate this FromIterator implementation entirely. + iter.into_iter().collect_ca("") } } diff --git a/crates/polars-core/src/datatypes/aliases.rs b/crates/polars-core/src/datatypes/aliases.rs index 1d7e930ddc44..87cb707da2c4 100644 --- a/crates/polars-core/src/datatypes/aliases.rs +++ b/crates/polars-core/src/datatypes/aliases.rs @@ -1,12 +1,15 @@ +pub use polars_arrow::index::{IdxArr, IdxSize}; + use super::*; +use crate::hashing::IdBuildHasher; + +/// [ChunkIdx, DfIdx] +pub type ChunkId = [IdxSize; 2]; #[cfg(not(feature = "bigidx"))] pub type IdxCa = UInt32Chunked; #[cfg(feature = "bigidx")] pub type IdxCa = UInt64Chunked; -pub use polars_arrow::index::{IdxArr, IdxSize}; - -use crate::hashing::IdBuildHasher; #[cfg(not(feature = "bigidx"))] pub const IDX_DTYPE: DataType = DataType::UInt32; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index d9d1451ba6cf..7778475049d1 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -399,7 +399,14 @@ impl<'a> AnyValue<'a> { #[cfg(feature = "dtype-duration")] Duration(v, _) => NumCast::from(*v), #[cfg(feature = "dtype-decimal")] - Decimal(v, _) => NumCast::from(*v), + Decimal(v, scale) => { + if *scale == 0 { + NumCast::from(*v) + } else { + let f: Option = NumCast::from(*v); + NumCast::from(f? / 10f64.powi(*scale as _)) + } + }, Boolean(v) => { if *v { NumCast::from(1) diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 6043fa31504e..165cfcbda847 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -161,39 +161,29 @@ impl DataType { self.is_numeric() | matches!(self, DataType::Boolean | DataType::Utf8 | DataType::Binary) } - /// Check if this [`DataType`] is a numeric type. + /// Check if this [`DataType`] is a basic numeric type (excludes Decimal). pub fn is_numeric(&self) -> bool { - // allow because it cannot be replaced when object feature is activated - #[allow(clippy::match_like_matches_macro)] - match self { - DataType::Utf8 - | DataType::List(_) - | DataType::Date - | DataType::Datetime(_, _) - | DataType::Duration(_) - | DataType::Time - | DataType::Boolean - | DataType::Unknown - | DataType::Null => false, - DataType::Binary => false, - #[cfg(feature = "object")] - DataType::Object(_) => false, - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_) => false, - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => false, - #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => false, - _ => true, - } + self.is_float() || self.is_integer() } + /// Check if this [`DataType`] is a basic floating point type (excludes Decimal). pub fn is_float(&self) -> bool { matches!(self, DataType::Float32 | DataType::Float64) } + /// Check if this [`DataType`] is an integer. pub fn is_integer(&self) -> bool { - self.is_numeric() && !matches!(self, DataType::Float32 | DataType::Float64) + matches!( + self, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) } pub fn is_signed(&self) -> bool { diff --git a/crates/polars-core/src/datatypes/from_values.rs b/crates/polars-core/src/datatypes/from_values.rs deleted file mode 100644 index 07341355caa9..000000000000 --- a/crates/polars-core/src/datatypes/from_values.rs +++ /dev/null @@ -1,185 +0,0 @@ -use std::borrow::Cow; -use std::error::Error; - -use arrow::array::{ - BinaryArray, BooleanArray, MutableBinaryArray, MutableBinaryValuesArray, MutablePrimitiveArray, - MutableUtf8Array, MutableUtf8ValuesArray, PrimitiveArray, Utf8Array, -}; -use arrow::bitmap::Bitmap; -use polars_arrow::array::utf8::{BinaryFromIter, Utf8FromIter}; -use polars_arrow::prelude::FromData; -use polars_arrow::trusted_len::TrustedLen; - -use crate::datatypes::NumericNative; -use crate::prelude::StaticArray; - -pub trait ArrayFromElementIter -where - Self: Sized, -{ - type ArrayType: StaticArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType; - - fn array_from_values_iter>(iter: I) -> Self::ArrayType; - - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result; - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result; -} - -impl ArrayFromElementIter for bool { - type ArrayType = BooleanArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BooleanArray::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BooleanArray::from_trusted_len_values_iter_unchecked(iter) } - } - - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BooleanArray::try_from_trusted_len_iter_unchecked(iter) } - } - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - // SAFETY: guarded by `TrustedLen` trait - let values = unsafe { Bitmap::try_from_trusted_len_iter_unchecked(iter) }?; - Ok(BooleanArray::from_data_default(values, None)) - } -} - -impl ArrayFromElementIter for T -where - T: NumericNative, -{ - type ArrayType = PrimitiveArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { PrimitiveArray::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { PrimitiveArray::from_trusted_len_values_iter_unchecked(iter) } - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - // SAFETY: guarded by `TrustedLen` trait - unsafe { Ok(MutablePrimitiveArray::try_from_trusted_len_iter_unchecked(iter)?.into()) } - } - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let values: Vec<_> = iter.collect::, _>>()?; - Ok(PrimitiveArray::from_vec(values)) - } -} - -impl ArrayFromElementIter for &str { - type ArrayType = Utf8Array; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { Utf8Array::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - let len = iter.size_hint().0; - Utf8Array::from_values_iter(iter, len, len * 24) - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8Array::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8ValuesArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } -} - -impl ArrayFromElementIter for Cow<'_, str> { - type ArrayType = Utf8Array; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { Utf8Array::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - let len = iter.size_hint().0; - Utf8Array::from_values_iter(iter, len, len * 24) - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8Array::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8ValuesArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } -} - -impl ArrayFromElementIter for Cow<'_, [u8]> { - type ArrayType = BinaryArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BinaryArray::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - let len = iter.size_hint().0; - BinaryArray::from_values_iter(iter, len, len * 24) - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableBinaryArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableBinaryValuesArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } -} diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 120820365207..a0940454d4ea 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -12,8 +12,8 @@ mod aliases; mod any_value; mod dtype; mod field; -mod from_values; mod static_array; +mod static_array_collect; mod time_unit; use std::cmp::Ordering; @@ -30,10 +30,10 @@ use arrow::datatypes::IntegerType; pub use arrow::datatypes::{DataType as ArrowDataType, TimeUnit as ArrowTimeUnit}; use arrow::types::simd::Simd; use arrow::types::NativeType; +use bytemuck::Zeroable; pub use dtype::*; pub use field::*; -pub use from_values::ArrayFromElementIter; -use num_traits::{Bounded, FromPrimitive, Num, NumCast, Zero}; +use num_traits::{Bounded, FromPrimitive, Num, NumCast, One, Zero}; use polars_arrow::data_types::IsFloat; #[cfg(feature = "serde")] use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor}; @@ -42,129 +42,180 @@ use serde::{Deserialize, Serialize}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] use serde::{Deserializer, Serializer}; pub use static_array::StaticArray; +pub use static_array_collect::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype}; pub use time_unit::*; use crate::chunked_array::arithmetic::ArrayArithmetics; pub use crate::chunked_array::logical::*; #[cfg(feature = "object")] +use crate::chunked_array::object::ObjectArray; +#[cfg(feature = "object")] use crate::chunked_array::object::PolarsObjectSafe; use crate::prelude::*; use crate::utils::Wrap; -pub struct Utf8Type {} - -pub struct BinaryType {} +pub struct Nested; +pub struct Flat; -#[cfg(feature = "dtype-array")] -pub struct FixedSizeListType {} - -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct ListType {} +/// # Safety +/// +/// The StaticArray and dtype return must be correct. +pub unsafe trait PolarsDataType: Send + Sync + Sized { + type Physical<'a>; + type ZeroablePhysical<'a>: Zeroable + From>; + type Array: for<'a> StaticArray< + ValueT<'a> = Self::Physical<'a>, + ZeroableValueT<'a> = Self::ZeroablePhysical<'a>, + >; + type Structure; -pub trait PolarsDataType: Send + Sync { fn get_dtype() -> DataType where Self: Sized; } -macro_rules! impl_polars_datatype { - ($ca:ident, $variant:ident, $physical:ty) => { +pub trait PolarsNumericType: 'static +where + Self: for<'a> PolarsDataType< + Physical<'a> = Self::Native, + ZeroablePhysical<'a> = Self::Native, + Array = PrimitiveArray, + Structure = Flat, + >, +{ + type Native: NumericNative; +} + +pub trait PolarsIntegerType: PolarsNumericType {} +pub trait PolarsFloatType: PolarsNumericType {} + +macro_rules! impl_polars_num_datatype { + ($trait: ident, $ca:ident, $variant:ident, $physical:ty) => { #[derive(Clone, Copy)] pub struct $ca {} - impl PolarsDataType for $ca { + unsafe impl PolarsDataType for $ca { + type Physical<'a> = $physical; + type ZeroablePhysical<'a> = $physical; + type Array = PrimitiveArray<$physical>; + type Structure = Flat; + #[inline] fn get_dtype() -> DataType { DataType::$variant } } + + impl PolarsNumericType for $ca { + type Native = $physical; + } + + impl $trait for $ca {} }; } -impl_polars_datatype!(UInt8Type, UInt8, u8); -impl_polars_datatype!(UInt16Type, UInt16, u16); -impl_polars_datatype!(UInt32Type, UInt32, u32); -impl_polars_datatype!(UInt64Type, UInt64, u64); -impl_polars_datatype!(Int8Type, Int8, i8); -impl_polars_datatype!(Int16Type, Int16, i16); -impl_polars_datatype!(Int32Type, Int32, i32); -impl_polars_datatype!(Int64Type, Int64, i64); -impl_polars_datatype!(Float32Type, Float32, f32); -impl_polars_datatype!(Float64Type, Float64, f64); -impl_polars_datatype!(DateType, Date, i32); -#[cfg(feature = "dtype-decimal")] -impl_polars_datatype!(DecimalType, Unknown, i128); -impl_polars_datatype!(DatetimeType, Unknown, i64); -impl_polars_datatype!(DurationType, Unknown, i64); -impl_polars_datatype!(CategoricalType, Unknown, u32); -impl_polars_datatype!(TimeType, Time, i64); +macro_rules! impl_polars_datatype { + ($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => { + #[derive(Clone, Copy)] + pub struct $ca {} -impl PolarsDataType for Utf8Type { - fn get_dtype() -> DataType { - DataType::Utf8 - } -} + unsafe impl PolarsDataType for $ca { + type Physical<$lt> = $phys; + type ZeroablePhysical<$lt> = $zerophys; + type Array = $arr; + type Structure = Flat; -impl PolarsDataType for BinaryType { - fn get_dtype() -> DataType { - DataType::Binary - } + #[inline] + fn get_dtype() -> DataType { + DataType::$variant + } + } + }; } -pub struct BooleanType {} +impl_polars_num_datatype!(PolarsIntegerType, UInt8Type, UInt8, u8); +impl_polars_num_datatype!(PolarsIntegerType, UInt16Type, UInt16, u16); +impl_polars_num_datatype!(PolarsIntegerType, UInt32Type, UInt32, u32); +impl_polars_num_datatype!(PolarsIntegerType, UInt64Type, UInt64, u64); +impl_polars_num_datatype!(PolarsIntegerType, Int8Type, Int8, i8); +impl_polars_num_datatype!(PolarsIntegerType, Int16Type, Int16, i16); +impl_polars_num_datatype!(PolarsIntegerType, Int32Type, Int32, i32); +impl_polars_num_datatype!(PolarsIntegerType, Int64Type, Int64, i64); +impl_polars_num_datatype!(PolarsFloatType, Float32Type, Float32, f32); +impl_polars_num_datatype!(PolarsFloatType, Float64Type, Float64, f64); +impl_polars_datatype!(DateType, Date, PrimitiveArray, 'a, i32, i32); +#[cfg(feature = "dtype-decimal")] +impl_polars_datatype!(DecimalType, Unknown, PrimitiveArray, 'a, i128, i128); +impl_polars_datatype!(DatetimeType, Unknown, PrimitiveArray, 'a, i64, i64); +impl_polars_datatype!(DurationType, Unknown, PrimitiveArray, 'a, i64, i64); +impl_polars_datatype!(CategoricalType, Unknown, PrimitiveArray, 'a, u32, u32); +impl_polars_datatype!(TimeType, Time, PrimitiveArray, 'a, i64, i64); +impl_polars_datatype!(Utf8Type, Utf8, Utf8Array, 'a, &'a str, Option<&'a str>); +impl_polars_datatype!(BinaryType, Binary, BinaryArray, 'a, &'a [u8], Option<&'a [u8]>); +impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool); -impl PolarsDataType for BooleanType { - fn get_dtype() -> DataType { - DataType::Boolean - } -} +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ListType {} +unsafe impl PolarsDataType for ListType { + type Physical<'a> = Box; + type ZeroablePhysical<'a> = Option>; + type Array = ListArray; + type Structure = Nested; -impl PolarsDataType for ListType { fn get_dtype() -> DataType { - // null as we cannot know anything without self. + // Null as we cannot know anything without self. DataType::List(Box::new(DataType::Null)) } } #[cfg(feature = "dtype-array")] -impl PolarsDataType for FixedSizeListType { +pub struct FixedSizeListType {} +#[cfg(feature = "dtype-array")] +unsafe impl PolarsDataType for FixedSizeListType { + type Physical<'a> = Box; + type ZeroablePhysical<'a> = Option>; + type Array = FixedSizeListArray; + type Structure = Nested; + fn get_dtype() -> DataType { - // null as we cannot know anything without self. + // Null as we cannot know anything without self. DataType::Array(Box::new(DataType::Null), 0) } } - #[cfg(feature = "dtype-decimal")] pub struct Int128Type {} - #[cfg(feature = "dtype-decimal")] -impl PolarsDataType for Int128Type { +unsafe impl PolarsDataType for Int128Type { + type Physical<'a> = i128; + type ZeroablePhysical<'a> = i128; + type Array = PrimitiveArray; + type Structure = Flat; + fn get_dtype() -> DataType { - DataType::Decimal(None, Some(0)) // scale is not None to allow for get_any_value() to work + // Scale is not None to allow for get_any_value() to work. + DataType::Decimal(None, Some(0)) } } - +#[cfg(feature = "dtype-decimal")] +impl PolarsNumericType for Int128Type { + type Native = i128; +} +#[cfg(feature = "dtype-decimal")] +impl PolarsIntegerType for Int128Type {} #[cfg(feature = "object")] pub struct ObjectType(T); #[cfg(feature = "object")] -pub type ObjectChunked = ChunkedArray>; +unsafe impl PolarsDataType for ObjectType { + type Physical<'a> = &'a T; + type ZeroablePhysical<'a> = Option<&'a T>; + type Array = ObjectArray; + type Structure = Nested; -#[cfg(feature = "object")] -impl PolarsDataType for ObjectType { fn get_dtype() -> DataType { DataType::Object(T::type_name()) } } -/// Any type that is not nested -pub trait PolarsSingleType: PolarsDataType {} - -impl PolarsSingleType for T where T: NativeType + PolarsDataType {} - -impl PolarsSingleType for Utf8Type {} - -impl PolarsSingleType for BinaryType {} - #[cfg(feature = "dtype-array")] pub type ArrayChunked = ChunkedArray; pub type ListChunked = ChunkedArray; @@ -183,13 +234,17 @@ pub type Float32Chunked = ChunkedArray; pub type Float64Chunked = ChunkedArray; pub type Utf8Chunked = ChunkedArray; pub type BinaryChunked = ChunkedArray; +#[cfg(feature = "object")] +pub type ObjectChunked = ChunkedArray>; pub trait NumericNative: - PartialOrd + TotalOrd + + PartialOrd + NativeType + Num + NumCast + Zero + + One + Simd + Simd8 + std::iter::Sum @@ -205,147 +260,40 @@ pub trait NumericNative: + IsFloat + ArrayArithmetics { - type POLARSTYPE: PolarsNumericType; + type PolarsType: PolarsNumericType; } impl NumericNative for i8 { - type POLARSTYPE = Int8Type; + type PolarsType = Int8Type; } impl NumericNative for i16 { - type POLARSTYPE = Int16Type; + type PolarsType = Int16Type; } impl NumericNative for i32 { - type POLARSTYPE = Int32Type; + type PolarsType = Int32Type; } impl NumericNative for i64 { - type POLARSTYPE = Int64Type; + type PolarsType = Int64Type; } impl NumericNative for u8 { - type POLARSTYPE = UInt8Type; + type PolarsType = UInt8Type; } impl NumericNative for u16 { - type POLARSTYPE = UInt16Type; + type PolarsType = UInt16Type; } impl NumericNative for u32 { - type POLARSTYPE = UInt32Type; + type PolarsType = UInt32Type; } impl NumericNative for u64 { - type POLARSTYPE = UInt64Type; + type PolarsType = UInt64Type; } #[cfg(feature = "dtype-decimal")] impl NumericNative for i128 { - type POLARSTYPE = Int128Type; + type PolarsType = Int128Type; } impl NumericNative for f32 { - type POLARSTYPE = Float32Type; + type PolarsType = Float32Type; } impl NumericNative for f64 { - type POLARSTYPE = Float64Type; -} - -pub trait PolarsNumericType: Send + Sync + PolarsDataType + 'static { - type Native: NumericNative; -} -impl PolarsNumericType for UInt8Type { - type Native = u8; -} -impl PolarsNumericType for UInt16Type { - type Native = u16; -} -impl PolarsNumericType for UInt32Type { - type Native = u32; -} -impl PolarsNumericType for UInt64Type { - type Native = u64; -} -impl PolarsNumericType for Int8Type { - type Native = i8; -} -impl PolarsNumericType for Int16Type { - type Native = i16; -} -impl PolarsNumericType for Int32Type { - type Native = i32; -} -impl PolarsNumericType for Int64Type { - type Native = i64; -} -#[cfg(feature = "dtype-decimal")] -impl PolarsNumericType for Int128Type { - type Native = i128; -} -impl PolarsNumericType for Float32Type { - type Native = f32; -} -impl PolarsNumericType for Float64Type { - type Native = f64; -} - -pub trait PolarsIntegerType: PolarsNumericType {} -impl PolarsIntegerType for UInt8Type {} -impl PolarsIntegerType for UInt16Type {} -impl PolarsIntegerType for UInt32Type {} -impl PolarsIntegerType for UInt64Type {} -impl PolarsIntegerType for Int8Type {} -impl PolarsIntegerType for Int16Type {} -impl PolarsIntegerType for Int32Type {} -impl PolarsIntegerType for Int64Type {} - -pub trait PolarsFloatType: PolarsNumericType {} -impl PolarsFloatType for Float32Type {} -impl PolarsFloatType for Float64Type {} - -// Provide options to cloud providers (credentials, region). -pub type CloudOptions = PlHashMap; - -/// Used to safely match the underlying type of Polars data structures. -/// -/// # Safety -/// -/// The underlying physical type of the data structure on which this -/// is implemented must always match the given PolarsDataType. -pub unsafe trait StaticallyMatchesPolarsType {} - -unsafe impl StaticallyMatchesPolarsType for PrimitiveArray {} -unsafe impl StaticallyMatchesPolarsType for PrimitiveArray {} -unsafe impl StaticallyMatchesPolarsType for Utf8Array {} -unsafe impl StaticallyMatchesPolarsType for BinaryArray {} -unsafe impl StaticallyMatchesPolarsType for BooleanArray {} -unsafe impl StaticallyMatchesPolarsType for ListArray {} -#[cfg(feature = "dtype-array")] -unsafe impl StaticallyMatchesPolarsType for FixedSizeListArray {} - -#[doc(hidden)] -pub unsafe trait HasUnderlyingArray { - type ArrayT: StaticArray; -} - -unsafe impl HasUnderlyingArray for ChunkedArray { - type ArrayT = PrimitiveArray; -} - -unsafe impl HasUnderlyingArray for BooleanChunked { - type ArrayT = BooleanArray; -} - -unsafe impl HasUnderlyingArray for Utf8Chunked { - type ArrayT = Utf8Array; -} - -unsafe impl HasUnderlyingArray for BinaryChunked { - type ArrayT = BinaryArray; -} - -unsafe impl HasUnderlyingArray for ListChunked { - type ArrayT = ListArray; -} - -#[cfg(feature = "dtype-array")] -unsafe impl HasUnderlyingArray for ArrayChunked { - type ArrayT = FixedSizeListArray; -} - -#[cfg(feature = "object")] -unsafe impl HasUnderlyingArray for ObjectChunked { - type ArrayT = crate::chunked_array::object::ObjectArray; + type PolarsType = Float64Type; } diff --git a/crates/polars-core/src/datatypes/static_array.rs b/crates/polars-core/src/datatypes/static_array.rs index ecf54b7179e7..8ec8a9f7cf25 100644 --- a/crates/polars-core/src/datatypes/static_array.rs +++ b/crates/polars-core/src/datatypes/static_array.rs @@ -1,12 +1,22 @@ use arrow::bitmap::utils::{BitmapIter, ZipValidity}; use arrow::bitmap::Bitmap; +use bytemuck::Zeroable; #[cfg(feature = "object")] -use crate::chunked_array::object::ObjectArray; +use crate::chunked_array::object::{ObjectArray, ObjectValueIter}; +use crate::datatypes::static_array_collect::ArrayFromIterDtype; use crate::prelude::*; -pub trait StaticArray: Array { - type ValueT<'a> +pub trait StaticArray: + Array + + for<'a> ArrayFromIterDtype> + + for<'a> ArrayFromIterDtype> + + for<'a> ArrayFromIterDtype>> +{ + type ValueT<'a>: Clone + where + Self: 'a; + type ZeroableValueT<'a>: Zeroable + From> where Self: 'a; type ValueIterT<'a>: Iterator> @@ -15,31 +25,115 @@ pub trait StaticArray: Array { where Self: 'a; + #[inline] + fn get(&self, idx: usize) -> Option> { + if idx >= self.len() { + None + } else { + unsafe { self.get_unchecked(idx) } + } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + unsafe fn get_unchecked(&self, idx: usize) -> Option> { + if self.is_null_unchecked(idx) { + None + } else { + Some(self.value_unchecked(idx)) + } + } + + #[inline] + fn last(&self) -> Option> { + unsafe { self.get_unchecked(self.len().checked_sub(1)?) } + } + + #[inline] + fn value(&self, idx: usize) -> Self::ValueT<'_> { + assert!(idx < self.len()); + unsafe { self.value_unchecked(idx) } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_>; + + #[inline(always)] + fn as_slice(&self) -> Option<&[Self::ValueT<'_>]> { + None + } + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter>; fn values_iter(&self) -> Self::ValueIterT<'_>; fn with_validity_typed(self, validity: Option) -> Self; + + fn from_vec(v: Vec>, dtype: DataType) -> Self { + v.into_iter().collect_arr_with_dtype(dtype) + } + + fn from_zeroable_vec(v: Vec>, dtype: DataType) -> Self { + v.into_iter().collect_arr_with_dtype(dtype) + } +} + +pub trait ParameterFreeDtypeStaticArray: StaticArray { + fn get_dtype() -> DataType; } impl StaticArray for PrimitiveArray { type ValueT<'a> = T; + type ZeroableValueT<'a> = T; type ValueIterT<'a> = std::iter::Copied>; + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + fn values_iter(&self) -> Self::ValueIterT<'_> { self.values_iter().copied() } + #[inline(always)] + fn as_slice(&self) -> Option<&[Self::ValueT<'_>]> { + Some(self.values().as_slice()) + } + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { ZipValidity::new_with_validity(self.values().iter().copied(), self.validity()) } + fn with_validity_typed(self, validity: Option) -> Self { self.with_validity(validity) } + + fn from_vec(v: Vec>, _dtype: DataType) -> Self { + PrimitiveArray::from_vec(v) + } + + fn from_zeroable_vec(v: Vec>, _dtype: DataType) -> Self { + PrimitiveArray::from_vec(v) + } +} + +impl ParameterFreeDtypeStaticArray for PrimitiveArray { + fn get_dtype() -> DataType { + T::PolarsType::get_dtype() + } } impl StaticArray for BooleanArray { type ValueT<'a> = bool; + type ZeroableValueT<'a> = bool; type ValueIterT<'a> = BitmapIter<'a>; + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + fn values_iter(&self) -> Self::ValueIterT<'_> { self.values_iter() } @@ -47,15 +141,36 @@ impl StaticArray for BooleanArray { fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { self.iter() } + fn with_validity_typed(self, validity: Option) -> Self { self.with_validity(validity) } + + fn from_vec(v: Vec>, _dtype: DataType) -> Self { + BooleanArray::from_slice(v) + } + + fn from_zeroable_vec(v: Vec>, _dtype: DataType) -> Self { + BooleanArray::from_slice(v) + } +} + +impl ParameterFreeDtypeStaticArray for BooleanArray { + fn get_dtype() -> DataType { + DataType::Boolean + } } impl StaticArray for Utf8Array { type ValueT<'a> = &'a str; + type ZeroableValueT<'a> = Option<&'a str>; type ValueIterT<'a> = Utf8ValuesIter<'a, i64>; + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + fn values_iter(&self) -> Self::ValueIterT<'_> { self.values_iter() } @@ -63,15 +178,28 @@ impl StaticArray for Utf8Array { fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { self.iter() } + fn with_validity_typed(self, validity: Option) -> Self { self.with_validity(validity) } } +impl ParameterFreeDtypeStaticArray for Utf8Array { + fn get_dtype() -> DataType { + DataType::Utf8 + } +} + impl StaticArray for BinaryArray { type ValueT<'a> = &'a [u8]; + type ZeroableValueT<'a> = Option<&'a [u8]>; type ValueIterT<'a> = BinaryValueIter<'a, i64>; + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + fn values_iter(&self) -> Self::ValueIterT<'_> { self.values_iter() } @@ -79,15 +207,28 @@ impl StaticArray for BinaryArray { fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { self.iter() } + fn with_validity_typed(self, validity: Option) -> Self { self.with_validity(validity) } } +impl ParameterFreeDtypeStaticArray for BinaryArray { + fn get_dtype() -> DataType { + DataType::Binary + } +} + impl StaticArray for ListArray { type ValueT<'a> = Box; + type ZeroableValueT<'a> = Option>; type ValueIterT<'a> = ListValuesIter<'a, i64>; + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + fn values_iter(&self) -> Self::ValueIterT<'_> { self.values_iter() } @@ -95,6 +236,7 @@ impl StaticArray for ListArray { fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { self.iter() } + fn with_validity_typed(self, validity: Option) -> Self { self.with_validity(validity) } @@ -103,8 +245,14 @@ impl StaticArray for ListArray { #[cfg(feature = "dtype-array")] impl StaticArray for FixedSizeListArray { type ValueT<'a> = Box; + type ZeroableValueT<'a> = Option>; type ValueIterT<'a> = ArrayValuesIter<'a, FixedSizeListArray>; + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + fn values_iter(&self) -> Self::ValueIterT<'_> { self.values_iter() } @@ -112,6 +260,7 @@ impl StaticArray for FixedSizeListArray { fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { self.iter() } + fn with_validity_typed(self, validity: Option) -> Self { self.with_validity(validity) } @@ -119,17 +268,24 @@ impl StaticArray for FixedSizeListArray { #[cfg(feature = "object")] impl StaticArray for ObjectArray { - type ValueT<'a> = &'a (); - type ValueIterT<'a> = std::slice::Iter<'a, ()>; + type ValueT<'a> = &'a T; + type ZeroableValueT<'a> = Option<&'a T>; + type ValueIterT<'a> = ObjectValueIter<'a, T>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } fn values_iter(&self) -> Self::ValueIterT<'_> { - todo!() + self.values_iter() } fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { - todo!() + self.iter() } - fn with_validity_typed(self, _validity: Option) -> Self { - todo!() + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) } } diff --git a/crates/polars-core/src/datatypes/static_array_collect.rs b/crates/polars-core/src/datatypes/static_array_collect.rs new file mode 100644 index 000000000000..cfb24c68a41e --- /dev/null +++ b/crates/polars-core/src/datatypes/static_array_collect.rs @@ -0,0 +1,885 @@ +use std::borrow::Cow; +use std::sync::Arc; + +#[cfg(feature = "dtype-array")] +use arrow::array::FixedSizeListArray; +use arrow::array::{ + Array, BinaryArray, BooleanArray, ListArray, MutableBinaryArray, MutableBinaryValuesArray, + PrimitiveArray, Utf8Array, +}; +use arrow::bitmap::Bitmap; +#[cfg(feature = "dtype-array")] +use polars_arrow::prelude::fixed_size_list::AnonymousBuilder as AnonymousFixedSizeListArrayBuilder; +use polars_arrow::prelude::list::AnonymousBuilder as AnonymousListArrayBuilder; +use polars_arrow::trusted_len::{TrustedLen, TrustedLenPush}; + +#[cfg(feature = "object")] +use crate::chunked_array::object::{ObjectArray, PolarsObject}; +use crate::datatypes::static_array::ParameterFreeDtypeStaticArray; +use crate::datatypes::{DataType, NumericNative, PolarsDataType, StaticArray}; + +pub trait ArrayFromIterDtype: Sized { + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self; + + #[inline(always)] + fn arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter_with_dtype(dtype, iter) + } + + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result; + + #[inline(always)] + fn try_arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::try_arr_from_iter_with_dtype(dtype, iter) + } +} + +pub trait ArrayFromIter: Sized { + fn arr_from_iter>(iter: I) -> Self; + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter>>(iter: I) -> Result; + + #[inline(always)] + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::try_arr_from_iter(iter) + } +} + +impl> ArrayFromIterDtype for A { + #[inline(always)] + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::arr_from_iter(iter) + } + + #[inline(always)] + fn arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::arr_from_iter_trusted(iter) + } + + #[inline(always)] + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::try_arr_from_iter(iter) + } + + #[inline(always)] + fn try_arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::try_arr_from_iter_trusted(iter) + } +} + +pub trait ArrayCollectIterExt: Iterator + Sized { + #[inline(always)] + fn collect_arr(self) -> A + where + A: ArrayFromIter, + { + A::arr_from_iter(self) + } + + #[inline(always)] + fn collect_arr_trusted(self) -> A + where + A: ArrayFromIter, + Self: TrustedLen, + { + A::arr_from_iter_trusted(self) + } + + #[inline(always)] + fn try_collect_arr(self) -> Result + where + A: ArrayFromIter, + Self: Iterator>, + { + A::try_arr_from_iter(self) + } + + #[inline(always)] + fn try_collect_arr_trusted(self) -> Result + where + A: ArrayFromIter, + Self: Iterator> + TrustedLen, + { + A::try_arr_from_iter_trusted(self) + } + + #[inline(always)] + fn collect_arr_with_dtype(self, dtype: DataType) -> A + where + A: ArrayFromIterDtype, + { + A::arr_from_iter_with_dtype(dtype, self) + } + + #[inline(always)] + fn collect_arr_trusted_with_dtype(self, dtype: DataType) -> A + where + A: ArrayFromIterDtype, + Self: TrustedLen, + { + A::arr_from_iter_trusted_with_dtype(dtype, self) + } + + #[inline(always)] + fn try_collect_arr_with_dtype(self, dtype: DataType) -> Result + where + A: ArrayFromIterDtype, + Self: Iterator>, + { + A::try_arr_from_iter_with_dtype(dtype, self) + } + + #[inline(always)] + fn try_collect_arr_trusted_with_dtype(self, dtype: DataType) -> Result + where + A: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + A::try_arr_from_iter_trusted_with_dtype(dtype, self) + } +} + +impl ArrayCollectIterExt for I {} + +// --------------- +// Implementations +// --------------- +macro_rules! impl_collect_vec_validity { + ($iter: ident, $x:ident, $unpack:expr) => {{ + let mut iter = $iter.into_iter(); + let mut buf: Vec = Vec::new(); + let mut bitmap: Vec = Vec::new(); + let lo = iter.size_hint().0; + buf.reserve(8 + lo); + bitmap.reserve(8 + 8 * (lo / 64)); + + let mut nonnull_count = 0; + let mut mask = 0u8; + 'exhausted: loop { + unsafe { + // SAFETY: when we enter this loop we always have at least one + // capacity in bitmap, and at least 8 in buf. + for i in 0..8 { + let Some($x) = iter.next() else { + break 'exhausted; + }; + #[allow(clippy::all)] + // #[allow(clippy::redundant_locals)] Clippy lint too new + let x = $unpack; + let nonnull = x.is_some(); + mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + buf.push_unchecked(x.unwrap_or_default()); + } + + bitmap.push_unchecked(mask); + mask = 0; + } + + buf.reserve(8); + if bitmap.len() == bitmap.capacity() { + bitmap.reserve(8); // Waste some space to make branch more predictable. + } + } + + unsafe { + // SAFETY: when we broke to 'exhausted we had capacity by the loop invariant. + // It's also no problem if we make the mask bigger than strictly necessary. + bitmap.push_unchecked(mask); + } + + let null_count = buf.len() - nonnull_count; + let arrow_bitmap = if null_count > 0 { + unsafe { + // SAFETY: we made sure the null_count is correct. + Some(Bitmap::from_inner(Arc::new(bitmap.into()), 0, buf.len(), null_count).unwrap()) + } + } else { + None + }; + + (buf, arrow_bitmap) + }}; +} + +macro_rules! impl_trusted_collect_vec_validity { + ($iter: ident, $x:ident, $unpack:expr) => {{ + let mut iter = $iter.into_iter(); + let mut buf: Vec = Vec::new(); + let mut bitmap: Vec = Vec::new(); + let n = iter.size_hint().1.expect("must have an upper bound"); + buf.reserve(n); + bitmap.reserve(8 + 8 * (n / 64)); + + let mut nonnull_count = 0; + while buf.len() + 8 <= n { + unsafe { + let mut mask = 0u8; + for i in 0..8 { + let $x = iter.next().unwrap_unchecked(); + #[allow(clippy::all)] + // #[allow(clippy::redundant_locals)] Clippy lint too new + let x = $unpack; + let nonnull = x.is_some(); + mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + buf.push_unchecked(x.unwrap_or_default()); + } + bitmap.push_unchecked(mask); + } + } + + if buf.len() < n { + unsafe { + let mut mask = 0u8; + for i in 0..n - buf.len() { + let $x = iter.next().unwrap_unchecked(); + let x = $unpack; + let nonnull = x.is_some(); + mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + buf.push_unchecked(x.unwrap_or_default()); + } + bitmap.push_unchecked(mask); + } + } + + let null_count = buf.len() - nonnull_count; + let arrow_bitmap = if null_count > 0 { + unsafe { + // SAFETY: we made sure the null_count is correct. + Some(Bitmap::from_inner(Arc::new(bitmap.into()), 0, buf.len(), null_count).unwrap()) + } + } else { + None + }; + + (buf, arrow_bitmap) + }}; +} + +impl ArrayFromIter for PrimitiveArray { + #[inline] + fn arr_from_iter>(iter: I) -> Self { + PrimitiveArray::from_vec(iter.into_iter().collect()) + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + PrimitiveArray::from_vec(Vec::from_trusted_len_iter(iter)) + } + + #[inline] + fn try_arr_from_iter>>(iter: I) -> Result { + let v: Result, E> = iter.into_iter().collect(); + Ok(PrimitiveArray::from_vec(v?)) + } + + #[inline] + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + let v = Vec::try_from_trusted_len_iter(iter); + Ok(PrimitiveArray::from_vec(v?)) + } +} + +impl ArrayFromIter> for PrimitiveArray { + fn arr_from_iter>>(iter: I) -> Self { + let (buf, validity) = impl_collect_vec_validity!(iter, x, x); + PrimitiveArray::new(T::PolarsType::get_dtype().to_arrow(), buf.into(), validity) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + let (buf, validity) = impl_trusted_collect_vec_validity!(iter, x, x); + PrimitiveArray::new(T::PolarsType::get_dtype().to_arrow(), buf.into(), validity) + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let (buf, validity) = impl_collect_vec_validity!(iter, x, x?); + Ok(PrimitiveArray::new( + T::PolarsType::get_dtype().to_arrow(), + buf.into(), + validity, + )) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator, E>>, + I::IntoIter: TrustedLen, + { + let (buf, validity) = impl_trusted_collect_vec_validity!(iter, x, x?); + Ok(PrimitiveArray::new( + T::PolarsType::get_dtype().to_arrow(), + buf.into(), + validity, + )) + } +} + +// We don't use AsRef here because it leads to problems with conflicting implementations, +// as Rust considers that AsRef<[u8]> for Option<&[u8]> could be implemented. +trait IntoBytes { + type AsRefT: AsRef<[u8]>; + fn into_bytes(self) -> Self::AsRefT; +} +trait TrivialIntoBytes: AsRef<[u8]> {} +impl IntoBytes for T { + type AsRefT = Self; + fn into_bytes(self) -> Self { + self + } +} +impl TrivialIntoBytes for Vec {} +impl<'a> TrivialIntoBytes for Cow<'a, [u8]> {} +impl<'a> TrivialIntoBytes for &'a [u8] {} +impl TrivialIntoBytes for String {} +impl<'a> TrivialIntoBytes for &'a str {} +impl<'a> IntoBytes for Cow<'a, str> { + type AsRefT = Cow<'a, [u8]>; + fn into_bytes(self) -> Cow<'a, [u8]> { + match self { + Cow::Borrowed(a) => Cow::Borrowed(a.as_bytes()), + Cow::Owned(s) => Cow::Owned(s.into_bytes()), + } + } +} + +impl ArrayFromIter for BinaryArray { + fn arr_from_iter>(iter: I) -> Self { + BinaryArray::from_iter_values(iter.into_iter().map(|s| s.into_bytes())) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: our iterator is TrustedLen. + MutableBinaryArray::from_trusted_len_values_iter_unchecked( + iter.into_iter().map(|s| s.into_bytes()), + ) + .into() + } + } + + fn try_arr_from_iter>>(iter: I) -> Result { + // No built-in for this? + let mut arr = MutableBinaryValuesArray::new(); + let mut iter = iter.into_iter(); + arr.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push(x?.into_bytes()); + Ok(()) + })?; + Ok(arr.into()) + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +impl ArrayFromIter> for BinaryArray { + fn arr_from_iter>>(iter: I) -> Self { + BinaryArray::from_iter(iter.into_iter().map(|s| Some(s?.into_bytes()))) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: the iterator is TrustedLen. + BinaryArray::from_trusted_len_iter_unchecked( + iter.into_iter().map(|s| Some(s?.into_bytes())), + ) + } + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + // No built-in for this? + let mut arr = MutableBinaryArray::new(); + let mut iter = iter.into_iter(); + arr.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push(x?.map(|s| s.into_bytes())); + Ok(()) + })?; + Ok(arr.into()) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator, E>>, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: the iterator is TrustedLen. + BinaryArray::try_from_trusted_len_iter_unchecked( + iter.into_iter().map(|s| s.map(|s| Some(s?.into_bytes()))), + ) + } + } +} + +/// We use this to re-use the binary collect implementation for strings. +/// # Safety +/// The array must be valid UTF-8. +unsafe fn into_utf8array(arr: BinaryArray) -> Utf8Array { + unsafe { + let (_dt, offsets, values, validity) = arr.into_inner(); + let dt = arrow::datatypes::DataType::LargeUtf8; + Utf8Array::try_new_unchecked(dt, offsets, values, validity).unwrap_unchecked() + } +} + +trait StrIntoBytes: IntoBytes {} +impl StrIntoBytes for String {} +impl<'a> StrIntoBytes for &'a str {} +impl<'a> StrIntoBytes for Cow<'a, str> {} + +impl ArrayFromIter for Utf8Array { + #[inline(always)] + fn arr_from_iter>(iter: I) -> Self { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn try_arr_from_iter>>(iter: I) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } + + #[inline(always)] + fn try_arr_from_iter_trusted>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } +} + +impl ArrayFromIter> for Utf8Array { + #[inline(always)] + fn arr_from_iter>>(iter: I) -> Self { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } + + #[inline(always)] + fn try_arr_from_iter_trusted, E>>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } +} + +macro_rules! impl_collect_bool_validity { + ($iter: ident, $x:ident, $unpack:expr, $truth:expr, $nullity:expr, $with_valid:literal) => {{ + let mut iter = $iter.into_iter(); + let mut buf: Vec = Vec::new(); + let mut validity: Vec = Vec::new(); + let lo = iter.size_hint().0; + buf.reserve(8 + 8 * (lo / 64)); + if $with_valid { + validity.reserve(8 + 8 * (lo / 64)); + } + + let mut len = 0; + let mut buf_mask = 0u8; + let mut true_count = 0; + let mut valid_mask = 0u8; + let mut nonnull_count = 0; + 'exhausted: loop { + unsafe { + for i in 0..8 { + let Some($x) = iter.next() else { + break 'exhausted; + }; + #[allow(clippy::all)] + // #[allow(clippy::redundant_locals)] Clippy lint too new + let $x = $unpack; + let is_true: bool = $truth; + buf_mask |= (is_true as u8) << i; + true_count += is_true as usize; + if $with_valid { + let nonnull: bool = $nullity; + valid_mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + } + len += 1; + } + + buf.push_unchecked(buf_mask); + buf_mask = 0; + if $with_valid { + validity.push_unchecked(valid_mask); + valid_mask = 0; + } + } + + if buf.len() == buf.capacity() { + buf.reserve(8); // Waste some space to make branch more predictable. + if $with_valid { + validity.reserve(8); + } + } + } + + unsafe { + // SAFETY: when we broke to 'exhausted we had capacity by the loop invariant. + // It's also no problem if we make the mask bigger than strictly necessary. + buf.push_unchecked(buf_mask); + if $with_valid { + validity.push_unchecked(valid_mask); + } + } + + let false_count = len - true_count; + let values = + unsafe { Bitmap::from_inner(Arc::new(buf.into()), 0, len, false_count).unwrap() }; + + let null_count = len - nonnull_count; + let validity_bitmap = if $with_valid && null_count > 0 { + unsafe { + // SAFETY: we made sure the null_count is correct. + Some(Bitmap::from_inner(Arc::new(validity.into()), 0, len, null_count).unwrap()) + } + } else { + None + }; + + (values, validity_bitmap) + }}; +} + +impl ArrayFromIter for BooleanArray { + fn arr_from_iter>(iter: I) -> Self { + let dt = arrow::datatypes::DataType::Boolean; + let (values, _valid) = impl_collect_bool_validity!(iter, x, x, x, false, false); + BooleanArray::new(dt, values, None) + } + + // TODO: are efficient trusted collects for booleans worth it? + // fn arr_from_iter_trusted(iter: I) -> Self + + fn try_arr_from_iter>>(iter: I) -> Result { + let dt = arrow::datatypes::DataType::Boolean; + let (values, _valid) = impl_collect_bool_validity!(iter, x, x?, x, false, false); + Ok(BooleanArray::new(dt, values, None)) + } + + // fn try_arr_from_iter_trusted>>( +} + +impl ArrayFromIter> for BooleanArray { + fn arr_from_iter>>(iter: I) -> Self { + let dt = arrow::datatypes::DataType::Boolean; + let (values, valid) = + impl_collect_bool_validity!(iter, x, x, x.unwrap_or(false), x.is_some(), true); + BooleanArray::new(dt, values, valid) + } + + // fn arr_from_iter_trusted(iter: I) -> Self + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let dt = arrow::datatypes::DataType::Boolean; + let (values, valid) = + impl_collect_bool_validity!(iter, x, x?, x.unwrap_or(false), x.is_some(), true); + Ok(BooleanArray::new(dt, values, valid)) + } + + // fn try_arr_from_iter_trusted, E>>>( +} + +// We don't use AsRef here because it leads to problems with conflicting implementations, +// as Rust considers that AsRef for Option<&dyn Array> could be implemented. +trait AsArray { + fn as_array(&self) -> &dyn Array; + fn into_boxed_array(self) -> Box; // Prevents unnecessary re-boxing. +} +impl AsArray for Box { + fn as_array(&self) -> &dyn Array { + self.as_ref() + } + fn into_boxed_array(self) -> Box { + self + } +} +impl<'a> AsArray for &'a dyn Array { + fn as_array(&self) -> &'a dyn Array { + *self + } + fn into_boxed_array(self) -> Box { + self.to_boxed() + } +} + +// TODO: more efficient (fixed size) list collect routines. +impl ArrayFromIterDtype for ListArray { + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self { + let iter_values: Vec = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push(arr.as_array()); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +impl ArrayFromIterDtype> for ListArray { + fn arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Self { + let iter_values: Vec> = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push_opt(arr.as_ref().map(|a| a.as_array())); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayFromIterDtype> for FixedSizeListArray { + fn arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Self { + let DataType::Array(_, width) = &dtype else { + panic!("FixedSizeListArray::arr_from_iter_with_dtype called with non-Array dtype"); + }; + let iter_values: Vec<_> = iter.into_iter().collect(); + let mut builder = AnonymousFixedSizeListArrayBuilder::new(iter_values.len(), *width); + for arr in iter_values { + builder.push(arr.into_boxed_array()); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayFromIterDtype>> for FixedSizeListArray { + fn arr_from_iter_with_dtype>>>( + dtype: DataType, + iter: I, + ) -> Self { + let DataType::Array(_, width) = &dtype else { + panic!("FixedSizeListArray::arr_from_iter_with_dtype called with non-Array dtype"); + }; + let iter_values: Vec<_> = iter.into_iter().collect(); + let mut builder = AnonymousFixedSizeListArrayBuilder::new(iter_values.len(), *width); + for arr in iter_values { + match arr { + Some(a) => builder.push(a.into_boxed_array()), + None => builder.push_null(), + } + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype< + E, + I: IntoIterator>, E>>, + >( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +// TODO: more efficient implementations, I really took the short path here. +#[cfg(feature = "object")] +impl<'a, T: PolarsObject> ArrayFromIterDtype<&'a T> for ObjectArray { + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self { + Self::try_arr_from_iter_with_dtype( + dtype, + iter.into_iter().map(|o| -> Result<_, ()> { Ok(Some(o)) }), + ) + .unwrap() + } + + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result { + Self::try_arr_from_iter_with_dtype(dtype, iter.into_iter().map(|o| Ok(Some(o?)))) + } +} + +#[cfg(feature = "object")] +impl<'a, T: PolarsObject> ArrayFromIterDtype> for ObjectArray { + fn arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Self { + Self::try_arr_from_iter_with_dtype( + dtype, + iter.into_iter().map(|o| -> Result<_, ()> { Ok(o) }), + ) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + _dtype: DataType, + iter: I, + ) -> Result { + let iter = iter.into_iter(); + let size = iter.size_hint().0; + + let mut null_mask_builder = arrow::bitmap::MutableBitmap::with_capacity(size); + let values: Vec = iter + .map(|value| match value? { + Some(value) => { + null_mask_builder.push(true); + Ok(value.clone()) + }, + None => { + null_mask_builder.push(false); + Ok(T::default()) + }, + }) + .collect::, E>>()?; + + let null_bit_buffer: Option = null_mask_builder.into(); + let null_bitmap = null_bit_buffer; + let len = values.len(); + Ok(ObjectArray { + values: Arc::new(values), + null_bitmap, + offset: 0, + len, + }) + } +} diff --git a/crates/polars-core/src/datatypes/time_unit.rs b/crates/polars-core/src/datatypes/time_unit.rs index 83305320e208..996b9be0e8c8 100644 --- a/crates/polars-core/src/datatypes/time_unit.rs +++ b/crates/polars-core/src/datatypes/time_unit.rs @@ -1,6 +1,6 @@ use super::*; -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] #[cfg_attr( any(feature = "serde-lazy", feature = "serde"), derive(Serialize, Deserialize) diff --git a/crates/polars-core/src/doc/changelog/mod.rs b/crates/polars-core/src/doc/changelog/mod.rs deleted file mode 100644 index 40f167264afc..000000000000 --- a/crates/polars-core/src/doc/changelog/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod v0_10_0_11; -pub mod v0_3; -pub mod v0_4; -pub mod v0_5; -pub mod v0_6; -pub mod v0_7; -pub mod v0_8; -pub mod v0_9; diff --git a/crates/polars-core/src/doc/changelog/v0_10_0_11.rs b/crates/polars-core/src/doc/changelog/v0_10_0_11.rs deleted file mode 100644 index 8136f24f8f80..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_10_0_11.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! # Changelog v0.10 / v0.11 -//! -//! * CSV Read IO -//! - Parallel csv reader -//! * Sample DataFrames/ Series -//! * Performance increase in take kernel -//! * Performance increase in ChunkedArray builders -//! * Join operation on multiple columns. -//! * ~3.5 x performance increase in group_by operations (measured on db-benchmark), -//! due to embarrassingly parallel grouping and better branch prediction (tight loops). -//! * Performance increase on join operation due to better branch prediction. -//! * Categorical datatype and global string cache (BETA). -//! -//! * Lazy -//! - Lot's of bug fixes in optimizer. -//! - Parallel execution of Physical plan -//! - Partition window function -//! - More simplify expression optimizations. -//! - Caching -//! - Alpha release of Aggregate pushdown optimization. -//! * Start of general Object type in ChunkedArray/DataFrames/Series diff --git a/crates/polars-core/src/doc/changelog/v0_3.rs b/crates/polars-core/src/doc/changelog/v0_3.rs deleted file mode 100644 index 738021313cdc..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_3.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! # Changelog v0.3 -//! -//! * Utf8 type is nullable [#37](https://github.com/pola-rs/polars/issues/37) -//! * Support all ARROW numeric types [#40](https://github.com/pola-rs/polars/issues/40) -//! * Support all ARROW temporal types [#46](https://github.com/pola-rs/polars/issues/46) -//! * ARROW IPC Reader/ Writer [#50](https://github.com/pola-rs/polars/issues/50) -//! * Implement DoubleEndedIterator trait for ChunkedArray's [#34](https://github.com/pola-rs/polars/issues/34) -//! diff --git a/crates/polars-core/src/doc/changelog/v0_4.rs b/crates/polars-core/src/doc/changelog/v0_4.rs deleted file mode 100644 index d357526134ef..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_4.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! # Changelog v0.4 -//! -//! * median aggregation added to `ChunkedArray` -//! * Arrow LargeList datatype support (and group_by aggregation into LargeList). -//! * Shift operation. -//! * Fill None operation. -//! * Buffered serialization (less memory requirements) -//! * Temporal utilities -//! diff --git a/crates/polars-core/src/doc/changelog/v0_5.rs b/crates/polars-core/src/doc/changelog/v0_5.rs deleted file mode 100644 index 7d82f3271cc0..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_5.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! # Changelog v0.5 -//! -//! * `DataFrame.column` returns `Result<_>` **breaking change**. -//! * Define idiomatic way to do inplace operations on a `DataFrame` with `apply`, `try_apply` and `ChunkSet` -//! * `ChunkSet` Trait. -//! * `Groupby` aggregations can be done on a selection of multiple columns. -//! * `Groupby` operation can be done on multiple keys. -//! * `Groupby` `first` operation. -//! * `Pivot` operation. -//! * Random access to `ChunkedArray` types via `.get` and `.get_unchecked`. -//! diff --git a/crates/polars-core/src/doc/changelog/v0_6.rs b/crates/polars-core/src/doc/changelog/v0_6.rs deleted file mode 100644 index 23e38f3d2369..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_6.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! # Changelog v0.6 -//! -//! * Add more distributions for random sampling. -//! * Fix float aggregations with NaNs. -//! * Comparisons are more performant. -//! * Outer join is more performant. -//! * Start with parallel iterator support for ChunkedArrays. -//! * Remove crossbeam dependency. diff --git a/crates/polars-core/src/doc/changelog/v0_7.rs b/crates/polars-core/src/doc/changelog/v0_7.rs deleted file mode 100644 index 55996f2fcaa5..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_7.rs +++ /dev/null @@ -1,32 +0,0 @@ -//! # Changelog v0.7 -//! -//! * More group by aggregations: -//! - n_unique -//! - quantile -//! - median -//! - last -//! - group indexes -//! - agg (combined aggregations) -//! * explode operation -//! * melt operation -//! * df! macro -//! * Rem trait implemented for Series and ChunkedArrays -//! * ChunkedArrays broadcasting arithmetic -//! * ChunkedArray/Series `zip_with` operation -//! * ChunkedArray/Series `new_from_index` operation -//! * laziness api initiated. -//! - Predicate pushdown optimizer -//! - Projection pushdown optimizer -//! - Type coercion optimizer -//! - Selection (filter, where clause) -//! - Projection (select foo from bar) -//! - Aggregation (group_by) -//! - all eager aggregations supported -//! - Joins -//! - WithColumn operation -//! - DSL -//! * (col, lit, lt, lt_eq, alias, etc.) -//! * arithmetic -//! * when / then /otherwise -//! * 1.3-1.7 performance increase of filter -//! * ChunkedArray/ Series creation speedup: No nulls: 10X speedup, Nulls: 1.1-2.2x speedup. diff --git a/crates/polars-core/src/doc/changelog/v0_8.rs b/crates/polars-core/src/doc/changelog/v0_8.rs deleted file mode 100644 index 3d7c6fdabb8f..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_8.rs +++ /dev/null @@ -1,24 +0,0 @@ -//! # Changelog v0.8 -//! -//! * Upgrade to Arrow 2.0 -//! * Add quantile aggregation to `ChunkedArray` -//! * Option to stop reading CSV after n rows. -//! * Read parquet file in a single batch reducing reading time. -//! * Faster kernel for zip_with and set_with operation -//! * String utilities -//! - Utf8Chunked::str_lengths method -//! - Utf8Chunked::contains method -//! - Utf8Chunked::replace method -//! - Utf8Chunked::replace_all method -//! * Temporal utilities -//! - Utf8Chunked to dat32 / datetime -//! * Lazy -//! - fill_null expression -//! - shift expression -//! - Series aggregations -//! - aggregations on DataFrame level -//! - aggregate to largelist -//! - a lot of bugs fixed in optimizers -//! - UDF's / closures in lazy dsl -//! - DataFrame reverse operation -//! diff --git a/crates/polars-core/src/doc/changelog/v0_9.rs b/crates/polars-core/src/doc/changelog/v0_9.rs deleted file mode 100644 index f0ece2b79bcf..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_9.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! # Changelog v0.9 -//! -//! * CSV Read IO -//! - large performance increase -//! - skip_rows -//! - ignore parser errors -//! * Overall performance increase by using aHash in favor of FNV. -//! * Groupby floating point keys -//! * DataFrame operations -//! - drop_nulls -//! - drop duplicate rows -//! * Temporal handling -//! * Lazy -//! - a lot of bug fixes in the optimizer -//! - start of optimizer framework -//! - start of simplify expression optimizer -//! - csv scan -//! - various operations -//! * Start of general Object type in ChunkedArray/DataFrames/Series diff --git a/crates/polars-core/src/doc/mod.rs b/crates/polars-core/src/doc/mod.rs deleted file mode 100644 index 18169f152474..000000000000 --- a/crates/polars-core/src/doc/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Other documentation -pub mod changelog; diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index a309cc249a87..59c69f8cdc95 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -197,7 +197,6 @@ where { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let limit = std::cmp::min(LIMIT, self.len()); - let taker = self.take_rand(); let inner_type = T::type_name(); write!( f, @@ -208,21 +207,21 @@ where if limit < self.len() { for i in 0..limit / 2 { - match taker.get(i) { + match self.get(i) { None => writeln!(f, "\tnull")?, Some(val) => writeln!(f, "\t{val}")?, }; } writeln!(f, "\t…")?; for i in (0..limit / 2).rev() { - match taker.get(self.len() - i - 1) { + match self.get(self.len() - i - 1) { None => writeln!(f, "\tnull")?, Some(val) => writeln!(f, "\t{val}")?, }; } } else { for i in 0..limit { - match taker.get(i) { + match self.get(i) { None => writeln!(f, "\tnull")?, Some(val) => writeln!(f, "\t{val}")?, }; @@ -314,7 +313,7 @@ impl Debug for Series { "Series" ), DataType::Null => { - writeln!(f, "nullarray") + format_array!(f, self.null().unwrap(), "null", self.name(), "Series") }, DataType::Binary => { format_array!(f, self.binary().unwrap(), "binary", self.name(), "Series") diff --git a/crates/polars-core/src/frame/arithmetic.rs b/crates/polars-core/src/frame/arithmetic.rs index 4488640a0b45..be60fb04346f 100644 --- a/crates/polars-core/src/frame/arithmetic.rs +++ b/crates/polars-core/src/frame/arithmetic.rs @@ -4,8 +4,9 @@ use rayon::prelude::*; use crate::prelude::*; use crate::utils::try_get_supertype; +use crate::POOL; -/// Get the supertype that is valid for all columns in the DataFrame. +/// Get the supertype that is valid for all columns in the [`DataFrame`]. /// This reduces casting of the rhs in arithmetic. fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult { df.columns.iter().try_fold(rhs.dtype().clone(), |dt, s| { @@ -17,9 +18,9 @@ macro_rules! impl_arithmetic { ($self:expr, $rhs:expr, $operand: tt) => {{ let st = get_supertype_all($self, $rhs)?; let rhs = $rhs.cast(&st)?; - let cols = $self.columns.par_iter().map(|s| { + let cols = POOL.install(|| {$self.columns.par_iter().map(|s| { Ok(&s.cast(&st)? $operand &rhs) - }).collect::>()?; + }).collect::>()})?; Ok(DataFrame::new_no_checks(cols)) }} } diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs index e14fbff23cc0..ed986f7aaefc 100644 --- a/crates/polars-core/src/frame/from.rs +++ b/crates/polars-core/src/frame/from.rs @@ -17,7 +17,7 @@ impl TryFrom for DataFrame { .map(|(fld, arr)| { // Safety // reported data type is correct - unsafe { Series::try_from_arrow_unchecked(&fld.name, vec![arr], fld.data_type()) } + unsafe { Series::_try_from_arrow_unchecked(&fld.name, vec![arr], fld.data_type()) } }) .collect::>>()?; DataFrame::new(columns) diff --git a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs index f26bf095d30f..caeea68b8fef 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs @@ -157,7 +157,7 @@ impl AggList for BooleanChunked { let mut builder = ListBooleanChunkedBuilder::new(self.name(), groups.len(), self.len()); for idx in groups.all().iter() { - let ca = { self.take_unchecked(idx.into()) }; + let ca = { self.take_unchecked(idx) }; builder.append(&ca) } builder.finish().into_series() @@ -183,7 +183,7 @@ impl AggList for Utf8Chunked { let mut builder = ListUtf8ChunkedBuilder::new(self.name(), groups.len(), self.len()); for idx in groups.all().iter() { - let ca = { self.take_unchecked(idx.into()) }; + let ca = { self.take_unchecked(idx) }; builder.append(&ca) } builder.finish().into_series() @@ -208,7 +208,7 @@ impl AggList for BinaryChunked { let mut builder = ListBinaryChunkedBuilder::new(self.name(), groups.len(), self.len()); for idx in groups.all().iter() { - let ca = { self.take_unchecked(idx.into()) }; + let ca = { self.take_unchecked(idx) }; builder.append(&ca) } builder.finish().into_series() @@ -226,8 +226,8 @@ impl AggList for BinaryChunked { } } -/// This aggregates into a `ListChunked` by slicing the array that is aggregated. -/// Used for `List` and `Array` data types. +/// This aggregates into a [`ListChunked`] by slicing the array that is aggregated. +/// Used for [`List`] and [`Array`] data types. fn agg_list_by_slicing< A: PolarsDataType, F: Fn(&ChunkedArray, bool, &mut Vec, &mut i64, &mut Vec) -> bool, @@ -292,7 +292,7 @@ impl AggList for ListChunked { // SAFETY: // group tuples are in bounds { - let mut s = ca.take_unchecked(idx.into()); + let mut s = ca.take_unchecked(idx); let arr = s.chunks.pop().unwrap_unchecked_release(); list_values.push_unchecked(arr); @@ -362,7 +362,7 @@ impl AggList for ArrayChunked { // SAFETY: group tuples are in bounds { - let mut s = ca.take_unchecked(idx.into()); + let mut s = ca.take_unchecked(idx); let arr = s.chunks.pop().unwrap_unchecked_release(); list_values.push_unchecked(arr); } @@ -419,7 +419,7 @@ impl AggList for ObjectChunked { GroupsIndicator::Idx((_first, idx)) => { // SAFETY: // group tuples always in bounds - let group_vals = self.take_unchecked(idx.into()); + let group_vals = self.take_unchecked(idx); (group_vals, idx.len() as IdxSize) }, @@ -481,7 +481,7 @@ impl AggList for StructChunked { Some(self.dtype().clone()), ); for idx in groups.all().iter() { - let taken = s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)); + let taken = s.take_slice_unchecked(idx); builder.append_series(&taken).unwrap(); } builder.finish().into_series() diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index 0e353fe12be2..abee55a44fd0 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -24,8 +24,7 @@ impl Series { } else if !self.has_validity() { Some(idx.len() as IdxSize) } else { - let take = - unsafe { self.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) }; + let take = unsafe { self.take_slice_unchecked(idx) }; Some((take.len() - take.null_count()) as IdxSize) } }), @@ -49,31 +48,28 @@ impl Series { pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Series { let mut out = match groups { GroupsProxy::Idx(groups) => { - let mut iter = groups.iter().map(|(first, idx)| { - if idx.is_empty() { - None - } else { - Some(first as usize) - } - }); - // Safety: - // groups are always in bounds - self.take_opt_iter_unchecked(&mut iter) - }, - GroupsProxy::Slice { groups, .. } => { - let mut iter = - groups.iter().map( - |&[first, len]| { - if len == 0 { + let indices = groups + .iter() + .map( + |(first, idx)| { + if idx.is_empty() { None } else { - Some(first as usize) + Some(first) } }, - ); - // Safety: - // groups are always in bounds - self.take_opt_iter_unchecked(&mut iter) + ) + .collect_ca(""); + // SAFETY: groups are always in bounds. + self.take_unchecked(&indices) + }, + GroupsProxy::Slice { groups, .. } => { + let indices = groups + .iter() + .map(|&[first, len]| if len == 0 { None } else { Some(first) }) + .collect_ca(""); + // SAFETY: groups are always in bounds. + self.take_unchecked(&indices) }, }; if groups.is_sorted_flag() { @@ -90,7 +86,7 @@ impl Series { if idx.is_empty() { None } else { - let take = self.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)); + let take = self.take_slice_unchecked(idx); take.n_unique().ok().map(|v| v as IdxSize) } }), @@ -186,24 +182,31 @@ impl Series { pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series { let out = match groups { GroupsProxy::Idx(groups) => { - let mut iter = groups.all().iter().map(|idx| { - if idx.is_empty() { - None - } else { - Some(idx[idx.len() - 1] as usize) - } - }); - self.take_opt_iter_unchecked(&mut iter) + let indices = groups + .all() + .iter() + .map(|idx| { + if idx.is_empty() { + None + } else { + Some(idx[idx.len() - 1]) + } + }) + .collect_ca(""); + self.take_unchecked(&indices) }, GroupsProxy::Slice { groups, .. } => { - let mut iter = groups.iter().map(|&[first, len]| { - if len == 0 { - None - } else { - Some((first + len - 1) as usize) - } - }); - self.take_opt_iter_unchecked(&mut iter) + let indices = groups + .iter() + .map(|&[first, len]| { + if len == 0 { + None + } else { + Some(first + len - 1) + } + }) + .collect_ca(""); + self.take_unchecked(&indices) }, }; self.restore_logical(out) diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index b40f9137c554..51b1f24d3e9d 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -355,7 +355,7 @@ where if idx.is_empty() { return None; } - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; // checked with invalid quantile check take._quantile(quantile, interpol).unwrap_unchecked() }) @@ -429,7 +429,7 @@ where if idx.is_empty() { return None; } - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; take._median() }) }, @@ -527,7 +527,7 @@ where 1 => self.get(first as usize), _ => { let arr_group = _slice_from_offsets(self, first, len); - arr_group.min() + ChunkAgg::min(&arr_group) }, } }) @@ -610,7 +610,7 @@ where 1 => self.get(first as usize), _ => { let arr_group = _slice_from_offsets(self, first, len); - arr_group.max() + ChunkAgg::max(&arr_group) }, } }) @@ -691,11 +691,8 @@ where impl SeriesWrap> where T: PolarsFloatType, - ChunkedArray: IntoSeries - + ChunkVar - + VarAggSeries - + ChunkQuantile - + QuantileAggSeries, + ChunkedArray: + IntoSeries + ChunkVar + VarAggSeries + ChunkQuantile + QuantileAggSeries, T::Native: Simd + NumericNative + Pow, ::Simd: std::ops::Add::Simd> + arrow::compute::aggregate::Sum @@ -987,7 +984,7 @@ where }) }, _ => { - let take = { self.take_unchecked(idx.into()) }; + let take = { self.take_unchecked(idx) }; take.mean() }, } @@ -1114,5 +1111,3 @@ where agg_median_generic::<_, Float64Type>(self, groups) } } - -impl ChunkedArray where ChunkedArray: ChunkTake + IntoSeries {} diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index 3e24df8817b3..462e10fbe5cf 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -10,24 +10,20 @@ use super::GroupsProxy; use crate::datatypes::PlHashMap; use crate::frame::group_by::{GroupsIdx, IdxItem}; use crate::hashing::{ - df_rows_to_hashes_threaded_vertical, series_to_hashes, this_partition, AsU64, IdBuildHasher, - IdxHash, + _df_rows_to_hashes_threaded_vertical, series_to_hashes, this_partition, AsU64, IdBuildHasher, + IdxHash, *, }; use crate::prelude::compare_inner::PartialEqInner; use crate::prelude::*; use crate::utils::{flatten, split_df, CustomIterTools}; use crate::POOL; -// We must strike a balance between cache coherence and resizing costs. -// Overallocation seems a lot more expensive than resizing so we start reasonable small. -pub(crate) const HASHMAP_INIT_SIZE: usize = 512; - fn get_init_size() -> usize { // we check if this is executed from the main thread // we don't want to pre-allocate this much if executed // group_tuples in a parallel iterator as that explodes allocation if POOL.current_thread_index().is_none() { - HASHMAP_INIT_SIZE + _HASHMAP_INIT_SIZE } else { 0 } @@ -82,9 +78,9 @@ fn finish_group_order(mut out: Vec>, sorted: bool) -> GroupsProxy { } } -// The inner vecs should be sorted by IdxSize +// The inner vecs should be sorted by [`IdxSize`] // the group_by multiple keys variants suffice -// this requirements as they use an IdxMap strategy +// this requirements as they use an [`IdxMap`] strategy fn finish_group_order_vecs( mut vecs: Vec<(Vec, Vec>)>, sorted: bool, @@ -319,75 +315,6 @@ where finish_group_order(out, sorted) } -/// Utility function used as comparison function in the hashmap. -/// The rationale is that equality is an AND operation and therefore its probability of success -/// declines rapidly with the number of keys. Instead of first copying an entire row from both -/// sides and then do the comparison, we do the comparison value by value catching early failures -/// eagerly. -/// -/// # Safety -/// Doesn't check any bounds -#[inline] -pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usize) -> bool { - for s in keys.get_columns() { - if !s.equal_element(idx_a, idx_b, s) { - return false; - } - } - true -} - -/// Populate a multiple key hashmap with row indexes. -/// Instead of the keys (which could be very large), the row indexes are stored. -/// To check if a row is equal the original DataFrame is also passed as ref. -/// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared -/// on equality. -pub(crate) fn populate_multiple_key_hashmap( - hash_tbl: &mut HashMap, - // row index - idx: IdxSize, - // hash - original_h: u64, - // keys of the hash table (will not be inserted, the indexes will be used) - // the keys are needed for the equality check - keys: &DataFrame, - // value to insert - vacant_fn: G, - // function that gets a mutable ref to the occupied value in the hash table - mut occupied_fn: F, -) where - G: Fn() -> V, - F: FnMut(&mut V), - H: BuildHasher, -{ - let entry = hash_tbl - .raw_entry_mut() - // uses the idx to probe rows in the original DataFrame with keys - // to check equality to find an entry - // this does not invalidate the hashmap as this equality function is not used - // during rehashing/resize (then the keys are already known to be unique). - // Only during insertion and probing an equality function is needed - .from_hash(original_h, |idx_hash| { - // first check the hash values - // before we incur a cache miss - idx_hash.hash == original_h && { - let key_idx = idx_hash.idx; - // Safety: - // indices in a group_by operation are always in bounds. - unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) } - } - }); - match entry { - RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn()); - }, - RawEntryMut::Occupied(mut entry) => { - let (_k, v) = entry.get_key_value_mut(); - occupied_fn(v); - }, - } -} - #[inline] pub(crate) unsafe fn compare_keys<'a>( keys_cmp: &'a [Box], @@ -456,7 +383,7 @@ pub(crate) fn group_by_threaded_multiple_keys_flat( sorted: bool, ) -> PolarsResult { let dfs = split_df(&mut keys, n_partitions).unwrap(); - let (hashes, _random_state) = df_rows_to_hashes_threaded_vertical(&dfs, None)?; + let (hashes, _random_state) = _df_rows_to_hashes_threaded_vertical(&dfs, None)?; let n_partitions = n_partitions as u64; let init_size = get_init_size(); diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 2ff670ac248f..fc457c8bfdba 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -60,16 +60,21 @@ impl DataFrame { !by.is_empty(), ComputeError: "at least one key is required in a group_by operation" ); - let by_len = by[0].len(); + let minimal_by_len = by.iter().map(|s| s.len()).min().expect("at least 1 key"); + let df_height = self.height(); // we only throw this error if self.width > 0 // so that we can still call this on a dummy dataframe where we provide the keys - if (by_len != self.height()) && (self.width() > 0) { + if (minimal_by_len != df_height) && (self.width() > 0) { polars_ensure!( - by_len == 1, + minimal_by_len == 1, ShapeMismatch: "series used as keys should have the same length as the dataframe" ); - by[0] = by[0].new_from_index(0, self.height()) + for by_key in by.iter_mut() { + if by_key.len() == minimal_by_len { + *by_key = by_key.new_from_index(0, df_height) + } + } }; let n_partitions = _set_partition_size(); @@ -265,10 +270,8 @@ impl<'df> GroupBy<'df> { .map(|s| { match groups { GroupsProxy::Idx(groups) => { - let mut iter = groups.first().iter().map(|first| *first as usize); - // Safety: - // groups are always in bounds - let mut out = unsafe { s.take_iter_unchecked(&mut iter) }; + // SAFETY: groups are always in bounds. + let mut out = unsafe { s.take_slice_unchecked(groups.first()) }; if groups.sorted { out.set_sorted_flag(s.is_sorted_flag()); }; @@ -276,7 +279,7 @@ impl<'df> GroupBy<'df> { }, GroupsProxy::Slice { groups, rolling } => { if *rolling && !groups.is_empty() { - // groups can be sliced + // Groups can be sliced. let offset = groups[0][0]; let [upper_offset, upper_len] = groups[groups.len() - 1]; return s.slice( @@ -285,11 +288,10 @@ impl<'df> GroupBy<'df> { ); } - let mut iter = groups.iter().map(|&[first, _len]| first as usize); - // Safety: - // groups are always in bounds - let mut out = unsafe { s.take_iter_unchecked(&mut iter) }; - // sliced groups are always in order of discovery + let indices = groups.iter().map(|&[first, _len]| first).collect_ca(""); + // SAFETY: groups are always in bounds. + let mut out = unsafe { s.take_unchecked(&indices) }; + // Sliced groups are always in order of discovery. out.set_sorted_flag(s.is_sorted_flag()); out }, @@ -583,7 +585,7 @@ impl<'df> GroupBy<'df> { DataFrame::new(cols) } - /// Aggregate grouped `Series` and determine the quantile per group. + /// Aggregate grouped [`Series`] and determine the quantile per group. /// /// # Example /// @@ -616,7 +618,7 @@ impl<'df> GroupBy<'df> { DataFrame::new(cols) } - /// Aggregate grouped `Series` and determine the median per group. + /// Aggregate grouped [`Series`] and determine the median per group. /// /// # Example /// @@ -638,7 +640,7 @@ impl<'df> GroupBy<'df> { DataFrame::new(cols) } - /// Aggregate grouped `Series` and determine the variance per group. + /// Aggregate grouped [`Series`] and determine the variance per group. #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] pub fn var(&self, ddof: u8) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; @@ -651,7 +653,7 @@ impl<'df> GroupBy<'df> { DataFrame::new(cols) } - /// Aggregate grouped `Series` and determine the standard deviation per group. + /// Aggregate grouped [`Series`] and determine the standard deviation per group. #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] pub fn std(&self, ddof: u8) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; @@ -790,7 +792,7 @@ impl<'df> GroupBy<'df> { } } - /// Apply a closure over the groups as a new DataFrame in parallel. + /// Apply a closure over the groups as a new [`DataFrame`] in parallel. #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] pub fn par_apply(&self, f: F) -> PolarsResult where @@ -813,7 +815,7 @@ impl<'df> GroupBy<'df> { Ok(df) } - /// Apply a closure over the groups as a new DataFrame. + /// Apply a closure over the groups as a new [`DataFrame`]. pub fn apply(&self, mut f: F) -> PolarsResult where F: FnMut(DataFrame) -> PolarsResult + Send + Sync, @@ -838,7 +840,7 @@ impl<'df> GroupBy<'df> { unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame { match g { - GroupsIndicator::Idx(idx) => df.take_iter_unchecked(idx.1.iter().map(|i| *i as usize)), + GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1), GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize), } } diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 43bb126dfd04..49f9afc0c9bd 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -229,7 +229,7 @@ impl CategoricalChunked { #[repr(C, align(64))] struct AlignTo64([u8; 64]); -/// There are no guarantees that the Vec will remain aligned if you reallocate the data. +/// There are no guarantees that the [`Vec`] will remain aligned if you reallocate the data. /// This means that you cannot reallocate so you will need to know how big to allocate up front. unsafe fn aligned_vec(n: usize) -> Vec { assert!(std::mem::align_of::() <= 64); diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index ebd33232772d..d8bc3b4c60e3 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -367,6 +367,23 @@ impl GroupsProxy { } } + /// # Safety + /// This will not do any bounds checks. The caller must ensure + /// all groups have members. + pub unsafe fn take_group_lasts(self) -> Vec { + match self { + GroupsProxy::Idx(groups) => groups + .all + .iter() + .map(|idx| *idx.get_unchecked(idx.len() - 1)) + .collect(), + GroupsProxy::Slice { groups, .. } => groups + .into_iter() + .map(|[first, len]| first + len - 1) + .collect(), + } + } + pub fn par_iter(&self) -> GroupsProxyParIter { GroupsProxyParIter::new(self) } diff --git a/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs b/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs deleted file mode 100644 index 6bcf87273fea..000000000000 --- a/crates/polars-core/src/frame/hash_join/single_keys_dispatch.rs +++ /dev/null @@ -1,543 +0,0 @@ -use num_traits::NumCast; - -use super::single_keys_inner::hash_join_tuples_inner; -use super::*; -#[cfg(feature = "chunked_ids")] -use crate::utils::create_chunked_index_mapping; - -impl Series { - #[doc(hidden)] - pub fn hash_join_left( - &self, - other: &Series, - validate: JoinValidation, - ) -> PolarsResult { - let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); - validate.validate_probe(&lhs, &rhs, false)?; - - use DataType::*; - match lhs.dtype() { - Utf8 => { - let lhs = lhs.utf8().unwrap(); - let rhs = rhs.utf8().unwrap(); - - let lhs = lhs.as_binary(); - let rhs = rhs.as_binary(); - lhs.hash_join_left(&rhs, validate) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - lhs.hash_join_left(rhs, validate) - }, - _ => { - if self.bit_repr_is_large() { - let lhs = lhs.bit_repr_large(); - let rhs = rhs.bit_repr_large(); - num_group_join_left(&lhs, &rhs, validate) - } else { - let lhs = lhs.bit_repr_small(); - let rhs = rhs.bit_repr_small(); - num_group_join_left(&lhs, &rhs, validate) - } - }, - } - } - - #[cfg(feature = "semi_anti_join")] - pub(super) fn hash_join_semi_anti(&self, other: &Series, anti: bool) -> Vec { - let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); - - use DataType::*; - match lhs.dtype() { - Utf8 => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); - - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - lhs.hash_join_semi_anti(rhs, anti) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - lhs.hash_join_semi_anti(rhs, anti) - }, - _ => { - if self.bit_repr_is_large() { - let lhs = lhs.bit_repr_large(); - let rhs = rhs.bit_repr_large(); - num_group_join_anti_semi(&lhs, &rhs, anti) - } else { - let lhs = lhs.bit_repr_small(); - let rhs = rhs.bit_repr_small(); - num_group_join_anti_semi(&lhs, &rhs, anti) - } - }, - } - } - - // returns the join tuples and whether or not the lhs tuples are sorted - pub(super) fn hash_join_inner( - &self, - other: &Series, - validate: JoinValidation, - ) -> PolarsResult<(InnerJoinIds, bool)> { - let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); - validate.validate_probe(&lhs, &rhs, true)?; - - use DataType::*; - match lhs.dtype() { - Utf8 => { - let lhs = lhs.utf8().unwrap(); - let rhs = rhs.utf8().unwrap(); - - let lhs = lhs.as_binary(); - let rhs = rhs.as_binary(); - lhs.hash_join_inner(&rhs, validate) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - lhs.hash_join_inner(rhs, validate) - }, - _ => { - if self.bit_repr_is_large() { - let lhs = self.bit_repr_large(); - let rhs = other.bit_repr_large(); - num_group_join_inner(&lhs, &rhs, validate) - } else { - let lhs = self.bit_repr_small(); - let rhs = other.bit_repr_small(); - num_group_join_inner(&lhs, &rhs, validate) - } - }, - } - } - - pub(super) fn hash_join_outer( - &self, - other: &Series, - validate: JoinValidation, - ) -> PolarsResult, Option)>> { - let (lhs, rhs) = (self.to_physical_repr(), other.to_physical_repr()); - validate.validate_probe(&lhs, &rhs, true)?; - - use DataType::*; - match lhs.dtype() { - Utf8 => { - let lhs = lhs.utf8().unwrap(); - let rhs = rhs.utf8().unwrap(); - - let lhs = lhs.as_binary(); - let rhs = rhs.as_binary(); - lhs.hash_join_outer(&rhs, validate) - }, - Binary => { - let lhs = lhs.binary().unwrap(); - let rhs = rhs.binary().unwrap(); - lhs.hash_join_outer(rhs, validate) - }, - _ => { - if self.bit_repr_is_large() { - let lhs = self.bit_repr_large(); - let rhs = other.bit_repr_large(); - lhs.hash_join_outer(&rhs, validate) - } else { - let lhs = self.bit_repr_small(); - let rhs = other.bit_repr_small(); - lhs.hash_join_outer(&rhs, validate) - } - }, - } - } -} - -fn splitted_to_slice(splitted: &[ChunkedArray]) -> Vec<&[T::Native]> -where - T: PolarsNumericType, -{ - splitted.iter().map(|ca| ca.cont_slice().unwrap()).collect() -} - -fn splitted_by_chunks(splitted: &[ChunkedArray]) -> Vec<&[T::Native]> -where - T: PolarsNumericType, -{ - splitted - .iter() - .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) - .collect() -} - -fn splitted_to_opt_vec(splitted: &[ChunkedArray]) -> Vec>> -where - T: PolarsNumericType, -{ - POOL.install(|| { - splitted - .par_iter() - .map(|ca| ca.into_iter().collect_trusted::>()) - .collect() - }) -} - -// returns the join tuples and whether or not the lhs tuples are sorted -fn num_group_join_inner( - left: &ChunkedArray, - right: &ChunkedArray, - validate: JoinValidation, -) -> PolarsResult<(InnerJoinIds, bool)> -where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + AsU64 + Copy, - Option: AsU64, -{ - let n_threads = POOL.current_num_threads(); - let (a, b, swapped) = det_hash_prone_order!(left, right); - let splitted_a = split_ca(a, n_threads).unwrap(); - let splitted_b = split_ca(b, n_threads).unwrap(); - match ( - left.null_count() == 0, - right.null_count() == 0, - left.chunks.len(), - right.chunks.len(), - ) { - (true, true, 1, 1) => { - let keys_a = splitted_to_slice(&splitted_a); - let keys_b = splitted_to_slice(&splitted_b); - Ok(( - hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?, - !swapped, - )) - }, - (true, true, _, _) => { - let keys_a = splitted_by_chunks(&splitted_a); - let keys_b = splitted_by_chunks(&splitted_b); - Ok(( - hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?, - !swapped, - )) - }, - _ => { - let keys_a = splitted_to_opt_vec(&splitted_a); - let keys_b = splitted_to_opt_vec(&splitted_b); - Ok(( - hash_join_tuples_inner(keys_a, keys_b, swapped, validate)?, - !swapped, - )) - }, - } -} - -#[cfg(feature = "chunked_ids")] -fn create_mappings( - chunks_left: &[ArrayRef], - chunks_right: &[ArrayRef], - left_len: usize, - right_len: usize, -) -> (Option>, Option>) { - let mapping_left = || { - if chunks_left.len() > 1 { - Some(create_chunked_index_mapping(chunks_left, left_len)) - } else { - None - } - }; - - let mapping_right = || { - if chunks_right.len() > 1 { - Some(create_chunked_index_mapping(chunks_right, right_len)) - } else { - None - } - }; - - POOL.join(mapping_left, mapping_right) -} - -#[cfg(not(feature = "chunked_ids"))] -fn create_mappings( - _chunks_left: &[ArrayRef], - _chunks_right: &[ArrayRef], - _left_len: usize, - _right_len: usize, -) -> (Option>, Option>) { - (None, None) -} - -fn num_group_join_left( - left: &ChunkedArray, - right: &ChunkedArray, - validate: JoinValidation, -) -> PolarsResult -where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + AsU64, - Option: AsU64, -{ - let n_threads = POOL.current_num_threads(); - let splitted_a = split_ca(left, n_threads).unwrap(); - let splitted_b = split_ca(right, n_threads).unwrap(); - match ( - left.null_count(), - right.null_count(), - left.chunks.len(), - right.chunks.len(), - ) { - (0, 0, 1, 1) => { - let keys_a = splitted_to_slice(&splitted_a); - let keys_b = splitted_to_slice(&splitted_b); - hash_join_tuples_left(keys_a, keys_b, None, None, validate) - }, - (0, 0, _, _) => { - let keys_a = splitted_by_chunks(&splitted_a); - let keys_b = splitted_by_chunks(&splitted_b); - - let (mapping_left, mapping_right) = - create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); - hash_join_tuples_left( - keys_a, - keys_b, - mapping_left.as_deref(), - mapping_right.as_deref(), - validate, - ) - }, - _ => { - let keys_a = splitted_to_opt_vec(&splitted_a); - let keys_b = splitted_to_opt_vec(&splitted_b); - let (mapping_left, mapping_right) = - create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); - hash_join_tuples_left( - keys_a, - keys_b, - mapping_left.as_deref(), - mapping_right.as_deref(), - validate, - ) - }, - } -} - -impl ChunkedArray -where - T: PolarsIntegerType + Sync, - T::Native: Eq + Hash + NumCast, -{ - fn hash_join_outer( - &self, - other: &ChunkedArray, - validate: JoinValidation, - ) -> PolarsResult, Option)>> { - let (a, b, swapped) = det_hash_prone_order!(self, other); - - let n_partitions = _set_partition_size(); - let splitted_a = split_ca(a, n_partitions).unwrap(); - let splitted_b = split_ca(b, n_partitions).unwrap(); - - match (a.null_count(), b.null_count()) { - (0, 0) => { - let iters_a = splitted_a - .iter() - .map(|ca| ca.into_no_null_iter()) - .collect::>(); - let iters_b = splitted_b - .iter() - .map(|ca| ca.into_no_null_iter()) - .collect::>(); - hash_join_tuples_outer(iters_a, iters_b, swapped, validate) - }, - _ => { - let iters_a = splitted_a - .iter() - .map(|ca| ca.into_iter()) - .collect::>(); - let iters_b = splitted_b - .iter() - .map(|ca| ca.into_iter()) - .collect::>(); - hash_join_tuples_outer(iters_a, iters_b, swapped, validate) - }, - } - } -} - -pub(crate) fn prepare_bytes<'a>( - been_split: &'a [BinaryChunked], - hb: &RandomState, -) -> Vec>> { - POOL.install(|| { - been_split - .par_iter() - .map(|ca| { - ca.into_iter() - .map(|opt_b| { - let mut state = hb.build_hasher(); - opt_b.hash(&mut state); - let hash = state.finish(); - BytesHash::new(opt_b, hash) - }) - .collect::>() - }) - .collect() - }) -} - -impl BinaryChunked { - fn prepare( - &self, - other: &BinaryChunked, - // In inner join and outer join, the shortest relation will be used to create a hash table. - // In left join, always use the right side to create. - build_shortest_table: bool, - ) -> (Vec, Vec, bool, RandomState) { - let n_threads = POOL.current_num_threads(); - - let (a, b, swapped) = if build_shortest_table { - det_hash_prone_order!(self, other) - } else { - (self, other, false) - }; - - let hb = RandomState::default(); - let splitted_a = split_ca(a, n_threads).unwrap(); - let splitted_b = split_ca(b, n_threads).unwrap(); - - (splitted_a, splitted_b, swapped, hb) - } - - // returns the join tuples and whether or not the lhs tuples are sorted - fn hash_join_inner( - &self, - other: &BinaryChunked, - validate: JoinValidation, - ) -> PolarsResult<(InnerJoinIds, bool)> { - let (splitted_a, splitted_b, swapped, hb) = self.prepare(other, true); - let str_hashes_a = prepare_bytes(&splitted_a, &hb); - let str_hashes_b = prepare_bytes(&splitted_b, &hb); - Ok(( - hash_join_tuples_inner(str_hashes_a, str_hashes_b, swapped, validate)?, - !swapped, - )) - } - - fn hash_join_left( - &self, - other: &BinaryChunked, - validate: JoinValidation, - ) -> PolarsResult { - let (splitted_a, splitted_b, _, hb) = self.prepare(other, false); - let str_hashes_a = prepare_bytes(&splitted_a, &hb); - let str_hashes_b = prepare_bytes(&splitted_b, &hb); - - let (mapping_left, mapping_right) = - create_mappings(self.chunks(), other.chunks(), self.len(), other.len()); - hash_join_tuples_left( - str_hashes_a, - str_hashes_b, - mapping_left.as_deref(), - mapping_right.as_deref(), - validate, - ) - } - - #[cfg(feature = "semi_anti_join")] - fn hash_join_semi_anti(&self, other: &BinaryChunked, anti: bool) -> Vec { - let (splitted_a, splitted_b, _, hb) = self.prepare(other, false); - let str_hashes_a = prepare_bytes(&splitted_a, &hb); - let str_hashes_b = prepare_bytes(&splitted_b, &hb); - if anti { - hash_join_tuples_left_anti(str_hashes_a, str_hashes_b) - } else { - hash_join_tuples_left_semi(str_hashes_a, str_hashes_b) - } - } - - fn hash_join_outer( - &self, - other: &BinaryChunked, - validate: JoinValidation, - ) -> PolarsResult, Option)>> { - let (a, b, swapped) = det_hash_prone_order!(self, other); - - let n_partitions = _set_partition_size(); - let splitted_a = split_ca(a, n_partitions).unwrap(); - let splitted_b = split_ca(b, n_partitions).unwrap(); - - match (a.has_validity(), b.has_validity()) { - (false, false) => { - let iters_a = splitted_a - .iter() - .map(|ca| ca.into_no_null_iter()) - .collect::>(); - let iters_b = splitted_b - .iter() - .map(|ca| ca.into_no_null_iter()) - .collect::>(); - hash_join_tuples_outer(iters_a, iters_b, swapped, validate) - }, - _ => { - let iters_a = splitted_a - .iter() - .map(|ca| ca.into_iter()) - .collect::>(); - let iters_b = splitted_b - .iter() - .map(|ca| ca.into_iter()) - .collect::>(); - hash_join_tuples_outer(iters_a, iters_b, swapped, validate) - }, - } - } -} - -#[cfg(feature = "semi_anti_join")] -fn num_group_join_anti_semi( - left: &ChunkedArray, - right: &ChunkedArray, - anti: bool, -) -> Vec -where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + AsU64, - Option: AsU64, -{ - let n_threads = POOL.current_num_threads(); - let splitted_a = split_ca(left, n_threads).unwrap(); - let splitted_b = split_ca(right, n_threads).unwrap(); - match ( - left.null_count(), - right.null_count(), - left.chunks.len(), - right.chunks.len(), - ) { - (0, 0, 1, 1) => { - let keys_a = splitted_to_slice(&splitted_a); - let keys_b = splitted_to_slice(&splitted_b); - if anti { - hash_join_tuples_left_anti(keys_a, keys_b) - } else { - hash_join_tuples_left_semi(keys_a, keys_b) - } - }, - (0, 0, _, _) => { - let keys_a = splitted_by_chunks(&splitted_a); - let keys_b = splitted_by_chunks(&splitted_b); - if anti { - hash_join_tuples_left_anti(keys_a, keys_b) - } else { - hash_join_tuples_left_semi(keys_a, keys_b) - } - }, - _ => { - let keys_a = splitted_to_opt_vec(&splitted_a); - let keys_b = splitted_to_opt_vec(&splitted_b); - if anti { - hash_join_tuples_left_anti(keys_a, keys_b) - } else { - hash_join_tuples_left_semi(keys_a, keys_b) - } - }, - } -} diff --git a/crates/polars-core/src/frame/hash_join/zip_outer.rs b/crates/polars-core/src/frame/hash_join/zip_outer.rs deleted file mode 100644 index 1da5ed1635f9..000000000000 --- a/crates/polars-core/src/frame/hash_join/zip_outer.rs +++ /dev/null @@ -1,123 +0,0 @@ -use super::*; - -pub trait ZipOuterJoinColumn { - fn zip_outer_join_column( - &self, - _right_column: &Series, - _opt_join_tuples: &[(Option, Option)], - ) -> Series { - unimplemented!() - } -} - -impl ZipOuterJoinColumn for ChunkedArray -where - T: PolarsIntegerType, - ChunkedArray: IntoSeries, -{ - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let right_ca = self.unpack_series_matching_type(right_column).unwrap(); - - let left_rand_access = self.take_rand(); - let right_rand_access = right_ca.take_rand(); - - opt_join_tuples - .iter() - .map(|(opt_left_idx, opt_right_idx)| { - if let Some(left_idx) = opt_left_idx { - unsafe { left_rand_access.get_unchecked(*left_idx as usize) } - } else { - unsafe { - let right_idx = opt_right_idx.unwrap_unchecked(); - right_rand_access.get_unchecked(right_idx as usize) - } - } - }) - .collect_trusted::>() - .into_series() - } -} - -macro_rules! impl_zip_outer_join { - ($chunkedtype:ident) => { - impl ZipOuterJoinColumn for $chunkedtype { - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let right_ca = self.unpack_series_matching_type(right_column).unwrap(); - - let left_rand_access = self.take_rand(); - let right_rand_access = right_ca.take_rand(); - - opt_join_tuples - .iter() - .map(|(opt_left_idx, opt_right_idx)| { - if let Some(left_idx) = opt_left_idx { - unsafe { left_rand_access.get_unchecked(*left_idx as usize) } - } else { - unsafe { - let right_idx = opt_right_idx.unwrap_unchecked(); - right_rand_access.get_unchecked(right_idx as usize) - } - } - }) - .collect::<$chunkedtype>() - .into_series() - } - } - }; -} -impl_zip_outer_join!(BooleanChunked); -impl_zip_outer_join!(BinaryChunked); - -impl ZipOuterJoinColumn for Utf8Chunked { - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - unsafe { - let out = self.as_binary().zip_outer_join_column( - &right_column.cast_unchecked(&DataType::Binary).unwrap(), - opt_join_tuples, - ); - out.cast_unchecked(&DataType::Utf8).unwrap_unchecked() - } - } -} - -impl ZipOuterJoinColumn for Float32Chunked { - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - self.apply_as_ints(|s| { - s.zip_outer_join_column( - &right_column.bit_repr_small().into_series(), - opt_join_tuples, - ) - }) - } -} - -impl ZipOuterJoinColumn for Float64Chunked { - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - self.apply_as_ints(|s| { - s.zip_outer_join_column( - &right_column.bit_repr_large().into_series(), - opt_join_tuples, - ) - }) - } -} diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index d239a716cb93..6e7f7a858c49 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -7,6 +7,7 @@ use ahash::AHashSet; use polars_arrow::prelude::QuantileInterpolOptions; use rayon::prelude::*; +#[cfg(feature = "algorithm_group_by")] use crate::chunked_array::ops::unique::is_unique_helper; use crate::prelude::*; #[cfg(feature = "describe")] @@ -15,15 +16,11 @@ use crate::utils::{slice_offsets, split_ca, split_df, try_get_supertype, NoNull} #[cfg(feature = "dataframe_arithmetic")] mod arithmetic; -#[cfg(feature = "asof_join")] -pub(crate) mod asof_join; mod chunks; -#[cfg(feature = "cross_join")] -pub(crate) mod cross_join; pub mod explode; mod from; +#[cfg(feature = "algorithm_group_by")] pub mod group_by; -pub mod hash_join; #[cfg(feature = "rows")] pub mod row; mod top_k; @@ -34,9 +31,10 @@ pub use chunks::*; use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::GroupsIndicator; #[cfg(feature = "row_hash")] -use crate::hashing::df_rows_to_hashes_threaded_vertical; +use crate::hashing::_df_rows_to_hashes_threaded_vertical; #[cfg(feature = "zip_with")] use crate::prelude::min_max_binary::min_max_binary_series; use crate::prelude::sort::{argsort_multiple_row_fmt, prepare_arg_sort}; @@ -477,7 +475,7 @@ impl DataFrame { } } - /// Ensure all the chunks in the DataFrame are aligned. + /// Ensure all the chunks in the [`DataFrame`] are aligned. pub fn align_chunks(&mut self) -> &mut Self { if self.should_rechunk() { self.as_single_chunk_par() @@ -486,7 +484,7 @@ impl DataFrame { } } - /// Get the `DataFrame` schema. + /// Get the [`DataFrame`] schema. /// /// # Example /// @@ -503,10 +501,10 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn schema(&self) -> Schema { - self.iter().map(|s| s.field().into_owned()).collect() + self.columns.as_slice().into() } - /// Get a reference to the `DataFrame` columns. + /// Get a reference to the [`DataFrame`] columns. /// /// # Example /// @@ -533,7 +531,7 @@ impl DataFrame { &mut self.columns } - /// Iterator over the columns as `Series`. + /// Iterator over the columns as [`Series`]. /// /// # Example /// @@ -568,7 +566,7 @@ impl DataFrame { self.columns.iter().map(|s| s.name()).collect() } - /// Get the `Vec` representing the column names. + /// Get the [`Vec`] representing the column names. pub fn get_column_names_owned(&self) -> Vec { self.columns.iter().map(|s| s.name().into()).collect() } @@ -610,7 +608,7 @@ impl DataFrame { Ok(()) } - /// Get the data types of the columns in the DataFrame. + /// Get the data types of the columns in the [`DataFrame`]. /// /// # Example /// @@ -634,7 +632,7 @@ impl DataFrame { } } - /// Get a reference to the schema fields of the `DataFrame`. + /// Get a reference to the schema fields of the [`DataFrame`]. /// /// # Example /// @@ -656,7 +654,7 @@ impl DataFrame { .collect() } - /// Get (height, width) of the `DataFrame`. + /// Get (height, width) of the [`DataFrame`]. /// /// # Example /// @@ -679,7 +677,7 @@ impl DataFrame { } } - /// Get the width of the `DataFrame` which is the number of columns. + /// Get the width of the [`DataFrame`] which is the number of columns. /// /// # Example /// @@ -699,7 +697,7 @@ impl DataFrame { self.columns.len() } - /// Get the height of the `DataFrame` which is the number of rows. + /// Get the height of the [`DataFrame`] which is the number of rows. /// /// # Example /// @@ -718,7 +716,7 @@ impl DataFrame { self.shape().0 } - /// Check if the `DataFrame` is empty. + /// Check if the [`DataFrame`] is empty. /// /// # Example /// @@ -747,7 +745,7 @@ impl DataFrame { self } - /// Add multiple `Series` to a `DataFrame`. + /// Add multiple [`Series`] to a [`DataFrame`]. /// The added `Series` are required to have the same length. /// /// # Example @@ -781,7 +779,7 @@ impl DataFrame { Ok(unsafe { self.hstack_mut_unchecked(columns) }) } - /// Add multiple `Series` to a `DataFrame`. + /// Add multiple [`Series`] to a [`DataFrame`]. /// The added `Series` are required to have the same length. /// /// # Example @@ -820,7 +818,7 @@ impl DataFrame { DataFrame::new(new_cols) } - /// Concatenate a `DataFrame` to this `DataFrame` and return as newly allocated `DataFrame`. + /// Concatenate a [`DataFrame`] to this [`DataFrame`] and return as newly allocated [`DataFrame`]. /// /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks`]. /// @@ -866,7 +864,7 @@ impl DataFrame { Ok(df) } - /// Concatenate a DataFrame to this DataFrame + /// Concatenate a [`DataFrame`] to this [`DataFrame`] /// /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks`]. /// @@ -991,7 +989,7 @@ impl DataFrame { Ok(self.columns.remove(idx)) } - /// Return a new `DataFrame` where all null values are dropped. + /// Return a new [`DataFrame`] where all null values are dropped. /// /// # Example /// @@ -1047,7 +1045,7 @@ impl DataFrame { } /// Drop a column by name. - /// This is a pure method and will return a new `DataFrame` instead of modifying + /// This is a pure method and will return a new [`DataFrame`] instead of modifying /// the current one in place. /// /// # Example @@ -1079,7 +1077,7 @@ impl DataFrame { self.drop_many_amortized(&names) } - /// Drop columns that are in `names` without allocating a `HashSet`. + /// Drop columns that are in `names` without allocating a [`HashSet`](std::collections::HashSet). pub fn drop_many_amortized(&self, names: &PlHashSet<&str>) -> DataFrame { let mut new_cols = Vec::with_capacity(self.columns.len().saturating_sub(names.len())); self.columns.iter().for_each(|s| { @@ -1092,7 +1090,7 @@ impl DataFrame { } /// Insert a new column at a given index without checking for duplicates. - /// This can leave the DataFrame at an invalid state + /// This can leave the [`DataFrame`] at an invalid state fn insert_at_idx_no_name_check( &mut self, index: usize, @@ -1127,7 +1125,7 @@ impl DataFrame { Ok(()) } - /// Add a new column to this `DataFrame` or replace an existing one. + /// Add a new column to this [`DataFrame`] or replace an existing one. pub fn with_column(&mut self, column: S) -> PolarsResult<&mut Self> { fn inner(df: &mut DataFrame, mut series: Series) -> PolarsResult<&mut DataFrame> { let height = df.height(); @@ -1155,7 +1153,7 @@ impl DataFrame { inner(self, series) } - /// Adds a column to the `DataFrame` without doing any checks + /// Adds a column to the [`DataFrame`] without doing any checks /// on length or duplicates. /// /// # Safety @@ -1193,7 +1191,7 @@ impl DataFrame { Ok(()) } - /// Add a new column to this `DataFrame` or replace an existing one. + /// Add a new column to this [`DataFrame`] or replace an existing one. /// Uses an existing schema to amortize lookups. /// If the schema is incorrect, we will fallback to linear search. pub fn with_column_and_schema( @@ -1225,7 +1223,7 @@ impl DataFrame { } } - /// Get a row in the `DataFrame`. Beware this is slow. + /// Get a row in the [`DataFrame`]. Beware this is slow. /// /// # Example /// @@ -1248,7 +1246,7 @@ impl DataFrame { unsafe { Some(self.columns.iter().map(|s| s.get_unchecked(idx)).collect()) } } - /// Select a `Series` by index. + /// Select a [`Series`] by index. /// /// # Example /// @@ -1275,7 +1273,7 @@ impl DataFrame { self.columns.get_mut(idx) } - /// Select column(s) from this `DataFrame` by range and return a new DataFrame + /// Select column(s) from this [`DataFrame`] by range and return a new [`DataFrame`] /// /// # Examples /// @@ -1334,10 +1332,10 @@ impl DataFrame { let colnames = self.get_column_names_owned(); let range = get_range(range, ..colnames.len()); - self.select_impl(&colnames[range]) + self._select_impl(&colnames[range]) } - /// Get column index of a `Series` by name. + /// Get column index of a [`Series`] by name. /// # Example /// /// ```rust @@ -1358,7 +1356,7 @@ impl DataFrame { self.columns.iter().position(|s| s.name() == name) } - /// Get column index of a `Series` by name. + /// Get column index of a [`Series`] by name. pub fn try_find_idx_by_name(&self, name: &str) -> PolarsResult { self.find_idx_by_name(name) .ok_or_else(|| polars_err!(ColumnNotFound: "{}", name)) @@ -1409,7 +1407,7 @@ impl DataFrame { .collect() } - /// Select column(s) from this `DataFrame` and return a new `DataFrame`. + /// Select column(s) from this [`DataFrame`] and return a new [`DataFrame`]. /// /// # Examples /// @@ -1428,11 +1426,15 @@ impl DataFrame { .into_iter() .map(|s| SmartString::from(s.as_ref())) .collect::>(); - self.select_impl(&cols) + self._select_impl(&cols) } - fn select_impl(&self, cols: &[SmartString]) -> PolarsResult { + pub fn _select_impl(&self, cols: &[SmartString]) -> PolarsResult { self.select_check_duplicates(cols)?; + self._select_impl_unchecked(cols) + } + + pub fn _select_impl_unchecked(&self, cols: &[SmartString]) -> PolarsResult { let selected = self.select_series_impl(cols)?; Ok(DataFrame::new_no_checks(selected)) } @@ -1522,7 +1524,7 @@ impl DataFrame { Ok(()) } - /// Select column(s) from this `DataFrame` and return them into a `Vec`. + /// Select column(s) from this [`DataFrame`] and return them into a [`Vec`]. /// /// # Example /// @@ -1638,7 +1640,7 @@ impl DataFrame { })) } - /// Take the `DataFrame` rows by a boolean mask. + /// Take the [`DataFrame`] rows by a boolean mask. /// /// # Example /// @@ -1654,7 +1656,14 @@ impl DataFrame { return self.clone().filter_vertical(mask); } let new_col = self.try_apply_columns_par(&|s| match s.dtype() { - DataType::Utf8 => s.filter_threaded(mask, true), + DataType::Utf8 => { + let ca = s.utf8().unwrap(); + if ca.get_values_size() / 24 <= ca.len() { + s.filter(mask) + } else { + s.filter_threaded(mask, true) + } + }, _ => s.filter(mask), })?; Ok(DataFrame::new_no_checks(new_col)) @@ -1666,104 +1675,7 @@ impl DataFrame { Ok(DataFrame::new_no_checks(new_col)) } - /// Take `DataFrame` value by indexes from an iterator. - /// - /// # Example - /// - /// ``` - /// # use polars_core::prelude::*; - /// fn example(df: &DataFrame) -> PolarsResult { - /// let iterator = (0..9).into_iter(); - /// df.take_iter(iterator) - /// } - /// ``` - pub fn take_iter(&self, iter: I) -> PolarsResult - where - I: Iterator + Clone + Sync + TrustedLen, - { - let new_col = self.try_apply_columns_par(&|s| { - let mut i = iter.clone(); - s.take_iter(&mut i) - })?; - - Ok(DataFrame::new_no_checks(new_col)) - } - - /// Take `DataFrame` values by indexes from an iterator. - /// - /// # Safety - /// - /// This doesn't do any bound checking but checks null validity. - #[must_use] - pub unsafe fn take_iter_unchecked(&self, mut iter: I) -> Self - where - I: Iterator + Clone + Sync + TrustedLen, - { - let n_chunks = self.n_chunks(); - let has_utf8 = self - .columns - .iter() - .any(|s| matches!(s.dtype(), DataType::Utf8)); - - if (n_chunks == 1 && self.width() > 1) || has_utf8 { - let idx_ca: NoNull = iter.map(|idx| idx as IdxSize).collect(); - let idx_ca = idx_ca.into_inner(); - return self.take_unchecked(&idx_ca); - } - - let new_col = if self.width() == 1 { - self.columns - .iter() - .map(|s| s.take_iter_unchecked(&mut iter)) - .collect::>() - } else { - self.apply_columns_par(&|s| { - let mut i = iter.clone(); - s.take_iter_unchecked(&mut i) - }) - }; - DataFrame::new_no_checks(new_col) - } - - /// Take `DataFrame` values by indexes from an iterator that may contain None values. - /// - /// # Safety - /// - /// This doesn't do any bound checking. Out of bounds may access uninitialized memory. - /// Null validity is checked - #[must_use] - pub unsafe fn take_opt_iter_unchecked(&self, mut iter: I) -> Self - where - I: Iterator> + Clone + Sync + TrustedLen, - { - let n_chunks = self.n_chunks(); - - let has_utf8 = self - .columns - .iter() - .any(|s| matches!(s.dtype(), DataType::Utf8)); - - if (n_chunks == 1 && self.width() > 1) || has_utf8 { - let idx_ca: IdxCa = iter.map(|opt| opt.map(|v| v as IdxSize)).collect(); - return self.take_unchecked(&idx_ca); - } - - let new_col = if self.width() == 1 { - self.columns - .iter() - .map(|s| s.take_opt_iter_unchecked(&mut iter)) - .collect::>() - } else { - self.apply_columns_par(&|s| { - let mut i = iter.clone(); - s.take_opt_iter_unchecked(&mut i) - }) - }; - - DataFrame::new_no_checks(new_col) - } - - /// Take `DataFrame` rows by index values. + /// Take [`DataFrame`] rows by index values. /// /// # Example /// @@ -1775,22 +1687,26 @@ impl DataFrame { /// } /// ``` pub fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; let new_col = POOL.install(|| { self.try_apply_columns_par(&|s| match s.dtype() { - DataType::Utf8 => s.take_threaded(&indices, true), - _ => s.take(&indices), + DataType::Utf8 => { + let ca = s.utf8().unwrap(); + if ca.get_values_size() / 24 <= ca.len() { + s.take(indices) + } else { + s.take_threaded(indices, true) + } + }, + _ => s.take(indices), }) })?; Ok(DataFrame::new_no_checks(new_col)) } - pub(crate) unsafe fn take_unchecked(&self, idx: &IdxCa) -> Self { + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_unchecked(&self, idx: &IdxCa) -> Self { self.take_unchecked_impl(idx, true) } @@ -1798,20 +1714,38 @@ impl DataFrame { let cols = if allow_threads { POOL.install(|| { self.apply_columns_par(&|s| match s.dtype() { - DataType::Utf8 => s.take_unchecked_threaded(idx, true).unwrap(), - _ => s.take_unchecked(idx).unwrap(), + DataType::Utf8 => s.take_unchecked_threaded(idx, true), + _ => s.take_unchecked(idx), + }) + }) + } else { + self.columns.iter().map(|s| s.take_unchecked(idx)).collect() + }; + DataFrame::new_no_checks(cols) + } + + pub(crate) unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { + self.take_slice_unchecked_impl(idx, true) + } + + unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { + let cols = if allow_threads { + POOL.install(|| { + self.apply_columns_par(&|s| match s.dtype() { + DataType::Utf8 => s.take_slice_unchecked_threaded(idx, true), + _ => s.take_slice_unchecked(idx), }) }) } else { self.columns .iter() - .map(|s| s.take_unchecked(idx).unwrap()) + .map(|s| s.take_slice_unchecked(idx)) .collect() }; DataFrame::new_no_checks(cols) } - /// Rename a column in the `DataFrame`. + /// Rename a column in the [`DataFrame`]. /// /// # Example /// @@ -1836,7 +1770,7 @@ impl DataFrame { Ok(self) } - /// Sort `DataFrame` in place by a column. + /// Sort [`DataFrame`] in place by a column. pub fn sort_in_place( &mut self, by_column: impl IntoVec, @@ -1955,7 +1889,7 @@ impl DataFrame { Ok(df) } - /// Return a sorted clone of this `DataFrame`. + /// Return a sorted clone of this [`DataFrame`]. /// /// # Example /// @@ -1980,7 +1914,7 @@ impl DataFrame { Ok(df) } - /// Sort the `DataFrame` by a single column with extra options. + /// Sort the [`DataFrame`] by a single column with extra options. pub fn sort_with_options(&self, by_column: &str, options: SortOptions) -> PolarsResult { let mut df = self.clone(); let by_column = vec![df.column(by_column)?.clone()]; @@ -1998,7 +1932,7 @@ impl DataFrame { Ok(df) } - /// Replace a column with a `Series`. + /// Replace a column with a [`Series`]. /// /// # Example /// @@ -2029,7 +1963,7 @@ impl DataFrame { self.with_column(new_col) } - /// Replace column at index `idx` with a `Series`. + /// Replace column at index `idx` with a [`Series`]. /// /// # Example /// @@ -2300,7 +2234,7 @@ impl DataFrame { self.try_apply_at_idx(idx, f) } - /// Slice the `DataFrame` along the rows. + /// Slice the [`DataFrame`] along the rows. /// /// # Example /// @@ -2367,7 +2301,7 @@ impl DataFrame { })) } - /// Get the head of the `DataFrame`. + /// Get the head of the [`DataFrame`]. /// /// # Example /// @@ -2410,7 +2344,7 @@ impl DataFrame { DataFrame::new_no_checks(col) } - /// Get the tail of the `DataFrame`. + /// Get the tail of the [`DataFrame`]. /// /// # Example /// @@ -2450,11 +2384,11 @@ impl DataFrame { DataFrame::new_no_checks(col) } - /// Iterator over the rows in this `DataFrame` as Arrow RecordBatches. + /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches. /// /// # Panics /// - /// Panics if the `DataFrame` that is passed is not rechunked. + /// Panics if the [`DataFrame`] that is passed is not rechunked. /// /// This responsibility is left to the caller as we don't want to take mutable references here, /// but we also don't want to rechunk here, as this operation is costly and would benefit the caller @@ -2467,11 +2401,11 @@ impl DataFrame { } } - /// Iterator over the rows in this `DataFrame` as Arrow RecordBatches as physical values. + /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches as physical values. /// /// # Panics /// - /// Panics if the `DataFrame` that is passed is not rechunked. + /// Panics if the [`DataFrame`] that is passed is not rechunked. /// /// This responsibility is left to the caller as we don't want to take mutable references here, /// but we also don't want to rechunk here, as this operation is costly and would benefit the caller @@ -2482,7 +2416,7 @@ impl DataFrame { } } - /// Get a `DataFrame` with all the columns in reversed order. + /// Get a [`DataFrame`] with all the columns in reversed order. #[must_use] pub fn reverse(&self) -> Self { let col = self.columns.iter().map(|s| s.reverse()).collect::>(); @@ -2964,11 +2898,11 @@ impl DataFrame { let columns = self .columns .iter() - .cloned() .filter(|s| { let dtype = s.dtype(); dtype.is_numeric() || matches!(dtype, DataType::Boolean) }) + .cloned() .collect(); let numeric_df = DataFrame::new_no_checks(columns); @@ -3028,7 +2962,7 @@ impl DataFrame { f(self, args) } - /// Drop duplicate rows from a `DataFrame`. + /// Drop duplicate rows from a [`DataFrame`]. /// *This fails when there is a column of type List in DataFrame* /// /// Stable means that the order is maintained. This has a higher cost than an unstable distinct. @@ -3061,6 +2995,7 @@ impl DataFrame { /// | 3 | 3 | "c" | /// +-----+-----+-----+ /// ``` + #[cfg(feature = "algorithm_group_by")] pub fn unique_stable( &self, subset: Option<&[String]>, @@ -3071,6 +3006,7 @@ impl DataFrame { } /// Unstable distinct. See [`DataFrame::unique_stable`]. + #[cfg(feature = "algorithm_group_by")] pub fn unique( &self, subset: Option<&[String]>, @@ -3080,6 +3016,7 @@ impl DataFrame { self.unique_impl(false, subset, keep, slice) } + #[cfg(feature = "algorithm_group_by")] pub fn unique_impl( &self, maintain_order: bool, @@ -3152,7 +3089,7 @@ impl DataFrame { Ok(DataFrame::new_no_checks(columns)) } - /// Get a mask of all the unique rows in the `DataFrame`. + /// Get a mask of all the unique rows in the [`DataFrame`]. /// /// # Example /// @@ -3165,6 +3102,7 @@ impl DataFrame { /// assert!(ca.all()); /// # Ok::<(), PolarsError>(()) /// ``` + #[cfg(feature = "algorithm_group_by")] pub fn is_unique(&self) -> PolarsResult { let gb = self.group_by(self.get_column_names())?; let groups = gb.take_groups(); @@ -3176,7 +3114,7 @@ impl DataFrame { )) } - /// Get a mask of all the duplicated rows in the `DataFrame`. + /// Get a mask of all the duplicated rows in the [`DataFrame`]. /// /// # Example /// @@ -3189,6 +3127,7 @@ impl DataFrame { /// assert!(!ca.all()); /// # Ok::<(), PolarsError>(()) /// ``` + #[cfg(feature = "algorithm_group_by")] pub fn is_duplicated(&self) -> PolarsResult { let gb = self.group_by(self.get_column_names())?; let groups = gb.take_groups(); @@ -3200,7 +3139,7 @@ impl DataFrame { )) } - /// Create a new `DataFrame` that shows the null counts per column. + /// Create a new [`DataFrame`] that shows the null counts per column. #[must_use] pub fn null_count(&self) -> Self { let cols = self @@ -3218,7 +3157,7 @@ impl DataFrame { hasher_builder: Option, ) -> PolarsResult { let dfs = split_df(self, POOL.current_num_threads())?; - let (cas, _) = df_rows_to_hashes_threaded_vertical(&dfs, hasher_builder)?; + let (cas, _) = _df_rows_to_hashes_threaded_vertical(&dfs, hasher_builder)?; let mut iter = cas.into_iter(); let mut acc_ca = iter.next().unwrap(); @@ -3263,7 +3202,9 @@ impl DataFrame { } #[cfg(feature = "chunked_ids")] - pub(crate) unsafe fn take_chunked_unchecked(&self, idx: &[ChunkId], sorted: IsSorted) -> Self { + /// # Safety + /// Doesn't perform any bound checks + pub unsafe fn _take_chunked_unchecked(&self, idx: &[ChunkId], sorted: IsSorted) -> Self { let cols = self.apply_columns_par(&|s| match s.dtype() { DataType::Utf8 => s._take_chunked_unchecked_threaded(idx, sorted, true), _ => s._take_chunked_unchecked(idx, sorted), @@ -3273,7 +3214,9 @@ impl DataFrame { } #[cfg(feature = "chunked_ids")] - pub(crate) unsafe fn take_opt_chunked_unchecked(&self, idx: &[Option]) -> Self { + /// # Safety + /// Doesn't perform any bound checks + pub unsafe fn _take_opt_chunked_unchecked(&self, idx: &[Option]) -> Self { let cols = self.apply_columns_par(&|s| match s.dtype() { DataType::Utf8 => s._take_opt_chunked_unchecked_threaded(idx, true), _ => s._take_opt_chunked_unchecked(idx), @@ -3323,7 +3266,7 @@ impl DataFrame { self.take_unchecked_impl(&ca, allow_threads) } - #[cfg(feature = "partition_by")] + #[cfg(all(feature = "partition_by", feature = "algorithm_group_by"))] #[doc(hidden)] pub fn _partition_by_impl( &self, diff --git a/crates/polars-core/src/frame/row/av_buffer.rs b/crates/polars-core/src/frame/row/av_buffer.rs index 50205348db69..dd3d76cd5484 100644 --- a/crates/polars-core/src/frame/row/av_buffer.rs +++ b/crates/polars-core/src/frame/row/av_buffer.rs @@ -127,7 +127,7 @@ impl<'a> AnyValueBuffer<'a> { self.add(val.clone()).ok_or_else(|| { polars_err!( ComputeError: "could not append value: {} of type: {} to the builder; make sure that all rows \ - have the same schema or consider increasing `schema_inference_length`\n\ + have the same schema or consider increasing `infer_schema_length`\n\ \n\ it might also be that a value overflows the data-type's capacity", val, val.dtype() ) @@ -293,7 +293,7 @@ impl From<(&DataType, usize)> for AnyValueBuffer<'_> { } } -/// An `AnyValueBuffer` that should be used when we trust the builder +/// An [`AnyValueBuffer`] that should be used when we trust the builder #[derive(Clone)] pub enum AnyValueBufferTrusted<'a> { Boolean(BooleanChunkedBuilder), @@ -433,13 +433,13 @@ impl<'a> AnyValueBufferTrusted<'a> { } } - /// Will add the AnyValue into `Self` and unpack as the physical type - /// belonging to `Self`. This should only be used with physical buffers + /// Will add the [`AnyValue`] into [`Self`] and unpack as the physical type + /// belonging to [`Self`]. This should only be used with physical buffers /// /// If a type is not primitive or utf8, the anyvalue will be converted to static /// /// # Safety - /// The caller must ensure that the `AnyValue` type exactly matches the `Buffer` type and is owned. + /// The caller must ensure that the [`AnyValue`] type exactly matches the `Buffer` type and is owned. #[inline] pub unsafe fn add_unchecked_owned_physical(&mut self, val: &AnyValue<'_>) { use AnyValueBufferTrusted::*; @@ -478,7 +478,7 @@ impl<'a> AnyValueBufferTrusted<'a> { } /// # Safety - /// The caller must ensure that the `AnyValue` type exactly matches the `Buffer` type and is borrowed. + /// The caller must ensure that the [`AnyValue`] type exactly matches the `Buffer` type and is borrowed. #[inline] pub unsafe fn add_unchecked_borrowed_physical(&mut self, val: &AnyValue<'_>) { use AnyValueBufferTrusted::*; diff --git a/crates/polars-core/src/frame/row/dataframe.rs b/crates/polars-core/src/frame/row/dataframe.rs index f2faf909a189..1aa2197d1ac5 100644 --- a/crates/polars-core/src/frame/row/dataframe.rs +++ b/crates/polars-core/src/frame/row/dataframe.rs @@ -2,7 +2,7 @@ use super::*; use crate::frame::row::av_buffer::AnyValueBuffer; impl DataFrame { - /// Get a row from a DataFrame. Use of this is discouraged as it will likely be slow. + /// Get a row from a [`DataFrame`]. Use of this is discouraged as it will likely be slow. pub fn get_row(&self, idx: usize) -> PolarsResult { let values = self .columns @@ -14,7 +14,7 @@ impl DataFrame { /// Amortize allocations by reusing a row. /// The caller is responsible to make sure that the row has at least the capacity for the number - /// of columns in the DataFrame + /// of columns in the [`DataFrame`] pub fn get_row_amortized<'a>(&'a self, idx: usize, row: &mut Row<'a>) -> PolarsResult<()> { for (s, any_val) in self.columns.iter().zip(&mut row.0) { *any_val = s.get(idx)?; @@ -24,7 +24,7 @@ impl DataFrame { /// Amortize allocations by reusing a row. /// The caller is responsible to make sure that the row has at least the capacity for the number - /// of columns in the DataFrame + /// of columns in the [`DataFrame`] /// /// # Safety /// Does not do any bounds checking. @@ -38,14 +38,14 @@ impl DataFrame { }); } - /// Create a new DataFrame from rows. This should only be used when you have row wise data, - /// as this is a lot slower than creating the `Series` in a columnar fashion + /// Create a new [`DataFrame`] from rows. This should only be used when you have row wise data, + /// as this is a lot slower than creating the [`Series`] in a columnar fashion pub fn from_rows_and_schema(rows: &[Row], schema: &Schema) -> PolarsResult { Self::from_rows_iter_and_schema(rows.iter(), schema) } - /// Create a new DataFrame from an iterator over rows. This should only be used when you have row wise data, - /// as this is a lot slower than creating the `Series` in a columnar fashion + /// Create a new [`DataFrame`] from an iterator over rows. This should only be used when you have row wise data, + /// as this is a lot slower than creating the [`Series`] in a columnar fashion pub fn from_rows_iter_and_schema<'a, I>(mut rows: I, schema: &Schema) -> PolarsResult where I: Iterator>, @@ -86,8 +86,8 @@ impl DataFrame { DataFrame::new(v) } - /// Create a new DataFrame from an iterator over rows. This should only be used when you have row wise data, - /// as this is a lot slower than creating the `Series` in a columnar fashion + /// Create a new [`DataFrame`] from an iterator over rows. This should only be used when you have row wise data, + /// as this is a lot slower than creating the [`Series`] in a columnar fashion pub fn try_from_rows_iter_and_schema<'a, I>(mut rows: I, schema: &Schema) -> PolarsResult where I: Iterator>>, @@ -128,8 +128,8 @@ impl DataFrame { DataFrame::new(v) } - /// Create a new DataFrame from rows. This should only be used when you have row wise data, - /// as this is a lot slower than creating the `Series` in a columnar fashion + /// Create a new [`DataFrame`] from rows. This should only be used when you have row wise data, + /// as this is a lot slower than creating the [`Series`] in a columnar fashion pub fn from_rows(rows: &[Row]) -> PolarsResult { let schema = rows_to_schema_first_non_null(rows, Some(50)); let has_nulls = schema diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs index a0a1b83436a9..909252cf1eca 100644 --- a/crates/polars-core/src/frame/row/mod.rs +++ b/crates/polars-core/src/frame/row/mod.rs @@ -108,7 +108,7 @@ fn types_set_to_dtype(types_set: PlIndexSet) -> PolarsResult types_set .into_iter() .map(Ok) - .fold_first_(|a, b| try_get_supertype(&a?, &b?)) + .reduce(|a, b| try_get_supertype(&a?, &b?)) .unwrap() } diff --git a/crates/polars-core/src/functions.rs b/crates/polars-core/src/functions.rs index 515913b43aad..ba447b8ca54c 100644 --- a/crates/polars-core/src/functions.rs +++ b/crates/polars-core/src/functions.rs @@ -8,9 +8,7 @@ use std::ops::Add; use ahash::AHashSet; use arrow::compute; use arrow::types::simd::Simd; -use num_traits::{Float, NumCast, ToPrimitive}; -#[cfg(feature = "concat_str")] -use polars_arrow::prelude::ValueSize; +use num_traits::ToPrimitive; use crate::prelude::*; use crate::utils::coalesce_nulls; @@ -18,27 +16,9 @@ use crate::utils::coalesce_nulls; use crate::utils::concat_df; /// Compute the covariance between two columns. -pub fn cov_f(a: &ChunkedArray, b: &ChunkedArray) -> Option +pub fn cov(a: &ChunkedArray, b: &ChunkedArray) -> Option where - T: PolarsFloatType, - T::Native: Float, - ::Simd: Add::Simd> - + compute::aggregate::Sum - + compute::aggregate::SimdOrd, -{ - if a.len() != b.len() { - None - } else { - let tmp = (a - a.mean()?) * (b - b.mean()?); - let n = tmp.len() - tmp.null_count(); - Some(tmp.sum()? / NumCast::from(n - 1).unwrap()) - } -} - -/// Compute the covariance between two columns. -pub fn cov_i(a: &ChunkedArray, b: &ChunkedArray) -> Option -where - T: PolarsIntegerType, + T: PolarsNumericType, T::Native: ToPrimitive, ::Simd: Add::Simd> + compute::aggregate::Sum @@ -59,133 +39,26 @@ where } /// Compute the pearson correlation between two columns. -pub fn pearson_corr_i(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option +pub fn pearson_corr(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option where - T: PolarsIntegerType, + T: PolarsNumericType, T::Native: ToPrimitive, ::Simd: Add::Simd> + compute::aggregate::Sum + compute::aggregate::SimdOrd, - ChunkedArray: ChunkVar, + ChunkedArray: ChunkVar, { let (a, b) = coalesce_nulls(a, b); let a = a.as_ref(); let b = b.as_ref(); - Some(cov_i(a, b)? / (a.std(ddof)? * b.std(ddof)?)) -} - -/// Compute the pearson correlation between two columns. -pub fn pearson_corr_f(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option -where - T: PolarsFloatType, - T::Native: Float, - ::Simd: Add::Simd> - + compute::aggregate::Sum - + compute::aggregate::SimdOrd, - ChunkedArray: ChunkVar, -{ - let (a, b) = coalesce_nulls(a, b); - let a = a.as_ref(); - let b = b.as_ref(); - - Some(cov_f(a, b)? / (a.std(ddof)? * b.std(ddof)?)) -} - -// utility to be able to also add literals to concat_str function -#[cfg(feature = "concat_str")] -enum IterBroadCast<'a> { - Column(Box> + 'a>), - Value(Option<&'a str>), -} - -#[cfg(feature = "concat_str")] -impl<'a> IterBroadCast<'a> { - fn next(&mut self) -> Option> { - use IterBroadCast::*; - match self { - Column(iter) => iter.next(), - Value(val) => Some(*val), - } - } + Some(cov(a, b)? / (a.std(ddof)? * b.std(ddof)?)) } -/// Casts all series to string data and will concat them in linear time. -/// The concatenated strings are separated by a `delimiter`. -/// If no `delimiter` is needed, an empty &str should be passed as argument. -#[cfg(feature = "concat_str")] -pub fn concat_str(s: &[Series], delimiter: &str) -> PolarsResult { - polars_ensure!(!s.is_empty(), NoData: "expected multiple series in `concat_str`"); - if s.iter().any(|s| s.is_empty()) { - return Ok(Utf8Chunked::full_null(s[0].name(), 0)); - } - - let len = s.iter().map(|s| s.len()).max().unwrap(); - - let cas = s - .iter() - .map(|s| { - let s = s.cast(&DataType::Utf8)?; - let mut ca = s.utf8()?.clone(); - // broadcast - if ca.len() == 1 && len > 1 { - ca = ca.new_from_index(0, len) - } - - Ok(ca) - }) - .collect::>>()?; - - polars_ensure!( - s.iter().all(|s| s.len() == 1 || s.len() == len), - ComputeError: "all series in `concat_str` should have equal or unit length" - ); - let mut iters = cas - .iter() - .map(|ca| match ca.len() { - 1 => IterBroadCast::Value(ca.get(0)), - _ => IterBroadCast::Column(ca.into_iter()), - }) - .collect::>(); - - let bytes_cap = cas.iter().map(|ca| ca.get_values_size()).sum(); - let mut builder = Utf8ChunkedBuilder::new(s[0].name(), len, bytes_cap); - - // use a string buffer, to amortize alloc - let mut buf = String::with_capacity(128); - - for _ in 0..len { - let mut has_null = false; - - iters.iter_mut().enumerate().for_each(|(i, it)| { - if i > 0 { - buf.push_str(delimiter); - } - - match it.next() { - Some(Some(s)) => buf.push_str(s), - Some(None) => has_null = true, - None => { - // should not happen as the out loop counts to length - unreachable!() - }, - } - }); - - if has_null { - builder.append_null(); - } else { - builder.append_value(&buf) - } - buf.truncate(0) - } - Ok(builder.finish()) -} - -/// Concat `[DataFrame]`s horizontally. +/// Concat [`DataFrame`]s horizontally. #[cfg(feature = "horizontal_concat")] /// Concat horizontally and extend with null values if lengths don't match -pub fn hor_concat_df(dfs: &[DataFrame]) -> PolarsResult { +pub fn concat_df_horizontal(dfs: &[DataFrame]) -> PolarsResult { let max_len = dfs .iter() .map(|df| df.height()) @@ -222,10 +95,10 @@ pub fn hor_concat_df(dfs: &[DataFrame]) -> PolarsResult { Ok(first_df) } -/// Concat `[DataFrame]`s diagonally. +/// Concat [`DataFrame`]s diagonally. #[cfg(feature = "diagonal_concat")] /// Concat diagonally thereby combining different schemas. -pub fn diag_concat_df(dfs: &[DataFrame]) -> PolarsResult { +pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { // TODO! replace with lazy only? let upper_bound_width = dfs.iter().map(|df| df.width()).sum(); let mut column_names = AHashSet::with_capacity(upper_bound_width); @@ -267,11 +140,11 @@ mod test { fn test_cov() { let a = Series::new("a", &[1.0f32, 2.0, 5.0]); let b = Series::new("b", &[1.0f32, 2.0, -3.0]); - let out = cov_f(a.f32().unwrap(), b.f32().unwrap()); + let out = cov(a.f32().unwrap(), b.f32().unwrap()); assert_eq!(out, Some(-5.0)); let a = a.cast(&DataType::Int32).unwrap(); let b = b.cast(&DataType::Int32).unwrap(); - let out = cov_i(a.i32().unwrap(), b.i32().unwrap()); + let out = cov(a.i32().unwrap(), b.i32().unwrap()); assert_eq!(out, Some(-5.0)); } @@ -279,27 +152,8 @@ mod test { fn test_pearson_corr() { let a = Series::new("a", &[1.0f32, 2.0]); let b = Series::new("b", &[1.0f32, 2.0]); - assert!((cov_f(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 0.5).abs() < 0.001); - assert!( - (pearson_corr_f(a.f32().unwrap(), b.f32().unwrap(), 1).unwrap() - 1.0).abs() < 0.001 - ); - } - - #[test] - #[cfg(feature = "concat_str")] - fn test_concat_str() { - let a = Series::new("a", &["foo", "bar"]); - let b = Series::new("b", &["spam", "ham"]); - - let out = concat_str(&[a.clone(), b.clone()], "_").unwrap(); - assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]); - - let c = Series::new("b", &["literal"]); - let out = concat_str(&[a, b, c], "_").unwrap(); - assert_eq!( - Vec::from(&out), - &[Some("foo_spam_literal"), Some("bar_ham_literal")] - ); + assert!((cov(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 0.5).abs() < 0.001); + assert!((pearson_corr(a.f32().unwrap(), b.f32().unwrap(), 1).unwrap() - 1.0).abs() < 0.001); } #[test] @@ -321,7 +175,7 @@ mod test { "d" => [1, 2] ]?; - let out = diag_concat_df(&[a, b, c])?; + let out = concat_df_diagonal(&[a, b, c])?; let expected = df![ "a" => [Some(1), Some(2), None, None, Some(5), Some(7)], diff --git a/crates/polars-core/src/hashing/identity.rs b/crates/polars-core/src/hashing/identity.rs index 40e27d8ab0d6..2b09b27c76ad 100644 --- a/crates/polars-core/src/hashing/identity.rs +++ b/crates/polars-core/src/hashing/identity.rs @@ -42,11 +42,11 @@ pub type IdBuildHasher = BuildHasherDefault; /// Contains an idx of a row in a DataFrame and the precomputed hash of that row. /// That hash still needs to be used to create another hash to be able to resize hashmaps without /// accidental quadratic behavior. So do not use an Identity function! -pub(crate) struct IdxHash { +pub struct IdxHash { // idx in row of Series, DataFrame - pub(crate) idx: IdxSize, + pub idx: IdxSize, // precomputed hash of T - pub(crate) hash: u64, + pub hash: u64, } impl Hash for IdxHash { diff --git a/crates/polars-core/src/hashing/mod.rs b/crates/polars-core/src/hashing/mod.rs index 110ead6db68e..ab9ba14788a7 100644 --- a/crates/polars-core/src/hashing/mod.rs +++ b/crates/polars-core/src/hashing/mod.rs @@ -7,6 +7,8 @@ use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; use ahash::RandomState; pub use fx::*; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; pub use identity::*; pub(crate) use partition::*; pub use vector_hasher::*; @@ -18,3 +20,76 @@ use crate::prelude::*; pub fn _boost_hash_combine(l: u64, r: u64) -> u64 { l ^ r.wrapping_add(0x9e3779b9u64.wrapping_add(l << 6).wrapping_add(r >> 2)) } + +// We must strike a balance between cache +// Overallocation seems a lot more expensive than resizing so we start reasonable small. +pub const _HASHMAP_INIT_SIZE: usize = 512; + +/// Utility function used as comparison function in the hashmap. +/// The rationale is that equality is an AND operation and therefore its probability of success +/// declines rapidly with the number of keys. Instead of first copying an entire row from both +/// sides and then do the comparison, we do the comparison value by value catching early failures +/// eagerly. +/// +/// # Safety +/// Doesn't check any bounds +#[inline] +pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usize) -> bool { + for s in keys.get_columns() { + if !s.equal_element(idx_a, idx_b, s) { + return false; + } + } + true +} + +/// Populate a multiple key hashmap with row indexes. +/// Instead of the keys (which could be very large), the row indexes are stored. +/// To check if a row is equal the original DataFrame is also passed as ref. +/// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared +/// on equality. +pub fn populate_multiple_key_hashmap( + hash_tbl: &mut HashMap, + // row index + idx: IdxSize, + // hash + original_h: u64, + // keys of the hash table (will not be inserted, the indexes will be used) + // the keys are needed for the equality check + keys: &DataFrame, + // value to insert + vacant_fn: G, + // function that gets a mutable ref to the occupied value in the hash table + mut occupied_fn: F, +) where + G: Fn() -> V, + F: FnMut(&mut V), + H: BuildHasher, +{ + let entry = hash_tbl + .raw_entry_mut() + // uses the idx to probe rows in the original DataFrame with keys + // to check equality to find an entry + // this does not invalidate the hashmap as this equality function is not used + // during rehashing/resize (then the keys are already known to be unique). + // Only during insertion and probing an equality function is needed + .from_hash(original_h, |idx_hash| { + // first check the hash values + // before we incur a cache miss + idx_hash.hash == original_h && { + let key_idx = idx_hash.idx; + // Safety: + // indices in a group_by operation are always in bounds. + unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) } + } + }); + match entry { + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn()); + }, + RawEntryMut::Occupied(mut entry) => { + let (_k, v) = entry.get_key_value_mut(); + occupied_fn(v); + }, + } +} diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 4aa7b08a527c..0d5c5b64ec9a 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -1,9 +1,6 @@ use arrow::bitmap::utils::get_bit_unchecked; -use hashbrown::hash_map::RawEntryMut; -use hashbrown::HashMap; #[cfg(feature = "group_by_list")] use polars_arrow::kernels::list_bytes_iter::numeric_list_bytes_iter; -use polars_arrow::utils::CustomIterTools; use rayon::prelude::*; use xxhash_rust::xxh3::xxh3_64_with_seed; @@ -45,12 +42,8 @@ pub(crate) const fn folded_multiply(s: u64, by: u64) -> u64 { pub(crate) fn get_null_hash_value(random_state: RandomState) -> u64 { // we just start with a large prime number and hash that twice // to get a constant hash value for null/None - let mut hasher = random_state.build_hasher(); - 3188347919usize.hash(&mut hasher); - let first = hasher.finish(); - let mut hasher = random_state.build_hasher(); - first.hash(&mut hasher); - hasher.finish() + let first = random_state.hash_one(3188347919usize); + random_state.hash_one(first) } fn insert_null_hash(chunks: &[ArrayRef], random_state: RandomState, buf: &mut Vec) { @@ -395,13 +388,8 @@ where buf.clear(); buf.reserve(self.len()); - self.downcast_iter().for_each(|arr| { - buf.extend(arr.into_iter().map(|opt_v| { - let mut hasher = random_state.build_hasher(); - opt_v.hash(&mut hasher); - hasher.finish() - })) - }); + self.downcast_iter() + .for_each(|arr| buf.extend(arr.into_iter().map(|opt_v| random_state.hash_one(opt_v)))); Ok(()) } @@ -409,9 +397,8 @@ where fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { self.apply_to_slice( |opt_v, h| { - let mut hasher = random_state.build_hasher(); - opt_v.hash(&mut hasher); - _boost_hash_combine(hasher.finish(), *h) + let hashed = random_state.hash_one(opt_v); + _boost_hash_combine(hashed, *h) }, hashes, ); @@ -423,7 +410,7 @@ where /// During rehashes, we will rehash the hash instead of the string, that makes rehashing /// cheap and allows cache coherent small hash tables. #[derive(Eq, Copy, Clone, Debug)] -pub(crate) struct BytesHash<'a> { +pub struct BytesHash<'a> { payload: Option<&'a [u8]>, pub(super) hash: u64, } @@ -436,7 +423,7 @@ impl<'a> Hash for BytesHash<'a> { impl<'a> BytesHash<'a> { #[inline] - pub(crate) fn new(s: Option<&'a [u8]>, hash: u64) -> Self { + pub fn new(s: Option<&'a [u8]>, hash: u64) -> Self { Self { payload: s, hash } } } @@ -448,98 +435,7 @@ impl<'a> PartialEq for BytesHash<'a> { } } -pub(crate) fn prepare_hashed_relation_threaded( - iters: Vec, -) -> Vec), RandomState>> -where - I: Iterator + Send + TrustedLen, - T: Send + Hash + Eq + Sync + Copy, -{ - let n_partitions = iters.len(); - assert!(n_partitions.is_power_of_two()); - - let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None); - - // We will create a hashtable in every thread. - // We use the hash to partition the keys to the matching hashtable. - // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions) - .into_par_iter() - .map(|partition_no| { - let build_hasher = build_hasher.clone(); - let hashes_and_keys = &hashes_and_keys; - let partition_no = partition_no as u64; - let mut hash_tbl: HashMap), RandomState> = - HashMap::with_hasher(build_hasher); - - let n_threads = n_partitions as u64; - let mut offset = 0; - for hashes_and_keys in hashes_and_keys { - let len = hashes_and_keys.len(); - hashes_and_keys - .iter() - .enumerate() - .for_each(|(idx, (h, k))| { - let idx = idx as IdxSize; - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if this_partition(*h, partition_no, n_threads) { - let idx = idx + offset; - let entry = hash_tbl - .raw_entry_mut() - // uses the key to check equality to find and entry - .from_key_hashed_nocheck(*h, k); - - match entry { - RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(*h, *k, (false, vec![idx])); - }, - RawEntryMut::Occupied(mut entry) => { - let (_k, v) = entry.get_key_value_mut(); - v.1.push(idx); - }, - } - } - }); - - offset += len as IdxSize; - } - hash_tbl - }) - .collect() - }) -} - -pub(crate) fn create_hash_and_keys_threaded_vectorized( - iters: Vec, - build_hasher: Option, -) -> (Vec>, RandomState) -where - I: IntoIterator + Send, - I::IntoIter: TrustedLen, - T: Send + Hash + Eq, -{ - let build_hasher = build_hasher.unwrap_or_default(); - let hashes = POOL.install(|| { - iters - .into_par_iter() - .map(|iter| { - // create hashes and keys - iter.into_iter() - .map(|val| { - let mut hasher = build_hasher.build_hasher(); - val.hash(&mut hasher); - (hasher.finish(), val) - }) - .collect_trusted::>() - }) - .collect() - }); - (hashes, build_hasher) -} - -pub(crate) fn df_rows_to_hashes_threaded_vertical( +pub fn _df_rows_to_hashes_threaded_vertical( keys: &[DataFrame], hasher_builder: Option, ) -> PolarsResult<(Vec, RandomState)> { diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index 17ead65b8daa..954d9fd09ea1 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -1,20 +1,14 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![cfg_attr(feature = "simd", feature(portable_simd))] #![allow(ambiguous_glob_reexports)] -#![cfg_attr( - feature = "nightly", - allow(clippy::incorrect_partial_ord_impl_on_ord_type) -)] // remove once stable +#![cfg_attr(feature = "nightly", allow(clippy::non_canonical_partial_ord_impl))] // remove once stable extern crate core; #[macro_use] pub mod utils; pub mod chunked_array; -pub mod cloud; pub mod config; pub mod datatypes; -#[cfg(feature = "docs")] -pub mod doc; pub mod error; pub mod export; pub mod fmt; @@ -41,7 +35,7 @@ use once_cell::sync::Lazy; use rayon::{ThreadPool, ThreadPoolBuilder}; #[cfg(feature = "dtype-categorical")] -pub use crate::chunked_array::logical::categorical::stringcache::*; +pub use crate::chunked_array::logical::categorical::string_cache::*; pub static PROCESS_ID: Lazy = Lazy::new(|| { SystemTime::now() diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index e80e37899f85..39021328d290 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -3,6 +3,7 @@ pub use std::sync::Arc; pub(crate) use arrow::array::*; pub use arrow::datatypes::{Field as ArrowField, Schema as ArrowSchema}; +pub(crate) use arrow::util::total_ord::{TotalEq, TotalOrd}; pub(crate) use polars_arrow::export::*; #[cfg(feature = "ewma")] pub use polars_arrow::kernels::ewm::EWMOptions; @@ -14,6 +15,7 @@ pub use crate::chunked_array::builder::{ ListBooleanChunkedBuilder, ListBuilderTrait, ListPrimitiveChunkedBuilder, ListUtf8ChunkedBuilder, NewChunkedArray, PrimitiveChunkedBuilder, Utf8ChunkedBuilder, }; +pub use crate::chunked_array::collect::{ChunkedCollectInferIterExt, ChunkedCollectIterExt}; pub use crate::chunked_array::iterator::PolarsIterator; #[cfg(feature = "dtype-categorical")] pub use crate::chunked_array::logical::categorical::*; @@ -24,24 +26,20 @@ pub use crate::chunked_array::object::PolarsObject; pub use crate::chunked_array::ops::aggregate::*; #[cfg(feature = "rolling_window")] pub use crate::chunked_array::ops::rolling_window::RollingOptionsFixedWindow; -#[cfg(feature = "rank")] -pub use crate::chunked_array::ops::unique::rank::{RankMethod, RankOptions}; pub use crate::chunked_array::ops::*; #[cfg(feature = "temporal")] pub use crate::chunked_array::temporal::conversion::*; pub(crate) use crate::chunked_array::ChunkIdIter; pub use crate::chunked_array::ChunkedArray; -pub use crate::datatypes::*; +pub use crate::datatypes::{ArrayCollectIterExt, *}; pub use crate::error::{ polars_bail, polars_ensure, polars_err, polars_warn, PolarsError, PolarsResult, }; -#[cfg(feature = "asof_join")] -pub use crate::frame::asof_join::*; pub use crate::frame::explode::MeltArgs; +#[cfg(feature = "algorithm_group_by")] pub(crate) use crate::frame::group_by::aggregations::*; +#[cfg(feature = "algorithm_group_by")] pub use crate::frame::group_by::{GroupsIdx, GroupsProxy, GroupsSlice, IntoGroupsProxy}; -pub(crate) use crate::frame::hash_join::*; -pub use crate::frame::hash_join::{JoinArgs, JoinType}; pub use crate::frame::{DataFrame, UniqueKeepStrategy}; pub use crate::hashing::{FxHash, VecHash}; pub use crate::named_from::{NamedFrom, NamedFromOwned}; @@ -53,4 +51,4 @@ pub use crate::series::{IntoSeries, Series, SeriesTrait}; pub use crate::testing::*; pub(crate) use crate::utils::CustomIterTools; pub use crate::utils::IntoVec; -pub use crate::{cloud, datatypes, df}; +pub use crate::{datatypes, df}; diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index 266e36935af6..350cd10944a1 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -9,7 +9,7 @@ use smartstring::alias::String as SmartString; use crate::prelude::*; use crate::utils::try_get_supertype; -/// A map from field/column name (`String`) to the type of that field/column (`DataType`) +/// A map from field/column name ([`String`](smartstring::alias::String)) to the type of that field/column ([`DataType`]) #[derive(Eq, Clone, Default)] #[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] pub struct Schema { @@ -17,7 +17,7 @@ pub struct Schema { } // Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner == -// other.inner` because IndexMap ignores order when checking equality, but we don't want to ignore it. +// other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it. impl PartialEq for Schema { fn eq(&self, other: &Self) -> bool { self.len() == other.len() && self.iter().zip(other.iter()).all(|(a, b)| a == b) @@ -34,6 +34,12 @@ impl Debug for Schema { } } +impl From<&[Series]> for Schema { + fn from(value: &[Series]) -> Self { + value.iter().map(|s| s.field().into_owned()).collect() + } +} + impl FromIterator for Schema where F: Into, @@ -129,7 +135,7 @@ impl Schema { ) -> PolarsResult { polars_ensure!( index <= self.len(), - ComputeError: + OutOfBounds: "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", index, self.len() @@ -167,7 +173,7 @@ impl Schema { ) -> PolarsResult> { polars_ensure!( index <= self.len(), - ComputeError: + OutOfBounds: "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", index, self.len() @@ -360,9 +366,9 @@ impl Schema { ArrowSchema::from(fields) } - /// Iterates the `Field`s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair + /// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair /// - /// Note that this clones each name and dtype in order to form an owned `Field`. For a clone-free version, use + /// Note that this clones each name and dtype in order to form an owned [`Field`]. For a clone-free version, use /// [`iter`][Self::iter], which returns `(&name, &dtype)`. pub fn iter_fields(&self) -> impl Iterator + ExactSizeIterator + '_ { self.inner diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 23e56cdcd9ba..fb69d543661a 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -4,6 +4,7 @@ use std::fmt::Formatter; use serde::de::{MapAccess, Visitor}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use crate::chunked_array::builder::AnonymousListBuilder; use crate::chunked_array::Settings; use crate::prelude::*; @@ -202,13 +203,15 @@ impl<'de> Deserialize<'de> for Series { let values: Vec>> = map.next_value()?; Ok(Series::new(&name, values)) }, - DataType::List(_) => { + DataType::List(inner) => { let values: Vec> = map.next_value()?; - if values.is_empty() { - Ok(Series::new_empty(&name, &dtype)) - } else { - Ok(Series::new(&name, values)) + let mut lb = AnonymousListBuilder::new(&name, values.len(), Some(*inner)); + for value in &values { + lb.append_opt_series(value.as_ref()).map_err(|e| { + de::Error::custom(format!("could not append series to list: {e}")) + })?; } + Ok(lb.finish().into_series()) }, DataType::Binary => { let values: Vec>> = map.next_value()?; diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 9bad277aa970..48b0f41e34c0 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -76,33 +76,33 @@ fn any_values_to_decimal( ComputeError: "unable to losslessly convert any-value of scale {s_max} to scale {}", scale, ); - } else if s_min == s_max && s_max == scale { - // no conversions needed; will potentially check values for precision though - any_values_to_primitive::(avs).into_decimal(precision, scale) - } else { - // rescaling is needed - let mut builder = PrimitiveChunkedBuilder::::new("", avs.len()); - for av in avs { - let (v, s_av) = if av.is_signed() || av.is_unsigned() { - ( - av.try_extract::().unwrap_or_else(|_| unreachable!()), - 0, - ) - } else if let AnyValue::Decimal(v, scale) = av { - (*v, *scale) - } else { - // it has to be a null because we've already checked it - builder.append_null(); - continue; - }; + } + let mut builder = PrimitiveChunkedBuilder::::new("", avs.len()); + let is_equally_scaled = s_min == s_max && s_max == scale; + for av in avs { + let (v, s_av) = if av.is_signed() || av.is_unsigned() { + ( + av.try_extract::().unwrap_or_else(|_| unreachable!()), + 0, + ) + } else if let AnyValue::Decimal(v, scale) = av { + (*v, *scale) + } else { + // it has to be a null because we've already checked it + builder.append_null(); + continue; + }; + if is_equally_scaled { + builder.append_value(v); + } else { let factor = 10_i128.pow((scale - s_av) as _); // this cast is safe builder.append_value(v.checked_mul(factor).ok_or_else(|| { polars_err!(ComputeError: "overflow while converting to decimal scale {}", scale) })?); } - // build the array and do a precision check if needed - builder.finish().into_decimal(precision, scale) } + // build the array and do a precision check if needed + builder.finish().into_decimal(precision, scale) } fn any_values_to_binary(avs: &[AnyValue]) -> BinaryChunked { diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index bb04ddc9c976..6bcda0bfec29 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -177,10 +177,8 @@ pub mod checked { // see check_div for chunkedarray let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; - Ok(arity::binary_elementwise::<_, _, Float32Type, _, _>( - lhs, - rhs, - |opt_l, opt_r| match (opt_l, opt_r) { + let ca: Float32Chunked = + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { (Some(l), Some(r)) => { if r.is_zero() { None @@ -189,9 +187,8 @@ pub mod checked { } }, _ => None, - }, - ) - .into_series()) + }); + Ok(ca.into_series()) } } @@ -201,10 +198,8 @@ pub mod checked { // see check_div let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; - Ok(arity::binary_elementwise::<_, _, Float64Type, _, _>( - lhs, - rhs, - |opt_l, opt_r| match (opt_l, opt_r) { + let ca: Float64Chunked = + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { (Some(l), Some(r)) => { if r.is_zero() { None @@ -213,9 +208,8 @@ pub mod checked { } }, _ => None, - }, - ) - .into_series()) + }); + Ok(ca.into_series()) } } diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 3c4876d75e80..5b639e309b26 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -4,7 +4,8 @@ use arrow::compute::cast::utf8_to_large_utf8; #[cfg(any( feature = "dtype-date", feature = "dtype-datetime", - feature = "dtype-time" + feature = "dtype-time", + feature = "dtype-duration" ))] use arrow::temporal_conversions::*; use polars_arrow::compute::cast::cast; @@ -97,7 +98,9 @@ impl Series { Float32 => Float32Chunked::from_chunks(name, chunks).into_series(), Float64 => Float64Chunked::from_chunks(name, chunks).into_series(), #[cfg(feature = "dtype-struct")] - Struct(_) => Series::try_from_arrow_unchecked(name, chunks, &dtype.to_arrow()).unwrap(), + Struct(_) => { + Series::_try_from_arrow_unchecked(name, chunks, &dtype.to_arrow()).unwrap() + }, #[cfg(feature = "object")] Object(_) => { assert_eq!(chunks.len(), 1); @@ -123,10 +126,11 @@ impl Series { } } - // Create a new Series without checking if the inner dtype of the chunks is correct - // # Safety - // The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes. - pub(crate) unsafe fn try_from_arrow_unchecked( + /// Create a new Series without checking if the inner dtype of the chunks is correct + /// + /// # Safety + /// The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes. + pub unsafe fn _try_from_arrow_unchecked( name: &str, chunks: Vec, dtype: &ArrowDataType, @@ -383,7 +387,7 @@ impl Series { .iter() .zip(dtype_fields) .map(|(arr, field)| { - Series::try_from_arrow_unchecked( + Series::_try_from_arrow_unchecked( &field.name, vec![arr.clone()], &field.data_type, @@ -485,7 +489,7 @@ fn convert ArrayRef>(arr: &[ArrayRef], f: F) -> Vec) -> (Vec, DataType) { +unsafe fn to_physical_and_dtype(arrays: Vec) -> (Vec, DataType) { match arrays[0].data_type() { ArrowDataType::Utf8 => ( convert(&arrays, |arr| { @@ -499,7 +503,7 @@ fn to_physical_and_dtype(arrays: Vec) -> (Vec, DataType) { feature_gated!("dtype-categorical", { let s = unsafe { let dt = dt.clone(); - Series::try_from_arrow_unchecked("", arrays, &dt) + Series::_try_from_arrow_unchecked("", arrays, &dt) } .unwrap(); (s.chunks().clone(), s.dtype().clone()) @@ -609,6 +613,19 @@ fn to_physical_and_dtype(arrays: Vec) -> (Vec, DataType) { (vec![arrow_array], DataType::Struct(polars_fields)) }) }, + // Use Series architecture to convert nested logical types to physical. + dt @ (ArrowDataType::Duration(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32 + | ArrowDataType::Decimal(_, _) + | ArrowDataType::Date64) => { + let dt = dt.clone(); + let mut s = Series::_try_from_arrow_unchecked("", arrays, &dt).unwrap(); + let dtype = s.dtype().clone(); + (std::mem::take(s.chunks_mut()), dtype) + }, dt => { let dtype = dt.into(); (arrays, dtype) @@ -638,7 +655,7 @@ impl TryFrom<(&str, Vec)> for Series { } // Safety: // dtype is checked - unsafe { Series::try_from_arrow_unchecked(name, chunks, &data_type) } + unsafe { Series::_try_from_arrow_unchecked(name, chunks, &data_type) } } } diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index c142a9a8fbe5..698003b873ea 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -4,7 +4,10 @@ use std::borrow::Cow; use super::{private, IntoSeries, SeriesTrait}; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::{AsSinglePtr, Settings}; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -43,10 +46,12 @@ impl private::PrivateSeries for SeriesWrap { ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -104,38 +109,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(ChunkTake::take_unchecked(&self.0, (&*idx).into()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 53388a0f0843..0aca6335deee 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -9,8 +9,8 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -59,17 +59,11 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples) - } fn subtract(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::subtract(&self.0, rhs) } @@ -85,6 +79,7 @@ impl private::PrivateSeries for SeriesWrap { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -148,47 +143,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - - let mut out = ChunkTake::take_unchecked(&self.0, (&*idx).into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -232,14 +199,17 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } @@ -276,14 +246,4 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - RepeatBy::repeat_by(&self.0, by) - } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(self.0.mode()?.into_series()) - } } diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 9b713062ca98..4aff9bca60d5 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -10,8 +10,8 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::{AsSinglePtr, ChunkIdIter}; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -60,27 +60,33 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.0.agg_sum(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { self.0 .cast(&DataType::Float64) .unwrap() .agg_std(groups, _ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { self.0 .cast(&DataType::Float64) @@ -88,13 +94,7 @@ impl private::PrivateSeries for SeriesWrap { .agg_var(groups, _ddof) } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples) - } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -179,38 +179,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(ChunkTake::take_unchecked(&self.0, (&*idx).into()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -254,14 +235,17 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } @@ -330,14 +314,4 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - RepeatBy::repeat_by(&self.0, by) - } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(self.0.mode()?.into_series()) - } } diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index b386bced5131..9c8ecd6236a4 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -8,8 +8,8 @@ use crate::chunked_array::comparison::*; use crate::chunked_array::ops::compare_inner::{IntoPartialOrdInner, PartialOrdInner}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -105,6 +105,7 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect let list = self.0.logical().agg_list(groups); @@ -113,30 +114,7 @@ impl private::PrivateSeries for SeriesWrap { list.into_series() } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let new_rev_map = self - .0 - .merge_categorical_map(right_column.categorical().unwrap()) - .unwrap(); - let left = self.0.logical(); - let right = right_column - .categorical() - .unwrap() - .logical() - .clone() - .into_series(); - - let cats = left.zip_outer_join_column(&right, opt_join_tuples); - let cats = cats.u32().unwrap().clone(); - - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked(cats, new_rev_map).into_series() - } - } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { #[cfg(feature = "performant")] { @@ -189,7 +167,7 @@ impl SeriesTrait for SeriesWrap { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.categorical()?; self.0.logical_mut().extend(other.logical()); - let new_rev_map = self.0.merge_categorical_map(other)?; + let new_rev_map = self.0._merge_categorical_map(other)?; // SAFETY // rev_maps are merged unsafe { self.0.set_rev_map(new_rev_map, false) }; @@ -214,45 +192,23 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - self.try_with_state(false, |cats| cats.take((&*indices).into())) + self.try_with_state(false, |cats| cats.take(indices)) .map(|ca| ca.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - let cats = self.0.logical().take(iter.into())?; - Ok(self.finish_with_state(false, cats).into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - let cats = self.0.logical().take_unchecked(iter.into()); - self.finish_with_state(false, cats).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(self - .with_state(false, |cats| cats.take_unchecked((&*idx).into())) - .into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.with_state(false, |cats| cats.take_unchecked(indices)) + .into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - let cats = self.0.logical().take_unchecked(iter.into()); - self.finish_with_state(false, cats).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + self.try_with_state(false, |cats| cats.take(indices)) + .map(|ca| ca.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - let cats = self.0.logical().take(iter.into())?; - Ok(self.finish_with_state(false, cats).into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.with_state(false, |cats| cats.take_unchecked(indices)) + .into_series() } fn len(&self) -> usize { @@ -297,14 +253,17 @@ impl SeriesTrait for SeriesWrap { self.0.logical().has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.logical().arg_unique() } @@ -358,21 +317,6 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - let out = self.0.logical().repeat_by(by)?; - let casted = out - .cast(&DataType::List(Box::new(self.dtype().clone()))) - .unwrap(); - Ok(casted.list().unwrap().clone()) - } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - let cats = self.0.logical().mode()?; - Ok(self.finish_with_state(false, cats).into_series()) - } } impl private::PrivateSeriesNumeric for SeriesWrap { diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index cb9f33ea1c29..00c14b265ff1 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -17,8 +17,8 @@ use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::ops::ToBitRepr; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::*; use crate::prelude::*; macro_rules! impl_dyn_series { @@ -90,14 +90,17 @@ macro_rules! impl_dyn_series { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups).$into_logical().into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups).$into_logical().into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -106,18 +109,6 @@ macro_rules! impl_dyn_series { .unwrap() } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let right_column = right_column.to_physical_repr().into_owned(); - self.0 - .zip_outer_join_column(&right_column, opt_join_tuples) - .$into_logical() - .into_series() - } - fn subtract(&self, rhs: &Series) -> PolarsResult { match (self.dtype(), rhs.dtype()) { (DataType::Date, DataType::Date) => { @@ -153,6 +144,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { polars_bail!(opq = rem, self.0.dtype(), rhs.dtype()); } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -238,43 +230,19 @@ macro_rules! impl_dyn_series { } fn take(&self, indices: &IdxCa) -> PolarsResult { - ChunkTake::take(self.0.deref(), indices.into()) - .map(|ca| ca.$into_logical().into_series()) + Ok(self.0.take(indices)?.$into_logical().into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()) - .map(|ca| ca.$into_logical().into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).$into_logical().into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) - .$into_logical() - .into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = ChunkTake::take_unchecked(self.0.deref(), idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.$into_logical().into_series()) + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.$into_logical().into_series()) } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) - .$into_logical() - .into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()) - .map(|ca| ca.$into_logical().into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).$into_logical().into_series() } fn len(&self) -> usize { @@ -294,6 +262,7 @@ macro_rules! impl_dyn_series { fn cast(&self, data_type: &DataType) -> PolarsResult { match (self.dtype(), data_type) { + #[cfg(feature="dtype-date")] (DataType::Date, DataType::Utf8) => Ok(self .0 .clone() @@ -302,6 +271,7 @@ macro_rules! impl_dyn_series { .unwrap() .to_string("%Y-%m-%d") .into_series()), + #[cfg(feature="dtype-time")] (DataType::Time, DataType::Utf8) => Ok(self .0 .clone() @@ -352,14 +322,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } +#[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| ca.$into_logical().into_series()) } +#[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } +#[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } @@ -428,40 +401,6 @@ macro_rules! impl_dyn_series { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - - fn peak_max(&self) -> BooleanChunked { - self.0.peak_max() - } - - fn peak_min(&self) -> BooleanChunked { - self.0.peak_min() - } - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - match self.0.dtype() { - DataType::Date => Ok(self - .0 - .repeat_by(by)? - .cast(&DataType::List(Box::new(DataType::Date))) - .unwrap() - .list() - .unwrap() - .clone()), - DataType::Time => Ok(self - .0 - .repeat_by(by)? - .cast(&DataType::List(Box::new(DataType::Time))) - .unwrap() - .list() - .unwrap() - .clone()), - _ => unreachable!(), - } - } - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - self.0.mode().map(|ca| ca.$into_logical().into_series()) - } } }; } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 1f4095a8b003..59b3bce8a5e2 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -6,8 +6,8 @@ use ahash::RandomState; use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::*; use crate::prelude::*; unsafe impl IntoSeries for DatetimeChunked { @@ -84,6 +84,7 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0 .agg_min(groups) @@ -91,12 +92,14 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0 .agg_max(groups) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -105,17 +108,6 @@ impl private::PrivateSeries for SeriesWrap { .unwrap() } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let right_column = right_column.to_physical_repr().into_owned(); - self.0 - .zip_outer_join_column(&right_column, opt_join_tuples) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - } fn subtract(&self, rhs: &Series) -> PolarsResult { match (self.dtype(), rhs.dtype()) { (DataType::Datetime(tu, tz), DataType::Datetime(tur, tzr)) => { @@ -160,6 +152,7 @@ impl private::PrivateSeries for SeriesWrap { fn remainder(&self, rhs: &Series) -> PolarsResult { polars_bail!(opq = rem, self.dtype(), rhs.dtype()); } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -243,53 +236,31 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - ChunkTake::take(self.0.deref(), indices.into()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) + let ca = self.0.take(indices)?; + Ok(ca + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + let ca = self.0.take_unchecked(indices); + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = ChunkTake::take_unchecked(self.0.deref(), idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + let ca = self.0.take(indices)?; + Ok(ca .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series()) } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + let ca = self.0.take_unchecked(indices); + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) - } - fn len(&self) -> usize { self.0.len() } @@ -351,6 +322,7 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| { ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) @@ -358,10 +330,12 @@ impl SeriesTrait for SeriesWrap { }) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } @@ -435,33 +409,4 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - - fn peak_max(&self) -> BooleanChunked { - self.0.peak_max() - } - - fn peak_min(&self) -> BooleanChunked { - self.0.peak_min() - } - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - Ok(self - .0 - .repeat_by(by)? - .cast(&DataType::List(Box::new(DataType::Datetime( - self.0.time_unit(), - self.0.time_zone().clone(), - )))) - .unwrap() - .list() - .unwrap() - .clone()) - } - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - self.0.mode().map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) - } } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 5c2127b86d91..0b1d6954886b 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -67,18 +67,22 @@ impl private::PrivateSeries for SeriesWrap { .into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.agg_helper(|ca| ca.agg_sum(groups)) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.agg_helper(|ca| ca.agg_min(groups)) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.agg_helper(|ca| ca.agg_max(groups)) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } @@ -158,46 +162,36 @@ impl SeriesTrait for SeriesWrap { self.apply_physical(|ca| ca.take_opt_chunked_unchecked(by)) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()).map(|ca| { - ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() - }) + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0 + .take_unchecked(indices) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = ChunkTake::take_unchecked(self.0.deref(), idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self + .0 + .take(indices)? .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .take_unchecked(indices) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } - fn take(&self, indices: &IdxCa) -> PolarsResult { - ChunkTake::take(self.0.deref(), indices.into()).map(|ca| { - ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() - }) - } - fn len(&self) -> usize { self.0.len() } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 25f07c2def40..834c6c57c181 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -7,8 +7,8 @@ use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::*; use crate::prelude::*; unsafe impl IntoSeries for DurationChunked { @@ -89,6 +89,7 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0 .agg_min(groups) @@ -96,6 +97,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0 .agg_max(groups) @@ -103,6 +105,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.0 .agg_sum(groups) @@ -110,6 +113,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0 .agg_std(groups, ddof) @@ -120,6 +124,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0 .agg_var(groups, ddof) @@ -130,6 +135,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -138,17 +144,6 @@ impl private::PrivateSeries for SeriesWrap { .unwrap() } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let right_column = right_column.to_physical_repr().into_owned(); - self.0 - .zip_outer_join_column(&right_column, opt_join_tuples) - .into_duration(self.0.time_unit()) - .into_series() - } fn subtract(&self, rhs: &Series) -> PolarsResult { match (self.dtype(), rhs.dtype()) { (DataType::Duration(tu), DataType::Duration(tur)) => { @@ -168,6 +163,24 @@ impl private::PrivateSeries for SeriesWrap { let rhs = rhs.cast(&DataType::Int64).unwrap(); Ok(lhs.add_to(&rhs)?.into_duration(*tu).into_series()) }, + (DataType::Duration(tu), DataType::Date) => { + let one_day_in_tu: i64 = match tu { + TimeUnit::Milliseconds => 86_400_000, + TimeUnit::Microseconds => 86_400_000_000, + TimeUnit::Nanoseconds => 86_400_000_000_000, + }; + let lhs = self.cast(&DataType::Int64).unwrap() / one_day_in_tu; + let rhs = rhs + .cast(&DataType::Int32) + .unwrap() + .cast(&DataType::Int64) + .unwrap(); + Ok(lhs + .add_to(&rhs)? + .cast(&DataType::Int32)? + .into_date() + .into_series()) + }, (DataType::Duration(tu), DataType::Datetime(tur, tz)) => { polars_ensure!(tu == tur, InvalidOperation: "units are different"); let lhs = self.cast(&DataType::Int64).unwrap(); @@ -197,6 +210,7 @@ impl private::PrivateSeries for SeriesWrap { .into_duration(self.0.time_unit()) .into_series()) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -277,45 +291,35 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - ChunkTake::take(self.0.deref(), indices.into()) - .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()) - .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) + Ok(self + .0 + .take(indices)? + .into_duration(self.0.time_unit()) + .into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0 + .take_unchecked(indices) .into_duration(self.0.time_unit()) .into_series() } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = ChunkTake::take_unchecked(self.0.deref(), idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_duration(self.0.time_unit()).into_series()) + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_duration(self.0.time_unit()) + .into_series()) } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(self.0.deref(), iter.into()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .take_unchecked(indices) .into_duration(self.0.time_unit()) .into_series() } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - ChunkTake::take(self.0.deref(), iter.into()) - .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) - } - fn len(&self) -> usize { self.0.len() } @@ -366,16 +370,19 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0 .unique() .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } @@ -439,31 +446,4 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - - fn peak_max(&self) -> BooleanChunked { - self.0.peak_max() - } - - fn peak_min(&self) -> BooleanChunked { - self.0.peak_min() - } - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - Ok(self - .0 - .repeat_by(by)? - .cast(&DataType::List(Box::new(DataType::Duration( - self.0.time_unit(), - )))) - .unwrap() - .list() - .unwrap() - .clone()) - } - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - self.0 - .mode() - .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) - } } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index f1ffd4475373..5c09473c7c26 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -11,8 +11,8 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; #[cfg(feature = "checked_arithmetic")] use crate::series::arithmetic::checked::NumOpsDispatchChecked; @@ -89,37 +89,36 @@ macro_rules! impl_dyn_series { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.0.agg_sum(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.agg_std(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.agg_var(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples) - } fn subtract(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::subtract(&self.0, rhs) } @@ -135,6 +134,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -214,47 +214,19 @@ macro_rules! impl_dyn_series { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - - let mut out = ChunkTake::take_unchecked(&self.0, (&*idx).into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -298,14 +270,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } @@ -360,28 +335,10 @@ macro_rules! impl_dyn_series { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - fn peak_max(&self) -> BooleanChunked { - self.0.peak_max() - } - - fn peak_min(&self) -> BooleanChunked { - self.0.peak_min() - } - - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - RepeatBy::repeat_by(&self.0, by) - } - #[cfg(feature = "checked_arithmetic")] fn checked_div(&self, rhs: &Series) -> PolarsResult { self.0.checked_div(rhs) } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(self.0.mode()?.into_series()) - } } }; } diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index c4361b284e33..0ae8b9ced37b 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -9,6 +9,7 @@ use crate::chunked_array::comparison::*; use crate::chunked_array::ops::compare_inner::{IntoPartialEqInner, PartialEqInner}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::{AsSinglePtr, Settings}; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -45,10 +46,12 @@ impl private::PrivateSeries for SeriesWrap { ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -125,38 +128,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(ChunkTake::take_unchecked(&self.0, (&*idx).into()).into_series()) - } - - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -193,6 +177,7 @@ impl SeriesTrait for SeriesWrap { } #[cfg(feature = "group_by_list")] + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { if !self.inner_dtype().is_numeric() { polars_bail!(opq = unique, self.dtype()); @@ -209,6 +194,7 @@ impl SeriesTrait for SeriesWrap { } #[cfg(feature = "group_by_list")] + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { if !self.inner_dtype().is_numeric() { polars_bail!(opq = n_unique, self.dtype()); @@ -226,6 +212,7 @@ impl SeriesTrait for SeriesWrap { } #[cfg(feature = "group_by_list")] + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { if !self.inner_dtype().is_numeric() { polars_bail!(opq = arg_unique, self.dtype()); diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 7d7c8c2f3f52..63ad25829894 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -40,9 +40,11 @@ use crate::chunked_array::ops::compare_inner::{ IntoPartialEqInner, IntoPartialOrdInner, PartialEqInner, PartialOrdInner, }; use crate::chunked_array::ops::explode::ExplodeByOffsets; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; #[cfg(feature = "checked_arithmetic")] use crate::series::arithmetic::checked::NumOpsDispatchChecked; @@ -152,14 +154,17 @@ macro_rules! impl_dyn_series { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { use DataType::*; match self.dtype() { @@ -168,25 +173,21 @@ macro_rules! impl_dyn_series { } } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0.agg_std(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0.agg_var(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples) - } fn subtract(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::subtract(&self.0, rhs) } @@ -202,6 +203,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -311,44 +313,19 @@ macro_rules! impl_dyn_series { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - let mut out = ChunkTake::take_unchecked(&self.0, (&*idx).into()); - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - Ok(out.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -392,14 +369,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } @@ -454,19 +434,6 @@ macro_rules! impl_dyn_series { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - fn peak_max(&self) -> BooleanChunked { - self.0.peak_max() - } - - fn peak_min(&self) -> BooleanChunked { - self.0.peak_min() - } - - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - RepeatBy::repeat_by(&self.0, by) - } - #[cfg(feature = "checked_arithmetic")] fn checked_div(&self, rhs: &Series) -> PolarsResult { self.0.checked_div(rhs) @@ -476,15 +443,6 @@ macro_rules! impl_dyn_series { fn as_any(&self) -> &dyn Any { &self.0 } - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(self.0.mode()?.into_series()) - } - - #[cfg(feature = "concat_str")] - fn str_concat(&self, delimiter: &str) -> Utf8Chunked { - self.0.str_concat(delimiter) - } fn tile(&self, n: usize) -> Series { self.0.tile(n).into_series() diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index a52e3472d66d..aeeeced2aab2 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::sync::Arc; use polars_arrow::prelude::ArrayRef; +use polars_error::constants::LENGTH_LIMIT_MSG; use polars_utils::IdxSize; use crate::datatypes::IdxCa; @@ -43,7 +44,14 @@ impl PrivateSeriesNumeric for NullChunked {} impl PrivateSeries for NullChunked { fn compute_len(&mut self) { - // no-op + fn inner(chunks: &[ArrayRef]) -> usize { + match chunks.len() { + // fast path + 1 => chunks[0].len(), + _ => chunks.iter().fold(0, |acc, arr| acc + arr.len()), + } + } + self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); } fn _field(&self) -> Cow { Cow::Owned(Field::new(self.name(), DataType::Null)) @@ -99,24 +107,20 @@ impl SeriesTrait for NullChunked { NullChunked::new(self.name.clone(), by.len()).into_series() } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(NullChunked::new(self.name.clone(), iter.size_hint().0).into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - NullChunked::new(self.name.clone(), iter.size_hint().0).into_series() + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - Ok(NullChunked::new(self.name.clone(), idx.len()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + NullChunked::new(self.name.clone(), indices.len()).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - NullChunked::new(self.name.clone(), iter.size_hint().0).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) } - fn take(&self, indices: &IdxCa) -> PolarsResult { - Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + NullChunked::new(self.name.clone(), indices.len()).into_series() } fn len(&self) -> usize { diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index 8015d7693e57..c58f37159d90 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -5,7 +5,10 @@ use ahash::RandomState; use crate::chunked_array::object::PolarsObjectSafe; use crate::chunked_array::ops::compare_inner::{IntoPartialEqInner, PartialEqInner}; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::Settings; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::{GroupsProxy, IntoGroupsProxy}; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -64,6 +67,7 @@ where Ok(()) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -132,33 +136,19 @@ where } fn take(&self, indices: &IdxCa) -> PolarsResult { - Ok(ChunkTake::take(&self.0, indices.into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(ChunkTake::take_unchecked(&self.0, (&*idx).into()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, _iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - todo!() + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index da183572d7b7..16d758fd5257 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -58,10 +58,12 @@ impl private::PrivateSeries for SeriesWrap { Ok(StructChunked::new_unchecked(self.0.name(), &fields).into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { let df = DataFrame::new_no_checks(vec![]); let gb = df @@ -176,16 +178,6 @@ impl SeriesTrait for SeriesWrap { .map(|ca| ca.into_series()) } - /// Take by index from an iterator. This operation clones the data. - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - self.0 - .try_apply_fields(|s| { - let mut iter = iter.boxed_clone(); - s.take_iter(&mut *iter) - }) - .map(|ca| ca.into_series()) - } - #[cfg(feature = "chunked_ids")] unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { self.0 @@ -200,53 +192,30 @@ impl SeriesTrait for SeriesWrap { .into_series() } - /// Take by index from an iterator. This operation clones the data. - /// - /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0 - .apply_fields(|s| { - let mut iter = iter.boxed_clone(); - s.take_iter_unchecked(&mut *iter) - }) - .into_series() - } - - /// Take by index if ChunkedArray contains a single chunk. - /// - /// # Safety - /// This doesn't check any bounds. - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { + fn take(&self, indices: &IdxCa) -> PolarsResult { self.0 - .try_apply_fields(|s| s.take_unchecked(idx)) + .try_apply_fields(|s| s.take(indices)) .map(|ca| ca.into_series()) } - /// Take by index from an iterator. This operation clones the data. - /// - /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { self.0 - .apply_fields(|s| { - let mut iter = iter.boxed_clone(); - s.take_opt_iter_unchecked(&mut *iter) - }) + .apply_fields(|s| s.take_unchecked(indices)) .into_series() } - /// Take by index. This operation is clone. - fn take(&self, indices: &IdxCa) -> PolarsResult { + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { self.0 - .try_apply_fields(|s| s.take(indices)) + .try_apply_fields(|s| s.take_slice(indices)) .map(|ca| ca.into_series()) } + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .apply_fields(|s| s.take_slice_unchecked(indices)) + .into_series() + } + /// Get length of series. fn len(&self) -> usize { self.0.len() @@ -283,6 +252,7 @@ impl SeriesTrait for SeriesWrap { } /// Get unique values in the Series. + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot if self.len() < 2 { @@ -296,6 +266,7 @@ impl SeriesTrait for SeriesWrap { } /// Get unique values in the Series. + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot match self.len() { @@ -311,6 +282,7 @@ impl SeriesTrait for SeriesWrap { } /// Get first indexes of unique values. + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot if self.len() == 1 { diff --git a/crates/polars-core/src/series/implementations/utf8.rs b/crates/polars-core/src/series/implementations/utf8.rs index 3e2f8f96a68b..b28722975ffc 100644 --- a/crates/polars-core/src/series/implementations/utf8.rs +++ b/crates/polars-core/src/series/implementations/utf8.rs @@ -9,8 +9,8 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; -use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -60,25 +60,21 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples) - } fn subtract(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::subtract(&self.0, rhs) } @@ -94,6 +90,7 @@ impl private::PrivateSeries for SeriesWrap { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -163,47 +160,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(ChunkTake::take(&self.0, (&*indices).into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - - let mut out = ChunkTake::take_unchecked(&self.0, (&*idx).into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - ChunkTake::take_unchecked(&self.0, iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(ChunkTake::take(&self.0, iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -247,14 +216,17 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } @@ -292,16 +264,6 @@ impl SeriesTrait for SeriesWrap { Arc::new(SeriesWrap(Clone::clone(&self.0))) } - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, by: &IdxCa) -> PolarsResult { - RepeatBy::repeat_by(&self.0, by) - } - - #[cfg(feature = "mode")] - fn mode(&self) -> PolarsResult { - Ok(self.0.mode()?.into_series()) - } - #[cfg(feature = "concat_str")] fn str_concat(&self, delimiter: &str) -> Utf8Chunked { self.0.str_concat(delimiter) diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index d8de5cd028ac..2ec370820c67 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -28,8 +28,6 @@ use rayon::prelude::*; pub use series_trait::{IsSorted, *}; use crate::chunked_array::Settings; -#[cfg(feature = "rank")] -use crate::prelude::unique::rank::rank; #[cfg(feature = "zip_with")] use crate::series::arithmetic::coerce_lhs_rhs; use crate::utils::{_split_offsets, get_casting_failures, split_ca, split_series, Wrap}; @@ -154,14 +152,13 @@ impl Hash for Wrap { } impl Series { - /// Create a new empty Series + /// Create a new empty Series. pub fn new_empty(name: &str, dtype: &DataType) -> Series { Series::full_null(name, 0, dtype) } pub fn clear(&self) -> Series { - // only the inner of objects know their type - // so use this hack + // Only the inner of objects know their type, so use this hack. #[cfg(feature = "object")] if matches!(self.dtype(), DataType::Object(_)) { return if self.is_empty() { @@ -229,6 +226,21 @@ impl Series { self } + /// Return this Series with a new name. + pub fn with_name(mut self, name: &str) -> Series { + self.rename(name); + self + } + + pub fn from_arrow(name: &str, array: ArrayRef) -> PolarsResult { + Self::try_from((name, array)) + } + + #[cfg(feature = "arrow_rs")] + pub fn from_arrow_rs(name: &str, array: &dyn arrow_array::Array) -> PolarsResult { + Self::from_arrow(name, array.into()) + } + /// Shrink the capacity of this array to fit its length. pub fn shrink_to_fit(&mut self) { self._get_inner_mut().shrink_to_fit() @@ -262,23 +274,18 @@ impl Series { self._get_inner_mut().as_single_ptr() } - /// Cast `[Series]` to another `[DataType]` + /// Cast `[Series]` to another `[DataType]`. pub fn cast(&self, dtype: &DataType) -> PolarsResult { - // best leave as is. + // Best leave as is. if matches!(dtype, DataType::Unknown) { return Ok(self.clone()); } - match self.0.cast(dtype) { - Ok(out) => Ok(out), - Err(err) => { - let len = self.len(); - if self.null_count() == len { - Ok(Series::full_null(self.name(), len, dtype)) - } else { - Err(err) - } - }, + let ret = self.0.cast(dtype); + let len = self.len(); + if ret.is_err() && self.null_count() == len { + return Ok(Series::full_null(self.name(), len, dtype)); } + ret } /// Cast from physical to logical types without any checks on the validity of the cast. @@ -288,28 +295,27 @@ impl Series { pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { match self.dtype() { #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { - let ca = self.struct_().unwrap(); - ca.cast_unchecked(dtype) - }, - DataType::List(_) => { - let ca = self.list().unwrap(); - ca.cast_unchecked(dtype) - }, + DataType::Struct(_) => self.struct_().unwrap().cast_unchecked(dtype), + DataType::List(_) => self.list().unwrap().cast_unchecked(dtype), dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(dt, |$T| { let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); ca.cast_unchecked(dtype) }) }, - DataType::Binary => { - let ca = self.binary().unwrap(); - ca.cast_unchecked(dtype) - }, + DataType::Binary => self.binary().unwrap().cast_unchecked(dtype), _ => self.cast(dtype), } } + /// Cast numerical types to f64, and keep floats as is. + pub fn to_float(&self) -> PolarsResult { + match self.dtype() { + DataType::Float32 | DataType::Float64 => Ok(self.clone()), + _ => self.cast(&DataType::Float64), + } + } + /// Compute the sum of all values in this Series. /// Returns `Some(0)` if the array is empty, and `None` if the array only /// contains null values. @@ -326,10 +332,8 @@ impl Series { where T: NumCast, { - self.sum_as_series() - .cast(&DataType::Float64) - .ok() - .and_then(|s| s.f64().unwrap().get(0).and_then(T::from)) + let sum = self.sum_as_series().cast(&DataType::Float64).ok()?; + T::from(sum.f64().unwrap().get(0)?) } /// Returns the minimum value in the array, according to the natural order. @@ -343,10 +347,8 @@ impl Series { where T: NumCast, { - self.min_as_series() - .cast(&DataType::Float64) - .ok() - .and_then(|s| s.f64().unwrap().get(0).and_then(T::from)) + let min = self.min_as_series().cast(&DataType::Float64).ok()?; + T::from(min.f64().unwrap().get(0)?) } /// Returns the maximum value in the array, according to the natural order. @@ -360,10 +362,8 @@ impl Series { where T: NumCast, { - self.max_as_series() - .cast(&DataType::Float64) - .ok() - .and_then(|s| s.f64().unwrap().get(0).and_then(T::from)) + let max = self.max_as_series().cast(&DataType::Float64).ok()?; + T::from(max.f64().unwrap().get(0)?) } /// Explode a list Series. This expands every item to a new row.. @@ -444,16 +444,13 @@ impl Series { #[cfg(feature = "dtype-struct")] Struct(_) => { let arr = self.struct_().unwrap(); - let fields = arr + let fields: Vec<_> = arr .fields() .iter() .map(|s| s.to_physical_repr().into_owned()) - .collect::>(); - Cow::Owned( - StructChunked::new(self.name(), &fields) - .unwrap() - .into_series(), - ) + .collect(); + let ca = StructChunked::new(self.name(), &fields).unwrap(); + Cow::Owned(ca.into_series()) }, _ => Cow::Borrowed(self), } @@ -474,7 +471,7 @@ impl Series { } } - // take a function pointer to reduce bloat + // Take a function pointer to reduce bloat. fn threaded_op( &self, rechunk: bool, @@ -498,24 +495,31 @@ impl Series { /// /// # Safety /// This doesn't check any bounds. Null validity is checked. - pub unsafe fn take_unchecked_from_slice(&self, idx: &[IdxSize]) -> PolarsResult { - let idx = IdxCa::mmap_slice("", idx); - self.take_unchecked(&idx) + pub unsafe fn take_unchecked_from_slice(&self, idx: &[IdxSize]) -> Series { + self.take_slice_unchecked(idx) } /// Take by index if ChunkedArray contains a single chunk. /// /// # Safety /// This doesn't check any bounds. Null validity is checked. - pub unsafe fn take_unchecked_threaded( - &self, - idx: &IdxCa, - rechunk: bool, - ) -> PolarsResult { + pub unsafe fn take_unchecked_threaded(&self, idx: &IdxCa, rechunk: bool) -> Series { self.threaded_op(rechunk, idx.len(), &|offset, len| { let idx = idx.slice(offset as i64, len); - self.take_unchecked(&idx) + Ok(self.take_unchecked(&idx)) }) + .unwrap() + } + + /// Take by index if ChunkedArray contains a single chunk. + /// + /// # Safety + /// This doesn't check any bounds. Null validity is checked. + pub unsafe fn take_slice_unchecked_threaded(&self, idx: &[IdxSize], rechunk: bool) -> Series { + self.threaded_op(rechunk, idx.len(), &|offset, len| { + Ok(self.take_slice_unchecked(&idx[offset..offset + len])) + }) + .unwrap() } /// # Safety @@ -560,11 +564,18 @@ impl Series { }) } + /// Traverse and collect every nth element in a new array. + pub fn take_every(&self, n: usize) -> Series { + let idx = (0..self.len() as IdxSize).step_by(n).collect_ca(""); + // SAFETY: we stay in-bounds. + unsafe { self.take_unchecked(&idx) } + } + /// Filter by boolean mask. This operation clones data. pub fn filter_threaded(&self, filter: &BooleanChunked, rechunk: bool) -> PolarsResult { - // this would fail if there is a broadcasting filter. - // because we cannot split that filter over threads - // besides they are a no-op, so we do the standard filter. + // This would fail if there is a broadcasting filter, because we cannot + // split that filter over threads besides they are a no-op, so we do the + // standard filter. if filter.len() == 1 { return self.filter(filter); } @@ -598,10 +609,8 @@ impl Series { if self.is_empty() && (self.dtype().is_numeric() || matches!(self.dtype(), DataType::Boolean)) { - return Series::new(self.name(), [0]) - .cast(self.dtype()) - .unwrap() - .sum_as_series(); + let zero = Series::new(self.name(), [0]); + return zero.cast(self.dtype()).unwrap().sum_as_series(); } match self.dtype() { Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().sum_as_series(), @@ -609,7 +618,7 @@ impl Series { } } - /// Get an array with the cumulative max computed at every element + /// Get an array with the cumulative max computed at every element. pub fn cummax(&self, _reverse: bool) -> Series { #[cfg(feature = "cum_agg")] { @@ -621,7 +630,7 @@ impl Series { } } - /// Get an array with the cumulative min computed at every element + /// Get an array with the cumulative min computed at every element. pub fn cummin(&self, _reverse: bool) -> Series { #[cfg(feature = "cum_agg")] { @@ -648,30 +657,12 @@ impl Series { let s = self.cast(&Int64).unwrap(); s.cumsum(reverse) }, - Int32 => { - let ca = self.i32().unwrap(); - ca.cumsum(reverse).into_series() - }, - UInt32 => { - let ca = self.u32().unwrap(); - ca.cumsum(reverse).into_series() - }, - UInt64 => { - let ca = self.u64().unwrap(); - ca.cumsum(reverse).into_series() - }, - Int64 => { - let ca = self.i64().unwrap(); - ca.cumsum(reverse).into_series() - }, - Float32 => { - let ca = self.f32().unwrap(); - ca.cumsum(reverse).into_series() - }, - Float64 => { - let ca = self.f64().unwrap(); - ca.cumsum(reverse).into_series() - }, + Int32 => self.i32().unwrap().cumsum(reverse).into_series(), + UInt32 => self.u32().unwrap().cumsum(reverse).into_series(), + UInt64 => self.u64().unwrap().cumsum(reverse).into_series(), + Int64 => self.i64().unwrap().cumsum(reverse).into_series(), + Float32 => self.f32().unwrap().cumsum(reverse).into_series(), + Float64 => self.f64().unwrap().cumsum(reverse).into_series(), #[cfg(feature = "dtype-duration")] Duration(tu) => { let ca = self.to_physical_repr(); @@ -687,7 +678,7 @@ impl Series { } } - /// Get an array with the cumulative product computed at every element + /// Get an array with the cumulative product computed at every element. /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is /// first cast to `Int64` to prevent overflow issues. @@ -702,22 +693,10 @@ impl Series { let s = self.cast(&Int64).unwrap(); s.cumprod(reverse) }, - Int64 => { - let ca = self.i64().unwrap(); - ca.cumprod(reverse).into_series() - }, - UInt64 => { - let ca = self.u64().unwrap(); - ca.cumprod(reverse).into_series() - }, - Float32 => { - let ca = self.f32().unwrap(); - ca.cumprod(reverse).into_series() - }, - Float64 => { - let ca = self.f64().unwrap(); - ca.cumprod(reverse).into_series() - }, + Int64 => self.i64().unwrap().cumprod(reverse).into_series(), + UInt64 => self.u64().unwrap().cumprod(reverse).into_series(), + Float32 => self.f32().unwrap().cumprod(reverse).into_series(), + Float64 => self.f64().unwrap().cumprod(reverse).into_series(), dt => panic!("cumprod not supported for dtype: {dt:?}"), } } @@ -741,23 +720,11 @@ impl Series { let s = self.cast(&Int64).unwrap(); s.product() }, - Int64 => { - let ca = self.i64().unwrap(); - ca.prod_as_series() - }, - UInt64 => { - let ca = self.u64().unwrap(); - ca.prod_as_series() - }, - Float32 => { - let ca = self.f32().unwrap(); - ca.prod_as_series() - }, - Float64 => { - let ca = self.f64().unwrap(); - ca.prod_as_series() - }, - dt => panic!("cumprod not supported for dtype: {dt:?}"), + Int64 => self.i64().unwrap().prod_as_series(), + UInt64 => self.u64().unwrap().prod_as_series(), + Float32 => self.f32().unwrap().prod_as_series(), + Float64 => self.f64().unwrap().prod_as_series(), + dt => panic!("product not supported for dtype: {dt:?}"), } } #[cfg(not(feature = "product"))] @@ -766,11 +733,6 @@ impl Series { } } - #[cfg(feature = "rank")] - pub fn rank(&self, options: RankOptions, seed: Option) -> Series { - rank(self, options.method, options.descending, seed) - } - /// Cast throws an error if conversion had overflows pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { let null_count = self.null_count(); @@ -962,9 +924,8 @@ impl Series { /// than a naive [`Series::unique`](SeriesTrait::unique). pub fn unique_stable(&self) -> PolarsResult { let idx = self.arg_unique()?; - // Safety: - // Indices are in bounds. - unsafe { self.take_unchecked(&idx) } + // SAFETY: Indices are in bounds. + unsafe { Ok(self.take_unchecked(&idx)) } } pub fn idx(&self) -> PolarsResult<&IdxCa> { @@ -1059,7 +1020,7 @@ where DataType::Decimal(None, None) => panic!("impl error"), _ => { if &T::get_dtype() == self.dtype() || - // needed because we want to get ref of List no matter what the inner type is. + // Needed because we want to get ref of List no matter what the inner type is. (matches!(T::get_dtype(), DataType::List(_)) && matches!(self.dtype(), DataType::List(_))) { unsafe { &*(self as *const dyn SeriesTrait as *const ChunkedArray) } @@ -1081,7 +1042,7 @@ where { fn as_mut(&mut self) -> &mut ChunkedArray { if &T::get_dtype() == self.dtype() || - // needed because we want to get ref of List no matter what the inner type is. + // Needed because we want to get ref of List no matter what the inner type is. (matches!(T::get_dtype(), DataType::List(_)) && matches!(self.dtype(), DataType::List(_))) { unsafe { &mut *(self as *mut dyn SeriesTrait as *mut ChunkedArray) } diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs index 3f50fb4c6dbd..e632ecf78ea8 100644 --- a/crates/polars-core/src/series/ops/downcast.rs +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use crate::series::implementations::null::NullChunked; macro_rules! unpack_chunked { ($series:expr, $expected:pat => $ca:ty, $name:expr) => { @@ -14,17 +15,17 @@ macro_rules! unpack_chunked { } impl Series { - /// Unpack to ChunkedArray of dtype `[DataType::Int8]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int8]` pub fn i8(&self) -> PolarsResult<&Int8Chunked> { unpack_chunked!(self, DataType::Int8 => Int8Chunked, "Int8") } - /// Unpack to ChunkedArray of dtype `[DataType::Int16]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int16]` pub fn i16(&self) -> PolarsResult<&Int16Chunked> { unpack_chunked!(self, DataType::Int16 => Int16Chunked, "Int16") } - /// Unpack to ChunkedArray + /// Unpack to [`ChunkedArray`] /// ``` /// # use polars_core::prelude::*; /// let s = Series::new("foo", [1i32 ,2, 3]); @@ -38,109 +39,109 @@ impl Series { /// } /// }).collect(); /// ``` - /// Unpack to ChunkedArray of dtype `[DataType::Int32]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int32]` pub fn i32(&self) -> PolarsResult<&Int32Chunked> { unpack_chunked!(self, DataType::Int32 => Int32Chunked, "Int32") } - /// Unpack to ChunkedArray of dtype `[DataType::Int64]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int64]` pub fn i64(&self) -> PolarsResult<&Int64Chunked> { unpack_chunked!(self, DataType::Int64 => Int64Chunked, "Int64") } - /// Unpack to ChunkedArray of dtype `[DataType::Float32]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float32]` pub fn f32(&self) -> PolarsResult<&Float32Chunked> { unpack_chunked!(self, DataType::Float32 => Float32Chunked, "Float32") } - /// Unpack to ChunkedArray of dtype `[DataType::Float64]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float64]` pub fn f64(&self) -> PolarsResult<&Float64Chunked> { unpack_chunked!(self, DataType::Float64 => Float64Chunked, "Float64") } - /// Unpack to ChunkedArray of dtype `[DataType::UInt8]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt8]` pub fn u8(&self) -> PolarsResult<&UInt8Chunked> { unpack_chunked!(self, DataType::UInt8 => UInt8Chunked, "UInt8") } - /// Unpack to ChunkedArray of dtype `[DataType::UInt16]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt16]` pub fn u16(&self) -> PolarsResult<&UInt16Chunked> { unpack_chunked!(self, DataType::UInt16 => UInt16Chunked, "UInt16") } - /// Unpack to ChunkedArray of dtype `[DataType::UInt32]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt32]` pub fn u32(&self) -> PolarsResult<&UInt32Chunked> { unpack_chunked!(self, DataType::UInt32 => UInt32Chunked, "UInt32") } - /// Unpack to ChunkedArray of dtype `[DataType::UInt64]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt64]` pub fn u64(&self) -> PolarsResult<&UInt64Chunked> { unpack_chunked!(self, DataType::UInt64 => UInt64Chunked, "UInt64") } - /// Unpack to ChunkedArray of dtype `[DataType::Boolean]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Boolean]` pub fn bool(&self) -> PolarsResult<&BooleanChunked> { unpack_chunked!(self, DataType::Boolean => BooleanChunked, "Boolean") } - /// Unpack to ChunkedArray of dtype `[DataType::Utf8]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Utf8]` pub fn utf8(&self) -> PolarsResult<&Utf8Chunked> { unpack_chunked!(self, DataType::Utf8 => Utf8Chunked, "Utf8") } - /// Unpack to ChunkedArray of dtype `[DataType::Binary]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` pub fn binary(&self) -> PolarsResult<&BinaryChunked> { unpack_chunked!(self, DataType::Binary => BinaryChunked, "Binary") } - /// Unpack to ChunkedArray of dtype `[DataType::Time]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Time]` #[cfg(feature = "dtype-time")] pub fn time(&self) -> PolarsResult<&TimeChunked> { unpack_chunked!(self, DataType::Time => TimeChunked, "Time") } - /// Unpack to ChunkedArray of dtype `[DataType::Date]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Date]` #[cfg(feature = "dtype-date")] pub fn date(&self) -> PolarsResult<&DateChunked> { unpack_chunked!(self, DataType::Date => DateChunked, "Date") } - /// Unpack to ChunkedArray of dtype `[DataType::Datetime]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Datetime]` #[cfg(feature = "dtype-datetime")] pub fn datetime(&self) -> PolarsResult<&DatetimeChunked> { unpack_chunked!(self, DataType::Datetime(_, _) => DatetimeChunked, "Datetime") } - /// Unpack to ChunkedArray of dtype `[DataType::Duration]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Duration]` #[cfg(feature = "dtype-duration")] pub fn duration(&self) -> PolarsResult<&DurationChunked> { unpack_chunked!(self, DataType::Duration(_) => DurationChunked, "Duration") } - /// Unpack to ChunkedArray of dtype `[DataType::Decimal]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Decimal]` #[cfg(feature = "dtype-decimal")] pub fn decimal(&self) -> PolarsResult<&DecimalChunked> { unpack_chunked!(self, DataType::Decimal(_, _) => DecimalChunked, "Decimal") } - /// Unpack to ChunkedArray of dtype list + /// Unpack to [`ChunkedArray`] of dtype list pub fn list(&self) -> PolarsResult<&ListChunked> { unpack_chunked!(self, DataType::List(_) => ListChunked, "List") } - /// Unpack to ChunkedArray of dtype `[DataType::Array]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Array]` #[cfg(feature = "dtype-array")] pub fn array(&self) -> PolarsResult<&ArrayChunked> { unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked, "FixedSizeList") } - /// Unpack to ChunkedArray of dtype `[DataType::Categorical]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Categorical]` #[cfg(feature = "dtype-categorical")] pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { unpack_chunked!(self, DataType::Categorical(_) => CategoricalChunked, "Categorical") } - /// Unpack to ChunkedArray of dtype `[DataType::Struct]` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Struct]` #[cfg(feature = "dtype-struct")] pub fn struct_(&self) -> PolarsResult<&StructChunked> { #[cfg(debug_assertions)] @@ -152,4 +153,9 @@ impl Series { } unpack_chunked!(self, DataType::Struct(_) => StructChunked, "Struct") } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Null]` + pub fn null(&self) -> PolarsResult<&NullChunked> { + unpack_chunked!(self, DataType::Null => NullChunked, "Null") + } } diff --git a/crates/polars-core/src/series/ops/round.rs b/crates/polars-core/src/series/ops/round.rs index edcd3f31cbca..37abe7797941 100644 --- a/crates/polars-core/src/series/ops/round.rs +++ b/crates/polars-core/src/series/ops/round.rs @@ -1,5 +1,4 @@ use num_traits::pow::Pow; -use num_traits::{clamp_max, clamp_min}; use crate::prelude::*; @@ -60,67 +59,4 @@ impl Series { } polars_bail!(opq = ceil, self.dtype()); } - - /// Clamp underlying values to the `min` and `max` values. - pub fn clip(mut self, min: AnyValue<'_>, max: AnyValue<'_>) -> PolarsResult { - if self.dtype().is_numeric() { - macro_rules! apply_clip { - ($pl_type:ty, $ca:expr) => {{ - let min = min - .extract::<<$pl_type as PolarsNumericType>::Native>() - .unwrap(); - let max = max - .extract::<<$pl_type as PolarsNumericType>::Native>() - .unwrap(); - - $ca.apply_mut(|val| val.clamp(min, max)); - }}; - } - let mutable = self._get_inner_mut(); - downcast_as_macro_arg_physical_mut!(mutable, apply_clip); - Ok(self) - } else { - polars_bail!(opq = clip, self.dtype()); - } - } - - /// Clamp underlying values to the `max` value. - pub fn clip_max(mut self, max: AnyValue<'_>) -> PolarsResult { - if self.dtype().is_numeric() { - macro_rules! apply_clip { - ($pl_type:ty, $ca:expr) => {{ - let max = max - .extract::<<$pl_type as PolarsNumericType>::Native>() - .unwrap(); - - $ca.apply_mut(|val| clamp_max(val, max)); - }}; - } - let mutable = self._get_inner_mut(); - downcast_as_macro_arg_physical_mut!(mutable, apply_clip); - Ok(self) - } else { - polars_bail!(opq = clip_max, self.dtype()); - } - } - - /// Clamp underlying values to the `min` value. - pub fn clip_min(mut self, min: AnyValue<'_>) -> PolarsResult { - if self.dtype().is_numeric() { - macro_rules! apply_clip { - ($pl_type:ty, $ca:expr) => {{ - let min = min - .extract::<<$pl_type as PolarsNumericType>::Native>() - .unwrap(); - - $ca.apply_mut(|val| clamp_min(val, min)); - }}; - } - let mutable = self._get_inner_mut(); - downcast_as_macro_arg_physical_mut!(mutable, apply_clip); - Ok(self) - } else { - polars_bail!(opq = clip_min, self.dtype()); - } - } } diff --git a/crates/polars-core/src/series/ops/unique.rs b/crates/polars-core/src/series/ops/unique.rs index cfae77d687e7..9daee89710d3 100644 --- a/crates/polars-core/src/series/ops/unique.rs +++ b/crates/polars-core/src/series/ops/unique.rs @@ -2,7 +2,7 @@ use std::hash::Hash; #[cfg(feature = "unique_counts")] -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; +use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; #[cfg(feature = "unique_counts")] use crate::utils::NoNull; @@ -13,7 +13,7 @@ where I: Iterator, J: Hash + Eq, { - let mut map = PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + let mut map = PlIndexMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); for item in items { map.entry(item) .and_modify(|cnt| { diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 83d93a76b11d..660413e488b9 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -47,7 +47,7 @@ pub(crate) mod private { use super::*; use crate::chunked_array::ops::compare_inner::{PartialEqInner, PartialOrdInner}; use crate::chunked_array::Settings; - #[cfg(feature = "rows")] + #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::GroupsProxy; pub trait PrivateSeriesNumeric { @@ -126,35 +126,33 @@ pub(crate) mod private { ) -> PolarsResult<()> { polars_bail!(opq = vec_hash_combine, self._dtype()); } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } - fn zip_outer_join_column( - &self, - _right_column: &Series, - _opt_join_tuples: &[(Option, Option)], - ) -> Series { - invalid_operation_panic!(zip_outer_join_column, self) - } - fn subtract(&self, _rhs: &Series) -> PolarsResult { invalid_operation_panic!(sub, self) } @@ -170,6 +168,7 @@ pub(crate) mod private { fn remainder(&self, _rhs: &Series) -> PolarsResult { invalid_operation_panic!(rem, self) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { invalid_operation_panic!(group_tuples, self) } @@ -268,40 +267,23 @@ pub trait SeriesTrait: #[cfg(feature = "chunked_ids")] unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series; - /// Take by index from an iterator. This operation clones the data. - fn take_iter(&self, _iter: &mut dyn TakeIterator) -> PolarsResult; - - /// Take by index from an iterator. This operation clones the data. - /// - /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_iter_unchecked(&self, _iter: &mut dyn TakeIterator) -> Series; + /// Take by index. This operation is clone. + fn take(&self, _indices: &IdxCa) -> PolarsResult; - /// Take by index if ChunkedArray contains a single chunk. + /// Take by index. /// /// # Safety /// This doesn't check any bounds. - unsafe fn take_unchecked(&self, _idx: &IdxCa) -> PolarsResult; + unsafe fn take_unchecked(&self, _idx: &IdxCa) -> Series; + + /// Take by index. This operation is clone. + fn take_slice(&self, _indices: &[IdxSize]) -> PolarsResult; - /// Take by index from an iterator. This operation clones the data. + /// Take by index. /// /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_opt_iter_unchecked(&self, _iter: &mut dyn TakeIteratorNulls) -> Series; - - /// Take by index from an iterator. This operation clones the data. - /// todo! remove? - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, _iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - invalid_operation_panic!(take_opt_iter, self) - } - - /// Take by index. This operation is clone. - fn take(&self, _indices: &IdxCa) -> PolarsResult; + /// This doesn't check any bounds. + unsafe fn take_slice_unchecked(&self, _idx: &[IdxSize]) -> Series; /// Get length of series. fn len(&self) -> usize; @@ -496,31 +478,11 @@ pub trait SeriesTrait: invalid_operation_panic!(as_any_mut, self) } - /// Get a boolean mask of the local maximum peaks. - fn peak_max(&self) -> BooleanChunked { - invalid_operation_panic!(peak_max, self) - } - - /// Get a boolean mask of the local minimum peaks. - fn peak_min(&self) -> BooleanChunked { - invalid_operation_panic!(peak_min, self) - } - - #[cfg(feature = "repeat_by")] - fn repeat_by(&self, _by: &IdxCa) -> PolarsResult { - polars_bail!(opq = repeat_by, self._dtype()); - } #[cfg(feature = "checked_arithmetic")] fn checked_div(&self, _rhs: &Series) -> PolarsResult { polars_bail!(opq = checked_div, self._dtype()); } - #[cfg(feature = "mode")] - /// Compute the most occurring element in the array. - fn mode(&self) -> PolarsResult { - polars_bail!(opq = mode, self._dtype()); - } - #[cfg(feature = "rolling_window")] /// Apply a custom function over a rolling/ moving window of the array. /// This has quite some dynamic dispatch, so prefer rolling_min, max, mean, sum over this. @@ -531,14 +493,6 @@ pub trait SeriesTrait: ) -> PolarsResult { polars_bail!(opq = rolling_map, self._dtype()); } - #[cfg(feature = "concat_str")] - /// Concat the values into a string array. - /// # Arguments - /// - /// * `delimiter` - A string that will act as delimiter between values. - fn str_concat(&self, _delimiter: &str) -> Utf8Chunked { - invalid_operation_panic!(str_concat, self); - } fn tile(&self, _n: usize) -> Series { invalid_operation_panic!(tile, self); diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs index 4a2e14be242c..2f18cd5eae6c 100644 --- a/crates/polars-core/src/testing.rs +++ b/crates/polars-core/src/testing.rs @@ -14,7 +14,7 @@ impl Series { } /// Check if all values in series are equal where `None == None` evaluates to `true`. - /// Two `Datetime` series are *not* equal if their timezones are different, regardless + /// Two [`Datetime`](DataType::Datetime) series are *not* equal if their timezones are different, regardless /// if they represent the same UTC time or not. pub fn series_equal_missing(&self, other: &Series) -> bool { match (self.dtype(), other.dtype()) { @@ -40,7 +40,7 @@ impl Series { } } - /// Get a pointer to the underlying data of this Series. + /// Get a pointer to the underlying data of this [`Series`]. /// Can be useful for fast comparisons. pub fn get_data_ptr(&self) -> usize { let object = self.0.deref(); @@ -63,7 +63,7 @@ impl PartialEq for Series { } impl DataFrame { - /// Check if `DataFrames` schemas are equal. + /// Check if [`DataFrame`]' schemas are equal. pub fn frame_equal_schema(&self, other: &DataFrame) -> PolarsResult<()> { for (lhs, rhs) in self.iter().zip(other.iter()) { polars_ensure!( @@ -80,7 +80,7 @@ impl DataFrame { Ok(()) } - /// Check if `DataFrames` are equal. Note that `None == None` evaluates to `false` + /// Check if [`DataFrame`]s are equal. Note that `None == None` evaluates to `false` /// /// # Example /// @@ -106,7 +106,7 @@ impl DataFrame { true } - /// Check if all values in `DataFrames` are equal where `None == None` evaluates to `true`. + /// Check if all values in [`DataFrame`]s are equal where `None == None` evaluates to `true`. /// /// # Example /// @@ -132,7 +132,7 @@ impl DataFrame { true } - /// Checks if the Arc ptrs of the Series are equal + /// Checks if the Arc ptrs of the [`Series`] are equal /// /// # Example /// diff --git a/crates/polars-core/src/utils/flatten.rs b/crates/polars-core/src/utils/flatten.rs index afc8167a31e6..7b5b56bde98b 100644 --- a/crates/polars-core/src/utils/flatten.rs +++ b/crates/polars-core/src/utils/flatten.rs @@ -37,7 +37,7 @@ pub fn flatten_series(s: &Series) -> Vec { } } -pub(crate) fn cap_and_offsets(v: &[Vec]) -> (usize, Vec) { +pub fn cap_and_offsets(v: &[Vec]) -> (usize, Vec) { let cap = v.iter().map(|v| v.len()).sum::(); let offsets = v .iter() diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index e521aa274d3b..347627c437c0 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -4,6 +4,7 @@ mod supertype; use std::borrow::Cow; use std::ops::{Deref, DerefMut}; +use arrow::bitmap::bitmask::BitMask; use arrow::bitmap::Bitmap; use flatten::*; use num_traits::{One, Zero}; @@ -145,7 +146,7 @@ pub fn split_series(s: &Series, n: usize) -> PolarsResult> { pub fn split_df_as_ref(df: &DataFrame, n: usize) -> PolarsResult> { let total_len = df.height(); - let chunk_size = std::cmp::max(total_len / n, 3); + let chunk_size = std::cmp::max(total_len / n, 1); if df.n_chunks() == n && df.get_columns()[0] @@ -335,6 +336,19 @@ macro_rules! with_match_physical_integer_type {( } })} +#[macro_export] +macro_rules! with_match_physical_float_polars_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $key_type { + Float32 => __with_ty__! { Float32Type }, + Float64 => __with_ty__! { Float64Type }, + _ => unimplemented!() + } +})} + #[macro_export] macro_rules! with_match_physical_numeric_polars_type {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* @@ -673,7 +687,6 @@ where } #[allow(clippy::type_complexity)] -#[cfg(feature = "zip_with")] pub fn align_chunks_ternary<'a, A, B, C>( a: &'a ChunkedArray, b: &'a ChunkedArray, @@ -809,33 +822,6 @@ pub(crate) fn index_to_chunked_index< (current_chunk_idx, index_remainder) } -#[cfg(feature = "dtype-struct")] -pub(crate) fn index_to_chunked_index2(chunks: &[ArrayRef], index: usize) -> (usize, usize) { - let mut index_remainder = index; - let mut current_chunk_idx = 0; - - for chunk in chunks { - if chunk.len() > index_remainder { - break; - } else { - index_remainder -= chunk.len(); - current_chunk_idx += 1; - } - } - (current_chunk_idx, index_remainder) -} - -#[cfg(feature = "chunked_ids")] -pub(crate) fn create_chunked_index_mapping(chunks: &[ArrayRef], len: usize) -> Vec { - let mut vals = Vec::with_capacity(len); - - for (chunk_i, chunk) in chunks.iter().enumerate() { - vals.extend((0..chunk.len()).map(|array_i| [chunk_i as IdxSize, array_i as IdxSize])) - } - - vals -} - pub(crate) fn first_non_null<'a, I>(iter: I) -> Option where I: Iterator>, @@ -843,10 +829,9 @@ where let mut offset = 0; for validity in iter { if let Some(validity) = validity { - for (idx, is_valid) in validity.iter().enumerate() { - if is_valid { - return Some(offset + idx); - } + let mask = BitMask::from_bitmap(validity); + if let Some(n) = mask.nth_set_bit_idx(0, 0) { + return Some(offset + n); } offset += validity.len() } else { @@ -864,17 +849,16 @@ where return None; } let mut offset = 0; - let len = len - 1; for validity in iter.rev() { if let Some(validity) = validity { - for (idx, is_valid) in validity.iter().rev().enumerate() { - if is_valid { - return Some(len - (offset + idx)); - } + let mask = BitMask::from_bitmap(validity); + if let Some(n) = mask.nth_set_bit_idx_rev(0, mask.len()) { + let mask_start = len - offset - mask.len(); + return Some(mask_start + n); } offset += validity.len() } else { - return Some(len - offset); + return Some(len - 1 - offset); } } None diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs index b6c87b2cff33..115dc2843dad 100644 --- a/crates/polars-core/src/utils/series.rs +++ b/crates/polars-core/src/utils/series.rs @@ -19,7 +19,7 @@ pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { .collect() } -/// A utility that allocates an `UnstableSeries`. The applied function can then use that +/// A utility that allocates an [`UnstableSeries`]. The applied function can then use that /// series container to save heap allocations and swap arrow arrays. pub fn with_unstable_series(dtype: &DataType, f: F) -> T where diff --git a/crates/polars-error/Cargo.toml b/crates/polars-error/Cargo.toml index ce622dcccbe9..a6ed9123f199 100644 --- a/crates/polars-error/Cargo.toml +++ b/crates/polars-error/Cargo.toml @@ -10,6 +10,7 @@ description = "Error definitions for the Polars DataFrame library" [dependencies] arrow = { workspace = true } +object_store = { version = "0.7", default-features = false, optional = true } regex = { workspace = true, optional = true } thiserror = { workspace = true } diff --git a/crates/polars-error/README.md b/crates/polars-error/README.md index 0b52af09b05a..9aaac05c5795 100644 --- a/crates/polars-error/README.md +++ b/crates/polars-error/README.md @@ -1,5 +1,5 @@ # polars-error -`polars-error` is a sub-crate that provides error definitions for the Polars dataframe library. +`polars-error` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, defining its error types. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-error/src/constants.rs b/crates/polars-error/src/constants.rs index 473e9edfe55b..910c1e62a499 100644 --- a/crates/polars-error/src/constants.rs +++ b/crates/polars-error/src/constants.rs @@ -8,3 +8,10 @@ pub static FALSE: &str = "False"; pub static TRUE: &str = "true"; #[cfg(not(feature = "python"))] pub static FALSE: &str = "false"; + +#[cfg(not(feature = "python"))] +pub static LENGTH_LIMIT_MSG: &str = + "polars' maximum length reached. Consider compiling with 'bigidx' feature."; +#[cfg(feature = "python")] +pub static LENGTH_LIMIT_MSG: &str = + "polars' maximum length reached. Consider installing 'polars-u64-idx'."; diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index 6cf86706ab44..db89125643ec 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -55,6 +55,8 @@ pub enum PolarsError { Io(#[from] io::Error), #[error("no data: {0}")] NoData(ErrString), + #[error("{0}")] + OutOfBounds(ErrString), #[error("field not found: {0}")] SchemaFieldNotFound(ErrString), #[error("data types don't match: {0}")] @@ -80,6 +82,16 @@ impl From for PolarsError { } } +#[cfg(feature = "object_store")] +impl From for PolarsError { + fn from(err: object_store::Error) -> Self { + PolarsError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + format!("object store error {err:?}"), + )) + } +} + pub type PolarsResult = Result; pub use arrow::error::Error as ArrowError; @@ -95,6 +107,7 @@ impl PolarsError { InvalidOperation(msg) => InvalidOperation(func(msg).into()), Io(err) => ComputeError(func(&format!("IO: {err}")).into()), NoData(msg) => NoData(func(msg).into()), + OutOfBounds(msg) => OutOfBounds(func(msg).into()), SchemaFieldNotFound(msg) => SchemaFieldNotFound(func(msg).into()), SchemaMismatch(msg) => SchemaMismatch(func(msg).into()), ShapeMismatch(msg) => ShapeMismatch(func(msg).into()), @@ -110,14 +123,14 @@ pub fn map_err(error: E) -> PolarsError { #[macro_export] macro_rules! polars_err { - ($variant:ident: $err:expr $(,)?) => { + ($variant:ident: $fmt:literal $(, $arg:expr)* $(,)?) => { $crate::__private::must_use( - $crate::PolarsError::$variant($err.into()) + $crate::PolarsError::$variant(format!($fmt, $($arg),*).into()) ) }; - ($variant:ident: $fmt:literal, $($arg:tt)+) => { + ($variant:ident: $err:expr $(,)?) => { $crate::__private::must_use( - $crate::PolarsError::$variant(format!($fmt, $($arg)+).into()) + $crate::PolarsError::$variant($err.into()) ) }; (expr = $expr:expr, $variant:ident: $err:expr $(,)?) => { @@ -187,7 +200,7 @@ Help: if you're using Python, this may look something like: Alternatively, if the performance cost is acceptable, you could just set: import polars as pl - pl.enable_string_cache(True) + pl.enable_string_cache() on startup."#.trim_start()) }; @@ -195,7 +208,7 @@ on startup."#.trim_start()) polars_err!(Duplicate: "column with name '{}' has more than one occurrences", $name) }; (oob = $idx:expr, $len:expr) => { - polars_err!(ComputeError: "index {} is out of bounds for sequence of size {}", $idx, $len) + polars_err!(OutOfBounds: "index {} is out of bounds for sequence of length {}", $idx, $len) }; (agg_len = $agg_len:expr, $groups_len:expr) => { polars_err!( diff --git a/crates/polars-ffi/Cargo.toml b/crates/polars-ffi/Cargo.toml new file mode 100644 index 000000000000..2414c80aae70 --- /dev/null +++ b/crates/polars-ffi/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "polars-ffi" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "FFI utils for the Polars project." + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +polars-core = { workspace = true } diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs new file mode 100644 index 000000000000..699d5e7a7fd5 --- /dev/null +++ b/crates/polars-ffi/src/lib.rs @@ -0,0 +1,128 @@ +use std::mem::ManuallyDrop; + +use arrow::ffi; +use arrow::ffi::{ArrowArray, ArrowSchema}; +use polars_core::error::PolarsResult; +use polars_core::prelude::{ArrayRef, ArrowField, Series}; + +// A utility that helps releasing/owning memory. +#[allow(dead_code)] +struct PrivateData { + schema: Box, + arrays: Box<[*mut ArrowArray]>, +} + +/// An FFI exported `Series`. +#[repr(C)] +pub struct SeriesExport { + field: *mut ArrowSchema, + // A double ptr, so we can easily release the buffer + // without dropping the arrays. + arrays: *mut *mut ArrowArray, + len: usize, + release: Option, + private_data: *mut std::os::raw::c_void, +} + +impl Drop for SeriesExport { + fn drop(&mut self) { + if let Some(release) = self.release { + unsafe { release(self) } + } + } +} + +// callback used to drop [SeriesExport] when it is exported. +unsafe extern "C" fn c_release_series_export(e: *mut SeriesExport) { + if e.is_null() { + return; + } + let e = &mut *e; + let private = Box::from_raw(e.private_data as *mut PrivateData); + for ptr in private.arrays.iter() { + // drop the box, not the array + let _ = Box::from_raw(*ptr as *mut ManuallyDrop); + } + + e.release = None; +} + +pub fn export_series(s: &Series) -> SeriesExport { + let field = ArrowField::new(s.name(), s.dtype().to_arrow(), true); + let schema = Box::new(ffi::export_field_to_c(&field)); + let mut arrays = s + .chunks() + .iter() + .map(|arr| Box::into_raw(Box::new(ffi::export_array_to_c(arr.clone())))) + .collect::>(); + let len = arrays.len(); + let ptr = arrays.as_mut_ptr(); + SeriesExport { + field: schema.as_ref() as *const ArrowSchema as *mut ArrowSchema, + arrays: ptr, + len, + release: Some(c_release_series_export), + private_data: Box::into_raw(Box::new(PrivateData { arrays, schema })) + as *mut std::os::raw::c_void, + } +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_series(e: SeriesExport) -> PolarsResult { + let field = ffi::import_field_from_c(&(*e.field))?; + + let pointers = std::slice::from_raw_parts_mut(e.arrays, e.len); + let chunks = pointers + .iter() + .map(|ptr| { + let arr = std::ptr::read(*ptr); + import_array(arr, &(*e.field)) + }) + .collect::>>()?; + + Ok(Series::from_chunks_and_dtype_unchecked( + &field.name, + chunks, + &(&field.data_type).into(), + )) +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_series_buffer(e: *mut SeriesExport, len: usize) -> PolarsResult> { + let mut out = Vec::with_capacity(len); + for i in 0..len { + let e = std::ptr::read(e.add(i)); + out.push(import_series(e)?) + } + Ok(out) +} + +/// # Safety +/// `ArrowArray` and `ArrowSchema` must be valid +unsafe fn import_array( + array: ffi::ArrowArray, + schema: &ffi::ArrowSchema, +) -> PolarsResult { + let field = ffi::import_field_from_c(schema)?; + let out = ffi::import_array_from_c(array, field.data_type)?; + Ok(out) +} + +#[cfg(test)] +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_ffi() { + let s = Series::new("a", [1, 2]); + let e = export_series(&s); + + unsafe { + assert_eq!(import_series(e).unwrap(), s); + }; + } +} diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 003b36abef42..e299003fab50 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -9,12 +9,12 @@ repository = { workspace = true } description = "IO related logic for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow" } -polars-core = { version = "0.32.0", path = "../polars-core", features = [], default-features = false } -polars-error = { version = "0.32.0", path = "../polars-error", default-features = false } -polars-json = { version = "0.32.0", optional = true, path = "../polars-json" } -polars-time = { version = "0.32.0", path = "../polars-time", features = [], default-features = false, optional = true } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true } +polars-core = { workspace = true } +polars-error = { workspace = true, default-features = false } +polars-json = { workspace = true, optional = true } +polars-time = { workspace = true, features = [], optional = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } @@ -25,20 +25,26 @@ chrono-tz = { workspace = true, optional = true } fast-float = { version = "0.2", optional = true } flate2 = { version = "1", optional = true, default-features = false } futures = { workspace = true, optional = true } +itoa = { workspace = true, optional = true } lexical = { version = "6", optional = true, default-features = false, features = ["std", "parse-integers"] } -lexical-core = { version = "0.8", optional = true } +lexical-core = { workspace = true, optional = true } memchr = { workspace = true } -memmap = { package = "memmap2", version = "0.7", optional = true } +memmap = { package = "memmap2", version = "0.7" } num-traits = { workspace = true } object_store = { workspace = true, optional = true } once_cell = { workspace = true } +percent-encoding = { workspace = true } rayon = { workspace = true } regex = { workspace = true } +reqwest = { workspace = true, optional = true } +ryu = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"], optional = true } serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value"], optional = true } simd-json = { workspace = true, optional = true } -simdutf8 = { version = "0.1", optional = true } -tokio = { version = "1.26", features = ["net"], optional = true } +simdutf8 = { workspace = true, optional = true } +smartstring = { workspace = true } +tokio = { workspace = true, features = ["net", "rt-multi-thread", "time"], optional = true } +tokio-util = { workspace = true, features = ["io", "io-util"], optional = true } url = { workspace = true, optional = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] @@ -51,22 +57,21 @@ tempdir = "0.3.7" default = ["decompress"] # support for arrows json parsing json = [ - "arrow/io_json_write", "polars-json", "simd-json", - "memmap", "lexical", "lexical-core", "serde_json", "dtype-struct", + "csv", ] # support for arrows ipc file parsing -ipc = ["arrow/io_ipc", "arrow/io_ipc_compression", "memmap"] +ipc = ["arrow/io_ipc", "arrow/io_ipc_compression"] # support for arrows streaming ipc file parsing ipc_streaming = ["arrow/io_ipc", "arrow/io_ipc_compression"] # support for arrow avro parsing avro = ["arrow/io_avro", "arrow/io_avro_compression"] -csv = ["memmap", "lexical", "polars-core/rows", "lexical-core", "fast-float", "simdutf8"] +csv = ["lexical", "polars-core/rows", "itoa", "ryu", "fast-float", "simdutf8"] decompress = ["flate2/rust_backend"] decompress-fast = ["flate2/zlib-ng"] dtype-categorical = ["polars-core/dtype-categorical"] @@ -87,12 +92,12 @@ dtype-struct = ["polars-core/dtype-struct"] dtype-decimal = ["polars-core/dtype-decimal"] fmt = ["polars-core/fmt"] lazy = [] -parquet = ["polars-core/parquet", "arrow/io_parquet", "arrow/io_parquet_compression", "memmap"] -async = ["async-trait", "futures", "tokio", "arrow/io_ipc_write_async", "polars-error/regex"] -cloud = ["object_store", "async", "url"] -aws = ["object_store/aws", "cloud", "polars-core/aws"] -azure = ["object_store/azure", "cloud", "polars-core/azure"] -gcp = ["object_store/gcp", "cloud", "polars-core/gcp"] +parquet = ["polars-core/parquet", "arrow/io_parquet", "arrow/io_parquet_compression"] +async = ["async-trait", "futures", "tokio", "tokio-util", "arrow/io_ipc_write_async", "polars-error/regex"] +cloud = ["object_store", "async", "polars-error/object_store", "url"] +aws = ["object_store/aws", "cloud", "reqwest"] +azure = ["object_store/azure", "cloud"] +gcp = ["object_store/gcp", "cloud"] partition = ["polars-core/partition_by"] temporal = ["dtype-datetime", "dtype-date", "dtype-time"] simd = [] diff --git a/crates/polars-io/README.md b/crates/polars-io/README.md index ef01d9c1d33b..963a1a86e25d 100644 --- a/crates/polars-io/README.md +++ b/crates/polars-io/README.md @@ -1,5 +1,5 @@ # polars-io -`polars-io` is a sub-crate that provides IO functionality for the Polars dataframe library. +`polars-io` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, that provides IO functionality for the Polars dataframe library. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-io/src/avro/read.rs b/crates/polars-io/src/avro/read.rs index 98477a3119c8..5ed7c46f1f47 100644 --- a/crates/polars-io/src/avro/read.rs +++ b/crates/polars-io/src/avro/read.rs @@ -7,7 +7,9 @@ use polars_core::prelude::*; use super::{finish_reader, ArrowChunk, ArrowReader, ArrowResult}; use crate::prelude::*; -/// Read Apache Avro format into a DataFrame +/// Read [Apache Avro] format into a [`DataFrame`] +/// +/// [Apache Avro]: https://avro.apache.org /// /// # Example /// ``` diff --git a/crates/polars-io/src/avro/write.rs b/crates/polars-io/src/avro/write.rs index 4a2658e04726..7d3de8469ce3 100644 --- a/crates/polars-io/src/avro/write.rs +++ b/crates/polars-io/src/avro/write.rs @@ -6,7 +6,9 @@ pub use Compression as AvroCompression; use super::*; -/// Write a DataFrame to Apache Avro format +/// Write a [`DataFrame`] to [Apache Avro] format +/// +/// [Apache Avro]: https://avro.apache.org /// /// # Example /// diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs index 2060c876f4a8..e0652ffb88fb 100644 --- a/crates/polars-io/src/cloud/adaptors.rs +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -1,18 +1,25 @@ //! Interface with the object_store crate and define AsyncSeek, AsyncRead. -//! This is used, for example, by the parquet2 crate. +//! This is used, for example, by the [parquet2] crate. +//! +//! [parquet2]: https://crates.io/crates/parquet2 use std::io::{self}; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; +use bytes::Bytes; use futures::executor::block_on; use futures::future::BoxFuture; -use futures::lock::Mutex; use futures::{AsyncRead, AsyncSeek, Future, TryFutureExt}; use object_store::path::Path; -use object_store::ObjectStore; +use object_store::{MultipartId, ObjectStore}; +use polars_error::{to_compute_err, PolarsError, PolarsResult}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; -type OptionalFuture = Arc>>>>>; +use super::*; +use crate::pl_async::get_runtime; + +type OptionalFuture = Option>>; /// Adaptor to translate from AsyncSeek and AsyncRead to the object_store get_range API. pub struct CloudReader { @@ -21,7 +28,7 @@ pub struct CloudReader { // The total size of the object is required when seeking from the end of the file. length: Option, // Hold an reference to the store in a thread safe way. - object_store: Arc>>, + object_store: Arc, // The path in the object_store of the current object being read. path: Path, // If a read is pending then `active` will point to its future. @@ -29,17 +36,13 @@ pub struct CloudReader { } impl CloudReader { - pub fn new( - length: Option, - object_store: Arc>>, - path: Path, - ) -> Self { + pub fn new(length: Option, object_store: Arc, path: Path) -> Self { Self { pos: 0, length, object_store, path, - active: Arc::new(Mutex::new(None)), + active: None, } } @@ -48,24 +51,22 @@ impl CloudReader { mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, length: usize, - ) -> std::task::Poll>> { + ) -> std::task::Poll> { let start = self.pos as usize; // If we already have a future just poll it. - if let Some(fut) = self.active.lock().await.as_mut() { + if let Some(fut) = self.active.as_mut() { return Future::poll(fut.as_mut(), cx); } // Create the future. let future = { let path = self.path.clone(); - let arc = self.object_store.clone(); + let object_store = self.object_store.clone(); // Use an async move block to get our owned objects. async move { - let object_store = arc.lock().await; object_store .get_range(&path, start..start + length) - .map_ok(|r| r.to_vec()) .map_err(|e| { std::io::Error::new( std::io::ErrorKind::Other, @@ -84,8 +85,7 @@ impl CloudReader { let polled = Future::poll(future.as_mut(), cx); // Save for next time. - let mut state = self.active.lock().await; - *state = Some(future); + self.active = Some(future); polled } } @@ -100,7 +100,7 @@ impl AsyncRead for CloudReader { // With this approach we keep ownership of the buffer and we don't have to pass it to the future runtime. match block_on(self.read_operation(cx, buf.len())) { Poll::Ready(Ok(bytes)) => { - buf.copy_from_slice(&bytes); + buf.copy_from_slice(bytes.as_ref()); Poll::Ready(Ok(bytes.len())) }, Poll::Ready(Err(e)) => Poll::Ready(Err(e)), @@ -126,6 +126,160 @@ impl AsyncSeek for CloudReader { }, io::SeekFrom::Current(pos) => self.pos = (self.pos as i64 + pos) as u64, }; + self.active = None; std::task::Poll::Ready(Ok(self.pos)) } } + +/// Adaptor which wraps the asynchronous interface of [ObjectStore::put_multipart](https://docs.rs/object_store/latest/object_store/trait.ObjectStore.html#tymethod.put_multipart) +/// exposing a synchronous interface which implements `std::io::Write`. +/// +/// This allows it to be used in sync code which would otherwise write to a simple File or byte stream, +/// such as with `polars::prelude::CsvWriter`. +pub struct CloudWriter { + // Hold a reference to the store + object_store: Arc, + // The path in the object_store which we want to write to + path: Path, + // ID of a partially-done upload, used to abort the upload on error + multipart_id: MultipartId, + // Internal writer, constructed at creation + writer: Box, +} + +impl CloudWriter { + /// Construct a new CloudWriter, re-using the given `object_store` + /// + /// Creates a new (current-thread) Tokio runtime + /// which bridges the sync writing process with the async ObjectStore multipart uploading. + /// TODO: Naming? + pub async fn new_with_object_store( + object_store: Arc, + path: Path, + ) -> PolarsResult { + let build_result = Self::build_writer(&object_store, &path).await; + match build_result { + Err(error) => Err(PolarsError::from(error)), + Ok((multipart_id, writer)) => Ok(CloudWriter { + object_store, + path, + multipart_id, + writer, + }), + } + } + + /// Constructs a new CloudWriter from a path and an optional set of CloudOptions. + /// + /// Wrapper around `CloudWriter::new_with_object_store` that is useful if you only have a single write task. + /// TODO: Naming? + pub async fn new(uri: &str, cloud_options: Option<&CloudOptions>) -> PolarsResult { + let (cloud_location, object_store) = + crate::cloud::build_object_store(uri, cloud_options).await?; + Self::new_with_object_store(object_store, cloud_location.prefix.into()).await + } + + async fn build_writer( + object_store: &Arc, + path: &Path, + ) -> object_store::Result<(MultipartId, Box)> { + let (multipart_id, s3_writer) = object_store.put_multipart(path).await?; + Ok((multipart_id, s3_writer)) + } + + async fn abort(&self) -> PolarsResult<()> { + self.object_store + .abort_multipart(&self.path, &self.multipart_id) + .await + .map_err(to_compute_err) + } +} + +impl std::io::Write for CloudWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + get_runtime().block_on(async { + let res = self.writer.write(buf).await; + if res.is_err() { + let _ = self.abort().await; + } + res + }) + } + + fn flush(&mut self) -> std::io::Result<()> { + get_runtime().block_on(async { + let res = self.writer.flush().await; + if res.is_err() { + let _ = self.abort().await; + } + res + }) + } +} + +impl Drop for CloudWriter { + fn drop(&mut self) { + let _ = get_runtime().block_on(self.writer.shutdown()); + } +} + +#[cfg(feature = "csv")] +#[cfg(test)] +mod tests { + use object_store::ObjectStore; + use polars_core::df; + use polars_core::prelude::{DataFrame, NamedFrom}; + + use super::*; + + fn example_dataframe() -> DataFrame { + df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap() + } + + #[test] + fn csv_to_local_objectstore_cloudwriter() { + use crate::csv::CsvWriter; + use crate::prelude::SerWriter; + + let mut df = example_dataframe(); + + let object_store: Arc = Arc::new( + object_store::local::LocalFileSystem::new_with_prefix(std::env::temp_dir()) + .expect("Could not initialize connection"), + ); + + let path: object_store::path::Path = "cloud_writer_example.csv".into(); + + let mut cloud_writer = get_runtime() + .block_on(CloudWriter::new_with_object_store(object_store, path)) + .unwrap(); + CsvWriter::new(&mut cloud_writer) + .finish(&mut df) + .expect("Could not write dataframe as CSV to remote location"); + } + + // Skip this tests on Windows since it does not have a convenient /tmp/ location. + #[cfg_attr(target_os = "windows", ignore)] + #[test] + fn cloudwriter_from_cloudlocation_test() { + use crate::csv::CsvWriter; + use crate::prelude::SerWriter; + + let mut df = example_dataframe(); + + let mut cloud_writer = get_runtime() + .block_on(CloudWriter::new( + "file:///tmp/cloud_writer_example2.csv", + None, + )) + .unwrap(); + + CsvWriter::new(&mut cloud_writer) + .finish(&mut df) + .expect("Could not write dataframe as CSV to remote location"); + } +} diff --git a/crates/polars-io/src/cloud/glob.rs b/crates/polars-io/src/cloud/glob.rs index f1798be9fa3d..ddb76c7a2936 100644 --- a/crates/polars-io/src/cloud/glob.rs +++ b/crates/polars-io/src/cloud/glob.rs @@ -2,12 +2,13 @@ use futures::future::ready; use futures::{StreamExt, TryStreamExt}; use object_store::path::Path; use polars_arrow::error::polars_bail; -use polars_core::cloud::CloudOptions; use polars_core::error::to_compute_err; use polars_core::prelude::{polars_ensure, polars_err, PolarsError, PolarsResult}; use regex::Regex; use url::Url; +use super::*; + const DELIMITER: char = '/'; /// Split the url in @@ -95,11 +96,17 @@ impl CloudLocation { let key = parsed.path(); let bucket = parsed .host() - .ok_or(polars_err!(ComputeError: "cannot parse bucket (host) from url: {}", url))? + .ok_or_else( + || polars_err!(ComputeError: "cannot parse bucket (host) from url: {}", url), + )? .to_string(); (bucket, key) }; - let (mut prefix, expansion) = extract_prefix_expansion(key)?; + + let key = percent_encoding::percent_decode_str(key) + .decode_utf8() + .map_err(to_compute_err)?; + let (mut prefix, expansion) = extract_prefix_expansion(&key)?; if is_local && key.starts_with(DELIMITER) { prefix.insert(0, DELIMITER); } @@ -159,7 +166,7 @@ pub async fn glob(url: &str, cloud_options: Option<&CloudOptions>) -> PolarsResu expansion, }, store, - ) = super::build(url, cloud_options)?; + ) = super::build_object_store(url, cloud_options).await?; let matcher = Matcher::new(prefix.clone(), expansion.as_deref())?; let list_stream = store diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs index 66cb604bc298..6118a4bb9a76 100644 --- a/crates/polars-io/src/cloud/mod.rs +++ b/crates/polars-io/src/cloud/mod.rs @@ -1,79 +1,31 @@ //! Interface with cloud storage through the object_store crate. +#[cfg(feature = "cloud")] +use std::borrow::Cow; +#[cfg(feature = "cloud")] use std::str::FromStr; +#[cfg(feature = "cloud")] +use std::sync::Arc; +#[cfg(feature = "cloud")] use object_store::local::LocalFileSystem; +#[cfg(feature = "cloud")] use object_store::ObjectStore; -use polars_core::cloud::{CloudOptions, CloudType}; +#[cfg(feature = "cloud")] use polars_core::prelude::{polars_bail, PolarsError, PolarsResult}; +#[cfg(feature = "cloud")] mod adaptors; +#[cfg(feature = "cloud")] mod glob; +#[cfg(feature = "cloud")] +mod object_store_setup; +pub mod options; + +#[cfg(feature = "cloud")] pub use adaptors::*; +#[cfg(feature = "cloud")] pub use glob::*; - -type BuildResult = PolarsResult<(CloudLocation, Box)>; - -#[allow(dead_code)] -fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { - polars_bail!( - ComputeError: - "feature '{}' must be enabled in order to use '{}' cloud urls", feature, scheme, - ); -} -#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] -fn err_missing_configuration(feature: &str, scheme: &str) -> BuildResult { - polars_bail!( - ComputeError: - "configuration '{}' must be provided in order to use '{}' cloud urls", feature, scheme, - ); -} -/// Build an ObjectStore based on the URL and passed in url. Return the cloud location and an implementation of the object store. -pub fn build(url: &str, _options: Option<&CloudOptions>) -> BuildResult { - let cloud_location = CloudLocation::new(url)?; - let store = match CloudType::from_str(url)? { - CloudType::File => { - let local = LocalFileSystem::new(); - Ok::<_, PolarsError>(Box::new(local) as Box) - }, - CloudType::Aws => { - #[cfg(feature = "aws")] - match _options { - Some(options) => { - let store = options.build_aws(&cloud_location.bucket)?; - Ok::<_, PolarsError>(Box::new(store) as Box) - }, - _ => return err_missing_configuration("aws", &cloud_location.scheme), - } - #[cfg(not(feature = "aws"))] - return err_missing_feature("aws", &cloud_location.scheme); - }, - CloudType::Gcp => { - #[cfg(feature = "gcp")] - match _options { - Some(options) => { - let store = options.build_gcp(&cloud_location.bucket)?; - Ok::<_, PolarsError>(Box::new(store) as Box) - }, - _ => return err_missing_configuration("gcp", &cloud_location.scheme), - } - #[cfg(not(feature = "gcp"))] - return err_missing_feature("gcp", &cloud_location.scheme); - }, - CloudType::Azure => { - { - #[cfg(feature = "azure")] - match _options { - Some(options) => { - let store = options.build_azure(&cloud_location.bucket)?; - Ok::<_, PolarsError>(Box::new(store) as Box) - }, - _ => return err_missing_configuration("azure", &cloud_location.scheme), - } - } - #[cfg(not(feature = "azure"))] - return err_missing_feature("azure", &cloud_location.scheme); - }, - }?; - Ok((cloud_location, store)) -} +#[cfg(feature = "cloud")] +pub use object_store_setup::*; +pub use options::*; diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs new file mode 100644 index 000000000000..91119dbbf248 --- /dev/null +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -0,0 +1,99 @@ +use once_cell::sync::Lazy; +pub use options::*; +use tokio::sync::RwLock; + +use super::*; + +type CacheKey = (CloudType, Option); + +/// A very simple cache that only stores a single object-store. +/// This greatly reduces the query times as multiple object stores (when reading many small files) +/// get rate limited when querying the DNS (can take up to 5s). +#[allow(clippy::type_complexity)] +static OBJECT_STORE_CACHE: Lazy)>>> = + Lazy::new(Default::default); + +type BuildResult = PolarsResult<(CloudLocation, Arc)>; + +#[allow(dead_code)] +fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { + polars_bail!( + ComputeError: + "feature '{}' must be enabled in order to use '{}' cloud urls", feature, scheme, + ); +} +#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] +fn err_missing_configuration(feature: &str, scheme: &str) -> BuildResult { + polars_bail!( + ComputeError: + "configuration '{}' must be provided in order to use '{}' cloud urls", feature, scheme, + ); +} + +/// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. +pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> BuildResult { + let cloud_location = CloudLocation::new(url)?; + + let cloud_type = CloudType::from_str(url)?; + let options = options.cloned(); + let key = (cloud_type, options); + + { + let cache = OBJECT_STORE_CACHE.read().await; + if let Some((stored_key, store)) = cache.as_ref() { + if stored_key == &key { + return Ok((cloud_location, store.clone())); + } + } + } + + let store = match key.0 { + CloudType::File => { + let local = LocalFileSystem::new(); + Ok::<_, PolarsError>(Arc::new(local) as Arc) + }, + CloudType::Aws => { + #[cfg(feature = "aws")] + { + let options = key + .1 + .as_ref() + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned(Default::default())); + let store = options.build_aws(url).await?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + } + #[cfg(not(feature = "aws"))] + return err_missing_feature("aws", &cloud_location.scheme); + }, + CloudType::Gcp => { + #[cfg(feature = "gcp")] + match key.1.as_ref() { + Some(options) => { + let store = options.build_gcp(url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + }, + _ => return err_missing_configuration("gcp", &cloud_location.scheme), + } + #[cfg(not(feature = "gcp"))] + return err_missing_feature("gcp", &cloud_location.scheme); + }, + CloudType::Azure => { + { + #[cfg(feature = "azure")] + match key.1.as_ref() { + Some(options) => { + let store = options.build_azure(url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + }, + _ => return err_missing_configuration("azure", &cloud_location.scheme), + } + } + #[cfg(not(feature = "azure"))] + return err_missing_feature("azure", &cloud_location.scheme); + }, + }?; + let mut cache = OBJECT_STORE_CACHE.write().await; + *cache = Some((key, store.clone())); + Ok((cloud_location, store)) +} diff --git a/crates/polars-core/src/cloud.rs b/crates/polars-io/src/cloud/options.rs similarity index 56% rename from crates/polars-core/src/cloud.rs rename to crates/polars-io/src/cloud/options.rs index 7ec42c083bc9..78aedaf8e90f 100644 --- a/crates/polars-core/src/cloud.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -12,15 +12,26 @@ use object_store::azure::MicrosoftAzureBuilder; use object_store::gcp::GoogleCloudStorageBuilder; #[cfg(feature = "gcp")] pub use object_store::gcp::GoogleConfigKey; -#[cfg(feature = "async")] +#[cfg(feature = "cloud")] use object_store::ObjectStore; -use polars_error::{polars_bail, polars_err}; -#[cfg(feature = "serde-lazy")] +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +use object_store::{BackoffConfig, RetryConfig}; +#[cfg(feature = "aws")] +use once_cell::sync::Lazy; +use polars_core::error::{PolarsError, PolarsResult}; +use polars_error::*; +#[cfg(feature = "aws")] +use polars_utils::cache::FastFixedCache; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "async")] +#[cfg(feature = "aws")] +use smartstring::alias::String as SmartString; +#[cfg(feature = "cloud")] use url::Url; -use crate::error::{PolarsError, PolarsResult}; +#[cfg(feature = "aws")] +static BUCKET_REGION: Lazy>> = + Lazy::new(|| std::sync::Mutex::new(FastFixedCache::new(32))); /// The type of the config keys must satisfy the following requirements: /// 1. must be easily collected into a HashMap, the type required by the object_crate API. @@ -32,7 +43,7 @@ use crate::error::{PolarsError, PolarsResult}; type Configs = Vec<(T, String)>; #[derive(Clone, Debug, Default, PartialEq)] -#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Options to connect to various cloud providers. pub struct CloudOptions { #[cfg(feature = "aws")] @@ -41,6 +52,7 @@ pub struct CloudOptions { azure: Option>, #[cfg(feature = "gcp")] gcp: Option>, + pub max_retries: usize, } #[allow(dead_code)] @@ -63,6 +75,7 @@ where .collect::>>() } +#[derive(PartialEq)] pub enum CloudType { Aws, Azure, @@ -73,23 +86,31 @@ pub enum CloudType { impl FromStr for CloudType { type Err = PolarsError; - #[cfg(feature = "async")] + #[cfg(feature = "cloud")] fn from_str(url: &str) -> Result { - let parsed = Url::parse(url).map_err(polars_error::to_compute_err)?; + let parsed = Url::parse(url).map_err(to_compute_err)?; Ok(match parsed.scheme() { - "s3" => Self::Aws, - "az" | "adl" | "abfs" => Self::Azure, + "s3" | "s3a" => Self::Aws, + "az" | "azure" | "adl" | "abfs" | "abfss" => Self::Azure, "gs" | "gcp" | "gcs" => Self::Gcp, "file" => Self::File, _ => polars_bail!(ComputeError: "unknown url scheme"), }) } - #[cfg(not(feature = "async"))] + #[cfg(not(feature = "cloud"))] fn from_str(_s: &str) -> Result { polars_bail!(ComputeError: "at least one of the cloud features must be enabled"); } } +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +fn get_retry_config(max_retries: usize) -> RetryConfig { + RetryConfig { + backoff: BackoffConfig::default(), + max_retries, + retry_timeout: std::time::Duration::from_secs(10), + } +} impl CloudOptions { /// Set the configuration for AWS connections. This is the preferred API from rust. @@ -107,22 +128,58 @@ impl CloudOptions { self } - /// Build the ObjectStore implementation for AWS. + /// Build the [`ObjectStore`] implementation for AWS. #[cfg(feature = "aws")] - pub fn build_aws(&self, bucket_name: &str) -> PolarsResult { - let options = self - .aws - .as_ref() - .ok_or_else(|| polars_err!(ComputeError: "`aws` configuration missing"))?; - - let mut builder = AmazonS3Builder::new(); - for (key, value) in options.iter() { - builder = builder.with_config(*key, value); + pub async fn build_aws(&self, url: &str) -> PolarsResult { + let options = self.aws.as_ref(); + let mut builder = AmazonS3Builder::from_env().with_url(url); + if let Some(options) = options { + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } } + + if builder + .get_config_value(&AmazonS3ConfigKey::DefaultRegion) + .is_none() + && builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_none() + { + let bucket = crate::cloud::CloudLocation::new(url)?.bucket; + let region = { + let bucket_region = BUCKET_REGION.lock().unwrap(); + bucket_region.get(bucket.as_str()).cloned() + }; + + match region { + Some(region) => { + builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str()) + }, + None => { + polars_warn!("'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning."); + let result = reqwest::Client::builder() + .build() + .unwrap() + .head(format!("https://{bucket}.s3.amazonaws.com")) + .send() + .await + .map_err(to_compute_err)?; + if let Some(region) = result.headers().get("x-amz-bucket-region") { + let region = + std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?; + let mut bucket_region = BUCKET_REGION.lock().unwrap(); + bucket_region.insert(bucket.into(), region.into()); + builder = builder.with_config(AmazonS3ConfigKey::Region, region) + } + }, + }; + }; + builder - .with_bucket_name(bucket_name) + .with_retry(get_retry_config(self.max_retries)) .build() - .map_err(polars_error::to_compute_err) + .map_err(to_compute_err) } /// Set the configuration for Azure connections. This is the preferred API from rust. @@ -140,22 +197,22 @@ impl CloudOptions { self } - /// Build the ObjectStore implementation for Azure. + /// Build the [`ObjectStore`] implementation for Azure. #[cfg(feature = "azure")] - pub fn build_azure(&self, container_name: &str) -> PolarsResult { - let options = self - .azure - .as_ref() - .ok_or_else(|| polars_err!(ComputeError: "`azure` configuration missing"))?; - - let mut builder = MicrosoftAzureBuilder::new(); - for (key, value) in options.iter() { - builder = builder.with_config(*key, value); + pub fn build_azure(&self, url: &str) -> PolarsResult { + let options = self.azure.as_ref(); + let mut builder = MicrosoftAzureBuilder::from_env(); + if let Some(options) = options { + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } } + builder - .with_container_name(container_name) + .with_url(url) + .with_retry(get_retry_config(self.max_retries)) .build() - .map_err(polars_error::to_compute_err) + .map_err(to_compute_err) } /// Set the configuration for GCP connections. This is the preferred API from rust. @@ -173,22 +230,22 @@ impl CloudOptions { self } - /// Build the ObjectStore implementation for GCP. + /// Build the [`ObjectStore`] implementation for GCP. #[cfg(feature = "gcp")] - pub fn build_gcp(&self, bucket_name: &str) -> PolarsResult { - let options = self - .gcp - .as_ref() - .ok_or_else(|| polars_err!(ComputeError: "`gcp` configuration missing"))?; - - let mut builder = GoogleCloudStorageBuilder::new(); - for (key, value) in options.iter() { - builder = builder.with_config(*key, value); + pub fn build_gcp(&self, url: &str) -> PolarsResult { + let options = self.gcp.as_ref(); + let mut builder = GoogleCloudStorageBuilder::from_env(); + if let Some(options) = options { + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } } + builder - .with_bucket_name(bucket_name) + .with_url(url) + .with_retry(get_retry_config(self.max_retries)) .build() - .map_err(polars_error::to_compute_err) + .map_err(to_compute_err) } /// Parse a configuration from a Hashmap. This is the interface from Python. diff --git a/crates/polars-io/src/csv/mod.rs b/crates/polars-io/src/csv/mod.rs index 483012a42299..33da2fc51ca5 100644 --- a/crates/polars-io/src/csv/mod.rs +++ b/crates/polars-io/src/csv/mod.rs @@ -18,7 +18,7 @@ //! //! CsvWriter::new(&mut file) //! .has_header(true) -//! .with_delimiter(b',') +//! .with_separator(b',') //! .finish(df) //! } //! ``` @@ -66,8 +66,7 @@ pub use write::{BatchedWriter, CsvWriter, QuoteStyle}; pub use write_impl::SerializeOptions; use crate::csv::read_impl::CoreReader; -use crate::csv::utils::get_reader_bytes; use crate::mmap::MmapBytesReader; use crate::predicates::PhysicalIoExpr; -use crate::utils::resolve_homedir; +use crate::utils::{get_reader_bytes, resolve_homedir}; use crate::{RowCount, SerReader, SerWriter}; diff --git a/crates/polars-io/src/csv/parser.rs b/crates/polars-io/src/csv/parser.rs index b89d5cbcb297..1b7880f1352e 100644 --- a/crates/polars-io/src/csv/parser.rs +++ b/crates/polars-io/src/csv/parser.rs @@ -30,20 +30,20 @@ pub(crate) fn next_line_position_naive(input: &[u8], eol_char: u8) -> Option, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) -> Option { fn accept_line( line: &[u8], expected_fields: usize, - delimiter: u8, + separator: u8, eol_char: u8, quote_char: Option, ) -> bool { let mut count = 0usize; - for (field, _) in SplitFields::new(line, delimiter, quote_char, eol_char) { - if memchr2_iter(delimiter, eol_char, field).count() >= expected_fields { + for (field, _) in SplitFields::new(line, separator, quote_char, eol_char) { + if memchr2_iter(separator, eol_char, field).count() >= expected_fields { return false; } count += 1; @@ -95,10 +95,10 @@ pub(crate) fn next_line_position( match (line, expected_fields) { // count the fields, and determine if they are equal to what we expect from the schema (Some(line), Some(expected_fields)) => { - if accept_line(line, expected_fields, delimiter, eol_char, quote_char) { + if accept_line(line, expected_fields, separator, eol_char, quote_char) { let mut valid = true; for line in lines.take(2) { - if !accept_line(line, expected_fields, delimiter, eol_char, quote_char) { + if !accept_line(line, expected_fields, separator, eol_char, quote_char) { valid = false; break; } @@ -160,13 +160,13 @@ pub(crate) fn skip_whitespace(input: &[u8]) -> &[u8] { } #[inline] -/// Can be used to skip whitespace, but exclude the delimiter +/// Can be used to skip whitespace, but exclude the separator pub(crate) fn skip_whitespace_exclude(input: &[u8], exclude: u8) -> &[u8] { skip_condition(input, |b| b != exclude && (is_whitespace(b))) } #[inline] -/// Can be used to skip whitespace, but exclude the delimiter +/// Can be used to skip whitespace, but exclude the separator pub(crate) fn skip_whitespace_line_ending_exclude( input: &[u8], exclude: u8, @@ -188,7 +188,7 @@ pub(crate) fn get_line_stats( n_lines: usize, eol_char: u8, expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, ) -> Option<(f32, f32)> { let mut lengths = Vec::with_capacity(n_lines); @@ -204,7 +204,7 @@ pub(crate) fn get_line_stats( let pos = next_line_position( bytes_trunc, Some(expected_fields), - delimiter, + separator, quote_char, eol_char, )?; @@ -350,7 +350,7 @@ fn skip_this_line(bytes: &[u8], quote: Option, eol_char: u8) -> &[u8] { pub(super) fn parse_lines<'a>( mut bytes: &'a [u8], offset: usize, - delimiter: u8, + separator: u8, comment_char: Option, quote_char: Option, eol_char: u8, @@ -391,9 +391,9 @@ pub(super) fn parse_lines<'a>( // only when we have one column \n should not be skipped // other widths should have commas. bytes = if schema_len > 1 { - skip_whitespace_line_ending_exclude(bytes, delimiter, eol_char) + skip_whitespace_line_ending_exclude(bytes, separator, eol_char) } else { - skip_whitespace_exclude(bytes, delimiter) + skip_whitespace_exclude(bytes, separator) }; if bytes.is_empty() { return Ok(original_bytes_len); @@ -416,7 +416,7 @@ pub(super) fn parse_lines<'a>( let mut next_projected = unsafe { projection_iter.next().unwrap_unchecked() }; let mut processed_fields = 0; - let mut iter = SplitFields::new(bytes, delimiter, quote_char, eol_char); + let mut iter = SplitFields::new(bytes, separator, quote_char, eol_char); let mut idx = 0u32; let mut read_sol = 0; loop { diff --git a/crates/polars-io/src/csv/read.rs b/crates/polars-io/src/csv/read.rs index 5f0b3a228596..4d8527b70b80 100644 --- a/crates/polars-io/src/csv/read.rs +++ b/crates/polars-io/src/csv/read.rs @@ -109,7 +109,7 @@ where projection: Option>, /// Optional column names to project/ select. columns: Option>, - delimiter: Option, + separator: Option, pub(crate) schema: Option, encoding: CsvEncoding, n_threads: Option, @@ -204,9 +204,9 @@ where self } - /// Set the CSV file's column delimiter as a byte character - pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.delimiter = Some(delimiter); + /// Set the CSV file's column separator as a byte character + pub fn with_separator(mut self, separator: u8) -> Self { + self.separator = Some(separator); self } @@ -310,8 +310,8 @@ where } /// Set the `char` used as quote char. The default is `b'"'`. If set to `[None]` quoting is disabled. - pub fn with_quote_char(mut self, quote: Option) -> Self { - self.quote_char = quote; + pub fn with_quote_char(mut self, quote_char: Option) -> Self { + self.quote_char = quote_char; self } @@ -358,7 +358,7 @@ impl<'a, R: MmapBytesReader + 'a> CsvReader<'a, R> { self.skip_rows_before_header, std::mem::take(&mut self.projection), self.max_records, - self.delimiter, + self.separator, self.has_header, self.ignore_errors, self.schema.clone(), @@ -481,7 +481,7 @@ impl<'a> CsvReader<'a, Box> { let (inferred_schema, _, _) = infer_file_schema( &reader_bytes, - self.delimiter.unwrap_or(b','), + self.separator.unwrap_or(b','), self.max_records, self.has_header, None, @@ -510,7 +510,7 @@ impl<'a> CsvReader<'a, Box> { let (inferred_schema, _, _) = infer_file_schema( &reader_bytes, - self.delimiter.unwrap_or(b','), + self.separator.unwrap_or(b','), self.max_records, self.has_header, None, @@ -543,7 +543,7 @@ where max_records: Some(128), skip_rows_before_header: 0, projection: None, - delimiter: None, + separator: None, has_header: true, ignore_errors: false, schema: None, @@ -584,7 +584,7 @@ where #[cfg(feature = "dtype-categorical")] if _has_cat { - _cat_lock = Some(polars_core::IUseStringCache::hold()) + _cat_lock = Some(polars_core::StringCacheHolder::hold()) } let mut csv_reader = self.core_reader(Some(Arc::new(schema)), to_cast)?; @@ -602,7 +602,7 @@ where }) .unwrap_or(false); if has_cat { - _cat_lock = Some(polars_core::IUseStringCache::hold()) + _cat_lock = Some(polars_core::StringCacheHolder::hold()) } } let mut csv_reader = self.core_reader(self.schema.clone(), vec![])?; diff --git a/crates/polars-io/src/csv/read_impl/batched_mmap.rs b/crates/polars-io/src/csv/read_impl/batched_mmap.rs index 18824d5e08f1..f0299ca40fe9 100644 --- a/crates/polars-io/src/csv/read_impl/batched_mmap.rs +++ b/crates/polars-io/src/csv/read_impl/batched_mmap.rs @@ -13,7 +13,7 @@ pub(crate) fn get_file_chunks_iterator( chunk_size: usize, bytes: &[u8], expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) { @@ -27,7 +27,7 @@ pub(crate) fn get_file_chunks_iterator( let end_pos = match next_line_position( &bytes[search_pos..], Some(expected_fields), - delimiter, + separator, quote_char, eol_char, ) { @@ -49,7 +49,7 @@ struct ChunkOffsetIter<'a> { // not a promise, but something we want rows_per_batch: usize, expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, } @@ -68,7 +68,7 @@ impl<'a> Iterator for ChunkOffsetIter<'a> { let bytes_first_row = next_line_position( &self.bytes[self.last_offset + 2..], Some(self.expected_fields), - self.delimiter, + self.separator, self.quote_char, self.eol_char, ) @@ -84,7 +84,7 @@ impl<'a> Iterator for ChunkOffsetIter<'a> { self.rows_per_batch * bytes_first_row, self.bytes, self.expected_fields, - self.delimiter, + self.separator, self.quote_char, self.eol_char, ); @@ -124,7 +124,7 @@ impl<'a> CoreReader<'a> { n_chunks: offset_batch_size, rows_per_batch: self.chunk_size, expected_fields: self.schema.len(), - delimiter: self.delimiter, + separator: self.separator, quote_char: self.quote_char, eol_char: self.eol_char, }; @@ -136,7 +136,7 @@ impl<'a> CoreReader<'a> { // RAII structure that will ensure we maintain a global stringcache #[cfg(feature = "dtype-categorical")] let _cat_lock = if _has_cat { - Some(polars_core::IUseStringCache::hold()) + Some(polars_core::StringCacheHolder::hold()) } else { None }; @@ -164,7 +164,7 @@ impl<'a> CoreReader<'a> { truncate_ragged_lines: self.truncate_ragged_lines, n_rows: self.n_rows, encoding: self.encoding, - delimiter: self.delimiter, + separator: self.separator, schema: self.schema, rows_read: 0, _cat_lock, @@ -192,11 +192,11 @@ pub struct BatchedCsvReaderMmap<'a> { ignore_errors: bool, n_rows: Option, encoding: CsvEncoding, - delimiter: u8, + separator: u8, schema: SchemaRef, rows_read: IdxSize, #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, + _cat_lock: Option, #[cfg(not(feature = "dtype-categorical"))] _cat_lock: Option, } @@ -233,7 +233,7 @@ impl<'a> BatchedCsvReaderMmap<'a> { .map(|(bytes_offset_thread, stop_at_nbytes)| { let mut df = read_chunk( bytes, - self.delimiter, + self.separator, self.schema.as_ref(), self.ignore_errors, &self.projection, diff --git a/crates/polars-io/src/csv/read_impl/batched_read.rs b/crates/polars-io/src/csv/read_impl/batched_read.rs index af3831f00b70..9e8e6b6e6836 100644 --- a/crates/polars-io/src/csv/read_impl/batched_read.rs +++ b/crates/polars-io/src/csv/read_impl/batched_read.rs @@ -14,7 +14,7 @@ pub(crate) fn get_offsets( chunk_size: usize, bytes: &[u8], expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) { @@ -29,7 +29,7 @@ pub(crate) fn get_offsets( let end_pos = match next_line_position( &bytes[search_pos..], Some(expected_fields), - delimiter, + separator, quote_char, eol_char, ) { @@ -57,7 +57,7 @@ struct ChunkReader<'a> { // not a promise, but something we want rows_per_batch: usize, expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, } @@ -67,7 +67,7 @@ impl<'a> ChunkReader<'a> { file: &'a File, rows_per_batch: usize, expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, page_size: u64, @@ -85,7 +85,7 @@ impl<'a> ChunkReader<'a> { n_chunks: 16, rows_per_batch, expected_fields, - delimiter, + separator, quote_char, eol_char, } @@ -132,7 +132,7 @@ impl<'a> ChunkReader<'a> { bytes_first_row = next_line_position( &self.buf[2..], Some(self.expected_fields), - self.delimiter, + self.separator, self.quote_char, self.eol_char, ); @@ -179,7 +179,7 @@ impl<'a> ChunkReader<'a> { self.rows_per_batch * bytes_first_row, &self.buf, self.expected_fields, - self.delimiter, + self.separator, self.quote_char, self.eol_char, ); @@ -206,7 +206,7 @@ impl<'a> CoreReader<'a> { file, self.chunk_size, self.schema.len(), - self.delimiter, + self.separator, self.quote_char, self.eol_char, 4096, @@ -219,7 +219,7 @@ impl<'a> CoreReader<'a> { // RAII structure that will ensure we maintain a global stringcache #[cfg(feature = "dtype-categorical")] let _cat_lock = if _has_cat { - Some(polars_core::IUseStringCache::hold()) + Some(polars_core::StringCacheHolder::hold()) } else { None }; @@ -247,7 +247,7 @@ impl<'a> CoreReader<'a> { truncate_ragged_lines: self.truncate_ragged_lines, n_rows: self.n_rows, encoding: self.encoding, - delimiter: self.delimiter, + separator: self.separator, schema: self.schema, rows_read: 0, _cat_lock, @@ -275,11 +275,11 @@ pub struct BatchedCsvReaderRead<'a> { truncate_ragged_lines: bool, n_rows: Option, encoding: CsvEncoding, - delimiter: u8, + separator: u8, schema: SchemaRef, rows_read: IdxSize, #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, + _cat_lock: Option, #[cfg(not(feature = "dtype-categorical"))] _cat_lock: Option, } @@ -330,7 +330,7 @@ impl<'a> BatchedCsvReaderRead<'a> { let stop_at_n_bytes = chunk.len(); let mut df = read_chunk( chunk, - self.delimiter, + self.separator, self.schema.as_ref(), self.ignore_errors, &self.projection, diff --git a/crates/polars-io/src/csv/read_impl/mod.rs b/crates/polars-io/src/csv/read_impl/mod.rs index 2b13585f80bf..3d9b43adc15c 100644 --- a/crates/polars-io/src/csv/read_impl/mod.rs +++ b/crates/polars-io/src/csv/read_impl/mod.rs @@ -110,7 +110,7 @@ pub(crate) struct CoreReader<'a> { encoding: CsvEncoding, n_threads: Option, has_header: bool, - delimiter: u8, + separator: u8, sample_size: usize, chunk_size: usize, low_memory: bool, @@ -191,7 +191,7 @@ impl<'a> CoreReader<'a> { mut skip_rows: usize, mut projection: Option>, max_records: Option, - delimiter: Option, + separator: Option, has_header: bool, ignore_errors: bool, schema: Option, @@ -228,7 +228,7 @@ impl<'a> CoreReader<'a> { } // check if schema should be inferred - let delimiter = delimiter.unwrap_or(b','); + let separator = separator.unwrap_or(b','); let mut schema = match schema { Some(schema) => schema, @@ -239,14 +239,14 @@ impl<'a> CoreReader<'a> { // again after decompression. #[cfg(any(feature = "decompress", feature = "decompress-fast"))] if let Some(b) = - decompress(&reader_bytes, n_rows, delimiter, quote_char, eol_char) + decompress(&reader_bytes, n_rows, separator, quote_char, eol_char) { reader_bytes = ReaderBytes::Owned(b); } let (inferred_schema, _, _) = infer_file_schema( &reader_bytes, - delimiter, + separator, max_records, has_header, schema_overwrite.as_deref(), @@ -300,7 +300,7 @@ impl<'a> CoreReader<'a> { encoding, n_threads, has_header, - delimiter, + separator, sample_size, chunk_size, low_memory, @@ -325,7 +325,7 @@ impl<'a> CoreReader<'a> { let starting_point_offset = bytes.as_ptr() as usize; // Skip all leading white space and the occasional utf8-bom - bytes = skip_whitespace_exclude(skip_bom(bytes), self.delimiter); + bytes = skip_whitespace_exclude(skip_bom(bytes), self.separator); // \n\n can be a empty string row of a single column // in other cases we skip it. if self.schema.len() > 1 { @@ -354,7 +354,7 @@ impl<'a> CoreReader<'a> { // we don't pass expected fields // as we want to skip all rows // no matter the no. of fields - _ => next_line_position(bytes, None, self.delimiter, self.quote_char, eol_char), + _ => next_line_position(bytes, None, self.separator, self.quote_char, eol_char), } .ok_or_else(|| polars_err!(NoData: "not enough lines to skip"))?; @@ -391,7 +391,7 @@ impl<'a> CoreReader<'a> { self.sample_size, self.eol_char, self.schema.len(), - self.delimiter, + self.separator, self.quote_char, ) { if logging { @@ -415,7 +415,7 @@ impl<'a> CoreReader<'a> { if let Some(pos) = next_line_position( &bytes[n_bytes..], Some(self.schema.len()), - self.delimiter, + self.separator, self.quote_char, self.eol_char, ) { @@ -471,7 +471,7 @@ impl<'a> CoreReader<'a> { bytes, n_file_chunks, self.schema.len(), - self.delimiter, + self.separator, self.quote_char, self.eol_char, ); @@ -515,7 +515,7 @@ impl<'a> CoreReader<'a> { for i in projection { let (_, dtype) = self.schema.get_at_index(*i).ok_or_else(|| { polars_err!( - ComputeError: + OutOfBounds: "projection index {} is out of bounds for CSV schema with {} columns", i, self.schema.len(), ) @@ -557,23 +557,7 @@ impl<'a> CoreReader<'a> { // An empty file with a schema should return an empty DataFrame with that schema if bytes.is_empty() { - // TODO! add DataFrame::new_from_schema - let buffers = init_buffers( - &projection, - 0, - &self.schema, - &self.init_string_size_stats(&str_columns, 0), - self.quote_char, - self.encoding, - self.ignore_errors, - )?; - let df = DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - ); - return Ok(df); + return Ok(DataFrame::from(self.schema.as_ref())); } // all the buffers returned from the threads @@ -585,7 +569,6 @@ impl<'a> CoreReader<'a> { file_chunks .into_par_iter() .map(|(bytes_offset_thread, stop_at_nbytes)| { - let delimiter = self.delimiter; let schema = self.schema.as_ref(); let ignore_errors = self.ignore_errors; let projection = &projection; @@ -615,7 +598,7 @@ impl<'a> CoreReader<'a> { read += parse_lines( local_bytes, offset, - delimiter, + self.separator, self.comment_char, self.quote_char, self.eol_char, @@ -681,7 +664,7 @@ impl<'a> CoreReader<'a> { .map(|(bytes_offset_thread, stop_at_nbytes)| { let mut df = read_chunk( bytes, - self.delimiter, + self.separator, self.schema.as_ref(), self.ignore_errors, &projection, @@ -733,7 +716,7 @@ impl<'a> CoreReader<'a> { parse_lines( remaining_bytes, 0, - self.delimiter, + self.separator, self.comment_char, self.quote_char, self.eol_char, @@ -811,7 +794,7 @@ fn update_string_stats( #[allow(clippy::too_many_arguments)] fn read_chunk( bytes: &[u8], - delimiter: u8, + separator: u8, schema: &Schema, ignore_errors: bool, projection: &[usize], @@ -852,7 +835,7 @@ fn read_chunk( read += parse_lines( local_bytes, offset, - delimiter, + separator, comment_char, quote_char, eol_char, diff --git a/crates/polars-io/src/csv/splitfields.rs b/crates/polars-io/src/csv/splitfields.rs index 7e00aefc53dd..1804cea8559e 100644 --- a/crates/polars-io/src/csv/splitfields.rs +++ b/crates/polars-io/src/csv/splitfields.rs @@ -4,7 +4,7 @@ mod inner { /// This exists solely because we cannot split the lines naively as pub(crate) struct SplitFields<'a> { v: &'a [u8], - delimiter: u8, + separator: u8, finished: bool, quote_char: u8, quoting: bool, @@ -14,13 +14,13 @@ mod inner { impl<'a> SplitFields<'a> { pub(crate) fn new( slice: &'a [u8], - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) -> Self { Self { v: slice, - delimiter, + separator, finished: false, quote_char: quote_char.unwrap_or(b'"'), quoting: quote_char.is_some(), @@ -44,7 +44,7 @@ mod inner { } fn eof_oel(&self, current_ch: u8) -> bool { - current_ch == self.delimiter || current_ch == self.eol_char + current_ch == self.separator || current_ch == self.eol_char } } @@ -59,7 +59,7 @@ mod inner { } let mut needs_escaping = false; - // There can be strings with delimiters: + // There can be strings with separators: // "Street, City", // Safety: @@ -157,33 +157,33 @@ mod inner { /// This exists solely because we cannot split the lines naively as pub(crate) struct SplitFields<'a> { pub v: &'a [u8], - delimiter: u8, + separator: u8, pub finished: bool, quote_char: u8, quoting: bool, eol_char: u8, - simd_delimiter: SimdVec, + simd_separator: SimdVec, simd_eol_char: SimdVec, } impl<'a> SplitFields<'a> { pub(crate) fn new( slice: &'a [u8], - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) -> Self { - let simd_delimiter = SimdVec::splat(delimiter); + let simd_separator = SimdVec::splat(separator); let simd_eol_char = SimdVec::splat(eol_char); Self { v: slice, - delimiter, + separator, finished: false, quote_char: quote_char.unwrap_or(b'"'), quoting: quote_char.is_some(), eol_char, - simd_delimiter, + simd_separator, simd_eol_char, } } @@ -204,7 +204,7 @@ mod inner { } fn eof_oel(&self, current_ch: u8) -> bool { - current_ch == self.delimiter || current_ch == self.eol_char + current_ch == self.separator || current_ch == self.eol_char } } @@ -219,7 +219,7 @@ mod inner { } let mut needs_escaping = false; - // There can be strings with delimiters: + // There can be strings with separators: // "Street, City", // Safety: @@ -279,8 +279,8 @@ mod inner { .unwrap_unchecked_release(); let simd_bytes = SimdVec::from(lane); let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char); - let has_delimiter = simd_bytes.simd_eq(self.simd_delimiter); - let has_any = has_delimiter.bitor(has_eol_char); + let has_separator = simd_bytes.simd_eq(self.simd_separator); + let has_any = has_separator.bitor(has_eol_char); if has_any.any() { // soundness we can transmute because we have the same alignment let has_any = std::mem::transmute::< diff --git a/crates/polars-io/src/csv/utils.rs b/crates/polars-io/src/csv/utils.rs index e9c89d1a2ab7..ba8cc68ee63c 100644 --- a/crates/polars-io/src/csv/utils.rs +++ b/crates/polars-io/src/csv/utils.rs @@ -1,29 +1,29 @@ use std::borrow::Cow; +#[cfg(any(feature = "decompress", feature = "decompress-fast"))] use std::io::Read; use std::mem::MaybeUninit; -use once_cell::sync::Lazy; use polars_core::datatypes::PlHashSet; use polars_core::prelude::*; #[cfg(feature = "polars-time")] use polars_time::chunkedarray::utf8::infer as date_infer; #[cfg(feature = "polars-time")] use polars_time::prelude::utf8::Pattern; -use regex::{Regex, RegexBuilder}; #[cfg(any(feature = "decompress", feature = "decompress-fast"))] use crate::csv::parser::next_line_position_naive; use crate::csv::parser::{next_line_position, skip_bom, skip_line_ending, SplitLines}; use crate::csv::splitfields::SplitFields; use crate::csv::CsvEncoding; -use crate::mmap::{MmapBytesReader, ReaderBytes}; +use crate::mmap::ReaderBytes; use crate::prelude::NullValues; +use crate::utils::{BOOLEAN_RE, FLOAT_RE, INTEGER_RE}; pub(crate) fn get_file_chunks( bytes: &[u8], n_chunks: usize, expected_fields: usize, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) -> Vec<(usize, usize)> { @@ -41,7 +41,7 @@ pub(crate) fn get_file_chunks( let end_pos = match next_line_position( &bytes[search_pos..], Some(expected_fields), - delimiter, + separator, quote_char, eol_char, ) { @@ -57,50 +57,6 @@ pub(crate) fn get_file_chunks( offsets } -pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( - reader: &'a mut R, -) -> PolarsResult> { - // we have a file so we can mmap - if let Some(file) = reader.to_file() { - let mmap = unsafe { memmap::Mmap::map(file)? }; - - // somehow bck thinks borrows alias - // this is sound as file was already bound to 'a - use std::fs::File; - let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; - Ok(ReaderBytes::Mapped(mmap, file)) - } else { - // we can get the bytes for free - if reader.to_bytes().is_some() { - // duplicate .to_bytes() is necessary to satisfy the borrow checker - Ok(ReaderBytes::Borrowed((*reader).to_bytes().unwrap())) - } else { - // we have to read to an owned buffer to get the bytes. - let mut bytes = Vec::with_capacity(1024 * 128); - reader.read_to_end(&mut bytes)?; - if !bytes.is_empty() - && (bytes[bytes.len() - 1] != b'\n' || bytes[bytes.len() - 1] != b'\r') - { - bytes.push(b'\n') - } - Ok(ReaderBytes::Owned(bytes)) - } - } -} - -static FLOAT_RE: Lazy = Lazy::new(|| { - Regex::new(r"^\s*[-+]?((\d*\.\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+\.)$").unwrap() -}); - -static INTEGER_RE: Lazy = Lazy::new(|| Regex::new(r"^\s*-?(\d+)$").unwrap()); - -static BOOLEAN_RE: Lazy = Lazy::new(|| { - RegexBuilder::new(r"^\s*(true)$|^(false)$") - .case_insensitive(true) - .build() - .unwrap() -}); - /// Infer the data type of a record fn infer_field_schema(string: &str, try_parse_dates: bool) -> DataType { // when quoting is enabled in the reader, these quotes aren't escaped, we default to @@ -178,7 +134,7 @@ pub(crate) fn parse_bytes_with_encoding( #[allow(clippy::too_many_arguments)] pub fn infer_file_schema_inner( reader_bytes: &ReaderBytes, - delimiter: u8, + separator: u8, max_read_rows: Option, has_header: bool, schema_overwrite: Option<&Schema>, @@ -243,7 +199,7 @@ pub fn infer_file_schema_inner( } } - let byterecord = SplitFields::new(header_line, delimiter, quote_char, eol_char); + let byterecord = SplitFields::new(header_line, separator, quote_char, eol_char); if has_header { let headers = byterecord .map(|(slice, needs_escaping)| { @@ -277,8 +233,8 @@ pub fn infer_file_schema_inner( .map(|(i, _s)| format!("column_{}", i + 1)) .collect(); // needed because SplitLines does not return the \n char, so SplitFields does not catch - // the latest value if ending with a delimiter. - if header_line.ends_with(&[delimiter]) { + // the latest value if ending with a separator. + if header_line.ends_with(&[separator]) { column_names.push(format!("column_{}", column_names.len() + 1)) } column_names @@ -292,7 +248,7 @@ pub fn infer_file_schema_inner( return infer_file_schema_inner( &ReaderBytes::Owned(buf), - delimiter, + separator, max_read_rows, has_header, schema_overwrite, @@ -366,7 +322,7 @@ pub fn infer_file_schema_inner( } } - let mut record = SplitFields::new(line, delimiter, quote_char, eol_char); + let mut record = SplitFields::new(line, separator, quote_char, eol_char); for i in 0..header_length { if let Some((slice, needs_escaping)) = record.next() { @@ -478,7 +434,7 @@ pub fn infer_file_schema_inner( rb.push(eol_char); return infer_file_schema_inner( &ReaderBytes::Owned(rb), - delimiter, + separator, max_read_rows, has_header, schema_overwrite, @@ -509,7 +465,7 @@ pub fn infer_file_schema_inner( #[allow(clippy::too_many_arguments)] pub fn infer_file_schema( reader_bytes: &ReaderBytes, - delimiter: u8, + separator: u8, max_read_rows: Option, has_header: bool, schema_overwrite: Option<&Schema>, @@ -526,7 +482,7 @@ pub fn infer_file_schema( ) -> PolarsResult<(Schema, usize, usize)> { infer_file_schema_inner( reader_bytes, - delimiter, + separator, max_read_rows, has_header, schema_overwrite, @@ -560,7 +516,7 @@ pub fn is_compressed(bytes: &[u8]) -> bool { fn decompress_impl( decoder: &mut R, n_rows: Option, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) -> Option> { @@ -592,7 +548,7 @@ fn decompress_impl( } // now that we have enough, we compute the number of fields (also takes embedding into account) expected_fields = - SplitFields::new(&out, delimiter, quote_char, eol_char).count(); + SplitFields::new(&out, separator, quote_char, eol_char).count(); break; } } @@ -605,7 +561,7 @@ fn decompress_impl( match next_line_position( &out[buf_pos + 1..], Some(expected_fields), - delimiter, + separator, quote_char, eol_char, ) { @@ -633,16 +589,16 @@ fn decompress_impl( pub(crate) fn decompress( bytes: &[u8], n_rows: Option, - delimiter: u8, + separator: u8, quote_char: Option, eol_char: u8, ) -> Option> { if bytes.starts_with(&GZIP) { let mut decoder = flate2::read::MultiGzDecoder::new(bytes); - decompress_impl(&mut decoder, n_rows, delimiter, quote_char, eol_char) + decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) } else if bytes.starts_with(&ZLIB0) || bytes.starts_with(&ZLIB1) || bytes.starts_with(&ZLIB2) { let mut decoder = flate2::read::ZlibDecoder::new(bytes); - decompress_impl(&mut decoder, n_rows, delimiter, quote_char, eol_char) + decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) } else { None } diff --git a/crates/polars-io/src/csv/write.rs b/crates/polars-io/src/csv/write.rs index f9c31a827805..0752cfac872b 100644 --- a/crates/polars-io/src/csv/write.rs +++ b/crates/polars-io/src/csv/write.rs @@ -9,12 +9,14 @@ pub enum QuoteStyle { /// This puts quotes around every field. Always. Always, /// This puts quotes around fields only when necessary. - // They are necessary when fields contain a quote, delimiter or record terminator. Quotes are also necessary when writing an empty record (which is indistinguishable from a record with one empty field). + // They are necessary when fields contain a quote, separator or record terminator. Quotes are also necessary when writing an empty record (which is indistinguishable from a record with one empty field). // This is the default. #[default] Necessary, /// This puts quotes around all fields that are non-numeric. Namely, when writing a field that does not parse as a valid float or integer, then quotes will be used even if they aren’t strictly necessary. NonNumeric, + /// Never quote any fields, even if it would produce invalid CSV data. + Never, } /// Write a DataFrame to csv. @@ -67,9 +69,9 @@ where self } - /// Set the CSV file's column delimiter as a byte character. - pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.options.delimiter = delimiter; + /// Set the CSV file's column separator as a byte character. + pub fn with_separator(mut self, separator: u8) -> Self { + self.options.separator = separator; self } @@ -112,8 +114,8 @@ where } /// Set the single byte character used for quoting. - pub fn with_quoting_char(mut self, char: u8) -> Self { - self.options.quote = char; + pub fn with_quote_char(mut self, char: u8) -> Self { + self.options.quote_char = char; self } diff --git a/crates/polars-io/src/csv/write_impl.rs b/crates/polars-io/src/csv/write_impl.rs index 88edf18f2900..edf15bc9c730 100644 --- a/crates/polars-io/src/csv/write_impl.rs +++ b/crates/polars-io/src/csv/write_impl.rs @@ -8,7 +8,6 @@ use std::io::Write; use arrow::temporal_conversions; #[cfg(feature = "timezones")] use chrono::TimeZone; -use lexical_core::{FormattedSize, ToLexical}; use memchr::{memchr, memchr2}; use polars_arrow::time_zone::Tz; use polars_core::prelude::*; @@ -22,44 +21,47 @@ use serde::{Deserialize, Serialize}; use super::write::QuoteStyle; fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> std::io::Result<()> { + if options.quote_style == QuoteStyle::Never { + return write!(f, "{v}"); + } + let quote = options.quote_char as char; if v.is_empty() { - write!(f, "\"\"") - } else { - let needs_escaping = memchr(options.quote, v.as_bytes()).is_some(); - - if needs_escaping { - let replaced = unsafe { - // Replace from single quote " to double quote "". - v.replace( - std::str::from_utf8_unchecked(&[options.quote]), - std::str::from_utf8_unchecked(&[options.quote, options.quote]), - ) - }; - return write!(f, "\"{replaced}\""); - } - let surround_with_quotes = match options.quote_style { - QuoteStyle::Always | QuoteStyle::NonNumeric => true, - QuoteStyle::Necessary => memchr2(options.delimiter, b'\n', v.as_bytes()).is_some(), + return write!(f, "{quote}{quote}"); + } + let needs_escaping = memchr(options.quote_char, v.as_bytes()).is_some(); + if needs_escaping { + let replaced = unsafe { + // Replace from single quote " to double quote "". + v.replace( + std::str::from_utf8_unchecked(&[options.quote_char]), + std::str::from_utf8_unchecked(&[options.quote_char, options.quote_char]), + ) }; - - let quote = options.quote as char; - if surround_with_quotes { - write!(f, "{quote}{v}{quote}") - } else { - write!(f, "{v}") - } + return write!(f, "{quote}{replaced}{quote}"); + } + let surround_with_quotes = match options.quote_style { + QuoteStyle::Always | QuoteStyle::NonNumeric => true, + QuoteStyle::Necessary => memchr2(options.separator, b'\n', v.as_bytes()).is_some(), + QuoteStyle::Never => false, + }; + + if surround_with_quotes { + write!(f, "{quote}{v}{quote}") + } else { + write!(f, "{v}") } } -fn fast_float_write(f: &mut Vec, n: N, write_size: usize) -> std::io::Result<()> { - let len = f.len(); - f.reserve(write_size); - unsafe { - let buffer = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size); - let written_n = n.to_lexical(buffer).len(); - f.set_len(len + written_n); - } - Ok(()) +fn fast_float_write(f: &mut Vec, val: I) { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) +} + +fn write_integer(f: &mut Vec, val: I) { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) } unsafe fn write_anyvalue( @@ -84,7 +86,7 @@ unsafe fn write_anyvalue( }, _ => { // Then we deal with the numeric types - let quote = options.quote as char; + let quote = options.quote_char as char; let mut end_with_quote = matches!(options.quote_style, QuoteStyle::Always); if end_with_quote { @@ -94,20 +96,50 @@ unsafe fn write_anyvalue( match value { AnyValue::Null => write!(f, "{}", &options.null), - AnyValue::Int8(v) => write!(f, "{v}"), - AnyValue::Int16(v) => write!(f, "{v}"), - AnyValue::Int32(v) => write!(f, "{v}"), - AnyValue::Int64(v) => write!(f, "{v}"), - AnyValue::UInt8(v) => write!(f, "{v}"), - AnyValue::UInt16(v) => write!(f, "{v}"), - AnyValue::UInt32(v) => write!(f, "{v}"), - AnyValue::UInt64(v) => write!(f, "{v}"), + AnyValue::Int8(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::Int16(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::Int32(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::Int64(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt8(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt16(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt32(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt64(v) => { + write_integer(f, v); + Ok(()) + }, AnyValue::Float32(v) => match &options.float_precision { - None => fast_float_write(f, v, f32::FORMATTED_SIZE_DECIMAL), + None => { + fast_float_write(f, v); + Ok(()) + }, Some(precision) => write!(f, "{v:.precision$}"), }, AnyValue::Float64(v) => match &options.float_precision { - None => fast_float_write(f, v, f64::FORMATTED_SIZE_DECIMAL), + None => { + fast_float_write(f, v); + Ok(()) + }, Some(precision) => write!(f, "{v:.precision$}"), }, _ => { @@ -154,18 +186,19 @@ unsafe fn write_anyvalue( }, _ => ndt.format(datetime_format), }; - return write!(f, "{formatted}").map_err(|_|{ - + let str_result = write!(f, "{formatted}"); + if str_result.is_err() { let datetime_format = unsafe { *datetime_formats.get_unchecked(i) }; let type_name = if tz.is_some() { "DateTime" } else { "NaiveDateTime" }; - polars_err!( - ComputeError: "cannot format {} with format '{}'", type_name, datetime_format, - ) - }); + polars_bail!( + ComputeError: "cannot format {} with format '{}'", type_name, datetime_format, + ) + }; + str_result }, #[cfg(feature = "dtype-time")] AnyValue::Time(v) => { @@ -205,10 +238,10 @@ pub struct SerializeOptions { pub datetime_format: Option, /// Used for [`DataType::Float64`] and [`DataType::Float32`]. pub float_precision: Option, - /// Used as separator/delimiter. - pub delimiter: u8, + /// Used as separator. + pub separator: u8, /// Quoting character. - pub quote: u8, + pub quote_char: u8, /// Null value representation. pub null: String, /// String appended after every row. @@ -223,8 +256,8 @@ impl Default for SerializeOptions { time_format: None, datetime_format: None, float_precision: None, - delimiter: b',', - quote: b'"', + separator: b',', + quote_char: b'"', null: String::new(), line_terminator: "\n".into(), quote_style: Default::default(), @@ -269,10 +302,10 @@ pub(crate) fn write( // Check that the double quote is valid UTF-8. polars_ensure!( - std::str::from_utf8(&[options.quote, options.quote]).is_ok(), + std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(), ComputeError: "quote char results in invalid utf-8", ); - let delimiter = char::from(options.delimiter); + let separator = char::from(options.separator); let (datetime_formats, time_zones): (Vec<&str>, Vec>) = df .get_columns() @@ -406,7 +439,7 @@ pub(crate) fn write( } let current_ptr = col as *const SeriesIter; if current_ptr != last_ptr { - write!(&mut write_buffer, "{delimiter}").unwrap() + write!(&mut write_buffer, "{separator}").unwrap() } } if !finished { @@ -422,7 +455,7 @@ pub(crate) fn write( }); // rayon will ensure the right order - result_buf.par_extend(par_iter); + POOL.install(|| result_buf.par_extend(par_iter)); for buf in result_buf.drain(..) { let mut buf = buf?; @@ -455,7 +488,7 @@ pub(crate) fn write_header( } writer.write_all( escaped_names - .join(std::str::from_utf8(&[options.delimiter]).unwrap()) + .join(std::str::from_utf8(&[options.separator]).unwrap()) .as_bytes(), )?; writer.write_all(options.line_terminator.as_bytes())?; diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs index ee8e7cc9560e..d8320e0c48fa 100644 --- a/crates/polars-io/src/ipc/mmap.rs +++ b/crates/polars-io/src/ipc/mmap.rs @@ -67,6 +67,7 @@ impl ArrowReader for MMapChunkIter<'_> { } } +#[cfg(feature = "ipc")] impl IpcReader { pub(super) fn finish_memmapped( &mut self, diff --git a/crates/polars-io/src/ipc/mod.rs b/crates/polars-io/src/ipc/mod.rs index 7b74b8486935..1366aa84324f 100644 --- a/crates/polars-io/src/ipc/mod.rs +++ b/crates/polars-io/src/ipc/mod.rs @@ -6,7 +6,7 @@ mod ipc_file; #[cfg(feature = "ipc_streaming")] mod ipc_stream; mod mmap; -#[cfg(feature = "ipc")] +#[cfg(any(feature = "ipc", feature = "ipc_streaming"))] mod write; #[cfg(all(feature = "async", feature = "ipc"))] mod write_async; diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index c5c7d8fe503f..16c03d8a5349 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -67,9 +67,7 @@ use std::ops::Deref; use arrow::array::StructArray; pub use arrow::error::Result as ArrowResult; -pub use arrow::io::json; use polars_arrow::conversion::chunk_to_struct; -use polars_arrow::utils::CustomIterTools; use polars_core::error::to_compute_err; use polars_core::prelude::*; use polars_core::utils::try_get_supertype; @@ -141,13 +139,14 @@ where match self.json_format { JsonFormat::JsonLines => { - let serializer = arrow_ndjson::write::Serializer::new(batches, vec![]); - let writer = arrow_ndjson::write::FileWriter::new(&mut self.buffer, serializer); + let serializer = polars_json::ndjson::write::Serializer::new(batches, vec![]); + let writer = + polars_json::ndjson::write::FileWriter::new(&mut self.buffer, serializer); writer.collect::>()?; }, JsonFormat::Json => { - let serializer = json::write::Serializer::new(batches, vec![]); - json::write::write(&mut self.buffer, serializer)?; + let serializer = polars_json::json::write::Serializer::new(batches, vec![]); + polars_json::json::write::write(&mut self.buffer, serializer)?; }, } @@ -216,14 +215,13 @@ where let mut_schema = Arc::make_mut(&mut schema); overwrite_schema(mut_schema, overwrite)?; } + DataType::Struct(schema.iter_fields().collect()).to_arrow() } else { // infer - if let BorrowedValue::Array(values) = &json_value { - polars_ensure!(self.schema_overwrite.is_none() && self.schema.is_none(), ComputeError: "schema arguments not yet supported for Array json"); - + let inner_dtype = if let BorrowedValue::Array(values) = &json_value { // struct types may have missing fields so find supertype - let dtype = values + values .iter() .take(self.infer_schema_len.unwrap_or(usize::MAX)) .map(|value| { @@ -231,37 +229,45 @@ where .map_err(PolarsError::from) .map(|dt| DataType::from(&dt)) }) - .fold_first_(|l, r| { + .reduce(|l, r| { let l = l?; let r = r?; try_get_supertype(&l, &r) }) - .unwrap()?; - let dtype = DataType::List(Box::new(dtype)); - dtype.to_arrow() + .unwrap()? + .to_arrow() } else { - let dtype = infer(&json_value)?; - if let Some(overwrite) = self.schema_overwrite { - let ArrowDataType::Struct(fields) = dtype else { - polars_bail!(ComputeError: "can only deserialize json objects") - }; + infer(&json_value)? + }; - let mut schema = Schema::from_iter(fields.iter()); - overwrite_schema(&mut schema, overwrite)?; + if let Some(overwrite) = self.schema_overwrite { + let ArrowDataType::Struct(fields) = inner_dtype else { + polars_bail!(ComputeError: "can only deserialize json objects") + }; - DataType::Struct( - schema - .into_iter() - .map(|(name, dt)| Field::new(&name, dt)) - .collect(), - ) - .to_arrow() - } else { - dtype - } + let mut schema = Schema::from_iter(fields.iter()); + overwrite_schema(&mut schema, overwrite)?; + + DataType::Struct( + schema + .into_iter() + .map(|(name, dt)| Field::new(&name, dt)) + .collect(), + ) + .to_arrow() + } else { + inner_dtype } }; + let dtype = if let BorrowedValue::Array(_) = &json_value { + ArrowDataType::LargeList(Box::new(arrow::datatypes::Field::new( + "item", dtype, true, + ))) + } else { + dtype + }; + let arr = polars_json::json::deserialize(&json_value, dtype)?; let arr = arr.as_any().downcast_ref::().ok_or_else( || polars_err!(ComputeError: "can only deserialize json objects"), diff --git a/crates/polars-io/src/lib.rs b/crates/polars-io/src/lib.rs index e3f5f25b737e..f57dfa4aa6ce 100644 --- a/crates/polars-io/src/lib.rs +++ b/crates/polars-io/src/lib.rs @@ -4,8 +4,7 @@ #[cfg(feature = "avro")] pub mod avro; -#[cfg(feature = "cloud")] -mod cloud; +pub mod cloud; #[cfg(any(feature = "csv", feature = "json"))] pub mod csv; #[cfg(feature = "parquet")] @@ -19,12 +18,6 @@ pub mod ndjson; #[cfg(feature = "cloud")] pub use crate::cloud::glob as async_glob; -#[cfg(any( - feature = "csv", - feature = "parquet", - feature = "ipc", - feature = "json" -))] pub mod mmap; mod options; #[cfg(feature = "parquet")] @@ -33,10 +26,14 @@ pub mod predicates; pub mod prelude; #[cfg(all(test, feature = "csv"))] mod tests; -pub(crate) mod utils; +pub mod utils; +use once_cell::sync::Lazy; +use regex::Regex; #[cfg(feature = "partition")] pub mod partition; +#[cfg(feature = "async")] +pub mod pl_async; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; @@ -169,9 +166,13 @@ pub(crate) fn finish_reader( } } +static CLOUD_URL: Lazy = + Lazy::new(|| Regex::new(r"^(s3a?|gs|gcs|file|abfss?|azure|az|adl)://").unwrap()); + /// Check if the path is a cloud url. pub fn is_cloud_url>(p: P) -> bool { - p.as_ref().starts_with("s3://") - || p.as_ref().starts_with("file://") - || p.as_ref().starts_with("gcs://") + match p.as_ref().as_os_str().to_str() { + Some(s) => CLOUD_URL.is_match(s), + _ => false, + } } diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs index 726ea8d78bd3..dda5108d2f9c 100644 --- a/crates/polars-io/src/ndjson/core.rs +++ b/crates/polars-io/src/ndjson/core.rs @@ -3,14 +3,12 @@ use std::io::Cursor; use std::path::PathBuf; pub use arrow::array::StructArray; -pub use arrow::io::ndjson as arrow_ndjson; use num_traits::pow::Pow; use polars_core::prelude::*; use polars_core::utils::accumulate_dataframes_vertical; use polars_core::POOL; use rayon::prelude::*; -use crate::csv::utils::*; use crate::mmap::{MmapBytesReader, ReaderBytes}; use crate::ndjson::buffer::*; use crate::prelude::*; diff --git a/crates/polars-io/src/parquet/async_impl.rs b/crates/polars-io/src/parquet/async_impl.rs index 961776f8dfb0..ce5ecb7ce7f4 100644 --- a/crates/polars-io/src/parquet/async_impl.rs +++ b/crates/polars-io/src/parquet/async_impl.rs @@ -2,44 +2,44 @@ use std::ops::Range; use std::sync::Arc; -use arrow::io::parquet::read::{ - self as parquet2_read, read_columns_async, ColumnChunkMetaData, RowGroupMetaData, -}; +use arrow::io::parquet::read::{self as parquet2_read, RowGroupMetaData}; use arrow::io::parquet::write::FileMetaData; -use futures::future::BoxFuture; -use futures::lock::Mutex; -use futures::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use bytes::Bytes; +use futures::future::try_join_all; use object_store::path::Path as ObjectPath; use object_store::ObjectStore; -use polars_core::cloud::CloudOptions; use polars_core::config::verbose; use polars_core::datatypes::PlHashMap; use polars_core::error::{to_compute_err, PolarsResult}; use polars_core::prelude::*; use polars_core::schema::Schema; +use smartstring::alias::String as SmartString; -use super::cloud::{build, CloudLocation, CloudReader}; +use super::cloud::{build_object_store, CloudLocation, CloudReader}; use super::mmap; use super::mmap::ColumnStore; -use super::read_impl::FetchRowGroups; +use crate::cloud::CloudOptions; pub struct ParquetObjectStore { - store: Arc>>, + store: Arc, path: ObjectPath, length: Option, - metadata: Option, + metadata: Option>, } impl ParquetObjectStore { - pub fn from_uri(uri: &str, options: Option<&CloudOptions>) -> PolarsResult { - let (CloudLocation { prefix, .. }, store) = build(uri, options)?; - let store = Arc::new(Mutex::from(store)); + pub async fn from_uri( + uri: &str, + options: Option<&CloudOptions>, + metadata: Option>, + ) -> PolarsResult { + let (CloudLocation { prefix, .. }, store) = build_object_store(uri, options).await?; Ok(ParquetObjectStore { store, - path: prefix.into(), + path: ObjectPath::from_url_path(prefix).map_err(to_compute_err)?, length: None, - metadata: None, + metadata, }) } @@ -48,9 +48,13 @@ impl ParquetObjectStore { if self.length.is_some() { return Ok(()); } - let path = self.path.clone(); - let locked_store = self.store.lock().await; - self.length = Some(locked_store.head(&path).await.map_err(to_compute_err)?.size as u64); + self.length = Some( + self.store + .head(&self.path) + .await + .map_err(to_compute_err)? + .size as u64, + ); Ok(()) } @@ -75,102 +79,115 @@ impl ParquetObjectStore { let path = self.path.clone(); let length = self.length; let mut reader = CloudReader::new(length, object_store, path); + parquet2_read::read_metadata_async(&mut reader) .await .map_err(to_compute_err) } /// Fetch and memoize the metadata of the parquet file. - pub async fn get_metadata(&mut self) -> PolarsResult<&FileMetaData> { - self.initialize_length().await?; + pub async fn get_metadata(&mut self) -> PolarsResult<&Arc> { if self.metadata.is_none() { - self.metadata = Some(self.fetch_metadata().await?); + self.metadata = Some(Arc::new(self.fetch_metadata().await?)); } Ok(self.metadata.as_ref().unwrap()) } } -/// A vector of downloaded RowGroups. -/// A RowGroup will have 1 or more columns, for each column we store: -/// - a reference to its metadata -/// - the actual content as downloaded from object storage (generally cloud). -type RowGroupChunks<'a> = Vec)>>; +async fn read_single_column_async( + async_reader: &ParquetObjectStore, + start: usize, + length: usize, +) -> PolarsResult<(u64, Bytes)> { + let chunk = async_reader + .store + .get_range(&async_reader.path, start..start + length) + .await + .map_err(to_compute_err)?; + Ok((start as u64, chunk)) +} + +async fn read_columns_async2( + async_reader: &ParquetObjectStore, + ranges: &[(u64, u64)], +) -> PolarsResult> { + let futures = ranges.iter().map(|(start, length)| async { + read_single_column_async(async_reader, *start as usize, *length as usize).await + }); + + try_join_all(futures).await +} /// Download rowgroups for the column whose indexes are given in `projection`. /// We concurrently download the columns for each field. -#[tokio::main(flavor = "current_thread")] -async fn download_projection<'a: 'b, 'b>( - projection: &[usize], - row_groups: &'a [RowGroupMetaData], - schema: &ArrowSchema, - async_reader: &'b ParquetObjectStore, -) -> PolarsResult> { - let fields = projection - .iter() - .map(|i| schema.fields[*i].name.clone()) - .collect::>(); - - let reader_factory = || { - let object_store = async_reader.store.clone(); - let path = async_reader.path.clone(); - Box::pin(futures::future::ready(Ok(CloudReader::new( - async_reader.length, - object_store, - path, - )))) - } - as BoxFuture<'static, std::result::Result>; - +async fn download_projection( + fields: &[SmartString], + row_groups: &[RowGroupMetaData], + async_reader: &Arc, +) -> PolarsResult>> { // Build the cartesian product of the fields and the row groups. - let product = fields - .into_iter() - .flat_map(|f| row_groups.iter().map(move |r| (f.clone(), r))); - - // Download them all concurrently. - stream::iter(product) - .then(move |(name, row_group)| async move { + let product_futures = fields + .iter() + .flat_map(|name| row_groups.iter().map(move |r| (name.clone(), r))) + .map(|(name, row_group)| async move { let columns = row_group.columns(); - read_columns_async(reader_factory, columns, name.as_ref()) - .map_err(to_compute_err) - .await - }) - .try_collect() - .await + let ranges = columns + .iter() + .filter_map(|meta| { + if meta.descriptor().path_in_schema[0] == name.as_str() { + Some(meta.byte_range()) + } else { + None + } + }) + .collect::>(); + let async_reader = async_reader.clone(); + let handle = + tokio::spawn(async move { read_columns_async2(&async_reader, &ranges).await }); + handle.await.unwrap() + }); + + // Download concurrently + futures::future::try_join_all(product_futures).await } -pub(crate) struct FetchRowGroupsFromObjectStore { - reader: ParquetObjectStore, +pub struct FetchRowGroupsFromObjectStore { + reader: Arc, row_groups_metadata: Vec, - projection: Vec, + projected_fields: Vec, logging: bool, - schema: ArrowSchema, } impl FetchRowGroupsFromObjectStore { pub fn new( reader: ParquetObjectStore, metadata: &FileMetaData, - projection: &Option>, + schema: SchemaRef, + projection: Option<&[usize]>, ) -> PolarsResult { - let schema = parquet2_read::schema::infer_schema(metadata)?; let logging = verbose(); - let projection = projection - .to_owned() - .unwrap_or_else(|| (0usize..schema.fields.len()).collect::>()); + let projected_fields = projection + .map(|projection| { + projection + .iter() + .map(|i| schema.get_at_index(*i).unwrap().0.clone()) + .collect::>() + }) + .unwrap_or_else(|| schema.iter().map(|tpl| tpl.0).cloned().collect()); Ok(FetchRowGroupsFromObjectStore { - reader, + reader: Arc::new(reader), row_groups_metadata: metadata.row_groups.to_owned(), - projection, + projected_fields, logging, - schema, }) } -} -impl FetchRowGroups for FetchRowGroupsFromObjectStore { - fn fetch_row_groups(&mut self, row_groups: Range) -> PolarsResult { + pub(crate) async fn fetch_row_groups( + &mut self, + row_groups: Range, + ) -> PolarsResult { // Fetch the required row groups. let row_groups = &self .row_groups_metadata @@ -184,21 +201,19 @@ impl FetchRowGroups for FetchRowGroupsFromObjectStore { // Package in the format required by ColumnStore. let downloaded = - download_projection(&self.projection, row_groups, &self.schema, &self.reader)?; + download_projection(&self.projected_fields, row_groups, &self.reader).await?; + if self.logging { eprintln!( "BatchedParquetReader: fetched {} row_groups for {} fields, yielding {} column chunks.", row_groups.len(), - self.projection.len(), + self.projected_fields.len(), downloaded.len(), ); } let downloaded_per_filepos = downloaded .into_iter() - .flat_map(|rg| { - rg.into_iter() - .map(|(meta, data)| (meta.byte_range().0, data)) - }) + .flat_map(|rg| rg.into_iter()) .collect::>(); if self.logging { diff --git a/crates/polars-io/src/parquet/mmap.rs b/crates/polars-io/src/parquet/mmap.rs index 8ad0a3d29e6b..8615d75b88fb 100644 --- a/crates/polars-io/src/parquet/mmap.rs +++ b/crates/polars-io/src/parquet/mmap.rs @@ -3,6 +3,7 @@ use arrow::io::parquet::read::{ column_iter_to_arrays, get_field_columns, ArrayIter, BasicDecompressor, ColumnChunkMetaData, PageReader, }; +use bytes::Bytes; #[cfg(feature = "async")] use polars_core::datatypes::PlHashMap; @@ -22,7 +23,7 @@ use super::*; pub enum ColumnStore<'a> { Local(&'a [u8]), #[cfg(feature = "async")] - Fetched(PlHashMap>), + Fetched(PlHashMap), } /// For local files memory maps all columns that are part of the parquet field `field_name`. @@ -52,7 +53,7 @@ fn _mmap_single_column<'a>( "mmap_columns: column with start {start} must be prefetched in ColumnStore.\n" ) }); - entry.as_slice() + entry.as_ref() }, }; (meta, chunk) diff --git a/crates/polars-io/src/parquet/predicates.rs b/crates/polars-io/src/parquet/predicates.rs index 9454ac431f3b..02262660e384 100644 --- a/crates/polars-io/src/parquet/predicates.rs +++ b/crates/polars-io/src/parquet/predicates.rs @@ -1,113 +1,18 @@ -use arrow::compute::concatenate::concatenate; use arrow::io::parquet::read::statistics::{deserialize, Statistics}; use arrow::io::parquet::read::RowGroupMetaData; use polars_core::prelude::*; -use crate::predicates::PhysicalIoExpr; +use crate::predicates::{BatchStats, ColumnStats, PhysicalIoExpr}; use crate::ArrowResult; -/// The statistics for a column in a Parquet file -/// they typically hold -/// - max value -/// - min value -/// - null_count -#[cfg_attr(debug_assertions, derive(Debug))] -pub struct ColumnStats(Statistics, Field); - impl ColumnStats { - pub fn dtype(&self) -> DataType { - self.1.data_type().clone() - } - - pub fn null_count(&self) -> Option { - match self.1.data_type() { - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => None, - _ => { - // the array holds the null count for every row group - // so we sum them to get them of the whole file. - Series::try_from(("", self.0.null_count.clone())) - .unwrap() - .sum() - }, - } - } - - pub fn to_min_max(&self) -> Option { - let max_val = &*self.0.max_value; - let min_val = &*self.0.min_value; - - let dtype = DataType::from(min_val.data_type()); - - if Self::use_min_max(dtype) { - let arr = concatenate(&[min_val, max_val]).unwrap(); - let s = Series::try_from(("", arr)).unwrap(); - if s.null_count() > 0 { - None - } else { - Some(s) - } - } else { - None - } - } - - pub fn to_min(&self) -> Option { - let min_val = self.0.min_value.clone(); - let dtype = DataType::from(min_val.data_type()); - - if !Self::use_min_max(dtype) || min_val.len() != 1 { - return None; - } - - let s = Series::try_from(("", min_val)).unwrap(); - if s.null_count() > 0 { - None - } else { - Some(s) - } - } - - pub fn to_max(&self) -> Option { - let max_val = self.0.max_value.clone(); - let dtype = DataType::from(max_val.data_type()); - - if !Self::use_min_max(dtype) || max_val.len() != 1 { - return None; - } - - let s = Series::try_from(("", max_val)).unwrap(); - if s.null_count() > 0 { - None - } else { - Some(s) - } - } - - #[cfg(feature = "dtype-binary")] - fn use_min_max(dtype: DataType) -> bool { - dtype.is_numeric() || matches!(dtype, DataType::Utf8) || matches!(dtype, DataType::Binary) - } - - #[cfg(not(feature = "dtype-binary"))] - fn use_min_max(dtype: DataType) -> bool { - dtype.is_numeric() || matches!(dtype, DataType::Utf8) - } -} - -/// A collection of column stats with a known schema. -pub struct BatchStats { - schema: Schema, - stats: Vec, -} - -impl BatchStats { - pub fn get_stats(&self, column: &str) -> polars_core::error::PolarsResult<&ColumnStats> { - self.schema.try_index_of(column).map(|i| &self.stats[i]) - } - - pub fn schema(&self) -> &Schema { - &self.schema + fn from_arrow_stats(stats: Statistics, field: &ArrowField) -> Self { + Self::new( + field.into(), + Some(Series::try_from(("", stats.null_count)).unwrap()), + Some(Series::try_from(("", stats.min_value)).unwrap()), + Some(Series::try_from(("", stats.max_value)).unwrap()), + ) } } @@ -128,13 +33,13 @@ pub(crate) fn collect_statistics( Some(rg) => deserialize(fld, &md[rg..rg + 1])?, }; schema.with_column((&fld.name).into(), (&fld.data_type).into()); - stats.push(ColumnStats(st, fld.into())); + stats.push(ColumnStats::from_arrow_stats(st, fld)); } Ok(if stats.is_empty() { None } else { - Some(BatchStats { schema, stats }) + Some(BatchStats::new(schema, stats)) }) } diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index c21a2892c8ba..544c216975fd 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -3,13 +3,15 @@ use std::sync::Arc; use arrow::io::parquet::read; use arrow::io::parquet::write::FileMetaData; -#[cfg(feature = "cloud")] -use polars_core::cloud::CloudOptions; use polars_core::prelude::*; +#[cfg(feature = "cloud")] +use polars_core::utils::concat_df; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use super::read_impl::FetchRowGroupsFromMmapReader; +#[cfg(feature = "cloud")] +use crate::cloud::CloudOptions; use crate::mmap::MmapBytesReader; #[cfg(feature = "cloud")] use crate::parquet::async_impl::FetchRowGroupsFromObjectStore; @@ -17,6 +19,8 @@ use crate::parquet::async_impl::FetchRowGroupsFromObjectStore; use crate::parquet::async_impl::ParquetObjectStore; use crate::parquet::read_impl::read_parquet; pub use crate::parquet::read_impl::BatchedParquetReader; +#[cfg(feature = "cloud")] +use crate::predicates::apply_predicate; use crate::predicates::PhysicalIoExpr; use crate::prelude::*; use crate::RowCount; @@ -47,7 +51,8 @@ pub struct ParquetReader { parallel: ParallelStrategy, row_count: Option, low_memory: bool, - metadata: Option, + metadata: Option>, + hive_partition_columns: Option>, use_statistics: bool, } @@ -74,10 +79,11 @@ impl ParquetReader { self.parallel, self.row_count, self.use_statistics, + self.hive_partition_columns.as_deref(), ) .map(|mut df| { if rechunk { - df.align_chunks(); + df.as_single_chunk_par(); }; df }) @@ -141,9 +147,14 @@ impl ParquetReader { Ok(metadata.num_rows) } - fn get_metadata(&mut self) -> PolarsResult<&FileMetaData> { + pub fn with_hive_partition_columns(mut self, columns: Option>) -> Self { + self.hive_partition_columns = columns; + self + } + + pub fn get_metadata(&mut self) -> PolarsResult<&Arc> { if self.metadata.is_none() { - self.metadata = Some(read::read_metadata(&mut self.reader)?); + self.metadata = Some(Arc::new(read::read_metadata(&mut self.reader)?)); } Ok(self.metadata.as_ref().unwrap()) } @@ -151,9 +162,9 @@ impl ParquetReader { impl ParquetReader { pub fn batched(mut self, chunk_size: usize) -> PolarsResult { - let metadata = read::read_metadata(&mut self.reader)?; + let metadata = self.get_metadata()?.clone(); - let row_group_fetcher = Box::new(FetchRowGroupsFromMmapReader::new(Box::new(self.reader))?); + let row_group_fetcher = FetchRowGroupsFromMmapReader::new(Box::new(self.reader))?.into(); BatchedParquetReader::new( row_group_fetcher, metadata, @@ -162,6 +173,7 @@ impl ParquetReader { self.row_count, chunk_size, self.use_statistics, + self.hive_partition_columns, ) } } @@ -180,6 +192,7 @@ impl SerReader for ParquetReader { low_memory: false, metadata: None, use_statistics: true, + hive_partition_columns: None, } } @@ -206,6 +219,7 @@ impl SerReader for ParquetReader { self.parallel, self.row_count, self.use_statistics, + self.hive_partition_columns.as_deref(), ) .map(|mut df| { if self.rechunk { @@ -221,43 +235,35 @@ impl SerReader for ParquetReader { #[cfg(feature = "cloud")] pub struct ParquetAsyncReader { reader: ParquetObjectStore, - rechunk: bool, n_rows: Option, + rechunk: bool, projection: Option>, row_count: Option, - low_memory: bool, use_statistics: bool, + hive_partition_columns: Option>, + schema: Option, } #[cfg(feature = "cloud")] impl ParquetAsyncReader { - pub fn from_uri( + pub async fn from_uri( uri: &str, cloud_options: Option<&CloudOptions>, + schema: Option, + metadata: Option>, ) -> PolarsResult { Ok(ParquetAsyncReader { - reader: ParquetObjectStore::from_uri(uri, cloud_options)?, + reader: ParquetObjectStore::from_uri(uri, cloud_options, metadata).await?, rechunk: false, n_rows: None, projection: None, row_count: None, - low_memory: false, use_statistics: true, + hive_partition_columns: None, + schema, }) } - /// Fetch the file info in a synchronous way to for the query planning phase. - #[tokio::main(flavor = "current_thread")] - pub async fn file_info( - uri: &str, - options: Option<&CloudOptions>, - ) -> PolarsResult<(Schema, usize)> { - let mut reader = ParquetAsyncReader::from_uri(uri, options)?; - let schema = reader.schema().await?; - let num_rows = reader.num_rows().await?; - Ok((schema, num_rows)) - } - pub async fn schema(&mut self) -> PolarsResult { self.reader.schema().await } @@ -280,11 +286,6 @@ impl ParquetAsyncReader { self } - pub fn set_low_memory(mut self, low_memory: bool) -> Self { - self.low_memory = low_memory; - self - } - pub fn with_projection(mut self, projection: Option>) -> Self { self.projection = projection; self @@ -297,14 +298,21 @@ impl ParquetAsyncReader { self } - #[tokio::main(flavor = "current_thread")] + pub fn with_hive_partition_columns(mut self, columns: Option>) -> Self { + self.hive_partition_columns = columns; + self + } + pub async fn batched(mut self, chunk_size: usize) -> PolarsResult { - let metadata = self.reader.get_metadata().await?.to_owned(); - let row_group_fetcher = Box::new(FetchRowGroupsFromObjectStore::new( + let metadata = self.reader.get_metadata().await?.clone(); + // row group fetched deals with projection + let row_group_fetcher = FetchRowGroupsFromObjectStore::new( self.reader, &metadata, - &self.projection, - )?); + self.schema.unwrap(), + self.projection.as_deref(), + )? + .into(); BatchedParquetReader::new( row_group_fetcher, metadata, @@ -313,6 +321,37 @@ impl ParquetAsyncReader { self.row_count, chunk_size, self.use_statistics, + self.hive_partition_columns, ) } + + pub async fn get_metadata(&mut self) -> PolarsResult<&Arc> { + self.reader.get_metadata().await + } + + pub async fn finish( + self, + predicate: Option>, + ) -> PolarsResult { + let rechunk = self.rechunk; + + // batched reader deals with slice pushdown + let reader = self.batched(usize::MAX).await?; + let mut iter = reader.iter(16); + + let mut chunks = Vec::with_capacity(16); + while let Some(result) = iter.next_().await { + let out = result.and_then(|mut df| { + apply_predicate(&mut df, predicate.as_deref(), true)?; + Ok(df) + })?; + chunks.push(out) + } + let mut df = concat_df(&chunks)?; + + if rechunk { + df.as_single_chunk_par(); + } + Ok(df) + } } diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index e9cf8d71d47d..5f99e770b343 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -14,14 +14,32 @@ use rayon::prelude::*; use super::mmap::ColumnStore; use crate::mmap::{MmapBytesReader, ReaderBytes}; +#[cfg(feature = "cloud")] +use crate::parquet::async_impl::FetchRowGroupsFromObjectStore; use crate::parquet::mmap::mmap_columns; use crate::parquet::predicates::read_this_row_group; use crate::parquet::{mmap, ParallelStrategy}; use crate::predicates::{apply_predicate, arrow_schema_to_empty_df, PhysicalIoExpr}; -use crate::prelude::utils::get_reader_bytes; -use crate::utils::apply_projection; +use crate::utils::{apply_projection, get_reader_bytes}; use crate::RowCount; +fn enlarge_data_type(mut data_type: ArrowDataType) -> ArrowDataType { + match data_type { + ArrowDataType::Utf8 => { + data_type = ArrowDataType::LargeUtf8; + }, + ArrowDataType::Binary => { + data_type = ArrowDataType::LargeBinary; + }, + ArrowDataType::List(mut inner_field) => { + inner_field.data_type = enlarge_data_type(inner_field.data_type); + data_type = ArrowDataType::LargeList(inner_field); + }, + _ => {}, + } + data_type +} + fn column_idx_to_series( column_i: usize, md: &RowGroupMetaData, @@ -32,16 +50,7 @@ fn column_idx_to_series( ) -> PolarsResult { let mut field = schema.fields[column_i].clone(); - match field.data_type { - ArrowDataType::Utf8 => { - field.data_type = ArrowDataType::LargeUtf8; - }, - ArrowDataType::Binary => { - field.data_type = ArrowDataType::LargeBinary; - }, - ArrowDataType::List(fld) => field.data_type = ArrowDataType::LargeList(fld), - _ => {}, - } + field.data_type = enlarge_data_type(field.data_type); let columns = mmap_columns(store, md.columns(), &field.name); let iter = mmap::to_deserializer(columns, field.clone(), remaining_rows, Some(chunk_size))?; @@ -85,8 +94,18 @@ pub(super) fn array_iter_to_series( } } +/// Materializes hive partitions. +fn materialize_hive_partitions(df: &mut DataFrame, hive_partition_columns: Option<&[Series]>) { + if let Some(hive_columns) = hive_partition_columns { + let num_rows = df.height(); + + for s in hive_columns { + unsafe { df.with_column_unchecked(s.new_from_index(0, num_rows)) }; + } + } +} + #[allow(clippy::too_many_arguments)] -// might parallelize over columns fn rg_to_dfs( store: &mmap::ColumnStore, previous_row_count: &mut IdxSize, @@ -100,6 +119,58 @@ fn rg_to_dfs( parallel: ParallelStrategy, projection: &[usize], use_statistics: bool, + hive_partition_columns: Option<&[Series]>, +) -> PolarsResult> { + if let ParallelStrategy::Columns | ParallelStrategy::None = parallel { + rg_to_dfs_optionally_par_over_columns( + store, + previous_row_count, + row_group_start, + row_group_end, + remaining_rows, + file_metadata, + schema, + predicate, + row_count, + parallel, + projection, + use_statistics, + hive_partition_columns, + ) + } else { + rg_to_dfs_par_over_rg( + store, + row_group_start, + row_group_end, + previous_row_count, + remaining_rows, + file_metadata, + schema, + predicate, + row_count, + projection, + use_statistics, + hive_partition_columns, + ) + } +} + +#[allow(clippy::too_many_arguments)] +// might parallelize over columns +fn rg_to_dfs_optionally_par_over_columns( + store: &mmap::ColumnStore, + previous_row_count: &mut IdxSize, + row_group_start: usize, + row_group_end: usize, + remaining_rows: &mut usize, + file_metadata: &FileMetaData, + schema: &ArrowSchema, + predicate: Option>, + row_count: Option, + parallel: ParallelStrategy, + projection: &[usize], + use_statistics: bool, + hive_partition_columns: Option<&[Series]>, ) -> PolarsResult> { let mut dfs = Vec::with_capacity(row_group_end - row_group_start); @@ -149,6 +220,7 @@ fn rg_to_dfs( if let Some(rc) = &row_count { df.with_row_count_mut(&rc.name, Some(*previous_row_count + rc.offset)); } + materialize_hive_partitions(&mut df, hive_partition_columns); apply_predicate(&mut df, predicate.as_deref(), true)?; @@ -164,7 +236,7 @@ fn rg_to_dfs( #[allow(clippy::too_many_arguments)] // parallelizes over row groups -fn rg_to_dfs_par( +fn rg_to_dfs_par_over_rg( store: &mmap::ColumnStore, row_group_start: usize, row_group_end: usize, @@ -176,6 +248,7 @@ fn rg_to_dfs_par( row_count: Option, projection: &[usize], use_statistics: bool, + hive_partition_columns: Option<&[Series]>, ) -> PolarsResult> { // compute the limits per row group and the row count offsets let row_groups = file_metadata @@ -223,6 +296,7 @@ fn rg_to_dfs_par( if let Some(rc) = &row_count { df.with_row_count_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset)); } + materialize_hive_partitions(&mut df, hive_partition_columns); apply_predicate(&mut df, predicate.as_deref(), false)?; @@ -243,6 +317,7 @@ pub fn read_parquet( mut parallel: ParallelStrategy, row_count: Option, use_statistics: bool, + hive_partition_columns: Option<&[Series]>, ) -> PolarsResult { let file_metadata = metadata .map(Ok) @@ -252,10 +327,10 @@ pub fn read_parquet( // if there are multiple row groups and categorical data // we need a string cache // we keep it alive until the end of the function - let _string_cache = if n_row_groups > 1 { + let _sc = if n_row_groups > 1 { #[cfg(feature = "dtype-categorical")] { - Some(polars_core::IUseStringCache::hold()) + Some(polars_core::StringCacheHolder::hold()) } #[cfg(not(feature = "dtype-categorical"))] { @@ -284,37 +359,22 @@ pub fn read_parquet( let reader = ReaderBytes::from(&reader); let bytes = reader.deref(); let store = mmap::ColumnStore::Local(bytes); - let dfs = match parallel { - ParallelStrategy::Columns | ParallelStrategy::None => rg_to_dfs( - &store, - &mut 0, - 0, - n_row_groups, - &mut limit, - &file_metadata, - schema, - predicate, - row_count, - parallel, - &projection, - use_statistics, - )?, - ParallelStrategy::RowGroups => rg_to_dfs_par( - &store, - 0, - file_metadata.row_groups.len(), - &mut 0, - &mut limit, - &file_metadata, - schema, - predicate, - row_count, - &projection, - use_statistics, - )?, - // auto should already be replaced by Columns or RowGroups - ParallelStrategy::Auto => unimplemented!(), - }; + + let dfs = rg_to_dfs( + &store, + &mut 0, + 0, + n_row_groups, + &mut limit, + &file_metadata, + schema, + predicate, + row_count, + parallel, + &projection, + use_statistics, + hive_partition_columns, + )?; if dfs.is_empty() { let schema = if let Cow::Borrowed(_) = projection { @@ -322,20 +382,20 @@ pub fn read_parquet( } else { Cow::Borrowed(schema) }; - Ok(arrow_schema_to_empty_df(&schema)) + let mut df = arrow_schema_to_empty_df(&schema); + if let Some(parts) = hive_partition_columns { + for s in parts { + // SAFETY: length is equal + unsafe { df.with_column_unchecked(s.clear()) }; + } + } + Ok(df) } else { accumulate_dataframes_vertical(dfs) } } -/// Provide RowGroup content to the BatchedReader. -/// This allows us to share the code to do in-memory processing for different use cases. -pub trait FetchRowGroups: Sync + Send { - /// Fetch the row groups in the given range and package them in a ColumnStore. - fn fetch_row_groups(&mut self, row_groups: Range) -> PolarsResult; -} - -pub(crate) struct FetchRowGroupsFromMmapReader(ReaderBytes<'static>); +pub struct FetchRowGroupsFromMmapReader(ReaderBytes<'static>); impl FetchRowGroupsFromMmapReader { pub fn new(mut reader: Box) -> PolarsResult { @@ -350,23 +410,50 @@ impl FetchRowGroupsFromMmapReader { let reader_bytes = get_reader_bytes(reader_ptr)?; Ok(FetchRowGroupsFromMmapReader(reader_bytes)) } + async fn fetch_row_groups(&mut self, _row_groups: Range) -> PolarsResult { + Ok(mmap::ColumnStore::Local(self.0.deref())) + } } -/// There is nothing to do when fetching a mmap-ed file. -impl FetchRowGroups for FetchRowGroupsFromMmapReader { - fn fetch_row_groups(&mut self, _row_groups: Range) -> PolarsResult { - Ok(mmap::ColumnStore::Local(self.0.deref())) +// We couldn't use a trait as async trait gave very hard HRT lifetime errors. +// Maybe a puzzle for another day. +pub enum RowGroupFetcher { + #[cfg(feature = "cloud")] + ObjectStore(FetchRowGroupsFromObjectStore), + Local(FetchRowGroupsFromMmapReader), +} + +#[cfg(feature = "cloud")] +impl From for RowGroupFetcher { + fn from(value: FetchRowGroupsFromObjectStore) -> Self { + RowGroupFetcher::ObjectStore(value) + } +} + +impl From for RowGroupFetcher { + fn from(value: FetchRowGroupsFromMmapReader) -> Self { + RowGroupFetcher::Local(value) + } +} + +impl RowGroupFetcher { + async fn fetch_row_groups(&mut self, _row_groups: Range) -> PolarsResult { + match self { + RowGroupFetcher::Local(f) => f.fetch_row_groups(_row_groups).await, + #[cfg(feature = "cloud")] + RowGroupFetcher::ObjectStore(f) => f.fetch_row_groups(_row_groups).await, + } } } pub struct BatchedParquetReader { // use to keep ownership #[allow(dead_code)] - row_group_fetcher: Box, + row_group_fetcher: RowGroupFetcher, limit: usize, projection: Vec, - schema: ArrowSchema, - metadata: FileMetaData, + schema: Arc, + metadata: Arc, row_count: Option, rows_read: IdxSize, row_group_offset: usize, @@ -375,19 +462,22 @@ pub struct BatchedParquetReader { parallel: ParallelStrategy, chunk_size: usize, use_statistics: bool, + hive_partition_columns: Option>, } impl BatchedParquetReader { + #[allow(clippy::too_many_arguments)] pub fn new( - row_group_fetcher: Box, - metadata: FileMetaData, + row_group_fetcher: RowGroupFetcher, + metadata: Arc, limit: usize, projection: Option>, row_count: Option, chunk_size: usize, use_statistics: bool, + hive_partition_columns: Option>, ) -> PolarsResult { - let schema = read::schema::infer_schema(&metadata)?; + let schema = Arc::new(read::schema::infer_schema(&metadata)?); let n_row_groups = metadata.row_groups.len(); let projection = projection.unwrap_or_else(|| (0usize..schema.fields.len()).collect::>()); @@ -413,55 +503,37 @@ impl BatchedParquetReader { parallel, chunk_size, use_statistics, + hive_partition_columns, }) } - pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { + pub async fn next_batches(&mut self, n: usize) -> PolarsResult>> { // fill up fifo stack if self.row_group_offset <= self.n_row_groups && self.chunks_fifo.len() < n { let row_group_start = self.row_group_offset; let row_group_end = std::cmp::min(self.row_group_offset + n, self.n_row_groups); let store = self .row_group_fetcher - .fetch_row_groups(row_group_start..row_group_end)?; - let dfs = match self.parallel { - ParallelStrategy::Columns => { - let dfs = rg_to_dfs( - &store, - &mut self.rows_read, - row_group_start, - row_group_end, - &mut self.limit, - &self.metadata, - &self.schema, - None, - self.row_count.clone(), - ParallelStrategy::Columns, - &self.projection, - self.use_statistics, - )?; - self.row_group_offset += n; - dfs - }, - ParallelStrategy::RowGroups => { - let dfs = rg_to_dfs_par( - &store, - self.row_group_offset, - std::cmp::min(self.row_group_offset + n, self.n_row_groups), - &mut self.rows_read, - &mut self.limit, - &self.metadata, - &self.schema, - None, - self.row_count.clone(), - &self.projection, - self.use_statistics, - )?; - self.row_group_offset += n; - dfs - }, - _ => unimplemented!(), - }; + .fetch_row_groups(row_group_start..row_group_end) + .await?; + + let dfs = rg_to_dfs( + &store, + &mut self.rows_read, + row_group_start, + row_group_end, + &mut self.limit, + &self.metadata, + &self.schema, + None, + self.row_count.clone(), + self.parallel, + &self.projection, + self.use_statistics, + self.hive_partition_columns.as_deref(), + )?; + + self.row_group_offset += n; // case where there is no data in the file // the streaming engine needs at least a single chunk if self.rows_read == 0 && dfs.is_empty() { @@ -511,28 +583,30 @@ impl BatchedParquetReader { } /// Turn the batched reader into an iterator. - pub fn iter(self, batch_size: usize) -> BatchedParquetIter { + #[cfg(feature = "async")] + pub fn iter(self, batches_per_iter: usize) -> BatchedParquetIter { BatchedParquetIter { - batch_size, + batches_per_iter, inner: self, current_batch: vec![].into_iter(), } } } +#[cfg(feature = "async")] pub struct BatchedParquetIter { - batch_size: usize, + batches_per_iter: usize, inner: BatchedParquetReader, current_batch: std::vec::IntoIter, } -impl Iterator for BatchedParquetIter { - type Item = PolarsResult; - - fn next(&mut self) -> Option { +#[cfg(feature = "async")] +impl BatchedParquetIter { + // todo! implement stream + pub(crate) async fn next_(&mut self) -> Option> { match self.current_batch.next() { Some(df) => Some(Ok(df)), - None => match self.inner.next_batches(self.batch_size) { + None => match self.inner.next_batches(self.batches_per_iter).await { Err(e) => Some(Err(e)), Ok(opt_batch) => { let batch = opt_batch?; diff --git a/crates/polars-io/src/pl_async.rs b/crates/polars-io/src/pl_async.rs new file mode 100644 index 000000000000..42f56fa1acca --- /dev/null +++ b/crates/polars-io/src/pl_async.rs @@ -0,0 +1,79 @@ +use std::collections::BTreeSet; +use std::future::Future; +use std::ops::Deref; +use std::sync::RwLock; + +use once_cell::sync::Lazy; +use polars_core::POOL; +use tokio::runtime::{Builder, Runtime}; + +pub struct RuntimeManager { + rt: Runtime, + blocking_rayon_threads: RwLock>, +} + +impl RuntimeManager { + fn new() -> Self { + let rt = Builder::new_multi_thread() + .worker_threads(std::cmp::max(POOL.current_num_threads() / 2, 4)) + .enable_io() + .enable_time() + .build() + .unwrap(); + + Self { + rt, + blocking_rayon_threads: Default::default(), + } + } + + /// Keep track of rayon threads that drive the runtime. Every thread + /// only allows a single runtime. If this thread calls block_on and this + /// rayon thread is already driving an async execution we must start a new thread + /// otherwise we panic. This can happen when we parallelize reads over 100s of files. + pub fn block_on_potential_spawn(&'static self, future: F) -> F::Output + where + F: Future + Send, + F::Output: Send, + { + if let Some(thread_id) = POOL.current_thread_index() { + if self + .blocking_rayon_threads + .read() + .unwrap() + .contains(&thread_id) + { + std::thread::scope(|s| s.spawn(|| self.rt.block_on(future)).join().unwrap()) + } else { + self.blocking_rayon_threads + .write() + .unwrap() + .insert(thread_id); + let out = self.rt.block_on(future); + self.blocking_rayon_threads + .write() + .unwrap() + .remove(&thread_id); + out + } + } + // Assumption that the main thread never runs rayon tasks, so we wouldn't be rescheduled + // on the main thread and thus we can always block. + else { + self.rt.block_on(future) + } + } + + pub fn block_on(&self, future: F) -> F::Output + where + F: Future, + { + self.rt.block_on(future) + } +} + +static RUNTIME: Lazy = Lazy::new(RuntimeManager::new); + +pub fn get_runtime() -> &'static RuntimeManager { + RUNTIME.deref() +} diff --git a/crates/polars-io/src/predicates.rs b/crates/polars-io/src/predicates.rs index 4a3a675105bf..6beac2a3af12 100644 --- a/crates/polars-io/src/predicates.rs +++ b/crates/polars-io/src/predicates.rs @@ -1,22 +1,22 @@ use polars_core::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; pub trait PhysicalIoExpr: Send + Sync { - /// Take a `DataFrame` and produces a boolean `Series` that serves + /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves /// as a predicate mask fn evaluate(&self, df: &DataFrame) -> PolarsResult; /// Can take &dyn Statistics and determine of a file should be /// read -> `true` /// or not -> `false` - #[cfg(feature = "parquet")] fn as_stats_evaluator(&self) -> Option<&dyn StatsEvaluator> { None } } -#[cfg(feature = "parquet")] pub trait StatsEvaluator { - fn should_read(&self, stats: &crate::parquet::predicates::BatchStats) -> PolarsResult; + fn should_read(&self, stats: &BatchStats) -> PolarsResult; } #[cfg(feature = "parquet")] @@ -47,3 +47,149 @@ pub(crate) fn apply_predicate( } Ok(()) } + +/// The statistics for a column in a Parquet file +/// or Hive partition. +/// they typically hold +/// - max value +/// - min value +/// - null_count +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ColumnStats { + field: Field, + // The array may hold the null count for every row group, + // or for a single row group. + null_count: Option, + min_value: Option, + max_value: Option, +} + +impl ColumnStats { + pub fn new( + field: Field, + null_count: Option, + min_value: Option, + max_value: Option, + ) -> Self { + Self { + field, + null_count, + min_value, + max_value, + } + } + + pub fn from_column_literal(s: Series) -> Self { + debug_assert_eq!(s.len(), 1); + Self { + field: s.field().into_owned(), + null_count: None, + min_value: Some(s.clone()), + max_value: Some(s), + } + } + + pub fn dtype(&self) -> &DataType { + self.field.data_type() + } + + pub fn null_count(&self) -> Option { + match self.field.data_type() { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => None, + _ => { + let s = self.null_count.as_ref()?; + // if all null, there are no statistics. + if s.null_count() != s.len() { + s.sum() + } else { + None + } + }, + } + } + + pub fn to_min_max(&self) -> Option { + let max_val = self.max_value.as_ref()?; + let min_val = self.min_value.as_ref()?; + + let dtype = min_val.dtype(); + + if Self::use_min_max(dtype) { + let mut min_max_values = min_val.clone(); + min_max_values.append(max_val).unwrap(); + if min_max_values.null_count() > 0 { + None + } else { + Some(min_max_values) + } + } else { + None + } + } + + pub fn get_min_state(&self) -> Option<&Series> { + self.min_value.as_ref() + } + + pub fn to_min(&self) -> Option<&Series> { + let min_val = self.min_value.as_ref()?; + let dtype = min_val.dtype(); + + if !Self::use_min_max(dtype) || min_val.len() != 1 { + return None; + } + + if min_val.null_count() > 0 { + None + } else { + Some(min_val) + } + } + + pub fn to_max(&self) -> Option<&Series> { + let max_val = self.max_value.as_ref()?; + let dtype = max_val.dtype(); + + if !Self::use_min_max(dtype) || max_val.len() != 1 { + return None; + } + + if max_val.null_count() > 0 { + None + } else { + Some(max_val) + } + } + + fn use_min_max(dtype: &DataType) -> bool { + dtype.is_numeric() || matches!(dtype, DataType::Utf8 | DataType::Binary | DataType::Boolean) + } +} + +/// A collection of column stats with a known schema. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct BatchStats { + schema: Schema, + stats: Vec, +} + +impl BatchStats { + pub fn new(schema: Schema, stats: Vec) -> Self { + Self { schema, stats } + } + + pub fn get_stats(&self, column: &str) -> polars_core::error::PolarsResult<&ColumnStats> { + self.schema.try_index_of(column).map(|i| &self.stats[i]) + } + + pub fn schema(&self) -> &Schema { + &self.schema + } + + pub fn column_stats(&self) -> &[ColumnStats] { + self.stats.as_ref() + } +} diff --git a/crates/polars-io/src/prelude.rs b/crates/polars-io/src/prelude.rs index f62a5e357187..2d1362c6970f 100644 --- a/crates/polars-io/src/prelude.rs +++ b/crates/polars-io/src/prelude.rs @@ -12,7 +12,8 @@ pub use crate::ndjson::core::*; #[cfg(feature = "parquet")] pub use crate::parquet::*; pub use crate::utils::*; -pub use crate::{SerReader, SerWriter}; +pub use crate::{cloud, SerReader, SerWriter}; + #[cfg(test)] pub(crate) fn create_df() -> DataFrame { let s0 = Series::new("days", [0, 1, 2, 3, 4].as_ref()); diff --git a/crates/polars-io/src/utils.rs b/crates/polars-io/src/utils.rs index ae15d7982152..2bf09380cd46 100644 --- a/crates/polars-io/src/utils.rs +++ b/crates/polars-io/src/utils.rs @@ -1,8 +1,12 @@ +use std::io::Read; use std::path::{Path, PathBuf}; +use once_cell::sync::Lazy; use polars_core::frame::DataFrame; use polars_core::prelude::*; +use regex::{Regex, RegexBuilder}; +use crate::mmap::{MmapBytesReader, ReaderBytes}; #[cfg(any( feature = "ipc", feature = "ipc_streaming", @@ -11,6 +15,37 @@ use polars_core::prelude::*; ))] use crate::ArrowSchema; +pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( + reader: &'a mut R, +) -> PolarsResult> { + // we have a file so we can mmap + if let Some(file) = reader.to_file() { + let mmap = unsafe { memmap::Mmap::map(file)? }; + + // somehow bck thinks borrows alias + // this is sound as file was already bound to 'a + use std::fs::File; + let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; + Ok(ReaderBytes::Mapped(mmap, file)) + } else { + // we can get the bytes for free + if reader.to_bytes().is_some() { + // duplicate .to_bytes() is necessary to satisfy the borrow checker + Ok(ReaderBytes::Borrowed((*reader).to_bytes().unwrap())) + } else { + // we have to read to an owned buffer to get the bytes. + let mut bytes = Vec::with_capacity(1024 * 128); + reader.read_to_end(&mut bytes)?; + if !bytes.is_empty() + && (bytes[bytes.len() - 1] != b'\n' || bytes[bytes.len() - 1] != b'\r') + { + bytes.push(b'\n') + } + Ok(ReaderBytes::Owned(bytes)) + } + } +} + // used by python polars pub fn resolve_homedir(path: &Path) -> PathBuf { // replace "~" with home directory @@ -118,6 +153,51 @@ pub(crate) fn overwrite_schema( Ok(()) } +pub static FLOAT_RE: Lazy = Lazy::new(|| { + Regex::new(r"^\s*[-+]?((\d*\.\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+\.)$").unwrap() +}); + +pub static INTEGER_RE: Lazy = Lazy::new(|| Regex::new(r"^\s*-?(\d+)$").unwrap()); + +pub static BOOLEAN_RE: Lazy = Lazy::new(|| { + RegexBuilder::new(r"^\s*(true)$|^(false)$") + .case_insensitive(true) + .build() + .unwrap() +}); + +pub fn materialize_projection( + with_columns: Option<&[String]>, + schema: &Schema, + hive_partitions: Option<&[Series]>, + has_row_count: bool, +) -> Option> { + match hive_partitions { + None => with_columns.map(|with_columns| { + with_columns + .iter() + .map(|name| schema.index_of(name).unwrap() - has_row_count as usize) + .collect() + }), + Some(part_cols) => { + with_columns.map(|with_columns| { + with_columns + .iter() + .flat_map(|name| { + // the hive partitions are added at the end of the schema, but we don't want to project + // them from the file + if part_cols.iter().any(|s| s.name() == name.as_str()) { + None + } else { + Some(schema.index_of(name).unwrap() - has_row_count as usize) + } + }) + .collect() + }) + }, + } +} + #[cfg(test)] mod tests { use std::path::PathBuf; diff --git a/crates/polars-json/Cargo.toml b/crates/polars-json/Cargo.toml index 8a9d0d53fb9a..583ade78928e 100644 --- a/crates/polars-json/Cargo.toml +++ b/crates/polars-json/Cargo.toml @@ -9,14 +9,18 @@ repository = { workspace = true } description = "JSON related logic for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", default-features = false } -polars-error = { version = "0.32.0", path = "../polars-error" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true } +polars-error = { workspace = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } +chrono = { workspace = true } fallible-streaming-iterator = { version = "0.1" } hashbrown = { workspace = true } indexmap = { workspace = true } +itoa = { workspace = true } num-traits = { workspace = true } +ryu = { workspace = true } simd-json = { workspace = true } +streaming-iterator = { workspace = true } diff --git a/crates/polars-json/README.md b/crates/polars-json/README.md new file mode 100644 index 000000000000..d0fa8cab5a8b --- /dev/null +++ b/crates/polars-json/README.md @@ -0,0 +1,5 @@ +# polars-json + +`polars-json` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, provides functionalities to handle JSON objects. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-json/src/json/deserialize.rs b/crates/polars-json/src/json/deserialize.rs index 3baa2ea0b61f..0ce3831a46eb 100644 --- a/crates/polars-json/src/json/deserialize.rs +++ b/crates/polars-json/src/json/deserialize.rs @@ -85,9 +85,18 @@ fn deserialize_list<'a, A: Borrow>>( inner.extend(value.iter()); validity.push(true); offsets - .try_push_usize(value.len()) + .try_push(value.len()) .expect("List offset is too large :/"); }, + BorrowedValue::Static(StaticNode::Null) => { + validity.push(false); + offsets.extend_constant(1) + }, + value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { + inner.push(value); + validity.push(true); + offsets.try_push(1).expect("List offset is too large :/"); + }, _ => { validity.push(false); offsets.extend_constant(1); diff --git a/crates/polars-json/src/json/mod.rs b/crates/polars-json/src/json/mod.rs index 1ab9c2dd15ce..d39d7513c431 100644 --- a/crates/polars-json/src/json/mod.rs +++ b/crates/polars-json/src/json/mod.rs @@ -5,3 +5,4 @@ pub use deserialize::deserialize; pub use infer_schema::{infer, infer_records_schema}; use polars_error::*; use polars_utils::aliases::*; +pub mod write; diff --git a/crates/polars-json/src/json/write/mod.rs b/crates/polars-json/src/json/write/mod.rs new file mode 100644 index 000000000000..343bae73e520 --- /dev/null +++ b/crates/polars-json/src/json/write/mod.rs @@ -0,0 +1,157 @@ +//! APIs to write to JSON +mod serialize; +mod utf8; + +use std::io::Write; + +use arrow::array::Array; +use arrow::chunk::Chunk; +use arrow::datatypes::Schema; +use arrow::error::Error; +use arrow::io::iterator::StreamingIterator; +pub use fallible_streaming_iterator::*; +pub(crate) use serialize::new_serializer; +use serialize::serialize; + +/// [`FallibleStreamingIterator`] that serializes an [`Array`] to bytes of valid JSON +/// # Implementation +/// Advancing this iterator CPU-bounded +#[derive(Debug, Clone)] +pub struct Serializer +where + A: AsRef, + I: Iterator>, +{ + arrays: I, + buffer: Vec, +} + +impl Serializer +where + A: AsRef, + I: Iterator>, +{ + /// Creates a new [`Serializer`]. + pub fn new(arrays: I, buffer: Vec) -> Self { + Self { arrays, buffer } + } +} + +impl FallibleStreamingIterator for Serializer +where + A: AsRef, + I: Iterator>, +{ + type Item = [u8]; + + type Error = Error; + + fn advance(&mut self) -> Result<(), Error> { + self.buffer.clear(); + self.arrays + .next() + .map(|maybe_array| maybe_array.map(|array| serialize(array.as_ref(), &mut self.buffer))) + .transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// [`FallibleStreamingIterator`] that serializes a [`Chunk`] into bytes of JSON +/// in a (pandas-compatible) record-oriented format. +/// +/// # Implementation +/// Advancing this iterator is CPU-bounded. +pub struct RecordSerializer<'a> { + schema: Schema, + index: usize, + end: usize, + iterators: Vec + Send + Sync + 'a>>, + buffer: Vec, +} + +impl<'a> RecordSerializer<'a> { + /// Creates a new [`RecordSerializer`]. + pub fn new(schema: Schema, chunk: &'a Chunk, buffer: Vec) -> Self + where + A: AsRef, + { + let end = chunk.len(); + let iterators = chunk + .arrays() + .iter() + .map(|arr| new_serializer(arr.as_ref(), 0, usize::MAX)) + .collect(); + + Self { + schema, + index: 0, + end, + iterators, + buffer, + } + } +} + +impl<'a> FallibleStreamingIterator for RecordSerializer<'a> { + type Item = [u8]; + + type Error = Error; + + fn advance(&mut self) -> Result<(), Error> { + self.buffer.clear(); + if self.index == self.end { + return Ok(()); + } + + let mut is_first_row = true; + write!(&mut self.buffer, "{{")?; + for (f, ref mut it) in self.schema.fields.iter().zip(self.iterators.iter_mut()) { + if !is_first_row { + write!(&mut self.buffer, ",")?; + } + write!(&mut self.buffer, "\"{}\":", f.name)?; + + self.buffer.extend_from_slice(it.next().unwrap()); + is_first_row = false; + } + write!(&mut self.buffer, "}}")?; + + self.index += 1; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// Writes valid JSON from an iterator of (assumed JSON-encoded) bytes to `writer` +pub fn write(writer: &mut W, mut blocks: I) -> Result<(), Error> +where + W: std::io::Write, + I: FallibleStreamingIterator, +{ + writer.write_all(&[b'['])?; + let mut is_first_row = true; + while let Some(block) = blocks.next()? { + if !is_first_row { + writer.write_all(&[b','])?; + } + is_first_row = false; + writer.write_all(block)?; + } + writer.write_all(&[b']'])?; + Ok(()) +} diff --git a/crates/polars-json/src/json/write/serialize.rs b/crates/polars-json/src/json/write/serialize.rs new file mode 100644 index 000000000000..7622d006e761 --- /dev/null +++ b/crates/polars-json/src/json/write/serialize.rs @@ -0,0 +1,522 @@ +use std::io::Write; + +use arrow::array::*; +use arrow::bitmap::utils::ZipValidity; +use arrow::datatypes::{DataType, IntegerType, TimeUnit}; +use arrow::io::iterator::BufStreamingIterator; +use arrow::offset::Offset; +#[cfg(feature = "chrono-tz")] +use arrow::temporal_conversions::parse_offset_tz; +use arrow::temporal_conversions::{ + date32_to_date, date64_to_date, duration_ms_to_duration, duration_ns_to_duration, + duration_s_to_duration, duration_us_to_duration, parse_offset, timestamp_ms_to_datetime, + timestamp_ns_to_datetime, timestamp_s_to_datetime, timestamp_to_datetime, + timestamp_us_to_datetime, +}; +use arrow::types::NativeType; +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use streaming_iterator::StreamingIterator; + +use super::utf8; + +fn write_integer(buf: &mut Vec, val: I) { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(val); + buf.extend_from_slice(value.as_bytes()) +} + +fn write_float(f: &mut Vec, val: I) { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) +} + +fn materialize_serializer<'a, I, F, T>( + f: F, + iterator: I, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: 'a, + I: Iterator + Send + Sync + 'a, + F: FnMut(T, &mut Vec) + Send + Sync + 'a, +{ + if offset > 0 || take < usize::MAX { + Box::new(BufStreamingIterator::new( + iterator.skip(offset).take(take), + f, + vec![], + )) + } else { + Box::new(BufStreamingIterator::new(iterator, f, vec![])) + } +} + +fn boolean_serializer<'a>( + array: &'a BooleanArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option, buf: &mut Vec| match x { + Some(true) => buf.extend_from_slice(b"true"), + Some(false) => buf.extend_from_slice(b"false"), + None => buf.extend_from_slice(b"null"), + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn null_serializer( + len: usize, + offset: usize, + take: usize, +) -> Box + Send + Sync> { + let f = |_x: (), buf: &mut Vec| buf.extend_from_slice(b"null"); + materialize_serializer(f, std::iter::repeat(()).take(len), offset, take) +} + +fn primitive_serializer<'a, T: NativeType + itoa::Integer>( + array: &'a PrimitiveArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + write_integer(buf, *x) + } else { + buf.extend(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn float_serializer<'a, T>( + array: &'a PrimitiveArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: num_traits::Float + NativeType + ryu::Float, +{ + let f = |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + if T::is_nan(*x) || T::is_infinite(*x) { + buf.extend(b"null") + } else { + write_float(buf, *x) + } + } else { + buf.extend(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn dictionary_utf8_serializer<'a, K: DictionaryKey, O: Offset>( + array: &'a DictionaryArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let iter = array.iter_typed::>().unwrap().skip(offset); + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, iter, offset, take) +} + +fn utf8_serializer<'a, O: Offset>( + array: &'a Utf8Array, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn struct_serializer<'a>( + array: &'a StructArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + // {"a": [1, 2, 3], "b": [a, b, c], "c": {"a": [1, 2, 3]}} + // [ + // {"a": 1, "b": a, "c": {"a": 1}}, + // {"a": 2, "b": b, "c": {"a": 2}}, + // {"a": 3, "b": c, "c": {"a": 3}}, + // ] + // + let mut serializers = array + .values() + .iter() + .map(|x| x.as_ref()) + .map(|arr| new_serializer(arr, offset, take)) + .collect::>(); + let names = array.fields().iter().map(|f| f.name.as_str()); + + Box::new(BufStreamingIterator::new( + ZipValidity::new_with_validity(0..array.len(), array.validity()), + move |maybe, buf| { + if maybe.is_some() { + let names = names.clone(); + let mut record: Vec<(&str, &[u8])> = Default::default(); + serializers + .iter_mut() + .zip(names) + // `unwrap` is infalible because `array.len()` equals `len` on `Chunk` + .for_each(|(iter, name)| { + let item = iter.next().unwrap(); + record.push((name, item)); + }); + serialize_item(buf, &record, true); + } else { + serializers.iter_mut().for_each(|iter| { + let _ = iter.next(); + }); + buf.extend(b"null"); + } + }, + vec![], + )) +} + +fn list_serializer<'a, O: Offset>( + array: &'a ListArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + // [[1, 2], [3]] + // [ + // [1, 2], + // [3] + // ] + // + let offsets = array.offsets().as_slice(); + let start = offsets[0].to_usize(); + let end = offsets.last().unwrap().to_usize(); + let mut serializer = new_serializer(array.values().as_ref(), start, end - start); + + let f = move |offset: Option<&[O]>, buf: &mut Vec| { + if let Some(offset) = offset { + let length = (offset[1] - offset[0]).to_usize(); + buf.push(b'['); + let mut is_first_row = true; + for _ in 0..length { + if !is_first_row { + buf.push(b','); + } + is_first_row = false; + buf.extend(serializer.next().unwrap()); + } + buf.push(b']'); + } else { + buf.extend(b"null"); + } + }; + + let iter = + ZipValidity::new_with_validity(array.offsets().buffer().windows(2), array.validity()); + materialize_serializer(f, iter, offset, take) +} + +fn fixed_size_list_serializer<'a>( + array: &'a FixedSizeListArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let mut serializer = new_serializer(array.values().as_ref(), offset, take); + + Box::new(BufStreamingIterator::new( + ZipValidity::new(0..array.len(), array.validity().map(|x| x.iter())), + move |ix, buf| { + if ix.is_some() { + let length = array.size(); + buf.push(b'['); + let mut is_first_row = true; + for _ in 0..length { + if !is_first_row { + buf.push(b','); + } + is_first_row = false; + buf.extend(serializer.next().unwrap()); + } + buf.push(b']'); + } else { + buf.extend(b"null"); + } + }, + vec![], + )) +} + +fn date_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> NaiveDate + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let nd = convert(*x); + write!(buf, "\"{nd}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn duration_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> Duration + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let duration = convert(*x); + write!(buf, "\"{duration}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn timestamp_serializer<'a, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + F: Fn(i64) -> NaiveDateTime + 'static + Send + Sync, +{ + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let ndt = convert(*x); + write!(buf, "\"{ndt}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn timestamp_tz_serializer<'a>( + array: &'a PrimitiveArray, + time_unit: TimeUnit, + tz: &str, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + match parse_offset(tz) { + Ok(parsed_tz) => { + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let dt_str = timestamp_to_datetime(*x, time_unit, &parsed_tz).to_rfc3339(); + write!(buf, "\"{dt_str}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) + }, + #[cfg(feature = "chrono-tz")] + _ => match parse_offset_tz(tz) { + Ok(parsed_tz) => { + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let dt_str = timestamp_to_datetime(*x, time_unit, &parsed_tz).to_rfc3339(); + write!(buf, "\"{dt_str}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) + }, + _ => { + panic!("Timezone {} is invalid or not supported", tz); + }, + }, + #[cfg(not(feature = "chrono-tz"))] + _ => { + panic!("Invalid Offset format (must be [-]00:00) or chrono-tz feature not active"); + }, + } +} + +pub(crate) fn new_serializer<'a>( + array: &'a dyn Array, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + match array.data_type().to_logical_type() { + DataType::Boolean => { + boolean_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int8 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int16 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int32 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int64 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt8 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt16 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt32 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt64 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Float32 => { + float_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Float64 => { + float_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Utf8 => { + utf8_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::LargeUtf8 => { + utf8_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Struct(_) => { + struct_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::FixedSizeList(_, _) => { + fixed_size_list_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::List(_) => { + list_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::LargeList(_) => { + list_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + other @ DataType::Dictionary(k, v, _) => match (k, &**v) { + (IntegerType::UInt32, DataType::LargeUtf8) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + dictionary_utf8_serializer::(array, offset, take) + }, + _ => { + todo!("Writing {:?} to JSON", other) + }, + }, + DataType::Date32 => date_serializer( + array.as_any().downcast_ref().unwrap(), + date32_to_date, + offset, + take, + ), + DataType::Date64 => date_serializer( + array.as_any().downcast_ref().unwrap(), + date64_to_date, + offset, + take, + ), + DataType::Timestamp(tu, None) => { + let convert = match tu { + TimeUnit::Nanosecond => timestamp_ns_to_datetime, + TimeUnit::Microsecond => timestamp_us_to_datetime, + TimeUnit::Millisecond => timestamp_ms_to_datetime, + TimeUnit::Second => timestamp_s_to_datetime, + }; + timestamp_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + DataType::Timestamp(time_unit, Some(tz)) => timestamp_tz_serializer( + array.as_any().downcast_ref().unwrap(), + *time_unit, + tz, + offset, + take, + ), + DataType::Duration(tu) => { + let convert = match tu { + TimeUnit::Nanosecond => duration_ns_to_duration, + TimeUnit::Microsecond => duration_us_to_duration, + TimeUnit::Millisecond => duration_ms_to_duration, + TimeUnit::Second => duration_s_to_duration, + }; + duration_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + DataType::Null => null_serializer(array.len(), offset, take), + other => todo!("Writing {:?} to JSON", other), + } +} + +fn serialize_item(buffer: &mut Vec, record: &[(&str, &[u8])], is_first_row: bool) { + if !is_first_row { + buffer.push(b','); + } + buffer.push(b'{'); + let mut first_item = true; + for (key, value) in record { + if !first_item { + buffer.push(b','); + } + first_item = false; + utf8::write_str(buffer, key).unwrap(); + buffer.push(b':'); + buffer.extend(*value); + } + buffer.push(b'}'); +} + +/// Serializes `array` to a valid JSON to `buffer` +/// # Implementation +/// This operation is CPU-bounded +pub(crate) fn serialize(array: &dyn Array, buffer: &mut Vec) { + let mut serializer = new_serializer(array, 0, usize::MAX); + + (0..array.len()).for_each(|i| { + if i != 0 { + buffer.push(b','); + } + buffer.extend_from_slice(serializer.next().unwrap()); + }); +} diff --git a/crates/polars-json/src/json/write/utf8.rs b/crates/polars-json/src/json/write/utf8.rs new file mode 100644 index 000000000000..941d73379c3d --- /dev/null +++ b/crates/polars-json/src/json/write/utf8.rs @@ -0,0 +1,138 @@ +// Adapted from https://github.com/serde-rs/json/blob/f901012df66811354cb1d490ad59480d8fdf77b5/src/ser.rs +use std::io; + +pub fn write_str(writer: &mut W, value: &str) -> io::Result<()> +where + W: io::Write, +{ + writer.write_all(b"\"")?; + let bytes = value.as_bytes(); + + let mut start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + let escape = ESCAPE[byte as usize]; + if escape == 0 { + continue; + } + + if start < i { + writer.write_all(&bytes[start..i])?; + } + + let char_escape = CharEscape::from_escape_table(escape, byte); + write_char_escape(writer, char_escape)?; + + start = i + 1; + } + + if start != bytes.len() { + writer.write_all(&bytes[start..])?; + } + writer.write_all(b"\"") +} + +const BB: u8 = b'b'; // \x08 +const TT: u8 = b't'; // \x09 +const NN: u8 = b'n'; // \x0A +const FF: u8 = b'f'; // \x0C +const RR: u8 = b'r'; // \x0D +const QU: u8 = b'"'; // \x22 +const BS: u8 = b'\\'; // \x5C +const UU: u8 = b'u'; // \x00...\x1F except the ones above +const __: u8 = 0; + +// Lookup table of escape sequences. A value of b'x' at index i means that byte +// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped. +static ESCAPE: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0 + UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1 + __, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F +]; + +/// Represents a character escape code in a type-safe manner. +pub enum CharEscape { + /// An escaped quote `"` + Quote, + /// An escaped reverse solidus `\` + ReverseSolidus, + // An escaped solidus `/` + //Solidus, + /// An escaped backspace character (usually escaped as `\b`) + Backspace, + /// An escaped form feed character (usually escaped as `\f`) + FormFeed, + /// An escaped line feed character (usually escaped as `\n`) + LineFeed, + /// An escaped carriage return character (usually escaped as `\r`) + CarriageReturn, + /// An escaped tab character (usually escaped as `\t`) + Tab, + /// An escaped ASCII plane control character (usually escaped as + /// `\u00XX` where `XX` are two hex characters) + AsciiControl(u8), +} + +impl CharEscape { + #[inline] + fn from_escape_table(escape: u8, byte: u8) -> CharEscape { + match escape { + self::BB => CharEscape::Backspace, + self::TT => CharEscape::Tab, + self::NN => CharEscape::LineFeed, + self::FF => CharEscape::FormFeed, + self::RR => CharEscape::CarriageReturn, + self::QU => CharEscape::Quote, + self::BS => CharEscape::ReverseSolidus, + self::UU => CharEscape::AsciiControl(byte), + _ => unreachable!(), + } + } +} + +#[inline] +fn write_char_escape(writer: &mut W, char_escape: CharEscape) -> io::Result<()> +where + W: io::Write, +{ + use self::CharEscape::*; + + let s = match char_escape { + Quote => b"\\\"", + ReverseSolidus => b"\\\\", + //Solidus => b"\\/", + Backspace => b"\\b", + FormFeed => b"\\f", + LineFeed => b"\\n", + CarriageReturn => b"\\r", + Tab => b"\\t", + AsciiControl(byte) => { + static HEX_DIGITS: [u8; 16] = *b"0123456789abcdef"; + let bytes = &[ + b'\\', + b'u', + b'0', + b'0', + HEX_DIGITS[(byte >> 4) as usize], + HEX_DIGITS[(byte & 0xF) as usize], + ]; + return writer.write_all(bytes); + }, + }; + + writer.write_all(s) +} diff --git a/crates/polars-json/src/ndjson/mod.rs b/crates/polars-json/src/ndjson/mod.rs index 429b1096b1ae..2076715e711f 100644 --- a/crates/polars-json/src/ndjson/mod.rs +++ b/crates/polars-json/src/ndjson/mod.rs @@ -3,5 +3,6 @@ use polars_arrow::prelude::*; use polars_error::*; pub mod deserialize; mod file; +pub mod write; pub use file::{infer, infer_iter}; diff --git a/crates/polars-json/src/ndjson/write.rs b/crates/polars-json/src/ndjson/write.rs new file mode 100644 index 000000000000..5cbda120711f --- /dev/null +++ b/crates/polars-json/src/ndjson/write.rs @@ -0,0 +1,118 @@ +//! APIs to serialize and write to [NDJSON](http://ndjson.org/). +use std::io::Write; + +use arrow::array::Array; +use arrow::error::Error; +pub use fallible_streaming_iterator::FallibleStreamingIterator; + +use super::super::json::write::new_serializer; + +fn serialize(array: &dyn Array, buffer: &mut Vec) { + let mut serializer = new_serializer(array, 0, usize::MAX); + (0..array.len()).for_each(|_| { + buffer.extend_from_slice(serializer.next().unwrap()); + buffer.push(b'\n'); + }); +} + +/// [`FallibleStreamingIterator`] that serializes an [`Array`] to bytes of valid NDJSON +/// where every line is an element of the array. +/// # Implementation +/// Advancing this iterator CPU-bounded +#[derive(Debug, Clone)] +pub struct Serializer +where + A: AsRef, + I: Iterator>, +{ + arrays: I, + buffer: Vec, +} + +impl Serializer +where + A: AsRef, + I: Iterator>, +{ + /// Creates a new [`Serializer`]. + pub fn new(arrays: I, buffer: Vec) -> Self { + Self { arrays, buffer } + } +} + +impl FallibleStreamingIterator for Serializer +where + A: AsRef, + I: Iterator>, +{ + type Item = [u8]; + + type Error = Error; + + fn advance(&mut self) -> Result<(), Error> { + self.buffer.clear(); + self.arrays + .next() + .map(|maybe_array| maybe_array.map(|array| serialize(array.as_ref(), &mut self.buffer))) + .transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// An iterator adapter that receives an implementer of [`Write`] and +/// an implementer of [`FallibleStreamingIterator`] (such as [`Serializer`]) +/// and writes a valid NDJSON +/// # Implementation +/// Advancing this iterator mixes CPU-bounded (serializing arrays) tasks and IO-bounded (write to the writer). +pub struct FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + writer: W, + iterator: I, +} + +impl FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + /// Creates a new [`FileWriter`]. + pub fn new(writer: W, iterator: I) -> Self { + Self { writer, iterator } + } + + /// Returns the inner content of this iterator + /// + /// There are two use-cases for this function: + /// * to continue writing to its writer + /// * to re-use an internal buffer of its iterator + pub fn into_inner(self) -> (W, I) { + (self.writer, self.iterator) + } +} + +impl Iterator for FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + type Item = Result<(), Error>; + + fn next(&mut self) -> Option { + let item = self.iterator.next().transpose()?; + Some(item.and_then(|x| { + self.writer.write_all(x)?; + Ok(()) + })) + } +} diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 31414d8cd7f3..92d7a4f9913e 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -9,15 +9,15 @@ repository = { workspace = true } description = "Lazy query engine for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow" } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["lazy", "zip_with", "random"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", features = ["lazy", "csv"], default-features = false } -polars-json = { version = "0.32.0", path = "../polars-json", optional = true } -polars-ops = { version = "0.32.0", path = "../polars-ops", default-features = false } -polars-pipe = { version = "0.32.0", path = "../polars-pipe", optional = true } -polars-plan = { version = "0.32.0", path = "../polars-plan" } -polars-time = { version = "0.32.0", path = "../polars-time", optional = true } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } +polars-io = { workspace = true, features = ["lazy"] } +polars-json = { workspace = true, optional = true } +polars-ops = { workspace = true } +polars-pipe = { workspace = true, optional = true } +polars-plan = { workspace = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } ahash = { workspace = true } bitflags = { workspace = true } @@ -26,6 +26,7 @@ once_cell = { workspace = true } pyo3 = { workspace = true, optional = true } rayon = { workspace = true } smartstring = { workspace = true } +tokio = { workspace = true, optional = true } [dev-dependencies] serde_json = { workspace = true } @@ -34,36 +35,35 @@ serde_json = { workspace = true } version_check = { workspace = true } [features] -nightly = ["polars-core/nightly", "polars-pipe/nightly", "polars-plan/nightly"] -compile = ["polars-plan/compile"] -streaming = ["chunked_ids", "polars-pipe/compile", "polars-plan/streaming"] -default = ["compile"] -parquet = ["polars-core/parquet", "polars-io/parquet", "polars-plan/parquet", "polars-pipe/parquet"] +nightly = ["polars-core/nightly", "polars-pipe?/nightly", "polars-plan/nightly"] +streaming = ["chunked_ids", "polars-pipe", "polars-plan/streaming", "polars-ops/chunked_ids"] +parquet = ["polars-core/parquet", "polars-io/parquet", "polars-plan/parquet", "polars-pipe?/parquet"] async = [ "polars-plan/async", "polars-io/cloud", - "polars-pipe/async", - "streaming", + "polars-pipe?/async", ] -ipc = ["polars-io/ipc", "polars-plan/ipc", "polars-pipe/ipc"] +cloud = ["async", "polars-pipe?/cloud", "polars-plan/cloud", "tokio"] +cloud_write = ["cloud"] +ipc = ["polars-io/ipc", "polars-plan/ipc", "polars-pipe?/ipc"] json = ["polars-io/json", "polars-plan/json", "polars-json"] -csv = ["polars-io/csv", "polars-plan/csv", "polars-pipe/csv"] +csv = ["polars-io/csv", "polars-plan/csv", "polars-pipe?/csv"] temporal = ["dtype-datetime", "dtype-date", "dtype-time", "dtype-duration", "polars-plan/temporal"] # debugging purposes fmt = ["polars-core/fmt", "polars-plan/fmt"] strings = ["polars-plan/strings"] future = [] -dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe/dtype-u8"] -dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe/dtype-u16"] -dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe/dtype-i8"] -dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe/dtype-i16"] -dtype-decimal = ["polars-plan/dtype-decimal", "polars-pipe/dtype-decimal"] +dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe?/dtype-u8"] +dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe?/dtype-u16"] +dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe?/dtype-i8"] +dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe?/dtype-i16"] +dtype-decimal = ["polars-plan/dtype-decimal", "polars-pipe?/dtype-decimal"] dtype-date = ["polars-plan/dtype-date", "polars-time/dtype-date", "temporal"] dtype-datetime = ["polars-plan/dtype-datetime", "polars-time/dtype-datetime", "temporal"] dtype-duration = ["polars-plan/dtype-duration", "polars-time/dtype-duration", "temporal"] dtype-time = ["polars-core/dtype-time", "temporal"] -dtype-array = ["polars-plan/dtype-array", "polars-pipe/dtype-array", "polars-ops/dtype-array"] -dtype-categorical = ["polars-plan/dtype-categorical", "polars-pipe/dtype-categorical"] +dtype-array = ["polars-plan/dtype-array", "polars-pipe?/dtype-array", "polars-ops/dtype-array"] +dtype-categorical = ["polars-plan/dtype-categorical", "polars-pipe?/dtype-categorical"] dtype-struct = ["polars-plan/dtype-struct"] object = ["polars-plan/object"] date_offset = ["polars-plan/date_offset"] @@ -81,10 +81,11 @@ approx_unique = ["polars-plan/approx_unique"] is_in = ["polars-plan/is_in", "polars-ops/is_in"] repeat_by = ["polars-plan/repeat_by"] round_series = ["polars-plan/round_series", "polars-ops/round_series"] -is_first = ["polars-plan/is_first"] +is_first_distinct = ["polars-plan/is_first_distinct"] +is_last_distinct = ["polars-plan/is_last_distinct"] is_unique = ["polars-plan/is_unique"] -cross_join = ["polars-plan/cross_join", "polars-pipe/cross_join", "polars-ops/cross_join"] -asof_join = ["polars-plan/asof_join", "polars-time"] +cross_join = ["polars-plan/cross_join", "polars-pipe?/cross_join", "polars-ops/cross_join"] +asof_join = ["polars-plan/asof_join", "polars-time", "polars-ops/asof_join"] concat_str = ["polars-plan/concat_str"] range = ["polars-plan/range"] mode = ["polars-plan/mode"] @@ -108,7 +109,7 @@ unique_counts = ["polars-plan/unique_counts"] log = ["polars-plan/log"] list_eval = [] cumulative_eval = [] -chunked_ids = ["polars-plan/chunked_ids", "polars-core/chunked_ids"] +chunked_ids = ["polars-plan/chunked_ids", "polars-core/chunked_ids", "polars-ops/chunked_ids"] list_to_struct = ["polars-plan/list_to_struct"] python = ["pyo3", "polars-plan/python", "polars-core/python", "polars-io/python"] row_hash = ["polars-plan/row_hash"] @@ -129,16 +130,18 @@ serde = [ "polars-plan/serde", "polars-arrow/serde", "polars-core/serde-lazy", - "polars-time/serde", + "polars-time?/serde", "polars-io/serde", "polars-ops/serde", ] fused = ["polars-plan/fused", "polars-ops/fused"] list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] +list_drop_nulls = ["polars-ops/list_drop_nulls", "polars-plan/list_drop_nulls"] cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] rle = ["polars-plan/rle", "polars-ops/rle"] extract_groups = ["polars-plan/extract_groups"] +peaks = ["polars-plan/peaks"] binary_encoding = ["polars-plan/binary_encoding"] diff --git a/crates/polars-lazy/README.md b/crates/polars-lazy/README.md index abfbaed58b21..81f5cb01d220 100644 --- a/crates/polars-lazy/README.md +++ b/crates/polars-lazy/README.md @@ -1,7 +1,5 @@ # polars-lazy -`polars-lazy` is a lazy query engine for the Polars DataFrame library. It allows you to perform operations on DataFrames in a lazy manner, only executing them when necessary. This can lead to significant performance improvements for large datasets. +`polars-lazy` serves as the lazy query engine for the [Polars](https://crates.io/crates/polars) DataFrame library. It allows you to perform operations on DataFrames in a lazy manner, only executing them when necessary. This can lead to significant performance improvements for large datasets. -## Features - -Please refer to the parent `polars` crate for a comprehensive list of features. +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs index d8eeb8fde8c1..95dbf6b5f97d 100644 --- a/crates/polars-lazy/src/dsl/eval.rs +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -7,7 +7,7 @@ use crate::physical_plan::state::ExecutionState; use crate::prelude::*; pub(crate) fn eval_field_to_dtype(f: &Field, expr: &Expr, list: bool) -> Field { - // dummy df to determine output dtype + // Dummy df to determine output dtype. let dtype = f .data_type() .inner_dtype() @@ -41,7 +41,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { /// Run an expression over a sliding window that increases `1` slot every iteration. /// /// # Warning - /// this can be really slow as it can have `O(n^2)` complexity. Don't use this for operations + /// This can be really slow as it can have `O(n^2)` complexity. Don't use this for operations /// that visit all elements. fn cumulative_eval(self, expr: Expr, min_periods: usize, parallel: bool) -> Expr { let this = self.into_expr(); @@ -50,7 +50,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { let name = s.name().to_string(); s.rename(""); - // ensure we get the new schema + // Ensure we get the new schema. let output_field = eval_field_to_dtype(s.field().as_ref(), &expr, false); let expr = expr.clone(); diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs index bc4247ce71ba..a62aa3b29448 100644 --- a/crates/polars-lazy/src/dsl/functions.rs +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -118,12 +118,11 @@ pub(crate) fn concat_impl>( #[cfg(feature = "diagonal_concat")] /// Concat [LazyFrame]s diagonally. /// Calls [`concat`][concat()] internally. -pub fn diag_concat_lf>( - lfs: L, - rechunk: bool, - parallel: bool, +pub fn concat_lf_diagonal>( + inputs: L, + args: UnionArgs, ) -> PolarsResult { - let lfs = lfs.as_ref().to_vec(); + let lfs = inputs.as_ref(); let schemas = lfs .iter() .map(|lf| lf.schema()) @@ -143,12 +142,12 @@ pub fn diag_concat_lf>( } }); } - let lfs_with_all_columns = lfs - .into_iter() + .iter() // Zip Frames with their Schemas .zip(schemas) - .map(|(mut lf, lf_schema)| { + .map(|(lf, lf_schema)| { + let mut lf = lf.clone(); for (name, dtype) in total_schema.iter() { // If a name from Total Schema is not present - append if lf_schema.get_field(name).is_none() { @@ -163,19 +162,11 @@ pub fn diag_concat_lf>( .map(|col_name| col(col_name)) .collect::>(), ); - Ok(reordered_lf) }) .collect::>>()?; - concat( - lfs_with_all_columns, - UnionArgs { - rechunk, - parallel, - to_supertypes: false, - }, - ) + concat(lfs_with_all_columns, args) } #[derive(Clone, Copy)] @@ -195,7 +186,7 @@ impl Default for UnionArgs { } } -/// Concat multiple +/// Concat multiple [`LazyFrame`]s vertically. pub fn concat>(inputs: L, args: UnionArgs) -> PolarsResult { concat_impl( inputs, @@ -206,7 +197,7 @@ pub fn concat>(inputs: L, args: UnionArgs) -> PolarsResult ) } -/// Collect all `LazyFrame` computations. +/// Collect all [`LazyFrame`] computations. pub fn collect_all(lfs: I) -> PolarsResult> where I: IntoParallelIterator, @@ -241,7 +232,15 @@ mod test { "d" => [1, 2] ]?; - let out = diag_concat_lf(&[a.lazy(), b.lazy(), c.lazy()], false, false)?.collect()?; + let out = concat_lf_diagonal( + &[a.lazy(), b.lazy(), c.lazy()], + UnionArgs { + rechunk: false, + parallel: false, + ..Default::default() + }, + )? + .collect()?; let expected = df![ "a" => [Some(1), Some(2), None, None, Some(5), Some(7)], diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs index 5e2b851c0904..82a9d2e35540 100644 --- a/crates/polars-lazy/src/dsl/list.rs +++ b/crates/polars-lazy/src/dsl/list.rs @@ -24,25 +24,24 @@ fn offsets_to_groups(offsets: &[i64]) -> Option { let mut start = offsets[0]; let end = *offsets.last().unwrap(); let fits_into_idx = (end - start) <= IdxSize::MAX as i64; + if !fits_into_idx { + return None; + } - if fits_into_idx { - let groups = offsets - .iter() - .skip(1) - .map(|end| { - let offset = start as IdxSize; - let len = (*end - start) as IdxSize; - start = *end; - [offset, len] - }) - .collect(); - Some(GroupsProxy::Slice { - groups, - rolling: false, + let groups = offsets + .iter() + .skip(1) + .map(|end| { + let offset = start as IdxSize; + let len = (*end - start) as IdxSize; + start = *end; + [offset, len] }) - } else { - None - } + .collect(); + Some(GroupsProxy::Slice { + groups, + rolling: false, + }) } fn run_per_sublist( @@ -75,7 +74,7 @@ fn run_per_sublist( }) }) .collect(); - err = m_err.lock().unwrap().take(); + err = m_err.into_inner().unwrap(); ca } else { let mut df_container = DataFrame::new_no_checks(vec![]); @@ -119,10 +118,10 @@ fn run_on_group_by_engine( let arr = lst.downcast_iter().next().unwrap(); let groups = offsets_to_groups(arr.offsets()).unwrap(); - // list elements in a series + // List elements in a series. let values = Series::try_from(("", arr.values().clone())).unwrap(); let inner_dtype = lst.inner_dtype(); - // ensure we use the logical type + // Ensure we use the logical type. let values = values.cast(&inner_dtype).unwrap(); let df_context = DataFrame::new_no_checks(vec![values]); @@ -130,15 +129,14 @@ fn run_on_group_by_engine( let state = ExecutionState::new(); let mut ac = phys_expr.evaluate_on_groups(&df_context, &groups, &state)?; - let mut out = match ac.agg_state() { + let out = match ac.agg_state() { AggState::AggregatedFlat(_) | AggState::Literal(_) => { let out = ac.aggregated(); out.as_list().into_series() }, _ => ac.aggregated(), }; - out.rename(name); - Ok(Some(out)) + Ok(Some(out.with_name(name))) } pub trait ListNameSpaceExtension: IntoListNameSpace + Sized { @@ -182,15 +180,10 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized { } let fits_idx_size = lst.get_values_size() <= (IdxSize::MAX as usize); - // if a users passes a return type to `apply` - // e.g. `return_dtype=pl.Int64` - // this fails as the list builder expects `List` - // so let's skip that for now + // If a users passes a return type to `apply`, e.g. `return_dtype=pl.Int64`, + // this fails as the list builder expects `List`, so let's skip that for now. let is_user_apply = || { - expr.into_iter().any(|e| match e { - Expr::AnonymousFunction { options, .. } => options.fmt_str == MAP_LIST_NAME, - _ => false, - }) + expr.into_iter().any(|e| matches!(e, Expr::AnonymousFunction { options, .. } if options.fmt_str == MAP_LIST_NAME)) }; if fits_idx_size && s.null_count() == 0 && !is_user_apply() { diff --git a/crates/polars-lazy/src/dsl/mod.rs b/crates/polars-lazy/src/dsl/mod.rs index 95d475e61ced..1093697c5644 100644 --- a/crates/polars-lazy/src/dsl/mod.rs +++ b/crates/polars-lazy/src/dsl/mod.rs @@ -2,15 +2,17 @@ //! //! This DSL revolves around the [`Expr`] type, which represents an abstract //! operation on a DataFrame, such as mapping over a column, filtering, group_by, or aggregation. -//! In general, functions on [`LazyFrame`](crate::frame::LazyFrame)s consume the LazyFrame and produce a new LazyFrame representing +//! In general, functions on [`LazyFrame`]s consume the [`LazyFrame`] and produce a new [`LazyFrame`] representing //! the result of applying the function and passed expressions to the consumed LazyFrame. //! At runtime, when [`LazyFrame::collect`](crate::frame::LazyFrame::collect) is called, the expressions that comprise -//! the LazyFrame's logical plan are materialized on the actual underlying Series. +//! the [`LazyFrame`]'s logical plan are materialized on the actual underlying Series. //! For instance, `let expr = col("x").pow(lit(2)).alias("x2");` would produce an expression representing the abstract //! operation of squaring the column `"x"` and naming the resulting column `"x2"`, and to apply this operation to a -//! LazyFrame, you'd use `let lazy_df = lazy_df.with_column(expr);`. +//! [`LazyFrame`], you'd use `let lazy_df = lazy_df.with_column(expr);`. //! (Of course, a column named `"x"` must either exist in the original DataFrame or be produced by one of the preceding -//! operations on the LazyFrame.) +//! operations on the [`LazyFrame`].) +//! +//! [`LazyFrame`]: crate::frame::LazyFrame //! //! There are many, many free functions that this module exports that produce an [`Expr`] from scratch; [`col`] and //! [`lit`] are two examples. @@ -28,7 +30,7 @@ //! that will yield an `f64` column (instead of `bool`), or `col("string") - col("f64")`, which would attempt //! to subtract an `f64` Series from a `string` Series. //! These kinds of invalid operations will only yield an error at runtime, when -//! [`collect`](crate::frame::LazyFrame::collect) is called on the LazyFrame. +//! [`collect`](crate::frame::LazyFrame::collect) is called on the [`LazyFrame`]. #[cfg(any(feature = "cumulative_eval", feature = "list_eval"))] mod eval; diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index dea0a304668b..9800fa38aeb1 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1,18 +1,8 @@ -//! Lazy variant of a [DataFrame](polars_core::frame::DataFrame). -#[cfg(feature = "csv")] -mod csv; -#[cfg(feature = "ipc")] -mod ipc; -#[cfg(feature = "json")] -mod ndjson; -#[cfg(feature = "parquet")] -mod parquet; +//! Lazy variant of a [DataFrame]. #[cfg(feature = "python")] mod python; -mod anonymous_scan; mod err; -mod file_list_reader; #[cfg(feature = "pivot")] pub mod pivot; @@ -33,9 +23,9 @@ pub use ndjson::*; pub use parquet::*; use polars_arrow::prelude::QuantileInterpolOptions; use polars_core::frame::explode::MeltArgs; -use polars_core::frame::hash_join::{JoinType, JoinValidation}; use polars_core::prelude::*; use polars_io::RowCount; +use polars_plan::dsl::all_horizontal; pub use polars_plan::frame::{AllowedOptimizations, OptState}; use polars_plan::global::FETCH_ROWS; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] @@ -46,7 +36,7 @@ use smartstring::alias::String as SmartString; use crate::fallible; use crate::physical_plan::executors::Executor; -use crate::physical_plan::planner::create_physical_plan; +use crate::physical_plan::planner::{create_physical_expr, create_physical_plan}; use crate::physical_plan::state::ExecutionState; #[cfg(feature = "streaming")] use crate::physical_plan::streaming::insert_streaming_nodes; @@ -140,6 +130,8 @@ impl LazyFrame { #[cfg(feature = "cse")] comm_subexpr_elim: false, streaming: false, + eager: false, + fast_projection: false, }) } @@ -193,6 +185,11 @@ impl LazyFrame { self } + pub fn _with_eager(mut self, toggle: bool) -> Self { + self.opt_state.eager = toggle; + self + } + /// Return a String describing the naive (un-optimized) logical plan. pub fn describe_plan(&self) -> String { self.logical_plan.describe() @@ -348,7 +345,7 @@ impl LazyFrame { /// } /// ``` pub fn reverse(self) -> Self { - self.select_local(vec![col("*").reverse()]) + self.select(vec![col("*").reverse()]) } /// Check the if the `names` are available in the `schema`, if not @@ -431,22 +428,14 @@ impl LazyFrame { I: IntoIterator, T: AsRef, { - let columns: Vec = columns + let to_drop = columns .into_iter() - .map(|name| name.as_ref().into()) - .collect(); - self.drop_columns_impl(columns) - } + .map(|s| s.as_ref().to_string()) + .collect::>(); - #[allow(clippy::ptr_arg)] - fn drop_columns_impl(self, columns: Vec) -> Self { - if let Some(lp) = self.check_names(&columns, None) { - lp - } else { - self.map_private(FunctionNode::Drop { - names: columns.into(), - }) - } + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().drop_columns(to_drop).build(); + Self::from_logical_plan(lp, opt_state) } /// Shift the values by a given period and fill the parts that will be empty due to this operation @@ -454,7 +443,7 @@ impl LazyFrame { /// /// See the method on [Series](polars_core::series::SeriesTrait::shift) for more info on the `shift` operation. pub fn shift(self, periods: i64) -> Self { - self.select_local(vec![col("*").shift(periods)]) + self.select(vec![col("*").shift(periods)]) } /// Shift the values by a given period and fill the parts that will be empty due to this operation @@ -462,7 +451,7 @@ impl LazyFrame { /// /// See the method on [Series](polars_core::series::SeriesTrait::shift) for more info on the `shift` operation. pub fn shift_and_fill>(self, periods: i64, fill_value: E) -> Self { - self.select_local(vec![col("*").shift_and_fill(periods, fill_value.into())]) + self.select(vec![col("*").shift_and_fill(periods, fill_value.into())]) } /// Fill None values in the DataFrame with an expression. @@ -562,7 +551,25 @@ impl LazyFrame { ); opt_state.comm_subplan_elim = false; } - let lp_top = optimize(self.logical_plan, opt_state, lp_arena, expr_arena, scratch)?; + let lp_top = optimize( + self.logical_plan, + opt_state, + lp_arena, + expr_arena, + scratch, + Some(&|node, expr_arena| { + let phys_expr = create_physical_expr( + node, + Context::Default, + expr_arena, + None, + &mut Default::default(), + ) + .ok()?; + let io_expr = phys_expr_to_io_expr(phys_expr); + Some(io_expr) + }), + )?; if streaming { #[cfg(feature = "streaming")] @@ -604,9 +611,9 @@ impl LazyFrame { None }; - // file sink should be replaced + // sink should be replaced let no_file_sink = if check_sink { - !matches!(lp_arena.get(lp_top), ALogicalPlan::FileSink { .. }) + !matches!(lp_arena.get(lp_top), ALogicalPlan::Sink { .. }) } else { true }; @@ -665,9 +672,9 @@ impl LazyFrame { #[cfg(feature = "parquet")] pub fn sink_parquet(mut self, path: PathBuf, options: ParquetWriteOptions) -> PolarsResult<()> { self.opt_state.streaming = true; - self.logical_plan = LogicalPlan::FileSink { + self.logical_plan = LogicalPlan::Sink { input: Box::new(self.logical_plan), - payload: FileSinkOptions { + payload: SinkType::File { path: Arc::new(path), file_type: FileType::Parquet(options), }, @@ -682,15 +689,44 @@ impl LazyFrame { Ok(()) } + /// Stream a query result into a parquet file on an ObjectStore-compatible cloud service. This is useful if the final result doesn't fit + /// into memory, and where you do not want to write to a local file but to a location in the cloud. + /// This method will return an error if the query cannot be completely done in a + /// streaming fashion. + #[cfg(all(feature = "cloud_write", feature = "parquet"))] + pub fn sink_parquet_cloud( + mut self, + uri: String, + cloud_options: Option, + parquet_options: ParquetWriteOptions, + ) -> PolarsResult<()> { + self.opt_state.streaming = true; + self.logical_plan = LogicalPlan::Sink { + input: Box::new(self.logical_plan), + payload: SinkType::Cloud { + uri: Arc::new(uri), + cloud_options, + file_type: FileType::Parquet(parquet_options), + }, + }; + let (mut state, mut physical_plan, is_streaming) = self.prepare_collect(true)?; + polars_ensure!( + is_streaming, + ComputeError: "cannot run the whole query in a streaming order" + ); + let _ = physical_plan.execute(&mut state)?; + Ok(()) + } + /// Stream a query result into an ipc/arrow file. This is useful if the final result doesn't fit /// into memory. This methods will return an error if the query cannot be completely done in a /// streaming fashion. #[cfg(feature = "ipc")] pub fn sink_ipc(mut self, path: PathBuf, options: IpcWriterOptions) -> PolarsResult<()> { self.opt_state.streaming = true; - self.logical_plan = LogicalPlan::FileSink { + self.logical_plan = LogicalPlan::Sink { input: Box::new(self.logical_plan), - payload: FileSinkOptions { + payload: SinkType::File { path: Arc::new(path), file_type: FileType::Ipc(options), }, @@ -711,9 +747,9 @@ impl LazyFrame { #[cfg(feature = "csv")] pub fn sink_csv(mut self, path: PathBuf, options: CsvWriterOptions) -> PolarsResult<()> { self.opt_state.streaming = true; - self.logical_plan = LogicalPlan::FileSink { + self.logical_plan = LogicalPlan::Sink { input: Box::new(self.logical_plan), - payload: FileSinkOptions { + payload: SinkType::File { path: Arc::new(path), file_type: FileType::Csv(options), }, @@ -752,7 +788,7 @@ impl LazyFrame { /// Select (and optionally rename, with [`alias`](crate::dsl::Expr::alias)) columns from the query. /// - /// Columns can be selected with [`col`](crate::dsl::col); + /// Columns can be selected with [`col`]; /// If you want to select all columns use `col("*")`. /// /// # Example @@ -777,7 +813,13 @@ impl LazyFrame { /// ``` pub fn select>(self, exprs: E) -> Self { let exprs = exprs.as_ref().to_vec(); - self.select_impl(exprs, ProjectionOptions { run_parallel: true }) + self.select_impl( + exprs, + ProjectionOptions { + run_parallel: true, + duplicate_check: true, + }, + ) } pub fn select_seq>(self, exprs: E) -> Self { @@ -786,6 +828,7 @@ impl LazyFrame { exprs, ProjectionOptions { run_parallel: false, + duplicate_check: true, }, ) } @@ -796,14 +839,6 @@ impl LazyFrame { Self::from_logical_plan(lp, opt_state) } - /// A projection that doesn't get optimized and may drop projections if they are not in - /// schema after optimization - fn select_local(self, exprs: Vec) -> Self { - let opt_state = self.get_opt_state(); - let lp = self.get_plan_builder().project_local(exprs).build(); - Self::from_logical_plan(lp, opt_state) - } - /// Performs a "group-by" on a `LazyFrame`, producing a [`LazyGroupBy`], which can subsequently be aggregated. /// /// Takes a list of expressions to group on. @@ -966,6 +1001,38 @@ impl LazyFrame { } } + /// Left anti join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn anti_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .anti_join(other, col("foo"), col("bar").cast(DataType::Utf8)) + /// } + /// ``` + #[cfg(feature = "semi_anti_join")] + pub fn anti_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Anti), + ) + } + + /// Creates the cartesian product from both frames, preserving the order of the left keys. + #[cfg(feature = "cross_join")] + pub fn cross_join(self, other: LazyFrame) -> LazyFrame { + self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross)) + } + /// Left join this query with another lazy query. /// /// Matches on the values of the expressions `left_on` and `right_on`. For more @@ -977,7 +1044,7 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// fn join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// fn left_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { /// ldf /// .left_join(other, col("foo"), col("bar")) /// } @@ -991,6 +1058,31 @@ impl LazyFrame { ) } + /// Inner join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn inner_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .inner_join(other, col("foo"), col("bar").cast(DataType::Utf8)) + /// } + /// ``` + pub fn inner_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Inner), + ) + } + /// Outer join this query with another lazy query. /// /// Matches on the values of the expressions `left_on` and `right_on`. For more @@ -1002,7 +1094,7 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// fn join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// fn outer_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { /// ldf /// .outer_join(other, col("foo"), col("bar")) /// } @@ -1016,7 +1108,7 @@ impl LazyFrame { ) } - /// Inner join this query with another lazy query. + /// Left semi join this query with another lazy query. /// /// Matches on the values of the expressions `left_on` and `right_on`. For more /// flexible join logic, see [`join`](LazyFrame::join) or @@ -1027,26 +1119,21 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// fn join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// fn semi_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { /// ldf - /// .inner_join(other, col("foo"), col("bar").cast(DataType::Utf8)) + /// .semi_join(other, col("foo"), col("bar").cast(DataType::Utf8)) /// } /// ``` - pub fn inner_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + #[cfg(feature = "semi_anti_join")] + pub fn semi_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { self.join( other, [left_on.into()], [right_on.into()], - JoinArgs::new(JoinType::Inner), + JoinArgs::new(JoinType::Semi), ) } - /// Creates the cartesian product from both frames, preserving the order of the left keys. - #[cfg(feature = "cross_join")] - pub fn cross_join(self, other: LazyFrame) -> LazyFrame { - self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross)) - } - /// Generic function to join two LazyFrames. /// /// `join` can join on multiple columns, given as two list of expressions, and with a @@ -1120,6 +1207,7 @@ impl LazyFrame { vec![expr], ProjectionOptions { run_parallel: false, + duplicate_check: true, }, ) .build(); @@ -1142,7 +1230,13 @@ impl LazyFrame { /// ``` pub fn with_columns>(self, exprs: E) -> LazyFrame { let exprs = exprs.as_ref().to_vec(); - self.with_columns_impl(exprs, ProjectionOptions { run_parallel: true }) + self.with_columns_impl( + exprs, + ProjectionOptions { + run_parallel: true, + duplicate_check: true, + }, + ) } /// Add multiple columns to a DataFrame, but evaluate them sequentially. @@ -1152,6 +1246,7 @@ impl LazyFrame { exprs, ProjectionOptions { run_parallel: false, + duplicate_check: true, }, ) } @@ -1177,14 +1272,14 @@ impl LazyFrame { /// /// Aggregated columns will have the same names as the original columns. pub fn max(self) -> LazyFrame { - self.select_local(vec![col("*").max()]) + self.select(vec![col("*").max()]) } /// Aggregate all the columns as their minimum values. /// /// Aggregated columns will have the same names as the original columns. pub fn min(self) -> LazyFrame { - self.select_local(vec![col("*").min()]) + self.select(vec![col("*").min()]) } /// Aggregate all the columns as their sum values. @@ -1197,7 +1292,7 @@ impl LazyFrame { /// silently wrap. /// - String columns will sum to None. pub fn sum(self) -> LazyFrame { - self.select_local(vec![col("*").sum()]) + self.select(vec![col("*").sum()]) } /// Aggregate all the columns as their mean values. @@ -1205,7 +1300,7 @@ impl LazyFrame { /// - Boolean and integer columns are converted to `f64` before computing the mean. /// - String columns will have a mean of None. pub fn mean(self) -> LazyFrame { - self.select_local(vec![col("*").mean()]) + self.select(vec![col("*").mean()]) } /// Aggregate all the columns as their median values. @@ -1214,12 +1309,12 @@ impl LazyFrame { /// susceptible to overflow before this conversion occurs. /// - String columns will sum to None. pub fn median(self) -> LazyFrame { - self.select_local(vec![col("*").median()]) + self.select(vec![col("*").median()]) } /// Aggregate all the columns as their quantile values. pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> LazyFrame { - self.select_local(vec![col("*").quantile(quantile, interpol)]) + self.select(vec![col("*").quantile(quantile, interpol)]) } /// Aggregate all the columns as their standard deviation values. @@ -1235,7 +1330,7 @@ impl LazyFrame { /// /// Source: [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.std.html#) pub fn std(self, ddof: u8) -> LazyFrame { - self.select_local(vec![col("*").std(ddof)]) + self.select(vec![col("*").std(ddof)]) } /// Aggregate all the columns as their variance values. @@ -1248,7 +1343,7 @@ impl LazyFrame { /// /// Source: [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.var.html#) pub fn var(self, ddof: u8) -> LazyFrame { - self.select_local(vec![col("*").var(ddof)]) + self.select(vec![col("*").var(ddof)]) } /// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode). @@ -1265,7 +1360,7 @@ impl LazyFrame { /// Aggregate all the columns as the sum of their null value count. pub fn null_count(self) -> LazyFrame { - self.select_local(vec![col("*").null_count()]) + self.select(vec![col("*").null_count()]) } /// Drop non-unique rows and maintain the order of kept rows. @@ -1454,8 +1549,9 @@ impl LazyFrame { LogicalPlan::Scan { file_options: options, file_info, + scan_type, .. - } => { + } if !matches!(scan_type, FileScan::Anonymous { .. }) => { options.row_count = Some(RowCount { name: name.to_string(), offset: offset.unwrap_or(0), @@ -1533,7 +1629,7 @@ pub struct LazyGroupBy { impl LazyGroupBy { /// Group by and aggregate. /// - /// Select a column with [col](crate::dsl::col) and choose an aggregation. + /// Select a column with [col] and choose an aggregation. /// If you want to aggregate all columns use `col("*")`. /// /// # Example diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index cc87a027b2e6..433261ab41ee 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -8,24 +8,30 @@ //! //! # Lazy DSL //! -//! The lazy API of polars replaces the eager `DataFrame` with the [`LazyFrame`](crate::frame::LazyFrame), through which +//! The lazy API of polars replaces the eager [`DataFrame`] with the [`LazyFrame`], through which //! the lazy API is exposed. -//! The `LazyFrame` represents a logical execution plan: a sequence of operations to perform on a concrete data source. -//! These operations are not executed until we call [`collect`](crate::frame::LazyFrame::collect). +//! The [`LazyFrame`] represents a logical execution plan: a sequence of operations to perform on a concrete data source. +//! These operations are not executed until we call [`collect`]. //! This allows polars to optimize/reorder the query which may lead to faster queries or fewer type errors. //! -//! In general, a `LazyFrame` requires a concrete data source — a `DataFrame`, a file on disk, etc. — which polars-lazy +//! [`DataFrame`]: polars_core::frame::DataFrame +//! [`LazyFrame`]: crate::frame::LazyFrame +//! [`collect`]: crate::frame::LazyFrame::collect +//! +//! In general, a [`LazyFrame`] requires a concrete data source — a [`DataFrame`], a file on disk, etc. — which polars-lazy //! then applies the user-specified sequence of operations to. -//! To obtain a `LazyFrame` from an existing `DataFrame`, we call the [`lazy`](crate::frame::IntoLazy::lazy) method on -//! the `DataFrame`. -//! A `LazyFrame` can also be obtained through the lazy versions of file readers, such as [`LazyCsvReader`](crate::frame::LazyCsvReader). +//! To obtain a [`LazyFrame`] from an existing [`DataFrame`], we call the [`lazy`](crate::frame::IntoLazy::lazy) method on +//! the [`DataFrame`]. +//! A [`LazyFrame`] can also be obtained through the lazy versions of file readers, such as [`LazyCsvReader`](crate::frame::LazyCsvReader). //! //! The other major component of the polars lazy API is [`Expr`](crate::dsl::Expr), which represents an operation to be -//! performed on a `LazyFrame`, such as mapping over a column, filtering, or groupby-aggregation. -//! `Expr` and the functions that produce them can be found in the [dsl module](crate::dsl). +//! performed on a [`LazyFrame`], such as mapping over a column, filtering, or groupby-aggregation. +//! [`Expr`] and the functions that produce them can be found in the [dsl module](crate::dsl). +//! +//! [`Expr`]: crate::dsl::Expr //! -//! Most operations on a `LazyFrame` consume the `LazyFrame` and return a new `LazyFrame` with the updated plan. -//! If you need to use the same `LazyFrame` multiple times, you should [`clone`](crate::frame::LazyFrame::clone) it, and optionally +//! Most operations on a [`LazyFrame`] consume the [`LazyFrame`] and return a new [`LazyFrame`] with the updated plan. +//! If you need to use the same [`LazyFrame`] multiple times, you should [`clone`](crate::frame::LazyFrame::clone) it, and optionally //! [`cache`](crate::frame::LazyFrame::cache) it beforehand. //! //! ## Examples @@ -200,6 +206,7 @@ pub mod dsl; pub mod frame; pub mod physical_plan; pub mod prelude; +mod scan; #[cfg(test)] mod tests; pub mod utils; diff --git a/crates/polars-lazy/src/physical_plan/executors/ext_context.rs b/crates/polars-lazy/src/physical_plan/executors/ext_context.rs index 7ad75181b89b..05d09f82b9a0 100644 --- a/crates/polars-lazy/src/physical_plan/executors/ext_context.rs +++ b/crates/polars-lazy/src/physical_plan/executors/ext_context.rs @@ -13,14 +13,15 @@ impl Executor for ExternalContext { println!("run ExternalContext") } } - let df = self.input.execute(state)?; + // we evaluate contexts first as input may has pushed exprs. let contexts = self .contexts .iter_mut() .map(|e| e.execute(state)) .collect::>>()?; - state.ext_contexts = Arc::new(contexts); + let df = self.input.execute(state)?; + Ok(df) } } diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs index b6d890cbac0a..8c9bef244bbf 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs @@ -17,6 +17,26 @@ pub(crate) struct GroupByRollingExec { pub(crate) apply: Option>, } +unsafe fn update_keys(keys: &mut [Series], groups: &GroupsProxy) { + match groups { + GroupsProxy::Idx(groups) => { + let first = groups.first(); + // we don't use agg_first here, because the group + // can be empty, but we still want to know the first value + // of that group + for key in keys.iter_mut() { + *key = key.take_unchecked_from_slice(first); + } + }, + GroupsProxy::Slice { groups, .. } => { + for key in keys.iter_mut() { + let indices = groups.iter().map(|[first, _len]| *first).collect_ca(""); + *key = key.take_unchecked(&indices); + } + }, + } +} + impl GroupByRollingExec { #[cfg(feature = "dynamic_group_by")] fn execute_impl( @@ -58,26 +78,7 @@ impl GroupByRollingExec { // the ordering has changed due to the group_by if !keys.is_empty() { - unsafe { - match groups { - GroupsProxy::Idx(groups) => { - let first = groups.first(); - // we don't use agg_first here, because the group - // can be empty, but we still want to know the first value - // of that group - for key in keys.iter_mut() { - *key = key.take_unchecked_from_slice(first).unwrap(); - } - }, - GroupsProxy::Slice { groups, .. } => { - for key in keys.iter_mut() { - let iter = &mut groups.iter().map(|[first, _len]| *first as usize) - as &mut dyn TakeIterator; - *key = key.take_iter_unchecked(iter); - } - }, - } - } + unsafe { update_keys(&mut keys, groups) } }; let agg_columns = evaluate_aggs(&df, &self.aggs, groups, state)?; diff --git a/crates/polars-lazy/src/physical_plan/executors/join.rs b/crates/polars-lazy/src/physical_plan/executors/join.rs index b4c31da43e63..fa84d46e7a84 100644 --- a/crates/polars-lazy/src/physical_plan/executors/join.rs +++ b/crates/polars-lazy/src/physical_plan/executors/join.rs @@ -1,5 +1,3 @@ -use polars_core::frame::hash_join::JoinArgs; - use super::*; pub struct JoinExec { diff --git a/crates/polars-lazy/src/physical_plan/executors/mod.rs b/crates/polars-lazy/src/physical_plan/executors/mod.rs index e43a99c2d53f..562c6c5b5e40 100644 --- a/crates/polars-lazy/src/physical_plan/executors/mod.rs +++ b/crates/polars-lazy/src/physical_plan/executors/mod.rs @@ -5,7 +5,7 @@ mod filter; mod group_by; mod group_by_dynamic; mod group_by_partitioned; -mod group_by_rolling; +pub(super) mod group_by_rolling; mod join; mod projection; mod projection_utils; @@ -36,7 +36,7 @@ pub(super) use self::group_by::*; pub(super) use self::group_by_dynamic::*; pub(super) use self::group_by_partitioned::*; #[cfg(feature = "dynamic_group_by")] -pub(super) use self::group_by_rolling::*; +pub(super) use self::group_by_rolling::GroupByRollingExec; pub(super) use self::join::*; pub(super) use self::projection::*; #[cfg(feature = "python")] diff --git a/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs b/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs index 70e33fb986a0..0b8277145fce 100644 --- a/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs +++ b/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs @@ -1,3 +1,5 @@ +use polars_utils::format_smartstring; +use polars_utils::iter::EnumerateIdxTrait; use smartstring::alias::String as SmartString; use super::*; @@ -17,6 +19,8 @@ pub(super) fn profile_name( } } +type IdAndExpression = (u32, Arc); + fn execute_projection_cached_window_fns( df: &DataFrame, exprs: &[Arc], @@ -32,31 +36,39 @@ fn execute_projection_cached_window_fns( #[allow(clippy::type_complexity)] // String: partition_name, // u32: index, - let mut windows: Vec<(String, Vec<(u32, Arc)>)> = vec![]; + let mut windows: PlHashMap> = PlHashMap::default(); + #[cfg(feature = "dynamic_group_by")] + let mut rolling: PlHashMap<&RollingGroupOptions, Vec> = PlHashMap::default(); let mut other = Vec::with_capacity(exprs.len()); // first we partition the window function by the values they group over. // the group_by values should be cached - let mut index = 0u32; - exprs.iter().for_each(|phys| { - index += 1; + exprs.iter().enumerate_u32().for_each(|(index, phys)| { let e = phys.as_expression().unwrap(); let mut is_window = false; for e in e.into_iter() { - if let Expr::Window { partition_by, .. } = e { - let group_by = format!("{:?}", partition_by.as_slice()); - if let Some(tpl) = windows.iter_mut().find(|tpl| tpl.0 == group_by) { - tpl.1.push((index, phys.clone())) - } else { - windows.push((group_by, vec![(index, phys.clone())])) - } + if let Expr::Window { + partition_by, + options, + .. + } = e + { + let entry = match options { + WindowType::Over(_) => { + let group_by = format_smartstring!("{:?}", partition_by.as_slice()); + windows.entry(group_by).or_insert_with(Vec::new) + }, + #[cfg(feature = "dynamic_group_by")] + WindowType::Rolling(options) => rolling.entry(options).or_insert_with(Vec::new), + }; + entry.push((index, phys.clone())); is_window = true; break; } } if !is_window { - other.push((index, phys)) + other.push((index, phys.as_ref())) } }); @@ -67,6 +79,31 @@ fn execute_projection_cached_window_fns( .collect::>>() })?; + // Run partitioned rolling expressions. + // Per partition we run in parallel. We compute the groups before and store them once per partition. + // The rolling expression knows how to fetch the groups. + #[cfg(feature = "dynamic_group_by")] + for (options, partition) in rolling { + // clear the cache for every partitioned group + let state = state.split(); + let (_time_key, _keys, groups) = df.group_by_rolling(vec![], options)?; + // Set the groups so all expressions in partition can use it. + // Create a separate scope, so the lock is dropped, otherwise we deadlock when the + // rolling expression try to get read access. + { + let mut groups_map = state.group_tuples.write().unwrap(); + groups_map.insert(options.index_column.to_string(), groups); + } + + let results = POOL.install(|| { + partition + .par_iter() + .map(|(idx, expr)| expr.evaluate(df, &state).map(|s| (*idx, s))) + .collect::>>() + })?; + selected_columns.extend_from_slice(&results); + } + for partition in windows { // clear the cache for every partitioned group let mut state = state.split(); @@ -185,7 +222,9 @@ pub(super) fn check_expand_literals( mut selected_columns: Vec, zero_length: bool, ) -> PolarsResult { - let first_len = selected_columns[0].len(); + let Some(first_len) = selected_columns.get(0).map(|s| s.len()) else { + return Ok(DataFrame::empty()); + }; let mut df_height = 0; let mut all_equal_len = true; { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs index 80b2b2e3aa95..8542432b31bc 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs @@ -29,7 +29,7 @@ impl CsvExec { .unwrap() .has_header(self.options.has_header) .with_dtypes(Some(self.schema.clone())) - .with_delimiter(self.options.delimiter) + .with_separator(self.options.separator) .with_ignore_errors(self.options.ignore_errors) .with_skip_rows(self.options.skip_rows) .with_n_rows(n_rows) diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index 79c7b43dccc6..e5ee49c06a16 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -19,8 +19,9 @@ impl IpcExec { &mut self.schema, self.file_options.n_rows, self.file_options.row_count.is_some(), + None, ); - IpcReader::new(file) + IpcReader::new(file.unwrap()) .with_n_rows(n_rows) .with_row_count(std::mem::take(&mut self.file_options.row_count)) .set_rechunk(self.file_options.rechunk) diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs b/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs index 07788ff10a74..2c6d8ac24d88 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs @@ -8,6 +8,7 @@ mod ndjson; mod parquet; use std::mem; +use std::ops::Deref; #[cfg(feature = "csv")] pub(crate) use csv::CsvExec; @@ -41,18 +42,19 @@ fn prepare_scan_args( schema: &mut SchemaRef, n_rows: Option, has_row_count: bool, -) -> (std::fs::File, Projection, StopNRows, Predicate) { - let file = std::fs::File::open(path).unwrap(); + hive_partitions: Option<&[Series]>, +) -> (Option, Projection, StopNRows, Predicate) { + let file = std::fs::File::open(path).ok(); let with_columns = mem::take(with_columns); let schema = mem::take(schema); - let projection: Option> = with_columns.map(|with_columns| { - with_columns - .iter() - .map(|name| schema.index_of(name).unwrap() - has_row_count as usize) - .collect() - }); + let projection = materialize_projection( + with_columns.as_deref().map(|cols| cols.deref()), + &schema, + hive_partitions, + has_row_count, + ); let n_rows = _set_n_rows_for_scan(n_rows); let predicate = predicate.clone().map(phys_expr_to_io_expr); @@ -102,43 +104,51 @@ impl Executor for DataFrameExec { pub(crate) struct AnonymousScanExec { pub(crate) function: Arc, - pub(crate) options: AnonymousScanOptions, + pub(crate) file_options: FileScanOptions, + pub(crate) file_info: FileInfo, pub(crate) predicate: Option>, + pub(crate) output_schema: Option, pub(crate) predicate_has_windows: bool, } impl Executor for AnonymousScanExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + let mut args = AnonymousScanArgs { + n_rows: self.file_options.n_rows, + with_columns: self.file_options.with_columns.clone(), + schema: self.file_info.schema.clone(), + output_schema: self.output_schema.clone(), + predicate: None, + }; + if self.predicate.is_some() { + state.insert_has_window_function_flag() + } + match (self.function.allows_predicate_pushdown(), &self.predicate) { (true, Some(predicate)) => state.record( || { - self.options.predicate = predicate.as_expression().cloned(); - self.function.scan(self.options.clone()) + args.predicate = predicate.as_expression().cloned(); + self.function.scan(args) }, "anonymous_scan".into(), ), - (false, Some(predicate)) => { - if self.predicate_has_windows { - state.insert_has_window_function_flag() - } - state.record(|| { - let mut df = self.function.scan(self.options.clone())?; - let s = predicate.evaluate(&df, state)?; - if self.predicate_has_windows { - state.clear_window_expr_cache() - } - let mask = s.bool().map_err( - |_| polars_err!(ComputeError: "filter predicate was not of type boolean"), - )?; - df = df.filter(mask)?; - - Ok(df) - },"anonymous_scan".into()) - }, - _ => state.record( - || self.function.scan(self.options.clone()), + (false, Some(predicate)) => state.record( + || { + let mut df = self.function.scan(args)?; + let s = predicate.evaluate(&df, state)?; + if self.predicate_has_windows { + state.clear_window_expr_cache() + } + let mask = s.bool().map_err( + |_| polars_err!(ComputeError: "filter predicate was not of type boolean"), + )?; + df = df.filter(mask)?; + + Ok(df) + }, "anonymous_scan".into(), ), + _ => state.record(|| self.function.scan(args), "anonymous_scan".into()), } } } diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs index a6f6cedda2f7..2d2a6368cd2b 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs @@ -1,13 +1,13 @@ use polars_core::error::to_compute_err; use super::*; -use crate::prelude::{AnonymousScan, AnonymousScanOptions, LazyJsonLineReader}; +use crate::prelude::{AnonymousScan, LazyJsonLineReader}; impl AnonymousScan for LazyJsonLineReader { fn as_any(&self) -> &dyn std::any::Any { self } - fn scan(&self, scan_opts: AnonymousScanOptions) -> PolarsResult { + fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult { let schema = scan_opts.output_schema.unwrap_or(scan_opts.schema); JsonLineReader::from_path(&self.path)? .with_schema(schema) @@ -19,7 +19,12 @@ impl AnonymousScan for LazyJsonLineReader { .finish() } - fn schema(&self, infer_schema_length: Option) -> PolarsResult { + fn schema(&self, infer_schema_length: Option) -> PolarsResult { + // Short-circuit schema inference if the schema has been explicitly provided. + if let Some(schema) = &self.schema { + return Ok(schema.clone()); + } + let f = polars_utils::open_file(&self.path)?; let mut reader = std::io::BufReader::new(f); @@ -27,7 +32,7 @@ impl AnonymousScan for LazyJsonLineReader { polars_json::ndjson::infer(&mut reader, infer_schema_length).map_err(to_compute_err)?; let schema = Schema::from_iter(StructArray::get_fields(&data_type)); - Ok(schema) + Ok(Arc::new(schema)) } fn allows_projection_pushdown(&self) -> bool { true diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 987b92fbdcff..4bbae2254a87 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -1,56 +1,97 @@ use std::path::PathBuf; -use polars_core::cloud::CloudOptions; +use polars_core::utils::arrow::io::parquet::read::FileMetaData; +use polars_io::cloud::CloudOptions; +use polars_io::is_cloud_url; use super::*; -#[allow(dead_code)] pub struct ParquetExec { path: PathBuf, - schema: SchemaRef, + file_info: FileInfo, predicate: Option>, options: ParquetOptions, + #[allow(dead_code)] cloud_options: Option, file_options: FileScanOptions, + metadata: Option>, } impl ParquetExec { pub(crate) fn new( path: PathBuf, - schema: SchemaRef, + file_info: FileInfo, predicate: Option>, options: ParquetOptions, cloud_options: Option, file_options: FileScanOptions, + metadata: Option>, ) -> Self { ParquetExec { path, - schema, + file_info, predicate, options, cloud_options, file_options, + metadata, } } fn read(&mut self) -> PolarsResult { + let hive_partitions = self + .file_info + .hive_parts + .as_ref() + .map(|hive| hive.materialize_partition_columns()); + let (file, projection, n_rows, predicate) = prepare_scan_args( &self.path, &self.predicate, &mut self.file_options.with_columns, - &mut self.schema, + &mut self.file_info.schema.clone(), self.file_options.n_rows, self.file_options.row_count.is_some(), + hive_partitions.as_deref(), ); - ParquetReader::new(file) - .with_n_rows(n_rows) - .read_parallel(self.options.parallel) - .with_row_count(mem::take(&mut self.file_options.row_count)) - .set_rechunk(self.file_options.rechunk) - .set_low_memory(self.options.low_memory) - .use_statistics(self.options.use_statistics) - ._finish_with_scan_ops(predicate, projection.as_ref().map(|v| v.as_ref())) + if let Some(file) = file { + ParquetReader::new(file) + .with_n_rows(n_rows) + .read_parallel(self.options.parallel) + .with_row_count(mem::take(&mut self.file_options.row_count)) + .set_rechunk(self.file_options.rechunk) + .set_low_memory(self.options.low_memory) + .use_statistics(self.options.use_statistics) + .with_hive_partition_columns(hive_partitions) + ._finish_with_scan_ops(predicate, projection.as_ref().map(|v| v.as_ref())) + } else if is_cloud_url(self.path.as_path()) { + #[cfg(feature = "cloud")] + { + polars_io::pl_async::get_runtime().block_on_potential_spawn(async { + let reader = ParquetAsyncReader::from_uri( + &self.path.to_string_lossy(), + self.cloud_options.as_ref(), + Some(self.file_info.schema.clone()), + self.metadata.clone(), + ) + .await? + .with_n_rows(n_rows) + .with_row_count(mem::take(&mut self.file_options.row_count)) + .with_projection(projection) + .use_statistics(self.options.use_statistics) + .with_hive_partition_columns(hive_partitions); + + reader.finish(predicate).await + }) + } + #[cfg(not(feature = "cloud"))] + { + panic!("activate cloud feature") + } + } else { + polars_bail!(ComputeError: "could not read {}", self.path.display()) + } } } diff --git a/crates/polars-lazy/src/physical_plan/executors/sort.rs b/crates/polars-lazy/src/physical_plan/executors/sort.rs index b3c64d0ab61b..6d31fe41ea84 100644 --- a/crates/polars-lazy/src/physical_plan/executors/sort.rs +++ b/crates/polars-lazy/src/physical_plan/executors/sort.rs @@ -20,10 +20,10 @@ impl SortExec { .enumerate() .map(|(i, e)| { let mut s = e.evaluate(&df, state)?; - // polars core will try to set the sorted columns as sorted - // this should only be done with simple col("foo") expressions + // Polars core will try to set the sorted columns as sorted. + // This should only be done with simple col("foo") expressions, // therefore we rename more complex expressions so that - // polars core does not match these + // polars core does not match these. if !matches!(e.as_expression(), Some(&Expr::Column(_))) { s.rename(&format!("_POLARS_SORT_BY_{i}")); } diff --git a/crates/polars-lazy/src/physical_plan/executors/union.rs b/crates/polars-lazy/src/physical_plan/executors/union.rs index 0d175904bce0..924218d0f837 100644 --- a/crates/polars-lazy/src/physical_plan/executors/union.rs +++ b/crates/polars-lazy/src/physical_plan/executors/union.rs @@ -103,7 +103,13 @@ impl Executor for UnionExec { .collect::>>() }); - concat_df(out?.iter().flat_map(|dfs| dfs.iter())) + concat_df(out?.iter().flat_map(|dfs| dfs.iter())).map(|df| { + if let Some((offset, len)) = self.options.slice { + df.slice(offset, len) + } else { + df + } + }) } .map(|mut df| { if self.options.rechunk { diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index 3f7b700dde44..4689afa2dd2b 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -61,7 +61,7 @@ impl PhysicalExpr for AggregationExpr { let out = rename_series(agg_s, &keep_name); return Ok(AggregationContext::new(out, Cow::Borrowed(groups), true)) } else { - polars_bail!(ComputeError: "cannot aggregate as {}, the column is already aggregated"); + polars_bail!(ComputeError: "cannot aggregate as {}, the column is already aggregated", self.agg_type); } }, _ => () @@ -427,7 +427,7 @@ impl PartitionedAggregation for AggregationExpr { let ca = unsafe { // Safety // The indexes of the group_by operation are never out of bounds - ca.take_unchecked(idx.into()) + ca.take_unchecked(idx) }; process_group(ca)?; } diff --git a/crates/polars-lazy/src/physical_plan/expressions/alias.rs b/crates/polars-lazy/src/physical_plan/expressions/alias.rs index d9cc5cb73511..5d89fc009236 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/alias.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/alias.rs @@ -20,9 +20,9 @@ impl AliasExpr { expr, } } - fn finish(&self, mut input: Series) -> Series { - input.rename(&self.name); - input + + fn finish(&self, input: Series) -> Series { + input.with_name(&self.name) } } @@ -68,6 +68,7 @@ impl PhysicalExpr for AliasExpr { fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { Some(self) } + fn is_valid_aggregation(&self) -> bool { self.physical_expr.is_valid_aggregation() } @@ -81,10 +82,8 @@ impl PartitionedAggregation for AliasExpr { state: &ExecutionState, ) -> PolarsResult { let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); - agg.evaluate_partitioned(df, groups, state).map(|mut s| { - s.rename(&self.name); - s - }) + let s = agg.evaluate_partitioned(df, groups, state)?; + Ok(s.with_name(&self.name)) } fn finalize( @@ -94,9 +93,7 @@ impl PartitionedAggregation for AliasExpr { state: &ExecutionState, ) -> PolarsResult { let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); - agg.finalize(partitioned, groups, state).map(|mut s| { - s.rename(&self.name); - s - }) + let s = agg.finalize(partitioned, groups, state)?; + Ok(s.with_name(&self.name)) } } diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index b24d14c48e60..ef9b129ad9b1 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -5,9 +5,7 @@ use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "parquet")] -use polars_io::parquet::predicates::BatchStats; -#[cfg(feature = "parquet")] -use polars_io::predicates::StatsEvaluator; +use polars_io::predicates::{BatchStats, StatsEvaluator}; #[cfg(feature = "parquet")] use polars_plan::dsl::FunctionExpr; use rayon::prelude::*; @@ -89,14 +87,14 @@ impl ApplyExpr { } } - /// evaluates and flattens `Option` to `Series`. + /// Evaluates and flattens `Option` to `Series`. fn eval_and_flatten(&self, inputs: &mut [Series]) -> PolarsResult { - self.function.call_udf(inputs).map(|opt_out| { - opt_out.unwrap_or_else(|| { - let field = self.to_field(self.input_schema.as_ref().unwrap()).unwrap(); - Series::full_null(field.name(), 1, field.data_type()) - }) - }) + if let Some(out) = self.function.call_udf(inputs)? { + Ok(out) + } else { + let field = self.to_field(self.input_schema.as_ref().unwrap()).unwrap(); + Ok(Series::full_null(field.name(), 1, field.data_type())) + } } fn apply_single_group_aware<'a>( &self, @@ -112,14 +110,11 @@ impl ApplyExpr { let name = s.name().to_string(); let agg = ac.aggregated(); - // collection of empty list leads to a null dtype - // see: #3687 + // Collection of empty list leads to a null dtype. See: #3687. if agg.len() == 0 { - // create input for the function to determine the output dtype - // see #3946 + // Create input for the function to determine the output dtype, see #3946. let agg = agg.list().unwrap(); let input_dtype = agg.inner_dtype(); - let input = Series::full_null("", 0, &input_dtype); let output = self.eval_and_flatten(&mut [input])?; @@ -133,12 +128,11 @@ impl ApplyExpr { if self.pass_name_to_apply { s.rename(&name); } - let mut container = [s]; - self.function.call_udf(&mut container) + self.function.call_udf(&mut [s]) }, }; - let mut ca: ListChunked = if self.allow_threading { + let ca: ListChunked = if self.allow_threading { POOL.install(|| { agg.list() .unwrap() @@ -154,11 +148,10 @@ impl ApplyExpr { .collect::>()? }; - ca.rename(&name); - self.finish_apply_groups(ac, ca) + self.finish_apply_groups(ac, ca.with_name(&name)) } - /// Apply elementwise e.g. ignore the group/list indices + /// Apply elementwise e.g. ignore the group/list indices. fn apply_single_elementwise<'a>( &self, mut ac: AggregationContext<'a>, @@ -189,27 +182,29 @@ impl ApplyExpr { let schema = self.get_input_schema(df); let field = self.to_field(&schema)?; - // aggregate representation of the aggregation contexts + // Aggregate representation of the aggregation contexts, // then unpack the lists and finally create iterators from this list chunked arrays. let mut iters = acs .iter_mut() - .map(|ac| ac.iter_groups(self.pass_name_to_apply)) + .map(|ac| + // SAFETY: unstable series never lives longer than the iterator. + unsafe { ac.iter_groups(self.pass_name_to_apply) }) .collect::>(); - // length of the items to iterate over + // Length of the items to iterate over. let len = iters[0].size_hint().0; if len == 0 { let out = Series::full_null(field.name(), 0, &field.dtype); - drop(iters); - // take the first aggregation context that as that is the input series + + // Take the first aggregation context that as that is the input series. let mut ac = acs.swap_remove(0); ac.with_series(out, true, Some(&self.expr))?; return Ok(ac); } - let mut ca: ListChunked = (0..len) + let ca = (0..len) .map(|_| { container.clear(); for iter in &mut iters { @@ -220,12 +215,12 @@ impl ApplyExpr { } self.function.call_udf(&mut container) }) - .collect::>()?; + .collect::>()? + .with_name(&field.name); - ca.rename(&field.name); drop(iters); - // take the first aggregation context that as that is the input series + // Take the first aggregation context that as that is the input series. let ac = acs.swap_remove(0); self.finish_apply_groups(ac, ca) } @@ -266,15 +261,13 @@ impl PhysicalExpr for ApplyExpr { }?; if self.allow_rename { - return self.eval_and_flatten(&mut inputs); - } - let in_name = inputs[0].name().to_string(); - let mut out = self.eval_and_flatten(&mut inputs)?; - if in_name != out.name() { - out.rename(&in_name); + self.eval_and_flatten(&mut inputs) + } else { + let in_name = inputs[0].name().to_string(); + Ok(self.eval_and_flatten(&mut inputs)?.with_name(&in_name)) } - Ok(out) } + #[allow(clippy::ptr_arg)] fn evaluate_on_groups<'a>( &self, @@ -367,8 +360,8 @@ fn apply_multiple_elementwise<'a>( check_lengths: bool, ) -> PolarsResult> { match acs.first().unwrap().agg_state() { - // a fast path that doesn't drop groups of the first arg - // this doesn't require group re-computation + // A fast path that doesn't drop groups of the first arg. + // This doesn't require group re-computation. AggState::AggregatedList(s) => { let ca = s.list().unwrap(); @@ -378,10 +371,10 @@ fn apply_multiple_elementwise<'a>( .collect::>(); let out = ca.apply_to_inner(&|s| { - let mut args = vec![s]; + let mut args = Vec::with_capacity(other.len() + 1); + args.push(s); args.extend_from_slice(&other); - let out = function.call_udf(&mut args)?.unwrap(); - Ok(out) + Ok(function.call_udf(&mut args)?.unwrap()) })?; let mut ac = acs.swap_remove(0); ac.with_series(out.into_series(), true, None)?; @@ -392,9 +385,8 @@ fn apply_multiple_elementwise<'a>( .iter_mut() .enumerate() .map(|(i, ac)| { - // make sure the groups are updated because we are about to throw away - // the series length information - // only on first iteration + // Make sure the groups are updated because we are about to throw away + // the series length information, only on the first iteration. if let (0, UpdateGroups::WithSeriesLen) = (i, &ac.update_groups) { ac.groups(); } @@ -409,7 +401,7 @@ fn apply_multiple_elementwise<'a>( check_map_output_len(input_len, s.len(), expr)?; } - // take the first aggregation context that as that is the input series + // Take the first aggregation context that as that is the input series. let mut ac = acs.swap_remove(0); ac.with_series_and_args(s, false, None, true)?; Ok(ac) @@ -421,14 +413,13 @@ fn apply_multiple_elementwise<'a>( impl StatsEvaluator for ApplyExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { let read = self.should_read_impl(stats)?; - - let state = ExecutionState::new(); - - if state.verbose() && read { - eprintln!("parquet file must be read, statistics not sufficient for predicate.") - } else if state.verbose() && !read { - eprintln!("parquet file can be skipped, the statistics were sufficient to apply the predicate.") - }; + if ExecutionState::new().verbose() { + if read { + eprintln!("parquet file must be read, statistics not sufficient for predicate.") + } else { + eprintln!("parquet file can be skipped, the statistics were sufficient to apply the predicate.") + } + } Ok(read) } @@ -443,8 +434,8 @@ impl ApplyExpr { } => (function, input), _ => return Ok(true), }; - // ensure the input of the function is only a `col(..)` - // if it does any arithmetic the code below is flawed + // Ensure the input of the function is only a `col(..)`. + // If it does any arithmetic the code below is flawed. if !matches!(input[0], Expr::Column(_)) { return Ok(true); } @@ -463,52 +454,23 @@ impl ApplyExpr { }, #[cfg(feature = "is_in")] FunctionExpr::Boolean(BooleanFunction::IsIn) => { - let root = match expr_to_leaf_column_name(&input[0]) { - Ok(root) => root, - Err(_) => return Ok(true), - }; - - let input: &Series = match &input[1] { - Expr::Literal(LiteralValue::Series(s)) => s, - _ => return Ok(true), + let should_read = || -> Option { + let root = expr_to_leaf_column_name(&input[0]).ok()?; + let Expr::Literal(LiteralValue::Series(input)) = &input[1] else { + return None; + }; + #[allow(clippy::explicit_auto_deref)] + let input: &Series = &**input; + let st = stats.get_stats(&root).ok()?; + let min = st.to_min()?; + let max = st.to_max()?; + + let all_smaller = || Some(ChunkCompare::lt(input, min).ok()?.all()); + let all_bigger = || Some(ChunkCompare::gt(input, max).ok()?.all()); + Some(!all_smaller()? && !all_bigger()?) }; - match stats.get_stats(&root).ok() { - Some(st) => { - let min = match st.to_min() { - Some(min) => min, - None => return Ok(true), - }; - - let max = match st.to_max() { - Some(max) => max, - None => return Ok(true), - }; - - // all wanted values are smaller than minimum - // don't need to read - if ChunkCompare::<&Series>::lt(input, &min) - .ok() - .map(|ca| ca.all()) - == Some(true) - { - return Ok(false); - } - - // all wanted values are bigger than maximum - // don't need to read - if ChunkCompare::<&Series>::gt(input, &max) - .ok() - .map(|ca| ca.all()) - == Some(true) - { - return Ok(false); - } - - Ok(true) - }, - None => Ok(true), - } + Ok(should_read().unwrap_or(true)) }, _ => Ok(true), } @@ -526,14 +488,11 @@ impl PartitionedAggregation for ApplyExpr { let s = a.evaluate_partitioned(df, groups, state)?; if self.allow_rename { - return self.eval_and_flatten(&mut [s]); - } - let in_name = s.name().to_string(); - let mut out = self.eval_and_flatten(&mut [s])?; - if in_name != out.name() { - out.rename(&in_name); + self.eval_and_flatten(&mut [s]) + } else { + let in_name = s.name().to_string(); + Ok(self.eval_and_flatten(&mut [s])?.with_name(&in_name)) } - Ok(out) } fn finalize( diff --git a/crates/polars-lazy/src/physical_plan/expressions/binary.rs b/crates/polars-lazy/src/physical_plan/expressions/binary.rs index 74c1b7833b7a..502562f9056e 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/binary.rs @@ -46,14 +46,12 @@ fn apply_operator_owned(left: Series, right: Series, op: Operator) -> PolarsResu pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResult { use DataType::*; match op { - Operator::Gt => ChunkCompare::<&Series>::gt(left, right).map(|ca| ca.into_series()), - Operator::GtEq => ChunkCompare::<&Series>::gt_eq(left, right).map(|ca| ca.into_series()), - Operator::Lt => ChunkCompare::<&Series>::lt(left, right).map(|ca| ca.into_series()), - Operator::LtEq => ChunkCompare::<&Series>::lt_eq(left, right).map(|ca| ca.into_series()), - Operator::Eq => ChunkCompare::<&Series>::equal(left, right).map(|ca| ca.into_series()), - Operator::NotEq => { - ChunkCompare::<&Series>::not_equal(left, right).map(|ca| ca.into_series()) - }, + Operator::Gt => ChunkCompare::gt(left, right).map(|ca| ca.into_series()), + Operator::GtEq => ChunkCompare::gt_eq(left, right).map(|ca| ca.into_series()), + Operator::Lt => ChunkCompare::lt(left, right).map(|ca| ca.into_series()), + Operator::LtEq => ChunkCompare::lt_eq(left, right).map(|ca| ca.into_series()), + Operator::Eq => ChunkCompare::equal(left, right).map(|ca| ca.into_series()), + Operator::NotEq => ChunkCompare::not_equal(left, right).map(|ca| ca.into_series()), Operator::Plus => Ok(left + right), Operator::Minus => Ok(left - right), Operator::Multiply => Ok(left * right), @@ -90,18 +88,14 @@ impl BinaryExpr { ac_r: AggregationContext, aggregated: bool, ) -> PolarsResult> { - // we want to be able to mutate in place - // so we take the lhs to make sure that we drop + // We want to be able to mutate in place, so we take the lhs to make sure that we drop. let lhs = ac_l.series().clone(); let rhs = ac_r.series().clone(); - // drop lhs so that we might operate in place - { - let _ = ac_l.take(); - } + // Drop lhs so that we might operate in place. + drop(ac_l.take()); let out = apply_operator_owned(lhs, rhs, self.op)?; - ac_l.with_series(out, aggregated, Some(&self.expr))?; Ok(ac_l) } @@ -112,33 +106,26 @@ impl BinaryExpr { mut ac_r: AggregationContext<'a>, ) -> PolarsResult> { let name = ac_l.series().name().to_string(); - let mut ca: ListChunked = ac_l - .iter_groups(false) - .zip(ac_r.iter_groups(false)) - .map(|(l, r)| { - match (l, r) { - (Some(l), Some(r)) => { - let l = l.as_ref(); - let r = r.as_ref(); - Some(apply_operator(l, r, self.op)) - }, - _ => None, - } - .transpose() - }) - .collect::>()?; - ca.rename(&name); + // SAFETY: unstable series never lives longer than the iterator. + let ca = unsafe { + ac_l.iter_groups(false) + .zip(ac_r.iter_groups(false)) + .map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op))) + .map(|opt_res| opt_res.transpose()) + .collect::>()? + .with_name(&name) + }; - // try if we can reuse the groups + // Try if we can reuse the groups. use AggState::*; match (ac_l.agg_state(), ac_r.agg_state()) { - // no need to change update groups + // No need to change update groups. (AggregatedList(_), _) => {}, - // we can take the groups of the rhs + // We can take the groups of the rhs. (_, AggregatedList(_)) if matches!(ac_r.update_groups, UpdateGroups::No) => { ac_l.groups = ac_r.groups }, - // we must update the groups + // We must update the groups. _ => { ac_l.with_update_groups(UpdateGroups::WithSeriesLen); }, @@ -155,43 +142,38 @@ impl PhysicalExpr for BinaryExpr { } fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { - // window functions may set a global state that determine their output + // Window functions may set a global state that determine their output // state, so we don't let them run in parallel as they race - // they also saturate the thread pool by themselves, so that's fine + // they also saturate the thread pool by themselves, so that's fine. let has_window = state.has_window(); - // streaming takes care of parallelism, don't parallelize here, as it - // increases contention + // Streaming takes care of parallelism, don't parallelize here, as it + // increases contention. #[cfg(feature = "streaming")] let in_streaming = state.in_streaming_engine(); #[cfg(not(feature = "streaming"))] let in_streaming = false; - let (lhs, rhs) = if has_window { + let (lhs, rhs); + if has_window { let mut state = state.split(); state.remove_cache_window_flag(); - ( - self.left.evaluate(df, &state), - self.right.evaluate(df, &state), - ) - } - // literals are free, don't pay par cost - else if in_streaming || self.has_literal { - ( - self.left.evaluate(df, state), - self.right.evaluate(df, state), - ) + lhs = self.left.evaluate(df, &state)?; + rhs = self.right.evaluate(df, &state)?; + } else if in_streaming || self.has_literal { + // Literals are free, don't pay par cost. + lhs = self.left.evaluate(df, state)?; + rhs = self.right.evaluate(df, state)?; } else { - POOL.install(|| { + let (opt_lhs, opt_rhs) = POOL.install(|| { rayon::join( || self.left.evaluate(df, state), || self.right.evaluate(df, state), ) - }) + }); + (lhs, rhs) = (opt_lhs?, opt_rhs?); }; - let lhs = lhs?; - let rhs = rhs?; polars_ensure!( lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1, expr = self.expr, @@ -256,133 +238,113 @@ impl PhysicalExpr for BinaryExpr { } fn is_valid_aggregation(&self) -> bool { - // we don't want: - // col(a) == lit(1) - - // we do want - // col(a).sum() == lit(1) + // We don't want: col(a) == lit(1). + // We do want col(a).sum() == lit(1). (!self.left.is_literal() && self.left.is_valid_aggregation()) - | (!self.right.is_literal() && self.right.is_valid_aggregation()) + || (!self.right.is_literal() && self.right.is_valid_aggregation()) } } #[cfg(feature = "parquet")] mod stats { - use polars_io::parquet::predicates::BatchStats; - use polars_io::predicates::StatsEvaluator; + use polars_io::predicates::{BatchStats, StatsEvaluator}; use super::*; fn apply_operator_stats_eq(min_max: &Series, literal: &Series) -> bool { - // literal is greater than max, don't need to read - if ChunkCompare::<&Series>::gt(literal, min_max) - .ok() - .map(|s| s.all()) - == Some(true) - { + use ChunkCompare as C; + // Literal is greater than max, don't need to read. + if C::gt(literal, min_max).map(|s| s.all()).unwrap_or(false) { return false; } - // literal is smaller than min, don't need to read - if ChunkCompare::<&Series>::lt(literal, min_max) - .ok() - .map(|s| s.all()) - == Some(true) - { + // Literal is smaller than min, don't need to read. + if C::lt(literal, min_max).map(|s| s.all()).unwrap_or(false) { return false; } true } + fn apply_operator_stats_neq(min_max: &Series, literal: &Series) -> bool { + if min_max.len() < 2 || min_max.null_count() > 0 { + return true; + } + use ChunkCompare as C; + + // First check proofs all values are the same (e.g. min/max is the same) + // Second check proofs all values are equal, so we can skip as we search + // for non-equal values. + if min_max.get(0).unwrap() == min_max.get(1).unwrap() + && C::equal(literal, min_max).map(|s| s.all()).unwrap_or(false) + { + return false; + } + true + } + fn apply_operator_stats_rhs_lit(min_max: &Series, literal: &Series, op: Operator) -> bool { + use ChunkCompare as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), + Operator::NotEq => apply_operator_stats_neq(min_max, literal), // col > lit // e.g. - // [min, - // max] > 0 + // [min, max] > 0 // - // [-1, - // 2] > 0 + // [-1, 2] > 0 // // [false, true] -> true -> read Operator::Gt => { - // literal is bigger than max value - // selection needs all rows - ChunkCompare::<&Series>::gt(min_max, literal) - .ok() - .map(|s| s.any()) - == Some(true) + // Literal is bigger than max value, selection needs all rows. + C::gt(min_max, literal).map(|s| s.any()).unwrap_or(false) }, // col >= lit Operator::GtEq => { - // literal is bigger than max value - // selection needs all rows - ChunkCompare::<&Series>::gt_eq(min_max, literal) - .ok() - .map(|ca| ca.any()) - == Some(true) + // Literal is bigger than max value, selection needs all rows. + C::gt_eq(min_max, literal).map(|s| s.any()).unwrap_or(false) }, // col < lit Operator::Lt => { - // literal is smaller than min value - // selection needs all rows - ChunkCompare::<&Series>::lt(min_max, literal) - .ok() - .map(|ca| ca.any()) - == Some(true) + // Literal is smaller than min value, selection needs all rows. + C::lt(min_max, literal).map(|s| s.any()).unwrap_or(false) }, // col <= lit Operator::LtEq => { - // literal is smaller than min value - // selection needs all rows - ChunkCompare::<&Series>::lt_eq(min_max, literal) - .ok() - .map(|ca| ca.any()) - == Some(true) + // Literal is smaller than min value, selection needs all rows. + C::lt_eq(min_max, literal).map(|s| s.any()).unwrap_or(false) }, - // default: read the file + // Default: read the file _ => true, } } fn apply_operator_stats_lhs_lit(literal: &Series, min_max: &Series, op: Operator) -> bool { + use ChunkCompare as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), + Operator::NotEq => apply_operator_stats_eq(min_max, literal), Operator::Gt => { - // literal is bigger than max value - // selection needs all rows - ChunkCompare::<&Series>::gt(literal, min_max) - .ok() - .map(|ca| ca.any()) - == Some(true) + // Literal is bigger than max value, selection needs all rows. + C::gt(literal, min_max).map(|ca| ca.any()).unwrap_or(false) }, Operator::GtEq => { - // literal is bigger than max value - // selection needs all rows - ChunkCompare::<&Series>::gt_eq(literal, min_max) - .ok() + // Literal is bigger than max value, selection needs all rows. + C::gt_eq(literal, min_max) .map(|ca| ca.any()) - == Some(true) + .unwrap_or(false) }, Operator::Lt => { - // literal is smaller than min value - // selection needs all rows - ChunkCompare::<&Series>::lt(literal, min_max) - .ok() - .map(|ca| ca.any()) - == Some(true) + // Literal is smaller than min value, selection needs all rows. + C::lt(literal, min_max).map(|ca| ca.any()).unwrap_or(false) }, Operator::LtEq => { - // literal is smaller than min value - // selection needs all rows - ChunkCompare::<&Series>::lt_eq(literal, min_max) - .ok() + // Literal is smaller than min value, selection needs all rows. + C::lt_eq(literal, min_max) .map(|ca| ca.any()) - == Some(true) + .unwrap_or(false) }, - // default: read the file + // Default: read the file. _ => true, } } @@ -393,19 +355,21 @@ mod stats { use Expr::*; use Operator::*; if !self.expr.into_iter().all(|e| match e { - BinaryExpr { op, .. } => !matches!( - op, - Multiply | Divide | TrueDivide | FloorDivide | Modulus | NotEq - ), + BinaryExpr { op, .. } => { + !matches!(op, Multiply | Divide | TrueDivide | FloorDivide | Modulus) + }, Column(_) | Literal(_) | Alias(_, _) => true, _ => false, }) { return Ok(true); } - let schema = stats.schema(); - let fld_l = self.left.to_field(schema)?; - let fld_r = self.right.to_field(schema)?; + let Some(fld_l) = self.left.to_field(schema).ok() else { + return Ok(true); + }; + let Some(fld_r) = self.right.to_field(schema).ok() else { + return Ok(true); + }; #[cfg(debug_assertions)] { @@ -447,7 +411,7 @@ mod stats { }, } }, - // default: read the file + // Default: read the file _ => Ok(true), }; out.map(|read| { diff --git a/crates/polars-lazy/src/physical_plan/expressions/count.rs b/crates/polars-lazy/src/physical_plan/expressions/count.rs index 24ab27aefa40..e7a102b20f55 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/count.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/count.rs @@ -20,6 +20,7 @@ impl PhysicalExpr for CountExpr { fn as_expression(&self) -> Option<&Expr> { Some(&self.expr) } + fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult { Ok(Series::new("count", [df.height() as IdxSize])) } @@ -30,12 +31,11 @@ impl PhysicalExpr for CountExpr { groups: &'a GroupsProxy, _state: &ExecutionState, ) -> PolarsResult> { - let mut ca = groups.group_count(); - ca.rename(COUNT); + let ca = groups.group_count().with_name(COUNT); let s = ca.into_series(); - Ok(AggregationContext::new(s, Cow::Borrowed(groups), true)) } + fn to_field(&self, _input_schema: &Schema) -> PolarsResult { Ok(Field::new(COUNT, IDX_DTYPE)) } @@ -69,10 +69,8 @@ impl PartitionedAggregation for CountExpr { groups: &GroupsProxy, _state: &ExecutionState, ) -> PolarsResult { - // safety: - // groups are in bounds - let mut agg = unsafe { partitioned.agg_sum(groups) }; - agg.rename(COUNT); - Ok(agg) + // SAFETY: groups are in bounds. + let agg = unsafe { partitioned.agg_sum(groups) }; + Ok(agg.with_name(COUNT)) } } diff --git a/crates/polars-lazy/src/physical_plan/expressions/filter.rs b/crates/polars-lazy/src/physical_plan/expressions/filter.rs index a3408a377a2c..6095354960ef 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/filter.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/filter.rs @@ -49,18 +49,21 @@ impl PhysicalExpr for FilterExpr { let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?); if ac_predicate.is_aggregated() || ac_s.is_aggregated() { - let preds = ac_predicate.iter_groups(false); + // SAFETY: unstable series never lives longer than the iterator. + let preds = unsafe { ac_predicate.iter_groups(false) }; let s = ac_s.aggregated(); let ca = s.list()?; - let mut out = ca - .amortized_iter() - .zip(preds) - .map(|(opt_s, opt_pred)| match (opt_s, opt_pred) { - (Some(s), Some(pred)) => s.as_ref().filter(pred.as_ref().bool()?).map(Some), - _ => Ok(None), - }) - .collect::>()?; - out.rename(s.name()); + // SAFETY: unstable series never lives longer than the iterator. + let out = unsafe { + ca.amortized_iter() + .zip(preds) + .map(|(opt_s, opt_pred)| match (opt_s, opt_pred) { + (Some(s), Some(pred)) => s.as_ref().filter(pred.as_ref().bool()?).map(Some), + _ => Ok(None), + }) + .collect::>()? + .with_name(s.name()) + }; ac_s.with_series(out.into_series(), true, Some(&self.expr))?; ac_s.update_groups = WithSeriesLen; Ok(ac_s) @@ -69,12 +72,11 @@ impl PhysicalExpr for FilterExpr { let predicate_s = ac_predicate.flat_naive(); let predicate = predicate_s.bool()?; - // all values true don't do anything - if predicate.all() { + // All values true - don't do anything. + if let Some(true) = predicate.all_kleene() { return Ok(ac_s); } - // all values false - // create empty groups + // All values false - create empty groups. let groups = if !predicate.any() { let groups = groups.iter().map(|gi| [gi.first(), 0]).collect::>(); GroupsProxy::Slice { @@ -82,7 +84,7 @@ impl PhysicalExpr for FilterExpr { rolling: false, } } - // filter the indexes that are true + // Filter the indexes that are true. else { let predicate = predicate.rechunk(); let predicate = predicate.downcast_iter().next().unwrap(); @@ -94,15 +96,11 @@ impl PhysicalExpr for FilterExpr { .map(|(first, idx)| unsafe { let idx: Vec = idx .iter() - // Safety: - // just checked bounds in short circuited lhs - .filter_map(|i| { - match predicate.value(*i as usize) + .copied() + .filter(|i| { + // SAFETY: just checked bounds in short circuited lhs. + predicate.value(*i as usize) && predicate.is_valid_unchecked(*i as usize) - { - true => Some(*i), - _ => None, - } }) .collect(); @@ -117,9 +115,8 @@ impl PhysicalExpr for FilterExpr { .par_iter() .map(|&[first, len]| unsafe { let idx: Vec = (first..first + len) - // Safety: - // just checked bounds in short circuited lhs .filter(|&i| { + // SAFETY: just checked bounds in short circuited lhs predicate.value(i as usize) && predicate.is_valid_unchecked(i as usize) }) diff --git a/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs b/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs index 186d6392ba47..a70760b07e15 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/group_iter.rs @@ -5,7 +5,10 @@ use polars_core::series::unstable::UnstableSeries; use super::*; impl<'a> AggregationContext<'a> { - pub(super) fn iter_groups( + /// # Safety + /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive + /// longer than the iterator is UB. + pub(super) unsafe fn iter_groups( &mut self, keep_names: bool, ) -> Box>> + '_> { diff --git a/crates/polars-lazy/src/physical_plan/expressions/mod.rs b/crates/polars-lazy/src/physical_plan/expressions/mod.rs index 23dec5117d7c..de041d0c6112 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/mod.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/mod.rs @@ -8,6 +8,8 @@ mod count; mod filter; mod group_iter; mod literal; +#[cfg(feature = "dynamic_group_by")] +mod rolling; mod slice; mod sort; mod sortby; @@ -31,6 +33,8 @@ use polars_arrow::utils::CustomIterTools; use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_io::predicates::PhysicalIoExpr; +#[cfg(feature = "dynamic_group_by")] +pub(crate) use rolling::RollingExpr; pub(crate) use slice::*; pub(crate) use sort::*; pub(crate) use sortby::*; @@ -265,23 +269,25 @@ impl<'a> AggregationContext<'a> { }); }, _ => { - let groups = self - .series() - .list() - .expect("impl error, should be a list at this point") - .amortized_iter() - .map(|s| { - if let Some(s) = s { - let len = s.as_ref().len() as IdxSize; - let new_offset = offset + len; - let out = [offset, len]; - offset = new_offset; - out - } else { - [offset, 0] - } - }) - .collect_trusted(); + // SAFETY: unstable series never lives longer than the iterator. + let groups = unsafe { + self.series() + .list() + .expect("impl error, should be a list at this point") + .amortized_iter() + .map(|s| { + if let Some(s) = s { + let len = s.as_ref().len() as IdxSize; + let new_offset = offset + len; + let out = [offset, len]; + offset = new_offset; + out + } else { + [offset, 0] + } + }) + .collect_trusted() + }; self.groups = Cow::Owned(GroupsProxy::Slice { groups, rolling: false, @@ -621,7 +627,7 @@ impl PhysicalIoExpr for PhysicalIoHelper { } } -pub(super) fn phys_expr_to_io_expr(expr: Arc) -> Arc { +pub(crate) fn phys_expr_to_io_expr(expr: Arc) -> Arc { let has_window_function = if let Some(expr) = expr.as_expression() { expr.into_iter() .any(|expr| matches!(expr, Expr::Window { .. })) diff --git a/crates/polars-lazy/src/physical_plan/expressions/rolling.rs b/crates/polars-lazy/src/physical_plan/expressions/rolling.rs new file mode 100644 index 000000000000..2cdf54af60c0 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/expressions/rolling.rs @@ -0,0 +1,58 @@ +use super::*; + +pub(crate) struct RollingExpr { + /// the root column that the Function will be applied on. + /// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index + /// TODO! support keys? + /// The challenge is that the group_by will reorder the results and the + /// keys, and time index would need to be updated, or the result should be joined back + /// For now, don't support it. + /// + /// A function Expr. i.e. Mean, Median, Max, etc. + pub(crate) function: Expr, + pub(crate) phys_function: Arc, + pub(crate) out_name: Option>, + pub(crate) options: RollingGroupOptions, + pub(crate) expr: Expr, +} + +impl PhysicalExpr for RollingExpr { + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let groups_map = state.group_tuples.read().unwrap(); + // Groups must be set by expression runner. + let groups = groups_map + .get(self.options.index_column.as_str()) + .expect("impl error"); + + let mut out = self + .phys_function + .evaluate_on_groups(df, groups, state)? + .finalize(); + polars_ensure!(out.len() == groups.len(), agg_len = out.len(), groups.len()); + if let Some(name) = &self.out_name { + out.rename(name.as_ref()); + } + Ok(out) + } + + fn evaluate_on_groups<'a>( + &self, + _df: &DataFrame, + _groups: &'a GroupsProxy, + _state: &ExecutionState, + ) -> PolarsResult> { + polars_bail!(InvalidOperation: "rolling expression not allowed in aggregation"); + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.function.to_field(input_schema, Context::Default) + } + + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn is_valid_aggregation(&self) -> bool { + false + } +} diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 473c43e5befc..4f95d3e555d3 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -84,13 +84,8 @@ impl PhysicalExpr for SortExpr { groups .par_iter() .map(|(first, idx)| { - // Safety: - // Group tuples are always in bounds - let group = unsafe { - series.take_iter_unchecked( - &mut idx.iter().map(|i| *i as usize), - ) - }; + // SAFETY: group tuples are always in bounds. + let group = unsafe { series.take_slice_unchecked(idx) }; let sorted_idx = group.arg_sort(sort_options); let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx); diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index fac35f8e4009..feb398432b64 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -1,4 +1,3 @@ -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use polars_core::frame::group_by::{GroupsIndicator, GroupsProxy}; @@ -34,15 +33,138 @@ impl SortByExpr { fn prepare_descending(descending: &[bool], by_len: usize) -> Vec { match (descending.len(), by_len) { - // equal length + // Equal length. (n_rdescending, n) if n_rdescending == n => descending.to_vec(), - // none given all false + // None given all false. (0, n) => vec![false; n], - // broadcast first + // Broadcast first. (_, n) => vec![descending[0]; n], } } +static ERR_MSG: &str = "expressions in 'sort_by' produced a different number of groups"; + +fn check_groups(a: &GroupsProxy, b: &GroupsProxy) -> PolarsResult<()> { + polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| { + a.len() == b.len() + }), ComputeError: ERR_MSG); + Ok(()) +} + +fn sort_by_groups_single_by( + indicator: GroupsIndicator, + sort_by_s: &Series, + descending: &[bool], +) -> PolarsResult<(IdxSize, Vec)> { + let new_idx = match indicator { + GroupsIndicator::Idx((_, idx)) => { + // SAFETY: group tuples are always in bounds. + let group = unsafe { sort_by_s.take_slice_unchecked(idx) }; + + let sorted_idx = group.arg_sort(SortOptions { + descending: descending[0], + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + }, + GroupsIndicator::Slice([first, len]) => { + let group = sort_by_s.slice(first as i64, len as usize); + let sorted_idx = group.arg_sort(SortOptions { + descending: descending[0], + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + map_sorted_indices_to_group_slice(&sorted_idx, first) + }, + }; + let first = new_idx + .first() + .ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?; + + Ok((*first, new_idx)) +} + +fn sort_by_groups_no_match_single<'a>( + mut ac_in: AggregationContext<'a>, + mut ac_by: AggregationContext<'a>, + descending: bool, + expr: &Expr, +) -> PolarsResult> { + let s_in = ac_in.aggregated(); + let s_by = ac_by.aggregated(); + let mut s_in = s_in.list().unwrap().clone(); + let mut s_by = s_by.list().unwrap().clone(); + + let ca: PolarsResult = POOL.install(|| { + s_in.par_iter_indexed() + .zip(s_by.par_iter_indexed()) + .map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) { + (Some(s), Some(s_sort_by)) => { + polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression"); + let idx = s_sort_by.arg_sort(SortOptions { + descending, + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + Ok(Some(unsafe { s.take_unchecked(&idx) })) + }, + _ => Ok(None), + }) + .collect() + }); + let s = ca?.with_name(s_in.name()).into_series(); + ac_in.with_series(s, true, Some(expr))?; + Ok(ac_in) +} + +fn sort_by_groups_multiple_by( + indicator: GroupsIndicator, + sort_by_s: &[Series], + descending: &[bool], +) -> PolarsResult<(IdxSize, Vec)> { + let new_idx = match indicator { + GroupsIndicator::Idx((_first, idx)) => { + // SAFETY: group tuples are always in bounds. + let groups = sort_by_s + .iter() + .map(|s| unsafe { s.take_slice_unchecked(idx) }) + .collect::>(); + + let options = SortMultipleOptions { + other: groups[1..].to_vec(), + descending: descending.to_owned(), + multithreaded: false, + }; + + let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + }, + GroupsIndicator::Slice([first, len]) => { + let groups = sort_by_s + .iter() + .map(|s| s.slice(first as i64, len as usize)) + .collect::>(); + + let options = SortMultipleOptions { + other: groups[1..].to_vec(), + descending: descending.to_owned(), + multithreaded: false, + }; + let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); + map_sorted_indices_to_group_slice(&sorted_idx, first) + }, + }; + let first = new_idx + .first() + .ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?; + + Ok((*first, new_idx)) +} + impl PhysicalExpr for SortByExpr { fn as_expression(&self) -> Option<&Expr> { Some(&self.expr) @@ -91,9 +213,8 @@ impl PhysicalExpr for SortByExpr { sorted_idx.len(), series.len() ); - // Safety: - // sorted index are within bounds - unsafe { series.take_unchecked(&sorted_idx) } + // SAFETY: sorted index are within bounds. + unsafe { Ok(series.take_unchecked(&sorted_idx)) } } #[allow(clippy::ptr_arg)] @@ -104,210 +225,90 @@ impl PhysicalExpr for SortByExpr { state: &ExecutionState, ) -> PolarsResult> { let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?; - // if the length of the sort_by argument differs - // we raise an error - let invalid = AtomicBool::new(false); + let descending = prepare_descending(&self.descending, self.by.len()); + + let mut ac_sort_by = self + .by + .iter() + .map(|e| e.evaluate_on_groups(df, groups, state)) + .collect::>>()?; + let mut sort_by_s = ac_sort_by + .iter() + .map(|s| { + let s = s.flat_naive(); + match s.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_) => s.into_owned(), + _ => s.to_physical_repr().into_owned(), + } + }) + .collect::>(); - // the groups of the lhs of the expressions do not match the series values - // we must take the slower path. - if !matches!(ac_in.update_groups, UpdateGroups::No) { + // A check up front to ensure the input expressions have the same number of total elements. + for sort_by_s in &sort_by_s { polars_ensure!( - self.by.len() <= 1, expr = self.expr, ComputeError: - "this expression is not supported for more than two sort columns" + sort_by_s.len() == ac_in.flat_naive().len(), expr = self.expr, ComputeError: + "the expression in `sort_by` argument must result in the same length" ); - let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?; - let sort_by = ac_sort_by.aggregated(); - let mut sort_by = sort_by.list().unwrap().clone(); - let s = ac_in.aggregated(); - let mut s = s.list().unwrap().clone(); - - let descending = self.descending[0]; - let mut ca: ListChunked = POOL.install(|| { - s.par_iter_indexed() - .zip(sort_by.par_iter_indexed()) - .map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) { - (Some(s), Some(s_sort_by)) => { - if s.len() != s_sort_by.len() { - invalid.store(true, Ordering::Relaxed); - None - } else { - let idx = s_sort_by.arg_sort(SortOptions { - descending, - // we are already in par iter. - multithreaded: false, - ..Default::default() - }); - Some(unsafe { s.take_unchecked(&idx).unwrap() }) - } - }, - _ => None, - }) - .collect() - }); - ca.rename(s.name()); - let s = ca.into_series(); - ac_in.with_series(s, true, Some(&self.expr))?; - Ok(ac_in) - } else { - let descending = prepare_descending(&self.descending, self.by.len()); - - let (groups, ordered_by_group_operation) = if self.by.len() == 1 { - let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?; - let sort_by_s = ac_sort_by.flat_naive().into_owned(); - polars_ensure!( - sort_by_s.len() == ac_in.flat_naive().len(), expr = self.expr, ComputeError: - "the expression in `sort_by` argument must result in the same length" - ); - let ordered_by_group_operation = matches!( - ac_sort_by.update_groups, - UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen - ); - let groups = ac_sort_by.groups(); - - let groups = POOL.install(|| { - groups - .par_iter() - .map(|indicator| { - let new_idx = match indicator { - GroupsIndicator::Idx((_, idx)) => { - // Safety: - // Group tuples are always in bounds - let group = unsafe { - sort_by_s.take_iter_unchecked( - &mut idx.iter().map(|i| *i as usize), - ) - }; - - let sorted_idx = group.arg_sort(SortOptions { - descending: descending[0], - // we are already in par iter. - multithreaded: false, - ..Default::default() - }); - map_sorted_indices_to_group_idx(&sorted_idx, idx) - }, - GroupsIndicator::Slice([first, len]) => { - let group = sort_by_s.slice(first as i64, len as usize); - let sorted_idx = group.arg_sort(SortOptions { - descending: descending[0], - // we are already in par iter. - multithreaded: false, - ..Default::default() - }); - map_sorted_indices_to_group_slice(&sorted_idx, first) - }, - }; - let first = new_idx.first().unwrap_or_else(|| { - invalid.store(true, Ordering::Relaxed); - &0 - }); - - (*first, new_idx) - }) - .collect() - }); + } - (GroupsProxy::Idx(groups), ordered_by_group_operation) - } else { - let mut ac_sort_by = self - .by - .iter() - .map(|e| e.evaluate_on_groups(df, groups, state)) - .collect::>>()?; - let sort_by_s = ac_sort_by - .iter() - .map(|s| { - let s = s.flat_naive(); - match s.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_) => s.into_owned(), - _ => s.to_physical_repr().into_owned(), - } - }) - .collect::>(); + let ordered_by_group_operation = matches!( + ac_sort_by[0].update_groups, + UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen + ); - for sort_by_s in &sort_by_s { - polars_ensure!( - sort_by_s.len() == ac_in.flat_naive().len(), expr = self.expr, ComputeError: - "the expression in `sort_by` argument must result in the same length" - ); - } + let groups = if self.by.len() == 1 { + let mut ac_sort_by = ac_sort_by.pop().unwrap(); - let ordered_by_group_operation = matches!( - ac_sort_by[0].update_groups, - UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen + // The groups of the lhs of the expressions do not match the series values, + // we must take the slower path. + if !matches!(ac_in.update_groups, UpdateGroups::No) { + return sort_by_groups_no_match_single( + ac_in, + ac_sort_by, + self.descending[0], + &self.expr, ); - let groups = ac_sort_by[0].groups(); + }; + + let sort_by_s = sort_by_s.pop().unwrap(); + let groups = ac_sort_by.groups(); - let groups = POOL.install(|| { + let (check, groups) = POOL.join( + || check_groups(groups, ac_in.groups()), + || { groups .par_iter() .map(|indicator| { - let new_idx = match indicator { - GroupsIndicator::Idx((_first, idx)) => { - // Safety: - // Group tuples are always in bounds - let groups = sort_by_s - .iter() - .map(|s| unsafe { - s.take_iter_unchecked( - &mut idx.iter().map(|i| *i as usize), - ) - }) - .collect::>(); - - let options = SortMultipleOptions { - other: groups[1..].to_vec(), - descending: descending.clone(), - multithreaded: false, - }; - - let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); - map_sorted_indices_to_group_idx(&sorted_idx, idx) - }, - GroupsIndicator::Slice([first, len]) => { - let groups = sort_by_s - .iter() - .map(|s| s.slice(first as i64, len as usize)) - .collect::>(); - - let options = SortMultipleOptions { - other: groups[1..].to_vec(), - descending: descending.clone(), - multithreaded: false, - }; - let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); - map_sorted_indices_to_group_slice(&sorted_idx, first) - }, - }; - let first = new_idx.first().unwrap_or_else(|| { - invalid.store(true, Ordering::Relaxed); - &0 - }); - - (*first, new_idx) + sort_by_groups_single_by(indicator, &sort_by_s, &descending) }) - .collect() - }); - - (GroupsProxy::Idx(groups), ordered_by_group_operation) - }; - polars_ensure!( - !invalid.load(Ordering::Relaxed), expr = self.expr, ComputeError: - "the expression in `sort_by` argument must result in the same length" + .collect::>() + }, ); + check?; - // if the rhs is already aggregated once, - // it is reordered by the group_by operation - // we must ensure that we are as well. - if ordered_by_group_operation { - let s = ac_in.aggregated(); - ac_in.with_series(s.explode().unwrap(), false, None)?; - } + GroupsProxy::Idx(groups?) + } else { + let groups = ac_sort_by[0].groups(); - ac_in.with_groups(groups); - Ok(ac_in) + let groups = POOL.install(|| { + groups + .par_iter() + .map(|indicator| sort_by_groups_multiple_by(indicator, &sort_by_s, &descending)) + .collect::>() + }); + GroupsProxy::Idx(groups?) + }; + + // If the rhs is already aggregated once, it is reordered by the + // group_by operation - we must ensure that we are as well. + if ordered_by_group_operation { + let s = ac_in.aggregated(); + ac_in.with_series(s.explode().unwrap(), false, None)?; } + + ac_in.with_groups(groups); + Ok(ac_in) } fn to_field(&self, input_schema: &Schema) -> PolarsResult { diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index 22b14b517dfa..b07023372a72 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -22,9 +22,7 @@ impl TakeExpr { series: Series, ) -> PolarsResult { let idx = self.idx.evaluate(df, state)?; - let nulls_before_cast = idx.null_count(); - let idx = idx.cast(&IDX_DTYPE)?; if idx.null_count() != nulls_before_cast { self.oob_err()?; @@ -64,17 +62,16 @@ impl PhysicalExpr for TakeExpr { let idx = idx.idx().unwrap(); // The indexes are AggregatedFlat, meaning they are a single values pointing into - // a group. - // If we zip this with the first of each group -> `idx + firs` then we can + // a group. If we zip this with the first of each group -> `idx + firs` then we can // simply use a take operation on the whole array instead of per group. - // The groups maybe scattered all over the place, so we sort by group + // The groups maybe scattered all over the place, so we sort by group. ac.sort_by_groups(); - // A previous aggregation may have updated the groups + // A previous aggregation may have updated the groups. let groups = ac.groups(); - // Determine the take indices + // Determine the take indices. let idx: IdxCa = match groups.as_ref() { GroupsProxy::Idx(groups) => { if groups.all().iter().zip(idx).any(|(g, idx)| match idx { @@ -108,7 +105,7 @@ impl PhysicalExpr for TakeExpr { return Ok(ac); }, AggState::AggregatedList(s) => s.list().unwrap().clone(), - // Maybe a literal as well, this needs a different path + // Maybe a literal as well, this needs a different path. AggState::NotAggregated(_) => { let s = idx.aggregated(); s.list().unwrap().clone() @@ -123,13 +120,13 @@ impl PhysicalExpr for TakeExpr { Some(idx) => { if idx != 0 { // We must make sure that the column we take from is sorted by - // groups otherwise we might point into the wrong group + // groups otherwise we might point into the wrong group. ac.sort_by_groups() } // Make sure that we look at the updated groups. let groups = ac.groups(); - // we offset the groups first by idx; + // We offset the groups first by idx. let idx: NoNull = match groups.as_ref() { GroupsProxy::Idx(groups) => { if groups.all().iter().any(|g| idx >= g.len() as IdxSize) { @@ -169,24 +166,17 @@ impl PhysicalExpr for TakeExpr { let s = idx.cast(&DataType::List(Box::new(IDX_DTYPE)))?; let idx = s.list().unwrap(); - let mut taken = ac - .aggregated() - .list() - .unwrap() - .amortized_iter() - .zip(idx.amortized_iter()) - .map(|(s, idx)| { - s.and_then(|s| { - idx.map(|idx| { - let idx = idx.as_ref().idx().unwrap(); - s.as_ref().take(idx) - }) - }) - .transpose() - }) - .collect::>()?; - - taken.rename(ac.series().name()); + let taken = unsafe { + ac.aggregated() + .list() + .unwrap() + .amortized_iter() + .zip(idx.amortized_iter()) + .map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap()))) + .map(|opt_res| opt_res.transpose()) + .collect::>()? + .with_name(ac.series().name()) + }; ac.with_series(taken.into_series(), true, Some(&self.expr))?; ac.with_update_groups(UpdateGroups::WithGroupsLen); diff --git a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs index 1efad2cfdbdf..ffb54b6c366d 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs @@ -13,7 +13,7 @@ pub struct TernaryExpr { truthy: Arc, falsy: Arc, expr: Expr, - // can expensive on small data to run literals in parallel + // Can be expensive on small data to run literals in parallel. run_par: bool, } @@ -55,31 +55,34 @@ fn finish_as_iters<'a>( mut ac_falsy: AggregationContext<'a>, mut ac_mask: AggregationContext<'a>, ) -> PolarsResult> { - let mut ca: ListChunked = ac_truthy - .iter_groups(false) - .zip(ac_falsy.iter_groups(false)) - .zip(ac_mask.iter_groups(false)) - .map(|((truthy, falsy), mask)| { - match (truthy, falsy, mask) { - (Some(truthy), Some(falsy), Some(mask)) => Some( - truthy - .as_ref() - .zip_with(mask.as_ref().bool()?, falsy.as_ref()), - ), - _ => None, - } - .transpose() - }) - .collect::>()?; - - ca.rename(ac_truthy.series().name()); - // aggregation leaves only a single chunks + // SAFETY: unstable series never lives longer than the iterator. + let ca = unsafe { + ac_truthy + .iter_groups(false) + .zip(ac_falsy.iter_groups(false)) + .zip(ac_mask.iter_groups(false)) + .map(|((truthy, falsy), mask)| { + match (truthy, falsy, mask) { + (Some(truthy), Some(falsy), Some(mask)) => Some( + truthy + .as_ref() + .zip_with(mask.as_ref().bool()?, falsy.as_ref()), + ), + _ => None, + } + .transpose() + }) + .collect::>()? + .with_name(ac_truthy.series().name()) + }; + + // Aggregation leaves only a single chunk. let arr = ca.downcast_iter().next().unwrap(); let list_vals_len = arr.values().len(); - let mut out = ca.into_series(); + let mut out = ca.into_series(); if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() && - // exploded list should be equal to groups length + // Exploded list should be equal to groups length. list_vals_len == ac_truthy.groups.len() { out = out.explode()? @@ -93,16 +96,16 @@ impl PhysicalExpr for TernaryExpr { fn as_expression(&self) -> Option<&Expr> { Some(&self.expr) } + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { let mut state = state.split(); - // don't cache window functions as they run in parallel + // Don't cache window functions as they run in parallel. state.remove_cache_window_flag(); let mask_series = self.predicate.evaluate(df, &state)?; let mut mask = mask_series.bool()?.clone(); let op_truthy = || self.truthy.evaluate(df, &state); let op_falsy = || self.falsy.evaluate(df, &state); - let (truthy, falsy) = if self.run_par { POOL.install(|| rayon::join(op_truthy, op_falsy)) } else { @@ -122,9 +125,9 @@ impl PhysicalExpr for TernaryExpr { } expand_lengths(&mut truthy, &mut falsy, &mut mask); - truthy.zip_with(&mask, &falsy) } + fn to_field(&self, input_schema: &Schema) -> PolarsResult { self.truthy.to_field(input_schema) } @@ -138,15 +141,15 @@ impl PhysicalExpr for TernaryExpr { ) -> PolarsResult> { let aggregation_predicate = self.predicate.is_valid_aggregation(); if !aggregation_predicate { - // unwrap will not fail as it is not an aggregation expression. + // Unwrap will not fail as it is not an aggregation expression. eprintln!( "The predicate '{}' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the group_by operation would. This behavior is experimental and may be subject to change", self.predicate.as_expression().unwrap() ) } + let op_mask = || self.predicate.evaluate_on_groups(df, groups, state); let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state); let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state); - let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par { POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy))) } else { @@ -159,10 +162,10 @@ impl PhysicalExpr for TernaryExpr { let mask_s = ac_mask.flat_naive(); - // BIG TODO: find which branches are never hit and remove them + // BIG TODO: find which branches are never hit and remove them. use AggState::*; match (ac_truthy.agg_state(), ac_falsy.agg_state()) { - // all branches are aggregated-flat or literal + // All branches are aggregated-flat or literal // mask -> aggregated-flat // truthy -> aggregated-flat | literal // falsy -> aggregated-flat | literal @@ -175,23 +178,23 @@ impl PhysicalExpr for TernaryExpr { let mut falsy = falsy.clone(); let mut mask = ac_mask.series().bool()?.clone(); expand_lengths(&mut truthy, &mut falsy, &mut mask); - let mut out = truthy.zip_with(&mask, &falsy).unwrap(); - out.rename(truthy.name()); - ac_truthy.with_series(out, true, Some(&self.expr))?; + let out = truthy.zip_with(&mask, &falsy).unwrap(); + ac_truthy.with_series(out.with_name(truthy.name()), true, Some(&self.expr))?; Ok(ac_truthy) }, - // we cannot flatten a list because that changes the order, so we apply over groups + // We cannot flatten a list because that changes the order, so we apply over groups. (AggregatedList(_), NotAggregated(_)) | (NotAggregated(_), AggregatedList(_)) => { finish_as_iters(ac_truthy, ac_falsy, ac_mask) }, - // then: + + // Then: // col().shift() - // otherwise: + // Otherwise: // None (AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => { if !aggregation_predicate { - // experimental elementwise behavior tested in `test_binary_agg_context_1` + // Experimental elementwise behavior tested in `test_binary_agg_context_1`. return finish_as_iters(ac_truthy, ac_falsy, ac_mask); } let mask = mask_s.bool()?; @@ -208,16 +211,12 @@ impl PhysicalExpr for TernaryExpr { let s = ac_truthy.aggregated(); let ca = s.list().unwrap(); check_length(ca, mask)?; - let mut out: ListChunked = ca + let out = ca .into_iter() .zip(mask) - .map(|(truthy, take)| match (truthy, take) { - (Some(v), Some(true)) => Some(v), - (Some(_), Some(false)) => None, - _ => None, - }) - .collect_trusted(); - out.rename(ac_truthy.series().name()); + .map(|(truthy, take)| if take? { truthy } else { None }) + .collect_trusted::() + .with_name(ac_truthy.series().name()); ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; Ok(ac_truthy) } else if ac_truthy.is_literal() @@ -226,38 +225,30 @@ impl PhysicalExpr for TernaryExpr { let s = ac_falsy.aggregated(); let ca = s.list().unwrap(); check_length(ca, mask)?; - let mut out: ListChunked = ca + let out = ca .into_iter() .zip(mask) - .map(|(falsy, take)| match (falsy, take) { - (Some(_), Some(true)) => None, - (Some(v), Some(false)) => Some(v), - _ => None, - }) - .collect_trusted(); - out.rename(ac_truthy.series().name()); + .map(|(falsy, take)| if take? { None } else { falsy }) + .collect_trusted::() + .with_name(ac_truthy.series().name()); ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; Ok(ac_truthy) } - // then: + // Then: // col().shift() - // otherwise: + // Otherwise: // lit(list) else if ac_truthy.is_literal() { let literal = ac_truthy.series(); let s = ac_falsy.aggregated(); let ca = s.list().unwrap(); check_length(ca, mask)?; - let mut out: ListChunked = ca + let out = ca .into_iter() .zip(mask) - .map(|(falsy, take)| match (falsy, take) { - (Some(_), Some(true)) => Some(literal.clone()), - (Some(v), Some(false)) => Some(v), - _ => None, - }) - .collect_trusted(); - out.rename(ac_truthy.series().name()); + .map(|(falsy, take)| if take? { Some(literal.clone()) } else { falsy }) + .collect_trusted::() + .with_name(ac_truthy.series().name()); ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; Ok(ac_truthy) } else { @@ -265,25 +256,21 @@ impl PhysicalExpr for TernaryExpr { let s = ac_truthy.aggregated(); let ca = s.list().unwrap(); check_length(ca, mask)?; - let mut out: ListChunked = ca + let out = ca .into_iter() .zip(mask) - .map(|(truthy, take)| match (truthy, take) { - (Some(v), Some(true)) => Some(v), - (Some(_), Some(false)) => Some(literal.clone()), - _ => None, - }) - .collect_trusted(); - out.rename(ac_truthy.series().name()); + .map(|(truthy, take)| if take? { truthy } else { Some(literal.clone()) }) + .collect_trusted::() + .with_name(ac_truthy.series().name()); ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; Ok(ac_truthy) } }, // Both are or a flat series or aggregated into a list - // so we can flatten the Series an apply the operators + // so we can flatten the Series an apply the operators. _ => { - // inspect the predicate and if it is consisting - // if arity/binary and some aggregation we apply as iters as + // Inspect the predicate and if it is consisting + // of arity/binary and some aggregation we apply as iters as // it gets complicated quickly. // For instance: // when(col(..) > min(..)).then(..).otherwise(..) @@ -309,7 +296,7 @@ impl PhysicalExpr for TernaryExpr { } if !aggregation_predicate { - // experimental elementwise behavior tested in `test_binary_agg_context_1` + // Experimental elementwise behavior tested in `test_binary_agg_context_1`. return finish_as_iters(ac_truthy, ac_falsy, ac_mask); } let mut mask = mask_s.bool()?.clone(); @@ -318,7 +305,7 @@ impl PhysicalExpr for TernaryExpr { expand_lengths(&mut truthy, &mut falsy, &mut mask); let out = truthy.zip_with(&mask, &falsy)?; - // because of the flattening we don't have to do that anymore + // Because of the flattening we don't have to do that anymore. if matches!(ac_truthy.update_groups, UpdateGroups::WithSeriesLen) { ac_truthy.with_update_groups(UpdateGroups::No); } diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index fb26b0e6fb5c..0d0add86dea3 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -4,13 +4,13 @@ use std::sync::Arc; use polars_arrow::export::arrow::array::PrimitiveArray; use polars_core::export::arrow::bitmap::Bitmap; use polars_core::frame::group_by::{GroupBy, GroupsProxy}; -use polars_core::frame::hash_join::{ - default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds, JoinValidation, -}; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::_split_offsets; use polars_core::{downcast_as_macro_arg_physical, POOL}; +use polars_ops::frame::join::{ + default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds, JoinValidation, +}; use polars_utils::format_smartstring; use polars_utils::sort::perfect_sort; use polars_utils::sync::SyncPtr; @@ -29,7 +29,7 @@ pub struct WindowExpr { /// A function Expr. i.e. Mean, Median, Max, etc. pub(crate) function: Expr, pub(crate) phys_function: Arc, - pub(crate) options: WindowOptions, + pub(crate) mapping: WindowMapping, pub(crate) expr: Expr, } @@ -120,7 +120,7 @@ impl WindowExpr { // Safety: // groups should always be in bounds. - unsafe { flattened.take_unchecked(&idx) } + unsafe { Ok(flattened.take_unchecked(&idx)) } } #[allow(clippy::too_many_arguments)] @@ -321,7 +321,7 @@ impl WindowExpr { sorted_keys: bool, gb: &GroupBy, ) -> PolarsResult { - match (self.options.mapping, agg_state) { + match (self.mapping, agg_state) { // Explode // `(col("x").sum() * col("y")).list().over("groups").flatten()` (WindowMapping::Explode, _) => Ok(MapStrategy::Explode), @@ -423,7 +423,7 @@ impl PhysicalExpr for WindowExpr { let explicit_list_agg = self.is_explicit_list_agg(); // if we flatten this column we need to make sure the groups are sorted. - let mut sort_groups = matches!(self.options.mapping, WindowMapping::Explode) || + let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) || // if not // `col().over()` // and not @@ -454,7 +454,7 @@ impl PhysicalExpr for WindowExpr { cache_key.push_str(s.name()); } - let mut gt_map = state.group_tuples.lock().unwrap(); + let mut gt_map = state.group_tuples.write().unwrap(); // we run sequential and partitioned // and every partition run the cache should be empty so we expect a max of 1. debug_assert!(gt_map.len() <= 1); @@ -495,7 +495,7 @@ impl PhysicalExpr for WindowExpr { // Worst case is that a categorical is created with indexes from the string // cache which is fine, as the physical representation is undefined. #[cfg(feature = "dtype-categorical")] - let _sc = polars_core::IUseStringCache::hold(); + let _sc = polars_core::StringCacheHolder::hold(); let mut ac = self.run_aggregation(df, state, &gb)?; use MapStrategy::*; @@ -635,9 +635,7 @@ fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Series) -> Se match join_opt_ids { Either::Left(ids) => unsafe { - out_column.take_opt_iter_unchecked( - &mut ids.iter().map(|&opt_i| opt_i.map(|i| i as usize)), - ) + out_column.take_unchecked(&ids.iter().copied().collect_ca("")) }, Either::Right(ids) => unsafe { out_column._take_opt_chunked_unchecked(ids) }, } @@ -645,16 +643,14 @@ fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Series) -> Se #[cfg(not(feature = "chunked_ids"))] unsafe { - out_column.take_opt_iter_unchecked( - &mut join_opt_ids.iter().map(|&opt_i| opt_i.map(|i| i as usize)), - ) + out_column.take_unchecked(&join_opt_ids.iter().copied().collect_ca("")) } } fn cache_gb(gb: GroupBy, state: &ExecutionState, cache_key: &str) { if state.cache_window() { let groups = gb.take_groups(); - let mut gt_map = state.group_tuples.lock().unwrap(); + let mut gt_map = state.group_tuples.write().unwrap(); gt_map.insert(cache_key.to_string(), groups); } } diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 2411539969d0..0e614c4cb683 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -95,59 +95,71 @@ pub(crate) fn create_physical_expr( Window { mut function, partition_by, - order_by: _, options, } => { - state.set_window(); - // TODO! Order by - let group_by = create_physical_expressions( - &partition_by, - Context::Default, - expr_arena, - schema, - state, - )?; - - // set again as the state can be reset state.set_window(); let phys_function = create_physical_expr(function, Context::Aggregation, expr_arena, schema, state)?; - let mut out_name = None; - let mut apply_columns = aexpr_to_leaf_names(function, expr_arena); - // sort and then dedup removes consecutive duplicates == all duplicates - apply_columns.sort(); - apply_columns.dedup(); - - if apply_columns.is_empty() { - if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) { - apply_columns.push(Arc::from("literal")) - } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Count)) { - apply_columns.push(Arc::from("count")) - } else { - let e = node_to_expr(function, expr_arena); - polars_bail!( - ComputeError: - "cannot apply a window function, did not find a root column; \ - this is likely due to a syntax error in this expression: {:?}", e - ); - } - } + let mut out_name = None; if let Alias(expr, name) = expr_arena.get(function) { function = *expr; out_name = Some(name.clone()); }; - let function = node_to_expr(function, expr_arena); + let function_expr = node_to_expr(function, expr_arena); + let expr = node_to_expr(expression, expr_arena); - Ok(Arc::new(WindowExpr { - group_by, - apply_columns, - out_name, - function, - phys_function, - options, - expr: node_to_expr(expression, expr_arena), - })) + match options { + WindowType::Over(mapping) => { + // set again as the state can be reset + state.set_window(); + // TODO! Order by + let group_by = create_physical_expressions( + &partition_by, + Context::Default, + expr_arena, + schema, + state, + )?; + let mut apply_columns = aexpr_to_leaf_names(function, expr_arena); + // sort and then dedup removes consecutive duplicates == all duplicates + apply_columns.sort(); + apply_columns.dedup(); + + if apply_columns.is_empty() { + if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) { + apply_columns.push(Arc::from("literal")) + } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Count)) { + apply_columns.push(Arc::from("count")) + } else { + let e = node_to_expr(function, expr_arena); + polars_bail!( + ComputeError: + "cannot apply a window function, did not find a root column; \ + this is likely due to a syntax error in this expression: {:?}", e + ); + } + } + + Ok(Arc::new(WindowExpr { + group_by, + apply_columns, + out_name, + function: function_expr, + phys_function, + mapping, + expr, + })) + }, + #[cfg(feature = "dynamic_group_by")] + WindowType::Rolling(options) => Ok(Arc::new(RollingExpr { + function: function_expr, + phys_function, + out_name, + options, + expr, + })), + } }, Literal(value) => { state.local.has_lit = true; diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 5a08853631dd..45af5e1f3a11 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -150,9 +150,16 @@ pub fn create_physical_plan( match logical_plan { #[cfg(feature = "python")] PythonScan { options, .. } => Ok(Box::new(executors::PythonScanExec { options })), - FileSink { .. } => panic!( - "sink_parquet not yet supported in standard engine. Use 'collect().write_parquet()'" - ), + Sink { payload, .. } => { + match payload { + SinkType::Memory => panic!("Memory Sink not supported in the standard engine."), + SinkType::File{file_type, ..} => panic!( + "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_parquet()'" + ), + #[cfg(feature = "cloud")] + SinkType::Cloud{..} => panic!("Cloud Sink not supported in standard engine.") + } + } Union { inputs, options } => { let inputs = inputs .into_iter() @@ -189,6 +196,7 @@ pub fn create_physical_plan( predicate, file_options, } => { + let mut state = ExpressionConversionState::default(); let predicate = predicate .map(|pred| { create_physical_expr( @@ -196,7 +204,7 @@ pub fn create_physical_plan( Context::Default, expr_arena, output_schema.as_ref(), - &mut Default::default(), + &mut state, ) }) .map_or(Ok(None), |v| v.map(Some))?; @@ -224,14 +232,30 @@ pub fn create_physical_plan( FileScan::Parquet { options, cloud_options, + metadata } => Ok(Box::new(executors::ParquetExec::new( path, - file_info.schema, + file_info, predicate, options, cloud_options, file_options, + metadata ))), + FileScan::Anonymous { + function, + .. + } => { + Ok(Box::new(executors::AnonymousScanExec { + function, + predicate, + file_options, + file_info, + output_schema, + predicate_has_windows: state.has_windows, + })) + + } } }, Projection { @@ -269,34 +293,6 @@ pub fn create_physical_plan( options, })) }, - LocalProjection { - expr, - input, - schema: _schema, - .. - } => { - let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); - - let input = create_physical_plan(input, lp_arena, expr_arena)?; - let mut state = ExpressionConversionState::new(POOL.current_num_threads() > expr.len()); - let phys_expr = create_physical_expressions( - &expr, - Context::Default, - expr_arena, - Some(&input_schema), - &mut state, - )?; - Ok(Box::new(executors::ProjectionExec { - input, - cse_exprs: vec![], - expr: phys_expr, - has_windows: state.has_windows, - input_schema, - #[cfg(test)] - schema: _schema, - options: Default::default(), - })) - }, DataFrameScan { df, projection, @@ -323,33 +319,6 @@ pub fn create_physical_plan( predicate_has_windows: state.has_windows, })) }, - AnonymousScan { - function, - predicate, - options, - output_schema, - .. - } => { - let mut state = ExpressionConversionState::default(); - let options = Arc::try_unwrap(options).unwrap_or_else(|options| (*options).clone()); - let predicate = predicate - .map(|pred| { - create_physical_expr( - pred, - Context::Default, - expr_arena, - output_schema.as_ref(), - &mut state, - ) - }) - .map_or(Ok(None), |v| v.map(Some))?; - Ok(Box::new(executors::AnonymousScanExec { - function, - predicate, - options, - predicate_has_windows: state.has_windows, - })) - }, Sort { input, by_column, diff --git a/crates/polars-lazy/src/physical_plan/state.rs b/crates/polars-lazy/src/physical_plan/state.rs index ebf1e501885c..c8db863a126e 100644 --- a/crates/polars-lazy/src/physical_plan/state.rs +++ b/crates/polars-lazy/src/physical_plan/state.rs @@ -6,8 +6,8 @@ use bitflags::bitflags; use once_cell::sync::OnceCell; use polars_core::config::verbose; use polars_core::frame::group_by::GroupsProxy; -use polars_core::frame::hash_join::ChunkJoinOptIds; use polars_core::prelude::*; +use polars_ops::prelude::ChunkJoinOptIds; #[cfg(any(feature = "parquet", feature = "csv", feature = "ipc"))] use polars_plan::logical_plan::FileFingerPrint; @@ -16,7 +16,7 @@ use super::file_cache::FileCache; use crate::physical_plan::node_timer::NodeTimer; pub type JoinTuplesCache = Arc>>; -pub type GroupsProxyCache = Arc>>; +pub type GroupsProxyCache = Arc>>; bitflags! { #[repr(transparent)] @@ -149,7 +149,7 @@ impl ExecutionState { schema_cache: Default::default(), #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] file_cache: FileCache::new(finger_prints), - group_tuples: Arc::new(Mutex::new(PlHashMap::default())), + group_tuples: Arc::new(RwLock::new(PlHashMap::default())), join_tuples: Arc::new(Mutex::new(PlHashMap::default())), branch_idx: 0, flags: AtomicU8::new(StateFlags::init().as_u8()), @@ -204,7 +204,7 @@ impl ExecutionState { /// Clear the cache used by the Window expressions pub(crate) fn clear_window_expr_cache(&self) { { - let mut lock = self.group_tuples.lock().unwrap(); + let mut lock = self.group_tuples.write().unwrap(); lock.clear(); } let mut lock = self.join_tuples.lock().unwrap(); diff --git a/crates/polars-lazy/src/physical_plan/streaming/checks.rs b/crates/polars-lazy/src/physical_plan/streaming/checks.rs index 473b429f8780..02f6f636e9b4 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/checks.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/checks.rs @@ -1,4 +1,4 @@ -use polars_core::prelude::{JoinArgs, JoinType}; +use polars_ops::prelude::*; use polars_plan::prelude::*; pub(super) fn is_streamable_sort(args: &SortArguments) -> bool { diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index baf18e913ef3..240fb7aaa7ec 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -122,7 +122,7 @@ pub(super) fn construct( // the file sink is always to the top of the tree // not every branch has a final sink. For instance rhs join branches if let Some(node) = branch.get_final_sink() { - if matches!(lp_arena.get(node), ALogicalPlan::FileSink { .. }) { + if matches!(lp_arena.get(node), ALogicalPlan::Sink { .. }) { final_sink = Some(node) } } @@ -191,20 +191,16 @@ pub(super) fn construct( return Ok(None); }; let insertion_location = match lp_arena.get(final_sink) { - FileSink { + // this was inserted only during conversion and does not exist + // in the original tree, so we take the input, as that's where + // we connect into the original tree. + Sink { input, - payload: FileSinkOptions { file_type, .. }, - } => { - // this was inserted only during conversion and does not exist - // in the original tree, so we take the input, as that's where - // we connect into the original tree. - if matches!(file_type, FileType::Memory) { - *input - } else { - // default case if the tree ended with a file_sink - final_sink - } - }, + payload: SinkType::Memory, + } => *input, + // Other sinks were not inserted during conversion, + // so they are returned as-is + Sink { .. } => final_sink, _ => unreachable!(), }; // keep the original around for formatting purposes diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 20369fafd237..e23357f58d40 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -37,13 +37,10 @@ fn process_non_streamable_node( fn insert_file_sink(mut root: Node, lp_arena: &mut Arena) -> Node { // The pipelines need a final sink, we insert that here. // this allows us to split at joins/unions and share a sink - if !matches!(lp_arena.get(root), ALogicalPlan::FileSink { .. }) { - root = lp_arena.add(ALogicalPlan::FileSink { + if !matches!(lp_arena.get(root), ALogicalPlan::Sink { .. }) { + root = lp_arena.add(ALogicalPlan::Sink { input: root, - payload: FileSinkOptions { - path: Default::default(), - file_type: FileType::Memory, - }, + payload: SinkType::Memory, }) } root @@ -74,9 +71,11 @@ pub(crate) fn insert_streaming_nodes( // to streaming allow_partial: bool, ) -> PolarsResult { + scratch.clear(); + // this is needed to determine which side of the joins should be // traversed first - set_estimated_row_counts(root, lp_arena, expr_arena, 0); + set_estimated_row_counts(root, lp_arena, expr_arena, 0, scratch); scratch.clear(); @@ -154,7 +153,7 @@ pub(crate) fn insert_streaming_nodes( state.operators_sinks.push(PipelineNode::Sink(root)); stack.push((*input, state, current_idx)) }, - FileSink { input, .. } => { + Sink { input, .. } => { state.streamable = true; state.operators_sinks.push(PipelineNode::Sink(root)); stack.push((*input, state, current_idx)) diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index 177817d7d973..81c31fe943db 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -1,4 +1,5 @@ pub(crate) use polars_ops::prelude::*; +pub use polars_ops::prelude::{JoinArgs, JoinType, JoinValidation}; pub use polars_plan::logical_plan::{ AnonymousScan, AnonymousScanOptions, Literal, LiteralValue, LogicalPlan, Null, NULL, }; @@ -18,3 +19,4 @@ pub(crate) use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; pub use crate::frame::*; pub use crate::physical_plan::expressions::*; +pub(crate) use crate::scan::*; diff --git a/crates/polars-lazy/src/frame/anonymous_scan.rs b/crates/polars-lazy/src/scan/anonymous_scan.rs similarity index 97% rename from crates/polars-lazy/src/frame/anonymous_scan.rs rename to crates/polars-lazy/src/scan/anonymous_scan.rs index e28871b78efe..2a26305eb84b 100644 --- a/crates/polars-lazy/src/frame/anonymous_scan.rs +++ b/crates/polars-lazy/src/scan/anonymous_scan.rs @@ -6,7 +6,7 @@ use crate::prelude::*; #[derive(Clone)] pub struct ScanArgsAnonymous { pub infer_schema_length: Option, - pub schema: Option, + pub schema: Option, pub skip_rows: Option, pub n_rows: Option, pub row_count: Option, diff --git a/crates/polars-lazy/src/frame/csv.rs b/crates/polars-lazy/src/scan/csv.rs similarity index 92% rename from crates/polars-lazy/src/frame/csv.rs rename to crates/polars-lazy/src/scan/csv.rs index be497c336388..639b508dd7b2 100644 --- a/crates/polars-lazy/src/frame/csv.rs +++ b/crates/polars-lazy/src/scan/csv.rs @@ -1,8 +1,9 @@ use std::path::{Path, PathBuf}; use polars_core::prelude::*; -use polars_io::csv::utils::{get_reader_bytes, infer_file_schema}; +use polars_io::csv::utils::infer_file_schema; use polars_io::csv::{CsvEncoding, NullValues}; +use polars_io::utils::get_reader_bytes; use polars_io::RowCount; use crate::frame::LazyFileListReader; @@ -12,7 +13,8 @@ use crate::prelude::*; #[cfg(feature = "csv")] pub struct LazyCsvReader<'a> { path: PathBuf, - delimiter: u8, + paths: Vec, + separator: u8, has_header: bool, ignore_errors: bool, skip_rows: usize, @@ -38,10 +40,15 @@ pub struct LazyCsvReader<'a> { #[cfg(feature = "csv")] impl<'a> LazyCsvReader<'a> { + pub fn new_paths(paths: Vec) -> Self { + Self::new("").with_paths(paths) + } + pub fn new(path: impl AsRef) -> Self { LazyCsvReader { path: path.as_ref().to_owned(), - delimiter: b',', + paths: vec![], + separator: b',', has_header: true, ignore_errors: false, skip_rows: 0, @@ -133,10 +140,10 @@ impl<'a> LazyCsvReader<'a> { self } - /// Set the CSV file's column delimiter as a byte character + /// Set the CSV file's column separator as a byte character #[must_use] - pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.delimiter = delimiter; + pub fn with_separator(mut self, separator: u8) -> Self { + self.separator = separator; self } @@ -224,12 +231,12 @@ impl<'a> LazyCsvReader<'a> { where F: Fn(Schema) -> PolarsResult, { - let mut file = if let Some(mut paths) = self.glob()? { + let mut file = if let Some(mut paths) = self.iter_paths()? { let path = match paths.next() { Some(globresult) => globresult?, None => polars_bail!(ComputeError: "globbing pattern did not match any files"), }; - polars_utils::open_file(&path) + polars_utils::open_file(path) } else { polars_utils::open_file(&self.path) }?; @@ -238,7 +245,7 @@ impl<'a> LazyCsvReader<'a> { let (schema, _, _) = infer_file_schema( &reader_bytes, - self.delimiter, + self.separator, self.infer_schema_length, self.has_header, // we set it to None and modify them after the schema is updated @@ -269,7 +276,7 @@ impl LazyFileListReader for LazyCsvReader<'_> { fn finish_no_glob(self) -> PolarsResult { let mut lf: LazyFrame = LogicalPlanBuilder::scan_csv( self.path, - self.delimiter, + self.separator, self.has_header, self.ignore_errors, self.skip_rows, @@ -301,11 +308,20 @@ impl LazyFileListReader for LazyCsvReader<'_> { &self.path } + fn paths(&self) -> &[PathBuf] { + &self.paths + } + fn with_path(mut self, path: PathBuf) -> Self { self.path = path; self } + fn with_paths(mut self, paths: Vec) -> Self { + self.paths = paths; + self + } + fn rechunk(&self) -> bool { self.rechunk } diff --git a/crates/polars-lazy/src/frame/file_list_reader.rs b/crates/polars-lazy/src/scan/file_list_reader.rs similarity index 68% rename from crates/polars-lazy/src/frame/file_list_reader.rs rename to crates/polars-lazy/src/scan/file_list_reader.rs index 8824406f2599..1ba125312508 100644 --- a/crates/polars-lazy/src/frame/file_list_reader.rs +++ b/crates/polars-lazy/src/scan/file_list_reader.rs @@ -1,17 +1,17 @@ use std::path::{Path, PathBuf}; -use polars_core::cloud::CloudOptions; use polars_core::error::to_compute_err; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::{is_cloud_url, RowCount}; use crate::prelude::*; -pub type GlobIterator = Box>>; +pub type PathIterator = Box>>; // cloud_options is used only with async feature #[allow(unused_variables)] -fn polars_glob(pattern: &str, cloud_options: Option<&CloudOptions>) -> PolarsResult { +fn polars_glob(pattern: &str, cloud_options: Option<&CloudOptions>) -> PolarsResult { if is_cloud_url(pattern) { #[cfg(feature = "async")] { @@ -34,12 +34,14 @@ fn polars_glob(pattern: &str, cloud_options: Option<&CloudOptions>) -> PolarsRes /// Use [LazyFileListReader::finish] to get the final [LazyFrame]. pub trait LazyFileListReader: Clone { /// Get the final [LazyFrame]. - fn finish(self) -> PolarsResult { - if let Some(paths) = self.glob()? { + fn finish(mut self) -> PolarsResult { + if let Some(paths) = self.iter_paths()? { let lfs = paths - .map(|r| { + .enumerate() + .map(|(i, r)| { let path = r?; - self.clone() + let lf = self + .clone() .with_path(path.clone()) .with_rechunk(false) .finish_no_glob() @@ -47,7 +49,15 @@ pub trait LazyFileListReader: Clone { polars_err!( ComputeError: "error while reading {}: {}", path.display(), e ) - }) + }); + + if i == 0 { + let lf = lf?; + self.set_known_schema(lf.schema()?); + Ok(lf) + } else { + lf + } }) .collect::>>()?; @@ -88,11 +98,18 @@ pub trait LazyFileListReader: Clone { /// It can be potentially a glob pattern. fn path(&self) -> &Path; + fn paths(&self) -> &[PathBuf]; + /// Set path of the scanned file. /// Support glob patterns. #[must_use] fn with_path(self, path: PathBuf) -> Self; + /// Set paths of the scanned files. + /// Doesn't glob patterns. + #[must_use] + fn with_paths(self, paths: Vec) -> Self; + /// Rechunk the memory to contiguous chunks when parsing is done. fn rechunk(&self) -> bool; @@ -112,15 +129,31 @@ pub trait LazyFileListReader: Clone { None } + /// Set a schema on first glob pattern, so that others don't have to fetch metadata + /// from cloud + fn known_schema(&self) -> Option { + None + } + + fn set_known_schema(&mut self, _known_schema: SchemaRef) {} + /// Get list of files referenced by this reader. /// /// Returns [None] if path is not a glob pattern. - fn glob(&self) -> PolarsResult> { - let path_str = self.path().to_string_lossy(); - if path_str.contains('*') || path_str.contains('?') || path_str.contains('[') { - polars_glob(&path_str, self.cloud_options()).map(Some) + fn iter_paths(&self) -> PolarsResult> { + let paths = self.paths(); + if paths.is_empty() { + let path_str = self.path().to_string_lossy(); + if path_str.contains('*') || path_str.contains('?') || path_str.contains('[') { + polars_glob(&path_str, self.cloud_options()).map(Some) + } else { + Ok(None) + } } else { - Ok(None) + polars_ensure!(self.path().to_string_lossy() == "", InvalidOperation: "expected only a single path argument"); + // Lint is incorrect as we need static lifetime. + #[allow(clippy::unnecessary_to_owned)] + Ok(Some(Box::new(paths.to_vec().into_iter().map(Ok)))) } } } diff --git a/crates/polars-lazy/src/frame/ipc.rs b/crates/polars-lazy/src/scan/ipc.rs similarity index 82% rename from crates/polars-lazy/src/frame/ipc.rs rename to crates/polars-lazy/src/scan/ipc.rs index 9401982f0698..e369b990f4df 100644 --- a/crates/polars-lazy/src/frame/ipc.rs +++ b/crates/polars-lazy/src/scan/ipc.rs @@ -30,11 +30,16 @@ impl Default for ScanArgsIpc { struct LazyIpcReader { args: ScanArgsIpc, path: PathBuf, + paths: Vec, } impl LazyIpcReader { fn new(path: PathBuf, args: ScanArgsIpc) -> Self { - Self { args, path } + Self { + args, + path, + paths: vec![], + } } } @@ -70,11 +75,20 @@ impl LazyFileListReader for LazyIpcReader { self.path.as_path() } + fn paths(&self) -> &[PathBuf] { + &self.paths + } + fn with_path(mut self, path: PathBuf) -> Self { self.path = path; self } + fn with_paths(mut self, paths: Vec) -> Self { + self.paths = paths; + self + } + fn rechunk(&self) -> bool { self.args.rechunk } @@ -98,4 +112,10 @@ impl LazyFrame { pub fn scan_ipc(path: impl AsRef, args: ScanArgsIpc) -> PolarsResult { LazyIpcReader::new(path.as_ref().to_owned(), args).finish() } + + pub fn scan_ipc_files(paths: Vec, args: ScanArgsIpc) -> PolarsResult { + LazyIpcReader::new(PathBuf::new(), args) + .with_paths(paths) + .finish() + } } diff --git a/crates/polars-lazy/src/scan/mod.rs b/crates/polars-lazy/src/scan/mod.rs new file mode 100644 index 000000000000..ed8705a01b8a --- /dev/null +++ b/crates/polars-lazy/src/scan/mod.rs @@ -0,0 +1,12 @@ +pub(super) mod anonymous_scan; +#[cfg(feature = "csv")] +pub(super) mod csv; +pub(super) mod file_list_reader; +#[cfg(feature = "ipc")] +pub(super) mod ipc; +#[cfg(feature = "json")] +pub(super) mod ndjson; +#[cfg(feature = "parquet")] +pub(super) mod parquet; + +use file_list_reader::*; diff --git a/crates/polars-lazy/src/frame/ndjson.rs b/crates/polars-lazy/src/scan/ndjson.rs similarity index 83% rename from crates/polars-lazy/src/frame/ndjson.rs rename to crates/polars-lazy/src/scan/ndjson.rs index 45a5edbe4656..bd21e00ddf6a 100644 --- a/crates/polars-lazy/src/frame/ndjson.rs +++ b/crates/polars-lazy/src/scan/ndjson.rs @@ -3,24 +3,31 @@ use std::path::{Path, PathBuf}; use polars_core::prelude::*; use polars_io::RowCount; -use super::{LazyFileListReader, LazyFrame, ScanArgsAnonymous}; +use super::*; +use crate::prelude::{LazyFrame, ScanArgsAnonymous}; #[derive(Clone)] pub struct LazyJsonLineReader { pub(crate) path: PathBuf, + paths: Vec, pub(crate) batch_size: Option, pub(crate) low_memory: bool, pub(crate) rechunk: bool, - pub(crate) schema: Option, + pub(crate) schema: Option, pub(crate) row_count: Option, pub(crate) infer_schema_length: Option, pub(crate) n_rows: Option, } impl LazyJsonLineReader { + pub fn new_paths(paths: Vec) -> Self { + Self::new(PathBuf::new()).with_paths(paths) + } + pub fn new(path: impl AsRef) -> Self { LazyJsonLineReader { path: path.as_ref().to_path_buf(), + paths: vec![], batch_size: None, low_memory: false, rechunk: true, @@ -45,6 +52,7 @@ impl LazyJsonLineReader { } /// Set the number of rows to use when inferring the json schema. /// the default is 100 rows. + /// Ignored when the schema is specified explicitly using [`Self::with_schema`]. /// Setting to `None` will do a full table scan, very slow. #[must_use] pub fn with_infer_schema_length(mut self, num_rows: Option) -> Self { @@ -53,8 +61,8 @@ impl LazyJsonLineReader { } /// Set the JSON file's schema #[must_use] - pub fn with_schema(mut self, schema: Schema) -> Self { - self.schema = Some(schema); + pub fn with_schema(mut self, schema: Option) -> Self { + self.schema = schema; self } @@ -90,11 +98,20 @@ impl LazyFileListReader for LazyJsonLineReader { &self.path } + fn paths(&self) -> &[PathBuf] { + &self.paths + } + fn with_path(mut self, path: PathBuf) -> Self { self.path = path; self } + fn with_paths(mut self, paths: Vec) -> Self { + self.paths = paths; + self + } + fn rechunk(&self) -> bool { self.rechunk } diff --git a/crates/polars-lazy/src/frame/parquet.rs b/crates/polars-lazy/src/scan/parquet.rs similarity index 69% rename from crates/polars-lazy/src/frame/parquet.rs rename to crates/polars-lazy/src/scan/parquet.rs index c71ed3b7821d..d4f42c014163 100644 --- a/crates/polars-lazy/src/frame/parquet.rs +++ b/crates/polars-lazy/src/scan/parquet.rs @@ -1,7 +1,7 @@ use std::path::{Path, PathBuf}; -use polars_core::cloud::CloudOptions; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::parquet::ParallelStrategy; use polars_io::RowCount; @@ -17,6 +17,7 @@ pub struct ScanArgsParquet { pub low_memory: bool, pub cloud_options: Option, pub use_statistics: bool, + pub hive_partitioning: bool, } impl Default for ScanArgsParquet { @@ -30,6 +31,7 @@ impl Default for ScanArgsParquet { low_memory: false, cloud_options: None, use_statistics: true, + hive_partitioning: false, } } } @@ -38,16 +40,24 @@ impl Default for ScanArgsParquet { struct LazyParquetReader { args: ScanArgsParquet, path: PathBuf, + paths: Vec, + known_schema: Option, } impl LazyParquetReader { fn new(path: PathBuf, args: ScanArgsParquet) -> Self { - Self { args, path } + Self { + args, + path, + paths: vec![], + known_schema: None, + } } } impl LazyFileListReader for LazyParquetReader { - fn finish_no_glob(self) -> PolarsResult { + fn finish_no_glob(mut self) -> PolarsResult { + let known_schema = self.known_schema(); let row_count = self.args.row_count; let path = self.path; let mut lf: LazyFrame = LogicalPlanBuilder::scan_parquet( @@ -60,6 +70,8 @@ impl LazyFileListReader for LazyParquetReader { self.args.low_memory, self.args.cloud_options, self.args.use_statistics, + self.args.hive_partitioning, + known_schema, )? .build() .into(); @@ -68,6 +80,7 @@ impl LazyFileListReader for LazyParquetReader { if let Some(row_count) = row_count { lf = lf.with_row_count(&row_count.name, Some(row_count.offset)) } + self.known_schema = Some(lf.schema()?); lf.opt_state.file_caching = true; Ok(lf) @@ -77,11 +90,20 @@ impl LazyFileListReader for LazyParquetReader { self.path.as_path() } + fn paths(&self) -> &[PathBuf] { + &self.paths + } + fn with_path(mut self, path: PathBuf) -> Self { self.path = path; self } + fn with_paths(mut self, paths: Vec) -> Self { + self.paths = paths; + self + } + fn rechunk(&self) -> bool { self.args.rechunk } @@ -99,6 +121,13 @@ impl LazyFileListReader for LazyParquetReader { self.args.n_rows } + fn known_schema(&self) -> Option { + self.known_schema.clone() + } + fn set_known_schema(&mut self, known_schema: SchemaRef) { + self.known_schema = Some(known_schema); + } + fn row_count(&self) -> Option<&RowCount> { self.args.row_count.as_ref() } @@ -109,4 +138,11 @@ impl LazyFrame { pub fn scan_parquet(path: impl AsRef, args: ScanArgsParquet) -> PolarsResult { LazyParquetReader::new(path.as_ref().to_owned(), args).finish() } + + /// Create a LazyFrame directly from a parquet scan. + pub fn scan_parquet_files(paths: Vec, args: ScanArgsParquet) -> PolarsResult { + LazyParquetReader::new(PathBuf::new(), args) + .with_paths(paths) + .finish() + } } diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index d0e620056ea1..87e9e594052b 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -52,6 +52,7 @@ fn test_agg_unique_first() -> PolarsResult<()> { } #[test] +#[cfg(feature = "csv")] fn test_lazy_agg_scan() { let lf = scan_foods_csv; let df = lf().min().collect().unwrap(); @@ -183,7 +184,7 @@ fn test_power_in_agg_list1() -> PolarsResult<()> { .collect()?; let agg = out.column("foo")?.list()?; - let first = agg.get(0).unwrap(); + let first = agg.get_as_series(0).unwrap(); let vals = first.f64()?; assert_eq!(Vec::from(vals), &[Some(1.0), Some(4.0), Some(25.0)]); diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 1a880980e00b..d7eae6909f45 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -241,7 +241,6 @@ fn test_csv_globbing() -> PolarsResult<()> { let df = lf.clone().collect()?; assert_eq!(df.shape(), (100, 4)); let df = LazyCsvReader::new(glob).finish()?.slice(20, 60).collect()?; - dbg!(&full_df, &df); assert!(full_df.slice(20, 60).frame_equal(&df)); let mut expr_arena = Arena::with_capacity(16); @@ -435,10 +434,10 @@ fn scan_predicate_on_set_null_values() -> PolarsResult<()> { #[test] fn scan_anonymous_fn() -> PolarsResult<()> { - let function = Arc::new(|_scan_opts: AnonymousScanOptions| Ok(fruits_cars())); + let function = Arc::new(|_scan_opts: AnonymousScanArgs| Ok(fruits_cars())); let args = ScanArgsAnonymous { - schema: Some(fruits_cars().schema()), + schema: Some(Arc::new(fruits_cars().schema())), ..ScanArgsAnonymous::default() }; diff --git a/crates/polars-lazy/src/tests/mod.rs b/crates/polars-lazy/src/tests/mod.rs index 2e1d991c1a95..641e250f6dd7 100644 --- a/crates/polars-lazy/src/tests/mod.rs +++ b/crates/polars-lazy/src/tests/mod.rs @@ -78,10 +78,13 @@ fn init_files() { match ext { ".parquet" => { - ParquetWriter::new(f) - .with_statistics(true) - .finish(&mut df) - .unwrap(); + #[cfg(feature = "parquet")] + { + ParquetWriter::new(f) + .with_statistics(true) + .finish(&mut df) + .unwrap(); + } }, ".ipc" => { IpcWriter::new(f).finish(&mut df).unwrap(); diff --git a/crates/polars-lazy/src/tests/optimization_checks.rs b/crates/polars-lazy/src/tests/optimization_checks.rs index ab6dcb57b177..b3763e2a13ab 100644 --- a/crates/polars-lazy/src/tests/optimization_checks.rs +++ b/crates/polars-lazy/src/tests/optimization_checks.rs @@ -38,6 +38,25 @@ pub(crate) fn predicate_at_scan(q: LazyFrame) -> bool { }) } +pub(crate) fn predicate_at_all_scans(q: LazyFrame) -> bool { + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + (&lp_arena).iter(lp).all(|(_, lp)| { + use ALogicalPlan::*; + matches!( + lp, + DataFrameScan { + selection: Some(_), + .. + } | Scan { + predicate: Some(_), + .. + } + ) + }) +} + pub(crate) fn is_pipeline(q: LazyFrame) -> bool { let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); @@ -147,6 +166,7 @@ fn test_no_left_join_pass() -> PolarsResult<()> { } #[test] +#[cfg(feature = "parquet")] pub fn test_simple_slice() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); let q = scan_foods_parquet(false).limit(3); @@ -166,6 +186,7 @@ pub fn test_simple_slice() -> PolarsResult<()> { } #[test] +#[cfg(feature = "parquet")] #[cfg(feature = "cse")] pub fn test_slice_pushdown_join() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); @@ -202,6 +223,7 @@ pub fn test_slice_pushdown_join() -> PolarsResult<()> { } #[test] +#[cfg(feature = "parquet")] pub fn test_slice_pushdown_group_by() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); let q = scan_foods_parquet(false).limit(100); @@ -231,6 +253,7 @@ pub fn test_slice_pushdown_group_by() -> PolarsResult<()> { } #[test] +#[cfg(feature = "parquet")] pub fn test_slice_pushdown_sort() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); let q = scan_foods_parquet(false).limit(100); @@ -527,6 +550,7 @@ fn test_with_column_prune() -> PolarsResult<()> { } #[test] +#[cfg(feature = "csv")] fn test_slice_at_scan_group_by() -> PolarsResult<()> { let ldf = scan_foods_csv(); diff --git a/crates/polars-lazy/src/tests/predicate_queries.rs b/crates/polars-lazy/src/tests/predicate_queries.rs index 0a854bf420a6..13c39ac620a2 100644 --- a/crates/polars-lazy/src/tests/predicate_queries.rs +++ b/crates/polars-lazy/src/tests/predicate_queries.rs @@ -1,6 +1,7 @@ use super::*; #[test] +#[cfg(feature = "parquet")] fn test_multiple_roots() -> PolarsResult<()> { let mut expr_arena = Arena::with_capacity(16); let mut lp_arena = Arena::with_capacity(8); @@ -268,3 +269,36 @@ fn test_predicate_on_join_suffix_4788() -> PolarsResult<()> { Ok(()) } + +#[test] +fn test_push_join_col_predicates_to_both_sides_7247() -> PolarsResult<()> { + let df1 = df! { + "a" => ["a1", "a2"], + "b" => ["b1", "b2"], + }?; + let df2 = df! { + "a" => ["a1", "a1", "a2"], + "b2" => ["b1", "b1", "b2"], + "c" => ["a1", "c", "a2"] + }?; + let df = df1.lazy().join( + df2.lazy(), + [col("a"), col("b")], + [col("a"), col("b2")], + JoinArgs::new(JoinType::Inner), + ); + let q = df + .filter(col("a").eq(lit("a1"))) + .filter(col("a").eq(col("c"))); + + predicate_at_all_scans(q.clone()); + + let out = q.collect()?; + let expected = df![ + "a" => ["a1"], + "b" => ["b1"], + "c" => ["a1"], + ]?; + assert_eq!(out, expected); + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index fe9035c0544a..00e4ca1a17fd 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -231,6 +231,7 @@ fn test_lazy_query_2() { } #[test] +#[cfg(feature = "csv")] fn test_lazy_query_3() { // query checks if schema of scanning is not changed by aggregation let _ = scan_foods_csv() @@ -299,7 +300,7 @@ fn test_lazy_query_5() { .unwrap() .list() .unwrap() - .get(0) + .get_as_series(0) .unwrap(); assert_eq!(s.len(), 2); let s = out @@ -307,7 +308,7 @@ fn test_lazy_query_5() { .unwrap() .list() .unwrap() - .get(0) + .get_as_series(0) .unwrap(); assert_eq!(s.len(), 2); } @@ -643,6 +644,7 @@ fn test_type_coercion() { } #[test] +#[cfg(feature = "csv")] fn test_lazy_partition_agg() { let df = df! { "foo" => &[1, 1, 2, 2, 3], @@ -670,7 +672,7 @@ fn test_lazy_partition_agg() { .collect() .unwrap(); let cat_agg_list = out.select_at_idx(1).unwrap(); - let fruit_series = cat_agg_list.list().unwrap().get(0).unwrap(); + let fruit_series = cat_agg_list.list().unwrap().get_as_series(0).unwrap(); let fruit_list = fruit_series.i64().unwrap(); assert_eq!( Vec::from(fruit_list), @@ -1132,20 +1134,17 @@ fn test_fill_forward() -> PolarsResult<()> { let out = df .lazy() - .select([col("b").forward_fill(None).over_with_options( - [col("a")], - WindowOptions { - mapping: WindowMapping::Join, - }, - )]) + .select([col("b") + .forward_fill(None) + .over_with_options([col("a")], WindowMapping::Join.into())]) .collect()?; let agg = out.column("b")?.list()?; - let a: Series = agg.get(0).unwrap(); + let a: Series = agg.get_as_series(0).unwrap(); assert!(a.series_equal(&Series::new("b", &[1, 1]))); - let a: Series = agg.get(2).unwrap(); + let a: Series = agg.get_as_series(2).unwrap(); assert!(a.series_equal(&Series::new("b", &[1, 1]))); - let a: Series = agg.get(1).unwrap(); + let a: Series = agg.get_as_series(1).unwrap(); assert_eq!(a.null_count(), 1); Ok(()) } @@ -1297,12 +1296,7 @@ fn test_filter_after_shift_in_groups() -> PolarsResult<()> { col("B") .shift(1) .filter(col("B").shift(1).gt(lit(4))) - .over_with_options( - [col("fruits")], - WindowOptions { - mapping: WindowMapping::Join, - }, - ) + .over_with_options([col("fruits")], WindowMapping::Join.into()) .alias("filtered"), ]) .collect()?; @@ -1310,7 +1304,7 @@ fn test_filter_after_shift_in_groups() -> PolarsResult<()> { assert_eq!( out.column("filtered")? .list()? - .get(0) + .get_as_series(0) .unwrap() .i32()? .get(0) @@ -1320,14 +1314,21 @@ fn test_filter_after_shift_in_groups() -> PolarsResult<()> { assert_eq!( out.column("filtered")? .list()? - .get(1) + .get_as_series(1) .unwrap() .i32()? .get(0) .unwrap(), 5 ); - assert_eq!(out.column("filtered")?.list()?.get(2).unwrap().len(), 0); + assert_eq!( + out.column("filtered")? + .list()? + .get_as_series(2) + .unwrap() + .len(), + 0 + ); Ok(()) } @@ -1355,6 +1356,7 @@ fn test_lazy_ternary_predicate_pushdown() -> PolarsResult<()> { } #[test] +#[cfg(feature = "dtype-categorical")] fn test_categorical_addition() -> PolarsResult<()> { let df = fruits_cars(); @@ -1463,6 +1465,7 @@ fn test_list_in_select_context() -> PolarsResult<()> { } #[test] +#[cfg(feature = "round_series")] fn test_round_after_agg() -> PolarsResult<()> { let df = fruits_cars(); @@ -1564,7 +1567,7 @@ fn test_group_by_rank() -> PolarsResult<()> { .collect()?; let out = out.column("B")?; - let out = out.list()?.get(1).unwrap(); + let out = out.list()?.get_as_series(1).unwrap(); let out = out.idx()?; assert_eq!(Vec::from(out), &[Some(1)]); @@ -1654,12 +1657,7 @@ fn test_single_ranked_group() -> PolarsResult<()> { }, None, ) - .over_with_options( - [col("group")], - WindowOptions { - mapping: WindowMapping::Join, - }, - )]) + .over_with_options([col("group")], WindowMapping::Join.into())]) .collect()?; let out = out.column("value")?.explode()?; @@ -1695,6 +1693,7 @@ fn empty_df() -> PolarsResult<()> { } #[test] +#[cfg(feature = "abs")] fn test_apply_flatten() -> PolarsResult<()> { let df = df![ "A"=> [1.1435, 2.223456, 3.44732, -1.5234, -2.1238, -3.2923], diff --git a/crates/polars-lazy/src/utils.rs b/crates/polars-lazy/src/utils.rs index 0c824509b00a..e8fa1ed4df79 100644 --- a/crates/polars-lazy/src/utils.rs +++ b/crates/polars-lazy/src/utils.rs @@ -11,16 +11,8 @@ pub(crate) fn agg_source_paths( ) { lp_arena.iter(root_lp).for_each(|(_, lp)| { use ALogicalPlan::*; - match lp { - Scan { path, .. } => { - paths.insert(path.clone()); - }, - // always block parallel on anonymous sources - // as we cannot know if they will lock or not. - AnonymousScan { .. } => { - paths.insert("anonymous".into()); - }, - _ => {}, + if let Scan { path, .. } = lp { + paths.insert(path.clone()); } }) } diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 7084dea297a9..e0e85a31804f 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -9,25 +9,37 @@ repository = { workspace = true } description = "More operations on Polars data structures" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", default-features = false } -polars-core = { version = "0.32.0", path = "../polars-core", features = [], default-features = false } -polars-json = { version = "0.32.0", optional = true, path = "../polars-json", default-features = false } -polars-utils = { version = "0.32.0", path = "../polars-utils", default-features = false } +polars-arrow = { workspace = true, default-features = false } +polars-core = { workspace = true, features = ["algorithm_group_by"], default-features = false } +polars-error = { workspace = true } +polars-json = { workspace = true, optional = true } +polars-utils = { workspace = true, default-features = false } +ahash = { workspace = true } argminmax = { version = "0.6.1", default-features = false, features = ["float"] } arrow = { workspace = true } -base64 = { version = "0.21", optional = true } +base64 = { workspace = true, optional = true } +bytemuck = { workspace = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } either = { workspace = true } -hex = { version = "0.4", optional = true } +hashbrown = { workspace = true } +hex = { workspace = true, optional = true } indexmap = { workspace = true } jsonpath_lib = { version = "0.3", optional = true, git = "https://github.com/ritchie46/jsonpath", branch = "improve_compiled" } memchr = { workspace = true } +num-traits = { workspace = true } +rand = { workspace = true, optional = true, features = ["small_rng", "std"] } +rand_distr = { workspace = true, optional = true } +rayon = { workspace = true } +regex = { workspace = true } serde = { workspace = true, features = ["derive"], optional = true } serde_json = { workspace = true, optional = true } smartstring = { workspace = true } +[dev-dependencies] +rand = { workspace = true } + [build-dependencies] version_check = { workspace = true } @@ -51,13 +63,16 @@ propagate_nans = [] performant = ["polars-core/performant", "fused"] big_idx = ["polars-core/bigidx"] round_series = [] -is_first = [] +is_first_distinct = [] +is_last_distinct = [] is_unique = [] approx_unique = [] fused = [] cutqcut = ["dtype-categorical", "dtype-struct"] rle = ["dtype-struct"] timezones = ["chrono-tz", "chrono"] +random = ["rand", "rand_distr"] +rank = ["rand"] # extra utilities for BinaryChunked binary_encoding = ["base64", "hex"] @@ -75,19 +90,25 @@ string_from_radix = ["polars-core/strings"] extract_jsonpath = ["serde_json", "jsonpath_lib", "polars-json"] log = [] hash = [] +zip_with = ["polars-core/zip_with"] group_by_list = ["polars-core/group_by_list"] rolling_window = ["polars-core/rolling_window"] moment = ["polars-core/moment"] +mode = [] search_sorted = [] merge_sorted = [] top_k = [] pivot = ["polars-core/reinterpret"] -cross_join = ["polars-core/cross_join"] +cross_join = [] chunked_ids = ["polars-core/chunked_ids"] asof_join = ["polars-core/asof_join"] -semi_anti_join = ["polars-core/semi_anti_join"] +semi_anti_join = [] list_take = [] list_sets = [] list_any_all = [] +list_drop_nulls = [] extract_groups = ["dtype-struct", "polars-core/regex"] is_in = ["polars-core/reinterpret"] +convert_index = [] +repeat_by = [] +peaks = [] diff --git a/crates/polars-ops/README.md b/crates/polars-ops/README.md index 3f5d3c005fbe..9c575ee43613 100644 --- a/crates/polars-ops/README.md +++ b/crates/polars-ops/README.md @@ -1,5 +1,5 @@ # polars-ops -`polars-ops` is a sub-crate that provides more operations on Polars data structures. +`polars-ops` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, providing extended operations on Polars data structures. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-ops/src/chunked_array/binary/namespace.rs b/crates/polars-ops/src/chunked_array/binary/namespace.rs index 59c444b4ac1d..400f6023a94b 100644 --- a/crates/polars-ops/src/chunked_array/binary/namespace.rs +++ b/crates/polars-ops/src/chunked_array/binary/namespace.rs @@ -6,26 +6,27 @@ use base64::engine::general_purpose; #[cfg(feature = "binary_encoding")] use base64::Engine as _; use memchr::memmem::find; +use polars_core::prelude::arity::binary_elementwise_values; use super::*; pub trait BinaryNameSpaceImpl: AsBinary { /// Check if binary contains given literal - fn contains(&self, lit: &[u8]) -> PolarsResult { + fn contains(&self, lit: &[u8]) -> BooleanChunked { let ca = self.as_binary(); let f = |s: &[u8]| find(s, lit).is_some(); - let mut out: BooleanChunked = if !ca.has_validity() { - ca.into_no_null_iter().map(f).collect() - } else { - ca.into_iter().map(|opt_s| opt_s.map(f)).collect() - }; - out.rename(ca.name()); - Ok(out) + ca.apply_values_generic(f) } - /// Check if strings contain a given literal - fn contains_literal(&self, lit: &[u8]) -> PolarsResult { - self.contains(lit) + fn contains_chunked(&self, lit: &BinaryChunked) -> BooleanChunked { + let ca = self.as_binary(); + match lit.len() { + 1 => match lit.get(0) { + Some(lit) => ca.contains(lit), + None => BooleanChunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise_values(ca, lit, |src, lit| find(src, lit).is_some()), + } } /// Check if strings ends with a substring @@ -46,6 +47,28 @@ pub trait BinaryNameSpaceImpl: AsBinary { out } + fn starts_with_chunked(&self, prefix: &BinaryChunked) -> BooleanChunked { + let ca = self.as_binary(); + match prefix.len() { + 1 => match prefix.get(0) { + Some(s) => self.starts_with(s), + None => BooleanChunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise_values(ca, prefix, |s, sub| s.starts_with(sub)), + } + } + + fn ends_with_chunked(&self, suffix: &BinaryChunked) -> BooleanChunked { + let ca = self.as_binary(); + match suffix.len() { + 1 => match suffix.get(0) { + Some(s) => self.ends_with(s), + None => BooleanChunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise_values(ca, suffix, |s, sub| s.ends_with(sub)), + } + } + #[cfg(feature = "binary_encoding")] fn hex_decode(&self, strict: bool) -> PolarsResult { let ca = self.as_binary(); diff --git a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs index 00e99fa1ef80..cc3713d91562 100644 --- a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs +++ b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs @@ -2,9 +2,9 @@ use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, }; use chrono::NaiveDateTime; -use chrono_tz::Tz; +use chrono_tz::{Tz, UTC}; use polars_arrow::kernels::convert_to_naive_local; -use polars_core::chunked_array::ops::arity::try_binary_elementwise_values; +use polars_core::chunked_array::ops::arity::try_binary_elementwise; use polars_core::prelude::*; fn parse_time_zone(s: &str) -> PolarsResult { @@ -20,6 +20,17 @@ pub fn replace_time_zone( let from_time_zone = datetime.time_zone().as_deref().unwrap_or("UTC"); let from_tz = parse_time_zone(from_time_zone)?; let to_tz = parse_time_zone(time_zone.unwrap_or("UTC"))?; + if (from_tz == to_tz) + & ((from_tz == UTC) + | ((ambiguous.len() == 1) & (unsafe { ambiguous.get_unchecked(0) } == Some("raise")))) + { + let mut out = datetime + .0 + .clone() + .into_datetime(datetime.time_unit(), time_zone.map(|x| x.to_string())); + out.set_sorted_flag(datetime.is_sorted_flag()); + return Ok(out); + } let timestamp_to_datetime: fn(i64) -> NaiveDateTime = match datetime.time_unit() { TimeUnit::Milliseconds => timestamp_ms_to_datetime, TimeUnit::Microseconds => timestamp_us_to_datetime, @@ -31,7 +42,7 @@ pub fn replace_time_zone( TimeUnit::Nanoseconds => datetime_to_timestamp_ns, }; let out = match ambiguous.len() { - 1 => match ambiguous.get(0) { + 1 => match unsafe { ambiguous.get_unchecked(0) } { Some(ambiguous) => datetime.0.try_apply(|timestamp| { let ndt = timestamp_to_datetime(timestamp); Ok(datetime_to_timestamp(convert_to_naive_local( @@ -40,14 +51,17 @@ pub fn replace_time_zone( }), _ => Ok(datetime.0.apply(|_| None)), }, - _ => { - try_binary_elementwise_values(datetime, ambiguous, |timestamp: i64, ambiguous: &str| { - let ndt = timestamp_to_datetime(timestamp); - Ok::(datetime_to_timestamp(convert_to_naive_local( - &from_tz, &to_tz, ndt, ambiguous, - )?)) - }) - }, + _ => try_binary_elementwise(datetime, ambiguous, |timestamp_opt, ambiguous_opt| { + match (timestamp_opt, ambiguous_opt) { + (Some(timestamp), Some(ambiguous)) => { + let ndt = timestamp_to_datetime(timestamp); + Ok(Some(datetime_to_timestamp(convert_to_naive_local( + &from_tz, &to_tz, ndt, ambiguous, + )?))) + }, + _ => Ok(None), + } + }), }; let mut out = out?.into_datetime(datetime.time_unit(), time_zone.map(|x| x.to_string())); if from_time_zone == "UTC" && ambiguous.len() == 1 && ambiguous.get(0).unwrap() == "raise" { diff --git a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs new file mode 100644 index 000000000000..57d8c6ff2de8 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs @@ -0,0 +1,227 @@ +use arrow::array::Array; +use arrow::bitmap::bitmask::BitMask; +use arrow::compute::concatenate::concatenate_validities; +use bytemuck::allocation::zeroed_vec; +use polars_core::prelude::gather::check_bounds_ca; +use polars_core::prelude::*; +use polars_utils::index::check_bounds; + +/// # Safety +/// For each index pair, pair.0 < len && pair.1 < ca.null_count() must hold. +unsafe fn gather_skip_nulls_idx_pairs_unchecked<'a, T: PolarsDataType>( + ca: &'a ChunkedArray, + mut index_pairs: Vec<(IdxSize, IdxSize)>, + len: usize, +) -> Vec> { + if index_pairs.is_empty() { + return zeroed_vec(len); + } + + // We sort by gather index so we can do the null scan in one pass. + index_pairs.sort_unstable_by_key(|t| t.1); + let mut pair_iter = index_pairs.iter().copied(); + let (mut out_idx, mut nonnull_idx); + (out_idx, nonnull_idx) = pair_iter.next().unwrap(); + + let mut out: Vec> = zeroed_vec(len); + let mut nonnull_prev_arrays = 0; + 'outer: for arr in ca.downcast_iter() { + let arr_nonnull_len = arr.len() - arr.null_count(); + let mut arr_scan_offset = 0; + let mut nonnull_before_offset = 0; + let mask = arr.validity().map(BitMask::from_bitmap).unwrap_or_default(); + + // Is our next nonnull_idx in this array? + while nonnull_idx as usize - nonnull_prev_arrays < arr_nonnull_len { + let nonnull_idx_in_arr = nonnull_idx as usize - nonnull_prev_arrays; + + let phys_idx_in_arr = if arr.null_count() == 0 { + // Happy fast path for full non-null array. + nonnull_idx_in_arr + } else { + mask.nth_set_bit_idx(nonnull_idx_in_arr - nonnull_before_offset, arr_scan_offset) + .unwrap() + }; + + unsafe { + let val = arr.value_unchecked(phys_idx_in_arr); + *out.get_unchecked_mut(out_idx as usize) = val.into(); + } + + arr_scan_offset = phys_idx_in_arr; + nonnull_before_offset = nonnull_idx_in_arr; + + let Some(next_pair) = pair_iter.next() else { + break 'outer; + }; + (out_idx, nonnull_idx) = next_pair; + } + + nonnull_prev_arrays += arr_nonnull_len; + } + + out +} + +pub trait ChunkGatherSkipNulls: Sized { + fn gather_skip_nulls(&self, indices: &I) -> PolarsResult; +} + +impl ChunkGatherSkipNulls<[IdxSize]> for ChunkedArray +where + ChunkedArray: ChunkFilter, +{ + fn gather_skip_nulls(&self, indices: &[IdxSize]) -> PolarsResult { + if self.null_count() == 0 { + return self.take(indices); + } + + // If we want many indices it's probably better to do a normal gather on + // a dense array. + if indices.len() >= self.len() / 4 { + return ChunkFilter::filter(self, &self.is_not_null()) + .unwrap() + .take(indices); + } + + let bound = self.len() - self.null_count(); + check_bounds(indices, bound as IdxSize)?; + + let index_pairs: Vec<_> = indices + .iter() + .enumerate() + .map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx)) + .collect(); + let gathered = + unsafe { gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.len()) }; + let arr = T::Array::from_zeroable_vec(gathered, self.dtype().clone()); + Ok(ChunkedArray::from_chunk_iter_like(self, [arr])) + } +} + +impl ChunkGatherSkipNulls for ChunkedArray +where + ChunkedArray: ChunkFilter, +{ + fn gather_skip_nulls(&self, indices: &IdxCa) -> PolarsResult { + if self.null_count() == 0 { + return self.take(indices); + } + + // If we want many indices it's probably better to do a normal gather on + // a dense array. + if indices.len() >= self.len() / 4 { + return ChunkFilter::filter(self, &self.is_not_null()) + .unwrap() + .take(indices); + } + + let bound = self.len() - self.null_count(); + check_bounds_ca(indices, bound as IdxSize)?; + + let index_pairs: Vec<_> = if indices.null_count() == 0 { + indices + .downcast_iter() + .flat_map(|arr| arr.values_iter()) + .enumerate() + .map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx)) + .collect() + } else { + // Filter *after* the enumerate so we place the non-null gather + // requests at the right places. + indices + .downcast_iter() + .flat_map(|arr| arr.iter()) + .enumerate() + .filter_map(|(out_idx, nonnull_idx)| Some((out_idx as IdxSize, *nonnull_idx?))) + .collect() + }; + let gathered = unsafe { + gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.as_ref().len()) + }; + + let mut arr = T::Array::from_zeroable_vec(gathered, self.dtype().clone()); + if indices.null_count() > 0 { + let array_refs: Vec<&dyn Array> = indices.chunks().iter().map(|x| &**x).collect(); + arr = arr.with_validity_typed(concatenate_validities(&array_refs)); + } + Ok(ChunkedArray::from_chunk_iter_like(self, [arr])) + } +} + +#[cfg(test)] +mod test { + use std::ops::Range; + + use rand::distributions::uniform::SampleUniform; + use rand::prelude::*; + use rand::rngs::SmallRng; + + use super::*; + + fn random_vec( + rng: &mut R, + val: Range, + len_range: Range, + ) -> Vec { + let n = rng.gen_range(len_range); + (0..n).map(|_| rng.gen_range(val.clone())).collect() + } + + fn random_filter(rng: &mut R, v: &[T], pr: Range) -> Vec> { + let p = rng.gen_range(pr); + let rand_filter = |x| Some(x).filter(|_| rng.gen::() < p); + v.iter().cloned().map(rand_filter).collect() + } + + fn ref_gather_nulls(v: Vec>, idx: Vec>) -> Option>> { + let v: Vec = v.into_iter().flatten().collect(); + if idx.iter().any(|oi| oi.map(|i| i >= v.len()) == Some(true)) { + return None; + } + Some(idx.into_iter().map(|i| Some(v[i?])).collect()) + } + + fn test_equal_ref(ca: &UInt32Chunked, idx_ca: &IdxCa) { + let ref_ca: Vec> = ca.into_iter().collect(); + let ref_idx_ca: Vec> = + (&idx_ca).into_iter().map(|i| Some(i? as usize)).collect(); + let gather = ca.gather_skip_nulls(idx_ca).ok(); + let ref_gather = ref_gather_nulls(ref_ca, ref_idx_ca); + assert_eq!(gather.map(|ca| ca.into_iter().collect()), ref_gather); + } + + fn gather_skip_nulls_check(ca: &UInt32Chunked, idx_ca: &IdxCa) { + test_equal_ref(ca, idx_ca); + test_equal_ref(&ca.rechunk(), idx_ca); + test_equal_ref(ca, &idx_ca.rechunk()); + test_equal_ref(&ca.rechunk(), &idx_ca.rechunk()); + } + + #[rustfmt::skip] + #[test] + fn test_gather_skip_nulls() { + let mut rng = SmallRng::seed_from_u64(0xdeadbeef); + + for _test in 0..20 { + let num_elem_chunks = rng.gen_range(1..10); + let elem_chunks: Vec<_> = (0..num_elem_chunks).map(|_| random_vec(&mut rng, 0..u32::MAX, 0..100)).collect(); + let null_elem_chunks: Vec<_> = elem_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect(); + let num_nonnull_elems: usize = null_elem_chunks.iter().map(|c| c.iter().filter(|x| x.is_some()).count()).sum(); + + let num_idx_chunks = rng.gen_range(1..10); + let idx_chunks: Vec<_> = (0..num_idx_chunks).map(|_| random_vec(&mut rng, 0..num_nonnull_elems as IdxSize, 0..200)).collect(); + let null_idx_chunks: Vec<_> = idx_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect(); + + let nonnull_ca = UInt32Chunked::from_chunk_iter("", elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let ca = UInt32Chunked::from_chunk_iter("", null_elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let nonnull_idx_ca = IdxCa::from_chunk_iter("", idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let idx_ca = IdxCa::from_chunk_iter("", null_idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + + gather_skip_nulls_check(&ca, &idx_ca); + gather_skip_nulls_check(&ca, &nonnull_idx_ca); + gather_skip_nulls_check(&nonnull_ca, &idx_ca); + gather_skip_nulls_check(&nonnull_ca, &nonnull_idx_ca); + } + } +} diff --git a/crates/polars-ops/src/chunked_array/interpolate.rs b/crates/polars-ops/src/chunked_array/interpolate.rs index 1d824200f824..b06c06574960 100644 --- a/crates/polars-ops/src/chunked_array/interpolate.rs +++ b/crates/polars-ops/src/chunked_array/interpolate.rs @@ -60,29 +60,6 @@ where } } -#[inline] -fn unsigned_interp(low: T, high: T, steps: IdxSize, steps_n: T, av: &mut Vec) -where - T: Sub - + Mul - + Add - + Div - + NumCast - + PartialOrd - + Copy, -{ - if high >= low { - signed_interp::(low, high, steps, steps_n, av) - } else { - let diff = low - high; - for step_i in (1..steps).rev() { - let step_i: T = NumCast::from(step_i).unwrap(); - let v = linear_itp(high, step_i, diff, steps_n); - av.push(v) - } - } -} - fn interpolate_impl(chunked_arr: &ChunkedArray, interpolation_branch: I) -> ChunkedArray where T: PolarsNumericType, @@ -196,34 +173,46 @@ fn interpolate_linear(s: &Series) -> Series { let logical = s.dtype(); let s = s.to_physical_repr(); - let out = match s.dtype() { - #[cfg(feature = "dtype-i8")] - DataType::Int8 => linear_interp_signed(s.i8().unwrap()), - #[cfg(feature = "dtype-i16")] - DataType::Int16 => linear_interp_signed(s.i16().unwrap()), - DataType::Int32 => linear_interp_signed(s.i32().unwrap()), - DataType::Int64 => linear_interp_signed(s.i64().unwrap()), - #[cfg(feature = "dtype-u8")] - DataType::UInt8 => linear_interp_unsigned(s.u8().unwrap()), - #[cfg(feature = "dtype-u16")] - DataType::UInt16 => linear_interp_unsigned(s.u16().unwrap()), - DataType::UInt32 => linear_interp_unsigned(s.u32().unwrap()), - DataType::UInt64 => linear_interp_unsigned(s.u64().unwrap()), - DataType::Float32 => linear_interp_unsigned(s.f32().unwrap()), - DataType::Float64 => linear_interp_unsigned(s.f64().unwrap()), - _ => s.as_ref().clone(), + + let out = if matches!( + logical, + DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time + ) { + match s.dtype() { + // Datetime, Time, or Duration + DataType::Int64 => linear_interp_signed(s.i64().unwrap()), + // Date + DataType::Int32 => linear_interp_signed(s.i32().unwrap()), + _ => unreachable!(), + } + } else { + match s.dtype() { + DataType::Float32 => linear_interp_signed(s.f32().unwrap()), + DataType::Float64 => linear_interp_signed(s.f64().unwrap()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap()) + }, + _ => s.as_ref().clone(), + } }; - out.cast(logical).unwrap() + match logical { + DataType::Date + | DataType::Datetime(_, _) + | DataType::Duration(_) + | DataType::Time => out.cast(logical).unwrap(), + _ => out, + } }, } } -fn linear_interp_unsigned(ca: &ChunkedArray) -> Series -where - ChunkedArray: IntoSeries, -{ - interpolate_impl(ca, unsigned_interp::).into_series() -} fn linear_interp_signed(ca: &ChunkedArray) -> Series where ChunkedArray: IntoSeries, diff --git a/crates/polars-ops/src/chunked_array/list/count.rs b/crates/polars-ops/src/chunked_array/list/count.rs index 546bebba969b..f066552eb934 100644 --- a/crates/polars-ops/src/chunked_array/list/count.rs +++ b/crates/polars-ops/src/chunked_array/list/count.rs @@ -25,7 +25,7 @@ fn count_bits_set_by_offsets(values: &Bitmap, offset: &[i64]) -> Vec { } #[cfg(feature = "list_count")] -pub fn list_count_match(ca: &ListChunked, value: AnyValue) -> PolarsResult { +pub fn list_count_matches(ca: &ListChunked, value: AnyValue) -> PolarsResult { let value = Series::new("", [value]); let ca = ca.apply_to_inner(&|s| { diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index 7ffca3ffcf68..1852d8ca3a5c 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -78,24 +78,18 @@ pub(super) fn list_min_function(ca: &ListChunked) -> Series { match ca.inner_dtype() { DataType::Boolean => { let out: BooleanChunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().bool().unwrap().min())) - .collect_trusted(); + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().bool().unwrap().min())); out.into_series() }, dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(dt, |$T| { - let out: ChunkedArray<$T> = ca - .amortized_iter() - .map(|opt_s| - { + + let out: ChunkedArray<$T> = ca.apply_amortized_generic(|opt_s| { let s = opt_s?; let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); ca.min() - } - ) - .collect_trusted(); - out.into_series() + }); + out.into_series() }) }, _ => ca @@ -184,24 +178,19 @@ pub(super) fn list_max_function(ca: &ListChunked) -> Series { match ca.inner_dtype() { DataType::Boolean => { let out: BooleanChunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().bool().unwrap().max())) - .collect_trusted(); + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().bool().unwrap().max())); out.into_series() }, dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(dt, |$T| { - let out: ChunkedArray<$T> = ca - .amortized_iter() - .map(|opt_s| - { + + let out: ChunkedArray<$T> = ca.apply_amortized_generic(|opt_s| { let s = opt_s?; let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); ca.max() - } - ) - .collect_trusted(); - out.into_series() + }); + out.into_series() + }) }, _ => ca diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index ddf44c70f96a..ad6d1ce34994 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -10,7 +10,7 @@ use polars_core::export::num::ToPrimitive; use polars_core::export::num::{NumCast, Signed, Zero}; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; -use polars_core::utils::{try_get_supertype, CustomIterTools}; +use polars_core::utils::try_get_supertype; use super::*; #[cfg(feature = "list_any_all")] @@ -76,42 +76,84 @@ fn cast_rhs( pub trait ListNameSpaceImpl: AsList { /// In case the inner dtype [`DataType::Utf8`], the individual items will be joined into a /// single string separated by `separator`. - fn lst_join(&self, separator: &str) -> PolarsResult { + fn lst_join(&self, separator: &Utf8Chunked) -> PolarsResult { let ca = self.as_list(); match ca.inner_dtype() { - DataType::Utf8 => { - // used to amortize heap allocs - let mut buf = String::with_capacity(128); - - let mut builder = Utf8ChunkedBuilder::new( - ca.name(), - ca.len(), - ca.get_values_size() + separator.len() * ca.len(), - ); - - ca.amortized_iter().for_each(|opt_s| { - let opt_val = opt_s.map(|s| { - // make sure that we don't write values of previous iteration - buf.clear(); - let ca = s.as_ref().utf8().unwrap(); - let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); - - for val in iter { - buf.write_str(val).unwrap(); - buf.write_str(separator).unwrap(); - } - // last value should not have a separator, so slice that off - // saturating sub because there might have been nothing written. - &buf[..buf.len().saturating_sub(separator.len())] - }); - builder.append_option(opt_val) - }); - Ok(builder.finish()) + DataType::Utf8 => match separator.len() { + 1 => match separator.get(0) { + Some(separator) => self.join_literal(separator), + _ => Ok(Utf8Chunked::full_null(ca.name(), ca.len())), + }, + _ => self.join_many(separator), }, dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"), } } + fn join_literal(&self, separator: &str) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = Utf8ChunkedBuilder::new( + ca.name(), + ca.len(), + ca.get_values_size() + separator.len() * ca.len(), + ); + + ca.for_each_amortized(|opt_s| { + let opt_val = opt_s.map(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().utf8().unwrap(); + let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); + + for val in iter { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + &buf[..buf.len().saturating_sub(separator.len())] + }); + builder.append_option(opt_val) + }); + Ok(builder.finish()) + } + + fn join_many(&self, separator: &Utf8Chunked) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = + Utf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size() + ca.len()); + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + ca.amortized_iter() + .zip(separator) + .for_each(|(opt_s, opt_sep)| match opt_sep { + Some(separator) => { + let opt_val = opt_s.map(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().utf8().unwrap(); + let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); + + for val in iter { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + &buf[..buf.len().saturating_sub(separator.len())] + }); + builder.append_option(opt_val) + }, + _ => builder.append_null(), + }) + } + Ok(builder.finish()) + } + fn lst_max(&self) -> Series { list_max_function(self.as_list()) } @@ -197,22 +239,18 @@ pub trait ListNameSpaceImpl: AsList { fn lst_arg_min(&self) -> IdxCa { let ca = self.as_list(); - let mut out: IdxCa = ca - .amortized_iter() - .map(|opt_s| opt_s.and_then(|s| s.as_ref().arg_min().map(|idx| idx as IdxSize))) - .collect_trusted(); - out.rename(ca.name()); - out + ca.apply_amortized_generic(|opt_s| { + opt_s.and_then(|s| s.as_ref().arg_min().map(|idx| idx as IdxSize)) + }) + .with_name(ca.name()) } fn lst_arg_max(&self) -> IdxCa { let ca = self.as_list(); - let mut out: IdxCa = ca - .amortized_iter() - .map(|opt_s| opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize))) - .collect_trusted(); - out.rename(ca.name()); - out + ca.apply_amortized_generic(|opt_s| { + opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize)) + }) + .with_name(ca.name()) } #[cfg(feature = "diff")] @@ -221,10 +259,26 @@ pub trait ListNameSpaceImpl: AsList { ca.try_apply_amortized(|s| s.as_ref().diff(n, null_behavior)) } - fn lst_shift(&self, periods: i64) -> ListChunked { + fn lst_shift(&self, periods: &Series) -> PolarsResult { let ca = self.as_list(); - let out = ca.apply_amortized(|s| s.as_ref().shift(periods)); - self.same_type(out) + let periods_s = periods.cast(&DataType::Int64)?; + let periods = periods_s.i64()?; + let out = match periods.len() { + 1 => { + if let Some(periods) = periods.get(0) { + ca.apply_amortized(|s| s.as_ref().shift(periods)) + } else { + ListChunked::full_null_with_dtype(ca.name(), ca.len(), &ca.inner_dtype()) + } + }, + _ => ca.zip_and_apply_amortized(periods, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(periods)) => Some(s.as_ref().shift(periods)), + _ => None, + } + }), + }; + Ok(self.same_type(out)) } fn lst_slice(&self, offset: i64, length: usize) -> ListChunked { @@ -268,41 +322,47 @@ pub trait ListNameSpaceImpl: AsList { let index_typed_index = |idx: &Series| { let idx = idx.cast(&IDX_DTYPE).unwrap(); - list_ca - .amortized_iter() - .map(|s| { - s.map(|s| { - let s = s.as_ref(); - take_series(s, idx.clone(), null_on_oob) + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + list_ca + .amortized_iter() + .map(|s| { + s.map(|s| { + let s = s.as_ref(); + take_series(s, idx.clone(), null_on_oob) + }) + .transpose() }) - .transpose() - }) - .collect::>() - .map(|mut ca| { - ca.rename(list_ca.name()); - ca.into_series() - }) + .collect::>() + .map(|mut ca| { + ca.rename(list_ca.name()); + ca.into_series() + }) + } }; use DataType::*; match idx.dtype() { List(_) => { let idx_ca = idx.list().unwrap(); - let mut out = list_ca - .amortized_iter() - .zip(idx_ca) - .map(|(opt_s, opt_idx)| { - { - match (opt_s, opt_idx) { - (Some(s), Some(idx)) => { - Some(take_series(s.as_ref(), idx, null_on_oob)) - }, - _ => None, + // SAFETY: unstable series never lives longer than the iterator. + let mut out = unsafe { + list_ca + .amortized_iter() + .zip(idx_ca) + .map(|(opt_s, opt_idx)| { + { + match (opt_s, opt_idx) { + (Some(s), Some(idx)) => { + Some(take_series(s.as_ref(), idx, null_on_oob)) + }, + _ => None, + } } - } - .transpose() - }) - .collect::>()?; + .transpose() + }) + .collect::>()? + }; out.rename(list_ca.name()); Ok(out.into_series()) @@ -313,14 +373,17 @@ pub trait ListNameSpaceImpl: AsList { if min >= 0 { index_typed_index(idx) } else { - let mut out = list_ca - .amortized_iter() - .map(|opt_s| { - opt_s - .map(|s| take_series(s.as_ref(), idx.clone(), null_on_oob)) - .transpose() - }) - .collect::>()?; + // SAFETY: unstable series never lives longer than the iterator. + let mut out = unsafe { + list_ca + .amortized_iter() + .map(|opt_s| { + opt_s + .map(|s| take_series(s.as_ref(), idx.clone(), null_on_oob)) + .transpose() + }) + .collect::>()? + }; out.rename(list_ca.name()); Ok(out.into_series()) } @@ -332,6 +395,13 @@ pub trait ListNameSpaceImpl: AsList { } } + #[cfg(feature = "list_drop_nulls")] + fn lst_drop_nulls(&self) -> ListChunked { + let list_ca = self.as_list(); + + list_ca.apply_amortized(|s| s.as_ref().drop_nulls()) + } + fn lst_concat(&self, other: &[Series]) -> PolarsResult { let ca = self.as_list(); let other_len = other.len(); @@ -371,7 +441,7 @@ pub trait ListNameSpaceImpl: AsList { .iter() .flat_map(|s| { let lst = s.list().unwrap(); - lst.get(0) + lst.get_as_series(0) }) .collect::>(); // there was a None, so all values will be None @@ -422,7 +492,8 @@ pub trait ListNameSpaceImpl: AsList { let mut iters = Vec::with_capacity(other_len + 1); for s in other.iter_mut() { - iters.push(s.list()?.amortized_iter()) + // SAFETY: unstable series never lives longer than the iterator. + iters.push(unsafe { s.list()?.amortized_iter() }) } let mut first_iter = ca.into_iter(); let mut builder = get_list_builder( diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index 299020e4dbea..0d6d7bcd8dba 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -75,52 +75,38 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Series // TODO: add fast path for smaller ints? let mut out = match inner_dtype { Boolean => { - let out: IdxCa = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: IdxCa = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, UInt32 => { - let out: UInt32Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: UInt32Chunked = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, UInt64 => { - let out: UInt64Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: UInt64Chunked = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, Int32 => { - let out: Int32Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: Int32Chunked = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, Int64 => { - let out: Int64Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, Float32 => { - let out: Float32Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: Float32Chunked = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, Float64 => { - let out: Float64Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().sum())) - .collect(); + let out: Float64Chunked = + ca.apply_amortized_generic(|s| s.and_then(|s| s.as_ref().sum::())); out.into_series() }, // slowest sum_as_series path @@ -198,21 +184,15 @@ pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Se pub(super) fn mean_with_nulls(ca: &ListChunked) -> Series { return match ca.inner_dtype() { DataType::Float32 => { - let mut out: Float32Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32))) - .collect(); - - out.rename(ca.name()); + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32))) + .with_name(ca.name()); out.into_series() }, _ => { - let mut out: Float64Chunked = ca - .amortized_iter() - .map(|s| s.and_then(|s| s.as_ref().mean())) - .collect(); - - out.rename(ca.name()); + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean())) + .with_name(ca.name()); out.into_series() }, }; diff --git a/crates/polars-ops/src/chunked_array/mod.rs b/crates/polars-ops/src/chunked_array/mod.rs index 62052893ac9f..f085b960800d 100644 --- a/crates/polars-ops/src/chunked_array/mod.rs +++ b/crates/polars-ops/src/chunked_array/mod.rs @@ -8,12 +8,21 @@ mod interpolate; pub mod list; #[cfg(feature = "propagate_nans")] pub mod nan_propagating_aggregate; +#[cfg(feature = "peaks")] +pub mod peaks; mod set; mod strings; mod sum; #[cfg(feature = "top_k")] mod top_k; +#[cfg(feature = "mode")] +pub mod mode; + +pub mod gather_skip_nulls; +#[cfg(feature = "repeat_by")] +mod repeat_by; + pub use binary::*; #[cfg(feature = "timezones")] pub use datetime::*; @@ -22,6 +31,8 @@ pub use interpolate::*; pub use list::*; #[allow(unused_imports)] use polars_core::prelude::*; +#[cfg(feature = "repeat_by")] +pub use repeat_by::*; pub use set::ChunkedSet; pub use strings::*; #[cfg(feature = "top_k")] diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs new file mode 100644 index 000000000000..2402f3276a8a --- /dev/null +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -0,0 +1,127 @@ +use polars_arrow::utils::CustomIterTools; +use polars_core::frame::group_by::IntoGroupsProxy; +use polars_core::prelude::*; +use polars_core::with_match_physical_integer_polars_type; + +fn mode_primitive(ca: &ChunkedArray) -> PolarsResult> +where + ChunkedArray: IntoGroupsProxy + ChunkTake<[IdxSize]>, +{ + if ca.is_empty() { + return Ok(ca.clone()); + } + let groups = ca.group_tuples(true, false).unwrap(); + let idx = mode_indices(groups); + + // Safety: + // group indices are in bounds + Ok(unsafe { ca.take_unchecked(idx.as_slice()) }) +} + +fn mode_f32(ca: &Float32Chunked) -> PolarsResult { + let s = ca.apply_as_ints(|v| mode(v).unwrap()); + let ca = s.f32().unwrap().clone(); + Ok(ca) +} + +fn mode_64(ca: &Float64Chunked) -> PolarsResult { + let s = ca.apply_as_ints(|v| mode(v).unwrap()); + let ca = s.f64().unwrap().clone(); + Ok(ca) +} + +fn mode_indices(groups: GroupsProxy) -> Vec { + match groups { + GroupsProxy::Idx(groups) => { + let mut groups = groups.into_iter().collect_trusted::>(); + groups.sort_unstable_by_key(|k| k.1.len()); + let last = &groups.last().unwrap(); + let max_occur = last.1.len(); + groups + .iter() + .rev() + .take_while(|v| v.1.len() == max_occur) + .map(|v| v.0) + .collect() + }, + GroupsProxy::Slice { groups, .. } => { + let last = groups.last().unwrap(); + let max_occur = last[1]; + + groups + .iter() + .rev() + .take_while(|v| { + let len = v[1]; + len == max_occur + }) + .map(|v| v[0]) + .collect() + }, + } +} + +pub fn mode(s: &Series) -> PolarsResult { + let s_phys = s.to_physical_repr(); + let out = match s_phys.dtype() { + DataType::Binary => mode_primitive(s_phys.binary().unwrap())?.into_series(), + DataType::Boolean => mode_primitive(s_phys.bool().unwrap())?.into_series(), + DataType::Float32 => mode_f32(s_phys.f32().unwrap())?.into_series(), + DataType::Float64 => mode_64(s_phys.f64().unwrap())?.into_series(), + DataType::Utf8 => mode_primitive(&s_phys.utf8().unwrap().as_binary())?.into_series(), + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref(); + mode_primitive(ca)?.into_series() + }) + }, + _ => polars_bail!(opq = mode, s.dtype()), + }; + // # Safety: Casting back into the original from physical representation + unsafe { out.cast_unchecked(s.dtype()) } +} + +#[cfg(test)] +mod test { + use polars_core::prelude::*; + + use super::{mode, mode_primitive}; + + #[test] + fn mode_test() { + let ca = Int32Chunked::from_slice("test", &[0, 1, 2, 3, 4, 4, 5, 6, 5, 0]); + let mut result = mode_primitive(&ca).unwrap().to_vec(); + result.sort_by_key(|a| a.unwrap()); + assert_eq!(&result, &[Some(0), Some(4), Some(5)]); + + let ca = Int32Chunked::from_slice("test", &[1, 1]); + let mut result = mode_primitive(&ca).unwrap().to_vec(); + result.sort_by_key(|a| a.unwrap()); + assert_eq!(&result, &[Some(1)]); + + let ca = Int32Chunked::from_slice("test", &[]); + let mut result = mode_primitive(&ca).unwrap().to_vec(); + result.sort_by_key(|a| a.unwrap()); + assert_eq!(result, &[]); + + let ca = Float32Chunked::from_slice("test", &[1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0]); + let result = mode_primitive(&ca).unwrap().to_vec(); + assert_eq!(result, &[Some(3.0f32)]); + + let ca = Utf8Chunked::from_slice("test", &["test", "test", "test", "another test"]); + let result = mode_primitive(&ca).unwrap(); + let vec_result4: Vec> = result.into_iter().collect(); + assert_eq!(vec_result4, &[Some("test")]); + + let mut ca_builder = CategoricalChunkedBuilder::new("test", 5); + ca_builder.append_value("test"); + ca_builder.append_value("test"); + ca_builder.append_value("test2"); + ca_builder.append_value("test2"); + ca_builder.append_value("test2"); + let s = ca_builder.finish().into_series(); + let result = mode(&s).unwrap(); + assert_eq!(result.str_value(0).unwrap(), "test2"); + assert_eq!(result.len(), 1); + } +} diff --git a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs index ebf7f0d1545d..8d89d95b2c1b 100644 --- a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs +++ b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs @@ -1,13 +1,9 @@ -use std::cmp::Ordering; - use polars_arrow::export::arrow::array::Array; use polars_arrow::kernels::rolling; use polars_arrow::kernels::rolling::no_nulls::{MaxWindow, MinWindow}; -use polars_arrow::kernels::rolling::{compare_fn_nan_max, compare_fn_nan_min}; use polars_arrow::kernels::take_agg::{ take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked, }; -use polars_arrow::utils::CustomIterTools; use polars_core::export::num::Bounded; use polars_core::frame::group_by::aggregations::{ _agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls, @@ -15,20 +11,26 @@ use polars_core::frame::group_by::aggregations::{ }; use polars_core::prelude::*; -#[inline] -fn nan_min(a: T, b: T) -> T { - if let Ordering::Less = compare_fn_nan_min(&a, &b) { +#[inline(always)] +fn nan_min(a: T, b: T) -> T { + // If b is nan, min is nan, because the comparison failed. We have + // to poison the result if a is nan. + let min = if a < b { a } else { b }; + if a.is_nan() { a } else { - b + min } } -#[inline] -fn nan_max(a: T, b: T) -> T { - if let Ordering::Greater = compare_fn_nan_max(&a, &b) { + +#[inline(always)] +fn nan_max(a: T, b: T) -> T { + // See nan_min. + let max = if a > b { a } else { b }; + if a.is_nan() { a } else { - b + max } } @@ -40,18 +42,12 @@ where let mut cum_agg = None; ca.downcast_iter().for_each(|arr| { let agg = if arr.null_count() == 0 { - arr.values().iter().copied().fold_first_(min_or_max_fn) + arr.values().iter().copied().reduce(min_or_max_fn) } else { arr.iter() .unwrap_optional() - .map(|opt| opt.copied()) - .fold_first_(|a, b| match (a, b) { - (Some(a), Some(b)) => Some(min_or_max_fn(a, b)), - (None, Some(b)) => Some(b), - (Some(a), None) => Some(a), - (None, None) => None, - }) - .flatten() + .filter_map(|opt| opt.copied()) + .reduce(min_or_max_fn) }; match cum_agg { None => cum_agg = agg, @@ -118,7 +114,7 @@ where idx.len() as IdxSize, ), _ => { - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; ca_nan_agg(&take, nan_max) }, } @@ -190,7 +186,7 @@ where idx.len() as IdxSize, ), _ => { - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; ca_nan_agg(&take, nan_min) }, } @@ -235,7 +231,7 @@ where } /// # Safety -/// `groups` must be in bounds +/// `groups` must be in bounds. pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsProxy) -> Series { match s.dtype() { DataType::Float32 => { @@ -251,7 +247,7 @@ pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsProxy) -> Series { } /// # Safety -/// `groups` must be in bounds +/// `groups` must be in bounds. pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsProxy) -> Series { match s.dtype() { DataType::Float32 => { diff --git a/crates/polars-ops/src/chunked_array/peaks.rs b/crates/polars-ops/src/chunked_array/peaks.rs new file mode 100644 index 000000000000..60d4e0be8ae0 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/peaks.rs @@ -0,0 +1,16 @@ +use num_traits::Zero; +use polars_core::prelude::*; + +/// Get a boolean mask of the local maximum peaks. +pub fn peak_max(ca: &ChunkedArray) -> BooleanChunked { + let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); + let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); + ChunkedArray::lt(&shift_left, ca) & ChunkedArray::lt(&shift_right, ca) +} + +/// Get a boolean mask of the local minimum peaks. +pub fn peak_min(ca: &ChunkedArray) -> BooleanChunked { + let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); + let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); + ChunkedArray::gt(&shift_left, ca) & ChunkedArray::gt(&shift_right, ca) +} diff --git a/crates/polars-ops/src/chunked_array/repeat_by.rs b/crates/polars-ops/src/chunked_array/repeat_by.rs new file mode 100644 index 000000000000..e20b40896df7 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/repeat_by.rs @@ -0,0 +1,134 @@ +use arrow::array::ListArray; +use polars_arrow::array::ListFromIter; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +type LargeListArray = ListArray; + +fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> { + polars_ensure!( + (length_srs == length_by) | (length_by == 1) | (length_srs == 1), + ComputeError: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}", + length_srs, length_by + ); + Ok(()) +} + +fn new_by(by: &IdxCa, len: usize) -> IdxCa { + IdxCa::new( + "", + std::iter::repeat(by.get(0).unwrap()) + .take(len) + .collect::>(), + ) +} + +fn repeat_by_primitive(ca: &ChunkedArray, by: &IdxCa) -> PolarsResult +where + T: PolarsNumericType, +{ + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(ca, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { + LargeListArray::from_iter_primitive_trusted_len(iter, T::get_dtype().to_arrow()) + } + })) + }, + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_primitive(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_primitive(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +fn repeat_by_bool(ca: &BooleanChunked, by: &IdxCa) -> PolarsResult { + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(ca, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_bool_trusted_len(iter) } + })) + }, + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_bool(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_bool(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +fn repeat_by_binary(ca: &BinaryChunked, by: &IdxCa) -> PolarsResult { + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(ca, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_binary_trusted_len(iter, ca.len()) } + })) + }, + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_binary(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_binary(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +pub fn repeat_by(s: &Series, by: &IdxCa) -> PolarsResult { + let s_phys = s.to_physical_repr(); + use DataType::*; + let out = match s_phys.dtype() { + Boolean => repeat_by_bool(s_phys.bool().unwrap(), by), + Utf8 => { + let ca = s_phys.utf8().unwrap(); + repeat_by_binary(&ca.as_binary(), by) + }, + Binary => repeat_by_binary(s_phys.binary().unwrap(), by), + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref(); + repeat_by_primitive(ca, by) + }) + }, + _ => polars_bail!(opq = repeat_by, s.dtype()), + }; + out.and_then(|ca| { + let logical_type = s.dtype(); + ca.apply_to_inner(&|s| unsafe { s.cast_unchecked(logical_type) }) + }) +} diff --git a/crates/polars-ops/src/chunked_array/set.rs b/crates/polars-ops/src/chunked_array/set.rs index 05848a6a7a78..705ff0d7039c 100644 --- a/crates/polars-ops/src/chunked_array/set.rs +++ b/crates/polars-ops/src/chunked_array/set.rs @@ -4,6 +4,7 @@ use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::arrow::bitmap::MutableBitmap; use polars_core::utils::arrow::types::NativeType; +use polars_utils::index::check_bounds; pub trait ChunkedSet { fn set_at_idx2(self, idx: &[IdxSize], values: V) -> PolarsResult @@ -27,19 +28,6 @@ fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> { Ok(()) } -fn check_bounds(idx: &[IdxSize], len: IdxSize) -> PolarsResult<()> { - let mut inbounds = true; - - for &i in idx { - if i >= len { - // we will not break here as that prevents SIMD - inbounds = false; - } - } - polars_ensure!(inbounds, ComputeError: "set indices are out of bounds"); - Ok(()) -} - trait PolarsOpsNumericType: PolarsNumericType {} impl PolarsOpsNumericType for UInt8Type {} diff --git a/crates/polars-ops/src/chunked_array/strings/case.rs b/crates/polars-ops/src/chunked_array/strings/case.rs index eb9829e292dd..cd4993287349 100644 --- a/crates/polars-ops/src/chunked_array/strings/case.rs +++ b/crates/polars-ops/src/chunked_array/strings/case.rs @@ -1,14 +1,9 @@ -#[cfg(feature = "nightly")] -use core::unicode::conversions; - use polars_core::prelude::Utf8Chunked; -// inlined from std +// Inlined from std. fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec) { - unsafe { - out.set_len(0); - out.reserve(b.len()); - } + out.clear(); + out.reserve(b.len()); const USIZE_SIZE: usize = std::mem::size_of::(); const MAGIC_UNROLL: usize = 2; @@ -18,69 +13,43 @@ fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec) { let mut i = 0; unsafe { while i + N <= b.len() { - // Safety: we have checks the sizes `b` and `out` to know that our + // SAFETY: we have checks the sizes `b` and `out`. let in_chunk = b.get_unchecked(i..i + N); let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N); let mut bits = 0; for j in 0..MAGIC_UNROLL { - // read the bytes 1 usize at a time (unaligned since we haven't checked the alignment) - // safety: in_chunk is valid bytes in the range + // Read the bytes 1 usize at a time (unaligned since we haven't checked the alignment). + // SAFETY: in_chunk is valid bytes in the range. bits |= in_chunk.as_ptr().cast::().add(j).read_unaligned(); } - // if our chunks aren't ascii, then return only the prior bytes as init + // If our chunks aren't ascii, then return only the prior bytes as init. if bits & NONASCII_MASK != 0 { break; } - // perform the case conversions on N bytes (gets heavily autovec'd) + // Perform the case conversions on N bytes (gets heavily autovec'd). for j in 0..N { - // safety: in_chunk and out_chunk is valid bytes in the range + // SAFETY: in_chunk and out_chunk are valid bytes in the range. let out = out_chunk.get_unchecked_mut(j); out.write(convert(in_chunk.get_unchecked(j))); } - // mark these bytes as initialised + // Mark these bytes as initialised. i += N; } out.set_len(i); } } -#[cfg(not(feature = "nightly"))] -pub(super) fn to_lowercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked { - // this amortizes allocations and will not be freed - // so will have size of max(len) - let mut buf = Vec::new(); - - // this is one that will be set if we cannot convert ascii - // this length will change every iteration we must use this - let mut buf2 = Vec::new(); - let f = |s: &'a str| { - convert_while_ascii(s.as_bytes(), u8::to_ascii_lowercase, &mut buf); - let slice = if buf.len() < s.len() { - buf2 = s.to_lowercase().into_bytes(); - buf2.as_ref() - } else { - buf.as_ref() - }; - // extend lifetime - // lifetime is bound to 'a - let slice = unsafe { std::str::from_utf8_unchecked(slice) }; - unsafe { std::mem::transmute::<&str, &'a str>(slice) } - }; - ca.apply_mut(f) -} - -#[cfg(feature = "nightly")] fn to_lowercase_helper(s: &str, buf: &mut Vec) { convert_while_ascii(s.as_bytes(), u8::to_ascii_lowercase, buf); - // Safety: we know this is a valid char boundary since - // out.len() is only progressed if ascii bytes are found + // SAFETY: we know this is a valid char boundary since + // out.len() is only progressed if ASCII bytes are found. let rest = unsafe { s.get_unchecked(buf.len()..) }; - // Safety: We have written only valid ASCII to our vec + // SAFETY: We have written only valid ASCII to our vec. let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(buf)) }; for (i, c) in rest[..].char_indices() { @@ -92,18 +61,7 @@ fn to_lowercase_helper(s: &str, buf: &mut Vec) { // See https://github.com/rust-lang/rust/issues/26035 map_uppercase_sigma(rest, i, &mut s) } else { - match conversions::to_lower(c) { - [a, '\0', _] => s.push(a), - [a, b, '\0'] => { - s.push(a); - s.push(b); - }, - [a, b, c] => { - s.push(a); - s.push(b); - s.push(c); - }, - } + s.extend(c.to_lowercase()); } } @@ -124,136 +82,80 @@ fn to_lowercase_helper(s: &str, buf: &mut Vec) { None => false, } } - // put buf back for next iteration + + // Put buf back for next iteration. *buf = s.into_bytes(); } -// inlined from std -#[cfg(feature = "nightly")] pub(super) fn to_lowercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked { - // amortize allocation + // Amortize allocation. let mut buf = Vec::new(); - let f = |s: &'a str| { + let f = |s: &'a str| -> &'a str { to_lowercase_helper(s, &mut buf); - - // extend lifetime - // lifetime is bound to 'a + // SAFETY: apply_mut will copy value from buf before next iteration. let slice = unsafe { std::str::from_utf8_unchecked(&buf) }; unsafe { std::mem::transmute::<&str, &'a str>(slice) } }; ca.apply_mut(f) } -#[cfg(not(feature = "nightly"))] +// Inlined from std. pub(super) fn to_uppercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked { - // this amortizes allocations and will not be freed - // so will have size of max(len) + // Amortize allocation. let mut buf = Vec::new(); - - // this is one that will be set if we cannot convert ascii - // this length will change every iteration we must use this - let mut buf2 = Vec::new(); - let f = |s: &'a str| { + let f = |s: &'a str| -> &'a str { convert_while_ascii(s.as_bytes(), u8::to_ascii_uppercase, &mut buf); - let slice = if buf.len() < s.len() { - buf2 = s.to_uppercase().into_bytes(); - buf2.as_ref() - } else { - buf.as_ref() - }; - // extend lifetime - // lifetime is bound to 'a - let slice = unsafe { std::str::from_utf8_unchecked(slice) }; - unsafe { std::mem::transmute::<&str, &'a str>(slice) } - }; - ca.apply_mut(f) -} -#[inline] -#[cfg(feature = "nightly")] -fn push_char_to_upper(c: char, s: &mut String) { - match conversions::to_upper(c) { - [a, '\0', _] => s.push(a), - [a, b, '\0'] => { - s.push(a); - s.push(b); - }, - [a, b, c] => { - s.push(a); - s.push(b); - s.push(c); - }, - } -} - -// inlined from std -#[cfg(feature = "nightly")] -pub(super) fn to_uppercase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked { - // amortize allocation - let mut buf = Vec::new(); - let f = |s: &'a str| { - convert_while_ascii(s.as_bytes(), u8::to_ascii_uppercase, &mut buf); - - // Safety: we know this is a valid char boundary since - // out.len() is only progressed if ascii bytes are found + // SAFETY: we know this is a valid char boundary since + // out.len() is only progressed if ascii bytes are found. let rest = unsafe { s.get_unchecked(buf.len()..) }; - // Safety: We have written only valid ASCII to our vec + // SAFETY: We have written only valid ASCII to our vec. let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(&mut buf)) }; for c in rest.chars() { - push_char_to_upper(c, &mut s); + s.extend(c.to_uppercase()); } - // put buf back for next iteration + // Put buf back for next iteration. buf = s.into_bytes(); - // extend lifetime - // lifetime is bound to 'a + // SAFETY: apply_mut will copy value from buf before next iteration. let slice = unsafe { std::str::from_utf8_unchecked(&buf) }; unsafe { std::mem::transmute::<&str, &'a str>(slice) } }; ca.apply_mut(f) } -#[cfg(feature = "nightly")] pub(super) fn to_titlecase<'a>(ca: &'a Utf8Chunked) -> Utf8Chunked { - // amortize allocation + // Amortize allocation. let mut buf = Vec::new(); - // temporary scratch - // we have a double copy as we first convert to lowercase - // and then copy to `buf` + // Temporary scratch space. + // We have a double copy as we first convert to lowercase and then copy to `buf`. let mut scratch = Vec::new(); - let f = |s: &'a str| { - unsafe { - buf.set_len(0); - } - // this helper sets scratch len to 0 + let f = |s: &'a str| -> &'a str { to_lowercase_helper(s, &mut scratch); - - let mut next_is_upper = true; - let lowercased = unsafe { std::str::from_utf8_unchecked(&scratch) }; + // SAFETY: the buffer is clear, empty string is valid UTF-8. + buf.clear(); let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(&mut buf)) }; + let mut next_is_upper = true; for c in lowercased.chars() { - let is_whitespace = c.is_whitespace(); - if is_whitespace || !next_is_upper { - next_is_upper = is_whitespace; - s.push(c); + if next_is_upper { + s.extend(c.to_uppercase()); } else { - next_is_upper = false; - push_char_to_upper(c, &mut s); + s.push(c); } + next_is_upper = c.is_whitespace(); } - // put buf back for next iteration + // Put buf back for next iteration. buf = s.into_bytes(); - // extend lifetime - // lifetime is bound to 'a + // SAFETY: apply_mut will copy value from buf before next iteration. let slice = unsafe { std::str::from_utf8_unchecked(&buf) }; unsafe { std::mem::transmute::<&str, &'a str>(slice) } }; diff --git a/crates/polars-ops/src/chunked_array/strings/concat.rs b/crates/polars-ops/src/chunked_array/strings/concat.rs new file mode 100644 index 000000000000..caddf3eee3fd --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/concat.rs @@ -0,0 +1,160 @@ +use arrow::array::Utf8Array; +use polars_arrow::array::default_arrays::FromDataUtf8; +use polars_core::prelude::*; + +// Vertically concatenate all strings in a Utf8Chunked. +pub fn str_concat(ca: &Utf8Chunked, delimiter: &str) -> Utf8Chunked { + if ca.len() <= 1 { + return ca.clone(); + } + + // Calculate capacity. + let null_str_len = 4; + let capacity = + ca.get_values_size() + ca.null_count() * null_str_len + delimiter.len() * (ca.len() - 1); + + let mut buf = String::with_capacity(capacity); + let mut first = true; + for arr in ca.downcast_iter() { + for val in arr.into_iter() { + if !first { + buf.push_str(delimiter); + } + + if let Some(s) = val { + buf.push_str(s); + } else { + buf.push_str("null"); + } + + first = false; + } + } + + let buf = buf.into_bytes(); + let offsets = vec![0, buf.len() as i64]; + let arr = unsafe { Utf8Array::from_data_unchecked_default(offsets.into(), buf.into(), None) }; + Utf8Chunked::with_chunk(ca.name(), arr) +} + +enum ColumnIter { + Iter(I), + Broadcast(T), +} + +/// Horizontally concatenate all strings. +/// +/// Each array should have length 1 or a length equal to the maximum length. +pub fn hor_str_concat(cas: &[&Utf8Chunked], delimiter: &str) -> PolarsResult { + if cas.is_empty() { + return Ok(Utf8Chunked::full_null("", 0)); + } + if cas.len() == 1 { + return Ok(cas[0].clone()); + } + + // Calculate the post-broadcast length and ensure everything is consistent. + let len = cas + .iter() + .map(|ca| ca.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + polars_ensure!( + cas.iter().all(|ca| ca.len() == 1 || ca.len() == len), + ComputeError: "all series in `hor_str_concat` should have equal or unit length" + ); + + let has_empty_ca = cas.iter().any(|ca| ca.is_empty()); + if has_empty_ca { + return Ok(Utf8Chunked::full_null(cas[0].name(), 0)); + } + + // Calculate total capacity needed. + let tot_strings_bytes: usize = cas + .iter() + .map(|ca| { + let bytes = ca.get_values_size(); + if ca.len() == 1 { + len * bytes + } else { + bytes + } + }) + .sum(); + let capacity = tot_strings_bytes + (cas.len() - 1) * delimiter.len() * len; + let mut builder = Utf8ChunkedBuilder::new(cas[0].name(), len, capacity); + + // Broadcast if appropriate. + let mut cols: Vec<_> = cas + .iter() + .map(|ca| { + if ca.len() > 1 { + ColumnIter::Iter(ca.into_iter()) + } else { + ColumnIter::Broadcast(ca.get(0)) + } + }) + .collect(); + + // Build concatenated string. + let mut buf = String::with_capacity(1024); + for _row in 0..len { + let mut has_null = false; + for (i, col) in cols.iter_mut().enumerate() { + if i > 0 { + buf.push_str(delimiter); + } + + let val = match col { + ColumnIter::Iter(i) => i.next().unwrap(), + ColumnIter::Broadcast(s) => *s, + }; + if let Some(s) = val { + buf.push_str(s); + } else { + has_null = true; + } + } + + if has_null { + builder.append_null(); + } else { + builder.append_value(&buf) + } + buf.clear(); + } + + Ok(builder.finish()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_str_concat() { + let ca = Int32Chunked::new("foo", &[Some(1), None, Some(3)]); + let ca_str = ca.cast(&DataType::Utf8).unwrap(); + let out = str_concat(&ca_str.utf8().unwrap(), "-"); + + let out = out.get(0); + assert_eq!(out, Some("1-null-3")); + } + + #[test] + fn test_hor_str_concat() { + let a = Utf8Chunked::new("a", &["foo", "bar"]); + let b = Utf8Chunked::new("b", &["spam", "ham"]); + + let out = hor_str_concat(&[&a, &b], "_").unwrap(); + assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]); + + let c = Utf8Chunked::new("b", &["literal"]); + let out = hor_str_concat(&[&a, &b, &c], "_").unwrap(); + assert_eq!( + Vec::from(&out), + &[Some("foo_spam_literal"), Some("bar_ham_literal")] + ); + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index 3caaec8a9dba..ee858f7548e5 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -1,6 +1,8 @@ #[cfg(feature = "strings")] mod case; #[cfg(feature = "strings")] +mod concat; +#[cfg(feature = "strings")] mod extract; #[cfg(feature = "extract_jsonpath")] mod json_path; @@ -10,12 +12,24 @@ mod justify; mod namespace; #[cfg(feature = "strings")] mod replace; +#[cfg(feature = "strings")] +mod split; +#[cfg(feature = "strings")] +mod strip; +#[cfg(feature = "strings")] +mod substring; +#[cfg(feature = "strings")] +pub use concat::*; #[cfg(feature = "extract_jsonpath")] pub use json_path::*; #[cfg(feature = "strings")] pub use namespace::*; use polars_core::prelude::*; +#[cfg(feature = "strings")] +pub use split::*; +#[cfg(feature = "strings")] +pub use strip::*; pub trait AsUtf8 { fn as_utf8(&self) -> &Utf8Chunked; diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index ab9b4919abed..b38f6e9f9590 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -2,17 +2,27 @@ use base64::engine::general_purpose; #[cfg(feature = "string_encoding")] use base64::Engine as _; -use polars_arrow::export::arrow::compute::substring::substring; -use polars_arrow::export::arrow::{self}; use polars_arrow::kernels::string::*; #[cfg(feature = "string_from_radix")] use polars_core::export::num::Num; -use polars_core::export::regex::{escape, Regex}; +use polars_core::export::regex::Regex; +use polars_core::prelude::arity::*; +use polars_utils::cache::FastFixedCache; +use regex::escape; use super::*; #[cfg(feature = "binary_encoding")] use crate::chunked_array::binary::BinaryNameSpaceImpl; +// We need this to infer the right lifetimes for the match closure. +#[inline(always)] +fn infer_re_match(f: F) -> F +where + F: for<'a, 'b> FnMut(Option<&'a str>, Option<&'b str>) -> Option, +{ + f +} + pub trait Utf8NameSpaceImpl: AsUtf8 { #[cfg(not(feature = "binary_encoding"))] fn hex_decode(&self) -> PolarsResult { @@ -85,16 +95,65 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(out) } + fn contains_chunked( + &self, + pat: &Utf8Chunked, + literal: bool, + strict: bool, + ) -> PolarsResult { + let ca = self.as_utf8(); + match pat.len() { + 1 => match pat.get(0) { + Some(pat) => { + if literal { + ca.contains_literal(pat) + } else { + ca.contains(pat, strict) + } + }, + None => Ok(BooleanChunked::full_null(ca.name(), ca.len())), + }, + _ => { + if literal { + Ok(binary_elementwise_values(ca, pat, |src, pat| { + src.contains(pat) + })) + } else if strict { + // A sqrt(n) regex cache is not too small, not too large. + let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); + try_binary_elementwise(ca, pat, |opt_src, opt_pat| match (opt_src, opt_pat) { + (Some(src), Some(pat)) => { + let reg = reg_cache.try_get_or_insert_with(pat, |p| Regex::new(p))?; + Ok(Some(reg.is_match(src))) + }, + _ => Ok(None), + }) + } else { + // A sqrt(n) regex cache is not too small, not too large. + let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); + Ok(binary_elementwise( + ca, + pat, + infer_re_match(|src, pat| { + let reg = reg_cache.try_get_or_insert_with(pat?, |p| Regex::new(p)); + Some(reg.ok()?.is_match(src?)) + }), + )) + } + }, + } + } + /// Get the length of the string values as number of chars. - fn str_n_chars(&self) -> UInt32Chunked { + fn str_len_chars(&self) -> UInt32Chunked { let ca = self.as_utf8(); - ca.apply_kernel_cast(&string_nchars) + ca.apply_kernel_cast(&string_len_chars) } /// Get the length of the string values as number of bytes. - fn str_lengths(&self) -> UInt32Chunked { + fn str_len_bytes(&self) -> UInt32Chunked { let ca = self.as_utf8(); - ca.apply_kernel_cast(&string_lengths) + ca.apply_kernel_cast(&string_len_bytes) } /// Return a copy of the string left filled with ASCII '0' digits to make a string of length width. @@ -132,18 +191,11 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { let res_reg = Regex::new(pat); let opt_reg = if strict { Some(res_reg?) } else { res_reg.ok() }; - let mut out: BooleanChunked = match (opt_reg, ca.has_validity()) { - (Some(reg), false) => ca - .into_no_null_iter() - .map(|s: &str| reg.is_match(s)) - .collect(), - (Some(reg), true) => ca - .into_iter() - .map(|opt_s| opt_s.map(|s: &str| reg.is_match(s))) - .collect(), - (None, _) => ca.into_iter().map(|_| None).collect(), + let out: BooleanChunked = if let Some(reg) = opt_reg { + ca.apply_values_generic(|s| reg.is_match(s)) + } else { + BooleanChunked::full_null(ca.name(), ca.len()) }; - out.rename(ca.name()); Ok(out) } @@ -152,25 +204,7 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { // note: benchmarking shows that the regex engine is actually // faster at finding literal matches than str::contains. // ref: https://github.com/pola-rs/polars/pull/6811 - self.contains(escape(lit).as_str(), true) - } - - /// Check if strings ends with a substring - fn ends_with(&self, sub: &str) -> BooleanChunked { - let ca = self.as_utf8(); - let f = |s: &str| s.ends_with(sub); - let mut out: BooleanChunked = ca.into_iter().map(|opt_s| opt_s.map(f)).collect(); - out.rename(ca.name()); - out - } - - /// Check if strings starts with a substring - fn starts_with(&self, sub: &str) -> BooleanChunked { - let ca = self.as_utf8(); - let f = |s: &str| s.starts_with(sub); - let mut out: BooleanChunked = ca.into_iter().map(|opt_s| opt_s.map(f)).collect(); - out.rename(ca.name()); - out + self.contains(regex::escape(lit).as_str(), true) } /// Replace the leftmost regex-matched (sub)string with another string @@ -258,14 +292,14 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { })); } - // amortize allocation + // Amortize allocation. let mut buf = String::new(); let f = move |s: &'a str| { buf.clear(); let mut changed = false; - // See: str.replace + // See: str.replace. let mut last_end = 0; for (start, part) in s.match_indices(pat) { changed = true; @@ -276,8 +310,7 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { buf.push_str(unsafe { s.get_unchecked(last_end..s.len()) }); if changed { - // extend lifetime - // lifetime is bound to 'a + // Extend lifetime, lifetime is bound to 'a. let slice = buf.as_str(); unsafe { std::mem::transmute::<&str, &'a str>(slice) } } else { @@ -288,36 +321,98 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(ca.apply_mut(f)) } - /// Extract the nth capture group from pattern + /// Extract the nth capture group from pattern. fn extract(&self, pat: &str, group_index: usize) -> PolarsResult { let ca = self.as_utf8(); super::extract::extract_group(ca, pat, group_index) } - /// Extract each successive non-overlapping regex match in an individual string as an array + /// Extract each successive non-overlapping regex match in an individual string as an array. fn extract_all(&self, pat: &str) -> PolarsResult { let ca = self.as_utf8(); let reg = Regex::new(pat)?; let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); - for opt_s in ca.into_iter() { match opt_s { None => builder.append_null(), - Some(s) => { - let mut iter = reg.find_iter(s).map(|m| m.as_str()).peekable(); - if iter.peek().is_some() { - builder.append_values_iter(iter); - } else { - builder.append_null() - } - }, + Some(s) => builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())), } } Ok(builder.finish()) } - /// Extract each successive non-overlapping regex match in an individual string as an array + fn strip_chars(&self, pat: &Series) -> PolarsResult { + let ca = self.as_utf8(); + if pat.dtype() == &DataType::Null { + Ok(ca.apply_generic(|opt_s| opt_s.map(|s| s.trim()))) + } else { + Ok(strip_chars(ca, pat.utf8()?)) + } + } + + fn strip_chars_start(&self, pat: &Series) -> PolarsResult { + let ca = self.as_utf8(); + if pat.dtype() == &DataType::Null { + return Ok(ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_start()))); + } else { + Ok(strip_chars_start(ca, pat.utf8()?)) + } + } + + fn strip_chars_end(&self, pat: &Series) -> PolarsResult { + let ca = self.as_utf8(); + if pat.dtype() == &DataType::Null { + return Ok(ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_end()))); + } else { + Ok(strip_chars_end(ca, pat.utf8()?)) + } + } + + fn strip_prefix(&self, prefix: &Utf8Chunked) -> Utf8Chunked { + let ca = self.as_utf8(); + strip_prefix(ca, prefix) + } + + fn strip_suffix(&self, suffix: &Utf8Chunked) -> Utf8Chunked { + let ca = self.as_utf8(); + strip_suffix(ca, suffix) + } + + #[cfg(feature = "dtype-struct")] + fn split_exact(&self, by: &Utf8Chunked, n: usize) -> PolarsResult { + let ca = self.as_utf8(); + + split_to_struct(ca, by, n + 1, |s, by| s.split(by)) + } + + #[cfg(feature = "dtype-struct")] + fn split_exact_inclusive(&self, by: &Utf8Chunked, n: usize) -> PolarsResult { + let ca = self.as_utf8(); + + split_to_struct(ca, by, n + 1, |s, by| s.split_inclusive(by)) + } + + #[cfg(feature = "dtype-struct")] + fn splitn(&self, by: &Utf8Chunked, n: usize) -> PolarsResult { + let ca = self.as_utf8(); + + split_to_struct(ca, by, n, |s, by| s.splitn(n, by)) + } + + fn split(&self, by: &Utf8Chunked) -> ListChunked { + let ca = self.as_utf8(); + + split_helper(ca, by, str::split) + } + + fn split_inclusive(&self, by: &Utf8Chunked) -> ListChunked { + let ca = self.as_utf8(); + + split_helper(ca, by, str::split_inclusive) + } + + /// Extract each successive non-overlapping regex match in an individual string as an array. fn extract_all_many(&self, pat: &Utf8Chunked) -> PolarsResult { let ca = self.as_utf8(); polars_ensure!( @@ -326,60 +421,85 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { pat.len(), ca.len(), ); + // A sqrt(n) regex cache is not too small, not too large. + let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); - - for (opt_s, opt_pat) in ca.into_iter().zip(pat) { - match (opt_s, opt_pat) { - (_, None) | (None, _) => builder.append_null(), - (Some(s), Some(pat)) => { - let reg = Regex::new(pat)?; - let mut iter = reg.find_iter(s).map(|m| m.as_str()).peekable(); - if iter.peek().is_some() { - builder.append_values_iter(iter); - } else { - builder.append_null() - } - }, - } - } + binary_elementwise_for_each(ca, pat, |opt_s, opt_pat| match (opt_s, opt_pat) { + (_, None) | (None, _) => builder.append_null(), + (Some(s), Some(pat)) => { + let reg = reg_cache.get_or_insert_with(pat, |p| Regex::new(p).unwrap()); + builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())); + }, + }); Ok(builder.finish()) } #[cfg(feature = "extract_groups")] - /// Extract all capture groups from pattern and return as a struct + /// Extract all capture groups from pattern and return as a struct. fn extract_groups(&self, pat: &str, dtype: &DataType) -> PolarsResult { let ca = self.as_utf8(); super::extract::extract_groups(ca, pat, dtype) } /// Count all successive non-overlapping regex matches. - fn count_match(&self, pat: &str) -> PolarsResult { + fn count_matches(&self, pat: &str, literal: bool) -> PolarsResult { let ca = self.as_utf8(); - let reg = Regex::new(pat)?; + let reg = if literal { + Regex::new(escape(pat).as_str())? + } else { + Regex::new(pat)? + }; - let mut out: UInt32Chunked = ca - .into_iter() - .map(|opt_s| opt_s.map(|s| reg.find_iter(s).count() as u32)) - .collect(); - out.rename(ca.name()); - Ok(out) + Ok(ca.apply_generic(|opt_s| opt_s.map(|s| reg.find_iter(s).count() as u32))) + } + + /// Count all successive non-overlapping regex matches. + fn count_matches_many(&self, pat: &Utf8Chunked, literal: bool) -> PolarsResult { + let ca = self.as_utf8(); + polars_ensure!( + ca.len() == pat.len(), + ComputeError: "pattern's length: {} does not match that of the argument series: {}", + pat.len(), ca.len(), + ); + + // A sqrt(n) regex cache is not too small, not too large. + let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); + let op = move |opt_s: Option<&str>, opt_pat: Option<&str>| -> PolarsResult> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + let reg = reg_cache.get_or_insert_with(pat, |p| { + if literal { + Regex::new(escape(p).as_str()).unwrap() + } else { + Regex::new(p).unwrap() + } + }); + Ok(Some(reg.find_iter(s).count() as u32)) + }, + _ => Ok(None), + } + }; + + let out: UInt32Chunked = try_binary_elementwise(ca, pat, op)?; + + Ok(out.with_name(ca.name())) } - /// Modify the strings to their lowercase equivalent + /// Modify the strings to their lowercase equivalent. #[must_use] fn to_lowercase(&self) -> Utf8Chunked { let ca = self.as_utf8(); case::to_lowercase(ca) } - /// Modify the strings to their uppercase equivalent + /// Modify the strings to their uppercase equivalent. #[must_use] fn to_uppercase(&self) -> Utf8Chunked { let ca = self.as_utf8(); case::to_uppercase(ca) } - /// Modify the strings to their titlecase equivalent + /// Modify the strings to their titlecase equivalent. #[must_use] #[cfg(feature = "nightly")] fn to_titlecase(&self) -> Utf8Chunked { @@ -387,24 +507,23 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { case::to_titlecase(ca) } - /// Concat with the values from a second Utf8Chunked + /// Concat with the values from a second Utf8Chunked. #[must_use] fn concat(&self, other: &Utf8Chunked) -> Utf8Chunked { let ca = self.as_utf8(); ca + other } - /// Slice the string values + /// Slice the string values. + /// /// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`. /// `start` can be negative, in which case the start counts from the end of the string. - fn str_slice(&self, start: i64, length: Option) -> PolarsResult { + fn str_slice(&self, start: i64, length: Option) -> Utf8Chunked { let ca = self.as_utf8(); - let chunks = ca + let iter = ca .downcast_iter() - .map(|c| substring(c, start, &length)) - .collect::>()?; - // SAFETY: these are all the same type. - unsafe { Ok(Utf8Chunked::from_chunks(ca.name(), chunks)) } + .map(|c| substring::utf8_substring(c, start, &length)); + Utf8Chunked::from_chunk_iter_like(ca, iter) } } diff --git a/crates/polars-ops/src/chunked_array/strings/split.rs b/crates/polars-ops/src/chunked_array/strings/split.rs new file mode 100644 index 000000000000..e5fe046652fa --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/split.rs @@ -0,0 +1,113 @@ +#[cfg(feature = "dtype-struct")] +use polars_arrow::export::arrow::array::{MutableArray, MutableUtf8Array}; +use polars_core::chunked_array::ops::arity::binary_elementwise_for_each; + +use super::*; + +#[cfg(feature = "dtype-struct")] +pub fn split_to_struct<'a, F, I>( + ca: &'a Utf8Chunked, + by: &'a Utf8Chunked, + n: usize, + op: F, +) -> PolarsResult +where + F: Fn(&'a str, &'a str) -> I, + I: Iterator, +{ + let mut arrs = (0..n) + .map(|_| MutableUtf8Array::::with_capacity(ca.len())) + .collect::>(); + + if by.len() == 1 { + if let Some(by) = by.get(0) { + ca.for_each(|opt_s| match opt_s { + None => { + for arr in &mut arrs { + arr.push_null() + } + }, + Some(s) => { + let mut arr_iter = arrs.iter_mut(); + let split_iter = op(s, by); + (split_iter) + .zip(&mut arr_iter) + .for_each(|(splitted, arr)| arr.push(Some(splitted))); + // fill the remaining with null + for arr in arr_iter { + arr.push_null() + } + }, + }); + } else { + for arr in &mut arrs { + arr.push_null() + } + } + } else { + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let mut arr_iter = arrs.iter_mut(); + let split_iter = op(s, by); + (split_iter) + .zip(&mut arr_iter) + .for_each(|(splitted, arr)| arr.push(Some(splitted))); + // fill the remaining with null + for arr in arr_iter { + arr.push_null() + } + }, + _ => { + for arr in &mut arrs { + arr.push_null() + } + }, + }) + } + + let fields = arrs + .into_iter() + .enumerate() + .map(|(i, mut arr)| { + Series::try_from((format!("field_{i}").as_str(), arr.as_box())).unwrap() + }) + .collect::>(); + + StructChunked::new(ca.name(), &fields) +} + +pub fn split_helper<'a, F, I>(ca: &'a Utf8Chunked, by: &'a Utf8Chunked, op: F) -> ListChunked +where + F: Fn(&'a str, &'a str) -> I, + I: Iterator, +{ + if by.len() == 1 { + if let Some(by) = by.get(0) { + let mut builder = + ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + ca.for_each(|opt_s| match opt_s { + Some(s) => { + let iter = op(s, by); + builder.append_values_iter(iter) + }, + _ => builder.append_null(), + }); + builder.finish() + } else { + ListChunked::full_null_with_dtype(ca.name(), ca.len(), &DataType::Utf8) + } + } else { + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let iter = op(s, by); + builder.append_values_iter(iter); + }, + _ => builder.append_null(), + }); + + builder.finish() + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/strip.rs b/crates/polars-ops/src/chunked_array/strings/strip.rs new file mode 100644 index 000000000000..1d58f4fd5f7a --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/strip.rs @@ -0,0 +1,139 @@ +use polars_core::prelude::arity::binary_elementwise; + +use super::*; + +fn strip_chars_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + if pat.chars().count() == 1 { + Some(s.trim_matches(pat.chars().next().unwrap())) + } else { + Some(s.trim_matches(|c| pat.contains(c))) + } + }, + (Some(s), _) => Some(s.trim()), + _ => None, + } +} + +fn strip_chars_start_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + if pat.chars().count() == 1 { + Some(s.trim_start_matches(pat.chars().next().unwrap())) + } else { + Some(s.trim_start_matches(|c| pat.contains(c))) + } + }, + (Some(s), _) => Some(s.trim_start()), + _ => None, + } +} + +fn strip_chars_end_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + if pat.chars().count() == 1 { + Some(s.trim_end_matches(pat.chars().next().unwrap())) + } else { + Some(s.trim_end_matches(|c| pat.contains(c))) + } + }, + (Some(s), _) => Some(s.trim_end()), + _ => None, + } +} + +fn strip_prefix_binary<'a>(s: Option<&'a str>, prefix: Option<&str>) -> Option<&'a str> { + Some(s?.strip_prefix(prefix?).unwrap_or(s?)) +} + +fn strip_suffix_binary<'a>(s: Option<&'a str>, suffix: Option<&str>) -> Option<&'a str> { + Some(s?.strip_suffix(suffix?).unwrap_or(s?)) +} + +pub fn strip_chars(ca: &Utf8Chunked, pat: &Utf8Chunked) -> Utf8Chunked { + match pat.len() { + 1 => { + if let Some(pat) = pat.get(0) { + if pat.chars().count() == 1 { + // Fast path for when a single character is passed + ca.apply_generic(|opt_s| { + opt_s.map(|s| s.trim_matches(pat.chars().next().unwrap())) + }) + } else { + ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_matches(|c| pat.contains(c)))) + } + } else { + ca.apply_generic(|opt_s| opt_s.map(|s| s.trim())) + } + }, + _ => binary_elementwise(ca, pat, strip_chars_binary), + } +} + +pub fn strip_chars_start(ca: &Utf8Chunked, pat: &Utf8Chunked) -> Utf8Chunked { + match pat.len() { + 1 => { + if let Some(pat) = pat.get(0) { + if pat.chars().count() == 1 { + // Fast path for when a single character is passed + ca.apply_generic(|opt_s| { + opt_s.map(|s| s.trim_start_matches(pat.chars().next().unwrap())) + }) + } else { + ca.apply_generic(|opt_s| { + opt_s.map(|s| s.trim_start_matches(|c| pat.contains(c))) + }) + } + } else { + ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_start())) + } + }, + _ => binary_elementwise(ca, pat, strip_chars_start_binary), + } +} + +pub fn strip_chars_end(ca: &Utf8Chunked, pat: &Utf8Chunked) -> Utf8Chunked { + match pat.len() { + 1 => { + if let Some(pat) = pat.get(0) { + if pat.chars().count() == 1 { + // Fast path for when a single character is passed + ca.apply_generic(|opt_s| { + opt_s.map(|s| s.trim_end_matches(pat.chars().next().unwrap())) + }) + } else { + ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_end_matches(|c| pat.contains(c)))) + } + } else { + ca.apply_generic(|opt_s| opt_s.map(|s| s.trim_end())) + } + }, + _ => binary_elementwise(ca, pat, strip_chars_end_binary), + } +} + +pub fn strip_prefix(ca: &Utf8Chunked, prefix: &Utf8Chunked) -> Utf8Chunked { + match prefix.len() { + 1 => match prefix.get(0) { + Some(prefix) => { + ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s))) + }, + _ => Utf8Chunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise(ca, prefix, strip_prefix_binary), + } +} + +pub fn strip_suffix(ca: &Utf8Chunked, suffix: &Utf8Chunked) -> Utf8Chunked { + match suffix.len() { + 1 => match suffix.get(0) { + Some(suffix) => { + ca.apply_generic(|opt_s| opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s))) + }, + _ => Utf8Chunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise(ca, suffix, strip_suffix_binary), + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs new file mode 100644 index 000000000000..e485e25dd216 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -0,0 +1,51 @@ +use arrow::array::Utf8Array; + +/// Returns a Utf8Array with a substring starting from `start` and with optional length `length` of each of the elements in `array`. +/// `start` can be negative, in which case the start counts from the end of the string. +pub(super) fn utf8_substring( + array: &Utf8Array, + start: i64, + length: &Option, +) -> Utf8Array { + let length = length.map(|v| v as usize); + + let iter = array.values_iter().map(|str_val| { + // compute where we should start slicing this entry. + let start = if start >= 0 { + start as usize + } else { + let start = (0i64 - start) as usize; + str_val + .char_indices() + .rev() + .nth(start) + .map(|(idx, _)| idx + 1) + .unwrap_or(0) + }; + + let mut iter_chars = str_val.char_indices(); + if let Some((start_idx, _)) = iter_chars.nth(start) { + // length of the str + let len_end = str_val.len() - start_idx; + + // length to slice + let length = length.unwrap_or(len_end); + + if length == 0 { + return ""; + } + // compute + let end_idx = iter_chars + .nth(length.saturating_sub(1)) + .map(|(idx, _)| idx) + .unwrap_or(str_val.len()); + + &str_val[start_idx..end_idx] + } else { + "" + } + }); + + let new = Utf8Array::::from_trusted_len_values_iter(iter); + new.with_validity(array.validity().cloned()) +} diff --git a/crates/polars-ops/src/chunked_array/top_k.rs b/crates/polars-ops/src/chunked_array/top_k.rs index c89236cbaf04..47a338b202c0 100644 --- a/crates/polars-ops/src/chunked_array/top_k.rs +++ b/crates/polars-ops/src/chunked_array/top_k.rs @@ -70,19 +70,34 @@ where } } -pub fn top_k(s: &Series, k: usize, descending: bool) -> PolarsResult { - if s.is_empty() { - return Ok(s.clone()); +pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { + let src = &s[0]; + let k_s = &s[1]; + + if src.is_empty() { + return Ok(src.clone()); } - let dtype = s.dtype(); - let s = s.to_physical_repr(); + polars_ensure!( + k_s.len() == 1, + ComputeError: "k must be a single value." + ); - macro_rules! dispatch { - ($ca:expr) => {{ - top_k_impl($ca, k, descending).into_series() - }}; - } + let k_s = k_s.cast(&IDX_DTYPE)?; + let k = k_s.idx()?; + + let dtype = src.dtype(); - downcast_as_macro_arg_physical!(&s, dispatch).cast(dtype) + if let Some(k) = k.get(0) { + let s = src.to_physical_repr(); + macro_rules! dispatch { + ($ca:expr) => {{ + top_k_impl($ca, k as usize, descending).into_series() + }}; + } + + downcast_as_macro_arg_physical!(&s, dispatch).cast(dtype) + } else { + Ok(Series::full_null(src.name(), src.len(), dtype)) + } } diff --git a/crates/polars-ops/src/frame/hashing.rs b/crates/polars-ops/src/frame/hashing.rs new file mode 100644 index 000000000000..245f125edfb9 --- /dev/null +++ b/crates/polars-ops/src/frame/hashing.rs @@ -0,0 +1,97 @@ +use std::hash::Hash; + +use ahash::RandomState; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use polars_arrow::trusted_len::TrustedLen; +use polars_arrow::utils::CustomIterTools; +use polars_core::hashing::partition::this_partition; +use polars_core::prelude::*; +use polars_core::utils::_set_partition_size; +use polars_core::POOL; +use rayon::prelude::*; + +pub(crate) fn prepare_hashed_relation_threaded( + iters: Vec, +) -> Vec), RandomState>> +where + I: Iterator + Send + TrustedLen, + T: Send + Hash + Eq + Sync + Copy, +{ + let n_partitions = _set_partition_size(); + let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None); + + // We will create a hashtable in every thread. + // We use the hash to partition the keys to the matching hashtable. + // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. + POOL.install(|| { + (0..n_partitions) + .into_par_iter() + .map(|partition_no| { + let build_hasher = build_hasher.clone(); + let hashes_and_keys = &hashes_and_keys; + let partition_no = partition_no as u64; + let mut hash_tbl: HashMap), RandomState> = + HashMap::with_hasher(build_hasher); + + let n_threads = n_partitions as u64; + let mut offset = 0; + for hashes_and_keys in hashes_and_keys { + let len = hashes_and_keys.len(); + hashes_and_keys + .iter() + .enumerate() + .for_each(|(idx, (h, k))| { + let idx = idx as IdxSize; + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if this_partition(*h, partition_no, n_threads) { + let idx = idx + offset; + let entry = hash_tbl + .raw_entry_mut() + // uses the key to check equality to find and entry + .from_key_hashed_nocheck(*h, k); + + match entry { + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck(*h, *k, (false, vec![idx])); + }, + RawEntryMut::Occupied(mut entry) => { + let (_k, v) = entry.get_key_value_mut(); + v.1.push(idx); + }, + } + } + }); + + offset += len as IdxSize; + } + hash_tbl + }) + .collect() + }) +} + +pub(crate) fn create_hash_and_keys_threaded_vectorized( + iters: Vec, + build_hasher: Option, +) -> (Vec>, RandomState) +where + I: IntoIterator + Send, + I::IntoIter: TrustedLen, + T: Send + Hash + Eq, +{ + let build_hasher = build_hasher.unwrap_or_default(); + let hashes = POOL.install(|| { + iters + .into_par_iter() + .map(|iter| { + // create hashes and keys + iter.into_iter() + .map(|val| (build_hasher.hash_one(&val), val)) + .collect_trusted::>() + }) + .collect() + }); + (hashes, build_hasher) +} diff --git a/crates/polars-core/src/frame/hash_join/args.rs b/crates/polars-ops/src/frame/join/args.rs similarity index 95% rename from crates/polars-core/src/frame/hash_join/args.rs rename to crates/polars-ops/src/frame/join/args.rs index 5b36b3353034..af8509f7642b 100644 --- a/crates/polars-core/src/frame/hash_join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -15,8 +15,11 @@ pub type ChunkJoinOptIds = Vec>; #[cfg(not(feature = "chunked_ids"))] pub type ChunkJoinIds = Vec; -/// [ChunkIdx, DfIdx] -pub type ChunkId = [IdxSize; 2]; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "asof_join")] +use super::asof::AsOfOptions; #[derive(Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -125,7 +128,7 @@ impl JoinValidation { if !self.needs_checks() { return Ok(()); } - polars_ensure!(n_keys == 1, ComputeError: "{validation} not yet supported for multiple keys"); + polars_ensure!(n_keys == 1, ComputeError: "{self} validation on a {join_type} is not yet supported for multiple keys"); polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Outer | JoinType::Left), ComputeError: "{self} validation on a {join_type} join is not supported"); Ok(()) diff --git a/crates/polars-core/src/frame/asof_join/asof.rs b/crates/polars-ops/src/frame/join/asof/default.rs similarity index 98% rename from crates/polars-core/src/frame/asof_join/asof.rs rename to crates/polars-ops/src/frame/join/asof/default.rs index 7edbd1372bae..ba46ca214b74 100644 --- a/crates/polars-core/src/frame/asof_join/asof.rs +++ b/crates/polars-ops/src/frame/join/asof/default.rs @@ -245,7 +245,6 @@ pub(super) fn join_asof_nearest_with_tolerance< } // We made it to the window: matches are now possible, start measuring distance. - found_window = true; let current_dist = if val_l > val_r { val_l - val_r } else { @@ -259,10 +258,15 @@ pub(super) fn join_asof_nearest_with_tolerance< break; } } else { - // We'ved moved farther away, so the last element was the match. - out.push(Some(offset - 1)); + // We'ved moved farther away, so the last element was the match if it's within tolerance + if found_window { + out.push(Some(offset - 1)); + } else { + out.push(None); + } break; } + found_window = true; offset += 1; } } diff --git a/crates/polars-core/src/frame/asof_join/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs similarity index 96% rename from crates/polars-core/src/frame/asof_join/groups.rs rename to crates/polars-ops/src/frame/join/asof/groups.rs index 9c980c935b3b..a8e31f69af50 100644 --- a/crates/polars-core/src/frame/asof_join/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -5,19 +5,15 @@ use std::ops::{Add, Sub}; use ahash::RandomState; use arrow::types::NativeType; use num_traits::{Bounded, Zero}; +use polars_core::hashing::partition::AsU64; +use polars_core::hashing::{_df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; +use polars_core::utils::{split_ca, split_df}; +use polars_core::POOL; use rayon::prelude::*; use smartstring::alias::String as SmartString; use super::*; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; -#[cfg(feature = "dtype-categorical")] -use crate::frame::hash_join::_check_categorical_src; -use crate::frame::hash_join::{ - build_tables, get_hash_tbl_threaded_join_partitioned, multiple_keys as mk, prepare_bytes, -}; -use crate::hashing::{df_rows_to_hashes_threaded_vertical, AsU64}; -use crate::utils::{split_ca, split_df}; -use crate::POOL; +use crate::frame::IntoDf; pub(super) unsafe fn join_asof_backward_with_indirection_and_tolerance< T: PartialOrd + Copy + Sub + Debug, @@ -131,7 +127,6 @@ pub(super) unsafe fn join_asof_nearest_with_indirection_and_tolerance< } // We made it to the window: matches are now possible, start measuring distance. - found_window = true; let current_dist = if val_l > val_r { val_l - val_r } else { @@ -145,9 +140,14 @@ pub(super) unsafe fn join_asof_nearest_with_indirection_and_tolerance< } prev_offset = offset; } else { - // We'ved moved farther away, so the last element was the match. - return (Some(prev_offset), idx - 1); + // We'ved moved farther away, so the last element was the match if it's within tolerance + if found_window { + return (Some(prev_offset), idx - 1); + } else { + return (None, n_right - 1); + } } + found_window = true; } // This should be unreachable. @@ -398,7 +398,7 @@ where // assume the result tuples equal length of the no. of hashes processed by this thread. let mut results = Vec::with_capacity(vals_left.len()); - let mut right_tbl_offsets = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut right_tbl_offsets = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); vals_left.iter().enumerate().for_each(|(idx_a, k)| { let idx_a = (idx_a + offset) as IdxSize; @@ -526,7 +526,7 @@ where // assume the result tuples equal length of the no. of hashes processed by this thread. let mut results = Vec::with_capacity(vals_left.len()); - let mut right_tbl_offsets = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut right_tbl_offsets = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); vals_left.iter().enumerate().for_each(|(idx_a, k)| { let idx_a = (idx_a + offset) as IdxSize; @@ -621,9 +621,9 @@ where let dfs_a = split_df(a, n_threads).unwrap(); let dfs_b = split_df(b, n_threads).unwrap(); - let (build_hashes, random_state) = df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); + let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); let (probe_hashes, _) = - df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); + _df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); let hash_tbls = mk::create_probe_table(&build_hashes, b); // early drop to reduce memory pressure @@ -644,7 +644,7 @@ where // assume the result tuples equal length of the no. of hashes processed by this thread. let mut results = Vec::with_capacity(probe_hashes.len()); - let mut right_tbl_offsets = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut right_tbl_offsets = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); let local_offset = offset; @@ -759,10 +759,10 @@ fn dispatch_join( Ok(out) } -impl DataFrame { +pub trait AsofJoinBy: IntoDf { #[allow(clippy::too_many_arguments)] #[doc(hidden)] - pub fn _join_asof_by( + fn _join_asof_by( &self, other: &DataFrame, left_on: &str, @@ -774,7 +774,8 @@ impl DataFrame { suffix: Option<&str>, slice: Option<(i64, usize)>, ) -> PolarsResult { - let left_asof = self.column(left_on)?.to_physical_repr(); + let self_df = self.to_df(); + let left_asof = self_df.column(left_on)?.to_physical_repr(); let right_asof = other.column(right_on)?.to_physical_repr(); let right_asof_name = right_asof.name(); let left_asof_name = left_asof.name(); @@ -785,7 +786,7 @@ impl DataFrame { left_by.is_empty() && right_by.is_empty(), )?; - let mut left_by = self.select(left_by)?; + let mut left_by = self_df.select(left_by)?; let mut right_by = other.select(right_by)?; unsafe { @@ -838,7 +839,7 @@ impl DataFrame { .collect(); let other = DataFrame::new_no_checks(cols); - let mut left = self.clone(); + let mut left = self_df.clone(); let mut right_join_tuples = &*right_join_tuples; if let Some((offset, len)) = slice { @@ -846,15 +847,9 @@ impl DataFrame { right_join_tuples = slice_slice(right_join_tuples, offset, len); } - // Safety: - // join tuples are in bounds - let right_df = unsafe { - other.take_opt_iter_unchecked( - right_join_tuples - .iter() - .map(|opt_idx| opt_idx.map(|idx| idx as usize)), - ) - }; + // SAFETY: join tuples are in bounds. + let right_df = + unsafe { other.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) }; _finish_join(left, right_df, suffix) } @@ -863,7 +858,7 @@ impl DataFrame { /// The keys must be sorted to perform an asof join. This is a special implementation of an asof join /// that searches for the nearest keys within a subgroup set by `by`. #[allow(clippy::too_many_arguments)] - pub fn join_asof_by( + fn join_asof_by( &self, other: &DataFrame, left_on: &str, @@ -877,14 +872,17 @@ impl DataFrame { I: IntoIterator, S: AsRef, { + let self_df = self.to_df(); let left_by = left_by.into_iter().map(|s| s.as_ref().into()).collect(); let right_by = right_by.into_iter().map(|s| s.as_ref().into()).collect(); - self._join_asof_by( + self_df._join_asof_by( other, left_on, right_on, left_by, right_by, strategy, tolerance, None, None, ) } } +impl AsofJoinBy for DataFrame {} + #[cfg(test)] mod test { use super::*; diff --git a/crates/polars-core/src/frame/asof_join/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs similarity index 52% rename from crates/polars-core/src/frame/asof_join/mod.rs rename to crates/polars-ops/src/frame/join/asof/mod.rs index c496c670d696..eb1765930035 100644 --- a/crates/polars-core/src/frame/asof_join/mod.rs +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -1,15 +1,24 @@ -mod asof; +mod default; mod groups; use std::borrow::Cow; -use asof::*; +use default::*; +pub(super) use groups::AsofJoinBy; use num_traits::Bounded; +use polars_core::prelude::*; +use polars_core::utils::{ensure_sorted_arg, slice_slice}; +use polars_core::with_match_physical_numeric_polars_type; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; -use crate::prelude::*; -use crate::utils::{ensure_sorted_arg, slice_slice}; +#[cfg(feature = "dtype-categorical")] +use super::_check_categorical_src; +use super::{ + _finish_join, build_tables, get_hash_tbl_threaded_join_partitioned, multiple_keys as mk, + prepare_bytes, +}; +use crate::frame::IntoDf; #[derive(Clone, Debug, PartialEq, Eq, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -63,66 +72,64 @@ pub enum AsofStrategy { Nearest, } -impl ChunkedArray +pub(crate) fn join_asof( + input_ca: &ChunkedArray, + other: &Series, + strategy: AsofStrategy, + tolerance: Option>, +) -> PolarsResult>> where T: PolarsNumericType, T::Native: Bounded + PartialOrd, { - pub(crate) fn join_asof( - &self, - other: &Series, - strategy: AsofStrategy, - tolerance: Option>, - ) -> PolarsResult>> { - let other = self.unpack_series_matching_type(other)?; - - // cont_slice requires a single chunk - let ca = self.rechunk(); - let other = other.rechunk(); - - let out = match strategy { - AsofStrategy::Forward => match tolerance { - None => join_asof_forward(ca.cont_slice().unwrap(), other.cont_slice().unwrap()), - Some(tolerance) => { - let tolerance = tolerance.extract::().unwrap(); - join_asof_forward_with_tolerance( - ca.cont_slice().unwrap(), - other.cont_slice().unwrap(), - tolerance, - ) - }, + let other = input_ca.unpack_series_matching_type(other)?; + + // cont_slice requires a single chunk + let ca = input_ca.rechunk(); + let other = other.rechunk(); + + let out = match strategy { + AsofStrategy::Forward => match tolerance { + None => join_asof_forward(ca.cont_slice().unwrap(), other.cont_slice().unwrap()), + Some(tolerance) => { + let tolerance = tolerance.extract::().unwrap(); + join_asof_forward_with_tolerance( + ca.cont_slice().unwrap(), + other.cont_slice().unwrap(), + tolerance, + ) }, - AsofStrategy::Backward => match tolerance { - None => join_asof_backward(ca.cont_slice().unwrap(), other.cont_slice().unwrap()), - Some(tolerance) => { - let tolerance = tolerance.extract::().unwrap(); - join_asof_backward_with_tolerance( - self.cont_slice().unwrap(), - other.cont_slice().unwrap(), - tolerance, - ) - }, + }, + AsofStrategy::Backward => match tolerance { + None => join_asof_backward(ca.cont_slice().unwrap(), other.cont_slice().unwrap()), + Some(tolerance) => { + let tolerance = tolerance.extract::().unwrap(); + join_asof_backward_with_tolerance( + input_ca.cont_slice().unwrap(), + other.cont_slice().unwrap(), + tolerance, + ) }, - AsofStrategy::Nearest => match tolerance { - None => join_asof_nearest(ca.cont_slice().unwrap(), other.cont_slice().unwrap()), - Some(tolerance) => { - let tolerance = tolerance.extract::().unwrap(); - join_asof_nearest_with_tolerance( - self.cont_slice().unwrap(), - other.cont_slice().unwrap(), - tolerance, - ) - }, + }, + AsofStrategy::Nearest => match tolerance { + None => join_asof_nearest(ca.cont_slice().unwrap(), other.cont_slice().unwrap()), + Some(tolerance) => { + let tolerance = tolerance.extract::().unwrap(); + join_asof_nearest_with_tolerance( + input_ca.cont_slice().unwrap(), + other.cont_slice().unwrap(), + tolerance, + ) }, - }; - Ok(out) - } + }, + }; + Ok(out) } -impl DataFrame { +pub trait AsofJoin: IntoDf { #[doc(hidden)] #[allow(clippy::too_many_arguments)] - pub fn _join_asof( + fn _join_asof( &self, other: &DataFrame, left_on: &str, @@ -132,7 +139,8 @@ impl DataFrame { suffix: Option, slice: Option<(i64, usize)>, ) -> PolarsResult { - let left_key = self.column(left_on)?; + let self_df = self.to_df(); + let left_key = self_df.column(left_on)?; let right_key = other.column(right_on)?; check_asof_columns(left_key, right_key, true)?; @@ -140,37 +148,35 @@ impl DataFrame { let right_key = right_key.to_physical_repr(); let take_idx = match left_key.dtype() { - DataType::Int64 => left_key - .i64() - .unwrap() - .join_asof(&right_key, strategy, tolerance), - DataType::Int32 => left_key - .i32() - .unwrap() - .join_asof(&right_key, strategy, tolerance), - DataType::UInt64 => left_key - .u64() - .unwrap() - .join_asof(&right_key, strategy, tolerance), - DataType::UInt32 => left_key - .u32() - .unwrap() - .join_asof(&right_key, strategy, tolerance), - DataType::Float32 => left_key - .f32() - .unwrap() - .join_asof(&right_key, strategy, tolerance), - DataType::Float64 => left_key - .f64() - .unwrap() - .join_asof(&right_key, strategy, tolerance), + DataType::Int64 => { + let ca = left_key.i64().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) + }, + DataType::Int32 => { + let ca = left_key.i32().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) + }, + DataType::UInt64 => { + let ca = left_key.u64().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) + }, + DataType::UInt32 => { + let ca = left_key.u32().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) + }, + DataType::Float32 => { + let ca = left_key.f32().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) + }, + DataType::Float64 => { + let ca = left_key.f64().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) + }, _ => { let left_key = left_key.cast(&DataType::Int32).unwrap(); let right_key = right_key.cast(&DataType::Int32).unwrap(); - left_key - .i32() - .unwrap() - .join_asof(&right_key, strategy, tolerance) + let ca = left_key.i32().unwrap(); + join_asof(ca, &right_key, strategy, tolerance) }, }?; @@ -186,7 +192,7 @@ impl DataFrame { Cow::Borrowed(other) }; - let mut left = self.clone(); + let mut left = self_df.clone(); let mut take_idx = &*take_idx; if let Some((offset, len)) = slice { @@ -194,22 +200,15 @@ impl DataFrame { take_idx = slice_slice(take_idx, offset, len); } - // Safety: - // join tuples are in bounds - let right_df = unsafe { - other.take_opt_iter_unchecked( - take_idx - .iter() - .map(|opt_idx| opt_idx.map(|idx| idx as usize)), - ) - }; + // SAFETY: join tuples are in bounds. + let right_df = unsafe { other.take_unchecked(&take_idx.iter().copied().collect_ca("")) }; _finish_join(left, right_df, suffix.as_deref()) } /// This is similar to a left-join except that we match on nearest key rather than equal keys. /// The keys must be sorted to perform an asof join - pub fn join_asof( + fn join_asof( &self, other: &DataFrame, left_on: &str, @@ -221,3 +220,5 @@ impl DataFrame { self._join_asof(other, left_on, right_on, strategy, tolerance, suffix, None) } } + +impl AsofJoin for DataFrame {} diff --git a/crates/polars-ops/src/frame/join/checks.rs b/crates/polars-ops/src/frame/join/checks.rs new file mode 100644 index 000000000000..b87baa2f5e4c --- /dev/null +++ b/crates/polars-ops/src/frame/join/checks.rs @@ -0,0 +1,10 @@ +use super::*; + +/// If Categorical types are created without a global string cache or under +/// a different global string cache the mapping will be incorrect. +pub(crate) fn _check_categorical_src(l: &DataType, r: &DataType) -> PolarsResult<()> { + if let (DataType::Categorical(Some(l)), DataType::Categorical(Some(r))) = (l, r) { + polars_ensure!(l.same_src(r), string_cache_mismatch); + } + Ok(()) +} diff --git a/crates/polars-core/src/frame/cross_join.rs b/crates/polars-ops/src/frame/join/cross_join.rs similarity index 78% rename from crates/polars-core/src/frame/cross_join.rs rename to crates/polars-ops/src/frame/join/cross_join.rs index 46d73cb85a74..1bc596960b2c 100644 --- a/crates/polars-core/src/frame/cross_join.rs +++ b/crates/polars-ops/src/frame/join/cross_join.rs @@ -1,9 +1,9 @@ +use polars_core::series::IsSorted; +use polars_core::utils::{concat_df_unchecked, slice_offsets, CustomIterTools, NoNull}; +use polars_core::POOL; use smartstring::alias::String as SmartString; -use crate::prelude::*; -use crate::series::IsSorted; -use crate::utils::{concat_df_unchecked, slice_offsets, CustomIterTools, NoNull}; -use crate::POOL; +use super::*; fn slice_take( total_rows: IdxSize, @@ -41,14 +41,15 @@ fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, us slice_take(total_rows, n_rows_right, slice, inner) } -impl DataFrame { +pub trait CrossJoin: IntoDf { fn cross_join_dfs( &self, other: &DataFrame, slice: Option<(i64, usize)>, parallel: bool, ) -> PolarsResult<(DataFrame, DataFrame)> { - let n_rows_left = self.height() as IdxSize; + let df_self = self.to_df(); + let n_rows_left = df_self.height() as IdxSize; let n_rows_right = other.height() as IdxSize; let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else { polars_bail!( @@ -57,7 +58,7 @@ impl DataFrame { ); }; if n_rows_left == 0 || n_rows_right == 0 { - return Ok((self.clear(), other.clear())); + return Ok((df_self.clear(), other.clear())); } // the left side has the Nth row combined with every row from right. @@ -71,7 +72,7 @@ impl DataFrame { let create_left_df = || { // Safety: // take left is in bounds - unsafe { self.take_unchecked(&take_left(total_rows, n_rows_right, slice)) } + unsafe { df_self.take_unchecked(&take_left(total_rows, n_rows_right, slice)) } }; let create_right_df = || { @@ -97,7 +98,7 @@ impl DataFrame { #[doc(hidden)] /// used by streaming - pub fn _cross_join_with_names( + fn _cross_join_with_names( &self, other: &DataFrame, names: &[SmartString], @@ -105,7 +106,7 @@ impl DataFrame { let (mut l_df, r_df) = self.cross_join_dfs(other, None, false)?; unsafe { - l_df.get_columns_mut().extend_from_slice(&r_df.columns); + l_df.get_columns_mut().extend_from_slice(r_df.get_columns()); l_df.get_columns_mut() .iter_mut() @@ -120,7 +121,7 @@ impl DataFrame { } /// Creates the cartesian product from both frames, preserves the order of the left keys. - pub fn cross_join( + fn cross_join( &self, other: &DataFrame, suffix: Option<&str>, @@ -132,31 +133,4 @@ impl DataFrame { } } -#[cfg(test)] -mod test { - use super::*; - use crate::df; - - #[test] - fn test_cross_join() -> PolarsResult<()> { - let df_a = df![ - "a" => [1, 2], - "b" => ["foo", "spam"] - ]?; - - let df_b = df![ - "b" => ["a", "b", "c"] - ]?; - - let out = df_a.cross_join(&df_b, None, None)?; - let expected = df![ - "a" => [1, 1, 1, 2, 2, 2], - "b" => ["foo", "foo", "foo", "spam", "spam", "spam"], - "b_right" => ["a", "b", "c", "a", "b", "c"] - ]?; - - assert!(out.frame_equal(&expected)); - - Ok(()) - } -} +impl CrossJoin for DataFrame {} diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs new file mode 100644 index 000000000000..56fd3b11f0f3 --- /dev/null +++ b/crates/polars-ops/src/frame/join/general.rs @@ -0,0 +1,47 @@ +use super::*; + +pub fn _join_suffix_name(name: &str, suffix: &str) -> String { + format!("{name}{suffix}") +} + +/// Utility method to finish a join. +#[doc(hidden)] +pub fn _finish_join( + mut df_left: DataFrame, + mut df_right: DataFrame, + suffix: Option<&str>, +) -> PolarsResult { + let mut left_names = PlHashSet::with_capacity(df_left.width()); + + df_left.get_columns().iter().for_each(|series| { + left_names.insert(series.name()); + }); + + let mut rename_strs = Vec::with_capacity(df_right.width()); + + df_right.get_columns().iter().for_each(|series| { + if left_names.contains(series.name()) { + rename_strs.push(series.name().to_owned()) + } + }); + let suffix = suffix.unwrap_or("_right"); + + for name in rename_strs { + df_right.rename(&name, &_join_suffix_name(&name, suffix))?; + } + + drop(left_names); + df_left.hstack_mut(df_right.get_columns())?; + Ok(df_left) +} + +#[cfg(feature = "chunked_ids")] +pub(crate) fn create_chunked_index_mapping(chunks: &[ArrayRef], len: usize) -> Vec { + let mut vals = Vec::with_capacity(len); + + for (chunk_i, chunk) in chunks.iter().enumerate() { + vals.extend((0..chunk.len()).map(|array_i| [chunk_i as IdxSize, array_i as IdxSize])) + } + + vals +} diff --git a/crates/polars-core/src/frame/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs similarity index 63% rename from crates/polars-core/src/frame/hash_join/mod.rs rename to crates/polars-ops/src/frame/join/hash_join/mod.rs index f2f7500447da..c86f3d517def 100644 --- a/crates/polars-core/src/frame/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -1,5 +1,4 @@ -mod args; -pub(crate) mod multiple_keys; +pub(super) mod multiple_keys; pub(super) mod single_keys; mod single_keys_dispatch; mod single_keys_inner; @@ -10,47 +9,24 @@ mod single_keys_semi_anti; pub(super) mod sort_merge; mod zip_outer; -use std::fmt::{Debug, Display, Formatter}; -use std::hash::{BuildHasher, Hash, Hasher}; - -use ahash::RandomState; pub use args::*; -#[cfg(feature = "chunked_ids")] -use arrow::Either; -use hashbrown::hash_map::{Entry, RawEntryMut}; -use hashbrown::HashMap; -use polars_arrow::utils::CustomIterTools; -use rayon::prelude::*; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; -#[cfg(feature = "asof_join")] -pub(crate) use single_keys::build_tables; +pub use multiple_keys::private_left_join_multiple_keys; +pub(super) use multiple_keys::*; +use polars_core::utils::{_set_partition_size, slice_slice, split_ca}; +use polars_core::POOL; +pub(super) use single_keys::*; #[cfg(feature = "asof_join")] -pub(crate) use single_keys_dispatch::prepare_bytes; +pub(super) use single_keys_dispatch::prepare_bytes; +pub use single_keys_dispatch::SeriesJoin; +use single_keys_inner::*; use single_keys_left::*; use single_keys_outer::*; #[cfg(feature = "semi_anti_join")] use single_keys_semi_anti::*; pub use sort_merge::*; -pub(crate) use zip_outer::*; +pub(super) use zip_outer::zip_outer_join_column; -pub use self::multiple_keys::private_left_join_multiple_keys; -use crate::datatypes::PlHashMap; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; -pub use crate::frame::hash_join::multiple_keys::{ - _inner_join_multiple_keys, _left_join_multiple_keys, _outer_join_multiple_keys, -}; -#[cfg(feature = "semi_anti_join")] -pub use crate::frame::hash_join::multiple_keys::{ - _left_anti_multiple_keys, _left_semi_multiple_keys, -}; -use crate::hashing::{ - create_hash_and_keys_threaded_vectorized, prepare_hashed_relation_threaded, this_partition, - AsU64, BytesHash, -}; -use crate::prelude::*; -use crate::utils::{_set_partition_size, slice_slice, split_ca}; -use crate::POOL; +pub use super::*; pub fn default_join_ids() -> ChunkJoinOptIds { #[cfg(feature = "chunked_ids")] @@ -86,19 +62,7 @@ pub(super) use det_hash_prone_order; use polars_arrow::conversion::primitive_to_vec; use polars_utils::hash_to_partition; -use crate::series::IsSorted; - -/// If Categorical types are created without a global string cache or under -/// a different global string cache the mapping will be incorrect. -#[cfg(feature = "dtype-categorical")] -pub fn _check_categorical_src(l: &DataType, r: &DataType) -> PolarsResult<()> { - if let (DataType::Categorical(Some(l)), DataType::Categorical(Some(r))) = (l, r) { - polars_ensure!(l.same_src(r), string_cache_mismatch); - } - Ok(()) -} - -pub(crate) unsafe fn get_hash_tbl_threaded_join_partitioned( +pub(super) unsafe fn get_hash_tbl_threaded_join_partitioned( h: u64, hash_tables: &[Item], len: u64, @@ -117,48 +81,14 @@ unsafe fn get_hash_tbl_threaded_join_mut_partitioned( hash_tables.get_unchecked_mut(i) } -pub fn _join_suffix_name(name: &str, suffix: &str) -> String { - format!("{name}{suffix}") -} - -/// Utility method to finish a join. -#[doc(hidden)] -pub fn _finish_join( - mut df_left: DataFrame, - mut df_right: DataFrame, - suffix: Option<&str>, -) -> PolarsResult { - let mut left_names = PlHashSet::with_capacity(df_left.width()); - - df_left.columns.iter().for_each(|series| { - left_names.insert(series.name()); - }); - - let mut rename_strs = Vec::with_capacity(df_right.width()); - - df_right.columns.iter().for_each(|series| { - if left_names.contains(series.name()) { - rename_strs.push(series.name().to_owned()) - } - }); - let suffix = suffix.unwrap_or("_right"); - - for name in rename_strs { - df_right.rename(&name, &_join_suffix_name(&name, suffix))?; - } - - drop(left_names); - df_left.hstack_mut(&df_right.columns)?; - Ok(df_left) -} - -impl DataFrame { +pub trait JoinDispatch: IntoDf { /// # Safety /// Join tuples must be in bounds #[cfg(feature = "chunked_ids")] unsafe fn create_left_df_chunked(&self, chunk_ids: &[ChunkId], left_join: bool) -> DataFrame { - if left_join && chunk_ids.len() == self.height() { - self.clone() + let df_self = self.to_df(); + if left_join && chunk_ids.len() == df_self.height() { + df_self.clone() } else { // left join keys are in ascending order let sorted = if left_join { @@ -166,20 +96,21 @@ impl DataFrame { } else { IsSorted::Not }; - self.take_chunked_unchecked(chunk_ids, sorted) + df_self._take_chunked_unchecked(chunk_ids, sorted) } } /// # Safety /// Join tuples must be in bounds - pub unsafe fn _create_left_df_from_slice( + unsafe fn _create_left_df_from_slice( &self, join_tuples: &[IdxSize], left_join: bool, sorted_tuple_idx: bool, ) -> DataFrame { - if left_join && join_tuples.len() == self.height() { - self.clone() + let df_self = self.to_df(); + if left_join && join_tuples.len() == df_self.height() { + df_self.clone() } else { // left join tuples are always in ascending order let sorted = if left_join || sorted_tuple_idx { @@ -188,24 +119,25 @@ impl DataFrame { IsSorted::Not }; - self._take_unchecked_slice_sorted(join_tuples, true, sorted) + df_self._take_unchecked_slice_sorted(join_tuples, true, sorted) } } #[cfg(not(feature = "chunked_ids"))] - pub fn _finish_left_join( + fn _finish_left_join( &self, ids: LeftJoinIds, other: &DataFrame, args: JoinArgs, ) -> PolarsResult { + let ca_self = self.to_df(); let (left_idx, right_idx) = ids; let materialize_left = || { let mut left_idx = &*left_idx; if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - unsafe { self._create_left_df_from_slice(left_idx, true, true) } + unsafe { ca_self._create_left_df_from_slice(left_idx, true, true) } }; let materialize_right = || { @@ -213,11 +145,7 @@ impl DataFrame { if let Some((offset, len)) = args.slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { - other.take_opt_iter_unchecked( - right_idx.iter().map(|opt_i| opt_i.map(|i| i as usize)), - ) - } + unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -225,12 +153,13 @@ impl DataFrame { } #[cfg(feature = "chunked_ids")] - pub fn _finish_left_join( + fn _finish_left_join( &self, ids: LeftJoinIds, other: &DataFrame, args: JoinArgs, ) -> PolarsResult { + let ca_self = self.to_df(); let suffix = &args.suffix; let slice = args.slice; let (left_idx, right_idx) = ids; @@ -240,14 +169,14 @@ impl DataFrame { if let Some((offset, len)) = slice { left_idx = slice_slice(left_idx, offset, len); } - unsafe { self._create_left_df_from_slice(left_idx, true, true) } + unsafe { ca_self._create_left_df_from_slice(left_idx, true, true) } }, ChunkJoinIds::Right(left_idx) => { let mut left_idx = &*left_idx; if let Some((offset, len)) = slice { left_idx = slice_slice(left_idx, offset, len); } - unsafe { self.create_left_df_chunked(left_idx, true) } + unsafe { ca_self.create_left_df_chunked(left_idx, true) } }, }; @@ -257,18 +186,14 @@ impl DataFrame { if let Some((offset, len)) = slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { - other.take_opt_iter_unchecked( - right_idx.iter().map(|opt_i| opt_i.map(|i| i as usize)), - ) - } + unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }, ChunkJoinOptIds::Right(right_idx) => { let mut right_idx = &*right_idx; if let Some((offset, len)) = slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { other.take_opt_chunked_unchecked(right_idx) } + unsafe { other._take_opt_chunked_unchecked(right_idx) } }, }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -276,7 +201,7 @@ impl DataFrame { _finish_join(df_left, df_right, suffix.as_deref()) } - pub fn _left_join_from_series( + fn _left_join_from_series( &self, other: &DataFrame, s_left: &Series, @@ -284,11 +209,12 @@ impl DataFrame { args: JoinArgs, verbose: bool, ) -> PolarsResult { + let ca_self = self.to_df(); #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; // ensure that the chunks are aligned otherwise we go OOB - let mut left = self.clone(); + let mut left = ca_self.clone(); let mut s_left = s_left.clone(); let mut right = other.clone(); let mut s_right = s_right.clone(); @@ -307,46 +233,52 @@ impl DataFrame { #[cfg(feature = "semi_anti_join")] /// # Safety /// `idx` must be in bounds - pub unsafe fn _finish_anti_semi_join( + unsafe fn _finish_anti_semi_join( &self, mut idx: &[IdxSize], slice: Option<(i64, usize)>, ) -> DataFrame { + let ca_self = self.to_df(); if let Some((offset, len)) = slice { idx = slice_slice(idx, offset, len); } // idx from anti-semi join should always be sorted - self._take_unchecked_slice_sorted(idx, true, IsSorted::Ascending) + ca_self._take_unchecked_slice_sorted(idx, true, IsSorted::Ascending) } #[cfg(feature = "semi_anti_join")] - pub fn _semi_anti_join_from_series( + fn _semi_anti_join_from_series( &self, s_left: &Series, s_right: &Series, slice: Option<(i64, usize)>, anti: bool, ) -> PolarsResult { + let ca_self = self.to_df(); #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; let idx = s_left.hash_join_semi_anti(s_right, anti); // Safety: // indices are in bounds - Ok(unsafe { self._finish_anti_semi_join(&idx, slice) }) + Ok(unsafe { ca_self._finish_anti_semi_join(&idx, slice) }) } - pub fn _outer_join_from_series( + fn _outer_join_from_series( &self, other: &DataFrame, s_left: &Series, s_right: &Series, args: JoinArgs, ) -> PolarsResult { + let ca_self = self.to_df(); #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; // store this so that we can keep original column order. - let join_column_index = self.iter().position(|s| s.name() == s_left.name()).unwrap(); + let join_column_index = ca_self + .iter() + .position(|s| s.name() == s_left.name()) + .unwrap(); // Get the indexes of the joined relations let opt_join_tuples = s_left.hash_join_outer(s_right, args.validation)?; @@ -359,30 +291,38 @@ impl DataFrame { // Take the left and right dataframes by join tuples let (mut df_left, df_right) = POOL.join( || unsafe { - self.drop(s_left.name()).unwrap().take_opt_iter_unchecked( - opt_join_tuples + ca_self.drop(s_left.name()).unwrap().take_unchecked( + &opt_join_tuples .iter() - .map(|(left, _right)| left.map(|i| i as usize)), + .copied() + .map(|(left, _right)| left) + .collect_ca("outer-join-left-indices"), ) }, || unsafe { - other.drop(s_right.name()).unwrap().take_opt_iter_unchecked( - opt_join_tuples + other.drop(s_right.name()).unwrap().take_unchecked( + &opt_join_tuples .iter() - .map(|(_left, right)| right.map(|i| i as usize)), + .copied() + .map(|(_left, right)| right) + .collect_ca("outer-join-right-indices"), ) }, ); - let mut s = s_left - .to_physical_repr() - .zip_outer_join_column(&s_right.to_physical_repr(), opt_join_tuples); - s.rename(s_left.name()); + let s = unsafe { + zip_outer_join_column( + &s_left.to_physical_repr(), + &s_right.to_physical_repr(), + opt_join_tuples, + ) + .with_name(s_left.name()) + }; let s = match s_left.dtype() { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => { let ca_left = s_left.categorical().unwrap(); - let new_rev_map = ca_left.merge_categorical_map(s_right.categorical().unwrap())?; + let new_rev_map = ca_left._merge_categorical_map(s_right.categorical().unwrap())?; let logical = s.u32().unwrap().clone(); // safety: // categorical maps are merged @@ -402,3 +342,5 @@ impl DataFrame { _finish_join(df_left, df_right, args.suffix.as_deref()) } } + +impl JoinDispatch for DataFrame {} diff --git a/crates/polars-core/src/frame/hash_join/multiple_keys.rs b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs similarity index 93% rename from crates/polars-core/src/frame/hash_join/multiple_keys.rs rename to crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs index 157584328978..4cf38665ae54 100644 --- a/crates/polars-core/src/frame/hash_join/multiple_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs @@ -1,19 +1,14 @@ use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; -use rayon::prelude::*; +use polars_core::hashing::{ + populate_multiple_key_hashmap, IdBuildHasher, IdxHash, _HASHMAP_INIT_SIZE, +}; +use polars_core::utils::{_set_partition_size, split_df}; +use polars_core::POOL; use super::*; -use crate::frame::group_by::hashing::{populate_multiple_key_hashmap, HASHMAP_INIT_SIZE}; -use crate::frame::hash_join::{ - get_hash_tbl_threaded_join_mut_partitioned, get_hash_tbl_threaded_join_partitioned, -}; -use crate::hashing::{df_rows_to_hashes_threaded_vertical, this_partition, IdBuildHasher, IdxHash}; -use crate::prelude::*; -use crate::utils::series::_to_physical_and_bit_repr; -use crate::utils::{_set_partition_size, split_df}; -use crate::POOL; -/// Compare the rows of two DataFrames +/// Compare the rows of two [`DataFrame`]s pub(crate) unsafe fn compare_df_rows2( left: &DataFrame, right: &DataFrame, @@ -43,7 +38,7 @@ pub(crate) fn create_probe_table( .map(|part_no| { let part_no = part_no as u64; let mut hash_tbl: HashMap, IdBuildHasher> = - HashMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); let n_partitions = n_partitions as u64; let mut offset = 0; @@ -92,7 +87,7 @@ fn create_build_table_outer( (0..n_partitions).into_par_iter().map(|part_no| { let part_no = part_no as u64; let mut hash_tbl: HashMap), IdBuildHasher> = - HashMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); let n_partitions = n_partitions as u64; let mut offset = 0; @@ -187,9 +182,9 @@ pub fn _inner_join_multiple_keys( let dfs_a = split_df(a, n_threads).unwrap(); let dfs_b = split_df(b, n_threads).unwrap(); - let (build_hashes, random_state) = df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); + let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); let (probe_hashes, _) = - df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); + _df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); let hash_tbls = create_probe_table(&build_hashes, b); // early drop to reduce memory pressure @@ -269,9 +264,9 @@ pub fn _left_join_multiple_keys( let dfs_a = split_df(a, n_threads).unwrap(); let dfs_b = split_df(b, n_threads).unwrap(); - let (build_hashes, random_state) = df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); + let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); let (probe_hashes, _) = - df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); + _df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); let hash_tbls = create_probe_table(&build_hashes, b); // early drop to reduce memory pressure @@ -353,7 +348,7 @@ pub(crate) fn create_build_table_semi_anti( (0..n_partitions).into_par_iter().map(|part_no| { let part_no = part_no as u64; let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); let n_partitions = n_partitions as u64; let mut offset = 0; @@ -400,9 +395,9 @@ pub(crate) fn semi_anti_join_multiple_keys_impl<'a>( let dfs_a = split_df(a, n_threads).unwrap(); let dfs_b = split_df(b, n_threads).unwrap(); - let (build_hashes, random_state) = df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); + let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); let (probe_hashes, _) = - df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); + _df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); let hash_tbls = create_build_table_semi_anti(&build_hashes, b); // early drop to reduce memory pressure @@ -555,9 +550,9 @@ pub fn _outer_join_multiple_keys( let dfs_a = split_df(a, n_threads).unwrap(); let dfs_b = split_df(b, n_threads).unwrap(); - let (build_hashes, random_state) = df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); + let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&dfs_b, None).unwrap(); let (probe_hashes, _) = - df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); + _df_rows_to_hashes_threaded_vertical(&dfs_a, Some(random_state)).unwrap(); let mut hash_tbls = create_build_table_outer(&build_hashes, b); // early drop to reduce memory pressure diff --git a/crates/polars-core/src/frame/hash_join/single_keys.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs similarity index 72% rename from crates/polars-core/src/frame/hash_join/single_keys.rs rename to crates/polars-ops/src/frame/join/hash_join/single_keys.rs index c19833643a50..7fac84d9c758 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs @@ -1,9 +1,10 @@ use super::*; -pub(crate) fn build_tables(keys: Vec) -> Vec>> +pub(crate) fn build_tables(keys: Vec) -> Vec>> where T: Send + Hash + Eq + Sync + Copy + AsU64, - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Send + Sync + Clone, + // ::IntoIter: TrustedLen, { let n_partitions = _set_partition_size(); @@ -17,21 +18,21 @@ where let partition_no = partition_no as u64; let mut hash_tbl: PlHashMap> = - PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); let n_partitions = n_partitions as u64; let mut offset = 0; for keys in &keys { - let keys = keys.as_ref(); - let len = keys.len() as IdxSize; + let keys = keys.clone().into_iter(); + let len = keys.size_hint().1.unwrap() as IdxSize; let mut cnt = 0; - keys.iter().for_each(|k| { + keys.for_each(|k| { let idx = cnt + offset; cnt += 1; if this_partition(k.as_u64(), partition_no, n_partitions) { - let entry = hash_tbl.entry(*k); + let entry = hash_tbl.entry(k); match entry { Entry::Vacant(entry) => { @@ -53,14 +54,15 @@ where } // we determine the offset so that we later know which index to store in the join tuples -pub(super) fn probe_to_offsets(probe: &[IntoSlice]) -> Vec +pub(super) fn probe_to_offsets(probe: &[I]) -> Vec where - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Clone, + // ::IntoIter: TrustedLen, T: Send + Hash + Eq + Sync + Copy + AsU64, { probe .iter() - .map(|ph| ph.as_ref().len()) + .map(|ph| ph.clone().into_iter().size_hint().1.unwrap()) .scan(0, |state, val| { let out = *state; *state += val; diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs new file mode 100644 index 000000000000..fefb207af236 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -0,0 +1,460 @@ +use num_traits::NumCast; + +use super::*; +use crate::series::SeriesSealed; + +pub trait SeriesJoin: SeriesSealed + Sized { + #[doc(hidden)] + fn hash_join_left( + &self, + other: &Series, + validate: JoinValidation, + ) -> PolarsResult { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, false)?; + + use DataType::*; + match lhs.dtype() { + Utf8 => { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + lhs.hash_join_left(&rhs, JoinValidation::ManyToMany) + }, + Binary => { + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, _, _) = prepare_binary(lhs, rhs, false); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + hash_join_tuples_left(lhs, rhs, None, None, validate) + }, + _ => { + if s_self.bit_repr_is_large() { + let lhs = lhs.bit_repr_large(); + let rhs = rhs.bit_repr_large(); + num_group_join_left(&lhs, &rhs, validate) + } else { + let lhs = lhs.bit_repr_small(); + let rhs = rhs.bit_repr_small(); + num_group_join_left(&lhs, &rhs, validate) + } + }, + } + } + + #[cfg(feature = "semi_anti_join")] + fn hash_join_semi_anti(&self, other: &Series, anti: bool) -> Vec { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + + use DataType::*; + match lhs.dtype() { + Utf8 => { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + lhs.hash_join_semi_anti(&rhs, anti) + }, + Binary => { + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, _, _) = prepare_binary(lhs, rhs, false); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + if anti { + hash_join_tuples_left_anti(lhs, rhs) + } else { + hash_join_tuples_left_semi(lhs, rhs) + } + }, + _ => { + if s_self.bit_repr_is_large() { + let lhs = lhs.bit_repr_large(); + let rhs = rhs.bit_repr_large(); + num_group_join_anti_semi(&lhs, &rhs, anti) + } else { + let lhs = lhs.bit_repr_small(); + let rhs = rhs.bit_repr_small(); + num_group_join_anti_semi(&lhs, &rhs, anti) + } + }, + } + } + + // returns the join tuples and whether or not the lhs tuples are sorted + fn hash_join_inner( + &self, + other: &Series, + validate: JoinValidation, + ) -> PolarsResult<(InnerJoinIds, bool)> { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, true)?; + + use DataType::*; + match lhs.dtype() { + Utf8 => { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + lhs.hash_join_inner(&rhs, JoinValidation::ManyToMany) + }, + Binary => { + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, swapped, _) = prepare_binary(lhs, rhs, true); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + Ok(( + hash_join_tuples_inner(lhs, rhs, swapped, validate)?, + !swapped, + )) + }, + _ => { + if s_self.bit_repr_is_large() { + let lhs = s_self.bit_repr_large(); + let rhs = other.bit_repr_large(); + group_join_inner::(&lhs, &rhs, validate) + } else { + let lhs = s_self.bit_repr_small(); + let rhs = other.bit_repr_small(); + group_join_inner::(&lhs, &rhs, validate) + } + }, + } + } + + fn hash_join_outer( + &self, + other: &Series, + validate: JoinValidation, + ) -> PolarsResult, Option)>> { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, true)?; + + use DataType::*; + match lhs.dtype() { + Utf8 => { + let lhs = lhs.cast(&Binary).unwrap(); + let rhs = rhs.cast(&Binary).unwrap(); + lhs.hash_join_outer(&rhs, JoinValidation::ManyToMany) + }, + Binary => { + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, swapped, _) = prepare_binary(lhs, rhs, true); + let lhs = lhs.iter().collect::>(); + let rhs = rhs.iter().collect::>(); + hash_join_tuples_outer(lhs, rhs, swapped, validate) + }, + _ => { + if s_self.bit_repr_is_large() { + let lhs = s_self.bit_repr_large(); + let rhs = other.bit_repr_large(); + hash_join_outer(&lhs, &rhs, validate) + } else { + let lhs = s_self.bit_repr_small(); + let rhs = other.bit_repr_small(); + hash_join_outer(&lhs, &rhs, validate) + } + }, + } + } +} + +impl SeriesJoin for Series {} + +fn chunks_as_slices(splitted: &[ChunkedArray]) -> Vec<&[T::Native]> +where + T: PolarsNumericType, +{ + splitted + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) + .collect() +} + +fn get_arrays(cas: &[ChunkedArray]) -> Vec<&T::Array> { + cas.iter().flat_map(|arr| arr.downcast_iter()).collect() +} + +fn group_join_inner( + left: &ChunkedArray, + right: &ChunkedArray, + validate: JoinValidation, +) -> PolarsResult<(InnerJoinIds, bool)> +where + T: PolarsDataType, + for<'a> &'a T::Array: IntoIterator>>, + for<'a> T::Physical<'a>: Hash + Eq + Send + AsU64 + Copy + Send + Sync, +{ + let n_threads = POOL.current_num_threads(); + let (a, b, swapped) = det_hash_prone_order!(left, right); + let splitted_a = split_ca(a, n_threads).unwrap(); + let splitted_b = split_ca(b, n_threads).unwrap(); + let splitted_a = get_arrays(&splitted_a); + let splitted_b = get_arrays(&splitted_b); + + match (left.null_count(), right.null_count()) { + (0, 0) => { + let first = &splitted_a[0]; + if first.as_slice().is_some() { + let splitted_a = splitted_a + .iter() + .map(|arr| arr.as_slice().unwrap()) + .collect::>(); + let splitted_b = splitted_b + .iter() + .map(|arr| arr.as_slice().unwrap()) + .collect::>(); + Ok(( + hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate)?, + !swapped, + )) + } else { + Ok(( + hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate)?, + !swapped, + )) + } + }, + _ => Ok(( + hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate)?, + !swapped, + )), + } +} + +#[cfg(feature = "chunked_ids")] +fn create_mappings( + chunks_left: &[ArrayRef], + chunks_right: &[ArrayRef], + left_len: usize, + right_len: usize, +) -> (Option>, Option>) { + let mapping_left = || { + if chunks_left.len() > 1 { + Some(create_chunked_index_mapping(chunks_left, left_len)) + } else { + None + } + }; + + let mapping_right = || { + if chunks_right.len() > 1 { + Some(create_chunked_index_mapping(chunks_right, right_len)) + } else { + None + } + }; + + POOL.join(mapping_left, mapping_right) +} + +#[cfg(not(feature = "chunked_ids"))] +fn create_mappings( + _chunks_left: &[ArrayRef], + _chunks_right: &[ArrayRef], + _left_len: usize, + _right_len: usize, +) -> (Option>, Option>) { + (None, None) +} + +fn num_group_join_left( + left: &ChunkedArray, + right: &ChunkedArray, + validate: JoinValidation, +) -> PolarsResult +where + T: PolarsIntegerType, + T::Native: Hash + Eq + Send + AsU64, + Option: AsU64, +{ + let n_threads = POOL.current_num_threads(); + let splitted_a = split_ca(left, n_threads).unwrap(); + let splitted_b = split_ca(right, n_threads).unwrap(); + match ( + left.null_count(), + right.null_count(), + left.chunks().len(), + right.chunks().len(), + ) { + (0, 0, 1, 1) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + hash_join_tuples_left(keys_a, keys_b, None, None, validate) + }, + (0, 0, _, _) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + + let (mapping_left, mapping_right) = + create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); + hash_join_tuples_left( + keys_a, + keys_b, + mapping_left.as_deref(), + mapping_right.as_deref(), + validate, + ) + }, + _ => { + let keys_a = get_arrays(&splitted_a); + let keys_b = get_arrays(&splitted_b); + let (mapping_left, mapping_right) = + create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); + hash_join_tuples_left( + keys_a, + keys_b, + mapping_left.as_deref(), + mapping_right.as_deref(), + validate, + ) + }, + } +} + +fn hash_join_outer( + ca_in: &ChunkedArray, + other: &ChunkedArray, + validate: JoinValidation, +) -> PolarsResult, Option)>> +where + T: PolarsIntegerType + Sync, + T::Native: Eq + Hash + NumCast, +{ + let (a, b, swapped) = det_hash_prone_order!(ca_in, other); + + let n_partitions = _set_partition_size(); + let splitted_a = split_ca(a, n_partitions).unwrap(); + let splitted_b = split_ca(b, n_partitions).unwrap(); + + match (a.null_count(), b.null_count()) { + (0, 0) => { + let iters_a = splitted_a + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) + .collect::>(); + let iters_b = splitted_b + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) + .collect::>(); + hash_join_tuples_outer(iters_a, iters_b, swapped, validate) + }, + _ => { + let iters_a = splitted_a + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.iter())) + .collect::>(); + let iters_b = splitted_b + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.iter())) + .collect::>(); + hash_join_tuples_outer(iters_a, iters_b, swapped, validate) + }, + } +} + +pub fn prepare_bytes<'a>( + been_split: &'a [BinaryChunked], + hb: &RandomState, +) -> Vec>> { + POOL.install(|| { + been_split + .par_iter() + .map(|ca| { + ca.into_iter() + .map(|opt_b| { + let hash = hb.hash_one(opt_b); + BytesHash::new(opt_b, hash) + }) + .collect::>() + }) + .collect() + }) +} + +fn prepare_binary<'a>( + ca: &'a BinaryChunked, + other: &'a BinaryChunked, + // In inner join and outer join, the shortest relation will be used to create a hash table. + // In left join, always use the right side to create. + build_shortest_table: bool, +) -> ( + Vec>>, + Vec>>, + bool, + RandomState, +) { + let n_threads = POOL.current_num_threads(); + + let (a, b, swapped) = if build_shortest_table { + det_hash_prone_order!(ca, other) + } else { + (ca, other, false) + }; + + let hb = RandomState::default(); + let splitted_a = split_ca(a, n_threads).unwrap(); + let splitted_b = split_ca(b, n_threads).unwrap(); + let str_hashes_a = prepare_bytes(&splitted_a, &hb); + let str_hashes_b = prepare_bytes(&splitted_b, &hb); + + // SAFETY: + // Splitting a Ca keeps the same buffers, so the lifetime is still valid. + let str_hashes_a = unsafe { std::mem::transmute::<_, Vec>>>(str_hashes_a) }; + let str_hashes_b = unsafe { std::mem::transmute::<_, Vec>>>(str_hashes_b) }; + + (str_hashes_a, str_hashes_b, swapped, hb) +} + +#[cfg(feature = "semi_anti_join")] +fn num_group_join_anti_semi( + left: &ChunkedArray, + right: &ChunkedArray, + anti: bool, +) -> Vec +where + T: PolarsIntegerType, + T::Native: Hash + Eq + Send + AsU64, + Option: AsU64, +{ + let n_threads = POOL.current_num_threads(); + let splitted_a = split_ca(left, n_threads).unwrap(); + let splitted_b = split_ca(right, n_threads).unwrap(); + match ( + left.null_count(), + right.null_count(), + left.chunks().len(), + right.chunks().len(), + ) { + (0, 0, 1, 1) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + if anti { + hash_join_tuples_left_anti(keys_a, keys_b) + } else { + hash_join_tuples_left_semi(keys_a, keys_b) + } + }, + (0, 0, _, _) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + if anti { + hash_join_tuples_left_anti(keys_a, keys_b) + } else { + hash_join_tuples_left_semi(keys_a, keys_b) + } + }, + _ => { + let keys_a = get_arrays(&splitted_a); + let keys_b = get_arrays(&splitted_b); + if anti { + hash_join_tuples_left_anti(keys_a, keys_b) + } else { + hash_join_tuples_left_semi(keys_a, keys_b) + } + }, + } +} diff --git a/crates/polars-core/src/frame/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs similarity index 84% rename from crates/polars-core/src/frame/hash_join/single_keys_inner.rs rename to crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs index a877fa516ab5..cae187cd74ef 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_inner.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -1,14 +1,11 @@ +use polars_core::utils::flatten; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::sync::SyncPtr; -use super::single_keys::build_tables; use super::*; -use crate::frame::hash_join::single_keys::probe_to_offsets; -use crate::utils::flatten; -/// Probe the build table and add tuples to the results (inner join) -pub(super) fn probe_inner( - probe: &[T], +pub(super) fn probe_inner( + probe: I, hash_tbls: &[PlHashMap>], results: &mut Vec<(IdxSize, IdxSize)>, local_offset: IdxSize, @@ -16,16 +13,18 @@ pub(super) fn probe_inner( swap_fn: F, ) where T: Send + Hash + Eq + Sync + Copy + AsU64, + I: IntoIterator, + // ::IntoIter: TrustedLen, F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize), { assert!(hash_tbls.len().is_power_of_two()); - probe.iter().enumerate_idx().for_each(|(idx_a, k)| { + probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| { let idx_a = idx_a + local_offset; // probe table that contains the hashed value let current_probe_table = unsafe { get_hash_tbl_threaded_join_partitioned(k.as_u64(), hash_tbls, n_tables) }; - let value = current_probe_table.get(k); + let value = current_probe_table.get(&k); if let Some(indexes_b) = value { let tuples = indexes_b.iter().map(|&idx_b| swap_fn(idx_a, idx_b)); @@ -34,22 +33,25 @@ pub(super) fn probe_inner( }); } -pub(super) fn hash_join_tuples_inner( - probe: Vec, - build: Vec, +pub(super) fn hash_join_tuples_inner( + probe: Vec, + build: Vec, // Because b should be the shorter relation we could need to swap to keep left left and right right. swapped: bool, validate: JoinValidation, ) -> PolarsResult<(Vec, Vec)> where - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Send + Sync + Copy, + // ::IntoIter: TrustedLen, T: Send + Hash + Eq + Sync + Copy + AsU64, { // NOTE: see the left join for more elaborate comments - // first we hash one relation let hash_tbls = if validate.needs_checks() { - let expected_size = build.iter().map(|v| v.as_ref().len()).sum(); + let expected_size = build + .iter() + .map(|v| v.into_iter().size_hint().1.unwrap()) + .sum(); let hash_tbls = build_tables(build); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); validate.validate_build(build_size, expected_size, swapped)?; @@ -68,10 +70,10 @@ where .into_par_iter() .zip(offsets) .map(|(probe, offset)| { - let probe = probe.as_ref(); + let probe = probe.into_iter(); // local reference let hash_tbls = &hash_tbls; - let mut results = Vec::with_capacity(probe.len()); + let mut results = Vec::with_capacity(probe.size_hint().1.unwrap()); let local_offset = offset as IdxSize; // branch is to hoist swap out of the inner loop. diff --git a/crates/polars-core/src/frame/hash_join/single_keys_left.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs similarity index 90% rename from crates/polars-core/src/frame/hash_join/single_keys_left.rs rename to crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs index 9fd51a285106..6b80c1f38d5d 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_left.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs @@ -1,7 +1,6 @@ -use super::single_keys::build_tables; +use polars_core::utils::flatten::flatten_par; + use super::*; -use crate::frame::hash_join::single_keys::probe_to_offsets; -use crate::utils::flatten::flatten_par; #[cfg(feature = "chunked_ids")] unsafe fn apply_mapping(idx: Vec, chunk_mapping: &[ChunkId]) -> Vec { @@ -98,9 +97,9 @@ pub(super) fn flatten_left_join_ids(result: Vec) -> LeftJoinIds { } } -pub(super) fn hash_join_tuples_left( - probe: Vec, - build: Vec, +pub(super) fn hash_join_tuples_left( + probe: Vec, + build: Vec, // map the global indices to [chunk_idx, array_idx] // only needed if we have non contiguous memory chunk_mapping_left: Option<&[ChunkId]>, @@ -108,12 +107,15 @@ pub(super) fn hash_join_tuples_left( validate: JoinValidation, ) -> PolarsResult where - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator, + ::IntoIter: Send + Sync + Clone, T: Send + Hash + Eq + Sync + Copy + AsU64, { + let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); + let build = build.into_iter().map(|i| i.into_iter()).collect::>(); // first we hash one relation let hash_tbls = if validate.needs_checks() { - let expected_size = build.iter().map(|v| v.as_ref().len()).sum(); + let expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum(); let hash_tbls = build_tables(build); let build_size = hash_tbls.iter().map(|m| m.len()).sum(); validate.validate_build(build_size, expected_size, false)?; @@ -138,13 +140,12 @@ where .map(move |(probe, offset)| { // local reference let hash_tbls = &hash_tbls; - let probe = probe.as_ref(); // assume the result tuples equal length of the no. of hashes processed by this thread. - let mut result_idx_left = Vec::with_capacity(probe.len()); - let mut result_idx_right = Vec::with_capacity(probe.len()); + let mut result_idx_left = Vec::with_capacity(probe.size_hint().1.unwrap()); + let mut result_idx_right = Vec::with_capacity(probe.size_hint().1.unwrap()); - probe.iter().enumerate().for_each(|(idx_a, k)| { + probe.enumerate().for_each(|(idx_a, k)| { let idx_a = (idx_a + offset) as IdxSize; // probe table that contains the hashed value let current_probe_table = unsafe { @@ -152,7 +153,7 @@ where }; // we already hashed, so we don't have to hash again. - let value = current_probe_table.get(k); + let value = current_probe_table.get(&k); match value { // left and right matches diff --git a/crates/polars-core/src/frame/hash_join/single_keys_outer.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs similarity index 89% rename from crates/polars-core/src/frame/hash_join/single_keys_outer.rs rename to crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs index 676f77818294..a0c34a432a76 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs @@ -67,10 +67,14 @@ pub(super) fn hash_join_tuples_outer( validate: JoinValidation, ) -> PolarsResult, Option)>> where - I: Iterator + Send + TrustedLen, - J: Iterator + Send + TrustedLen, + I: IntoIterator, + J: IntoIterator, + ::IntoIter: TrustedLen + Send, + ::IntoIter: TrustedLen + Send, T: Hash + Eq + Copy + Sync + Send, { + let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); + let build = build.into_iter().map(|i| i.into_iter()).collect::>(); // This function is partially multi-threaded. // Parts that are done in parallel: // - creation of the probe tables @@ -79,8 +83,14 @@ where // during the probe phase values are removed from the tables, that's done single threaded to // keep it lock free. - let size = probe.iter().map(|a| a.size_hint().0).sum::() - + build.iter().map(|b| b.size_hint().0).sum::(); + let size = probe + .iter() + .map(|a| a.size_hint().1.unwrap()) + .sum::() + + build + .iter() + .map(|b| b.size_hint().1.unwrap()) + .sum::(); let mut results = Vec::with_capacity(size); // prepare hash table diff --git a/crates/polars-core/src/frame/hash_join/single_keys_semi_anti.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs similarity index 72% rename from crates/polars-core/src/frame/hash_join/single_keys_semi_anti.rs rename to crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs index 616711c03c79..254c31e8baf2 100644 --- a/crates/polars-core/src/frame/hash_join/single_keys_semi_anti.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs @@ -1,11 +1,10 @@ use super::*; -use crate::frame::hash_join::single_keys::probe_to_offsets; /// Only keeps track of membership in right table -pub(super) fn create_probe_table_semi_anti(keys: Vec) -> Vec> +pub(super) fn create_probe_table_semi_anti(keys: Vec) -> Vec> where T: Send + Hash + Eq + Sync + Copy + AsU64, - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Copy + Send + Sync, { let n_partitions = _set_partition_size(); @@ -16,14 +15,13 @@ where (0..n_partitions).into_par_iter().map(|partition_no| { let partition_no = partition_no as u64; - let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(HASHMAP_INIT_SIZE); + let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); let n_partitions = n_partitions as u64; for keys in &keys { - let keys = keys.as_ref(); - keys.iter().for_each(|k| { + keys.into_iter().for_each(|k| { if this_partition(k.as_u64(), partition_no, n_partitions) { - hash_tbl.insert(*k); + hash_tbl.insert(k); } }); } @@ -33,12 +31,12 @@ where .collect() } -pub(super) fn semi_anti_impl( - probe: Vec, - build: Vec, +pub(super) fn semi_anti_impl( + probe: Vec, + build: Vec, ) -> impl ParallelIterator where - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Copy + Send + Sync, T: Send + Hash + Eq + Sync + Copy + AsU64, { // first we hash one relation @@ -60,12 +58,12 @@ where .flat_map(move |(probe, offset)| { // local reference let hash_sets = &hash_sets; - let probe = probe.as_ref(); + let probe_iter = probe.into_iter(); // assume the result tuples equal length of the no. of hashes processed by this thread. - let mut results = Vec::with_capacity(probe.len()); + let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); - probe.iter().enumerate().for_each(|(idx_a, k)| { + probe_iter.enumerate().for_each(|(idx_a, k)| { let idx_a = (idx_a + offset) as IdxSize; // probe table that contains the hashed value let current_probe_table = unsafe { @@ -73,7 +71,7 @@ where }; // we already hashed, so we don't have to hash again. - let value = current_probe_table.get(k); + let value = current_probe_table.get(&k); match value { // left and right matches @@ -87,12 +85,9 @@ where }) } -pub(super) fn hash_join_tuples_left_anti( - probe: Vec, - build: Vec, -) -> Vec +pub(super) fn hash_join_tuples_left_anti(probe: Vec, build: Vec) -> Vec where - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Copy + Send + Sync, T: Send + Hash + Eq + Sync + Copy + AsU64, { semi_anti_impl(probe, build) @@ -101,12 +96,9 @@ where .collect() } -pub(super) fn hash_join_tuples_left_semi( - probe: Vec, - build: Vec, -) -> Vec +pub(super) fn hash_join_tuples_left_semi(probe: Vec, build: Vec) -> Vec where - IntoSlice: AsRef<[T]> + Send + Sync, + I: IntoIterator + Copy + Send + Sync, T: Send + Hash + Eq + Sync + Copy + AsU64, { semi_anti_impl(probe, build) diff --git a/crates/polars-core/src/frame/hash_join/sort_merge.rs b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs similarity index 98% rename from crates/polars-core/src/frame/hash_join/sort_merge.rs rename to crates/polars-ops/src/frame/join/hash_join/sort_merge.rs index 6db0b0a4e086..24f2692d21b8 100644 --- a/crates/polars-core/src/frame/hash_join/sort_merge.rs +++ b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs @@ -1,11 +1,11 @@ #[cfg(feature = "performant")] use polars_arrow::kernels::sorted_join; - -use super::*; #[cfg(feature = "performant")] -use crate::utils::_split_offsets; +use polars_core::utils::_split_offsets; #[cfg(feature = "performant")] -use crate::utils::flatten::flatten_par; +use polars_core::utils::flatten::flatten_par; + +use super::*; #[cfg(feature = "performant")] fn par_sorted_merge_left_impl( @@ -169,7 +169,7 @@ fn to_left_join_ids(left_idx: Vec, right_idx: Vec>) -> #[cfg(feature = "performant")] fn create_reverse_map_from_arg_sort(mut arg_sort: IdxCa) -> Vec { - let arr = arg_sort.chunks.pop().unwrap(); + let arr = unsafe { arg_sort.chunks_mut() }.pop().unwrap(); primitive_to_vec::(arr).unwrap() } @@ -226,7 +226,7 @@ pub fn _sort_or_hash_inner( multithreaded: true, maintain_order: false, }); - let s_right = unsafe { s_right.take_unchecked(&sort_idx).unwrap() }; + let s_right = unsafe { s_right.take_unchecked(&sort_idx) }; let ids = par_sorted_merge_inner_no_nulls(s_left, &s_right); let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); @@ -253,7 +253,7 @@ pub fn _sort_or_hash_inner( multithreaded: true, maintain_order: false, }); - let s_left = unsafe { s_left.take_unchecked(&sort_idx).unwrap() }; + let s_left = unsafe { s_left.take_unchecked(&sort_idx) }; let ids = par_sorted_merge_inner_no_nulls(&s_left, s_right); let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); @@ -322,7 +322,7 @@ pub(super) fn sort_or_hash_left( multithreaded: true, maintain_order: false, }); - let s_right = unsafe { s_right.take_unchecked(&sort_idx).unwrap() }; + let s_right = unsafe { s_right.take_unchecked(&sort_idx) }; let ids = par_sorted_merge_left(s_left, &s_right); let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); diff --git a/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs b/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs new file mode 100644 index 000000000000..ba23f32c2910 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs @@ -0,0 +1,181 @@ +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub(crate) unsafe fn zip_outer_join_column( + left_column: &Series, + right_column: &Series, + opt_join_tuples: &[(Option, Option)], +) -> Series { + match left_column.dtype() { + DataType::Null => { + Series::full_null(left_column.name(), opt_join_tuples.len(), &DataType::Null) + }, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_) => { + let left_column = left_column.categorical().unwrap(); + let new_rev_map = left_column + ._merge_categorical_map(right_column.categorical().unwrap()) + .unwrap(); + let left = left_column.logical(); + let right = right_column + .categorical() + .unwrap() + .logical() + .clone() + .into_series(); + + let cats = zip_outer_join_column_ca(left, &right, opt_join_tuples); + let cats = cats.u32().unwrap().clone(); + + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked(cats, new_rev_map).into_series() + } + }, + DataType::Utf8 => { + let left_column = left_column.cast(&DataType::Binary).unwrap(); + let right_column = right_column.cast(&DataType::Binary).unwrap(); + let out = zip_outer_join_column_ca( + left_column.binary().unwrap(), + &right_column, + opt_join_tuples, + ); + out.cast_unchecked(&DataType::Utf8).unwrap() + }, + DataType::Binary => { + zip_outer_join_column_ca(left_column.binary().unwrap(), right_column, opt_join_tuples) + }, + DataType::Boolean => { + zip_outer_join_column_ca(left_column.bool().unwrap(), right_column, opt_join_tuples) + }, + logical_type => { + let lhs_phys = left_column.to_physical_repr(); + let rhs_phys = right_column.to_physical_repr(); + + let out = with_match_physical_numeric_polars_type!(lhs_phys.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs_phys.as_ref().as_ref().as_ref(); + + zip_outer_join_column_ca(lhs, &rhs_phys, opt_join_tuples) + }); + out.cast_unchecked(logical_type).unwrap() + }, + } +} + +fn get_value T>( + opt_left_idx: Option, + opt_right_idx: Option, + left_arr: A, + right_arr: A, + getter: F, +) -> T { + if let Some(left_idx) = opt_left_idx { + getter(left_arr, left_idx as usize) + } else { + unsafe { + let right_idx = opt_right_idx.unwrap_unchecked(); + getter(right_arr, right_idx as usize) + } + } +} + +// TODO! improve this once we have a proper scatter. +// Two scatters should do it. Can also improve the `opt_join_tuples` format then. +unsafe fn zip_outer_join_column_ca<'a, T>( + left_column: &'a ChunkedArray, + right_column: &Series, + opt_join_tuples: &[(Option, Option)], +) -> Series +where + T: PolarsDataType, + ChunkedArray: IntoSeries, + T::Physical<'a>: Copy, +{ + let right_ca = left_column + .unpack_series_matching_type(right_column) + .unwrap(); + + let tuples_iter = opt_join_tuples.iter(); + + // No nulls. + if left_column.null_count() == 0 && right_ca.null_count() == 0 { + // Single chunk case. + if left_column.chunks().len() == 1 && right_column.chunks().len() == 1 { + let left_arr = left_column.downcast_iter().next().unwrap(); + let right_arr = right_ca.downcast_iter().next().unwrap(); + + match (left_arr.as_slice(), right_arr.as_slice()) { + (Some(left_slice), Some(right_slice)) => tuples_iter + .map(|(opt_left_idx, opt_right_idx)| { + get_value( + *opt_left_idx, + *opt_right_idx, + left_slice, + right_slice, + |slice, idx| *slice.get_unchecked(idx), + ) + }) + .collect_ca_trusted_like(left_column) + .into_series(), + _ => tuples_iter + .map(|(opt_left_idx, opt_right_idx)| { + get_value( + *opt_left_idx, + *opt_right_idx, + left_arr, + right_arr, + |slice, idx| slice.value_unchecked(idx), + ) + }) + .collect_ca_trusted_like(left_column) + .into_series(), + } + } else { + tuples_iter + .map(|(opt_left_idx, opt_right_idx)| { + get_value( + *opt_left_idx, + *opt_right_idx, + left_column, + right_ca, + |slice, idx| slice.value_unchecked(idx), + ) + }) + .collect_ca_trusted_like(left_column) + .into_series() + } + + // Nulls. + } else { + // Single chunk case. + if left_column.chunks().len() == 1 && right_column.chunks().len() == 1 { + let left_arr = left_column.downcast_iter().next().unwrap(); + let right_arr = right_ca.downcast_iter().next().unwrap(); + tuples_iter + .map(|(opt_left_idx, opt_right_idx)| { + get_value( + *opt_left_idx, + *opt_right_idx, + left_arr, + right_arr, + |slice, idx| slice.get_unchecked(idx), + ) + }) + .collect_ca_trusted_like(left_column) + .into_series() + } else { + tuples_iter + .map(|(opt_left_idx, opt_right_idx)| { + get_value( + *opt_left_idx, + *opt_right_idx, + left_column, + right_ca, + |slice, idx| slice.get_unchecked(idx), + ) + }) + .collect_ca_trusted_like(left_column) + .into_series() + } + } +} diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 2c325962b02e..4c7fff623464 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -1,34 +1,49 @@ +mod args; +#[cfg(feature = "asof_join")] +mod asof; +#[cfg(feature = "dtype-categorical")] +mod checks; +mod cross_join; +mod general; +mod hash_join; #[cfg(feature = "merge_sorted")] mod merge_sorted; + #[cfg(feature = "chunked_ids")] use std::borrow::Cow; - +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; + +use ahash::RandomState; +pub use args::*; +#[cfg(feature = "asof_join")] +use asof::AsofJoinBy; +#[cfg(feature = "asof_join")] +pub use asof::{AsOfOptions, AsofJoin, AsofStrategy}; +#[cfg(feature = "dtype-categorical")] +pub(crate) use checks::*; +pub use cross_join::CrossJoin; +#[cfg(feature = "chunked_ids")] +use either::Either; +#[cfg(feature = "chunked_ids")] +use general::create_chunked_index_mapping; +pub use general::{_finish_join, _join_suffix_name}; +pub use hash_join::*; +use hashbrown::hash_map::{Entry, RawEntryMut}; +use hashbrown::HashMap; #[cfg(feature = "merge_sorted")] pub use merge_sorted::_merge_sorted_dfs; -use polars_core::frame::hash_join::*; +use polars_arrow::trusted_len::TrustedLen; +use polars_core::hashing::partition::{this_partition, AsU64}; +use polars_core::hashing::{BytesHash, _df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; use polars_core::prelude::*; +pub(super) use polars_core::series::IsSorted; use polars_core::utils::{_to_physical_and_bit_repr, slice_slice}; use polars_core::POOL; +use rayon::prelude::*; -use super::*; - -macro_rules! det_hash_prone_order { - ($self:expr, $other:expr) => {{ - // The shortest relation will be used to create a hash table. - let left_first = $self.len() > $other.len(); - let a; - let b; - if left_first { - a = $self; - b = $other; - } else { - b = $self; - a = $other; - } - - (a, b, !left_first) - }}; -} +use super::hashing::{create_hash_and_keys_threaded_vectorized, prepare_hashed_relation_threaded}; +use super::IntoDf; pub trait DataFrameJoinOps: IntoDf { /// Generic join method. Can be used to join on multiple columns. @@ -295,25 +310,29 @@ pub trait DataFrameJoinOps: IntoDf { // Take the left and right dataframes by join tuples let (df_left, df_right) = POOL.join( || unsafe { - remove_selected(left_df, &selected_left).take_opt_iter_unchecked( - opt_join_tuples + remove_selected(left_df, &selected_left).take_unchecked( + &opt_join_tuples .iter() - .map(|(left, _right)| left.map(|i| i as usize)), + .map(|(left, _right)| *left) + .collect_ca(""), ) }, || unsafe { - remove_selected(other, &selected_right).take_opt_iter_unchecked( - opt_join_tuples + remove_selected(other, &selected_right).take_unchecked( + &opt_join_tuples .iter() - .map(|(_left, right)| right.map(|i| i as usize)), + .map(|(_left, right)| *right) + .collect_ca(""), ) }, ); // Allocate a new vec for df_left so that the keys are left and then other values. let mut keys = Vec::with_capacity(selected_left.len() + df_left.width()); for (s_left, s_right) in selected_left.iter().zip(&selected_right) { - let mut s = s_left.zip_outer_join_column(s_right, opt_join_tuples); - s.rename(s_left.name()); + let s = unsafe { + zip_outer_join_column(s_left, s_right, opt_join_tuples) + .with_name(s_left.name()) + }; keys.push(s) } keys.extend_from_slice(df_left.get_columns()); diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index 789458b4ea0e..ce7f02018e65 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -1,4 +1,5 @@ -mod join; +mod hashing; +pub mod join; #[cfg(feature = "pivot")] pub mod pivot; @@ -45,7 +46,7 @@ pub trait DataFrameOps: IntoDf { /// }.unwrap(); /// /// let dummies = df.to_dummies(None, false).unwrap(); - /// dbg!(dummies); + /// println!("{}", dummies); /// # } /// ``` /// Outputs: diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index df020f010548..08fbb75424c3 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -67,7 +67,7 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { let ca = s.u32().unwrap(); ca.reinterpret_signed().cast(logical_type).unwrap() }, - _ => s.cast(logical_type).unwrap(), + _ => unsafe { s.cast_unchecked(logical_type).unwrap() }, } } diff --git a/crates/polars-ops/src/lib.rs b/crates/polars-ops/src/lib.rs index 341d85969ad6..a6237ffca036 100644 --- a/crates/polars-ops/src/lib.rs +++ b/crates/polars-ops/src/lib.rs @@ -5,6 +5,6 @@ extern crate core; pub mod chunked_array; #[cfg(feature = "pivot")] pub use frame::pivot; -mod frame; +pub mod frame; pub mod prelude; mod series; diff --git a/crates/polars-ops/src/prelude.rs b/crates/polars-ops/src/prelude.rs index 2e929c977c2e..1f0717945b49 100644 --- a/crates/polars-ops/src/prelude.rs +++ b/crates/polars-ops/src/prelude.rs @@ -4,5 +4,6 @@ pub(crate) use {crate::series::*, polars_core::export::rayon::prelude::*}; pub use crate::chunked_array::*; #[cfg(feature = "merge_sorted")] pub use crate::frame::_merge_sorted_dfs; +pub use crate::frame::join::*; pub use crate::frame::{DataFrameJoinOps, DataFrameOps}; pub use crate::series::*; diff --git a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs index 133e4e3f8298..7df61317d9bc 100644 --- a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs +++ b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs @@ -16,10 +16,10 @@ //! assert_eq!(hllp.count(), 2); //! ``` -use std::hash::{BuildHasher, Hash, Hasher}; +use std::hash::Hash; use std::marker::PhantomData; -use polars_core::export::ahash::{AHasher, RandomState}; +use polars_core::export::ahash::RandomState; /// The greater is P, the smaller the error. const HLL_P: usize = 14_usize; @@ -85,9 +85,7 @@ where /// reasonable performance. #[inline] fn hash_value(&self, obj: &T) -> u64 { - let mut hasher: AHasher = SEED.build_hasher(); - obj.hash(&mut hasher); - hasher.finish() + SEED.hash_one(obj) } /// Adds an element to the HyperLogLog. diff --git a/crates/polars-ops/src/series/ops/approx_unique.rs b/crates/polars-ops/src/series/ops/approx_unique.rs index c526591e2cb3..1ff1c9d21236 100644 --- a/crates/polars-ops/src/series/ops/approx_unique.rs +++ b/crates/polars-ops/src/series/ops/approx_unique.rs @@ -58,7 +58,7 @@ fn dispatcher(s: &Series) -> PolarsResult { /// let s = Series::new("s", [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]); /// /// let approx_count = approx_n_unique(&s).unwrap(); -/// dbg!(approx_count); +/// println!("{}", approx_count); /// # } /// ``` /// Outputs: diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs index 613f833ab497..929bed548a7c 100644 --- a/crates/polars-ops/src/series/ops/arg_min_max.rs +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -1,7 +1,6 @@ use argminmax::ArgMinMax; use arrow::array::Array; -use arrow::bitmap::utils::{BitChunkIterExact, BitChunksExact}; -use arrow::bitmap::Bitmap; +use polars_arrow::bit_util::*; use polars_core::series::IsSorted; use polars_core::with_match_physical_numeric_polars_type; @@ -121,96 +120,6 @@ fn arg_min_bool(ca: &BooleanChunked) -> Option { } } -#[inline] -fn get_leading_zeroes(chunk: u64) -> u32 { - if cfg!(target_endian = "little") { - chunk.trailing_zeros() - } else { - chunk.leading_zeros() - } -} - -#[inline] -fn get_leading_ones(chunk: u64) -> u32 { - if cfg!(target_endian = "little") { - chunk.trailing_ones() - } else { - chunk.leading_ones() - } -} - -fn first_set_bit_impl(mut mask_chunks: I) -> usize -where - I: BitChunkIterExact, -{ - let mut total = 0usize; - let size = 64; - for chunk in &mut mask_chunks { - let pos = get_leading_zeroes(chunk); - if pos != size { - return total + pos as usize; - } else { - total += size as usize - } - } - if let Some(pos) = mask_chunks.remainder_iter().position(|v| v) { - total += pos; - return total; - } - // all null, return the first - 0 -} - -fn first_set_bit(mask: &Bitmap) -> usize { - if mask.unset_bits() == 0 || mask.unset_bits() == mask.len() { - return 0; - } - let (slice, offset, length) = mask.as_slice(); - if offset == 0 { - let mask_chunks = BitChunksExact::::new(slice, length); - first_set_bit_impl(mask_chunks) - } else { - let mask_chunks = mask.chunks::(); - first_set_bit_impl(mask_chunks) - } -} - -fn first_unset_bit_impl(mut mask_chunks: I) -> usize -where - I: BitChunkIterExact, -{ - let mut total = 0usize; - let size = 64; - for chunk in &mut mask_chunks { - let pos = get_leading_ones(chunk); - if pos != size { - return total + pos as usize; - } else { - total += size as usize - } - } - if let Some(pos) = mask_chunks.remainder_iter().position(|v| !v) { - total += pos; - return total; - } - // all null, return the first - 0 -} - -fn first_unset_bit(mask: &Bitmap) -> usize { - if mask.unset_bits() == 0 || mask.unset_bits() == mask.len() { - return 0; - } - let (slice, offset, length) = mask.as_slice(); - if offset == 0 { - let mask_chunks = BitChunksExact::::new(slice, length); - first_unset_bit_impl(mask_chunks) - } else { - let mask_chunks = mask.chunks::(); - first_unset_bit_impl(mask_chunks) - } -} - fn arg_min_str(ca: &Utf8Chunked) -> Option { if ca.is_empty() || ca.null_count() == ca.len() { return None; diff --git a/crates/polars-ops/src/series/ops/clip.rs b/crates/polars-ops/src/series/ops/clip.rs new file mode 100644 index 000000000000..170e7961d6a2 --- /dev/null +++ b/crates/polars-ops/src/series/ops/clip.rs @@ -0,0 +1,151 @@ +use num_traits::{clamp, clamp_max, clamp_min}; +use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise}; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +fn clip_helper( + ca: &ChunkedArray, + min: &ChunkedArray, + max: &ChunkedArray, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + match (min.len(), max.len()) { + (1, 1) => match (min.get(0), max.get(0)) { + (Some(min), Some(max)) => { + ca.apply_generic(|s| s.map(|s| num_traits::clamp(s, min, max))) + }, + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + (1, _) => match min.get(0) { + Some(min) => binary_elementwise(ca, max, |opt_s, opt_max| match (opt_s, opt_max) { + (Some(s), Some(max)) => Some(clamp(s, min, max)), + _ => None, + }), + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + (_, 1) => match max.get(0) { + Some(max) => binary_elementwise(ca, min, |opt_s, opt_min| match (opt_s, opt_min) { + (Some(s), Some(min)) => Some(clamp(s, min, max)), + _ => None, + }), + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + _ => ternary_elementwise(ca, min, max, |opt_s, opt_min, opt_max| { + match (opt_s, opt_min, opt_max) { + (Some(s), Some(min), Some(max)) => Some(clamp(s, min, max)), + _ => None, + } + }), + } +} + +fn clip_min_max_helper( + ca: &ChunkedArray, + bound: &ChunkedArray, + op: F, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, + F: Fn(T::Native, T::Native) -> T::Native, +{ + match bound.len() { + 1 => match bound.get(0) { + Some(bound) => ca.apply_generic(|s| s.map(|s| op(s, bound))), + _ => ChunkedArray::::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) { + (Some(s), Some(bound)) => Some(op(s, bound)), + _ => None, + }), + } +} + +/// Clamp underlying values to the `min` and `max` values. +pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult { + polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported."); + + let original_type = s.dtype(); + // cast min & max to the dtype of s first. + let (min, max) = (min.cast(s.dtype())?, max.cast(s.dtype())?); + + let (s, min, max) = ( + s.to_physical_repr(), + min.to_physical_repr(), + max.to_physical_repr(), + ); + + match s.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); + let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); + let out = clip_helper(ca, min, max).into_series(); + if original_type.is_logical(){ + out.cast(original_type) + }else{ + Ok(out) + } + }) + }, + dt => polars_bail!(opq = clippy, dt), + } +} + +/// Clamp underlying values to the `max` value. +pub fn clip_max(s: &Series, max: &Series) -> PolarsResult { + polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported."); + + let original_type = s.dtype(); + // cast max to the dtype of s first. + let max = max.cast(s.dtype())?; + + let (s, max) = (s.to_physical_repr(), max.to_physical_repr()); + + match s.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); + let out = clip_min_max_helper(ca, max, clamp_max).into_series(); + if original_type.is_logical(){ + out.cast(original_type) + }else{ + Ok(out) + } + }) + }, + dt => polars_bail!(opq = clippy_max, dt), + } +} + +/// Clamp underlying values to the `min` value. +pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { + polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported."); + + let original_type = s.dtype(); + // cast min to the dtype of s first. + let min = min.cast(s.dtype())?; + + let (s, min) = (s.to_physical_repr(), min.to_physical_repr()); + + match s.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); + let out = clip_min_max_helper(ca, min, clamp_min).into_series(); + if original_type.is_logical(){ + out.cast(original_type) + }else{ + Ok(out) + } + }) + }, + dt => polars_bail!(opq = clippy_min, dt), + } +} diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs new file mode 100644 index 000000000000..d7053f03e63d --- /dev/null +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -0,0 +1,80 @@ +use std::ops::{BitAnd, BitOr}; + +use polars_core::prelude::*; +use polars_core::POOL; +use rayon::prelude::*; + +pub fn sum_horizontal(s: &[Series]) -> PolarsResult { + let out = POOL + .install(|| { + s.par_iter() + .try_fold( + || UInt32Chunked::new("", &[0u32]).into_series(), + |acc, b| { + PolarsResult::Ok( + acc.fill_null(FillNullStrategy::Zero)? + + b.fill_null(FillNullStrategy::Zero)?, + ) + }, + ) + .try_reduce( + || UInt32Chunked::new("", &[0u32]).into_series(), + |a, b| { + PolarsResult::Ok( + a.fill_null(FillNullStrategy::Zero)? + + b.fill_null(FillNullStrategy::Zero)?, + ) + }, + ) + })? + .with_name("sum"); + Ok(out) +} + +pub fn any_horizontal(s: &[Series]) -> PolarsResult { + let out = POOL + .install(|| { + s.par_iter() + .try_fold( + || BooleanChunked::new("", &[false]), + |acc, b| { + let b = b.cast(&DataType::Boolean)?; + let b = b.bool()?; + PolarsResult::Ok((&acc).bitor(b)) + }, + ) + .try_reduce(|| BooleanChunked::new("", [false]), |a, b| Ok(a.bitor(b))) + })? + .with_name("any"); + Ok(out.into_series()) +} + +pub fn all_horizontal(s: &[Series]) -> PolarsResult { + let out = POOL + .install(|| { + s.par_iter() + .try_fold( + || BooleanChunked::new("", &[true]), + |acc, b| { + let b = b.cast(&DataType::Boolean)?; + let b = b.bool()?; + PolarsResult::Ok((&acc).bitand(b)) + }, + ) + .try_reduce(|| BooleanChunked::new("", [true]), |a, b| Ok(a.bitand(b))) + })? + .with_name("all"); + Ok(out.into_series()) +} + +#[cfg(feature = "zip_with")] +pub fn max_horizontal(s: &[Series]) -> PolarsResult> { + let df = DataFrame::new_no_checks(Vec::from(s)); + df.hmax().map(|opt_s| opt_s.map(|s| s.with_name("max"))) +} + +#[cfg(feature = "zip_with")] +pub fn min_horizontal(s: &[Series]) -> PolarsResult> { + let df = DataFrame::new_no_checks(Vec::from(s)); + df.hmin().map(|opt_s| opt_s.map(|s| s.with_name("min"))) +} diff --git a/crates/polars-ops/src/series/ops/index.rs b/crates/polars-ops/src/series/ops/index.rs new file mode 100644 index 000000000000..a73b5378dd51 --- /dev/null +++ b/crates/polars-ops/src/series/ops/index.rs @@ -0,0 +1,65 @@ +use std::fmt::Debug; + +use polars_core::error::{polars_bail, polars_ensure, PolarsError, PolarsResult}; +use polars_core::export::num::{FromPrimitive, Signed, ToPrimitive, Zero}; +use polars_core::prelude::{ChunkedArray, DataType, IdxCa, PolarsIntegerType, Series, IDX_DTYPE}; +use polars_utils::IdxSize; + +fn convert(ca: &ChunkedArray, target_len: usize) -> PolarsResult +where + T: PolarsIntegerType, + IdxSize: TryFrom, + >::Error: Debug, + T::Native: FromPrimitive + Signed + Zero, +{ + let len = + i64::from_usize(target_len).ok_or_else(|| PolarsError::ComputeError("overflow".into()))?; + + let zero = T::Native::zero(); + + ca.try_apply_values_generic(|v| { + if v >= zero { + Ok(IdxSize::try_from(v).unwrap()) + } else { + IdxSize::from_i64(len + v.to_i64().unwrap()).ok_or_else(|| { + PolarsError::OutOfBounds( + format!( + "index {} is out of bounds for series of len {}", + v, target_len + ) + .into(), + ) + }) + } + }) +} + +pub fn convert_to_unsigned_index(s: &Series, target_len: usize) -> PolarsResult { + let dtype = s.dtype(); + polars_ensure!(dtype.is_integer(), InvalidOperation: "expected integers as index"); + if dtype.is_unsigned() { + let out = s.cast(&IDX_DTYPE).unwrap(); + return Ok(out.idx().unwrap().clone()); + } + match dtype { + DataType::Int64 => { + let ca = s.i64().unwrap(); + convert(ca, target_len) + }, + DataType::Int32 => { + let ca = s.i32().unwrap(); + convert(ca, target_len) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + let ca = s.i16().unwrap(); + convert(ca, target_len) + }, + #[cfg(feature = "dtype-i8")] + DataType::Int8 => { + let ca = s.i8().unwrap(); + convert(ca, target_len) + }, + _ => unreachable!(), + } +} diff --git a/crates/polars-ops/src/series/ops/is_first.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs similarity index 58% rename from crates/polars-ops/src/series/ops/is_first.rs rename to crates/polars-ops/src/series/ops/is_first_distinct.rs index be7c726100b6..9542394c00ef 100644 --- a/crates/polars-ops/src/series/ops/is_first.rs +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -2,13 +2,11 @@ use std::hash::Hash; use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; +use polars_arrow::bit_util::*; use polars_arrow::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; - -use crate::series::ops::arg_min_max::arg_max_bool; - -fn is_first_numeric(ca: &ChunkedArray) -> BooleanChunked +fn is_first_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, T::Native: Hash + Eq, @@ -23,7 +21,7 @@ where BooleanChunked::from_chunk_iter(ca.name(), chunks) } -fn is_first_bin(ca: &BinaryChunked) -> BooleanChunked { +fn is_first_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { let mut unique = PlHashSet::new(); let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { arr.into_iter() @@ -34,22 +32,46 @@ fn is_first_bin(ca: &BinaryChunked) -> BooleanChunked { BooleanChunked::from_chunk_iter(ca.name(), chunks) } -fn is_first_boolean(ca: &BooleanChunked) -> BooleanChunked { +fn is_first_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { let mut out = MutableBitmap::with_capacity(ca.len()); out.extend_constant(ca.len(), false); - if let Some(index) = arg_max_bool(ca) { - out.set(index, true) - } - if let Some(index) = ca.first_non_null() { - out.set(index, true) - } + if ca.null_count() == ca.len() { + out.set(0, true); + } else { + let ca = ca.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + if ca.null_count() == 0 { + let (true_index, false_index) = + find_first_true_false_no_null(arr.values().chunks::()); + if let Some(idx) = true_index { + out.set(idx, true) + } + if let Some(idx) = false_index { + out.set(idx, true) + } + } else { + let (true_index, false_index, null_index) = find_first_true_false_null( + arr.values().chunks::(), + arr.validity().unwrap().chunks::(), + ); + if let Some(idx) = true_index { + out.set(idx, true) + } + if let Some(idx) = false_index { + out.set(idx, true) + } + if let Some(idx) = null_index { + out.set(idx, true) + } + } + } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); BooleanChunked::with_chunk(ca.name(), arr) } #[cfg(feature = "dtype-struct")] -fn is_first_struct(s: &Series) -> PolarsResult { +fn is_first_distinct_struct(s: &Series) -> PolarsResult { let groups = s.group_tuples(true, false)?; let first = groups.take_group_firsts(); let mut out = MutableBitmap::with_capacity(s.len()); @@ -65,7 +87,7 @@ fn is_first_struct(s: &Series) -> PolarsResult { } #[cfg(feature = "group_by_list")] -fn is_first_list(ca: &ListChunked) -> PolarsResult { +fn is_first_distinct_list(ca: &ListChunked) -> PolarsResult { let groups = ca.group_tuples(true, false)?; let first = groups.take_group_firsts(); let mut out = MutableBitmap::with_capacity(ca.len()); @@ -80,45 +102,52 @@ fn is_first_list(ca: &ListChunked) -> PolarsResult { Ok(BooleanChunked::with_chunk(ca.name(), arr)) } -pub fn is_first(s: &Series) -> PolarsResult { +pub fn is_first_distinct(s: &Series) -> PolarsResult { + // fast path. + if s.len() == 0 { + return Ok(BooleanChunked::full_null(s.name(), 0)); + } else if s.len() == 1 { + return Ok(BooleanChunked::new(s.name(), &[true])); + } + let s = s.to_physical_repr(); use DataType::*; let out = match s.dtype() { Boolean => { let ca = s.bool().unwrap(); - is_first_boolean(ca) + is_first_distinct_boolean(ca) }, Binary => { let ca = s.binary().unwrap(); - is_first_bin(ca) + is_first_distinct_bin(ca) }, Utf8 => { let s = s.cast(&Binary).unwrap(); - return is_first(&s); + return is_first_distinct(&s); }, Float32 => { let ca = s.bit_repr_small(); - is_first_numeric(&ca) + is_first_distinct_numeric(&ca) }, Float64 => { let ca = s.bit_repr_large(); - is_first_numeric(&ca) + is_first_distinct_numeric(&ca) }, dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - is_first_numeric(ca) + is_first_distinct_numeric(ca) }) }, #[cfg(feature = "dtype-struct")] - Struct(_) => return is_first_struct(&s), + Struct(_) => return is_first_distinct_struct(&s), #[cfg(feature = "group_by_list")] List(inner) if inner.is_numeric() => { let ca = s.list().unwrap(); - return is_first_list(ca); + return is_first_distinct_list(ca); }, - dt => polars_bail!(opq = is_first, dt), + dt => polars_bail!(opq = is_first_distinct, dt), }; Ok(out) } diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index b18fefb66dfb..cbc7f822eeda 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -7,8 +7,7 @@ use polars_core::with_match_physical_integer_polars_type; fn is_in_helper<'a, T>(ca: &'a ChunkedArray, other: &Series) -> PolarsResult where T: PolarsDataType, - ChunkedArray: HasUnderlyingArray, - < as HasUnderlyingArray>::ArrayT as StaticArray>::ValueT<'a>: Hash + Eq + Copy, + T::Physical<'a>: Hash + Eq + Copy, { let mut set = PlHashSet::with_capacity(other.len()); @@ -41,19 +40,16 @@ where let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { let value = ca_in.get(0); - other - .list()? - .amortized_iter() - .map(|opt_s| { - opt_s.map(|s| { - let ca = s.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == value) - }) == Some(true) - }) - .collect_trusted() + other.list()?.apply_amortized_generic(|opt_s| { + Some(opt_s.map(|s| { + let ca = s.as_ref().unpack::().unwrap(); + ca.into_iter().any(|a| a == value) + }) == Some(true)) + }) } else { polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); - ca_in.into_iter() + // SAFETY: unstable series never lives longer than the iterator. + unsafe { ca_in.into_iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { (val, Some(series)) => { @@ -62,7 +58,7 @@ where } _ => false, }) - .collect_trusted() + .collect_trusted()} }; ca.rename(ca_in.name()); Ok(ca) @@ -93,22 +89,18 @@ fn is_in_utf8(ca_in: &Utf8Chunked, other: &Series) -> PolarsResult { - let mut ca: BooleanChunked = other - .amortized_iter() - .map(|opt_s| opt_s.map(|s| s.as_ref().null_count() > 0) == Some(true)) - .collect_trusted(); - ca.rename(ca_in.name()); - Ok(ca) - }, + None => Ok(other + .apply_amortized_generic(|opt_s| { + opt_s.map(|s| Some(s.as_ref().null_count() > 0) == Some(true)) + }) + .with_name(ca_in.name())), Some(value) => { match rev_map.find(value) { // all false None => Ok(BooleanChunked::full(ca_in.name(), false, other.len())), - Some(idx) => { - let mut ca: BooleanChunked = other - .amortized_iter() - .map(|opt_s| { + Some(idx) => Ok(other + .apply_amortized_generic(|opt_s| { + Some( opt_s.map(|s| { let s = s.as_ref().to_physical_repr(); let ca = s.as_ref().u32().unwrap(); @@ -117,12 +109,10 @@ fn is_in_utf8(ca_in: &Utf8Chunked, other: &Series) -> PolarsResult PolarsResult { let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { let value = ca_in.get(0); - other - .list()? - .amortized_iter() - .map(|opt_b| { - opt_b.map(|s| { - let ca = s.as_ref().unpack::().unwrap(); - ca.into_iter().any(|a| a == value) - }) == Some(true) - }) - .collect_trusted() + + other.list()?.apply_amortized_generic(|opt_b| { + Some(opt_b.map(|s| { + let ca = s.as_ref().unpack::().unwrap(); + ca.into_iter().any(|a| a == value) + }) == Some(true)) + }) } else { polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); - ca_in.into_iter() + // SAFETY: unstable series never lives longer than the iterator. + unsafe { ca_in.into_iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { (val, Some(series)) => { @@ -167,7 +155,7 @@ fn is_in_binary(ca_in: &BinaryChunked, other: &Series) -> PolarsResult false, }) - .collect_trusted() + .collect_trusted()} }; ca.rename(ca_in.name()); Ok(ca) @@ -188,7 +176,8 @@ fn is_in_boolean(ca_in: &BooleanChunked, other: &Series) -> PolarsResult { let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { let value = ca_in.get(0); - // safety: we know the iterators len + // SAFETY: we know the iterators len + // SAFETY: unstable series never lives longer than the iterator. unsafe { other .list()? @@ -204,7 +193,8 @@ fn is_in_boolean(ca_in: &BooleanChunked, other: &Series) -> PolarsResult { @@ -214,7 +204,7 @@ fn is_in_boolean(ca_in: &BooleanChunked, other: &Series) -> PolarsResult false, }) .collect_trusted() - }; + }}; ca.rename(ca_in.name()); Ok(ca) } @@ -226,7 +216,7 @@ fn is_in_boolean(ca_in: &BooleanChunked, other: &Series) -> PolarsResult PolarsResult { - let ca = series.as_ref().struct_().unwrap(); - ca.into_iter().any(|a| a == val) - }, - _ => false, - }) - .collect() + unsafe { + ca_in + .into_iter() + .zip(other.list()?.amortized_iter()) + .map(|(value, series)| match (value, series) { + (val, Some(series)) => { + let ca = series.as_ref().struct_().unwrap(); + ca.into_iter().any(|a| a == val) + }, + _ => false, + }) + .collect() + } }; ca.rename(ca_in.name()); Ok(ca) @@ -349,7 +340,7 @@ pub fn is_in(s: &Series, other: &Series) -> PolarsResult { match s.dtype() { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => { - use polars_core::frame::hash_join::_check_categorical_src; + use crate::frame::join::_check_categorical_src; _check_categorical_src(s.dtype(), other.dtype())?; let ca = s.categorical().unwrap(); let ca = ca.logical(); @@ -413,6 +404,11 @@ pub fn is_in(s: &Series, other: &Series) -> PolarsResult { is_in_numeric(ca, other) }) }, - dt => polars_bail!(opq = is_int, dt), + DataType::Null => { + let series_bool = s.cast(&DataType::Boolean)?; + let ca = series_bool.bool().unwrap(); + Ok(ca.clone()) + }, + dt => polars_bail!(opq = is_in, dt), } } diff --git a/crates/polars-ops/src/series/ops/is_last_distinct.rs b/crates/polars-ops/src/series/ops/is_last_distinct.rs new file mode 100644 index 000000000000..57c388f2c5fc --- /dev/null +++ b/crates/polars-ops/src/series/ops/is_last_distinct.rs @@ -0,0 +1,181 @@ +use std::hash::Hash; + +use arrow::array::BooleanArray; +use arrow::bitmap::MutableBitmap; +use polars_arrow::utils::CustomIterTools; +use polars_core::prelude::*; +use polars_core::utils::NoNull; +use polars_core::with_match_physical_integer_polars_type; + +pub fn is_last_distinct(s: &Series) -> PolarsResult { + // fast path. + if s.len() == 0 { + return Ok(BooleanChunked::full_null(s.name(), 0)); + } else if s.len() == 1 { + return Ok(BooleanChunked::new(s.name(), &[true])); + } + + let s = s.to_physical_repr(); + + use DataType::*; + let out = match s.dtype() { + Boolean => { + let ca = s.bool().unwrap(); + is_last_distinct_boolean(ca) + }, + Binary => { + let ca = s.binary().unwrap(); + is_last_distinct_bin(ca) + }, + Utf8 => { + let s = s.cast(&Binary).unwrap(); + return is_last_distinct(&s); + }, + Float32 => { + let ca = s.bit_repr_small(); + is_last_distinct_numeric(&ca) + }, + Float64 => { + let ca = s.bit_repr_large(); + is_last_distinct_numeric(&ca) + }, + dt if dt.is_numeric() => { + with_match_physical_integer_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + is_last_distinct_numeric(ca) + }) + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => return is_last_distinct_struct(&s), + #[cfg(feature = "group_by_list")] + List(inner) if inner.is_numeric() => { + let ca = s.list().unwrap(); + return is_last_distinct_list(ca); + }, + dt => polars_bail!(opq = is_last_distinct, dt), + }; + Ok(out) +} + +fn is_last_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { + let mut out = MutableBitmap::with_capacity(ca.len()); + out.extend_constant(ca.len(), false); + + if ca.null_count() == ca.len() { + out.set(ca.len() - 1, true); + } + // TODO supports fast path. + else { + let mut first_true_found = false; + let mut first_false_found = false; + let mut first_null_found = false; + let mut all_found = false; + let ca = ca.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + arr.into_iter() + .enumerate() + .rev() + .find_map(|(idx, val)| match val { + Some(true) if !first_true_found => { + first_true_found = true; + all_found &= first_true_found; + out.set(idx, true); + if all_found { + Some(()) + } else { + None + } + }, + Some(false) if !first_false_found => { + first_false_found = true; + all_found &= first_false_found; + out.set(idx, true); + if all_found { + Some(()) + } else { + None + } + }, + None if !first_null_found => { + first_null_found = true; + all_found &= first_null_found; + out.set(idx, true); + if all_found { + Some(()) + } else { + None + } + }, + _ => None, + }); + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + BooleanChunked::with_chunk(ca.name(), arr) +} + +fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { + let ca = ca.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let mut unique = PlHashSet::new(); + let mut new_ca: BooleanChunked = arr + .into_iter() + .rev() + .map(|opt_v| unique.insert(opt_v)) + .collect_reversed::>() + .into_inner(); + new_ca.rename(ca.name()); + new_ca +} + +fn is_last_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked +where + T: PolarsNumericType, + T::Native: Hash + Eq, +{ + let ca = ca.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let mut unique = PlHashSet::new(); + let mut new_ca: BooleanChunked = arr + .into_iter() + .rev() + .map(|opt_v| unique.insert(opt_v)) + .collect_reversed::>() + .into_inner(); + new_ca.rename(ca.name()); + new_ca +} + +#[cfg(feature = "dtype-struct")] +fn is_last_distinct_struct(s: &Series) -> PolarsResult { + let groups = s.group_tuples(true, false)?; + // SAFETY: all groups have at least a single member + let last = unsafe { groups.take_group_lasts() }; + let mut out = MutableBitmap::with_capacity(s.len()); + out.extend_constant(s.len(), false); + + for idx in last { + // Group tuples are always in bounds + unsafe { out.set_unchecked(idx as usize, true) } + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + Ok(BooleanChunked::with_chunk(s.name(), arr)) +} + +#[cfg(feature = "group_by_list")] +fn is_last_distinct_list(ca: &ListChunked) -> PolarsResult { + let groups = ca.group_tuples(true, false)?; + // SAFETY: all groups have at least a single member + let last = unsafe { groups.take_group_lasts() }; + let mut out = MutableBitmap::with_capacity(ca.len()); + out.extend_constant(ca.len(), false); + + for idx in last { + // Group tuples are always in bounds + unsafe { out.set_unchecked(idx as usize, true) } + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + Ok(BooleanChunked::with_chunk(ca.name(), arr)) +} diff --git a/crates/polars-ops/src/series/ops/is_unique.rs b/crates/polars-ops/src/series/ops/is_unique.rs index 51bac025ac2f..d321f627979f 100644 --- a/crates/polars-ops/src/series/ops/is_unique.rs +++ b/crates/polars-ops/src/series/ops/is_unique.rs @@ -1,7 +1,6 @@ -use std::hash::Hash; - use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; +use arrow::util::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; @@ -10,7 +9,7 @@ fn is_unique_ca<'a, T>(ca: &'a ChunkedArray, invert: bool) -> BooleanChunked where T: PolarsDataType, &'a ChunkedArray: IntoIterator, - <<&'a ChunkedArray as IntoIterator>::IntoIter as IntoIterator>::Item: Hash + Eq, + <<&'a ChunkedArray as IntoIterator>::IntoIter as IntoIterator>::Item: TotalHash + TotalEq, { let len = ca.len(); let mut idx_key = PlHashMap::new(); @@ -19,7 +18,7 @@ where // just toggle a boolean that's false if a group has multiple entries. ca.into_iter().enumerate().for_each(|(idx, key)| { idx_key - .entry(key) + .entry(TotalOrdWrap(key)) .and_modify(|v: &mut (IdxSize, bool)| v.1 = false) .or_insert((idx as IdxSize, true)); }); @@ -56,12 +55,12 @@ fn dispatcher(s: &Series, invert: bool) -> PolarsResult { is_unique_ca(ca, invert) }, Float32 => { - let ca = s.bit_repr_small(); - is_unique_ca(&ca, invert) + let ca = s.f32().unwrap(); + is_unique_ca(ca, invert) }, Float64 => { - let ca = s.bit_repr_large(); - is_unique_ca(&ca, invert) + let ca = s.f64().unwrap(); + is_unique_ca(ca, invert) }, #[cfg(feature = "dtype-struct")] Struct(_) => { diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 6abde8b44886..d4c10d7fd078 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -2,20 +2,28 @@ mod approx_algo; #[cfg(feature = "approx_unique")] mod approx_unique; mod arg_min_max; +mod clip; #[cfg(feature = "cutqcut")] mod cut; #[cfg(feature = "round_series")] mod floor_divide; #[cfg(feature = "fused")] mod fused; -#[cfg(feature = "is_first")] -mod is_first; +mod horizontal; +#[cfg(feature = "convert_index")] +mod index; +#[cfg(feature = "is_first_distinct")] +mod is_first_distinct; #[cfg(feature = "is_in")] mod is_in; +#[cfg(feature = "is_last_distinct")] +mod is_last_distinct; #[cfg(feature = "is_unique")] mod is_unique; #[cfg(feature = "log")] mod log; +#[cfg(feature = "rank")] +mod rank; #[cfg(feature = "rle")] mod rle; #[cfg(feature = "rolling_window")] @@ -30,21 +38,29 @@ pub use approx_algo::*; #[cfg(feature = "approx_unique")] pub use approx_unique::*; pub use arg_min_max::ArgAgg; +pub use clip::*; #[cfg(feature = "cutqcut")] pub use cut::*; #[cfg(feature = "round_series")] pub use floor_divide::*; #[cfg(feature = "fused")] pub use fused::*; -#[cfg(feature = "is_first")] -pub use is_first::*; +pub use horizontal::*; +#[cfg(feature = "convert_index")] +pub use index::*; +#[cfg(feature = "is_first_distinct")] +pub use is_first_distinct::*; #[cfg(feature = "is_in")] pub use is_in::*; +#[cfg(feature = "is_last_distinct")] +pub use is_last_distinct::*; #[cfg(feature = "is_unique")] pub use is_unique::*; #[cfg(feature = "log")] pub use log::*; use polars_core::prelude::*; +#[cfg(feature = "rank")] +pub use rank::*; #[cfg(feature = "rle")] pub use rle::*; #[cfg(feature = "rolling_window")] diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs new file mode 100644 index 000000000000..41f9b4ca8eb9 --- /dev/null +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -0,0 +1,339 @@ +use arrow::array::BooleanArray; +use arrow::compute::concatenate::concatenate_validities; +use polars_core::prelude::*; +#[cfg(feature = "random")] +use rand::prelude::SliceRandom; +use rand::prelude::*; +#[cfg(feature = "random")] +use rand::{rngs::SmallRng, SeedableRng}; + +use crate::prelude::SeriesSealed; + +#[derive(Copy, Clone)] +pub enum RankMethod { + Average, + Min, + Max, + Dense, + Ordinal, + #[cfg(feature = "random")] + Random, +} + +// We might want to add a `nulls_last` or `null_behavior` field. +#[derive(Copy, Clone)] +pub struct RankOptions { + pub method: RankMethod, + pub descending: bool, +} + +impl Default for RankOptions { + fn default() -> Self { + Self { + method: RankMethod::Dense, + descending: false, + } + } +} + +#[cfg(feature = "random")] +fn get_random_seed() -> u64 { + let mut rng = SmallRng::from_entropy(); + + rng.next_u64() +} + +unsafe fn rank_impl(idxs: &IdxCa, neq: &BooleanArray, mut flush_ties: F) { + let mut ties_indices = Vec::with_capacity(128); + let mut idx_it = idxs.downcast_iter().flat_map(|arr| arr.values_iter()); + let Some(first_idx) = idx_it.next() else { + return; + }; + ties_indices.push(*first_idx); + + for (eq_idx, idx) in idx_it.enumerate() { + if neq.value_unchecked(eq_idx) { + flush_ties(&mut ties_indices); + ties_indices.clear() + } + + ties_indices.push(*idx); + } + flush_ties(&mut ties_indices); +} + +fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> Series { + let len = s.len(); + let null_count = s.null_count(); + match len { + 1 => { + return match method { + Average => Series::new(s.name(), &[1.0f64]), + _ => Series::new(s.name(), &[1 as IdxSize]), + }; + }, + 0 => { + return match method { + Average => Float64Chunked::from_slice(s.name(), &[]).into_series(), + _ => IdxCa::from_slice(s.name(), &[]).into_series(), + }; + }, + _ => {}, + } + + if null_count == len { + return match method { + Average => Float64Chunked::full_null(s.name(), len).into_series(), + _ => IdxCa::full_null(s.name(), len).into_series(), + }; + } + + let sort_idx_ca = s + .arg_sort(SortOptions { + descending, + nulls_last: true, + ..Default::default() + }) + .slice(0, len - null_count); + + let chunk_refs: Vec<_> = s.chunks().iter().map(|c| &**c).collect(); + let validity = concatenate_validities(&chunk_refs); + + use RankMethod::*; + if let Ordinal = method { + let mut out = vec![0 as IdxSize; s.len()]; + let mut rank = 0; + for arr in sort_idx_ca.downcast_iter() { + for i in arr.values_iter() { + out[*i as usize] = rank + 1; + rank += 1; + } + } + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + } else { + let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) }; + let not_consecutive_same = sorted_values + .slice(1, sorted_values.len() - 1) + .not_equal(&sorted_values.slice(0, sorted_values.len() - 1)) + .unwrap() + .rechunk(); + let neq = not_consecutive_same.downcast_iter().next().unwrap(); + + let mut rank = 1; + match method { + #[cfg(feature = "random")] + Random => unsafe { + let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed)); + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + ties.shuffle(&mut rng); + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank; + rank += 1; + } + }); + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Average => unsafe { + let mut out = vec![0.0; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + let first = rank; + rank += ties.len() as IdxSize; + let last = rank - 1; + let avg = 0.5 * (first as f64 + last as f64); + for i in ties { + *out.get_unchecked_mut(*i as usize) = avg; + } + }); + Float64Chunked::new_from_owned_with_null_bitmap(s.name(), out, validity) + .into_series() + }, + Min => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + for i in ties.iter() { + *out.get_unchecked_mut(*i as usize) = rank; + } + rank += ties.len() as IdxSize; + }); + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Max => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + rank += ties.len() as IdxSize; + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank - 1; + } + }); + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Dense => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank; + } + rank += 1; + }); + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Ordinal => unreachable!(), + } + } +} + +pub trait SeriesRank: SeriesSealed { + fn rank(&self, options: RankOptions, seed: Option) -> Series { + rank(self.as_series(), options.method, options.descending, seed) + } +} + +impl SeriesRank for Series {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_rank() -> PolarsResult<()> { + let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]); + + let out = rank(&s, RankMethod::Ordinal, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]); + + #[cfg(feature = "random")] + { + let out = rank(&s, RankMethod::Random, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out[0], 2); + assert_eq!(out[6], 1); + assert_eq!(out[1] + out[3] + out[4], 12); + assert_eq!(out[2] + out[5], 13); + assert_ne!(out[1], out[3]); + assert_ne!(out[1], out[4]); + assert_ne!(out[3], out[4]); + } + + let out = rank(&s, RankMethod::Dense, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]); + + let out = rank(&s, RankMethod::Max, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]); + + let out = rank(&s, RankMethod::Min, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]); + + let out = rank(&s, RankMethod::Average, false, None) + .f64()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]); + + let s = Series::new( + "a", + &[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)], + ); + + let out = rank(&s, RankMethod::Average, false, None) + .f64()? + .into_iter() + .collect::>(); + + assert_eq!( + out, + &[ + Some(2.0f64), + Some(3.5), + Some(5.0), + Some(3.5), + None, + None, + Some(1.0) + ] + ); + let s = Series::new( + "a", + &[ + Some(5), + Some(6), + Some(4), + None, + Some(78), + Some(4), + Some(2), + Some(8), + ], + ); + let out = rank(&s, RankMethod::Max, false, None) + .idx()? + .into_iter() + .collect::>(); + assert_eq!( + out, + &[ + Some(4), + Some(5), + Some(3), + None, + Some(7), + Some(3), + Some(1), + Some(6) + ] + ); + + Ok(()) + } + + #[test] + fn test_rank_all_null() -> PolarsResult<()> { + let s = UInt32Chunked::new("", &[None, None, None]).into_series(); + let out = rank(&s, RankMethod::Average, false, None) + .f64()? + .into_iter() + .collect::>(); + assert_eq!(out, &[None, None, None]); + let out = rank(&s, RankMethod::Dense, false, None) + .idx()? + .into_iter() + .collect::>(); + assert_eq!(out, &[None, None, None]); + Ok(()) + } + + #[test] + fn test_rank_empty() { + let s = UInt32Chunked::from_slice("", &[]).into_series(); + let out = rank(&s, RankMethod::Average, false, None); + assert_eq!(out.dtype(), &DataType::Float64); + let out = rank(&s, RankMethod::Max, false, None); + assert_eq!(out.dtype(), &IDX_DTYPE); + } + + #[test] + fn test_rank_reverse() -> PolarsResult<()> { + let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]); + let out = rank(&s, RankMethod::Dense, true, None) + .idx()? + .into_iter() + .collect::>(); + assert_eq!(out, &[None, Some(2 as IdxSize), Some(2), Some(1), None]); + + Ok(()) + } +} diff --git a/crates/polars-ops/src/series/ops/rle.rs b/crates/polars-ops/src/series/ops/rle.rs index 15605a901772..31417de8fd0d 100644 --- a/crates/polars-ops/src/series/ops/rle.rs +++ b/crates/polars-ops/src/series/ops/rle.rs @@ -1,5 +1,6 @@ use polars_core::prelude::*; +/// Get the lengths of runs of identical values. pub fn rle(s: &Series) -> PolarsResult { let (s1, s2) = (s.slice(0, s.len() - 1), s.slice(1, s.len())); let s_neq = s1.not_equal_missing(&s2)?; @@ -22,6 +23,7 @@ pub fn rle(s: &Series) -> PolarsResult { Ok(StructChunked::new("rle", &outvals)?.into_series()) } +/// Similar to `rle`, but maps values to run IDs. pub fn rle_id(s: &Series) -> PolarsResult { if s.len() == 0 { return Ok(Series::new_empty("id", &DataType::UInt32)); diff --git a/crates/polars-pipe/Cargo.toml b/crates/polars-pipe/Cargo.toml index 592e713722e7..e1c3e8899744 100644 --- a/crates/polars-pipe/Cargo.toml +++ b/crates/polars-pipe/Cargo.toml @@ -9,16 +9,17 @@ repository = { workspace = true } description = "Lazy query engine for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", default-features = false } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["lazy", "zip_with", "random"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", default-features = false } -polars-ops = { version = "0.32.0", path = "../polars-ops", features = ["search_sorted"] } -polars-plan = { version = "0.32.0", path = "../polars-plan", default-features = false, features = ["compile"] } -polars-row = { version = "0.32.0", path = "../polars-row" } -polars-utils = { version = "0.32.0", path = "../polars-utils", features = ["sysinfo"] } +polars-arrow = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random", "rows", "chunked_ids"] } +polars-io = { workspace = true, features = ["ipc"] } +polars-ops = { workspace = true, features = ["search_sorted"] } +polars-plan = { workspace = true } +polars-row = { workspace = true } +polars-utils = { workspace = true, features = ["sysinfo"] } +tokio = { workspace = true, optional = true } -crossbeam-channel = { version = "0.5", optional = true } -crossbeam-queue = { version = "0.3", optional = true } +crossbeam-channel = { workspace = true } +crossbeam-queue = { version = "0.3" } enum_dispatch = { version = "0.3" } hashbrown = { workspace = true } num-traits = { workspace = true } @@ -29,13 +30,13 @@ smartstring = { workspace = true } version_check = { workspace = true } [features] -compile = ["crossbeam-channel", "crossbeam-queue", "polars-io/ipc"] csv = ["polars-plan/csv", "polars-io/csv"] -parquet = ["polars-plan/parquet", "polars-io/parquet"] +cloud = ["async", "polars-io/cloud", "polars-plan/cloud", "tokio"] +parquet = ["polars-plan/parquet", "polars-io/parquet", "polars-io/async"] ipc = ["polars-plan/ipc", "polars-io/ipc"] async = ["polars-plan/async", "polars-io/async"] nightly = ["polars-core/nightly", "polars-utils/nightly", "hashbrown/nightly"] -cross_join = ["polars-core/cross_join"] +cross_join = ["polars-ops/cross_join"] dtype-u8 = ["polars-core/dtype-u8"] dtype-u16 = ["polars-core/dtype-u16"] dtype-i8 = ["polars-core/dtype-i8"] @@ -44,4 +45,4 @@ dtype-decimal = ["polars-core/dtype-decimal"] dtype-array = ["polars-core/dtype-array"] dtype-categorical = ["polars-core/dtype-categorical"] trigger_ooc = [] -test = ["compile", "polars-core/chunked_ids"] +test = ["polars-core/chunked_ids"] diff --git a/crates/polars-pipe/README.md b/crates/polars-pipe/README.md index 9578d9703c10..1186ce9a898a 100644 --- a/crates/polars-pipe/README.md +++ b/crates/polars-pipe/README.md @@ -1,5 +1,5 @@ -# Polars Pipe +# polars-pipe -`polars-pipe` is a sub-crate that provides OOC (out of core) algorithms to the polars physical plans. +`polars-pipe` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, introducing OOC (out of core) algorithms to polars physical plans. -Not intended for external usage. +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-pipe/src/executors/operators/projection.rs b/crates/polars-pipe/src/executors/operators/projection.rs index d8acca09b92a..1501da49d5fa 100644 --- a/crates/polars-pipe/src/executors/operators/projection.rs +++ b/crates/polars-pipe/src/executors/operators/projection.rs @@ -3,18 +3,19 @@ use std::sync::Arc; use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_core::schema::SchemaRef; +use smartstring::alias::String as SmartString; use crate::expressions::PhysicalPipedExpr; use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; #[derive(Clone)] pub(crate) struct FastProjectionOperator { - columns: Arc<[Arc]>, + columns: Arc<[SmartString]>, input_schema: SchemaRef, } impl FastProjectionOperator { - pub(crate) fn new(columns: Arc<[Arc]>, input_schema: SchemaRef) -> Self { + pub(crate) fn new(columns: Arc<[SmartString]>, input_schema: SchemaRef) -> Self { Self { columns, input_schema, diff --git a/crates/polars-pipe/src/executors/sinks/file_sink.rs b/crates/polars-pipe/src/executors/sinks/file_sink.rs index 3845b03e590f..67dea31c355b 100644 --- a/crates/polars-pipe/src/executors/sinks/file_sink.rs +++ b/crates/polars-pipe/src/executors/sinks/file_sink.rs @@ -4,12 +4,13 @@ use std::thread::JoinHandle; use crossbeam_channel::{bounded, Receiver, Sender}; use polars_core::prelude::*; +#[cfg(feature = "csv")] use polars_io::csv::CsvWriter; #[cfg(feature = "parquet")] use polars_io::parquet::ParquetWriter; #[cfg(feature = "ipc")] use polars_io::prelude::IpcWriter; -#[cfg(feature = "ipc")] +#[cfg(any(feature = "ipc", feature = "csv"))] use polars_io::SerWriter; use polars_plan::prelude::*; @@ -23,7 +24,7 @@ trait SinkWriter { } #[cfg(feature = "parquet")] -impl SinkWriter for polars_io::parquet::BatchedWriter { +impl SinkWriter for polars_io::parquet::BatchedWriter { fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { self.write_batch(df) } @@ -35,7 +36,7 @@ impl SinkWriter for polars_io::parquet::BatchedWriter { } #[cfg(feature = "ipc")] -impl SinkWriter for polars_io::ipc::BatchedWriter { +impl SinkWriter for polars_io::ipc::BatchedWriter { fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { self.write_batch(df) } @@ -78,7 +79,7 @@ impl ParquetSink { .set_parallel(false) .batched(schema)?; - let writer = Box::new(writer) as Box; + let writer = Box::new(writer) as Box; let morsels_per_sink = morsels_per_sink(); let backpressure = morsels_per_sink * 2; @@ -98,6 +99,49 @@ impl ParquetSink { } } +#[cfg(all(feature = "parquet", feature = "cloud"))] +pub struct ParquetCloudSink {} +#[cfg(all(feature = "parquet", feature = "cloud"))] +impl ParquetCloudSink { + #[allow(clippy::new_ret_no_self)] + #[tokio::main(flavor = "current_thread")] + pub async fn new( + uri: &str, + cloud_options: Option<&polars_io::cloud::CloudOptions>, + parquet_options: ParquetWriteOptions, + schema: &Schema, + ) -> PolarsResult { + let cloud_writer = polars_io::cloud::CloudWriter::new(uri, cloud_options).await?; + let writer = ParquetWriter::new(cloud_writer) + .with_compression(parquet_options.compression) + .with_data_pagesize_limit(parquet_options.data_pagesize_limit) + .with_statistics(parquet_options.statistics) + .with_row_group_size(parquet_options.row_group_size) + // This is important! Otherwise we will deadlock + // See: #7074 + .set_parallel(false) + .batched(schema)?; + + let writer = Box::new(writer) as Box; + + let morsels_per_sink = morsels_per_sink(); + let backpressure = morsels_per_sink * 2; + let (sender, receiver) = bounded(backpressure); + + let io_thread_handle = Arc::new(Some(init_writer_thread( + receiver, + writer, + parquet_options.maintain_order, + morsels_per_sink, + ))); + + Ok(FilesSink { + sender, + io_thread_handle, + }) + } +} + #[cfg(feature = "ipc")] pub struct IpcSink {} #[cfg(feature = "ipc")] @@ -109,7 +153,7 @@ impl IpcSink { .with_compression(options.compression) .batched(schema)?; - let writer = Box::new(writer) as Box; + let writer = Box::new(writer) as Box; let morsels_per_sink = morsels_per_sink(); let backpressure = morsels_per_sink * 2; @@ -138,9 +182,9 @@ impl CsvSink { let file = std::fs::File::create(path)?; let writer = CsvWriter::new(file) .has_header(options.has_header) - .with_delimiter(options.serialize_options.delimiter) + .with_separator(options.serialize_options.separator) .with_line_terminator(options.serialize_options.line_terminator) - .with_quoting_char(options.serialize_options.quote) + .with_quote_char(options.serialize_options.quote_char) .with_batch_size(options.batch_size) .with_datetime_format(options.serialize_options.datetime_format) .with_date_format(options.serialize_options.date_format) @@ -170,10 +214,10 @@ impl CsvSink { } } -#[cfg(any(feature = "parquet", feature = "ipc"))] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] fn init_writer_thread( receiver: Receiver>, - mut writer: Box, + mut writer: Box, maintain_order: bool, // this is used to determine when a batch of chunks should be written to disk // all chunks per push should be collected to determine in which order they should diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs index 9e3276b9e0a6..247df7658e46 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs @@ -107,9 +107,9 @@ pub fn can_convert_to_hash_agg( } /// # Returns: -/// - input_dtype: dtype that goes into the agg expression -/// - physical expr: physical expression that produces the input of the aggregation -/// - aggregation function: the aggregation function +/// - input_dtype: dtype that goes into the agg expression +/// - physical expr: physical expression that produces the input of the aggregation +/// - aggregation function: the aggregation function pub(crate) fn convert_to_hash_agg( node: Node, expr_arena: &Arena, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs index 7898ee36d9cb..b5475767d875 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs @@ -42,7 +42,7 @@ impl MeanAgg { impl AggregateFn for MeanAgg where - K::POLARSTYPE: PolarsNumericType, + K::PolarsType: PolarsNumericType, K: NumericNative + Add, ::Simd: Add::Simd> + Sum, { @@ -107,7 +107,7 @@ where let arr = values.chunks().get_unchecked(0); arr.sliced_unchecked(offset as usize, length as usize) }; - let dtype = K::POLARSTYPE::get_dtype().to_arrow(); + let dtype = K::PolarsType::get_dtype().to_arrow(); let arr = polars_arrow::compute::cast::cast(arr.as_ref(), &dtype).unwrap(); let arr = unsafe { arr.as_any() diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs index 486ea1631424..b44faa172ef1 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs @@ -89,7 +89,7 @@ where length: IdxSize, values: &Series, ) { - let ca: &ChunkedArray = values.as_ref().as_ref(); + let ca: &ChunkedArray = values.as_ref().as_ref(); let arr = ca.downcast_iter().next().unwrap(); let arr = unsafe { arr.slice_typed_unchecked(offset as usize, length as usize) }; // convince the compiler that K::POLARSTYPE::Native == K diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs index 2133910065da..bb79272f21b2 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs @@ -34,7 +34,7 @@ impl + NumCast> SumAgg { impl AggregateFn for SumAgg where - K::POLARSTYPE: PolarsNumericType, + K::PolarsType: PolarsNumericType, K: NumericNative + Add, ::Simd: Add::Simd> + Sum, { @@ -89,7 +89,7 @@ where let arr = values.chunks().get_unchecked(0); arr.sliced_unchecked(offset as usize, length as usize) }; - let dtype = K::POLARSTYPE::get_dtype().to_arrow(); + let dtype = K::PolarsType::get_dtype().to_arrow(); let arr = polars_arrow::compute::cast::cast(arr.as_ref(), &dtype).unwrap(); let arr = unsafe { arr.as_any() diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs index bfd43fb33436..708b9f5dd71b 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/cross.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -6,6 +6,7 @@ use std::vec; use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; +use polars_ops::prelude::CrossJoin as CrossJoinTrait; use smartstring::alias::String as SmartString; use crate::operators::{ diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index e848f2ff5367..3963a2efcacc 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -4,14 +4,15 @@ use std::sync::Arc; use hashbrown::hash_map::RawEntryMut; use polars_arrow::export::arrow::array::BinaryArray; +use polars_core::datatypes::ChunkId; use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; -use polars_core::frame::hash_join::ChunkId; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; use polars_utils::hash_to_partition; use polars_utils::slice::GetSaferUnchecked; +use super::*; use crate::executors::sinks::joins::inner_left::GenericJoinProbe; use crate::executors::sinks::utils::{hash_rows, load_vec}; use crate::executors::sinks::HASHMAP_INIT_SIZE; diff --git a/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs index 7433dfa848cc..eaeeb80629b4 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs @@ -2,11 +2,13 @@ use std::borrow::Cow; use std::sync::Arc; use polars_arrow::export::arrow::array::BinaryArray; +use polars_core::datatypes::ChunkId; use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; -use polars_core::frame::hash_join::{ChunkId, _finish_join}; use polars_core::prelude::*; use polars_core::series::IsSorted; +use polars_ops::frame::join::_finish_join; +use polars_ops::prelude::JoinType; use polars_row::RowsEncoded; use polars_utils::hash_to_partition; use polars_utils::slice::GetSaferUnchecked; diff --git a/crates/polars-pipe/src/executors/sinks/joins/mod.rs b/crates/polars-pipe/src/executors/sinks/joins/mod.rs index 117cdb726872..f906f5f1d190 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/mod.rs @@ -6,3 +6,4 @@ mod inner_left; #[cfg(feature = "cross_join")] pub(crate) use cross::*; pub(crate) use generic_build::GenericBuild; +use polars_ops::prelude::JoinType; diff --git a/crates/polars-pipe/src/executors/sinks/mod.rs b/crates/polars-pipe/src/executors/sinks/mod.rs index 328ab178a9e6..8c9b46366da7 100644 --- a/crates/polars-pipe/src/executors/sinks/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/mod.rs @@ -1,4 +1,4 @@ -#[cfg(any(feature = "parquet", feature = "ipc"))] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] mod file_sink; pub(crate) mod group_by; mod io; @@ -10,7 +10,7 @@ mod slice; mod sort; mod utils; -#[cfg(any(feature = "parquet", feature = "ipc"))] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] pub(crate) use file_sink::*; pub(crate) use joins::*; pub(crate) use ordered::*; diff --git a/crates/polars-pipe/src/executors/sinks/ordered.rs b/crates/polars-pipe/src/executors/sinks/ordered.rs index 4cea0a032c3d..156dad5f9a9e 100644 --- a/crates/polars-pipe/src/executors/sinks/ordered.rs +++ b/crates/polars-pipe/src/executors/sinks/ordered.rs @@ -1,6 +1,8 @@ use std::any::Any; use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; use crate::operators::{ chunks_to_df_unchecked, DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult, @@ -10,11 +12,15 @@ use crate::operators::{ #[derive(Clone)] pub struct OrderedSink { chunks: Vec, + schema: SchemaRef, } impl OrderedSink { - pub fn new() -> Self { - OrderedSink { chunks: vec![] } + pub fn new(schema: SchemaRef) -> Self { + OrderedSink { + chunks: vec![], + schema, + } } fn sort(&mut self) { @@ -41,6 +47,11 @@ impl Sink for OrderedSink { Box::new(self.clone()) } fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + if self.chunks.is_empty() { + return Ok(FinalizedSink::Finished(DataFrame::from( + self.schema.as_ref(), + ))); + } self.sort(); let chunks = std::mem::take(&mut self.chunks); Ok(FinalizedSink::Finished(chunks_to_df_unchecked(chunks))) diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 752e507517a8..671632955a03 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -89,9 +89,12 @@ impl SortSink { self.dump(true)?; } }; - self.current_chunks_size += chunk_bytes; - self.current_chunk_rows += chunk.data.height(); - self.chunks.push(chunk.data); + // don't add empty dataframes + if chunk.data.height() > 0 || self.chunks.is_empty() { + self.current_chunks_size += chunk_bytes; + self.current_chunk_rows += chunk.data.height(); + self.chunks.push(chunk.data); + } Ok(()) } @@ -211,11 +214,13 @@ impl Sink for SortSink { } pub(super) fn sort_accumulated( - df: DataFrame, + mut df: DataFrame, sort_idx: usize, descending: bool, slice: Option<(i64, usize)>, ) -> PolarsResult { + // This is needed because we can have empty blocks and we require chunks to have single chunks. + df.as_single_chunk_par(); let sort_column = df.get_columns()[sort_idx].clone(); df.sort_impl( vec![sort_column], diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 32f98b0194a7..34431ade9c17 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -68,7 +68,7 @@ fn finalize_dataframe( // we decode the row-encoded binary column // this will be decoded into multiple columns - // this are the columns we sorted by + // these are the columns we sorted by // those need to be inserted at the `sort_idx` position // in the `DataFrame`. if can_decode { diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index 1053ff1d236c..b1297c39e07c 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -63,7 +63,7 @@ impl CsvSource { .unwrap() .has_header(options.has_header) .with_dtypes(Some(self.schema.clone())) - .with_delimiter(options.delimiter) + .with_separator(options.separator) .with_ignore_errors(options.ignore_errors) .with_skip_rows(options.skip_rows) .with_n_rows(n_rows) diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index 4cb11f6b1688..b8afaf1b8208 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -1,13 +1,18 @@ +use std::ops::Deref; use std::path::PathBuf; +use std::sync::Arc; -use polars_core::cloud::CloudOptions; use polars_core::error::PolarsResult; -use polars_core::schema::*; +use polars_core::utils::arrow::io::parquet::read::FileMetaData; use polars_core::POOL; +use polars_io::cloud::CloudOptions; use polars_io::parquet::{BatchedParquetReader, ParquetReader}; +use polars_io::pl_async::get_runtime; +use polars_io::prelude::materialize_projection; #[cfg(feature = "async")] use polars_io::prelude::ParquetAsyncReader; use polars_io::{is_cloud_url, SerReader}; +use polars_plan::logical_plan::FileInfo; use polars_plan::prelude::{FileScanOptions, ParquetOptions}; use polars_utils::IdxSize; @@ -23,7 +28,8 @@ pub struct ParquetSource { file_options: Option, #[allow(dead_code)] cloud_options: Option, - schema: Option, + metadata: Option>, + file_info: FileInfo, verbose: bool, } @@ -35,13 +41,23 @@ impl ParquetSource { let path = self.path.take().unwrap(); let options = self.options.take().unwrap(); let file_options = self.file_options.take().unwrap(); - let schema = self.schema.take().unwrap(); - let projection: Option> = file_options.with_columns.map(|with_columns| { - with_columns - .iter() - .map(|name| schema.index_of(name).unwrap()) - .collect() - }); + let schema = self.file_info.schema.clone(); + + let hive_partitions = self + .file_info + .hive_parts + .as_ref() + .map(|hive| hive.materialize_partition_columns()); + + let projection = materialize_projection( + file_options + .with_columns + .as_deref() + .map(|cols| cols.deref()), + &schema, + hive_partitions.as_deref(), + false, + ); let n_cols = projection.as_ref().map(|v| v.len()).unwrap_or(schema.len()); let chunk_size = determine_chunk_size(n_cols, self.n_threads)?; @@ -60,12 +76,27 @@ impl ParquetSource { #[cfg(feature = "async")] { let uri = path.to_string_lossy(); - ParquetAsyncReader::from_uri(&uri, self.cloud_options.as_ref())? + polars_io::pl_async::get_runtime().block_on(async { + ParquetAsyncReader::from_uri( + &uri, + self.cloud_options.as_ref(), + Some(self.file_info.schema.clone()), + self.metadata.clone(), + ) + .await? .with_n_rows(file_options.n_rows) .with_row_count(file_options.row_count) .with_projection(projection) .use_statistics(options.use_statistics) - .batched(chunk_size)? + .with_hive_partition_columns( + self.file_info + .hive_parts + .as_ref() + .map(|hive| hive.materialize_partition_columns()), + ) + .batched(chunk_size) + .await + })? } } else { let file = std::fs::File::open(path).unwrap(); @@ -75,6 +106,12 @@ impl ParquetSource { .with_row_count(file_options.row_count) .with_projection(projection) .use_statistics(options.use_statistics) + .with_hive_partition_columns( + self.file_info + .hive_parts + .as_ref() + .map(|hive| hive.materialize_partition_columns()), + ) .batched(chunk_size)? }; self.batched_reader = Some(batched_reader); @@ -86,8 +123,9 @@ impl ParquetSource { path: PathBuf, options: ParquetOptions, cloud_options: Option, + metadata: Option>, file_options: FileScanOptions, - schema: SchemaRef, + file_info: FileInfo, verbose: bool, ) -> PolarsResult { let n_threads = POOL.current_num_threads(); @@ -100,7 +138,8 @@ impl ParquetSource { file_options: Some(file_options), path: Some(path), cloud_options, - schema: Some(schema), + metadata, + file_info, verbose, }) } @@ -111,11 +150,12 @@ impl Source for ParquetSource { if self.batched_reader.is_none() { self.init_reader()?; } - let batches = self - .batched_reader - .as_mut() - .unwrap() - .next_batches(self.n_threads)?; + let batches = get_runtime().block_on( + self.batched_reader + .as_mut() + .unwrap() + .next_batches(self.n_threads), + )?; Ok(match batches { None => SourceResult::Finished, Some(batches) => SourceResult::GotMoreData( diff --git a/crates/polars-pipe/src/expressions.rs b/crates/polars-pipe/src/expressions.rs index 5c4a78cc896f..af7bbc9dab28 100644 --- a/crates/polars-pipe/src/expressions.rs +++ b/crates/polars-pipe/src/expressions.rs @@ -6,7 +6,7 @@ use polars_plan::dsl::Expr; use crate::operators::DataChunk; pub trait PhysicalPipedExpr: Send + Sync { - /// Take a `DataFrame` and produces a boolean `Series` that serves + /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves /// as a predicate mask fn evaluate(&self, chunk: &DataChunk, lazy_state: &dyn Any) -> PolarsResult; diff --git a/crates/polars-pipe/src/lib.rs b/crates/polars-pipe/src/lib.rs index 4a63657adb8a..b2724e9a8981 100644 --- a/crates/polars-pipe/src/lib.rs +++ b/crates/polars-pipe/src/lib.rs @@ -1,13 +1,8 @@ extern crate core; -#[cfg(feature = "compile")] mod executors; -#[cfg(feature = "compile")] pub mod expressions; -#[cfg(feature = "compile")] pub mod operators; -#[cfg(feature = "compile")] pub mod pipeline; -#[cfg(feature = "compile")] pub use operators::SExecutionContext; diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 1f1d6ff4b5e2..0c2d48ffa89e 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use hashbrown::hash_map::Entry; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; +use polars_ops::prelude::JoinType; use polars_plan::prelude::*; use crate::executors::operators::HstackOperator; @@ -13,7 +14,7 @@ use crate::executors::sinks::group_by::GenericGroupby2; use crate::executors::sinks::*; use crate::executors::{operators, sources}; use crate::expressions::PhysicalPipedExpr; -use crate::operators::{Operator, Sink, Source}; +use crate::operators::{Operator, Sink as SinkTrait, Source}; use crate::pipeline::PipeLine; fn exprs_to_physical( @@ -99,13 +100,15 @@ where FileScan::Parquet { options: parquet_options, cloud_options, + metadata, } => { let src = sources::ParquetSource::new( path, parquet_options, cloud_options, + metadata, file_options, - file_info.schema, + file_info, verbose, )?; Ok(Box::new(src) as Box) @@ -122,31 +125,69 @@ pub fn get_sink( lp_arena: &Arena, expr_arena: &mut Arena, to_physical: &F, -) -> PolarsResult> +) -> PolarsResult> where F: Fn(Node, &Arena, Option<&SchemaRef>) -> PolarsResult>, { use ALogicalPlan::*; let out = match lp_arena.get(node) { - FileSink { input, payload } => { - let path = payload.path.as_ref().as_path(); + Sink { input, payload } => { let input_schema = lp_arena.get(*input).schema(lp_arena); - match &payload.file_type { - #[cfg(feature = "parquet")] - FileType::Parquet(options) => { - Box::new(ParquetSink::new(path, *options, input_schema.as_ref())?) - as Box + match payload { + SinkType::Memory => { + Box::new(OrderedSink::new(input_schema.into_owned())) as Box }, - #[cfg(feature = "ipc")] - FileType::Ipc(options) => { - Box::new(IpcSink::new(path, *options, input_schema.as_ref())?) as Box + SinkType::File { + path, file_type, .. + } => { + let path = path.as_ref().as_path(); + match &file_type { + #[cfg(feature = "parquet")] + FileType::Parquet(options) => { + Box::new(ParquetSink::new(path, *options, input_schema.as_ref())?) + as Box + }, + #[cfg(feature = "ipc")] + FileType::Ipc(options) => { + Box::new(IpcSink::new(path, *options, input_schema.as_ref())?) + as Box + }, + #[cfg(feature = "csv")] + FileType::Csv(options) => { + Box::new(CsvSink::new(path, options.clone(), input_schema.as_ref())?) + as Box + }, + #[allow(unreachable_patterns)] + _ => unreachable!(), + } }, - #[cfg(feature = "csv")] - FileType::Csv(options) => { - Box::new(CsvSink::new(path, options.clone(), input_schema.as_ref())?) - as Box + #[cfg(feature = "cloud")] + SinkType::Cloud { + uri, + file_type, + cloud_options, + } => { + let uri = uri.as_ref().as_str(); + let input_schema = lp_arena.get(*input).schema(lp_arena); + let cloud_options = &cloud_options; + match &file_type { + #[cfg(feature = "parquet")] + FileType::Parquet(parquet_options) => Box::new(ParquetCloudSink::new( + uri, + cloud_options.as_ref(), + *parquet_options, + input_schema.as_ref(), + )?) + as Box, + #[cfg(feature = "ipc")] + FileType::Ipc(_ipc_options) => { + // TODO: support Ipc as well + todo!("For now, only parquet cloud files are supported"); + }, + #[allow(unreachable_patterns)] + _ => unreachable!(), + } }, - FileType::Memory => Box::new(OrderedSink::new()) as Box, } }, Join { @@ -163,7 +204,7 @@ where match &options.args.how { #[cfg(feature = "cross_join")] JoinType::Cross => { - Box::new(CrossJoin::new(options.args.suffix().into())) as Box + Box::new(CrossJoin::new(options.args.suffix().into())) as Box }, join_type @ JoinType::Inner | join_type @ JoinType::Left => { let input_schema_left = lp_arena.get(*input_left).schema(lp_arena); @@ -195,14 +236,14 @@ where swapped, join_columns_left, join_columns_right, - )) as Box + )) as Box }, _ => unimplemented!(), } }, Slice { offset, len, .. } => { let slice = SliceSink::new(*offset as u64, *len as usize); - Box::new(slice) as Box + Box::new(slice) as Box }, Sort { input, @@ -218,7 +259,7 @@ where let index = input_schema.try_index_of(by_column.as_ref())?; let sort_sink = SortSink::new(index, args.clone(), input_schema); - Box::new(sort_sink) as Box + Box::new(sort_sink) as Box } else { let sort_idx = by_column .iter() @@ -229,7 +270,7 @@ where .collect::>>()?; let sort_sink = SortSinkMultiple::new(args.clone(), input_schema, sort_idx); - Box::new(sort_sink) as Box + Box::new(sort_sink) as Box } }, Distinct { input, options } => { @@ -368,7 +409,7 @@ where input_schema, output_schema.clone(), options.slice, - )) as Box + )) as Box }) }, (DataType::Utf8, 1) => Box::new(group_by::Utf8GroupbySink::new( @@ -378,7 +419,7 @@ where input_schema, output_schema.clone(), options.slice, - )) as Box, + )) as Box, _ => Box::new(GenericGroupby2::new( key_columns, aggregation_columns, @@ -493,7 +534,7 @@ where Box::new(op) as Box }, MapFunction { - function: FunctionNode::FastProjection { columns }, + function: FunctionNode::FastProjection { columns, .. }, input, } => { let input_schema = lp_arena.get(*input).schema(lp_arena); @@ -527,7 +568,7 @@ pub fn create_pipeline( expr_arena: &mut Arena, to_physical: F, verbose: bool, - sink_cache: &mut PlHashMap>, + sink_cache: &mut PlHashMap>, ) -> PolarsResult where F: Fn(Node, &Arena, Option<&SchemaRef>) -> PolarsResult>, diff --git a/crates/polars-pipe/src/pipeline/mod.rs b/crates/polars-pipe/src/pipeline/mod.rs index eced14cd6ece..f61d5e1b329e 100644 --- a/crates/polars-pipe/src/pipeline/mod.rs +++ b/crates/polars-pipe/src/pipeline/mod.rs @@ -30,6 +30,6 @@ pub(crate) fn determine_chunk_size(n_cols: usize, n_threads: usize) -> PolarsRes ) } else { let thread_factor = std::cmp::max(12 / n_threads, 1); - Ok(std::cmp::max(50_000 / n_cols * thread_factor, 1000)) + Ok(std::cmp::max(50_000 / n_cols.max(1) * thread_factor, 1000)) } } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index aefc78e45f4f..3bd23550b514 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -12,12 +12,14 @@ description = "Lazy query engine for the Polars DataFrame library" doctest = false [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow" } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["lazy", "zip_with", "random"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", features = ["lazy", "csv"], default-features = false } -polars-ops = { version = "0.32.0", path = "../polars-ops", default-features = false } -polars-time = { version = "0.32.0", path = "../polars-time", optional = true } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +libloading = { version = "0.8.0", optional = true } +polars-arrow = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } +polars-ffi = { workspace = true, optional = true } +polars-io = { workspace = true, features = ["lazy"], default-features = false } +polars-ops = { workspace = true, features = ["zip_with"], default-features = false } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } @@ -26,6 +28,7 @@ chrono-tz = { workspace = true, optional = true } ciborium = { workspace = true, optional = true } futures = { workspace = true, optional = true } once_cell = { workspace = true } +percent-encoding = { workspace = true } pyo3 = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true, optional = true } @@ -40,9 +43,6 @@ version_check = { workspace = true } # debugging utility debugging = [] python = ["dep:pyo3", "ciborium"] -# make sure we don't compile unneeded things even though -# this dependency gets activated -compile = [] serde = [ "dep:serde", "polars-core/serde-lazy", @@ -50,10 +50,10 @@ serde = [ "polars-io/serde", "polars-ops/serde", ] -default = ["compile"] streaming = [] parquet = ["polars-core/parquet", "polars-io/parquet"] -async = ["polars-io/cloud"] +async = ["polars-io/async"] +cloud = ["async", "polars-io/cloud"] ipc = ["polars-io/ipc"] json = ["polars-io/json"] csv = ["polars-io/csv"] @@ -89,15 +89,16 @@ extract_jsonpath = ["polars-ops/extract_jsonpath"] # operations approx_unique = ["polars-ops/approx_unique"] is_in = ["polars-ops/is_in"] -repeat_by = ["polars-core/repeat_by"] +repeat_by = ["polars-ops/repeat_by"] round_series = ["polars-core/round_series"] -is_first = ["polars-core/is_first", "polars-ops/is_first"] +is_first_distinct = ["polars-core/is_first_distinct", "polars-ops/is_first_distinct"] +is_last_distinct = ["polars-core/is_last_distinct", "polars-ops/is_last_distinct"] is_unique = ["polars-ops/is_unique"] -cross_join = ["polars-core/cross_join"] +cross_join = ["polars-ops/cross_join"] asof_join = ["polars-core/asof_join", "polars-time", "polars-ops/asof_join"] -concat_str = ["polars-core/concat_str"] +concat_str = [] range = [] -mode = ["polars-core/mode"] +mode = ["polars-ops/mode"] cum_agg = ["polars-core/cum_agg"] interpolate = ["polars-ops/interpolate"] rolling_window = [ @@ -106,7 +107,7 @@ rolling_window = [ "polars-ops/rolling_window", "polars-time/rolling_window", ] -rank = ["polars-core/rank"] +rank = ["polars-ops/rank"] diff = ["polars-core/diff", "polars-ops/diff"] pct_change = ["polars-core/pct_change"] moment = ["polars-core/moment", "polars-ops/moment"] @@ -128,18 +129,22 @@ merge_sorted = ["polars-ops/merge_sorted"] meta = [] pivot = ["polars-core/rows", "polars-ops/pivot"] top_k = ["polars-ops/top_k"] -semi_anti_join = ["polars-core/semi_anti_join", "polars-ops/semi_anti_join"] +semi_anti_join = ["polars-ops/semi_anti_join"] cse = [] propagate_nans = ["polars-ops/propagate_nans"] coalesce = [] fused = ["polars-ops/fused"] list_sets = ["polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all"] +list_drop_nulls = ["polars-ops/list_drop_nulls"] cutqcut = ["polars-ops/cutqcut"] rle = ["polars-ops/rle"] extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] +ffi_plugin = ["libloading", "polars-ffi"] +hive_partitions = [] +peaks = ["polars-ops/peaks"] -bigidx = ["polars-arrow/bigidx", "polars-core/bigidx", "polars-utils/bigidx"] +bigidx = ["polars-core/bigidx"] panic_on_schema = [] diff --git a/crates/polars-plan/README.md b/crates/polars-plan/README.md new file mode 100644 index 000000000000..23d78053d6da --- /dev/null +++ b/crates/polars-plan/README.md @@ -0,0 +1,5 @@ +# polars-plan- + +`polars-plan` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, that provides source code responsible for Polars logical planning. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-plan/src/dot.rs b/crates/polars-plan/src/dot.rs index b8be9a047bec..a581eb7dafe6 100644 --- a/crates/polars-plan/src/dot.rs +++ b/crates/polars-plan/src/dot.rs @@ -131,20 +131,6 @@ impl LogicalPlan { let (mut branch, id) = id; match self { - AnonymousScan { - file_info, options, .. - } => self.write_scan( - acc_str, - prev_node, - "ANONYMOUS SCAN", - Path::new(""), - options.with_columns.as_deref().map(|cols| cols.as_slice()), - file_info.schema.len(), - &options.predicate, - branch, - id, - id_map, - ), Union { inputs, .. } => { let current_node = DotNode { branch, @@ -231,21 +217,6 @@ impl LogicalPlan { self.write_dot(acc_str, prev_node, current_node, id_map)?; input.dot(acc_str, (branch, id + 1), current_node, id_map) }, - LocalProjection { expr, input, .. } => { - let schema = input.schema().map_err(|_| { - eprintln!("could not determine schema"); - std::fmt::Error - })?; - - let fmt = format!("LOCAL π {}/{}", expr.len(), schema.len(),); - let current_node = DotNode { - branch, - id, - fmt: &fmt, - }; - self.write_dot(acc_str, prev_node, current_node, id_map)?; - input.dot(acc_str, (branch, id + 1), current_node, id_map) - }, Aggregate { input, keys, aggs, .. } => { @@ -406,11 +377,16 @@ impl LogicalPlan { self.write_dot(acc_str, prev_node, current_node, id_map)?; input.dot(acc_str, (branch, id + 1), current_node, id_map) }, - FileSink { input, .. } => { + Sink { input, payload, .. } => { let current_node = DotNode { branch, id, - fmt: "FILE_SINK", + fmt: match payload { + SinkType::Memory => "SINK (MEMORY)", + SinkType::File { .. } => "SINK (FILE)", + #[cfg(feature = "cloud")] + SinkType::Cloud { .. } => "SINK (CLOUD)", + }, }; self.write_dot(acc_str, prev_node, current_node, id_map)?; input.dot(acc_str, (branch, id + 1), current_node, id_map) diff --git a/crates/polars-plan/src/dsl/arithmetic.rs b/crates/polars-plan/src/dsl/arithmetic.rs index 5cf14575cba3..874d42d7192a 100644 --- a/crates/polars-plan/src/dsl/arithmetic.rs +++ b/crates/polars-plan/src/dsl/arithmetic.rs @@ -51,14 +51,12 @@ impl Expr { /// Raise expression to the power `exponent` pub fn pow>(self, exponent: E) -> Self { - Expr::Function { - input: vec![self, exponent.into()], - function: FunctionExpr::Pow(PowFunction::Generic), - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - ..Default::default() - }, - } + self.map_many_private( + FunctionExpr::Pow(PowFunction::Generic), + &[exponent.into()], + false, + false, + ) } /// Compute the square root of the given expression @@ -116,14 +114,7 @@ impl Expr { /// Compute the inverse tangent of the given expression, with the angle expressed as the argument of a complex number #[cfg(feature = "trigonometry")] pub fn arctan2(self, x: Self) -> Self { - Expr::Function { - input: vec![self, x], - function: FunctionExpr::Atan2, - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - ..Default::default() - }, - } + self.map_many_private(FunctionExpr::Atan2, &[x], false, false) } /// Compute the hyperbolic cosine of the given expression diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index 5a979d87d723..5c8aab05c579 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -5,21 +5,32 @@ pub struct BinaryNameSpace(pub(crate) Expr); impl BinaryNameSpace { /// Check if a binary value contains a literal binary. - pub fn contains_literal>(self, pat: S) -> Expr { - let pat = pat.as_ref().into(); - self.0 - .map_private(BinaryFunction::Contains { pat, literal: true }.into()) + pub fn contains_literal(self, pat: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::BinaryExpr(BinaryFunction::Contains), + &[pat], + true, + true, + ) } /// Check if a binary value ends with the given sequence. - pub fn ends_with>(self, sub: S) -> Expr { - let sub = sub.as_ref().into(); - self.0.map_private(BinaryFunction::EndsWith(sub).into()) + pub fn ends_with(self, sub: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::BinaryExpr(BinaryFunction::EndsWith), + &[sub], + true, + true, + ) } /// Check if a binary value starts with the given sequence. - pub fn starts_with>(self, sub: S) -> Expr { - let sub = sub.as_ref().into(); - self.0.map_private(BinaryFunction::StartsWith(sub).into()) + pub fn starts_with(self, sub: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::BinaryExpr(BinaryFunction::StartsWith), + &[sub], + true, + true, + ) } } diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index 7e630c0fcfa2..84e9023f5994 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -1,7 +1,4 @@ -use polars_time::prelude::TemporalMethods; - use super::*; -use crate::prelude::function_expr::TemporalFunction; /// Specialized expressions for [`Series`] with dates/datetimes. pub struct DateLikeNameSpace(pub(crate) Expr); @@ -11,10 +8,10 @@ impl DateLikeNameSpace { /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). pub fn to_string(self, format: &str) -> Expr { let format = format.to_string(); - let function = move |s: Series| TemporalMethods::to_string(&s, &format).map(Some); self.0 - .map(function, GetOutput::from_type(DataType::Utf8)) - .with_fmt("to_string") + .map_private(FunctionExpr::TemporalExpr(TemporalFunction::ToString( + format, + ))) } /// Convert from Date/Time/Datetime into Utf8 with the given format. @@ -27,70 +24,26 @@ impl DateLikeNameSpace { /// Change the underlying [`TimeUnit`]. And update the data accordingly. pub fn cast_time_unit(self, tu: TimeUnit) -> Expr { - self.0.map( - move |s| match s.dtype() { - DataType::Datetime(_, _) => { - let ca = s.datetime().unwrap(); - Ok(Some(ca.cast_time_unit(tu).into_series())) - }, - #[cfg(feature = "dtype-duration")] - DataType::Duration(_) => { - let ca = s.duration().unwrap(); - Ok(Some(ca.cast_time_unit(tu).into_series())) - }, - dt => polars_bail!(ComputeError: "dtype `{}` has no time unit", dt), - }, - GetOutput::map_dtype(move |dtype| match dtype { - DataType::Duration(_) => DataType::Duration(tu), - DataType::Datetime(_, tz) => DataType::Datetime(tu, tz.clone()), - _ => panic!("expected duration or datetime"), - }), - ) + self.0 + .map_private(FunctionExpr::TemporalExpr(TemporalFunction::CastTimeUnit( + tu, + ))) } /// Change the underlying [`TimeUnit`] of the [`Series`]. This does not modify the data. pub fn with_time_unit(self, tu: TimeUnit) -> Expr { - self.0.map( - move |s| match s.dtype() { - DataType::Datetime(_, _) => { - let mut ca = s.datetime().unwrap().clone(); - ca.set_time_unit(tu); - Ok(Some(ca.into_series())) - }, - #[cfg(feature = "dtype-duration")] - DataType::Duration(_) => { - let mut ca = s.duration().unwrap().clone(); - ca.set_time_unit(tu); - Ok(Some(ca.into_series())) - }, - dt => polars_bail!(ComputeError: "dtype `{}` has no time unit", dt), - }, - GetOutput::same_type(), - ) + self.0 + .map_private(FunctionExpr::TemporalExpr(TemporalFunction::WithTimeUnit( + tu, + ))) } /// Change the underlying [`TimeZone`] of the [`Series`]. This does not modify the data. #[cfg(feature = "timezones")] pub fn convert_time_zone(self, time_zone: TimeZone) -> Expr { - let time_zone_clone = time_zone.clone(); - self.0.map( - move |s| match s.dtype() { - DataType::Datetime(_, Some(_)) => { - let mut ca = s.datetime().unwrap().clone(); - ca.set_time_zone(time_zone.clone())?; - Ok(Some(ca.into_series())) - }, - _ => polars_bail!( - ComputeError: - "cannot call `convert_time_zone` on tz-naive; set a time zone first \ - with `replace_time_zone`" - ), - }, - GetOutput::map_dtype(move |dtype| match dtype { - DataType::Datetime(tu, _) => DataType::Datetime(*tu, Some(time_zone_clone.clone())), - _ => panic!("expected datetime"), - }), - ) + self.0.map_private(FunctionExpr::TemporalExpr( + TemporalFunction::ConvertTimeZone(time_zone), + )) } /// Get the year of a Date/Datetime @@ -215,10 +168,11 @@ impl DateLikeNameSpace { .map_private(FunctionExpr::TemporalExpr(TemporalFunction::TimeStamp(tu))) } - pub fn truncate(self, options: TruncateOptions, ambiguous: Expr) -> Expr { + pub fn truncate(self, every: Expr, offset: String, ambiguous: Expr) -> Expr { self.0.map_many_private( - FunctionExpr::TemporalExpr(TemporalFunction::Truncate(options)), - &[ambiguous], + FunctionExpr::TemporalExpr(TemporalFunction::Truncate(offset)), + &[every, ambiguous], + true, false, ) } @@ -251,20 +205,23 @@ impl DateLikeNameSpace { .map_private(FunctionExpr::TemporalExpr(TemporalFunction::DSTOffset)) } - pub fn round>(self, every: S, offset: S) -> Expr { + pub fn round>(self, every: S, offset: S, ambiguous: Expr) -> Expr { let every = every.as_ref().into(); let offset = offset.as_ref().into(); - self.0 - .map_private(FunctionExpr::TemporalExpr(TemporalFunction::Round( - every, offset, - ))) + self.0.map_many_private( + FunctionExpr::TemporalExpr(TemporalFunction::Round(every, offset)), + &[ambiguous], + false, + false, + ) } /// Offset this `Date/Datetime` by a given offset [`Duration`]. /// This will take leap years/ months into account. #[cfg(feature = "date_offset")] - pub fn offset_by(self, by: Duration) -> Expr { - self.0.map_private(FunctionExpr::DateOffset(by)) + pub fn offset_by(self, by: Expr) -> Expr { + self.0 + .map_many_private(FunctionExpr::DateOffset, &[by], false, false) } #[cfg(feature = "timezones")] @@ -273,6 +230,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::ReplaceTimeZone(time_zone)), &[ambiguous], false, + false, ) } @@ -281,6 +239,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::Combine(tu)), &[time], false, + false, ) } } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 18eecfe055e5..dd7b2eeec8fa 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -120,8 +120,7 @@ pub enum Expr { /// Also has the input. i.e. avg("foo") function: Box, partition_by: Vec, - order_by: Option>, - options: WindowOptions, + options: WindowType, }, Wildcard, Slice { @@ -154,6 +153,7 @@ pub enum Expr { output_type: GetOutput, options: FunctionOptions, }, + SubPlan(SpecialEq>, Vec), /// Expressions in this node should only be expanding /// e.g. /// `Expr::Columns` diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 2c8731216aca..4d8c81d45e69 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -11,6 +11,10 @@ use super::*; /// A wrapper trait for any closure `Fn(Vec) -> PolarsResult` pub trait SeriesUdf: Send + Sync { + fn as_any(&self) -> &dyn std::any::Any { + unimplemented!("as_any not implemented for this 'opaque' function") + } + fn call_udf(&self, s: &mut [Series]) -> PolarsResult>; fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { @@ -162,6 +166,27 @@ impl<'a> Deserialize<'a> for SpecialEq { } } +#[cfg(feature = "serde")] +impl Serialize for SpecialEq> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + self.0.serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for SpecialEq> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + let t = LogicalPlan::deserialize(deserializer)?; + Ok(SpecialEq(Arc::new(t))) + } +} + impl SpecialEq { pub fn new(val: T) -> Self { SpecialEq(val) diff --git a/crates/polars-plan/src/dsl/function_expr/binary.rs b/crates/polars-plan/src/dsl/function_expr/binary.rs index 8ca4c7eaa256..0aa8688dde13 100644 --- a/crates/polars-plan/src/dsl/function_expr/binary.rs +++ b/crates/polars-plan/src/dsl/function_expr/binary.rs @@ -6,9 +6,9 @@ use super::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, PartialEq, Debug, Eq, Hash)] pub enum BinaryFunction { - Contains { pat: Vec, literal: bool }, - StartsWith(Vec), - EndsWith(Vec), + Contains, + StartsWith, + EndsWith, } impl Display for BinaryFunction { @@ -16,29 +16,36 @@ impl Display for BinaryFunction { use BinaryFunction::*; let s = match self { Contains { .. } => "contains", - StartsWith(_) => "starts_with", - EndsWith(_) => "ends_with", + StartsWith => "starts_with", + EndsWith => "ends_with", }; write!(f, "bin.{s}") } } -pub(super) fn contains(s: &Series, pat: &[u8], literal: bool) -> PolarsResult { - let ca = s.binary()?; - if literal { - ca.contains_literal(pat).map(|ca| ca.into_series()) - } else { - ca.contains(pat).map(|ca| ca.into_series()) - } +pub(super) fn contains(s: &[Series]) -> PolarsResult { + let ca = s[0].binary()?; + let lit = s[1].binary()?; + Ok(ca.contains_chunked(lit).with_name(ca.name()).into_series()) } -pub(super) fn ends_with(s: &Series, sub: &[u8]) -> PolarsResult { - let ca = s.binary()?; - Ok(ca.ends_with(sub).into_series()) +pub(super) fn ends_with(s: &[Series]) -> PolarsResult { + let ca = s[0].binary()?; + let suffix = s[1].binary()?; + + Ok(ca + .ends_with_chunked(suffix) + .with_name(ca.name()) + .into_series()) } -pub(super) fn starts_with(s: &Series, sub: &[u8]) -> PolarsResult { - let ca = s.binary()?; - Ok(ca.starts_with(sub).into_series()) +pub(super) fn starts_with(s: &[Series]) -> PolarsResult { + let ca = s[0].binary()?; + let prefix = s[1].binary()?; + + Ok(ca + .starts_with_chunked(prefix) + .with_name(ca.name()) + .into_series()) } impl From for FunctionExpr { diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs index 2f534f95de93..ed829773ce66 100644 --- a/crates/polars-plan/src/dsl/function_expr/boolean.rs +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -1,10 +1,7 @@ -use std::ops::{BitAnd, BitOr, Not}; - -use polars_core::POOL; -use rayon::prelude::*; +use std::ops::Not; use super::*; -use crate::{map, wrap}; +use crate::{map, map_as_slice, wrap}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, PartialEq, Debug, Eq, Hash)] @@ -15,15 +12,16 @@ pub enum BooleanFunction { All { ignore_nulls: bool, }, - IsNot, IsNull, IsNotNull, IsFinite, IsInfinite, IsNan, IsNotNan, - #[cfg(feature = "is_first")] - IsFirst, + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct, + #[cfg(feature = "is_last_distinct")] + IsLastDistinct, #[cfg(feature = "is_unique")] IsUnique, #[cfg(feature = "is_unique")] @@ -32,6 +30,7 @@ pub enum BooleanFunction { IsIn, AllHorizontal, AnyHorizontal, + Not, } impl BooleanFunction { @@ -51,15 +50,16 @@ impl Display for BooleanFunction { let s = match self { All { .. } => "all", Any { .. } => "any", - IsNot => "is_not", IsNull => "is_null", IsNotNull => "is_not_null", IsFinite => "is_finite", IsInfinite => "is_infinite", IsNan => "is_nan", IsNotNan => "is_not_nan", - #[cfg(feature = "is_first")] - IsFirst => "is_first", + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct => "is_first_distinct", + #[cfg(feature = "is_last_distinct")] + IsLastDistinct => "is_last_distinct", #[cfg(feature = "is_unique")] IsUnique => "is_unique", #[cfg(feature = "is_unique")] @@ -68,6 +68,7 @@ impl Display for BooleanFunction { IsIn => "is_in", AnyHorizontal => "any_horizontal", AllHorizontal => "all_horizontal", + Not => "not_", }; write!(f, "{s}") } @@ -79,23 +80,25 @@ impl From for SpecialEq> { match func { Any { ignore_nulls } => map!(any, ignore_nulls), All { ignore_nulls } => map!(all, ignore_nulls), - IsNot => map!(is_not), IsNull => map!(is_null), IsNotNull => map!(is_not_null), IsFinite => map!(is_finite), IsInfinite => map!(is_infinite), IsNan => map!(is_nan), IsNotNan => map!(is_not_nan), - #[cfg(feature = "is_first")] - IsFirst => map!(is_first), + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct => map!(is_first_distinct), + #[cfg(feature = "is_last_distinct")] + IsLastDistinct => map!(is_last_distinct), #[cfg(feature = "is_unique")] IsUnique => map!(is_unique), #[cfg(feature = "is_unique")] IsDuplicated => map!(is_duplicated), #[cfg(feature = "is_in")] IsIn => wrap!(is_in), - AllHorizontal => wrap!(all_horizontal), - AnyHorizontal => wrap!(any_horizontal), + AllHorizontal => map_as_slice!(all_horizontal), + AnyHorizontal => map_as_slice!(any_horizontal), + Not => map!(not_), } } } @@ -124,10 +127,6 @@ fn all(s: &Series, ignore_nulls: bool) -> PolarsResult { } } -fn is_not(s: &Series) -> PolarsResult { - Ok(s.bool()?.not().into_series()) -} - fn is_null(s: &Series) -> PolarsResult { Ok(s.is_null().into_series()) } @@ -152,9 +151,14 @@ pub(super) fn is_not_nan(s: &Series) -> PolarsResult { s.is_not_nan().map(|ca| ca.into_series()) } -#[cfg(feature = "is_first")] -fn is_first(s: &Series) -> PolarsResult { - polars_ops::prelude::is_first(s).map(|ca| ca.into_series()) +#[cfg(feature = "is_first_distinct")] +fn is_first_distinct(s: &Series) -> PolarsResult { + polars_ops::prelude::is_first_distinct(s).map(|ca| ca.into_series()) +} + +#[cfg(feature = "is_last_distinct")] +fn is_last_distinct(s: &Series) -> PolarsResult { + polars_ops::prelude::is_last_distinct(s).map(|ca| ca.into_series()) } #[cfg(feature = "is_unique")] @@ -174,36 +178,14 @@ fn is_in(s: &mut [Series]) -> PolarsResult> { polars_ops::prelude::is_in(left, other).map(|ca| Some(ca.into_series())) } -fn any_horizontal(s: &mut [Series]) -> PolarsResult> { - let mut out = POOL.install(|| { - s.par_iter() - .try_fold( - || BooleanChunked::new("", &[false]), - |acc, b| { - let b = b.cast(&DataType::Boolean)?; - let b = b.bool()?; - PolarsResult::Ok((&acc).bitor(b)) - }, - ) - .try_reduce(|| BooleanChunked::new("", [false]), |a, b| Ok(a.bitor(b))) - })?; - out.rename("any"); - Ok(Some(out.into_series())) -} - -fn all_horizontal(s: &mut [Series]) -> PolarsResult> { - let mut out = POOL.install(|| { - s.par_iter() - .try_fold( - || BooleanChunked::new("", &[true]), - |acc, b| { - let b = b.cast(&DataType::Boolean)?; - let b = b.bool()?; - PolarsResult::Ok((&acc).bitand(b)) - }, - ) - .try_reduce(|| BooleanChunked::new("", [true]), |a, b| Ok(a.bitand(b))) - })?; - out.rename("all"); - Ok(Some(out.into_series())) +fn any_horizontal(s: &[Series]) -> PolarsResult { + polars_ops::prelude::any_horizontal(s) +} + +fn all_horizontal(s: &[Series]) -> PolarsResult { + polars_ops::prelude::all_horizontal(s) +} + +fn not_(s: &Series) -> PolarsResult { + Ok(s.bool()?.not().into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/cat.rs b/crates/polars-plan/src/dsl/function_expr/cat.rs index 98e4410c511d..1da5edc225f9 100644 --- a/crates/polars-plan/src/dsl/function_expr/cat.rs +++ b/crates/polars-plan/src/dsl/function_expr/cat.rs @@ -25,7 +25,7 @@ impl Display for CategoricalFunction { SetOrdering { .. } => "set_ordering", GetCategories => "get_categories", }; - write!(f, "{s}") + write!(f, "cat.{s}") } } diff --git a/crates/polars-plan/src/dsl/function_expr/clip.rs b/crates/polars-plan/src/dsl/function_expr/clip.rs index 97ebaf326813..2f643857e1a2 100644 --- a/crates/polars-plan/src/dsl/function_expr/clip.rs +++ b/crates/polars-plan/src/dsl/function_expr/clip.rs @@ -1,14 +1,10 @@ use super::*; -pub(super) fn clip( - s: Series, - min: Option>, - max: Option>, -) -> PolarsResult { - match (min, max) { - (Some(min), Some(max)) => s.clip(min, max), - (Some(min), None) => s.clip_min(min), - (None, Some(max)) => s.clip_max(max), +pub(super) fn clip(s: &[Series], has_min: bool, has_max: bool) -> PolarsResult { + match (has_min, has_max) { + (true, true) => polars_ops::prelude::clip(&s[0], &s[1], &s[2]), + (true, false) => polars_ops::prelude::clip_min(&s[0], &s[1]), + (false, true) => polars_ops::prelude::clip_max(&s[0], &s[1]), _ => unreachable!(), } } diff --git a/crates/polars-plan/src/dsl/function_expr/coerce.rs b/crates/polars-plan/src/dsl/function_expr/coerce.rs new file mode 100644 index 000000000000..00c180d0ba4a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/coerce.rs @@ -0,0 +1,6 @@ +use polars_core::prelude::*; + +#[cfg(feature = "dtype-struct")] +pub fn as_struct(s: &[Series]) -> PolarsResult { + Ok(StructChunked::new(s[0].name(), s)?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/correlation.rs b/crates/polars-plan/src/dsl/function_expr/correlation.rs index 86777305d2e3..c56114487501 100644 --- a/crates/polars-plan/src/dsl/function_expr/correlation.rs +++ b/crates/polars-plan/src/dsl/function_expr/correlation.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use super::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Copy, Clone, PartialEq, Debug)] +#[derive(Copy, Clone, PartialEq, Debug, Hash)] pub enum CorrelationMethod { Pearson, #[cfg(all(feature = "rank", feature = "propagate_nans"))] @@ -41,46 +41,21 @@ fn covariance(s: &[Series]) -> PolarsResult { let b = &s[1]; let name = "cov"; - let s = match a.dtype() { - DataType::Float32 => { - let ca_a = a.f32().unwrap(); - let ca_b = b.f32().unwrap(); - Series::new(name, &[polars_core::functions::cov_f(ca_a, ca_b)]) - }, - DataType::Float64 => { - let ca_a = a.f64().unwrap(); - let ca_b = b.f64().unwrap(); - Series::new(name, &[polars_core::functions::cov_f(ca_a, ca_b)]) - }, - DataType::Int32 => { - let ca_a = a.i32().unwrap(); - let ca_b = b.i32().unwrap(); - Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)]) - }, - DataType::Int64 => { - let ca_a = a.i64().unwrap(); - let ca_b = b.i64().unwrap(); - Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)]) - }, - DataType::UInt32 => { - let ca_a = a.u32().unwrap(); - let ca_b = b.u32().unwrap(); - Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)]) - }, - DataType::UInt64 => { - let ca_a = a.u64().unwrap(); - let ca_b = b.u64().unwrap(); - Series::new(name, &[polars_core::functions::cov_i(ca_a, ca_b)]) - }, + use polars_core::functions::cov; + let ret = match a.dtype() { + DataType::Float32 => cov(a.f32().unwrap(), b.f32().unwrap()), + DataType::Float64 => cov(a.f64().unwrap(), b.f64().unwrap()), + DataType::Int32 => cov(a.i32().unwrap(), b.i32().unwrap()), + DataType::Int64 => cov(a.i64().unwrap(), b.i64().unwrap()), + DataType::UInt32 => cov(a.u32().unwrap(), b.u32().unwrap()), + DataType::UInt64 => cov(a.u64().unwrap(), b.u64().unwrap()), _ => { let a = a.cast(&DataType::Float64)?; let b = b.cast(&DataType::Float64)?; - let ca_a = a.f64().unwrap(); - let ca_b = b.f64().unwrap(); - Series::new(name, &[polars_core::functions::cov_f(ca_a, ca_b)]) + cov(a.f64().unwrap(), b.f64().unwrap()) }, }; - Ok(s) + Ok(Series::new(name, &[ret])) } fn pearson_corr(s: &[Series], ddof: u8) -> PolarsResult { @@ -88,67 +63,20 @@ fn pearson_corr(s: &[Series], ddof: u8) -> PolarsResult { let b = &s[1]; let name = "pearson_corr"; - let s = match a.dtype() { - DataType::Float32 => { - let ca_a = a.f32().unwrap(); - let ca_b = b.f32().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_f(ca_a, ca_b, ddof)], - ) - }, - DataType::Float64 => { - let ca_a = a.f64().unwrap(); - let ca_b = b.f64().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_f(ca_a, ca_b, ddof)], - ) - }, - DataType::Int32 => { - let ca_a = a.i32().unwrap(); - let ca_b = b.i32().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)], - ) - }, - DataType::Int64 => { - let ca_a = a.i64().unwrap(); - let ca_b = b.i64().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)], - ) - }, - DataType::UInt32 => { - let ca_a = a.u32().unwrap(); - let ca_b = b.u32().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)], - ) - }, - DataType::UInt64 => { - let ca_a = a.u64().unwrap(); - let ca_b = b.u64().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_i(ca_a, ca_b, ddof)], - ) - }, + use polars_core::functions::pearson_corr; + let ret = match a.dtype() { + DataType::Float32 => pearson_corr(a.f32().unwrap(), b.f32().unwrap(), ddof), + DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap(), ddof), + DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap(), ddof), + DataType::Int64 => pearson_corr(a.i64().unwrap(), b.i64().unwrap(), ddof), + DataType::UInt32 => pearson_corr(a.u32().unwrap(), b.u32().unwrap(), ddof), _ => { let a = a.cast(&DataType::Float64)?; let b = b.cast(&DataType::Float64)?; - let ca_a = a.f64().unwrap(); - let ca_b = b.f64().unwrap(); - Series::new( - name, - &[polars_core::functions::pearson_corr_f(ca_a, ca_b, ddof)], - ) + pearson_corr(a.f64().unwrap(), b.f64().unwrap(), ddof) }, }; - Ok(s) + Ok(Series::new(name, &[ret])) } #[cfg(all(feature = "rank", feature = "propagate_nans"))] diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs index 143fc14d80bc..858d3cbf8cba 100644 --- a/crates/polars-plan/src/dsl/function_expr/datetime.rs +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -30,8 +30,13 @@ pub enum TemporalFunction { Millisecond, Microsecond, Nanosecond, + ToString(String), + CastTimeUnit(TimeUnit), + WithTimeUnit(TimeUnit), + #[cfg(feature = "timezones")] + ConvertTimeZone(TimeZone), TimeStamp(TimeUnit), - Truncate(TruncateOptions), + Truncate(String), #[cfg(feature = "date_offset")] MonthStart, #[cfg(feature = "date_offset")] @@ -43,26 +48,6 @@ pub enum TemporalFunction { Round(String, String), #[cfg(feature = "timezones")] ReplaceTimeZone(Option), - DateRange { - every: Duration, - closed: ClosedWindow, - time_unit: Option, - time_zone: Option, - }, - DateRanges { - every: Duration, - closed: ClosedWindow, - time_unit: Option, - time_zone: Option, - }, - TimeRange { - every: Duration, - closed: ClosedWindow, - }, - TimeRanges { - every: Duration, - closed: ClosedWindow, - }, Combine(TimeUnit), DatetimeFunction { time_unit: TimeUnit, @@ -70,6 +55,63 @@ pub enum TemporalFunction { }, } +impl TemporalFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use TemporalFunction::*; + match self { + Year | IsoYear => mapper.with_dtype(DataType::Int32), + Month | Quarter | Week | WeekDay | Day | OrdinalDay | Hour | Minute | Millisecond + | Microsecond | Nanosecond | Second => mapper.with_dtype(DataType::UInt32), + ToString(_) => mapper.with_dtype(DataType::Utf8), + WithTimeUnit(_) => mapper.with_same_dtype(), + CastTimeUnit(tu) => mapper.try_map_dtype(|dt| match dt { + DataType::Duration(_) => Ok(DataType::Duration(*tu)), + DataType::Datetime(_, tz) => Ok(DataType::Datetime(*tu, tz.clone())), + dtype => polars_bail!(ComputeError: "expected duration or datetime, got {}", dtype), + }), + #[cfg(feature = "timezones")] + ConvertTimeZone(tz) => mapper.try_map_dtype(|dt| match dt { + DataType::Datetime(tu, _) => Ok(DataType::Datetime(*tu, Some(tz.clone()))), + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), + }), + TimeStamp(_) => mapper.with_dtype(DataType::Int64), + IsLeapYear => mapper.with_dtype(DataType::Boolean), + Time => mapper.with_dtype(DataType::Time), + Date => mapper.with_dtype(DataType::Date), + Datetime => mapper.try_map_dtype(|dt| match dt { + DataType::Datetime(tu, _) => Ok(DataType::Datetime(*tu, None)), + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), + }), + Truncate(_) => mapper.with_same_dtype(), + #[cfg(feature = "date_offset")] + MonthStart => mapper.with_same_dtype(), + #[cfg(feature = "date_offset")] + MonthEnd => mapper.with_same_dtype(), + #[cfg(feature = "timezones")] + BaseUtcOffset => mapper.with_dtype(DataType::Duration(TimeUnit::Milliseconds)), + #[cfg(feature = "timezones")] + DSTOffset => mapper.with_dtype(DataType::Duration(TimeUnit::Milliseconds)), + Round(..) => mapper.with_same_dtype(), + #[cfg(feature = "timezones")] + ReplaceTimeZone(tz) => mapper.map_datetime_dtype_timezone(tz.as_ref()), + DatetimeFunction { + time_unit, + time_zone, + } => Ok(Field::new( + "datetime", + DataType::Datetime(*time_unit, time_zone.clone()), + )), + Combine(tu) => mapper.try_map_dtype(|dt| match dt { + DataType::Datetime(_, tz) => Ok(DataType::Datetime(*tu, tz.clone())), + DataType::Date => Ok(DataType::Datetime(*tu, None)), + dtype => { + polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype) + }, + }), + } + } +} + impl Display for TemporalFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use TemporalFunction::*; @@ -92,6 +134,11 @@ impl Display for TemporalFunction { Millisecond => "millisecond", Microsecond => "microsecond", Nanosecond => "nanosecond", + ToString(_) => "to_string", + #[cfg(feature = "timezones")] + ConvertTimeZone(_) => "convert_time_zone", + CastTimeUnit(_) => "cast_time_unit", + WithTimeUnit(_) => "with_time_unit", TimeStamp(tu) => return write!(f, "dt.timestamp({tu})"), Truncate(..) => "truncate", #[cfg(feature = "date_offset")] @@ -105,11 +152,7 @@ impl Display for TemporalFunction { Round(..) => "round", #[cfg(feature = "timezones")] ReplaceTimeZone(_) => "replace_time_zone", - DateRange { .. } => return write!(f, "date_range"), - DateRanges { .. } => return write!(f, "date_ranges"), - TimeRange { .. } => return write!(f, "time_range"), - TimeRanges { .. } => return write!(f, "time_ranges"), - DatetimeFunction { .. } => return write!(f, "datetime"), + DatetimeFunction { .. } => return write!(f, "dt.datetime"), Combine(_) => "combine", }; write!(f, "dt.{s}") @@ -224,28 +267,75 @@ pub(super) fn nanosecond(s: &Series) -> PolarsResult { pub(super) fn timestamp(s: &Series, tu: TimeUnit) -> PolarsResult { s.timestamp(tu).map(|ca| ca.into_series()) } +pub(super) fn to_string(s: &Series, format: &str) -> PolarsResult { + TemporalMethods::to_string(s, format) +} +#[cfg(feature = "timezones")] +pub(super) fn convert_time_zone(s: &Series, time_zone: &TimeZone) -> PolarsResult { + match s.dtype() { + DataType::Datetime(_, Some(_)) => { + let mut ca = s.datetime()?.clone(); + ca.set_time_zone(time_zone.clone())?; + Ok(ca.into_series()) + }, + _ => polars_bail!( + ComputeError: + "cannot call `convert_time_zone` on tz-naive; set a time zone first \ + with `replace_time_zone`" + ), + } +} +pub(super) fn with_time_unit(s: &Series, tu: TimeUnit) -> PolarsResult { + match s.dtype() { + DataType::Datetime(_, _) => { + let mut ca = s.datetime()?.clone(); + ca.set_time_unit(tu); + Ok(ca.into_series()) + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => { + let mut ca = s.duration()?.clone(); + ca.set_time_unit(tu); + Ok(ca.into_series()) + }, + dt => polars_bail!(ComputeError: "dtype `{}` has no time unit", dt), + } +} +pub(super) fn cast_time_unit(s: &Series, tu: TimeUnit) -> PolarsResult { + match s.dtype() { + DataType::Datetime(_, _) => { + let ca = s.datetime()?; + Ok(ca.cast_time_unit(tu).into_series()) + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => { + let ca = s.duration()?; + Ok(ca.cast_time_unit(tu).into_series()) + }, + dt => polars_bail!(ComputeError: "dtype `{}` has no time unit", dt), + } +} -pub(super) fn truncate(s: &[Series], options: &TruncateOptions) -> PolarsResult { +pub(super) fn truncate(s: &[Series], offset: &str) -> PolarsResult { let time_series = &s[0]; - let ambiguous = &s[1].utf8().unwrap(); + let every = s[1].utf8()?; + let ambiguous = s[2].utf8()?; + let mut out = match time_series.dtype() { DataType::Datetime(_, tz) => match tz { #[cfg(feature = "timezones")] Some(tz) => time_series - .datetime() - .unwrap() - .truncate(options, tz.parse::().ok().as_ref(), ambiguous)? + .datetime()? + .truncate(tz.parse::().ok().as_ref(), every, offset, ambiguous)? .into_series(), _ => time_series - .datetime() - .unwrap() - .truncate(options, None, ambiguous)? + .datetime()? + .truncate(None, every, offset, ambiguous)? .into_series(), }, DataType::Date => time_series - .date() - .unwrap() - .truncate(options, None, ambiguous)? + .date()? + .truncate(None, every, offset, ambiguous)? .into_series(), dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), }; @@ -320,24 +410,32 @@ pub(super) fn dst_offset(s: &Series) -> PolarsResult { } } -pub(super) fn round(s: &Series, every: &str, offset: &str) -> PolarsResult { +pub(super) fn round(s: &[Series], every: &str, offset: &str) -> PolarsResult { let every = Duration::parse(every); let offset = Duration::parse(offset); - Ok(match s.dtype() { + + let time_series = &s[0]; + let ambiguous = s[1].utf8()?; + + Ok(match time_series.dtype() { DataType::Datetime(_, tz) => match tz { #[cfg(feature = "timezones")] - Some(tz) => s + Some(tz) => time_series .datetime() .unwrap() - .round(every, offset, tz.parse::().ok().as_ref())? + .round(every, offset, tz.parse::().ok().as_ref(), ambiguous)? .into_series(), - _ => s + _ => time_series .datetime() .unwrap() - .round(every, offset, None)? + .round(every, offset, None, ambiguous)? .into_series(), }, - DataType::Date => s.date().unwrap().round(every, offset, None)?.into_series(), + DataType::Date => time_series + .date() + .unwrap() + .round(every, offset, None, ambiguous)? + .into_series(), dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), }) } diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index d0f652c7be59..bbe11390fb01 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -40,3 +40,34 @@ pub(super) fn replace_time_zone(s: &[Series], time_zone: Option<&str>) -> Polars let s2 = &s[1].utf8().unwrap(); Ok(polars_ops::prelude::replace_time_zone(ca, time_zone, s2)?.into_series()) } + +#[cfg(feature = "dtype-struct")] +pub(super) fn value_counts(s: &Series, sort: bool, parallel: bool) -> PolarsResult { + s.value_counts(sort, parallel) + .map(|df| df.into_struct(s.name()).into_series()) +} + +#[cfg(feature = "unique_counts")] +pub(super) fn unique_counts(s: &Series) -> PolarsResult { + Ok(s.unique_counts().into_series()) +} + +pub(super) fn backward_fill(s: &Series, limit: FillNullLimit) -> PolarsResult { + s.fill_null(FillNullStrategy::Backward(limit)) +} + +pub(super) fn forward_fill(s: &Series, limit: FillNullLimit) -> PolarsResult { + s.fill_null(FillNullStrategy::Forward(limit)) +} + +pub(super) fn sum_horizontal(s: &[Series]) -> PolarsResult { + polars_ops::prelude::sum_horizontal(s) +} + +pub(super) fn max_horizontal(s: &mut [Series]) -> PolarsResult> { + polars_ops::prelude::max_horizontal(s) +} + +pub(super) fn min_horizontal(s: &mut [Series]) -> PolarsResult> { + polars_ops::prelude::min_horizontal(s) +} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index a2f4dca9f007..8cf674fd2dc5 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -9,19 +9,37 @@ pub enum ListFunction { Concat, #[cfg(feature = "is_in")] Contains, + #[cfg(feature = "list_drop_nulls")] + DropNulls, Slice, + Shift, Get, #[cfg(feature = "list_take")] Take(bool), #[cfg(feature = "list_count")] - CountMatch, + CountMatches, Sum, + Length, + Max, + Min, + Mean, + ArgMin, + ArgMax, + #[cfg(feature = "diff")] + Diff { + n: i64, + null_behavior: NullBehavior, + }, + Sort(SortOptions), + Reverse, + Unique(bool), #[cfg(feature = "list_sets")] SetOperation(SetOperation), #[cfg(feature = "list_any_all")] Any, #[cfg(feature = "list_any_all")] All, + Join, } impl Display for ListFunction { @@ -32,35 +50,63 @@ impl Display for ListFunction { Concat => "concat", #[cfg(feature = "is_in")] Contains => "contains", + #[cfg(feature = "list_drop_nulls")] + DropNulls => "drop_nulls", Slice => "slice", + Shift => "shift", Get => "get", #[cfg(feature = "list_take")] Take(_) => "take", #[cfg(feature = "list_count")] - CountMatch => "count", + CountMatches => "count", Sum => "sum", + Min => "min", + Max => "max", + Mean => "mean", + ArgMin => "arg_min", + ArgMax => "arg_max", + #[cfg(feature = "diff")] + Diff { .. } => "diff", + Length => "length", + Sort(_) => "sort", + Reverse => "reverse", + Unique(is_stable) => { + if *is_stable { + "unique_stable" + } else { + "unique" + } + }, #[cfg(feature = "list_sets")] - SetOperation(s) => return write!(f, "{s}"), + SetOperation(s) => return write!(f, "list.{s}"), #[cfg(feature = "list_any_all")] Any => "any", #[cfg(feature = "list_any_all")] All => "all", + Join => "join", }; - write!(f, "{name}") + write!(f, "list.{name}") } } #[cfg(feature = "is_in")] pub(super) fn contains(args: &mut [Series]) -> PolarsResult> { let list = &args[0]; - let is_in = &args[1]; + let item = &args[1]; - polars_ops::prelude::is_in(is_in, list).map(|mut ca| { + polars_ops::prelude::is_in(item, list).map(|mut ca| { ca.rename(list.name()); Some(ca.into_series()) }) } +#[cfg(feature = "list_drop_nulls")] +pub(super) fn drop_nulls(s: &Series) -> PolarsResult { + let list = s.list()?; + + Ok(list.lst_drop_nulls().into_series()) +} + fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> { polars_ensure!( slice_len == ca_len, @@ -71,6 +117,13 @@ fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsR Ok(()) } +pub(super) fn shift(s: &[Series]) -> PolarsResult { + let list = s[0].list()?; + let periods = &s[1]; + + list.lst_shift(periods).map(|ok| ok.into_series()) +} + pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { let s = &args[0]; let list_ca = s.list()?; @@ -95,14 +148,17 @@ pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { let length_ca = length_s.cast(&DataType::Int64)?; let length_ca = length_ca.i64().unwrap(); - list_ca - .amortized_iter() - .zip(length_ca) - .map(|(opt_s, opt_length)| match (opt_s, opt_length) { - (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)), - _ => None, - }) - .collect_trusted() + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + list_ca + .amortized_iter() + .zip(length_ca) + .map(|(opt_s, opt_length)| match (opt_s, opt_length) { + (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)), + _ => None, + }) + .collect_trusted() + } }, (offset_len, 1) => { check_slice_arg_shape(offset_len, list_ca.len(), "offset")?; @@ -113,14 +169,17 @@ pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { .unwrap_or(usize::MAX); let offset_ca = offset_s.cast(&DataType::Int64)?; let offset_ca = offset_ca.i64().unwrap(); - list_ca - .amortized_iter() - .zip(offset_ca) - .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) { - (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)), - _ => None, - }) - .collect_trusted() + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + list_ca + .amortized_iter() + .zip(offset_ca) + .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) { + (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)), + _ => None, + }) + .collect_trusted() + } }, _ => { check_slice_arg_shape(offset_s.len(), list_ca.len(), "offset")?; @@ -132,19 +191,22 @@ pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { let length_ca = length_s.cast(&DataType::Int64)?; let length_ca = length_ca.i64().unwrap(); - list_ca - .amortized_iter() - .zip(offset_ca) - .zip(length_ca) - .map( - |((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) { - (Some(s), Some(offset), Some(length)) => { - Some(s.as_ref().slice(offset, length as usize)) - }, - _ => None, - }, - ) - .collect_trusted() + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + list_ca + .amortized_iter() + .zip(offset_ca) + .zip(length_ca) + .map(|((opt_s, opt_offset), opt_length)| { + match (opt_s, opt_offset, opt_length) { + (Some(s), Some(offset), Some(length)) => { + Some(s.as_ref().slice(offset, length as usize)) + }, + _ => None, + } + }) + .collect_trusted() + } }, }; out.rename(s.name()); @@ -210,7 +272,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { }) .collect::(); let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); - unsafe { s.take_unchecked(&take_by) }.map(Some) + unsafe { Ok(Some(s.take_unchecked(&take_by))) } }, len => polars_bail!( ComputeError: @@ -238,7 +300,7 @@ pub(super) fn take(args: &[Series], null_on_oob: bool) -> PolarsResult { } #[cfg(feature = "list_count")] -pub(super) fn count_match(args: &[Series]) -> PolarsResult { +pub(super) fn count_matches(args: &[Series]) -> PolarsResult { let s = &args[0]; let element = &args[1]; polars_ensure!( @@ -247,13 +309,58 @@ pub(super) fn count_match(args: &[Series]) -> PolarsResult { element.len() ); let ca = s.list()?; - list_count_match(ca, element.get(0).unwrap()) + list_count_matches(ca, element.get(0).unwrap()) } pub(super) fn sum(s: &Series) -> PolarsResult { Ok(s.list()?.lst_sum()) } +pub(super) fn length(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_lengths().into_series()) +} + +pub(super) fn max(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_max()) +} + +pub(super) fn min(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_min()) +} + +pub(super) fn mean(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_mean()) +} + +pub(super) fn arg_min(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_arg_min().into_series()) +} + +pub(super) fn arg_max(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_arg_max().into_series()) +} + +#[cfg(feature = "diff")] +pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { + Ok(s.list()?.lst_diff(n, null_behavior)?.into_series()) +} + +pub(super) fn sort(s: &Series, options: SortOptions) -> PolarsResult { + Ok(s.list()?.lst_sort(options).into_series()) +} + +pub(super) fn reverse(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_reverse().into_series()) +} + +pub(super) fn unique(s: &Series, is_stable: bool) -> PolarsResult { + if is_stable { + Ok(s.list()?.lst_unique_stable()?.into_series()) + } else { + Ok(s.list()?.lst_unique()?.into_series()) + } +} + #[cfg(feature = "list_sets")] pub(super) fn set_operation(s: &[Series], set_type: SetOperation) -> PolarsResult { let s0 = &s[0]; @@ -270,3 +377,9 @@ pub(super) fn lst_any(s: &Series) -> PolarsResult { pub(super) fn lst_all(s: &Series) -> PolarsResult { s.list()?.lst_all() } + +pub(super) fn join(s: &[Series]) -> PolarsResult { + let ca = s[0].list()?; + let separator = s[1].utf8()?; + Ok(ca.lst_join(separator)?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index ef3e5b1c3f8f..7f51397f8dfe 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -11,6 +11,7 @@ mod bounds; mod cat; #[cfg(feature = "round_series")] mod clip; +mod coerce; mod concat; mod correlation; mod cum; @@ -24,6 +25,10 @@ mod list; #[cfg(feature = "log")] mod log; mod nan; +#[cfg(feature = "peaks")] +mod peaks; +#[cfg(feature = "ffi_plugin")] +mod plugin; mod pow; #[cfg(feature = "random")] mod random; @@ -35,7 +40,7 @@ mod rolling; mod round; #[cfg(feature = "row_hash")] mod row_hash; -mod schema; +pub(super) mod schema; #[cfg(feature = "search_sorted")] mod search_sorted; mod shift_and_fill; @@ -110,7 +115,7 @@ pub enum FunctionExpr { #[cfg(feature = "range")] Range(RangeFunction), #[cfg(feature = "date_offset")] - DateOffset(polars_time::Duration), + DateOffset, #[cfg(feature = "trigonometry")] Trigonometry(TrigonometricFunction), #[cfg(feature = "trigonometry")] @@ -132,19 +137,18 @@ pub enum FunctionExpr { DropNans, #[cfg(feature = "round_series")] Clip { - min: Option>, - max: Option>, + has_min: bool, + has_max: bool, }, ListExpr(ListFunction), #[cfg(feature = "dtype-array")] ArrayExpr(ArrayFunction), #[cfg(feature = "dtype-struct")] StructExpr(StructFunction), + #[cfg(feature = "dtype-struct")] + AsStruct, #[cfg(feature = "top_k")] - TopK { - k: usize, - descending: bool, - }, + TopK(bool), Shift(i64), Cumcount { reverse: bool, @@ -162,6 +166,13 @@ pub enum FunctionExpr { reverse: bool, }, Reverse, + #[cfg(feature = "dtype-struct")] + ValueCounts { + sort: bool, + parallel: bool, + }, + #[cfg(feature = "unique_counts")] + UniqueCounts, Boolean(BooleanFunction), #[cfg(feature = "approx_unique")] ApproxNUnique, @@ -204,6 +215,10 @@ pub enum FunctionExpr { method: correlation::CorrelationMethod, ddof: u8, }, + #[cfg(feature = "peaks")] + PeakMin, + #[cfg(feature = "peaks")] + PeakMax, #[cfg(feature = "cutqcut")] Cut { breaks: Vec, @@ -230,18 +245,41 @@ pub enum FunctionExpr { seed: Option, }, SetSortedFlag(IsSorted), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { + lib: Arc, + symbol: Arc, + }, + BackwardFill { + limit: FillNullLimit, + }, + ForwardFill { + limit: FillNullLimit, + }, + SumHorizontal, + MaxHorizontal, + MinHorizontal, } impl Hash for FunctionExpr { fn hash(&self, state: &mut H) { std::mem::discriminant(self).hash(state); match self { + FunctionExpr::Pow(f) => f.hash(state), + #[cfg(feature = "search_sorted")] + FunctionExpr::SearchSorted(f) => f.hash(state), FunctionExpr::BinaryExpr(f) => f.hash(state), FunctionExpr::Boolean(f) => f.hash(state), #[cfg(feature = "strings")] FunctionExpr::StringExpr(f) => f.hash(state), + FunctionExpr::ListExpr(f) => f.hash(state), + #[cfg(feature = "dtype-array")] + FunctionExpr::ArrayExpr(f) => f.hash(state), + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(f) => f.hash(state), #[cfg(feature = "random")] FunctionExpr::Random { method, .. } => method.hash(state), + FunctionExpr::Correlation { method, .. } => method.hash(state), #[cfg(feature = "range")] FunctionExpr::Range(f) => f.hash(state), #[cfg(feature = "temporal")] @@ -250,10 +288,17 @@ impl Hash for FunctionExpr { FunctionExpr::Trigonometry(f) => f.hash(state), #[cfg(feature = "fused")] FunctionExpr::Fused(f) => f.hash(state), + #[cfg(feature = "diff")] + FunctionExpr::Diff(_, null_behavior) => null_behavior.hash(state), #[cfg(feature = "interpolate")] FunctionExpr::Interpolate(f) => f.hash(state), #[cfg(feature = "dtype-categorical")] FunctionExpr::Categorical(f) => f.hash(state), + #[cfg(feature = "ffi_plugin")] + FunctionExpr::FfiPlugin { lib, symbol } => { + lib.hash(state); + symbol.hash(state); + }, _ => {}, } } @@ -282,7 +327,7 @@ impl Display for FunctionExpr { #[cfg(feature = "range")] Range(func) => return write!(f, "{func}"), #[cfg(feature = "date_offset")] - DateOffset(_) => "dt.offset_by", + DateOffset => "dt.offset_by", #[cfg(feature = "trigonometry")] Trigonometry(func) => return write!(f, "{func}"), #[cfg(feature = "trigonometry")] @@ -295,23 +340,35 @@ impl Display for FunctionExpr { ShiftAndFill { .. } => "shift_and_fill", DropNans => "drop_nans", #[cfg(feature = "round_series")] - Clip { min, max } => match (min, max) { - (Some(_), Some(_)) => "clip", - (None, Some(_)) => "clip_max", - (Some(_), None) => "clip_min", + Clip { has_min, has_max } => match (has_min, has_max) { + (true, true) => "clip", + (false, true) => "clip_max", + (true, false) => "clip_min", _ => unreachable!(), }, ListExpr(func) => return write!(f, "{func}"), #[cfg(feature = "dtype-struct")] StructExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "dtype-struct")] + AsStruct => "as_struct", #[cfg(feature = "top_k")] - TopK { .. } => "top_k", + TopK(descending) => { + if *descending { + "bottom_k" + } else { + "top_k" + } + }, Shift(_) => "shift", Cumcount { .. } => "cumcount", Cumsum { .. } => "cumsum", Cumprod { .. } => "cumprod", Cummin { .. } => "cummin", Cummax { .. } => "cummax", + #[cfg(feature = "dtype-struct")] + ValueCounts { .. } => "value_counts", + #[cfg(feature = "unique_counts")] + UniqueCounts => "unique_counts", Reverse => "reverse", Boolean(func) => return write!(f, "{func}"), #[cfg(feature = "approx_unique")] @@ -353,6 +410,10 @@ impl Display for FunctionExpr { ArrayExpr(af) => return Display::fmt(af, f), ConcatExpr(_) => "concat_expr", Correlation { method, .. } => return Display::fmt(method, f), + #[cfg(feature = "peaks")] + PeakMin => "peak_min", + #[cfg(feature = "peaks")] + PeakMax => "peak_max", #[cfg(feature = "cutqcut")] Cut { .. } => "cut", #[cfg(feature = "cutqcut")] @@ -365,6 +426,13 @@ impl Display for FunctionExpr { #[cfg(feature = "random")] Random { method, .. } => method.into(), SetSortedFlag(_) => "set_sorted", + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol, .. } => return write!(f, "{lib}:{symbol}"), + BackwardFill { .. } => "backward_fill", + ForwardFill { .. } => "forward_fill", + SumHorizontal => "sum_horizontal", + MaxHorizontal => "max_horizontal", + MinHorizontal => "min_horizontal", }; write!(f, "{s}") } @@ -380,6 +448,7 @@ macro_rules! wrap { // Fn(&[Series], args) // all expression arguments are in the slice. // the first element is the root expression. +#[macro_export] macro_rules! map_as_slice { ($func:path) => {{ let f = move |s: &mut [Series]| { @@ -482,8 +551,8 @@ impl From for SpecialEq> { Range(func) => func.into(), #[cfg(feature = "date_offset")] - DateOffset(offset) => { - map_owned!(temporal::date_offset, offset) + DateOffset => { + map_as_slice!(temporal::date_offset) }, #[cfg(feature = "trigonometry")] @@ -512,8 +581,8 @@ impl From for SpecialEq> { }, DropNans => map_owned!(nan::drop_nans), #[cfg(feature = "round_series")] - Clip { min, max } => { - map_owned!(clip::clip, min.clone(), max.clone()) + Clip { has_min, has_max } => { + map_as_slice!(clip::clip, has_min, has_max) }, ListExpr(lf) => { use ListFunction::*; @@ -521,19 +590,34 @@ impl From for SpecialEq> { Concat => wrap!(list::concat), #[cfg(feature = "is_in")] Contains => wrap!(list::contains), + #[cfg(feature = "list_drop_nulls")] + DropNulls => map!(list::drop_nulls), Slice => wrap!(list::slice), + Shift => map_as_slice!(list::shift), Get => wrap!(list::get), #[cfg(feature = "list_take")] Take(null_ob_oob) => map_as_slice!(list::take, null_ob_oob), #[cfg(feature = "list_count")] - CountMatch => map_as_slice!(list::count_match), + CountMatches => map_as_slice!(list::count_matches), Sum => map!(list::sum), + Length => map!(list::length), + Max => map!(list::max), + Min => map!(list::min), + Mean => map!(list::mean), + ArgMin => map!(list::arg_min), + ArgMax => map!(list::arg_max), + #[cfg(feature = "diff")] + Diff { n, null_behavior } => map!(list::diff, n, null_behavior), + Sort(options) => map!(list::sort, options), + Reverse => map!(list::reverse), + Unique(is_stable) => map!(list::unique, is_stable), #[cfg(feature = "list_sets")] SetOperation(s) => map_as_slice!(list::set_operation, s), #[cfg(feature = "list_any_all")] Any => map!(list::lst_any), #[cfg(feature = "list_any_all")] All => map!(list::lst_all), + Join => map_as_slice!(list::join), } }, #[cfg(feature = "dtype-array")] @@ -552,11 +636,16 @@ impl From for SpecialEq> { match sf { FieldByIndex(index) => map!(struct_::get_by_index, index), FieldByName(name) => map!(struct_::get_by_name, name.clone()), + RenameFields(names) => map!(struct_::rename_fields, names.clone()), } }, + #[cfg(feature = "dtype-struct")] + AsStruct => { + map_as_slice!(coerce::as_struct) + }, #[cfg(feature = "top_k")] - TopK { k, descending } => { - map!(top_k, k, descending) + TopK(descending) => { + map_as_slice!(top_k, descending) }, Shift(periods) => map!(dispatch::shift, periods), Cumcount { reverse } => map!(cum::cumcount, reverse), @@ -564,6 +653,10 @@ impl From for SpecialEq> { Cumprod { reverse } => map!(cum::cumprod, reverse), Cummin { reverse } => map!(cum::cummin, reverse), Cummax { reverse } => map!(cum::cummax, reverse), + #[cfg(feature = "dtype-struct")] + ValueCounts { sort, parallel } => map!(dispatch::value_counts, sort, parallel), + #[cfg(feature = "unique_counts")] + UniqueCounts => map!(dispatch::unique_counts), Reverse => map!(dispatch::reverse), Boolean(func) => func.into(), #[cfg(feature = "approx_unique")] @@ -599,6 +692,10 @@ impl From for SpecialEq> { Fused(op) => map_as_slice!(fused::fused, op), ConcatExpr(rechunk) => map_as_slice!(concat::concat_expr, rechunk), Correlation { method, ddof } => map_as_slice!(correlation::corr, ddof, method), + #[cfg(feature = "peaks")] + PeakMin => map!(peaks::peak_min), + #[cfg(feature = "peaks")] + PeakMax => map!(peaks::peak_max), #[cfg(feature = "cutqcut")] Cut { breaks, @@ -633,8 +730,31 @@ impl From for SpecialEq> { RLEID => map!(rle_id), ToPhysical => map!(dispatch::to_physical), #[cfg(feature = "random")] - Random { method, seed } => map!(random::random, method, seed), + Random { method, seed } => { + use RandomMethod::*; + match method { + Shuffle => map!(random::shuffle, seed), + SampleFrac { + frac, + with_replacement, + shuffle, + } => map!(random::sample_frac, frac, with_replacement, shuffle, seed), + SampleN { + with_replacement, + shuffle, + } => map_as_slice!(random::sample_n, with_replacement, shuffle, seed), + } + }, SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol, .. } => unsafe { + map_as_slice!(plugin::call_plugin, lib.as_ref(), symbol.as_ref()) + }, + BackwardFill { limit } => map!(dispatch::backward_fill, limit), + ForwardFill { limit } => map!(dispatch::forward_fill, limit), + SumHorizontal => map_as_slice!(dispatch::sum_horizontal), + MaxHorizontal => wrap!(dispatch::max_horizontal), + MinHorizontal => wrap!(dispatch::min_horizontal), } } } @@ -646,8 +766,8 @@ impl From for SpecialEq> { match func { #[cfg(feature = "regex")] Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict), - CountMatch(pat) => { - map!(strings::count_match, &pat) + CountMatches(literal) => { + map_as_slice!(strings::count_matches, literal) }, EndsWith { .. } => map_as_slice!(strings::ends_with), StartsWith { .. } => map_as_slice!(strings::starts_with), @@ -661,8 +781,8 @@ impl From for SpecialEq> { ExtractGroups { pat, dtype } => { map!(strings::extract_groups, &pat, &dtype) }, - NChars => map!(strings::n_chars), - Length => map!(strings::lengths), + LenBytes => map!(strings::len_bytes), + LenChars => map!(strings::len_chars), #[cfg(feature = "string_justify")] Zfill(alignment) => { map!(strings::zfill, alignment) @@ -679,6 +799,13 @@ impl From for SpecialEq> { Strptime(dtype, options) => { map_as_slice!(strings::strptime, dtype.clone(), &options) }, + Split(inclusive) => { + map_as_slice!(strings::split, inclusive) + }, + #[cfg(feature = "dtype-struct")] + SplitExact { n, inclusive } => map_as_slice!(strings::split_exact, n, inclusive), + #[cfg(feature = "dtype-struct")] + SplitN(n) => map_as_slice!(strings::splitn, n), #[cfg(feature = "concat_str")] ConcatVertical(delimiter) => map!(strings::concat, &delimiter), #[cfg(feature = "concat_str")] @@ -689,9 +816,11 @@ impl From for SpecialEq> { Lowercase => map!(strings::lowercase), #[cfg(feature = "nightly")] Titlecase => map!(strings::titlecase), - Strip(matches) => map!(strings::strip, matches.as_deref()), - LStrip(matches) => map!(strings::lstrip, matches.as_deref()), - RStrip(matches) => map!(strings::rstrip, matches.as_deref()), + StripChars => map_as_slice!(strings::strip_chars), + StripCharsStart => map_as_slice!(strings::strip_chars_start), + StripCharsEnd => map_as_slice!(strings::strip_chars_end), + StripPrefix => map_as_slice!(strings::strip_prefix), + StripSuffix => map_as_slice!(strings::strip_suffix), #[cfg(feature = "string_from_radix")] FromRadix(radix, strict) => map!(strings::from_radix, radix, strict), Slice(start, length) => map!(strings::str_slice, start, length), @@ -711,14 +840,14 @@ impl From for SpecialEq> { fn from(func: BinaryFunction) -> Self { use BinaryFunction::*; match func { - Contains { pat, literal } => { - map!(binary::contains, &pat, literal) + Contains => { + map_as_slice!(binary::contains) }, - EndsWith(sub) => { - map!(binary::ends_with, &sub) + EndsWith => { + map_as_slice!(binary::ends_with) }, - StartsWith(sub) => { - map!(binary::starts_with, &sub) + StartsWith => { + map_as_slice!(binary::starts_with) }, } } @@ -747,9 +876,14 @@ impl From for SpecialEq> { Millisecond => map!(datetime::millisecond), Microsecond => map!(datetime::microsecond), Nanosecond => map!(datetime::nanosecond), + ToString(format) => map!(datetime::to_string, &format), TimeStamp(tu) => map!(datetime::timestamp, tu), - Truncate(truncate_options) => { - map_as_slice!(datetime::truncate, &truncate_options) + #[cfg(feature = "timezones")] + ConvertTimeZone(tz) => map!(datetime::convert_time_zone, &tz), + WithTimeUnit(tu) => map!(datetime::with_time_unit, tu), + CastTimeUnit(tu) => map!(datetime::cast_time_unit, tu), + Truncate(offset) => { + map_as_slice!(datetime::truncate, &offset) }, #[cfg(feature = "date_offset")] MonthStart => map!(datetime::month_start), @@ -759,62 +893,12 @@ impl From for SpecialEq> { BaseUtcOffset => map!(datetime::base_utc_offset), #[cfg(feature = "timezones")] DSTOffset => map!(datetime::dst_offset), - Round(every, offset) => map!(datetime::round, &every, &offset), + Round(every, offset) => map_as_slice!(datetime::round, &every, &offset), #[cfg(feature = "timezones")] ReplaceTimeZone(tz) => { map_as_slice!(dispatch::replace_time_zone, tz.as_deref()) }, Combine(tu) => map_as_slice!(temporal::combine, tu), - DateRange { - every, - closed, - time_unit, - time_zone, - } => { - map_as_slice!( - temporal::temporal_range_dispatch, - "date", - every, - closed, - time_unit, - time_zone.clone() - ) - }, - DateRanges { - every, - closed, - time_unit, - time_zone, - } => { - map_as_slice!( - temporal::temporal_ranges_dispatch, - "date_range", - every, - closed, - time_unit, - time_zone.clone() - ) - }, - TimeRange { every, closed } => { - map_as_slice!( - temporal::temporal_range_dispatch, - "time", - every, - closed, - None, - None - ) - }, - TimeRanges { every, closed } => { - map_as_slice!( - temporal::temporal_ranges_dispatch, - "time_range", - every, - closed, - None, - None - ) - }, DatetimeFunction { time_unit, time_zone, @@ -824,18 +908,3 @@ impl From for SpecialEq> { } } } - -#[cfg(feature = "range")] -impl From for SpecialEq> { - fn from(func: RangeFunction) -> Self { - use RangeFunction::*; - match func { - IntRange { step } => { - map_as_slice!(range::int_range, step) - }, - IntRanges { step } => { - map_as_slice!(range::int_ranges, step) - }, - } - } -} diff --git a/crates/polars-plan/src/dsl/function_expr/peaks.rs b/crates/polars-plan/src/dsl/function_expr/peaks.rs new file mode 100644 index 000000000000..bd3ce01b975c --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/peaks.rs @@ -0,0 +1,36 @@ +use polars_core::with_match_physical_numeric_polars_type; +use polars_ops::chunked_array::peaks::{peak_max as pmax, peak_min as pmin}; + +use super::*; + +pub(super) fn peak_min(s: &Series) -> PolarsResult { + let s = s.to_physical_repr(); + let s = match s.dtype() { + DataType::Boolean => polars_bail!(opq = peak_min, DataType::Boolean), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => pmin(s.decimal()?).into_series(), + dt => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + pmin(ca).into_series() + }) + }, + }; + Ok(s) +} + +pub(super) fn peak_max(s: &Series) -> PolarsResult { + let s = s.to_physical_repr(); + let s = match s.dtype() { + DataType::Boolean => polars_bail!(opq = peak_max, DataType::Boolean), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => pmax(s.decimal()?).into_series(), + dt => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + pmax(ca).into_series() + }) + }, + }; + Ok(s) +} diff --git a/crates/polars-plan/src/dsl/function_expr/plugin.rs b/crates/polars-plan/src/dsl/function_expr/plugin.rs new file mode 100644 index 000000000000..6c8113a54aac --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/plugin.rs @@ -0,0 +1,75 @@ +use std::sync::RwLock; + +use arrow::ffi::{import_field_from_c, ArrowSchema}; +use libloading::Library; +use once_cell::sync::Lazy; +use polars_ffi::*; + +use super::*; + +static LOADED: Lazy>> = Lazy::new(Default::default); + +fn get_lib(lib: &str) -> PolarsResult<&'static Library> { + let lib_map = LOADED.read().unwrap(); + if let Some(library) = lib_map.get(lib) { + // lifetime is static as we never remove libraries. + Ok(unsafe { std::mem::transmute::<&Library, &'static Library>(library) }) + } else { + drop(lib_map); + let library = unsafe { + Library::new(lib).map_err(|e| { + PolarsError::ComputeError(format!("error loading dynamic library: {e}").into()) + })? + }; + + let mut lib_map = LOADED.write().unwrap(); + lib_map.insert(lib.to_string(), library); + drop(lib_map); + + get_lib(lib) + } +} + +pub(super) unsafe fn call_plugin(s: &[Series], lib: &str, symbol: &str) -> PolarsResult { + let lib = get_lib(lib)?; + + let symbol: libloading::Symbol< + unsafe extern "C" fn(*const SeriesExport, usize) -> SeriesExport, + > = lib.get(symbol.as_bytes()).unwrap(); + + let n_args = s.len(); + + let input = s.iter().map(export_series).collect::>(); + let slice_ptr = input.as_ptr(); + let out = symbol(slice_ptr, n_args); + + for e in input { + std::mem::forget(e); + } + + import_series(out) +} + +pub(super) unsafe fn plugin_field( + fields: &[Field], + lib: &str, + symbol: &str, +) -> PolarsResult { + let lib = get_lib(lib)?; + + let symbol: libloading::Symbol ArrowSchema> = + lib.get(symbol.as_bytes()).unwrap(); + + // we deallocate the fields buffer + let fields = fields + .iter() + .map(|field| arrow::ffi::export_field_to_c(&field.to_arrow())) + .collect::>() + .into_boxed_slice(); + let n_args = fields.len(); + let slice_ptr = fields.as_ptr(); + let out = symbol(slice_ptr, n_args); + + let arrow_field = import_field_from_c(&out)?; + Ok(Field::from(&arrow_field)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/random.rs b/crates/polars-plan/src/dsl/function_expr/random.rs index cb8b7c586915..9555671abaf0 100644 --- a/crates/polars-plan/src/dsl/function_expr/random.rs +++ b/crates/polars-plan/src/dsl/function_expr/random.rs @@ -10,7 +10,6 @@ use super::*; pub enum RandomMethod { Shuffle, SampleN { - n: usize, with_replacement: bool, shuffle: bool, }, @@ -27,18 +26,39 @@ impl Hash for RandomMethod { } } -pub(super) fn random(s: &Series, method: RandomMethod, seed: Option) -> PolarsResult { - match method { - RandomMethod::Shuffle => Ok(s.shuffle(seed)), - RandomMethod::SampleFrac { - frac, - with_replacement, - shuffle, - } => s.sample_frac(frac, with_replacement, shuffle, seed), - RandomMethod::SampleN { - n, - with_replacement, - shuffle, - } => s.sample_n(n, with_replacement, shuffle, seed), +pub(super) fn shuffle(s: &Series, seed: Option) -> PolarsResult { + Ok(s.shuffle(seed)) +} + +pub(super) fn sample_frac( + s: &Series, + frac: f64, + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + s.sample_frac(frac, with_replacement, shuffle, seed) +} + +pub(super) fn sample_n( + s: &[Series], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let src = &s[0]; + let n_s = &s[1]; + + polars_ensure!( + n_s.len() == 1, + ComputeError: "Sample size must be a single value." + ); + + let n_s = n_s.cast(&IDX_DTYPE)?; + let n = n_s.idx()?; + + match n.get(0) { + Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed), + None => Ok(Series::new_empty(src.name(), src.dtype())), } } diff --git a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs new file mode 100644 index 000000000000..84406821f493 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs @@ -0,0 +1,178 @@ +use polars_core::prelude::*; +use polars_core::series::Series; +use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY; +use polars_time::{datetime_range_impl, ClosedWindow, Duration}; + +use super::datetime_range::{datetime_range, datetime_ranges}; +use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar}; +use crate::dsl::function_expr::FieldsMapper; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn temporal_range( + s: &[Series], + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> PolarsResult { + if s[0].dtype() == &DataType::Date && interval.is_full_days() { + date_range(s, interval, closed) + } else { + let mut s = datetime_range(s, interval, closed, time_unit, time_zone)?; + s.rename("date"); + Ok(s) + } +} + +pub(super) fn temporal_ranges( + s: &[Series], + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> PolarsResult { + if s[0].dtype() == &DataType::Date && interval.is_full_days() { + date_ranges(s, interval, closed) + } else { + let mut s = datetime_ranges(s, interval, closed, time_unit, time_zone)?; + s.rename("date_range"); + Ok(s) + } +} + +fn date_range(s: &[Series], interval: Duration, closed: ClosedWindow) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + + let dtype = DataType::Date; + let start = temporal_series_to_i64_scalar(start) * MILLISECONDS_IN_DAY; + let end = temporal_series_to_i64_scalar(end) * MILLISECONDS_IN_DAY; + + let result = datetime_range_impl( + "date", + start, + end, + interval, + closed, + TimeUnit::Milliseconds, + None, + )? + .cast(&dtype)?; + + Ok(result.into_series()) +} + +fn date_ranges(s: &[Series], interval: Duration, closed: ClosedWindow) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + polars_ensure!( + start.len() == end.len(), + ComputeError: "`start` and `end` must have the same length", + ); + + let start = date_series_to_i64_ca(start)? * MILLISECONDS_IN_DAY; + let end = date_series_to_i64_ca(end)? * MILLISECONDS_IN_DAY; + + let mut builder = ListPrimitiveChunkedBuilder::::new( + "date_range", + start.len(), + start.len() * CAPACITY_FACTOR, + DataType::Int32, + ); + for (start, end) in start.as_ref().into_iter().zip(&end) { + match (start, end) { + (Some(start), Some(end)) => { + // TODO: Implement an i32 version of `date_range_impl` + let rng = datetime_range_impl( + "", + start, + end, + interval, + closed, + TimeUnit::Milliseconds, + None, + )?; + let rng = rng.cast(&DataType::Date).unwrap(); + let rng = rng.to_physical_repr(); + let rng = rng.i32().unwrap(); + builder.append_slice(rng.cont_slice().unwrap()) + }, + _ => builder.append_null(), + } + } + let list = builder.finish().into_series(); + + let to_type = DataType::List(Box::new(DataType::Date)); + list.cast(&to_type) +} +fn date_series_to_i64_ca(s: &Series) -> PolarsResult> { + let s = s.cast(&DataType::Int64)?; + let result = s.i64().unwrap(); + Ok(result.clone()) +} + +impl<'a> FieldsMapper<'a> { + pub(super) fn map_to_date_range_dtype( + &self, + interval: &Duration, + time_unit: Option<&TimeUnit>, + time_zone: Option<&str>, + ) -> PolarsResult { + let data_dtype = self.map_to_supertype()?.dtype; + match data_dtype { + DataType::Datetime(tu, tz) => { + map_datetime_to_date_range_dtype(tu, tz, time_unit, time_zone) + }, + DataType::Date => { + let schema_dtype = map_date_to_date_range_dtype(interval, time_unit, time_zone); + Ok(schema_dtype) + }, + _ => polars_bail!(ComputeError: "expected Date or Datetime, got {}", data_dtype), + } + } +} + +fn map_datetime_to_date_range_dtype( + data_time_unit: TimeUnit, + data_time_zone: Option, + given_time_unit: Option<&TimeUnit>, + given_time_zone: Option<&str>, +) -> PolarsResult { + let schema_time_zone = match (data_time_zone, given_time_zone) { + (Some(data_tz), Some(given_tz)) => { + polars_ensure!( + data_tz == given_tz, + ComputeError: format!( + "`time_zone` does not match the data\ + \n\nData has time zone '{}', got '{}'.", data_tz, given_tz) + ); + Some(data_tz) + }, + (_, Some(given_tz)) => Some(given_tz.to_string()), + (Some(data_tz), None) => Some(data_tz), + (_, _) => None, + }; + let schema_time_unit = given_time_unit.unwrap_or(&data_time_unit); + + let schema_dtype = DataType::Datetime(*schema_time_unit, schema_time_zone); + Ok(schema_dtype) +} +fn map_date_to_date_range_dtype( + interval: &Duration, + time_unit: Option<&TimeUnit>, + time_zone: Option<&str>, +) -> DataType { + if interval.is_full_days() { + DataType::Date + } else if let Some(tu) = time_unit { + DataType::Datetime(*tu, time_zone.map(String::from)) + } else if interval.nanoseconds() % 1000 != 0 { + DataType::Datetime(TimeUnit::Nanoseconds, time_zone.map(String::from)) + } else { + DataType::Datetime(TimeUnit::Microseconds, time_zone.map(String::from)) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs new file mode 100644 index 000000000000..874656cd55fc --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -0,0 +1,212 @@ +use polars_core::prelude::*; +use polars_core::series::Series; +use polars_time::{datetime_range_impl, ClosedWindow, Duration}; + +use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar}; +use crate::dsl::function_expr::FieldsMapper; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn datetime_range( + s: &[Series], + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + + // Note: `start` and `end` have already been cast to their supertype, + // so only `start`'s dtype needs to be matched against. + #[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block. + let mut dtype = match (start.dtype(), time_unit) { + (DataType::Date, time_unit) => { + if let Some(tu) = time_unit { + DataType::Datetime(tu, None) + } else if interval.nanoseconds() % 1_000 != 0 { + DataType::Datetime(TimeUnit::Nanoseconds, None) + } else { + DataType::Datetime(TimeUnit::Microseconds, None) + } + }, + // overwrite nothing, keep as-is + (DataType::Datetime(_, _), None) => start.dtype().clone(), + // overwrite time unit, keep timezone + (DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()), + _ => unreachable!(), + }; + + let (start, end) = match dtype { + #[cfg(feature = "timezones")] + DataType::Datetime(_, Some(_)) => ( + polars_ops::prelude::replace_time_zone( + start.cast(&dtype)?.datetime().unwrap(), + None, + &Utf8Chunked::from_iter(std::iter::once("raise")), + )? + .into_series(), + polars_ops::prelude::replace_time_zone( + end.cast(&dtype)?.datetime().unwrap(), + None, + &Utf8Chunked::from_iter(std::iter::once("raise")), + )? + .into_series(), + ), + _ => (start.cast(&dtype)?, end.cast(&dtype)?), + }; + + // overwrite time zone, if specified + match (&dtype, &time_zone) { + #[cfg(feature = "timezones")] + (DataType::Datetime(tu, _), Some(tz)) => { + dtype = DataType::Datetime(*tu, Some(tz.clone())); + }, + _ => {}, + }; + + let start = temporal_series_to_i64_scalar(&start); + let end = temporal_series_to_i64_scalar(&end); + + let result = match dtype { + DataType::Datetime(tu, ref tz) => { + datetime_range_impl("datetime", start, end, interval, closed, tu, tz.as_ref())? + }, + _ => unimplemented!(), + }; + Ok(result.cast(&dtype).unwrap().into_series()) +} + +pub(super) fn datetime_ranges( + s: &[Series], + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + polars_ensure!( + start.len() == end.len(), + ComputeError: "`start` and `end` must have the same length", + ); + + // Note: `start` and `end` have already been cast to their supertype, + // so only `start`'s dtype needs to be matched against. + #[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block. + let mut dtype = match (start.dtype(), time_unit) { + (DataType::Date, time_unit) => { + if let Some(tu) = time_unit { + DataType::Datetime(tu, None) + } else if interval.nanoseconds() % 1_000 != 0 { + DataType::Datetime(TimeUnit::Nanoseconds, None) + } else { + DataType::Datetime(TimeUnit::Microseconds, None) + } + }, + // overwrite nothing, keep as-is + (DataType::Datetime(_, _), None) => start.dtype().clone(), + // overwrite time unit, keep timezone + (DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()), + _ => unreachable!(), + }; + + let (start, end) = match dtype { + #[cfg(feature = "timezones")] + DataType::Datetime(_, Some(_)) => ( + polars_ops::prelude::replace_time_zone( + start.cast(&dtype)?.datetime().unwrap(), + None, + &Utf8Chunked::from_iter(std::iter::once("raise")), + )? + .into_series() + .to_physical_repr() + .cast(&DataType::Int64)?, + polars_ops::prelude::replace_time_zone( + end.cast(&dtype)?.datetime().unwrap(), + None, + &Utf8Chunked::from_iter(std::iter::once("raise")), + )? + .into_series() + .to_physical_repr() + .cast(&DataType::Int64)?, + ), + _ => ( + start + .cast(&dtype)? + .to_physical_repr() + .cast(&DataType::Int64)?, + end.cast(&dtype)? + .to_physical_repr() + .cast(&DataType::Int64)?, + ), + }; + + // overwrite time zone, if specified + match (&dtype, &time_zone) { + #[cfg(feature = "timezones")] + (DataType::Datetime(tu, _), Some(tz)) => { + dtype = DataType::Datetime(*tu, Some(tz.clone())); + }, + _ => {}, + }; + + let start = start.i64().unwrap(); + let end = end.i64().unwrap(); + + let list = match dtype { + DataType::Datetime(tu, ref tz) => { + let mut builder = ListPrimitiveChunkedBuilder::::new( + "datetime_range", + start.len(), + start.len() * CAPACITY_FACTOR, + DataType::Int64, + ); + for (start, end) in start.into_iter().zip(end) { + match (start, end) { + (Some(start), Some(end)) => { + let rng = + datetime_range_impl("", start, end, interval, closed, tu, tz.as_ref())?; + builder.append_slice(rng.cont_slice().unwrap()) + }, + _ => builder.append_null(), + } + } + builder.finish().into_series() + }, + _ => unimplemented!(), + }; + + let to_type = DataType::List(Box::new(dtype)); + list.cast(&to_type) +} + +impl<'a> FieldsMapper<'a> { + pub(super) fn map_to_datetime_range_dtype( + &self, + time_unit: Option<&TimeUnit>, + time_zone: Option<&str>, + ) -> PolarsResult { + let data_dtype = self.map_to_supertype()?.dtype; + + let (data_tu, data_tz) = if let DataType::Datetime(tu, tz) = data_dtype { + (tu, tz) + } else { + (TimeUnit::Microseconds, None) + }; + + let tu = match time_unit { + Some(tu) => *tu, + None => data_tu, + }; + let tz = match time_zone { + Some(tz) => Some(tz.to_string()), + None => data_tz, + }; + + Ok(DataType::Datetime(tu, tz)) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs similarity index 72% rename from crates/polars-plan/src/dsl/function_expr/range.rs rename to crates/polars-plan/src/dsl/function_expr/range/int_range.rs index e348f17ba7c0..3af3933f1b0b 100644 --- a/crates/polars-plan/src/dsl/function_expr/range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -1,83 +1,27 @@ -use super::*; +use polars_core::prelude::*; +use polars_core::series::{IsSorted, Series}; -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] -pub enum RangeFunction { - IntRange { step: i64 }, - IntRanges { step: i64 }, -} - -impl Display for RangeFunction { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use RangeFunction::*; - match self { - IntRange { .. } => write!(f, "int_range"), - IntRanges { .. } => write!(f, "int_ranges"), - } - } -} - -fn int_range_impl(start: T::Native, end: T::Native, step: i64) -> PolarsResult -where - T: PolarsNumericType, - ChunkedArray: IntoSeries, - std::ops::Range: Iterator, - std::ops::RangeInclusive: DoubleEndedIterator, -{ - let name = "int"; - - let mut ca = match step { - 0 => polars_bail!(InvalidOperation: "step must not be zero"), - 1 => ChunkedArray::::from_iter_values(name, start..end), - 2.. => ChunkedArray::::from_iter_values(name, (start..end).step_by(step as usize)), - _ => { - polars_ensure!(start > end, InvalidOperation: "range must be decreasing if 'step' is negative"); - ChunkedArray::::from_iter_values( - name, - (end..=start).rev().step_by(step.unsigned_abs() as usize), - ) - }, - }; - - let is_sorted = if end < start { - IsSorted::Descending - } else { - IsSorted::Ascending - }; - ca.set_sorted_flag(is_sorted); - - Ok(ca.into_series()) -} +use super::utils::ensure_range_bounds_contain_exactly_one_value; pub(super) fn int_range(s: &[Series], step: i64) -> PolarsResult { let start = &s[0]; let end = &s[1]; + ensure_range_bounds_contain_exactly_one_value(start, end)?; + match start.dtype() { dt if dt == &IDX_DTYPE => { - let start = start - .idx()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `start` evaluation"))?; + let start = start.idx()?.get(0).unwrap(); let end = end.cast(&IDX_DTYPE)?; - let end = end - .idx()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `end` evaluation"))?; + let end = end.idx()?.get(0).unwrap(); int_range_impl::(start, end, step) }, _ => { let start = start.cast(&DataType::Int64)?; let end = end.cast(&DataType::Int64)?; - let start = start - .i64()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `start` evaluation"))?; - let end = end - .i64()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `end` evaluation"))?; + let start = start.i64()?.get(0).unwrap(); + let end = end.i64()?.get(0).unwrap(); int_range_impl::(start, end, step) }, } @@ -150,9 +94,9 @@ pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult { builder.append_iter_values((start_v..end_v).step_by(step as usize)); }, _ => builder.append_iter_values( - (end_v..=start_v) - .rev() - .step_by(step.unsigned_abs() as usize), + (end_v..start_v) + .step_by(step.unsigned_abs() as usize) + .map(|x| start_v - (x - end_v)), ), }, _ => builder.append_null(), @@ -161,3 +105,36 @@ pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult { Ok(builder.finish().into_series()) } + +fn int_range_impl(start: T::Native, end: T::Native, step: i64) -> PolarsResult +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, + std::ops::Range: DoubleEndedIterator, +{ + let name = "int"; + + let mut ca = match step { + 0 => polars_bail!(InvalidOperation: "step must not be zero"), + 1 => ChunkedArray::::from_iter_values(name, start..end), + 2.. => ChunkedArray::::from_iter_values(name, (start..end).step_by(step as usize)), + _ => { + polars_ensure!(start > end, InvalidOperation: "range must be decreasing if 'step' is negative"); + ChunkedArray::::from_iter_values( + name, + (end..start) + .step_by(step.unsigned_abs() as usize) + .map(|x| start - (x - end)), + ) + }, + }; + + let is_sorted = if end < start { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(is_sorted); + + Ok(ca.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs new file mode 100644 index 000000000000..8c51d266d461 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -0,0 +1,258 @@ +#[cfg(feature = "temporal")] +mod date_range; +#[cfg(feature = "dtype-datetime")] +mod datetime_range; +mod int_range; +#[cfg(feature = "dtype-time")] +mod time_range; +mod utils; + +use std::fmt::{Display, Formatter}; + +use polars_core::prelude::*; +use polars_core::series::Series; +#[cfg(feature = "temporal")] +use polars_time::{ClosedWindow, Duration}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::dsl::function_expr::FieldsMapper; +use crate::dsl::SpecialEq; +use crate::map_as_slice; +use crate::prelude::SeriesUdf; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum RangeFunction { + IntRange { + step: i64, + }, + IntRanges { + step: i64, + }, + #[cfg(feature = "temporal")] + DateRange { + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, + }, + #[cfg(feature = "temporal")] + DateRanges { + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRange { + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, + }, + #[cfg(feature = "dtype-time")] + TimeRange { + interval: Duration, + closed: ClosedWindow, + }, + #[cfg(feature = "dtype-time")] + TimeRanges { + interval: Duration, + closed: ClosedWindow, + }, +} + +impl RangeFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use RangeFunction::*; + let field = match self { + IntRange { .. } => Field::new("int", DataType::Int64), + IntRanges { .. } => Field::new("int_range", DataType::List(Box::new(DataType::Int64))), + #[cfg(feature = "temporal")] + DateRange { + interval, + closed: _, + time_unit, + time_zone, + } => { + // output dtype may change based on `interval`, `time_unit`, and `time_zone` + let dtype = mapper.map_to_date_range_dtype( + interval, + time_unit.as_ref(), + time_zone.as_deref(), + )?; + return Ok(Field::new("date", dtype)); + }, + #[cfg(feature = "temporal")] + DateRanges { + interval, + closed: _, + time_unit, + time_zone, + } => { + // output dtype may change based on `interval`, `time_unit`, and `time_zone` + let inner_dtype = mapper.map_to_date_range_dtype( + interval, + time_unit.as_ref(), + time_zone.as_deref(), + )?; + return Ok(Field::new( + "date_range", + DataType::List(Box::new(inner_dtype)), + )); + }, + #[cfg(feature = "temporal")] + DatetimeRange { + interval: _, + closed: _, + time_unit, + time_zone, + } => { + // output dtype may change based on `interval`, `time_unit`, and `time_zone` + let dtype = + mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_deref())?; + return Ok(Field::new("datetime", dtype)); + }, + #[cfg(feature = "temporal")] + DatetimeRanges { + interval: _, + closed: _, + time_unit, + time_zone, + } => { + // output dtype may change based on `interval`, `time_unit`, and `time_zone` + let inner_dtype = + mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_deref())?; + return Ok(Field::new( + "datetime_range", + DataType::List(Box::new(inner_dtype)), + )); + }, + #[cfg(feature = "dtype-time")] + TimeRange { .. } => { + return Ok(Field::new("time", DataType::Time)); + }, + #[cfg(feature = "dtype-time")] + TimeRanges { .. } => { + return Ok(Field::new( + "time_range", + DataType::List(Box::new(DataType::Time)), + )); + }, + }; + Ok(field) + } +} + +impl Display for RangeFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RangeFunction::*; + let s = match self { + IntRange { .. } => "int_range", + IntRanges { .. } => "int_ranges", + #[cfg(feature = "temporal")] + DateRange { .. } => "date_range", + #[cfg(feature = "temporal")] + DateRanges { .. } => "date_ranges", + #[cfg(feature = "dtype-datetime")] + DatetimeRange { .. } => "datetime_range", + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { .. } => "datetime_ranges", + #[cfg(feature = "dtype-time")] + TimeRange { .. } => "time_range", + #[cfg(feature = "dtype-time")] + TimeRanges { .. } => "time_ranges", + }; + write!(f, "{s}") + } +} + +impl From for SpecialEq> { + fn from(func: RangeFunction) -> Self { + use RangeFunction::*; + match func { + IntRange { step } => { + map_as_slice!(int_range::int_range, step) + }, + IntRanges { step } => { + map_as_slice!(int_range::int_ranges, step) + }, + #[cfg(feature = "temporal")] + DateRange { + interval, + closed, + time_unit, + time_zone, + } => { + map_as_slice!( + date_range::temporal_range, + interval, + closed, + time_unit, + time_zone.clone() + ) + }, + #[cfg(feature = "temporal")] + DateRanges { + interval, + closed, + time_unit, + time_zone, + } => { + map_as_slice!( + date_range::temporal_ranges, + interval, + closed, + time_unit, + time_zone.clone() + ) + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRange { + interval, + closed, + time_unit, + time_zone, + } => { + map_as_slice!( + datetime_range::datetime_range, + interval, + closed, + time_unit, + time_zone.clone() + ) + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { + interval, + closed, + time_unit, + time_zone, + } => { + map_as_slice!( + datetime_range::datetime_ranges, + interval, + closed, + time_unit, + time_zone.clone() + ) + }, + #[cfg(feature = "dtype-time")] + TimeRange { interval, closed } => { + map_as_slice!(time_range::time_range, interval, closed) + }, + #[cfg(feature = "dtype-time")] + TimeRanges { interval, closed } => { + map_as_slice!(time_range::time_ranges, interval, closed) + }, + } + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs new file mode 100644 index 000000000000..3dd3e44fb904 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs @@ -0,0 +1,68 @@ +use polars_core::prelude::*; +use polars_core::series::Series; +use polars_time::{time_range_impl, ClosedWindow, Duration}; + +use super::utils::{ensure_range_bounds_contain_exactly_one_value, temporal_series_to_i64_scalar}; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn time_range( + s: &[Series], + interval: Duration, + closed: ClosedWindow, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + + let dtype = DataType::Time; + let start = temporal_series_to_i64_scalar(&start.cast(&dtype)?); + let end = temporal_series_to_i64_scalar(&end.cast(&dtype)?); + + let out = time_range_impl("time", start, end, interval, closed)?; + Ok(out.cast(&dtype).unwrap().into_series()) +} + +pub(super) fn time_ranges( + s: &[Series], + interval: Duration, + closed: ClosedWindow, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + polars_ensure!( + start.len() == end.len(), + ComputeError: "`start` and `end` must have the same length", + ); + + let start = time_series_to_i64_ca(start)?; + let end = time_series_to_i64_ca(end)?; + + let mut builder = ListPrimitiveChunkedBuilder::::new( + "time_range", + start.len(), + start.len() * CAPACITY_FACTOR, + DataType::Int64, + ); + for (start, end) in start.as_ref().into_iter().zip(&end) { + match (start, end) { + (Some(start), Some(end)) => { + let rng = time_range_impl("", start, end, interval, closed)?; + builder.append_slice(rng.cont_slice().unwrap()) + }, + _ => builder.append_null(), + } + } + let list = builder.finish().into_series(); + + let to_type = DataType::List(Box::new(DataType::Time)); + list.cast(&to_type) +} +fn time_series_to_i64_ca(s: &Series) -> PolarsResult> { + let s = s.cast(&DataType::Time)?; + let s = s.to_physical_repr(); + let result = s.i64().unwrap(); + Ok(result.clone()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/utils.rs b/crates/polars-plan/src/dsl/function_expr/range/utils.rs new file mode 100644 index 000000000000..1ff8da6b2ba3 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/utils.rs @@ -0,0 +1,25 @@ +use polars_core::prelude::{polars_bail, polars_ensure, PolarsResult}; +use polars_core::series::Series; + +pub(super) fn temporal_series_to_i64_scalar(s: &Series) -> i64 { + s.to_physical_repr() + .get(0) + .unwrap() + .extract::() + .unwrap() +} + +pub(super) fn ensure_range_bounds_contain_exactly_one_value( + start: &Series, + end: &Series, +) -> PolarsResult<()> { + polars_ensure!( + start.len() == 1, + ComputeError: "`start` must contain exactly one value, got {} values", start.len() + ); + polars_ensure!( + end.len() == 1, + ComputeError: "`end` must contain exactly one value, got {} values", end.len() + ); + Ok(()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 37760a3c4742..3d8996e74431 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -27,112 +27,15 @@ impl FunctionExpr { BinaryExpr(s) => { use BinaryFunction::*; match s { - Contains { .. } | EndsWith(_) | StartsWith(_) => { - mapper.with_dtype(DataType::Boolean) - }, + Contains { .. } | EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), } }, #[cfg(feature = "temporal")] - TemporalExpr(fun) => { - use TemporalFunction::*; - let dtype = match fun { - Year | IsoYear => DataType::Int32, - Month | Quarter | Week | WeekDay | Day | OrdinalDay | Hour | Minute - | Millisecond | Microsecond | Nanosecond | Second => DataType::UInt32, - TimeStamp(_) => DataType::Int64, - IsLeapYear => DataType::Boolean, - Time => DataType::Time, - Date => DataType::Date, - Datetime => match mapper.with_same_dtype().unwrap().dtype { - DataType::Datetime(tu, _) => DataType::Datetime(tu, None), - dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), - }, - Truncate(..) => mapper.with_same_dtype().unwrap().dtype, - #[cfg(feature = "date_offset")] - MonthStart => mapper.with_same_dtype().unwrap().dtype, - #[cfg(feature = "date_offset")] - MonthEnd => mapper.with_same_dtype().unwrap().dtype, - #[cfg(feature = "timezones")] - BaseUtcOffset => DataType::Duration(TimeUnit::Milliseconds), - #[cfg(feature = "timezones")] - DSTOffset => DataType::Duration(TimeUnit::Milliseconds), - Round(..) => mapper.with_same_dtype().unwrap().dtype, - #[cfg(feature = "timezones")] - ReplaceTimeZone(tz) => return mapper.map_datetime_dtype_timezone(tz.as_ref()), - DateRange { - every, - closed: _, - time_unit, - time_zone, - } => { - // output dtype may change based on `every`, `time_unit`, and `time_zone` - let dtype = mapper.map_to_date_range_dtype( - every, - time_unit.as_ref(), - time_zone.as_deref(), - )?; - return Ok(Field::new("date", dtype)); - }, - DateRanges { - every, - closed: _, - time_unit, - time_zone, - } => { - // output dtype may change based on `every`, `time_unit`, and `time_zone` - let inner_dtype = mapper.map_to_date_range_dtype( - every, - time_unit.as_ref(), - time_zone.as_deref(), - )?; - return Ok(Field::new( - "date_range", - DataType::List(Box::new(inner_dtype)), - )); - }, - - TimeRange { .. } => { - return Ok(Field::new("time", DataType::Time)); - }, - TimeRanges { .. } => { - return Ok(Field::new( - "time_range", - DataType::List(Box::new(DataType::Time)), - )); - }, - DatetimeFunction { - time_unit, - time_zone, - } => { - return Ok(Field::new( - "datetime", - DataType::Datetime(*time_unit, time_zone.clone()), - )); - }, - Combine(tu) => match mapper.with_same_dtype().unwrap().dtype { - DataType::Datetime(_, tz) => DataType::Datetime(*tu, tz), - DataType::Date => DataType::Datetime(*tu, None), - dtype => { - polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype) - }, - }, - }; - mapper.with_dtype(dtype) - }, - + TemporalExpr(fun) => fun.get_field(mapper), #[cfg(feature = "range")] - Range(fun) => { - use RangeFunction::*; - let field = match fun { - IntRange { .. } => Field::new("int", DataType::Int64), - IntRanges { .. } => { - Field::new("int_range", DataType::List(Box::new(DataType::Int64))) - }, - }; - Ok(field) - }, + Range(func) => func.get_field(mapper), #[cfg(feature = "date_offset")] - DateOffset(_) => mapper.with_same_dtype(), + DateOffset { .. } => mapper.with_same_dtype(), #[cfg(feature = "trigonometry")] Trigonometry(_) => mapper.map_to_float_dtype(), #[cfg(feature = "trigonometry")] @@ -152,19 +55,34 @@ impl FunctionExpr { Concat => mapper.map_to_list_supertype(), #[cfg(feature = "is_in")] Contains => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "list_drop_nulls")] + DropNulls => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), + Shift => mapper.with_same_dtype(), Get => mapper.map_to_list_inner_dtype(), #[cfg(feature = "list_take")] Take(_) => mapper.with_same_dtype(), #[cfg(feature = "list_count")] - CountMatch => mapper.with_dtype(IDX_DTYPE), + CountMatches => mapper.with_dtype(IDX_DTYPE), Sum => mapper.nested_sum_type(), + Min => mapper.map_to_list_inner_dtype(), + Max => mapper.map_to_list_inner_dtype(), + Mean => mapper.with_dtype(DataType::Float64), + ArgMin => mapper.with_dtype(IDX_DTYPE), + ArgMax => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "diff")] + Diff { .. } => mapper.with_same_dtype(), + Sort(_) => mapper.with_same_dtype(), + Reverse => mapper.with_same_dtype(), + Unique(_) => mapper.with_same_dtype(), + Length => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "list_sets")] SetOperation(_) => mapper.with_same_dtype(), #[cfg(feature = "list_any_all")] Any => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_any_all")] All => mapper.with_dtype(DataType::Boolean), + Join => mapper.with_dtype(DataType::Utf8), } }, #[cfg(feature = "dtype-array")] @@ -183,39 +101,23 @@ impl FunctionExpr { } }, #[cfg(feature = "dtype-struct")] - StructExpr(s) => { - use polars_core::utils::slice_offsets; - use StructFunction::*; - match s { - FieldByIndex(index) => { - let (index, _) = slice_offsets(*index, 0, fields.len()); - if let DataType::Struct(flds) = &fields[0].dtype { - flds.get(index).cloned().ok_or_else( - || polars_err!(ComputeError: "index out of bounds in `struct.field`") - ) - } else { - polars_bail!( - ComputeError: "expected struct dtype, got: `{}`", &fields[0].dtype - ) - } - }, - FieldByName(name) => { - if let DataType::Struct(flds) = &fields[0].dtype { - let fld = flds - .iter() - .find(|fld| fld.name() == name.as_ref()) - .ok_or_else( - || polars_err!(StructFieldNotFound: "{}", name.as_ref()), - )?; - Ok(fld.clone()) - } else { - polars_bail!(StructFieldNotFound: "{}", name.as_ref()); - } - }, - } - }, + AsStruct => Ok(Field::new( + fields[0].name(), + DataType::Struct(fields.to_vec()), + )), + #[cfg(feature = "dtype-struct")] + StructExpr(s) => s.get_field(mapper), #[cfg(feature = "top_k")] - TopK { .. } => mapper.with_same_dtype(), + TopK(_) => mapper.with_same_dtype(), + #[cfg(feature = "dtype-struct")] + ValueCounts { .. } => mapper.map_dtype(|dt| { + DataType::Struct(vec![ + Field::new(fields[0].name().as_str(), dt.clone()), + Field::new("counts", IDX_DTYPE), + ]) + }), + #[cfg(feature = "unique_counts")] + UniqueCounts => mapper.with_dtype(IDX_DTYPE), Shift(..) | Reverse => mapper.with_same_dtype(), Boolean(func) => func.get_field(mapper), #[cfg(feature = "dtype-categorical")] @@ -241,7 +143,10 @@ impl FunctionExpr { dt => dt.clone(), }), #[cfg(feature = "interpolate")] - Interpolate(_) => mapper.with_same_dtype(), + Interpolate(method) => match method { + InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(), + InterpolationMethod::Nearest => mapper.with_same_dtype(), + }, ShrinkType => { // we return the smallest type this can return // this might not be correct once the actual data @@ -275,10 +180,46 @@ impl FunctionExpr { Fused(_) => mapper.map_to_supertype(), ConcatExpr(_) => mapper.map_to_supertype(), Correlation { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "peaks")] + PeakMin => mapper.with_same_dtype(), + #[cfg(feature = "peaks")] + PeakMax => mapper.with_same_dtype(), + #[cfg(feature = "cutqcut")] + Cut { + include_breaks: false, + .. + } => mapper.with_dtype(DataType::Categorical(None)), + #[cfg(feature = "cutqcut")] + Cut { + include_breaks: true, + .. + } => { + let name = fields[0].name(); + let name_bin = format!("{}_bin", name); + let struct_dt = DataType::Struct(vec![ + Field::new("brk", DataType::Float64), + Field::new(name_bin.as_str(), DataType::Categorical(None)), + ]); + mapper.with_dtype(struct_dt) + }, #[cfg(feature = "cutqcut")] - Cut { .. } => mapper.with_dtype(DataType::Categorical(None)), + QCut { + include_breaks: false, + .. + } => mapper.with_dtype(DataType::Categorical(None)), #[cfg(feature = "cutqcut")] - QCut { .. } => mapper.with_dtype(DataType::Categorical(None)), + QCut { + include_breaks: true, + .. + } => { + let name = fields[0].name(); + let name_bin = format!("{}_bin", name); + let struct_dt = DataType::Struct(vec![ + Field::new("brk", DataType::Float64), + Field::new(name_bin.as_str(), DataType::Categorical(None)), + ]); + mapper.with_dtype(struct_dt) + }, #[cfg(feature = "rle")] RLE => mapper.map_dtype(|dt| { DataType::Struct(vec![ @@ -292,47 +233,85 @@ impl FunctionExpr { #[cfg(feature = "random")] Random { .. } => mapper.with_same_dtype(), SetSortedFlag(_) => mapper.with_same_dtype(), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol } => unsafe { + plugin::plugin_field(fields, lib, &format!("__polars_field_{}", symbol.as_ref())) + }, + BackwardFill { .. } => mapper.with_same_dtype(), + ForwardFill { .. } => mapper.with_same_dtype(), + SumHorizontal => mapper.map_to_supertype(), + MaxHorizontal => mapper.map_to_supertype(), + MinHorizontal => mapper.map_to_supertype(), } } } -pub(super) struct FieldsMapper<'a> { +pub struct FieldsMapper<'a> { fields: &'a [Field], } impl<'a> FieldsMapper<'a> { + pub fn new(fields: &'a [Field]) -> Self { + Self { fields } + } + /// Field with the same dtype. - pub(super) fn with_same_dtype(&self) -> PolarsResult { + pub fn with_same_dtype(&self) -> PolarsResult { self.map_dtype(|dtype| dtype.clone()) } /// Set a dtype. - pub(super) fn with_dtype(&self, dtype: DataType) -> PolarsResult { + pub fn with_dtype(&self, dtype: DataType) -> PolarsResult { Ok(Field::new(self.fields[0].name(), dtype)) } /// Map a single dtype. - pub(super) fn map_dtype(&self, func: impl Fn(&DataType) -> DataType) -> PolarsResult { + pub fn map_dtype(&self, func: impl Fn(&DataType) -> DataType) -> PolarsResult { let dtype = func(self.fields[0].data_type()); Ok(Field::new(self.fields[0].name(), dtype)) } + pub fn get_fields_lens(&self) -> usize { + self.fields.len() + } + + /// Map a single field with a potentially failing mapper function. + pub fn try_map_field( + &self, + func: impl Fn(&Field) -> PolarsResult, + ) -> PolarsResult { + func(&self.fields[0]) + } + /// Map to a float supertype. - pub(super) fn map_to_float_dtype(&self) -> PolarsResult { + pub fn map_to_float_dtype(&self) -> PolarsResult { self.map_dtype(|dtype| match dtype { DataType::Float32 => DataType::Float32, _ => DataType::Float64, }) } + /// Map to a float supertype if numeric, else preserve + pub fn map_numeric_to_float_dtype(&self) -> PolarsResult { + self.map_dtype(|dtype| { + if dtype.is_numeric() { + match dtype { + DataType::Float32 => DataType::Float32, + _ => DataType::Float64, + } + } else { + dtype.clone() + } + }) + } + /// Map to a physical type. - pub(super) fn to_physical_type(&self) -> PolarsResult { + pub fn to_physical_type(&self) -> PolarsResult { self.map_dtype(|dtype| dtype.to_physical()) } /// Map a single dtype with a potentially failing mapper function. - #[cfg(any(feature = "timezones", feature = "dtype-array"))] - pub(super) fn try_map_dtype( + pub fn try_map_dtype( &self, func: impl Fn(&DataType) -> PolarsResult, ) -> PolarsResult { @@ -341,7 +320,7 @@ impl<'a> FieldsMapper<'a> { } /// Map all dtypes with a potentially failing mapper function. - pub(super) fn try_map_dtypes( + pub fn try_map_dtypes( &self, func: impl Fn(&[&DataType]) -> PolarsResult, ) -> PolarsResult { @@ -357,7 +336,7 @@ impl<'a> FieldsMapper<'a> { } /// Map the dtype to the "supertype" of all fields. - pub(super) fn map_to_supertype(&self) -> PolarsResult { + pub fn map_to_supertype(&self) -> PolarsResult { let mut first = self.fields[0].clone(); let mut st = first.data_type().clone(); for field in &self.fields[1..] { @@ -368,7 +347,7 @@ impl<'a> FieldsMapper<'a> { } /// Map the dtype to the dtype of the list elements. - pub(super) fn map_to_list_inner_dtype(&self) -> PolarsResult { + pub fn map_to_list_inner_dtype(&self) -> PolarsResult { let mut first = self.fields[0].clone(); let dt = first .data_type() @@ -379,73 +358,8 @@ impl<'a> FieldsMapper<'a> { Ok(first) } - #[cfg(feature = "temporal")] - pub(super) fn map_to_date_range_dtype( - &self, - every: &Duration, - time_unit: Option<&TimeUnit>, - time_zone: Option<&str>, - ) -> PolarsResult { - let data_dtype = self.map_to_supertype()?.dtype; - match data_dtype { - DataType::Datetime(tu, tz) => { - self.map_datetime_to_date_range_dtype(tu, tz, time_unit, time_zone) - }, - DataType::Date => { - let schema_dtype = self.map_date_to_date_range_dtype(every, time_unit, time_zone); - Ok(schema_dtype) - }, - _ => polars_bail!(ComputeError: "expected Date or Datetime, got {}", data_dtype), - } - } - #[cfg(feature = "temporal")] - fn map_datetime_to_date_range_dtype( - &self, - data_time_unit: TimeUnit, - data_time_zone: Option, - given_time_unit: Option<&TimeUnit>, - given_time_zone: Option<&str>, - ) -> PolarsResult { - let schema_time_zone = match (data_time_zone, given_time_zone) { - (Some(data_tz), Some(given_tz)) => { - polars_ensure!( - data_tz == given_tz, - ComputeError: format!( - "`time_zone` does not match the data\ - \n\nData has time zone '{}', got '{}'.", data_tz, given_tz) - ); - Some(data_tz) - }, - (_, Some(given_tz)) => Some(given_tz.to_string()), - (Some(data_tz), None) => Some(data_tz), - (_, _) => None, - }; - let schema_time_unit = given_time_unit.unwrap_or(&data_time_unit); - - let schema_dtype = DataType::Datetime(*schema_time_unit, schema_time_zone); - Ok(schema_dtype) - } - #[cfg(feature = "temporal")] - fn map_date_to_date_range_dtype( - &self, - every: &Duration, - time_unit: Option<&TimeUnit>, - time_zone: Option<&str>, - ) -> DataType { - let nsecs = every.nanoseconds(); - if nsecs == 0 { - DataType::Date - } else if let Some(tu) = time_unit { - DataType::Datetime(*tu, time_zone.map(String::from)) - } else if nsecs % 1000 != 0 { - DataType::Datetime(TimeUnit::Nanoseconds, time_zone.map(String::from)) - } else { - DataType::Datetime(TimeUnit::Microseconds, time_zone.map(String::from)) - } - } - /// Map the dtypes to the "supertype" of a list of lists. - pub(super) fn map_to_list_supertype(&self) -> PolarsResult { + pub fn map_to_list_supertype(&self) -> PolarsResult { self.try_map_dtypes(|dts| { let mut super_type_inner = None; @@ -471,7 +385,7 @@ impl<'a> FieldsMapper<'a> { /// Set the timezone of a datetime dtype. #[cfg(feature = "timezones")] - pub(super) fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult { + pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult { self.try_map_dtype(|dt| { if let DataType::Datetime(tu, _) = dt { Ok(DataType::Datetime(*tu, tz.cloned())) @@ -481,7 +395,7 @@ impl<'a> FieldsMapper<'a> { }) } - fn nested_sum_type(&self) -> PolarsResult { + pub fn nested_sum_type(&self) -> PolarsResult { let mut first = self.fields[0].clone(); use DataType::*; let dt = first.data_type().inner_dtype().cloned().unwrap_or(Unknown); @@ -495,7 +409,7 @@ impl<'a> FieldsMapper<'a> { } #[cfg(feature = "extract_jsonpath")] - pub(super) fn with_opt_dtype(&self, dtype: Option) -> PolarsResult { + pub fn with_opt_dtype(&self, dtype: Option) -> PolarsResult { let dtype = dtype.unwrap_or(DataType::Unknown); self.with_dtype(dtype) } diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 89a37c5dc170..296aefed0ae7 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -12,6 +12,9 @@ use serde::{Deserialize, Serialize}; static TZ_AWARE_RE: Lazy = Lazy::new(|| Regex::new(r"(%z)|(%:z)|(%::z)|(%:::z)|(%#z)|(^%\+$)").unwrap()); +#[cfg(feature = "dtype-struct")] +use polars_utils::format_smartstring; + use super::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -26,7 +29,7 @@ pub enum StringFunction { literal: bool, strict: bool, }, - CountMatch(String), + CountMatches(bool), EndsWith, Explode, Extract { @@ -41,15 +44,14 @@ pub enum StringFunction { }, #[cfg(feature = "string_from_radix")] FromRadix(u32, bool), - NChars, - Length, + LenBytes, + LenChars, #[cfg(feature = "string_justify")] LJust { width: usize, fillchar: char, }, Lowercase, - LStrip(Option), #[cfg(feature = "extract_jsonpath")] JsonExtract { dtype: Option, @@ -67,12 +69,23 @@ pub enum StringFunction { width: usize, fillchar: char, }, - RStrip(Option), Slice(i64, Option), StartsWith, - Strip(Option), + StripChars, + StripCharsStart, + StripCharsEnd, + StripPrefix, + StripSuffix, + #[cfg(feature = "dtype-struct")] + SplitExact { + n: usize, + inclusive: bool, + }, + #[cfg(feature = "dtype-struct")] + SplitN(usize), #[cfg(feature = "temporal")] Strptime(DataType, StrptimeOptions), + Split(bool), #[cfg(feature = "dtype-decimal")] ToDecimal(usize), #[cfg(feature = "nightly")] @@ -87,10 +100,10 @@ impl StringFunction { use StringFunction::*; match self { #[cfg(feature = "concat_str")] - ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(), + ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_dtype(DataType::Utf8), #[cfg(feature = "regex")] Contains { .. } => mapper.with_dtype(DataType::Boolean), - CountMatch(_) => mapper.with_dtype(DataType::UInt32), + CountMatches(_) => mapper.with_dtype(DataType::UInt32), EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), Explode => mapper.with_same_dtype(), Extract { .. } => mapper.with_same_dtype(), @@ -101,21 +114,39 @@ impl StringFunction { FromRadix { .. } => mapper.with_dtype(DataType::Int32), #[cfg(feature = "extract_jsonpath")] JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()), - Length => mapper.with_dtype(DataType::UInt32), - NChars => mapper.with_dtype(DataType::UInt32), + LenBytes => mapper.with_dtype(DataType::UInt32), + LenChars => mapper.with_dtype(DataType::UInt32), #[cfg(feature = "regex")] Replace { .. } => mapper.with_same_dtype(), #[cfg(feature = "temporal")] Strptime(dtype, _) => mapper.with_dtype(dtype.clone()), + Split(_) => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))), #[cfg(feature = "nightly")] Titlecase => mapper.with_same_dtype(), #[cfg(feature = "dtype-decimal")] ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)), - Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => { - mapper.with_same_dtype() - }, + Uppercase + | Lowercase + | StripChars + | StripCharsStart + | StripCharsEnd + | StripPrefix + | StripSuffix + | Slice(_, _) => mapper.with_same_dtype(), #[cfg(feature = "string_justify")] Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(), + #[cfg(feature = "dtype-struct")] + SplitExact { n, .. } => mapper.with_dtype(DataType::Struct( + (0..n + 1) + .map(|i| Field::from_owned(format_smartstring!("field_{i}"), DataType::Utf8)) + .collect(), + )), + #[cfg(feature = "dtype-struct")] + SplitN(n) => mapper.with_dtype(DataType::Struct( + (0..*n) + .map(|i| Field::from_owned(format_smartstring!("field_{i}"), DataType::Utf8)) + .collect(), + )), } } } @@ -125,7 +156,7 @@ impl Display for StringFunction { let s = match self { #[cfg(feature = "regex")] StringFunction::Contains { .. } => "contains", - StringFunction::CountMatch(_) => "count_match", + StringFunction::CountMatches(_) => "count_matches", StringFunction::EndsWith { .. } => "ends_with", StringFunction::Extract { .. } => "extract", #[cfg(feature = "concat_str")] @@ -141,21 +172,40 @@ impl Display for StringFunction { #[cfg(feature = "extract_jsonpath")] StringFunction::JsonExtract { .. } => "json_extract", #[cfg(feature = "string_justify")] - StringFunction::LJust { .. } => "str.ljust", - StringFunction::LStrip(_) => "lstrip", - StringFunction::Length => "str_lengths", + StringFunction::LJust { .. } => "ljust", + StringFunction::LenBytes => "len_bytes", StringFunction::Lowercase => "lowercase", - StringFunction::NChars => "n_chars", + StringFunction::LenChars => "len_chars", #[cfg(feature = "string_justify")] StringFunction::RJust { .. } => "rjust", - StringFunction::RStrip(_) => "rstrip", #[cfg(feature = "regex")] StringFunction::Replace { .. } => "replace", - StringFunction::Slice(_, _) => "str_slice", + StringFunction::Slice(_, _) => "slice", StringFunction::StartsWith { .. } => "starts_with", - StringFunction::Strip(_) => "strip", + StringFunction::StripChars => "strip_chars", + StringFunction::StripCharsStart => "strip_chars_start", + StringFunction::StripCharsEnd => "strip_chars_end", + StringFunction::StripPrefix => "strip_prefix", + StringFunction::StripSuffix => "strip_suffix", + #[cfg(feature = "dtype-struct")] + StringFunction::SplitExact { inclusive, .. } => { + if *inclusive { + "split_exact" + } else { + "split_exact_inclusive" + } + }, + #[cfg(feature = "dtype-struct")] + StringFunction::SplitN(_) => "splitn", #[cfg(feature = "temporal")] StringFunction::Strptime(_, _) => "strptime", + StringFunction::Split(inclusive) => { + if *inclusive { + "split" + } else { + "split_inclusive" + } + }, #[cfg(feature = "nightly")] StringFunction::Titlecase => "titlecase", #[cfg(feature = "dtype-decimal")] @@ -184,113 +234,36 @@ pub(super) fn titlecase(s: &Series) -> PolarsResult { Ok(ca.to_titlecase().into_series()) } -pub(super) fn n_chars(s: &Series) -> PolarsResult { +pub(super) fn len_chars(s: &Series) -> PolarsResult { let ca = s.utf8()?; - Ok(ca.str_n_chars().into_series()) + Ok(ca.str_len_chars().into_series()) } -pub(super) fn lengths(s: &Series) -> PolarsResult { +pub(super) fn len_bytes(s: &Series) -> PolarsResult { let ca = s.utf8()?; - Ok(ca.str_lengths().into_series()) + Ok(ca.str_len_bytes().into_series()) } #[cfg(feature = "regex")] pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResult { - // TODO! move to polars-ops let ca = s[0].utf8()?; let pat = s[1].utf8()?; - - let mut out: BooleanChunked = match pat.len() { - 1 => match pat.get(0) { - Some(pat) => { - if literal { - ca.contains_literal(pat)? - } else { - ca.contains(pat, strict)? - } - }, - None => BooleanChunked::full(ca.name(), false, ca.len()), - }, - _ => { - if literal { - ca.into_iter() - .zip(pat) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(pat)) => src.contains(pat), - _ => false, - }) - .collect_trusted() - } else if strict { - ca.into_iter() - .zip(pat) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(pat)) => { - let re = Regex::new(pat)?; - Ok(re.is_match(src)) - }, - _ => Ok(false), - }) - .collect::>()? - } else { - ca.into_iter() - .zip(pat) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(pat)) => Regex::new(pat).ok().map(|re| re.is_match(src)), - _ => Some(false), - }) - .collect_trusted() - } - }, - }; - - out.rename(ca.name()); - Ok(out.into_series()) + ca.contains_chunked(pat, literal, strict) + .map(|ok| ok.into_series()) } pub(super) fn ends_with(s: &[Series]) -> PolarsResult { - let ca = s[0].utf8()?; - let sub = s[1].utf8()?; - - let mut out: BooleanChunked = match sub.len() { - 1 => match sub.get(0) { - Some(s) => ca.ends_with(s), - None => BooleanChunked::full(ca.name(), false, ca.len()), - }, - _ => ca - .into_iter() - .zip(sub) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(val)) => src.ends_with(val), - _ => false, - }) - .collect_trusted(), - }; + let ca = &s[0].utf8()?.as_binary(); + let suffix = &s[1].utf8()?.as_binary(); - out.rename(ca.name()); - Ok(out.into_series()) + Ok(ca.ends_with_chunked(suffix).into_series()) } pub(super) fn starts_with(s: &[Series]) -> PolarsResult { - let ca = s[0].utf8()?; - let sub = s[1].utf8()?; + let ca = &s[0].utf8()?.as_binary(); + let prefix = &s[1].utf8()?.as_binary(); - let mut out: BooleanChunked = match sub.len() { - 1 => match sub.get(0) { - Some(s) => ca.starts_with(s), - None => BooleanChunked::full(ca.name(), false, ca.len()), - }, - _ => ca - .into_iter() - .zip(sub) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(val)) => src.starts_with(val), - _ => false, - }) - .collect_trusted(), - }; - - out.rename(ca.name()); - Ok(out.into_series()) + Ok(ca.starts_with_chunked(prefix).into_series()) } /// Extract a regex pattern from the a string value. @@ -325,67 +298,34 @@ pub(super) fn rjust(s: &Series, width: usize, fillchar: char) -> PolarsResult) -> PolarsResult { - let ca = s.utf8()?; - if let Some(matches) = matches { - if matches.chars().count() == 1 { - // Fast path for when a single character is passed - Ok(ca - .apply_values(|s| Cow::Borrowed(s.trim_matches(matches.chars().next().unwrap()))) - .into_series()) - } else { - Ok(ca - .apply_values(|s| Cow::Borrowed(s.trim_matches(|c| matches.contains(c)))) - .into_series()) - } - } else { - Ok(ca.apply_values(|s| Cow::Borrowed(s.trim())).into_series()) - } +pub(super) fn strip_chars(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let pat_s = &s[1]; + ca.strip_chars(pat_s).map(|ok| ok.into_series()) } -pub(super) fn lstrip(s: &Series, matches: Option<&str>) -> PolarsResult { - let ca = s.utf8()?; +pub(super) fn strip_chars_start(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let pat_s = &s[1]; + ca.strip_chars_start(pat_s).map(|ok| ok.into_series()) +} - if let Some(matches) = matches { - if matches.chars().count() == 1 { - // Fast path for when a single character is passed - Ok(ca - .apply_values(|s| { - Cow::Borrowed(s.trim_start_matches(matches.chars().next().unwrap())) - }) - .into_series()) - } else { - Ok(ca - .apply_values(|s| Cow::Borrowed(s.trim_start_matches(|c| matches.contains(c)))) - .into_series()) - } - } else { - Ok(ca - .apply_values(|s| Cow::Borrowed(s.trim_start())) - .into_series()) - } +pub(super) fn strip_chars_end(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let pat_s = &s[1]; + ca.strip_chars_end(pat_s).map(|ok| ok.into_series()) } -pub(super) fn rstrip(s: &Series, matches: Option<&str>) -> PolarsResult { - let ca = s.utf8()?; - if let Some(matches) = matches { - if matches.chars().count() == 1 { - // Fast path for when a single character is passed - Ok(ca - .apply_values(|s| { - Cow::Borrowed(s.trim_end_matches(matches.chars().next().unwrap())) - }) - .into_series()) - } else { - Ok(ca - .apply_values(|s| Cow::Borrowed(s.trim_end_matches(|c| matches.contains(c)))) - .into_series()) - } - } else { - Ok(ca - .apply_values(|s| Cow::Borrowed(s.trim_end())) - .into_series()) - } +pub(super) fn strip_prefix(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let prefix = s[1].utf8()?; + Ok(ca.strip_prefix(prefix).into_series()) +} + +pub(super) fn strip_suffix(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let suffix = s[1].utf8()?; + Ok(ca.strip_suffix(suffix).into_series()) } pub(super) fn extract_all(args: &[Series]) -> PolarsResult { @@ -396,20 +336,36 @@ pub(super) fn extract_all(args: &[Series]) -> PolarsResult { let pat = pat.utf8()?; if pat.len() == 1 { - let pat = pat - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "expected a pattern, got null"))?; - ca.extract_all(pat).map(|ca| ca.into_series()) + if let Some(pat) = pat.get(0) { + ca.extract_all(pat).map(|ca| ca.into_series()) + } else { + Ok(Series::full_null( + ca.name(), + ca.len(), + &DataType::List(Box::new(DataType::Utf8)), + )) + } } else { ca.extract_all_many(pat).map(|ca| ca.into_series()) } } -pub(super) fn count_match(s: &Series, pat: &str) -> PolarsResult { - let pat = pat.to_string(); +pub(super) fn count_matches(args: &[Series], literal: bool) -> PolarsResult { + let s = &args[0]; + let pat = &args[1]; let ca = s.utf8()?; - ca.count_match(&pat).map(|ca| ca.into_series()) + let pat = pat.utf8()?; + if pat.len() == 1 { + if let Some(pat) = pat.get(0) { + ca.count_matches(pat, literal).map(|ca| ca.into_series()) + } else { + Ok(Series::full_null(ca.name(), ca.len(), &DataType::UInt32)) + } + } else { + ca.count_matches_many(pat, literal) + .map(|ca| ca.into_series()) + } } #[cfg(feature = "temporal")] @@ -428,6 +384,37 @@ pub(super) fn strptime( } } +#[cfg(feature = "dtype-struct")] +pub(super) fn split_exact(s: &[Series], n: usize, inclusive: bool) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + if inclusive { + ca.split_exact_inclusive(by, n).map(|ca| ca.into_series()) + } else { + ca.split_exact(by, n).map(|ca| ca.into_series()) + } +} + +#[cfg(feature = "dtype-struct")] +pub(super) fn splitn(s: &[Series], n: usize) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + ca.splitn(by, n).map(|ca| ca.into_series()) +} + +pub(super) fn split(s: &[Series], inclusive: bool) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + if inclusive { + Ok(ca.split_inclusive(by).into_series()) + } else { + Ok(ca.split(by).into_series()) + } +} + fn handle_temporal_parsing_error( ca: &Utf8Chunked, out: &Series, @@ -561,12 +548,19 @@ fn to_time(s: &Series, options: &StrptimeOptions) -> PolarsResult { #[cfg(feature = "concat_str")] pub(super) fn concat(s: &Series, delimiter: &str) -> PolarsResult { - Ok(s.str_concat(delimiter).into_series()) + let str_s = s.cast(&DataType::Utf8)?; + let concat = polars_ops::chunked_array::str_concat(str_s.utf8()?, delimiter); + Ok(concat.into_series()) } #[cfg(feature = "concat_str")] -pub(super) fn concat_hor(s: &[Series], delimiter: &str) -> PolarsResult { - polars_core::functions::concat_str(s, delimiter).map(|ca| ca.into_series()) +pub(super) fn concat_hor(series: &[Series], delimiter: &str) -> PolarsResult { + let str_series: Vec<_> = series + .iter() + .map(|s| s.cast(&DataType::Utf8)) + .collect::>()?; + let cas: Vec<_> = str_series.iter().map(|s| s.utf8().unwrap()).collect(); + Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter)?.into_series()) } impl From for FunctionExpr { @@ -739,7 +733,7 @@ pub(super) fn from_radix(s: &Series, radix: u32, strict: bool) -> PolarsResult) -> PolarsResult { let ca = s.utf8()?; - ca.str_slice(start, length).map(|ca| ca.into_series()) + Ok(ca.str_slice(start, length).into_series()) } pub(super) fn explode(s: &Series) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/struct_.rs b/crates/polars-plan/src/dsl/function_expr/struct_.rs index 7d9522d133a9..46c3c5b53108 100644 --- a/crates/polars-plan/src/dsl/function_expr/struct_.rs +++ b/crates/polars-plan/src/dsl/function_expr/struct_.rs @@ -7,14 +7,67 @@ use super::*; pub enum StructFunction { FieldByIndex(i64), FieldByName(Arc), + RenameFields(Arc>), +} + +impl StructFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use StructFunction::*; + + match self { + FieldByIndex(index) => mapper.try_map_field(|field| { + let (index, _) = slice_offsets(*index, 0, mapper.get_fields_lens()); + if let DataType::Struct(ref fields) = field.dtype { + fields.get(index).cloned().ok_or_else( + || polars_err!(ComputeError: "index out of bounds in `struct.field`"), + ) + } else { + polars_bail!( + ComputeError: "expected struct dtype, got: `{}`", &field.dtype + ) + } + }), + FieldByName(name) => mapper.try_map_field(|field| { + if let DataType::Struct(ref fields) = field.dtype { + let fld = fields + .iter() + .find(|fld| fld.name() == name.as_ref()) + .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name.as_ref()))?; + Ok(fld.clone()) + } else { + polars_bail!(StructFieldNotFound: "{}", name.as_ref()); + } + }), + RenameFields(names) => mapper.map_dtype(|dt| match dt { + DataType::Struct(fields) => { + let fields = fields + .iter() + .zip(names.as_ref()) + .map(|(fld, name)| Field::new(name, fld.data_type().clone())) + .collect(); + DataType::Struct(fields) + }, + // The types will be incorrect, but its better than nothing + // we can get an incorrect type with python lambdas, because we only know return type when running + // the query + dt => DataType::Struct( + names + .iter() + .map(|name| Field::new(name, dt.clone())) + .collect(), + ), + }), + } + } } impl Display for StructFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use self::*; + use StructFunction::*; match self { - StructFunction::FieldByIndex(_) => write!(f, "struct.field_by_name"), - StructFunction::FieldByName(_) => write!(f, "struct.field_by_index"), + FieldByIndex(index) => write!(f, "struct.field_by_index({index})"), + FieldByName(name) => write!(f, "struct.field_by_name({name})"), + RenameFields(names) => write!(f, "struct.rename_fields({:?})", names), } } } @@ -31,3 +84,18 @@ pub(super) fn get_by_name(s: &Series, name: Arc) -> PolarsResult { let ca = s.struct_()?; ca.field_by_name(name.as_ref()) } + +pub(super) fn rename_fields(s: &Series, names: Arc>) -> PolarsResult { + let ca = s.struct_()?; + let fields = ca + .fields() + .iter() + .zip(names.as_ref()) + .map(|(s, name)| { + let mut s = s.clone(); + s.rename(name); + s + }) + .collect::>(); + StructChunked::new(ca.name(), &fields).map(|ca| ca.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/temporal.rs b/crates/polars-plan/src/dsl/function_expr/temporal.rs index 3c75daba6e76..15e601dd7dbb 100644 --- a/crates/polars-plan/src/dsl/function_expr/temporal.rs +++ b/crates/polars-plan/src/dsl/function_expr/temporal.rs @@ -1,6 +1,7 @@ #[cfg(feature = "date_offset")] use polars_arrow::time_zone::Tz; -use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; +#[cfg(feature = "date_offset")] +use polars_core::chunked_array::ops::arity::try_binary_elementwise; #[cfg(feature = "date_offset")] use polars_time::prelude::*; @@ -89,7 +90,7 @@ pub(super) fn datetime( .map(|ndt| match time_unit { TimeUnit::Milliseconds => ndt.timestamp_millis(), TimeUnit::Microseconds => ndt.timestamp_micros(), - TimeUnit::Nanoseconds => ndt.timestamp_nanos(), + TimeUnit::Nanoseconds => ndt.timestamp_nanos_opt().unwrap(), }) } else { None @@ -119,44 +120,93 @@ pub(super) fn datetime( } #[cfg(feature = "date_offset")] -pub(super) fn date_offset(s: Series, offset: Duration) -> PolarsResult { +fn apply_offsets_to_datetime( + datetime: &Logical, + offsets: &Utf8Chunked, + offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult, + time_zone: Option<&Tz>, +) -> PolarsResult { + match (datetime.len(), offsets.len()) { + (1, _) => match datetime.0.get(0) { + Some(dt) => offsets.try_apply_values_generic(|offset| { + offset_fn(&Duration::parse(offset), dt, time_zone) + }), + _ => Ok(Int64Chunked::full_null(datetime.0.name(), offsets.len())), + }, + (_, 1) => match offsets.get(0) { + Some(offset) => datetime + .0 + .try_apply(|v| offset_fn(&Duration::parse(offset), v, time_zone)), + _ => Ok(datetime.0.apply(|_| None)), + }, + _ => try_binary_elementwise(datetime, offsets, |timestamp_opt, offset_opt| { + match (timestamp_opt, offset_opt) { + (Some(timestamp), Some(offset)) => { + offset_fn(&Duration::parse(offset), timestamp, time_zone).map(Some) + }, + _ => Ok(None), + } + }), + } +} + +#[cfg(feature = "date_offset")] +pub(super) fn date_offset(s: &[Series]) -> PolarsResult { + let ts = &s[0]; + let offsets = &s[1].utf8().unwrap(); + let preserve_sortedness: bool; - let out = match s.dtype().clone() { + let out = match ts.dtype() { DataType::Date => { - let s = s + let ts = ts .cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) .unwrap(); - preserve_sortedness = true; - date_offset(s, offset).and_then(|s| s.cast(&DataType::Date)) + let datetime = ts.datetime().unwrap(); + let out = apply_offsets_to_datetime(datetime, offsets, Duration::add_ms, None)?; + // sortedness is only guaranteed to be preserved if a constant offset is being added to every datetime + preserve_sortedness = match offsets.len() { + 1 => offsets.get(0).is_some(), + _ => false, + }; + out.cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) + .unwrap() + .cast(&DataType::Date) }, DataType::Datetime(tu, tz) => { - let ca = s.datetime().unwrap(); + let datetime = ts.datetime().unwrap(); - fn offset_fn(tu: TimeUnit) -> fn(&Duration, i64, Option<&Tz>) -> PolarsResult { - match tu { - TimeUnit::Nanoseconds => Duration::add_ns, - TimeUnit::Microseconds => Duration::add_us, - TimeUnit::Milliseconds => Duration::add_ms, - } - } + let offset_fn = match tu { + TimeUnit::Nanoseconds => Duration::add_ns, + TimeUnit::Microseconds => Duration::add_us, + TimeUnit::Milliseconds => Duration::add_ms, + }; let out = match tz { #[cfg(feature = "timezones")] - Some(ref tz) => { - let offset_fn = offset_fn(tu); - ca.0.try_apply(|v| offset_fn(&offset, v, tz.parse::().ok().as_ref())) - }, - _ => { - let offset_fn = offset_fn(tu); - ca.0.try_apply(|v| offset_fn(&offset, v, None)) - }, - }?; + Some(ref tz) => apply_offsets_to_datetime( + datetime, + offsets, + offset_fn, + tz.parse::().ok().as_ref(), + )?, + _ => apply_offsets_to_datetime(datetime, offsets, offset_fn, None)?, + }; // Sortedness may not be preserved when crossing daylight savings time boundaries // for calendar-aware durations. // Constant durations (e.g. 2 hours) always preserve sortedness. - preserve_sortedness = - tz.is_none() || tz.as_deref() == Some("UTC") || offset.is_constant_duration(); - out.cast(&DataType::Datetime(tu, tz)) + preserve_sortedness = match offsets.len() { + 1 => match offsets.get(0) { + Some(offset) => { + let offset = Duration::parse(offset); + tz.is_none() + || tz.as_deref() == Some("UTC") + || offset.is_constant_duration() + }, + None => false, + }, + _ => false, + }; + out.cast(&DataType::Datetime(*tu, tz.clone())) }, dt => polars_bail!( ComputeError: "cannot use 'date_offset' on Series of datatype {}", dt, @@ -164,11 +214,14 @@ pub(super) fn date_offset(s: Series, offset: Duration) -> PolarsResult { }; if preserve_sortedness { out.map(|mut out| { - out.set_sorted_flag(s.is_sorted_flag()); + out.set_sorted_flag(ts.is_sorted_flag()); out }) } else { - out + out.map(|mut out| { + out.set_sorted_flag(IsSorted::Not); + out + }) } } @@ -200,289 +253,3 @@ pub(super) fn combine(s: &[Series], tu: TimeUnit) -> PolarsResult { _ => Ok(result_naive), } } - -pub(super) fn temporal_range_dispatch( - s: &[Series], - name: &str, - every: Duration, - closed: ClosedWindow, - time_unit: Option, - time_zone: Option, -) -> PolarsResult { - let start = &s[0]; - let stop = &s[1]; - - polars_ensure!( - start.len() == stop.len(), - ComputeError: "'start' and 'stop' should have the same length", - ); - const TO_MS: i64 = SECONDS_IN_DAY * 1000; - - // Note: `start` and `stop` have already been cast to their supertype, - // so only `start`'s dtype needs to be matched against. - #[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block. - let mut dtype = match (start.dtype(), time_unit) { - (DataType::Date, time_unit) => { - let nsecs = every.nanoseconds(); - if nsecs == 0 { - DataType::Date - } else if let Some(tu) = time_unit { - DataType::Datetime(tu, None) - } else if nsecs % 1_000 != 0 { - DataType::Datetime(TimeUnit::Nanoseconds, None) - } else { - DataType::Datetime(TimeUnit::Microseconds, None) - } - }, - (DataType::Time, _) => DataType::Time, - // overwrite nothing, keep as-is - (DataType::Datetime(_, _), None) => start.dtype().clone(), - // overwrite time unit, keep timezone - (DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()), - _ => unreachable!(), - }; - - let (mut start, mut stop) = match dtype { - #[cfg(feature = "timezones")] - DataType::Datetime(_, Some(_)) => ( - polars_ops::prelude::replace_time_zone( - start.cast(&dtype)?.datetime().unwrap(), - None, - &Utf8Chunked::from_iter(std::iter::once("raise")), - )? - .into_series() - .to_physical_repr() - .cast(&DataType::Int64)?, - polars_ops::prelude::replace_time_zone( - stop.cast(&dtype)?.datetime().unwrap(), - None, - &Utf8Chunked::from_iter(std::iter::once("raise")), - )? - .into_series() - .to_physical_repr() - .cast(&DataType::Int64)?, - ), - _ => ( - start - .cast(&dtype)? - .to_physical_repr() - .cast(&DataType::Int64)?, - stop.cast(&dtype)? - .to_physical_repr() - .cast(&DataType::Int64)?, - ), - }; - - if dtype == DataType::Date { - start = &start * TO_MS; - stop = &stop * TO_MS; - } - - // overwrite time zone, if specified - match (&dtype, &time_zone) { - #[cfg(feature = "timezones")] - (DataType::Datetime(tu, _), Some(tz)) => { - dtype = DataType::Datetime(*tu, Some(tz.clone())); - }, - _ => {}, - }; - - let start = start.get(0).unwrap().extract::().unwrap(); - let stop = stop.get(0).unwrap().extract::().unwrap(); - - let out = match dtype { - DataType::Date => date_range_impl( - name, - start, - stop, - every, - closed, - TimeUnit::Milliseconds, - None, - )?, - DataType::Datetime(tu, ref tz) => { - date_range_impl(name, start, stop, every, closed, tu, tz.as_ref())? - }, - DataType::Time => date_range_impl( - name, - start, - stop, - every, - closed, - TimeUnit::Nanoseconds, - None, - )?, - _ => unimplemented!(), - }; - Ok(out.cast(&dtype).unwrap().into_series()) -} - -pub(super) fn temporal_ranges_dispatch( - s: &[Series], - name: &str, - every: Duration, - closed: ClosedWindow, - time_unit: Option, - time_zone: Option, -) -> PolarsResult { - let start = &s[0]; - let stop = &s[1]; - - polars_ensure!( - start.len() == stop.len(), - ComputeError: "'start' and 'stop' should have the same length", - ); - const TO_MS: i64 = SECONDS_IN_DAY * 1000; - - // Note: `start` and `stop` have already been cast to their supertype, - // so only `start`'s dtype needs to be matched against. - #[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block. - let mut dtype = match (start.dtype(), time_unit) { - (DataType::Date, time_unit) => { - let nsecs = every.nanoseconds(); - if nsecs == 0 { - DataType::Date - } else if let Some(tu) = time_unit { - DataType::Datetime(tu, None) - } else if nsecs % 1_000 != 0 { - DataType::Datetime(TimeUnit::Nanoseconds, None) - } else { - DataType::Datetime(TimeUnit::Microseconds, None) - } - }, - (DataType::Time, _) => DataType::Time, - // overwrite nothing, keep as-is - (DataType::Datetime(_, _), None) => start.dtype().clone(), - // overwrite time unit, keep timezone - (DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()), - _ => unreachable!(), - }; - - let (mut start, mut stop) = match dtype { - #[cfg(feature = "timezones")] - DataType::Datetime(_, Some(_)) => ( - polars_ops::prelude::replace_time_zone( - start.cast(&dtype)?.datetime().unwrap(), - None, - &Utf8Chunked::from_iter(std::iter::once("raise")), - )? - .into_series() - .to_physical_repr() - .cast(&DataType::Int64)?, - polars_ops::prelude::replace_time_zone( - stop.cast(&dtype)?.datetime().unwrap(), - None, - &Utf8Chunked::from_iter(std::iter::once("raise")), - )? - .into_series() - .to_physical_repr() - .cast(&DataType::Int64)?, - ), - _ => ( - start - .cast(&dtype)? - .to_physical_repr() - .cast(&DataType::Int64)?, - stop.cast(&dtype)? - .to_physical_repr() - .cast(&DataType::Int64)?, - ), - }; - - if dtype == DataType::Date { - start = &start * TO_MS; - stop = &stop * TO_MS; - } - - // overwrite time zone, if specified - match (&dtype, &time_zone) { - #[cfg(feature = "timezones")] - (DataType::Datetime(tu, _), Some(tz)) => { - dtype = DataType::Datetime(*tu, Some(tz.clone())); - }, - _ => {}, - }; - - let start = start.i64().unwrap(); - let stop = stop.i64().unwrap(); - - let list = match dtype { - DataType::Date => { - let mut builder = ListPrimitiveChunkedBuilder::::new( - name, - start.len(), - start.len() * 5, - DataType::Int32, - ); - for (start, stop) in start.into_iter().zip(stop) { - match (start, stop) { - (Some(start), Some(stop)) => { - let rng = date_range_impl( - "", - start, - stop, - every, - closed, - TimeUnit::Milliseconds, - None, - )?; - let rng = rng.cast(&DataType::Date).unwrap(); - let rng = rng.to_physical_repr(); - let rng = rng.i32().unwrap(); - builder.append_slice(rng.cont_slice().unwrap()) - }, - _ => builder.append_null(), - } - } - builder.finish().into_series() - }, - DataType::Datetime(tu, ref tz) => { - let mut builder = ListPrimitiveChunkedBuilder::::new( - name, - start.len(), - start.len() * 5, - DataType::Int64, - ); - for (start, stop) in start.into_iter().zip(stop) { - match (start, stop) { - (Some(start), Some(stop)) => { - let rng = date_range_impl("", start, stop, every, closed, tu, tz.as_ref())?; - builder.append_slice(rng.cont_slice().unwrap()) - }, - _ => builder.append_null(), - } - } - builder.finish().into_series() - }, - DataType::Time => { - let mut builder = ListPrimitiveChunkedBuilder::::new( - name, - start.len(), - start.len() * 5, - DataType::Int64, - ); - for (start, stop) in start.into_iter().zip(stop) { - match (start, stop) { - (Some(start), Some(stop)) => { - let rng = date_range_impl( - "", - start, - stop, - every, - closed, - TimeUnit::Nanoseconds, - None, - )?; - builder.append_slice(rng.cont_slice().unwrap()) - }, - _ => builder.append_null(), - } - } - builder.finish().into_series() - }, - _ => unimplemented!(), - }; - - let to_type = DataType::List(Box::new(dtype)); - list.cast(&to_type) -} diff --git a/crates/polars-plan/src/dsl/functions/arity.rs b/crates/polars-plan/src/dsl/functions/arity.rs index e37b158d19ee..a8735462181d 100644 --- a/crates/polars-plan/src/dsl/functions/arity.rs +++ b/crates/polars-plan/src/dsl/functions/arity.rs @@ -11,9 +11,9 @@ macro_rules! prepare_binary_function { }; } -/// Apply a closure on the two columns that are evaluated from `Expr` a and `Expr` b. +/// Apply a closure on the two columns that are evaluated from [`Expr`] a and [`Expr`] b. /// -/// The closure takes two arguments, each a `Series`. `output_type` must be the output dtype of the resulting `Series`. +/// The closure takes two arguments, each a [`Series`]. `output_type` must be the output dtype of the resulting [`Series`]. pub fn map_binary(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr where F: Fn(Series, Series) -> PolarsResult> + Send + Sync, diff --git a/crates/polars-plan/src/dsl/functions/coerce.rs b/crates/polars-plan/src/dsl/functions/coerce.rs index e28a1697eefd..e009e9e61918 100644 --- a/crates/polars-plan/src/dsl/functions/coerce.rs +++ b/crates/polars-plan/src/dsl/functions/coerce.rs @@ -3,16 +3,15 @@ use super::*; /// Take several expressions and collect them into a [`StructChunked`]. #[cfg(feature = "dtype-struct")] -pub fn as_struct(exprs: &[Expr]) -> Expr { - map_multiple( - |s| StructChunked::new(s[0].name(), s).map(|ca| Some(ca.into_series())), - exprs, - GetOutput::map_fields(|fld| Field::new(fld[0].name(), DataType::Struct(fld.to_vec()))), - ) - .with_function_options(|mut options| { - options.input_wildcard_expansion = true; - options.fmt_str = "as_struct"; - options.pass_name_to_apply = true; - options - }) +pub fn as_struct(exprs: Vec) -> Expr { + Expr::Function { + input: exprs, + function: FunctionExpr::AsStruct, + options: FunctionOptions { + input_wildcard_expansion: true, + pass_name_to_apply: true, + collect_groups: ApplyOptions::ApplyFlat, + ..Default::default() + }, + } } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index cb6728653f26..8517ef115387 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -236,11 +236,18 @@ pub fn max_horizontal>(exprs: E) -> Expr { if exprs.is_empty() { return Expr::Columns(Vec::new()); } - let func = |s1, s2| { - let df = DataFrame::new_no_checks(vec![s1, s2]); - df.hmax() - }; - reduce_exprs(func, exprs).alias("max") + + Expr::Function { + input: exprs, + function: FunctionExpr::MaxHorizontal, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyFlat, + input_wildcard_expansion: true, + auto_explode: true, + allow_rename: true, + ..Default::default() + }, + } } /// Create a new column with the the minimum value per row. @@ -251,25 +258,38 @@ pub fn min_horizontal>(exprs: E) -> Expr { if exprs.is_empty() { return Expr::Columns(Vec::new()); } - let func = |s1, s2| { - let df = DataFrame::new_no_checks(vec![s1, s2]); - df.hmin() - }; - reduce_exprs(func, exprs).alias("min") + + Expr::Function { + input: exprs, + function: FunctionExpr::MinHorizontal, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyFlat, + input_wildcard_expansion: true, + auto_explode: true, + allow_rename: true, + ..Default::default() + }, + } } /// Create a new column with the the sum of the values in each row. /// /// The name of the resulting column will be `"sum"`; use [`alias`](Expr::alias) to choose a different name. pub fn sum_horizontal>(exprs: E) -> Expr { - let mut exprs = exprs.as_ref().to_vec(); - let func = |s1, s2| Ok(Some(&s1 + &s2)); - let init = match exprs.pop() { - Some(e) => e, - // use u32 as that is not cast to float as eagerly - _ => lit(0u32), - }; - fold_exprs(init, func, exprs).alias("sum") + let exprs = exprs.as_ref().to_vec(); + + Expr::Function { + input: exprs, + function: FunctionExpr::SumHorizontal, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyFlat, + input_wildcard_expansion: true, + auto_explode: true, + cast_to_supertypes: false, + allow_rename: true, + ..Default::default() + }, + } } /// Folds the expressions from left to right keeping the first non-null values. @@ -281,7 +301,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr { input, function: FunctionExpr::Coalesce, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::ApplyFlat, cast_to_supertypes: true, input_wildcard_expansion: true, ..Default::default() diff --git a/crates/polars-plan/src/dsl/functions/mod.rs b/crates/polars-plan/src/dsl/functions/mod.rs index 56d72d4706ee..f95be89c0af6 100644 --- a/crates/polars-plan/src/dsl/functions/mod.rs +++ b/crates/polars-plan/src/dsl/functions/mod.rs @@ -1,14 +1,15 @@ //! # Functions //! //! Functions on expressions that might be useful. -//! mod arity; mod coerce; mod concat; mod correlation; mod horizontal; mod index; +#[cfg(feature = "range")] mod range; +mod repeat; mod selectors; mod syntactic_sugar; mod temporal; @@ -20,14 +21,18 @@ pub use correlation::*; pub use horizontal::*; pub use index::*; #[cfg(feature = "temporal")] -use polars_core::export::arrow::temporal_conversions::NANOSECONDS; +use polars_core::export::arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; #[cfg(feature = "temporal")] use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; #[cfg(feature = "dtype-struct")] use polars_core::utils::get_supertype; +#[cfg(all(feature = "range", feature = "temporal"))] +pub use range::date_range; // This shouldn't be necessary, but clippy complains about dead code +#[cfg(all(feature = "range", feature = "dtype-time"))] +pub use range::time_range; // This shouldn't be necessary, but clippy complains about dead code +#[cfg(feature = "range")] pub use range::*; -#[cfg(feature = "temporal")] -pub use range::{date_range, time_range}; +pub use repeat::*; pub use selectors::*; pub use syntactic_sugar::*; pub use temporal::*; diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs index cf1093b154a5..89c3e6bf6312 100644 --- a/crates/polars-plan/src/dsl/functions/range.rs +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -3,12 +3,10 @@ use super::*; /// Generate a range of integers. /// /// Alias for `int_range`. -#[cfg(feature = "range")] pub fn arange(start: Expr, end: Expr, step: i64) -> Expr { int_range(start, end, step) } -#[cfg(feature = "range")] /// Generate a range of integers. pub fn int_range(start: Expr, end: Expr, step: i64) -> Expr { let input = vec![start, end]; @@ -23,7 +21,6 @@ pub fn int_range(start: Expr, end: Expr, step: i64) -> Expr { } } -#[cfg(feature = "range")] /// Generate a range of integers for each row of the input columns. pub fn int_ranges(start: Expr, end: Expr, step: i64) -> Expr { let input = vec![start, end]; @@ -38,39 +35,12 @@ pub fn int_ranges(start: Expr, end: Expr, step: i64) -> Expr { } } -pub trait Range { - fn into_range(self, high: T) -> Expr; -} - -macro_rules! impl_into_range { - ($dt: ty) => { - impl Range<$dt> for $dt { - fn into_range(self, high: $dt) -> Expr { - Expr::Literal(LiteralValue::Range { - low: self as i64, - high: high as i64, - data_type: DataType::Int32, - }) - } - } - }; -} - -impl_into_range!(i32); -impl_into_range!(i64); -impl_into_range!(u32); - -/// Create a range literal. -pub fn range>(low: T, high: T) -> Expr { - low.into_range(high) -} - /// Create a date range from a `start` and `stop` expression. #[cfg(feature = "temporal")] pub fn date_range( start: Expr, end: Expr, - every: Duration, + interval: Duration, closed: ClosedWindow, time_unit: Option, time_zone: Option, @@ -79,8 +49,8 @@ pub fn date_range( Expr::Function { input, - function: FunctionExpr::TemporalExpr(TemporalFunction::DateRange { - every, + function: FunctionExpr::Range(RangeFunction::DateRange { + interval, closed, time_unit, time_zone, @@ -99,7 +69,7 @@ pub fn date_range( pub fn date_ranges( start: Expr, end: Expr, - every: Duration, + interval: Duration, closed: ClosedWindow, time_unit: Option, time_zone: Option, @@ -108,8 +78,8 @@ pub fn date_ranges( Expr::Function { input, - function: FunctionExpr::TemporalExpr(TemporalFunction::DateRanges { - every, + function: FunctionExpr::Range(RangeFunction::DateRanges { + interval, closed, time_unit, time_zone, @@ -123,54 +93,92 @@ pub fn date_ranges( } } -/// Generate a time range. -#[cfg(feature = "temporal")] -pub fn time_range(start: Expr, end: Expr, every: Duration, closed: ClosedWindow) -> Expr { +/// Create a datetime range from a `start` and `stop` expression. +#[cfg(feature = "dtype-datetime")] +pub fn datetime_range( + start: Expr, + end: Expr, + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> Expr { let input = vec![start, end]; Expr::Function { input, - function: FunctionExpr::TemporalExpr(TemporalFunction::TimeRange { every, closed }), + function: FunctionExpr::Range(RangeFunction::DatetimeRange { + interval, + closed, + time_unit, + time_zone, + }), options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, - cast_to_supertypes: false, + cast_to_supertypes: true, allow_rename: true, ..Default::default() }, } } -/// Create a column of time ranges from a `start` and `stop` expression. -#[cfg(feature = "temporal")] -pub fn time_ranges(start: Expr, end: Expr, every: Duration, closed: ClosedWindow) -> Expr { +/// Create a column of datetime ranges from a `start` and `stop` expression. +#[cfg(feature = "dtype-datetime")] +pub fn datetime_ranges( + start: Expr, + end: Expr, + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> Expr { let input = vec![start, end]; Expr::Function { input, - function: FunctionExpr::TemporalExpr(TemporalFunction::TimeRanges { every, closed }), + function: FunctionExpr::Range(RangeFunction::DatetimeRanges { + interval, + closed, + time_unit, + time_zone, + }), + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyGroups, + cast_to_supertypes: true, + allow_rename: true, + ..Default::default() + }, + } +} + +/// Generate a time range. +#[cfg(feature = "dtype-time")] +pub fn time_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWindow) -> Expr { + let input = vec![start, end]; + + Expr::Function { + input, + function: FunctionExpr::Range(RangeFunction::TimeRange { interval, closed }), options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, - cast_to_supertypes: false, allow_rename: true, ..Default::default() }, } } -/// Create a column of length `n` containing `n` copies of the literal `value`. Generally you won't need this function, -/// as `lit(value)` already represents a column containing only `value` whose length is automatically set to the correct -/// number of rows. -pub fn repeat>(value: E, n: Expr) -> Expr { - let function = |s: Series, n: Series| { - polars_ensure!( - n.dtype().is_integer(), - SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() - ); - let first_value = n.get(0)?; - let n = first_value.extract::().ok_or_else( - || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), - )?; - Ok(Some(s.new_from_index(0, n))) - }; - apply_binary(value.into(), n, function, GetOutput::same_type()).alias("repeat") +/// Create a column of time ranges from a `start` and `stop` expression. +#[cfg(feature = "dtype-time")] +pub fn time_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWindow) -> Expr { + let input = vec![start, end]; + + Expr::Function { + input, + function: FunctionExpr::Range(RangeFunction::TimeRanges { interval, closed }), + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyGroups, + allow_rename: true, + ..Default::default() + }, + } } diff --git a/crates/polars-plan/src/dsl/functions/repeat.rs b/crates/polars-plan/src/dsl/functions/repeat.rs new file mode 100644 index 000000000000..1b32abc97b5a --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/repeat.rs @@ -0,0 +1,19 @@ +use super::*; + +/// Create a column of length `n` containing `n` copies of the literal `value`. Generally you won't need this function, +/// as `lit(value)` already represents a column containing only `value` whose length is automatically set to the correct +/// number of rows. +pub fn repeat>(value: E, n: Expr) -> Expr { + let function = |s: Series, n: Series| { + polars_ensure!( + n.dtype().is_integer(), + SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() + ); + let first_value = n.get(0)?; + let n = first_value.extract::().ok_or_else( + || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), + )?; + Ok(Some(s.new_from_index(0, n))) + }; + apply_binary(value.into(), n, function, GetOutput::same_type()).alias("repeat") +} diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index 4e5211c2d88e..5d702059150b 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -10,9 +10,9 @@ macro_rules! impl_unit_setter { }; } -/// Arguments used by `datetime` in order to produce an `Expr` of `Datetime` +/// Arguments used by `datetime` in order to produce an [`Expr`] of Datetime /// -/// Construct a `DatetimeArgs` with `DatetimeArgs::new(y, m, d)`. This will set the other time units to `lit(0)`. You +/// Construct a [`DatetimeArgs`] with `DatetimeArgs::new(y, m, d)`. This will set the other time units to `lit(0)`. You /// can then set the other fields with the `with_*` methods, or use `with_hms` to set `hour`, `minute`, and `second` all /// at once. /// @@ -150,9 +150,9 @@ pub fn datetime(args: DatetimeArgs) -> Expr { } } -/// Arguments used by `duration` in order to produce an `Expr` of `Duration` +/// Arguments used by `duration` in order to produce an [`Expr`] of [`Duration`] /// -/// To construct a `DurationArgs`, use struct literal syntax with `..Default::default()` to leave unspecified fields at +/// To construct a [`DurationArgs`], use struct literal syntax with `..Default::default()` to leave unspecified fields at /// their default value of `lit(0)`, as demonstrated below. /// /// ``` @@ -177,6 +177,7 @@ pub struct DurationArgs { pub milliseconds: Expr, pub microseconds: Expr, pub nanoseconds: Expr, + pub time_unit: TimeUnit, } impl Default for DurationArgs { @@ -190,12 +191,13 @@ impl Default for DurationArgs { milliseconds: lit(0), microseconds: lit(0), nanoseconds: lit(0), + time_unit: TimeUnit::Microseconds, } } } impl DurationArgs { - /// Create a new `DurationArgs` with all fields set to `lit(0)`. Use the `with_*` methods to set the fields. + /// Create a new [`DurationArgs`] with all fields set to `lit(0)`. Use the `with_*` methods to set the fields. pub fn new() -> Self { Self::default() } @@ -250,7 +252,7 @@ impl DurationArgs { impl_unit_setter!(with_nanoseconds(nanoseconds)); } -/// Construct a column of `Duration` from the provided [`DurationArgs`] +/// Construct a column of [`Duration`] from the provided [`DurationArgs`] #[cfg(feature = "temporal")] pub fn duration(args: DurationArgs) -> Expr { let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| { @@ -258,15 +260,15 @@ pub fn duration(args: DurationArgs) -> Expr { if s.iter().any(|s| s.is_empty()) { return Ok(Some(Series::new_empty( s[0].name(), - &DataType::Duration(TimeUnit::Nanoseconds), + &DataType::Duration(args.time_unit), ))); } let days = s[0].cast(&DataType::Int64).unwrap(); let seconds = s[1].cast(&DataType::Int64).unwrap(); let mut nanoseconds = s[2].cast(&DataType::Int64).unwrap(); - let microseconds = s[3].cast(&DataType::Int64).unwrap(); - let milliseconds = s[4].cast(&DataType::Int64).unwrap(); + let mut microseconds = s[3].cast(&DataType::Int64).unwrap(); + let mut milliseconds = s[4].cast(&DataType::Int64).unwrap(); let minutes = s[5].cast(&DataType::Int64).unwrap(); let hours = s[6].cast(&DataType::Int64).unwrap(); let weeks = s[7].cast(&DataType::Int64).unwrap(); @@ -278,34 +280,59 @@ pub fn duration(args: DurationArgs) -> Expr { (s.len() != max_len && s.get(0).unwrap() != AnyValue::Int64(0)) || s.len() == max_len }; - if nanoseconds.len() != max_len { - nanoseconds = nanoseconds.new_from_index(0, max_len); - } - if condition(µseconds) { - nanoseconds = nanoseconds + (microseconds * 1_000); - } - if condition(&milliseconds) { - nanoseconds = nanoseconds + (milliseconds * 1_000_000); - } + let multiplier = match args.time_unit { + TimeUnit::Nanoseconds => NANOSECONDS, + TimeUnit::Microseconds => MICROSECONDS, + TimeUnit::Milliseconds => MILLISECONDS, + }; + + let mut duration = match args.time_unit { + TimeUnit::Nanoseconds => { + if nanoseconds.len() != max_len { + nanoseconds = nanoseconds.new_from_index(0, max_len); + } + if condition(µseconds) { + nanoseconds = nanoseconds + (microseconds * 1_000); + } + if condition(&milliseconds) { + nanoseconds = nanoseconds + (milliseconds * 1_000_000); + } + nanoseconds + }, + TimeUnit::Microseconds => { + if microseconds.len() != max_len { + microseconds = microseconds.new_from_index(0, max_len); + } + if condition(&milliseconds) { + microseconds = microseconds + (milliseconds * 1_000); + } + microseconds + }, + TimeUnit::Milliseconds => { + if milliseconds.len() != max_len { + milliseconds = milliseconds.new_from_index(0, max_len); + } + milliseconds + }, + }; + if condition(&seconds) { - nanoseconds = nanoseconds + (seconds * NANOSECONDS); + duration = duration + (seconds * multiplier); } if condition(&days) { - nanoseconds = nanoseconds + (days * NANOSECONDS * SECONDS_IN_DAY); + duration = duration + (days * multiplier * SECONDS_IN_DAY); } if condition(&minutes) { - nanoseconds = nanoseconds + minutes * NANOSECONDS * 60; + duration = duration + minutes * multiplier * 60; } if condition(&hours) { - nanoseconds = nanoseconds + hours * NANOSECONDS * 60 * 60; + duration = duration + hours * multiplier * 60 * 60; } if condition(&weeks) { - nanoseconds = nanoseconds + weeks * NANOSECONDS * SECONDS_IN_DAY * 7; + duration = duration + weeks * multiplier * SECONDS_IN_DAY * 7; } - nanoseconds - .cast(&DataType::Duration(TimeUnit::Nanoseconds)) - .map(Some) + duration.cast(&DataType::Duration(args.time_unit)).map(Some) }) as Arc); Expr::AnonymousFunction { @@ -320,7 +347,7 @@ pub fn duration(args: DurationArgs) -> Expr { args.weeks, ], function, - output_type: GetOutput::from_type(DataType::Duration(TimeUnit::Nanoseconds)), + output_type: GetOutput::from_type(DataType::Duration(args.time_unit)), options: FunctionOptions { collect_groups: ApplyOptions::ApplyFlat, input_wildcard_expansion: true, diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 8665c76c03d5..6e9bde5b68eb 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -28,49 +28,30 @@ impl ListNameSpace { .with_fmt("list.all") } - /// Get lengths of the arrays in the List type. - pub fn lengths(self) -> Expr { - let function = |s: Series| { - let ca = s.list()?; - Ok(Some(ca.lst_lengths().into_series())) - }; + #[cfg(feature = "list_drop_nulls")] + pub fn drop_nulls(self) -> Expr { self.0 - .map(function, GetOutput::from_type(IDX_DTYPE)) - .with_fmt("list.len") + .map_private(FunctionExpr::ListExpr(ListFunction::DropNulls)) + } + + /// Return the number of elements in each list. + /// + /// Null values are treated like regular elements in this context. + pub fn len(self) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::Length)) } /// Compute the maximum of the items in every sublist. pub fn max(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_max())), - GetOutput::map_field(|f| { - if let DataType::List(adt) = f.data_type() { - Field::new(f.name(), *adt.clone()) - } else { - // inner type - f.clone() - } - }), - ) - .with_fmt("list.max") + .map_private(FunctionExpr::ListExpr(ListFunction::Max)) } /// Compute the minimum of the items in every sublist. pub fn min(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_min())), - GetOutput::map_field(|f| { - if let DataType::List(adt) = f.data_type() { - Field::new(f.name(), *adt.clone()) - } else { - // inner type - f.clone() - } - }), - ) - .with_fmt("list.min") + .map_private(FunctionExpr::ListExpr(ListFunction::Min)) } /// Compute the sum the items in every sublist. @@ -82,57 +63,41 @@ impl ListNameSpace { /// Compute the mean of every sublist and return a `Series` of dtype `Float64` pub fn mean(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_mean().into_series())), - GetOutput::from_type(DataType::Float64), - ) - .with_fmt("list.mean") + .map_private(FunctionExpr::ListExpr(ListFunction::Mean)) } /// Sort every sublist. pub fn sort(self, options: SortOptions) -> Expr { self.0 - .map( - move |s| Ok(Some(s.list()?.lst_sort(options).into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.sort") + .map_private(FunctionExpr::ListExpr(ListFunction::Sort(options))) } /// Reverse every sublist pub fn reverse(self) -> Expr { self.0 - .map( - move |s| Ok(Some(s.list()?.lst_reverse().into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.reverse") + .map_private(FunctionExpr::ListExpr(ListFunction::Reverse)) } /// Keep only the unique values in every sublist. pub fn unique(self) -> Expr { self.0 - .map( - move |s| Ok(Some(s.list()?.lst_unique()?.into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.unique") + .map_private(FunctionExpr::ListExpr(ListFunction::Unique(false))) } /// Keep only the unique values in every sublist. pub fn unique_stable(self) -> Expr { self.0 - .map( - move |s| Ok(Some(s.list()?.lst_unique_stable()?.into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.unique_stable") + .map_private(FunctionExpr::ListExpr(ListFunction::Unique(true))) } /// Get items in every sublist by index. pub fn get(self, index: Expr) -> Expr { - self.0 - .map_many_private(FunctionExpr::ListExpr(ListFunction::Get), &[index], false) + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Get), + &[index], + true, + false, + ) } /// Get items in every sublist by multiple indexes. @@ -145,6 +110,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Take(null_on_oob)), &[index], + true, false, ) } @@ -162,59 +128,45 @@ impl ListNameSpace { /// Join all string items in a sublist and place a separator between them. /// # Error /// This errors if inner type of list `!= DataType::Utf8`. - pub fn join(self, separator: &str) -> Expr { - let separator = separator.to_string(); - self.0 - .map( - move |s| { - s.list()? - .lst_join(&separator) - .map(|ca| Some(ca.into_series())) - }, - GetOutput::from_type(DataType::Utf8), - ) - .with_fmt("list.join") + pub fn join(self, separator: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Join), + &[separator], + false, + false, + ) } /// Return the index of the minimal value of every sublist pub fn arg_min(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_arg_min().into_series())), - GetOutput::from_type(IDX_DTYPE), - ) - .with_fmt("list.arg_min") + .map_private(FunctionExpr::ListExpr(ListFunction::ArgMin)) } /// Return the index of the maximum value of every sublist pub fn arg_max(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_arg_max().into_series())), - GetOutput::from_type(IDX_DTYPE), - ) - .with_fmt("list.arg_max") + .map_private(FunctionExpr::ListExpr(ListFunction::ArgMax)) } /// Diff every sublist. #[cfg(feature = "diff")] pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr { self.0 - .map( - move |s| Ok(Some(s.list()?.lst_diff(n, null_behavior)?.into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.diff") + .map_private(FunctionExpr::ListExpr(ListFunction::Diff { + n, + null_behavior, + })) } /// Shift every sublist. - pub fn shift(self, periods: i64) -> Expr { - self.0 - .map( - move |s| Ok(Some(s.list()?.lst_shift(periods).into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.shift") + pub fn shift(self, periods: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Shift), + &[periods], + false, + false, + ) } /// Slice every sublist. @@ -222,6 +174,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Slice), &[offset, length], + true, false, ) } @@ -297,47 +250,49 @@ impl ListNameSpace { pub fn contains>(self, other: E) -> Expr { let other = other.into(); - Expr::Function { - input: vec![self.0, other], - function: FunctionExpr::ListExpr(ListFunction::Contains), - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - input_wildcard_expansion: true, - auto_explode: true, - ..Default::default() - }, - } + self.0 + .map_many_private( + FunctionExpr::ListExpr(ListFunction::Contains), + &[other], + true, + false, + ) + .with_function_options(|mut options| { + options.input_wildcard_expansion = true; + options + }) } #[cfg(feature = "list_count")] /// Count how often the value produced by ``element`` occurs. - pub fn count_match>(self, element: E) -> Expr { + pub fn count_matches>(self, element: E) -> Expr { let other = element.into(); - Expr::Function { - input: vec![self.0, other], - function: FunctionExpr::ListExpr(ListFunction::CountMatch), - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - input_wildcard_expansion: true, - auto_explode: true, - ..Default::default() - }, - } + self.0 + .map_many_private( + FunctionExpr::ListExpr(ListFunction::CountMatches), + &[other], + true, + false, + ) + .with_function_options(|mut options| { + options.input_wildcard_expansion = true; + options + }) } #[cfg(feature = "list_sets")] fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr { - Expr::Function { - input: vec![self.0, other], - function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)), - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - input_wildcard_expansion: true, - auto_explode: false, - cast_to_supertypes: true, - ..Default::default() - }, - } + self.0 + .map_many_private( + FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)), + &[other], + false, + true, + ) + .with_function_options(|mut options| { + options.input_wildcard_expansion = true; + options + }) } /// Return the SET UNION between both list arrays. diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 5d93efb895c3..1226c869b829 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1,5 +1,9 @@ #![allow(ambiguous_glob_reexports)] //! Domain specific language for the Lazy API. +#[cfg(feature = "rolling_window")] +use polars_core::utils::ensure_sorted_arg; +#[cfg(feature = "mode")] +use polars_ops::chunked_array::mode::mode; #[cfg(feature = "dtype-categorical")] pub mod cat; #[cfg(feature = "dtype-categorical")] @@ -15,7 +19,6 @@ mod expr; mod expr_dyn_fn; mod from; pub(crate) mod function_expr; -#[cfg(feature = "compile")] pub mod functions; mod list; #[cfg(feature = "meta")] @@ -31,7 +34,7 @@ mod selector; pub mod string; #[cfg(feature = "dtype-struct")] mod struct_; - +pub mod udf; use std::fmt::Debug; use std::sync::Arc; @@ -39,6 +42,7 @@ pub use arity::*; #[cfg(feature = "dtype-array")] pub use array::*; pub use expr::*; +pub use function_expr::schema::FieldsMapper; pub use function_expr::*; pub use functions::*; pub use list::*; @@ -52,10 +56,11 @@ use polars_core::series::ops::NullBehavior; use polars_core::series::IsSorted; use polars_core::utils::{try_get_supertype, NoNull}; #[cfg(feature = "rolling_window")] -use polars_time::series::SeriesOpsTime; +use polars_time::prelude::SeriesOpsTime; pub(crate) use selector::Selector; #[cfg(feature = "dtype-struct")] pub use struct_::*; +pub use udf::UserDefinedFunction; use crate::constants::MAP_LIST_NAME; pub use crate::logical_plan::lit; @@ -103,8 +108,8 @@ impl Expr { } } - /// Overwrite the function name used for formatting - /// this is not intended to be used + /// Overwrite the function name used for formatting. + /// (this is not intended to be used). #[doc(hidden)] pub fn with_fmt(self, name: &'static str) -> Expr { self.with_function_options(|mut options| { @@ -113,50 +118,50 @@ impl Expr { }) } - /// Compare `Expr` with other `Expr` on equality + /// Compare `Expr` with other `Expr` on equality. pub fn eq>(self, other: E) -> Expr { binary_expr(self, Operator::Eq, other.into()) } - /// Compare `Expr` with other `Expr` on equality where `None == None` + /// Compare `Expr` with other `Expr` on equality where `None == None`. pub fn eq_missing>(self, other: E) -> Expr { binary_expr(self, Operator::EqValidity, other.into()) } - /// Compare `Expr` with other `Expr` on non-equality + /// Compare `Expr` with other `Expr` on non-equality. pub fn neq>(self, other: E) -> Expr { binary_expr(self, Operator::NotEq, other.into()) } - /// Compare `Expr` with other `Expr` on non-equality where `None == None` + /// Compare `Expr` with other `Expr` on non-equality where `None == None`. pub fn neq_missing>(self, other: E) -> Expr { binary_expr(self, Operator::NotEqValidity, other.into()) } - /// Check if `Expr` < `Expr` + /// Check if `Expr` < `Expr`. pub fn lt>(self, other: E) -> Expr { binary_expr(self, Operator::Lt, other.into()) } - /// Check if `Expr` > `Expr` + /// Check if `Expr` > `Expr`. pub fn gt>(self, other: E) -> Expr { binary_expr(self, Operator::Gt, other.into()) } - /// Check if `Expr` >= `Expr` + /// Check if `Expr` >= `Expr`. pub fn gt_eq>(self, other: E) -> Expr { binary_expr(self, Operator::GtEq, other.into()) } - /// Check if `Expr` <= `Expr` + /// Check if `Expr` <= `Expr`. pub fn lt_eq>(self, other: E) -> Expr { binary_expr(self, Operator::LtEq, other.into()) } - /// Negate `Expr` + /// Negate `Expr`. #[allow(clippy::should_implement_trait)] pub fn not(self) -> Expr { - self.map_private(BooleanFunction::IsNot.into()) + self.map_private(BooleanFunction::Not.into()) } /// Rename Column. @@ -176,12 +181,12 @@ impl Expr { self.map_private(BooleanFunction::IsNotNull.into()) } - /// Drop null values + /// Drop null values. pub fn drop_nulls(self) -> Self { self.apply(|s| Ok(Some(s.drop_nulls())), GetOutput::same_type()) } - /// Drop NaN values + /// Drop NaN values. pub fn drop_nans(self) -> Self { self.apply_private(FunctionExpr::DropNans) } @@ -252,7 +257,7 @@ impl Expr { AggExpr::Last(Box::new(self)).into() } - /// Aggregate the group to a Series + /// Aggregate the group to a Series. pub fn implode(self) -> Self { AggExpr::Implode(Box::new(self)).into() } @@ -272,12 +277,12 @@ impl Expr { AggExpr::AggGroups(Box::new(self)).into() } - /// Alias for explode + /// Alias for `explode`. pub fn flatten(self) -> Self { self.explode() } - /// Explode the utf8/ list column + /// Explode the utf8/ list column. pub fn explode(self) -> Self { Expr::Explode(Box::new(self)) } @@ -316,12 +321,12 @@ impl Expr { ) } - /// Get the first `n` elements of the Expr result + /// Get the first `n` elements of the Expr result. pub fn head(self, length: Option) -> Self { self.slice(lit(0), lit(length.unwrap_or(10) as u64)) } - /// Get the last `n` elements of the Expr result + /// Get the last `n` elements of the Expr result. pub fn tail(self, length: Option) -> Self { let len = length.unwrap_or(10); self.slice(lit(-(len as i64)), lit(len as u64)) @@ -347,7 +352,7 @@ impl Expr { .with_fmt("arg_unique") } - /// Get the index value that has the minimum value + /// Get the index value that has the minimum value. pub fn arg_min(self) -> Self { let options = FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, @@ -368,7 +373,7 @@ impl Expr { ) } - /// Get the index value that has the maximum value + /// Get the index value that has the maximum value. pub fn arg_max(self) -> Self { let options = FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, @@ -422,7 +427,7 @@ impl Expr { } /// Cast expression to another data type. - /// Throws an error if conversion had overflows + /// Throws an error if conversion had overflows. pub fn strict_cast(self, data_type: DataType) -> Self { Expr::Cast { expr: Box::new(self), @@ -471,22 +476,16 @@ impl Expr { /// /// This has time complexity `O(n + k log(n))`. #[cfg(feature = "top_k")] - pub fn top_k(self, k: usize) -> Self { - self.apply_private(FunctionExpr::TopK { - k, - descending: false, - }) + pub fn top_k(self, k: Expr) -> Self { + self.apply_many_private(FunctionExpr::TopK(false), &[k], false, false) } /// Returns the `k` smallest elements. /// /// This has time complexity `O(n + k log(n))`. #[cfg(feature = "top_k")] - pub fn bottom_k(self, k: usize) -> Self { - self.apply_private(FunctionExpr::TopK { - k, - descending: true, - }) + pub fn bottom_k(self, k: Expr) -> Self { + self.apply_many_private(FunctionExpr::TopK(true), &[k], false, false) } /// Reverse column @@ -532,7 +531,7 @@ impl Expr { } } - /// Apply a function/closure once the logical plan get executed with many arguments + /// Apply a function/closure once the logical plan get executed with many arguments. /// /// See the [`Expr::map`] function for the differences between [`map`](Expr::map) and [`apply`](Expr::apply). pub fn map_many(self, function: F, arguments: &[Expr], output_type: GetOutput) -> Self @@ -687,6 +686,7 @@ impl Expr { self, function_expr: FunctionExpr, arguments: &[Expr], + auto_explode: bool, cast_to_supertypes: bool, ) -> Self { let mut input = Vec::with_capacity(arguments.len() + 1); @@ -698,31 +698,31 @@ impl Expr { function: function_expr, options: FunctionOptions { collect_groups: ApplyOptions::ApplyFlat, - auto_explode: true, + auto_explode, cast_to_supertypes, ..Default::default() }, } } - /// Get mask of finite values if dtype is Float + /// Get mask of finite values if dtype is Float. #[allow(clippy::wrong_self_convention)] pub fn is_finite(self) -> Self { self.map_private(BooleanFunction::IsFinite.into()) } - /// Get mask of infinite values if dtype is Float + /// Get mask of infinite values if dtype is Float. #[allow(clippy::wrong_self_convention)] pub fn is_infinite(self) -> Self { self.map_private(BooleanFunction::IsInfinite.into()) } - /// Get mask of NaN values if dtype is Float + /// Get mask of NaN values if dtype is Float. pub fn is_nan(self) -> Self { self.map_private(BooleanFunction::IsNan.into()) } - /// Get inverse mask of NaN values if dtype is Float + /// Get inverse mask of NaN values if dtype is Float. pub fn is_not_nan(self) -> Self { self.map_private(BooleanFunction::IsNotNan.into()) } @@ -747,27 +747,27 @@ impl Expr { self.apply_private(FunctionExpr::Cumcount { reverse }) } - /// Get an array with the cumulative sum computed at every element + /// Get an array with the cumulative sum computed at every element. pub fn cumsum(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumsum { reverse }) } - /// Get an array with the cumulative product computed at every element + /// Get an array with the cumulative product computed at every element. pub fn cumprod(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumprod { reverse }) } - /// Get an array with the cumulative min computed at every element + /// Get an array with the cumulative min computed at every element. pub fn cummin(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cummin { reverse }) } - /// Get an array with the cumulative max computed at every element + /// Get an array with the cumulative max computed at every element. pub fn cummax(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cummax { reverse }) } - /// Get the product aggregation of an expression + /// Get the product aggregation of an expression. pub fn product(self) -> Self { let options = FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, @@ -793,20 +793,12 @@ impl Expr { /// Fill missing value with next non-null. pub fn backward_fill(self, limit: FillNullLimit) -> Self { - self.apply( - move |s: Series| s.fill_null(FillNullStrategy::Backward(limit)).map(Some), - GetOutput::same_type(), - ) - .with_fmt("backward_fill") + self.apply_private(FunctionExpr::BackwardFill { limit }) } /// Fill missing value with previous non-null. pub fn forward_fill(self, limit: FillNullLimit) -> Self { - self.apply( - move |s: Series| s.fill_null(FillNullStrategy::Forward(limit)).map(Some), - GetOutput::same_type(), - ) - .with_fmt("forward_fill") + self.apply_private(FunctionExpr::ForwardFill { limit }) } /// Round underlying floating point array to given decimal numbers. @@ -835,29 +827,44 @@ impl Expr { /// Clip underlying values to a set boundary. #[cfg(feature = "round_series")] - pub fn clip(self, min: AnyValue<'_>, max: AnyValue<'_>) -> Self { - self.map_private(FunctionExpr::Clip { - min: Some(min.into_static().unwrap()), - max: Some(max.into_static().unwrap()), - }) + pub fn clip(self, min: Expr, max: Expr) -> Self { + self.map_many_private( + FunctionExpr::Clip { + has_min: true, + has_max: true, + }, + &[min, max], + false, + false, + ) } /// Clip underlying values to a set boundary. #[cfg(feature = "round_series")] - pub fn clip_max(self, max: AnyValue<'_>) -> Self { - self.map_private(FunctionExpr::Clip { - min: None, - max: Some(max.into_static().unwrap()), - }) + pub fn clip_max(self, max: Expr) -> Self { + self.map_many_private( + FunctionExpr::Clip { + has_min: false, + has_max: true, + }, + &[max], + false, + false, + ) } /// Clip underlying values to a set boundary. #[cfg(feature = "round_series")] - pub fn clip_min(self, min: AnyValue<'_>) -> Self { - self.map_private(FunctionExpr::Clip { - min: Some(min.into_static().unwrap()), - max: None, - }) + pub fn clip_min(self, min: Expr) -> Self { + self.map_many_private( + FunctionExpr::Clip { + has_min: true, + has_max: false, + }, + &[min], + false, + false, + ) } /// Convert all values to their absolute/positive value. @@ -923,7 +930,7 @@ impl Expr { pub fn over_with_options, IE: Into + Clone>( self, partition_by: E, - options: WindowOptions, + options: WindowMapping, ) -> Self { let partition_by = partition_by .as_ref() @@ -933,8 +940,16 @@ impl Expr { Expr::Window { function: Box::new(self), partition_by, - order_by: None, - options, + options: options.into(), + } + } + + #[cfg(feature = "dynamic_group_by")] + pub fn rolling(self, options: RollingGroupOptions) -> Self { + Expr::Window { + function: Box::new(self), + partition_by: vec![], + options: WindowType::Rolling(options), } } @@ -977,24 +992,24 @@ impl Expr { AggExpr::Count(Box::new(self)).into() } - /// Standard deviation of the values of the Series + /// Standard deviation of the values of the Series. pub fn std(self, ddof: u8) -> Self { AggExpr::Std(Box::new(self), ddof).into() } - /// Variance of the values of the Series + /// Variance of the values of the Series. pub fn var(self, ddof: u8) -> Self { AggExpr::Var(Box::new(self), ddof).into() } - /// Get a mask of duplicated values + /// Get a mask of duplicated values. #[allow(clippy::wrong_self_convention)] #[cfg(feature = "is_unique")] pub fn is_duplicated(self) -> Self { self.apply_private(BooleanFunction::IsDuplicated.into()) } - /// Get a mask of unique values + /// Get a mask of unique values. #[allow(clippy::wrong_self_convention)] #[cfg(feature = "is_unique")] pub fn is_unique(self) -> Self { @@ -1011,17 +1026,17 @@ impl Expr { }) } - /// and operation + /// "and" operation. pub fn and>(self, expr: E) -> Self { binary_expr(self, Operator::And, expr.into()) } - // xor operation + /// "xor" operation. pub fn xor>(self, expr: E) -> Self { binary_expr(self, Operator::Xor, expr.into()) } - /// or operation + /// "or" operation. pub fn or>(self, expr: E) -> Self { binary_expr(self, Operator::Or, expr.into()) } @@ -1050,7 +1065,7 @@ impl Expr { let arguments = &[other]; // we don't have to apply on groups, so this is faster if has_literal { - self.map_many_private(BooleanFunction::IsIn.into(), arguments, true) + self.map_many_private(BooleanFunction::IsIn.into(), arguments, true, true) } else { self.apply_many_private(BooleanFunction::IsIn.into(), arguments, true, true) } @@ -1078,7 +1093,7 @@ impl Expr { let by = &s[1]; let s = &s[0]; let by = by.cast(&IDX_DTYPE)?; - Ok(Some(s.repeat_by(by.idx()?)?.into_series())) + Ok(Some(repeat_by(s, by.idx()?)?.into_series())) }; self.apply_many( @@ -1091,22 +1106,30 @@ impl Expr { #[cfg(feature = "repeat_by")] /// Repeat the column `n` times, where `n` is determined by the values in `by`. - /// This yields an `Expr` of dtype `List` + /// This yields an `Expr` of dtype `List`. pub fn repeat_by>(self, by: E) -> Expr { self.repeat_by_impl(by.into()) } - #[cfg(feature = "is_first")] + #[cfg(feature = "is_first_distinct")] #[allow(clippy::wrong_self_convention)] /// Get a mask of the first unique value. - pub fn is_first(self) -> Expr { - self.apply_private(BooleanFunction::IsFirst.into()) + pub fn is_first_distinct(self) -> Expr { + self.apply_private(BooleanFunction::IsFirstDistinct.into()) + } + + #[cfg(feature = "is_last_distinct")] + #[allow(clippy::wrong_self_convention)] + /// Get a mask of the last unique value. + pub fn is_last_distinct(self) -> Expr { + self.apply_private(BooleanFunction::IsLastDistinct.into()) } fn dot_impl(self, other: Expr) -> Expr { (self * other).sum() } + /// Compute the dot/inner product between two expressions. pub fn dot>(self, other: E) -> Expr { self.dot_impl(other.into()) } @@ -1114,11 +1137,8 @@ impl Expr { #[cfg(feature = "mode")] /// Compute the mode(s) of this column. This is the most occurring value. pub fn mode(self) -> Expr { - self.apply( - |s| s.mode().map(|ca| Some(ca.into_series())), - GetOutput::same_type(), - ) - .with_fmt("mode") + self.apply(|s| mode(&s).map(Some), GetOutput::same_type()) + .with_fmt("mode") } /// Keep the original root name @@ -1197,8 +1217,8 @@ impl Expr { Expr::Exclude(Box::new(self), v) } - // Interpolate None values #[cfg(feature = "interpolate")] + /// Fill null values using interpolation. pub fn interpolate(self, method: InterpolationMethod) -> Expr { self.apply_private(FunctionExpr::Interpolate(method)) } @@ -1227,12 +1247,16 @@ impl Expr { ComputeError: "`weights` is not supported in 'rolling by' expression" ); let (by, tz) = match by.dtype() { - DataType::Datetime(_, tz) => ( - by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?, - tz, + DataType::Datetime(tu, tz) => { + (by.cast(&DataType::Datetime(*tu, None))?, tz) + }, + DataType::Date => ( + by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, + &None, ), - _ => (by.clone(), &None), + dt => polars_bail!(opq = expr_name, got = dt, expected = "date/datetime"), }; + ensure_sorted_arg(&by, expr_name)?; let by = by.datetime().unwrap(); let by_values = by.cont_slice().map_err(|_| { polars_err!( @@ -1357,7 +1381,7 @@ impl Expr { ) } - /// Apply a rolling variance + /// Apply a rolling variance. #[cfg(feature = "rolling_window")] pub fn rolling_var(self, options: RollingOptions) -> Expr { self.finish_rolling( @@ -1369,7 +1393,7 @@ impl Expr { ) } - /// Apply a rolling std-dev + /// Apply a rolling std-dev. #[cfg(feature = "rolling_window")] pub fn rolling_std(self, options: RollingOptions) -> Expr { self.finish_rolling( @@ -1381,7 +1405,7 @@ impl Expr { ) } - /// Apply a rolling skew + /// Apply a rolling skew. #[cfg(feature = "rolling_window")] #[cfg(feature = "moment")] pub fn rolling_skew(self, window_size: usize, bias: bool) -> Expr { @@ -1442,7 +1466,18 @@ impl Expr { .with_fmt("rolling_map_float") } + #[cfg(feature = "peaks")] + pub fn peak_min(self) -> Expr { + self.apply_private(FunctionExpr::PeakMin) + } + + #[cfg(feature = "peaks")] + pub fn peak_max(self) -> Expr { + self.apply_private(FunctionExpr::PeakMax) + } + #[cfg(feature = "rank")] + /// Assign ranks to data, dealing with ties appropriately. pub fn rank(self, options: RankOptions, seed: Option) -> Expr { self.apply( move |s| Ok(Some(s.rank(options, seed))), @@ -1455,6 +1490,7 @@ impl Expr { } #[cfg(feature = "cutqcut")] + /// Bin continuous values into discrete categories. pub fn cut( self, breaks: Vec, @@ -1471,6 +1507,7 @@ impl Expr { } #[cfg(feature = "cutqcut")] + /// Bin continuous values into discrete categories based on their quantiles. pub fn qcut( self, probs: Vec, @@ -1489,6 +1526,7 @@ impl Expr { } #[cfg(feature = "cutqcut")] + /// Bin continuous values into discrete categories using uniform quantile probabilities. pub fn qcut_uniform( self, n_bins: usize, @@ -1508,20 +1546,25 @@ impl Expr { } #[cfg(feature = "rle")] + /// Get the lengths of runs of identical values. pub fn rle(self) -> Expr { self.apply_private(FunctionExpr::RLE) } + #[cfg(feature = "rle")] + /// Similar to `rle`, but maps values to run IDs. pub fn rle_id(self) -> Expr { self.apply_private(FunctionExpr::RLEID) } #[cfg(feature = "diff")] + /// Calculate the n-th discrete difference between values. pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr { self.apply_private(FunctionExpr::Diff(n, null_behavior)) } #[cfg(feature = "pct_change")] + /// Computes percentage change between values. pub fn pct_change(self, n: i64) -> Expr { use DataType::*; self.apply( @@ -1561,6 +1604,13 @@ impl Expr { } #[cfg(feature = "moment")] + /// Compute the kurtosis (Fisher or Pearson). + /// + /// Kurtosis is the fourth central moment divided by the square of the + /// variance. If Fisher's definition is used, then 3.0 is subtracted from + /// the result to give 0.0 for a normal distribution. + /// If bias is False then the kurtosis is calculated using k statistics to + /// eliminate bias coming from biased moment estimators. pub fn kurtosis(self, fisher: bool, bias: bool) -> Expr { self.apply( move |s| { @@ -1614,6 +1664,7 @@ impl Expr { } #[cfg(feature = "ewma")] + /// Calculate the exponentially-weighted moving average. pub fn ewm_mean(self, options: EWMOptions) -> Self { use DataType::*; self.apply( @@ -1627,6 +1678,7 @@ impl Expr { } #[cfg(feature = "ewma")] + /// Calculate the exponentially-weighted moving standard deviation. pub fn ewm_std(self, options: EWMOptions) -> Self { use DataType::*; self.apply( @@ -1640,6 +1692,7 @@ impl Expr { } #[cfg(feature = "ewma")] + /// Calculate the exponentially-weighted moving variance. pub fn ewm_var(self, options: EWMOptions) -> Self { use DataType::*; self.apply( @@ -1690,54 +1743,38 @@ impl Expr { } #[cfg(feature = "dtype-struct")] - /// Count all unique values and create a struct mapping value to count - /// Note that it is better to turn parallel off in the aggregation context + /// Count all unique values and create a struct mapping value to count. + /// (Note that it is better to turn parallel off in the aggregation context). pub fn value_counts(self, sort: bool, parallel: bool) -> Self { - self.apply( - move |s| { - s.value_counts(sort, parallel) - .map(|df| Some(df.into_struct(s.name()).into_series())) - }, - GetOutput::map_field(|fld| { - Field::new( - fld.name(), - DataType::Struct(vec![fld.clone(), Field::new("counts", IDX_DTYPE)]), - ) - }), - ) - .with_function_options(|mut opts| { - opts.pass_name_to_apply = true; - opts - }) - .with_fmt("value_counts") + self.apply_private(FunctionExpr::ValueCounts { sort, parallel }) + .with_function_options(|mut opts| { + opts.pass_name_to_apply = true; + opts + }) } #[cfg(feature = "unique_counts")] /// Returns a count of the unique values in the order of appearance. /// This method differs from [`Expr::value_counts]` in that it does not return the - /// values, only the counts and might be faster + /// values, only the counts and might be faster. pub fn unique_counts(self) -> Self { - self.apply( - |s| Ok(Some(s.unique_counts().into_series())), - GetOutput::from_type(IDX_DTYPE), - ) - .with_fmt("unique_counts") + self.apply_private(FunctionExpr::UniqueCounts) } #[cfg(feature = "log")] - /// Compute the logarithm to a given base + /// Compute the logarithm to a given base. pub fn log(self, base: f64) -> Self { self.map_private(FunctionExpr::Log { base }) } #[cfg(feature = "log")] - /// Compute the natural logarithm of all elements plus one in the input array + /// Compute the natural logarithm of all elements plus one in the input array. pub fn log1p(self) -> Self { self.map_private(FunctionExpr::Log1p) } #[cfg(feature = "log")] - /// Calculate the exponential of all elements in the input array + /// Calculate the exponential of all elements in the input array. pub fn exp(self) -> Self { self.map_private(FunctionExpr::Exp) } @@ -1752,7 +1789,7 @@ impl Expr { options }) } - /// Get the null count of the column/group + /// Get the null count of the column/group. pub fn null_count(self) -> Expr { self.apply_private(FunctionExpr::NullCount) .with_function_options(|mut options| { @@ -1771,7 +1808,7 @@ impl Expr { } #[cfg(feature = "row_hash")] - /// Compute the hash of every element + /// Compute the hash of every element. pub fn hash(self, k0: u64, k1: u64, k2: u64, k3: u64) -> Expr { self.map_private(FunctionExpr::Hash(k0, k1, k2, k3)) } @@ -1798,19 +1835,19 @@ impl Expr { list::ListNameSpace(self) } - /// Get the [`array::ArrayNameSpace`] + /// Get the [`array::ArrayNameSpace`]. #[cfg(feature = "dtype-array")] pub fn arr(self) -> array::ArrayNameSpace { array::ArrayNameSpace(self) } - /// Get the [`CategoricalNameSpace`] + /// Get the [`CategoricalNameSpace`]. #[cfg(feature = "dtype-categorical")] pub fn cat(self) -> cat::CategoricalNameSpace { cat::CategoricalNameSpace(self) } - /// Get the [`struct_::StructNameSpace`] + /// Get the [`struct_::StructNameSpace`]. #[cfg(feature = "dtype-struct")] pub fn struct_(self) -> struct_::StructNameSpace { struct_::StructNameSpace(self) @@ -1914,17 +1951,17 @@ where } } -/// Count expression +/// Count expression. pub fn count() -> Expr { Expr::Count } -/// First column in DataFrame +/// First column in DataFrame. pub fn first() -> Expr { Expr::Nth(0) } -/// Last column in DataFrame +/// Last column in DataFrame. pub fn last() -> Expr { Expr::Nth(-1) } diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index a5f9202151b7..f7c8d355e5a8 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -1,4 +1,6 @@ -use polars_core::prelude::{JoinArgs, JoinType}; +use polars_ops::prelude::{JoinArgs, JoinType}; +#[cfg(feature = "dynamic_group_by")] +use polars_time::RollingGroupOptions; use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -60,12 +62,26 @@ impl Default for JoinOptions { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct WindowOptions { +pub enum WindowType { /// Explode the aggregated list and just do a hstack instead of a join /// this requires the groups to be sorted to make any sense - pub mapping: WindowMapping, + Over(WindowMapping), + #[cfg(feature = "dynamic_group_by")] + Rolling(RollingGroupOptions), +} + +impl From for WindowType { + fn from(value: WindowMapping) -> Self { + Self::Over(value) + } +} + +impl Default for WindowType { + fn default() -> Self { + Self::Over(WindowMapping::default()) + } } #[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] diff --git a/crates/polars-plan/src/dsl/random.rs b/crates/polars-plan/src/dsl/random.rs index 8c1f9e1c683b..efd36a15a86c 100644 --- a/crates/polars-plan/src/dsl/random.rs +++ b/crates/polars-plan/src/dsl/random.rs @@ -10,19 +10,23 @@ impl Expr { pub fn sample_n( self, - n: usize, + n: Expr, with_replacement: bool, shuffle: bool, seed: Option, ) -> Self { - self.apply_private(FunctionExpr::Random { - method: RandomMethod::SampleN { - n, - with_replacement, - shuffle, + self.apply_many_private( + FunctionExpr::Random { + method: RandomMethod::SampleN { + with_replacement, + shuffle, + }, + seed, }, - seed, - }) + &[n], + false, + false, + ) } pub fn sample_frac( diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 9b0ba4653fb2..1fdebe23f676 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -1,8 +1,3 @@ -#[cfg(feature = "dtype-struct")] -use polars_arrow::export::arrow::array::{MutableArray, MutableUtf8Array}; -#[cfg(feature = "dtype-struct")] -use polars_utils::format_smartstring; - use super::function_expr::StringFunction; use super::*; /// Specialized expressions for [`Series`] of [`DataType::Utf8`]. @@ -19,6 +14,7 @@ impl StringNameSpace { }), &[pat], true, + true, ) } @@ -33,6 +29,7 @@ impl StringNameSpace { }), &[pat], true, + true, ) } @@ -42,6 +39,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::EndsWith), &[sub], true, + true, ) } @@ -51,6 +49,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StartsWith), &[sub], true, + true, ) } @@ -124,13 +123,17 @@ impl StringNameSpace { /// Extract each successive non-overlapping match in an individual string as an array pub fn extract_all(self, pat: Expr) -> Expr { self.0 - .map_many_private(StringFunction::ExtractAll.into(), &[pat], false) + .map_many_private(StringFunction::ExtractAll.into(), &[pat], false, false) } /// Count all successive non-overlapping regex matches. - pub fn count_match(self, pat: &str) -> Expr { - let pat = pat.to_string(); - self.0.map_private(StringFunction::CountMatch(pat).into()) + pub fn count_matches(self, pat: Expr, literal: bool) -> Expr { + self.0.map_many_private( + StringFunction::CountMatches(literal).into(), + &[pat], + true, + false, + ) } /// Convert a Utf8 column into a Date/Datetime/Time column. @@ -139,6 +142,7 @@ impl StringNameSpace { self.0.map_many_private( StringFunction::Strptime(dtype, options).into(), &[ambiguous], + true, false, ) } @@ -199,229 +203,59 @@ impl StringNameSpace { /// * `delimiter` - A string that will act as delimiter between values. #[cfg(feature = "concat_str")] pub fn concat(self, delimiter: &str) -> Expr { - let delimiter = delimiter.to_owned(); - - Expr::Function { - input: vec![self.0], - function: StringFunction::ConcatVertical(delimiter).into(), - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - input_wildcard_expansion: false, - auto_explode: true, - ..Default::default() - }, - } + self.0 + .apply_private(StringFunction::ConcatVertical(delimiter.to_owned()).into()) + .with_function_options(|mut options| { + options.auto_explode = true; + options + }) } /// Split the string by a substring. The resulting dtype is `List`. - pub fn split(self, by: &str) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size()); - ca.into_iter().for_each(|opt_s| match opt_s { - None => builder.append_null(), - Some(s) => { - let iter = s.split(&by); - builder.append_values_iter(iter); - }, - }); - Ok(Some(builder.finish().into_series())) - }; + pub fn split(self, by: Expr) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))), - ) - .with_fmt("str.split") + .map_many_private(StringFunction::Split(false).into(), &[by], false, false) } /// Split the string by a substring and keep the substring. The resulting dtype is `List`. - pub fn split_inclusive(self, by: &str) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size()); - ca.into_iter().for_each(|opt_s| match opt_s { - None => builder.append_null(), - Some(s) => { - let iter = s.split_inclusive(&by); - builder.append_values_iter(iter); - }, - }); - Ok(Some(builder.finish().into_series())) - }; + pub fn split_inclusive(self, by: Expr) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))), - ) - .with_fmt("str.split_inclusive") + .map_many_private(StringFunction::Split(true).into(), &[by], false, false) } #[cfg(feature = "dtype-struct")] /// Split exactly `n` times by a given substring. The resulting dtype is [`DataType::Struct`]. - pub fn split_exact(self, by: &str, n: usize) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut arrs = (0..n + 1) - .map(|_| MutableUtf8Array::::with_capacity(ca.len())) - .collect::>(); - - ca.into_iter().for_each(|opt_s| match opt_s { - None => { - for arr in &mut arrs { - arr.push_null() - } - }, - Some(s) => { - let mut arr_iter = arrs.iter_mut(); - let split_iter = s.split(&by); - (split_iter) - .zip(&mut arr_iter) - .for_each(|(splitted, arr)| arr.push(Some(splitted))); - // fill the remaining with null - for arr in arr_iter { - arr.push_null() - } - }, - }); - let fields = arrs - .into_iter() - .enumerate() - .map(|(i, mut arr)| { - Series::try_from((format!("field_{i}").as_str(), arr.as_box())).unwrap() - }) - .collect::>(); - Ok(Some(StructChunked::new(ca.name(), &fields)?.into_series())) - }; - self.0 - .map( - function, - GetOutput::from_type(DataType::Struct( - (0..n + 1) - .map(|i| { - Field::from_owned(format_smartstring!("field_{i}"), DataType::Utf8) - }) - .collect(), - )), - ) - .with_fmt("str.split_exact") + pub fn split_exact(self, by: Expr, n: usize) -> Expr { + self.0.map_many_private( + StringFunction::SplitExact { + n, + inclusive: false, + } + .into(), + &[by], + false, + false, + ) } #[cfg(feature = "dtype-struct")] /// Split exactly `n` times by a given substring and keep the substring. /// The resulting dtype is [`DataType::Struct`]. - pub fn split_exact_inclusive(self, by: &str, n: usize) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut arrs = (0..n + 1) - .map(|_| MutableUtf8Array::::with_capacity(ca.len())) - .collect::>(); - - ca.into_iter().for_each(|opt_s| match opt_s { - None => { - for arr in &mut arrs { - arr.push_null() - } - }, - Some(s) => { - let mut arr_iter = arrs.iter_mut(); - let split_iter = s.split_inclusive(&by); - (split_iter) - .zip(&mut arr_iter) - .for_each(|(splitted, arr)| arr.push(Some(splitted))); - // fill the remaining with null - for arr in arr_iter { - arr.push_null() - } - }, - }); - let fields = arrs - .into_iter() - .enumerate() - .map(|(i, mut arr)| { - Series::try_from((format!("field_{i}").as_str(), arr.as_box())).unwrap() - }) - .collect::>(); - Ok(Some(StructChunked::new(ca.name(), &fields)?.into_series())) - }; - self.0 - .map( - function, - GetOutput::from_type(DataType::Struct( - (0..n + 1) - .map(|i| { - Field::from_owned(format_smartstring!("field_{i}"), DataType::Utf8) - }) - .collect(), - )), - ) - .with_fmt("str.split_exact") + pub fn split_exact_inclusive(self, by: Expr, n: usize) -> Expr { + self.0.map_many_private( + StringFunction::SplitExact { n, inclusive: true }.into(), + &[by], + false, + false, + ) } #[cfg(feature = "dtype-struct")] /// Split by a given substring, returning exactly `n` items. If there are more possible splits, /// keeps the remainder of the string intact. The resulting dtype is [`DataType::Struct`]. - pub fn splitn(self, by: &str, n: usize) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut arrs = (0..n) - .map(|_| MutableUtf8Array::::with_capacity(ca.len())) - .collect::>(); - - ca.into_iter().for_each(|opt_s| match opt_s { - None => { - for arr in &mut arrs { - arr.push_null() - } - }, - Some(s) => { - let mut arr_iter = arrs.iter_mut(); - let split_iter = s.splitn(n, &by); - (split_iter) - .zip(&mut arr_iter) - .for_each(|(splitted, arr)| arr.push(Some(splitted))); - // fill the remaining with null - for arr in arr_iter { - arr.push_null() - } - }, - }); - let fields = arrs - .into_iter() - .enumerate() - .map(|(i, mut arr)| { - Series::try_from((format!("field_{i}").as_str(), arr.as_box())).unwrap() - }) - .collect::>(); - Ok(Some(StructChunked::new(ca.name(), &fields)?.into_series())) - }; + pub fn splitn(self, by: Expr, n: usize) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::Struct( - (0..n) - .map(|i| { - Field::from_owned(format_smartstring!("field_{i}"), DataType::Utf8) - }) - .collect(), - )), - ) - .with_fmt("str.splitn") + .map_many_private(StringFunction::SplitN(n).into(), &[by], false, false) } #[cfg(feature = "regex")] @@ -430,6 +264,7 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::Replace { n: 1, literal }), &[pat, value], + false, true, ) } @@ -440,6 +275,7 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::Replace { n, literal }), &[pat, value], + false, true, ) } @@ -450,26 +286,59 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::Replace { n: -1, literal }), &[pat, value], + false, true, ) } /// Remove leading and trailing characters, or whitespace if matches is None. - pub fn strip(self, matches: Option) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::Strip(matches))) + pub fn strip_chars(self, matches: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripChars), + &[matches], + false, + false, + ) } /// Remove leading characters, or whitespace if matches is None. - pub fn lstrip(self, matches: Option) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::LStrip(matches))) + pub fn strip_chars_start(self, matches: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripCharsStart), + &[matches], + false, + false, + ) } - /// Remove trailing characters, or whitespace if matches is None.. - pub fn rstrip(self, matches: Option) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::RStrip(matches))) + /// Remove trailing characters, or whitespace if matches is None. + pub fn strip_chars_end(self, matches: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripCharsEnd), + &[matches], + false, + false, + ) + } + + /// Remove prefix. + pub fn strip_prefix(self, prefix: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripPrefix), + &[prefix], + false, + false, + ) + } + + /// Remove suffix. + pub fn strip_suffix(self, suffix: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::StripSuffix), + &[suffix], + false, + false, + ) } /// Convert all characters to lowercase. @@ -500,16 +369,29 @@ impl StringNameSpace { ))) } - /// Return the number of characters in the string (not bytes). - pub fn n_chars(self) -> Expr { + /// Return the length of each string as the number of bytes. + /// + /// When working with non-ASCII text, the length in bytes is not the same + /// as the length in characters. You may want to use + /// [`len_chars`] instead. Note that `len_bytes` is much more + /// performant (_O(1)_) than [`len_chars`] (_O(n)_). + /// + /// [`len_chars`]: StringNameSpace::len_chars + pub fn len_bytes(self) -> Expr { self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::NChars)) + .map_private(FunctionExpr::StringExpr(StringFunction::LenBytes)) } - /// Return the number of bytes in the string (not characters). - pub fn lengths(self) -> Expr { + /// Return the length of each string as the number of characters. + /// + /// When working with ASCII text, use [`len_bytes`] instead to achieve + /// equivalent output with much better performance: + /// [`len_bytes`] runs in _O(1)_, while `len_chars` runs in _O(n)_. + /// + /// [`len_bytes`]: StringNameSpace::len_bytes + pub fn len_chars(self) -> Expr { self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::Length)) + .map_private(FunctionExpr::StringExpr(StringFunction::LenChars)) } /// Slice the string values. diff --git a/crates/polars-plan/src/dsl/struct_.rs b/crates/polars-plan/src/dsl/struct_.rs index ddf01fc8bd4f..331d1d09d3d5 100644 --- a/crates/polars-plan/src/dsl/struct_.rs +++ b/crates/polars-plan/src/dsl/struct_.rs @@ -30,44 +30,9 @@ impl StructNameSpace { /// Rename the fields of the [`StructChunked`]. pub fn rename_fields(self, names: Vec) -> Expr { - let names = Arc::new(names); - let names2 = names.clone(); self.0 - .map( - move |s| { - let ca = s.struct_()?; - let fields = ca - .fields() - .iter() - .zip(names.as_ref()) - .map(|(s, name)| { - let mut s = s.clone(); - s.rename(name); - s - }) - .collect::>(); - StructChunked::new(ca.name(), &fields).map(|ca| Some(ca.into_series())) - }, - GetOutput::map_dtype(move |dt| match dt { - DataType::Struct(fields) => { - let fields = fields - .iter() - .zip(names2.as_ref()) - .map(|(fld, name)| Field::new(name, fld.data_type().clone())) - .collect(); - DataType::Struct(fields) - }, - // The types will be incorrect, but its better than nothing - // we can get an incorrect type with python lambdas, because we only know return type when running - // the query - dt => DataType::Struct( - names2 - .iter() - .map(|name| Field::new(name, dt.clone())) - .collect(), - ), - }), - ) - .with_fmt("struct.rename_fields") + .map_private(FunctionExpr::StructExpr(StructFunction::RenameFields( + Arc::from(names), + ))) } } diff --git a/crates/polars-plan/src/dsl/udf.rs b/crates/polars-plan/src/dsl/udf.rs new file mode 100644 index 000000000000..900cbcc95cb9 --- /dev/null +++ b/crates/polars-plan/src/dsl/udf.rs @@ -0,0 +1,92 @@ +use std::sync::Arc; + +use polars_arrow::error::{polars_bail, PolarsResult}; +use polars_core::prelude::Field; +use polars_core::schema::Schema; + +use super::{Expr, GetOutput, SeriesUdf, SpecialEq}; +use crate::prelude::{Context, FunctionOptions}; + +/// Represents a user-defined function +#[derive(Clone)] +pub struct UserDefinedFunction { + /// name + pub name: String, + /// The function signature. + pub input_fields: Vec, + /// The function output type. + pub return_type: GetOutput, + /// The function implementation. + pub fun: SpecialEq>, + /// Options for the function. + pub options: FunctionOptions, +} + +impl std::fmt::Debug for UserDefinedFunction { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("UserDefinedFunction") + .field("name", &self.name) + .field("signature", &self.input_fields) + .field("fun", &"") + .field("options", &self.options) + .finish() + } +} + +impl UserDefinedFunction { + /// Create a new UserDefinedFunction + pub fn new( + name: &str, + input_fields: Vec, + return_type: GetOutput, + fun: impl SeriesUdf + 'static, + ) -> Self { + Self { + name: name.to_owned(), + input_fields, + return_type, + fun: SpecialEq::new(Arc::new(fun)), + options: FunctionOptions::default(), + } + } + + /// creates a logical expression with a call of the UDF + /// This utility allows using the UDF without requiring access to the registry. + /// The schema is validated and the query will fail if the schema is invalid. + pub fn call(self, args: Vec) -> PolarsResult { + if args.len() != self.input_fields.len() { + polars_bail!(InvalidOperation: "expected {} arguments, got {}", self.input_fields.len(), args.len()) + } + let schema = Schema::from_iter(self.input_fields); + + if args + .iter() + .map(|e| e.to_field(&schema, Context::Default)) + .collect::>>() + .is_err() + { + polars_bail!(InvalidOperation: "unexpected field in UDF \nexpected: {:?}\n received {:?}", schema, args) + }; + + Ok(Expr::AnonymousFunction { + input: args, + function: self.fun, + output_type: self.return_type, + options: self.options, + }) + } + + /// creates a logical expression with a call of the UDF + /// This does not do any schema validation and is therefore faster. + /// + /// Only use this if you are certain that the schema is correct. + /// If the schema is invalid, the query will fail at runtime. + pub fn call_unchecked(self, args: Vec) -> Expr { + Expr::AnonymousFunction { + input: args, + function: self.fun, + output_type: self.return_type.clone(), + options: self.options, + } + } +} diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs index 99184c32f1a9..1415ffd66aca 100644 --- a/crates/polars-plan/src/frame/opt_state.rs +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -12,6 +12,8 @@ pub struct OptState { #[cfg(feature = "cse")] pub comm_subexpr_elim: bool, pub streaming: bool, + pub eager: bool, + pub fast_projection: bool, } impl Default for OptState { @@ -29,6 +31,8 @@ impl Default for OptState { #[cfg(feature = "cse")] comm_subexpr_elim: true, streaming: false, + fast_projection: true, + eager: false, } } } diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index 5b68ac6394ee..b4e62e3004aa 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -15,7 +15,6 @@ use crate::dsl::function_expr::FunctionExpr; #[cfg(feature = "cse")] use crate::logical_plan::visitor::AexprNode; use crate::logical_plan::Context; -use crate::prelude::aexpr::NodeInputs::Single; use crate::prelude::names::COUNT; use crate::prelude::*; @@ -179,8 +178,7 @@ pub enum AExpr { Window { function: Node, partition_by: Vec, - order_by: Option, - options: WindowOptions, + options: WindowType, }, #[default] Wildcard, @@ -286,9 +284,10 @@ impl AExpr { // latest, so that it is popped first container.push(*input); }, - Agg(agg_e) => { - let node = agg_e.get_input().first(); - container.push(node); + Agg(agg_e) => match agg_e.get_input() { + NodeInputs::Single(node) => container.push(node), + NodeInputs::Many(nodes) => container.extend_from_slice(&nodes), + NodeInputs::Leaf => {}, }, Ternary { truthy, @@ -314,15 +313,11 @@ impl AExpr { Window { function, partition_by, - order_by, options: _, } => { for e in partition_by.iter().rev() { container.push(*e); } - if let Some(e) = order_by { - container.push(*e); - } // latest so that it is popped first container.push(*function); }, @@ -345,7 +340,7 @@ impl AExpr { Column(_) | Literal(_) | Wildcard | Count | Nth(_) => return self, Alias(input, _) => input, Cast { expr, .. } => expr, - Explode(input) | Slice { input, .. } => input, + Explode(input) => input, BinaryExpr { left, right, .. } => { *right = inputs[0]; *left = inputs[1]; @@ -369,7 +364,15 @@ impl AExpr { return self; }, Agg(a) => { - a.set_input(inputs[0]); + match a { + AAggExpr::Quantile { expr, quantile, .. } => { + *expr = inputs[0]; + *quantile = inputs[1]; + }, + _ => { + a.set_input(inputs[0]); + }, + } return self; }, Ternary { @@ -387,17 +390,25 @@ impl AExpr { input.extend(inputs.iter().rev().copied()); return self; }, + Slice { + input, + offset, + length, + } => { + *length = inputs[0]; + *offset = inputs[1]; + *input = inputs[2]; + return self; + }, Window { function, partition_by, - order_by, .. } => { *function = *inputs.last().unwrap(); partition_by.clear(); partition_by.extend_from_slice(&inputs[..inputs.len() - 1]); - assert!(order_by.is_none()); return self; }, }; @@ -416,6 +427,7 @@ impl AExpr { impl AAggExpr { pub fn get_input(&self) -> NodeInputs { use AAggExpr::*; + use NodeInputs::*; match self { Min { input, .. } => Single(*input), Max { input, .. } => Single(*input), @@ -425,7 +437,7 @@ impl AAggExpr { Last(input) => Single(*input), Mean(input) => Single(*input), Implode(input) => Single(*input), - Quantile { expr, .. } => Single(*expr), + Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]), Sum(input) => Single(*input), Count(input) => Single(*input), Std(input, _) => Single(*input), @@ -464,7 +476,7 @@ pub enum NodeInputs { impl NodeInputs { pub fn first(&self) -> Node { match self { - Single(node) => *node, + NodeInputs::Single(node) => *node, NodeInputs::Many(nodes) => nodes[0], NodeInputs::Leaf => panic!(), } diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 782041b6c74e..81eb3502c861 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -185,6 +185,7 @@ impl AExpr { output_type, input, function, + options, .. } => { let tmp = function.get_output(); @@ -194,6 +195,7 @@ impl AExpr { // default context because `col()` would return a list in aggregation context .map(|node| arena.get(*node).to_field(schema, Context::Default, arena)) .collect::>>()?; + polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str); Ok(output_type.get_field(schema, ctxt, &fields)) }, Function { @@ -204,6 +206,7 @@ impl AExpr { // default context because `col()` would return a list in aggregation context .map(|node| arena.get(*node).to_field(schema, Context::Default, arena)) .collect::>>()?; + polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); function.get_field(schema, ctxt, &fields) }, Slice { input, .. } => arena.get(*input).to_field(schema, ctxt, arena), diff --git a/crates/polars-plan/src/logical_plan/alp.rs b/crates/polars-plan/src/logical_plan/alp.rs index 22a64b03f1ff..d6a96e2394a1 100644 --- a/crates/polars-plan/src/logical_plan/alp.rs +++ b/crates/polars-plan/src/logical_plan/alp.rs @@ -12,16 +12,9 @@ use crate::logical_plan::FileScan; use crate::prelude::*; use crate::utils::PushNode; -/// ALogicalPlan is a representation of LogicalPlan with Nodes which are allocated in an Arena +/// [`ALogicalPlan`] is a representation of [`LogicalPlan`] with [`Node`]s which are allocated in an [`Arena`] #[derive(Clone, Debug)] pub enum ALogicalPlan { - AnonymousScan { - function: Arc, - file_info: FileInfo, - output_schema: Option, - predicate: Option, - options: Arc, - }, #[cfg(feature = "python")] PythonScan { options: PythonOptions, @@ -60,11 +53,6 @@ pub enum ALogicalPlan { schema: SchemaRef, options: ProjectionOptions, }, - LocalProjection { - expr: Vec, - input: Node, - schema: SchemaRef, - }, Sort { input: Node, by_column: Vec, @@ -115,9 +103,9 @@ pub enum ALogicalPlan { contexts: Vec, schema: SchemaRef, }, - FileSink { + Sink { input: Node, - payload: FileSinkOptions, + payload: SinkType, }, } @@ -141,7 +129,6 @@ impl ALogicalPlan { Scan { file_info, .. } => &file_info.schema, #[cfg(feature = "python")] PythonScan { options, .. } => &options.schema, - AnonymousScan { file_info, .. } => &file_info.schema, _ => unreachable!(), } } @@ -150,14 +137,12 @@ impl ALogicalPlan { use ALogicalPlan::*; match self { Scan { scan_type, .. } => scan_type.into(), - AnonymousScan { .. } => "anonymous_scan", #[cfg(feature = "python")] PythonScan { .. } => "python_scan", Slice { .. } => "slice", Selection { .. } => "selection", DataFrameScan { .. } => "df", Projection { .. } => "projection", - LocalProjection { .. } => "local_projection", Sort { .. } => "sort", Cache { .. } => "cache", Aggregate { .. } => "aggregate", @@ -167,7 +152,12 @@ impl ALogicalPlan { MapFunction { .. } => "map_function", Union { .. } => "union", ExtContext { .. } => "ext_context", - FileSink { .. } => "file_sink", + Sink { payload, .. } => match payload { + SinkType::Memory => "sink (memory)", + SinkType::File { .. } => "sink (file)", + #[cfg(feature = "cloud")] + SinkType::Cloud { .. } => "sink (cloud)", + }, } } @@ -176,7 +166,7 @@ impl ALogicalPlan { use ALogicalPlan::*; let schema = match self { #[cfg(feature = "python")] - PythonScan { options, .. } => &options.schema, + PythonScan { options, .. } => options.output_schema.as_ref().unwrap_or(&options.schema), Union { inputs, .. } => return arena.get(inputs[0]).schema(arena), Cache { input, .. } => return arena.get(*input).schema(arena), Sort { input, .. } => return arena.get(*input).schema(arena), @@ -190,20 +180,12 @@ impl ALogicalPlan { output_schema, .. } => output_schema.as_ref().unwrap_or(schema), - AnonymousScan { - file_info, - output_schema, - .. - } => output_schema.as_ref().unwrap_or(&file_info.schema), Selection { input, .. } => return arena.get(*input).schema(arena), Projection { schema, .. } => schema, - LocalProjection { schema, .. } => schema, Aggregate { schema, .. } => schema, Join { schema, .. } => schema, HStack { schema, .. } => schema, - Distinct { input, .. } | FileSink { input, .. } => { - return arena.get(*input).schema(arena) - }, + Distinct { input, .. } | Sink { input, .. } => return arena.get(*input).schema(arena), Slice { input, .. } => return arena.get(*input).schema(arena), MapFunction { input, function } => { let input_schema = arena.get(*input).schema(arena); @@ -249,11 +231,6 @@ impl ALogicalPlan { input: inputs[0], predicate: exprs[0], }, - LocalProjection { schema, .. } => LocalProjection { - input: inputs[0], - expr: exprs, - schema: schema.clone(), - }, Projection { schema, options, .. } => Projection { @@ -356,26 +333,6 @@ impl ALogicalPlan { selection: new_selection, } }, - AnonymousScan { - function, - file_info, - output_schema, - predicate, - options, - } => { - let mut new_predicate = None; - if predicate.is_some() { - new_predicate = exprs.pop() - } - - AnonymousScan { - function: function.clone(), - file_info: file_info.clone(), - output_schema: output_schema.clone(), - predicate: new_predicate, - options: options.clone(), - } - }, MapFunction { function, .. } => MapFunction { input: inputs[0], function: function.clone(), @@ -385,7 +342,7 @@ impl ALogicalPlan { contexts: inputs, schema: schema.clone(), }, - FileSink { payload, .. } => FileSink { + Sink { payload, .. } => Sink { input: inputs.pop().unwrap(), payload: payload.clone(), }, @@ -400,7 +357,6 @@ impl ALogicalPlan { Sort { by_column, .. } => container.extend_from_slice(by_column), Selection { predicate, .. } => container.push(*predicate), Projection { expr, .. } => container.extend_from_slice(expr), - LocalProjection { expr, .. } => container.extend_from_slice(expr), Aggregate { keys, aggs, .. } => { let iter = keys.iter().copied().chain(aggs.iter().copied()); container.extend(iter) @@ -424,12 +380,7 @@ impl ALogicalPlan { }, #[cfg(feature = "python")] PythonScan { .. } => {}, - AnonymousScan { predicate, .. } => { - if let Some(node) = predicate { - container.push(*node) - } - }, - ExtContext { .. } | FileSink { .. } => {}, + ExtContext { .. } | Sink { .. } => {}, } } @@ -458,7 +409,6 @@ impl ALogicalPlan { Slice { input, .. } => *input, Selection { input, .. } => *input, Projection { input, .. } => *input, - LocalProjection { input, .. } => *input, Sort { input, .. } => *input, Cache { input, .. } => *input, Aggregate { input, .. } => *input, @@ -474,7 +424,7 @@ impl ALogicalPlan { HStack { input, .. } => *input, Distinct { input, .. } => *input, MapFunction { input, .. } => *input, - FileSink { input, .. } => *input, + Sink { input, .. } => *input, ExtContext { input, contexts, .. } => { @@ -485,7 +435,6 @@ impl ALogicalPlan { }, Scan { .. } => return, DataFrameScan { .. } => return, - AnonymousScan { .. } => return, #[cfg(feature = "python")] PythonScan { .. } => return, }; diff --git a/crates/polars-plan/src/logical_plan/anonymous_scan.rs b/crates/polars-plan/src/logical_plan/anonymous_scan.rs index 1f6f7d665ea7..8f94b7c27820 100644 --- a/crates/polars-plan/src/logical_plan/anonymous_scan.rs +++ b/crates/polars-plan/src/logical_plan/anonymous_scan.rs @@ -4,15 +4,24 @@ use std::fmt::{Debug, Formatter}; use polars_core::prelude::*; pub use super::options::AnonymousScanOptions; +use crate::dsl::Expr; + +pub struct AnonymousScanArgs { + pub n_rows: Option, + pub with_columns: Option>>, + pub schema: SchemaRef, + pub output_schema: Option, + pub predicate: Option, +} pub trait AnonymousScan: Send + Sync { fn as_any(&self) -> &dyn Any; /// Creates a dataframe from the supplied function & scan options. - fn scan(&self, scan_opts: AnonymousScanOptions) -> PolarsResult; + fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult; /// function to supply the schema. /// Allows for an optional infer schema argument for data sources with dynamic schemas - fn schema(&self, _infer_schema_length: Option) -> PolarsResult { + fn schema(&self, _infer_schema_length: Option) -> PolarsResult { polars_bail!(ComputeError: "must supply either a schema or a schema function"); } /// specify if the scan provider should allow predicate pushdowns @@ -37,13 +46,13 @@ pub trait AnonymousScan: Send + Sync { impl AnonymousScan for F where - F: Fn(AnonymousScanOptions) -> PolarsResult + Send + Sync, + F: Fn(AnonymousScanArgs) -> PolarsResult + Send + Sync, { fn as_any(&self) -> &dyn Any { unimplemented!() } - fn scan(&self, scan_opts: AnonymousScanOptions) -> PolarsResult { + fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult { self(scan_opts) } } diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index 276f2aeb76d9..ebe5fdff2f02 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -1,16 +1,18 @@ #[cfg(feature = "csv")] use std::io::{Read, Seek}; -#[cfg(feature = "parquet")] -use polars_core::cloud::CloudOptions; use polars_core::frame::explode::MeltArgs; use polars_core::prelude::*; +#[cfg(feature = "parquet")] +use polars_io::cloud::CloudOptions; #[cfg(feature = "ipc")] use polars_io::ipc::IpcReader; #[cfg(all(feature = "parquet", feature = "async"))] use polars_io::parquet::ParquetAsyncReader; #[cfg(feature = "parquet")] use polars_io::parquet::ParquetReader; +#[cfg(all(feature = "cloud", feature = "parquet"))] +use polars_io::pl_async::get_runtime; #[cfg(any( feature = "parquet", feature = "parquet_async", @@ -20,9 +22,10 @@ use polars_io::parquet::ParquetReader; use polars_io::RowCount; #[cfg(feature = "csv")] use polars_io::{ - csv::utils::{get_reader_bytes, infer_file_schema, is_compressed}, + csv::utils::{infer_file_schema, is_compressed}, csv::CsvEncoding, csv::NullValues, + utils::get_reader_bytes, }; use super::builder_functions::*; @@ -81,37 +84,51 @@ macro_rules! try_delayed { }; } +#[cfg(any(feature = "parquet", feature = "parquet_async",))] +fn prepare_schema(mut schema: Schema, row_count: Option<&RowCount>) -> SchemaRef { + if let Some(rc) = row_count { + let _ = schema.insert_at_index(0, rc.name.as_str().into(), IDX_DTYPE); + } + Arc::new(schema) +} + impl LogicalPlanBuilder { pub fn anonymous_scan( function: Arc, - schema: Option, + schema: Option, infer_schema_length: Option, skip_rows: Option, n_rows: Option, name: &'static str, ) -> PolarsResult { - let schema = Arc::new(match schema { + let schema = match schema { Some(s) => s, None => function.schema(infer_schema_length)?, - }); + }; - let file_info = FileInfo { - schema: schema.clone(), - row_estimation: (n_rows, n_rows.unwrap_or(usize::MAX)), + let file_info = FileInfo::new(schema.clone(), (n_rows, n_rows.unwrap_or(usize::MAX))); + let file_options = FileScanOptions { + n_rows, + with_columns: None, + cache: false, + row_count: None, + rechunk: false, + file_counter: Default::default(), + hive_partitioning: false, }; - Ok(LogicalPlan::AnonymousScan { - function, + + Ok(LogicalPlan::Scan { + path: "".into(), file_info, predicate: None, - options: Arc::new(AnonymousScanOptions { - fmt_str: name, - schema, - skip_rows, - n_rows, - output_schema: None, - with_columns: None, - predicate: None, - }), + file_options, + scan_type: FileScan::Anonymous { + function, + options: Arc::new(AnonymousScanOptions { + fmt_str: name, + skip_rows, + }), + }, } .into()) } @@ -128,35 +145,50 @@ impl LogicalPlanBuilder { low_memory: bool, cloud_options: Option, use_statistics: bool, + hive_partitioning: bool, + // used to prevent multiple cloud calls + known_schema: Option, ) -> PolarsResult { use polars_io::{is_cloud_url, SerReader as _}; let path = path.into(); - let (mut schema, num_rows) = if is_cloud_url(&path) { - #[cfg(not(feature = "async"))] + let (schema, num_rows, metadata) = if is_cloud_url(&path) { + #[cfg(not(feature = "cloud"))] panic!( "One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled." ); - #[cfg(feature = "async")] - { + #[cfg(feature = "cloud")] + if let Some(known_schema) = known_schema { + (known_schema, None, None) + } else { let uri = path.to_string_lossy(); - ParquetAsyncReader::file_info(&uri, cloud_options.as_ref())? + get_runtime().block_on(async { + let mut reader = + ParquetAsyncReader::from_uri(&uri, cloud_options.as_ref(), None, None) + .await?; + let schema = Arc::new(reader.schema().await?); + let num_rows = reader.num_rows().await?; + let metadata = reader.get_metadata().await?.clone(); + + PolarsResult::Ok((schema, Some(num_rows), Some(metadata))) + })? } } else { let file = polars_utils::open_file(&path)?; let mut reader = ParquetReader::new(file); - (reader.schema()?, reader.num_rows()?) + ( + prepare_schema(reader.schema()?, row_count.as_ref()), + Some(reader.num_rows()?), + Some(reader.get_metadata()?.clone()), + ) }; - if let Some(rc) = &row_count { - let _ = schema.insert_at_index(0, rc.name.as_str().into(), IDX_DTYPE); - } + let mut file_info = FileInfo::new(schema, (num_rows, num_rows.unwrap_or(0))); - let file_info = FileInfo { - schema: Arc::new(schema), - row_estimation: (Some(num_rows), num_rows), - }; + if hive_partitioning { + file_info.set_hive_partitions(path.as_path()); + } let options = FileScanOptions { with_columns: None, @@ -165,6 +197,7 @@ impl LogicalPlanBuilder { rechunk, row_count, file_counter: Default::default(), + hive_partitioning, }; Ok(LogicalPlan::Scan { path, @@ -178,6 +211,7 @@ impl LogicalPlanBuilder { use_statistics, }, cloud_options, + metadata, }, } .into()) @@ -205,10 +239,7 @@ impl LogicalPlanBuilder { let schema = Arc::new(schema); let num_rows = reader._num_rows()?; - let file_info = FileInfo { - schema, - row_estimation: (None, num_rows), - }; + let file_info = FileInfo::new(schema, (None, num_rows)); let file_options = FileScanOptions { with_columns: None, @@ -217,6 +248,8 @@ impl LogicalPlanBuilder { rechunk, row_count, file_counter: Default::default(), + // TODO! add + hive_partitioning: false, }; Ok(LogicalPlan::Scan { path, @@ -232,7 +265,7 @@ impl LogicalPlanBuilder { #[cfg(feature = "csv")] pub fn scan_csv>( path: P, - delimiter: u8, + separator: u8, has_header: bool, ignore_errors: bool, mut skip_rows: usize, @@ -281,7 +314,7 @@ impl LogicalPlanBuilder { // this needs a way to estimated bytes/rows. let (mut inferred_schema, rows_read, bytes_read) = infer_file_schema( &reader_bytes, - delimiter, + separator, infer_schema_length, has_header, schema_overwrite, @@ -315,10 +348,7 @@ impl LogicalPlanBuilder { let estimated_n_rows = (rows_read as f64 / bytes_read as f64 * n_bytes as f64) as usize; skip_rows += skip_rows_after_header; - let file_info = FileInfo { - schema, - row_estimation: (None, estimated_n_rows), - }; + let file_info = FileInfo::new(schema, (None, estimated_n_rows)); let options = FileScanOptions { with_columns: None, @@ -327,6 +357,8 @@ impl LogicalPlanBuilder { rechunk, row_count, file_counter: Default::default(), + // TODO! add + hive_partitioning: false, }; Ok(LogicalPlan::Scan { path, @@ -336,7 +368,7 @@ impl LogicalPlanBuilder { scan_type: FileScan::Csv { options: CsvParserOptions { has_header, - delimiter, + separator, ignore_errors, skip_rows, low_memory, @@ -365,6 +397,44 @@ impl LogicalPlanBuilder { .into() } + pub fn drop_columns(self, to_drop: PlHashSet) -> Self { + let schema = try_delayed!(self.0.schema(), &self.0, into); + + let mut output_schema = Schema::with_capacity(schema.len() - to_drop.len()); + let columns = schema + .iter() + .filter_map(|(col_name, dtype)| { + if to_drop.contains(col_name.as_str()) { + None + } else { + let out = Some(col(col_name)); + output_schema.with_column(col_name.clone(), dtype.clone()); + out + } + }) + .collect::>(); + + if columns.is_empty() { + self.map( + |_| Ok(DataFrame::new_no_checks(vec![])), + AllowedOptimizations::default(), + Some(Arc::new(|_: &Schema| Ok(Arc::new(Schema::default())))), + "EMPTY PROJECTION", + ) + } else { + LogicalPlan::Projection { + expr: columns, + input: Box::new(self.0), + schema: Arc::new(output_schema), + options: ProjectionOptions { + run_parallel: false, + duplicate_check: false, + }, + } + .into() + } + } + pub fn project(self, exprs: Vec, options: ProjectionOptions) -> Self { let schema = try_delayed!(self.0.schema(), &self.0, into); let (exprs, schema) = try_delayed!(prepare_projection(exprs, &schema), &self.0, into); @@ -387,24 +457,13 @@ impl LogicalPlanBuilder { } } - pub fn project_local(self, exprs: Vec) -> Self { - let schema = try_delayed!(self.0.schema(), &self.0, into); - let (exprs, schema) = try_delayed!(prepare_projection(exprs, &schema), &self.0, into); - LogicalPlan::LocalProjection { - expr: exprs, - input: Box::new(self.0), - schema: Arc::new(schema), - } - .into() - } - pub fn fill_null(self, fill_value: Expr) -> Self { let schema = try_delayed!(self.0.schema(), &self.0, into); let exprs = schema .iter_names() .map(|name| col(name).fill_null(fill_value.clone())) .collect(); - self.project_local(exprs) + self.project(exprs, Default::default()) } pub fn fill_nan(self, fill_value: Expr) -> Self { @@ -423,6 +482,7 @@ impl LogicalPlanBuilder { exprs, ProjectionOptions { run_parallel: false, + duplicate_check: false, }, ) } @@ -495,9 +555,11 @@ impl LogicalPlanBuilder { pub fn filter(self, predicate: Expr) -> Self { let predicate = if has_expr(&predicate, |e| match e { Expr::Column(name) => is_regex_projection(name), - Expr::Wildcard | Expr::RenameAlias { .. } | Expr::Columns(_) | Expr::DtypeColumn(_) => { - true - }, + Expr::Wildcard + | Expr::RenameAlias { .. } + | Expr::Columns(_) + | Expr::DtypeColumn(_) + | Expr::Nth(_) => true, _ => false, }) { let schema = try_delayed!(self.0.schema(), &self.0, into); diff --git a/crates/polars-plan/src/logical_plan/builder_alp.rs b/crates/polars-plan/src/logical_plan/builder_alp.rs index b5ce0d863b73..0e65060d41cd 100644 --- a/crates/polars-plan/src/logical_plan/builder_alp.rs +++ b/crates/polars-plan/src/logical_plan/builder_alp.rs @@ -41,17 +41,6 @@ impl<'a> ALogicalPlanBuilder<'a> { ALogicalPlanBuilder::new(node, self.expr_arena, self.lp_arena) } - pub fn project_local(self, exprs: Vec) -> Self { - let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena); - let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena); - let lp = ALogicalPlan::LocalProjection { - expr: exprs, - input: self.root, - schema: Arc::new(schema), - }; - self.add_alp(lp) - } - pub fn project(self, exprs: Vec, options: ProjectionOptions) -> Self { let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena); let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena); diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 3a9ef2509850..f1910f2be2a9 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -128,12 +128,10 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { Expr::Window { function, partition_by, - order_by, options, } => AExpr::Window { function: to_aexpr(*function, arena), partition_by: to_aexprs(partition_by, arena), - order_by: order_by.map(|ob| to_aexpr(*ob, arena)), options, }, Expr::Slice { @@ -148,6 +146,7 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { Expr::Wildcard => AExpr::Wildcard, Expr::Count => AExpr::Count, Expr::Nth(i) => AExpr::Nth(i), + Expr::SubPlan { .. } => panic!("no SQLSubquery expected at this point"), Expr::KeepName(_) => panic!("no keep_name expected at this point"), Expr::Exclude(_, _) => panic!("no exclude expected at this point"), Expr::RenameAlias { .. } => panic!("no `rename_alias` expected at this point"), @@ -181,18 +180,6 @@ pub fn to_alp( scan_type, file_options: options, }, - LogicalPlan::AnonymousScan { - function, - file_info, - predicate, - options, - } => ALogicalPlan::AnonymousScan { - function, - file_info, - output_schema: None, - predicate: predicate.map(|expr| to_aexpr(expr, expr_arena)), - options, - }, #[cfg(feature = "python")] LogicalPlan::PythonScan { options } => ALogicalPlan::PythonScan { options, @@ -245,19 +232,6 @@ pub fn to_alp( options, } }, - LogicalPlan::LocalProjection { - expr, - input, - schema, - } => { - let exp = expr.into_iter().map(|x| to_aexpr(x, expr_arena)).collect(); - let i = to_alp(*input, expr_arena, lp_arena)?; - ALogicalPlan::LocalProjection { - expr: exp, - input: i, - schema, - } - }, LogicalPlan::Sort { input, by_column, @@ -377,9 +351,9 @@ pub fn to_alp( schema, } }, - LogicalPlan::FileSink { input, payload } => { + LogicalPlan::Sink { input, payload } => { let input = to_alp(*input, expr_arena, lp_arena)?; - ALogicalPlan::FileSink { input, payload } + ALogicalPlan::Sink { input, payload } }, }; Ok(lp_arena.add(v)) @@ -578,16 +552,13 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AExpr::Window { function, partition_by, - order_by, options, } => { let function = Box::new(node_to_expr(function, expr_arena)); let partition_by = nodes_to_exprs(&partition_by, expr_arena); - let order_by = order_by.map(|ob| Box::new(node_to_expr(ob, expr_arena))); Expr::Window { function, partition_by, - order_by, options, } }, @@ -639,18 +610,6 @@ impl ALogicalPlan { scan_type, file_options: options, }, - ALogicalPlan::AnonymousScan { - function, - file_info, - output_schema: _, - predicate, - options, - } => LogicalPlan::AnonymousScan { - function, - file_info, - predicate: predicate.map(|n| node_to_expr(n, expr_arena)), - options, - }, #[cfg(feature = "python")] ALogicalPlan::PythonScan { options, .. } => LogicalPlan::PythonScan { options }, ALogicalPlan::Union { inputs, options } => { @@ -704,19 +663,6 @@ impl ALogicalPlan { options, } }, - ALogicalPlan::LocalProjection { - expr, - input, - schema, - } => { - let i = convert_to_lp(input, lp_arena); - - LogicalPlan::LocalProjection { - expr: nodes_to_exprs(&expr, expr_arena), - input: Box::new(i), - schema, - } - }, ALogicalPlan::Sort { input, by_column, @@ -816,9 +762,9 @@ impl ALogicalPlan { schema, } }, - ALogicalPlan::FileSink { input, payload } => { + ALogicalPlan::Sink { input, payload } => { let input = Box::new(convert_to_lp(input, lp_arena)); - LogicalPlan::FileSink { input, payload } + LogicalPlan::Sink { input, payload } }, } } diff --git a/crates/polars-plan/src/logical_plan/file_scan.rs b/crates/polars-plan/src/logical_plan/file_scan.rs index f3d52c115f93..15b0ba13c0bc 100644 --- a/crates/polars-plan/src/logical_plan/file_scan.rs +++ b/crates/polars-plan/src/logical_plan/file_scan.rs @@ -1,6 +1,9 @@ +#[cfg(feature = "parquet")] +use arrow::io::parquet::write::FileMetaData; + use super::*; -#[derive(Clone, Debug, IntoStaticStr, PartialEq)] +#[derive(Clone, Debug, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum FileScan { #[cfg(feature = "csv")] @@ -9,9 +12,41 @@ pub enum FileScan { Parquet { options: ParquetOptions, cloud_options: Option, + #[cfg_attr(feature = "serde", serde(skip))] + metadata: Option>, }, #[cfg(feature = "ipc")] Ipc { options: IpcScanOptions }, + #[cfg_attr(feature = "serde", serde(skip))] + Anonymous { + options: Arc, + function: Arc, + }, +} + +impl PartialEq for FileScan { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + #[cfg(feature = "csv")] + (FileScan::Csv { options: l }, FileScan::Csv { options: r }) => l == r, + #[cfg(feature = "parquet")] + ( + FileScan::Parquet { + options: opt_l, + cloud_options: c_l, + .. + }, + FileScan::Parquet { + options: opt_r, + cloud_options: c_r, + .. + }, + ) => opt_l == opt_r && c_l == c_r, + #[cfg(feature = "ipc")] + (FileScan::Ipc { options: l }, FileScan::Ipc { options: r }) => l == r, + _ => false, + } + } } impl FileScan { diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 7cd1e5c146ed..ae7e4e48efd6 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -66,28 +66,6 @@ impl LogicalPlan { options.n_rows, ) }, - AnonymousScan { - file_info, - predicate, - options, - .. - } => { - let n_columns = options - .with_columns - .as_ref() - .map(|columns| columns.len() as i64) - .unwrap_or(-1); - write_scan( - f, - options.fmt_str, - Path::new(""), - sub_indent, - n_columns, - file_info.schema.len(), - predicate, - options.n_rows, - ) - }, Union { inputs, options } => { let mut name = String::new(); let name = if let Some(slice) = options.slice { @@ -170,10 +148,6 @@ impl LogicalPlan { write!(f, "{:indent$} SELECT {expr:?} FROM", "")?; input._format(f, sub_indent) }, - LocalProjection { expr, input, .. } => { - write!(f, "{:indent$} LOCAL SELECT {expr:?} FROM", "")?; - input._format(f, sub_indent) - }, Sort { input, by_column, .. } => { @@ -228,8 +202,14 @@ impl LogicalPlan { write!(f, "{:indent$}EXTERNAL_CONTEXT", "")?; input._format(f, sub_indent) }, - FileSink { input, .. } => { - write!(f, "{:indent$}FILE_SINK", "")?; + Sink { input, payload, .. } => { + let name = match payload { + SinkType::Memory => "SINK (memory)", + SinkType::File { .. } => "SINK (file)", + #[cfg(feature = "cloud")] + SinkType::Cloud { .. } => "SINK (cloud)", + }; + write!(f, "{:indent$}{}", "", name)?; input._format(f, sub_indent) }, } @@ -291,6 +271,9 @@ impl Debug for Expr { Take { expr, idx } => { write!(f, "{expr:?}.take({idx:?})") }, + SubPlan(lf, _) => { + write!(f, ".subplan({lf:?})") + }, Agg(agg) => { use AggExpr::*; match agg { diff --git a/crates/polars-plan/src/logical_plan/functions/drop.rs b/crates/polars-plan/src/logical_plan/functions/drop.rs deleted file mode 100644 index 242289329ead..000000000000 --- a/crates/polars-plan/src/logical_plan/functions/drop.rs +++ /dev/null @@ -1,33 +0,0 @@ -use super::*; - -pub(super) fn drop_impl(mut df: DataFrame, names: &[SmartString]) -> PolarsResult { - for name in names { - // ignore names that are not in there - // they might already be removed by projection pushdown - if let Some(idx) = df.find_idx_by_name(name) { - let _ = unsafe { df.get_columns_mut().remove(idx) }; - } - } - - Ok(df) -} - -pub(super) fn drop_schema<'a>( - input_schema: &'a SchemaRef, - names: &[SmartString], -) -> PolarsResult> { - let to_drop = PlHashSet::from_iter(names); - - let new_schema = input_schema - .iter() - .flat_map(|(name, dtype)| { - if to_drop.contains(name) { - None - } else { - Some(Field::new(name, dtype.clone())) - } - }) - .collect::(); - - Ok(Cow::Owned(Arc::new(new_schema))) -} diff --git a/crates/polars-plan/src/logical_plan/functions/mod.rs b/crates/polars-plan/src/logical_plan/functions/mod.rs index 2be3ff57a1e0..72efcf252c39 100644 --- a/crates/polars-plan/src/logical_plan/functions/mod.rs +++ b/crates/polars-plan/src/logical_plan/functions/mod.rs @@ -1,4 +1,3 @@ -mod drop; #[cfg(feature = "merge_sorted")] mod merge_sorted; #[cfg(feature = "python")] @@ -11,7 +10,7 @@ use std::sync::Arc; use polars_core::prelude::*; #[cfg(feature = "dtype-categorical")] -use polars_core::IUseStringCache; +use polars_core::StringCacheHolder; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; @@ -59,7 +58,8 @@ pub enum FunctionNode { columns: Arc<[Arc]>, }, FastProjection { - columns: Arc<[Arc]>, + columns: Arc<[SmartString]>, + duplicate_check: bool, }, DropNulls { subset: Arc<[Arc]>, @@ -80,9 +80,6 @@ pub enum FunctionNode { // A column name gets swapped with an existing column swapping: bool, }, - Drop { - names: Arc<[SmartString]>, - }, Explode { columns: Arc<[Arc]>, schema: SchemaRef, @@ -102,7 +99,16 @@ impl PartialEq for FunctionNode { fn eq(&self, other: &Self) -> bool { use FunctionNode::*; match (self, other) { - (FastProjection { columns: l }, FastProjection { columns: r }) => l == r, + ( + FastProjection { + columns: l, + duplicate_check: dl, + }, + FastProjection { + columns: r, + duplicate_check: dr, + }, + ) => l == r && dl == dr, (DropNulls { subset: l }, DropNulls { subset: r }) => l == r, (Rechunk, Rechunk) => true, ( @@ -117,7 +123,6 @@ impl PartialEq for FunctionNode { .. }, ) => existing_l == existing_r && new_l == new_r, - (Drop { names: l }, Drop { names: r }) => l == r, (Explode { columns: l, .. }, Explode { columns: r, .. }) => l == r, (Melt { args: l, .. }, Melt { args: r, .. }) => l == r, (RowCount { name: l, .. }, RowCount { name: r, .. }) => l == r, @@ -138,8 +143,7 @@ impl FunctionNode { | FastProjection { .. } | Unnest { .. } | Rename { .. } - | Explode { .. } - | Drop { .. } => true, + | Explode { .. } => true, Melt { args, .. } => args.streamable, Opaque { streamable, .. } => *streamable, #[cfg(feature = "python")] @@ -178,7 +182,7 @@ impl FunctionNode { .map(|schema| Cow::Owned(schema.clone())) .unwrap_or_else(|| Cow::Borrowed(input_schema))), Pipeline { schema, .. } => Ok(Cow::Owned(schema.clone())), - FastProjection { columns } => { + FastProjection { columns, .. } => { let schema = columns .iter() .map(|name| { @@ -229,7 +233,6 @@ impl FunctionNode { #[cfg(feature = "merge_sorted")] MergeSorted { .. } => Ok(Cow::Borrowed(input_schema)), Rename { existing, new, .. } => rename::rename_schema(input_schema, existing, new), - Drop { names } => drop::drop_schema(input_schema, names), Explode { schema, .. } | RowCount { schema, .. } | Melt { schema, .. } => { Ok(Cow::Owned(schema.clone())) }, @@ -248,8 +251,7 @@ impl FunctionNode { | Unnest { .. } | Rename { .. } | Explode { .. } - | Melt { .. } - | Drop { .. } => true, + | Melt { .. } => true, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, RowCount { .. } => false, @@ -269,8 +271,7 @@ impl FunctionNode { | Unnest { .. } | Rename { .. } | Explode { .. } - | Melt { .. } - | Drop { .. } => true, + | Melt { .. } => true, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, RowCount { .. } => true, @@ -300,7 +301,16 @@ impl FunctionNode { schema, .. } => python_udf::call_python_udf(function, df, *validate_output, schema.as_deref()), - FastProjection { columns } => df.select(columns.as_ref()), + FastProjection { + columns, + duplicate_check, + } => { + if *duplicate_check { + df._select_impl(columns.as_ref()) + } else { + df._select_impl_unchecked(columns.as_ref()) + } + }, DropNulls { subset } => df.drop_nulls(Some(subset.as_ref())), Rechunk => { df.as_single_chunk_par(); @@ -322,7 +332,7 @@ impl FunctionNode { // we use a global string cache here as streaming chunks all have different rev maps #[cfg(feature = "dtype-categorical")] { - let _hold = IUseStringCache::hold(); + let _sc = StringCacheHolder::hold(); Arc::get_mut(function).unwrap().call_udf(df) } @@ -332,7 +342,6 @@ impl FunctionNode { } }, Rename { existing, new, .. } => rename::rename_impl(df, existing, new), - Drop { names } => drop::drop_impl(df, names), Explode { columns, .. } => df.explode(columns.as_ref()), Melt { args, .. } => { let args = (**args).clone(); @@ -356,7 +365,7 @@ impl Display for FunctionNode { Opaque { fmt_str, .. } => write!(f, "{fmt_str}"), #[cfg(feature = "python")] OpaquePython { .. } => write!(f, "python dataframe udf"), - FastProjection { columns } => { + FastProjection { columns, .. } => { write!(f, "FAST_PROJECT: ")?; let columns = columns.as_ref(); fmt_column_delimited(f, columns, "[", "]") @@ -385,7 +394,6 @@ impl Display for FunctionNode { } }, Rename { .. } => write!(f, "RENAME"), - Drop { .. } => write!(f, "DROP"), Explode { .. } => write!(f, "EXPLODE"), Melt { .. } => write!(f, "MELT"), RowCount { .. } => write!(f, "WITH ROW COUNT"), diff --git a/crates/polars-plan/src/logical_plan/hive.rs b/crates/polars-plan/src/logical_plan/hive.rs new file mode 100644 index 000000000000..22745cf6247f --- /dev/null +++ b/crates/polars-plan/src/logical_plan/hive.rs @@ -0,0 +1,101 @@ +use std::path::Path; + +use percent_encoding::percent_decode_str; +use polars_core::prelude::*; +use polars_io::predicates::{BatchStats, ColumnStats}; +use polars_io::utils::{BOOLEAN_RE, FLOAT_RE, INTEGER_RE}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct HivePartitions { + /// Single value Series that can be used to run the predicate against. + /// They are to be broadcasted if the predicates don't filter them out. + stats: BatchStats, +} + +#[cfg(target_os = "windows")] +fn separator(url: &Path) -> char { + if polars_io::is_cloud_url(url) { + '/' + } else { + '\\' + } +} + +#[cfg(not(target_os = "windows"))] +fn separator(_url: &Path) -> char { + '/' +} + +impl HivePartitions { + pub fn get_statistics(&self) -> &BatchStats { + &self.stats + } + + /// Parse a url and optionally return HivePartitions + pub(crate) fn parse_url(url: &Path) -> Option { + let sep = separator(url); + + let partitions = url + .display() + .to_string() + .split(sep) + .filter_map(|part| { + let mut it = part.split('='); + let name = it.next()?; + let value = it.next()?; + + // Having multiple '=' doesn't seem like valid hive partition, + // continue as url. + if it.next().is_some() { + return None; + } + + let s = if INTEGER_RE.is_match(value) { + let value = value.parse::().ok()?; + Series::new(name, &[value]) + } else if BOOLEAN_RE.is_match(value) { + let value = value.parse::().ok()?; + Series::new(name, &[value]) + } else if FLOAT_RE.is_match(value) { + let value = value.parse::().ok()?; + Series::new(name, &[value]) + } else if value == "__HIVE_DEFAULT_PARTITION__" { + Series::full_null(name, 1, &DataType::Null) + } else { + Series::new(name, &[percent_decode_str(value).decode_utf8().ok()?]) + }; + Some(s) + }) + .collect::>(); + + if partitions.is_empty() { + None + } else { + let schema: Schema = partitions.as_slice().into(); + let stats = BatchStats::new( + schema, + partitions + .into_iter() + .map(ColumnStats::from_column_literal) + .collect(), + ); + + Some(HivePartitions { stats }) + } + } + + pub(crate) fn schema(&self) -> &Schema { + self.get_statistics().schema() + } + + pub fn materialize_partition_columns(&self) -> Vec { + self.get_statistics() + .column_stats() + .iter() + .map(|cs| cs.get_min_state().unwrap().clone()) + .collect() + } +} diff --git a/crates/polars-plan/src/logical_plan/iterator.rs b/crates/polars-plan/src/logical_plan/iterator.rs index adb333804b35..17ed5350c1d7 100644 --- a/crates/polars-plan/src/logical_plan/iterator.rs +++ b/crates/polars-plan/src/logical_plan/iterator.rs @@ -68,15 +68,11 @@ macro_rules! push_expr { Window { function, partition_by, - order_by, .. } => { for e in partition_by.into_iter().rev() { $push(e) } - if let Some(e) = order_by { - $push(e); - } // latest so that it is popped first $push(function); }, @@ -93,6 +89,7 @@ macro_rules! push_expr { Exclude(e, _) => $push(e), KeepName(e) => $push(e), RenameAlias { expr, .. } => $push(expr), + SubPlan { .. } => {}, // pass Selector(_) => {}, } diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index 7f8dc6da6446..0937a06ddab8 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -252,7 +252,7 @@ impl Literal for NaiveDateTime { fn lit(self) -> Expr { if in_nanoseconds_window(&self) { Expr::Literal(LiteralValue::DateTime( - self.timestamp_nanos(), + self.timestamp_nanos_opt().unwrap(), TimeUnit::Nanoseconds, None, )) @@ -312,10 +312,26 @@ pub fn lit(t: L) -> Expr { impl Hash for LiteralValue { fn hash(&self, state: &mut H) { - if let Some(v) = self.to_anyvalue() { - v.hash_impl(state, true) - } else { - 0.hash(state) + std::mem::discriminant(self).hash(state); + match self { + LiteralValue::Series(s) => { + s.dtype().hash(state); + s.len().hash(state); + }, + LiteralValue::Range { + low, + high, + data_type, + } => { + low.hash(state); + high.hash(state); + data_type.hash(state) + }, + _ => { + if let Some(v) = self.to_anyvalue() { + v.hash_impl(state, true) + } + }, } } } diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index 39d6b0d56511..ecb6d1e8917b 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -2,9 +2,9 @@ use std::fmt::Debug; use std::path::PathBuf; use std::sync::{Arc, Mutex}; -#[cfg(feature = "parquet")] -use polars_core::cloud::CloudOptions; use polars_core::prelude::*; +#[cfg(any(feature = "cloud", feature = "parquet"))] +use polars_io::cloud::CloudOptions; use crate::logical_plan::LogicalPlan::DataFrameScan; use crate::prelude::*; @@ -24,6 +24,7 @@ pub(crate) mod debug; mod file_scan; mod format; mod functions; +pub(super) mod hive; pub(crate) mod iterator; mod lit; pub(crate) mod optimizer; @@ -141,13 +142,6 @@ impl From for ErrorStateSync { #[derive(Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum LogicalPlan { - #[cfg_attr(feature = "serde", serde(skip))] - AnonymousScan { - function: Arc, - file_info: FileInfo, - predicate: Option, - options: Arc, - }, #[cfg(feature = "python")] PythonScan { options: PythonOptions }, /// Filter on a boolean mask @@ -178,13 +172,6 @@ pub enum LogicalPlan { projection: Option>>, selection: Option, }, - // a projection that doesn't have to be optimized - // or may drop projected columns if they aren't in current schema (after optimization) - LocalProjection { - expr: Vec, - input: Box, - schema: SchemaRef, - }, /// Column selection Projection { expr: Vec, @@ -257,9 +244,9 @@ pub enum LogicalPlan { contexts: Vec, schema: SchemaRef, }, - FileSink { + Sink { input: Box, - payload: FileSinkOptions, + payload: SinkType, }, } diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index 1bace7072a92..23881fdd4b3c 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -169,9 +169,7 @@ pub(super) fn set_cache_states( let lp = pd.optimize(lp, lp_arena, expr_arena).unwrap(); // remove the projection added by the optimization - let lp = if let ALogicalPlan::Projection { input, .. } - | ALogicalPlan::LocalProjection { input, .. } = lp - { + let lp = if let ALogicalPlan::Projection { input, .. } = lp { lp_arena.take(input) } else { lp diff --git a/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs b/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs new file mode 100644 index 000000000000..17b051411cde --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs @@ -0,0 +1,28 @@ +use super::*; + +pub(super) struct MemberCollector { + pub(crate) has_joins_or_unions: bool, + pub(crate) has_cache: bool, + pub(crate) has_ext_context: bool, +} + +impl MemberCollector { + pub(super) fn new() -> Self { + Self { + has_joins_or_unions: false, + has_cache: false, + has_ext_context: false, + } + } + pub fn collect(&mut self, root: Node, lp_arena: &Arena) { + use ALogicalPlan::*; + for (_, alp) in lp_arena.iter(root) { + match alp { + Join { .. } | Union { .. } => self.has_joins_or_unions = true, + Cache { .. } => self.has_cache = true, + ExtContext { .. } => self.has_ext_context = true, + _ => {}, + } + } + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse.rs b/crates/polars-plan/src/logical_plan/optimizer/cse.rs index 6ad2d63ef1e4..11e4c45ca925 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse.rs @@ -1,7 +1,6 @@ //! Common Subplan Elimination use std::collections::{BTreeMap, BTreeSet}; -use std::hash::{BuildHasher, Hash, Hasher}; use polars_core::prelude::*; @@ -310,9 +309,7 @@ pub(crate) fn elim_cmn_subplans( (Some(h), _) => *h, (_, Some(h)) => *h, _ => { - let mut h = hb.build_hasher(); - node1.hash(&mut h); - let hash = h.finish(); + let hash = hb.hash_one(node1); let mut cache_id = lp_cache.wrapping_add(hash as usize); // this ensures we can still add branch ids without overflowing // during the dot representation diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs index 52734df7e02f..3dc2d0f108f9 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs @@ -316,14 +316,20 @@ impl ExprIdentifierVisitor<'_> { // Don't allow this for now, as we can get `null().cast()` in ternary expressions. // TODO! Add a typed null AExpr::Literal(LiteralValue::Null) => REFUSE_NO_MEMBER, - AExpr::Column(_) | AExpr::Literal(_) | AExpr::Count | AExpr::Alias(_, _) => { - REFUSE_ALLOW_MEMBER + AExpr::Column(_) | AExpr::Literal(_) | AExpr::Alias(_, _) => REFUSE_ALLOW_MEMBER, + AExpr::Count => { + if self.is_group_by { + REFUSE_NO_MEMBER + } else { + REFUSE_ALLOW_MEMBER + } }, #[cfg(feature = "random")] AExpr::Function { function: FunctionExpr::Random { .. }, .. } => REFUSE_NO_MEMBER, + AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER, _ => { // During aggregation we only store elementwise operation in the state // other operations we cannot add to the state as they have the output size of the @@ -506,7 +512,12 @@ impl RewritingVisitor for CommonSubExprRewriter<'_> { return Ok(recurse); } - let (_, count) = self.sub_expr_map.get(id).unwrap(); + // Because some expressions don't have hash / equality guarantee (e.g. floats) + // we can get none here. This must be changed later. + let Some((_, count)) = self.sub_expr_map.get(id) else { + self.visited_idx += 1; + return Ok(RewriteRecursion::NoMutateAndContinue); + }; if *count > 1 { self.replaced_identifiers.insert(id.clone()); // rewrite this sub-expression, don't visit its children diff --git a/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs b/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs index 2de0261caddc..32b30e7761a9 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs @@ -1,6 +1,7 @@ use std::collections::BTreeSet; use polars_core::prelude::*; +use smartstring::SmartString; use super::*; use crate::logical_plan::alp::ALogicalPlan; @@ -19,12 +20,14 @@ pub(super) struct FastProjectionAndCollapse { /// keep track of nodes that are already processed when they /// can be expensive. Schema materialization can be for instance. processed: BTreeSet, + eager: bool, } impl FastProjectionAndCollapse { - pub(super) fn new() -> Self { + pub(super) fn new(eager: bool) -> Self { Self { processed: Default::default(), + eager, } } } @@ -33,11 +36,12 @@ fn impl_fast_projection( input: Node, expr: &[Node], expr_arena: &Arena, + duplicate_check: bool, ) -> Option { let mut columns = Vec::with_capacity(expr.len()); for node in expr.iter() { if let AExpr::Column(name) = expr_arena.get(*node) { - columns.push(name.clone()) + columns.push(SmartString::from(name.as_ref())) } else { break; } @@ -47,6 +51,7 @@ fn impl_fast_projection( input, function: FunctionNode::FastProjection { columns: Arc::from(columns), + duplicate_check, }, }; @@ -67,18 +72,22 @@ impl OptimizationRule for FastProjectionAndCollapse { let lp = lp_arena.get(node); match lp { - Projection { input, expr, .. } => { + Projection { + input, + expr, + options, + .. + } => { if !matches!(lp_arena.get(*input), ExtContext { .. }) { - impl_fast_projection(*input, expr, expr_arena) + impl_fast_projection(*input, expr, expr_arena, options.duplicate_check) } else { None } }, - LocalProjection { input, expr, .. } => impl_fast_projection(*input, expr, expr_arena), MapFunction { input, - function: FunctionNode::FastProjection { columns }, - } => { + function: FunctionNode::FastProjection { columns, .. }, + } if !self.eager => { // if there are 2 subsequent fast projections, flatten them and only take the last match lp_arena.get(*input) { MapFunction { @@ -88,6 +97,7 @@ impl OptimizationRule for FastProjectionAndCollapse { input: *prev_input, function: FunctionNode::FastProjection { columns: columns.clone(), + duplicate_check: true, }, }), // cleanup projections set in projection pushdown just above caches @@ -96,7 +106,7 @@ impl OptimizationRule for FastProjectionAndCollapse { let cache_schema = cache_lp.schema(lp_arena); if cache_schema.len() == columns.len() && cache_schema.iter_names().zip(columns.iter()).all( - |(left_name, right_name)| left_name.as_str() == right_name.as_ref(), + |(left_name, right_name)| left_name.as_str() == right_name.as_str(), ) { Some(cache_lp.clone()) @@ -112,7 +122,7 @@ impl OptimizationRule for FastProjectionAndCollapse { input, count: outer_count, .. - } => { + } if !self.eager => { if let Cache { input: prev_input, id, diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index 1b8b08e68a7b..3f0dc062e994 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -9,6 +9,7 @@ mod cse; mod delay_rechunk; mod drop_nulls; +mod collect_members; #[cfg(feature = "cse")] mod cse_expr; mod fast_projection; @@ -30,6 +31,7 @@ use drop_nulls::ReplaceDropNulls; use fast_projection::FastProjectionAndCollapse; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] use file_caching::{find_column_union_and_fingerprints, FileCacher}; +use polars_io::predicates::PhysicalIoExpr; pub use predicate_pushdown::PredicatePushDown; pub use projection_pushdown::ProjectionPushDown; pub use simplify_expr::{SimplifyBooleanRule, SimplifyExprRule}; @@ -41,8 +43,10 @@ use self::flatten_union::FlattenUnionRule; pub use crate::frame::{AllowedOptimizations, OptState}; #[cfg(feature = "cse")] use crate::logical_plan::optimizer::cse_expr::CommonSubExprOptimizer; +use crate::logical_plan::optimizer::predicate_pushdown::HiveEval; #[cfg(feature = "cse")] use crate::logical_plan::visitor::*; +use crate::prelude::optimizer::collect_members::MemberCollector; pub trait Optimize { fn optimize(&self, logical_plan: LogicalPlan) -> PolarsResult; @@ -61,6 +65,7 @@ pub fn optimize( lp_arena: &mut Arena, expr_arena: &mut Arena, scratch: &mut Vec, + hive_partition_eval: HiveEval<'_>, ) -> PolarsResult { // get toggle values let predicate_pushdown = opt_state.predicate_pushdown; @@ -69,13 +74,20 @@ pub fn optimize( let simplify_expr = opt_state.simplify_expr; let slice_pushdown = opt_state.slice_pushdown; let streaming = opt_state.streaming; + let fast_projection = opt_state.fast_projection; + // Don't run optimizations that don't make sense on a single node. + // This keeps eager execution more snappy. + let eager = opt_state.eager; #[cfg(feature = "cse")] - let comm_subplan_elim = opt_state.comm_subplan_elim; + let comm_subplan_elim = opt_state.comm_subplan_elim && !eager; + #[cfg(feature = "cse")] let comm_subexpr_elim = opt_state.comm_subexpr_elim; + #[cfg(not(feature = "cse"))] + let comm_subexpr_elim = false; #[allow(unused_variables)] - let agg_scan_projection = opt_state.file_caching && !streaming; + let agg_scan_projection = opt_state.file_caching && !streaming && !eager; // gradually fill the rules passed to the optimizer let opt = StackOptimizer {}; @@ -87,16 +99,23 @@ pub fn optimize( let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena)?; + // Collect members for optimizations that need it. + let mut members = MemberCollector::new(); + if !eager && (comm_subexpr_elim || projection_pushdown) { + members.collect(lp_top, lp_arena) + } + #[cfg(feature = "cse")] - let cse_changed = if comm_subplan_elim { + let cse_plan_changed = if comm_subplan_elim { let (lp, changed) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); lp_top = lp; + members.has_cache |= changed; changed } else { false }; #[cfg(not(feature = "cse"))] - let cse_changed = false; + let cse_plan_changed = false; // we do simplification if simplify_expr { @@ -112,23 +131,26 @@ pub fn optimize( let alp = projection_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); - if projection_pushdown_opt.has_joins_or_unions && projection_pushdown_opt.has_cache { - cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_changed); + if members.has_joins_or_unions && members.has_cache { + cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_plan_changed); } } if predicate_pushdown { - let predicate_pushdown_opt = PredicatePushDown::default(); + let predicate_pushdown_opt = PredicatePushDown::new(hive_partition_eval); let alp = lp_arena.take(lp_top); let alp = predicate_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); } // make sure its before slice pushdown. - if projection_pushdown { - rules.push(Box::new(FastProjectionAndCollapse::new())); + if fast_projection { + rules.push(Box::new(FastProjectionAndCollapse::new(eager))); + } + + if !eager { + rules.push(Box::new(DelayRechunk::new())); } - rules.push(Box::new(DelayRechunk::new())); if slice_pushdown { let slice_pushdown_opt = SlicePushDown::new(streaming); @@ -153,7 +175,7 @@ pub fn optimize( // and predicate pushdown are done. At that moment // the file fingerprints are finished. #[cfg(any(feature = "cse", feature = "parquet", feature = "ipc", feature = "csv"))] - if agg_scan_projection || cse_changed { + if agg_scan_projection || cse_plan_changed { // we do this so that expressions are simplified created by the pushdown optimizations // we must clean up the predicates, because the agg_scan_projection // uses them in the hashtable to determine duplicates. @@ -174,20 +196,22 @@ pub fn optimize( file_cacher.assign_unions(lp_top, lp_arena, expr_arena, scratch); #[cfg(feature = "cse")] - if cse_changed { + if cse_plan_changed { // this must run after cse cse::decrement_file_counters_by_cache_hits(lp_top, lp_arena, expr_arena, 0, scratch); } } rules.push(Box::new(ReplaceDropNulls {})); - rules.push(Box::new(FlattenUnionRule {})); + if !eager { + rules.push(Box::new(FlattenUnionRule {})); + } lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?; // This one should run (nearly) last as this modifies the projections #[cfg(feature = "cse")] - if comm_subexpr_elim { + if comm_subexpr_elim && !members.has_ext_context { let mut optimizer = CommonSubExprOptimizer::new(expr_arena); lp_top = ALogicalPlanNode::with_context(lp_top, lp_arena, |alp_node| { alp_node.rewrite(&mut optimizer) diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/group_by.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/group_by.rs new file mode 100644 index 000000000000..3ab423a6640d --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/group_by.rs @@ -0,0 +1,84 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_group_by( + opt: &PredicatePushDown, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + input: Node, + keys: Vec, + aggs: Vec, + schema: SchemaRef, + maintain_order: bool, + apply: Option>, + options: Arc, + acc_predicates: PlHashMap, Node>, +) -> PolarsResult { + use ALogicalPlan::*; + + #[cfg(feature = "dynamic_group_by")] + let no_push = { options.rolling.is_some() || options.dynamic.is_some() }; + + #[cfg(not(feature = "dynamic_group_by"))] + let no_push = false; + + // Don't pushdown predicates on these cases. + if apply.is_some() || no_push || options.slice.is_some() { + let lp = Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + return opt.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena); + } + + // If the predicate only resolves to the keys we can push it down. + // When it filters the aggregations, the predicate should be done after aggregation. + let mut local_predicates = Vec::with_capacity(acc_predicates.len()); + let key_schema = aexprs_to_schema( + &keys, + lp_arena.get(input).schema(lp_arena).as_ref(), + Context::Default, + expr_arena, + ); + + let mut new_acc_predicates = PlHashMap::with_capacity(acc_predicates.len()); + + for (pred_name, predicate) in &acc_predicates { + // Counts change due to groupby's + // TODO! handle aliases, so that the predicate that is pushed down refers to the column before alias. + let mut push_down = !has_aexpr(*predicate, expr_arena, |ae| { + matches!(ae, AExpr::Count | AExpr::Alias(_, _)) + }); + + for name in aexpr_to_leaf_names_iter(*predicate, expr_arena) { + push_down &= key_schema.contains(name.as_ref()); + + if !push_down { + break; + } + } + if !push_down { + local_predicates.push(*predicate) + } else { + new_acc_predicates.insert(pred_name.clone(), *predicate); + } + } + + opt.pushdown_and_assign(input, new_acc_predicates, lp_arena, expr_arena)?; + + let lp = Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + Ok(opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs index 5ea1eb9cec5f..30db76221022 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs @@ -28,9 +28,9 @@ fn should_block_join_specific(ae: &AExpr, how: &JoinType) -> LeftRight { | FunctionExpr::Boolean(BooleanFunction::IsDuplicated), .. } => LeftRight(true, true), - #[cfg(feature = "is_first")] + #[cfg(feature = "is_first_distinct")] Function { - function: FunctionExpr::Boolean(BooleanFunction::IsFirst), + function: FunctionExpr::Boolean(BooleanFunction::IsFirstDistinct), .. } => LeftRight(true, true), // any operation that checks for equality or ordering can be wrong because @@ -66,6 +66,21 @@ fn join_produces_null(how: &JoinType) -> LeftRight { } } +fn all_pred_cols_in_left_on( + predicate: Node, + expr_arena: &mut Arena, + left_on: &[Node], +) -> bool { + let left_on_col_exprs: Vec = left_on + .iter() + .map(|&node| node_to_expr(node, expr_arena)) + .collect(); + let mut col_exprs_in_predicate = aexpr_to_column_nodes_iter(predicate, expr_arena) + .map(|node| node_to_expr(node, expr_arena)); + + col_exprs_in_predicate.all(|expr| left_on_col_exprs.contains(&expr)) +} + #[allow(clippy::too_many_arguments)] pub(super) fn process_join( opt: &PredicatePushDown, @@ -108,12 +123,20 @@ pub(super) fn process_join( insert_and_combine_predicate(&mut pushdown_left, predicate, expr_arena); filter_left = true; } - // this is `else if` because if the predicate is in the left hand side + + // if the predicate is in the left hand side // the right hand side should be renamed with the suffix. // in that case we should not push down as the user wants to filter on `x` // not on `x_rhs`. - else if check_input_node(predicate, &schema_right, expr_arena) + if !filter_left + && check_input_node(predicate, &schema_right, expr_arena) && !block_pushdown_right + // However, if we push down to the left and all predicate columns are also + // join columns, we also push down right + || filter_left + && all_pred_cols_in_left_on(predicate, expr_arena, &left_on) + // TODO: Restricting to Inner and Left Join is probably too conservative + && matches!(&options.args.how, JoinType::Inner | JoinType::Left) { insert_and_combine_predicate(&mut pushdown_right, predicate, expr_arena); filter_right = true; diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 134787fa78a9..eec0ddaff940 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -1,23 +1,37 @@ +mod group_by; mod join; mod keys; mod rename; mod utils; +use polars_core::config::verbose; use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; use utils::*; use super::*; use crate::dsl::function_expr::FunctionExpr; -use crate::logical_plan::{optimizer, Context}; +use crate::logical_plan::optimizer; +use crate::prelude::optimizer::predicate_pushdown::group_by::process_group_by; use crate::prelude::optimizer::predicate_pushdown::join::process_join; use crate::prelude::optimizer::predicate_pushdown::rename::process_rename; -use crate::utils::{aexprs_to_schema, check_input_node, has_aexpr}; +use crate::utils::{check_input_node, has_aexpr}; -#[derive(Default)] -pub struct PredicatePushDown {} +pub type HiveEval<'a> = Option<&'a dyn Fn(Node, &Arena) -> Option>>; + +pub struct PredicatePushDown<'a> { + hive_partition_eval: HiveEval<'a>, + verbose: bool, +} + +impl<'a> PredicatePushDown<'a> { + pub fn new(hive_partition_eval: HiveEval<'a>) -> Self { + Self { + hive_partition_eval, + verbose: verbose(), + } + } -impl PredicatePushDown { fn optional_apply_predicate( &self, lp: ALogicalPlan, @@ -210,24 +224,6 @@ impl PredicatePushDown { }; Ok(lp) } - - LocalProjection { expr, input, .. } => { - self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?; - - let schema = lp_arena.get(input).schema(lp_arena); - // projection from a wildcard may be dropped if the schema changes due to the optimization - let expr: Vec<_> = expr - .into_iter() - .filter(|e| check_input_node(*e, &schema, expr_arena)) - .collect(); - - let schema = aexprs_to_schema(&expr, &schema, Context::Default, expr_arena); - Ok(ALogicalPlan::LocalProjection { - expr, - input, - schema: Arc::new(schema), - }) - } Scan { path, file_info, @@ -239,68 +235,68 @@ impl PredicatePushDown { let local_predicates = partition_by_full_context(&mut acc_predicates, expr_arena); let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); - let lp = match (predicate, &scan_type) { + if let (Some(hive_part_stats), Some(predicate)) = (file_info.hive_parts.as_deref(), predicate) { + if let Some(io_expr) = self.hive_partition_eval.unwrap()(predicate, expr_arena) { + if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { + if !stats_evaluator.should_read(hive_part_stats.get_statistics())? { + if self.verbose { + eprintln!("hive partitioning: skipped: {}", path.display()) + } + let schema = output_schema.as_ref().unwrap_or(&file_info.schema); + let df = DataFrame::from(schema.as_ref()); + + return Ok(DataFrameScan { + df: Arc::new(df), + schema: schema.clone(), + output_schema: None, + projection: None, + selection: None + }) + } + } + } + } + + let mut do_optimization = match &scan_type { #[cfg(feature = "csv")] - (Some(predicate), FileScan::Csv {..}) => { - let lp = Scan { - path, - file_info, - predicate: None, - file_options: options, - output_schema, - scan_type - }; + FileScan::Csv {..} => options.n_rows.is_none(), + FileScan::Anonymous {function, ..} => function.allows_predicate_pushdown(), + _ => true + }; + do_optimization &= predicate.is_some(); + + let lp = if do_optimization { + Scan { + path, + file_info, + predicate, + file_options: options, + output_schema, + scan_type + } + } else { + let lp = Scan { + path, + file_info, + predicate: None, + file_options: options, + output_schema, + scan_type + }; + if let Some(predicate) = predicate { let input = lp_arena.add(lp); Selection { input, predicate } - }, - _ => { - Scan { - path, - file_info, - predicate, - file_options: options, - output_schema, - scan_type - } + } else { + lp } }; Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) } - AnonymousScan { - function, - file_info, - output_schema, - options, - predicate, - } => { - if function.allows_predicate_pushdown() { - let local_predicates = partition_by_full_context(&mut acc_predicates, expr_arena); - let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); - let lp = AnonymousScan { - function, - file_info, - output_schema, - options, - predicate, - }; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - } else { - let lp = AnonymousScan { - function, - file_info, - output_schema, - options, - predicate, - }; - self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) - } - } - Distinct { input, options @@ -427,6 +423,10 @@ impl PredicatePushDown { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } } + Aggregate {input, keys, aggs, schema, apply, maintain_order, options, } => { + process_group_by(self, lp_arena, expr_arena, input, keys, aggs, schema, maintain_order, apply, options, acc_predicates) + + }, lp @ Union {..} => { let mut local_predicates = vec![]; @@ -457,7 +457,7 @@ impl PredicatePushDown { } // Pushed down passed these nodes - lp@ FileSink {..} => { + lp@ Sink {..} => { self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) } lp @ HStack {..} | lp @ Projection {..} | lp @ ExtContext {..} => { @@ -468,8 +468,7 @@ impl PredicatePushDown { lp @ Slice { .. } // caches will be different | lp @ Cache { .. } - // dont push down predicates. An aggregation needs all rows - | lp @ Aggregate {..} => { + => { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } #[cfg(feature = "python")] diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs index 9bd5290edeb9..d5c004ce8bad 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs @@ -42,7 +42,6 @@ pub(super) fn process_asof_join( lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - proj_pd.has_joins_or_unions = true; // n = 0 if no projections, so we don't allocate unneeded let n = acc_projections.len() * 2; let mut pushdown_left = Vec::with_capacity(n); @@ -222,7 +221,6 @@ pub(super) fn process_join( ); } - proj_pd.has_joins_or_unions = true; // n = 0 if no projections, so we don't allocate unneeded let n = acc_projections.len() * 2; let mut pushdown_left = Vec::with_capacity(n); diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 127e9ac14c95..3ff672683211 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -149,17 +149,11 @@ fn update_scan_schema( Ok(new_schema) } -pub struct ProjectionPushDown { - pub(crate) has_joins_or_unions: bool, - pub(crate) has_cache: bool, -} +pub struct ProjectionPushDown {} impl ProjectionPushDown { pub(super) fn new() -> Self { - Self { - has_joins_or_unions: false, - has_cache: false, - } + Self {} } /// Projection will be done at this node, but we continue optimization @@ -337,70 +331,6 @@ impl ProjectionPushDown { lp_arena, expr_arena, ), - LocalProjection { expr, input, .. } => { - self.pushdown_and_assign( - input, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - )?; - let lp = lp_arena.get(input); - let schema = lp.schema(lp_arena); - - // projection from a wildcard may be dropped if the schema changes due to the optimization - let proj = expr - .into_iter() - .filter(|e| check_input_node(*e, &schema, expr_arena)) - .collect(); - Ok(ALogicalPlanBuilder::new(input, expr_arena, lp_arena) - .project_local(proj) - .build()) - }, - AnonymousScan { - function, - file_info, - predicate, - mut options, - output_schema, - } => { - if function.allows_projection_pushdown() { - let mut_options = Arc::make_mut(&mut options); - mut_options.with_columns = - get_scan_columns(&mut acc_projections, expr_arena, None); - - let output_schema = if mut_options.with_columns.is_none() { - None - } else { - Some(Arc::new(update_scan_schema( - &acc_projections, - expr_arena, - &file_info.schema, - true, - )?)) - }; - mut_options.output_schema = output_schema.clone(); - - let lp = AnonymousScan { - function, - file_info, - output_schema, - options, - predicate, - }; - Ok(lp) - } else { - let lp = AnonymousScan { - function, - file_info, - predicate, - options, - output_schema, - }; - Ok(lp) - } - }, DataFrameScan { df, schema, @@ -452,24 +382,43 @@ impl ProjectionPushDown { scan_type, predicate, mut file_options, - .. + mut output_schema, } => { - file_options.with_columns = get_scan_columns( - &mut acc_projections, - expr_arena, - file_options.row_count.as_ref(), - ); + let mut do_optimization = true; + if let FileScan::Anonymous { ref function, .. } = scan_type { + do_optimization = function.allows_projection_pushdown(); + } - let output_schema = if file_options.with_columns.is_none() { - None - } else { - Some(Arc::new(update_scan_schema( - &acc_projections, + if do_optimization { + file_options.with_columns = get_scan_columns( + &mut acc_projections, expr_arena, - &file_info.schema, - scan_type.sort_projection(&file_options), - )?)) - }; + file_options.row_count.as_ref(), + ); + + output_schema = if file_options.with_columns.is_none() { + None + } else { + let mut schema = update_scan_schema( + &acc_projections, + expr_arena, + &file_info.schema, + scan_type.sort_projection(&file_options), + )?; + // Hive partitions are created AFTER the projection, so the output + // schema is incorrect. Here we ensure the columns that are projected and hive + // parts are added at the proper place in the schema, which is at the end. + if let Some(parts) = file_info.hive_parts.as_deref() { + let partition_schema = parts.schema(); + for (name, _) in partition_schema.iter() { + if let Some(dt) = schema.shift_remove(name) { + schema.with_column(name.clone(), dt); + } + } + } + Some(Arc::new(schema)) + }; + } let lp = Scan { path, @@ -696,20 +645,17 @@ impl ProjectionPushDown { lp_arena, expr_arena, ), - lp @ Union { .. } => { - self.has_joins_or_unions = true; - process_generic( - self, - lp, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - ) - }, + lp @ Union { .. } => process_generic( + self, + lp, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), // These nodes only have inputs and exprs, so we can use same logic. - lp @ Slice { .. } | lp @ FileSink { .. } => process_generic( + lp @ Slice { .. } | lp @ Sink { .. } => process_generic( self, lp, acc_projections, @@ -719,7 +665,6 @@ impl ProjectionPushDown { expr_arena, ), Cache { .. } => { - self.has_cache = true; // projections above this cache will be accumulated and pushed down // later // the redundant projection will be cleaned in the fast projection optimization diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs index 7e0cee38462e..16b2f5bb073b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs @@ -15,7 +15,6 @@ pub(super) fn process_semi_anti_join( lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - proj_pd.has_joins_or_unions = true; // n = 0 if no projections, so we don't allocate unneeded let n = acc_projections.len() * 2; let mut pushdown_left = Vec::with_capacity(n); diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index 89a89beba4c0..92d8621090c4 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -254,7 +254,7 @@ impl OptimizationRule for SimplifyBooleanRule { }, AExpr::Function { input, - function: FunctionExpr::Boolean(BooleanFunction::IsNot), + function: FunctionExpr::Boolean(BooleanFunction::Not), .. } => { let y = expr_arena.get(input[0]); @@ -263,7 +263,7 @@ impl OptimizationRule for SimplifyBooleanRule { // not(not x) => x AExpr::Function { input, - function: FunctionExpr::Boolean(BooleanFunction::IsNot), + function: FunctionExpr::Boolean(BooleanFunction::Not), .. } => Some(expr_arena.get(input[0]).clone()), // not(lit x) => !x diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index cd6f979c37f3..66887f25ee62 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -104,28 +104,6 @@ impl SlicePushDown { use ALogicalPlan::*; match (lp, state) { - (AnonymousScan { - function, - file_info, - output_schema, - predicate, - mut options, - }, - // TODO! we currently skip slice pushdown if there is a predicate. - // we can modify the readers to only limit after predicates have been applied - Some(state)) if state.offset == 0 && predicate.is_none() => { - let mut_options = Arc::make_mut(&mut options); - mut_options.n_rows = Some(state.len as usize); - let lp = AnonymousScan { - function, - file_info, - output_schema, - predicate, - options, - }; - - Ok(lp) - }, #[cfg(feature = "python")] (PythonScan { mut options, @@ -163,6 +141,7 @@ impl SlicePushDown { }; Ok(lp) }, + // TODO! we currently skip slice pushdown if there is a predicate. (Scan { path, file_info, @@ -309,9 +288,6 @@ impl SlicePushDown { // here we do not pushdown. // we reset the state and then start the optimization again m @ (Selection { .. }, _) - // let's be conservative. projections may do aggregations and a pushed down slice - // will lead to incorrect aggregations - | m @ (LocalProjection {..},_) // other blocking nodes | m @ (DataFrameScan {..}, _) | m @ (Sort {..}, _) diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index ba92141239fe..a69d5e1a9afb 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -331,6 +331,7 @@ impl OptimizationRule for TypeCoercionRule { op, right: node_right, } => return process_binary(expr_arena, lp_arena, lp_node, node_left, op, node_right), + #[cfg(feature = "is_in")] AExpr::Function { function: FunctionExpr::Boolean(BooleanFunction::IsIn), @@ -349,42 +350,62 @@ impl OptimizationRule for TypeCoercionRule { let casted_expr = match (&type_left, &type_other) { // types are equal, do nothing (a, b) if a == b => return Ok(None), + // all-null can represent anything (and/or empty list), so cast to target dtype + (_, DataType::Null) => AExpr::Cast { + expr: other_node, + data_type: type_left, + strict: false, + }, // cast both local and global string cache // note that there might not yet be a rev #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_), DataType::Utf8) => { - AExpr::Cast { - expr: other_node, - data_type: DataType::Categorical(None), - // does not matter - strict: false, + (DataType::Categorical(_), DataType::Utf8) => AExpr::Cast { + expr: other_node, + data_type: DataType::Categorical(None), + strict: false, + }, + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + // can't check for more granular time_unit in less-granular time_unit data, + // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) + (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) + } + }, + (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) } }, - (dt, DataType::Utf8) => { - polars_bail!(ComputeError: "cannot compare {:?} to {:?} type in 'is_in' operation", dt, type_other) + (_, DataType::List(other_inner)) => { + if other_inner.as_ref() == &type_left + || (type_left == DataType::Null) + || (other_inner.as_ref() == &DataType::Null) + || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) + { + return Ok(None); + } + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_left, &type_other) }, - (DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None), #[cfg(feature = "dtype-struct")] (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), - // if right is another type, we cast it to left - // we do not use super-type as an `is_in` operation should not - // cast the whole column implicitly. - (a, b) - if a != b - // For integer/ float comparison we let them use supertypes. - && !(a.is_integer() && b.is_float()) => - { - AExpr::Cast { - expr: other_node, - data_type: type_left, - // does not matter - strict: false, + + // don't attempt to cast between obviously mismatched types, but + // allow integer/float comparison (will use their supertypes). + (a, b) => { + if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) { + return Ok(None); } + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) }, - // do nothing - _ => return Ok(None), }; - let mut input = input.clone(); let other_input = expr_arena.add(casted_expr); input[1] = other_input; @@ -519,28 +540,56 @@ fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { } } -// TODO: Fix this test and re-enable it (currently does not compile) -// #[cfg(test)] -// #[cfg(feature = "dtype-categorical")] -// mod test { -// use polars_core::prelude::*; - -// use super::*; -// use crate::prelude::*; - -// #[test] -// fn test_categorical_utf8() { -// let mut rules: Vec> = vec![Box::new(TypeCoercionRule {})]; -// let schema = Schema::from_iter([Field::new("fruits", DataType::Categorical(None))]); - -// let expr = col("fruits").eq(lit("somestr")); -// let out = optimize_expr(expr.clone(), schema.clone(), &mut rules); -// // we test that the fruits column is not casted to utf8 for the comparison -// assert_eq!(out, expr); - -// let expr = col("fruits") + (lit("somestr")); -// let out = optimize_expr(expr, schema, &mut rules); -// let expected = col("fruits").cast(DataType::Utf8) + lit("somestr"); -// assert_eq!(out, expected); -// } -// } +#[cfg(test)] +#[cfg(feature = "dtype-categorical")] +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_categorical_utf8() { + let mut expr_arena = Arena::new(); + let mut lp_arena = Arena::new(); + let optimizer = StackOptimizer {}; + let rules: &mut [Box] = &mut [Box::new(TypeCoercionRule {})]; + + let df = DataFrame::new(Vec::from([Series::new_empty( + "fruits", + &DataType::Categorical(None), + )])) + .unwrap(); + + let expr_in = vec![col("fruits").eq(lit("somestr"))]; + let lp = LogicalPlanBuilder::from_existing_df(df.clone()) + .project(expr_in.clone(), Default::default()) + .build(); + + let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is not casted to utf8 for the comparison + if let LogicalPlan::Projection { expr, .. } = lp { + assert_eq!(expr, expr_in); + }; + + let expr_in = vec![col("fruits") + (lit("somestr"))]; + let lp = LogicalPlanBuilder::from_existing_df(df) + .project(expr_in, Default::default()) + .build(); + let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is casted to utf8 for the addition + let expected = vec![col("fruits").cast(DataType::Utf8) + lit("somestr")]; + if let LogicalPlan::Projection { expr, .. } = lp { + assert_eq!(expr, expected); + }; + } +} diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index 93764d462ad7..795107b49cf0 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -17,7 +17,6 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use crate::prelude::python_udf::PythonFunction; -use crate::prelude::Expr; pub type FileCount = u32; @@ -25,7 +24,7 @@ pub type FileCount = u32; #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct CsvParserOptions { - pub delimiter: u8, + pub separator: u8, pub comment_char: Option, pub quote_char: Option, pub eol_char: u8, @@ -101,6 +100,7 @@ pub struct FileScanOptions { pub row_count: Option, pub rechunk: bool, pub file_counter: FileCount, + pub hive_partitioning: bool, } #[derive(Clone, Debug, Copy, Default, Eq, PartialEq)] @@ -147,7 +147,7 @@ pub struct DistinctOptions { pub enum ApplyOptions { /// Collect groups to a list and apply the function over the groups. /// This can be important in aggregation context. - // e.g. [g1, g1, g2] -> [[g1, g2], g2] + // e.g. [g1, g1, g2] -> [[g1, g1], g2] ApplyGroups, // collect groups to a list and then apply // e.g. [g1, g1, g2] -> list([g1, g1, g2]) @@ -292,15 +292,26 @@ pub struct PythonOptions { #[derive(Clone, PartialEq, Eq, Debug, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct AnonymousScanOptions { - pub schema: SchemaRef, - pub output_schema: Option, pub skip_rows: Option, - pub n_rows: Option, - pub with_columns: Option>>, - pub predicate: Option, pub fmt_str: &'static str, } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Debug)] +pub enum SinkType { + Memory, + File { + path: Arc, + file_type: FileType, + }, + #[cfg(feature = "cloud")] + Cloud { + uri: Arc, + file_type: FileType, + cloud_options: Option, + }, +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Debug)] pub struct FileSinkOptions { @@ -317,17 +328,20 @@ pub enum FileType { Ipc(IpcWriterOptions), #[cfg(feature = "csv")] Csv(CsvWriterOptions), - Memory, } #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Copy, Debug)] pub struct ProjectionOptions { pub run_parallel: bool, + pub duplicate_check: bool, } impl Default for ProjectionOptions { fn default() -> Self { - Self { run_parallel: true } + Self { + run_parallel: true, + duplicate_check: true, + } } } diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index fd467fe56496..76dd348884aa 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -580,7 +580,9 @@ fn replace_selector_inner( }, Selector::Add(lhs, rhs) => { replace_selector_inner(*lhs, members, scratch, schema, keys)?; - replace_selector_inner(*rhs, members, scratch, schema, keys)?; + let mut rhs_members: PlIndexSet = Default::default(); + replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; + members.extend(rhs_members) }, Selector::Sub(lhs, rhs) => { // fill lhs diff --git a/crates/polars-plan/src/logical_plan/pyarrow.rs b/crates/polars-plan/src/logical_plan/pyarrow.rs index 5908d7b23f39..a007887be351 100644 --- a/crates/polars-plan/src/logical_plan/pyarrow.rs +++ b/crates/polars-plan/src/logical_plan/pyarrow.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use polars_core::datatypes::AnyValue; +use polars_core::prelude::{TimeUnit, TimeZone}; use crate::prelude::*; @@ -11,6 +12,15 @@ pub(super) struct Args { allow_literal_series: bool, } +fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String { + // note: `_to_python_datetime` and the `Datetime` + // dtype have to be in-scope on the python side + match tz { + None => format!("_to_python_datetime({},'{}')", v, tu.to_ascii()), + Some(tz) => format!("_to_python_datetime({},'{}',{})", v, tu.to_ascii(), tz), + } +} + // convert to a pyarrow expression that can be evaluated with pythons eval pub(super) fn predicate_to_pa( predicate: Node, @@ -39,15 +49,18 @@ pub(super) fn predicate_to_pa( if let AnyValue::Boolean(v) = av { let s = if v { "True" } else { "False" }; write!(list_repr, "{},", s).unwrap(); + } else if let AnyValue::Datetime(v, tu, tz) = av { + let dtm = to_py_datetime(v, &tu, tz.as_ref()); + write!(list_repr, "{dtm},").unwrap(); + } else if let AnyValue::Date(v) = av { + write!(list_repr, "_to_python_date({v}),").unwrap(); } else { write!(list_repr, "{av},").unwrap(); } } - // pop last comma list_repr.pop(); list_repr.push(']'); - Some(list_repr) } }, @@ -68,26 +81,10 @@ pub(super) fn predicate_to_pa( AnyValue::Date(v) => { // the function `_to_python_date` and the `Date` // dtype have to be in scope on the python side - Some(format!("_to_python_date(value={v})")) + Some(format!("_to_python_date({v})")) }, #[cfg(feature = "dtype-datetime")] - AnyValue::Datetime(v, tu, tz) => { - // the function `_to_python_datetime` and the `Datetime` - // dtype have to be in scope on the python side - match tz { - None => Some(format!( - "_to_python_datetime(value={}, tu='{}')", - v, - tu.to_ascii() - )), - Some(tz) => Some(format!( - "_to_python_datetime(value={}, tu='{}', tz={})", - v, - tu.to_ascii(), - tz - )), - } - }, + AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz.as_ref())), // Activate once pyarrow supports them // #[cfg(feature = "dtype-time")] // AnyValue::Time(v) => { @@ -119,7 +116,7 @@ pub(super) fn predicate_to_pa( } }, AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsNot), + function: FunctionExpr::Boolean(BooleanFunction::Not), input, .. } => { diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index 5cf2e59172d8..e97377c41e75 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::path::Path; use polars_core::prelude::*; use polars_utils::format_smartstring; @@ -18,14 +19,12 @@ impl LogicalPlan { Cache { input, .. } => input.schema(), Sort { input, .. } => input.schema(), DataFrameScan { schema, .. } => Ok(Cow::Borrowed(schema)), - AnonymousScan { file_info, .. } => Ok(Cow::Borrowed(&file_info.schema)), Selection { input, .. } => input.schema(), Projection { schema, .. } => Ok(Cow::Borrowed(schema)), - LocalProjection { schema, .. } => Ok(Cow::Borrowed(schema)), Aggregate { schema, .. } => Ok(Cow::Borrowed(schema)), Join { schema, .. } => Ok(Cow::Borrowed(schema)), HStack { schema, .. } => Ok(Cow::Borrowed(schema)), - Distinct { input, .. } | FileSink { input, .. } => input.schema(), + Distinct { input, .. } | Sink { input, .. } => input.schema(), Slice { input, .. } => input.schema(), MapFunction { input, function, .. @@ -49,6 +48,26 @@ pub struct FileInfo { // - known size // - estimated size pub row_estimation: (Option, usize), + pub hive_parts: Option>, +} + +impl FileInfo { + pub fn new(schema: SchemaRef, row_estimation: (Option, usize)) -> Self { + Self { + schema, + row_estimation, + hive_parts: None, + } + } + + pub fn set_hive_partitions(&mut self, url: &Path) { + self.hive_parts = hive::HivePartitions::parse_url(url).map(|hive_parts| { + let schema = Arc::make_mut(&mut self.schema); + schema.merge(hive_parts.get_statistics().schema().clone()); + + Arc::new(hive_parts) + }); + } } #[cfg(feature = "streaming")] @@ -73,6 +92,7 @@ pub fn set_estimated_row_counts( lp_arena: &mut Arena, expr_arena: &Arena, mut _filter_count: usize, + scratch: &mut Vec, ) -> (Option, usize, usize) { use ALogicalPlan::*; @@ -90,11 +110,12 @@ pub fn set_estimated_row_counts( .filter(|(_, ae)| matches!(ae, AExpr::BinaryExpr { .. })) .count() + 1; - set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count) + set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch) }, Slice { input, len, .. } => { let len = *len as usize; - let mut out = set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count); + let mut out = + set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch); apply_slice(&mut out, Some((0, len))); out }, @@ -106,7 +127,8 @@ pub fn set_estimated_row_counts( { let mut sum_output = (None, 0); for input in &inputs { - let mut out = set_estimated_row_counts(*input, lp_arena, expr_arena, 0); + let mut out = + set_estimated_row_counts(*input, lp_arena, expr_arena, 0, scratch); if let Some((_offset, len)) = options.slice { apply_slice(&mut out, Some((0, len))) } @@ -133,11 +155,11 @@ pub fn set_estimated_row_counts( { let mut_options = Arc::make_mut(&mut options); let (known_size, estimated_size, filter_count_left) = - set_estimated_row_counts(input_left, lp_arena, expr_arena, 0); + set_estimated_row_counts(input_left, lp_arena, expr_arena, 0, scratch); mut_options.rows_left = estimate_sizes(known_size, estimated_size, filter_count_left); let (known_size, estimated_size, filter_count_right) = - set_estimated_row_counts(input_right, lp_arena, expr_arena, 0); + set_estimated_row_counts(input_right, lp_arena, expr_arena, 0, scratch); mut_options.rows_right = estimate_sizes(known_size, estimated_size, filter_count_right); @@ -196,13 +218,20 @@ pub fn set_estimated_row_counts( // TODO! get row estimation. (None, usize::MAX, _filter_count) }, - AnonymousScan { options, .. } => { - let size = options.n_rows; - (size, size.unwrap_or(usize::MAX), _filter_count) - }, lp => { - let input = lp.get_input().unwrap(); - set_estimated_row_counts(input, lp_arena, expr_arena, _filter_count) + lp.copy_inputs(scratch); + let mut sum_output = (None, 0, 0); + while let Some(input) = scratch.pop() { + let out = + set_estimated_row_counts(input, lp_arena, expr_arena, _filter_count, scratch); + sum_output.1 += out.1; + sum_output.2 += out.2; + sum_output.0 = match sum_output.0 { + None => out.0, + p => p, + }; + } + sum_output }, } } @@ -252,7 +281,6 @@ pub(crate) fn det_join_schema( right_on.to_field_amortized(schema_right, Context::Default, &mut arena)?; if field_left.name != field_right.name { if schema_left.contains(&field_right.name) { - use polars_core::frame::hash_join::_join_suffix_name; new_schema.with_column( _join_suffix_name(&field_right.name, options.args.suffix()).into(), field_right.dtype, diff --git a/crates/polars-plan/src/logical_plan/tree_format.rs b/crates/polars-plan/src/logical_plan/tree_format.rs index 1a8327b0dc35..5f60d0104eec 100644 --- a/crates/polars-plan/src/logical_plan/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/tree_format.rs @@ -118,56 +118,192 @@ fn format_levels(f: &mut Formatter<'_>, levels: &[Vec]) -> std::fmt::Res let mut col_widths = vec![0usize; n_cols]; + let row_idx_width = levels.len().to_string().len() + 1; + let col_idx_width = n_cols.to_string().len(); + let space = " "; + let dash = "─"; + for (i, col_width) in col_widths.iter_mut().enumerate() { *col_width = levels .iter() .map(|row| row.get(i).map(|s| s.as_str()).unwrap_or("").chars().count()) .max() + .map(|n| if n < col_idx_width { col_idx_width } else { n }) .unwrap(); } - const COL_SPACING: usize = 4; + const COL_SPACING: usize = 2; for (row_count, row) in levels.iter().enumerate() { - // write vertical bars - if row_count != 0 { + if row_count == 0 { + // write the col numbers writeln!(f)?; - for ((col_i, col_name), col_width) in row.iter().enumerate().zip(&col_widths) { + write!(f, "{space:>row_idx_width$} ")?; + for (col_i, (_, col_width)) in + levels.last().unwrap().iter().zip(&col_widths).enumerate() + { let mut col_spacing = COL_SPACING; if col_i > 0 { col_spacing *= 2; } - - let mut remaining = col_width + col_spacing; - let half = (*col_width + col_spacing) / 2; + let half = (col_spacing + 4) / 2; + let remaining = col_spacing + 4 - half; // left_half - for _ in 0..half { - remaining -= 1; - write!(f, " ")?; + write!(f, "{space:^half$}")?; + // col num + write!(f, "{col_i:^col_width$}")?; + + write!(f, "{space:^remaining$}")?; + } + writeln!(f)?; + + // write the horizontal line + write!(f, "{space:>row_idx_width$} ┌")?; + for (col_i, (_, col_width)) in + levels.last().unwrap().iter().zip(&col_widths).enumerate() + { + let mut col_spacing = COL_SPACING; + if col_i > 0 { + col_spacing *= 2; + } + write!(f, "{dash:─^width$}", width = col_width + col_spacing + 4)?; + } + write!(f, "\n{space:>row_idx_width$} │\n")?; + } else { + // write connecting lines + write!(f, "{space:>row_idx_width$} │")?; + let mut last_empty = true; + let mut before = ""; + for ((col_i, col_name), col_width) in row.iter().enumerate().zip(&col_widths) { + let mut col_spacing = COL_SPACING; + if col_i > 0 { + col_spacing *= 2; } - // bar - remaining -= 1; - let val = if col_name.is_empty() { ' ' } else { '|' }; - write!(f, "{}", val)?; - for _ in 0..remaining { - write!(f, " ")? + let half = (*col_width + col_spacing + 4) / 2; + let remaining = col_width + col_spacing + 4 - half - 1; + if last_empty { + // left_half + write!(f, "{space:^half$}")?; + // bar + if col_name.is_empty() { + write!(f, " ")?; + } else { + write!(f, "│")?; + last_empty = false; + before = "│"; + } + } else { + // left_half + write!(f, "{dash:─^half$}")?; + // bar + write!(f, "╮")?; + before = "╮" + } + if (col_i == row.len() - 1) | col_name.is_empty() { + write!(f, "{space:^remaining$}")?; + } else { + if before == "│" { + write!(f, " ╰")?; + } else { + write!(f, "──")?; + } + write!(f, "{dash:─^width$}", width = remaining - 2)?; } } - write!(f, "\n\n")?; + writeln!(f)?; + // write vertical bars x 2 + for _ in 0..2 { + write!(f, "{space:>row_idx_width$} │")?; + for ((col_i, col_name), col_width) in row.iter().enumerate().zip(&col_widths) { + let mut col_spacing = COL_SPACING; + if col_i > 0 { + col_spacing *= 2; + } + + let half = (*col_width + col_spacing + 4) / 2; + let remaining = col_width + col_spacing + 4 - half - 1; + + // left_half + write!(f, "{space:^half$}")?; + // bar + let val = if col_name.is_empty() { ' ' } else { '│' }; + write!(f, "{}", val)?; + + write!(f, "{space:^remaining$}")?; + } + writeln!(f)?; + } + } + + // write the top of the boxes + write!(f, "{space:>row_idx_width$} │")?; + for (col_i, (col_repr, col_width)) in row.iter().zip(&col_widths).enumerate() { + let mut col_spacing = COL_SPACING; + if col_i > 0 { + col_spacing *= 2; + } + let char_count = col_repr.chars().count() + 4; + let half = (*col_width + col_spacing + 4 - char_count) / 2; + let remaining = col_width + col_spacing + 4 - half - char_count; + + write!(f, "{space:^half$}")?; + + if !col_repr.is_empty() { + write!(f, "╭")?; + write!(f, "{dash:─^width$}", width = char_count - 2)?; + write!(f, "╮")?; + } else { + write!(f, " ")?; + } + write!(f, "{space:^remaining$}")?; } + writeln!(f)?; // write column names and spacing - for (col_repr, col_width) in row.iter().zip(&col_widths) { - for _ in 0..COL_SPACING { - write!(f, " ")? + write!(f, "{row_count:>row_idx_width$} │")?; + for (col_i, (col_repr, col_width)) in row.iter().zip(&col_widths).enumerate() { + let mut col_spacing = COL_SPACING; + if col_i > 0 { + col_spacing *= 2; + } + let char_count = col_repr.chars().count() + 4; + let half = (*col_width + col_spacing + 4 - char_count) / 2; + let remaining = col_width + col_spacing + 4 - half - char_count; + + write!(f, "{space:^half$}")?; + + if !col_repr.is_empty() { + write!(f, "│ {} │", col_repr)?; + } else { + write!(f, " ")?; + } + write!(f, "{space:^remaining$}")?; + } + writeln!(f)?; + + // write the bottom of the boxes + write!(f, "{space:>row_idx_width$} │")?; + for (col_i, (col_repr, col_width)) in row.iter().zip(&col_widths).enumerate() { + let mut col_spacing = COL_SPACING; + if col_i > 0 { + col_spacing *= 2; } - write!(f, "{}", col_repr)?; - let remaining = *col_width - col_repr.chars().count(); - for _ in 0..remaining + COL_SPACING { - write!(f, " ")? + let char_count = col_repr.chars().count() + 4; + let half = (*col_width + col_spacing + 4 - char_count) / 2; + let remaining = col_width + col_spacing + 4 - half - char_count; + + write!(f, "{space:^half$}")?; + + if !col_repr.is_empty() { + write!(f, "╰")?; + write!(f, "{dash:─^width$}", width = char_count - 2)?; + write!(f, "╯")?; + } else { + write!(f, " ")?; } + write!(f, "{space:^remaining$}")?; } writeln!(f)?; } diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index edfa594b3c7b..c7b521167783 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -154,6 +154,7 @@ impl AexprNode { | (Filter { .. }, Filter { .. }) | (Ternary { .. }, Ternary { .. }) | (Count, Count) + | (Slice { .. }, Slice { .. }) | (Explode(_), Explode(_)) => true, (SortBy { descending: l, .. }, SortBy { descending: r, .. }) => l == r, (Agg(l), Agg(r)) => l.equal_nodes(r), @@ -169,12 +170,7 @@ impl AexprNode { .. }, ) => fl == fr && ol == or, - (AnonymousFunction { function: l, .. }, AnonymousFunction { function: r, .. }) => { - // check only data pointer as location - let l = l.as_ref() as *const _ as *const () as usize; - let r = r.as_ref() as *const _ as *const () as usize; - l == r - }, + (AnonymousFunction { .. }, AnonymousFunction { .. }) => false, (BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r, _ => false, }; diff --git a/crates/polars-plan/src/logical_plan/visitor/lp.rs b/crates/polars-plan/src/logical_plan/visitor/lp.rs index d231436dcb64..b13457ba1acb 100644 --- a/crates/polars-plan/src/logical_plan/visitor/lp.rs +++ b/crates/polars-plan/src/logical_plan/visitor/lp.rs @@ -81,8 +81,8 @@ impl ALogicalPlanNode { self.with_arena(|arena| arena.get(self.node).schema(arena)) } - /// Take a `Node` and convert it an `ALogicalPlanNode` and call - /// `F` with `self` and the new created `ALogicalPlanNode` + /// Take a [`Node`] and convert it an [`ALogicalPlanNode`] and call + /// `F` with `self` and the new created [`ALogicalPlanNode`] pub fn binary(&self, other: Node, op: F) -> T where F: FnOnce(&ALogicalPlanNode, &ALogicalPlanNode) -> T, diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index f84bf050adb9..85da66d68b61 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -17,10 +17,8 @@ pub(crate) use polars_time::{ pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; -pub(crate) use crate::logical_plan::conversion::*; #[cfg(feature = "debugging")] pub use crate::logical_plan::debug::*; -pub(crate) use crate::logical_plan::iterator::*; pub use crate::logical_plan::options::*; pub use crate::logical_plan::*; pub use crate::utils::*; diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index b8e3cd9d80c8..54644dff2bb8 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -73,9 +73,7 @@ impl PushNode for [Option; 1] { pub(crate) fn is_scan(plan: &ALogicalPlan) -> bool { matches!( plan, - ALogicalPlan::Scan { .. } - | ALogicalPlan::DataFrameScan { .. } - | ALogicalPlan::AnonymousScan { .. } + ALogicalPlan::Scan { .. } | ALogicalPlan::DataFrameScan { .. } ) } diff --git a/crates/polars-row/Cargo.toml b/crates/polars-row/Cargo.toml index 7a61189a7d6b..cd764898b2d6 100644 --- a/crates/polars-row/Cargo.toml +++ b/crates/polars-row/Cargo.toml @@ -9,7 +9,7 @@ repository = { workspace = true } description = "Row encodings for the Polars DataFrame library" [dependencies] -polars-error = { version = "0.32.0", path = "../polars-error" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-error = { workspace = true } +polars-utils = { workspace = true } arrow = { workspace = true } diff --git a/crates/polars-row/README.md b/crates/polars-row/README.md index a2cc699d6ad8..7e8f59f67620 100644 --- a/crates/polars-row/README.md +++ b/crates/polars-row/README.md @@ -1,5 +1,5 @@ # polars-row -`polars-row` is a sub-crate that provides row encodings for the Polars DataFrame library. +`polars-row` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, that provides row encodings for the Polars DataFrame Library. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index f7ceecb8667c..ad309ab573ad 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -4,6 +4,7 @@ use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::Bitmap; use arrow::datatypes::DataType; use arrow::types::NativeType; +use arrow::util::total_ord::{canonical_f32, canonical_f64}; use polars_utils::slice::*; use crate::row::{RowsEncoded, SortField}; @@ -107,7 +108,7 @@ impl FixedLengthEncoding for f32 { fn encode(self) -> [u8; 4] { // https://github.com/rust-lang/rust/blob/9c20b2a8cc7588decb6de25ac6a7912dcef24d65/library/core/src/num/f32.rs#L1176-L1260 - let s = self.to_bits() as i32; + let s = canonical_f32(self).to_bits() as i32; let val = s ^ (((s >> 31) as u32) >> 1) as i32; val.encode() } @@ -124,7 +125,7 @@ impl FixedLengthEncoding for f64 { fn encode(self) -> [u8; 8] { // https://github.com/rust-lang/rust/blob/9c20b2a8cc7588decb6de25ac6a7912dcef24d65/library/core/src/num/f32.rs#L1176-L1260 - let s = self.to_bits() as i64; + let s = canonical_f64(self).to_bits() as i64; let val = s ^ (((s >> 63) as u64) >> 1) as i64; val.encode() } diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index cbd079b8c6c2..c11584fb67cc 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -9,11 +9,12 @@ repository = { workspace = true } description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", features = ["like"] } -polars-core = { version = "0.32.0", path = "../polars-core", features = [] } -polars-lazy = { version = "0.32.0", path = "../polars-lazy", features = ["compile", "strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg"] } -polars-plan = { version = "0.32.0", path = "../polars-plan", features = ["compile"] } +polars-arrow = { workspace = true } +polars-core = { workspace = true } +polars-lazy = { workspace = true, features = ["strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg"] } +polars-plan = { workspace = true } +rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlparser = { workspace = true } @@ -25,3 +26,5 @@ json = ["polars-lazy/json"] default = [] ipc = ["polars-lazy/ipc"] parquet = ["polars-lazy/parquet"] +semi_anti_join = ["polars-lazy/semi_anti_join"] +diagonal_concat = ["polars-lazy/diagonal_concat"] diff --git a/crates/polars-sql/README.md b/crates/polars-sql/README.md index f971310c60cd..869d2e75da88 100644 --- a/crates/polars-sql/README.md +++ b/crates/polars-sql/README.md @@ -1,6 +1,6 @@ -# Polars SQL +# polars-sql -`polars-sql` is a sub-crate that provides a SQL transpiler for Polars. It can convert SQL queries to Polars logical plans. +`polars-sql` is a sub-crate of the [Polars](https://crates.io/crates/polars) library, offering a SQL transpiler. It allows for SQL query conversion to Polars logical plans. ## Usage @@ -17,6 +17,4 @@ You can then import the crate in your Rust code using: use polars_sql::*; ``` -## Features - -Please refer to the parent `polars` crate for a comprehensive list of features. +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 48a9c83cd07b..c03bf30f5957 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -7,26 +7,38 @@ use polars_lazy::prelude::*; use polars_plan::prelude::*; use polars_plan::utils::expressions_to_schema; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SqlExpr, FunctionArg, JoinOperator, ObjectName, - ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, + Distinct, ExcludeSelectItem, Expr as SqlExpr, FunctionArg, GroupByExpr, JoinOperator, + ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, Value as SQLValue, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; -use crate::sql_expr::{parse_sql_expr, process_join_constraint}; +use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry}; +use crate::sql_expr::{parse_sql_expr, process_join}; use crate::table_functions::PolarsTableFunctions; /// The SQLContext is the main entry point for executing SQL queries. -#[derive(Default, Clone)] +#[derive(Clone)] pub struct SQLContext { pub(crate) table_map: PlHashMap, + pub(crate) function_registry: Arc, cte_map: RefCell>, } +impl Default for SQLContext { + fn default() -> Self { + Self { + function_registry: Arc::new(DefaultFunctionRegistry {}), + table_map: Default::default(), + cte_map: Default::default(), + } + } +} + impl SQLContext { - /// Create a new SQLContext + /// Create a new SQLContext. /// ```rust /// # use polars_sql::SQLContext; /// # fn main() { @@ -34,12 +46,8 @@ impl SQLContext { /// # } /// ``` pub fn new() -> Self { - Self { - table_map: PlHashMap::new(), - cte_map: RefCell::new(PlHashMap::new()), - } + Self::default() } - /// Get the names of all registered tables, in sorted order. pub fn get_tables(&self) -> Vec { let mut tables = Vec::from_iter(self.table_map.keys().cloned()); @@ -47,7 +55,7 @@ impl SQLContext { tables } - /// Register a LazyFrame as a table in the SQLContext. + /// Register a [`LazyFrame`] as a table in the SQLContext. /// ```rust /// # use polars_sql::SQLContext; /// # use polars_core::prelude::*; @@ -66,12 +74,12 @@ impl SQLContext { self.table_map.insert(name.to_owned(), lf); } - /// Unregister a LazyFrame table from the SQLContext. + /// Unregister a [`LazyFrame`] table from the [`SQLContext`]. pub fn unregister(&mut self, name: &str) { self.table_map.remove(&name.to_owned()); } - /// Execute a SQL query, returning a LazyFrame. + /// Execute a SQL query, returning a [`LazyFrame`]. /// ```rust /// # use polars_sql::SQLContext; /// # use polars_core::prelude::*; @@ -103,10 +111,27 @@ impl SQLContext { .map_err(to_compute_err)?; polars_ensure!(ast.len() == 1, ComputeError: "One and only one statement at a time please"); let res = self.execute_statement(ast.get(0).unwrap()); - // every execution should clear the cte map + // Every execution should clear the CTE map. self.cte_map.borrow_mut().clear(); res } + + /// add a function registry to the SQLContext + /// the registry provides the ability to add custom functions to the SQLContext + pub fn with_function_registry(mut self, function_registry: Arc) -> Self { + self.function_registry = function_registry; + self + } + + /// Get the function registry of the SQLContext + pub fn registry(&self) -> &Arc { + &self.function_registry + } + + /// Get a mutable reference to the function registry of the SQLContext + pub fn registry_mut(&mut self) -> &mut dyn FunctionRegistry { + Arc::get_mut(&mut self.function_registry).unwrap() + } } impl SQLContext { @@ -114,12 +139,9 @@ impl SQLContext { self.cte_map.borrow_mut().insert(name.to_owned(), lf); } - fn get_table_from_current_scope(&mut self, name: &str) -> Option { - if let Some(lf) = self.table_map.get(name) { - Some(lf.clone()) - } else { - self.cte_map.borrow().get(name).cloned() - } + fn get_table_from_current_scope(&self, name: &str) -> Option { + let table_name = self.table_map.get(name).cloned(); + table_name.or_else(|| self.cte_map.borrow().get(name).cloned()) } pub(crate) fn execute_statement(&mut self, stmt: &Statement) -> PolarsResult { @@ -142,6 +164,10 @@ impl SQLContext { pub(crate) fn execute_query(&mut self, query: &Query) -> PolarsResult { self.register_ctes(query)?; + self.execute_query_no_ctes(query) + } + + pub(crate) fn execute_query_no_ctes(&mut self, query: &Query) -> PolarsResult { let lf = self.process_set_expr(&query.body, query)?; self.process_limit_offset(lf, &query.limit, &query.offset) @@ -150,7 +176,7 @@ impl SQLContext { fn process_set_expr(&mut self, expr: &SetExpr, query: &Query) -> PolarsResult { match expr { SetExpr::Select(select_stmt) => self.execute_select(select_stmt, query), - SetExpr::Query(query) => self.execute_query(query), + SetExpr::Query(query) => self.execute_query_no_ctes(query), SetExpr::SetOperation { op: SetOperator::Union, set_quantifier, @@ -173,29 +199,45 @@ impl SQLContext { ) -> PolarsResult { let left = self.process_set_expr(left, query)?; let right = self.process_set_expr(right, query)?; - let concatenated = polars_lazy::dsl::concat( - vec![left, right], - UnionArgs { - parallel: true, - ..Default::default() - }, - ); + let opts = UnionArgs { + parallel: true, + to_supertypes: true, + ..Default::default() + }; match quantifier { // UNION ALL - SetQuantifier::All => concatenated, - // UNION DISTINCT | UNION - _ => concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)), + SetQuantifier::All => polars_lazy::dsl::concat(vec![left, right], opts), + // UNION [DISTINCT] + SetQuantifier::Distinct | SetQuantifier::None => { + let concatenated = polars_lazy::dsl::concat(vec![left, right], opts); + concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) + }, + // UNION ALL BY NAME + // TODO: add recognition for SetQuantifier::DistinctByName + // when "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" is available + #[cfg(feature = "diagonal_concat")] + SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts), + // UNION [DISTINCT] BY NAME + #[cfg(feature = "diagonal_concat")] + SetQuantifier::ByName => { + let concatenated = concat_lf_diagonal(vec![left, right], opts); + concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) + }, + #[allow(unreachable_patterns)] + _ => polars_bail!(InvalidOperation: "UNION {} is not yet supported", quantifier), } } + // EXPLAIN SELECT * FROM DF fn execute_explain(&mut self, stmt: &Statement) -> PolarsResult { match stmt { Statement::Explain { statement, .. } => { let lf = self.execute_statement(statement)?; let plan = lf.describe_optimized_plan()?; - let mut plan = plan.split('\n').collect::(); - plan.rename("Logical Plan"); - + let plan = plan + .split('\n') + .collect::() + .with_name("Logical Plan"); let df = DataFrame::new(vec![plan])?; Ok(df.lazy()) }, @@ -203,7 +245,7 @@ impl SQLContext { } } - /// SHOW TABLES + // SHOW TABLES fn execute_show_tables(&mut self, _: &Statement) -> PolarsResult { let tables = Series::new("name", self.get_tables()); let df = DataFrame::new(vec![tables])?; @@ -238,27 +280,37 @@ impl SQLContext { /// execute the 'FROM' part of the query fn execute_from_statement(&mut self, tbl_expr: &TableWithJoins) -> PolarsResult { - let (tbl_name, mut lf) = self.get_table(&tbl_expr.relation)?; + let (l_name, mut lf) = self.get_table(&tbl_expr.relation)?; if !tbl_expr.joins.is_empty() { for tbl in &tbl_expr.joins { - let (join_tbl_name, join_tbl) = self.get_table(&tbl.relation)?; - match &tbl.join_operator { + let (r_name, rf) = self.get_table(&tbl.relation)?; + lf = match &tbl.join_operator { + JoinOperator::CrossJoin => lf.cross_join(rf), + JoinOperator::FullOuter(constraint) => { + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Outer)? + }, JoinOperator::Inner(constraint) => { - let (left_on, right_on) = - process_join_constraint(constraint, &tbl_name, &join_tbl_name)?; - lf = lf.inner_join(join_tbl, left_on, right_on) + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)? }, JoinOperator::LeftOuter(constraint) => { - let (left_on, right_on) = - process_join_constraint(constraint, &tbl_name, &join_tbl_name)?; - lf = lf.left_join(join_tbl, left_on, right_on) + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)? }, - JoinOperator::FullOuter(constraint) => { - let (left_on, right_on) = - process_join_constraint(constraint, &tbl_name, &join_tbl_name)?; - lf = lf.outer_join(join_tbl, left_on, right_on) + #[cfg(feature = "semi_anti_join")] + JoinOperator::LeftAnti(constraint) => { + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)? + }, + #[cfg(feature = "semi_anti_join")] + JoinOperator::LeftSemi(constraint) => { + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)? + }, + #[cfg(feature = "semi_anti_join")] + JoinOperator::RightAnti(constraint) => { + process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)? + }, + #[cfg(feature = "semi_anti_join")] + JoinOperator::RightSemi(constraint) => { + process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)? }, - JoinOperator::CrossJoin => lf = lf.cross_join(join_tbl), join_type => { polars_bail!( InvalidOperation: @@ -268,13 +320,13 @@ impl SQLContext { } } }; - Ok(lf) } - /// execute the 'SELECT' part of the query + + /// Execute the 'SELECT' part of the query. fn execute_select(&mut self, select_stmt: &Select, query: &Query) -> PolarsResult { - // Determine involved dataframe - // Implicit join require some more work in query parsers, Explicit join are preferred for now. + // Determine involved dataframes. + // Implicit joins require some more work in query parsers, explicit joins are preferred for now. let sql_tbl: &TableWithJoins = select_stmt .from .get(0) @@ -284,16 +336,14 @@ impl SQLContext { let mut contains_wildcard = false; let mut contains_wildcard_exclude = false; - // Filter Expression - lf = match select_stmt.selection.as_ref() { - Some(expr) => { - let filter_expression = parse_sql_expr(expr, self)?; - lf.filter(filter_expression) - }, - None => lf, - }; + // Filter expression. + if let Some(expr) = select_stmt.selection.as_ref() { + let mut filter_expression = parse_sql_expr(expr, self)?; + lf = self.process_subqueries(lf, vec![&mut filter_expression]); + lf = lf.filter(filter_expression); + } - // Column Projections + // Column projections. let projections: Vec<_> = select_stmt .projection .iter() @@ -323,30 +373,32 @@ impl SQLContext { }) .collect::>()?; - // Check for group by - // After projection since there might be number. - let group_by_keys: Vec = select_stmt - .group_by - .iter() - .map(|e| match e { - SqlExpr::Value(SQLValue::Number(idx, _)) => { - let idx = match idx.parse::() { - Ok(0) | Err(_) => Err(polars_err!( - ComputeError: - "group_by error: a positive number or an expression expected, got {}", - idx - )), - Ok(idx) => Ok(idx), - }?; - Ok(projections[idx].clone()) - }, - SqlExpr::Value(_) => Err(polars_err!( - ComputeError: - "group_by error: a positive number or an expression expected", - )), - _ => parse_sql_expr(e, self), - }) - .collect::>()?; + // Check for group by (after projections since there might be numbers). + let group_by_keys: Vec; + if let GroupByExpr::Expressions(group_by_exprs) = &select_stmt.group_by { + group_by_keys = group_by_exprs.iter() + .map(|e| match e { + SqlExpr::Value(SQLValue::Number(idx, _)) => { + let idx = match idx.parse::() { + Ok(0) | Err(_) => Err(polars_err!( + ComputeError: + "group_by error: a positive number or an expression expected, got {}", + idx + )), + Ok(idx) => Ok(idx), + }?; + Ok(projections[idx].clone()) + }, + SqlExpr::Value(_) => Err(polars_err!( + ComputeError: + "group_by error: a positive number or an expression expected", + )), + _ => parse_sql_expr(e, self), + }) + .collect::>()? + } else { + polars_bail!(ComputeError: "not implemented"); + }; lf = if group_by_keys.is_empty() { if query.order_by.is_empty() { @@ -392,11 +444,11 @@ impl SQLContext { let exclude_expr = projections.iter().find(|expr| { if let Expr::Exclude(_, excludes) = expr { - excludes.iter().for_each(|excluded| { + for excluded in excludes.iter() { if let Excluded::Name(name) = excluded { - dropped_names.push((*name).to_string()); + dropped_names.push(name.to_string()); } - }); + } true } else { false @@ -406,7 +458,6 @@ impl SQLContext { if exclude_expr.is_some() { lf = lf.with_columns(projections); lf = self.process_order_by(lf, &query.order_by)?; - lf.drop_columns(dropped_names) } else { lf = lf.select(projections); @@ -418,17 +469,16 @@ impl SQLContext { } } else { lf = self.process_group_by(lf, contains_wildcard, &group_by_keys, &projections)?; - lf = self.process_order_by(lf, &query.order_by)?; - // Apply optional 'having' clause, post-aggregation + // Apply optional 'having' clause, post-aggregation. match select_stmt.having.as_ref() { Some(expr) => lf.filter(parse_sql_expr(expr, self)?), None => lf, } }; - // Apply optional 'distinct' clause + // Apply optional 'distinct' clause. lf = match &select_stmt.distinct { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { @@ -460,6 +510,28 @@ impl SQLContext { Ok(lf) } + fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame { + let mut contexts = vec![]; + for expr in exprs { + expr.mutate().apply(|e| { + if let Expr::SubPlan(lp, names) = e { + contexts.push(::from((***lp).clone())); + + if names.len() == 1 { + *e = Expr::Column(names[0].as_str().into()); + } + }; + true + }) + } + + if contexts.is_empty() { + lf + } else { + lf.with_context(contexts) + } + } + fn execute_create_table(&mut self, stmt: &Statement) -> PolarsResult { if let Statement::CreateTable { if_not_exists, @@ -539,11 +611,7 @@ impl SQLContext { for ob in ob { by.push(parse_sql_expr(&ob.expr, self)?); - if let Some(false) = ob.asc { - descending.push(true) - } else { - descending.push(false) - } + descending.push(!ob.asc.unwrap_or(true)); polars_ensure!( ob.nulls_first.is_none(), ComputeError: "nulls first/last is not yet supported", @@ -560,8 +628,8 @@ impl SQLContext { group_by_keys: &[Expr], projections: &[Expr], ) -> PolarsResult { - // check group_by and projection due to difference between SQL and polars - // Return error on wild card, shouldn't process this + // Check group_by and projection due to difference between SQL and polars. + // Return error on wild card, shouldn't process this. polars_ensure!( !contains_wildcard, ComputeError: "group_by error: can't process wildcard in group_by" @@ -571,13 +639,13 @@ impl SQLContext { let group_by_keys_schema = expressions_to_schema(group_by_keys, &schema_before, Context::Default)?; - // remove the group_by keys as polars adds those implicitly + // Remove the group_by keys as polars adds those implicitly. let mut aggregation_projection = Vec::with_capacity(projections.len()); let mut aliases: BTreeSet<&str> = BTreeSet::new(); for mut e in projections { - // if it is a simple expression & has alias, - // we must defer the aliasing until after the group_by + // If it is a simple expression & has alias, + // we must defer the aliasing until after the group_by. if e.clone().meta().is_simple_projection() { if let Expr::Alias(expr, name) = e { aliases.insert(name); @@ -594,7 +662,7 @@ impl SQLContext { let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection); let projection_schema = expressions_to_schema(projections, &schema_before, Context::Default)?; - // a final projection to get the proper order + // A final projection to get the proper order. let final_projection = projection_schema .iter_names() .zip(projections) @@ -611,7 +679,7 @@ impl SQLContext { } fn process_limit_offset( - &mut self, + &self, lf: LazyFrame, limit: &Option, offset: &Option, @@ -714,7 +782,7 @@ impl SQLContext { pub fn new_from_table_map(table_map: PlHashMap) -> Self { Self { table_map, - cte_map: RefCell::new(PlHashMap::new()), + ..Default::default() } } } diff --git a/crates/polars-sql/src/function_registry.rs b/crates/polars-sql/src/function_registry.rs new file mode 100644 index 000000000000..e2d8c90c0b1a --- /dev/null +++ b/crates/polars-sql/src/function_registry.rs @@ -0,0 +1,30 @@ +//! This module defines the function registry and user defined functions. + +use polars_arrow::error::{polars_bail, PolarsResult}; +use polars_plan::prelude::udf::UserDefinedFunction; +pub use polars_plan::prelude::{Context, FunctionOptions}; +/// A registry that holds user defined functions. +pub trait FunctionRegistry: Send + Sync { + /// Register a function. + fn register(&mut self, name: &str, fun: UserDefinedFunction) -> PolarsResult<()>; + /// Call a user defined function. + fn get_udf(&self, name: &str) -> PolarsResult>; + /// Check if a function is registered. + fn contains(&self, name: &str) -> bool; +} + +/// A default registry that does not support registering or calling functions. +pub struct DefaultFunctionRegistry {} + +impl FunctionRegistry for DefaultFunctionRegistry { + fn register(&mut self, _name: &str, _fun: UserDefinedFunction) -> PolarsResult<()> { + polars_bail!(ComputeError: "'register' not implemented on DefaultFunctionRegistry'") + } + + fn get_udf(&self, _name: &str) -> PolarsResult> { + polars_bail!(ComputeError: "'get_udf' not implemented on DefaultFunctionRegistry'") + } + fn contains(&self, _name: &str) -> bool { + false + } +} diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 9d833f6222b0..5822aa8086bf 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1,8 +1,9 @@ -use polars_core::prelude::{polars_bail, polars_err, PolarsError, PolarsResult}; +use polars_core::prelude::{polars_bail, polars_err, PolarsResult}; use polars_lazy::dsl::Expr; -use polars_plan::dsl::count; +use polars_plan::dsl::{coalesce, count, when}; use polars_plan::logical_plan::LiteralValue; use polars_plan::prelude::lit; +use polars_plan::prelude::LiteralValue::Null; use sqlparser::ast::{ Expr as SqlExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Value as SqlValue, WindowSpec, WindowType, @@ -13,7 +14,7 @@ use crate::SQLContext; pub(crate) struct SqlFunctionVisitor<'a> { pub(crate) func: &'a SQLFunction, - pub(crate) ctx: &'a SQLContext, + pub(crate) ctx: &'a mut SQLContext, } /// SQL functions that are supported by Polars @@ -181,6 +182,11 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT column_2 from df WHERE ENDS_WITH(column_1, 'a'); /// ``` EndsWith, + /// SQL 'initcap' function + /// ```sql + /// SELECT INITCAP(column_1) from df; + /// ``` + InitCap, /// SQL 'left' function /// ```sql /// SELECT LEFT(column_1, 3) from df; @@ -232,6 +238,16 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT UPPER(column_1) from df; /// ``` Upper, + /// SQL 'nullif' function + /// ```sql + /// SELECT NULLIF(column_1, column_2) from df; + /// ``` + NullIf, + /// SQL 'coalesce' function + /// ```sql + /// SELECT COALESCE(column_1, ...) from df; + /// ``` + Coalesce, // ---- // Aggregate functions @@ -345,6 +361,11 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT unnest(column_1) from df; /// ``` Explode, + /// SQL 'array_to_string' function + /// ```sql + /// SELECT ARRAY_TO_STRING(column_1, ', ') from df; + /// ``` + ArrayToString, /// SQL 'array_get' function /// Returns the value at the given index in the array /// ```sql @@ -357,6 +378,7 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT ARRAY_CONTAINS(column_1, 'foo') from df; /// ``` ArrayContains, + Udf(String), } impl PolarsSqlFunctions { @@ -384,6 +406,7 @@ impl PolarsSqlFunctions { "cbrt", "ceil", "ceiling", + "coalesce", "cos", "cosd", "cot", @@ -406,6 +429,7 @@ impl PolarsSqlFunctions { "ltrim", "max", "min", + "nullif", "octet_length", "pi", "pow", @@ -430,9 +454,8 @@ impl PolarsSqlFunctions { } } -impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions { - type Error = PolarsError; - fn try_from(function: &'_ SQLFunction) -> Result { +impl PolarsSqlFunctions { + fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult { let function_name = function.name.0[0].value.to_lowercase(); Ok(match function_name.as_str() { // ---- @@ -471,10 +494,17 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions { "cbrt" => Self::Cbrt, "round" => Self::Round, + // ---- + // Comparison functions + // ---- + "nullif" => Self::NullIf, + "coalesce" => Self::Coalesce, + // ---- // String functions // ---- "ends_with" => Self::EndsWith, + "initcap" => Self::InitCap, "length" => Self::Length, "left" => Self::Left, "lower" => Self::Lower, @@ -509,20 +539,27 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions { "array_mean" => Self::ArrayMean, "array_reverse" => Self::ArrayReverse, "array_sum" => Self::ArraySum, + "array_to_string" => Self::ArrayToString, "array_unique" => Self::ArrayUnique, "array_upper" => Self::ArrayMax, "unnest" => Self::Explode, - other => polars_bail!(InvalidOperation: "unsupported SQL function: {}", other), + other => { + if ctx.function_registry.contains(other) { + Self::Udf(other.to_string()) + } else { + polars_bail!(InvalidOperation: "unsupported SQL function: {}", other); + } + }, }) } } impl SqlFunctionVisitor<'_> { - pub(crate) fn visit_function(&self) -> PolarsResult { + pub(crate) fn visit_function(&mut self) -> PolarsResult { let function = self.func; + let function_name = PolarsSqlFunctions::try_from_sql(function, self.ctx)?; - let function_name: PolarsSqlFunctions = function.try_into()?; use PolarsSqlFunctions::*; match function_name { @@ -574,10 +611,18 @@ impl SqlFunctionVisitor<'_> { polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}",function.args.len()); }, }, + + // ---- + // Comparison functions + // ---- + NullIf => self.visit_binary(|l, r: Expr| when(l.clone().eq(r)).then(lit(LiteralValue::Null)).otherwise(l)), + Coalesce => self.visit_variadic(coalesce), + // ---- // String functions // ---- EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)), + InitCap => self.visit_unary(|e| e.str().to_titlecase()), Left => self.try_visit_binary(|e, length| { Ok(e.str().str_slice(0, match length { Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64), @@ -586,17 +631,17 @@ impl SqlFunctionVisitor<'_> { } })) }), - Length => self.visit_unary(|e| e.str().n_chars()), + Length => self.visit_unary(|e| e.str().len_chars()), Lower => self.visit_unary(|e| e.str().to_lowercase()), LTrim => match function.args.len() { - 1 => self.visit_unary(|e| e.str().lstrip(None)), - 2 => self.visit_binary(|e, s| e.str().lstrip(Some(s))), + 1 => self.visit_unary(|e| e.str().strip_chars_start(lit(Null))), + 2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)), _ => polars_bail!(InvalidOperation: "Invalid number of arguments for LTrim: {}", function.args.len() ), }, - OctetLength => self.visit_unary(|e| e.str().lengths()), + OctetLength => self.visit_unary(|e| e.str().len_bytes()), RegexpLike => match function.args.len() { 2 => self.visit_binary(|e, s| e.str().contains(s, true)), 3 => self.try_visit_ternary(|e, pat, flags| { @@ -615,8 +660,8 @@ impl SqlFunctionVisitor<'_> { _ => polars_bail!(InvalidOperation:"Invalid number of arguments for RegexpLike: {}",function.args.len()), }, RTrim => match function.args.len() { - 1 => self.visit_unary(|e| e.str().rstrip(None)), - 2 => self.visit_binary(|e, s| e.str().rstrip(Some(s))), + 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), + 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)), _ => polars_bail!(InvalidOperation: "Invalid number of arguments for RTrim: {}", function.args.len() @@ -669,18 +714,43 @@ impl SqlFunctionVisitor<'_> { // ---- ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), ArrayGet => self.visit_binary(|e, i| e.list().get(i)), - ArrayLength => self.visit_unary(|e| e.list().lengths()), + ArrayLength => self.visit_unary(|e| e.list().len()), ArrayMax => self.visit_unary(|e| e.list().max()), ArrayMean => self.visit_unary(|e| e.list().mean()), ArrayMin => self.visit_unary(|e| e.list().min()), ArrayReverse => self.visit_unary(|e| e.list().reverse()), ArraySum => self.visit_unary(|e| e.list().sum()), + ArrayToString => self.try_visit_binary(|e, s| { + Ok(e.list().join(s)) + }), ArrayUnique => self.visit_unary(|e| e.list().unique()), Explode => self.visit_unary(|e| e.explode()), + Udf(func_name) => self.visit_udf(&func_name) } } - fn visit_unary(&self, f: impl Fn(Expr) -> Expr) -> PolarsResult { + fn visit_udf(&mut self, func_name: &str) -> PolarsResult { + let function = self.func; + + let args = extract_args(function); + let args = args + .into_iter() + .map(|arg| { + if let FunctionArgExpr::Expr(e) = arg { + parse_sql_expr(e, self.ctx) + } else { + polars_bail!(ComputeError: "Only expressions are supported in UDFs") + } + }) + .collect::>>()?; + if let Some(expr) = self.ctx.function_registry.get_udf(func_name)? { + expr.call(args) + } else { + polars_bail!(ComputeError: "UDF {} not found", func_name) + } + } + + fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult { self.visit_unary_no_window(f) .and_then(|e| self.apply_window_spec(e, &self.func.over)) } @@ -691,7 +761,7 @@ impl SqlFunctionVisitor<'_> { /// if there is a cumulative window spec, it will apply the cumulative function, /// otherwise it will apply the function fn visit_unary_with_opt_cumulative( - &self, + &mut self, f: impl Fn(Expr) -> Expr, cumulative_f: impl Fn(Expr, bool) -> Expr, ) -> PolarsResult { @@ -709,7 +779,7 @@ impl SqlFunctionVisitor<'_> { /// Window specs without partition bys are essentially cumulative functions /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false) fn apply_cumulative_window( - &self, + &mut self, f: impl Fn(Expr) -> Expr, cumulative_f: impl Fn(Expr, bool) -> Expr, WindowSpec { @@ -737,7 +807,7 @@ impl SqlFunctionVisitor<'_> { } } - fn visit_unary_no_window(&self, f: impl Fn(Expr) -> Expr) -> PolarsResult { + fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult { let function = self.func; let args = extract_args(function); @@ -751,12 +821,15 @@ impl SqlFunctionVisitor<'_> { } } - fn visit_binary(&self, f: impl Fn(Expr, Arg) -> Expr) -> PolarsResult { + fn visit_binary( + &mut self, + f: impl Fn(Expr, Arg) -> Expr, + ) -> PolarsResult { self.try_visit_binary(|e, a| Ok(f(e, a))) } fn try_visit_binary( - &self, + &mut self, f: impl Fn(Expr, Arg) -> PolarsResult, ) -> PolarsResult { let function = self.func; @@ -771,6 +844,27 @@ impl SqlFunctionVisitor<'_> { } } + fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult { + self.try_visit_variadic(|e| Ok(f(e))) + } + + fn try_visit_variadic( + &mut self, + f: impl Fn(&[Expr]) -> PolarsResult, + ) -> PolarsResult { + let function = self.func; + let args = extract_args(function); + let mut expr_args = vec![]; + for arg in args { + if let FunctionArgExpr::Expr(sql_expr) = arg { + expr_args.push(parse_sql_expr(sql_expr, self.ctx)?); + } else { + return self.not_supported_error(); + }; + } + f(&expr_args) + } + // fn visit_ternary( // &self, // f: impl Fn(Expr, Arg, Arg) -> Expr, @@ -779,7 +873,7 @@ impl SqlFunctionVisitor<'_> { // } fn try_visit_ternary( - &self, + &mut self, f: impl Fn(Expr, Arg, Arg) -> PolarsResult, ) -> PolarsResult { let function = self.func; @@ -804,23 +898,23 @@ impl SqlFunctionVisitor<'_> { Ok(f()) } - fn visit_count(&self) -> PolarsResult { + fn visit_count(&mut self) -> PolarsResult { let args = extract_args(self.func); match (self.func.distinct, args.as_slice()) { // count() (false, []) => Ok(count()), // count(column_name) (false, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = - self.apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &self.func.over)?; + let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.count()) }, // count(*) (false, [FunctionArgExpr::Wildcard]) => Ok(count()), // count(distinct column_name) (true, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = - self.apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &self.func.over)?; + let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.n_unique()) }, _ => self.not_supported_error(), @@ -828,7 +922,7 @@ impl SqlFunctionVisitor<'_> { } fn apply_window_spec( - &self, + &mut self, expr: Expr, window_type: &Option, ) -> PolarsResult { @@ -886,13 +980,13 @@ fn extract_args(sql_function: &SQLFunction) -> Vec<&FunctionArgExpr> { } pub(crate) trait FromSqlExpr { - fn from_sql_expr(expr: &SqlExpr, ctx: &SQLContext) -> PolarsResult + fn from_sql_expr(expr: &SqlExpr, ctx: &mut SQLContext) -> PolarsResult where Self: Sized; } impl FromSqlExpr for f64 { - fn from_sql_expr(expr: &SqlExpr, _ctx: &SQLContext) -> PolarsResult + fn from_sql_expr(expr: &SqlExpr, _ctx: &mut SQLContext) -> PolarsResult where Self: Sized, { @@ -909,7 +1003,7 @@ impl FromSqlExpr for f64 { } impl FromSqlExpr for String { - fn from_sql_expr(expr: &SqlExpr, _: &SQLContext) -> PolarsResult + fn from_sql_expr(expr: &SqlExpr, _: &mut SQLContext) -> PolarsResult where Self: Sized, { @@ -924,7 +1018,7 @@ impl FromSqlExpr for String { } impl FromSqlExpr for Expr { - fn from_sql_expr(expr: &SqlExpr, ctx: &SQLContext) -> PolarsResult + fn from_sql_expr(expr: &SqlExpr, ctx: &mut SQLContext) -> PolarsResult where Self: Sized, { diff --git a/crates/polars-sql/src/keywords.rs b/crates/polars-sql/src/keywords.rs index 53e1044fd162..aea00fb54152 100644 --- a/crates/polars-sql/src/keywords.rs +++ b/crates/polars-sql/src/keywords.rs @@ -16,9 +16,9 @@ pub fn all_keywords() -> Vec<&'static str> { use sqlparser::keywords; let sql_keywords = &[ keywords::AND, + keywords::ANTI, keywords::ARRAY, keywords::AS, - keywords::AS, keywords::ASC, keywords::BOOLEAN, keywords::BY, @@ -51,6 +51,7 @@ pub fn all_keywords() -> Vec<&'static str> { keywords::OUTER, keywords::RIGHT, keywords::SELECT, + keywords::SEMI, keywords::SHOW, keywords::TABLE, keywords::TABLES, diff --git a/crates/polars-sql/src/lib.rs b/crates/polars-sql/src/lib.rs index 35dd67d76370..a811a4cfad9b 100644 --- a/crates/polars-sql/src/lib.rs +++ b/crates/polars-sql/src/lib.rs @@ -2,6 +2,7 @@ //! This crate provides a SQL interface for Polars DataFrames #![deny(missing_docs)] mod context; +pub mod function_registry; mod functions; pub mod keywords; mod sql_expr; diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index d0ef9f6b9f20..b566ceba0a12 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -2,11 +2,14 @@ use polars_arrow::error::to_compute_err; use polars_core::prelude::*; use polars_lazy::dsl::Expr; use polars_lazy::prelude::*; +use polars_plan::prelude::LiteralValue::Null; use polars_plan::prelude::{col, lit, when}; +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; use sqlparser::ast::{ ArrayAgg, BinaryOperator as SQLBinaryOperator, BinaryOperator, DataType as SQLDataType, - Expr as SqlExpr, Function as SQLFunction, JoinConstraint, OrderByExpr, SelectItem, - TrimWhereField, UnaryOperator, Value as SqlValue, + Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint, OrderByExpr, + Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -52,16 +55,33 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { - ctx: &'a SQLContext, + ctx: &'a mut SQLContext, } impl SqlExprVisitor<'_> { - fn visit_expr(&self, expr: &SqlExpr) -> PolarsResult { + fn visit_expr(&mut self, expr: &SqlExpr) -> PolarsResult { match expr { - SqlExpr::AllOp(_) => Ok(self.visit_expr(expr)?.all(true)), - SqlExpr::AnyOp(expr) => Ok(self.visit_expr(expr)?.any(true)), + SqlExpr::AllOp { + left, + compare_op, + right, + } => self.visit_all(left, compare_op, right), + SqlExpr::AnyOp { + left, + compare_op, + right, + } => self.visit_any(left, compare_op, right), SqlExpr::ArrayAgg(expr) => self.visit_arr_agg(expr), SqlExpr::Between { expr, @@ -80,7 +100,13 @@ impl SqlExprVisitor<'_> { expr, list, negated, - } => self.visit_is_in(expr, list, *negated), + } => self.visit_in_list(expr, list, *negated), + SqlExpr::InSubquery { + expr, + subquery, + negated, + } => self.visit_in_subquery(expr, subquery, *negated), + SqlExpr::Subquery(_) => polars_bail!(InvalidOperation: "Unexpected SQL Subquery"), SqlExpr::IsDistinctFrom(e1, e2) => { Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?)) }, @@ -108,10 +134,47 @@ impl SqlExprVisitor<'_> { } } + fn visit_subquery( + &mut self, + subquery: &Subquery, + restriction: SubqueryRestriction, + ) -> PolarsResult { + if subquery.with.is_some() { + polars_bail!(InvalidOperation: "SQL subquery cannot be given CTEs"); + } + + let mut lf = self.ctx.execute_query_no_ctes(subquery)?; + + let schema = lf.schema()?; + if restriction == SubqueryRestriction::SingleColumn { + if schema.len() != 1 { + polars_bail!(InvalidOperation: "SQL subquery will return more than one column"); + } + let rand_string: String = thread_rng() + .sample_iter(&Alphanumeric) + .take(16) + .map(char::from) + .collect(); + + let schema_entry = schema.get_at_index(0); + if let Some((old_name, _)) = schema_entry { + let new_name = String::from(old_name.as_str()) + rand_string.as_str(); + lf = lf.rename([old_name.to_string()], [new_name.clone()]); + + return Ok(Expr::SubPlan( + SpecialEq::new(Arc::new(lf.logical_plan)), + vec![new_name], + )); + } + }; + + polars_bail!(InvalidOperation: "SQL subquery type not supported"); + } + /// Visit a compound identifier /// /// e.g. df.column or "df"."column" - fn visit_compound_identifier(&self, idents: &[sqlparser::ast::Ident]) -> PolarsResult { + fn visit_compound_identifier(&self, idents: &[Ident]) -> PolarsResult { match idents { [tbl_name, column_name] => { let lf = self.ctx.table_map.get(&tbl_name.value).ok_or_else(|| { @@ -139,7 +202,7 @@ impl SqlExprVisitor<'_> { } } - fn visit_unary_op(&self, op: &UnaryOperator, expr: &SqlExpr) -> PolarsResult { + fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SqlExpr) -> PolarsResult { let expr = self.visit_expr(expr)?; Ok(match op { UnaryOperator::Plus => lit(0) + expr, @@ -152,7 +215,7 @@ impl SqlExprVisitor<'_> { /// Visit a single identifier /// /// e.g. column - fn visit_identifier(&self, ident: &sqlparser::ast::Ident) -> PolarsResult { + fn visit_identifier(&self, ident: &Ident) -> PolarsResult { Ok(col(&ident.value)) } @@ -160,7 +223,7 @@ impl SqlExprVisitor<'_> { /// /// e.g. column + 1 or column1 / column2 fn visit_binary_op( - &self, + &mut self, left: &SqlExpr, op: &BinaryOperator, right: &SqlExpr, @@ -219,18 +282,64 @@ impl SqlExprVisitor<'_> { /// e.g. SUM(column) or COUNT(*) /// /// See [SqlFunctionVisitor] for more details - fn visit_function(&self, function: &SQLFunction) -> PolarsResult { - let visitor = SqlFunctionVisitor { + fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult { + let mut visitor = SqlFunctionVisitor { func: function, ctx: self.ctx, }; visitor.visit_function() } + /// Visit a SQL ALL + /// + /// e.g. `a > ALL(y)` + fn visit_all( + &mut self, + left: &SqlExpr, + compare_op: &BinaryOperator, + right: &SqlExpr, + ) -> PolarsResult { + let left = self.visit_expr(left)?; + let right = self.visit_expr(right)?; + + match compare_op { + BinaryOperator::Gt => Ok(left.gt(right.max())), + BinaryOperator::Lt => Ok(left.lt(right.min())), + BinaryOperator::GtEq => Ok(left.gt_eq(right.max())), + BinaryOperator::LtEq => Ok(left.lt_eq(right.min())), + BinaryOperator::Eq => polars_bail!(ComputeError: "ALL cannot be used with ="), + BinaryOperator::NotEq => polars_bail!(ComputeError: "ALL cannot be used with !="), + _ => polars_bail!(ComputeError: "Invalid comparison operator"), + } + } + + /// Visit a SQL ANY + /// + /// e.g. `a != ANY(y)` + fn visit_any( + &mut self, + left: &SqlExpr, + compare_op: &BinaryOperator, + right: &SqlExpr, + ) -> PolarsResult { + let left = self.visit_expr(left)?; + let right = self.visit_expr(right)?; + + match compare_op { + BinaryOperator::Gt => Ok(left.gt(right.min())), + BinaryOperator::Lt => Ok(left.lt(right.max())), + BinaryOperator::GtEq => Ok(left.gt_eq(right.min())), + BinaryOperator::LtEq => Ok(left.lt_eq(right.max())), + BinaryOperator::Eq => Ok(left.is_in(right)), + BinaryOperator::NotEq => Ok(left.is_in(right).not()), + _ => polars_bail!(ComputeError: "Invalid comparison operator"), + } + } + /// Visit a SQL CAST /// /// e.g. `CAST(column AS INT)` or `column::INT` - fn visit_cast(&self, expr: &SqlExpr, data_type: &SQLDataType) -> PolarsResult { + fn visit_cast(&mut self, expr: &SqlExpr, data_type: &SQLDataType) -> PolarsResult { let polars_type = map_sql_polars_datatype(data_type)?; let expr = self.visit_expr(expr)?; @@ -263,19 +372,34 @@ impl SqlExprVisitor<'_> { }) } - // similar to visit_literal, but returns an AnyValue instead of Expr - fn visit_anyvalue(&self, value: &SqlValue) -> PolarsResult { + /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr + fn visit_anyvalue( + &self, + value: &SqlValue, + op: Option<&UnaryOperator>, + ) -> PolarsResult { Ok(match value { SqlValue::Boolean(b) => AnyValue::Boolean(*b), SqlValue::Null => AnyValue::Null, SqlValue::Number(s, _) => { + let negate = match op { + Some(UnaryOperator::Minus) => true, + Some(UnaryOperator::Plus) => false, + _ => { + polars_bail!(ComputeError: "Unary op {:?} not supported for numeric SQL value", op) + }, + }; // Check for existence of decimal separator dot if s.contains('.') { - s.parse::().map(AnyValue::Float64).map_err(|_| ()) + s.parse::() + .map(|n: f64| AnyValue::Float64(if negate { -n } else { n })) + .map_err(|_| ()) } else { - s.parse::().map(AnyValue::Int64).map_err(|_| ()) + s.parse::() + .map(|n: i64| AnyValue::Int64(if negate { -n } else { n })) + .map_err(|_| ()) } - .map_err(|_| polars_err!(ComputeError: "cannot parse literal: {:?}"))? + .map_err(|_| polars_err!(ComputeError: "cannot parse literal: {s:?}"))? }, SqlValue::SingleQuotedString(s) | SqlValue::NationalStringLiteral(s) @@ -288,7 +412,7 @@ impl SqlExprVisitor<'_> { /// Visit a SQL `BETWEEN` expression /// See [sqlparser::ast::Expr::Between] for more details fn visit_between( - &self, + &mut self, expr: &SqlExpr, negated: bool, low: &SqlExpr, @@ -308,7 +432,7 @@ impl SqlExprVisitor<'_> { /// Visit a SQL 'TRIM' function /// See [sqlparser::ast::Expr::Trim] for more details fn visit_trim( - &self, + &mut self, expr: &SqlExpr, trim_where: &Option, trim_what: &Option>, @@ -322,17 +446,17 @@ impl SqlExprVisitor<'_> { }; Ok(match (trim_where, trim_what) { - (None | Some(TrimWhereField::Both), None) => expr.str().strip(None), - (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip(Some(val)), - (Some(TrimWhereField::Leading), None) => expr.str().lstrip(None), - (Some(TrimWhereField::Leading), Some(val)) => expr.str().lstrip(Some(val)), - (Some(TrimWhereField::Trailing), None) => expr.str().rstrip(None), - (Some(TrimWhereField::Trailing), Some(val)) => expr.str().rstrip(Some(val)), + (None | Some(TrimWhereField::Both), None) => expr.str().strip_chars(lit(Null)), + (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)), + (Some(TrimWhereField::Leading), None) => expr.str().strip_chars_start(lit(Null)), + (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)), + (Some(TrimWhereField::Trailing), None) => expr.str().strip_chars_end(lit(Null)), + (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)), }) } /// Visit a SQL `ARRAY_AGG` expression - fn visit_arr_agg(&self, expr: &ArrayAgg) -> PolarsResult { + fn visit_arr_agg(&mut self, expr: &ArrayAgg) -> PolarsResult { let mut base = self.visit_expr(&expr.expr)?; if let Some(order_by) = expr.order_by.as_ref() { @@ -363,15 +487,28 @@ impl SqlExprVisitor<'_> { } /// Visit a SQL `IN` expression - fn visit_is_in(&self, expr: &SqlExpr, list: &[SqlExpr], negated: bool) -> PolarsResult { + fn visit_in_list( + &mut self, + expr: &SqlExpr, + list: &[SqlExpr], + negated: bool, + ) -> PolarsResult { let expr = self.visit_expr(expr)?; let list = list .iter() .map(|e| { if let SqlExpr::Value(v) = e { - let av = self.visit_anyvalue(v)?; + let av = self.visit_anyvalue(v, None)?; Ok(av) - } else { + } else if let SqlExpr::UnaryOp {op, expr} = e { + match expr.as_ref() { + SqlExpr::Value(v) => { + let av = self.visit_anyvalue(v, Some(op))?; + Ok(av) + }, + _ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)) + } + }else{ Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)) } }) @@ -385,7 +522,24 @@ impl SqlExprVisitor<'_> { } } - fn visit_order_by(&self, order_by: &[OrderByExpr]) -> PolarsResult<(Vec, Vec)> { + fn visit_in_subquery( + &mut self, + expr: &SqlExpr, + subquery: &Subquery, + negated: bool, + ) -> PolarsResult { + let expr = self.visit_expr(expr)?; + + let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?; + + if negated { + Ok(expr.is_in(subquery_result).not()) + } else { + Ok(expr.is_in(subquery_result)) + } + } + + fn visit_order_by(&mut self, order_by: &[OrderByExpr]) -> PolarsResult<(Vec, Vec)> { let mut expr = Vec::with_capacity(order_by.len()); let mut descending = Vec::with_capacity(order_by.len()); for order_by_expr in order_by { @@ -398,7 +552,7 @@ impl SqlExprVisitor<'_> { Ok((expr, descending)) } - fn visit_when_then(&self, expr: &SqlExpr) -> PolarsResult { + fn visit_when_then(&mut self, expr: &SqlExpr) -> PolarsResult { if let SqlExpr::Case { operand, conditions, @@ -487,47 +641,68 @@ impl SqlExprVisitor<'_> { } } -pub(crate) fn parse_sql_expr(expr: &SqlExpr, ctx: &SQLContext) -> PolarsResult { - let visitor = SqlExprVisitor { ctx }; +pub(crate) fn parse_sql_expr(expr: &SqlExpr, ctx: &mut SQLContext) -> PolarsResult { + let mut visitor = SqlExprVisitor { ctx }; visitor.visit_expr(expr) } +pub(super) fn process_join( + left_tbl: LazyFrame, + right_tbl: LazyFrame, + constraint: &JoinConstraint, + tbl_name: &str, + join_tbl_name: &str, + join_type: JoinType, +) -> PolarsResult { + let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?; + + Ok(left_tbl + .join_builder() + .with(right_tbl) + .left_on(left_on) + .right_on(right_on) + .how(join_type) + .finish()) +} + pub(super) fn process_join_constraint( constraint: &JoinConstraint, left_name: &str, right_name: &str, -) -> PolarsResult<(Expr, Expr)> { +) -> PolarsResult<(Vec, Vec)> { if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint { + if op != &BinaryOperator::Eq { + polars_bail!(InvalidOperation: + "SQL interface (currently) only supports basic equi-join \ + constraints; found '{:?}' op in\n{:?}", op, constraint) + } match (left.as_ref(), right.as_ref()) { (SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) => { if left.len() == 2 && right.len() == 2 { - let tbl_a = &left[0].value; - let col_a = &left[1].value; - let tbl_b = &right[0].value; - let col_b = &right[1].value; - - if let BinaryOperator::Eq = op { - if left_name == tbl_a && right_name == tbl_b { - return Ok((col(col_a), col(col_b))); - } else if left_name == tbl_b && right_name == tbl_a { - return Ok((col(col_b), col(col_a))); - } + let (tbl_a, col_a) = (&left[0].value, &left[1].value); + let (tbl_b, col_b) = (&right[0].value, &right[1].value); + + if left_name == tbl_a && right_name == tbl_b { + return Ok((vec![col(col_a)], vec![col(col_b)])); + } else if left_name == tbl_b && right_name == tbl_a { + return Ok((vec![col(col_b)], vec![col(col_a)])); } } }, (SqlExpr::Identifier(left), SqlExpr::Identifier(right)) => { - return Ok((col(&left.value), col(&right.value))) + return Ok((vec![col(&left.value)], vec![col(&right.value)])) }, _ => {}, } } if let JoinConstraint::Using(idents) = constraint { if !idents.is_empty() { - let cols = &idents[0].value; - return Ok((col(cols), col(cols))); + let mut using = Vec::with_capacity(idents.len()); + using.extend(idents.iter().map(|id| col(&id.value))); + return Ok((using.clone(), using.clone())); } } - polars_bail!(InvalidOperation: "SQL join constraint {:?} is not yet supported", constraint); + polars_bail!(InvalidOperation: "Unsupported SQL join constraint:\n{:?}", constraint); } /// parse a SQL expression to a polars expression @@ -548,7 +723,7 @@ pub(super) fn process_join_constraint( /// # } /// ``` pub fn sql_expr>(s: S) -> PolarsResult { - let ctx = SQLContext::new(); + let mut ctx = SQLContext::new(); let mut parser = Parser::new(&GenericDialect); parser = parser.with_options(ParserOptions { @@ -561,10 +736,10 @@ pub fn sql_expr>(s: S) -> PolarsResult { Ok(match &expr { SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, &ctx)?; + let expr = parse_sql_expr(expr, &mut ctx)?; expr.alias(&alias.value) }, - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &ctx)?, + SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx)?, _ => polars_bail!(InvalidOperation: "Unable to parse '{}' as Expr", s.as_ref()), }) } diff --git a/crates/polars-sql/tests/functions_io.rs b/crates/polars-sql/tests/functions_io.rs index 823744b389e8..59740a6a792f 100644 --- a/crates/polars-sql/tests/functions_io.rs +++ b/crates/polars-sql/tests/functions_io.rs @@ -1,5 +1,8 @@ +#[cfg(any(feature = "csv", feature = "ipc"))] use polars_core::prelude::*; +#[cfg(any(feature = "csv", feature = "ipc"))] use polars_lazy::prelude::*; +#[cfg(any(feature = "csv", feature = "ipc"))] use polars_sql::*; #[test] diff --git a/crates/polars-sql/tests/functions_string.rs b/crates/polars-sql/tests/functions_string.rs index 2c36da6bcd01..32252f0cb505 100644 --- a/crates/polars-sql/tests/functions_string.rs +++ b/crates/polars-sql/tests/functions_string.rs @@ -1,5 +1,6 @@ use polars_core::prelude::*; use polars_lazy::prelude::*; +use polars_plan::prelude::LiteralValue::Null; use polars_sql::*; #[test] @@ -48,35 +49,78 @@ fn test_string_functions() { col("a").str().to_uppercase().alias("upper_a_df"), col("a").str().to_uppercase().alias("upper_a_df2"), col("a").str().to_uppercase().alias("upper_a_df3"), - col("a").str().strip(Some("x".into())).alias("trim_a"), + col("a").str().strip_chars(lit("x")).alias("trim_a"), col("a") .str() - .lstrip(Some("x".into())) + .strip_chars_start(lit("x")) .alias("trim_a_leading"), col("a") .str() - .rstrip(Some("x".into())) + .strip_chars_end(lit("x")) .alias("trim_a_trailing"), - col("a").str().lstrip(None).alias("ltrim_a"), - col("a").str().rstrip(None).alias("rtrim_a"), + col("a").str().strip_chars_start(lit(Null)).alias("ltrim_a"), + col("a").str().strip_chars_end(lit(Null)).alias("rtrim_a"), col("a") .str() - .lstrip(Some("-".into())) + .strip_chars_start(lit("-")) .alias("ltrim_a_dash"), col("a") .str() - .rstrip(Some("-".into())) + .strip_chars_end(lit("-")) .alias("rtrim_a_dash"), col("a") .str() - .lstrip(Some("xyz".into())) + .strip_chars_start(lit("xyz")) .alias("ltrim_a_xyz"), col("a") .str() - .rstrip(Some("xyz".into())) + .strip_chars_end(lit("xyz")) .alias("rtrim_a_xyz"), ]) .collect() .unwrap(); assert!(df_sql.frame_equal_missing(&df_pl)); } + +#[test] +fn array_to_string() { + let df = df! { + "a" => &["first", "first", "third"], + "b" => &[1, 1, 42], + } + .unwrap(); + let mut context = SQLContext::new(); + context.register("df", df.clone().lazy()); + let sql = context + .execute( + r#" + SELECT + b, + a + FROM df + GROUP BY + b"#, + ) + .unwrap(); + context.register("df_1", sql.clone()); + let sql = r#" + SELECT + b, + array_to_string(a, ', ') as as, + FROM df_1 + ORDER BY + b, + as"#; + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + + let df_pl = df + .lazy() + .group_by([col("b")]) + .agg([col("a")]) + .select(&[col("b"), col("a").list().join(lit(", ")).alias("as")]) + .sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true) + .collect() + .unwrap(); + + assert!(df_sql.frame_equal_missing(&df_pl)); +} diff --git a/crates/polars-sql/tests/iss_7436.rs b/crates/polars-sql/tests/iss_7436.rs index 34895e3a657f..65b3f1c854ec 100644 --- a/crates/polars-sql/tests/iss_7436.rs +++ b/crates/polars-sql/tests/iss_7436.rs @@ -1,9 +1,9 @@ -use polars_lazy::prelude::*; -use polars_sql::*; - #[test] #[cfg(feature = "csv")] fn iss_7436() { + use polars_lazy::prelude::*; + use polars_sql::*; + let mut context = SQLContext::new(); let sql = r#" CREATE TABLE foods AS diff --git a/crates/polars-sql/tests/iss_7437.rs b/crates/polars-sql/tests/iss_7437.rs index 29229ba5c4c6..9b150ac06244 100644 --- a/crates/polars-sql/tests/iss_7437.rs +++ b/crates/polars-sql/tests/iss_7437.rs @@ -1,5 +1,8 @@ +#[cfg(feature = "csv")] use polars_core::prelude::*; +#[cfg(feature = "csv")] use polars_lazy::prelude::*; +#[cfg(feature = "csv")] use polars_sql::*; #[test] diff --git a/crates/polars-sql/tests/iss_8395.rs b/crates/polars-sql/tests/iss_8395.rs index a54f360a456b..b48c30718771 100644 --- a/crates/polars-sql/tests/iss_8395.rs +++ b/crates/polars-sql/tests/iss_8395.rs @@ -1,4 +1,6 @@ +#[cfg(feature = "csv")] use polars_core::prelude::*; +#[cfg(feature = "csv")] use polars_sql::*; #[test] diff --git a/crates/polars-sql/tests/udf.rs b/crates/polars-sql/tests/udf.rs new file mode 100644 index 000000000000..4a8f4f93bc6a --- /dev/null +++ b/crates/polars-sql/tests/udf.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use polars_arrow::error::PolarsResult; +use polars_core::prelude::{DataType, Field, *}; +use polars_core::series::Series; +use polars_lazy::prelude::IntoLazy; +use polars_plan::prelude::{GetOutput, UserDefinedFunction}; +use polars_sql::function_registry::FunctionRegistry; +use polars_sql::SQLContext; + +struct MyFunctionRegistry { + functions: PlHashMap, +} + +impl MyFunctionRegistry { + fn new(funcs: Vec) -> Self { + let functions = funcs.into_iter().map(|f| (f.name.to_string(), f)).collect(); + MyFunctionRegistry { functions } + } +} + +impl FunctionRegistry for MyFunctionRegistry { + fn register(&mut self, name: &str, fun: UserDefinedFunction) -> PolarsResult<()> { + self.functions.insert(name.to_string(), fun); + Ok(()) + } + + fn get_udf(&self, name: &str) -> PolarsResult> { + Ok(self.functions.get(name).cloned()) + } + + fn contains(&self, name: &str) -> bool { + self.functions.contains_key(name) + } +} + +#[test] +fn test_udfs() -> PolarsResult<()> { + let my_custom_sum = UserDefinedFunction::new( + "my_custom_sum", + vec![ + Field::new("a", DataType::Int32), + Field::new("b", DataType::Int32), + ], + GetOutput::same_type(), + move |s: &mut [Series]| { + let first = s[0].clone(); + let second = s[1].clone(); + Ok(Some(first + second)) + }, + ); + + let mut ctx = SQLContext::new() + .with_function_registry(Arc::new(MyFunctionRegistry::new(vec![my_custom_sum]))); + + let df = df! { + "a" => &[1, 2, 3], + "b" => &[1, 2, 3], + "c" => &["a", "b", "c"] + } + .unwrap() + .lazy(); + + ctx.register("foo", df); + let res = ctx.execute("SELECT a, b, my_custom_sum(a, b) FROM foo"); + assert!(res.is_ok()); + + // schema is invalid so it will fail + assert!(ctx + .execute("SELECT a, b, my_custom_sum(c) as invalid FROM foo") + .is_err()); + + // create a new UDF to be registered on the context + let my_custom_divide = UserDefinedFunction::new( + "my_custom_divide", + vec![ + Field::new("a", DataType::Int32), + Field::new("b", DataType::Int32), + ], + GetOutput::same_type(), + move |s: &mut [Series]| { + let first = s[0].clone(); + let second = s[1].clone(); + Ok(Some(first / second)) + }, + ); + + // register a new UDF on an existing context + ctx.registry_mut().register("my_div", my_custom_divide)?; + + // execute the query + let res = ctx + .execute("SELECT a, b, my_div(a, b) as my_div FROM foo")? + .collect()?; + let expected = df! { + "a" => &[1, 2, 3], + "b" => &[1, 2, 3], + "my_div" => &[1, 1, 1] + }?; + assert!(expected.frame_equal_missing(&res)); + + Ok(()) +} diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index a06a270c8dd9..88a39d3659f4 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -9,10 +9,10 @@ repository = { workspace = true } description = "Time related code for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", features = ["compute", "temporal"] } -polars-core = { version = "0.32.0", path = "../polars-core", default-features = false, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } -polars-ops = { version = "0.32.0", path = "../polars-ops" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true, features = ["compute", "temporal"] } +polars-core = { workspace = true, default-features = false, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } +polars-ops = { workspace = true } +polars-utils = { workspace = true } arrow = { workspace = true } atoi = { workspace = true } diff --git a/crates/polars-time/README.md b/crates/polars-time/README.md index a3ca3cc6797a..d43adb2abb36 100644 --- a/crates/polars-time/README.md +++ b/crates/polars-time/README.md @@ -1,5 +1,5 @@ # polars-time -`polars-time` is a sub-crate that provides time-related code for the Polars dataframe library. +`polars-time` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, focusing on time-related utilities. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-time/src/chunkedarray/mod.rs b/crates/polars-time/src/chunkedarray/mod.rs index 04a20d644cba..d8ab91cfa1b5 100644 --- a/crates/polars-time/src/chunkedarray/mod.rs +++ b/crates/polars-time/src/chunkedarray/mod.rs @@ -6,6 +6,7 @@ mod datetime; #[cfg(feature = "dtype-duration")] mod duration; mod kernels; +#[cfg(feature = "rolling_window")] mod rolling_window; #[cfg(feature = "dtype-time")] mod time; @@ -21,6 +22,7 @@ pub use duration::DurationMethods; use kernels::*; use polars_arrow::utils::CustomIterTools; use polars_core::prelude::*; +#[cfg(feature = "rolling_window")] pub use rolling_window::*; #[cfg(feature = "dtype-time")] pub use time::TimeMethods; diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs new file mode 100644 index 000000000000..8135134dd884 --- /dev/null +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -0,0 +1,252 @@ +use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type}; + +use super::*; +use crate::prelude::*; +use crate::series::AsSeries; + +#[allow(clippy::type_complexity)] +fn rolling_agg( + ca: &ChunkedArray, + options: RollingOptionsImpl, + rolling_agg_fn: &dyn Fn( + &[T::Native], + usize, + usize, + bool, + Option<&[f64]>, + DynArgs, + ) -> PolarsResult, + rolling_agg_fn_nulls: &dyn Fn( + &PrimitiveArray, + usize, + usize, + bool, + Option<&[f64]>, + DynArgs, + ) -> ArrayRef, + rolling_agg_fn_dynamic: Option< + &dyn Fn( + &[T::Native], + Duration, + &[i64], + ClosedWindow, + TimeUnit, + Option<&TimeZone>, + DynArgs, + ) -> PolarsResult, + >, +) -> PolarsResult +where + T: PolarsNumericType, +{ + if ca.is_empty() { + return Ok(Series::new_empty(ca.name(), ca.dtype())); + } + let ca = ca.rechunk(); + + let arr = ca.downcast_iter().next().unwrap(); + // "5i" is a window size of 5, e.g. fixed + let arr = if options.window_size.parsed_int { + let options: RollingOptionsFixedWindow = options.into(); + check_input(options.window_size, options.min_periods)?; + + Ok(match ca.null_count() { + 0 => rolling_agg_fn( + arr.values().as_slice(), + options.window_size, + options.min_periods, + options.center, + options.weights.as_deref(), + options.fn_params, + )?, + _ => rolling_agg_fn_nulls( + arr, + options.window_size, + options.min_periods, + options.center, + options.weights.as_deref(), + options.fn_params, + ), + }) + } else { + if arr.null_count() > 0 { + panic!("'rolling by' not yet supported for series with null values, consider using 'group_by_rolling'") + } + let values = arr.values().as_slice(); + let duration = options.window_size; + polars_ensure!(duration.duration_ns() > 0 && !duration.negative, ComputeError:"window size should be strictly positive"); + let tu = options.tu.unwrap(); + let by = options.by.unwrap(); + let closed_window = options.closed_window.expect("closed window must be set"); + let func = rolling_agg_fn_dynamic.expect( + "'rolling by' not yet supported for this expression, consider using 'group_by_rolling'", + ); + + func( + values, + duration, + by, + closed_window, + tu, + options.tz, + options.fn_params, + ) + }?; + Series::try_from((ca.name(), arr)) +} + +pub trait SeriesOpsTime: AsSeries { + /// Apply a rolling mean to a Series. + /// + /// See: [`RollingAgg::rolling_mean`] + #[cfg(feature = "rolling_window")] + fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_mean, + &rolling::nulls::rolling_mean, + Some(&super::rolling_kernels::no_nulls::rolling_mean), + ) + }) + } + /// Apply a rolling sum to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { + let mut s = self.as_series().clone(); + if options.weights.is_some() { + s = s.to_float()?; + } + + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_sum, + &rolling::nulls::rolling_sum, + Some(&super::rolling_kernels::no_nulls::rolling_sum), + ) + }) + } + /// Apply a rolling median to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_median(&self, options: RollingOptionsImpl) -> PolarsResult { + let s = self.as_series().to_float()?; + + // At the last possible second, right before we do computations, make sure we're using the + // right quantile parameters to get a median. This also lets us have the convenience of + // calling `rolling_median` from Rust without a bunch of dedicated functions that just call + // out to the `rolling_quantile` anyway. + let mut options = options.clone(); + options.fn_params = Some(Arc::new(RollingQuantileParams { + prob: 0.5, + interpol: QuantileInterpolOptions::Linear, + }) as Arc); + + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_quantile, + &rolling::nulls::rolling_quantile, + Some(&super::rolling_kernels::no_nulls::rolling_quantile), + ) + }) + } + /// Apply a rolling quantile to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_quantile, + &rolling::nulls::rolling_quantile, + Some(&super::rolling_kernels::no_nulls::rolling_quantile), + ) + }) + } + + /// Apply a rolling min to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { + let mut s = self.as_series().clone(); + if options.weights.is_some() { + s = s.to_float()?; + } + + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_min, + &rolling::nulls::rolling_min, + Some(&super::rolling_kernels::no_nulls::rolling_min), + ) + }) + } + /// Apply a rolling max to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { + let mut s = self.as_series().clone(); + if options.weights.is_some() { + s = s.to_float()?; + } + + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_max, + &rolling::nulls::rolling_max, + Some(&super::rolling_kernels::no_nulls::rolling_max), + ) + }) + } + + /// Apply a rolling variance to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg( + ca, + options, + &rolling::no_nulls::rolling_var, + &rolling::nulls::rolling_var, + Some(&super::rolling_kernels::no_nulls::rolling_var), + ) + }) + } + + /// Apply a rolling std_dev to a Series. + #[cfg(feature = "rolling_window")] + fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { + self.rolling_var(options).map(|mut s| { + match s.dtype().clone() { + DataType::Float32 => { + let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); + ca.apply_mut(|v| v.powf(0.5)) + }, + DataType::Float64 => { + let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); + ca.apply_mut(|v| v.powf(0.5)) + }, + _ => unreachable!(), + } + s + }) + } +} + +impl SeriesOpsTime for Series {} diff --git a/crates/polars-time/src/chunkedarray/rolling_window/floats.rs b/crates/polars-time/src/chunkedarray/rolling_window/floats.rs deleted file mode 100644 index 55b5a89657f5..000000000000 --- a/crates/polars-time/src/chunkedarray/rolling_window/floats.rs +++ /dev/null @@ -1,151 +0,0 @@ -use num::pow::Pow; -use num::Float; -use polars_core::export::num; - -use super::*; - -#[cfg(not(feature = "rolling_window"))] -impl RollingAgg for WrapFloat> -where - T: PolarsFloatType, - T::Native: Pow + Float, - ChunkedArray: IntoSeries, -{ -} - -#[cfg(feature = "rolling_window")] -impl RollingAgg for WrapFloat> -where - T: PolarsFloatType, - T::Native: Pow + Float, - ChunkedArray: IntoSeries, -{ - /// Apply a rolling mean (moving mean) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their mean. - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_mean, - &rolling::nulls::rolling_mean, - Some(&super::rolling_kernels::no_nulls::rolling_mean), - ) - } - - /// Apply a rolling sum (moving sum) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their sum. - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_sum, - &rolling::nulls::rolling_sum, - Some(&super::rolling_kernels::no_nulls::rolling_sum), - ) - } - - /// Apply a rolling min (moving min) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their min. - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_min, - &rolling::nulls::rolling_min, - Some(&super::rolling_kernels::no_nulls::rolling_min), - ) - } - - /// Apply a rolling max (moving max) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their max. - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_max, - &rolling::nulls::rolling_max, - Some(&super::rolling_kernels::no_nulls::rolling_max), - ) - } - - /// Apply a rolling median (moving median) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be weighted according to the `weights` vector. - fn rolling_median(&self, options: RollingOptionsImpl) -> PolarsResult { - // At the last possible second, right before we do computations, make sure we're using the - // right quantile parameters to get a median. This also lets us have the convenience of - // calling `rolling_median` from Rust without a bunch of dedicated functions that just call - // out to the `rolling_quantile` anyway. - let mut options = options.clone(); - options.fn_params = Some(Arc::new(RollingQuantileParams { - prob: 0.5, - interpol: QuantileInterpolOptions::Linear, - }) as Arc); - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_quantile, - &rolling::nulls::rolling_quantile, - Some(&super::rolling_kernels::no_nulls::rolling_quantile), - ) - } - - /// Apply a rolling quantile (moving quantile) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be weighted according to the `weights` vector. - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_quantile, - &rolling::nulls::rolling_quantile, - Some(&super::rolling_kernels::no_nulls::rolling_quantile), - ) - } - - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_var, - &rolling::nulls::rolling_var, - Some(&super::rolling_kernels::no_nulls::rolling_var), - ) - } - - /// Apply a rolling std (moving std) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their std. - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_var, - &rolling::nulls::rolling_var, - Some(&super::rolling_kernels::no_nulls::rolling_var), - ) - .map(|mut s| { - match s.dtype().clone() { - DataType::Float32 => { - let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); - ca.apply_mut(|v| v.powf(0.5)) - }, - DataType::Float64 => { - let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); - ca.apply_mut(|v| v.powf(0.5)) - }, - _ => unreachable!(), - } - s - }) - } -} diff --git a/crates/polars-time/src/chunkedarray/rolling_window/ints.rs b/crates/polars-time/src/chunkedarray/rolling_window/ints.rs deleted file mode 100644 index a25664f98cae..000000000000 --- a/crates/polars-time/src/chunkedarray/rolling_window/ints.rs +++ /dev/null @@ -1,76 +0,0 @@ -use super::*; -use crate::series::WrapInt; - -#[cfg(not(feature = "rolling_window"))] -impl RollingAgg for WrapInt> -where - T: PolarsIntegerType, - T::Native: IsFloat + SubAssign, -{ -} - -#[cfg(feature = "rolling_window")] -impl RollingAgg for WrapInt> -where - T: PolarsIntegerType, - T::Native: IsFloat + SubAssign, -{ - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { - if options.weights.is_some() { - return self.0.cast(&DataType::Float64)?.rolling_sum(options); - } - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_sum, - &rolling::nulls::rolling_sum, - Some(&super::rolling_kernels::no_nulls::rolling_sum), - ) - } - - fn rolling_median(&self, options: RollingOptionsImpl) -> PolarsResult { - self.0.cast(&DataType::Float64)?.rolling_median(options) - } - - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { - self.0.cast(&DataType::Float64)?.rolling_quantile(options) - } - - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { - if options.weights.is_some() { - return self.0.cast(&DataType::Float64)?.rolling_min(options); - } - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_min, - &rolling::nulls::rolling_min, - Some(&super::rolling_kernels::no_nulls::rolling_min), - ) - } - - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { - if options.weights.is_some() { - return self.0.cast(&DataType::Float64)?.rolling_max(options); - } - rolling_agg( - &self.0, - options, - &rolling::no_nulls::rolling_max, - &rolling::nulls::rolling_max, - Some(&super::rolling_kernels::no_nulls::rolling_max), - ) - } - - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { - self.0.cast(&DataType::Float64)?.rolling_var(options) - } - - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { - self.0.cast(&DataType::Float64)?.rolling_std(options) - } - - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { - self.0.cast(&DataType::Float64)?.rolling_mean(options) - } -} diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index dbb3e07d18e6..c5eb5afc75ed 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -1,27 +1,17 @@ -mod floats; -mod ints; -#[cfg(feature = "rolling_window")] +mod dispatch; mod rolling_kernels; -#[cfg(feature = "rolling_window")] use std::convert::TryFrom; -use std::ops::SubAssign; -#[cfg(feature = "rolling_window")] use arrow::array::{Array, PrimitiveArray}; -use polars_arrow::data_types::IsFloat; -#[cfg(feature = "rolling_window")] +pub use dispatch::*; use polars_arrow::export::arrow; -#[cfg(feature = "rolling_window")] use polars_arrow::kernels::rolling; use polars_core::prelude::*; -#[cfg(feature = "rolling_window")] use crate::prelude::*; -use crate::series::WrapFloat; #[derive(Clone)] -#[cfg(feature = "rolling_window")] pub struct RollingOptions { /// The length of the window. pub window_size: Duration, @@ -40,7 +30,6 @@ pub struct RollingOptions { pub fn_params: DynArgs, } -#[cfg(feature = "rolling_window")] impl Default for RollingOptions { fn default() -> Self { RollingOptions { @@ -56,7 +45,6 @@ impl Default for RollingOptions { } #[derive(Clone)] -#[cfg(feature = "rolling_window")] pub struct RollingOptionsImpl<'a> { /// The length of the window. pub window_size: Duration, @@ -74,7 +62,6 @@ pub struct RollingOptionsImpl<'a> { pub fn_params: DynArgs, } -#[cfg(feature = "rolling_window")] impl From for RollingOptionsImpl<'static> { fn from(options: RollingOptions) -> Self { let window_size = options.window_size; @@ -97,7 +84,6 @@ impl From for RollingOptionsImpl<'static> { } } -#[cfg(feature = "rolling_window")] impl From for RollingOptionsFixedWindow { fn from(options: RollingOptions) -> Self { let window_size = options.window_size; @@ -116,7 +102,6 @@ impl From for RollingOptionsFixedWindow { } } -#[cfg(feature = "rolling_window")] impl Default for RollingOptionsImpl<'static> { fn default() -> Self { RollingOptionsImpl { @@ -133,7 +118,6 @@ impl Default for RollingOptionsImpl<'static> { } } -#[cfg(feature = "rolling_window")] impl<'a> From> for RollingOptionsFixedWindow { fn from(options: RollingOptionsImpl<'a>) -> Self { let window_size = options.window_size; @@ -152,61 +136,7 @@ impl<'a> From> for RollingOptionsFixedWindow { } } -#[cfg(not(feature = "rolling_window"))] -pub trait RollingAgg {} - -#[cfg(feature = "rolling_window")] -pub trait RollingAgg { - /// Apply a rolling mean (moving mean) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their mean. - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling sum (moving sum) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their sum. - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling min (moving min) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their min. - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling max (moving max) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their max. - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling median (moving median) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be weighted according to the `weights` vector. - fn rolling_median(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling quantile (moving quantile) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be weighted according to the `weights` vector. - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling var (moving var) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their var. - #[cfg(feature = "rolling_window")] - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult; - - /// Apply a rolling std (moving std) over the values in this array. - /// A window of length `window_size` will traverse the array. The values that fill this window - /// will (optionally) be multiplied with the weights given by the `weights` vector. The resulting - /// values will be aggregated to their std. - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult; -} - /// utility -#[cfg(feature = "rolling_window")] fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> { polars_ensure!( min_periods <= window_size, @@ -214,95 +144,3 @@ fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> { ); Ok(()) } - -#[cfg(feature = "rolling_window")] -#[allow(clippy::type_complexity)] -fn rolling_agg( - ca: &ChunkedArray, - options: RollingOptionsImpl, - rolling_agg_fn: &dyn Fn( - &[T::Native], - usize, - usize, - bool, - Option<&[f64]>, - DynArgs, - ) -> PolarsResult, - rolling_agg_fn_nulls: &dyn Fn( - &PrimitiveArray, - usize, - usize, - bool, - Option<&[f64]>, - DynArgs, - ) -> ArrayRef, - rolling_agg_fn_dynamic: Option< - &dyn Fn( - &[T::Native], - Duration, - &[i64], - ClosedWindow, - TimeUnit, - Option<&TimeZone>, - DynArgs, - ) -> PolarsResult, - >, -) -> PolarsResult -where - T: PolarsNumericType, -{ - if ca.is_empty() { - return Ok(Series::new_empty(ca.name(), ca.dtype())); - } - let ca = ca.rechunk(); - - let arr = ca.downcast_iter().next().unwrap(); - // "5i" is a window size of 5, e.g. fixed - let arr = if options.window_size.parsed_int { - let options: RollingOptionsFixedWindow = options.into(); - check_input(options.window_size, options.min_periods)?; - - Ok(match ca.null_count() { - 0 => rolling_agg_fn( - arr.values().as_slice(), - options.window_size, - options.min_periods, - options.center, - options.weights.as_deref(), - options.fn_params, - )?, - _ => rolling_agg_fn_nulls( - arr, - options.window_size, - options.min_periods, - options.center, - options.weights.as_deref(), - options.fn_params, - ), - }) - } else { - if arr.null_count() > 0 { - panic!("'rolling by' not yet supported for series with null values, consider using 'group_by_rolling'") - } - let values = arr.values().as_slice(); - let duration = options.window_size; - polars_ensure!(duration.duration_ns() > 0 && !duration.negative, ComputeError:"window size should be strictly positive"); - let tu = options.tu.unwrap(); - let by = options.by.unwrap(); - let closed_window = options.closed_window.expect("closed window must be set"); - let func = rolling_agg_fn_dynamic.expect( - "'rolling by' not yet supported for this expression, consider using 'group_by_rolling'", - ); - - func( - values, - duration, - by, - closed_window, - tu, - options.tz, - options.fn_params, - ) - }?; - Series::try_from((ca.name(), arr)) -} diff --git a/crates/polars-time/src/chunkedarray/time.rs b/crates/polars-time/src/chunkedarray/time.rs index fe0b94190cc9..5d2fd2691b0f 100644 --- a/crates/polars-time/src/chunkedarray/time.rs +++ b/crates/polars-time/src/chunkedarray/time.rs @@ -50,16 +50,15 @@ impl TimeMethods for TimeChunked { } fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> TimeChunked { - let mut ca: Int64Chunked = v - .iter() + v.iter() .map(|s| { NaiveTime::parse_from_str(s, fmt) .ok() .as_ref() .map(time_to_time64ns) }) - .collect_trusted(); - ca.rename(name); - ca.into() + .collect_trusted::() + .with_name(name) + .into() } } diff --git a/crates/polars-time/src/chunkedarray/utf8/infer.rs b/crates/polars-time/src/chunkedarray/utf8/infer.rs index f3e559368f74..7ac33e85545d 100644 --- a/crates/polars-time/src/chunkedarray/utf8/infer.rs +++ b/crates/polars-time/src/chunkedarray/utf8/infer.rs @@ -231,90 +231,34 @@ impl TryFromWithUnit for DatetimeInfer { fn try_from_with_unit(value: Pattern, time_unit: Option) -> PolarsResult { let time_unit = time_unit.expect("time_unit must be provided for datetime"); - match (value, time_unit) { - (Pattern::DatetimeDMY, TimeUnit::Milliseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeDMY, - patterns: patterns::DATETIME_D_M_Y, - latest_fmt: patterns::DATETIME_D_M_Y[0], - transform: transform_datetime_ms, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Milliseconds, None), - }), - (Pattern::DatetimeDMY, TimeUnit::Microseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeDMY, - patterns: patterns::DATETIME_D_M_Y, - latest_fmt: patterns::DATETIME_D_M_Y[0], - transform: transform_datetime_us, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Microseconds, None), - }), - (Pattern::DatetimeDMY, TimeUnit::Nanoseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeDMY, - patterns: patterns::DATETIME_D_M_Y, - latest_fmt: patterns::DATETIME_D_M_Y[0], - transform: transform_datetime_ns, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Nanoseconds, None), - }), - (Pattern::DatetimeYMD, TimeUnit::Milliseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeYMD, - patterns: patterns::DATETIME_Y_M_D, - latest_fmt: patterns::DATETIME_Y_M_D[0], - transform: transform_datetime_ms, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Milliseconds, None), - }), - (Pattern::DatetimeYMD, TimeUnit::Microseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeYMD, - patterns: patterns::DATETIME_Y_M_D, - latest_fmt: patterns::DATETIME_Y_M_D[0], - transform: transform_datetime_us, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Microseconds, None), - }), - (Pattern::DatetimeYMD, TimeUnit::Nanoseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeYMD, - patterns: patterns::DATETIME_Y_M_D, - latest_fmt: patterns::DATETIME_Y_M_D[0], - transform: transform_datetime_ns, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Nanoseconds, None), - }), - (Pattern::DatetimeYMDZ, TimeUnit::Milliseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeYMDZ, - patterns: patterns::DATETIME_Y_M_D_Z, - latest_fmt: patterns::DATETIME_Y_M_D_Z[0], - transform: transform_tzaware_datetime_ms, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Milliseconds, None), - }), - (Pattern::DatetimeYMDZ, TimeUnit::Microseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeYMDZ, - patterns: patterns::DATETIME_Y_M_D_Z, - latest_fmt: patterns::DATETIME_Y_M_D_Z[0], - transform: transform_tzaware_datetime_us, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Microseconds, None), - }), - (Pattern::DatetimeYMDZ, TimeUnit::Nanoseconds) => Ok(DatetimeInfer { - pattern: Pattern::DatetimeYMDZ, - patterns: patterns::DATETIME_Y_M_D_Z, - latest_fmt: patterns::DATETIME_Y_M_D_Z[0], - transform: transform_tzaware_datetime_ns, - transform_bytes: StrpTimeState::default(), - fmt_len: 0, - logical_type: DataType::Datetime(TimeUnit::Nanoseconds, None), - }), - _ => polars_bail!(ComputeError: "could not convert pattern"), - } + + let transform = match (time_unit, value) { + (TimeUnit::Milliseconds, Pattern::DatetimeYMDZ) => transform_tzaware_datetime_ms, + (TimeUnit::Milliseconds, _) => transform_datetime_ms, + (TimeUnit::Microseconds, Pattern::DatetimeYMDZ) => transform_tzaware_datetime_us, + (TimeUnit::Microseconds, _) => transform_datetime_us, + (TimeUnit::Nanoseconds, Pattern::DatetimeYMDZ) => transform_tzaware_datetime_ns, + (TimeUnit::Nanoseconds, _) => transform_datetime_ns, + }; + let (pattern, patterns) = match value { + Pattern::DatetimeDMY | Pattern::DateDMY => { + (Pattern::DatetimeDMY, patterns::DATETIME_D_M_Y) + }, + Pattern::DatetimeYMD | Pattern::DateYMD => { + (Pattern::DatetimeYMD, patterns::DATETIME_Y_M_D) + }, + Pattern::DatetimeYMDZ => (Pattern::DatetimeYMDZ, patterns::DATETIME_Y_M_D_Z), + }; + + Ok(DatetimeInfer { + pattern, + patterns, + latest_fmt: patterns[0], + transform, + transform_bytes: StrpTimeState::default(), + fmt_len: 0, + logical_type: DataType::Datetime(time_unit, None), + }) } } @@ -380,12 +324,11 @@ where .map(|opt_val| opt_val.and_then(|val| self.parse(val))); PrimitiveArray::from_trusted_len_iter(iter) }); - let mut out = ChunkedArray::from_chunk_iter(ca.name(), chunks) + ChunkedArray::from_chunk_iter(ca.name(), chunks) .into_series() .cast(&self.logical_type) - .unwrap(); - out.rename(ca.name()); - out + .unwrap() + .with_name(ca.name()) } } diff --git a/crates/polars-time/src/chunkedarray/utf8/mod.rs b/crates/polars-time/src/chunkedarray/utf8/mod.rs index 37dc9a7ab2f6..a15c7069a966 100644 --- a/crates/polars-time/src/chunkedarray/utf8/mod.rs +++ b/crates/polars-time/src/chunkedarray/utf8/mod.rs @@ -7,6 +7,7 @@ use chrono::ParseError; pub use patterns::Pattern; #[cfg(feature = "dtype-time")] use polars_core::chunked_array::temporal::time_to_time64ns; +use polars_utils::cache::CachedFunc; use super::*; #[cfg(feature = "dtype-date")] @@ -29,16 +30,11 @@ fn datetime_pattern(val: &str, convert: F) -> Option<&'static str> where F: Fn(&str, &str) -> chrono::ParseResult, { - let result = patterns::DATETIME_Y_M_D + patterns::DATETIME_Y_M_D .iter() + .chain(patterns::DATETIME_D_M_Y) .find(|fmt| convert(val, fmt).is_ok()) - .copied(); - result.or_else(|| { - patterns::DATETIME_D_M_Y - .iter() - .find(|fmt| convert(val, fmt).is_ok()) - .copied() - }) + .copied() } fn date_pattern(val: &str, convert: F) -> Option<&'static str> @@ -46,24 +42,19 @@ fn date_pattern(val: &str, convert: F) -> Option<&'static str> where F: Fn(&str, &str) -> chrono::ParseResult, { - let result = patterns::DATE_Y_M_D + patterns::DATE_Y_M_D .iter() + .chain(patterns::DATE_D_M_Y) .find(|fmt| convert(val, fmt).is_ok()) - .copied(); - result.or_else(|| { - patterns::DATE_D_M_Y - .iter() - .find(|fmt| convert(val, fmt).is_ok()) - .copied() - }) + .copied() } struct ParseErrorByteCopy(ParseErrorKind); impl From for ParseErrorByteCopy { fn from(e: ParseError) -> Self { - // we need to do this until chrono ParseErrorKind is public - // blocked by https://github.com/chronotope/chrono/pull/588 + // We need to do this until chrono ParseErrorKind is public + // blocked by https://github.com/chronotope/chrono/pull/588. unsafe { std::mem::transmute(e) } } } @@ -92,93 +83,40 @@ fn get_first_val(ca: &Utf8Chunked) -> PolarsResult<&str> { #[cfg(feature = "dtype-datetime")] fn sniff_fmt_datetime(ca_utf8: &Utf8Chunked) -> PolarsResult<&'static str> { let val = get_first_val(ca_utf8)?; - match datetime_pattern(val, NaiveDateTime::parse_from_str) { - Some(pattern) => Ok(pattern), - None => match datetime_pattern(val, NaiveDate::parse_from_str) { - Some(pattern) => Ok(pattern), - None => polars_bail!(parse_fmt_idk = "datetime"), - }, - } + datetime_pattern(val, NaiveDateTime::parse_from_str) + .or_else(|| datetime_pattern(val, NaiveDate::parse_from_str)) + .ok_or_else(|| polars_err!(parse_fmt_idk = "datetime")) } #[cfg(feature = "dtype-date")] fn sniff_fmt_date(ca_utf8: &Utf8Chunked) -> PolarsResult<&'static str> { let val = get_first_val(ca_utf8)?; - if let Some(pattern) = date_pattern(val, NaiveDate::parse_from_str) { - return Ok(pattern); - } - polars_bail!(parse_fmt_idk = "date"); + date_pattern(val, NaiveDate::parse_from_str).ok_or_else(|| polars_err!(parse_fmt_idk = "date")) } #[cfg(feature = "dtype-time")] fn sniff_fmt_time(ca_utf8: &Utf8Chunked) -> PolarsResult<&'static str> { let val = get_first_val(ca_utf8)?; - if let Some(pattern) = time_pattern(val, NaiveTime::parse_from_str) { - return Ok(pattern); - } - polars_bail!(parse_fmt_idk = "time"); + time_pattern(val, NaiveTime::parse_from_str).ok_or_else(|| polars_err!(parse_fmt_idk = "time")) } pub trait Utf8Methods: AsUtf8 { #[cfg(feature = "dtype-time")] /// Parsing string values and return a [`TimeChunked`] - fn as_time(&self, fmt: Option<&str>, cache: bool) -> PolarsResult { + fn as_time(&self, fmt: Option<&str>, use_cache: bool) -> PolarsResult { let utf8_ca = self.as_utf8(); let fmt = match fmt { Some(fmt) => fmt, None => sniff_fmt_time(utf8_ca)?, }; - let cache = cache && utf8_ca.len() > 50; - - let mut cache_map = PlHashMap::new(); - - let mut ca: Int64Chunked = match utf8_ca.has_validity() { - false => utf8_ca - .into_no_null_iter() - .map(|s| { - if cache { - *cache_map.entry(s).or_insert_with(|| { - NaiveTime::parse_from_str(s, fmt) - .ok() - .as_ref() - .map(time_to_time64ns) - }) - } else { - NaiveTime::parse_from_str(s, fmt) - .ok() - .as_ref() - .map(time_to_time64ns) - } - }) - .collect_trusted(), - _ => utf8_ca - .into_iter() - .map(|opt_s| { - let opt_nd = opt_s.map(|s| { - if cache { - *cache_map.entry(s).or_insert_with(|| { - NaiveTime::parse_from_str(s, fmt) - .ok() - .as_ref() - .map(time_to_time64ns) - }) - } else { - NaiveTime::parse_from_str(s, fmt) - .ok() - .as_ref() - .map(time_to_time64ns) - } - }); - match opt_nd { - None => None, - Some(None) => None, - Some(Some(nd)) => Some(nd), - } - }) - .collect_trusted(), - }; - ca.rename(utf8_ca.name()); - Ok(ca.into()) + let use_cache = use_cache && utf8_ca.len() > 50; + + let mut convert = CachedFunc::new(|s| { + let naive_time = NaiveTime::parse_from_str(s, fmt).ok()?; + Some(time_to_time64ns(&naive_time)) + }); + let ca = utf8_ca.apply_generic(|opt_s| convert.eval(opt_s?, use_cache)); + Ok(ca.with_name(utf8_ca.name()).into()) } #[cfg(feature = "dtype-date")] @@ -191,38 +129,29 @@ pub trait Utf8Methods: AsUtf8 { Some(fmt) => fmt, None => sniff_fmt_date(utf8_ca)?, }; - let mut ca: Int32Chunked = utf8_ca - .into_iter() - .map(|opt_s| match opt_s { - None => None, - Some(mut s) => { - let fmt_len = fmt.len(); + let ca = utf8_ca.apply_generic(|opt_s| { + let mut s = opt_s?; + let fmt_len = fmt.len(); - for i in 1..(s.len().saturating_sub(fmt_len)) { - if s.is_empty() { - return None; - } - match NaiveDate::parse_from_str(s, fmt).map(naive_date_to_date) { - Ok(nd) => return Some(nd), - Err(e) => { - let e: ParseErrorByteCopy = e.into(); - match e.0 { - ParseErrorKind::TooLong => { - s = &s[..s.len() - 1]; - }, - _ => { - s = &s[i..]; - }, - } - }, - } - } - None - }, - }) - .collect_trusted(); - ca.rename(utf8_ca.name()); - Ok(ca.into()) + for i in 1..(s.len().saturating_sub(fmt_len)) { + if s.is_empty() { + return None; + } + match NaiveDate::parse_from_str(s, fmt).map(naive_date_to_date) { + Ok(nd) => return Some(nd), + Err(e) => match ParseErrorByteCopy::from(e).0 { + ParseErrorKind::TooLong => { + s = &s[..s.len() - 1]; + }, + _ => { + s = &s[i..]; + }, + }, + } + } + None + }); + Ok(ca.with_name(utf8_ca.name()).into()) } #[cfg(feature = "dtype-datetime")] @@ -249,41 +178,38 @@ pub trait Utf8Methods: AsUtf8 { TimeUnit::Milliseconds => datetime_to_timestamp_ms, }; - let mut ca: Int64Chunked = utf8_ca - .into_iter() - .map(|opt_s| match opt_s { - None => None, - Some(mut s) => { - let fmt_len = fmt.len(); + let ca = utf8_ca + .apply_generic(|opt_s| { + let mut s = opt_s?; + let fmt_len = fmt.len(); - for i in 1..(s.len().saturating_sub(fmt_len)) { - if s.is_empty() { - return None; - } - let timestamp = match tz_aware { - true => DateTime::parse_from_str(s, fmt).map(|dt| func(dt.naive_utc())), - false => NaiveDateTime::parse_from_str(s, fmt).map(func), - }; - match timestamp { - Ok(ts) => return Some(ts), - Err(e) => { - let e: ParseErrorByteCopy = e.into(); - match e.0 { - ParseErrorKind::TooLong => { - s = &s[..s.len() - 1]; - }, - _ => { - s = &s[i..]; - }, - } - }, - } + for i in 1..(s.len().saturating_sub(fmt_len)) { + if s.is_empty() { + return None; } - None - }, + let timestamp = if tz_aware { + DateTime::parse_from_str(s, fmt).map(|dt| func(dt.naive_utc())) + } else { + NaiveDateTime::parse_from_str(s, fmt).map(func) + }; + match timestamp { + Ok(ts) => return Some(ts), + Err(e) => { + let e: ParseErrorByteCopy = e.into(); + match e.0 { + ParseErrorKind::TooLong => { + s = &s[..s.len() - 1]; + }, + _ => { + s = &s[i..]; + }, + } + }, + } + } + None }) - .collect_trusted(); - ca.rename(utf8_ca.name()); + .with_name(utf8_ca.name()); match (tz_aware, tz) { #[cfg(feature = "timezones")] (false, Some(tz)) => polars_ops::prelude::replace_time_zone( @@ -299,87 +225,46 @@ pub trait Utf8Methods: AsUtf8 { #[cfg(feature = "dtype-date")] /// Parsing string values and return a [`DateChunked`] - fn as_date(&self, fmt: Option<&str>, cache: bool) -> PolarsResult { + fn as_date(&self, fmt: Option<&str>, use_cache: bool) -> PolarsResult { let utf8_ca = self.as_utf8(); let fmt = match fmt { Some(fmt) => fmt, None => return infer::to_date(utf8_ca), }; - let cache = cache && utf8_ca.len() > 50; + let use_cache = use_cache && utf8_ca.len() > 50; let fmt = strptime::compile_fmt(fmt)?; - let mut cache_map = PlHashMap::new(); - // we can use the fast parser - let mut ca: Int32Chunked = if let Some(fmt_len) = strptime::fmt_len(fmt.as_bytes()) { + // We can use the fast parser. + let ca = if let Some(fmt_len) = strptime::fmt_len(fmt.as_bytes()) { let mut strptime_cache = StrpTimeState::default(); - let mut convert = |s: &str| { - // Safety: - // fmt_len is correct, it was computed with this `fmt` str. + let mut convert = CachedFunc::new(|s: &str| { + // SAFETY: fmt_len is correct, it was computed with this `fmt` str. match unsafe { strptime_cache.parse(s.as_bytes(), fmt.as_bytes(), fmt_len) } { - // fallback to chrono + // Fallback to chrono. None => NaiveDate::parse_from_str(s, &fmt).ok(), Some(ndt) => Some(ndt.date()), } .map(naive_date_to_date) - }; - - if utf8_ca.null_count() == 0 { - utf8_ca - .into_no_null_iter() - .map(|val| { - if cache { - *cache_map.entry(val).or_insert_with(|| convert(val)) - } else { - convert(val) - } - }) - .collect_trusted() - } else { - utf8_ca - .into_iter() - .map(|opt_s| { - opt_s.and_then(|val| { - if cache { - *cache_map.entry(val).or_insert_with(|| convert(val)) - } else { - convert(val) - } - }) - }) - .collect_trusted() - } + }); + utf8_ca.apply_generic(|val| convert.eval(val?, use_cache)) } else { - utf8_ca - .into_iter() - .map(|opt_s| { - opt_s.and_then(|s| { - if cache { - *cache_map.entry(s).or_insert_with(|| { - NaiveDate::parse_from_str(s, &fmt) - .ok() - .map(naive_date_to_date) - }) - } else { - NaiveDate::parse_from_str(s, &fmt) - .ok() - .map(naive_date_to_date) - } - }) - }) - .collect_trusted() + let mut convert = CachedFunc::new(|s| { + let naive_date = NaiveDate::parse_from_str(s, &fmt).ok()?; + Some(naive_date_to_date(naive_date)) + }); + utf8_ca.apply_generic(|val| convert.eval(val?, use_cache)) }; - ca.rename(utf8_ca.name()); - Ok(ca.into()) + Ok(ca.with_name(utf8_ca.name()).into()) } #[cfg(feature = "dtype-datetime")] - /// Parsing string values and return a [`DatetimeChunked`] + /// Parsing string values and return a [`DatetimeChunked`]. fn as_datetime( &self, fmt: Option<&str>, tu: TimeUnit, - cache: bool, + use_cache: bool, tz_aware: bool, tz: Option<&TimeZone>, ambiguous: &Utf8Chunked, @@ -390,7 +275,7 @@ pub trait Utf8Methods: AsUtf8 { None => return infer::to_datetime(utf8_ca, tu, tz, ambiguous), }; let fmt = strptime::compile_fmt(fmt)?; - let cache = cache && utf8_ca.len() > 50; + let use_cache = use_cache && utf8_ca.len() > 50; let func = match tu { TimeUnit::Nanoseconds => datetime_to_timestamp_ns, @@ -401,115 +286,45 @@ pub trait Utf8Methods: AsUtf8 { if tz_aware { #[cfg(feature = "timezones")] { - use polars_arrow::export::hashbrown::hash_map::Entry; - let mut cache_map = PlHashMap::new(); - - let convert = |s: &str| { - DateTime::parse_from_str(s, &fmt) - .ok() - .map(|dt| func(dt.naive_utc())) - }; - - let mut ca: Int64Chunked = utf8_ca - .into_iter() - .map(|opt_s| { - opt_s - .map(|s| { - let out = if cache { - match cache_map.entry(s) { - Entry::Vacant(entry) => { - let value = convert(s); - entry.insert(value); - value - }, - Entry::Occupied(val) => *val.get(), - } - } else { - convert(s) - }; - Ok(out) - }) - .transpose() - .map(|options| options.flatten()) - }) - .collect::>()?; - - ca.rename(utf8_ca.name()); - Ok(ca.into_datetime(tu, Some("UTC".to_string()))) + let mut convert = CachedFunc::new(|s: &str| { + let dt = DateTime::parse_from_str(s, &fmt).ok()?; + Some(func(dt.naive_utc())) + }); + Ok(utf8_ca + .apply_generic(|opt_s| convert.eval(opt_s?, use_cache)) + .with_name(utf8_ca.name()) + .into_datetime(tu, Some("UTC".to_string()))) } #[cfg(not(feature = "timezones"))] { panic!("activate 'timezones' feature") } } else { - let mut cache_map = PlHashMap::new(); let transform = match tu { TimeUnit::Nanoseconds => infer::transform_datetime_ns, TimeUnit::Microseconds => infer::transform_datetime_us, TimeUnit::Milliseconds => infer::transform_datetime_ms, }; - // we can use the fast parser - let mut ca: Int64Chunked = if let Some(fmt_len) = - self::strptime::fmt_len(fmt.as_bytes()) - { + // We can use the fast parser. + let ca = if let Some(fmt_len) = self::strptime::fmt_len(fmt.as_bytes()) { let mut strptime_cache = StrpTimeState::default(); - let mut convert = |s: &str| { - // Safety: - // fmt_len is correct, it was computed with this `fmt` str. + let mut convert = CachedFunc::new(|s: &str| { + // SAFETY: fmt_len is correct, it was computed with this `fmt` str. match unsafe { strptime_cache.parse(s.as_bytes(), fmt.as_bytes(), fmt_len) } { None => transform(s, &fmt), Some(ndt) => Some(func(ndt)), } - }; - if utf8_ca.null_count() == 0 { - utf8_ca - .into_no_null_iter() - .map(|val| { - if cache { - *cache_map.entry(val).or_insert_with(|| convert(val)) - } else { - convert(val) - } - }) - .collect_trusted() - } else { - utf8_ca - .into_iter() - .map(|opt_s| { - opt_s.and_then(|val| { - if cache { - *cache_map.entry(val).or_insert_with(|| convert(val)) - } else { - convert(val) - } - }) - }) - .collect_trusted() - } + }); + utf8_ca.apply_generic(|opt_s| convert.eval(opt_s?, use_cache)) } else { - let mut cache_map = PlHashMap::new(); - utf8_ca - .into_iter() - .map(|opt_s| { - opt_s.and_then(|s| { - if cache { - *cache_map.entry(s).or_insert_with(|| transform(s, &fmt)) - } else { - transform(s, &fmt) - } - }) - }) - .collect_trusted() + let mut convert = CachedFunc::new(|s| transform(s, &fmt)); + utf8_ca.apply_generic(|opt_s| convert.eval(opt_s?, use_cache)) }; - ca.rename(utf8_ca.name()); + let dt = ca.with_name(utf8_ca.name()).into_datetime(tu, None); match tz { #[cfg(feature = "timezones")] - Some(tz) => polars_ops::prelude::replace_time_zone( - &ca.into_datetime(tu, None), - Some(tz), - ambiguous, - ), - _ => Ok(ca.into_datetime(tu, None)), + Some(tz) => polars_ops::prelude::replace_time_zone(&dt, Some(tz), ambiguous), + _ => Ok(dt), } } } diff --git a/crates/polars-time/src/date_range.rs b/crates/polars-time/src/date_range.rs index b434a5892847..7b81ecbc27dc 100644 --- a/crates/polars-time/src/date_range.rs +++ b/crates/polars-time/src/date_range.rs @@ -1,4 +1,5 @@ use chrono::{Datelike, NaiveDateTime, NaiveTime}; +use polars_arrow::time_zone::Tz; use polars_core::chunked_array::temporal::time_to_time64ns; use polars_core::prelude::*; use polars_core::series::IsSorted; @@ -12,31 +13,46 @@ pub fn in_nanoseconds_window(ndt: &NaiveDateTime) -> bool { !(ndt.year() > 2554 || ndt.year() < 1386) } +/// Create a [`DatetimeChunked`] from a given `start` and `end` date and a given `interval`. +pub fn date_range( + name: &str, + start: NaiveDateTime, + end: NaiveDateTime, + interval: Duration, + closed: ClosedWindow, + tu: TimeUnit, + tz: Option, +) -> PolarsResult { + let (start, end) = match tu { + TimeUnit::Nanoseconds => ( + start.timestamp_nanos_opt().unwrap(), + end.timestamp_nanos_opt().unwrap(), + ), + TimeUnit::Microseconds => (start.timestamp_micros(), end.timestamp_micros()), + TimeUnit::Milliseconds => (start.timestamp_millis(), end.timestamp_millis()), + }; + datetime_range_impl(name, start, end, interval, closed, tu, tz.as_ref()) +} + #[doc(hidden)] -pub fn date_range_impl( +pub fn datetime_range_impl( name: &str, start: i64, - stop: i64, - every: Duration, + end: i64, + interval: Duration, closed: ClosedWindow, tu: TimeUnit, _tz: Option<&TimeZone>, ) -> PolarsResult { - if start > stop { - polars_bail!(ComputeError: "'start' cannot be greater than 'stop'") - } - if every.negative { - polars_bail!(ComputeError: "'interval' cannot be negative") - } let mut out = match _tz { #[cfg(feature = "timezones")] Some(tz) => match tz.parse::() { Ok(tz) => { let start = localize_timestamp(start, tu, tz); - let stop = localize_timestamp(stop, tu, tz); + let end = localize_timestamp(end, tu, tz); Int64Chunked::new_vec( name, - temporal_range_vec(start?, stop?, every, closed, tu, Some(&tz))?, + datetime_range_i64(start?, end?, interval, closed, tu, Some(&tz))?, ) .into_datetime(tu, _tz.cloned()) }, @@ -44,7 +60,7 @@ pub fn date_range_impl( }, _ => Int64Chunked::new_vec( name, - temporal_range_vec(start, stop, every, closed, tu, None)?, + datetime_range_i64(start, end, interval, closed, tu, None)?, ) .into_datetime(tu, None), }; @@ -53,41 +69,30 @@ pub fn date_range_impl( Ok(out) } -/// Create a [`DatetimeChunked`] from a given `start` and `stop` date and a given `every` interval. -pub fn date_range( +/// Create a [`TimeChunked`] from a given `start` and `end` date and a given `interval`. +pub fn time_range( name: &str, - start: NaiveDateTime, - stop: NaiveDateTime, - every: Duration, + start: NaiveTime, + end: NaiveTime, + interval: Duration, closed: ClosedWindow, - tu: TimeUnit, - tz: Option, -) -> PolarsResult { - let (start, stop) = match tu { - TimeUnit::Nanoseconds => (start.timestamp_nanos(), stop.timestamp_nanos()), - TimeUnit::Microseconds => (start.timestamp_micros(), stop.timestamp_micros()), - TimeUnit::Milliseconds => (start.timestamp_millis(), stop.timestamp_millis()), - }; - date_range_impl(name, start, stop, every, closed, tu, tz.as_ref()) +) -> PolarsResult { + let start = time_to_time64ns(&start); + let end = time_to_time64ns(&end); + time_range_impl(name, start, end, interval, closed) } #[doc(hidden)] pub fn time_range_impl( name: &str, start: i64, - stop: i64, - every: Duration, + end: i64, + interval: Duration, closed: ClosedWindow, ) -> PolarsResult { - if start > stop { - polars_bail!(ComputeError: "'start' cannot be greater than 'stop'") - } - if every.negative { - polars_bail!(ComputeError: "'interval' cannot be negative") - } let mut out = Int64Chunked::new_vec( name, - temporal_range_vec(start, stop, every, closed, TimeUnit::Nanoseconds, None)?, + datetime_range_i64(start, end, interval, closed, TimeUnit::Nanoseconds, None)?, ) .into_time(); @@ -95,15 +100,71 @@ pub fn time_range_impl( Ok(out) } -/// Create a [`TimeChunked`] from a given `start` and `stop` date and a given `every` interval. -pub fn time_range( - name: &str, - start: NaiveTime, - stop: NaiveTime, - every: Duration, +/// vector of i64 representing temporal values +pub(crate) fn datetime_range_i64( + start: i64, + end: i64, + interval: Duration, closed: ClosedWindow, -) -> PolarsResult { - let start = time_to_time64ns(&start); - let stop = time_to_time64ns(&stop); - time_range_impl(name, start, stop, every, closed) + tu: TimeUnit, + tz: Option<&Tz>, +) -> PolarsResult> { + check_range_bounds(start, end, interval)?; + + let size: usize; + let offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult; + + match tu { + TimeUnit::Nanoseconds => { + size = ((end - start) / interval.duration_ns() + 1) as usize; + offset_fn = Duration::add_ns; + }, + TimeUnit::Microseconds => { + size = ((end - start) / interval.duration_us() + 1) as usize; + offset_fn = Duration::add_us; + }, + TimeUnit::Milliseconds => { + size = ((end - start) / interval.duration_ms() + 1) as usize; + offset_fn = Duration::add_ms; + }, + } + let mut ts = Vec::with_capacity(size); + + let mut t = start; + match closed { + ClosedWindow::Both => { + while t <= end { + ts.push(t); + t = offset_fn(&interval, t, tz)? + } + }, + ClosedWindow::Left => { + while t < end { + ts.push(t); + t = offset_fn(&interval, t, tz)? + } + }, + ClosedWindow::Right => { + t = offset_fn(&interval, t, tz)?; + while t <= end { + ts.push(t); + t = offset_fn(&interval, t, tz)? + } + }, + ClosedWindow::None => { + t = offset_fn(&interval, t, tz)?; + while t < end { + ts.push(t); + t = offset_fn(&interval, t, tz)? + } + }, + } + debug_assert!(size >= ts.len()); + Ok(ts) +} + +fn check_range_bounds(start: i64, end: i64, interval: Duration) -> PolarsResult<()> { + polars_ensure!(end >= start, ComputeError: "`end` must be equal to or greater than `start`"); + polars_ensure!(!interval.negative && !interval.is_zero(), ComputeError: "`interval` must be positive"); + Ok(()) } diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 0364096189c6..2cae296fcf02 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -20,23 +20,22 @@ struct Wrap(pub T); #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct DynamicGroupOptions { - /// Time or index column + /// Time or index column. pub index_column: SmartString, - /// start a window at this interval + /// Start a window at this interval. pub every: Duration, - /// window duration + /// Window duration. pub period: Duration, - /// offset window boundaries + /// Offset window boundaries. pub offset: Duration, - /// truncate the time column values to the window - pub truncate: bool, - /// add the boundaries to the dataframe + /// Truncate the time column values to the window. + pub label: Label, + /// Add the boundaries to the dataframe. pub include_boundaries: bool, pub closed_window: ClosedWindow, pub start_by: StartBy, - /// In cases sortedness cannot be checked by - /// the sorted flag, traverse the data to - /// check sortedness + /// In cases sortedness cannot be checked by the sorted flag, + /// traverse the data to check sortedness. pub check_sorted: bool, } @@ -47,7 +46,7 @@ impl Default for DynamicGroupOptions { every: Duration::new(1), period: Duration::new(1), offset: Duration::new(1), - truncate: true, + label: Label::Left, include_boundaries: false, closed_window: ClosedWindow::Left, start_by: Default::default(), @@ -56,18 +55,17 @@ impl Default for DynamicGroupOptions { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingGroupOptions { - /// Time or index column + /// Time or index column. pub index_column: SmartString, - /// window duration + /// Window duration. pub period: Duration, pub offset: Duration, pub closed_window: ClosedWindow, - /// In cases sortedness cannot be checked by - /// the sorted flag, traverse the data to - /// check sortedness + /// In cases sortedness cannot be checked by the sorted flag, + /// traverse the data to check sortedness. pub check_sorted: bool, } @@ -130,14 +128,14 @@ impl Wrap<&DataFrame> { options: &RollingGroupOptions, ) -> PolarsResult<(Series, Vec, GroupsProxy)> { polars_ensure!( - options.period.duration_ns()>0 && !options.period.negative, + options.period.duration_ns() > 0 && !options.period.negative, ComputeError: "rolling window period should be strictly positive", ); let time = self.0.column(&options.index_column)?.clone(); if by.is_empty() { - // if by is given, the column must be sorted in the 'by' arg, which we can not check now - // this will be checked when the groups are materialized + // If by is given, the column must be sorted in the 'by' arg, which we can not check now + // this will be checked when the groups are materialized. ensure_sorted_arg(&time, "group_by_rolling")?; } let time_type = time.dtype(); @@ -195,7 +193,7 @@ impl Wrap<&DataFrame> { } } - /// Returns: time_keys, keys, groupsproxy + /// Returns: time_keys, keys, groupsproxy. fn group_by_dynamic( &self, by: Vec, @@ -212,8 +210,8 @@ impl Wrap<&DataFrame> { let time = self.0.column(&options.index_column)?.rechunk(); if by.is_empty() { - // if by is given, the column must be sorted in the 'by' arg, which we can not check now - // this will be checked when the groups are materialized + // If by is given, the column must be sorted in the 'by' arg, which we can not check now + // this will be checked when the groups are materialized. ensure_sorted_arg(&time, "group_by_dynamic")?; } let time_type = time.dtype(); @@ -275,8 +273,7 @@ impl Wrap<&DataFrame> { return dt.cast(time_type).map(|s| (s, by, GroupsProxy::default())); } - // a requirement for the index - // so we can set this such that downstream code has this info + // A requirement for the index so we can set this such that downstream code has this info. dt.set_sorted_flag(IsSorted::Ascending); let w = Window::new(options.every, options.period, options.offset); @@ -293,8 +290,10 @@ impl Wrap<&DataFrame> { include_lower_bound = true; include_upper_bound = true; } - if options.truncate { + if options.label == Label::Left { include_lower_bound = true; + } else if options.label == Label::Right { + include_upper_bound = true; } let mut update_bounds = @@ -334,14 +333,14 @@ impl Wrap<&DataFrame> { .group_by_with_series(by.clone(), true, true)? .take_groups(); - // include boundaries cannot be parallel (easily) - if include_lower_bound { + // Include boundaries cannot be parallel (easily). + if include_lower_bound | include_upper_bound { POOL.install(|| match groups { GroupsProxy::Idx(groups) => { let ir = groups .par_iter() .map(|base_g| { - let dt = unsafe { dt.take_unchecked(base_g.1.into()) }; + let dt = unsafe { dt.take_unchecked(base_g.1) }; let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); if options.check_sorted @@ -420,7 +419,7 @@ impl Wrap<&DataFrame> { let groupsidx = groups .par_iter() .map(|base_g| { - let dt = unsafe { dt.take_unchecked(base_g.1.into()) }; + let dt = unsafe { dt.take_unchecked(base_g.1) }; let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); if options.check_sorted @@ -484,29 +483,30 @@ impl Wrap<&DataFrame> { } let lower = lower_bound.map(|lower| Int64Chunked::new_vec(LB_NAME, lower)); + let upper = upper_bound.map(|upper| Int64Chunked::new_vec(UP_NAME, upper)); - if options.truncate { + if options.label == Label::Left { let mut lower = lower.clone().unwrap(); if by.is_empty() { lower.set_sorted_flag(IsSorted::Ascending) } - lower.rename(dt.name()); - dt = lower; + dt = lower.with_name(dt.name()); + } else if options.label == Label::Right { + let mut upper = upper.clone().unwrap(); + if by.is_empty() { + upper.set_sorted_flag(IsSorted::Ascending) + } + dt = upper.with_name(dt.name()); } - if let (true, Some(mut lower), Some(upper)) = - (options.include_boundaries, lower, upper_bound) + if let (true, Some(mut lower), Some(mut upper)) = (options.include_boundaries, lower, upper) { - let mut upper = Int64Chunked::new_vec(UP_NAME, upper) - .into_datetime(tu, tz.clone()) - .into_series(); - if by.is_empty() { lower.set_sorted_flag(IsSorted::Ascending); upper.set_sorted_flag(IsSorted::Ascending); } by.push(lower.into_datetime(tu, tz.clone()).into_series()); - by.push(upper); + by.push(upper.into_datetime(tu, tz.clone()).into_series()); } dt.into_datetime(tu, None) @@ -564,7 +564,7 @@ impl Wrap<&DataFrame> { let idx = groups .par_iter() .map(|base_g| { - let dt = unsafe { dt_local.take_unchecked(base_g.1.into()) }; + let dt = unsafe { dt_local.take_unchecked(base_g.1) }; let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); if options.check_sorted @@ -639,9 +639,8 @@ fn update_subgroups_idx( .iter() .map(|&[first, len]| { let new_first = if len == 0 { - // in case the group is empty - // keep the original first so that the - // group_by keys still point to the original group + // In case the group is empty, keep the original first so that the + // group_by keys still point to the original group. base_g.0 } else { unsafe { *base_g.1.get_unchecked_release(first as usize) } @@ -762,7 +761,7 @@ mod test { let expected = Series::new("", [3, 3, 3, 3, 2, 1]); assert_eq!(min, expected); - // expected for nulls is equal + // Expected for nulls is equality. let min = unsafe { nulls.agg_min(&groups) }; assert_eq!(min, expected); @@ -807,7 +806,7 @@ mod test { .and_hms_opt(3, 0, 0) .unwrap() .timestamp_millis(); - let range = date_range_impl( + let range = datetime_range_impl( "date", start, stop, @@ -829,7 +828,7 @@ mod test { every: Duration::parse("1h"), period: Duration::parse("1h"), offset: Duration::parse("0h"), - truncate: true, + label: Label::Left, include_boundaries: true, closed_window: ClosedWindow::Both, start_by: Default::default(), @@ -860,7 +859,7 @@ mod test { .and_hms_opt(3, 0, 0) .unwrap() .timestamp_millis(); - let range = date_range_impl( + let range = datetime_range_impl( "_upper_boundary", start, stop, @@ -883,7 +882,7 @@ mod test { .and_hms_opt(2, 0, 0) .unwrap() .timestamp_millis(); - let range = date_range_impl( + let range = datetime_range_impl( "_lower_boundary", start, stop, @@ -922,7 +921,7 @@ mod test { .and_hms_opt(12, 0, 0) .unwrap() .timestamp_millis(); - let range = date_range_impl( + let range = datetime_range_impl( "date", start, stop, @@ -944,7 +943,7 @@ mod test { every: Duration::parse("6d"), period: Duration::parse("6d"), offset: Duration::parse("0h"), - truncate: true, + label: Label::Left, include_boundaries: true, closed_window: ClosedWindow::Both, start_by: Default::default(), @@ -952,9 +951,8 @@ mod test { }, ) .unwrap(); - let mut lower_bound = keys[1].clone(); time_key.rename(""); - lower_bound.rename(""); + let lower_bound = keys[1].clone().with_name(""); assert!(time_key.series_equal(&lower_bound)); Ok(()) } diff --git a/crates/polars-time/src/lib.rs b/crates/polars-time/src/lib.rs index b2162e26a740..ea9f6373eb85 100644 --- a/crates/polars-time/src/lib.rs +++ b/crates/polars-time/src/lib.rs @@ -26,7 +26,6 @@ pub use month_start::*; pub use round::*; pub use truncate::*; pub use upsample::*; -pub use windows::calendar::temporal_range as temporal_range_vec; pub use windows::duration::Duration; pub use windows::group_by::ClosedWindow; pub use windows::window::Window; diff --git a/crates/polars-time/src/month_start.rs b/crates/polars-time/src/month_start.rs index f317d3852112..ff03f3317f51 100644 --- a/crates/polars-time/src/month_start.rs +++ b/crates/polars-time/src/month_start.rs @@ -30,16 +30,18 @@ pub(crate) fn roll_backward( ts.second(), ts.timestamp_subsec_nanos(), ) - .ok_or(polars_err!( - ComputeError: - format!( - "Could not construct time {}:{}:{}.{}", - ts.hour(), - ts.minute(), - ts.second(), - ts.timestamp_subsec_nanos() - ) - ))?; + .ok_or_else(|| { + polars_err!( + ComputeError: + format!( + "Could not construct time {}:{}:{}.{}", + ts.hour(), + ts.minute(), + ts.second(), + ts.timestamp_subsec_nanos() + ) + ) + })?; let ndt = NaiveDateTime::new(date, time); let t = match tz { #[cfg(feature = "timezones")] diff --git a/crates/polars-time/src/prelude.rs b/crates/polars-time/src/prelude.rs index aa9a2a1d5b21..22df3cfff049 100644 --- a/crates/polars-time/src/prelude.rs +++ b/crates/polars-time/src/prelude.rs @@ -1,7 +1,7 @@ pub use date_range::*; pub use crate::chunkedarray::*; -pub use crate::series::{SeriesOpsTime, TemporalMethods}; +pub use crate::series::TemporalMethods; pub use crate::windows::bounds::*; pub use crate::windows::duration::*; pub use crate::windows::group_by::*; diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index ed1e8a642468..3c8b4b1858c5 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -1,18 +1,31 @@ use polars_arrow::export::arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY}; use polars_arrow::time_zone::Tz; +use polars_core::prelude::arity::try_binary_elementwise; use polars_core::prelude::*; use crate::prelude::*; pub trait PolarsRound { - fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult + fn round( + &self, + every: Duration, + offset: Duration, + tz: Option<&Tz>, + ambiguous: &Utf8Chunked, + ) -> PolarsResult where Self: Sized; } #[cfg(feature = "dtype-datetime")] impl PolarsRound for DatetimeChunked { - fn round(&self, every: Duration, offset: Duration, tz: Option<&Tz>) -> PolarsResult { + fn round( + &self, + every: Duration, + offset: Duration, + tz: Option<&Tz>, + ambiguous: &Utf8Chunked, + ) -> PolarsResult { let w = Window::new(every, every, offset); let func = match self.time_unit() { @@ -20,20 +33,37 @@ impl PolarsRound for DatetimeChunked { TimeUnit::Microseconds => Window::round_us, TimeUnit::Milliseconds => Window::round_ms, }; - Ok(self - .try_apply(|t| func(&w, t, tz))? - .into_datetime(self.time_unit(), self.time_zone().clone())) + + let out = match ambiguous.len() { + 1 => match ambiguous.get(0) { + Some(ambiguous) => self.try_apply(|t| func(&w, t, tz, ambiguous)), + None => Ok(Int64Chunked::full_null(self.name(), self.len())), + }, + _ => try_binary_elementwise(self, ambiguous, |opt_t, opt_aambiguous| { + match (opt_t, opt_aambiguous) { + (Some(t), Some(aambiguous)) => func(&w, t, tz, aambiguous).map(Some), + _ => Ok(None), + } + }), + }; + out.map(|ok| ok.into_datetime(self.time_unit(), self.time_zone().clone())) } } #[cfg(feature = "dtype-date")] impl PolarsRound for DateChunked { - fn round(&self, every: Duration, offset: Duration, _tz: Option<&Tz>) -> PolarsResult { + fn round( + &self, + every: Duration, + offset: Duration, + _tz: Option<&Tz>, + _ambiguous: &Utf8Chunked, + ) -> PolarsResult { let w = Window::new(every, every, offset); Ok(self .try_apply(|t| { const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; - Ok((w.round_ms(MSECS_IN_DAY * t as i64, None)? / MSECS_IN_DAY) as i32) + Ok((w.round_ms(MSECS_IN_DAY * t as i64, None, "raise")? / MSECS_IN_DAY) as i32) })? .into_date()) } diff --git a/crates/polars-time/src/series/_trait.rs b/crates/polars-time/src/series/_trait.rs deleted file mode 100644 index a4a913f39131..000000000000 --- a/crates/polars-time/src/series/_trait.rs +++ /dev/null @@ -1,107 +0,0 @@ -use super::*; -#[cfg(feature = "rolling_window")] -use crate::prelude::*; - -#[cfg(feature = "rolling_window")] -macro_rules! invalid_operation { - ($s:expr) => { - Err(polars_err!( - InvalidOperation: "this operation is not implemented/valid for this dtype: {:?}", - $s.ops_time_dtype(), - )) - }; -} - -pub trait SeriesOpsTime { - fn ops_time_dtype(&self) -> &DataType; - - /// Apply a rolling mean to a Series. - /// - /// See: [`RollingAgg::rolling_mean`] - #[cfg(feature = "rolling_window")] - fn rolling_mean(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - /// Apply a rolling sum to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_sum(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - /// Apply a rolling median to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_median(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - /// Apply a rolling quantile to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_quantile(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - - /// Apply a rolling min to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_min(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - /// Apply a rolling max to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_max(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - - /// Apply a rolling variance to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_var(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } - - /// Apply a rolling std_dev to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_std(&self, _options: RollingOptionsImpl) -> PolarsResult { - invalid_operation!(self) - } -} - -impl SeriesOpsTime for Series { - fn ops_time_dtype(&self) -> &DataType { - self.deref().dtype() - } - #[cfg(feature = "rolling_window")] - fn rolling_mean(&self, _options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_mean(_options) - } - #[cfg(feature = "rolling_window")] - fn rolling_sum(&self, _options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_sum(_options) - } - #[cfg(feature = "rolling_window")] - fn rolling_median(&self, _options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_median(_options) - } - /// Apply a rolling quantile to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_quantile(options) - } - #[cfg(feature = "rolling_window")] - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_min(options) - } - /// Apply a rolling max to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_max(options) - } - - /// Apply a rolling variance to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_var(options) - } - - /// Apply a rolling std_dev to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { - self.to_ops().rolling_std(options) - } -} diff --git a/crates/polars-time/src/series/implementations/boolean.rs b/crates/polars-time/src/series/implementations/boolean.rs deleted file mode 100644 index 36f85da50e56..000000000000 --- a/crates/polars-time/src/series/implementations/boolean.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/categoricals.rs b/crates/polars-time/src/series/implementations/categoricals.rs deleted file mode 100644 index f68dfec5007a..000000000000 --- a/crates/polars-time/src/series/implementations/categoricals.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/date.rs b/crates/polars-time/src/series/implementations/date.rs deleted file mode 100644 index 035520b80242..000000000000 --- a/crates/polars-time/src/series/implementations/date.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/datetime.rs b/crates/polars-time/src/series/implementations/datetime.rs deleted file mode 100644 index 86f75a8725a4..000000000000 --- a/crates/polars-time/src/series/implementations/datetime.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/duration.rs b/crates/polars-time/src/series/implementations/duration.rs deleted file mode 100644 index b495a2f030e7..000000000000 --- a/crates/polars-time/src/series/implementations/duration.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/floats.rs b/crates/polars-time/src/series/implementations/floats.rs deleted file mode 100644 index d6a8d9377dab..000000000000 --- a/crates/polars-time/src/series/implementations/floats.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::ops::SubAssign; - -use polars_arrow::data_types::IsFloat; - -use super::*; - -impl SeriesOpsTime for WrapFloat> -where - T::Native: IsFloat + SubAssign, - Self: RollingAgg, -{ - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } - - #[cfg(feature = "rolling_window")] - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_mean(self, options) - } - #[cfg(feature = "rolling_window")] - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_sum(self, options) - } - #[cfg(feature = "rolling_window")] - fn rolling_median(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_median(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_quantile(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_min(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_max(self, options) - } - #[cfg(feature = "rolling_window")] - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_var(self, options) - } - - /// Apply a rolling std_dev to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_std(self, options) - } -} diff --git a/crates/polars-time/src/series/implementations/integers.rs b/crates/polars-time/src/series/implementations/integers.rs deleted file mode 100644 index 8ede537ba771..000000000000 --- a/crates/polars-time/src/series/implementations/integers.rs +++ /dev/null @@ -1,50 +0,0 @@ -use super::*; - -impl SeriesOpsTime for WrapInt> -where - T::Native: NumericNative, - Self: RollingAgg, -{ - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } - - #[cfg(feature = "rolling_window")] - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_mean(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_sum(self, options) - } - #[cfg(feature = "rolling_window")] - fn rolling_median(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_median(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_quantile(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_min(self, options) - } - - #[cfg(feature = "rolling_window")] - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_max(self, options) - } - #[cfg(feature = "rolling_window")] - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_var(self, options) - } - - /// Apply a rolling std_dev to a Series. - #[cfg(feature = "rolling_window")] - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { - RollingAgg::rolling_std(self, options) - } -} diff --git a/crates/polars-time/src/series/implementations/list.rs b/crates/polars-time/src/series/implementations/list.rs deleted file mode 100644 index a3caf4b90f05..000000000000 --- a/crates/polars-time/src/series/implementations/list.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/mod.rs b/crates/polars-time/src/series/implementations/mod.rs deleted file mode 100644 index 7a20995299b4..000000000000 --- a/crates/polars-time/src/series/implementations/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -mod boolean; -#[cfg(feature = "dtype-categorical")] -mod categoricals; -#[cfg(feature = "dtype-date")] -mod date; -#[cfg(feature = "dtype-datetime")] -mod datetime; -#[cfg(feature = "dtype-duration")] -mod duration; -mod floats; -mod integers; -mod list; -#[cfg(feature = "object")] -mod object; -#[cfg(feature = "dtype-struct")] -mod struct_; -#[cfg(feature = "dtype-time")] -mod time; -mod utf8; - -use polars_core::prelude::*; -use polars_core::utils::Wrap; - -use crate::prelude::*; -use crate::series::*; diff --git a/crates/polars-time/src/series/implementations/object.rs b/crates/polars-time/src/series/implementations/object.rs deleted file mode 100644 index 019fedfea500..000000000000 --- a/crates/polars-time/src/series/implementations/object.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap> { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/struct_.rs b/crates/polars-time/src/series/implementations/struct_.rs deleted file mode 100644 index 1c20f20d49ad..000000000000 --- a/crates/polars-time/src/series/implementations/struct_.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/time.rs b/crates/polars-time/src/series/implementations/time.rs deleted file mode 100644 index 787e2446c075..000000000000 --- a/crates/polars-time/src/series/implementations/time.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/implementations/utf8.rs b/crates/polars-time/src/series/implementations/utf8.rs deleted file mode 100644 index e459636c9405..000000000000 --- a/crates/polars-time/src/series/implementations/utf8.rs +++ /dev/null @@ -1,7 +0,0 @@ -use super::*; - -impl SeriesOpsTime for Wrap { - fn ops_time_dtype(&self) -> &DataType { - self.0.dtype() - } -} diff --git a/crates/polars-time/src/series/mod.rs b/crates/polars-time/src/series/mod.rs index 1e263e9b076f..6228e6c9b293 100644 --- a/crates/polars-time/src/series/mod.rs +++ b/crates/polars-time/src/series/mod.rs @@ -1,115 +1,9 @@ -mod _trait; -mod implementations; use std::ops::Deref; -use std::sync::Arc; use polars_core::prelude::*; -use polars_core::utils::Wrap; -pub use SeriesOpsTime; -pub use self::_trait::*; use crate::chunkedarray::*; -type SeriesOpsRef = Arc; - -pub trait IntoSeriesOps { - fn to_ops(&self) -> SeriesOpsRef; -} - -impl IntoSeriesOps for Series { - fn to_ops(&self) -> SeriesOpsRef { - match self.dtype() { - DataType::Int8 => self.i8().unwrap().to_ops(), - DataType::Int16 => self.i16().unwrap().to_ops(), - DataType::Int32 => self.i32().unwrap().to_ops(), - DataType::Int64 => self.i64().unwrap().to_ops(), - DataType::UInt8 => self.u8().unwrap().to_ops(), - DataType::UInt16 => self.u16().unwrap().to_ops(), - DataType::UInt32 => self.u32().unwrap().to_ops(), - DataType::UInt64 => self.u64().unwrap().to_ops(), - DataType::Float32 => self.f32().unwrap().to_ops(), - DataType::Float64 => self.f64().unwrap().to_ops(), - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_) => self.categorical().unwrap().to_ops(), - DataType::Boolean => self.bool().unwrap().to_ops(), - DataType::Utf8 => self.utf8().unwrap().to_ops(), - #[cfg(feature = "dtype-date")] - DataType::Date => self.date().unwrap().to_ops(), - #[cfg(feature = "dtype-datetime")] - DataType::Datetime(_, _) => self.datetime().unwrap().to_ops(), - #[cfg(feature = "dtype-duration")] - DataType::Duration(_) => self.duration().unwrap().to_ops(), - #[cfg(feature = "dtype-time")] - DataType::Time => self.time().unwrap().to_ops(), - DataType::List(_) => self.list().unwrap().to_ops(), - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => self.struct_().unwrap().to_ops(), - _ => unimplemented!(), - } - } -} - -impl IntoSeriesOps for &ChunkedArray -where - T::Native: NumericNative, -{ - fn to_ops(&self) -> SeriesOpsRef { - Arc::new(WrapInt((*self).clone())) - } -} - -#[repr(transparent)] -pub(crate) struct WrapFloat(pub T); - -#[repr(transparent)] -pub(crate) struct WrapInt(pub T); - -impl IntoSeriesOps for Float32Chunked { - fn to_ops(&self) -> SeriesOpsRef { - Arc::new(WrapFloat(self.clone())) - } -} - -impl IntoSeriesOps for Float64Chunked { - fn to_ops(&self) -> SeriesOpsRef { - Arc::new(WrapFloat(self.clone())) - } -} - -macro_rules! into_ops_impl_wrapped { - ($tp:ty) => { - impl IntoSeriesOps for $tp { - fn to_ops(&self) -> SeriesOpsRef { - Arc::new(Wrap(self.clone())) - } - } - }; -} - -into_ops_impl_wrapped!(Utf8Chunked); -into_ops_impl_wrapped!(BooleanChunked); -#[cfg(feature = "dtype-date")] -into_ops_impl_wrapped!(DateChunked); -#[cfg(feature = "dtype-time")] -into_ops_impl_wrapped!(TimeChunked); -#[cfg(feature = "dtype-duration")] -into_ops_impl_wrapped!(DurationChunked); -#[cfg(feature = "dtype-datetime")] -into_ops_impl_wrapped!(DatetimeChunked); -#[cfg(feature = "dtype-struct")] -into_ops_impl_wrapped!(StructChunked); -into_ops_impl_wrapped!(ListChunked); - -#[cfg(feature = "dtype-categorical")] -into_ops_impl_wrapped!(CategoricalChunked); - -#[cfg(feature = "object")] -impl IntoSeriesOps for ObjectChunked { - fn to_ops(&self) -> SeriesOpsRef { - Arc::new(Wrap(self.clone())) - } -} - pub trait AsSeries { fn as_series(&self) -> &Series; } diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index bc233c51a662..c682e27a9661 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -1,26 +1,17 @@ #[cfg(feature = "dtype-date")] use polars_arrow::export::arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY}; use polars_arrow::time_zone::Tz; -use polars_core::chunked_array::ops::arity::try_binary_elementwise_values; +use polars_core::chunked_array::ops::arity::{try_binary_elementwise, try_ternary_elementwise}; use polars_core::prelude::*; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; use crate::prelude::*; -#[derive(Clone, PartialEq, Debug, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct TruncateOptions { - /// Period length - pub every: String, - /// Offset of the window - pub offset: String, -} pub trait PolarsTruncate { fn truncate( &self, - options: &TruncateOptions, tz: Option<&Tz>, + every: &Utf8Chunked, + offset: &str, ambiguous: &Utf8Chunked, ) -> PolarsResult where @@ -31,13 +22,12 @@ pub trait PolarsTruncate { impl PolarsTruncate for DatetimeChunked { fn truncate( &self, - options: &TruncateOptions, tz: Option<&Tz>, + every: &Utf8Chunked, + offset: &str, ambiguous: &Utf8Chunked, ) -> PolarsResult { - let every = Duration::parse(&options.every); - let offset = Duration::parse(&options.offset); - let w = Window::new(every, every, offset); + let offset = Duration::parse(offset); let func = match self.time_unit() { TimeUnit::Nanoseconds => Window::truncate_ns, @@ -45,18 +35,65 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Milliseconds => Window::truncate_ms, }; - let out = match ambiguous.len() { - 1 => match ambiguous.get(0) { - Some(ambiguous) => self - .0 - .try_apply(|timestamp| func(&w, timestamp, tz, ambiguous)), - _ => Ok(self.0.apply(|_| None)), + let out = match (every.len(), ambiguous.len()) { + (1, 1) => match (every.get(0), ambiguous.get(0)) { + (Some(every), Some(ambiguous)) => { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + self.0 + .try_apply(|timestamp| func(&w, timestamp, tz, ambiguous)) + }, + _ => Ok(Int64Chunked::full_null(self.name(), self.len())), + }, + (1, _) => { + if let Some(every) = every.get(0) { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + try_binary_elementwise(self, ambiguous, |opt_timestamp, opt_ambiguous| { + match (opt_timestamp, opt_ambiguous) { + (Some(timestamp), Some(ambiguous)) => { + func(&w, timestamp, tz, ambiguous).map(Some) + }, + _ => Ok(None), + } + }) + } else { + Ok(Int64Chunked::full_null(self.name(), self.len())) + } }, - _ => { - try_binary_elementwise_values(self, ambiguous, |timestamp: i64, ambiguous: &str| { - func(&w, timestamp, tz, ambiguous) - }) + (_, 1) => { + if let Some(ambiguous) = ambiguous.get(0) { + try_binary_elementwise(self, every, |opt_timestamp, opt_every| { + match (opt_timestamp, opt_every) { + (Some(timestamp), Some(every)) => { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + func(&w, timestamp, tz, ambiguous).map(Some) + }, + _ => Ok(None), + } + }) + } else { + Ok(Int64Chunked::full_null(self.name(), self.len())) + } }, + _ => try_ternary_elementwise( + self, + every, + ambiguous, + |opt_timestamp, opt_every, opt_ambiguous| match ( + opt_timestamp, + opt_every, + opt_ambiguous, + ) { + (Some(timestamp), Some(every), Some(ambiguous)) => { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + func(&w, timestamp, tz, ambiguous).map(Some) + }, + _ => Ok(None), + }, + ), }; Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())) } @@ -66,18 +103,42 @@ impl PolarsTruncate for DatetimeChunked { impl PolarsTruncate for DateChunked { fn truncate( &self, - options: &TruncateOptions, _tz: Option<&Tz>, + every: &Utf8Chunked, + offset: &str, _ambiguous: &Utf8Chunked, ) -> PolarsResult { - let every = Duration::parse(&options.every); - let offset = Duration::parse(&options.offset); - let w = Window::new(every, every, offset); - Ok(self - .try_apply(|t| { - const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; - Ok((w.truncate_ms(MSECS_IN_DAY * t as i64, None, "raise")? / MSECS_IN_DAY) as i32) - })? - .into_date()) + let offset = Duration::parse(offset); + let out = + match every.len() { + 1 => { + if let Some(every) = every.get(0) { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + self.try_apply(|t| { + const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; + Ok((w.truncate_ms(MSECS_IN_DAY * t as i64, None, "raise")? + / MSECS_IN_DAY) as i32) + }) + } else { + Ok(Int32Chunked::full_null(self.name(), self.len())) + } + }, + _ => try_binary_elementwise(&self.0, every, |opt_t, opt_every| { + match (opt_t, opt_every) { + (Some(t), Some(every)) => { + const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + Ok(Some( + (w.truncate_ms(MSECS_IN_DAY * t as i64, None, "raise")? + / MSECS_IN_DAY) as i32, + )) + }, + _ => Ok(None), + } + }), + }; + Ok(out?.into_date()) } } diff --git a/crates/polars-time/src/upsample.rs b/crates/polars-time/src/upsample.rs index bd170a6bc213..b1804661ed79 100644 --- a/crates/polars-time/src/upsample.rs +++ b/crates/polars-time/src/upsample.rs @@ -1,6 +1,5 @@ #[cfg(feature = "timezones")] use chrono_tz::Tz; -use polars_core::frame::hash_join::JoinArgs; use polars_core::prelude::*; use polars_core::utils::ensure_sorted_arg; use polars_ops::prelude::*; @@ -10,7 +9,7 @@ use crate::prelude::*; use crate::utils::unlocalize_timestamp; pub trait PolarsUpsample { - /// Upsample a DataFrame at a regular frequency. + /// Upsample a [`DataFrame`] at a regular frequency. /// /// # Arguments /// * `by` - First group by these columns and then upsample for every group @@ -188,7 +187,7 @@ fn upsample_single_impl( TimeUnit::Microseconds => offset.add_us(first, None)?, TimeUnit::Milliseconds => offset.add_ms(first, None)?, }; - let range = date_range_impl( + let range = datetime_range_impl( index_col_name, first, last, diff --git a/crates/polars-time/src/utils.rs b/crates/polars-time/src/utils.rs index 21edd285941f..a5c781ee66c7 100644 --- a/crates/polars-time/src/utils.rs +++ b/crates/polars-time/src/utils.rs @@ -53,7 +53,8 @@ pub(crate) fn localize_timestamp(timestamp: i64, tu: TimeUnit, tz: Tz) -> Polars TimeUnit::Nanoseconds => { Ok( localize_datetime(timestamp_ns_to_datetime(timestamp), &tz, "raise")? - .timestamp_nanos(), + .timestamp_nanos_opt() + .unwrap(), ) }, TimeUnit::Microseconds => { @@ -74,9 +75,9 @@ pub(crate) fn localize_timestamp(timestamp: i64, tu: TimeUnit, tz: Tz) -> Polars #[cfg(feature = "timezones")] pub(crate) fn unlocalize_timestamp(timestamp: i64, tu: TimeUnit, tz: Tz) -> i64 { match tu { - TimeUnit::Nanoseconds => { - unlocalize_datetime(timestamp_ns_to_datetime(timestamp), &tz).timestamp_nanos() - }, + TimeUnit::Nanoseconds => unlocalize_datetime(timestamp_ns_to_datetime(timestamp), &tz) + .timestamp_nanos_opt() + .unwrap(), TimeUnit::Microseconds => { unlocalize_datetime(timestamp_us_to_datetime(timestamp), &tz).timestamp_micros() }, diff --git a/crates/polars-time/src/windows/bounds.rs b/crates/polars-time/src/windows/bounds.rs index c3699be2b278..eba76ac7fb72 100644 --- a/crates/polars-time/src/windows/bounds.rs +++ b/crates/polars-time/src/windows/bounds.rs @@ -7,7 +7,7 @@ pub struct Bounds { } impl Bounds { - /// Create a new `Bounds` and check the input is correct. + /// Create a new [`Bounds`] and check the input is correct. pub(crate) fn new_checked(start: i64, stop: i64) -> Self { assert!( start <= stop, @@ -17,17 +17,19 @@ impl Bounds { Self::new(start, stop) } - /// Create a new `Bounds` without checking input correctness. + /// Create a new [`Bounds`] without checking input correctness. pub(crate) fn new(start: i64, stop: i64) -> Self { Bounds { start, stop } } /// Duration in unit for this Boundary + #[inline] pub(crate) fn duration(&self) -> i64 { self.stop - self.start } // check if unit is within bounds + #[inline] pub(crate) fn is_member(&self, t: i64, closed: ClosedWindow) -> bool { match closed { ClosedWindow::Right => t > self.start && t <= self.stop, @@ -37,6 +39,27 @@ impl Bounds { } } + #[inline] + pub(crate) fn is_member_entry(&self, t: i64, closed: ClosedWindow) -> bool { + match closed { + ClosedWindow::Right => t > self.start, + ClosedWindow::Left => t >= self.start, + ClosedWindow::None => t > self.start, + ClosedWindow::Both => t >= self.start, + } + } + + #[inline] + pub(crate) fn is_member_exit(&self, t: i64, closed: ClosedWindow) -> bool { + match closed { + ClosedWindow::Right => t <= self.stop, + ClosedWindow::Left => t < self.stop, + ClosedWindow::None => t < self.stop, + ClosedWindow::Both => t <= self.stop, + } + } + + #[inline] pub(crate) fn is_future(&self, t: i64, closed: ClosedWindow) -> bool { match closed { ClosedWindow::Left | ClosedWindow::None => self.stop <= t, diff --git a/crates/polars-time/src/windows/calendar.rs b/crates/polars-time/src/windows/calendar.rs index 23c3f657577d..00b0a4b95f11 100644 --- a/crates/polars-time/src/windows/calendar.rs +++ b/crates/polars-time/src/windows/calendar.rs @@ -1,8 +1,3 @@ -use polars_arrow::time_zone::Tz; -use polars_core::prelude::*; - -use crate::prelude::*; - const LAST_DAYS_MONTH: [u32; 12] = [ 31, // January: 31, 28, // February: 28, @@ -34,64 +29,3 @@ pub const NS_MINUTE: i64 = 60 * NS_SECOND; pub const NS_HOUR: i64 = 60 * NS_MINUTE; pub const NS_DAY: i64 = 24 * NS_HOUR; pub const NS_WEEK: i64 = 7 * NS_DAY; - -/// vector of i64 representing temporal values -pub fn temporal_range( - start: i64, - stop: i64, - every: Duration, - closed: ClosedWindow, - tu: TimeUnit, - tz: Option<&Tz>, -) -> PolarsResult> { - let size: usize; - let offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult; - - match tu { - TimeUnit::Nanoseconds => { - size = ((stop - start) / every.duration_ns() + 1) as usize; - offset_fn = Duration::add_ns; - }, - TimeUnit::Microseconds => { - size = ((stop - start) / every.duration_us() + 1) as usize; - offset_fn = Duration::add_us; - }, - TimeUnit::Milliseconds => { - size = ((stop - start) / every.duration_ms() + 1) as usize; - offset_fn = Duration::add_ms; - }, - } - let mut ts = Vec::with_capacity(size); - - let mut t = start; - match closed { - ClosedWindow::Both => { - while t <= stop { - ts.push(t); - t = offset_fn(&every, t, tz)? - } - }, - ClosedWindow::Left => { - while t < stop { - ts.push(t); - t = offset_fn(&every, t, tz)? - } - }, - ClosedWindow::Right => { - t = offset_fn(&every, t, tz)?; - while t <= stop { - ts.push(t); - t = offset_fn(&every, t, tz)? - } - }, - ClosedWindow::None => { - t = offset_fn(&every, t, tz)?; - while t < stop { - ts.push(t); - t = offset_fn(&every, t, tz)? - } - }, - } - debug_assert!(size >= ts.len()); - Ok(ts) -} diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index d87032d03804..46ede4872c8d 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -338,6 +338,14 @@ impl Duration { self.days } + /// Returns whether the duration consists of full days. + /// + /// Note that 24 hours is not considered a full day due to possible + /// daylight savings time transitions. + pub fn is_full_days(&self) -> bool { + self.nsecs == 0 + } + pub fn is_constant_duration(&self) -> bool { self.months == 0 && self.weeks == 0 && self.days == 0 } diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 40d6c0e7bea4..52733f78180b 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -5,6 +5,7 @@ use polars_core::prelude::*; use polars_core::utils::_split_offsets; use polars_core::utils::flatten::flatten_par; use polars_core::POOL; +use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -19,6 +20,14 @@ pub enum ClosedWindow { None, } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Label { + Left, + Right, + DataPoint, +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum StartBy { @@ -58,7 +67,7 @@ impl StartBy { #[allow(clippy::too_many_arguments)] fn update_groups_and_bounds( bounds_iter: BoundsIter<'_>, - mut start_offset: usize, + mut start: usize, time: &[i64], closed_window: ClosedWindow, include_lower_bound: bool, @@ -67,38 +76,25 @@ fn update_groups_and_bounds( upper_bound: &mut Vec, groups: &mut Vec<[IdxSize; 2]>, ) { - for bi in bounds_iter { - let mut skip_window = false; + 'bounds: for bi in bounds_iter { // find starting point of window - while start_offset < time.len() { - let t = time[start_offset]; + for &t in &time[start..time.len().saturating_sub(1)] { + // the window is behind the time values. if bi.is_future(t, closed_window) { - // the window is behind the time values. - skip_window = true; - break; + continue 'bounds; } - if bi.is_member(t, closed_window) { + if bi.is_member_entry(t, closed_window) { break; } - start_offset += 1; - } - if skip_window { - start_offset = start_offset.saturating_sub(1); - continue; - } - if start_offset == time.len() { - start_offset = start_offset.saturating_sub(1); + start += 1; } // find members of this window - let mut i = start_offset; - // start next iteration 1 index back because of boundary conditions. - // e.g. "closed left" could match the next iteration, but did not this one. - start_offset = start_offset.saturating_sub(1); - - // last value - if i == time.len() - 1 { - let t = time[i]; + let mut end = start; + + // last value isn't always added + if end == time.len() - 1 { + let t = time[end]; if bi.is_member(t, closed_window) { if include_lower_bound { lower_bound.push(bi.start); @@ -106,21 +102,17 @@ fn update_groups_and_bounds( if include_upper_bound { upper_bound.push(bi.stop); } - groups.push([i as IdxSize, 1]) + groups.push([end as IdxSize, 1]) } continue; } - - let first = i as IdxSize; - - while i < time.len() { - let t = time[i]; - if !bi.is_member(t, closed_window) { + for &t in &time[end..] { + if !bi.is_member_exit(t, closed_window) { break; } - i += 1; + end += 1; } - let len = (i as IdxSize) - first; + let len = end - start; if include_lower_bound { lower_bound.push(bi.start); @@ -128,7 +120,7 @@ fn update_groups_and_bounds( if include_upper_bound { upper_bound.push(bi.stop); } - groups.push([first, len]) + groups.push([start as IdxSize, len as IdxSize]) } } @@ -223,7 +215,11 @@ pub fn group_by_windows( (groups, lower_bound, upper_bound) } -// this assumes that the given time point is the right endpoint of the window +// t is right at the end of the window +// ------t--- +// [------] +#[inline] +#[allow(clippy::too_many_arguments)] pub(crate) fn group_by_values_iter_lookbehind( period: Duration, offset: Duration, @@ -232,6 +228,7 @@ pub(crate) fn group_by_values_iter_lookbehind( tu: TimeUnit, tz: Option, start_offset: usize, + upper_bound: Option, ) -> impl Iterator> + TrustedLen + '_ { debug_assert!(offset.duration_ns() == period.duration_ns()); debug_assert!(offset.negative); @@ -241,46 +238,60 @@ pub(crate) fn group_by_values_iter_lookbehind( TimeUnit::Milliseconds => Duration::add_ms, }; - let mut last_lookbehind_i = 0; - time[start_offset..] + let upper_bound = upper_bound.unwrap_or(time.len()); + // Use binary search to find the initial start as that is behind. + let mut start = if let Some(&t) = time.get(start_offset) { + let lower = add(&offset, t, tz.as_ref()).unwrap(); + let upper = add(&period, lower, tz.as_ref()).unwrap(); + let b = Bounds::new(lower, upper); + let slice = &time[..start_offset]; + slice.partition_point(|v| !b.is_member(*v, closed_window)) + } else { + 0 + }; + let mut end = start; + time[start_offset..upper_bound] .iter() .enumerate() - .map(move |(mut i, lower)| { + .map(move |(mut i, t)| { i += start_offset; - let lower = add(&offset, *lower, tz.as_ref())?; + let lower = add(&offset, *t, tz.as_ref())?; let upper = add(&period, lower, tz.as_ref())?; let b = Bounds::new(lower, upper); - // we have a complete lookbehind so we know that `i` is the upper bound. - // Safety - // we are in bounds - let slice = { - #[cfg(debug_assertions)] - { - &time[last_lookbehind_i..i] - } - #[cfg(not(debug_assertions))] - { - unsafe { time.get_unchecked(last_lookbehind_i..i) } + for &t in unsafe { time.get_unchecked_release(start..i) } { + if b.is_member_entry(t, closed_window) { + break; } - }; - let offset = slice.partition_point(|v| !b.is_member(*v, closed_window)); - - let lookbehind_i = offset + last_lookbehind_i; - // -1 for window boundary effects - last_lookbehind_i = lookbehind_i.saturating_sub(1); + start += 1; + } - let mut len = i - lookbehind_i; - if matches!(closed_window, ClosedWindow::Right | ClosedWindow::Both) { - len += 1; + // faster path, check if `i` is member. + if b.is_member_exit(*t, closed_window) { + end = i; + } else { + end = std::cmp::max(end, start); + } + // we still must loop to consume duplicates + for &t in unsafe { time.get_unchecked_release(end..) } { + if !b.is_member_exit(t, closed_window) { + break; + } + end += 1; } - Ok((lookbehind_i as IdxSize, len as IdxSize)) + let len = end - start; + let offset = start as IdxSize; + + Ok((offset, len as IdxSize)) }) } // this one is correct for all lookbehind/lookaheads, but is slower +// window is completely behind t and t itself is not a member +// ---------------t--- +// [---] pub(crate) fn group_by_values_iter_window_behind_t( period: Duration, offset: Duration, @@ -295,8 +306,9 @@ pub(crate) fn group_by_values_iter_window_behind_t( TimeUnit::Milliseconds => Duration::add_ms, }; - let mut lagging_offset = 0; - time.iter().enumerate().map(move |(i, lower)| { + let mut start = 0; + let mut end = start; + time.iter().map(move |lower| { let lower = add(&offset, *lower, tz.as_ref())?; let upper = add(&period, lower, tz.as_ref())?; @@ -304,33 +316,32 @@ pub(crate) fn group_by_values_iter_window_behind_t( if b.is_future(time[0], closed_window) { Ok((0, 0)) } else { - // find starting point of window - // we can start searching from lagging offset as that is the minimum boundary because data is sorted - // and every iteration this boundary shifts right - // we cannot use binary search as a window is not binary, - // it is false left from the window, true inside, and false right of the window - let mut count = 0; - for &t in &time[lagging_offset..] { - if b.is_member(t, closed_window) || lagging_offset + count == i { + for &t in &time[start..] { + if b.is_member_entry(t, closed_window) { break; } - count += 1 + start += 1; } - if lagging_offset + count != i { - lagging_offset += count; + + end = std::cmp::max(start, end); + for &t in &time[end..] { + if !b.is_member_exit(t, closed_window) { + break; + } + end += 1; } - // Safety - // we just iterated over value i. - let slice = unsafe { time.get_unchecked(lagging_offset..) }; - let len = slice.partition_point(|v| b.is_member(*v, closed_window)); + let len = end - start; + let offset = start as IdxSize; - Ok((lagging_offset as IdxSize, len as IdxSize)) + Ok((offset, len as IdxSize)) } }) } -// this one is correct for all lookbehind/lookaheads, but is slower +// window is with -1 periods of t +// ----t--- +// [---] pub(crate) fn group_by_values_iter_partial_lookbehind( period: Duration, offset: Duration, @@ -345,68 +356,41 @@ pub(crate) fn group_by_values_iter_partial_lookbehind( TimeUnit::Milliseconds => Duration::add_ms, }; - let mut lagging_offset = 0; + let mut start = 0; + let mut end = start; time.iter().enumerate().map(move |(i, lower)| { let lower = add(&offset, *lower, tz.as_ref())?; let upper = add(&period, lower, tz.as_ref())?; let b = Bounds::new(lower, upper); - for &t in &time[lagging_offset..] { - if b.is_member(t, closed_window) || lagging_offset == i { + for &t in &time[start..] { + if b.is_member_entry(t, closed_window) || start == i { break; } - lagging_offset += 1; + start += 1; } - // Safety - // we just iterated over value i. - let slice = unsafe { time.get_unchecked(lagging_offset..) }; - let len = slice.partition_point(|v| b.is_member(*v, closed_window)); + end = std::cmp::max(start, end); + for &t in &time[end..] { + if !b.is_member_exit(t, closed_window) { + break; + } + end += 1; + } + + let len = end - start; + let offset = start as IdxSize; - Ok((lagging_offset as IdxSize, len as IdxSize)) + Ok((offset, len as IdxSize)) }) } #[allow(clippy::too_many_arguments)] -pub(crate) fn group_by_values_iter_partial_lookahead( - period: Duration, - offset: Duration, - time: &[i64], - closed_window: ClosedWindow, - tu: TimeUnit, - tz: Option, - start_offset: usize, - upper_bound: Option, -) -> impl Iterator> + TrustedLen + '_ { - let upper_bound = upper_bound.unwrap_or(time.len()); - debug_assert!(!offset.negative); - - let add = match tu { - TimeUnit::Nanoseconds => Duration::add_ns, - TimeUnit::Microseconds => Duration::add_us, - TimeUnit::Milliseconds => Duration::add_ms, - }; - - time[start_offset..upper_bound] - .iter() - .enumerate() - .map(move |(mut i, lower)| { - i += start_offset; - let lower = add(&offset, *lower, tz.as_ref())?; - let upper = add(&period, lower, tz.as_ref())?; - - let b = Bounds::new(lower, upper); - - debug_assert!(i < time.len()); - let slice = unsafe { time.get_unchecked(i..) }; - let len = slice.partition_point(|v| b.is_member(*v, closed_window)); - - Ok((i as IdxSize, len as IdxSize)) - }) -} -#[allow(clippy::too_many_arguments)] -pub(crate) fn group_by_values_iter_full_lookahead( +// window is completely ahead of t and t itself is not a member +// --t----------- +// [---] +pub(crate) fn group_by_values_iter_lookahead( period: Duration, offset: Duration, time: &[i64], @@ -417,55 +401,71 @@ pub(crate) fn group_by_values_iter_full_lookahead( upper_bound: Option, ) -> impl Iterator> + TrustedLen + '_ { let upper_bound = upper_bound.unwrap_or(time.len()); - debug_assert!(!offset.negative); let add = match tu { TimeUnit::Nanoseconds => Duration::add_ns, TimeUnit::Microseconds => Duration::add_us, TimeUnit::Milliseconds => Duration::add_ms, }; + let mut start = start_offset; + let mut end = start; - time[start_offset..upper_bound] - .iter() - .enumerate() - .map(move |(mut i, lower)| { - i += start_offset; - let lower = add(&offset, *lower, tz.as_ref())?; - let upper = add(&period, lower, tz.as_ref())?; + time[start_offset..upper_bound].iter().map(move |lower| { + let lower = add(&offset, *lower, tz.as_ref())?; + let upper = add(&period, lower, tz.as_ref())?; - let b = Bounds::new(lower, upper); + let b = Bounds::new(lower, upper); - // find starting point of window - for &t in &time[i..] { - if b.is_member(t, closed_window) { - break; - } - i += 1; + for &t in &time[start..] { + if b.is_member_entry(t, closed_window) { + break; } - if i >= time.len() { - return Ok((i as IdxSize, 0)); + start += 1; + } + + end = std::cmp::max(start, end); + for &t in &time[end..] { + if !b.is_member_exit(t, closed_window) { + break; } + end += 1; + } - let slice = unsafe { time.get_unchecked(i..) }; - let len = slice.partition_point(|v| b.is_member(*v, closed_window)); + let len = end - start; + let offset = start as IdxSize; - Ok((i as IdxSize, len as IdxSize)) - }) + Ok((offset, len as IdxSize)) + }) } #[cfg(feature = "rolling_window")] -pub(crate) fn group_by_values_iter<'a>( +#[inline] +pub(crate) fn group_by_values_iter( period: Duration, - time: &'a [i64], + time: &[i64], closed_window: ClosedWindow, tu: TimeUnit, tz: Option, -) -> Box> + 'a> { +) -> impl Iterator> + TrustedLen + '_ { let mut offset = period; offset.negative = true; // t is at the right endpoint of the window - let iter = group_by_values_iter_lookbehind(period, offset, time, closed_window, tu, tz, 0); - Box::new(iter) + group_by_values_iter_lookbehind(period, offset, time, closed_window, tu, tz, 0, None) +} + +/// Checks if the boundary elements don't split on duplicates +fn check_splits(time: &[i64], thread_offsets: &[(usize, usize)]) -> bool { + if time.is_empty() { + return true; + } + let mut valid = true; + for window in thread_offsets.windows(2) { + let left_block_end = window[0].0 + window[0].1; + let right_block_start = window[1].0; + + valid &= time[left_block_end] != time[right_block_start]; + } + valid } /// Different from `group_by_windows`, where define window buckets and search which values fit that @@ -481,7 +481,11 @@ pub fn group_by_values( tu: TimeUnit, tz: Option, ) -> PolarsResult { - let thread_offsets = _split_offsets(time.len(), POOL.current_num_threads()); + let mut thread_offsets = _split_offsets(time.len(), POOL.current_num_threads()); + // there are duplicates in the splits, so we opt for a single partition + if !check_splits(time, &thread_offsets) { + thread_offsets = _split_offsets(time.len(), 1) + } // we have a (partial) lookbehind window if offset.negative { @@ -499,11 +503,12 @@ pub fn group_by_values( let iter = group_by_values_iter_lookbehind( period, offset, - &time[..upper_bound], + time, closed_window, tu, tz, base_offset, + Some(upper_bound), ); iter.map(|result| result.map(|(offset, len)| [offset, len])) .collect::>>() @@ -556,7 +561,7 @@ pub fn group_by_values( .map(|(base_offset, len)| { let lower_bound = base_offset; let upper_bound = base_offset + len; - let iter = group_by_values_iter_full_lookahead( + let iter = group_by_values_iter_lookahead( period, offset, time, @@ -584,7 +589,7 @@ pub fn group_by_values( .map(|(base_offset, len)| { let lower_bound = base_offset; let upper_bound = base_offset + len; - let iter = group_by_values_iter_partial_lookahead( + let iter = group_by_values_iter_lookahead( period, offset, time, diff --git a/crates/polars-time/src/windows/test.rs b/crates/polars-time/src/windows/test.rs index 652746fbfa93..cc7a775352cb 100644 --- a/crates/polars-time/src/windows/test.rs +++ b/crates/polars-time/src/windows/test.rs @@ -2,6 +2,7 @@ use chrono::prelude::*; use polars_arrow::export::arrow::temporal_conversions::timestamp_ns_to_datetime; use polars_core::prelude::*; +use crate::date_range::datetime_range_i64; use crate::prelude::*; #[test] @@ -15,9 +16,9 @@ fn test_date_range() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap(); - let dates = temporal_range_vec( - start.timestamp_nanos(), - end.timestamp_nanos(), + let dates = datetime_range_i64( + start.timestamp_nanos_opt().unwrap(), + end.timestamp_nanos_opt().unwrap(), Duration::parse("1mo"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -31,7 +32,12 @@ fn test_date_range() { NaiveDate::from_ymd_opt(2022, 4, 1).unwrap(), ] .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); assert_eq!(dates, expected); } @@ -46,9 +52,9 @@ fn test_feb_date_range() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap(); - let dates = temporal_range_vec( - start.timestamp_nanos(), - end.timestamp_nanos(), + let dates = datetime_range_i64( + start.timestamp_nanos_opt().unwrap(), + end.timestamp_nanos_opt().unwrap(), Duration::parse("1mo"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -60,7 +66,12 @@ fn test_feb_date_range() { NaiveDate::from_ymd_opt(2022, 3, 1).unwrap(), ] .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); assert_eq!(dates, expected); } @@ -88,7 +99,12 @@ fn test_groups_large_interval() { ]; let ts = dates .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); let dur = Duration::parse("2d"); @@ -140,7 +156,8 @@ fn test_offset() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap() - .timestamp_nanos(); + .timestamp_nanos_opt() + .unwrap(); let w = Window::new( Duration::parse("5m"), Duration::parse("5m"), @@ -152,7 +169,8 @@ fn test_offset() { .unwrap() .and_hms_opt(23, 58, 0) .unwrap() - .timestamp_nanos(); + .timestamp_nanos_opt() + .unwrap(); assert_eq!(b.start, start); } @@ -167,9 +185,9 @@ fn test_boundaries() { .and_hms_opt(3, 0, 0) .unwrap(); - let ts = temporal_range_vec( - start.timestamp_nanos(), - stop.timestamp_nanos(), + let ts = datetime_range_i64( + start.timestamp_nanos_opt().unwrap(), + stop.timestamp_nanos_opt().unwrap(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -188,7 +206,7 @@ fn test_boundaries() { // earliest bound is first datapoint: 2021-12-16 00:00:00 let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); - assert_eq!(b.start, start.timestamp_nanos()); + assert_eq!(b.start, start.timestamp_nanos_opt().unwrap()); // test closed: "both" (includes both ends of the interval) let (groups, lower, higher) = group_by_windows( @@ -225,9 +243,9 @@ fn test_boundaries() { assert_eq!( g, &[ - t0.timestamp_nanos(), - t1.timestamp_nanos(), - t2.timestamp_nanos() + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap(), + t2.timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -240,7 +258,10 @@ fn test_boundaries() { .unwrap(); assert_eq!( &[lower[0], higher[0]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); // 2nd group @@ -266,9 +287,9 @@ fn test_boundaries() { assert_eq!( g, &[ - t0.timestamp_nanos(), - t1.timestamp_nanos(), - t2.timestamp_nanos() + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap(), + t2.timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -281,7 +302,10 @@ fn test_boundaries() { .unwrap(); assert_eq!( &[lower[1], higher[1]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); assert_eq!(groups[2], [4, 3]); @@ -343,9 +367,9 @@ fn test_boundaries_2() { .and_hms_opt(4, 0, 0) .unwrap(); - let ts = temporal_range_vec( - start.timestamp_nanos(), - stop.timestamp_nanos(), + let ts = datetime_range_i64( + start.timestamp_nanos_opt().unwrap(), + stop.timestamp_nanos_opt().unwrap(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -365,7 +389,10 @@ fn test_boundaries_2() { // earliest bound is first datapoint: 2021-12-16 00:00:00 + 30m offset: 2021-12-16 00:30:00 let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); - assert_eq!(b.start, start.timestamp_nanos() + offset.duration_ns()); + assert_eq!( + b.start, + start.timestamp_nanos_opt().unwrap() + offset.duration_ns() + ); let (groups, lower, higher) = group_by_windows( w, @@ -395,7 +422,13 @@ fn test_boundaries_2() { .unwrap() .and_hms_opt(1, 0, 0) .unwrap(); - assert_eq!(g, &[t0.timestamp_nanos(), t1.timestamp_nanos()]); + assert_eq!( + g, + &[ + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap() + ] + ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(0, 30, 0) @@ -406,7 +439,10 @@ fn test_boundaries_2() { .unwrap(); assert_eq!( &[lower[0], higher[0]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); // 2nd group @@ -425,7 +461,13 @@ fn test_boundaries_2() { .unwrap() .and_hms_opt(3, 0, 0) .unwrap(); - assert_eq!(g, &[t0.timestamp_nanos(), t1.timestamp_nanos()]); + assert_eq!( + g, + &[ + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap() + ] + ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(2, 30, 0) @@ -436,7 +478,10 @@ fn test_boundaries_2() { .unwrap(); assert_eq!( &[lower[1], higher[1]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); } @@ -451,7 +496,7 @@ fn test_boundaries_ms() { .and_hms_opt(3, 0, 0) .unwrap(); - let ts = temporal_range_vec( + let ts = datetime_range_i64( start.timestamp_millis(), stop.timestamp_millis(), Duration::parse("30m"), @@ -627,7 +672,7 @@ fn test_rolling_lookback() { .unwrap() .and_hms_opt(4, 0, 0) .unwrap(); - let dates = temporal_range_vec( + let dates = datetime_range_i64( start.timestamp_millis(), end.timestamp_millis(), Duration::parse("30m"), @@ -709,10 +754,18 @@ fn test_rolling_lookback() { ClosedWindow::None, ] { let offset = Duration::parse("-2h"); - let g0 = - group_by_values_iter_lookbehind(period, offset, &dates, closed_window, tu, None, 0) - .collect::>>() - .unwrap(); + let g0 = group_by_values_iter_lookbehind( + period, + offset, + &dates, + closed_window, + tu, + None, + 0, + None, + ) + .collect::>>() + .unwrap(); let g1 = group_by_values_iter_partial_lookbehind( period, offset, @@ -822,7 +875,12 @@ fn test_group_by_windows_offsets_3776() { ]; let ts = dates .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); let window = Window::new( diff --git a/crates/polars-time/src/windows/window.rs b/crates/polars-time/src/windows/window.rs index 48e5fc3bcf69..ebfeabb2d6df 100644 --- a/crates/polars-time/src/windows/window.rs +++ b/crates/polars-time/src/windows/window.rs @@ -61,23 +61,23 @@ impl Window { } /// Round the given ns timestamp by the window boundary. - pub fn round_ns(&self, t: i64, tz: Option<&Tz>) -> PolarsResult { + pub fn round_ns(&self, t: i64, tz: Option<&Tz>, ambiguous: &str) -> PolarsResult { let t = t + self.every.duration_ns() / 2_i64; - self.truncate_ns(t, tz, "raise") + self.truncate_ns(t, tz, ambiguous) } /// Round the given us timestamp by the window boundary. - pub fn round_us(&self, t: i64, tz: Option<&Tz>) -> PolarsResult { + pub fn round_us(&self, t: i64, tz: Option<&Tz>, ambiguous: &str) -> PolarsResult { let t = t + self.every.duration_ns() / (2 * timeunit_scale(ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Microsecond) as i64); - self.truncate_us(t, tz, "raise") + self.truncate_us(t, tz, ambiguous) } /// Round the given ms timestamp by the window boundary. - pub fn round_ms(&self, t: i64, tz: Option<&Tz>) -> PolarsResult { + pub fn round_ms(&self, t: i64, tz: Option<&Tz>, ambiguous: &str) -> PolarsResult { let t = t + self.every.duration_ns() / (2 * timeunit_scale(ArrowTimeUnit::Nanosecond, ArrowTimeUnit::Millisecond) as i64); - self.truncate_ms(t, tz, "raise") + self.truncate_ms(t, tz, ambiguous) } /// returns the bounds for the earliest window bounds diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index f48deb4d49d2..c017f731739c 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -9,9 +9,10 @@ repository = { workspace = true } description = "Private utils for the Polars DataFrame library" [dependencies] -polars-error = { version = "0.32.0", path = "../polars-error" } +polars-error = { workspace = true } ahash = { workspace = true } +bytemuck = { workspace = true } hashbrown = { workspace = true } num-traits = { workspace = true } once_cell = { workspace = true } diff --git a/crates/polars-utils/README.md b/crates/polars-utils/README.md index f0994cc71971..2f200a67f1a6 100644 --- a/crates/polars-utils/README.md +++ b/crates/polars-utils/README.md @@ -1,5 +1,5 @@ # polars-utils -`polars-utils` is a sub-crate that provides private utils for the Polars dataframe library. +`polars-utils` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, supplying private utility functions. -Not intended for external usage +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-utils/src/cache.rs b/crates/polars-utils/src/cache.rs new file mode 100644 index 000000000000..d16982ee39f2 --- /dev/null +++ b/crates/polars-utils/src/cache.rs @@ -0,0 +1,266 @@ +use std::borrow::Borrow; +use std::cell::Cell; +use std::hash::Hash; +use std::mem::MaybeUninit; + +use ahash::RandomState; +use bytemuck::allocation::zeroed_vec; +use bytemuck::Zeroable; + +use crate::aliases::PlHashMap; + +pub struct CachedFunc { + func: F, + cache: PlHashMap, +} + +impl CachedFunc +where + F: FnMut(T) -> R, + T: std::hash::Hash + Eq + Clone, + R: Copy, +{ + pub fn new(func: F) -> Self { + Self { + func, + cache: PlHashMap::with_capacity_and_hasher(0, Default::default()), + } + } + + pub fn eval(&mut self, x: T, use_cache: bool) -> R { + if use_cache { + *self + .cache + .entry(x) + .or_insert_with_key(|xr| (self.func)(xr.clone())) + } else { + (self.func)(x) + } + } +} + +/// A fixed-size cache optimized for access speed. Does not implement LRU or use +/// a full hash table due to cost, instead we assign two pseudorandom slots +/// based on the hash of the key, and if both are full we evict the one that had +/// the older last access. +const MIN_FAST_FIXED_CACHE_SIZE: usize = 16; + +#[derive(Clone)] +pub struct FastFixedCache { + slots: Vec>, + access_ctr: Cell, + shift: u32, + hash_builder: RandomState, +} + +impl Default for FastFixedCache { + fn default() -> Self { + Self::new(MIN_FAST_FIXED_CACHE_SIZE) + } +} + +impl FastFixedCache { + pub fn new(n: usize) -> Self { + let n = (n.max(MIN_FAST_FIXED_CACHE_SIZE)).next_power_of_two(); + Self { + slots: zeroed_vec(n), + access_ctr: Cell::new(1), + shift: 64 - n.ilog2(), + hash_builder: RandomState::new(), + } + } + + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + { + unsafe { + // SAFETY: slot_idx from raw_get is valid and occupied. + let slot_idx = self.raw_get(self.hash(key), key)?; + let slot = self.slots.get_unchecked(slot_idx); + Some(slot.value.assume_init_ref()) + } + } + + pub fn get_mut(&mut self, key: &Q) -> Option<&mut V> + where + K: Borrow, + { + unsafe { + // SAFETY: slot_idx from raw_get is valid and occupied. + let slot_idx = self.raw_get(self.hash(&key), key)?; + let slot = self.slots.get_unchecked_mut(slot_idx); + Some(slot.value.assume_init_mut()) + } + } + + pub fn insert(&mut self, key: K, value: V) -> &mut V { + unsafe { self.raw_insert(self.hash(&key), key, value) } + } + + pub fn get_or_insert_with(&mut self, key: &Q, f: F) -> &mut V + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + F: FnOnce(&K) -> V, + { + unsafe { + let h = self.hash(key); + if let Some(slot_idx) = self.raw_get(self.hash(&key), key) { + let slot = self.slots.get_unchecked_mut(slot_idx); + return slot.value.assume_init_mut(); + } + + let key = key.to_owned(); + let val = f(&key); + self.raw_insert(h, key, val) + } + } + + pub fn try_get_or_insert_with(&mut self, key: &Q, f: F) -> Result<&mut V, E> + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + F: FnOnce(&K) -> Result, + { + unsafe { + let h = self.hash(key); + if let Some(slot_idx) = self.raw_get(self.hash(&key), key) { + let slot = self.slots.get_unchecked_mut(slot_idx); + return Ok(slot.value.assume_init_mut()); + } + + let key = key.to_owned(); + let val = f(&key)?; + Ok(self.raw_insert(h, key, val)) + } + } + + unsafe fn raw_get(&self, h: HashResult, key: &Q) -> Option + where + K: Borrow, + { + unsafe { + // SAFETY: we assume h is a HashResult from self.hash with valid indices + // and we check slot.last_access != 0 before assuming the slot is initialized. + let slot = self.slots.get_unchecked(h.i1); + if slot.last_access.get() != 0 + && slot.hash_tag == h.tag + && slot.key.assume_init_ref().borrow() == key + { + slot.last_access.set(self.new_access_ctr()); + return Some(h.i1); + } + + let slot = self.slots.get_unchecked(h.i2); + if slot.last_access.get() != 0 + && slot.hash_tag == h.tag + && slot.key.assume_init_ref().borrow() == key + { + slot.last_access.set(self.new_access_ctr()); + return Some(h.i2); + } + } + + None + } + + unsafe fn raw_insert(&mut self, h: HashResult, key: K, value: V) -> &mut V { + let last_access = self.new_access_ctr(); + unsafe { + // SAFETY: i1 and i2 are valid indices and older_idx returns one of them. + let idx = self.older_idx(h.i1, h.i2); + let slot = self.slots.get_unchecked_mut(idx); + + // Drop impl takes care of dropping old value, if occupied. + *slot = CacheSlot { + last_access: Cell::new(last_access), + hash_tag: h.tag, + key: MaybeUninit::new(key), + value: MaybeUninit::new(value), + }; + slot.value.assume_init_mut() + } + } + + /// Returns the older index based on access time, where unoccupied slots + /// are considered infinitely old. + unsafe fn older_idx(&mut self, i1: usize, i2: usize) -> usize { + let age1 = self.slots.get_unchecked(i1).last_access.get(); + let age2 = self.slots.get_unchecked(i2).last_access.get(); + match (age1, age2) { + (0, _) => i1, + (_, 0) => i2, + // This takes into account the wrap-around of our access_ctr. + // We assume that the smaller value between age1.wrapping_sub(age2) + // and age2.wrapping_sub(age1) is the true delta. Thus if + // age1.wrapping_sub(age2) is >= 1 << 31, we know that + // age2.wrapping_sub(age1) is smaller than it, and we also + // immediately know that age1 is older. + _ if age1.wrapping_sub(age2) >= (1 << 31) => i1, + _ => i2, + } + } + + fn new_access_ctr(&self) -> u32 { + // This keeps the access_ctr always odd, so we don't hit access_ctr == 0, + // which would leak values. + self.access_ctr.replace(self.access_ctr.get() + 2) + } + + /// Computes the hash tag and two slot indexes for a given key. + fn hash(&self, key: &Q) -> HashResult { + // An instantiation of Dietzfelbinger's multiply-shift, see 2.3 of + // https://arxiv.org/pdf/1504.06804.pdf. + // The magic constants are just two randomly chosen odd 64-bit numbers. + let h = self.hash_builder.hash_one(key); + let tag = h as u32; + let i1 = (h.wrapping_mul(0x2e623b55bc0c9073) >> self.shift) as usize; + let i2 = (h.wrapping_mul(0x921932b06a233d39) >> self.shift) as usize; + HashResult { tag, i1, i2 } + } +} + +struct HashResult { + tag: u32, + i1: usize, + i2: usize, +} + +struct CacheSlot { + // If last_access != 0, the rest is assumed to be initialized. + last_access: Cell, + hash_tag: u32, + key: MaybeUninit, + value: MaybeUninit, +} + +unsafe impl Zeroable for CacheSlot {} + +impl Drop for CacheSlot { + fn drop(&mut self) { + unsafe { + if self.last_access.get() != 0 { + self.key.assume_init_drop(); + self.value.assume_init_drop(); + } + } + } +} + +impl Clone for CacheSlot { + fn clone(&self) -> Self { + unsafe { + if self.last_access.get() != 0 { + Self { + last_access: self.last_access.clone(), + hash_tag: self.hash_tag, + key: MaybeUninit::new(self.key.assume_init_ref().clone()), + value: MaybeUninit::new(self.value.assume_init_ref().clone()), + } + } else { + Self::zeroed() + } + } + } +} diff --git a/crates/polars-utils/src/cell.rs b/crates/polars-utils/src/cell.rs index f9b8907be148..ae6b6ae461fc 100644 --- a/crates/polars-utils/src/cell.rs +++ b/crates/polars-utils/src/cell.rs @@ -4,11 +4,11 @@ use std::cell::UnsafeCell; /// [`UnsafeCell`], but [`Sync`]. /// -/// This is just an `UnsafeCell`, except it implements `Sync` -/// if `T` implements `Sync`. +/// This is just an [`UnsafeCell`], except it implements [`Sync`] +/// if `T` implements [`Sync`]. /// -/// `UnsafeCell` doesn't implement `Sync`, to prevent accidental misuse. -/// You can use `SyncUnsafeCell` instead of `UnsafeCell` to allow it to be +/// [`UnsafeCell`] doesn't implement [`Sync`], to prevent accidental misuse. +/// You can use [`SyncUnsafeCell`] instead of [`UnsafeCell`] to allow it to be /// shared between threads, if that's intentional. /// Providing proper synchronization is still the task of the user, /// making this type just as unsafe to use. @@ -22,7 +22,7 @@ pub struct SyncUnsafeCell { unsafe impl Sync for SyncUnsafeCell {} impl SyncUnsafeCell { - /// Constructs a new instance of `SyncUnsafeCell` which will wrap the specified value. + /// Constructs a new instance of [`SyncUnsafeCell`] which will wrap the specified value. #[inline] pub fn new(value: T) -> Self { Self { @@ -51,7 +51,7 @@ impl SyncUnsafeCell { /// Returns a mutable reference to the underlying data. /// - /// This call borrows the `SyncUnsafeCell` mutably (at compile-time) which + /// This call borrows the [`SyncUnsafeCell`] mutably (at compile-time) which /// guarantees that we possess the only reference. #[inline] pub fn get_mut(&mut self) -> &mut T { @@ -78,7 +78,7 @@ impl Default for SyncUnsafeCell { } impl From for SyncUnsafeCell { - /// Creates a new `SyncUnsafeCell` containing the given value. + /// Creates a new [`SyncUnsafeCell`] containing the given value. fn from(t: T) -> SyncUnsafeCell { SyncUnsafeCell::new(t) } diff --git a/crates/polars-utils/src/functions.rs b/crates/polars-utils/src/functions.rs index 88cc128e6a6a..47bece31d73d 100644 --- a/crates/polars-utils/src/functions.rs +++ b/crates/polars-utils/src/functions.rs @@ -1,4 +1,4 @@ -use std::hash::{BuildHasher, Hash, Hasher}; +use std::hash::{BuildHasher, Hash}; // Faster than collecting from a flattened iterator. pub fn flatten>(bufs: &[R], len: Option) -> Vec { @@ -20,7 +20,5 @@ pub fn hash_to_partition(h: u64, n_partitions: usize) -> usize { #[inline] pub fn get_hash(value: T, hb: &B) -> u64 { - let mut hasher = hb.build_hasher(); - value.hash(&mut hasher); - hasher.finish() + hb.hash_one(value) } diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs new file mode 100644 index 000000000000..52b820016db5 --- /dev/null +++ b/crates/polars-utils/src/index.rs @@ -0,0 +1,20 @@ +use polars_error::{polars_bail, polars_ensure, PolarsResult}; + +use crate::IdxSize; + +pub fn check_bounds(idx: &[IdxSize], len: IdxSize) -> PolarsResult<()> { + // We iterate in large uninterrupted chunks to help auto-vectorization. + let mut in_bounds = true; + for chunk in idx.chunks(1024) { + for i in chunk { + if *i >= len { + in_bounds = false; + } + } + if !in_bounds { + break; + } + } + polars_ensure!(in_bounds, ComputeError: "indices are out of bounds"); + Ok(()) +} diff --git a/crates/polars-utils/src/iter/enumerate_idx.rs b/crates/polars-utils/src/iter/enumerate_idx.rs index 997f9333a5ce..8b17b7ef4038 100644 --- a/crates/polars-utils/src/iter/enumerate_idx.rs +++ b/crates/polars-utils/src/iter/enumerate_idx.rs @@ -1,3 +1,5 @@ +use num_traits::{FromPrimitive, One}; + use crate::IdxSize; /// An iterator that yields the current count and the element during iteration. @@ -9,16 +11,17 @@ use crate::IdxSize; /// [`Iterator`]: trait.Iterator.html #[derive(Clone, Debug)] #[must_use = "iterators are lazy and do nothing unless consumed"] -pub struct EnumerateIdx { +pub struct EnumerateIdx { iter: I, - count: IdxSize, + count: IdxType, } -impl Iterator for EnumerateIdx +impl Iterator for EnumerateIdx where I: Iterator, + IdxType: std::ops::Add + FromPrimitive + std::ops::AddAssign + One + Copy, { - type Item = (IdxSize, ::Item); + type Item = (IdxType, ::Item); /// # Overflow Behavior /// @@ -30,10 +33,10 @@ where /// /// Might panic if the index of the element overflows a `idx`. #[inline] - fn next(&mut self) -> Option<(IdxSize, ::Item)> { + fn next(&mut self) -> Option { let a = self.iter.next()?; let i = self.count; - self.count += 1; + self.count += IdxType::one(); Some((i, a)) } @@ -43,10 +46,10 @@ where } #[inline] - fn nth(&mut self, n: usize) -> Option<(IdxSize, I::Item)> { + fn nth(&mut self, n: usize) -> Option { let a = self.iter.nth(n)?; - let i = self.count + (n as IdxSize); - self.count = i + 1; + let i = self.count + IdxType::from_usize(n).unwrap(); + self.count = i + IdxType::one(); Some((i, a)) } @@ -56,32 +59,34 @@ where } } -impl DoubleEndedIterator for EnumerateIdx +impl DoubleEndedIterator for EnumerateIdx where I: ExactSizeIterator + DoubleEndedIterator, + IdxType: std::ops::Add + FromPrimitive + std::ops::AddAssign + One + Copy, { #[inline] - fn next_back(&mut self) -> Option<(IdxSize, ::Item)> { + fn next_back(&mut self) -> Option<(IdxType, ::Item)> { let a = self.iter.next_back()?; - let len = self.iter.len(); + let len = IdxType::from_usize(self.iter.len()).unwrap(); // Can safely add, `ExactSizeIterator` promises that the number of // elements fits into a `usize`. - Some((self.count + len as IdxSize, a)) + Some((self.count + len, a)) } #[inline] - fn nth_back(&mut self, n: usize) -> Option<(IdxSize, ::Item)> { + fn nth_back(&mut self, n: usize) -> Option<(IdxType, ::Item)> { let a = self.iter.nth_back(n)?; - let len = self.iter.len(); + let len = IdxType::from_usize(self.iter.len()).unwrap(); // Can safely add, `ExactSizeIterator` promises that the number of // elements fits into a `usize`. - Some((self.count + len as IdxSize, a)) + Some((self.count + len, a)) } } -impl ExactSizeIterator for EnumerateIdx +impl ExactSizeIterator for EnumerateIdx where I: ExactSizeIterator, + IdxType: std::ops::Add + FromPrimitive + std::ops::AddAssign + One + Copy, { fn len(&self) -> usize { self.iter.len() @@ -89,7 +94,17 @@ where } pub trait EnumerateIdxTrait: Iterator { - fn enumerate_idx(self) -> EnumerateIdx + fn enumerate_idx(self) -> EnumerateIdx + where + Self: Sized, + { + EnumerateIdx { + iter: self, + count: 0, + } + } + + fn enumerate_u32(self) -> EnumerateIdx where Self: Sized, { diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 7a4941c4a09f..e5fceb3066bb 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] pub mod arena; pub mod atomic; +pub mod cache; pub mod cell; pub mod contention_pool; mod error; @@ -28,5 +29,7 @@ pub mod vec; #[cfg(target_family = "wasm")] pub mod wasm; +pub mod index; pub mod io; + pub use io::open_file; diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 986190b718d2..19de558f8177 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -11,13 +11,13 @@ repository = { workspace = true } description = "DataFrame library based on Apache Arrow" [dependencies] -polars-algo = { version = "0.32.0", path = "../polars-algo", optional = true } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["docs"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", features = [], default-features = false, optional = true } -polars-lazy = { version = "0.32.0", path = "../polars-lazy", features = [], default-features = false, optional = true } -polars-ops = { version = "0.32.0", path = "../polars-ops" } -polars-sql = { version = "0.32.0", path = "../polars-sql", default-features = false, optional = true } -polars-time = { version = "0.32.0", path = "../polars-time", default-features = false, optional = true } +polars-algo = { workspace = true, optional = true } +polars-core = { workspace = true } +polars-io = { workspace = true, optional = true } +polars-lazy = { workspace = true, default-features = false, optional = true } +polars-ops = { workspace = true } +polars-sql = { workspace = true, optional = true } +polars-time = { workspace = true, optional = true } [dev-dependencies] ahash = { workspace = true } @@ -27,19 +27,18 @@ rand = { workspace = true } version_check = { workspace = true } # enable js feature for getrandom to work in wasm -[target.'cfg(target_family = "wasm")'.dependencies.getrandom] -version = "0.2" -features = ["js"] +[target.'cfg(target_family = "wasm")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } [features] sql = ["polars-sql"] rows = ["polars-core/rows"] simd = ["polars-core/simd", "polars-io/simd", "polars-ops/simd"] avx512 = ["polars-core/avx512"] -nightly = ["polars-core/nightly", "polars-ops/nightly", "simd", "polars-lazy/nightly"] +nightly = ["polars-core/nightly", "polars-ops/nightly", "simd", "polars-lazy?/nightly"] docs = ["polars-core/docs"] -temporal = ["polars-core/temporal", "polars-lazy/temporal", "polars-io/temporal", "polars-time"] -random = ["polars-core/random", "polars-lazy/random"] +temporal = ["polars-core/temporal", "polars-lazy?/temporal", "polars-io/temporal", "polars-time"] +random = ["polars-core/random", "polars-lazy?/random", "polars-ops/random"] default = [ "docs", "zip_with", @@ -51,36 +50,44 @@ default = [ ndarray = ["polars-core/ndarray"] # serde support for dataframes and series serde = ["polars-core/serde"] -serde-lazy = ["polars-core/serde-lazy", "polars-lazy/serde", "polars-time/serde", "polars-io/serde", "polars-ops/serde"] -parquet = ["polars-io", "polars-core/parquet", "polars-lazy/parquet", "polars-io/parquet", "polars-sql/parquet"] -async = ["polars-lazy/async"] -aws = ["async", "polars-io/aws"] -azure = ["async", "polars-io/azure"] -gcp = ["async", "polars-io/gcp"] -lazy = ["polars-core/lazy", "polars-lazy", "polars-lazy/compile"] +serde-lazy = [ + "polars-core/serde-lazy", + "polars-lazy?/serde", + "polars-time?/serde", + "polars-io/serde", + "polars-ops/serde", +] +parquet = ["polars-io", "polars-core/parquet", "polars-lazy?/parquet", "polars-io/parquet", "polars-sql?/parquet"] +async = ["polars-lazy?/async"] +cloud = ["polars-lazy?/cloud", "polars-io/cloud"] +cloud_write = ["cloud", "polars-lazy?/cloud_write"] +aws = ["async", "cloud", "polars-io/aws"] +azure = ["async", "cloud", "polars-io/azure"] +gcp = ["async", "cloud", "polars-io/gcp"] +lazy = ["polars-core/lazy", "polars-lazy"] # commented out until UB is fixed # parallel = ["polars-core/parallel"] # extra utilities for Utf8Chunked -strings = ["polars-core/strings", "polars-lazy/strings", "polars-ops/strings"] +strings = ["polars-core/strings", "polars-lazy?/strings", "polars-ops/strings"] # support for ObjectChunked (downcastable Series of any type) -object = ["polars-core/object", "polars-lazy/object", "polars-io/object"] +object = ["polars-core/object", "polars-lazy?/object", "polars-io/object"] # support for arrows json parsing -json = ["polars-io", "polars-io/json", "polars-lazy/json", "polars-sql/json", "dtype-struct"] +json = ["polars-io", "polars-io/json", "polars-lazy?/json", "polars-sql?/json", "dtype-struct"] # support for arrows ipc file parsing -ipc = ["polars-io", "polars-io/ipc", "polars-lazy/ipc", "polars-sql/ipc"] +ipc = ["polars-io", "polars-io/ipc", "polars-lazy?/ipc", "polars-sql?/ipc"] # support for arrows streaming ipc file parsing -ipc_streaming = ["polars-io", "polars-io/ipc_streaming", "polars-lazy/ipc"] +ipc_streaming = ["polars-io", "polars-io/ipc_streaming", "polars-lazy?/ipc"] # support for apache avro file parsing avro = ["polars-io", "polars-io/avro"] # support for arrows csv file parsing -csv = ["polars-io", "polars-io/csv", "polars-lazy/csv", "polars-sql/csv"] +csv = ["polars-io", "polars-io/csv", "polars-lazy?/csv", "polars-sql?/csv"] # slower builds performant = [ @@ -103,88 +110,90 @@ fmt_no_tty = ["polars-core/fmt_no_tty"] sort_multiple = ["polars-core/sort_multiple"] # extra operations -approx_unique = ["polars-lazy/approx_unique", "polars-ops/approx_unique"] -is_in = ["polars-lazy/is_in"] -zip_with = ["polars-core/zip_with"] -round_series = ["polars-core/round_series", "polars-lazy/round_series", "polars-ops/round_series"] +approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique"] +is_in = ["polars-lazy?/is_in"] +zip_with = ["polars-core/zip_with", "polars-ops/zip_with"] +round_series = ["polars-core/round_series", "polars-lazy?/round_series", "polars-ops/round_series"] checked_arithmetic = ["polars-core/checked_arithmetic"] -repeat_by = ["polars-core/repeat_by", "polars-lazy/repeat_by"] -is_first = ["polars-lazy/is_first", "polars-ops/is_first"] -is_unique = ["polars-lazy/is_unique", "polars-ops/is_unique"] -is_last = ["polars-core/is_last"] -asof_join = ["polars-core/asof_join", "polars-lazy/asof_join", "polars-ops/asof_join"] -cross_join = ["polars-core/cross_join", "polars-lazy/cross_join", "polars-ops/cross_join"] +repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"] +is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"] +is_last_distinct = ["polars-lazy?/is_last_distinct", "polars-ops/is_last_distinct"] +is_unique = ["polars-lazy?/is_unique", "polars-ops/is_unique"] +asof_join = ["polars-core/asof_join", "polars-lazy?/asof_join", "polars-ops/asof_join"] +cross_join = ["polars-lazy?/cross_join", "polars-ops/cross_join"] dot_product = ["polars-core/dot_product"] -concat_str = ["polars-core/concat_str", "polars-lazy/concat_str"] -row_hash = ["polars-core/row_hash", "polars-lazy/row_hash"] +concat_str = ["polars-lazy?/concat_str"] +row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"] reinterpret = ["polars-core/reinterpret"] decompress = ["polars-io/decompress"] decompress-fast = ["polars-io/decompress-fast"] -mode = ["polars-core/mode", "polars-lazy/mode"] +mode = ["polars-ops/mode", "polars-lazy?/mode"] take_opt_iter = ["polars-core/take_opt_iter"] extract_jsonpath = [ "polars-core/strings", "polars-ops/extract_jsonpath", "polars-ops/strings", - "polars-lazy/extract_jsonpath", + "polars-lazy?/extract_jsonpath", ] string_encoding = ["polars-ops/string_encoding", "polars-core/strings"] binary_encoding = ["polars-ops/binary_encoding"] group_by_list = ["polars-core/group_by_list", "polars-ops/group_by_list"] -lazy_regex = ["polars-lazy/regex"] +lazy_regex = ["polars-lazy?/regex"] cum_agg = ["polars-core/cum_agg", "polars-core/cum_agg"] -rolling_window = ["polars-core/rolling_window", "polars-lazy/rolling_window", "polars-time/rolling_window"] -interpolate = ["polars-ops/interpolate", "polars-lazy/interpolate"] -rank = ["polars-core/rank", "polars-lazy/rank"] -diff = ["polars-core/diff", "polars-lazy/diff", "polars-ops/diff"] -pct_change = ["polars-core/pct_change", "polars-lazy/pct_change"] -moment = ["polars-core/moment", "polars-lazy/moment", "polars-ops/moment"] -range = ["polars-lazy/range"] -true_div = ["polars-lazy/true_div"] -diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy/diagonal_concat"] +rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] +interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] +rank = ["polars-lazy?/rank", "polars-ops/rank"] +diff = ["polars-core/diff", "polars-lazy?/diff", "polars-ops/diff"] +pct_change = ["polars-core/pct_change", "polars-lazy?/pct_change"] +moment = ["polars-core/moment", "polars-lazy?/moment", "polars-ops/moment"] +range = ["polars-lazy?/range"] +true_div = ["polars-lazy?/true_div"] +diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy?/diagonal_concat", "polars-sql?/diagonal_concat"] horizontal_concat = ["polars-core/horizontal_concat"] -abs = ["polars-core/abs", "polars-lazy/abs"] -dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy/dynamic_group_by"] -ewma = ["polars-core/ewma", "polars-lazy/ewma"] -dot_diagram = ["polars-lazy/dot_diagram"] +abs = ["polars-core/abs", "polars-lazy?/abs"] +dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] +ewma = ["polars-core/ewma", "polars-lazy?/ewma"] +dot_diagram = ["polars-lazy?/dot_diagram"] dataframe_arithmetic = ["polars-core/dataframe_arithmetic"] product = ["polars-core/product"] -unique_counts = ["polars-core/unique_counts", "polars-lazy/unique_counts"] -log = ["polars-ops/log", "polars-lazy/log"] +unique_counts = ["polars-core/unique_counts", "polars-lazy?/unique_counts"] +log = ["polars-ops/log", "polars-lazy?/log"] partition_by = ["polars-core/partition_by"] -semi_anti_join = ["polars-core/semi_anti_join", "polars-lazy/semi_anti_join", "polars-ops/semi_anti_join"] -list_eval = ["polars-lazy/list_eval"] -cumulative_eval = ["polars-lazy/cumulative_eval"] -chunked_ids = ["polars-core/chunked_ids", "polars-lazy/chunked_ids", "polars-core/chunked_ids"] +semi_anti_join = ["polars-lazy?/semi_anti_join", "polars-ops/semi_anti_join", "polars-sql?/semi_anti_join"] +list_eval = ["polars-lazy?/list_eval"] +cumulative_eval = ["polars-lazy?/cumulative_eval"] +chunked_ids = ["polars-lazy?/chunked_ids", "polars-core/chunked_ids", "polars-ops/chunked_ids"] to_dummies = ["polars-ops/to_dummies"] -bigidx = ["polars-core/bigidx", "polars-lazy/bigidx", "polars-ops/big_idx"] -list_to_struct = ["polars-ops/list_to_struct", "polars-lazy/list_to_struct"] -list_count = ["polars-ops/list_count", "polars-lazy/list_count"] -list_take = ["polars-ops/list_take", "polars-lazy/list_take"] +bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] +list_to_struct = ["polars-ops/list_to_struct", "polars-lazy?/list_to_struct"] +list_count = ["polars-ops/list_count", "polars-lazy?/list_count"] +list_take = ["polars-ops/list_take", "polars-lazy?/list_take"] describe = ["polars-core/describe"] -timezones = ["polars-core/timezones", "polars-lazy/timezones", "polars-io/timezones"] -string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"] -string_from_radix = ["polars-lazy/string_from_radix", "polars-ops/string_from_radix"] -arg_where = ["polars-lazy/arg_where"] -search_sorted = ["polars-lazy/search_sorted"] -merge_sorted = ["polars-lazy/merge_sorted"] -meta = ["polars-lazy/meta"] -date_offset = ["polars-lazy/date_offset"] -trigonometry = ["polars-lazy/trigonometry"] -sign = ["polars-lazy/sign"] -pivot = ["polars-lazy/pivot"] -top_k = ["polars-lazy/top_k"] +timezones = ["polars-core/timezones", "polars-lazy?/timezones", "polars-io/timezones"] +string_justify = ["polars-lazy?/string_justify", "polars-ops/string_justify"] +string_from_radix = ["polars-lazy?/string_from_radix", "polars-ops/string_from_radix"] +arg_where = ["polars-lazy?/arg_where"] +search_sorted = ["polars-lazy?/search_sorted"] +merge_sorted = ["polars-lazy?/merge_sorted"] +meta = ["polars-lazy?/meta"] +date_offset = ["polars-lazy?/date_offset"] +trigonometry = ["polars-lazy?/trigonometry"] +sign = ["polars-lazy?/sign"] +pivot = ["polars-lazy?/pivot"] +top_k = ["polars-lazy?/top_k"] algo = ["polars-algo"] -cse = ["polars-lazy/cse"] -propagate_nans = ["polars-lazy/propagate_nans"] -coalesce = ["polars-lazy/coalesce"] -streaming = ["polars-lazy/streaming"] -fused = ["polars-ops/fused", "polars-lazy/fused"] -list_sets = ["polars-lazy/list_sets"] -list_any_all = ["polars-lazy/list_any_all"] -cutqcut = ["polars-lazy/cutqcut"] -rle = ["polars-lazy/rle"] -extract_groups = ["polars-lazy/extract_groups"] +cse = ["polars-lazy?/cse"] +propagate_nans = ["polars-lazy?/propagate_nans"] +coalesce = ["polars-lazy?/coalesce"] +streaming = ["polars-lazy?/streaming"] +fused = ["polars-ops/fused", "polars-lazy?/fused"] +list_sets = ["polars-lazy?/list_sets"] +list_any_all = ["polars-lazy?/list_any_all"] +list_drop_nulls = ["polars-lazy?/list_drop_nulls"] +cutqcut = ["polars-lazy?/cutqcut"] +rle = ["polars-lazy?/rle"] +extract_groups = ["polars-lazy?/extract_groups"] +peaks = ["polars-lazy/peaks"] test = [ "lazy", @@ -229,51 +238,51 @@ dtype-slim = [ # opt-in datatypes for Series dtype-date = [ "polars-core/dtype-date", - "polars-lazy/dtype-date", + "polars-lazy?/dtype-date", "polars-io/dtype-date", - "polars-time/dtype-date", + "polars-time?/dtype-date", "polars-core/dtype-date", "polars-ops/dtype-date", ] dtype-datetime = [ "polars-core/dtype-datetime", - "polars-lazy/dtype-datetime", + "polars-lazy?/dtype-datetime", "polars-io/dtype-datetime", - "polars-time/dtype-datetime", + "polars-time?/dtype-datetime", "polars-ops/dtype-datetime", ] dtype-duration = [ "polars-core/dtype-duration", - "polars-lazy/dtype-duration", - "polars-time/dtype-duration", + "polars-lazy?/dtype-duration", + "polars-time?/dtype-duration", "polars-core/dtype-duration", "polars-ops/dtype-duration", ] -dtype-time = ["polars-core/dtype-time", "polars-io/dtype-time", "polars-time/dtype-time", "polars-ops/dtype-time"] +dtype-time = ["polars-core/dtype-time", "polars-io/dtype-time", "polars-time?/dtype-time", "polars-ops/dtype-time"] dtype-array = [ "polars-core/dtype-array", - "polars-lazy/dtype-array", + "polars-lazy?/dtype-array", "polars-ops/dtype-array", ] -dtype-i8 = ["polars-core/dtype-i8", "polars-lazy/dtype-i8", "polars-ops/dtype-i8"] -dtype-i16 = ["polars-core/dtype-i16", "polars-lazy/dtype-i16", "polars-ops/dtype-i16"] +dtype-i8 = ["polars-core/dtype-i8", "polars-lazy?/dtype-i8", "polars-ops/dtype-i8"] +dtype-i16 = ["polars-core/dtype-i16", "polars-lazy?/dtype-i16", "polars-ops/dtype-i16"] dtype-decimal = [ "polars-core/dtype-decimal", - "polars-lazy/dtype-decimal", + "polars-lazy?/dtype-decimal", "polars-ops/dtype-decimal", "polars-io/dtype-decimal", ] -dtype-u8 = ["polars-core/dtype-u8", "polars-lazy/dtype-u8", "polars-ops/dtype-u8"] -dtype-u16 = ["polars-core/dtype-u16", "polars-lazy/dtype-u16", "polars-ops/dtype-u16"] +dtype-u8 = ["polars-core/dtype-u8", "polars-lazy?/dtype-u8", "polars-ops/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16", "polars-lazy?/dtype-u16", "polars-ops/dtype-u16"] dtype-categorical = [ "polars-core/dtype-categorical", "polars-io/dtype-categorical", - "polars-lazy/dtype-categorical", + "polars-lazy?/dtype-categorical", "polars-ops/dtype-categorical", ] dtype-struct = [ "polars-core/dtype-struct", - "polars-lazy/dtype-struct", + "polars-lazy?/dtype-struct", "polars-ops/dtype-struct", "polars-io/dtype-struct", ] @@ -298,8 +307,8 @@ docs-selection = [ "checked_arithmetic", "ndarray", "repeat_by", - "is_first", - "is_last", + "is_first_distinct", + "is_last_distinct", "asof_join", "cross_join", "concat_str", diff --git a/crates/polars/src/docs/eager.rs b/crates/polars/src/docs/eager.rs index e31a5ada79e7..8e2c91f050a5 100644 --- a/crates/polars/src/docs/eager.rs +++ b/crates/polars/src/docs/eager.rs @@ -2,7 +2,11 @@ //! # Polars Eager cookbook //! //! This page should serve a cookbook to quickly get you started with most fundamental operations -//! executed on a `ChunkedArray`, `Series` or `DataFrame`. +//! executed on a [`ChunkedArray`], [`Series`] or [`DataFrame`]. +//! +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray +//! [`Series`]: crate::series::Series +//! [`DataFrame`]: crate::frame::DataFrame //! //! ## Tree Of Contents //! @@ -93,8 +97,8 @@ //! ``` //! //! ## Arithmetic -//! Arithmetic can be done on both `Series` and `ChunkedArray`s. The most notable difference is that -//! a `Series` coerces the data to match the underlying data types. +//! Arithmetic can be done on both [`Series`] and [`ChunkedArray`]. The most notable difference is that +//! a [`Series`] coerces the data to match the underlying data types. //! //! ``` //! use polars::prelude::*; @@ -141,7 +145,9 @@ //! let subtract_one_by_s = 1.sub(&series); //! ``` //! -//! For `ChunkedArray`s this left hand side operations can be done with the `apply` method. +//! For [`ChunkedArray`] this left hand side operations can be done with the [`apply_values`] method. +//! +//! [`apply_values`]: crate::chunked_array::ops::ChunkApply::apply_values //! //! ```rust //! # use polars::prelude::*; @@ -153,7 +159,7 @@ //! //! ## Comparisons //! -//! `Series` and `ChunkedArray`s can be used in comparison operations to create `boolean` masks/predicates. +//! [`Series`] and [`ChunkedArray`] can be used in comparison operations to create _boolean_ masks/predicates. //! //! ``` //! use polars::prelude::*; @@ -571,7 +577,7 @@ //! // write DataFrame to file //! CsvWriter::new(&mut file) //! .has_header(true) -//! .with_delimiter(b',') +//! .with_separator(b',') //! .finish(df); //! # Ok(()) //! # } @@ -642,8 +648,11 @@ //! //! ## Replace NaN with Missing. //! The floating point [Not a Number: NaN](https://en.wikipedia.org/wiki/NaN) is conceptually different -//! than missing data in Polars. In the snippet below we show how we can replace `NaN` values with -//! missing values, by setting them to `None`. +//! than missing data in Polars. In the snippet below we show how we can replace [`NaN`] values with +//! missing values, by setting them to [`None`]. +//! +//! [`NaN`]: https://doc.rust-lang.org/std/primitive.f64.html#associatedconstant.NAN +//! //! ``` //! use polars::prelude::*; //! use polars::df; @@ -671,9 +680,11 @@ //! //! ## Extracting data //! -//! To be able to extract data out of `Series`, either by iterating over them or converting them -//! to other datatypes like a `Vec`, we first need to downcast them to a `ChunkedArray`. This -//! is needed because we don't know the data type that is hold by the `Series`. +//! To be able to extract data out of [`Series`], either by iterating over them or converting them +//! to other datatypes like a [`Vec`], we first need to downcast them to a [`ChunkedArray`]. This +//! is needed because we don't know the data type that is hold by the [`Series`]. +//! +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray //! //! ``` //! use polars::prelude::*; diff --git a/crates/polars/src/docs/lazy.rs b/crates/polars/src/docs/lazy.rs index 0d2404b6a753..4b737fbc027d 100644 --- a/crates/polars/src/docs/lazy.rs +++ b/crates/polars/src/docs/lazy.rs @@ -106,7 +106,7 @@ //! //! ## Groupby //! -//! This example is from the polars [user guide](https://pola-rs.github.io/polars-book/user-guide/concepts/contexts/#group_by-aggregation). +//! This example is from the polars [user guide](https://pola-rs.github.io/polars/user-guide/concepts/contexts/#group_by-aggregation). //! //! ``` //! use polars::prelude::*; @@ -114,7 +114,7 @@ //! //! let df = LazyCsvReader::new("reddit.csv") //! .has_header(true) -//! .with_delimiter(b',') +//! .with_separator(b',') //! .finish()? //! .group_by([col("comment_karma")]) //! .agg([col("name").n_unique().alias("unique_names"), col("link_karma").max()]) @@ -190,11 +190,15 @@ //! ``` //! //! ## Conditionally apply -//! If we want to create a new column based on some condition, we can use the `.when()/.then()/.otherwise()` expressions. +//! If we want to create a new column based on some condition, we can use the [`when`]/[`then`]/[`otherwise`] expressions. //! -//! * `when` - accepts a predicate expression -//! * `then` - expression to use when `predicate == true` -//! * `otherwise` - expression to use when `predicate == false` +//! * [`when`] - accepts a predicate expression +//! * [`then`] - expression to use when `predicate == true` +//! * [`otherwise`] - expression to use when `predicate == false` +//! +//! [`when`]: polars_lazy::dsl::Then::when +//! [`then`]: polars_lazy::dsl::When::then +//! [`otherwise`]: polars_lazy::dsl::Then::otherwise //! //! ``` //! use polars::prelude::*; @@ -239,7 +243,9 @@ //! //! The expression API should be expressive enough for most of what you want to achieve, but it can happen //! that you need to pass the values to an external function you do not control. The snippet below -//! shows how we use the `Struct` datatype to be able to apply a function over multiple inputs. +//! shows how we use the [`Struct`] datatype to be able to apply a function over multiple inputs. +//! +//! [`Struct`]: crate::datatypes::DataType::Struct //! //! ```ignore //! use polars::prelude::*; @@ -248,31 +254,31 @@ //! a //! } //! -//! fn apply_multiples(lf: LazyFrame) -> PolarsResult { +//! fn apply_multiples() -> PolarsResult { //! df![ -//! "a" => [1.0, 2.0, 3.0], -//! "b" => [3.0, 5.1, 0.3] +//! "a" => [1.0f32, 2.0, 3.0], +//! "b" => [3.0f32, 5.1, 0.3] //! ]? //! .lazy() -//! .select([concat_list(["col_a", "col_b"]).map( +//! .select([as_struct(&[col("a"), col("b")]).map( //! |s| { //! let ca = s.struct_()?; //! -//! let b = ca.field_by_name("col_a")?; -//! let a = ca.field_by_name("col_b")?; -//! let a = a.f32()?; -//! let b = b.f32()?; +//! let series_a = ca.field_by_name("a")?; +//! let series_b = ca.field_by_name("b")?; +//! let chunked_a = series_a.f32()?; +//! let chunked_b = series_b.f32()?; //! -//! let out: Float32Chunked = a +//! let out: Float32Chunked = chunked_a //! .into_iter() -//! .zip(b.into_iter()) +//! .zip(chunked_b.into_iter()) //! .map(|(opt_a, opt_b)| match (opt_a, opt_b) { //! (Some(a), Some(b)) => Some(my_black_box_function(a, b)), //! _ => None, //! }) //! .collect(); //! -//! Ok(out.into_series()) +//! Ok(Some(out.into_series())) //! }, //! GetOutput::from_type(DataType::Float32), //! )]) diff --git a/crates/polars/src/docs/performance.rs b/crates/polars/src/docs/performance.rs index 466fb54c2ae4..e8f5da30c538 100644 --- a/crates/polars/src/docs/performance.rs +++ b/crates/polars/src/docs/performance.rs @@ -1,6 +1,6 @@ //! # Performance //! -//! Understanding the memory format used by Arrow/ Polars can really increase performance of your +//! Understanding the memory format used by Arrow/Polars can really increase performance of your //! queries. This is especially true for large string data. The figure below shows how an Arrow UTF8 //! array is laid out in memory. //! @@ -13,27 +13,32 @@ //! ![](https://raw.githubusercontent.com/pola-rs/polars-static/master/docs/arrow-string.svg) //! //! This memory structure is very cache efficient if we are to read the string values. Especially if -//! we compare it to a `Vec`. +//! we compare it to a [`Vec`]. //! //! ![](https://raw.githubusercontent.com/pola-rs/polars-static/master/docs/pandas-string.svg) //! //! However, if we need to reorder the Arrow UTF8 array, we need to swap around all the bytes of the //! string values, which can become very expensive when we're dealing with large strings. On the -//! other hand, for the `Vec`, we only need to swap pointers around which is only 8 bytes data +//! other hand, for the [`Vec`], we only need to swap pointers around which is only 8 bytes data //! that have to be moved. //! -//! If you have a [DataFrame](crate::frame::DataFrame) with a large number of -//! [Utf8Chunked](crate::datatypes::Utf8Chunked) columns and you need to reorder them due to an +//! If you have a [`DataFrame`] with a large number of +//! [`Utf8Chunked`] columns and you need to reorder them due to an //! operation like a FILTER, JOIN, GROUPBY, etc. than this can become quite expensive. //! //! ## Categorical type -//! For this reason Polars has a [CategoricalType](https://pola-rs.github.io/polars/polars/datatypes/struct.CategoricalType.html). -//! A `CategoricalChunked` is an array filled with `u32` values that each represent a unique string value. +//! For this reason Polars has a [`CategoricalType`]. +//! A [`CategoricalChunked`] is an array filled with `u32` values that each represent a unique string value. //! Thereby maintaining cache-efficiency, whilst also making it cheap to move values around. //! +//! [`DataFrame`]: crate::frame::DataFrame +//! [`Utf8Chunked`]: crate::datatypes::Utf8Chunked +//! [`CategoricalType`]: crate::datatypes::CategoricalType +//! [`CategoricalChunked`]: crate::datatypes::CategoricalChunked +//! //! ### Example: Single DataFrame //! -//! In the example below we show how you can cast a `Utf8Chunked` column to a `CategoricalChunked`. +//! In the example below we show how you can cast a [`Utf8Chunked`] column to a [`CategoricalChunked`]. //! //! ```rust //! use polars::prelude::*; @@ -49,18 +54,20 @@ //! ``` //! //! ### Example: Eager join multiple DataFrames on a Categorical -//! When the strings of one column need to be joined with the string data from another `DataFrame`. -//! The `Categorical` data needs to be synchronized (Categories in df A need to point to the same +//! When the strings of one column need to be joined with the string data from another [`DataFrame`]. +//! The [`Categorical`] data needs to be synchronized (Categories in df A need to point to the same //! underlying string data as Categories in df B). You can do that by turning the global string cache //! on. //! +//! [`Categorical`]: crate::datatypes::CategoricalChunked +//! //! ```rust //! use polars::prelude::*; //! use polars::enable_string_cache; //! //! fn example(mut df_a: DataFrame, mut df_b: DataFrame) -> PolarsResult { //! // Set a global string cache -//! enable_string_cache(true); +//! enable_string_cache(); //! //! df_a.try_apply("a", |s| s.categorical().cloned())?; //! df_b.try_apply("b", |s| s.categorical().cloned())?; @@ -69,8 +76,10 @@ //! ``` //! //! ### Example: Lazy join multiple DataFrames on a Categorical -//! A lazy Query always has a global string cache (unless you opt-out) for the duration of that query (until `collect` is called). -//! The example below shows how you could join two DataFrames with Categorical types. +//! A lazy Query always has a global string cache (unless you opt-out) for the duration of that query (until [`collect`] is called). +//! The example below shows how you could join two [`DataFrame`]s with [`Categorical`] types. +//! +//! [`collect`]: polars_lazy::frame::LazyFrame::collect //! //! ```rust //! # #[cfg(feature = "lazy")] diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 01516feff4bf..0eaf13c040ce 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -5,10 +5,12 @@ //! standard for columnar data. //! //! ## Quickstart -//! We recommend to build your queries directly with polars-lazy. This allows you to combine +//! We recommend to build your queries directly with [polars-lazy]. This allows you to combine //! expression into powerful aggregations and column selections. All expressions are evaluated //! in parallel and your queries are optimized just in time. //! +//! [polars-lazy]: polars_lazy +//! //! ```no_run //! use polars::prelude::*; //! # fn example() -> PolarsResult<()> { @@ -64,27 +66,37 @@ //! * [Lazy](crate::docs::lazy) //! //! ## Data Structures -//! The base data structures provided by polars are `DataFrame`, `Series`, and `ChunkedArray`. +//! The base data structures provided by polars are [`DataFrame`], [`Series`], and [`ChunkedArray`]. //! We will provide a short, top-down view of these data structures. //! +//! [`DataFrame`]: crate::frame::DataFrame +//! [`Series`]: crate::series::Series +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray +//! //! ### DataFrame -//! A `DataFrame` is a 2 dimensional data structure that is backed by a `Series`, and it could be -//! seen as an abstraction on `Vec`. Operations that can be executed on `DataFrame`s are very +//! A [`DataFrame`] is a 2 dimensional data structure that is backed by a [`Series`], and it could be +//! seen as an abstraction on [`Vec`]. Operations that can be executed on [`DataFrame`] are very //! similar to what is done in a `SQL` like query. You can `GROUP`, `JOIN`, `PIVOT` etc. //! +//! [`Vec`]: std::vec::Vec +//! //! ### Series -//! `Series` are the type agnostic columnar data representation of Polars. They provide many -//! operations out of the box, many via the [Series struct](crate::prelude::Series) and -//! [SeriesTrait trait](crate::series::SeriesTrait). Whether or not an operation is provided -//! by a `Series` is determined by the operation. If the operation can be done without knowing the -//! underlying columnar type, this operation probably is provided by the `Series`. If not, you must -//! downcast to the typed data structure that is wrapped by the `Series`. That is the `ChunkedArray`. +//! [`Series`] are the type agnostic columnar data representation of Polars. They provide many +//! operations out of the box, many via the [`Series`] series and +//! [`SeriesTrait`] trait. Whether or not an operation is provided +//! by a [`Series`] is determined by the operation. If the operation can be done without knowing the +//! underlying columnar type, this operation probably is provided by the [`Series`]. If not, you must +//! downcast to the typed data structure that is wrapped by the [`Series`]. That is the [`ChunkedArray`]. +//! +//! [`SeriesTrait`]: crate::series::SeriesTrait //! //! ### ChunkedArray -//! `ChunkedArray` are wrappers around an arrow array, that can contain multiples chunks, e.g. -//! `Vec`. These are the root data structures of Polars, and implement many operations. -//! Most operations are implemented by traits defined in [chunked_array::ops](crate::chunked_array::ops), -//! or on the [ChunkedArray struct](crate::chunked_array::ChunkedArray). +//! [`ChunkedArray`] are wrappers around an arrow array, that can contain multiples chunks, e.g. +//! [`Vec`]. These are the root data structures of Polars, and implement many operations. +//! Most operations are implemented by traits defined in [chunked_array::ops], +//! or on the [`ChunkedArray`] struct. +//! +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray //! //! ## SIMD //! Polars / Arrow uses packed_simd to speed up kernels with SIMD operations. SIMD is an optional @@ -95,15 +107,17 @@ //! more verbose and less capable of building elegant composite queries. We recommend to use the Lazy API //! whenever you can. //! -//! As neither API is async they should be wrapped in `spawn_blocking` when used in an async context +//! As neither API is async they should be wrapped in _spawn_blocking_ when used in an async context //! to avoid blocking the async thread pool of the runtime. //! //! ## Expressions //! Polars has a powerful concept called expressions. //! Polars expressions can be used in various contexts and are a functional mapping of -//! `Fn(Series) -> Series`, meaning that they have Series as input and Series as output. -//! By looking at this functional definition, we can see that the output of an `Expr` also can serve -//! as the input of an `Expr`. +//! `Fn(Series) -> Series`, meaning that they have [`Series`] as input and [`Series`] as output. +//! By looking at this functional definition, we can see that the output of an [`Expr`] also can serve +//! as the input of an [`Expr`]. +//! +//! [`Expr`]: polars_lazy::dsl::Expr //! //! That may sound a bit strange, so lets give an example. The following is an expression: //! @@ -133,7 +147,7 @@ //! (Note that within an expression there may be more parallelization going on). //! //! Understanding polars expressions is most important when starting with the polars library. Read more -//! about them in the [User Guide](https://pola-rs.github.io/polars-book/user-guide/concepts/expressions). +//! about them in the [User Guide](https://pola-rs.github.io/polars/user-guide/concepts/expressions). //! Though the examples given there are in python. The expressions API is almost identical and the //! the read should certainly be valuable to rust users as well. //! @@ -150,7 +164,7 @@ //! Unlock full potential with lazy computation. This allows query optimizations and provides Polars //! the full query context so that the fastest algorithm can be chosen. //! -//! **[Read more in the lazy module.](crate::lazy)** +//! **[Read more in the lazy module.](polars_lazy)** //! //! ## Compile times //! A DataFrame library typically consists of @@ -166,18 +180,17 @@ //! //! * `performant` - Longer compile times more fast paths. //! * `lazy` - Lazy API -//! - `lazy_regex` - Use regexes in [column selection](crate::lazy::dsl::col) +//! - `lazy_regex` - Use regexes in [column selection] //! - `dot_diagram` - Create dot diagrams from lazy logical plans. //! * `sql` - Pass SQL queries to polars. //! * `streaming` - Be able to process datasets that are larger than RAM. //! * `random` - Generate arrays with randomly sampled values -//! * `ndarray`- Convert from `DataFrame` to `ndarray` +//! * `ndarray`- Convert from [`DataFrame`] to [ndarray](https://docs.rs/ndarray/) //! * `temporal` - Conversions between [Chrono](https://docs.rs/chrono/) and Polars for temporal data types //! * `timezones` - Activate timezone support. -//! * `strings` - Extra string utilities for `Utf8Chunked` -//! - `string_justify` - `zfill`, `ljust`, `rjust` +//! * `strings` - Extra string utilities for [`Utf8Chunked`] //! - `string_justify` - `zfill`, `ljust`, `rjust` //! - `string_from_radix` - `parse_int` -//! * `object` - Support for generic ChunkedArrays called `ObjectChunked` (generic over `T`). +//! * `object` - Support for generic ChunkedArrays called [`ObjectChunked`] (generic over `T`). //! These are downcastable from Series through the [Any](https://doc.rust-lang.org/std/any/index.html) trait. //! * Performance related: //! - `nightly` - Several nightly only features such as SIMD and specialization. @@ -200,36 +213,41 @@ //! * zip //! * gzip //! -//! * `DataFrame` operations: +//! [`Utf8Chunked`]: crate::datatypes::Utf8Chunked +//! [column selection]: polars_lazy::dsl::col +//! [`ObjectChunked`]: polars_core::datatypes::ObjectChunked +//! +//! +//! * [`DataFrame`] operations: //! - `dynamic_group_by` - Groupby based on a time window instead of predefined keys. //! Also activates rolling window group by operations. -//! - `sort_multiple` - Allow sorting a `DataFrame` on multiple columns -//! - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`. +//! - `sort_multiple` - Allow sorting a [`DataFrame`] on multiple columns +//! - `rows` - Create [`DataFrame`] from rows and extract rows from [`DataFrame`]s. //! And activates `pivot` and `transpose` operations //! - `asof_join` - Join ASOF, to join on nearest keys instead of exact equality match. -//! - `cross_join` - Create the cartesian product of two DataFrames. +//! - `cross_join` - Create the cartesian product of two [`DataFrame`]s. //! - `semi_anti_join` - SEMI and ANTI joins. //! - `group_by_list` - Allow group_by operation on keys of type List. -//! - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked +//! - `row_hash` - Utility to hash [`DataFrame`] rows to [`UInt64Chunked`] //! - `diagonal_concat` - Concat diagonally thereby combining different schemas. //! - `horizontal_concat` - Concat horizontally and extend with null values if lengths don't match -//! - `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series) -//! - `partition_by` - Split into multiple DataFrames partitioned by groups. -//! * `Series`/`Expression` operations: -//! - `is_in` - Check for membership in `Series`. +//! - `dataframe_arithmetic` - Arithmetic on ([`Dataframe`] and [`DataFrame`]s) and ([`DataFrame`] on [`Series`]) +//! - `partition_by` - Split into multiple [`DataFrame`]s partitioned by groups. +//! * [`Series`]/[`Expr`] operations: +//! - `is_in` - Check for membership in [`Series`]. //! - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip). -//! - `round_series` - round underlying float types of `Series`. +//! - `round_series` - round underlying float types of [`Series`]. //! - `repeat_by` - [Repeat element in an Array N times, where N is given by another array. -//! - `is_first` - Check if element is first unique value. -//! - `is_last` - Check if element is last unique value. -//! - `checked_arithmetic` - checked arithmetic/ returning `None` on invalid operations. -//! - `dot_product` - Dot/inner product on Series and Expressions. +//! - `is_first_distinct` - Check if element is first unique value. +//! - `is_last_distinct` - Check if element is last unique value. +//! - `checked_arithmetic` - checked arithmetic/ returning [`None`] on invalid operations. +//! - `dot_product` - Dot/inner product on [`Series`] and [`Expr`]. //! - `concat_str` - Concat string data in linear time. //! - `reinterpret` - Utility to reinterpret bits to signed/unsigned -//! - `take_opt_iter` - Take from a Series with `Iterator>` -//! - `mode` - [Return the most occurring value(s)](crate::chunked_array::ops::ChunkUnique::mode) -//! - `cum_agg` - cumsum, cummin, cummax aggregation. -//! - `rolling_window` - rolling window functions, like rolling_mean +//! - `take_opt_iter` - Take from a [`Series`] with [`Iterator>`](std::iter::Iterator). +//! - `mode` - [Return the most occurring value(s)](polars_ops::chunked_array::mode) +//! - `cum_agg` - [`cumsum`], [`cummin`], [`cummax`] aggregation. +//! - `rolling_window` - rolling window functions, like [`rolling_mean`] //! - `interpolate` [interpolate None values](polars_ops::chunked_array::interpolate) //! - `extract_jsonpath` - [Run jsonpath queries on Utf8Chunked](https://goessner.net/articles/JsonPath/) //! - `list` - List utils. @@ -237,14 +255,14 @@ //! - `rank` - Ranking algorithms. //! - `moment` - kurtosis and skew statistics //! - `ewma` - Exponential moving average windows -//! - `abs` - Get absolute values of Series -//! - `arange` - Range operation on Series -//! - `product` - Compute the product of a Series. -//! - `diff` - `diff` operation. +//! - `abs` - Get absolute values of [`Series`]. +//! - `arange` - Range operation on [`Series`]. +//! - `product` - Compute the product of a [`Series`]. +//! - `diff` - [`diff`] operation. //! - `pct_change` - Compute change percentages. //! - `unique_counts` - Count unique values in expressions. -//! - `log` - Logarithms for `Series`. -//! - `list_to_struct` - Convert `List` to `Struct` dtypes. +//! - `log` - Logarithms for [`Series`]. +//! - `list_to_struct` - Convert [`List`] to [`Struct`] dtypes. //! - `list_count` - Count elements in lists. //! - `list_eval` - Apply expressions over list elements. //! - `list_sets` - Compute UNION, INTERSECTION, and DIFFERENCE on list types. @@ -253,21 +271,30 @@ //! - `search_sorted` - Find indices where elements should be inserted to maintain order. //! - `date_offset` - Add an offset to dates that take months and leap years into account. //! - `trigonometry` - Trigonometric functions. -//! - `sign` - Compute the element-wise sign of a Series. +//! - `sign` - Compute the element-wise sign of a [`Series`]. //! - `propagate_nans` - NaN propagating min/max aggregations. //! - `extract_groups` - Extract multiple regex groups from strings. -//! * `DataFrame` pretty printing -//! - `fmt` - Activate DataFrame formatting +//! * [`DataFrame`] pretty printing +//! - `fmt` - Activate [`DataFrame`] formatting +//! +//! [`UInt64Chunked`]: crate::datatypes::UInt64Chunked +//! [`cumsum`]: crate::series::Series::cumsum +//! [`cummin`]: crate::series::Series::cummin +//! [`cummax`]: crate::series::Series::cummax +//! [`rolling_mean`]: crate::series::Series#method.rolling_mean +//! [`diff`]: crate::series::Series::diff +//! [`List`]: crate::datatypes::DataType::List +//! [`Struct`]: crate::datatypes::DataType::Struct //! //! ## Compile times and opt-in data types -//! As mentioned above, Polars `Series` are wrappers around -//! `ChunkedArray` without the generic parameter `T`. +//! As mentioned above, Polars [`Series`] are wrappers around +//! [`ChunkedArray`] without the generic parameter `T`. //! To get rid of the generic parameter, all the possible value of `T` are compiled -//! for `Series`. This gets more expensive the more types you want for a `Series`. In order to reduce -//! the compile times, we have decided to default to a minimal set of types and make more `Series` types +//! for [`Series`]. This gets more expensive the more types you want for a [`Series`]. In order to reduce +//! the compile times, we have decided to default to a minimal set of types and make more [`Series`] types //! opt-in. //! -//! Note that if you get strange compile time errors, you probably need to opt-in for that `Series` dtype. +//! Note that if you get strange compile time errors, you probably need to opt-in for that [`Series`] dtype. //! The opt-in dtypes are: //! //! | data type | feature flag | @@ -363,15 +390,14 @@ //! * `POLARS_PARTITION_UNIQUE_COUNT` -> at which (estimated) key count a partitioned group_by should run. //! defaults to `1000`, any higher cardinality will run default group_by. //! * `POLARS_FORCE_PARTITION` -> force partitioned group_by if the keys and aggregations allow it. -//! * `POLARS_ALLOW_EXTENSION` -> allows for `[ObjectChunked]` to be used in arrow, opening up possibilities like using +//! * `POLARS_ALLOW_EXTENSION` -> allows for [`ObjectChunked`] to be used in arrow, opening up possibilities like using //! `T` in complex lazy expressions. However this does require `unsafe` code allow this. //! * `POLARS_NO_PARQUET_STATISTICS` -> if set, statistics in parquet files are ignored. //! * `POLARS_PANIC_ON_ERR` -> panic instead of returning an Error. //! * `POLARS_NO_CHUNKED_JOIN` -> force rechunk before joins. //! -//! //! ## User Guide -//! If you want to read more, [check the User Guide](https://pola-rs.github.io/polars-book/). +//! If you want to read more, [check the User Guide](https://pola-rs.github.io/polars/). #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![allow(ambiguous_glob_reexports)] pub mod docs; @@ -382,8 +408,8 @@ pub mod prelude; pub mod sql; pub use polars_core::{ - apply_method_all_arrow_series, chunked_array, datatypes, df, doc, error, frame, functions, - series, testing, + apply_method_all_arrow_series, chunked_array, datatypes, df, error, frame, functions, series, + testing, }; #[cfg(feature = "dtype-categorical")] pub use polars_core::{enable_string_cache, using_string_cache}; diff --git a/crates/polars/src/sql.rs b/crates/polars/src/sql.rs index e0451dc7c505..31feba6cbc3a 100644 --- a/crates/polars/src/sql.rs +++ b/crates/polars/src/sql.rs @@ -1 +1,2 @@ +pub use polars_sql::function_registry::*; pub use polars_sql::{keywords, sql_expr, SQLContext}; diff --git a/crates/polars/tests/it/core/date_like.rs b/crates/polars/tests/it/core/date_like.rs index 051920789c6d..8d97b5e87d1d 100644 --- a/crates/polars/tests/it/core/date_like.rs +++ b/crates/polars/tests/it/core/date_like.rs @@ -157,3 +157,32 @@ fn test_duration() -> PolarsResult<()> { ); Ok(()) } + +#[test] +#[cfg(feature = "dtype-duration")] +fn test_duration_date_arithmetic() { + let date1 = Int32Chunked::new("", &[1, 1, 1]).into_date().into_series(); + let date2 = Int32Chunked::new("", &[2, 3, 4]).into_date().into_series(); + + let diff_ms = &date2 - &date1; + let diff_us = diff_ms + .cast(&DataType::Duration(TimeUnit::Microseconds)) + .unwrap(); + let diff_ns = diff_ms + .cast(&DataType::Duration(TimeUnit::Nanoseconds)) + .unwrap(); + + // `+` is commutative for date and duration + assert_series_eq(&(&diff_ms + &date1), &(&date1 + &diff_ms)); + assert_series_eq(&(&diff_us + &date1), &(&date1 + &diff_us)); + assert_series_eq(&(&diff_ns + &date1), &(&date1 + &diff_ns)); + + // `+` is correct date and duration + assert_series_eq(&(&diff_ms + &date1), &date2); + assert_series_eq(&(&diff_us + &date1), &date2); + assert_series_eq(&(&diff_ns + &date1), &date2); +} + +fn assert_series_eq(s1: &Series, s2: &Series) { + assert!(s1.series_equal(s2)) +} diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 9077d5cff21d..5f3550abfa79 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -1,6 +1,6 @@ use polars_core::utils::{accumulate_dataframes_vertical, split_df}; #[cfg(feature = "dtype-categorical")] -use polars_core::{reset_string_cache, IUseStringCache}; +use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; use super::*; @@ -256,8 +256,9 @@ fn test_join_multiple_columns() { #[cfg_attr(miri, ignore)] #[cfg(feature = "dtype-categorical")] fn test_join_categorical() { - let _lock = IUseStringCache::hold(); - let _lock = polars_core::SINGLE_LOCK.lock(); + let _guard = SINGLE_LOCK.lock(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); let (mut df_a, mut df_b) = get_dfs(); @@ -294,11 +295,10 @@ fn test_join_categorical() { let (mut df_a, mut df_b) = get_dfs(); df_a.try_apply("b", |s| s.cast(&DataType::Categorical(None))) .unwrap(); - // create a new cache - reset_string_cache(); - // _sc is needed to ensure we hold the string cache. - let _sc = IUseStringCache::hold(); + // Create a new string cache + drop(_sc); + let _sc = StringCacheHolder::hold(); df_b.try_apply("bar", |s| s.cast(&DataType::Categorical(None))) .unwrap(); diff --git a/crates/polars/tests/it/core/ops/take.rs b/crates/polars/tests/it/core/ops/take.rs index 14acb9338970..a958997954c1 100644 --- a/crates/polars/tests/it/core/ops/take.rs +++ b/crates/polars/tests/it/core/ops/take.rs @@ -2,13 +2,13 @@ use super::*; #[test] fn test_list_take_nulls_and_empty() { - unsafe { - let a: &[i32] = &[]; - let a = Series::new("", a); - let b = Series::new("", &[None, Some(a.clone())]); - let mut iter = [Some(0), Some(1usize), None].iter().copied(); - let out = b.take_opt_iter_unchecked(&mut iter); - let expected = Series::new("", &[None, Some(a), None]); - assert!(out.series_equal_missing(&expected)) - } + let a: &[i32] = &[]; + let a = Series::new("", a); + let b = Series::new("", &[None, Some(a.clone())]); + let indices = [Some(0 as IdxSize), Some(1), None] + .into_iter() + .collect_ca(""); + let out = b.take(&indices).unwrap(); + let expected = Series::new("", &[None, Some(a), None]); + assert!(out.series_equal_missing(&expected)) } diff --git a/crates/polars/tests/it/core/series.rs b/crates/polars/tests/it/core/series.rs index 42b533c78f1a..400895c2cde1 100644 --- a/crates/polars/tests/it/core/series.rs +++ b/crates/polars/tests/it/core/series.rs @@ -40,6 +40,6 @@ fn test_min_max_sorted_desc() { #[test] fn test_construct_list_of_null_series() { let s = Series::new("a", [Series::new_null("a1", 1), Series::new_null("a1", 1)]); - assert_eq!(s.null_count(), s.len()); + assert_eq!(s.null_count(), 0); assert_eq!(s.field().name(), "a"); } diff --git a/crates/polars/tests/it/io/csv.rs b/crates/polars/tests/it/io/csv.rs index 9df2115ed8d8..5d51a83f5a40 100644 --- a/crates/polars/tests/it/io/csv.rs +++ b/crates/polars/tests/it/io/csv.rs @@ -153,7 +153,7 @@ fn test_tab_sep() { let file = Cursor::new(csv); let df = CsvReader::new(file) .infer_schema(Some(100)) - .with_delimiter(b'\t') + .with_separator(b'\t') .has_header(false) .with_ignore_errors(true) .finish() @@ -472,10 +472,9 @@ fn test_skip_rows() -> PolarsResult<()> { let df = CsvReader::new(file) .has_header(false) .with_skip_rows(3) - .with_delimiter(b' ') + .with_separator(b' ') .finish()?; - dbg!(&df); assert_eq!(df.height(), 3); Ok(()) } @@ -491,7 +490,7 @@ fn test_projection_idx() -> PolarsResult<()> { let df = CsvReader::new(file) .has_header(false) .with_projection(Some(vec![4, 5])) - .with_delimiter(b' ') + .with_separator(b' ') .finish()?; assert_eq!(df.width(), 2); @@ -501,7 +500,7 @@ fn test_projection_idx() -> PolarsResult<()> { let out = CsvReader::new(file) .has_header(false) .with_projection(Some(vec![4, 6])) - .with_delimiter(b' ') + .with_separator(b' ') .finish(); assert!(out.is_err()); @@ -788,7 +787,7 @@ fn test_infer_schema_eol() -> PolarsResult<()> { } #[test] -fn test_whitespace_delimiters() -> PolarsResult<()> { +fn test_whitespace_separators() -> PolarsResult<()> { let tsv = "\ta\tb\tc\n1\ta1\tb1\tc1\n2\ta2\tb2\tc2\n".to_string(); let contents = vec![ @@ -799,7 +798,7 @@ fn test_whitespace_delimiters() -> PolarsResult<()> { for (content, sep) in contents { let file = Cursor::new(&content); - let df = CsvReader::new(file).with_delimiter(sep).finish()?; + let df = CsvReader::new(file).with_separator(sep).finish()?; assert_eq!(df.shape(), (2, 4)); assert_eq!(df.get_column_names(), &["", "a", "b", "c"]); @@ -828,7 +827,7 @@ fn test_tsv_header_offset() -> PolarsResult<()> { let file = Cursor::new(csv); let df = CsvReader::new(file) .truncate_ragged_lines(true) - .with_delimiter(b'\t') + .with_separator(b'\t') .finish()?; assert_eq!(df.shape(), (3, 2)); @@ -859,7 +858,7 @@ fn test_null_values_infer_schema() -> PolarsResult<()> { fn test_comma_separated_field_in_tsv() -> PolarsResult<()> { let csv = "first\tsecond\n1\t2.3,2.4\n3\t4.5,4.6\n"; let file = Cursor::new(csv); - let df = CsvReader::new(file).with_delimiter(b'\t').finish()?; + let df = CsvReader::new(file).with_separator(b'\t').finish()?; assert_eq!(df.dtypes(), &[DataType::Int64, DataType::Utf8]); Ok(()) } @@ -1096,7 +1095,7 @@ fn test_try_parse_dates_3380() -> PolarsResult<()> { 46.685;7.953;2022-05-10T08:07:12Z;8.8;0.00"; let file = Cursor::new(csv); let df = CsvReader::new(file) - .with_delimiter(b';') + .with_separator(b';') .with_try_parse_dates(true) .finish()?; assert_eq!(df.column("validdate")?.null_count(), 0); diff --git a/crates/polars/tests/it/lazy/explodes.rs b/crates/polars/tests/it/lazy/explodes.rs index 540af19a1525..01cc6ff69db7 100644 --- a/crates/polars/tests/it/lazy/explodes.rs +++ b/crates/polars/tests/it/lazy/explodes.rs @@ -9,7 +9,7 @@ fn test_explode_row_numbers() -> PolarsResult<()> { "text" => ["one two three four", "uno dos tres cuatro"] ]? .lazy() - .select([col("text").str().split(" ").alias("tokens")]) + .select([col("text").str().split(lit(" ")).alias("tokens")]) .with_row_count("row_nr", None) .explode([col("tokens")]) .select([col("row_nr"), col("tokens")]) diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs index 290bd9f3efca..a4cfc7796a66 100644 --- a/crates/polars/tests/it/lazy/expressions/arity.rs +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -116,7 +116,7 @@ fn includes_null_predicate_3038() -> PolarsResult<()> { #[test] #[cfg(feature = "dtype-categorical")] fn test_when_then_otherwise_cats() -> PolarsResult<()> { - polars::enable_string_cache(true); + polars::enable_string_cache(); let lf = df!["book" => [Some("bookA"), None, diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs index 19e13dc6dccb..4b9ca8123593 100644 --- a/crates/polars/tests/it/lazy/expressions/window.rs +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -47,12 +47,9 @@ fn test_shift_and_fill_window_function() -> PolarsResult<()> { .lazy() .select([ col("fruits"), - col("B").shift_and_fill(-1, lit(-1)).over_with_options( - [col("fruits")], - WindowOptions { - mapping: WindowMapping::Join, - }, - ), + col("B") + .shift_and_fill(-1, lit(-1)) + .over_with_options([col("fruits")], WindowMapping::Join), ]) .collect()?; @@ -61,12 +58,9 @@ fn test_shift_and_fill_window_function() -> PolarsResult<()> { .lazy() .select([ col("fruits"), - col("B").shift_and_fill(-1, lit(-1)).over_with_options( - [col("fruits")], - WindowOptions { - mapping: WindowMapping::Join, - }, - ), + col("B") + .shift_and_fill(-1, lit(-1)) + .over_with_options([col("fruits")], WindowMapping::Join), ]) .collect()?; @@ -87,12 +81,7 @@ fn test_exploded_window_function() -> PolarsResult<()> { col("fruits"), col("B") .shift(1) - .over_with_options( - [col("fruits")], - WindowOptions { - mapping: WindowMapping::Explode, - }, - ) + .over_with_options([col("fruits")], WindowMapping::Explode) .alias("shifted"), ]) .collect()?; @@ -111,12 +100,7 @@ fn test_exploded_window_function() -> PolarsResult<()> { col("fruits"), col("B") .shift_and_fill(1, lit(-1.0f32)) - .over_with_options( - [col("fruits")], - WindowOptions { - mapping: WindowMapping::Explode, - }, - ) + .over_with_options([col("fruits")], WindowMapping::Explode) .alias("shifted"), ]) .collect()?; @@ -185,12 +169,7 @@ fn test_literal_window_fn() -> PolarsResult<()> { .lazy() .select([repeat(1, count()) .cumsum(false) - .over_with_options( - [col("chars")], - WindowOptions { - mapping: WindowMapping::Join, - }, - ) + .over_with_options([col("chars")], WindowMapping::Join) .alias("foo")]) .collect()?; diff --git a/crates/polars/tests/it/lazy/folds.rs b/crates/polars/tests/it/lazy/folds.rs index 3a13b814b2ff..1b3f908fa914 100644 --- a/crates/polars/tests/it/lazy/folds.rs +++ b/crates/polars/tests/it/lazy/folds.rs @@ -21,7 +21,7 @@ fn test_fold_wildcard() -> PolarsResult<()> { // test if we don't panic due to wildcard let _out = df1 .lazy() - .select([all_horizontal([col("*").is_not_null()])]) + .select([polars_lazy::dsl::all_horizontal([col("*").is_not_null()])]) .collect()?; Ok(()) } diff --git a/crates/polars/tests/it/lazy/group_by_dynamic.rs b/crates/polars/tests/it/lazy/group_by_dynamic.rs index 1fa5ec6a396f..6c65a4041ec8 100644 --- a/crates/polars/tests/it/lazy/group_by_dynamic.rs +++ b/crates/polars/tests/it/lazy/group_by_dynamic.rs @@ -48,7 +48,7 @@ fn test_group_by_dynamic_week_bounds() -> PolarsResult<()> { period: Duration::parse("1w"), offset: Duration::parse("0w"), closed_window: ClosedWindow::Left, - truncate: false, + label: Label::DataPoint, include_boundaries: true, start_by: StartBy::DataPoint, ..Default::default() diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs index 36e63d64773d..d9aa60870e58 100644 --- a/crates/polars/tests/it/lazy/predicate_queries.rs +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -1,6 +1,6 @@ // used only if feature="is_in", feature="dtype-categorical" #[allow(unused_imports)] -use polars_core::{with_string_cache, SINGLE_LOCK}; +use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; use super::*; @@ -132,24 +132,22 @@ fn test_is_in_categorical_3420() -> PolarsResult<()> { ]?; let _guard = SINGLE_LOCK.lock(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); - let _: PolarsResult<_> = with_string_cache(|| { - let s = Series::new("x", ["a", "b", "c"]).strict_cast(&DataType::Categorical(None))?; - let out = df - .lazy() - .with_column(col("a").strict_cast(DataType::Categorical(None))) - .filter(col("a").is_in(lit(s).alias("x"))) - .collect()?; - - let mut expected = df![ - "a" => ["a", "b", "c"], - "b" => [1, 2, 3] - ]?; - expected.try_apply("a", |s| s.cast(&DataType::Categorical(None)))?; - assert!(out.frame_equal(&expected)); + let s = Series::new("x", ["a", "b", "c"]).strict_cast(&DataType::Categorical(None))?; + let out = df + .lazy() + .with_column(col("a").strict_cast(DataType::Categorical(None))) + .filter(col("a").is_in(lit(s).alias("x"))) + .collect()?; - Ok(()) - }); + let mut expected = df![ + "a" => ["a", "b", "c"], + "b" => [1, 2, 3] + ]?; + expected.try_apply("a", |s| s.cast(&DataType::Categorical(None)))?; + assert!(out.frame_equal(&expected)); Ok(()) } diff --git a/crates/polars/tests/it/lazy/queries.rs b/crates/polars/tests/it/lazy/queries.rs index d0af51efaab3..90a576720a14 100644 --- a/crates/polars/tests/it/lazy/queries.rs +++ b/crates/polars/tests/it/lazy/queries.rs @@ -76,7 +76,7 @@ fn test_special_group_by_schemas() -> PolarsResult<()> { every: Duration::parse("2i"), period: Duration::parse("2i"), offset: Duration::parse("0i"), - truncate: false, + label: Label::DataPoint, include_boundaries: false, closed_window: ClosedWindow::Left, ..Default::default() @@ -225,7 +225,7 @@ fn test_apply_multiple_columns() -> PolarsResult<()> { .collect()?; let out = out.column("A")?; - let out = out.list()?.get(1).unwrap(); + let out = out.list()?.get_as_series(1).unwrap(); let out = out.i32()?; assert_eq!(Vec::from(out), &[Some(16)]); diff --git a/docs/_build/API_REFERENCE_LINKS.yml b/docs/_build/API_REFERENCE_LINKS.yml new file mode 100644 index 000000000000..73b437cc9324 --- /dev/null +++ b/docs/_build/API_REFERENCE_LINKS.yml @@ -0,0 +1,374 @@ +python: + DataFrame: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/index.html + Categorical: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.Categorical.html + Series: https://pola-rs.github.io/polars/py-polars/html/reference/series/index.html + select: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.select.html + filter: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.filter.html + with_columns: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.with_columns.html + group_by: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by.html + join: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.join.html + hstack: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.hstack.html + read_csv: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html + write_csv: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_csv.html + read_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_json.html + write_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_json.html + read_ipc: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ipc.html + min: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.min.html + max: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.max.html + value_counts: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.value_counts.html + unnest: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.unnest.html + struct: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.struct.html + is_duplicated: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.is_duplicated.html + sample: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.sample.html + head: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.head.html + tail: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.tail.html + describe: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.describe.html + col: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.col.html + sort: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.sort.html + scan_csv: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_csv.html + collect: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.collect.html + fold: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.fold.html + concat_str: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.concat_str.html + str.split: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.split.html + Expr.list: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/list.html + element: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.element.html + all: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.all.html + exclude: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.exclude.html + alias: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.alias.html + prefix: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.prefix.html + suffix: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.suffix.html + map_alias: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.map_alias.html + n_unique: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.n_unique.html + approx_n_unique: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.approx_n_unique.html + when: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.when.html + concat_list: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.concat_list.html + list.eval: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.list.eval.html + null_count: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.null_count.html + is_null: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.is_null.html + fill_null: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.fill_null.html + interpolate: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.interpolate.html + fill_nan: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.fill_nan.html + operators: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/operators.html + map: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.map.html + apply: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.apply.html + over: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.over.html + implode: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.implode.html + DataFrame.explode: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.explode.html + read_database_connectorx: + name: read_database + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_database.html + feature_flags: ['connectorx'] + read_database: + name: read_database + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_database.html + write_database: + name: write_database + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_database.html + read_database_uri: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_database_uri.html + read_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html + write_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_parquet.html + scan_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_parquet.html + read_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_json.html + read_ndjson: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ndjson.html + write_ndjson: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_ndjson.html + write_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_json.html + scan_ndjson: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_ndjson.html + scan_pyarrow_dataset: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_pyarrow_dataset.html + from_arrow: + name: from_arrow + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.from_arrow.html + feature_flags: ['fsspec','pyarrow'] + show_graph: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.show_graph.html + lazy: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.lazy.html + explain: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.explain.html + fetch: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.fetch.html + SQLContext: https://pola-rs.github.io/polars/py-polars/html/reference/sql + SQLregister: + name: register + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.register.html#polars.SQLContext.register + SQLregister_many: + name: register_many + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.register_many.html + SQLquery: + name: query + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.query.html + SQLexecute: + name: execute + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.execute.html + join_asof: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.join_asof.html + concat: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.concat.html + pivot: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.pivot.html + melt: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html + is_between: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.is_between.html + + date_range: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.date_range.html + upsample: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.upsample.html + group_by_dynamic: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html + cast: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.cast.html + np.log: + name: log + link: https://numpy.org/doc/stable/reference/generated/numpy.log.html + feature_flags: ['numpy'] + Array: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.Array.html + Series.arr: https://pola-rs.github.io/polars/py-polars/html/reference/series/array.html + Series.dt.day: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.dt.day.html + + selectors: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html + cs.numeric: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.numeric + cs.by_name: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.by_name + cs.first: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.first + cs.temporal: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.temporal + cs.contains: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.contains + cs.matches: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.matches + is_selector: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.is_selector + selector_column_names: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.selector_column_names + + dt.convert_time_zone: + name: dt.convert_time_zone + link: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.convert_time_zone.html + feature_flags: ['timezone'] + dt.replace_time_zone: + name: dt.replace_time_zone + link: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.replace_time_zone.html + feature_flags: ['timezone'] + dt.to_string: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.to_string.html + dt.year: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.year.html + + str.starts_with: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.starts_with.html + str.ends_with: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.ends_with.html + str.extract: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.extract.html + str.extract_all: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.extract_all.html + str.contains: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.contains.html + str.replace: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.replace.html + str.replace_all: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.replace_all.html + str.to_datetime: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.to_datetime.html + str.to_date: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.to_date.html + str.len_chars: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.len_chars.html + str.len_bytes: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.len_bytes.html + + struct.field: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.struct.field.html + struct.rename_fields: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.struct.rename_fields.html + +rust: + DataFrame: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html + Series: https://pola-rs.github.io/polars/docs/rust/dev/polars/series/struct.Series.html + Categorical: + name: Categorical + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/enum.DataType.html#variant.Categorical + feature_flags: ['dtype-categorical'] + select: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.select + filter: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.filter + with_columns: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.with_columns + group_by: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.group_by + group_by_dynamic: + name: group_by_dynamic + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.group_by_dynamic + feature_flags: [dynamic_group_by] + join: https://pola-rs.github.io/polars/docs/rust/dev/polars_core/frame/hash_join/index.html + hstack: https://pola-rs.github.io/polars/docs/rust/dev/polars_core/frame/struct.DataFrame.html#method.hstack + concat: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/functions/fn.concat.html + SQLContext: https://pola-rs.github.io/polars/py-polars/html/reference/sql.html + + operators: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Operator.html + + Array: https://pola-rs.github.io/polars/docs/rust/dev/polars/datatypes/enum.DataType.html#variant.Array + + DataFrame.explode: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.explode + pivot: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/pivot/fn.pivot.html + melt: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.melt + upsample: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.upsample + join_asof: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/trait.AsofJoin.html#method.join_asof + unnest: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.unnest + + read_csv: + name: CsvReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/csv/struct.CsvReader.html + feature_flags: ['csv'] + scan_csv: + name: LazyCsvReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/struct.LazyCsvReader.html + feature_flags: ['csv'] + write_csv: + name: CsvWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/csv/struct.CsvWriter.html + feature_flags: ['csv'] + read_json: + name: JsonReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/json/struct.JsonReader.html + feature_flags: ['json'] + read_ndjson: + name: JsonLineReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/ndjson/core/struct.JsonLineReader.html + feature_flags: ['json'] + write_json: + name: JsonWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/json/struct.JsonWriter.html + feature_flags: ['json'] + write_ndjson: + name: JsonWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/json/struct.JsonWriter.html + feature_flags: ['json'] + scan_ndjson: + name: LazyJsonLineReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyJsonLineReader.html + feature_flags: ['json'] + read_parquet: + name: ParquetReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/parquet/struct.ParquetReader.html + feature_flags: ['parquet'] + write_parquet: + name: ParquetWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/parquet/struct.ParquetWriter.html + feature_flags: ['parquet'] + scan_parquet: + name: scan_parquet + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/struct.LazyFrame.html#method.scan_parquet + feature_flags: ['parquet'] + read_ipc: + name: IpcReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/prelude/struct.IpcReader.html + feature_flags: ['ipc'] + scan_pyarrow_dataset: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_pyarrow_dataset.html + + min: https://pola-rs.github.io/polars/docs/rust/dev/polars/series/struct.Series.html#method.min + max: https://pola-rs.github.io/polars/docs/rust/dev/polars/series/struct.Series.html#method.max + struct: + name: Struct + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/datatypes/enum.DataType.html#variant.Struct + feature_flags: ['dtype-struct'] + implode: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.implode + sample: + name: sample_n + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.sample_n + head: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.head + tail: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.tail + describe: + name: describe + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.describe + feature_flags: ['describe'] + collect: + name: collect + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/struct.LazyFrame.html#method.collect + feature_flags: ['streaming'] + + col: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.col.html + element: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.col.html + all: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/functions/fn.all.html + when: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.when.html + + sort: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.sort + arr.eval: + name: arr + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.arr + feature_flags: ['list_eval','rank'] + fold: + name: fold_exprs + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.fold_exprs.html + concat_str: + name: concat_str + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.concat_str.html + feature_flags: ['concat_str'] + concat_list: + name: concat_lst + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.concat_lst.html + map: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.map + apply: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.apply + over: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.over + + alias: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.alias + approx_n_unique: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.approx_n_unique + cast: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.cast + exclude: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.exclude + fill_nan: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.fill_nan + fill_null: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.fill_null + n_unique: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.n_unique + null_count: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.null_count + interpolate: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.interpolate + is_between: https://github.com/pola-rs/polars/issues/11285 + is_duplicated: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.is_duplicated + is_null: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/enum.Expr.html#method.is_null + value_counts: + name: value_counts + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.value_counts + feature_flags: [dtype-struct] + + Expr.list: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/struct.ListNameSpace.html + Series.arr: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/struct.ArrayNameSpace.html + + date_range: + name: date_range + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/functions/fn.date_range.html + feature_flags: [range, dtype-date] + + selectors: https://github.com/pola-rs/polars/issues/10594 + cs.numeric: https://github.com/pola-rs/polars/issues/10594 + cs.by_name: https://github.com/pola-rs/polars/issues/10594 + cs.first: https://github.com/pola-rs/polars/issues/10594 + cs.temporal: https://github.com/pola-rs/polars/issues/10594 + cs.contains: https://github.com/pola-rs/polars/issues/10594 + cs.matches: https://github.com/pola-rs/polars/issues/10594 + is_selector: https://github.com/pola-rs/polars/issues/10594 + selector_column_names: https://github.com/pola-rs/polars/issues/10594 + + dt.convert_time_zone: + name: dt.convert_time_zone + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.convert_time_zone + feature_flags: [timezones] + dt.replace_time_zone: + name: dt.replace_time_zone + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.replace_time_zone + feature_flags: [timezones] + dt.to_string: + name: dt.to_string + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.to_string + feature_flags: [temporal] + dt.year: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.year + Series.dt.day: + name: dt.day + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.day + feature_flags: [temporal] + + list.eval: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/trait.ListNameSpaceExtension.html#method.eval + + str.contains: + name: str.contains + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.contains + feature_flags: [regex] + str.extract: + name: str.extract + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.extract + str.extract_all: + name: str.extract_all + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.extract_all + str.replace: + name: str.replace + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.replace + feature_flags: [regex] + str.replace_all: + name: str.replace_all + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.replace_all + feature_flags: [regex] + str.starts_with: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.starts_with + str.ends_with: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.ends_with + str.split: + name: str.split + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.split + str.to_date: + name: str.replace_all + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.to_date + feature_flags: [dtype-date] + str.to_datetime: + name: str.replace_all + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.to_datetime + feature_flags: [dtype-datetime] + str.len_chars: + name: str.len_chars + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.len_chars + str.len_bytes: + name: str.len_bytes + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.len_bytes + + struct.rename_fields: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/struct.StructNameSpace.html#method.rename_fields + struct.field: + name: struct.field_by_name + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/struct.StructNameSpace.html#method.field_by_name diff --git a/docs/_build/assets/logo.png b/docs/_build/assets/logo.png new file mode 100644 index 000000000000..9b5486edce3b Binary files /dev/null and b/docs/_build/assets/logo.png differ diff --git a/docs/_build/css/extra.css b/docs/_build/css/extra.css new file mode 100644 index 000000000000..420db3966780 --- /dev/null +++ b/docs/_build/css/extra.css @@ -0,0 +1,64 @@ +:root { + --md-primary-fg-color: #0B7189 ; + --md-primary-fg-color--light: #C2CCD6; + --md-primary-fg-color--dark: #103547; + --md-text-font: 'Proxima Nova', sans-serif; +} + + +span .md-typeset .emojione, .md-typeset .gemoji, .md-typeset .twemoji { + vertical-align: text-bottom; +} + +@font-face { + font-family: 'Proxima Nova', sans-serif; + src: 'https://fonts.cdnfonts.com/css/proxima-nova-2' +} + +:root { + --md-code-font: "Source Code Pro" !important; +} + +.contributor_icon { + height:40px; + width:40px; + border-radius: 20px; + margin: 0 5px; +} + +.feature-flag{ + background-color: rgba(255, 245, 214,.5); + border: none; + padding: 0px 5px; + text-align: center; + text-decoration: none; + display: inline-block; + margin: 4px 2px; + cursor: pointer; + font-size: .85em; +} + +[data-md-color-scheme=slate] .feature-flag{ + background-color:var(--md-code-bg-color); +} +.md-typeset ol li, .md-typeset ul li{ + margin-bottom: 0em !important; +} + +:root { + --md-admonition-icon--rust: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 512 512'%3E%3C!--! Font Awesome Free 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2023 Fonticons, Inc.--%3E%3Cpath d='m508.52 249.75-21.82-13.51c-.17-2-.34-3.93-.55-5.88l18.72-17.5a7.35 7.35 0 0 0-2.44-12.25l-24-9c-.54-1.88-1.08-3.78-1.67-5.64l15-20.83a7.35 7.35 0 0 0-4.79-11.54l-25.42-4.15c-.9-1.73-1.79-3.45-2.73-5.15l10.68-23.42a7.35 7.35 0 0 0-6.95-10.39l-25.82.91q-1.79-2.22-3.61-4.4L439 81.84a7.36 7.36 0 0 0-8.84-8.84L405 78.93q-2.17-1.83-4.4-3.61l.91-25.82a7.35 7.35 0 0 0-10.39-7L367.7 53.23c-1.7-.94-3.43-1.84-5.15-2.73l-4.15-25.42a7.35 7.35 0 0 0-11.54-4.79L326 35.26c-1.86-.59-3.75-1.13-5.64-1.67l-9-24a7.35 7.35 0 0 0-12.25-2.44l-17.5 18.72c-1.95-.21-3.91-.38-5.88-.55L262.25 3.48a7.35 7.35 0 0 0-12.5 0L236.24 25.3c-2 .17-3.93.34-5.88.55l-17.5-18.72a7.35 7.35 0 0 0-12.25 2.44l-9 24c-1.89.55-3.79 1.08-5.66 1.68l-20.82-15a7.35 7.35 0 0 0-11.54 4.79l-4.15 25.41c-1.73.9-3.45 1.79-5.16 2.73l-23.4-10.63a7.35 7.35 0 0 0-10.39 7l.92 25.81c-1.49 1.19-3 2.39-4.42 3.61L81.84 73A7.36 7.36 0 0 0 73 81.84L78.93 107c-1.23 1.45-2.43 2.93-3.62 4.41l-25.81-.91a7.42 7.42 0 0 0-6.37 3.26 7.35 7.35 0 0 0-.57 7.13l10.66 23.41c-.94 1.7-1.83 3.43-2.73 5.16l-25.41 4.14a7.35 7.35 0 0 0-4.79 11.54l15 20.82c-.59 1.87-1.13 3.77-1.68 5.66l-24 9a7.35 7.35 0 0 0-2.44 12.25l18.72 17.5c-.21 1.95-.38 3.91-.55 5.88l-21.86 13.5a7.35 7.35 0 0 0 0 12.5l21.82 13.51c.17 2 .34 3.92.55 5.87l-18.72 17.5a7.35 7.35 0 0 0 2.44 12.25l24 9c.55 1.89 1.08 3.78 1.68 5.65l-15 20.83a7.35 7.35 0 0 0 4.79 11.54l25.42 4.15c.9 1.72 1.79 3.45 2.73 5.14l-10.63 23.43a7.35 7.35 0 0 0 .57 7.13 7.13 7.13 0 0 0 6.37 3.26l25.83-.91q1.77 2.22 3.6 4.4L73 430.16a7.36 7.36 0 0 0 8.84 8.84l25.16-5.93q2.18 1.83 4.41 3.61l-.92 25.82a7.35 7.35 0 0 0 10.39 6.95l23.43-10.68c1.69.94 3.42 1.83 5.14 2.73l4.15 25.42a7.34 7.34 0 0 0 11.54 4.78l20.83-15c1.86.6 3.76 1.13 5.65 1.68l9 24a7.36 7.36 0 0 0 12.25 2.44l17.5-18.72c1.95.21 3.92.38 5.88.55l13.51 21.82a7.35 7.35 0 0 0 12.5 0l13.51-21.82c2-.17 3.93-.34 5.88-.56l17.5 18.73a7.36 7.36 0 0 0 12.25-2.44l9-24c1.89-.55 3.78-1.08 5.65-1.68l20.82 15a7.34 7.34 0 0 0 11.54-4.78l4.15-25.42c1.72-.9 3.45-1.79 5.15-2.73l23.42 10.68a7.35 7.35 0 0 0 10.39-6.95l-.91-25.82q2.22-1.79 4.4-3.61l25.15 5.93a7.36 7.36 0 0 0 8.84-8.84L433.07 405q1.83-2.17 3.61-4.4l25.82.91a7.23 7.23 0 0 0 6.37-3.26 7.35 7.35 0 0 0 .58-7.13l-10.68-23.42c.94-1.7 1.83-3.43 2.73-5.15l25.42-4.15a7.35 7.35 0 0 0 4.79-11.54l-15-20.83c.59-1.87 1.13-3.76 1.67-5.65l24-9a7.35 7.35 0 0 0 2.44-12.25l-18.72-17.5c.21-1.95.38-3.91.55-5.87l21.82-13.51a7.35 7.35 0 0 0 0-12.5Zm-151 129.08A13.91 13.91 0 0 0 341 389.51l-7.64 35.67a187.51 187.51 0 0 1-156.36-.74l-7.64-35.66a13.87 13.87 0 0 0-16.46-10.68l-31.51 6.76a187.38 187.38 0 0 1-16.26-19.21H258.3c1.72 0 2.89-.29 2.89-1.91v-54.19c0-1.57-1.17-1.91-2.89-1.91h-44.83l.05-34.35H262c4.41 0 23.66 1.28 29.79 25.87 1.91 7.55 6.17 32.14 9.06 40 2.89 8.82 14.6 26.46 27.1 26.46H407a187.3 187.3 0 0 1-17.34 20.09Zm25.77 34.49A15.24 15.24 0 1 1 368 398.08h.44a15.23 15.23 0 0 1 14.8 15.24Zm-225.62-.68a15.24 15.24 0 1 1-15.25-15.25h.45a15.25 15.25 0 0 1 14.75 15.25Zm-88.1-178.49 32.83-14.6a13.88 13.88 0 0 0 7.06-18.33L102.69 186h26.56v119.73h-53.6a187.65 187.65 0 0 1-6.08-71.58Zm-11.26-36.06a15.24 15.24 0 0 1 15.23-15.25H74a15.24 15.24 0 1 1-15.67 15.24Zm155.16 24.49.05-35.32h63.26c3.28 0 23.07 3.77 23.07 18.62 0 12.29-15.19 16.7-27.68 16.7ZM399 306.71c-9.8 1.13-20.63-4.12-22-10.09-5.78-32.49-15.39-39.4-30.57-51.4 18.86-11.95 38.46-29.64 38.46-53.26 0-25.52-17.49-41.59-29.4-49.48-16.76-11-35.28-13.23-40.27-13.23h-198.9a187.49 187.49 0 0 1 104.89-59.19l23.47 24.6a13.82 13.82 0 0 0 19.6.44l26.26-25a187.51 187.51 0 0 1 128.37 91.43l-18 40.57a14 14 0 0 0 7.09 18.33l34.59 15.33a187.12 187.12 0 0 1 .4 32.54h-19.28c-1.91 0-2.69 1.27-2.69 3.13v8.82C421 301 409.31 305.58 399 306.71ZM240 60.21A15.24 15.24 0 0 1 255.21 45h.45A15.24 15.24 0 1 1 240 60.21ZM436.84 214a15.24 15.24 0 1 1 0-30.48h.44a15.24 15.24 0 0 1-.44 30.48Z'/%3E%3C/svg%3E"); + } + .md-typeset .admonition.rust, + .md-typeset details.rust { + border-color: rgb(205, 121, 44); + } + .md-typeset .rust > .admonition-title, + .md-typeset .rust > summary { + background-color: rgb(205, 121, 44,.1); + } + .md-typeset .rust > .admonition-title::before, + .md-typeset .rust > summary::before { + background-color:rgb(205, 121, 44); + -webkit-mask-image: var(--md-admonition-icon--rust); + mask-image: var(--md-admonition-icon--rust); + } \ No newline at end of file diff --git a/docs/_build/overrides/404.html b/docs/_build/overrides/404.html new file mode 100644 index 000000000000..ee9b8faa2aba --- /dev/null +++ b/docs/_build/overrides/404.html @@ -0,0 +1,222 @@ +{% extends "main.html" %} +{% block content %} +

+ +{% endblock %} diff --git a/docs/_build/scripts/macro.py b/docs/_build/scripts/macro.py new file mode 100644 index 000000000000..d93d5170adec --- /dev/null +++ b/docs/_build/scripts/macro.py @@ -0,0 +1,156 @@ +from collections import OrderedDict +import os +from typing import List, Optional, Set +import yaml +import logging + + +# Supported Languages and their metadata +LANGUAGES = OrderedDict( + python={ + "extension": ".py", + "display_name": "Python", + "icon_name": "python", + "code_name": "python", + }, + rust={ + "extension": ".rs", + "display_name": "Rust", + "icon_name": "rust", + "code_name": "rust", + }, +) + +# Load all links to reference docs +with open("docs/_build/API_REFERENCE_LINKS.yml", "r") as f: + API_REFERENCE_LINKS = yaml.load(f, Loader=yaml.CLoader) + + +def create_feature_flag_link(feature_name: str) -> str: + """Create a feature flag warning telling the user to activate a certain feature before running the code + + Args: + feature_name (str): name of the feature + + Returns: + str: Markdown formatted string with a link and the feature flag message + """ + return f'[:material-flag-plus: Available on feature {feature_name}](/polars/user-guide/installation/#feature-flags "To use this functionality enable the feature flag {feature_name}"){{.feature-flag}}' + + +def create_feature_flag_links(language: str, api_functions: List[str]) -> List[str]: + """Generate markdown feature flags for the code tas based on the api_functions. + It checks for the key feature_flag in the configuration yaml for the function and if it exists print out markdown + + Args: + language (str): programming languages + api_functions (List[str]): Api functions that are called + + Returns: + List[str]: Per unique feature flag a markdown formatted string for the feature flag + """ + api_functions_info = [ + info + for f in api_functions + if (info := API_REFERENCE_LINKS.get(language).get(f)) + ] + feature_flags: Set[str] = { + flag + for info in api_functions_info + if type(info) == dict and info.get("feature_flags") + for flag in info.get("feature_flags") + } + + return [create_feature_flag_link(flag) for flag in feature_flags] + + +def create_api_function_link(language: str, function_key: str) -> Optional[str]: + """Create an API link in markdown with an icon of the YAML file + + Args: + language (str): programming language + function_key (str): Key to the specific function + + Returns: + str: If the function is found than the link else None + """ + info = API_REFERENCE_LINKS.get(language, {}).get(function_key) + + if info is None: + logging.warning(f"Could not find {function_key} for language {language}") + return None + else: + # Either be a direct link + if type(info) == str: + return f"[:material-api: `{function_key}`]({info})" + else: + function_name = info["name"] + link = info["link"] + return f"[:material-api: `{function_name}`]({link})" + + +def code_tab( + base_path: str, + section: Optional[str], + language_info: dict, + api_functions: List[str], +) -> str: + """Generate a single tab for the code block corresponding to a specific language. + It gets the code at base_path and possible section and pretty prints markdown for it + + Args: + base_path (str): path where the code is located + section (str, optional): section in the code that should be displayed + language_info (dict): Language specific information (icon name, display name, ...) + api_functions (List[str]): List of api functions which should be linked + + Returns: + str: A markdown formatted string represented a single tab + """ + language = language_info["code_name"] + + # Create feature flags + feature_flags_links = create_feature_flag_links(language, api_functions) + + # Create API Links if they are defined in the YAML + api_functions = [ + link for f in api_functions if (link := create_api_function_link(language, f)) + ] + language_headers = " ·".join(api_functions + feature_flags_links) + + # Create path for Snippets extension + snippets_file_name = f"{base_path}:{section}" if section else f"{base_path}" + + # See Content Tabs for details https://squidfunk.github.io/mkdocs-material/reference/content-tabs/ + return f"""=== \":fontawesome-brands-{language_info['icon_name']}: {language_info['display_name']}\" + {language_headers} + ```{language} + --8<-- \"{snippets_file_name}\" + ``` + """ + + +def define_env(env): + @env.macro + def code_block( + path: str, section: str = None, api_functions: List[str] = None + ) -> str: + """Dynamically generate a code block for the code located under {language}/path + + Args: + path (str): base_path for each language + section (str, optional): Optional segment within the code file. Defaults to None. + api_functions (List[str], optional): API functions that should be linked. Defaults to None. + Returns: + str: Markdown tabbed code block with possible links to api functions and feature flags + """ + result = [] + + for language, info in LANGUAGES.items(): + base_path = f"{language}/{path}{info['extension']}" + full_path = "docs/src/" + base_path + # Check if file exists for the language + if os.path.exists(full_path): + result.append(code_tab(base_path, section, info, api_functions)) + + return "\n".join(result) diff --git a/docs/_build/scripts/people.py b/docs/_build/scripts/people.py new file mode 100644 index 000000000000..10186549d4d8 --- /dev/null +++ b/docs/_build/scripts/people.py @@ -0,0 +1,41 @@ +import itertools +from github import Github, Auth +import os + +token = os.getenv("GITHUB_TOKEN") +auth = Auth.Token(token) if token else None +g = Github(auth=auth) + +ICON_TEMPLATE = "[![{login}]({avatar_url}){{.contributor_icon}}]({html_url})" + + +def get_people_md(): + repo = g.get_repo("pola-rs/polars") + contributors = repo.get_contributors() + with open("./docs/people.md", "w") as f: + for c in itertools.islice(contributors, 50): + # We love dependabot, but he doesn't need a spot on our website + if c.login == "dependabot[bot]": + continue + + f.write( + ICON_TEMPLATE.format( + login=c.login, + avatar_url=c.avatar_url, + html_url=c.html_url, + ) + + "\n" + ) + + +def on_startup(command, dirty): + """Mkdocs hook to autogenerate docs/people.md on startup""" + try: + get_people_md() + except Exception as e: + msg = f"WARNING:{__file__}: Could not generate docs/people.md. Got error: {str(e)}" + print(msg) + + +if __name__ == "__main__": + get_people_md() diff --git a/docs/_build/snippets/under_construction.md b/docs/_build/snippets/under_construction.md new file mode 100644 index 000000000000..c4bb56a735af --- /dev/null +++ b/docs/_build/snippets/under_construction.md @@ -0,0 +1,4 @@ +!!! warning ":construction: Under Construction :construction: " + + This section is still under development. Want to help out? Consider contributing and making a [pull request](https://github.com/pola-rs/polars) to our repository. + Please read our [Contribution Guidelines](https://github.com/pola-rs/polars/blob/main/CONTRIBUTING.md) on how to proceed. diff --git a/docs/data/apple_stock.csv b/docs/data/apple_stock.csv new file mode 100644 index 000000000000..6c3f9752d587 --- /dev/null +++ b/docs/data/apple_stock.csv @@ -0,0 +1,101 @@ +Date,Close +1981-02-23,24.62 +1981-05-06,27.38 +1981-05-18,28.0 +1981-09-25,14.25 +1982-07-08,11.0 +1983-01-03,28.5 +1983-04-06,40.0 +1983-10-03,23.13 +1984-07-27,27.13 +1984-08-17,27.5 +1984-08-24,28.12 +1985-05-07,20.0 +1985-09-03,14.75 +1985-12-06,19.75 +1986-03-12,24.75 +1986-04-09,27.13 +1986-04-17,29.0 +1986-09-17,34.25 +1986-11-26,40.5 +1987-02-25,69.13 +1987-04-15,71.0 +1988-02-23,42.75 +1988-03-07,46.88 +1988-03-23,42.5 +1988-12-12,38.5 +1988-12-19,40.75 +1989-04-17,39.25 +1989-11-13,46.5 +1990-11-23,36.38 +1991-03-22,63.25 +1991-05-17,47.0 +1991-06-03,49.25 +1991-06-18,42.12 +1992-06-25,45.62 +1992-10-12,44.0 +1993-07-06,37.75 +1993-09-15,24.5 +1993-09-30,23.37 +1993-11-09,30.12 +1994-01-24,35.0 +1994-03-15,37.62 +1994-06-27,26.25 +1994-07-08,27.06 +1994-12-21,38.38 +1995-07-06,47.0 +1995-10-16,36.13 +1995-11-17,40.13 +1995-12-12,38.0 +1996-01-31,27.63 +1996-02-05,29.25 +1996-07-15,17.19 +1996-09-20,22.87 +1996-12-23,23.25 +1997-03-17,16.5 +1997-05-09,17.06 +1997-08-06,26.31 +1997-09-30,21.69 +1998-02-09,19.19 +1998-03-12,27.0 +1998-05-07,30.19 +1998-05-12,30.12 +1999-07-09,55.63 +1999-12-08,110.06 +2000-01-14,100.44 +2000-06-27,51.75 +2000-07-05,51.62 +2000-07-19,52.69 +2000-08-07,47.94 +2000-08-28,58.06 +2000-09-26,51.44 +2001-03-02,19.25 +2001-12-10,22.54 +2002-01-25,23.25 +2002-03-07,24.38 +2002-08-16,15.81 +2002-10-03,14.3 +2003-11-18,20.41 +2004-02-26,23.04 +2004-03-08,26.0 +2004-09-22,36.92 +2005-06-24,37.76 +2005-12-07,73.95 +2005-12-22,74.02 +2006-06-22,59.58 +2006-11-28,91.81 +2007-08-13,127.79 +2007-12-04,179.81 +2007-12-31,198.08 +2008-05-09,183.45 +2008-06-27,170.09 +2009-08-03,166.43 +2010-04-01,235.97 +2010-12-10,320.56 +2011-04-28,346.75 +2011-12-02,389.7 +2012-05-16,546.08 +2012-12-04,575.85 +2013-07-05,417.42 +2013-11-07,512.49 +2014-02-25,522.06 \ No newline at end of file diff --git a/docs/data/iris.csv b/docs/data/iris.csv new file mode 100644 index 000000000000..d6b466b31892 --- /dev/null +++ b/docs/data/iris.csv @@ -0,0 +1,151 @@ +sepal_length,sepal_width,petal_length,petal_width,species +5.1,3.5,1.4,.2,Setosa +4.9,3,1.4,.2,Setosa +4.7,3.2,1.3,.2,Setosa +4.6,3.1,1.5,.2,Setosa +5,3.6,1.4,.2,Setosa +5.4,3.9,1.7,.4,Setosa +4.6,3.4,1.4,.3,Setosa +5,3.4,1.5,.2,Setosa +4.4,2.9,1.4,.2,Setosa +4.9,3.1,1.5,.1,Setosa +5.4,3.7,1.5,.2,Setosa +4.8,3.4,1.6,.2,Setosa +4.8,3,1.4,.1,Setosa +4.3,3,1.1,.1,Setosa +5.8,4,1.2,.2,Setosa +5.7,4.4,1.5,.4,Setosa +5.4,3.9,1.3,.4,Setosa +5.1,3.5,1.4,.3,Setosa +5.7,3.8,1.7,.3,Setosa +5.1,3.8,1.5,.3,Setosa +5.4,3.4,1.7,.2,Setosa +5.1,3.7,1.5,.4,Setosa +4.6,3.6,1,.2,Setosa +5.1,3.3,1.7,.5,Setosa +4.8,3.4,1.9,.2,Setosa +5,3,1.6,.2,Setosa +5,3.4,1.6,.4,Setosa +5.2,3.5,1.5,.2,Setosa +5.2,3.4,1.4,.2,Setosa +4.7,3.2,1.6,.2,Setosa +4.8,3.1,1.6,.2,Setosa +5.4,3.4,1.5,.4,Setosa +5.2,4.1,1.5,.1,Setosa +5.5,4.2,1.4,.2,Setosa +4.9,3.1,1.5,.2,Setosa +5,3.2,1.2,.2,Setosa +5.5,3.5,1.3,.2,Setosa +4.9,3.6,1.4,.1,Setosa +4.4,3,1.3,.2,Setosa +5.1,3.4,1.5,.2,Setosa +5,3.5,1.3,.3,Setosa +4.5,2.3,1.3,.3,Setosa +4.4,3.2,1.3,.2,Setosa +5,3.5,1.6,.6,Setosa +5.1,3.8,1.9,.4,Setosa +4.8,3,1.4,.3,Setosa +5.1,3.8,1.6,.2,Setosa +4.6,3.2,1.4,.2,Setosa +5.3,3.7,1.5,.2,Setosa +5,3.3,1.4,.2,Setosa +7,3.2,4.7,1.4,Versicolor +6.4,3.2,4.5,1.5,Versicolor +6.9,3.1,4.9,1.5,Versicolor +5.5,2.3,4,1.3,Versicolor +6.5,2.8,4.6,1.5,Versicolor +5.7,2.8,4.5,1.3,Versicolor +6.3,3.3,4.7,1.6,Versicolor +4.9,2.4,3.3,1,Versicolor +6.6,2.9,4.6,1.3,Versicolor +5.2,2.7,3.9,1.4,Versicolor +5,2,3.5,1,Versicolor +5.9,3,4.2,1.5,Versicolor +6,2.2,4,1,Versicolor +6.1,2.9,4.7,1.4,Versicolor +5.6,2.9,3.6,1.3,Versicolor +6.7,3.1,4.4,1.4,Versicolor +5.6,3,4.5,1.5,Versicolor +5.8,2.7,4.1,1,Versicolor +6.2,2.2,4.5,1.5,Versicolor +5.6,2.5,3.9,1.1,Versicolor +5.9,3.2,4.8,1.8,Versicolor +6.1,2.8,4,1.3,Versicolor +6.3,2.5,4.9,1.5,Versicolor +6.1,2.8,4.7,1.2,Versicolor +6.4,2.9,4.3,1.3,Versicolor +6.6,3,4.4,1.4,Versicolor +6.8,2.8,4.8,1.4,Versicolor +6.7,3,5,1.7,Versicolor +6,2.9,4.5,1.5,Versicolor +5.7,2.6,3.5,1,Versicolor +5.5,2.4,3.8,1.1,Versicolor +5.5,2.4,3.7,1,Versicolor +5.8,2.7,3.9,1.2,Versicolor +6,2.7,5.1,1.6,Versicolor +5.4,3,4.5,1.5,Versicolor +6,3.4,4.5,1.6,Versicolor +6.7,3.1,4.7,1.5,Versicolor +6.3,2.3,4.4,1.3,Versicolor +5.6,3,4.1,1.3,Versicolor +5.5,2.5,4,1.3,Versicolor +5.5,2.6,4.4,1.2,Versicolor +6.1,3,4.6,1.4,Versicolor +5.8,2.6,4,1.2,Versicolor +5,2.3,3.3,1,Versicolor +5.6,2.7,4.2,1.3,Versicolor +5.7,3,4.2,1.2,Versicolor +5.7,2.9,4.2,1.3,Versicolor +6.2,2.9,4.3,1.3,Versicolor +5.1,2.5,3,1.1,Versicolor +5.7,2.8,4.1,1.3,Versicolor +6.3,3.3,6,2.5,Virginica +5.8,2.7,5.1,1.9,Virginica +7.1,3,5.9,2.1,Virginica +6.3,2.9,5.6,1.8,Virginica +6.5,3,5.8,2.2,Virginica +7.6,3,6.6,2.1,Virginica +4.9,2.5,4.5,1.7,Virginica +7.3,2.9,6.3,1.8,Virginica +6.7,2.5,5.8,1.8,Virginica +7.2,3.6,6.1,2.5,Virginica +6.5,3.2,5.1,2,Virginica +6.4,2.7,5.3,1.9,Virginica +6.8,3,5.5,2.1,Virginica +5.7,2.5,5,2,Virginica +5.8,2.8,5.1,2.4,Virginica +6.4,3.2,5.3,2.3,Virginica +6.5,3,5.5,1.8,Virginica +7.7,3.8,6.7,2.2,Virginica +7.7,2.6,6.9,2.3,Virginica +6,2.2,5,1.5,Virginica +6.9,3.2,5.7,2.3,Virginica +5.6,2.8,4.9,2,Virginica +7.7,2.8,6.7,2,Virginica +6.3,2.7,4.9,1.8,Virginica +6.7,3.3,5.7,2.1,Virginica +7.2,3.2,6,1.8,Virginica +6.2,2.8,4.8,1.8,Virginica +6.1,3,4.9,1.8,Virginica +6.4,2.8,5.6,2.1,Virginica +7.2,3,5.8,1.6,Virginica +7.4,2.8,6.1,1.9,Virginica +7.9,3.8,6.4,2,Virginica +6.4,2.8,5.6,2.2,Virginica +6.3,2.8,5.1,1.5,Virginica +6.1,2.6,5.6,1.4,Virginica +7.7,3,6.1,2.3,Virginica +6.3,3.4,5.6,2.4,Virginica +6.4,3.1,5.5,1.8,Virginica +6,3,4.8,1.8,Virginica +6.9,3.1,5.4,2.1,Virginica +6.7,3.1,5.6,2.4,Virginica +6.9,3.1,5.1,2.3,Virginica +5.8,2.7,5.1,1.9,Virginica +6.8,3.2,5.9,2.3,Virginica +6.7,3.3,5.7,2.5,Virginica +6.7,3,5.2,2.3,Virginica +6.3,2.5,5,1.9,Virginica +6.5,3,5.2,2,Virginica +6.2,3.4,5.4,2.3,Virginica +5.9,3,5.1,1.8,Virginica \ No newline at end of file diff --git a/docs/data/reddit.csv b/docs/data/reddit.csv new file mode 100644 index 000000000000..88f91e3df7db --- /dev/null +++ b/docs/data/reddit.csv @@ -0,0 +1,100 @@ +id,name,created_utc,updated_on,comment_karma,link_karma +1,truman48lamb_jasonbroken,1397113470,1536527864,0,0 +2,johnethen06_jasonbroken,1397113483,1536527864,0,0 +3,yaseinrez_jasonbroken,1397113483,1536527864,0,1 +4,Valve92_jasonbroken,1397113503,1536527864,0,0 +5,srbhuyan_jasonbroken,1397113506,1536527864,0,0 +6,taojianlong_jasonbroken,1397113510,1536527864,4,0 +7,YourPalGrant92_jasonbroken,1397113513,1536527864,0,0 +8,Lucki87_jasonbroken,1397113515,1536527864,0,0 +9,punkstock_jasonbroken,1397113517,1536527864,0,0 +10,duder_con_chile_jasonbroken,1397113519,1536527864,0,2 +11,IHaveBigBalls_jasonbroken,1397113520,1536527864,0,0 +12,Foggybanana_jasonbroken,1397113523,1536527864,0,0 +13,Thedrinkdriver_jasonbroken,1397113527,1536527864,-9,0 +14,littlemissd_jasonbroken,1397113530,1536527864,0,-3 +15,phonethaway_jasonbroken,1397113537,1536527864,0,0 +16,DreamingOfWinterfell_jasonbroken,1397113538,1536527864,0,0 +17,ssaig_jasonbroken,1397113544,1536527864,1,0 +18,divinetribe_jasonbroken,1397113549,1536527864,0,0 +19,fdbvfdssdgfds_jasonbroken,1397113552,1536527864,3,0 +20,hjtrsh54yh43_jasonbroken,1397113559,1536527864,-1,-1 +21,Dalin86_jasonbroken,1397113561,1536527864,0,0 +22,sgalex_jasonbroken,1397113561,1536527864,0,0 +23,beszhthw_jasonbroken,1397113566,1536527864,0,0 +24,WojkeN_jasonbroken,1397113572,1536527864,-8,0 +25,LixksHD_jasonbroken,1397113572,1536527864,0,0 +26,bradhrvf78_jasonbroken,1397113574,1536527864,0,0 +27,ravenfeathers_jasonbroken,1397113576,1536527864,0,0 +28,jayne101_jasonbroken,1397113583,1536527864,0,0 +29,jdennis6701_jasonbroken,1397113585,1536527864,0,0 +30,Puppy243_jasonbroken,1397113592,1536527864,0,0 +31,sissyt_jasonbroken,1397113609,1536527864,0,0 +32,fengye78_jasonbroken,1397113613,1536527864,0,0 +33,bigspender1988_jasonbroken,1397113614,1536527864,0,21 +34,bitdownworld_jasonbroken,1397113618,1536527864,0,0 +35,adhyufsdtha12_jasonbroken,1397113619,1536527864,0,0 +36,Haydenac_jasonbroken,1397113635,1536527864,0,0 +37,ihatewhoweare_jasonbroken,1397113636,1536527864,61,0 +38,HungDaddy69__jasonbroken,1397113641,1536527864,0,0 +39,FSUJohnny24_jasonbroken,1397113646,1536527864,0,0 +40,Toejimon_jasonbroken,1397113650,1536527864,0,0 +41,mine69flesh_jasonbroken,1397113651,1536527864,0,0 +42,brycentkt_jasonbroken,1397113653,1536527864,0,0 +43,hmmmitsbig,1397113655,1536527864,0,0 +77714,hockeyschtick,1137474000,1536497404,11104,451 +77715,kbmunkholm,1137474000,1536528267,0,0 +77716,dickb,1137588452,1536528267,0,0 +77717,stephenjcole,1137474000,1536528267,0,2 +77718,rosetree,1137474000,1536528267,0,0 +77719,benhawK,1138180921,1536528267,0,0 +77720,joenowak,1137474000,1536528268,0,0 +77721,constant,1137474000,1536528268,1,0 +77722,jpscott,1137474000,1536528268,0,1 +77723,meryn,1137474000,1536528268,0,2 +77724,momerath,1128916800,1536528268,2490,101 +77725,inuse,1137474000,1536528269,0,0 +77726,dubert11,1137474000,1536528269,38,59 +77727,CaliMark,1137474000,1536528269,0,0 +77728,Maniac,1137474000,1536528269,0,0 +77729,earlpearl,1137474000,1536528269,0,0 +77730,ghost,1137474000,1536497404,767,0 +77731,paulzg,1137474000,1536528270,0,0 +77732,rshawgo,1137474000,1536497404,707,6883 +77733,spage,1137474000,1536528270,0,0 +77734,HrothgarReborn,1137474000,1536528270,0,0 +77735,darknessvisible,1137474000,1536528270,26133,139 +77736,finleyt,1137714898,1536528270,0,0 +77737,Dalton,1137474000,1536528271,118,2 +77738,graemes,1137474000,1536528271,0,0 +77739,lettuce,1137780958,1536497404,4546,724 +77740,mudkicker,1137474000,1536528271,0,0 +77741,mydignet,1139649149,1536528271,0,0 +77742,markbo,1137474000,1536528271,0,0 +77743,mrfrostee,1137474000,1536528272,227,43 +77744,parappayo,1136350800,1536528272,53,164 +77745,danastasi,1137474000,1536528272,2335,146 +77747,AltherrWeb,1137474000,1536528272,1387,1605 +77748,dtpetty,1137474000,1536528273,0,0 +77749,jamesluke4,1137474000,1536528273,0,0 +77750,sankeld,1137474000,1536528273,9,45 +77751,iampivot,1139479524,1536497404,2640,31 +77752,mcaamano,1137474000,1536528273,0,0 +77753,wonsungi,1137596632,1536528273,0,0 +77754,naotakem,1137474000,1536528274,0,0 +77755,bis,1137474000,1536497404,2191,285 +77756,imeinzen,1137474000,1536528274,0,0 +77757,zrenneh,1137474000,1536528274,79,0 +77758,onclephilippe,1137474000,1536528274,0,0 +77759,Mokzaio415,1139422169,1536528274,0,0 +77761,-brisse,1137474000,1536528275,14,1 +77762,coolin86,1138303196,1536528275,40,7 +77763,Lunchy,1137599510,1536528275,65,0 +77764,jannemans,1137474000,1536528275,0,0 +77765,compostellas,1137474000,1536528276,6,0 +77766,genericbob,1137474000,1536528276,291,14 +77767,domlexch,1139482978,1536528276,0,0 +77768,TinheadNed,1139665457,1536497404,4434,103 +77769,patopurifik,1137474000,1536528276,0,0 +77770,PoPPo,1139057558,1536528276,0,0 +77771,tandrews,1137474000,1536528277,0,0 diff --git a/docs/getting-started/expressions.md b/docs/getting-started/expressions.md new file mode 100644 index 000000000000..692806d75de9 --- /dev/null +++ b/docs/getting-started/expressions.md @@ -0,0 +1,130 @@ +# Expressions + +`Expressions` are the core strength of `Polars`. The `expressions` offer a versatile structure that both solves easy queries and is easily extended to complex ones. Below we will cover the basic components that serve as building block (or in `Polars` terminology contexts) for all your queries: + +- `select` +- `filter` +- `with_columns` +- `group_by` + +To learn more about expressions and the context in which they operate, see the User Guide sections: [Contexts](../user-guide/concepts/contexts.md) and [Expressions](../user-guide/concepts/expressions.md). + +### Select statement + +To select a column we need to do two things. Define the `DataFrame` we want the data from. And second, select the data that we need. In the example below you see that we select `col('*')`. The asterisk stands for all columns. + +{{code_block('getting-started/expressions','select',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:setup" +print( + --8<-- "python/getting-started/expressions.py:select" +) +``` + +You can also specify the specific columns that you want to return. There are two ways to do this. The first option is to create a `list` of column names, as seen below. + +{{code_block('getting-started/expressions','select2',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:select2" +) +``` + +The second option is to specify each column within a `list` in the `select` statement. This option is shown below. + +{{code_block('getting-started/expressions','select3',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:select3" +) +``` + +If you want to exclude an entire column from your view, you can simply use `exclude` in your `select` statement. + +{{code_block('getting-started/expressions','exclude',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:exclude" +) +``` + +### Filter + +The `filter` option allows us to create a subset of the `DataFrame`. We use the same `DataFrame` as earlier and we filter between two specified dates. + +{{code_block('getting-started/expressions','filter',['filter'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:filter" +) +``` + +With `filter` you can also create more complex filters that include multiple columns. + +{{code_block('getting-started/expressions','filter2',['filter'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:filter2" +) +``` + +### With_columns + +`with_columns` allows you to create new columns for your analyses. We create two new columns `e` and `b+42`. First we sum all values from column `b` and store the results in column `e`. After that we add `42` to the values of `b`. Creating a new column `b+42` to store these results. + +{{code_block('getting-started/expressions','with_columns',['with_columns'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:with_columns" +) +``` + +### Group by + +We will create a new `DataFrame` for the Group by functionality. This new `DataFrame` will include several 'groups' that we want to group by. + +{{code_block('getting-started/expressions','dataframe2',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:dataframe2" +print(df2) +``` + +{{code_block('getting-started/expressions','group_by',['group_by'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:group_by" +) +``` + +{{code_block('getting-started/expressions','group_by2',['group_by'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:group_by2" +) +``` + +### Combining operations + +Below are some examples on how to combine operations to create the `DataFrame` you require. + +{{code_block('getting-started/expressions','combine',['select','with_columns'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:combine" +``` + +{{code_block('getting-started/expressions','combine2',['select','with_columns'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:combine2" +``` diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md new file mode 100644 index 000000000000..b8b8d18441e6 --- /dev/null +++ b/docs/getting-started/installation.md @@ -0,0 +1,31 @@ +# Installation + +Polars is a library and installation is as simple as invoking the package manager of the corresponding programming language. + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars + ``` + +## Importing + +To use the library import it into your project + +=== ":fontawesome-brands-python: Python" + + ``` python + import polars as pl + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` rust + use polars::prelude::*; + ``` diff --git a/docs/getting-started/intro.md b/docs/getting-started/intro.md new file mode 100644 index 000000000000..81d4ac110efc --- /dev/null +++ b/docs/getting-started/intro.md @@ -0,0 +1,16 @@ +# Introduction + +This getting started guide is written for new users of Polars. The goal is to provide a quick overview of the most common functionality. For a more detailed explanation, please go to the [User Guide](../user-guide/index.md) + +!!! rust "Rust Users Only" + + Due to historical reasons the eager API in Rust is outdated. In the future we would like to redesign it as a small wrapper around the lazy API (as is the design in Python / NodeJS). In the examples we will use the lazy API instead with `.lazy()` and `.collect()`. For now you can ignore these two functions. If you want to know more about the lazy and eager API go [here](../user-guide/concepts/lazy-vs-eager.md). + + To enable the Lazy API ensure you have the feature flag `lazy` configured when installing Polars + ``` + # Cargo.toml + [dependencies] + polars = { version = "x", features = ["lazy", ...]} + ``` + + Because of the ownership ruling in Rust we can not reuse the same `DataFrame` multiple times in the examples. For simplicity reasons we call `clone()` to overcome this issue. Note that this does not duplicate the data but just increments a pointer (`Arc`). diff --git a/docs/getting-started/joins.md b/docs/getting-started/joins.md new file mode 100644 index 000000000000..42d875d79144 --- /dev/null +++ b/docs/getting-started/joins.md @@ -0,0 +1,26 @@ +# Combining DataFrames + +There are two ways `DataFrame`s can be combined depending on the use case: join and concat. + +## Join + +Polars supports all types of join (e.g. left, right, inner, outer). Let's have a closer look on how to `join` two `DataFrames` into a single `DataFrame`. Our two `DataFrames` both have an 'id'-like column: `a` and `x`. We can use those columns to `join` the `DataFrames` in this example. + +{{code_block('getting-started/joins','join',['join'])}} + +```python exec="on" result="text" session="getting-started/joins" +--8<-- "python/getting-started/joins.py:setup" +--8<-- "python/getting-started/joins.py:join" +``` + +To see more examples with other types of joins, go the [User Guide](../user-guide/transformations/joins.md). + +## Concat + +We can also `concatenate` two `DataFrames`. Vertical concatenation will make the `DataFrame` longer. Horizontal concatenation will make the `DataFrame` wider. Below you can see the result of an horizontal concatenation of our two `DataFrames`. + +{{code_block('getting-started/joins','hstack',['hstack'])}} + +```python exec="on" result="text" session="getting-started/joins" +--8<-- "python/getting-started/joins.py:hstack" +``` diff --git a/docs/getting-started/reading-writing.md b/docs/getting-started/reading-writing.md new file mode 100644 index 000000000000..ad91be50f0f6 --- /dev/null +++ b/docs/getting-started/reading-writing.md @@ -0,0 +1,45 @@ +# Reading & writing + +Polars supports reading and writing to all common files (e.g. csv, json, parquet), cloud storage (S3, Azure Blob, BigQuery) and databases (e.g. postgres, mysql). In the following examples we will show how to operate on most common file formats. For the following dataframe + +{{code_block('getting-started/reading-writing','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:dataframe" +``` + +#### CSV + +Polars has its own fast implementation for csv reading with many flexible configuration options. + +{{code_block('getting-started/reading-writing','csv',['read_csv','write_csv'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:csv" +``` + +As we can see above, Polars made the datetimes a `string`. We can tell Polars to parse dates, when reading the csv, to ensure the date becomes a datetime. The example can be found below: + +{{code_block('getting-started/reading-writing','csv2',['read_csv'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:csv2" +``` + +#### JSON + +{{code_block('getting-started/reading-writing','json',['read_json','write_json'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:json" +``` + +#### Parquet + +{{code_block('getting-started/reading-writing','parquet',['read_parquet','write_parquet'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:parquet" +``` + +To see more examples and other data formats go to the [User Guide](../user-guide/io/csv.md), section IO. diff --git a/docs/getting-started/series-dataframes.md b/docs/getting-started/series-dataframes.md new file mode 100644 index 000000000000..d0a6e957fc2c --- /dev/null +++ b/docs/getting-started/series-dataframes.md @@ -0,0 +1,102 @@ +# Series & DataFrames + +The core base data structures provided by Polars are `Series` and `DataFrames`. + +## Series + +Series are a 1-dimensional data structure. Within a series all elements have the same data type (e.g. int, string). +The snippet below shows how to create a simple named `Series` object. In a later section of this getting started guide we will learn how to read data from external sources (e.g. files, database), for now lets keep it simple. + +{{code_block('getting-started/series-dataframes','series',['Series'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:series" +``` + +### Methods + +Although it is more common to work directly on a `DataFrame` object, `Series` implement a number of base methods which make it easy to perform transformations. Below are some examples of common operations you might want to perform. Note that these are for illustration purposes and only show a small subset of what is available. + +##### Aggregations + +`Series` out of the box supports all basic aggregations (e.g. min, max, mean, mode, ...). + +{{code_block('getting-started/series-dataframes','minmax',['min','max'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:minmax" +``` + +##### String + +There are a number of methods related to string operations in the `StringNamespace`. These only work on `Series` with the Datatype `Utf8`. + +{{code_block('getting-started/series-dataframes','string',['str.replace'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:string" +``` + +##### Datetime + +Similar to strings, there is a separate namespace for datetime related operations in the `DateLikeNameSpace`. These only work on `Series`with DataTypes related to dates. + +{{code_block('getting-started/series-dataframes','dt',['Series.dt.day'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:dt" +``` + +## DataFrame + +A `DataFrame` is a 2-dimensional data structure that is backed by a `Series`, and it could be seen as an abstraction of on collection (e.g. list) of `Series`. Operations that can be executed on `DataFrame` are very similar to what is done in a `SQL` like query. You can `GROUP BY`, `JOIN`, `PIVOT`, but also define custom functions. In the next pages we will cover how to perform these transformations. + +{{code_block('getting-started/series-dataframes','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:dataframe" +``` + +### Viewing data + +This part focuses on viewing data in a `DataFrame`. We will use the `DataFrame` from the previous example as a starting point. + +#### Head + +The `head` function shows by default the first 5 rows of a `DataFrame`. You can specify the number of rows you want to see (e.g. `df.head(10)`). + +{{code_block('getting-started/series-dataframes','head',['head'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:head" +``` + +#### Tail + +The `tail` function shows the last 5 rows of a `DataFrame`. You can also specify the number of rows you want to see, similar to `head`. + +{{code_block('getting-started/series-dataframes','tail',['tail'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:tail" +``` + +#### Sample + +If you want to get an impression of the data of your `DataFrame`, you can also use `sample`. With `sample` you get an _n_ number of random rows from the `DataFrame`. + +{{code_block('getting-started/series-dataframes','sample',['sample'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:sample" +``` + +#### Describe + +`Describe` returns summary statistics of your `DataFrame`. It will provide several quick statistics if possible. + +{{code_block('getting-started/series-dataframes','describe',['describe'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:describe" +``` diff --git a/docs/images/.gitignore b/docs/images/.gitignore new file mode 100644 index 000000000000..72e8ffc0db8a --- /dev/null +++ b/docs/images/.gitignore @@ -0,0 +1 @@ +* diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000000..2621ba4ee11d --- /dev/null +++ b/docs/index.md @@ -0,0 +1,71 @@ +--- +hide: + - navigation +--- + +# Polars + +![logo](https://raw.githubusercontent.com/pola-rs/polars-static/master/logos/polars_github_logo_rect_dark_name.svg) + +

Blazingly Fast DataFrame Library

+ + +Polars is a highly performant DataFrame library for manipulating structured data. The core is written in Rust, but the library is also available in Python. Its key features are: + +- **Fast**: Polars is written from the ground up, designed close to the machine and without external dependencies. +- **I/O**: First class support for all common data storage layers: local, cloud storage & databases. +- **Easy to use**: Write your queries the way they were intended. Polars, internally, will determine the most efficient way to execute using its query optimizer. +- **Out of Core**: Polars supports out of core data transformation with its streaming API. Allowing you to process your results without requiring all your data to be in memory at the same time +- **Parallel**: Polars fully utilises the power of your machine by dividing the workload among the available CPU cores without any additional configuration. +- **Vectorized Query Engine**: Polars uses [Apache Arrow](https://arrow.apache.org/), a columnar data format, to process your queries in a vectorized manner. It uses [SIMD](https://en.wikipedia.org/wiki/Single_instruction,_multiple_data) to optimize CPU usage. + +## About this guide + +The `Polars` user guide is intended to live alongside the API documentation. Its purpose is to explain (new) users how to use `Polars` and to provide meaningful examples. The guide is split into two parts: + +- [Getting Started](getting-started/intro.md): A 10 minute helicopter view of the library and its primary function. +- [User Guide](user-guide/index.md): A detailed explanation of how the library is setup and how to use it most effectively. + +If you are looking for details on a specific level / object, it is probably best to go the API documentation: [Python](https://pola-rs.github.io/polars/py-polars/html/reference/index.html) | [Rust](https://docs.rs/polars/latest/polars/). + +## Performance :rocket: :rocket: + +`Polars` is very fast, and in fact is one of the best performing solutions available. +See the results in h2oai's [db-benchmark](https://duckdblabs.github.io/db-benchmark/), revived by the DuckDB project. + +`Polars` [TPCH Benchmark results](https://www.pola.rs/benchmarks.html) are now available on the official website. + +## Example + +{{code_block('home/example','example',['scan_csv','filter','group_by','collect'])}} + +## Sponsors + +[](https://www.xomnia.com/)   [](https://www.jetbrains.com) + +## Community + +`Polars` has a very active community with frequent releases (approximately weekly). Below are some of the top contributors to the project: + +--8<-- "docs/people.md" + +## Contribute + +Thanks for taking the time to contribute! We appreciate all contributions, from reporting bugs to implementing new features. If you're unclear on how to proceed read our [contribution guide](https://github.com/pola-rs/polars/blob/main/CONTRIBUTING.md) or contact us on [discord](https://discord.com/invite/4UfP5cfBE7). + +## License + +This project is licensed under the terms of the MIT license. diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000000..2c317b06415b --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,9 @@ +pandas +pyarrow +graphviz +matplotlib + +mkdocs-material==9.2.5 +mkdocs-macros-plugin==1.0.4 +markdown-exec[ansi]==1.6.0 +PyGithub==1.59.1 diff --git a/docs/src/python/getting-started/expressions.py b/docs/src/python/getting-started/expressions.py new file mode 100644 index 000000000000..ea73e0819a90 --- /dev/null +++ b/docs/src/python/getting-started/expressions.py @@ -0,0 +1,91 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np +from datetime import datetime + +df = pl.DataFrame( + { + "a": np.arange(0, 8), + "b": np.random.rand(8), + "c": [ + datetime(2022, 12, 1), + datetime(2022, 12, 2), + datetime(2022, 12, 3), + datetime(2022, 12, 4), + datetime(2022, 12, 5), + datetime(2022, 12, 6), + datetime(2022, 12, 7), + datetime(2022, 12, 8), + ], + "d": [1, 2.0, np.NaN, np.NaN, 0, -5, -42, None], + } +) +# --8<-- [end:setup] + +# --8<-- [start:select] +df.select(pl.col("*")) +# --8<-- [end:select] + +# --8<-- [start:select2] +df.select(pl.col(["a", "b"])) +# --8<-- [end:select2] + +# --8<-- [start:select3] +df.select([pl.col("a"), pl.col("b")]).limit(3) +# --8<-- [end:select3] + +# --8<-- [start:exclude] +df.select([pl.exclude("a")]) +# --8<-- [end:exclude] + +# --8<-- [start:filter] +df.filter( + pl.col("c").is_between(datetime(2022, 12, 2), datetime(2022, 12, 8)), +) +# --8<-- [end:filter] + +# --8<-- [start:filter2] +df.filter((pl.col("a") <= 3) & (pl.col("d").is_not_nan())) +# --8<-- [end:filter2] + +# --8<-- [start:with_columns] +df.with_columns([pl.col("b").sum().alias("e"), (pl.col("b") + 42).alias("b+42")]) +# --8<-- [end:with_columns] + +# --8<-- [start:dataframe2] +df2 = pl.DataFrame( + { + "x": np.arange(0, 8), + "y": ["A", "A", "A", "B", "B", "C", "X", "X"], + } +) +# --8<-- [end:dataframe2] + +# --8<-- [start:group_by] +df2.group_by("y", maintain_order=True).count() +# --8<-- [end:group_by] + +# --8<-- [start:group_by2] +df2.group_by("y", maintain_order=True).agg( + [ + pl.col("*").count().alias("count"), + pl.col("*").sum().alias("sum"), + ] +) +# --8<-- [end:group_by2] + +# --8<-- [start:combine] +df_x = df.with_columns((pl.col("a") * pl.col("b")).alias("a * b")).select( + [pl.all().exclude(["c", "d"])] +) + +print(df_x) +# --8<-- [end:combine] + +# --8<-- [start:combine2] +df_y = df.with_columns([(pl.col("a") * pl.col("b")).alias("a * b")]).select( + [pl.all().exclude("d")] +) + +print(df_y) +# --8<-- [end:combine2] diff --git a/docs/src/python/getting-started/joins.py b/docs/src/python/getting-started/joins.py new file mode 100644 index 000000000000..e5a52416eef1 --- /dev/null +++ b/docs/src/python/getting-started/joins.py @@ -0,0 +1,29 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np + +# --8<-- [end:setup] + +# --8<-- [start:join] +df = pl.DataFrame( + { + "a": np.arange(0, 8), + "b": np.random.rand(8), + "d": [1, 2.0, np.NaN, np.NaN, 0, -5, -42, None], + } +) + +df2 = pl.DataFrame( + { + "x": np.arange(0, 8), + "y": ["A", "A", "A", "B", "B", "C", "X", "X"], + } +) +joined = df.join(df2, left_on="a", right_on="x") +print(joined) +# --8<-- [end:join] + +# --8<-- [start:hstack] +stacked = df.hstack(df2) +print(stacked) +# --8<-- [end:hstack] diff --git a/docs/src/python/getting-started/reading-writing.py b/docs/src/python/getting-started/reading-writing.py new file mode 100644 index 000000000000..dc8a54ebd18f --- /dev/null +++ b/docs/src/python/getting-started/reading-writing.py @@ -0,0 +1,41 @@ +# --8<-- [start:dataframe] +import polars as pl +from datetime import datetime + +df = pl.DataFrame( + { + "integer": [1, 2, 3], + "date": [ + datetime(2022, 1, 1), + datetime(2022, 1, 2), + datetime(2022, 1, 3), + ], + "float": [4.0, 5.0, 6.0], + } +) + +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:csv] +df.write_csv("docs/data/output.csv") +df_csv = pl.read_csv("docs/data/output.csv") +print(df_csv) +# --8<-- [end:csv] + +# --8<-- [start:csv2] +df_csv = pl.read_csv("docs/data/output.csv", try_parse_dates=True) +print(df_csv) +# --8<-- [end:csv2] + +# --8<-- [start:json] +df.write_json("docs/data/output.json") +df_json = pl.read_json("docs/data/output.json") +print(df_json) +# --8<-- [end:json] + +# --8<-- [start:parquet] +df.write_parquet("docs/data/output.parquet") +df_parquet = pl.read_parquet("docs/data/output.parquet") +print(df_parquet) +# --8<-- [end:parquet] diff --git a/docs/src/python/getting-started/series-dataframes.py b/docs/src/python/getting-started/series-dataframes.py new file mode 100644 index 000000000000..3171da06adbc --- /dev/null +++ b/docs/src/python/getting-started/series-dataframes.py @@ -0,0 +1,63 @@ +# --8<-- [start:series] +import polars as pl + +s = pl.Series("a", [1, 2, 3, 4, 5]) +print(s) +# --8<-- [end:series] + +# --8<-- [start:minmax] +s = pl.Series("a", [1, 2, 3, 4, 5]) +print(s.min()) +print(s.max()) +# --8<-- [end:minmax] + +# --8<-- [start:string] +s = pl.Series("a", ["polar", "bear", "arctic", "polar fox", "polar bear"]) +s2 = s.str.replace("polar", "pola") +print(s2) +# --8<-- [end:string] + +# --8<-- [start:dt] +from datetime import date + +start = date(2001, 1, 1) +stop = date(2001, 1, 9) +s = pl.date_range(start, stop, interval="2d", eager=True) +print(s.dt.day()) +# --8<-- [end:dt] + +# --8<-- [start:dataframe] +from datetime import datetime + +df = pl.DataFrame( + { + "integer": [1, 2, 3, 4, 5], + "date": [ + datetime(2022, 1, 1), + datetime(2022, 1, 2), + datetime(2022, 1, 3), + datetime(2022, 1, 4), + datetime(2022, 1, 5), + ], + "float": [4.0, 5.0, 6.0, 7.0, 8.0], + } +) + +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:head] +print(df.head(3)) +# --8<-- [end:head] + +# --8<-- [start:tail] +print(df.tail(3)) +# --8<-- [end:tail] + +# --8<-- [start:sample] +print(df.sample(2)) +# --8<-- [end:sample] + +# --8<-- [start:describe] +print(df.describe()) +# --8<-- [end:describe] diff --git a/docs/src/python/home/example.py b/docs/src/python/home/example.py new file mode 100644 index 000000000000..5f675f4e82e4 --- /dev/null +++ b/docs/src/python/home/example.py @@ -0,0 +1,12 @@ +# --8<-- [start:example] +import polars as pl + +q = ( + pl.scan_csv("docs/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.all().sum()) +) + +df = q.collect() +# --8<-- [end:example] diff --git a/docs/src/python/user-guide/concepts/contexts.py b/docs/src/python/user-guide/concepts/contexts.py new file mode 100644 index 000000000000..ea3baf965b52 --- /dev/null +++ b/docs/src/python/user-guide/concepts/contexts.py @@ -0,0 +1,55 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np + +np.random.seed(12) +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", None], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:select] + +out = df.select( + pl.sum("nrs"), + pl.col("names").sort(), + pl.col("names").first().alias("first name"), + (pl.mean("nrs") * 10).alias("10xnrs"), +) +print(out) +# --8<-- [end:select] + +# --8<-- [start:filter] +out = df.filter(pl.col("nrs") > 2) +print(out) +# --8<-- [end:filter] + +# --8<-- [start:with_columns] + +df = df.with_columns( + pl.sum("nrs").alias("nrs_sum"), + pl.col("random").count().alias("count"), +) +print(df) +# --8<-- [end:with_columns] + + +# --8<-- [start:group_by] +out = df.group_by("groups").agg( + pl.sum("nrs"), # sum nrs by groups + pl.col("random").count().alias("count"), # count group members + # sum random where name != null + pl.col("random").filter(pl.col("names").is_not_null()).sum().suffix("_sum"), + pl.col("names").reverse().alias("reversed names"), +) +print(out) +# --8<-- [end:group_by] diff --git a/docs/src/python/user-guide/concepts/expressions.py b/docs/src/python/user-guide/concepts/expressions.py new file mode 100644 index 000000000000..83e6c4514c23 --- /dev/null +++ b/docs/src/python/user-guide/concepts/expressions.py @@ -0,0 +1,16 @@ +import polars as pl + +df = pl.DataFrame( + { + "foo": [1, 2, 3, None, 5], + "bar": ["foo", "ham", "spam", "egg", None], + } +) + +# --8<-- [start:example1] +pl.col("foo").sort().head(2) +# --8<-- [end:example1] + +# --8<-- [start:example2] +df.select(pl.col("foo").sort().head(2), pl.col("bar").filter(pl.col("foo") == 1).sum()) +# --8<-- [end:example2] diff --git a/docs/src/python/user-guide/concepts/lazy-vs-eager.py b/docs/src/python/user-guide/concepts/lazy-vs-eager.py new file mode 100644 index 000000000000..1327bac6357a --- /dev/null +++ b/docs/src/python/user-guide/concepts/lazy-vs-eager.py @@ -0,0 +1,20 @@ +import polars as pl + +# --8<-- [start:eager] + +df = pl.read_csv("docs/data/iris.csv") +df_small = df.filter(pl.col("sepal_length") > 5) +df_agg = df_small.group_by("species").agg(pl.col("sepal_width").mean()) +print(df_agg) +# --8<-- [end:eager] + +# --8<-- [start:lazy] +q = ( + pl.scan_csv("docs/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.col("sepal_width").mean()) +) + +df = q.collect() +# --8<-- [end:lazy] diff --git a/docs/src/python/user-guide/concepts/streaming.py b/docs/src/python/user-guide/concepts/streaming.py new file mode 100644 index 000000000000..955750bf6c30 --- /dev/null +++ b/docs/src/python/user-guide/concepts/streaming.py @@ -0,0 +1,12 @@ +import polars as pl + +# --8<-- [start:streaming] +q = ( + pl.scan_csv("docs/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.col("sepal_width").mean()) +) + +df = q.collect(streaming=True) +# --8<-- [end:streaming] diff --git a/docs/src/python/user-guide/expressions/aggregation.py b/docs/src/python/user-guide/expressions/aggregation.py new file mode 100644 index 000000000000..79120d79547f --- /dev/null +++ b/docs/src/python/user-guide/expressions/aggregation.py @@ -0,0 +1,169 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import date + +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +url = "https://theunitedstates.io/congress-legislators/legislators-historical.csv" + +dtypes = { + "first_name": pl.Categorical, + "gender": pl.Categorical, + "type": pl.Categorical, + "state": pl.Categorical, + "party": pl.Categorical, +} + +dataset = pl.read_csv(url, dtypes=dtypes).with_columns( + pl.col("birthday").str.to_date(strict=False) +) +# --8<-- [end:dataframe] + +# --8<-- [start:basic] +q = ( + dataset.lazy() + .group_by("first_name") + .agg( + pl.count(), + pl.col("gender"), + pl.first("last_name"), + ) + .sort("count", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:basic] + +# --8<-- [start:conditional] +q = ( + dataset.lazy() + .group_by("state") + .agg( + (pl.col("party") == "Anti-Administration").sum().alias("anti"), + (pl.col("party") == "Pro-Administration").sum().alias("pro"), + ) + .sort("pro", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:conditional] + +# --8<-- [start:nested] +q = ( + dataset.lazy() + .group_by("state", "party") + .agg(pl.count("party").alias("count")) + .filter( + (pl.col("party") == "Anti-Administration") + | (pl.col("party") == "Pro-Administration") + ) + .sort("count", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:nested] + + +# --8<-- [start:filter] +def compute_age() -> pl.Expr: + return date(2021, 1, 1).year - pl.col("birthday").dt.year() + + +def avg_birthday(gender: str) -> pl.Expr: + return ( + compute_age() + .filter(pl.col("gender") == gender) + .mean() + .alias(f"avg {gender} birthday") + ) + + +q = ( + dataset.lazy() + .group_by("state") + .agg( + avg_birthday("M"), + avg_birthday("F"), + (pl.col("gender") == "M").sum().alias("# male"), + (pl.col("gender") == "F").sum().alias("# female"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:filter] + + +# --8<-- [start:sort] +def get_person() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort] + + +# --8<-- [start:sort2] +def get_person() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort().first().alias("alphabetical_first"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort2] + + +# --8<-- [start:sort3] +def get_person() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort().first().alias("alphabetical_first"), + pl.col("gender").sort_by("first_name").first().alias("gender"), + ) + .sort("state") + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort3] diff --git a/docs/src/python/user-guide/expressions/casting.py b/docs/src/python/user-guide/expressions/casting.py new file mode 100644 index 000000000000..5f248937743e --- /dev/null +++ b/docs/src/python/user-guide/expressions/casting.py @@ -0,0 +1,129 @@ +# --8<-- [start:setup] + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:dfnum] +df = pl.DataFrame( + { + "integers": [1, 2, 3, 4, 5], + "big_integers": [1, 10000002, 3, 10000004, 10000005], + "floats": [4.0, 5.0, 6.0, 7.0, 8.0], + "floats_with_decimal": [4.532, 5.5, 6.5, 7.5, 8.5], + } +) + +print(df) +# --8<-- [end:dfnum] + +# --8<-- [start:castnum] +out = df.select( + pl.col("integers").cast(pl.Float32).alias("integers_as_floats"), + pl.col("floats").cast(pl.Int32).alias("floats_as_integers"), + pl.col("floats_with_decimal") + .cast(pl.Int32) + .alias("floats_with_decimal_as_integers"), +) +print(out) +# --8<-- [end:castnum] + + +# --8<-- [start:downcast] +out = df.select( + pl.col("integers").cast(pl.Int16).alias("integers_smallfootprint"), + pl.col("floats").cast(pl.Float32).alias("floats_smallfootprint"), +) +print(out) +# --8<-- [end:downcast] + +# --8<-- [start:overflow] +try: + out = df.select(pl.col("big_integers").cast(pl.Int8)) + print(out) +except Exception as e: + print(e) +# --8<-- [end:overflow] + +# --8<-- [start:overflow2] +out = df.select(pl.col("big_integers").cast(pl.Int8, strict=False)) +print(out) +# --8<-- [end:overflow2] + + +# --8<-- [start:strings] +df = pl.DataFrame( + { + "integers": [1, 2, 3, 4, 5], + "float": [4.0, 5.03, 6.0, 7.0, 8.0], + "floats_as_string": ["4.0", "5.0", "6.0", "7.0", "8.0"], + } +) + +out = df.select( + pl.col("integers").cast(pl.Utf8), + pl.col("float").cast(pl.Utf8), + pl.col("floats_as_string").cast(pl.Float64), +) +print(out) +# --8<-- [end:strings] + + +# --8<-- [start:strings2] +df = pl.DataFrame({"strings_not_float": ["4.0", "not_a_number", "6.0", "7.0", "8.0"]}) +try: + out = df.select(pl.col("strings_not_float").cast(pl.Float64)) + print(out) +except Exception as e: + print(e) +# --8<-- [end:strings2] + +# --8<-- [start:bool] +df = pl.DataFrame( + { + "integers": [-1, 0, 2, 3, 4], + "floats": [0.0, 1.0, 2.0, 3.0, 4.0], + "bools": [True, False, True, False, True], + } +) + +out = df.select(pl.col("integers").cast(pl.Boolean), pl.col("floats").cast(pl.Boolean)) +print(out) +# --8<-- [end:bool] + +# --8<-- [start:dates] +from datetime import date, datetime + +df = pl.DataFrame( + { + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 5), eager=True), + "datetime": pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 1, 5), eager=True + ), + } +) + +out = df.select(pl.col("date").cast(pl.Int64), pl.col("datetime").cast(pl.Int64)) +print(out) +# --8<-- [end:dates] + +# --8<-- [start:dates2] +df = pl.DataFrame( + { + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 5), eager=True), + "string": [ + "2022-01-01", + "2022-01-02", + "2022-01-03", + "2022-01-04", + "2022-01-05", + ], + } +) + +out = df.select( + pl.col("date").dt.to_string("%Y-%m-%d"), + pl.col("string").str.to_datetime("%Y-%m-%d"), +) +print(out) +# --8<-- [end:dates2] diff --git a/docs/src/python/user-guide/expressions/column-selections.py b/docs/src/python/user-guide/expressions/column-selections.py new file mode 100644 index 000000000000..88951eaee831 --- /dev/null +++ b/docs/src/python/user-guide/expressions/column-selections.py @@ -0,0 +1,91 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:selectors_df] +from datetime import date, datetime + +df = pl.DataFrame( + { + "id": [9, 4, 2], + "place": ["Mars", "Earth", "Saturn"], + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 3), "1d", eager=True), + "sales": [33.4, 2142134.1, 44.7], + "has_people": [False, True, False], + "logged_at": pl.datetime_range( + datetime(2022, 12, 1), datetime(2022, 12, 1, 0, 0, 2), "1s", eager=True + ), + } +).with_row_count("rn") +print(df) +# --8<-- [end:selectors_df] + +# --8<-- [start:all] +out = df.select(pl.col("*")) + +# Is equivalent to +out = df.select(pl.all()) +print(out) +# --8<-- [end:all] + +# --8<-- [start:exclude] +out = df.select(pl.col("*").exclude("logged_at", "rn")) +print(out) +# --8<-- [end:exclude] + +# --8<-- [start:expansion_by_names] +out = df.select(pl.col("date", "logged_at").dt.to_string("%Y-%h-%d")) +print(out) +# --8<-- [end:expansion_by_names] + +# --8<-- [start:expansion_by_regex] +out = df.select(pl.col("^.*(as|sa).*$")) +print(out) +# --8<-- [end:expansion_by_regex] + +# --8<-- [start:expansion_by_dtype] +out = df.select(pl.col(pl.Int64, pl.UInt32, pl.Boolean).n_unique()) +print(out) +# --8<-- [end:expansion_by_dtype] + +# --8<-- [start:selectors_intro] +import polars.selectors as cs + +out = df.select(cs.integer(), cs.string()) +print(out) +# --8<-- [end:selectors_intro] + +# --8<-- [start:selectors_diff] +out = df.select(cs.numeric() - cs.first()) +print(out) +# --8<-- [end:selectors_diff] + +# --8<-- [start:selectors_union] +out = df.select(cs.by_name("rn") | ~cs.numeric()) +print(out) +# --8<-- [end:selectors_union] + +# --8<-- [start:selectors_by_name] +out = df.select(cs.contains("rn"), cs.matches(".*_.*")) +print(out) +# --8<-- [end:selectors_by_name] + +# --8<-- [start:selectors_to_expr] +out = df.select(cs.temporal().as_expr().dt.to_string("%Y-%h-%d")) +print(out) +# --8<-- [end:selectors_to_expr] + +# --8<-- [start:selectors_is_selector_utility] +from polars.selectors import is_selector + +out = cs.temporal() +print(is_selector(out)) +# --8<-- [end:selectors_is_selector_utility] + +# --8<-- [start:selectors_colnames_utility] +from polars.selectors import expand_selector + +out = cs.temporal().as_expr().dt.to_string("%Y-%h-%d") +print(expand_selector(df, out)) +# --8<-- [end:selectors_colnames_utility] diff --git a/docs/src/python/user-guide/expressions/folds.py b/docs/src/python/user-guide/expressions/folds.py new file mode 100644 index 000000000000..803591b5b581 --- /dev/null +++ b/docs/src/python/user-guide/expressions/folds.py @@ -0,0 +1,50 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:mansum] +df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [10, 20, 30], + } +) + +out = df.select( + pl.fold(acc=pl.lit(0), function=lambda acc, x: acc + x, exprs=pl.all()).alias( + "sum" + ), +) +print(out) +# --8<-- [end:mansum] + +# --8<-- [start:conditional] +df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [0, 1, 2], + } +) + +out = df.filter( + pl.fold( + acc=pl.lit(True), + function=lambda acc, x: acc & x, + exprs=pl.col("*") > 1, + ) +) +print(out) +# --8<-- [end:conditional] + +# --8<-- [start:string] +df = pl.DataFrame( + { + "a": ["a", "b", "c"], + "b": [1, 2, 3], + } +) + +out = df.select(pl.concat_str(["a", "b"])) +print(out) +# --8<-- [end:string] diff --git a/docs/src/python/user-guide/expressions/functions.py b/docs/src/python/user-guide/expressions/functions.py new file mode 100644 index 000000000000..5f9bbd5bb1da --- /dev/null +++ b/docs/src/python/user-guide/expressions/functions.py @@ -0,0 +1,60 @@ +# --8<-- [start:setup] + +import polars as pl +import numpy as np + +np.random.seed(12) +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", "spam"], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:samename] +df_samename = df.select(pl.col("nrs") + 5) +print(df_samename) +# --8<-- [end:samename] + + +# --8<-- [start:samenametwice] +try: + df_samename2 = df.select(pl.col("nrs") + 5, pl.col("nrs") - 5) + print(df_samename2) +except Exception as e: + print(e) +# --8<-- [end:samenametwice] + +# --8<-- [start:samenamealias] +df_alias = df.select( + (pl.col("nrs") + 5).alias("nrs + 5"), + (pl.col("nrs") - 5).alias("nrs - 5"), +) +print(df_alias) +# --8<-- [end:samenamealias] + +# --8<-- [start:countunique] +df_alias = df.select( + pl.col("names").n_unique().alias("unique"), + pl.approx_n_unique("names").alias("unique_approx"), +) +print(df_alias) +# --8<-- [end:countunique] + +# --8<-- [start:conditional] +df_conditional = df.select( + pl.col("nrs"), + pl.when(pl.col("nrs") > 2) + .then(pl.lit(True)) + .otherwise(pl.lit(False)) + .alias("conditional"), +) +print(df_conditional) +# --8<-- [end:conditional] diff --git a/docs/src/python/user-guide/expressions/lists.py b/docs/src/python/user-guide/expressions/lists.py new file mode 100644 index 000000000000..5703a01a5518 --- /dev/null +++ b/docs/src/python/user-guide/expressions/lists.py @@ -0,0 +1,111 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:weather_df] +weather = pl.DataFrame( + { + "station": ["Station " + str(x) for x in range(1, 6)], + "temperatures": [ + "20 5 5 E1 7 13 19 9 6 20", + "18 8 16 11 23 E2 8 E2 E2 E2 90 70 40", + "19 24 E9 16 6 12 10 22", + "E2 E0 15 7 8 10 E1 24 17 13 6", + "14 8 E0 16 22 24 E1", + ], + } +) +print(weather) +# --8<-- [end:weather_df] + +# --8<-- [start:string_to_list] +out = weather.with_columns(pl.col("temperatures").str.split(" ")) +print(out) +# --8<-- [end:string_to_list] + +# --8<-- [start:explode_to_atomic] +out = weather.with_columns(pl.col("temperatures").str.split(" ")).explode( + "temperatures" +) +print(out) +# --8<-- [end:explode_to_atomic] + +# --8<-- [start:list_ops] +out = weather.with_columns(pl.col("temperatures").str.split(" ")).with_columns( + pl.col("temperatures").list.head(3).alias("top3"), + pl.col("temperatures").list.slice(-3, 3).alias("bottom_3"), + pl.col("temperatures").list.len().alias("obs"), +) +print(out) +# --8<-- [end:list_ops] + + +# --8<-- [start:count_errors] +out = weather.with_columns( + pl.col("temperatures") + .str.split(" ") + .list.eval(pl.element().cast(pl.Int64, strict=False).is_null()) + .list.sum() + .alias("errors") +) +print(out) +# --8<-- [end:count_errors] + +# --8<-- [start:count_errors_regex] +out = weather.with_columns( + pl.col("temperatures") + .str.split(" ") + .list.eval(pl.element().str.contains("(?i)[a-z]")) + .list.sum() + .alias("errors") +) +print(out) +# --8<-- [end:count_errors_regex] + +# --8<-- [start:weather_by_day] +weather_by_day = pl.DataFrame( + { + "station": ["Station " + str(x) for x in range(1, 11)], + "day_1": [17, 11, 8, 22, 9, 21, 20, 8, 8, 17], + "day_2": [15, 11, 10, 8, 7, 14, 18, 21, 15, 13], + "day_3": [16, 15, 24, 24, 8, 23, 19, 23, 16, 10], + } +) +print(weather_by_day) +# --8<-- [end:weather_by_day] + +# --8<-- [start:weather_by_day_rank] +rank_pct = (pl.element().rank(descending=True) / pl.col("*").count()).round(2) + +out = weather_by_day.with_columns( + # create the list of homogeneous data + pl.concat_list(pl.all().exclude("station")).alias("all_temps") +).select( + # select all columns except the intermediate list + pl.all().exclude("all_temps"), + # compute the rank by calling `list.eval` + pl.col("all_temps").list.eval(rank_pct, parallel=True).alias("temps_rank"), +) + +print(out) +# --8<-- [end:weather_by_day_rank] + +# --8<-- [start:array_df] +array_df = pl.DataFrame( + [ + pl.Series("Array_1", [[1, 3], [2, 5]]), + pl.Series("Array_2", [[1, 7, 3], [8, 1, 0]]), + ], + schema={"Array_1": pl.Array(2, pl.Int64), "Array_2": pl.Array(3, pl.Int64)}, +) +print(array_df) +# --8<-- [end:array_df] + +# --8<-- [start:array_ops] +out = array_df.select( + pl.col("Array_1").arr.min().suffix("_min"), + pl.col("Array_2").arr.sum().suffix("_sum"), +) +print(out) +# --8<-- [end:array_ops] diff --git a/docs/src/python/user-guide/expressions/null.py b/docs/src/python/user-guide/expressions/null.py new file mode 100644 index 000000000000..4641773bbb85 --- /dev/null +++ b/docs/src/python/user-guide/expressions/null.py @@ -0,0 +1,88 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np + +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "value": [1, None], + }, +) +print(df) +# --8<-- [end:dataframe] + + +# --8<-- [start:count] +null_count_df = df.null_count() +print(null_count_df) +# --8<-- [end:count] + + +# --8<-- [start:isnull] +is_null_series = df.select( + pl.col("value").is_null(), +) +print(is_null_series) +# --8<-- [end:isnull] + + +# --8<-- [start:dataframe2] +df = pl.DataFrame( + { + "col1": [1, 2, 3], + "col2": [1, None, 3], + }, +) +print(df) +# --8<-- [end:dataframe2] + + +# --8<-- [start:fill] +fill_literal_df = ( + df.with_columns( + pl.col("col2").fill_null( + pl.lit(2), + ), + ), +) +print(fill_literal_df) +# --8<-- [end:fill] + +# --8<-- [start:fillstrategy] +fill_forward_df = df.with_columns( + pl.col("col2").fill_null(strategy="forward"), +) +print(fill_forward_df) +# --8<-- [end:fillstrategy] + +# --8<-- [start:fillexpr] +fill_median_df = df.with_columns( + pl.col("col2").fill_null(pl.median("col2")), +) +print(fill_median_df) +# --8<-- [end:fillexpr] + +# --8<-- [start:fillinterpolate] +fill_interpolation_df = df.with_columns( + pl.col("col2").interpolate(), +) +print(fill_interpolation_df) +# --8<-- [end:fillinterpolate] + +# --8<-- [start:nan] +nan_df = pl.DataFrame( + { + "value": [1.0, np.NaN, float("nan"), 3.0], + }, +) +print(nan_df) +# --8<-- [end:nan] + +# --8<-- [start:nanfill] +mean_nan_df = nan_df.with_columns( + pl.col("value").fill_nan(None).alias("value"), +).mean() +print(mean_nan_df) +# --8<-- [end:nanfill] diff --git a/docs/src/python/user-guide/expressions/numpy-example.py b/docs/src/python/user-guide/expressions/numpy-example.py new file mode 100644 index 000000000000..d3300591c4d6 --- /dev/null +++ b/docs/src/python/user-guide/expressions/numpy-example.py @@ -0,0 +1,7 @@ +import polars as pl +import numpy as np + +df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + +out = df.select(np.log(pl.all()).suffix("_log")) +print(out) diff --git a/docs/src/python/user-guide/expressions/operators.py b/docs/src/python/user-guide/expressions/operators.py new file mode 100644 index 000000000000..6f617487c81e --- /dev/null +++ b/docs/src/python/user-guide/expressions/operators.py @@ -0,0 +1,44 @@ +# --8<-- [start:setup] + +import polars as pl +import numpy as np + +np.random.seed(12) +# --8<-- [end:setup] + + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", None], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:numerical] + +df_numerical = df.select( + (pl.col("nrs") + 5).alias("nrs + 5"), + (pl.col("nrs") - 5).alias("nrs - 5"), + (pl.col("nrs") * pl.col("random")).alias("nrs * random"), + (pl.col("nrs") / pl.col("random")).alias("nrs / random"), +) +print(df_numerical) + +# --8<-- [end:numerical] + +# --8<-- [start:logical] +df_logical = df.select( + (pl.col("nrs") > 1).alias("nrs > 1"), + (pl.col("random") <= 0.5).alias("random < .5"), + (pl.col("nrs") != 1).alias("nrs != 1"), + (pl.col("nrs") == 1).alias("nrs == 1"), + ((pl.col("random") <= 0.5) & (pl.col("nrs") > 1)).alias("and_expr"), # and + ((pl.col("random") <= 0.5) | (pl.col("nrs") > 1)).alias("or_expr"), # or +) +print(df_logical) +# --8<-- [end:logical] diff --git a/docs/src/python/user-guide/expressions/strings.py b/docs/src/python/user-guide/expressions/strings.py new file mode 100644 index 000000000000..379c20358feb --- /dev/null +++ b/docs/src/python/user-guide/expressions/strings.py @@ -0,0 +1,61 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:df] +df = pl.DataFrame({"animal": ["Crab", "cat and dog", "rab$bit", None]}) + +out = df.select( + pl.col("animal").str.len_bytes().alias("byte_count"), + pl.col("animal").str.len_chars().alias("letter_count"), +) +print(out) +# --8<-- [end:df] + +# --8<-- [start:existence] +out = df.select( + pl.col("animal"), + pl.col("animal").str.contains("cat|bit").alias("regex"), + pl.col("animal").str.contains("rab$", literal=True).alias("literal"), + pl.col("animal").str.starts_with("rab").alias("starts_with"), + pl.col("animal").str.ends_with("dog").alias("ends_with"), +) +print(out) +# --8<-- [end:existence] + +# --8<-- [start:extract] +df = pl.DataFrame( + { + "a": [ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + } +) +out = df.select( + pl.col("a").str.extract(r"candidate=(\w+)", group_index=1), +) +print(out) +# --8<-- [end:extract] + + +# --8<-- [start:extract_all] +df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"]}) +out = df.select( + pl.col("foo").str.extract_all(r"(\d+)").alias("extracted_nrs"), +) +print(out) +# --8<-- [end:extract_all] + + +# --8<-- [start:replace] +df = pl.DataFrame({"id": [1, 2], "text": ["123abc", "abc456"]}) +out = df.with_columns( + pl.col("text").str.replace(r"abc\b", "ABC"), + pl.col("text").str.replace_all("a", "-", literal=True).alias("text_replace_all"), +) +print(out) +# --8<-- [end:replace] diff --git a/docs/src/python/user-guide/expressions/structs.py b/docs/src/python/user-guide/expressions/structs.py new file mode 100644 index 000000000000..ee034a362bc6 --- /dev/null +++ b/docs/src/python/user-guide/expressions/structs.py @@ -0,0 +1,66 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:ratings_df] +ratings = pl.DataFrame( + { + "Movie": ["Cars", "IT", "ET", "Cars", "Up", "IT", "Cars", "ET", "Up", "ET"], + "Theatre": ["NE", "ME", "IL", "ND", "NE", "SD", "NE", "IL", "IL", "SD"], + "Avg_Rating": [4.5, 4.4, 4.6, 4.3, 4.8, 4.7, 4.7, 4.9, 4.7, 4.6], + "Count": [30, 27, 26, 29, 31, 28, 28, 26, 33, 26], + } +) +print(ratings) +# --8<-- [end:ratings_df] + +# --8<-- [start:state_value_counts] +out = ratings.select(pl.col("Theatre").value_counts(sort=True)) +print(out) +# --8<-- [end:state_value_counts] + +# --8<-- [start:struct_unnest] +out = ratings.select(pl.col("Theatre").value_counts(sort=True)).unnest("Theatre") +print(out) +# --8<-- [end:struct_unnest] + +# --8<-- [start:series_struct] +rating_series = pl.Series( + "ratings", + [ + {"Movie": "Cars", "Theatre": "NE", "Avg_Rating": 4.5}, + {"Movie": "Toy Story", "Theatre": "ME", "Avg_Rating": 4.9}, + ], +) +print(rating_series) +# --8<-- [end:series_struct] + +# --8<-- [start:series_struct_extract] +out = rating_series.struct.field("Movie") +print(out) +# --8<-- [end:series_struct_extract] + +# --8<-- [start:series_struct_rename] +out = ( + rating_series.to_frame() + .select(pl.col("ratings").struct.rename_fields(["Film", "State", "Value"])) + .unnest("ratings") +) +print(out) +# --8<-- [end:series_struct_rename] + +# --8<-- [start:struct_duplicates] +out = ratings.filter(pl.struct("Movie", "Theatre").is_duplicated()) +print(out) +# --8<-- [end:struct_duplicates] + +# --8<-- [start:struct_ranking] +out = ratings.with_columns( + pl.struct("Count", "Avg_Rating") + .rank("dense", descending=True) + .over("Movie", "Theatre") + .alias("Rank") +).filter(pl.struct("Movie", "Theatre").is_duplicated()) +print(out) +# --8<-- [end:struct_ranking] diff --git a/docs/src/python/user-guide/expressions/user-defined-functions.py b/docs/src/python/user-guide/expressions/user-defined-functions.py new file mode 100644 index 000000000000..920812babd93 --- /dev/null +++ b/docs/src/python/user-guide/expressions/user-defined-functions.py @@ -0,0 +1,56 @@ +# --8<-- [start:setup] + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "keys": ["a", "a", "b"], + "values": [10, 7, 1], + } +) + +out = df.group_by("keys", maintain_order=True).agg( + pl.col("values").map_batches(lambda s: s.shift()).alias("shift_map"), + pl.col("values").shift().alias("shift_expression"), +) +print(df) +# --8<-- [end:dataframe] + + +# --8<-- [start:apply] +out = df.group_by("keys", maintain_order=True).agg( + pl.col("values").map_elements(lambda s: s.shift()).alias("shift_map"), + pl.col("values").shift().alias("shift_expression"), +) +print(out) +# --8<-- [end:apply] + +# --8<-- [start:counter] +counter = 0 + + +def add_counter(val: int) -> int: + global counter + counter += 1 + return counter + val + + +out = df.select( + pl.col("values").map_elements(add_counter).alias("solution_apply"), + (pl.col("values") + pl.int_range(1, pl.count() + 1)).alias("solution_expr"), +) +print(out) +# --8<-- [end:counter] + +# --8<-- [start:combine] +out = df.select( + pl.struct(["keys", "values"]) + .map_elements(lambda x: len(x["keys"]) + x["values"]) + .alias("solution_apply"), + (pl.col("keys").str.len_bytes() + pl.col("values")).alias("solution_expr"), +) +print(out) +# --8<-- [end:combine] diff --git a/docs/src/python/user-guide/expressions/window.py b/docs/src/python/user-guide/expressions/window.py new file mode 100644 index 000000000000..bd2adda867f5 --- /dev/null +++ b/docs/src/python/user-guide/expressions/window.py @@ -0,0 +1,84 @@ +# --8<-- [start:pokemon] +import polars as pl + +# then let's load some csv data with information about pokemon +df = pl.read_csv( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv" +) +print(df.head()) +# --8<-- [end:pokemon] + + +# --8<-- [start:group_by] +out = df.select( + "Type 1", + "Type 2", + pl.col("Attack").mean().over("Type 1").alias("avg_attack_by_type"), + pl.col("Defense") + .mean() + .over(["Type 1", "Type 2"]) + .alias("avg_defense_by_type_combination"), + pl.col("Attack").mean().alias("avg_attack"), +) +print(out) +# --8<-- [end:group_by] + +# --8<-- [start:operations] +filtered = df.filter(pl.col("Type 2") == "Psychic").select( + "Name", + "Type 1", + "Speed", +) +print(filtered) +# --8<-- [end:operations] + +# --8<-- [start:sort] +out = filtered.with_columns( + pl.col(["Name", "Speed"]).sort_by("Speed", descending=True).over("Type 1"), +) +print(out) +# --8<-- [end:sort] + +# --8<-- [start:rules] +# aggregate and broadcast within a group +# output type: -> Int32 +pl.sum("foo").over("groups") + +# sum within a group and multiply with group elements +# output type: -> Int32 +(pl.col("x").sum() * pl.col("y")).over("groups") + +# sum within a group and multiply with group elements +# and aggregate the group to a list +# output type: -> List(Int32) +(pl.col("x").sum() * pl.col("y")).over("groups", mapping_strategy="join") + +# sum within a group and multiply with group elements +# and aggregate the group to a list +# then explode the list to multiple rows + +# This is the fastest method to do things over groups when the groups are sorted +(pl.col("x").sum() * pl.col("y")).over("groups", mapping_strategy="explode") +# --8<-- [end:rules] + +# --8<-- [start:examples] +out = df.sort("Type 1").select( + pl.col("Type 1").head(3).over("Type 1", mapping_strategy="explode"), + pl.col("Name") + .sort_by(pl.col("Speed"), descending=True) + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("fastest/group"), + pl.col("Name") + .sort_by(pl.col("Attack"), descending=True) + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("strongest/group"), + pl.col("Name") + .sort() + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("sorted_by_alphabet"), +) +print(out) +# --8<-- [end:examples] diff --git a/docs/src/python/user-guide/io/bigquery.py b/docs/src/python/user-guide/io/bigquery.py new file mode 100644 index 000000000000..678ed70200b4 --- /dev/null +++ b/docs/src/python/user-guide/io/bigquery.py @@ -0,0 +1,38 @@ +""" +# --8<-- [start:read] +import polars as pl +from google.cloud import bigquery + +client = bigquery.Client() + +# Perform a query. +QUERY = ( + 'SELECT name FROM `bigquery-public-data.usa_names.usa_1910_2013` ' + 'WHERE state = "TX" ' + 'LIMIT 100') +query_job = client.query(QUERY) # API request +rows = query_job.result() # Waits for query to finish + +df = pl.from_arrow(rows.to_arrow()) +# --8<-- [end:read] + +# --8<-- [start:write] +from google.cloud import bigquery + +client = bigquery.Client() + +# Write dataframe to stream as parquet file; does not hit disk +with io.BytesIO() as stream: + df.write_parquet(stream) + stream.seek(0) + job = client.load_table_from_file( + stream, + destination='tablename', + project='projectname', + job_config=bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.PARQUET, + ), + ) +job.result() # Waits for the job to complete +# --8<-- [end:write] +""" diff --git a/docs/src/python/user-guide/io/cloud-storage.py b/docs/src/python/user-guide/io/cloud-storage.py new file mode 100644 index 000000000000..0f968e15f97b --- /dev/null +++ b/docs/src/python/user-guide/io/cloud-storage.py @@ -0,0 +1,63 @@ +""" +# --8<-- [start:read_parquet] +import polars as pl + +source = "s3://bucket/*.parquet" + +df = pl.read_parquet(source) +# --8<-- [end:read_parquet] + +# --8<-- [start:scan_parquet] +import polars as pl + +source = "s3://bucket/*.parquet" + +storage_options = { + "aws_access_key_id": "", + "aws_secret_access_key": "", + "aws_region": "us-east-1", +} +df = pl.scan_parquet(source, storage_options=storage_options) +# --8<-- [end:scan_parquet] + +# --8<-- [start:scan_parquet_query] +import polars as pl + +source = "s3://bucket/*.parquet" + + +df = pl.scan_parquet(source).filter(pl.col("id") < 100).select("id","value").collect() +# --8<-- [end:scan_parquet_query] + +# --8<-- [start:scan_pyarrow_dataset] +import polars as pl +import pyarrow.dataset as ds + +dset = ds.dataset("s3://my-partitioned-folder/", format="parquet") +( + pl.scan_pyarrow_dataset(dset) + .filter("foo" == "a") + .select(["foo", "bar"]) + .collect() +) +# --8<-- [end:scan_pyarrow_dataset] + +# --8<-- [start:write_parquet] + +import polars as pl +import s3fs + +df = pl.DataFrame({ + "foo": ["a", "b", "c", "d", "d"], + "bar": [1, 2, 3, 4, 5], +}) + +fs = s3fs.S3FileSystem() +destination = "s3://bucket/my_file.parquet" + +# write parquet +with fs.open(destination, mode='wb') as f: + df.write_parquet(f) +# --8<-- [end:write_parquet] + +""" diff --git a/docs/src/python/user-guide/io/csv.py b/docs/src/python/user-guide/io/csv.py new file mode 100644 index 000000000000..d4039a43ce35 --- /dev/null +++ b/docs/src/python/user-guide/io/csv.py @@ -0,0 +1,19 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_csv("docs/data/path.csv") +# --8<-- [end:read] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_csv("docs/data/path.csv") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_csv("docs/data/path.csv") +# --8<-- [end:scan] diff --git a/docs/src/python/user-guide/io/database.py b/docs/src/python/user-guide/io/database.py new file mode 100644 index 000000000000..b37045719995 --- /dev/null +++ b/docs/src/python/user-guide/io/database.py @@ -0,0 +1,44 @@ +""" +# --8<-- [start:read_uri] +import polars as pl + +uri = "postgres://username:password@server:port/database" +query = "SELECT * FROM foo" + +pl.read_database_uri(query=query, uri=uri) +# --8<-- [end:read_uri] + +# --8<-- [start:read_cursor] +import polars as pl +from sqlalchemy import create_engine + +conn = create_engine(f"sqlite:///test.db") + +query = "SELECT * FROM foo" + +pl.read_database(query=query, connection=conn.connect()) +# --8<-- [end:read_cursor] + + +# --8<-- [start:adbc] +uri = "postgres://username:password@server:port/database" +query = "SELECT * FROM foo" + +pl.read_database_uri(query=query, uri=uri, engine="adbc") +# --8<-- [end:adbc] + +# --8<-- [start:write] +uri = "postgres://username:password@server:port/database" +df = pl.DataFrame({"foo": [1, 2, 3]}) + +df.write_database(table_name="records", uri=uri) +# --8<-- [end:write] + +# --8<-- [start:write_adbc] +uri = "postgres://username:password@server:port/database" +df = pl.DataFrame({"foo": [1, 2, 3]}) + +df.write_database(table_name="records", uri=uri, engine="adbc") +# --8<-- [end:write_adbc] + +""" diff --git a/docs/src/python/user-guide/io/json.py b/docs/src/python/user-guide/io/json.py new file mode 100644 index 000000000000..8e6ba3955dc4 --- /dev/null +++ b/docs/src/python/user-guide/io/json.py @@ -0,0 +1,24 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_json("docs/data/path.json") +# --8<-- [end:read] + +# --8<-- [start:readnd] +df = pl.read_ndjson("docs/data/path.json") +# --8<-- [end:readnd] + +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_json("docs/data/path.json") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_ndjson("docs/data/path.json") +# --8<-- [end:scan] diff --git a/docs/src/python/user-guide/io/multiple.py b/docs/src/python/user-guide/io/multiple.py new file mode 100644 index 000000000000..f7500b6b6684 --- /dev/null +++ b/docs/src/python/user-guide/io/multiple.py @@ -0,0 +1,41 @@ +# --8<-- [start:create] +import polars as pl + +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "ham", "spam"]}) + +for i in range(5): + df.write_csv(f"docs/data/my_many_files_{i}.csv") +# --8<-- [end:create] + +# --8<-- [start:read] +df = pl.read_csv("docs/data/my_many_files_*.csv") +print(df) +# --8<-- [end:read] + +# --8<-- [start:creategraph] +import base64 + +pl.scan_csv("docs/data/my_many_files_*.csv").show_graph( + output_path="docs/images/multiple.png", show=False +) +with open("docs/images/multiple.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:creategraph] + +# --8<-- [start:graph] +pl.scan_csv("docs/data/my_many_files_*.csv").show_graph() +# --8<-- [end:graph] + +# --8<-- [start:glob] +import polars as pl +import glob + +queries = [] +for file in glob.glob("docs/data/my_many_files_*.csv"): + q = pl.scan_csv(file).group_by("bar").agg([pl.count(), pl.sum("foo")]) + queries.append(q) + +dataframes = pl.collect_all(queries) +print(dataframes) +# --8<-- [end:glob] diff --git a/docs/src/python/user-guide/io/parquet.py b/docs/src/python/user-guide/io/parquet.py new file mode 100644 index 000000000000..feba73df9a19 --- /dev/null +++ b/docs/src/python/user-guide/io/parquet.py @@ -0,0 +1,19 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_parquet("docs/data/path.parquet") +# --8<-- [end:read] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_parquet("docs/data/path.parquet") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_parquet("docs/data/path.parquet") +# --8<-- [end:scan] diff --git a/docs/src/python/user-guide/lazy/execution.py b/docs/src/python/user-guide/lazy/execution.py new file mode 100644 index 000000000000..110fb0105500 --- /dev/null +++ b/docs/src/python/user-guide/lazy/execution.py @@ -0,0 +1,36 @@ +import polars as pl + +""" +# --8<-- [start:df] +q1 = ( + pl.scan_csv("docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:df] + +# --8<-- [start:collect] +q4 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect() +) +# --8<-- [end:collect] +# --8<-- [start:stream] +q5 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect(streaming=True) +) +# --8<-- [end:stream] +# --8<-- [start:partial] +q9 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .fetch(n_rows=int(100)) +) +# --8<-- [end:partial] +""" diff --git a/docs/src/python/user-guide/lazy/query-plan.py b/docs/src/python/user-guide/lazy/query-plan.py new file mode 100644 index 000000000000..ed2c3f4bac45 --- /dev/null +++ b/docs/src/python/user-guide/lazy/query-plan.py @@ -0,0 +1,48 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:plan] +q1 = ( + pl.scan_csv("docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:plan] + +# --8<-- [start:createplan] +import base64 + +q1.show_graph(optimized=False, show=False, output_path="docs/images/query_plan.png") +with open("docs/images/query_plan.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:createplan] + +""" +# --8<-- [start:showplan] +q1.show_graph(optimized=False) +# --8<-- [end:showplan] +""" + +# --8<-- [start:describe] +q1.explain(optimized=False) +# --8<-- [end:describe] + +# --8<-- [start:createplan2] +q1.show_graph(show=False, output_path="docs/images/query_plan_optimized.png") +with open("docs/images/query_plan_optimized.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:createplan2] + +""" +# --8<-- [start:show] +q1.show_graph() +# --8<-- [end:show] +""" + +# --8<-- [start:optimized] +q1.explain() +# --8<-- [end:optimized] diff --git a/docs/src/python/user-guide/lazy/schema.py b/docs/src/python/user-guide/lazy/schema.py new file mode 100644 index 000000000000..e621718307ee --- /dev/null +++ b/docs/src/python/user-guide/lazy/schema.py @@ -0,0 +1,38 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:schema] +q3 = pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy() + +print(q3.schema) +# --8<-- [end:schema] + +# --8<-- [start:typecheck] +pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy().with_columns( + pl.col("bar").round(0) +) +# --8<-- [end:typecheck] + +# --8<-- [start:lazyeager] +lazy_eager_query = ( + pl.DataFrame( + { + "id": ["a", "b", "c"], + "month": ["jan", "feb", "mar"], + "values": [0, 1, 2], + } + ) + .lazy() + .with_columns((2 * pl.col("values")).alias("double_values")) + .collect() + .pivot( + index="id", columns="month", values="double_values", aggregate_function="first" + ) + .lazy() + .filter(pl.col("mar").is_null()) + .collect() +) +print(lazy_eager_query) +# --8<-- [end:lazyeager] diff --git a/docs/src/python/user-guide/lazy/using.py b/docs/src/python/user-guide/lazy/using.py new file mode 100644 index 000000000000..1a10abb189d2 --- /dev/null +++ b/docs/src/python/user-guide/lazy/using.py @@ -0,0 +1,15 @@ +import polars as pl + +""" +# --8<-- [start:dataframe] +q1 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:dataframe] + +# --8<-- [start:fromdf] +q3 = pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy() +# --8<-- [end:fromdf] +""" diff --git a/docs/src/python/user-guide/misc/multiprocess.py b/docs/src/python/user-guide/misc/multiprocess.py new file mode 100644 index 000000000000..55aec52d6b9f --- /dev/null +++ b/docs/src/python/user-guide/misc/multiprocess.py @@ -0,0 +1,84 @@ +""" +# --8<-- [start:recommendation] +from multiprocessing import get_context + + +def my_fun(s): + print(s) + + +with get_context("spawn").Pool() as pool: + pool.map(my_fun, ["input1", "input2", ...]) + +# --8<-- [end:recommendation] + +# --8<-- [start:example1] +import multiprocessing +import polars as pl + + +def test_sub_process(df: pl.DataFrame, job_id): + df_filtered = df.filter(pl.col("a") > 0) + print(f"Filtered (job_id: {job_id})", df_filtered, sep="\n") + + +def create_dataset(): + return pl.DataFrame({"a": [0, 2, 3, 4, 5], "b": [0, 4, 5, 56, 4]}) + + +def setup(): + # some setup work + df = create_dataset() + df.write_parquet("/tmp/test.parquet") + + +def main(): + test_df = pl.read_parquet("/tmp/test.parquet") + + for i in range(0, 5): + proc = multiprocessing.get_context("spawn").Process( + target=test_sub_process, args=(test_df, i) + ) + proc.start() + proc.join() + + print(f"Executed sub process {i}") + + +if __name__ == "__main__": + setup() + main() + +# --8<-- [end:example1] +""" +# --8<-- [start:example2] +import multiprocessing +import polars as pl + + +def test_sub_process(df: pl.DataFrame, job_id): + df_filtered = df.filter(pl.col("a") > 0) + print(f"Filtered (job_id: {job_id})", df_filtered, sep="\n") + + +def create_dataset(): + return pl.DataFrame({"a": [0, 2, 3, 4, 5], "b": [0, 4, 5, 56, 4]}) + + +def main(): + test_df = create_dataset() + + for i in range(0, 5): + proc = multiprocessing.get_context("fork").Process( + target=test_sub_process, args=(test_df, i) + ) + proc.start() + proc.join() + + print(f"Executed sub process {i}") + + +if __name__ == "__main__": + main() + +# --8<-- [end:example2] diff --git a/docs/src/python/user-guide/sql/create.py b/docs/src/python/user-guide/sql/create.py new file mode 100644 index 000000000000..e26ffd0a31f1 --- /dev/null +++ b/docs/src/python/user-guide/sql/create.py @@ -0,0 +1,21 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:create] +data = {"name": ["Alice", "Bob", "Charlie", "David"], "age": [25, 30, 35, 40]} +df = pl.LazyFrame(data) + +ctx = pl.SQLContext(my_table=df, eager_execution=True) + +result = ctx.execute( + """ + CREATE TABLE older_people + AS + SELECT * FROM my_table WHERE age > 30 +""" +) + +print(ctx.execute("SELECT * FROM older_people")) +# --8<-- [end:create] diff --git a/docs/src/python/user-guide/sql/cte.py b/docs/src/python/user-guide/sql/cte.py new file mode 100644 index 000000000000..c44b906cf3ad --- /dev/null +++ b/docs/src/python/user-guide/sql/cte.py @@ -0,0 +1,24 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:cte] +ctx = pl.SQLContext() +df = pl.LazyFrame( + {"name": ["Alice", "Bob", "Charlie", "David"], "age": [25, 30, 35, 40]} +) +ctx.register("my_table", df) + +result = ctx.execute( + """ + WITH older_people AS ( + SELECT * FROM my_table WHERE age > 30 + ) + SELECT * FROM older_people WHERE STARTS_WITH(name,'C') +""", + eager=True, +) + +print(result) +# --8<-- [end:cte] diff --git a/docs/src/python/user-guide/sql/intro.py b/docs/src/python/user-guide/sql/intro.py new file mode 100644 index 000000000000..3b59ac9e70d1 --- /dev/null +++ b/docs/src/python/user-guide/sql/intro.py @@ -0,0 +1,100 @@ +# --8<-- [start:setup] +import os + +import pandas as pd +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:context] +ctx = pl.SQLContext() +# --8<-- [end:context] + +# --8<-- [start:register_context] +df = pl.DataFrame({"a": [1, 2, 3]}) +lf = pl.LazyFrame({"b": [4, 5, 6]}) + +# Register all dataframes in the global namespace: registers both df and lf +ctx = pl.SQLContext(register_globals=True) + +# Other option: register dataframe df as "df" and lazyframe lf as "lf" +ctx = pl.SQLContext(df=df, lf=lf) +# --8<-- [end:register_context] + +# --8<-- [start:register_pandas] +import pandas as pd + +df_pandas = pd.DataFrame({"c": [7, 8, 9]}) +ctx = pl.SQLContext(df_pandas=pl.from_pandas(df_pandas)) +# --8<-- [end:register_pandas] + +# --8<-- [start:execute] +# For local files use scan_csv instead +pokemon = pl.read_csv( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv" +) +ctx = pl.SQLContext(register_globals=True, eager_execution=True) +df_small = ctx.execute("SELECT * from pokemon LIMIT 5") +print(df_small) +# --8<-- [end:execute] + +# --8<-- [start:prepare_multiple_sources] +with open("products_categories.json", "w") as temp_file: + json_data = """{"product_id": 1, "category": "Category 1"} +{"product_id": 2, "category": "Category 1"} +{"product_id": 3, "category": "Category 2"} +{"product_id": 4, "category": "Category 2"} +{"product_id": 5, "category": "Category 3"}""" + + temp_file.write(json_data) + +with open("products_masterdata.csv", "w") as temp_file: + csv_data = """product_id,product_name +1,Product A +2,Product B +3,Product C +4,Product D +5,Product E""" + + temp_file.write(csv_data) + +sales_data = pd.DataFrame( + { + "product_id": [1, 2, 3, 4, 5], + "sales": [100, 200, 150, 250, 300], + } +) +# --8<-- [end:prepare_multiple_sources] + +# --8<-- [start:execute_multiple_sources] +# Input data: +# products_masterdata.csv with schema {'product_id': Int64, 'product_name': Utf8} +# products_categories.json with schema {'product_id': Int64, 'category': Utf8} +# sales_data is a Pandas DataFrame with schema {'product_id': Int64, 'sales': Int64} + +ctx = pl.SQLContext( + products_masterdata=pl.scan_csv("products_masterdata.csv"), + products_categories=pl.scan_ndjson("products_categories.json"), + sales_data=pl.from_pandas(sales_data), + eager_execution=True, +) + +query = """ +SELECT + product_id, + product_name, + category, + sales +FROM + products_masterdata +LEFT JOIN products_categories USING (product_id) +LEFT JOIN sales_data USING (product_id) +""" + +print(ctx.execute(query)) +# --8<-- [end:execute_multiple_sources] + +# --8<-- [start:clean_multiple_sources] +os.remove("products_categories.json") +os.remove("products_masterdata.csv") +# --8<-- [end:clean_multiple_sources] diff --git a/docs/src/python/user-guide/sql/select.py b/docs/src/python/user-guide/sql/select.py new file mode 100644 index 000000000000..1e040c739b99 --- /dev/null +++ b/docs/src/python/user-guide/sql/select.py @@ -0,0 +1,106 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:df] +df = pl.DataFrame( + { + "city": [ + "New York", + "Los Angeles", + "Chicago", + "Houston", + "Phoenix", + "Amsterdam", + ], + "country": ["USA", "USA", "USA", "USA", "USA", "Netherlands"], + "population": [8399000, 3997000, 2705000, 2320000, 1680000, 900000], + } +) + +ctx = pl.SQLContext(population=df, eager_execution=True) + +print(ctx.execute("SELECT * FROM population")) +# --8<-- [end:df] + +# --8<-- [start:group_by] +result = ctx.execute( + """ + SELECT country, AVG(population) as avg_population + FROM population + GROUP BY country + """ +) +print(result) +# --8<-- [end:group_by] + + +# --8<-- [start:orderby] +result = ctx.execute( + """ + SELECT city, population + FROM population + ORDER BY population + """ +) +print(result) +# --8<-- [end:orderby] + +# --8<-- [start:join] +income = pl.DataFrame( + { + "city": [ + "New York", + "Los Angeles", + "Chicago", + "Houston", + "Amsterdam", + "Rotterdam", + "Utrecht", + ], + "country": [ + "USA", + "USA", + "USA", + "USA", + "Netherlands", + "Netherlands", + "Netherlands", + ], + "income": [55000, 62000, 48000, 52000, 42000, 38000, 41000], + } +) +ctx.register_many(income=income) +result = ctx.execute( + """ + SELECT country, city, income, population + FROM population + LEFT JOIN income on population.city = income.city + """ +) +print(result) +# --8<-- [end:join] + + +# --8<-- [start:functions] +result = ctx.execute( + """ + SELECT city, population + FROM population + WHERE STARTS_WITH(country,'U') + """ +) +print(result) +# --8<-- [end:functions] + +# --8<-- [start:tablefunctions] +result = ctx.execute( + """ + SELECT * + FROM read_csv('docs/data/iris.csv') + """ +) +print(result) +# --8<-- [end:tablefunctions] diff --git a/docs/src/python/user-guide/sql/show.py b/docs/src/python/user-guide/sql/show.py new file mode 100644 index 000000000000..cedf425dc54b --- /dev/null +++ b/docs/src/python/user-guide/sql/show.py @@ -0,0 +1,26 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:show] +# Create some DataFrames and register them with the SQLContext +df1 = pl.LazyFrame( + { + "name": ["Alice", "Bob", "Charlie", "David"], + "age": [25, 30, 35, 40], + } +) +df2 = pl.LazyFrame( + { + "name": ["Ellen", "Frank", "Gina", "Henry"], + "age": [45, 50, 55, 60], + } +) +ctx = pl.SQLContext(mytable1=df1, mytable2=df2) + +tables = ctx.execute("SHOW TABLES", eager=True) + +print(tables) +# --8<-- [end:show] diff --git a/docs/src/python/user-guide/transformations/concatenation.py b/docs/src/python/user-guide/transformations/concatenation.py new file mode 100644 index 000000000000..65b5c8239e83 --- /dev/null +++ b/docs/src/python/user-guide/transformations/concatenation.py @@ -0,0 +1,76 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import datetime + +# --8<-- [end:setup] + +# --8<-- [start:vertical] +df_v1 = pl.DataFrame( + { + "a": [1], + "b": [3], + } +) +df_v2 = pl.DataFrame( + { + "a": [2], + "b": [4], + } +) +df_vertical_concat = pl.concat( + [ + df_v1, + df_v2, + ], + how="vertical", +) +print(df_vertical_concat) +# --8<-- [end:vertical] + +# --8<-- [start:horizontal] +df_h1 = pl.DataFrame( + { + "l1": [1, 2], + "l2": [3, 4], + } +) +df_h2 = pl.DataFrame( + { + "r1": [5, 6], + "r2": [7, 8], + "r3": [9, 10], + } +) +df_horizontal_concat = pl.concat( + [ + df_h1, + df_h2, + ], + how="horizontal", +) +print(df_horizontal_concat) +# --8<-- [end:horizontal] + +# --8<-- [start:cross] +df_d1 = pl.DataFrame( + { + "a": [1], + "b": [3], + } +) +df_d2 = pl.DataFrame( + { + "a": [2], + "d": [4], + } +) + +df_diagonal_concat = pl.concat( + [ + df_d1, + df_d2, + ], + how="diagonal", +) +print(df_diagonal_concat) +# --8<-- [end:cross] diff --git a/docs/src/python/user-guide/transformations/joins.py b/docs/src/python/user-guide/transformations/joins.py new file mode 100644 index 000000000000..98828020820d --- /dev/null +++ b/docs/src/python/user-guide/transformations/joins.py @@ -0,0 +1,150 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import datetime + +# --8<-- [end:setup] + +# --8<-- [start:innerdf] +df_customers = pl.DataFrame( + { + "customer_id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + } +) +print(df_customers) +# --8<-- [end:innerdf] + +# --8<-- [start:innerdf2] +df_orders = pl.DataFrame( + { + "order_id": ["a", "b", "c"], + "customer_id": [1, 2, 2], + "amount": [100, 200, 300], + } +) +print(df_orders) +# --8<-- [end:innerdf2] + + +# --8<-- [start:inner] +df_inner_customer_join = df_customers.join(df_orders, on="customer_id", how="inner") +print(df_inner_customer_join) +# --8<-- [end:inner] + +# --8<-- [start:left] +df_left_join = df_customers.join(df_orders, on="customer_id", how="left") +print(df_left_join) +# --8<-- [end:left] + +# --8<-- [start:outer] +df_outer_join = df_customers.join(df_orders, on="customer_id", how="outer") +print(df_outer_join) +# --8<-- [end:outer] + +# --8<-- [start:df3] +df_colors = pl.DataFrame( + { + "color": ["red", "blue", "green"], + } +) +print(df_colors) +# --8<-- [end:df3] + +# --8<-- [start:df4] +df_sizes = pl.DataFrame( + { + "size": ["S", "M", "L"], + } +) +print(df_sizes) +# --8<-- [end:df4] + +# --8<-- [start:cross] +df_cross_join = df_colors.join(df_sizes, how="cross") +print(df_cross_join) +# --8<-- [end:cross] + +# --8<-- [start:df5] +df_cars = pl.DataFrame( + { + "id": ["a", "b", "c"], + "make": ["ford", "toyota", "bmw"], + } +) +print(df_cars) +# --8<-- [end:df5] + +# --8<-- [start:df6] +df_repairs = pl.DataFrame( + { + "id": ["c", "c"], + "cost": [100, 200], + } +) +print(df_repairs) +# --8<-- [end:df6] + +# --8<-- [start:inner2] +df_inner_join = df_cars.join(df_repairs, on="id", how="inner") +print(df_inner_join) +# --8<-- [end:inner2] + +# --8<-- [start:semi] +df_semi_join = df_cars.join(df_repairs, on="id", how="semi") +print(df_semi_join) +# --8<-- [end:semi] + +# --8<-- [start:anti] +df_anti_join = df_cars.join(df_repairs, on="id", how="anti") +print(df_anti_join) +# --8<-- [end:anti] + +# --8<-- [start:df7] +df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 3, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + } +) +print(df_trades) +# --8<-- [end:df7] + +# --8<-- [start:df8] +df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 2, 0), + datetime(2020, 1, 1, 9, 4, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "C", "A"], + "quote": [100, 300, 501, 102], + } +) + +print(df_quotes) +# --8<-- [end:df8] + +# --8<-- [start:asofpre] +df_trades = df_trades.sort("time") +df_quotes = df_quotes.sort("time") # Set column as sorted +# --8<-- [end:asofpre] + +# --8<-- [start:asof] +df_asof_join = df_trades.join_asof(df_quotes, on="time", by="stock") +print(df_asof_join) +# --8<-- [end:asof] + +# --8<-- [start:asof2] +df_asof_tolerance_join = df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="1m" +) +print(df_asof_tolerance_join) +# --8<-- [end:asof2] diff --git a/docs/src/python/user-guide/transformations/melt.py b/docs/src/python/user-guide/transformations/melt.py new file mode 100644 index 000000000000..e9bf53a96ec7 --- /dev/null +++ b/docs/src/python/user-guide/transformations/melt.py @@ -0,0 +1,18 @@ +# --8<-- [start:df] +import polars as pl + +df = pl.DataFrame( + { + "A": ["a", "b", "a"], + "B": [1, 3, 5], + "C": [10, 11, 12], + "D": [2, 4, 6], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:melt] +out = df.melt(id_vars=["A", "B"], value_vars=["C", "D"]) +print(out) +# --8<-- [end:melt] diff --git a/docs/src/python/user-guide/transformations/pivot.py b/docs/src/python/user-guide/transformations/pivot.py new file mode 100644 index 000000000000..d80b26ee0c34 --- /dev/null +++ b/docs/src/python/user-guide/transformations/pivot.py @@ -0,0 +1,31 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "n", "o"], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:eager] +out = df.pivot(index="foo", columns="bar", values="N", aggregate_function="first") +print(out) +# --8<-- [end:eager] + +# --8<-- [start:lazy] +q = ( + df.lazy() + .collect() + .pivot(index="foo", columns="bar", values="N", aggregate_function="first") + .lazy() +) +out = q.collect() +print(out) +# --8<-- [end:lazy] diff --git a/docs/src/python/user-guide/transformations/time-series/filter.py b/docs/src/python/user-guide/transformations/time-series/filter.py new file mode 100644 index 000000000000..e720c9ae8ef5 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/filter.py @@ -0,0 +1,30 @@ +# --8<-- [start:df] +import polars as pl +from datetime import datetime + +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=True) +print(df) +# --8<-- [end:df] + +# --8<-- [start:filter] +filtered_df = df.filter( + pl.col("Date") == datetime(1995, 10, 16), +) +print(filtered_df) +# --8<-- [end:filter] + +# --8<-- [start:range] +filtered_range_df = df.filter( + pl.col("Date").is_between(datetime(1995, 7, 1), datetime(1995, 11, 1)), +) +print(filtered_range_df) +# --8<-- [end:range] + +# --8<-- [start:negative] +ts = pl.Series(["-1300-05-23", "-1400-03-02"]).str.to_date() + +negative_dates_df = pl.DataFrame({"ts": ts, "values": [3, 4]}) + +negative_dates_filtered_df = negative_dates_df.filter(pl.col("ts").dt.year() < -1300) +print(negative_dates_filtered_df) +# --8<-- [end:negative] diff --git a/docs/src/python/user-guide/transformations/time-series/parsing.py b/docs/src/python/user-guide/transformations/time-series/parsing.py new file mode 100644 index 000000000000..0a7a05842cd1 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/parsing.py @@ -0,0 +1,43 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=True) +print(df) +# --8<-- [end:df] + + +# --8<-- [start:cast] +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=False) + +df = df.with_columns(pl.col("Date").str.to_date("%Y-%m-%d")) +print(df) +# --8<-- [end:cast] + + +# --8<-- [start:df3] +df_with_year = df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:df3] + +# --8<-- [start:extract] +df_with_year = df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:extract] + +# --8<-- [start:mixed] +data = [ + "2021-03-27T00:00:00+0100", + "2021-03-28T00:00:00+0100", + "2021-03-29T00:00:00+0200", + "2021-03-30T00:00:00+0200", +] +mixed_parsed = ( + pl.Series(data) + .str.to_datetime("%Y-%m-%dT%H:%M:%S%z") + .dt.convert_time_zone("Europe/Brussels") +) +print(mixed_parsed) +# --8<-- [end:mixed] diff --git a/docs/src/python/user-guide/transformations/time-series/resampling.py b/docs/src/python/user-guide/transformations/time-series/resampling.py new file mode 100644 index 000000000000..80a7b2597a67 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/resampling.py @@ -0,0 +1,36 @@ +# --8<-- [start:setup] +from datetime import datetime + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + "values": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:upsample] +out1 = df.upsample(time_column="time", every="15m").fill_null(strategy="forward") +print(out1) +# --8<-- [end:upsample] + +# --8<-- [start:upsample2] +out2 = ( + df.upsample(time_column="time", every="15m") + .interpolate() + .fill_null(strategy="forward") +) +print(out2) +# --8<-- [end:upsample2] diff --git a/docs/src/python/user-guide/transformations/time-series/rolling.py b/docs/src/python/user-guide/transformations/time-series/rolling.py new file mode 100644 index 000000000000..16f751523ade --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/rolling.py @@ -0,0 +1,75 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import date, datetime + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=True) +df = df.sort("Date") +print(df) +# --8<-- [end:df] + +# --8<-- [start:group_by] +annual_average_df = df.group_by_dynamic("Date", every="1y").agg(pl.col("Close").mean()) + +df_with_year = annual_average_df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:group_by] + +# --8<-- [start:group_by_dyn] +df = ( + pl.date_range( + start=date(2021, 1, 1), + end=date(2021, 12, 31), + interval="1d", + eager=True, + ) + .alias("time") + .to_frame() +) + +out = ( + df.group_by_dynamic("time", every="1mo", period="1mo", closed="left") + .agg( + [ + pl.col("time").cumcount().reverse().head(3).alias("day/eom"), + ((pl.col("time") - pl.col("time").first()).last().dt.days() + 1).alias( + "days_in_month" + ), + ] + ) + .explode("day/eom") +) +print(out) +# --8<-- [end:group_by_dyn] + +# --8<-- [start:group_by_roll] +df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + } +) +print(df) +# --8<-- [end:group_by_roll] + +# --8<-- [start:group_by_dyn2] +out = df.group_by_dynamic( + "time", + every="1h", + closed="both", + by="groups", + include_boundaries=True, +).agg( + [ + pl.count(), + ] +) +print(out) +# --8<-- [end:group_by_dyn2] diff --git a/docs/src/python/user-guide/transformations/time-series/timezones.py b/docs/src/python/user-guide/transformations/time-series/timezones.py new file mode 100644 index 000000000000..0f5470b08e30 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/timezones.py @@ -0,0 +1,27 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:example] +ts = ["2021-03-27 03:00", "2021-03-28 03:00"] +tz_naive = pl.Series("tz_naive", ts).str.to_datetime() +tz_aware = tz_naive.dt.replace_time_zone("UTC").rename("tz_aware") +time_zones_df = pl.DataFrame([tz_naive, tz_aware]) +print(time_zones_df) +# --8<-- [end:example] + +# --8<-- [start:example2] +time_zones_operations = time_zones_df.select( + [ + pl.col("tz_aware") + .dt.replace_time_zone("Europe/Brussels") + .alias("replace time zone"), + pl.col("tz_aware") + .dt.convert_time_zone("Asia/Kathmandu") + .alias("convert time zone"), + pl.col("tz_aware").dt.replace_time_zone(None).alias("unset time zone"), + ] +) +print(time_zones_operations) +# --8<-- [end:example2] diff --git a/docs/src/rust/getting-started/expressions.rs b/docs/src/rust/getting-started/expressions.rs new file mode 100644 index 000000000000..e8d031ebd1f7 --- /dev/null +++ b/docs/src/rust/getting-started/expressions.rs @@ -0,0 +1,144 @@ +use chrono::prelude::*; +use polars::prelude::*; +use rand::Rng; + +fn main() -> Result<(), Box> { + let mut rng = rand::thread_rng(); + + let df: DataFrame = df!("a" => 0..8, + "b"=> (0..8).map(|_| rng.gen::()).collect::>(), + "c"=> [ + NaiveDate::from_ymd_opt(2022, 12, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 4).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 6).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 7).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 8).unwrap().and_hms_opt(0, 0, 0).unwrap(), + ], + "d"=> [Some(1.0), Some(2.0), None, None, Some(0.0), Some(-5.0), Some(-42.), None] + ) + .expect("should not fail"); + + // --8<-- [start:select] + let out = df.clone().lazy().select([col("*")]).collect()?; + println!("{}", out); + // --8<-- [end:select] + + // --8<-- [start:select2] + let out = df.clone().lazy().select([col("a"), col("b")]).collect()?; + println!("{}", out); + // --8<-- [end:select2] + + // --8<-- [start:select3] + let out = df + .clone() + .lazy() + .select([col("a"), col("b")]) + .limit(3) + .collect()?; + println!("{}", out); + // --8<-- [end:select3] + + // --8<-- [start:exclude] + let out = df + .clone() + .lazy() + .select([col("*").exclude(["a"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:exclude] + + // --8<-- [start:filter] + let start_date = NaiveDate::from_ymd_opt(2022, 12, 2) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let end_date = NaiveDate::from_ymd_opt(2022, 12, 8) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let out = df + .clone() + .lazy() + .filter( + col("c") + .gt_eq(lit(start_date)) + .and(col("c").lt_eq(lit(end_date))), + ) + .collect()?; + println!("{}", out); + // --8<-- [end:filter] + + // --8<-- [start:filter2] + let out = df + .clone() + .lazy() + .filter(col("a").lt_eq(3).and(col("d").is_not_null())) + .collect()?; + println!("{}", out); + // --8<-- [end:filter2] + + // --8<-- [start:with_columns] + let out = df + .clone() + .lazy() + .with_columns([ + col("b").sum().alias("e"), + (col("b") + lit(42)).alias("b+42"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:with_columns] + + // --8<-- [start:dataframe2] + let df2: DataFrame = df!("x" => 0..8, + "y"=> &["A", "A", "A", "B", "B", "C", "X", "X"], + ) + .expect("should not fail"); + println!("{}", df2); + // --8<-- [end:dataframe2] + + // --8<-- [start:group_by] + let out = df2 + .clone() + .lazy() + .group_by(["y"]) + .agg([count()]) + .collect()?; + println!("{}", out); + // --8<-- [end:group_by] + + // --8<-- [start:group_by2] + let out = df2 + .clone() + .lazy() + .group_by(["y"]) + .agg([col("*").count().alias("count"), col("*").sum().alias("sum")]) + .collect()?; + println!("{}", out); + // --8<-- [end:group_by2] + + // --8<-- [start:combine] + let out = df + .clone() + .lazy() + .with_columns([(col("a") * col("b")).alias("a * b")]) + .select([col("*").exclude(["c", "d"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:combine] + + // --8<-- [start:combine2] + let out = df + .clone() + .lazy() + .with_columns([(col("a") * col("b")).alias("a * b")]) + .select([col("*").exclude(["d"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:combine2] + + Ok(()) +} diff --git a/docs/src/rust/getting-started/joins.rs b/docs/src/rust/getting-started/joins.rs new file mode 100644 index 000000000000..1f583dc0e4f9 --- /dev/null +++ b/docs/src/rust/getting-started/joins.rs @@ -0,0 +1,29 @@ +use polars::prelude::*; + + +fn main() -> Result<(), Box>{ + + + // --8<-- [start:join] + use rand::Rng; + let mut rng = rand::thread_rng(); + + let df: DataFrame = df!("a" => 0..8, + "b"=> (0..8).map(|_| rng.gen::()).collect::>(), + "d"=> [Some(1.0), Some(2.0), None, None, Some(0.0), Some(-5.0), Some(-42.), None] + ).expect("should not fail"); + let df2: DataFrame = df!("x" => 0..8, + "y"=> &["A", "A", "A", "B", "B", "C", "X", "X"], + ).expect("should not fail"); + let joined = df.join(&df2,["a"],["x"],JoinType::Left,None)?; + println!("{}",joined); + // --8<-- [end:join] + + // --8<-- [start:hstack] + let stacked = df.hstack(df2.get_columns())?; + println!("{}",stacked); + // --8<-- [end:hstack] + + Ok(()) + +} diff --git a/docs/src/rust/getting-started/reading-writing.rs b/docs/src/rust/getting-started/reading-writing.rs new file mode 100644 index 000000000000..54b538ad93d0 --- /dev/null +++ b/docs/src/rust/getting-started/reading-writing.rs @@ -0,0 +1,67 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use chrono::prelude::*; + use std::fs::File; + + let mut df: DataFrame = df!( + "integer" => &[1, 2, 3], + "date" => &[ + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + ], + "float" => &[4.0, 5.0, 6.0] + ) + .expect("should not fail"); + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:csv] + let mut file = File::create("docs/data/output.csv").expect("could not create file"); + CsvWriter::new(&mut file) + .has_header(true) + .with_separator(b',') + .finish(&mut df); + let df_csv = CsvReader::from_path("docs/data/output.csv")? + .infer_schema(None) + .has_header(true) + .finish()?; + println!("{}", df_csv); + // --8<-- [end:csv] + + // --8<-- [start:csv2] + let mut file = File::create("docs/data/output.csv").expect("could not create file"); + CsvWriter::new(&mut file) + .has_header(true) + .with_separator(b',') + .finish(&mut df); + let df_csv = CsvReader::from_path("docs/data/output.csv")? + .infer_schema(None) + .has_header(true) + .with_parse_dates(true) + .finish()?; + println!("{}", df_csv); + // --8<-- [end:csv2] + + // --8<-- [start:json] + let mut file = File::create("docs/data/output.json").expect("could not create file"); + JsonWriter::new(&mut file).finish(&mut df); + let mut f = File::open("docs/data/output.json")?; + let df_json = JsonReader::new(f) + .with_json_format(JsonFormat::JsonLines) + .finish()?; + println!("{}", df_json); + // --8<-- [end:json] + + // --8<-- [start:parquet] + let mut file = File::create("docs/data/output.parquet").expect("could not create file"); + ParquetWriter::new(&mut file).finish(&mut df); + let mut f = File::open("docs/data/output.parquet")?; + let df_parquet = ParquetReader::new(f).finish()?; + println!("{}", df_parquet); + // --8<-- [end:parquet] + + Ok(()) +} diff --git a/docs/src/rust/getting-started/series-dataframes.rs b/docs/src/rust/getting-started/series-dataframes.rs new file mode 100644 index 000000000000..f156784e2bbc --- /dev/null +++ b/docs/src/rust/getting-started/series-dataframes.rs @@ -0,0 +1,59 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:series] + use polars::prelude::*; + + let s = Series::new("a", [1, 2, 3, 4, 5]); + println!("{}", s); + // --8<-- [end:series] + + // --8<-- [start:minmax] + let s = Series::new("a", [1, 2, 3, 4, 5]); + // The use of generics is necessary for the type system + println!("{}", s.min::().unwrap()); + println!("{}", s.max::().unwrap()); + // --8<-- [end:minmax] + + // --8<-- [start:string] + // This operation is not directly available on the Series object yet, only on the DataFrame + // --8<-- [end:string] + + // --8<-- [start:dt] + // This operation is not directly available on the Series object yet, only as an Expression + // --8<-- [end:dt] + + // --8<-- [start:dataframe] + use chrono::prelude::*; + + let df: DataFrame = df!( + "integer" => &[1, 2, 3, 4, 5], + "date" => &[ + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 4).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap() + ], + "float" => &[4.0, 5.0, 6.0, 7.0, 8.0], + ) + .unwrap(); + + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:head] + println!("{}", df.head(Some(3))); + // --8<-- [end:head] + + // --8<-- [start:tail] + println!("{}", df.tail(Some(3))); + // --8<-- [end:tail] + + // --8<-- [start:sample] + println!("{}", df.sample_n(2, false, true, None)?); + // --8<-- [end:sample] + + // --8<-- [start:describe] + println!("{:?}", df.describe(None)); + // --8<-- [end:describe] + Ok(()) +} diff --git a/docs/src/rust/home/example.rs b/docs/src/rust/home/example.rs new file mode 100644 index 000000000000..00cf7de67bfb --- /dev/null +++ b/docs/src/rust/home/example.rs @@ -0,0 +1,16 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:example] + use polars::prelude::*; + + let q = LazyCsvReader::new("docs/data/iris.csv") + .has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("*").sum()]); + + let df = q.collect(); + // --8<-- [end:example] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/concepts/contexts.rs b/docs/src/rust/user-guide/concepts/contexts.rs new file mode 100644 index 000000000000..b911faa8fd6d --- /dev/null +++ b/docs/src/rust/user-guide/concepts/contexts.rs @@ -0,0 +1,69 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use rand::{thread_rng, Rng}; + + let mut arr = [0f64; 5]; + thread_rng().fill(&mut arr); + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &[Some("foo"), Some("ham"), Some("spam"), Some("eggs"), None], + "random" => &arr, + "groups" => &["A", "A", "B", "C", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:select] + let out = df + .clone() + .lazy() + .select([ + sum("nrs"), + col("names").sort(false), + col("names").first().alias("first name"), + (mean("nrs") * lit(10)).alias("10xnrs"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:select] + + // --8<-- [start:filter] + let out = df.clone().lazy().filter(col("nrs").gt(lit(2))).collect()?; + println!("{}", out); + // --8<-- [end:filter] + + // --8<-- [start:with_columns] + let out = df + .clone() + .lazy() + .with_columns([ + sum("nrs").alias("nrs_sum"), + col("random").count().alias("count"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:with_columns] + + // --8<-- [start:group_by] + let out = df + .lazy() + .group_by([col("groups")]) + .agg([ + sum("nrs"), // sum nrs by groups + col("random").count().alias("count"), // count group members + // sum random where name != null + col("random") + .filter(col("names").is_not_null()) + .sum() + .suffix("_sum"), + col("names").reverse().alias("reversed names"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:group_by] + Ok(()) +} diff --git a/docs/src/rust/user-guide/concepts/expressions.rs b/docs/src/rust/user-guide/concepts/expressions.rs new file mode 100644 index 000000000000..9c76fc6642e8 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/expressions.rs @@ -0,0 +1,24 @@ +use polars::prelude::*; +use rand::Rng; +use chrono::prelude::*; + +fn main() -> Result<(), Box>{ + + let df = df! ( + "foo" => &[Some(1), Some(2), Some(3), None, Some(5)], + "bar" => &[Some("foo"), Some("ham"), Some("spam"), Some("egg"), None], + )?; + + // --8<-- [start:example1] + df.column("foo")?.sort(false).head(Some(2)); + // --8<-- [end:example1] + + // --8<-- [start:example2] + df.clone().lazy().select([ + col("foo").sort(Default::default()).head(Some(2)), + col("bar").filter(col("foo").eq(lit(1))).sum(), + ]).collect()?; + // --8<-- [end:example2] + + Ok(()) +} \ No newline at end of file diff --git a/docs/src/rust/user-guide/concepts/lazy-vs-eager.rs b/docs/src/rust/user-guide/concepts/lazy-vs-eager.rs new file mode 100644 index 000000000000..910235fbbf65 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/lazy-vs-eager.rs @@ -0,0 +1,30 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:eager] + let df = CsvReader::from_path("docs/data/iris.csv") + .unwrap() + .finish() + .unwrap(); + let mask = df.column("sepal_length")?.f64()?.gt(5.0); + let df_small = df.filter(&mask)?; + let df_agg = df_small + .group_by(["species"])? + .select(["sepal_width"]) + .mean()?; + println!("{}", df_agg); + // --8<-- [end:eager] + + // --8<-- [start:lazy] + let q = LazyCsvReader::new("docs/data/iris.csv") + .has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + let df = q.collect()?; + println!("{}", df); + // --8<-- [end:lazy] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/concepts/streaming.rs b/docs/src/rust/user-guide/concepts/streaming.rs new file mode 100644 index 000000000000..f00b5e92ca99 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/streaming.rs @@ -0,0 +1,19 @@ +use chrono::prelude::*; +use polars::prelude::*; +use rand::Rng; + +fn main() -> Result<(), Box> { + // --8<-- [start:streaming] + let q = LazyCsvReader::new("docs/data/iris.csv") + .has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + + let df = q.with_streaming(true).collect()?; + println!("{}", df); + // --8<-- [end:streaming] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/aggregation.rs b/docs/src/rust/user-guide/expressions/aggregation.rs new file mode 100644 index 000000000000..205ec2f01bf7 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/aggregation.rs @@ -0,0 +1,204 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use reqwest::blocking::Client; + use std::io::Cursor; + + let url = "https://theunitedstates.io/congress-legislators/legislators-historical.csv"; + + let mut schema = Schema::new(); + schema.with_column("first_name".to_string(), DataType::Categorical(None)); + schema.with_column("gender".to_string(), DataType::Categorical(None)); + schema.with_column("type".to_string(), DataType::Categorical(None)); + schema.with_column("state".to_string(), DataType::Categorical(None)); + schema.with_column("party".to_string(), DataType::Categorical(None)); + schema.with_column("birthday".to_string(), DataType::Date); + + let data: Vec = Client::new().get(url).send()?.text()?.bytes().collect(); + + let dataset = CsvReader::new(Cursor::new(data)) + .has_header(true) + .with_dtypes(Some(&schema)) + .with_parse_dates(true) + .finish()?; + + println!("{}", &dataset); + // --8<-- [end:dataframe] + + // --8<-- [start:basic] + let df = dataset + .clone() + .lazy() + .group_by(["first_name"]) + .agg([count(), col("gender").list(), col("last_name").first()]) + .sort( + "count", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:basic] + + // --8<-- [start:conditional] + let df = dataset + .clone() + .lazy() + .group_by(["state"]) + .agg([ + (col("party").eq(lit("Anti-Administration"))) + .sum() + .alias("anti"), + (col("party").eq(lit("Pro-Administration"))) + .sum() + .alias("pro"), + ]) + .sort( + "pro", + SortOptions { + descending: true, + nulls_last: false, + }, + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:conditional] + + // --8<-- [start:nested] + let df = dataset + .clone() + .lazy() + .group_by(["state", "party"]) + .agg([col("party").count().alias("count")]) + .filter( + col("party") + .eq(lit("Anti-Administration")) + .or(col("party").eq(lit("Pro-Administration"))), + ) + .sort( + "count", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:nested] + + // --8<-- [start:filter] + fn compute_age() -> Expr { + lit(2022) - col("birthday").dt().year() + } + + fn avg_birthday(gender: &str) -> Expr { + compute_age() + .filter(col("gender").eq(lit(gender))) + .mean() + .alias(&format!("avg {} birthday", gender)) + } + + let df = dataset + .clone() + .lazy() + .group_by(["state"]) + .agg([ + avg_birthday("M"), + avg_birthday("F"), + (col("gender").eq(lit("M"))).sum().alias("# male"), + (col("gender").eq(lit("F"))).sum().alias("# female"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:filter] + + // --8<-- [start:sort] + fn get_person() -> Expr { + col("first_name") + lit(" ") + col("last_name") + } + + let df = dataset + .clone() + .lazy() + .sort( + "birthday", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .group_by(["state"]) + .agg([ + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort] + + // --8<-- [start:sort2] + let df = dataset + .clone() + .lazy() + .sort( + "birthday", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .group_by(["state"]) + .agg([ + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort(false).first().alias("alphabetical_first"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort2] + + // --8<-- [start:sort3] + let df = dataset + .clone() + .lazy() + .sort( + "birthday", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .group_by(["state"]) + .agg([ + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort(false).first().alias("alphabetical_first"), + col("gender") + .sort_by(["first_name"], [false]) + .first() + .alias("gender"), + ]) + .sort("state", SortOptions::default()) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort3] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/casting.rs b/docs/src/rust/user-guide/expressions/casting.rs new file mode 100644 index 000000000000..2c4938897b8a --- /dev/null +++ b/docs/src/rust/user-guide/expressions/casting.rs @@ -0,0 +1,203 @@ +// --8<-- [start:setup] +use polars::lazy::dsl::StrptimeOptions; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:dfnum] + let df = df! ( + "integers"=> &[1, 2, 3, 4, 5], + "big_integers"=> &[1, 10000002, 3, 10000004, 10000005], + "floats"=> &[4.0, 5.0, 6.0, 7.0, 8.0], + "floats_with_decimal"=> &[4.532, 5.5, 6.5, 7.5, 8.5], + )?; + + println!("{}", &df); + // --8<-- [end:dfnum] + + // --8<-- [start:castnum] + let out = df + .clone() + .lazy() + .select([ + col("integers") + .cast(DataType::Float32) + .alias("integers_as_floats"), + col("floats") + .cast(DataType::Int32) + .alias("floats_as_integers"), + col("floats_with_decimal") + .cast(DataType::Int32) + .alias("floats_with_decimal_as_integers"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:castnum] + + // --8<-- [start:downcast] + let out = df + .clone() + .lazy() + .select([ + col("integers") + .cast(DataType::Int16) + .alias("integers_smallfootprint"), + col("floats") + .cast(DataType::Float32) + .alias("floats_smallfootprint"), + ]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:downcast] + + // --8<-- [start:overflow] + + let out = df + .clone() + .lazy() + .select([col("big_integers").strict_cast(DataType::Int8)]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:overflow] + + // --8<-- [start:overflow2] + let out = df + .clone() + .lazy() + .select([col("big_integers").cast(DataType::Int8)]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:overflow2] + + // --8<-- [start:strings] + + let df = df! ( + "integers" => &[1, 2, 3, 4, 5], + "float" => &[4.0, 5.03, 6.0, 7.0, 8.0], + "floats_as_string" => &["4.0", "5.0", "6.0", "7.0", "8.0"], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("integers").cast(DataType::Utf8), + col("float").cast(DataType::Utf8), + col("floats_as_string").cast(DataType::Float64), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:strings] + + // --8<-- [start:strings2] + + let df = df! ("strings_not_float"=> ["4.0", "not_a_number", "6.0", "7.0", "8.0"])?; + + let out = df + .clone() + .lazy() + .select([col("strings_not_float").cast(DataType::Float64)]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:strings2] + + // --8<-- [start:bool] + + let df = df! ( + "integers"=> &[-1, 0, 2, 3, 4], + "floats"=> &[0.0, 1.0, 2.0, 3.0, 4.0], + "bools"=> &[true, false, true, false, true], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("integers").cast(DataType::Boolean), + col("floats").cast(DataType::Boolean), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:bool] + + // --8<-- [start:dates] + + use chrono::prelude::*; + use polars::time::*; + + let df = df! ( + "date" => date_range( + "date", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None + )?.cast(&DataType::Date)?, + "datetime" => datetime_range( + "datetime", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None + )?, + )?; + + let out = df + .clone() + .lazy() + .select([ + col("date").cast(DataType::Int64), + col("datetime").cast(DataType::Int64), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:dates] + + // --8<-- [start:dates2] + + let df = df! ( + "date" => date_range("date", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), Duration::parse("1d"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + "string" => &[ + "2022-01-01", + "2022-01-02", + "2022-01-03", + "2022-01-04", + "2022-01-05", + ], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("date").dt().to_string("%Y-%m-%d"), + col("string").str().to_datetime( + TimeUnit::Microseconds, + None, + StrptimeOptions::default(), + lit("raise"), + ), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:dates2] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/column-selections.rs b/docs/src/rust/user-guide/expressions/column-selections.rs new file mode 100644 index 000000000000..105cc6f102df --- /dev/null +++ b/docs/src/rust/user-guide/expressions/column-selections.rs @@ -0,0 +1,99 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:selectors_df] + + use chrono::prelude::*; + use polars::time::*; + + let df = df!( + "id" => &[9, 4, 2], + "place" => &["Mars", "Earth", "Saturn"], + "date" => date_range("date", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), Duration::parse("1d"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + "sales" => &[33.4, 2142134.1, 44.7], + "has_people" => &[false, true, false], + "logged_at" => date_range("logged_at", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 2).unwrap(), Duration::parse("1s"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + )? + .with_row_count("rn", None)?; + println!("{}", &df); + // --8<-- [end:selectors_df] + + // --8<-- [start:all] + let out = df.clone().lazy().select([col("*")]).collect()?; + + // Is equivalent to + let out = df.clone().lazy().select([all()]).collect()?; + println!("{}", &out); + // --8<-- [end:all] + + // --8<-- [start:exclude] + let out = df + .clone() + .lazy() + .select([col("*").exclude(["logged_at", "rn"])]) + .collect()?; + println!("{}", &out); + // --8<-- [end:exclude] + + // --8<-- [start:expansion_by_names] + let out = df + .clone() + .lazy() + .select([cols(["date", "logged_at"]).dt().to_string("%Y-%h-%d")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:expansion_by_names] + + // --8<-- [start:expansion_by_regex] + let out = df.clone().lazy().select([col("^.*(as|sa).*$")]).collect()?; + println!("{}", &out); + // --8<-- [end:expansion_by_regex] + + // --8<-- [start:expansion_by_dtype] + let out = df + .clone() + .lazy() + .select([dtype_cols([DataType::Int64, DataType::UInt32, DataType::Boolean]).n_unique()]) + .collect()?; + // gives different result than python as the id col is i32 in rust + println!("{}", &out); + // --8<-- [end:expansion_by_dtype] + + // --8<-- [start:selectors_intro] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_intro] + + // --8<-- [start:selectors_diff] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_diff] + + // --8<-- [start:selectors_union] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_union] + + // --8<-- [start:selectors_by_name] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/1059 + // --8<-- [end:selectors_by_name] + + // --8<-- [start:selectors_to_expr] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_to_expr] + + // --8<-- [start:selectors_is_selector_utility] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_is_selector_utility] + + // --8<-- [start:selectors_colnames_utility] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_colnames_utility] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/folds.rs b/docs/src/rust/user-guide/expressions/folds.rs new file mode 100644 index 000000000000..9312735b9284 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/folds.rs @@ -0,0 +1,49 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + + // --8<-- [start:mansum] + let df = df!( + "a" => &[1, 2, 3], + "b" => &[10, 20, 30], + )?; + + let out = df + .lazy() + .select([fold_exprs(lit(0), |acc, x| Ok(Some(acc + x)), [col("*")]).alias("sum")]) + .collect()?; + println!("{}", out); + // --8<-- [end:mansum] + + // --8<-- [start:conditional] + let df = df!( + "a" => &[1, 2, 3], + "b" => &[0, 1, 2], + )?; + + let out = df + .lazy() + .filter(fold_exprs( + lit(true), + |acc, x| acc.bitand(&x).map(Some), + [col("*").gt(1)], + )) + .collect()?; + println!("{}", out); + // --8<-- [end:conditional] + + // --8<-- [start:string] + let df = df!( + "a" => &["a", "b", "c"], + "b" => &[1, 2, 3], + )?; + + let out = df + .lazy() + .select([concat_str([col("a"), col("b")], "")]) + .collect()?; + println!("{:?}", out); + // --8<-- [end:string] + + Ok(()) +} \ No newline at end of file diff --git a/docs/src/rust/user-guide/expressions/functions.rs b/docs/src/rust/user-guide/expressions/functions.rs new file mode 100644 index 000000000000..490809b75557 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/functions.rs @@ -0,0 +1,79 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use rand::{thread_rng, Rng}; + + let mut arr = [0f64; 5]; + thread_rng().fill(&mut arr); + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &["foo", "ham", "spam", "egg", "spam"], + "random" => &arr, + "groups" => &["A", "A", "B", "C", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:samename] + let df_samename = df.clone().lazy().select([col("nrs") + lit(5)]).collect()?; + println!("{}", &df_samename); + // --8<-- [end:samename] + + // --8<-- [start:samenametwice] + let df_samename2 = df + .clone() + .lazy() + .select([col("nrs") + lit(5), col("nrs") - lit(5)]) + .collect(); + match df_samename2 { + Ok(df) => println!("{}", &df), + Err(e) => println!("{:?}", &e), + }; + // --8<-- [end:samenametwice] + + // --8<-- [start:samenamealias] + let df_alias = df + .clone() + .lazy() + .select([ + (col("nrs") + lit(5)).alias("nrs + 5"), + (col("nrs") - lit(5)).alias("nrs - 5"), + ]) + .collect()?; + println!("{}", &df_alias); + // --8<-- [end:samenamealias] + + // --8<-- [start:countunique] + let df_alias = df + .clone() + .lazy() + .select([ + col("names").n_unique().alias("unique"), + // Following query shows there isn't anything in Rust API + // https://docs.rs/polars/latest/polars/?search=approx_n_unique + // col("names").approx_n_unique().alias("unique_approx"), + ]) + .collect()?; + println!("{}", &df_alias); + // --8<-- [end:countunique] + + // --8<-- [start:conditional] + let df_conditional = df + .clone() + .lazy() + .select([ + col("nrs"), + when(col("nrs").gt(2)) + .then(lit(true)) + .otherwise(lit(false)) + .alias("conditional"), + ]) + .collect()?; + println!("{}", &df_conditional); + // --8<-- [end:conditional] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/lists.rs b/docs/src/rust/user-guide/expressions/lists.rs new file mode 100644 index 000000000000..257649e0cc7d --- /dev/null +++ b/docs/src/rust/user-guide/expressions/lists.rs @@ -0,0 +1,162 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] +fn main() -> Result<(), Box> { + // --8<-- [start:weather_df] + let stns: Vec = (1..6).map(|i| format!("Station {i}")).collect(); + let weather = df!( + "station"=> &stns, + "temperatures"=> &[ + "20 5 5 E1 7 13 19 9 6 20", + "18 8 16 11 23 E2 8 E2 E2 E2 90 70 40", + "19 24 E9 16 6 12 10 22", + "E2 E0 15 7 8 10 E1 24 17 13 6", + "14 8 E0 16 22 24 E1", + ], + )?; + println!("{}", &weather); + // --8<-- [end:weather_df] + + // --8<-- [start:string_to_list] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures").str().split(lit(" "))]) + .collect()?; + println!("{}", &out); + // --8<-- [end:string_to_list] + + // --8<-- [start:explode_to_atomic] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures").str().split(lit(" "))]) + .explode(["temperatures"]) + .collect()?; + println!("{}", &out); + // --8<-- [end:explode_to_atomic] + + // --8<-- [start:list_ops] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures").str().split(lit(" "))]) + .with_columns([ + col("temperatures").list().head(lit(3)).alias("top3"), + col("temperatures") + .list() + .slice(lit(-3), lit(3)) + .alias("bottom_3"), + col("temperatures").list().lengths().alias("obs"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:list_ops] + + // --8<-- [start:count_errors] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures") + .str() + .split(lit(" ")) + .list() + .eval(col("").cast(DataType::Int64).is_null(), false) + .list() + .sum() + .alias("errors")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:count_errors] + + // --8<-- [start:count_errors_regex] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures") + .str() + .split(lit(" ")) + .list() + .eval(col("").str().contains(lit("(?i)[a-z]"), false), false) + .list() + .sum() + .alias("errors")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:count_errors_regex] + + // --8<-- [start:weather_by_day] + let stns: Vec = (1..11).map(|i| format!("Station {i}")).collect(); + let weather_by_day = df!( + "station" => &stns, + "day_1" => &[17, 11, 8, 22, 9, 21, 20, 8, 8, 17], + "day_2" => &[15, 11, 10, 8, 7, 14, 18, 21, 15, 13], + "day_3" => &[16, 15, 24, 24, 8, 23, 19, 23, 16, 10], + )?; + println!("{}", &weather_by_day); + // --8<-- [end:weather_by_day] + + // --8<-- [start:weather_by_day_rank] + let rank_pct = (col("") + .rank( + RankOptions { + method: RankMethod::Average, + descending: true, + }, + None, + ) + .cast(DataType::Float32) + / col("*").count().cast(DataType::Float32)) + .round(2); + + let out = weather_by_day + .clone() + .lazy() + .with_columns( + // create the list of homogeneous data + [concat_list([all().exclude(["station"])])?.alias("all_temps")], + ) + .select( + // select all columns except the intermediate list + [ + all().exclude(["all_temps"]), + // compute the rank by calling `list.eval` + col("all_temps") + .list() + .eval(rank_pct, true) + .alias("temps_rank"), + ], + ) + .collect()?; + + println!("{}", &out); + // --8<-- [end:weather_by_day_rank] + + // --8<-- [start:array_df] + let mut col1: ListPrimitiveChunkedBuilder = + ListPrimitiveChunkedBuilder::new("Array_1", 8, 8, DataType::Int32); + col1.append_slice(&[1, 3]); + col1.append_slice(&[2, 5]); + let mut col2: ListPrimitiveChunkedBuilder = + ListPrimitiveChunkedBuilder::new("Array_2", 8, 8, DataType::Int32); + col2.append_slice(&[1, 7, 3]); + col2.append_slice(&[8, 1, 0]); + let array_df = DataFrame::new([col1.finish(), col2.finish()].into())?; + + println!("{}", &array_df); + // --8<-- [end:array_df] + + // --8<-- [start:array_ops] + let out = array_df + .clone() + .lazy() + .select([ + col("Array_1").list().min().suffix("_min"), + col("Array_2").list().sum().suffix("_sum"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:array_ops] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/null.rs b/docs/src/rust/user-guide/expressions/null.rs new file mode 100644 index 000000000000..8d78310cb0a9 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/null.rs @@ -0,0 +1,89 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + + let df = df! ( + "value" => &[Some(1), None], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:count] + let null_count_df = df.null_count(); + println!("{}", &null_count_df); + // --8<-- [end:count] + + // --8<-- [start:isnull] + let is_null_series = df + .clone() + .lazy() + .select([col("value").is_null()]) + .collect()?; + println!("{}", &is_null_series); + // --8<-- [end:isnull] + + // --8<-- [start:dataframe2] + let df = df!( + "col1" => &[Some(1), Some(2), Some(3)], + "col2" => &[Some(1), None, Some(3)], + + )?; + println!("{}", &df); + // --8<-- [end:dataframe2] + + // --8<-- [start:fill] + let fill_literal_df = df + .clone() + .lazy() + .with_columns([col("col2").fill_null(lit(2))]) + .collect()?; + println!("{}", &fill_literal_df); + // --8<-- [end:fill] + + // --8<-- [start:fillstrategy] + let fill_forward_df = df + .clone() + .lazy() + .with_columns([col("col2").forward_fill(None)]) + .collect()?; + println!("{}", &fill_forward_df); + // --8<-- [end:fillstrategy] + + // --8<-- [start:fillexpr] + let fill_median_df = df + .clone() + .lazy() + .with_columns([col("col2").fill_null(median("col2"))]) + .collect()?; + println!("{}", &fill_median_df); + // --8<-- [end:fillexpr] + + // --8<-- [start:fillinterpolate] + let fill_interpolation_df = df + .clone() + .lazy() + .with_columns([col("col2").interpolate(InterpolationMethod::Linear)]) + .collect()?; + println!("{}", &fill_interpolation_df); + // --8<-- [end:fillinterpolate] + + // --8<-- [start:nan] + let nan_df = df!( + "value" => [1.0, f64::NAN, f64::NAN, 3.0], + )?; + println!("{}", &nan_df); + // --8<-- [end:nan] + + // --8<-- [start:nanfill] + let mean_nan_df = nan_df + .clone() + .lazy() + .with_columns([col("value").fill_nan(lit(NULL)).alias("value")]) + .mean() + .collect()?; + println!("{}", &mean_nan_df); + // --8<-- [end:nanfill] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/operators.rs b/docs/src/rust/user-guide/expressions/operators.rs new file mode 100644 index 000000000000..868d301c2182 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/operators.rs @@ -0,0 +1,54 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use rand::{thread_rng, Rng}; + + let mut arr = [0f64; 5]; + thread_rng().fill(&mut arr); + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &[Some("foo"), Some("ham"), Some("spam"), Some("eggs"), None], + "random" => &arr, + "groups" => &["A", "A", "B", "C", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:numerical] + let df_numerical = df + .clone() + .lazy() + .select([ + (col("nrs") + lit(5)).alias("nrs + 5"), + (col("nrs") - lit(5)).alias("nrs - 5"), + (col("nrs") * col("random")).alias("nrs * random"), + (col("nrs") / col("random")).alias("nrs / random"), + ]) + .collect()?; + println!("{}", &df_numerical); + // --8<-- [end:numerical] + + // --8<-- [start:logical] + let df_logical = df + .clone() + .lazy() + .select([ + col("nrs").gt(1).alias("nrs > 1"), + col("random").lt_eq(0.5).alias("random < .5"), + col("nrs").neq(1).alias("nrs != 1"), + col("nrs").eq(1).alias("nrs == 1"), + (col("random").lt_eq(0.5)) + .and(col("nrs").gt(1)) + .alias("and_expr"), // and + (col("random").lt_eq(0.5)) + .or(col("nrs").gt(1)) + .alias("or_expr"), // or + ]) + .collect()?; + println!("{}", &df_logical); + // --8<-- [end:logical] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/strings.rs b/docs/src/rust/user-guide/expressions/strings.rs new file mode 100644 index 000000000000..0b606095ca92 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/strings.rs @@ -0,0 +1,93 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df! ( + "animal" => &[Some("Crab"), Some("cat and dog"), Some("rab$bit"), None], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("animal").str().len_bytes().alias("byte_count"), + col("animal").str().len_chars().alias("letter_count"), + ]) + .collect()?; + + println!("{}", &out); + // --8<-- [end:df] + + // --8<-- [start:existence] + let out = df + .clone() + .lazy() + .select([ + col("animal"), + col("animal") + .str() + .contains(lit("cat|bit"), false) + .alias("regex"), + col("animal") + .str() + .contains_literal(lit("rab$")) + .alias("literal"), + col("animal") + .str() + .starts_with(lit("rab")) + .alias("starts_with"), + col("animal").str().ends_with(lit("dog")).alias("ends_with"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:existence] + + // --8<-- [start:extract] + let df = df!( + "a" => &[ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + )?; + let out = df + .clone() + .lazy() + .select([col("a").str().extract(r"candidate=(\w+)", 1)]) + .collect()?; + println!("{}", &out); + // --8<-- [end:extract] + + // --8<-- [start:extract_all] + let df = df!("foo"=> &["123 bla 45 asd", "xyz 678 910t"])?; + let out = df + .clone() + .lazy() + .select([col("foo") + .str() + .extract_all(lit(r"(\d+)")) + .alias("extracted_nrs")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:extract_all] + + // --8<-- [start:replace] + let df = df!("id"=> &[1, 2], "text"=> &["123abc", "abc456"])?; + let out = df + .clone() + .lazy() + .with_columns([ + col("text").str().replace(lit(r"abc\b"), lit("ABC"), false), + col("text") + .str() + .replace_all(lit("a"), lit("-"), false) + .alias("text_replace_all"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:replace] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/structs.rs b/docs/src/rust/user-guide/expressions/structs.rs new file mode 100644 index 000000000000..662e264222a6 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/structs.rs @@ -0,0 +1,99 @@ +// --8<-- [start:setup] +use polars::{lazy::dsl::count, prelude::*}; +// --8<-- [end:setup] +fn main() -> Result<(), Box> { + // --8<-- [start:ratings_df] + let ratings = df!( + "Movie"=> &["Cars", "IT", "ET", "Cars", "Up", "IT", "Cars", "ET", "Up", "ET"], + "Theatre"=> &["NE", "ME", "IL", "ND", "NE", "SD", "NE", "IL", "IL", "SD"], + "Avg_Rating"=> &[4.5, 4.4, 4.6, 4.3, 4.8, 4.7, 4.7, 4.9, 4.7, 4.6], + "Count"=> &[30, 27, 26, 29, 31, 28, 28, 26, 33, 26], + + )?; + println!("{}", &ratings); + // --8<-- [end:ratings_df] + + // --8<-- [start:state_value_counts] + let out = ratings + .clone() + .lazy() + .select([col("Theatre").value_counts(true, true)]) + .collect()?; + println!("{}", &out); + // --8<-- [end:state_value_counts] + + // --8<-- [start:struct_unnest] + let out = ratings + .clone() + .lazy() + .select([col("Theatre").value_counts(true, true)]) + .unnest(["Theatre"]) + .collect()?; + println!("{}", &out); + // --8<-- [end:struct_unnest] + + // --8<-- [start:series_struct] + // Don't think we can make it the same way in rust, but this works + let rating_series = df!( + "Movie" => &["Cars", "Toy Story"], + "Theatre" => &["NE", "ME"], + "Avg_Rating" => &[4.5, 4.9], + )? + .into_struct("ratings") + .into_series(); + println!("{}", &rating_series); + // // --8<-- [end:series_struct] + + // --8<-- [start:series_struct_extract] + let out = rating_series.struct_()?.field_by_name("Movie")?; + println!("{}", &out); + // --8<-- [end:series_struct_extract] + + // --8<-- [start:series_struct_rename] + let out = DataFrame::new([rating_series].into())? + .lazy() + .select([col("ratings") + .struct_() + .rename_fields(["Film".into(), "State".into(), "Value".into()].to_vec())]) + .unnest(["ratings"]) + .collect()?; + + println!("{}", &out); + // --8<-- [end:series_struct_rename] + + // --8<-- [start:struct_duplicates] + let out = ratings + .clone() + .lazy() + // .filter(as_struct(&[col("Movie"), col("Theatre")]).is_duplicated()) + // Error: .is_duplicated() not available if you try that + // https://github.com/pola-rs/polars/issues/3803 + .filter(count().over([col("Movie"), col("Theatre")]).gt(lit(1))) + .collect()?; + println!("{}", &out); + // --8<-- [end:struct_duplicates] + + // --8<-- [start:struct_ranking] + let out = ratings + .clone() + .lazy() + .with_columns([as_struct(&[col("Count"), col("Avg_Rating")]) + .rank( + RankOptions { + method: RankMethod::Dense, + descending: false, + }, + None, + ) + .over([col("Movie"), col("Theatre")]) + .alias("Rank")]) + // .filter(as_struct(&[col("Movie"), col("Theatre")]).is_duplicated()) + // Error: .is_duplicated() not available if you try that + // https://github.com/pola-rs/polars/issues/3803 + .filter(count().over([col("Movie"), col("Theatre")]).gt(lit(1))) + .collect()?; + println!("{}", &out); + // --8<-- [end:struct_ranking] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/user-defined-functions.rs b/docs/src/rust/user-guide/expressions/user-defined-functions.rs new file mode 100644 index 000000000000..7cbe1605f3e3 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/user-defined-functions.rs @@ -0,0 +1,84 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + let df = df!( + "keys" => &["a", "a", "b"], + "values" => &[10, 7, 1], + )?; + + let out = df + .lazy() + .group_by(["keys"]) + .agg([ + col("values") + .map(|s| Ok(s.shift(1)), GetOutput::default()) + .alias("shift_map"), + col("values").shift(1).alias("shift_expression"), + ]) + .collect()?; + + println!("{}", out); + // --8<-- [end:dataframe] + + // --8<-- [start:apply] + let out = df + .clone() + .lazy() + .group_by([col("keys")]) + .agg([ + col("values") + .apply(|s| Ok(s.shift(1)), GetOutput::default()) + .alias("shift_map"), + col("values").shift(1).alias("shift_expression"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:apply] + + // --8<-- [start:counter] + + // --8<-- [end:counter] + + // --8<-- [start:combine] + let out = df + .lazy() + .select([ + // pack to struct to get access to multiple fields in a custom `apply/map` + as_struct(&[col("keys"), col("values")]) + // we will compute the len(a) + b + .apply( + |s| { + // downcast to struct + let ca = s.struct_()?; + + // get the fields as Series + let s_a = &ca.fields()[0]; + let s_b = &ca.fields()[1]; + + // downcast the `Series` to their known type + let ca_a = s_a.utf8()?; + let ca_b = s_b.i32()?; + + // iterate both `ChunkedArrays` + let out: Int32Chunked = ca_a + .into_iter() + .zip(ca_b) + .map(|(opt_a, opt_b)| match (opt_a, opt_b) { + (Some(a), Some(b)) => Some(a.len() as i32 + b), + _ => None, + }) + .collect(); + + Ok(out.into_series()) + }, + GetOutput::from_type(DataType::Int32), + ) + .alias("solution_apply"), + (col("keys").str().count_match(".") + col("values")).alias("solution_expr"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:combine] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/window.rs b/docs/src/rust/user-guide/expressions/window.rs new file mode 100644 index 000000000000..2fcc32cdc309 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/window.rs @@ -0,0 +1,131 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:pokemon] + use polars::prelude::*; + use reqwest::blocking::Client; + + let data: Vec = Client::new() + .get("https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv") + .send()? + .text()? + .bytes() + .collect(); + + let df = CsvReader::new(std::io::Cursor::new(data)) + .has_header(true) + .finish()?; + + println!("{}", df); + // --8<-- [end:pokemon] + + // --8<-- [start:group_by] + let out = df + .clone() + .lazy() + .select([ + col("Type 1"), + col("Type 2"), + col("Attack") + .mean() + .over(["Type 1"]) + .alias("avg_attack_by_type"), + col("Defense") + .mean() + .over(["Type 1", "Type 2"]) + .alias("avg_defense_by_type_combination"), + col("Attack").mean().alias("avg_attack"), + ]) + .collect()?; + + println!("{}", out); + // --8<-- [end:group_by] + + // --8<-- [start:operations] + let filtered = df + .clone() + .lazy() + .filter(col("Type 2").eq(lit("Psychic"))) + .select([col("Name"), col("Type 1"), col("Speed")]) + .collect()?; + + println!("{}", filtered); + // --8<-- [end:operations] + + // --8<-- [start:sort] + let out = filtered + .lazy() + .with_columns([cols(["Name", "Speed"]) + .sort_by(["Speed"], [true]) + .over(["Type 1"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:sort] + + // --8<-- [start:rules] + // aggregate and broadcast within a group + // output type: -> i32 + sum("foo").over([col("groups")]) + // sum within a group and multiply with group elements + // output type: -> i32 + (col("x").sum() * col("y")) + .over([col("groups")]) + .alias("x1") + // sum within a group and multiply with group elements + // and aggregate the group to a list + // output type: -> ChunkedArray + (col("x").sum() * col("y")) + .list() + .over([col("groups")]) + .alias("x2") + // note that it will require an explicit `list()` call + // sum within a group and multiply with group elements + // and aggregate the group to a list + // the flatten call explodes that list + + // This is the fastest method to do things over groups when the groups are sorted + (col("x").sum() * col("y")) + .list() + .over([col("groups")]) + .flatten() + .alias("x3"); + // --8<-- [end:rules] + + // --8<-- [start:examples] + let out = df + .clone() + .lazy() + .select([ + col("Type 1") + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten(), + col("Name") + .sort_by(["Speed"], [true]) + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten() + .alias("fastest/group"), + col("Name") + .sort_by(["Attack"], [true]) + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten() + .alias("strongest/group"), + col("Name") + .sort(false) + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten() + .alias("sorted_by_alphabet"), + ]) + .collect()?; + println!("{:?}", out); + // --8<-- [end:examples] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/io/cloud-storage.rs b/docs/src/rust/user-guide/io/cloud-storage.rs new file mode 100644 index 000000000000..4118e520628d --- /dev/null +++ b/docs/src/rust/user-guide/io/cloud-storage.rs @@ -0,0 +1,45 @@ +""" +# --8<-- [start:read_parquet] +use aws_sdk_s3::Region; + +use aws_config::meta::region::RegionProviderChain; +use aws_sdk_s3::Client; +use std::borrow::Cow; + +use polars::prelude::*; + +#[tokio::main] +async fn main() { + let bucket = ""; + let path = ""; + + let config = aws_config::from_env().load().await; + let client = Client::new(&config); + + let req = client.get_object().bucket(bucket).key(path); + + let res = req.clone().send().await.unwrap(); + let bytes = res.body.collect().await.unwrap(); + let bytes = bytes.into_bytes(); + + let cursor = std::io::Cursor::new(bytes); + + let df = CsvReader::new(cursor).finish().unwrap(); + + println!("{:?}", df); +} +# --8<-- [end:read_parquet] + +# --8<-- [start:scan_parquet] +# --8<-- [end:scan_parquet] + +# --8<-- [start:scan_parquet_query] +# --8<-- [end:scan_parquet_query] + +# --8<-- [start:scan_pyarrow_dataset] +# --8<-- [end:scan_pyarrow_dataset] + +# --8<-- [start:write_parquet] +# --8<-- [end:write_parquet] + +""" diff --git a/docs/src/rust/user-guide/io/csv.rs b/docs/src/rust/user-guide/io/csv.rs new file mode 100644 index 000000000000..7c56d813e626 --- /dev/null +++ b/docs/src/rust/user-guide/io/csv.rs @@ -0,0 +1,29 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box>{ + + """ + // --8<-- [start:read] + use polars::prelude::*; + + let df = CsvReader::from_path("docs/data/path.csv").unwrap().finish().unwrap(); + // --8<-- [end:read] + """ + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/data/path.csv").unwrap(); + CsvWriter::new(&mut file).finish(&mut df).unwrap(); + // --8<-- [end:write] + + // --8<-- [start:scan] + let df = LazyCsvReader::new("./test.csv").finish().unwrap(); + // --8<-- [end:scan] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/io/json.rs b/docs/src/rust/user-guide/io/json.rs new file mode 100644 index 000000000000..ab4df729c955 --- /dev/null +++ b/docs/src/rust/user-guide/io/json.rs @@ -0,0 +1,47 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box>{ + + """ + // --8<-- [start:read] + use polars::prelude::*; + + let mut file = std::fs::File::open("docs/data/path.json").unwrap(); + let df = JsonReader::new(&mut file).finish().unwrap(); + // --8<-- [end:read] + + + // --8<-- [start:readnd] + let mut file = std::fs::File::open("docs/data/path.json").unwrap(); + let df = JsonLineReader::new(&mut file).finish().unwrap(); + // --8<-- [end:readnd] + """ + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/data/path.json").unwrap(); + + // json + JsonWriter::new(&mut file) + .with_json_format(JsonFormat::Json) + .finish(&mut df) + .unwrap(); + + // ndjson + JsonWriter::new(&mut file) + .with_json_format(JsonFormat::JsonLines) + .finish(&mut df) + .unwrap(); + // --8<-- [end:write] + + // --8<-- [start:scan] + let df = LazyJsonLineReader::new("docs/data/path.json".to_string()).finish().unwrap(); + // --8<-- [end:scan] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/io/parquet.rs b/docs/src/rust/user-guide/io/parquet.rs new file mode 100644 index 000000000000..f3469ffd4e2c --- /dev/null +++ b/docs/src/rust/user-guide/io/parquet.rs @@ -0,0 +1,30 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box>{ + + """ + // --8<-- [start:read] + let mut file = std::fs::File::open("docs/data/path.parquet").unwrap(); + + let df = ParquetReader::new(&mut file).finish().unwrap(); + // --8<-- [end:read] + """ + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/data/path.parquet").unwrap(); + ParquetWriter::new(&mut file).finish(&mut df).unwrap(); + // --8<-- [end:write] + + // --8<-- [start:scan] + let args = ScanArgsParquet::default(); + let df = LazyFrame::scan_parquet("./file.parquet",args).unwrap(); + // --8<-- [end:scan] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/concatenation.rs b/docs/src/rust/user-guide/transformations/concatenation.rs new file mode 100644 index 000000000000..95db7b4749ea --- /dev/null +++ b/docs/src/rust/user-guide/transformations/concatenation.rs @@ -0,0 +1,49 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:vertical] + let df_v1 = df!( + "a"=> &[1], + "b"=> &[3], + )?; + let df_v2 = df!( + "a"=> &[2], + "b"=> &[4], + )?; + let df_vertical_concat = concat( + [df_v1.clone().lazy(), df_v2.clone().lazy()], + UnionArgs::default(), + )? + .collect()?; + println!("{}", &df_vertical_concat); + // --8<-- [end:vertical] + + // --8<-- [start:horizontal] + let df_h1 = df!( + "l1"=> &[1, 2], + "l2"=> &[3, 4], + )?; + let df_h2 = df!( + "r1"=> &[5, 6], + "r2"=> &[7, 8], + "r3"=> &[9, 10], + )?; + let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2])?; + println!("{}", &df_horizontal_concat); + // --8<-- [end:horizontal] + + // --8<-- [start:cross] + let df_d1 = df!( + "a"=> &[1], + "b"=> &[3], + )?; + let df_d2 = df!( + "a"=> &[2], + "d"=> &[4],)?; + let df_diagonal_concat = polars::functions::concat_df_diagonal(&[df_d1, df_d2])?; + println!("{}", &df_diagonal_concat); + // --8<-- [end:cross] + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/joins.rs b/docs/src/rust/user-guide/transformations/joins.rs new file mode 100644 index 000000000000..aa444c5d9a1a --- /dev/null +++ b/docs/src/rust/user-guide/transformations/joins.rs @@ -0,0 +1,205 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:innerdf] + let df_customers = df! ( + + "customer_id" => &[1, 2, 3], + "name" => &["Alice", "Bob", "Charlie"], + )?; + + println!("{}", &df_customers); + // --8<-- [end:innerdf] + + // --8<-- [start:innerdf2] + let df_orders = df!( + "order_id"=> &["a", "b", "c"], + "customer_id"=> &[1, 2, 2], + "amount"=> &[100, 200, 300], + )?; + println!("{}", &df_orders); + // --8<-- [end:innerdf2] + + // --8<-- [start:inner] + let df_inner_customer_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + println!("{}", &df_inner_customer_join); + // --8<-- [end:inner] + + // --8<-- [start:left] + let df_left_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Left), + ) + .collect()?; + println!("{}", &df_left_join); + // --8<-- [end:left] + + // --8<-- [start:outer] + let df_outer_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Outer), + ) + .collect()?; + println!("{}", &df_outer_join); + // --8<-- [end:outer] + + // --8<-- [start:df3] + let df_colors = df!( + "color"=> &["red", "blue", "green"], + )?; + println!("{}", &df_colors); + // --8<-- [end:df3] + + // --8<-- [start:df4] + let df_sizes = df!( + "size"=> &["S", "M", "L"], + )?; + println!("{}", &df_sizes); + // --8<-- [end:df4] + + // --8<-- [start:cross] + let df_cross_join = df_colors + .clone() + .lazy() + .cross_join(df_sizes.clone().lazy()) + .collect()?; + println!("{}", &df_cross_join); + // --8<-- [end:cross] + + // --8<-- [start:df5] + let df_cars = df!( + "id"=> &["a", "b", "c"], + "make"=> &["ford", "toyota", "bmw"], + )?; + println!("{}", &df_cars); + // --8<-- [end:df5] + + // --8<-- [start:df6] + let df_repairs = df!( + "id"=> &["c", "c"], + "cost"=> &[100, 200], + )?; + println!("{}", &df_repairs); + // --8<-- [end:df6] + + // --8<-- [start:inner2] + let df_inner_join = df_cars + .clone() + .lazy() + .inner_join(df_repairs.clone().lazy(), col("id"), col("id")) + .collect()?; + println!("{}", &df_inner_join); + // --8<-- [end:inner2] + + // --8<-- [start:semi] + let df_semi_join = df_cars + .clone() + .lazy() + .join( + df_repairs.clone().lazy(), + [col("id")], + [col("id")], + JoinArgs::new(JoinType::Semi), + ) + .collect()?; + println!("{}", &df_semi_join); + // --8<-- [end:semi] + + // --8<-- [start:anti] + let df_anti_join = df_cars + .clone() + .lazy() + .join( + df_repairs.clone().lazy(), + [col("id")], + [col("id")], + JoinArgs::new(JoinType::Anti), + ) + .collect()?; + println!("{}", &df_anti_join); + // --8<-- [end:anti] + + // --8<-- [start:df7] + use chrono::prelude::*; + let df_trades = df!( + "time"=> &[ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 3, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock"=> &["A", "B", "B", "C"], + "trade"=> &[101, 299, 301, 500], + )?; + println!("{}", &df_trades); + // --8<-- [end:df7] + + // --8<-- [start:df8] + let df_quotes = df!( + "time"=> &[ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 2, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 4, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock"=> &["A", "B", "C", "A"], + "quote"=> &[100, 300, 501, 102], + )?; + + println!("{}", &df_quotes); + // --8<-- [end:df8] + + // --8<-- [start:asofpre] + let df_trades = df_trades.sort(["time"], false, true).unwrap(); + let df_quotes = df_quotes.sort(["time"], false, true).unwrap(); + // --8<-- [end:asofpre] + + // --8<-- [start:asof] + let df_asof_join = df_trades.join_asof_by( + &df_quotes, + "time", + "time", + ["stock"], + ["stock"], + AsofStrategy::Backward, + None, + )?; + println!("{}", &df_asof_join); + // --8<-- [end:asof] + + // --8<-- [start:asof2] + let df_asof_tolerance_join = df_trades.join_asof_by( + &df_quotes, + "time", + "time", + ["stock"], + ["stock"], + AsofStrategy::Backward, + Some(AnyValue::Duration(60000, TimeUnit::Milliseconds)), + )?; + println!("{}", &df_asof_tolerance_join); + // --8<-- [end:asof2] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/melt.rs b/docs/src/rust/user-guide/transformations/melt.rs new file mode 100644 index 000000000000..ff797423d293 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/melt.rs @@ -0,0 +1,21 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "A"=> &["a", "b", "a"], + "B"=> &[1, 3, 5], + "C"=> &[10, 11, 12], + "D"=> &[2, 4, 6], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:melt] + let out = df.melt(["A", "B"], ["C", "D"])?; + println!("{}", &out); + // --8<-- [end:melt] + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/pivot.rs b/docs/src/rust/user-guide/transformations/pivot.rs new file mode 100644 index 000000000000..e632f095f31b --- /dev/null +++ b/docs/src/rust/user-guide/transformations/pivot.rs @@ -0,0 +1,28 @@ +// --8<-- [start:setup] +use polars::prelude::{pivot::pivot, *}; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "foo"=> ["A", "A", "B", "B", "C"], + "N"=> [1, 2, 2, 4, 2], + "bar"=> ["k", "l", "m", "n", "o"], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:eager] + let out = pivot(&df, ["N"], ["foo"], ["bar"], false, None, None)?; + println!("{}", &out); + // --8<-- [end:eager] + + // --8<-- [start:lazy] + let q = df.lazy(); + let q2 = pivot(&q.collect()?, ["N"], ["foo"], ["bar"], false, None, None)?.lazy(); + let out = q2.collect()?; + println!("{}", &out); + // --8<-- [end:lazy] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/filter.rs b/docs/src/rust/user-guide/transformations/time-series/filter.rs new file mode 100644 index 000000000000..da00effb30d0 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/filter.rs @@ -0,0 +1,61 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::lazy::dsl::StrptimeOptions; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(true) + .finish() + .unwrap(); + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:filter] + let filtered_df = df + .clone() + .lazy() + .filter(col("Date").eq(lit(NaiveDate::from_ymd_opt(1995, 10, 16).unwrap()))) + .collect()?; + println!("{}", &filtered_df); + // --8<-- [end:filter] + + // --8<-- [start:range] + let filtered_range_df = df + .clone() + .lazy() + .filter( + col("Date") + .gt(lit(NaiveDate::from_ymd_opt(1995, 7, 1).unwrap())) + .and(col("Date").lt(lit(NaiveDate::from_ymd_opt(1995, 11, 1).unwrap()))), + ) + .collect()?; + println!("{}", &filtered_range_df); + // --8<-- [end:range] + + // --8<-- [start:negative] + let negative_dates_df = df!( + "ts"=> &["-1300-05-23", "-1400-03-02"], + "values"=> &[3, 4])? + .lazy() + .with_column( + col("ts") + .str() + .to_date(StrptimeOptions::default(), lit("raise")), + ) + .collect()?; + + let negative_dates_filtered_df = negative_dates_df + .clone() + .lazy() + .filter(col("ts").dt().year().lt(-1300)) + .collect()?; + println!("{}", &negative_dates_filtered_df); + // --8<-- [end:negative] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/parsing.rs b/docs/src/rust/user-guide/transformations/time-series/parsing.rs new file mode 100644 index 000000000000..0f22761d371c --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/parsing.rs @@ -0,0 +1,77 @@ +// --8<-- [start:setup] +use polars::io::prelude::*; +use polars::lazy::dsl::StrptimeOptions; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(true) + .finish() + .unwrap(); + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:cast] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(false) + .finish() + .unwrap(); + let df = df + .clone() + .lazy() + .with_columns([col("Date") + .str() + .to_date(StrptimeOptions::default(), lit("raise"))]) + .collect()?; + println!("{}", &df); + // --8<-- [end:cast] + + // --8<-- [start:df3] + let df_with_year = df + .clone() + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:df3] + + // --8<-- [start:extract] + let df_with_year = df + .clone() + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:extract] + + // --8<-- [start:mixed] + let data = [ + "2021-03-27T00:00:00+0100", + "2021-03-28T00:00:00+0100", + "2021-03-29T00:00:00+0200", + "2021-03-30T00:00:00+0200", + ]; + let q = col("date") + .str() + .to_datetime( + TimeUnit::Microseconds, + None, + StrptimeOptions { + format: Some("%Y-%m-%dT%H:%M:%S%z".to_string()), + ..Default::default() + }, + lit("raise"), + ) + .dt() + .convert_time_zone("Europe/Brussels".to_string()); + let mixed_parsed = df!("date" => &data)?.lazy().select([q]).collect()?; + + println!("{}", &mixed_parsed); + // --8<-- [end:mixed] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/resampling.rs b/docs/src/rust/user-guide/transformations/time-series/resampling.rs new file mode 100644 index 000000000000..60888c264e12 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/resampling.rs @@ -0,0 +1,43 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::prelude::*; +use polars::time::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "time" => date_range( + "time", + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(3, 0, 0).unwrap(), + Duration::parse("30m"), + ClosedWindow::Both, + TimeUnit::Milliseconds, None)?, + "groups" => &["a", "a", "a", "b", "b", "a", "a"], + "values" => &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:upsample] + let out1 = df + .clone() + .upsample::<[String; 0]>([], "time", Duration::parse("15m"), Duration::parse("0"))? + .fill_null(FillNullStrategy::Forward(None))?; + println!("{}", &out1); + // --8<-- [end:upsample] + + // --8<-- [start:upsample2] + let out2 = df + .clone() + .upsample::<[String; 0]>([], "time", Duration::parse("15m"), Duration::parse("0"))? + .lazy() + .with_columns([col("values").interpolate(InterpolationMethod::Linear)]) + .collect()? + .fill_null(FillNullStrategy::Forward(None))?; + println!("{}", &out2); + // --8<-- [end:upsample2] + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/src/rust/user-guide/transformations/time-series/rolling.rs new file mode 100644 index 000000000000..6458eb69bdfc --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/rolling.rs @@ -0,0 +1,130 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::lazy::dsl::GetOutput; +use polars::prelude::*; +use polars::time::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(true) + .finish() + .unwrap() + .sort(["Date"], false, true)?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:group_by] + let annual_average_df = df + .clone() + .lazy() + .groupby_dynamic( + col("Date"), + [], + DynamicGroupOptions { + every: Duration::parse("1y"), + period: Duration::parse("1y"), + offset: Duration::parse("0"), + ..Default::default() + }, + ) + .agg([col("Close").mean()]) + .collect()?; + + let df_with_year = annual_average_df + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:group_by] + + // --8<-- [start:group_by_dyn] + let df = df!( + "time" => date_range( + "time", + NaiveDate::from_ymd_opt(2021, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 31).unwrap().and_hms_opt(0, 0, 0).unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, None)?.cast(&DataType::Date)?)?; + + let out = df + .clone() + .lazy() + .groupby_dynamic( + col("time"), + [], + DynamicGroupOptions { + every: Duration::parse("1mo"), + period: Duration::parse("1mo"), + offset: Duration::parse("0"), + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([ + col("time") + .cumcount(true) // python example has false + .reverse() + .head(Some(3)) + .alias("day/eom"), + ((col("time").last() - col("time").first()).map( + // had to use map as .duration().days() is not available + |s| { + Ok(Some( + s.duration()? + .into_iter() + .map(|d| d.map(|v| v / 1000 / 24 / 60 / 60)) + .collect::() + .into_series(), + )) + }, + GetOutput::from_type(DataType::Int64), + ) + lit(1)) + .alias("days_in_month"), + ]) + .explode([col("day/eom")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:group_by_dyn] + + // --8<-- [start:group_by_roll] + let df = df!( + "time" => date_range( + "time", + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(3, 0, 0).unwrap(), + Duration::parse("30m"), + ClosedWindow::Both, + TimeUnit::Milliseconds, None)?, + "groups"=> ["a", "a", "a", "b", "b", "a", "a"], + )?; + println!("{}", &df); + // --8<-- [end:group_by_roll] + + // --8<-- [start:group_by_dyn2] + let out = df + .clone() + .lazy() + .groupby_dynamic( + col("time"), + [col("groups")], + DynamicGroupOptions { + every: Duration::parse("1h"), + period: Duration::parse("1h"), + offset: Duration::parse("0"), + include_boundaries: true, + closed_window: ClosedWindow::Both, + ..Default::default() + }, + ) + .agg([count()]) + .collect()?; + println!("{}", &out); + // --8<-- [end:group_by_dyn2] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/timezones.rs b/docs/src/rust/user-guide/transformations/time-series/timezones.rs new file mode 100644 index 000000000000..09865a428586 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/timezones.rs @@ -0,0 +1,48 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:example] + let ts = ["2021-03-27 03:00", "2021-03-28 03:00"]; + let tz_naive = Series::new("tz_naive", &ts); + let time_zones_df = DataFrame::new(vec![tz_naive])? + .lazy() + .select([col("tz_naive").str().to_datetime( + TimeUnit::Milliseconds, + None, + StrptimeOptions::default(), + lit("raise"), + )]) + .with_columns([col("tz_naive") + .dt() + .replace_time_zone(Some("UTC".to_string()), None) + .alias("tz_aware")]) + .collect()?; + + println!("{}", &time_zones_df); + // --8<-- [end:example] + + // --8<-- [start:example2] + let time_zones_operations = time_zones_df + .lazy() + .select([ + col("tz_aware") + .dt() + .replace_time_zone(Some("Europe/Brussels".to_string()), None) + .alias("replace time zone"), + col("tz_aware") + .dt() + .convert_time_zone("Asia/Kathmandu".to_string()) + .alias("convert time zone"), + col("tz_aware") + .dt() + .replace_time_zone(None, None) + .alias("unset time zone"), + ]) + .collect()?; + println!("{}", &time_zones_operations); + // --8<-- [end:example2] + + Ok(()) +} diff --git a/docs/user-guide/concepts/contexts.md b/docs/user-guide/concepts/contexts.md new file mode 100644 index 000000000000..604ff311ca63 --- /dev/null +++ b/docs/user-guide/concepts/contexts.md @@ -0,0 +1,64 @@ +# Contexts + +Polars has developed its own Domain Specific Language (DSL) for transforming data. The language is very easy to use and allows for complex queries that remain human readable. The two core components of the language are Contexts and Expressions, the latter we will cover in the next section. + +A context, as implied by the name, refers to the context in which an expression needs to be evaluated. There are three main contexts [^1]: + +1. Selection: `df.select([..])`, `df.with_columns([..])` +1. Filtering: `df.filter()` +1. Group by / Aggregation: `df.group_by(..).agg([..])` + +The examples below are performed on the following `DataFrame`: + +{{code_block('user-guide/concepts/contexts','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:setup" +--8<-- "python/user-guide/concepts/contexts.py:dataframe" +``` + +## Select + +In the `select` context the selection applies expressions over columns. The expressions in this context must produce `Series` that are all the same length or have a length of 1. + +A `Series` of a length of 1 will be broadcasted to match the height of the `DataFrame`. Note that a select may produce new columns that are aggregations, combinations of expressions, or literals. + +{{code_block('user-guide/concepts/contexts','select',['select'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:select" +``` + +As you can see from the query the `select` context is very powerful and allows you to perform arbitrary expressions independent (and in parallel) of each other. + +Similarly to the `select` statement there is the `with_columns` statement which also is an entrance to the selection context. The main difference is that `with_columns` retains the original columns and adds new ones while `select` drops the original columns. + +{{code_block('user-guide/concepts/contexts','with_columns',['with_columns'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:with_columns" +``` + +## Filter + +In the `filter` context you filter the existing dataframe based on arbitrary expression which evaluates to the `Boolean` data type. + +{{code_block('user-guide/concepts/contexts','filter',['filter'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:filter" +``` + +## Group by / aggregation + +In the `group_by` context, expressions work on groups and thus may yield results of any length (a group may have many members). + +{{code_block('user-guide/concepts/contexts','group_by',['group_by'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:group_by" +``` + +As you can see from the result all expressions are applied to the group defined by the `group_by` context. Besides the standard `group_by`, `group_by_dynamic`, and `group_by_rolling` are also entrances to the group by context. + +[^1]: There are additional List and SQL contexts which are covered later in this guide. But for simplicity, we leave them out of scope for now. diff --git a/docs/user-guide/concepts/data-structures.md b/docs/user-guide/concepts/data-structures.md new file mode 100644 index 000000000000..1825f8bbc892 --- /dev/null +++ b/docs/user-guide/concepts/data-structures.md @@ -0,0 +1,68 @@ +# Data structures + +The core base data structures provided by Polars are `Series` and `DataFrames`. + +## Series + +Series are a 1-dimensional data structure. Within a series all elements have the same [Data Type](data-types.md) . +The snippet below shows how to create a simple named `Series` object. + +{{code_block('getting-started/series-dataframes','series',['Series'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:series" +``` + +## DataFrame + +A `DataFrame` is a 2-dimensional data structure that is backed by a `Series`, and it can be seen as an abstraction of a collection (e.g. list) of `Series`. Operations that can be executed on a `DataFrame` are very similar to what is done in a `SQL` like query. You can `GROUP BY`, `JOIN`, `PIVOT`, but also define custom functions. + +{{code_block('getting-started/series-dataframes','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:dataframe" +``` + +### Viewing data + +This part focuses on viewing data in a `DataFrame`. We will use the `DataFrame` from the previous example as a starting point. + +#### Head + +The `head` function shows by default the first 5 rows of a `DataFrame`. You can specify the number of rows you want to see (e.g. `df.head(10)`). + +{{code_block('getting-started/series-dataframes','head',['head'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:head" +``` + +#### Tail + +The `tail` function shows the last 5 rows of a `DataFrame`. You can also specify the number of rows you want to see, similar to `head`. + +{{code_block('getting-started/series-dataframes','tail',['tail'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:tail" +``` + +#### Sample + +If you want to get an impression of the data of your `DataFrame`, you can also use `sample`. With `sample` you get an _n_ number of random rows from the `DataFrame`. + +{{code_block('getting-started/series-dataframes','sample',['sample'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:sample" +``` + +#### Describe + +`Describe` returns summary statistics of your `DataFrame`. It will provide several quick statistics if possible. + +{{code_block('getting-started/series-dataframes','describe',['describe'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:describe" +``` diff --git a/docs/user-guide/concepts/data-types.md b/docs/user-guide/concepts/data-types.md new file mode 100644 index 000000000000..77bed2d51024 --- /dev/null +++ b/docs/user-guide/concepts/data-types.md @@ -0,0 +1,45 @@ +# Data types + +`Polars` is entirely based on `Arrow` data types and backed by `Arrow` memory arrays. This makes data processing +cache-efficient and well-supported for Inter Process Communication. Most data types follow the exact implementation +from `Arrow`, with the exception of `Utf8` (this is actually `LargeUtf8`), `Categorical`, and `Object` (support is limited). The data types are: + +| Group | Type | Details | +| -------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| Numeric | `Int8` | 8-bit signed integer. | +| | `Int16` | 16-bit signed integer. | +| | `Int32` | 32-bit signed integer. | +| | `Int64` | 64-bit signed integer. | +| | `UInt8` | 8-bit unsigned integer. | +| | `UInt16` | 16-bit unsigned integer. | +| | `UInt32` | 32-bit unsigned integer. | +| | `UInt64` | 64-bit unsigned integer. | +| | `Float32` | 32-bit floating point. | +| | `Float64` | 64-bit floating point. | +| Nested | `Struct` | A struct array is represented as a `Vec` and is useful to pack multiple/heterogenous values in a single column. | +| | `List` | A list array contains a child array containing the list values and an offset array. (this is actually `Arrow` `LargeList` internally). | +| Temporal | `Date` | Date representation, internally represented as days since UNIX epoch encoded by a 32-bit signed integer. | +| | `Datetime` | Datetime representation, internally represented as microseconds since UNIX epoch encoded by a 64-bit signed integer. | +| | `Duration` | A timedelta type, internally represented as microseconds. Created when subtracting `Date/Datetime`. | +| | `Time` | Time representation, internally represented as nanoseconds since midnight. | +| Other | `Boolean` | Boolean type effectively bit packed. | +| | `Utf8` | String data (this is actually `Arrow` `LargeUtf8` internally). | +| | `Binary` | Store data as bytes. | +| | `Object` | A limited supported data type that can be any value. | +| | `Categorical` | A categorical encoding of a set of strings. | + +To learn more about the internal representation of these data types, check the [`Arrow` columnar format](https://arrow.apache.org/docs/format/Columnar.html). + +## Floating Point + +`Polars` generally follows the IEEE 754 floating point standard for `Float32` and `Float64`, with some exceptions: + +- Any NaN compares equal to any other NaN, and greater than any non-NaN value. +- Operations do not guarantee any particular behavior on the sign of zero or NaN, + nor on the payload of NaN values. This is not just limited to arithmetic operations, + e.g. a sort or group by operation may canonicalize all zeroes to +0 and all NaNs + to a positive NaN without payload for efficient equality checks. + +`Polars` always attempts to provide reasonably accurate results for floating point computations, but does not provide guarantees +on the error unless mentioned otherwise. Generally speaking 100% accurate results are infeasibly expensive to acquire (requiring +much larger internal representations than 64-bit floats), and thus some error is always to be expected. diff --git a/docs/user-guide/concepts/expressions.md b/docs/user-guide/concepts/expressions.md new file mode 100644 index 000000000000..b276c494a4a3 --- /dev/null +++ b/docs/user-guide/concepts/expressions.md @@ -0,0 +1,49 @@ +# Expressions + +`Polars` has a powerful concept called expressions that is central to its very fast performance. + +Expressions are at the core of many data science operations: + +- taking a sample of rows from a column +- multiplying values in a column +- extracting a column of years from dates +- convert a column of strings to lowercase +- and so on! + +However, expressions are also used within other operations: + +- taking the mean of a group in a `group_by` operation +- calculating the size of groups in a `group_by` operation +- taking the sum horizontally across columns + +`Polars` performs these core data transformations very quickly by: + +- automatic query optimization on each expression +- automatic parallelization of expressions on many columns + +Polars expressions are a mapping from a series to a series (or mathematically `Fn(Series) -> Series`). As expressions have a `Series` as an input and a `Series` as an output then it is straightforward to do a sequence of expressions (similar to method chaining in `Pandas`). + +## Examples + +The following is an expression: + +{{code_block('user-guide/concepts/expressions','example1',['col','sort','head'])}} + +The snippet above says: + +1. Select column "foo" +1. Then sort the column (not in reversed order) +1. Then take the first two values of the sorted output + +The power of expressions is that every expression produces a new expression, and that they +can be _piped_ together. You can run an expression by passing them to one of `Polars` execution contexts. + +Here we run two expressions by running `df.select`: + +{{code_block('user-guide/concepts/expressions','example2',['select'])}} + +All expressions are run in parallel, meaning that separate `Polars` expressions are **embarrassingly parallel**. Note that within an expression there may be more parallelization going on. + +## Conclusion + +This is the tip of the iceberg in terms of possible expressions. There are a ton more, and they can be combined in a variety of ways. This page is intended to get you familiar with the concept of expressions, in the section on [expressions](../expressions/operators.md) we will dive deeper. diff --git a/docs/user-guide/concepts/lazy-vs-eager.md b/docs/user-guide/concepts/lazy-vs-eager.md new file mode 100644 index 000000000000..1b84a0272aa5 --- /dev/null +++ b/docs/user-guide/concepts/lazy-vs-eager.md @@ -0,0 +1,28 @@ +# Lazy / eager API + +`Polars` supports two modes of operation: lazy and eager. In the eager API the query is executed immediately while in the lazy API the query is only evaluated once it is 'needed'. Deferring the execution to the last minute can have significant performance advantages that is why the Lazy API is preferred in most cases. Let us demonstrate this with an example: + +{{code_block('user-guide/concepts/lazy-vs-eager','eager',['read_csv'])}} + +In this example we use the eager API to: + +1. Read the iris [dataset](https://archive.ics.uci.edu/ml/datasets/iris). +1. Filter the dataset based on sepal length +1. Calculate the mean of the sepal width per species + +Every step is executed immediately returning the intermediate results. This can be very wasteful as we might do work or load extra data that is not being used. If we instead used the lazy API and waited on execution until all the steps are defined then the query planner could perform various optimizations. In this case: + +- Predicate pushdown: Apply filters as early as possible while reading the dataset, thus only reading rows with sepal length greater than 5. +- Projection pushdown: Select only the columns that are needed while reading the dataset, thus removing the need to load additional columns (e.g. petal length & petal width) + +{{code_block('user-guide/concepts/lazy-vs-eager','lazy',['scan_csv'])}} + +These will significantly lower the load on memory & CPU thus allowing you to fit bigger datasets in memory and process faster. Once the query is defined you call `collect` to inform `Polars` that you want to execute it. In the section on Lazy API we will go into more details on its implementation. + +!!! info "Eager API" + + In many cases the eager API is actually calling the lazy API under the hood and immediately collecting the result. This has the benefit that within the query itself optimization(s) made by the query planner can still take place. + +### When to use which + +In general the lazy API should be preferred unless you are either interested in the intermediate results or are doing exploratory work and don't know yet what your query is going to look like. diff --git a/docs/user-guide/concepts/streaming.md b/docs/user-guide/concepts/streaming.md new file mode 100644 index 000000000000..e52e28bf2cfe --- /dev/null +++ b/docs/user-guide/concepts/streaming.md @@ -0,0 +1,21 @@ +# Streaming API + +One additional benefit of the lazy API is that it allows queries to be executed in a streaming manner. Instead of processing the data all-at-once `Polars` can execute the query in batches allowing you to process datasets that are larger-than-memory. + +To tell Polars we want to execute a query in streaming mode we pass the `streaming=True` argument to `collect` + +{{code_block('user-guide/concepts/streaming','streaming',['collect'])}} + +## When is streaming available? + +Streaming is still in development. We can ask Polars to execute any lazy query in streaming mode. However, not all lazy operations support streaming. If there is an operation for which streaming is not supported Polars will run the query in non-streaming mode. + +Streaming is supported for many operations including: + +- `filter`,`slice`,`head`,`tail` +- `with_columns`,`select` +- `group_by` +- `join` +- `sort` +- `explode`,`melt` +- `scan_csv`,`scan_parquet`,`scan_ipc` diff --git a/docs/user-guide/expressions/aggregation.md b/docs/user-guide/expressions/aggregation.md new file mode 100644 index 000000000000..6b5fb8bcaf48 --- /dev/null +++ b/docs/user-guide/expressions/aggregation.md @@ -0,0 +1,122 @@ +# Aggregation + +`Polars` implements a powerful syntax defined not only in its lazy API, but also in its eager API. Let's take a look at what that means. + +We can start with the simple [US congress `dataset`](https://github.com/unitedstates/congress-legislators). + +{{code_block('user-guide/expressions/aggregation','dataframe',['DataFrame','Categorical'])}} + +#### Basic aggregations + +You can easily combine different aggregations by adding multiple expressions in a +`list`. There is no upper bound on the number of aggregations you can do, and you can +make any combination you want. In the snippet below we do the following aggregations: + +Per GROUP `"first_name"` we + +- count the number of rows in the group: + - short form: `pl.count("party")` + - full form: `pl.col("party").count()` +- aggregate the gender values groups: + - full form: `pl.col("gender")` +- get the first value of column `"last_name"` in the group: + - short form: `pl.first("last_name")` (not available in Rust) + - full form: `pl.col("last_name").first()` + +Besides the aggregation, we immediately sort the result and limit to the top `5` so that +we have a nice summary overview. + +{{code_block('user-guide/expressions/aggregation','basic',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:setup" +--8<-- "python/user-guide/expressions/aggregation.py:dataframe" +--8<-- "python/user-guide/expressions/aggregation.py:basic" +``` + +#### Conditionals + +It's that easy! Let's turn it up a notch. Let's say we want to know how +many delegates of a "state" are "Pro" or "Anti" administration. We could directly query +that in the aggregation without the need of a `lambda` or grooming the `DataFrame`. + +{{code_block('user-guide/expressions/aggregation','conditional',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:conditional" +``` + +Similarly, this could also be done with a nested GROUP BY, but that doesn't help show off some of these nice features. 😉 + +{{code_block('user-guide/expressions/aggregation','nested',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:nested" +``` + +#### Filtering + +We can also filter the groups. Let's say we want to compute a mean per group, but we +don't want to include all values from that group, and we also don't want to filter the +rows from the `DataFrame` (because we need those rows for another aggregation). + +In the example below we show how this can be done. + +!!! note + + Note that we can make `Python` functions for clarity. These functions don't cost us anything. That is because we only create `Polars` expressions, we don't apply a custom function over a `Series` during runtime of the query. Of course, you can make functions that return expressions in Rust, too. + +{{code_block('user-guide/expressions/aggregation','filter',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:filter" +``` + +#### Sorting + +It's common to see a `DataFrame` being sorted for the sole purpose of managing the ordering during a GROUP BY operation. Let's say that we want to get the names of the oldest and youngest politicians per state. We could SORT and GROUP BY. + +{{code_block('user-guide/expressions/aggregation','sort',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort" +``` + +However, **if** we also want to sort the names alphabetically, this breaks. Luckily we can sort in a `group_by` context separate from the `DataFrame`. + +{{code_block('user-guide/expressions/aggregation','sort2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort2" +``` + +We can even sort by another column in the `group_by` context. If we want to know if the alphabetically sorted name is male or female we could add: `pl.col("gender").sort_by("first_name").first().alias("gender")` + +{{code_block('user-guide/expressions/aggregation','sort3',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort3" +``` + +### Do not kill parallelization + +!!! warning "Python Users Only" + + The following section is specific to `Python`, and doesn't apply to `Rust`. Within `Rust`, blocks and closures (lambdas) can, and will, be executed concurrently. + +We have all heard that `Python` is slow, and does "not scale." Besides the overhead of +running "slow" bytecode, `Python` has to remain within the constraints of the Global +Interpreter Lock (GIL). This means that if you were to use a `lambda` or a custom `Python` +function to apply during a parallelized phase, `Polars` speed is capped running `Python` +code preventing any multiple threads from executing the function. + +This all feels terribly limiting, especially because we often need those `lambda` functions in a +`.group_by()` step, for example. This approach is still supported by `Polars`, but +keeping in mind bytecode **and** the GIL costs have to be paid. It is recommended to try to solve your queries using the expression syntax before moving to `lambdas`. If you want to learn more about using `lambdas`, go to the [user defined functions section](./user-defined-functions.md). + +### Conclusion + +In the examples above we've seen that we can do a lot by combining expressions. By doing so we delay the use of custom `Python` functions that slow down the queries (by the slow nature of Python AND the GIL). + +If we are missing a type expression let us know by opening a +[feature request](https://github.com/pola-rs/polars/issues/new/choose)! diff --git a/docs/user-guide/expressions/casting.md b/docs/user-guide/expressions/casting.md new file mode 100644 index 000000000000..88b9d3fcbbd6 --- /dev/null +++ b/docs/user-guide/expressions/casting.md @@ -0,0 +1,100 @@ +# Casting + +Casting converts the underlying [`DataType`](../concepts/data-types.md) of a column to a new one. Polars uses Arrow to manage the data in memory and relies on the compute kernels in the [rust implementation](https://github.com/jorgecarleitao/arrow2) to do the conversion. Casting is available with the `cast()` method. + +The `cast` method includes a `strict` parameter that determines how Polars behaves when it encounters a value that can't be converted from the source `DataType` to the target `DataType`. By default, `strict=True`, which means that Polars will throw an error to notify the user of the failed conversion and provide details on the values that couldn't be cast. On the other hand, if `strict=False`, any values that can't be converted to the target `DataType` will be quietly converted to `null`. + +## Numerics + +Let's take a look at the following `DataFrame` which contains both integers and floating point numbers. + +{{code_block('user-guide/expressions/casting','dfnum',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:setup" +--8<-- "python/user-guide/expressions/casting.py:dfnum" +``` + +To perform casting operations between floats and integers, or vice versa, we can invoke the `cast()` function. + +{{code_block('user-guide/expressions/casting','castnum',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:castnum" +``` + +Note that in the case of decimal values these are rounded downwards when casting to an integer. + +##### Downcast + +Reducing the memory footprint is also achievable by modifying the number of bits allocated to an element. As an illustration, the code below demonstrates how casting from `Int64` to `Int16` and from `Float64` to `Float32` can be used to lower memory usage. + +{{code_block('user-guide/expressions/casting','downcast',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:downcast" +``` + +#### Overflow + +When performing downcasting, it is crucial to ensure that the chosen number of bits (such as 64, 32, or 16) is sufficient to accommodate the largest and smallest numbers in the column. For example, using a 32-bit signed integer (`Int32`) allows handling integers within the range of -2147483648 to +2147483647, while using `Int8` covers integers between -128 to 127. Attempting to cast to a `DataType` that is too small will result in a `ComputeError` thrown by Polars, as the operation is not supported. + +{{code_block('user-guide/expressions/casting','overflow',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:overflow" +``` + +You can set the `strict` parameter to `False`, this converts values that are overflowing to null values. + +{{code_block('user-guide/expressions/casting','overflow2',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:overflow2" +``` + +## Strings + +Strings can be casted to numerical data types and vice versa: + +{{code_block('user-guide/expressions/casting','strings',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:strings" +``` + +In case the column contains a non-numerical value, Polars will throw a `ComputeError` detailing the conversion error. Setting `strict=False` will convert the non float value to `null`. + +{{code_block('user-guide/expressions/casting','strings2',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:strings2" +``` + +## Booleans + +Booleans can be expressed as either 1 (`True`) or 0 (`False`). It's possible to perform casting operations between a numerical `DataType` and a boolean, and vice versa. However, keep in mind that casting from a string (`Utf8`) to a boolean is not permitted. + +{{code_block('user-guide/expressions/casting','bool',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:bool" +``` + +## Dates + +Temporal data types such as `Date` or `Datetime` are represented as the number of days (`Date`) and microseconds (`Datetime`) since epoch. Therefore, casting between the numerical types and the temporal data types is allowed. + +{{code_block('user-guide/expressions/casting','dates',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:dates" +``` + +To convert between strings and `Dates`/`Datetimes`, `dt.to_string` and `str.to_datetime` are utilized. Polars adopts the [chrono format syntax](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) for formatting. It's worth noting that `str.to_datetime` features additional options that support timezone functionality. Refer to the API documentation for further information. + +{{code_block('user-guide/expressions/casting','dates2',['dt.to_string','str.to_date'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:dates2" +``` diff --git a/docs/user-guide/expressions/column-selections.md b/docs/user-guide/expressions/column-selections.md new file mode 100644 index 000000000000..9c9579411ba4 --- /dev/null +++ b/docs/user-guide/expressions/column-selections.md @@ -0,0 +1,134 @@ +# Column selections + +Let's create a dataset to use in this section: + +{{code_block('user-guide/expressions/column-selections','selectors_df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:setup" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_df" +``` + +## Expression expansion + +As we've seen in the previous section, we can select specific columns using the `pl.col` method. It can also select multiple columns - both as a means of convenience, and to _expand_ the expression. + +This kind of convenience feature isn't just decorative or syntactic sugar. It allows for a very powerful application of [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself) principles in your code: a single expression that specifies multiple columns expands into a list of expressions (depending on the DataFrame schema), resulting in being able to select multiple columns + run computation on them! + +### Select all, or all but some + +We can select all columns in the `DataFrame` object by providing the argument `*`: + +{{code_block('user-guide/expressions/column-selections', 'all',['all'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:all" +``` + +Often, we don't just want to include all columns, but include all _while_ excluding a few. This can be done easily as well: + +{{code_block('user-guide/expressions/column-selections','exclude',['exclude'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:exclude" +``` + +### By multiple strings + +Specifying multiple strings allows expressions to _expand_ to all matching columns: + +{{code_block('user-guide/expressions/column-selections','expansion_by_names',['dt.to_string'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:expansion_by_names" +``` + +### By regular expressions + +Multiple column selection is possible by regular expressions also, by making sure to wrap the regex by `^` and `$` to let `pl.col` know that a regex selection is expected: + +{{code_block('user-guide/expressions/column-selections','expansion_by_regex',[])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:expansion_by_regex" +``` + +### By data type + +`pl.col` can select multiple columns using Polars data types: + +{{code_block('user-guide/expressions/column-selections','expansion_by_dtype',['n_unique'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:expansion_by_dtype" +``` + +## Using `selectors` + +Polars also allows for the use of intuitive selections for columns based on their name, `dtype` or other properties; and this is built on top of existing functionality outlined in `col` used above. It is recommended to use them by importing and aliasing `polars.selectors` as `cs`. + +### By `dtype` + +To select just the integer and string columns, we can do: + +{{code_block('user-guide/expressions/column-selections','selectors_intro',['selectors'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_intro" +``` + +### Applying set operations + +These _selectors_ also allow for set based selection operations. For instance, to select the **numeric** columns **except** the **first** column that indicates row numbers: + +{{code_block('user-guide/expressions/column-selections','selectors_diff',['cs.first', 'cs.numeric'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_diff" +``` + +We can also select the row number by name **and** any **non**-numeric columns: + +{{code_block('user-guide/expressions/column-selections','selectors_union',['cs.by_name', 'cs.numeric'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_union" +``` + +### By patterns and substrings + +_Selectors_ can also be matched by substring and regex patterns: + +{{code_block('user-guide/expressions/column-selections','selectors_by_name',['cs.contains', 'cs.matches'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_by_name" +``` + +### Converting to expressions + +What if we want to apply a specific operation on the selected columns (i.e. get back to representing them as **expressions** to operate upon)? We can simply convert them using `as_expr` and then proceed as normal: + +{{code_block('user-guide/expressions/column-selections','selectors_to_expr',['cs.temporal'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_to_expr" +``` + +### Debugging `selectors` + +Polars also provides two helpful utility functions to aid with using selectors: `is_selector` and `selector_column_names`: + +{{code_block('user-guide/expressions/column-selections','selectors_is_selector_utility',['is_selector'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_is_selector_utility" +``` + +To predetermine the column names that are selected, which is especially useful for a LazyFrame object: + +{{code_block('user-guide/expressions/column-selections','selectors_colnames_utility',['selector_column_names'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_colnames_utility" +``` diff --git a/docs/user-guide/expressions/folds.md b/docs/user-guide/expressions/folds.md new file mode 100644 index 000000000000..2339f8f114e5 --- /dev/null +++ b/docs/user-guide/expressions/folds.md @@ -0,0 +1,43 @@ +# Folds + +`Polars` provides expressions/methods for horizontal aggregations like `sum`,`min`, `mean`, +etc. However, when you need a more complex aggregation the default methods `Polars` supplies may not be sufficient. That's when `folds` come in handy. + +The `fold` expression operates on columns for maximum speed. It utilizes the data layout very efficiently and often has vectorized execution. + +### Manual sum + +Let's start with an example by implementing the `sum` operation ourselves, with a `fold`. + +{{code_block('user-guide/expressions/folds','mansum',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:setup" +--8<-- "python/user-guide/expressions/folds.py:mansum" +``` + +The snippet above recursively applies the function `f(acc, x) -> acc` to an accumulator `acc` and a new column `x`. The function operates on columns individually and can take advantage of cache efficiency and vectorization. + +### Conditional + +In the case where you'd want to apply a condition/predicate on all columns in a `DataFrame` a `fold` operation can be a very concise way to express this. + +{{code_block('user-guide/expressions/folds','conditional',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:conditional" +``` + +In the snippet we filter all rows where **each** column value is `> 1`. + +### Folds and string data + +Folds could be used to concatenate string data. However, due to the materialization of intermediate columns, this operation will have squared complexity. + +Therefore, we recommend using the `concat_str` expression for this. + +{{code_block('user-guide/expressions/folds','string',['concat_str'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:string" +``` diff --git a/docs/user-guide/expressions/functions.md b/docs/user-guide/expressions/functions.md new file mode 100644 index 000000000000..fde219cb25dd --- /dev/null +++ b/docs/user-guide/expressions/functions.md @@ -0,0 +1,65 @@ +# Functions + +`Polars` expressions have a large number of built in functions. These allow you to create complex queries without the need for [user defined functions](user-defined-functions.md). There are too many to go through here, but we will cover some of the more popular use cases. If you want to view all the functions go to the API Reference for your programming language. + +In the examples below we will use the following `DataFrame`: + +{{code_block('user-guide/expressions/functions','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:setup" +--8<-- "python/user-guide/expressions/functions.py:dataframe" +``` + +## Column naming + +By default if you perform an expression it will keep the same name as the original column. In the example below we perform an expression on the `nrs` column. Note that the output `DataFrame` still has the same name. + +{{code_block('user-guide/expressions/functions','samename',[])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:samename" +``` + +This might get problematic in the case you use the same column multiple times in your expression as the output columns will get duplicated. For example, the following query will fail. + +{{code_block('user-guide/expressions/functions','samenametwice',[])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:samenametwice" +``` + +You can change the output name of an expression by using the `alias` function + +{{code_block('user-guide/expressions/functions','samenamealias',['alias'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:samenamealias" +``` + +In case of multiple columns for example when using `all()` or `col(*)` you can apply a mapping function `map_alias` to change the original column name into something else. In case you want to add a suffix (`suffix()`) or prefix (`prefix()`) these are also built in. + +=== ":fontawesome-brands-python: Python" +[:material-api: `prefix`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.prefix.html) +[:material-api: `suffix`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.suffix.html) +[:material-api: `map_alias`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.map_alias.html) + +## Count unique values + +There are two ways to count unique values in `Polars`: an exact methodology and an approximation. The approximation uses the [HyperLogLog++](https://en.wikipedia.org/wiki/HyperLogLog) algorithm to approximate the cardinality and is especially useful for very large datasets where an approximation is good enough. + +{{code_block('user-guide/expressions/functions','countunique',['n_unique','approx_n_unique'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:countunique" +``` + +## Conditionals + +`Polars` supports if-else like conditions in expressions with the `when`, `then`, `otherwise` syntax. The predicate is placed in the `when` clause and when this evaluates to `true` the `then` expression is applied otherwise the `otherwise` expression is applied (row-wise). + +{{code_block('user-guide/expressions/functions','conditional',['when'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:conditional" +``` diff --git a/docs/user-guide/expressions/lists.md b/docs/user-guide/expressions/lists.md new file mode 100644 index 000000000000..467c663aafd5 --- /dev/null +++ b/docs/user-guide/expressions/lists.md @@ -0,0 +1,119 @@ +# Lists and Arrays + +`Polars` has first-class support for `List` columns: that is, columns where each row is a list of homogeneous elements, of varying lengths. `Polars` also has an `Array` datatype, which is analogous to `numpy`'s `ndarray` objects, where the length is identical across rows. + +Note: this is different from Python's `list` object, where the elements can be of any type. Polars can store these within columns, but as a generic `Object` datatype that doesn't have the special list manipulation features that we're about to discuss. + +## Powerful `List` manipulation + +Let's say we had the following data from different weather stations across a state. When the weather station is unable to get a result, an error code is recorded instead of the actual temperature at that time. + +{{code_block('user-guide/expressions/lists','weather_df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:setup" +--8<-- "python/user-guide/expressions/lists.py:weather_df" +``` + +### Creating a `List` column + +For the `weather` `DataFrame` created above, it's very likely we need to run some analysis on the temperatures that are captured by each station. To make this happen, we need to first be able to get individual temperature measurements. This is done by: + +{{code_block('user-guide/expressions/lists','string_to_list',['str.split'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:string_to_list" +``` + +One way we could go post this would be to convert each temperature measurement into its own row: + +{{code_block('user-guide/expressions/lists','explode_to_atomic',['DataFrame.explode'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:explode_to_atomic" +``` + +However, in Polars, we often do not need to do this to operate on the `List` elements. + +### Operating on `List` columns + +Polars provides several standard operations on `List` columns. If we want the first three measurements, we can do a `head(3)`. The last three can be obtained via a `tail(3)`, or alternately, via `slice` (negative indexing is supported). We can also identify the number of observations via `lengths`. Let's see them in action: + +{{code_block('user-guide/expressions/lists','list_ops',['Expr.list'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:list_ops" +``` + +!!! warning "`arr` then, `list` now" + + If you find references to the `arr` API on Stackoverflow or other sources, just replace `arr` with `list`, this was the old accessor for the `List` datatype. `arr` now refers to the newly introduced `Array` datatype (see below). + +### Element-wise computation within `List`s + +If we need to identify the stations that are giving the most number of errors from the starting `DataFrame`, we need to: + +1. Parse the string input as a `List` of string values (already done). +2. Identify those strings that can be converted to numbers. +3. Identify the number of non-numeric values (i.e. `null` values) in the list, by row. +4. Rename this output as `errors` so that we can easily identify the stations. + +The third step requires a casting (or alternately, a regex pattern search) operation to be perform on each element of the list. We can do this using by applying the operation on each element by first referencing them in the `pl.element()` context, and then calling a suitable Polars expression on them. Let's see how: + +{{code_block('user-guide/expressions/lists','count_errors',['Expr.list', 'element'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:count_errors" +``` + +What if we chose the regex route (i.e. recognizing the presence of _any_ alphabetical character?) + +{{code_block('user-guide/expressions/lists','count_errors_regex',['str.contains'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:count_errors_regex" +``` + +If you're unfamiliar with the `(?i)`, it's a good time to look at the documentation for the `str.contains` function in Polars! The rust regex crate provides a lot of additional regex flags that might come in handy. + +## Row-wise computations + +This context is ideal for computing in row orientation. + +We can apply **any** Polars operations on the elements of the list with the `list.eval` (`list().eval` in Rust) expression! These expressions run entirely on Polars' query engine and can run in parallel, so will be well optimized. Let's say we have another set of weather data across three days, for different stations: + +{{code_block('user-guide/expressions/lists','weather_by_day',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:weather_by_day" +``` + +Let's do something interesting, where we calculate the percentage rank of the temperatures by day, measured across stations. Pandas allows you to compute the percentages of the `rank` values. `Polars` doesn't provide a special function to do this directly, but because expressions are so versatile we can create our own percentage rank expression for highest temperature. Let's try that! + +{{code_block('user-guide/expressions/lists','weather_by_day_rank',['list.eval'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:weather_by_day_rank" +``` + +## Polars `Array`s + +`Array`s are a new data type that was recently introduced, and are still pretty nascent in features that it offers. The major difference between a `List` and an `Array` is that the latter is limited to having the same number of elements per row, while a `List` can have a variable number of elements. Both still require that each element's data type is the same. + +We can define `Array` columns in this manner: + +{{code_block('user-guide/expressions/lists','array_df',['Array'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:array_df" +``` + +Basic operations are available on it: + +{{code_block('user-guide/expressions/lists','array_ops',['Series.arr'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:array_ops" +``` + +Polars `Array`s are still being actively developed, so this section will likely change in the future. diff --git a/docs/user-guide/expressions/null.md b/docs/user-guide/expressions/null.md new file mode 100644 index 000000000000..5ded317ac2b5 --- /dev/null +++ b/docs/user-guide/expressions/null.md @@ -0,0 +1,140 @@ +# Missing data + +This page sets out how missing data is represented in `Polars` and how missing data can be filled. + +## `null` and `NaN` values + +Each column in a `DataFrame` (or equivalently a `Series`) is an Arrow array or a collection of Arrow arrays [based on the Apache Arrow format](https://arrow.apache.org/docs/format/Columnar.html#null-count). Missing data is represented in Arrow and `Polars` with a `null` value. This `null` missing value applies for all data types including numerical values. + +`Polars` also allows `NotaNumber` or `NaN` values for float columns. These `NaN` values are considered to be a type of floating point data rather than missing data. We discuss `NaN` values separately below. + +You can manually define a missing value with the python `None` value: + +{{code_block('user-guide/expressions/null','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:setup" +--8<-- "python/user-guide/expressions/null.py:dataframe" +``` + +!!! info + + In `Pandas` the value for missing data depends on the dtype of the column. In `Polars` missing data is always represented as a `null` value. + +## Missing data metadata + +Each Arrow array used by `Polars` stores two kinds of metadata related to missing data. This metadata allows `Polars` to quickly show how many missing values there are and which values are missing. + +The first piece of metadata is the `null_count` - this is the number of rows with `null` values in the column: + +{{code_block('user-guide/expressions/null','count',['null_count'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:count" +``` + +The `null_count` method can be called on a `DataFrame`, a column from a `DataFrame` or a `Series`. The `null_count` method is a cheap operation as `null_count` is already calculated for the underlying Arrow array. + +The second piece of metadata is an array called a _validity bitmap_ that indicates whether each data value is valid or missing. +The validity bitmap is memory efficient as it is bit encoded - each value is either a 0 or a 1. This bit encoding means the memory overhead per array is only (array length / 8) bytes. The validity bitmap is used by the `is_null` method in `Polars`. + +You can return a `Series` based on the validity bitmap for a column in a `DataFrame` or a `Series` with the `is_null` method: + +{{code_block('user-guide/expressions/null','isnull',['is_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:isnull" +``` + +The `is_null` method is a cheap operation that does not require scanning the full column for `null` values. This is because the validity bitmap already exists and can be returned as a Boolean array. + +## Filling missing data + +Missing data in a `Series` can be filled with the `fill_null` method. You have to specify how you want the `fill_null` method to fill the missing data. The main ways to do this are filling with: + +- a literal such as 0 or "0" +- a strategy such as filling forwards +- an expression such as replacing with values from another column +- interpolation + +We illustrate each way to fill nulls by defining a simple `DataFrame` with a missing value in `col2`: + +{{code_block('user-guide/expressions/null','dataframe2',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:dataframe2" +``` + +### Fill with specified literal value + +We can fill the missing data with a specified literal value with `pl.lit`: + +{{code_block('user-guide/expressions/null','fill',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fill" +``` + +### Fill with a strategy + +We can fill the missing data with a strategy such as filling forward: + +{{code_block('user-guide/expressions/null','fillstrategy',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fillstrategy" +``` + +You can find other fill strategies in the API docs. + +### Fill with an expression + +For more flexibility we can fill the missing data with an expression. For example, +to fill nulls with the median value from that column: + +{{code_block('user-guide/expressions/null','fillexpr',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fillexpr" +``` + +In this case the column is cast from integer to float because the median is a float statistic. + +### Fill with interpolation + +In addition, we can fill nulls with interpolation (without using the `fill_null` function): + +{{code_block('user-guide/expressions/null','fillinterpolate',['interpolate'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fillinterpolate" +``` + +## `NotaNumber` or `NaN` values + +Missing data in a `Series` has a `null` value. However, you can use `NotaNumber` or `NaN` values in columns with float datatypes. These `NaN` values can be created from Numpy's `np.nan` or the native python `float('nan')`: + +{{code_block('user-guide/expressions/null','nan',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:nan" +``` + +!!! info + + In `Pandas` by default a `NaN` value in an integer column causes the column to be cast to float. This does not happen in `Polars` - instead an exception is raised. + +`NaN` values are considered to be a type of floating point data and are **not considered to be missing data** in `Polars`. This means: + +- `NaN` values are **not** counted with the `null_count` method +- `NaN` values are filled when you use `fill_nan` method but are **not** filled with the `fill_null` method + +`Polars` has `is_nan` and `fill_nan` methods which work in a similar way to the `is_null` and `fill_null` methods. The underlying Arrow arrays do not have a pre-computed validity bitmask for `NaN` values so this has to be computed for the `is_nan` method. + +One further difference between `null` and `NaN` values is that taking the `mean` of a column with `null` values excludes the `null` values from the calculation but with `NaN` values taking the mean results in a `NaN`. This behaviour can be avoided by replacing the `NaN` values with `null` values; + +{{code_block('user-guide/expressions/null','nanfill',['fill_nan'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:nanfill" +``` diff --git a/docs/user-guide/expressions/numpy.md b/docs/user-guide/expressions/numpy.md new file mode 100644 index 000000000000..6449ffd634bf --- /dev/null +++ b/docs/user-guide/expressions/numpy.md @@ -0,0 +1,22 @@ +# Numpy + +`Polars` expressions support `NumPy` [ufuncs](https://numpy.org/doc/stable/reference/ufuncs.html). See [here](https://numpy.org/doc/stable/reference/ufuncs.html#available-ufuncs) +for a list on all supported numpy functions. + +This means that if a function is not provided by `Polars`, we can use `NumPy` and we still have fast columnar operation through the `NumPy` API. + +### Example + +{{code_block('user-guide/expressions/numpy-example',api_functions=['DataFrame','np.log'])}} + +```python exec="on" result="text" session="user-guide/numpy" +--8<-- "python/user-guide/expressions/numpy-example.py" +``` + +### Interoperability + +Polars `Series` have support for NumPy universal functions (ufuncs). Element-wise functions such as `np.exp()`, `np.cos()`, `np.div()`, etc. all work with almost zero overhead. + +However, as a Polars-specific remark: missing values are a separate bitmask and are not visible by NumPy. This can lead to a window function or a `np.convolve()` giving flawed or incomplete results. + +Convert a Polars `Series` to a NumPy array with the `.to_numpy()` method. Missing values will be replaced by `np.nan` during the conversion. If the `Series` does not include missing values, or those values are not desired anymore, the `.view()` method can be used instead, providing a zero-copy NumPy array of the data. diff --git a/docs/user-guide/expressions/operators.md b/docs/user-guide/expressions/operators.md new file mode 100644 index 000000000000..24cb4e6834b8 --- /dev/null +++ b/docs/user-guide/expressions/operators.md @@ -0,0 +1,30 @@ +# Basic operators + +This section describes how to use basic operators (e.g. addition, subtraction) in conjunction with Expressions. We will provide various examples using different themes in the context of the following dataframe. + +!!! note Operator Overloading + + In Rust and Python it is possible to use the operators directly (as in `+ - * / < > `) as the language allows operator overloading. For instance, the operator `+` translates to the `.add()` method. You can choose the one you prefer. + +{{code_block('user-guide/expressions/operators','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/operators" +--8<-- "python/user-guide/expressions/operators.py:setup" +--8<-- "python/user-guide/expressions/operators.py:dataframe" +``` + +### Numerical + +{{code_block('user-guide/expressions/operators','numerical',['operators'])}} + +```python exec="on" result="text" session="user-guide/operators" +--8<-- "python/user-guide/expressions/operators.py:numerical" +``` + +### Logical + +{{code_block('user-guide/expressions/operators','logical',['operators'])}} + +```python exec="on" result="text" session="user-guide/operators" +--8<-- "python/user-guide/expressions/operators.py:logical" +``` diff --git a/docs/user-guide/expressions/strings.md b/docs/user-guide/expressions/strings.md new file mode 100644 index 000000000000..93b1c4de93f7 --- /dev/null +++ b/docs/user-guide/expressions/strings.md @@ -0,0 +1,62 @@ +# Strings + +The following section discusses operations performed on `Utf8` strings, which are a frequently used `DataType` when working with `DataFrames`. However, processing strings can often be inefficient due to their unpredictable memory size, causing the CPU to access many random memory locations. To address this issue, Polars utilizes `Arrow` as its backend, which stores all strings in a contiguous block of memory. As a result, string traversal is cache-optimal and predictable for the CPU. + +String processing functions are available in the `str` namespace. + +##### Accessing the string namespace + +The `str` namespace can be accessed through the `.str` attribute of a column with `Utf8` data type. In the following example, we create a column named `animal` and compute the length of each element in the column in terms of the number of bytes and the number of characters. If you are working with ASCII text, then the results of these two computations will be the same, and using `lengths` is recommended since it is faster. + +{{code_block('user-guide/expressions/strings','df',['str.len_bytes','str.len_chars'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:setup" +--8<-- "python/user-guide/expressions/strings.py:df" +``` + +#### String parsing + +`Polars` offers multiple methods for checking and parsing elements of a string. Firstly, we can use the `contains` method to check whether a given pattern exists within a substring. Subsequently, we can extract these patterns and replace them using other methods, which will be demonstrated in upcoming examples. + +##### Check for existence of a pattern + +To check for the presence of a pattern within a string, we can use the contains method. The `contains` method accepts either a regular substring or a regex pattern, depending on the value of the `literal` parameter. If the pattern we're searching for is a simple substring located either at the beginning or end of the string, we can alternatively use the `starts_with` and `ends_with` functions. + +{{code_block('user-guide/expressions/strings','existence',['str.contains', 'str.starts_with','str.ends_with'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:existence" +``` + +##### Extract a pattern + +The `extract` method allows us to extract a pattern from a specified string. This method takes a regex pattern containing one or more capture groups, which are defined by parentheses `()` in the pattern. The group index indicates which capture group to output. + +{{code_block('user-guide/expressions/strings','extract',['str.extract'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:extract" +``` + +To extract all occurrences of a pattern within a string, we can use the `extract_all` method. In the example below, we extract all numbers from a string using the regex pattern `(\d+)`, which matches one or more digits. The resulting output of the `extract_all` method is a list containing all instances of the matched pattern within the string. + +{{code_block('user-guide/expressions/strings','extract_all',['str.extract_all'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:extract_all" +``` + +##### Replace a pattern + +We have discussed two methods for pattern matching and extraction thus far, and now we will explore how to replace a pattern within a string. Similar to `extract` and `extract_all`, Polars provides the `replace` and `replace_all` methods for this purpose. In the example below we replace one match of `abc` at the end of a word (`\b`) by `ABC` and we replace all occurrence of `a` with `-`. + +{{code_block('user-guide/expressions/strings','replace',['str.replace','str.replace_all'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:replace" +``` + +#### API documentation + +In addition to the examples covered above, Polars offers various other string manipulation methods for tasks such as formatting, stripping, splitting, and more. To explore these additional methods, you can go to the API documentation of your chosen programming language for Polars. diff --git a/docs/user-guide/expressions/structs.md b/docs/user-guide/expressions/structs.md new file mode 100644 index 000000000000..61978bbc25e7 --- /dev/null +++ b/docs/user-guide/expressions/structs.md @@ -0,0 +1,99 @@ +# The Struct datatype + +Polars `Struct`s are the idiomatic way of working with multiple columns. It is also a free operation i.e. moving columns into `Struct`s does not copy any data! + +For this section, let's start with a `DataFrame` that captures the average rating of a few movies across some states in the U.S.: + +{{code_block('user-guide/expressions/structs','ratings_df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:setup" +--8<-- "python/user-guide/expressions/structs.py:ratings_df" +``` + +## Encountering the `Struct` type + +A common operation that will lead to a `Struct` column is the ever so popular `value_counts` function that is commonly used in exploratory data analysis. Checking the number of times a state appears the data will be done as so: + +{{code_block('user-guide/expressions/structs','state_value_counts',['value_counts'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:state_value_counts" +``` + +Quite unexpected an output, especially if coming from tools that do not have such a data type. We're not in peril though, to get back to a more familiar output, all we need to do is `unnest` the `Struct` column into its constituent columns: + +{{code_block('user-guide/expressions/structs','struct_unnest',['unnest'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_unnest" +``` + +!!! note "Why `value_counts` returns a `Struct`" + + Polars expressions always have a `Fn(Series) -> Series` signature and `Struct` is thus the data type that allows us to provide multiple columns as input/ouput of an expression. In other words, all expressions have to return a `Series` object, and `Struct` allows us to stay consistent with that requirement. + +## Structs as `dict`s + +Polars will interpret a `dict` sent to the `Series` constructor as a `Struct`: + +{{code_block('user-guide/expressions/structs','series_struct',['Series'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct" +``` + +!!! note "Constructing `Series` objects" + + Note that `Series` here was constructed with the `name` of the series in the beginning, followed by the `values`. Providing the latter first + is considered an anti-pattern in Polars, and must be avoided. + +### Extracting individual values of a `Struct` + +Let's say that we needed to obtain just the `movie` value in the `Series` that we created above. We can use the `field` method to do so: + +{{code_block('user-guide/expressions/structs','series_struct_extract',['struct.field'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_extract" +``` + +### Renaming individual keys of a `Struct` + +What if we need to rename individual `field`s of a `Struct` column? We first convert the `rating_series` object to a `DataFrame` so that we can view the changes easily, and then use the `rename_fields` method: + +{{code_block('user-guide/expressions/structs','series_struct_rename',['struct.rename_fields'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_rename" +``` + +## Practical use-cases of `Struct` columns + +### Identifying duplicate rows + +Let's get back to the `ratings` data. We want to identify cases where there are duplicates at a `Movie` and `Theatre` level. This is where the `Struct` datatype shines: + +{{code_block('user-guide/expressions/structs','struct_duplicates',['is_duplicated', 'struct'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_duplicates" +``` + +We can identify the unique cases at this level also with `is_unique`! + +### Multi-column ranking + +Suppose, given that we know there are duplicates, we want to choose which rank gets a higher priority. We define `Count` of ratings to be more important than the actual `Avg_Rating` themselves, and only use it to break a tie. We can then do: + +{{code_block('user-guide/expressions/structs','struct_ranking',['is_duplicated', 'struct'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_ranking" +``` + +That's a pretty complex set of requirements done very elegantly in Polars! + +### Using multi-column apply + +This was discussed in the previous section on _User Defined Functions_. diff --git a/docs/user-guide/expressions/user-defined-functions.md b/docs/user-guide/expressions/user-defined-functions.md new file mode 100644 index 000000000000..dd83cb13c382 --- /dev/null +++ b/docs/user-guide/expressions/user-defined-functions.md @@ -0,0 +1,187 @@ +# User-defined functions + +!!! warning "Not updated for Python Polars `0.19.0`" + + This section of the user guide still needs to be updated for the latest Polars release. + +You should be convinced by now that Polars expressions are so powerful and flexible that there is much less need for custom Python functions +than in other libraries. + +Still, you need to have the power to be able to pass an expression's state to a third party library or apply your black box function +over data in Polars. + +For this we provide the following expressions: + +- `map` +- `apply` + +## To `map` or to `apply`. + +These functions have an important distinction in how they operate and consequently what data they will pass to the user. + +A `map` passes the `Series` backed by the `expression` as is. + +`map` follows the same rules in both the `select` and the `group_by` context, this will +mean that the `Series` represents a column in a `DataFrame`. Note that in the `group_by` context, that column is not yet +aggregated! + +Use cases for `map` are for instance passing the `Series` in an expression to a third party library. Below we show how +we could use `map` to pass an expression column to a neural network model. + +=== ":fontawesome-brands-python: Python" +[:material-api: `map`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.map.html) + +```python +df.with_columns([ + pl.col("features").map(lambda s: MyNeuralNetwork.forward(s.to_numpy())).alias("activations") +]) +``` + +=== ":fontawesome-brands-rust: Rust" + +```rust +df.with_columns([ + col("features").map(|s| Ok(my_nn.forward(s))).alias("activations") +]) +``` + +Use cases for `map` in the `group_by` context are slim. They are only used for performance reasons, but can quite easily lead to incorrect results. Let me explain why. + +{{code_block('user-guide/expressions/user-defined-functions','dataframe',['map'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:setup" +--8<-- "python/user-guide/expressions/user-defined-functions.py:dataframe" +``` + +In the snippet above we group by the `"keys"` column. That means we have the following groups: + +```c +"a" -> [10, 7] +"b" -> [1] +``` + +If we would then apply a `shift` operation to the right, we'd expect: + +```c +"a" -> [null, 10] +"b" -> [null] +``` + +Now, let's print and see what we've got. + +```python +print(out) +``` + +``` +shape: (2, 3) +┌──────┬────────────┬──────────────────┐ +│ keys ┆ shift_map ┆ shift_expression │ +│ --- ┆ --- ┆ --- │ +│ str ┆ list[i64] ┆ list[i64] │ +╞══════╪════════════╪══════════════════╡ +│ a ┆ [null, 10] ┆ [null, 10] │ +│ b ┆ [7] ┆ [null] │ +└──────┴────────────┴──────────────────┘ +``` + +Ouch.. we clearly get the wrong results here. Group `"b"` even got a value from group `"a"` 😵. + +This went horribly wrong, because the `map` applies the function before we aggregate! So that means the whole column `[10, 7, 1`\] got shifted to `[null, 10, 7]` and was then aggregated. + +So my advice is to never use `map` in the `group_by` context unless you know you need it and know what you are doing. + +## To `apply` + +Luckily we can fix previous example with `apply`. `apply` works on the smallest logical elements for that operation. + +That is: + +- `select context` -> single elements +- `group by context` -> single groups + +So with `apply` we should be able to fix our example: + +{{code_block('user-guide/expressions/user-defined-functions','apply',['apply'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:apply" +``` + +And observe, a valid result! 🎉 + +## `apply` in the `select` context + +In the `select` context, the `apply` expression passes elements of the column to the python function. + +_Note that you are now running Python, this will be slow._ + +Let's go through some examples to see what to expect. We will continue with the `DataFrame` we defined at the start of +this section and show an example with the `apply` function and a counter example where we use the expression API to +achieve the same goals. + +### Adding a counter + +In this example we create a global `counter` and then add the integer `1` to the global state at every element processed. +Every iteration the result of the increment will be added to the element value. + +> Note, this example isn't provided in Rust. The reason is that the global `counter` value would lead to data races when this apply is evaluated in parallel. It would be possible to wrap it in a `Mutex` to protect the variable, but that would be obscuring the point of the example. This is a case where the Python Global Interpreter Lock's performance tradeoff provides some safety guarantees. + +{{code_block('user-guide/expressions/user-defined-functions','counter',['apply'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:counter" +``` + +### Combining multiple column values + +If we want to have access to values of different columns in a single `apply` function call, we can create `struct` data +type. This data type collects those columns as fields in the `struct`. So if we'd create a struct from the columns +`"keys"` and `"values"`, we would get the following struct elements: + +```python +[ + {"keys": "a", "values": 10}, + {"keys": "a", "values": 7}, + {"keys": "b", "values": 1}, +] +``` + +In Python, those would be passed as `dict` to the calling python function and can thus be indexed by `field: str`. In rust, you'll get a `Series` with the `Struct` type. The fields of the struct can then be indexed and downcast. + +{{code_block('user-guide/expressions/user-defined-functions','combine',['apply','struct'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:combine" +``` + +`Structs` are covered in detail in the next section. + +### Return types? + +Custom python functions are black boxes for polars. We really don't know what kind of black arts you are doing, so we have +to infer and try our best to understand what you meant. + +As a user it helps to understand what we do to better utilize custom functions. + +The data type is automatically inferred. We do that by waiting for the first non-null value. That value will then be used +to determine the type of the `Series`. + +The mapping of python types to polars data types is as follows: + +- `int` -> `Int64` +- `float` -> `Float64` +- `bool` -> `Boolean` +- `str` -> `Utf8` +- `list[tp]` -> `List[tp]` (where the inner type is inferred with the same rules) +- `dict[str, [tp]]` -> `struct` +- `Any` -> `object` (Prevent this at all times) + +Rust types map as follows: + +- `i32` or `i64` -> `Int64` +- `f32` or `f64` -> `Float64` +- `bool` -> `Boolean` +- `String` or `str` -> `Utf8` +- `Vec` -> `List[tp]` (where the inner type is inferred with the same rules) diff --git a/docs/user-guide/expressions/window.md b/docs/user-guide/expressions/window.md new file mode 100644 index 000000000000..7ea426ccb1b9 --- /dev/null +++ b/docs/user-guide/expressions/window.md @@ -0,0 +1,91 @@ +# Window functions + +Window functions are expressions with superpowers. They allow you to perform aggregations on groups in the +`select` context. Let's get a feel for what that means. First we create a dataset. The dataset loaded in the +snippet below contains information about pokemon: + +{{code_block('user-guide/expressions/window','pokemon',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:pokemon" +``` + +## Group by aggregations in selection + +Below we show how to use window functions to group over different columns and perform an aggregation on them. +Doing so allows us to use multiple group by operations in parallel, using a single query. The results of the aggregation +are projected back to the original rows. Therefore, a window function will almost always lead to a `DataFrame` with the same size as the original. + +We will discuss later the cases where a window function can change the numbers of rows in a `DataFrame`. + +Note how we call `.over("Type 1")` and `.over(["Type 1", "Type 2"])`. Using window functions we can aggregate over different groups in a single `select` call! Note that, in Rust, the type of the argument to `over()` must be a collection, so even when you're only using one column, you must provided it in an array. + +The best part is, this won't cost you anything. The computed groups are cached and shared between different `window` expressions. + +{{code_block('user-guide/expressions/window','group_by',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:group_by" +``` + +## Operations per group + +Window functions can do more than aggregation. They can also be viewed as an operation within a group. If, for instance, you +want to `sort` the values within a `group`, you can write `col("value").sort().over("group")` and voilà! We sorted by group! + +Let's filter out some rows to make this more clear. + +{{code_block('user-guide/expressions/window','operations',['filter'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:operations" +``` + +Observe that the group `Water` of column `Type 1` is not contiguous. There are two rows of `Grass` in between. Also note +that each pokemon within a group are sorted by `Speed` in `ascending` order. Unfortunately, for this example we want them sorted in +`descending` speed order. Luckily with window functions this is easy to accomplish. + +{{code_block('user-guide/expressions/window','sort',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:sort" +``` + +`Polars` keeps track of each group's location and maps the expressions to the proper row locations. This will also work over different groups in a single `select`. + +The power of window expressions is that you often don't need a `group_by -> explode` combination, but you can put the logic in a single expression. It also makes the API cleaner. If properly used a: + +- `group_by` -> marks that groups are aggregated and we expect a `DataFrame` of size `n_groups` +- `over` -> marks that we want to compute something within a group, and doesn't modify the original size of the `DataFrame` except in specific cases + +## Map the expression result to the DataFrame rows + +In cases where the expression results in multiple values per group, the Window function has 3 strategies for linking the values back to the `DataFrame` rows: + +- `mapping_strategy = 'group_to_rows'` -> each value is assigned back to one row. The number of values returned should match the number of rows. + +- `mapping_strategy = 'join'` -> the values are imploded in a list, and the list is repeated on all rows. This can be memory intensive. + +- `mapping_strategy = 'explode'` -> the values are exploded to new rows. This operation changes the number of rows. + +## Window expression rules + +The evaluations of window expressions are as follows (assuming we apply it to a `pl.Int32` column): + +{{code_block('user-guide/expressions/window','rules',['over'])}} + +## More examples + +For more exercise, below are some window functions for us to compute: + +- sort all pokemon by type +- select the first `3` pokemon per type as `"Type 1"` +- sort the pokemon within a type by speed in descending order and select the first `3` as `"fastest/group"` +- sort the pokemon within a type by attack in descending order and select the first `3` as `"strongest/group"` +- sort the pokemon within a type by name and select the first `3` as `"sorted_by_alphabet"` + +{{code_block('user-guide/expressions/window','examples',['over','implode'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:examples" +``` diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md new file mode 100644 index 000000000000..8fb27a98c743 --- /dev/null +++ b/docs/user-guide/index.md @@ -0,0 +1,31 @@ +# Introduction + +This User Guide is an introduction to the [`Polars` DataFrame library](https://github.com/pola-rs/polars). Its goal is to introduce you to `Polars` by going through examples and comparing it to other +solutions. Some design choices are introduced here. The guide will also introduce you to optimal usage of `Polars`. + +Even though `Polars` is completely written in [`Rust`](https://www.rust-lang.org/) (no runtime overhead!) and uses [`Arrow`](https://arrow.apache.org/) -- the +[native arrow2 `Rust` implementation](https://github.com/jorgecarleitao/arrow2) -- as its foundation, the examples presented in this guide will be mostly using its higher-level language +bindings. Higher-level bindings only serve as a thin wrapper for functionality implemented in the core library. + +For [`Pandas`](https://pandas.pydata.org/) users, our [Python package](https://pypi.org/project/polars/) will offer the easiest way to get started with `Polars`. + +### Philosophy + +The goal of `Polars` is to provide a lightning fast `DataFrame` library that: + +- Utilizes all available cores on your machine. +- Optimizes queries to reduce unneeded work/memory allocations. +- Handles datasets much larger than your available RAM. +- Has an API that is consistent and predictable. +- Has a strict schema (data-types should be known before running the query). + +Polars is written in Rust which gives it C/C++ performance and allows it to fully control performance critical parts +in a query engine. + +As such `Polars` goes to great lengths to: + +- Reduce redundant copies. +- Traverse memory cache efficiently. +- Minimize contention in parallelism. +- Process data in chunks. +- Reuse memory allocations. diff --git a/docs/user-guide/installation.md b/docs/user-guide/installation.md new file mode 100644 index 000000000000..e461e44ba692 --- /dev/null +++ b/docs/user-guide/installation.md @@ -0,0 +1,173 @@ +# Installation + +Polars is a library and installation is as simple as invoking the package manager of the corresponding programming language. + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars -F lazy + + # Or Cargo.toml + [dependencies] + polars = { version = "x", features = ["lazy", ...]} + ``` + +## Importing + +To use the library import it into your project + +=== ":fontawesome-brands-python: Python" + + ``` python + import polars as pl + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` rust + use polars::prelude::*; + ``` + +## Feature Flags + +By using the above command you install the core of `Polars` onto your system. However depending on your use case you might want to install the optional dependencies as well. These are made optional to minimize the footprint. The flags are different depending on the programming language. Throughout the user guide we will mention when a functionality is used that requires an additional dependency. + +### Python + +```text +# For example +pip install polars[numpy, fsspec] +``` + +| Tag | Description | +| ---------- | ------------------------------------------------------------------------------------------------------------------------------------- | +| all | Install all optional dependencies (all of the following) | +| pandas | Install with Pandas for converting data to and from Pandas Dataframes/Series | +| numpy | Install with numpy for converting data to and from numpy arrays | +| pyarrow | Reading data formats using PyArrow | +| fsspec | Support for reading from remote file systems | +| connectorx | Support for reading from SQL databases | +| xlsx2csv | Support for reading from Excel files | +| deltalake | Support for reading from Delta Lake Tables | +| timezone | Timezone support, only needed if 1. you are on Python < 3.9 and/or 2. you are on Windows, otherwise no dependencies will be installed | + +### Rust + +```toml +# Cargo.toml +[dependencies] +polars = { version = "0.26.1", features = ["lazy", "temporal", "describe", "json", "parquet", "dtype-datetime"] } +``` + +The opt-in features are: + +- Additional data types: + - `dtype-date` + - `dtype-datetime` + - `dtype-time` + - `dtype-duration` + - `dtype-i8` + - `dtype-i16` + - `dtype-u8` + - `dtype-u16` + - `dtype-categorical` + - `dtype-struct` +- `lazy` - Lazy API + - `lazy_regex` - Use regexes in [column selection](crate::lazy::dsl::col) + - `dot_diagram` - Create dot diagrams from lazy logical plans. +- `sql` - Pass SQL queries to polars. +- `streaming` - Be able to process datasets that are larger than RAM. +- `random` - Generate arrays with randomly sampled values +- `ndarray`- Convert from `DataFrame` to `ndarray` +- `temporal` - Conversions between [Chrono](https://docs.rs/chrono/) and Polars for temporal data types +- `timezones` - Activate timezone support. +- `strings` - Extra string utilities for `Utf8Chunked` + - `string_justify` - `zfill`, `ljust`, `rjust` + - `string_from_radix` - `parse_int` +- `object` - Support for generic ChunkedArrays called `ObjectChunked` (generic over `T`). + These are downcastable from Series through the [Any](https://doc.rust-lang.org/std/any/index.html) trait. +- Performance related: + - `nightly` - Several nightly only features such as SIMD and specialization. + - `performant` - more fast paths, slower compile times. + - `bigidx` - Activate this feature if you expect >> 2^32 rows. This has not been needed by anyone. + This allows polars to scale up way beyond that by using `u64` as an index. + Polars will be a bit slower with this feature activated as many data structures + are less cache efficient. + - `cse` - Activate common subplan elimination optimization +- IO related: + + - `serde` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. + Can be used for JSON and more serde supported serialization formats. + - `serde-lazy` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. + Can be used for JSON and more serde supported serialization formats. + + - `parquet` - Read Apache Parquet format + - `json` - JSON serialization + - `ipc` - Arrow's IPC format serialization + - `decompress` - Automatically infer compression of csvs and decompress them. + Supported compressions: + - zip + - gzip + +- `DataFrame` operations: + - `dynamic_group_by` - Group by based on a time window instead of predefined keys. + Also activates rolling window group by operations. + - `sort_multiple` - Allow sorting a `DataFrame` on multiple columns + - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`. + And activates `pivot` and `transpose` operations + - `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match. + - `cross_join` - Create the cartesian product of two DataFrames. + - `semi_anti_join` - SEMI and ANTI joins. + - `group_by_list` - Allow group by operation on keys of type List. + - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked + - `diagonal_concat` - Concat diagonally thereby combining different schemas. + - `horizontal_concat` - Concat horizontally and extend with null values if lengths don't match + - `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series) + - `partition_by` - Split into multiple DataFrames partitioned by groups. +- `Series`/`Expression` operations: + - `is_in` - [Check for membership in `Series`](crate::chunked_array::ops::IsIn) + - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip) + - `round_series` - round underlying float types of `Series`. + - `repeat_by` - [Repeat element in an Array N times, where N is given by another array. + - `is_first_distinct` - Check if element is first unique value. + - `is_last_distinct` - Check if element is last unique value. + - `checked_arithmetic` - checked arithmetic/ returning `None` on invalid operations. + - `dot_product` - Dot/inner product on Series and Expressions. + - `concat_str` - Concat string data in linear time. + - `reinterpret` - Utility to reinterpret bits to signed/unsigned + - `take_opt_iter` - Take from a Series with `Iterator>` + - `mode` - [Return the most occurring value(s)](crate::chunked_array::ops::ChunkUnique::mode) + - `cum_agg` - cumsum, cummin, cummax aggregation. + - `rolling_window` - rolling window functions, like rolling_mean + - `interpolate` [interpolate None values](crate::chunked_array::ops::Interpolate) + - `extract_jsonpath` - [Run jsonpath queries on Utf8Chunked](https://goessner.net/articles/JsonPath/) + - `list` - List utils. + - `list_take` take sublist by multiple indices + - `rank` - Ranking algorithms. + - `moment` - kurtosis and skew statistics + - `ewma` - Exponential moving average windows + - `abs` - Get absolute values of Series + - `arange` - Range operation on Series + - `product` - Compute the product of a Series. + - `diff` - `diff` operation. + - `pct_change` - Compute change percentages. + - `unique_counts` - Count unique values in expressions. + - `log` - Logarithms for `Series`. + - `list_to_struct` - Convert `List` to `Struct` dtypes. + - `list_count` - Count elements in lists. + - `list_eval` - Apply expressions over list elements. + - `cumulative_eval` - Apply expressions over cumulatively increasing windows. + - `arg_where` - Get indices where condition holds. + - `search_sorted` - Find indices where elements should be inserted to maintain order. + - `date_offset` Add an offset to dates that take months and leap years into account. + - `trigonometry` Trigonometric functions. + - `sign` Compute the element-wise sign of a Series. + - `propagate_nans` NaN propagating min/max aggregations. +- `DataFrame` pretty printing + - `fmt` - Activate DataFrame formatting diff --git a/docs/user-guide/io/bigquery.md b/docs/user-guide/io/bigquery.md new file mode 100644 index 000000000000..21287cd448d2 --- /dev/null +++ b/docs/user-guide/io/bigquery.md @@ -0,0 +1,19 @@ +# Google BigQuery + +To read or write from GBQ, additional dependencies are needed: + +=== ":fontawesome-brands-python: Python" + +```shell +$ pip install google-cloud-bigquery +``` + +## Read + +We can load a query into a `DataFrame` like this: + +{{code_block('user-guide/io/bigquery','read',['from_arrow'])}} + +## Write + +{{code_block('user-guide/io/bigquery','write',[])}} diff --git a/docs/user-guide/io/cloud-storage.md b/docs/user-guide/io/cloud-storage.md new file mode 100644 index 000000000000..a10226a99e65 --- /dev/null +++ b/docs/user-guide/io/cloud-storage.md @@ -0,0 +1,51 @@ +# Cloud storage + +Polars can read and write to AWS S3, Azure Blob Storage and Google Cloud Storage. The API is the same for all three storage providers. + +To read from cloud storage, additional dependencies may be needed depending on the use case and cloud storage provider: + +=== ":fontawesome-brands-python: Python" + + ```shell + $ pip install fsspec s3fs adlfs gcsfs + ``` + +=== ":fontawesome-brands-rust: Rust" + + ```shell + $ cargo add aws_sdk_s3 aws_config tokio --features tokio/full + ``` + +## Reading from cloud storage + +Polars can read a CSV, IPC or Parquet file in eager mode from cloud storage. + +{{code_block('user-guide/io/cloud-storage','read_parquet',['read_parquet','read_csv','read_ipc'])}} + +This eager query downloads the file to a buffer in memory and creates a `DataFrame` from there. Polars uses `fsspec` to manage this download internally for all cloud storage providers. + +## Scanning from cloud storage with query optimisation + +Polars can scan a Parquet file in lazy mode from cloud storage. We may need to provide further details beyond the source url such as authentication details or storage region. Polars looks for these as environment variables but we can also do this manually by passing a `dict` as the `storage_options` argument. + +{{code_block('user-guide/io/cloud-storage','scan_parquet',['scan_parquet'])}} + +This query creates a `LazyFrame` without downloading the file. In the `LazyFrame` we have access to file metadata such as the schema. Polars uses the `object_store.rs` library internally to manage the interface with the cloud storage providers and so no extra dependencies are required in Python to scan a cloud Parquet file. + +If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimiszr will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. + +{{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} + +## Scanning with PyArrow + +We can also scan from cloud storage using PyArrow. This is particularly useful for partitioned datasets such as Hive partitioning. + +We first create a PyArrow dataset and then create a `LazyFrame` from the dataset. + +{{code_block('user-guide/io/cloud-storage','scan_pyarrow_dataset',['scan_pyarrow_dataset'])}} + +## Writing to cloud storage + +We can write a `DataFrame` to cloud storage in Python using s3fs for S3, adlfs for Azure Blob Storage and gcsfs for Google Cloud Storage. In this example, we write a Parquet file to S3. + +{{code_block('user-guide/io/cloud-storage','write_parquet',['write_parquet'])}} diff --git a/docs/user-guide/io/csv.md b/docs/user-guide/io/csv.md new file mode 100644 index 000000000000..eeb209dfb34e --- /dev/null +++ b/docs/user-guide/io/csv.md @@ -0,0 +1,21 @@ +# CSV + +## Read & write + +Reading a CSV file should look familiar: + +{{code_block('user-guide/io/csv','read',['read_csv'])}} + +Writing a CSV file is similar with the `write_csv` function: + +{{code_block('user-guide/io/csv','write',['write_csv'])}} + +## Scan + +`Polars` allows you to _scan_ a CSV input. Scanning delays the actual parsing of the +file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/csv','scan',['scan_csv'])}} + +If you want to know why this is desirable, you can read more about these `Polars` +optimizations [here](../concepts/lazy-vs-eager.md). diff --git a/docs/user-guide/io/database.md b/docs/user-guide/io/database.md new file mode 100644 index 000000000000..f83706e5e79a --- /dev/null +++ b/docs/user-guide/io/database.md @@ -0,0 +1,83 @@ +# Databases + +## Read from a database + +Polars can read from a database using the `pl.read_database_uri` and `pl.read_database` functions. + +### Difference between `read_database_uri` and `read_database` + +Use `pl.read_database_uri` if you want to specify the database connection with a connection string called a `uri`. For example, the following snippet shows a query to read all columns from the `foo` table in a Postgres database where we use the `uri` to connect: + +{{code_block('user-guide/io/database','read_uri',['read_database_uri'])}} + +On the other hand, use `pl.read_database` if you want to connect via a connection engine created with a library like SQLAlchemy. + +{{code_block('user-guide/io/database','read_cursor',['read_database'])}} + +Note that `pl.read_database_uri` is likely to be faster than `pl.read_database` if you are using a SQLAlchemy or DBAPI2 connection as these connections may load the data row-wise into Python before copying the data again to the column-wise Apache Arrow format. + +### Engines + +Polars doesn't manage connections and data transfer from databases by itself. Instead, external libraries (known as _engines_) handle this. + +When using `pl.read_database`, you specify the engine when you create the connection object. When using `pl.read_database_uri`, you can specify one of two engines to read from the database: + +- [ConnectorX](https://github.com/sfu-db/connector-x) and +- [ADBC](https://arrow.apache.org/docs/format/ADBC.html) + +Both engines have native support for Apache Arrow and so can read data directly into a Polars `DataFrame` without copying the data. + +#### ConnectorX + +ConnectorX is the default engine and [supports numerous databases](https://github.com/sfu-db/connector-x#sources) including Postgres, Mysql, SQL Server and Redshift. ConnectorX is written in Rust and stores data in Arrow format to allow for zero-copy to Polars. + +To read from one of the supported databases with `ConnectorX` you need to activate the additional dependency `ConnectorX` when installing Polars or install it manually with + +```shell +$ pip install connectorx +``` + +#### ADBC + +ADBC (Arrow Database Connectivity) is an engine supported by the Apache Arrow project. ADBC aims to be both an API standard for connecting to databases and libraries implementing this standard in a range of languages. + +It is still early days for ADBC so support for different databases is still limited. At present drivers for ADBC are only available for [Postgres and SQLite](https://arrow.apache.org/adbc/0.1.0/driver/cpp/index.html). To install ADBC you need to install the driver for your database. For example to install the driver for SQLite you run + +```shell +$ pip install adbc-driver-sqlite +``` + +As ADBC is not the default engine you must specify the engine as an argument to `pl.read_database_uri` + +{{code_block('user-guide/io/database','adbc',['read_database_uri'])}} + +## Write to a database + +We can write to a database with Polars using the `pl.write_database` function. + +### Engines + +As with reading from a database above Polars uses an _engine_ to write to a database. The currently supported engines are: + +- [SQLAlchemy](https://www.sqlalchemy.org/) and +- Arrow Database Connectivity (ADBC) + +#### SQLAlchemy + +With the default engine SQLAlchemy you can write to any database supported by SQLAlchemy. To use this engine you need to install SQLAlchemy and Pandas + +```shell +$ pip install SQLAlchemy pandas +``` + +In this example, we write the `DataFrame` to a table called `records` in the database + +{{code_block('user-guide/io/database','write',['write_database'])}} + +In the SQLAlchemy approach, Polars converts the `DataFrame` to a Pandas `DataFrame` backed by PyArrow and then uses SQLAlchemy methods on a Pandas `DataFrame` to write to the database. + +#### ADBC + +As with reading from a database, you can also use ADBC to write to a SQLite or Posgres database. As shown above, you need to install the appropriate ADBC driver for your database. + +{{code_block('user-guide/io/database','write_adbc',['write_database'])}} diff --git a/docs/user-guide/io/json.md b/docs/user-guide/io/json.md new file mode 100644 index 000000000000..c203d278ee87 --- /dev/null +++ b/docs/user-guide/io/json.md @@ -0,0 +1,30 @@ +# JSON files + +Polars can read and write both standard JSON and newline-delimited JSON (NDJSON). + +## Read + +### JSON + +Reading a JSON file should look familiar: + +{{code_block('user-guide/io/json','read',['read_json'])}} + +### Newline Delimited JSON + +JSON objects that are delimited by newlines can be read into polars in a much more performant way than standard json. + +Polars can read an NDJSON file into a `DataFrame` using the `read_ndjson` function: + +{{code_block('user-guide/io/json','readnd',['read_ndjson'])}} + +## Write + +{{code_block('user-guide/io/json','write',['write_json','write_ndjson'])}} + +## Scan + +`Polars` allows you to _scan_ a JSON input **only for newline delimited json**. Scanning delays the actual parsing of the +file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/json','scan',['scan_ndjson'])}} diff --git a/docs/user-guide/io/multiple.md b/docs/user-guide/io/multiple.md new file mode 100644 index 000000000000..c5a66b03940f --- /dev/null +++ b/docs/user-guide/io/multiple.md @@ -0,0 +1,40 @@ +## Dealing with multiple files. + +Polars can deal with multiple files differently depending on your needs and memory strain. + +Let's create some files to give us some context: + +{{code_block('user-guide/io/multiple','create',['write_csv'])}} + +## Reading into a single `DataFrame` + +To read multiple files into a single `DataFrame`, we can use globbing patterns: + +{{code_block('user-guide/io/multiple','read',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:create" +--8<-- "python/user-guide/io/multiple.py:read" +``` + +To see how this works we can take a look at the query plan. Below we see that all files are read separately and +concatenated into a single `DataFrame`. `Polars` will try to parallelize the reading. + +{{code_block('user-guide/io/multiple','graph',['show_graph'])}} + +```python exec="on" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:creategraph" +``` + +## Reading and processing in parallel + +If your files don't have to be in a single table you can also build a query plan for each file and execute them in parallel +on the `Polars` thread pool. + +All query plan execution is embarrassingly parallel and doesn't require any communication. + +{{code_block('user-guide/io/multiple','glob',['scan_csv'])}} + +```python exec="on" result="text" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:glob" +``` diff --git a/docs/user-guide/io/parquet.md b/docs/user-guide/io/parquet.md new file mode 100644 index 000000000000..c08efc2e1b9b --- /dev/null +++ b/docs/user-guide/io/parquet.md @@ -0,0 +1,25 @@ +# Parquet + +Loading or writing [`Parquet` files](https://parquet.apache.org/) is lightning fast as the layout of data in a Polars `DataFrame` in memory mirrors the layout of a Parquet file on disk in many respects. + +Unlike CSV, Parquet is a columnar format. This means that the data is stored in columns rather than rows. This is a more efficient way of storing data as it allows for better compression and faster access to data. + +## Read + +We can read a `Parquet` file into a `DataFrame` using the `read_parquet` function: + +{{code_block('user-guide/io/parquet','read',['read_parquet'])}} + +## Write + +{{code_block('user-guide/io/parquet','write',['write_parquet'])}} + +## Scan + +`Polars` allows you to _scan_ a `Parquet` input. Scanning delays the actual parsing of the file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/parquet','scan',['scan_parquet'])}} + +If you want to know why this is desirable, you can read more about those `Polars` optimizations [here](../concepts/lazy-vs-eager.md). + +When we scan a `Parquet` file stored in the cloud, we can also apply predicate and projection pushdowns. This can significantly reduce the amount of data that needs to be downloaded. For scanning a Parquet file in the cloud, see [Cloud storage](cloud-storage.md/#scanning-from-cloud-storage-with-query-optimisation). diff --git a/docs/user-guide/lazy/execution.md b/docs/user-guide/lazy/execution.md new file mode 100644 index 000000000000..975f52a0ac4a --- /dev/null +++ b/docs/user-guide/lazy/execution.md @@ -0,0 +1,79 @@ +# Query execution + +Our example query on the Reddit dataset is: + +{{code_block('user-guide/lazy/execution','df',['scan_csv'])}} + +If we were to run the code above on the Reddit CSV the query would not be evaluated. Instead Polars takes each line of code, adds it to the internal query graph and optimizes the query graph. + +When we execute the code Polars executes the optimized query graph by default. + +### Execution on the full dataset + +We can execute our query on the full dataset by calling the `.collect` method on the query. + +{{code_block('user-guide/lazy/execution','collect',['scan_csv','collect'])}} + +```text +shape: (14_029, 6) +┌─────────┬───────────────────────────┬─────────────┬────────────┬───────────────┬────────────┐ +│ id ┆ name ┆ created_utc ┆ updated_on ┆ comment_karma ┆ link_karma │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═════════╪═══════════════════════════╪═════════════╪════════════╪═══════════════╪════════════╡ +│ 6 ┆ TAOJIANLONG_JASONBROKEN ┆ 1397113510 ┆ 1536527864 ┆ 4 ┆ 0 │ +│ 17 ┆ SSAIG_JASONBROKEN ┆ 1397113544 ┆ 1536527864 ┆ 1 ┆ 0 │ +│ 19 ┆ FDBVFDSSDGFDS_JASONBROKEN ┆ 1397113552 ┆ 1536527864 ┆ 3 ┆ 0 │ +│ 37 ┆ IHATEWHOWEARE_JASONBROKEN ┆ 1397113636 ┆ 1536527864 ┆ 61 ┆ 0 │ +│ … ┆ … ┆ … ┆ … ┆ … ┆ … │ +│ 1229384 ┆ DSFOX ┆ 1163177415 ┆ 1536497412 ┆ 44411 ┆ 7917 │ +│ 1229459 ┆ NEOCARTY ┆ 1163177859 ┆ 1536533090 ┆ 40 ┆ 0 │ +│ 1229587 ┆ TEHSMA ┆ 1163178847 ┆ 1536497412 ┆ 14794 ┆ 5707 │ +│ 1229621 ┆ JEREMYLOW ┆ 1163179075 ┆ 1536497412 ┆ 411 ┆ 1063 │ +└─────────┴───────────────────────────┴─────────────┴────────────┴───────────────┴────────────┘ +``` + +Above we see that from the 10 million rows there are 14,029 rows that match our predicate. + +With the default `collect` method Polars processes all of your data as one batch. This means that all the data has to fit into your available memory at the point of peak memory usage in your query. + +!!! warning "Reusing `LazyFrame` objects" + + Remember that `LazyFrame`s are query plans i.e. a promise on computation and is not guaranteed to cache common subplans. This means that every time you reuse it in separate downstream queries after it is defined, it is computed all over again. If you define an operation on a `LazyFrame` that doesn't maintain row order (such as a `group_by`), then the order will also change every time it is run. To avoid this, use `maintain_order=True` arguments for such operations. + +### Execution on larger-than-memory data + +If your data requires more memory than you have available Polars may be able to process the data in batches using _streaming_ mode. To use streaming mode you simply pass the `streaming=True` argument to `collect` + +{{code_block('user-guide/lazy/execution','stream',['scan_csv','collect'])}} + +We look at [streaming in more detail here](streaming.md). + +### Execution on a partial dataset + +While you're writing, optimizing or checking your query on a large dataset, querying all available data may lead to a slow development process. + +You can instead execute the query with the `.fetch` method. The `.fetch` method takes a parameter `n_rows` and tries to 'fetch' that number of rows at the data source. The number of rows cannot be guaranteed, however, as the lazy API does not count how many rows there are at each stage of the query. + +Here we "fetch" 100 rows from the source file and apply the predicates. + +{{code_block('user-guide/lazy/execution','partial',['scan_csv','collect','fetch'])}} + +```text +shape: (27, 6) +┌───────┬───────────────────────────┬─────────────┬────────────┬───────────────┬────────────┐ +│ id ┆ name ┆ created_utc ┆ updated_on ┆ comment_karma ┆ link_karma │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═══════╪═══════════════════════════╪═════════════╪════════════╪═══════════════╪════════════╡ +│ 6 ┆ TAOJIANLONG_JASONBROKEN ┆ 1397113510 ┆ 1536527864 ┆ 4 ┆ 0 │ +│ 17 ┆ SSAIG_JASONBROKEN ┆ 1397113544 ┆ 1536527864 ┆ 1 ┆ 0 │ +│ 19 ┆ FDBVFDSSDGFDS_JASONBROKEN ┆ 1397113552 ┆ 1536527864 ┆ 3 ┆ 0 │ +│ 37 ┆ IHATEWHOWEARE_JASONBROKEN ┆ 1397113636 ┆ 1536527864 ┆ 61 ┆ 0 │ +│ … ┆ … ┆ … ┆ … ┆ … ┆ … │ +│ 77763 ┆ LUNCHY ┆ 1137599510 ┆ 1536528275 ┆ 65 ┆ 0 │ +│ 77765 ┆ COMPOSTELLAS ┆ 1137474000 ┆ 1536528276 ┆ 6 ┆ 0 │ +│ 77766 ┆ GENERICBOB ┆ 1137474000 ┆ 1536528276 ┆ 291 ┆ 14 │ +│ 77768 ┆ TINHEADNED ┆ 1139665457 ┆ 1536497404 ┆ 4434 ┆ 103 │ +└───────┴───────────────────────────┴─────────────┴────────────┴───────────────┴────────────┘ +``` diff --git a/docs/user-guide/lazy/optimizations.md b/docs/user-guide/lazy/optimizations.md new file mode 100644 index 000000000000..576413833a3a --- /dev/null +++ b/docs/user-guide/lazy/optimizations.md @@ -0,0 +1,17 @@ +# Optimizations + +If you use `Polars`' lazy API, `Polars` will run several optimizations on your query. Some of them are executed up front, +others are determined just in time as the materialized data comes in. + +Here is a non-complete overview of optimizations done by polars, what they do and how often they run. + +| Optimization | Explanation | runs | +| -------------------------- | ------------------------------------------------------------------------------------------------------------ | ----------------------------- | +| Predicate pushdown | Applies filters as early as possible/ at scan level. | 1 time | +| Projection pushdown | Select only the columns that are needed at the scan level. | 1 time | +| Slice pushdown | Only load the required slice from the scan level. Don't materialize sliced outputs (e.g. join.head(10)). | 1 time | +| Common subplan elimination | Cache subtrees/file scans that are used by multiple subtrees in the query plan. | 1 time | +| Simplify expressions | Various optimizations, such as constant folding and replacing expensive operations with faster alternatives. | until fixed point | +| Join ordering | Estimates the branches of joins that should be executed first in order to reduce memory pressure. | 1 time | +| Type coercion | Coerce types such that operations succeed and run on minimal required memory. | until fixed point | +| Cardinality estimation | Estimates cardinality in order to determine optimal group by strategy. | 0/n times; dependent on query | diff --git a/docs/user-guide/lazy/query-plan.md b/docs/user-guide/lazy/query-plan.md new file mode 100644 index 000000000000..c48a3f8a099c --- /dev/null +++ b/docs/user-guide/lazy/query-plan.md @@ -0,0 +1,96 @@ +# Query plan + +For any lazy query `Polars` has both: + +- a non-optimized plan with the set of steps code as we provided it and +- an optimized plan with changes made by the query optimizer + +We can understand both the non-optimized and optimized query plans with visualization and by printing them as text. + +
+```python exec="on" result="text" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:setup" +``` +
+ +Below we consider the following query: + +{{code_block('user-guide/lazy/query-plan','plan',[])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:plan" +``` + +## Non-optimized query plan + +### Graphviz visualization + +First we visualise the non-optimized plan by setting `optimized=False`. + +{{code_block('user-guide/lazy/query-plan','showplan',['show_graph'])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:createplan" +``` + +The query plan visualization should be read from bottom to top. In the visualization: + +- each box corresponds to a stage in the query plan +- the `sigma` stands for `SELECTION` and indicates any filter conditions +- the `pi` stands for `PROJECTION` and indicates choosing a subset of columns + +### Printed query plan + +We can also print the non-optimized plan with `explain(optimized=False)` + +{{code_block('user-guide/lazy/query-plan','describe',['explain'])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:describe" +``` + +```text +FILTER [(col("comment_karma")) > (0)] FROM WITH_COLUMNS: + [col("name").str.uppercase()] + + CSV SCAN data/reddit.csv + PROJECT */6 COLUMNS +``` + +The printed plan should also be read from bottom to top. This non-optimized plan is roughly equal to: + +- read from the `data/reddit.csv` file +- read all 6 columns (where the * wildcard in PROJECT \*/6 COLUMNS means take all columns) +- transform the `name` column to uppercase +- apply a filter on the `comment_karma` column + +## Optimized query plan + +Now we visualize the optimized plan with `show_graph`. + +{{code_block('user-guide/lazy/query-plan','show',['show_graph'])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:createplan2" +``` + +We can also print the optimized plan with `explain` + +{{code_block('user-guide/lazy/query-plan','optimized',['explain'])}} + +```text + WITH_COLUMNS: + [col("name").str.uppercase()] + + CSV SCAN data/reddit.csv + PROJECT */6 COLUMNS + SELECTION: [(col("comment_karma")) > (0)] +``` + +The optimized plan is to: + +- read the data from the Reddit CSV +- apply the filter on the `comment_karma` column while the CSV is being read line-by-line +- transform the `name` column to uppercase + +In this case the query optimizer has identified that the `filter` can be applied while the CSV is read from disk rather than reading the whole file into memory and then applying the filter. This optimization is called _Predicate Pushdown_. diff --git a/docs/user-guide/lazy/schemas.md b/docs/user-guide/lazy/schemas.md new file mode 100644 index 000000000000..77d2be54b722 --- /dev/null +++ b/docs/user-guide/lazy/schemas.md @@ -0,0 +1,60 @@ +# Schema + +The schema of a Polars `DataFrame` or `LazyFrame` sets out the names of the columns and their datatypes. You can see the schema with the `.schema` method on a `DataFrame` or `LazyFrame` + +{{code_block('user-guide/lazy/schema','schema',['DataFrame','lazy'])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:setup" +--8<-- "python/user-guide/lazy/schema.py:schema" +``` + +The schema plays an important role in the lazy API. + +## Type checking in the lazy API + +One advantage of the lazy API is that Polars will check the schema before any data is processed. This check happens when you execute your lazy query. + +We see how this works in the following simple example where we call the `.round` expression on the integer `bar` column. + +{{code_block('user-guide/lazy/schema','typecheck',['lazy','with_columns'])}} + +The `.round` expression is only valid for columns with a floating point dtype. Calling `.round` on an integer column means the operation will raise an `InvalidOperationError` when we evaluate the query with `collect`. This schema check happens before the data is processed when we call `collect`. + +`python exec="on" result="text" session="user-guide/lazy/schemas"` + +If we executed this query in eager mode the error would only be found once the data had been processed in all earlier steps. + +When we execute a lazy query Polars checks for any potential `InvalidOperationError` before the time-consuming step of actually processing the data in the pipeline. + +## The lazy API must know the schema + +In the lazy API the Polars query optimizer must be able to infer the schema at every step of a query plan. This means that operations where the schema is not knowable in advance cannot be used with the lazy API. + +The classic example of an operation where the schema is not knowable in advance is a `.pivot` operation. In a `.pivot` the new column names come from data in one of the columns. As these column names cannot be known in advance a `.pivot` is not available in the lazy API. + +## Dealing with operations not available in the lazy API + +If your pipeline includes an operation that is not available in the lazy API it is normally best to: + +- run the pipeline in lazy mode up until that point +- execute the pipeline with `.collect` to materialize a `DataFrame` +- do the non-lazy operation on the `DataFrame` +- convert the output back to a `LazyFrame` with `.lazy` and continue in lazy mode + +We show how to deal with a non-lazy operation in this example where we: + +- create a simple `DataFrame` +- convert it to a `LazyFrame` with `.lazy` +- do a transformation using `.with_columns` +- execute the query before the pivot with `.collect` to get a `DataFrame` +- do the `.pivot` on the `DataFrame` +- convert back in lazy mode +- do a `.filter` +- finish by executing the query with `.collect` to get a `DataFrame` + +{{code_block('user-guide/lazy/schema','lazyeager',['collect','pivot','filter'])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:lazyeager" +``` diff --git a/docs/user-guide/lazy/streaming.md b/docs/user-guide/lazy/streaming.md new file mode 100644 index 000000000000..3f9d268443ca --- /dev/null +++ b/docs/user-guide/lazy/streaming.md @@ -0,0 +1,3 @@ +# Streaming + +--8<-- "docs/_build/snippets/under_construction.md" diff --git a/docs/user-guide/lazy/using.md b/docs/user-guide/lazy/using.md new file mode 100644 index 000000000000..d777557da550 --- /dev/null +++ b/docs/user-guide/lazy/using.md @@ -0,0 +1,37 @@ +# Usage + +With the lazy API, Polars doesn't run each query line-by-line but instead processes the full query end-to-end. To get the most out of Polars it is important that you use the lazy API because: + +- the lazy API allows Polars to apply automatic query optimization with the query optimizer +- the lazy API allows you to work with larger than memory datasets using streaming +- the lazy API can catch schema errors before processing the data + +Here we see how to use the lazy API starting from either a file or an existing `DataFrame`. + +## Using the lazy API from a file + +In the ideal case we would use the lazy API right from a file as the query optimizer may help us to reduce the amount of data we read from the file. + +We create a lazy query from the Reddit CSV data and apply some transformations. + +By starting the query with `pl.scan_csv` we are using the lazy API. + +{{code_block('user-guide/lazy/using','dataframe',['scan_csv','with_columns','filter','col'])}} + +A `pl.scan_` function is available for a number of file types including CSV, IPC, Parquet and JSON. + +In this query we tell Polars that we want to: + +- load data from the Reddit CSV file +- convert the `name` column to uppercase +- apply a filter to the `comment_karma` column + +The lazy query will not be executed at this point. See this page on [executing lazy queries](execution.md) for more on running lazy queries. + +## Using the lazy API from a `DataFrame` + +An alternative way to access the lazy API is to call `.lazy` on a `DataFrame` that has already been created in memory. + +{{code_block('user-guide/lazy/using','fromdf',['lazy'])}} + +By calling `.lazy` we convert the `DataFrame` to a `LazyFrame`. diff --git a/docs/user-guide/migration/pandas.md b/docs/user-guide/migration/pandas.md new file mode 100644 index 000000000000..d781ae290f96 --- /dev/null +++ b/docs/user-guide/migration/pandas.md @@ -0,0 +1,328 @@ +# Coming from Pandas + +Here we set out the key points that anyone who has experience with `Pandas` and wants to +try `Polars` should know. We include both differences in the concepts the libraries are +built on and differences in how you should write `Polars` code compared to `Pandas` +code. + +## Differences in concepts between `Polars` and `Pandas` + +### `Polars` does not have a multi-index/index + +`Pandas` gives a label to each row with an index. `Polars` does not use an index and +each row is indexed by its integer position in the table. + +Polars aims to have predictable results and readable queries, as such we think an index does not help us reach that +objective. We believe the semantics of a query should not change by the state of an index or a `reset_index` call. + +In Polars a DataFrame will always be a 2D table with heterogeneous data-types. The data-types may have nesting, but the +table itself will not. +Operations like resampling will be done by specialized functions or methods that act like 'verbs' on a table explicitly +stating the columns that that 'verb' operates on. As such, it is our conviction that not having indices make things simpler, +more explicit, more readable and less error-prone. + +Note that an 'index' data structure as known in databases will be used by polars as an optimization technique. + +### `Polars` uses Apache Arrow arrays to represent data in memory while `Pandas` uses `Numpy` arrays + +`Polars` represents data in memory with Arrow arrays while `Pandas` represents data in +memory with `Numpy` arrays. Apache Arrow is an emerging standard for in-memory columnar +analytics that can accelerate data load times, reduce memory usage and accelerate +calculations. + +`Polars` can convert data to `Numpy` format with the `to_numpy` method. + +### `Polars` has more support for parallel operations than `Pandas` + +`Polars` exploits the strong support for concurrency in Rust to run many operations in +parallel. While some operations in `Pandas` are multi-threaded the core of the library +is single-threaded and an additional library such as `Dask` must be used to parallelize +operations. + +### `Polars` can lazily evaluate queries and apply query optimization + +Eager evaluation is when code is evaluated as soon as you run the code. Lazy evaluation +is when running a line of code means that the underlying logic is added to a query plan +rather than being evaluated. + +`Polars` supports eager evaluation and lazy evaluation whereas `Pandas` only supports +eager evaluation. The lazy evaluation mode is powerful because `Polars` carries out +automatic query optimization when it examines the query plan and looks for ways to +accelerate the query or reduce memory usage. + +`Dask` also supports lazy evaluation when it generates a query plan. However, `Dask` +does not carry out query optimization on the query plan. + +## Key syntax differences + +Users coming from `Pandas` generally need to know one thing... + +``` +polars != pandas +``` + +If your `Polars` code looks like it could be `Pandas` code, it might run, but it likely +runs slower than it should. + +Let's go through some typical `Pandas` code and see how we might rewrite it in `Polars`. + +### Selecting data + +As there is no index in `Polars` there is no `.loc` or `iloc` method in `Polars` - and +there is also no `SettingWithCopyWarning` in `Polars`. + +However, the best way to select data in `Polars` is to use the expression API. For +example, if you want to select a column in `Pandas` you can do one of the following: + +```python +df['a'] +df.loc[:,'a'] +``` + +but in `Polars` you would use the `.select` method: + +```python +df.select('a') +``` + +If you want to select rows based on the values then in `Polars` you use the `.filter` +method: + +```python +df.filter(pl.col('a') < 10) +``` + +As noted in the section on expressions below, `Polars` can run operations in `.select` +and `filter` in parallel and `Polars` can carry out query optimization on the full set +of data selection criteria. + +### Be lazy + +Working in lazy evaluation mode is straightforward and should be your default in +`Polars` as the lazy mode allows `Polars` to do query optimization. + +We can run in lazy mode by either using an implicitly lazy function (such as `scan_csv`) +or explicitly using the `lazy` method. + +Take the following simple example where we read a CSV file from disk and do a group by. +The CSV file has numerous columns but we just want to do a group by on one of the id +columns (`id1`) and then sum by a value column (`v1`). In `Pandas` this would be: + +```python +df = pd.read_csv(csv_file, usecols=['id1','v1']) +grouped_df = df.loc[:,['id1','v1']].groupby('id1').sum('v1') +``` + +In `Polars` you can build this query in lazy mode with query optimization and evaluate +it by replacing the eager `Pandas` function `read_csv` with the implicitly lazy `Polars` +function `scan_csv`: + +```python +df = pl.scan_csv(csv_file) +grouped_df = df.group_by('id1').agg(pl.col('v1').sum()).collect() +``` + +`Polars` optimizes this query by identifying that only the `id1` and `v1` columns are +relevant and so will only read these columns from the CSV. By calling the `.collect` +method at the end of the second line we instruct `Polars` to eagerly evaluate the query. + +If you do want to run this query in eager mode you can just replace `scan_csv` with +`read_csv` in the `Polars` code. + +Read more about working with lazy evaluation in the +[lazy API](../lazy/using.md) section. + +### Express yourself + +A typical `Pandas` script consists of multiple data transformations that are executed +sequentially. However, in `Polars` these transformations can be executed in parallel +using expressions. + +#### Column assignment + +We have a dataframe `df` with a column called `value`. We want to add two new columns, a +column called `tenXValue` where the `value` column is multiplied by 10 and a column +called `hundredXValue` where the `value` column is multiplied by 100. + +In `Pandas` this would be: + +```python +df["tenXValue"] = df["value"] * 10 +df["hundredXValue"] = df["value"] * 100 +``` + +These column assignments are executed sequentially. + +In `Polars` we add columns to `df` using the `.with_columns` method and name them with +the `.alias` method: + +```python +df.with_columns( + (pl.col("value") * 10).alias("tenXValue"), + (pl.col("value") * 100).alias("hundredXValue"), +) +``` + +These column assignments are executed in parallel. + +#### Column assignment based on predicate + +In this case we have a dataframe `df` with columns `a`,`b` and `c`. We want to re-assign +the values in column `a` based on a condition. When the value in column `c` is equal to +2 then we replace the value in `a` with the value in `b`. + +In `Pandas` this would be: + +```python +df.loc[df["c"] == 2, "a"] = df.loc[df["c"] == 2, "b"] +``` + +while in `Polars` this would be: + +```python +df.with_columns( + pl.when(pl.col("c") == 2) + .then(pl.col("b")) + .otherwise(pl.col("a")).alias("a") +) +``` + +The `Polars` way is pure in that the original `DataFrame` is not modified. The `mask` is +also not computed twice as in `Pandas` (you could prevent this in `Pandas`, but that +would require setting a temporary variable). + +Additionally `Polars` can compute every branch of an `if -> then -> otherwise` in +parallel. This is valuable, when the branches get more expensive to compute. + +#### Filtering + +We want to filter the dataframe `df` with housing data based on some criteria. + +In `Pandas` you filter the dataframe by passing Boolean expressions to the `loc` method: + +```python +df.loc[(df['sqft_living'] > 2500) & (df['price'] < 300000)] +``` + +while in `Polars` you call the `filter` method: + +```python +df.filter( + (pl.col("m2_living") > 2500) & (pl.col("price") < 300000) +) +``` + +The query optimizer in `Polars` can also detect if you write multiple filters separately +and combine them into a single filter in the optimized plan. + +## `Pandas` transform + +The `Pandas` documentation demonstrates an operation on a group by called `transform`. In +this case we have a dataframe `df` and we want a new column showing the number of rows +in each group. + +In `Pandas` we have: + +```python +df = pd.DataFrame({ + "type": ["m", "n", "o", "m", "m", "n", "n"], + "c": [1, 1, 1, 2, 2, 2, 2], +}) + +df["size"] = df.groupby("c")["type"].transform(len) +``` + +Here `Pandas` does a group by on `"c"`, takes column `"type"`, computes the group length +and then joins the result back to the original `DataFrame` producing: + +``` + c type size +0 1 m 3 +1 1 n 3 +2 1 o 3 +3 2 m 4 +4 2 m 4 +5 2 n 4 +6 2 n 4 +``` + +In `Polars` the same can be achieved with `window` functions: + +```python +df.select( + pl.all(), + pl.col("type").count().over("c").alias("size") +) +``` + +``` +shape: (7, 3) +┌─────┬──────┬──────┐ +│ c ┆ type ┆ size │ +│ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ u32 │ +╞═════╪══════╪══════╡ +│ 1 ┆ m ┆ 3 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 1 ┆ n ┆ 3 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 1 ┆ o ┆ 3 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 │ +└─────┴──────┴──────┘ +``` + +Because we can store the whole operation in a single expression, we can combine several +`window` functions and even combine different groups! + +`Polars` will cache window expressions that are applied over the same group, so storing +them in a single `select` is both convenient **and** optimal. In the following example +we look at a case where we are calculating group statistics over `"c"` twice: + +```python +df.select( + pl.all(), + pl.col("c").count().over("c").alias("size"), + pl.col("c").sum().over("type").alias("sum"), + pl.col("c").reverse().over("c").flatten().alias("reverse_type") +) +``` + +``` +shape: (7, 5) +┌─────┬──────┬──────┬─────┬──────────────┐ +│ c ┆ type ┆ size ┆ sum ┆ reverse_type │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ u32 ┆ i64 ┆ i64 │ +╞═════╪══════╪══════╪═════╪══════════════╡ +│ 1 ┆ m ┆ 3 ┆ 5 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 1 ┆ n ┆ 3 ┆ 5 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 1 ┆ o ┆ 3 ┆ 1 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 ┆ 5 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 ┆ 5 ┆ 1 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 ┆ 5 ┆ 1 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 ┆ 5 ┆ 1 │ +└─────┴──────┴──────┴─────┴──────────────┘ +``` + +## Missing data + +`Pandas` uses `NaN` and/or `None` values to indicate missing values depending on the dtype of the column. In addition the behaviour in `Pandas` varies depending on whether the default dtypes or optional nullable arrays are used. In `Polars` missing data corresponds to a `null` value for all data types. + +For float columns `Polars` permits the use of `NaN` values. These `NaN` values are not considered to be missing data but instead a special floating point value. + +In `Pandas` an integer column with missing values is cast to be a float column with `NaN` values for the missing values (unless using optional nullable integer dtypes). In `Polars` any missing values in an integer column are simply `null` values and the column remains an integer column. + +See the [missing data](../expressions/null.md) section for more details. diff --git a/docs/user-guide/migration/spark.md b/docs/user-guide/migration/spark.md new file mode 100644 index 000000000000..ea1a41abbd71 --- /dev/null +++ b/docs/user-guide/migration/spark.md @@ -0,0 +1,158 @@ +# Coming from Apache Spark + +## Column-based API vs. Row-based API + +Whereas the `Spark` `DataFrame` is analogous to a collection of rows, a `Polars` `DataFrame` is closer to a collection of columns. This means that you can combine columns in `Polars` in ways that are not possible in `Spark`, because `Spark` preserves the relationship of the data in each row. + +Consider this sample dataset: + +```python +import polars as pl + +df = pl.DataFrame({ + "foo": ["a", "b", "c", "d", "d"], + "bar": [1, 2, 3, 4, 5], +}) + +dfs = spark.createDataFrame( + [ + ("a", 1), + ("b", 2), + ("c", 3), + ("d", 4), + ("d", 5), + ], + schema=["foo", "bar"], +) +``` + +### Example 1: Combining `head` and `sum` + +In `Polars` you can write something like this: + +```python +df.select( + pl.col("foo").sort().head(2), + pl.col("bar").filter(pl.col("foo") == "d").sum() +) +``` + +Output: + +``` +shape: (2, 2) +┌─────┬─────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ str ┆ i64 │ +╞═════╪═════╡ +│ a ┆ 9 │ +├╌╌╌╌╌┼╌╌╌╌╌┤ +│ b ┆ 9 │ +└─────┴─────┘ +``` + +The expressions on columns `foo` and `bar` are completely independent. Since the expression on `bar` returns a single value, that value is repeated for each value output by the expression on `foo`. But `a` and `b` have no relation to the data that produced the sum of `9`. + +To do something similar in `Spark`, you'd need to compute the sum separately and provide it as a literal: + +```python +from pyspark.sql.functions import col, sum, lit + +bar_sum = ( + dfs + .where(col("foo") == "d") + .groupBy() + .agg(sum(col("bar"))) + .take(1)[0][0] +) + +( + dfs + .orderBy("foo") + .limit(2) + .withColumn("bar", lit(bar_sum)) + .show() +) +``` + +Output: + +``` ++---+---+ +|foo|bar| ++---+---+ +| a| 9| +| b| 9| ++---+---+ +``` + +### Example 2: Combining Two `head`s + +In `Polars` you can combine two different `head` expressions on the same DataFrame, provided that they return the same number of values. + +```python +df.select( + pl.col("foo").sort().head(2), + pl.col("bar").sort(descending=True).head(2), +) +``` + +Output: + +``` +shape: (3, 2) +┌─────┬─────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ str ┆ i64 │ +╞═════╪═════╡ +│ a ┆ 5 │ +├╌╌╌╌╌┼╌╌╌╌╌┤ +│ b ┆ 4 │ +└─────┴─────┘ +``` + +Again, the two `head` expressions here are completely independent, and the pairing of `a` to `5` and `b` to `4` results purely from the juxtaposition of the two columns output by the expressions. + +To accomplish something similar in `Spark`, you would need to generate an artificial key that enables you to join the values in this way. + +```python +from pyspark.sql import Window +from pyspark.sql.functions import row_number + +foo_dfs = ( + dfs + .withColumn( + "rownum", + row_number().over(Window.orderBy("foo")) + ) +) + +bar_dfs = ( + dfs + .withColumn( + "rownum", + row_number().over(Window.orderBy(col("bar").desc())) + ) +) + +( + foo_dfs.alias("foo") + .join(bar_dfs.alias("bar"), on="rownum") + .select("foo.foo", "bar.bar") + .limit(2) + .show() +) +``` + +Output: + +``` ++---+---+ +|foo|bar| ++---+---+ +| a| 5| +| b| 4| ++---+---+ +``` diff --git a/docs/user-guide/misc/alternatives.md b/docs/user-guide/misc/alternatives.md new file mode 100644 index 000000000000..a5544e7db354 --- /dev/null +++ b/docs/user-guide/misc/alternatives.md @@ -0,0 +1,66 @@ +# Alternatives + +These are some tools that share similar functionality to what polars does. + +- Pandas + + A very versatile tool for small data. Read [10 things I hate about pandas](https://wesmckinney.com/blog/apache-arrow-pandas-internals/) + written by the author himself. Polars has solved all those 10 things. + Polars is a versatile tool for small and large data with a more predictable, less ambiguous, and stricter API. + +- Pandas the API + + The API of pandas was designed for in memory data. This makes it a poor fit for performant analysis on large data + (read anything that does not fit into RAM). Any tool that tries to distribute that API will likely have a + suboptimal query plan compared to plans that follow from a declarative API like SQL or Polars' API. + +- Dask + + Parallelizes existing single-threaded libraries like `NumPy` and `Pandas`. As a consumer of those libraries Dask + therefore has less control over low level performance and semantics. + Those libraries are treated like a black box. + On a single machine the parallelization effort can also be seriously stalled by pandas strings. + Pandas strings, by default, are stored as python objects in + numpy arrays meaning that any operation on them is GIL bound and therefore single threaded. This can be circumvented + by multi-processing but has a non-trivial cost. + +- Modin + + Similar to Dask + +- Vaex + + Vaexs method of out-of-core analysis is memory mapping files. This works until it doesn't. For instance parquet + or csv files first need to be read and converted to a file format that can be memory mapped. Another downside is + that the OS determines when pages will be swapped. Operations that need a full data shuffle, such as + sorts, have terrible performance on memory mapped data. + Polars' out of core processing is not based on memory mapping, but on streaming data in batches (and spilling to disk + if needed), we control which data must be hold in memory, not the OS, meaning that we don't have unexpected IO stalls. + +- DuckDB + + Polars and DuckDB have many similarities. DuckDB is focused on providing an in-process OLAP Sqlite alternative, + Polars is focused on providing a scalable `DataFrame` interface to many languages. Those different front-ends lead to + different optimization strategies and different algorithm prioritization. The interoperability between both is zero-copy. + See more: https://duckdb.org/docs/guides/python/polars + +- Spark + + Spark is designed for distributed workloads and uses the JVM. The setup for spark is complicated and the startup-time + is slow. On a single machine Polars has much better performance characteristics. If you need to process TB's of data + Spark is a better choice. + +- CuDF + + GPU's and CuDF are fast! + However, GPU's are not readily available and expensive in production. The amount of memory available on a GPU + is often a fraction of the available RAM. + This (and out-of-core) processing means that Polars can handle much larger data-sets. + Next to that Polars can be close in [performance to CuDF](https://zakopilo.hatenablog.jp/entry/2023/02/04/220552). + CuDF doesn't optimize your query, so is not uncommon that on ETL jobs Polars will be faster because it can elide + unneeded work and materializations. + +- Any + + Polars is written in Rust. This gives it strong safety, performance and concurrency guarantees. + Polars is written in a modular manner. Parts of Polars can be used in other query programs and can be added as a library. diff --git a/docs/user-guide/misc/contributing.md b/docs/user-guide/misc/contributing.md new file mode 100644 index 000000000000..abd4d4d229be --- /dev/null +++ b/docs/user-guide/misc/contributing.md @@ -0,0 +1,11 @@ +# Contributing + +See the [`CONTRIBUTING.md`](https://github.com/pola-rs/polars/blob/master/CONTRIBUTING.md) if you would like to contribute to the `Polars` project. + +If you're new to this we recommend starting out with contributing examples to the Python API documentation. The Python API docs are generated from the docstrings of the Python wrapper located in `polars/py-polars`. + +Here is an example [commit](https://github.com/pola-rs/polars/pull/3567/commits/5db9e335f3f2777dd1d6f80df765c6bca8f307b0) that adds a docstring. + +If you spot any gaps in this User Guide you can submit fixes to the [`pola-rs/polars`](https://github.com/pola-rs/polars) repo. + +Happy hunting! diff --git a/docs/user-guide/misc/multiprocessing.md b/docs/user-guide/misc/multiprocessing.md new file mode 100644 index 000000000000..4973da8c0155 --- /dev/null +++ b/docs/user-guide/misc/multiprocessing.md @@ -0,0 +1,104 @@ +# Multiprocessing + +TLDR: if you find that using Python's built-in `multiprocessing` module together with Polars results in a Polars error about multiprocessing methods, you should make sure you are using `spawn`, not `fork`, as the starting method: + +{{code_block('user-guide/misc/multiprocess','recommendation',[])}} + +## When not to use multiprocessing + +Before we dive into the details, it is important to emphasize that Polars has been built from the start to use all your CPU cores. +It does this by executing computations which can be done in parallel in separate threads. +For example, requesting two expressions in a `select` statement can be done in parallel, with the results only being combined at the end. +Another example is aggregating a value within groups using `group_by().agg()`, each group can be evaluated separately. +It is very unlikely that the `multiprocessing` module can improve your code performance in these cases. + +See [the optimizations section](../lazy/optimizations.md) for more optimizations. + +## When to use multiprocessing + +Although Polars is multithreaded, other libraries may be single-threaded. +When the other library is the bottleneck, and the problem at hand is parallelizable, it makes sense to use multiprocessing to gain a speed up. + +## The problem with the default multiprocessing config + +### Summary + +The [Python multiprocessing documentation](https://docs.python.org/3/library/multiprocessing.html) lists the three methods to create a process pool: + +1. spawn +1. fork +1. forkserver + +The description of fork is (as of 2022-10-15): + +> The parent process uses os.fork() to fork the Python interpreter. The child process, when it begins, is effectively identical to the parent process. All resources of the parent are inherited by the child process. Note that safely forking a multithreaded process is problematic. + +> Available on Unix only. The default on Unix. + +The short summary is: Polars is multithreaded as to provide strong performance out-of-the-box. +Thus, it cannot be combined with `fork`. +If you are on Unix (Linux, BSD, etc), you are using `fork`, unless you explicitly override it. + +The reason you may not have encountered this before is that pure Python code, and most Python libraries, are (mostly) single threaded. +Alternatively, you are on Windows or MacOS, on which `fork` is not even available as a method (for MacOS it was up to Python 3.7). + +Thus one should use `spawn`, or `forkserver`, instead. `spawn` is available on all platforms and the safest choice, and hence the recommended method. + +### Example + +The problem with `fork` is in the copying of the parent's process. +Consider the example below, which is a slightly modified example posted on the [Polars issue tracker](https://github.com/pola-rs/polars/issues/3144): + +{{code_block('user-guide/misc/multiprocess','example1',[])}} + +Using `fork` as the method, instead of `spawn`, will cause a dead lock. +Please note: Polars will not even start and raise the error on multiprocessing method being set wrong, but if the check had not been there, the deadlock would exist. + +The fork method is equivalent to calling `os.fork()`, which is a system call as defined in [the POSIX standard](https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html): + +> A process shall be created with a single thread. If a multi-threaded process calls fork(), the new process shall contain a replica of the calling thread and its entire address space, possibly including the states of mutexes and other resources. Consequently, to avoid errors, the child process may only execute async-signal-safe operations until such time as one of the exec functions is called. + +In contrast, `spawn` will create a completely new fresh Python interpreter, and not inherit the state of mutexes. + +So what happens in the code example? +For reading the file with `pl.read_parquet` the file has to be locked. +Then `os.fork()` is called, copying the state of the parent process, including mutexes. +Thus all child processes will copy the file lock in an acquired state, leaving them hanging indefinitely waiting for the file lock to be released, which never happens. + +What makes debugging these issues tricky is that `fork` can work. +Change the example to not having the call to `pl.read_parquet`: + +{{code_block('user-guide/misc/multiprocess','example2',[])}} + +This works fine. +Therefore debugging these issues in larger code bases, i.e. not the small toy examples here, can be a real pain, as a seemingly unrelated change can break your multiprocessing code. +In general, one should therefore never use the `fork` start method with multithreaded libraries unless there are very specific requirements that cannot be met otherwise. + +### Pro's and cons of fork + +Based on the example, you may think, why is `fork` available in Python to start with? + +First, probably because of historical reasons: `spawn` was added to Python in version 3.4, whilst `fork` has been part of Python from the 2.x series. + +Second, there are several limitations for `spawn` and `forkserver` that do not apply to `fork`, in particular all arguments should be pickable. +See the [Python multiprocessing docs](https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods) for more information. + +Third, because it is faster to create new processes compared to `spawn`, as `spawn` is effectively `fork` + creating a brand new Python process without the locks by calling [execv](https://pubs.opengroup.org/onlinepubs/9699919799/functions/exec.html). +Hence the warning in the Python docs that it is slower: there is more overhead to `spawn`. +However, in almost all cases, one would like to use multiple processes to speed up computations that take multiple minutes or even hours, meaning the overhead is negligible in the grand scheme of things. +And more importantly, it actually works in combination with multithreaded libraries. + +Fourth, `spawn` starts a new process, and therefore it requires code to be importable, in contrast to `fork`. +In particular, this means that when using `spawn` the relevant code should not be in the global scope, such as in Jupyter notebooks or in plain scripts. +Hence in the examples above, we define functions where we spawn within, and run those functions from a `__main__` clause. +This is not an issue for typical projects, but during quick experimentation in notebooks it could fail. + +## References + +1. https://docs.python.org/3/library/multiprocessing.html + +1. https://pythonspeed.com/articles/python-multiprocessing/ + +1. https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html + +1. https://bnikolic.co.uk/blog/python/parallelism/2019/11/13/python-forkserver-preload.html diff --git a/docs/user-guide/misc/reference-guides.md b/docs/user-guide/misc/reference-guides.md new file mode 100644 index 000000000000..c0e082d08447 --- /dev/null +++ b/docs/user-guide/misc/reference-guides.md @@ -0,0 +1,6 @@ +# Reference guides + +The api documentations with details on function / object signatures can be found here: + +- [Python](https://pola-rs.github.io/polars/py-polars/html/reference/index.html) +- [Rust](https://docs.rs/polars/latest/polars/) diff --git a/docs/user-guide/sql/create.md b/docs/user-guide/sql/create.md new file mode 100644 index 000000000000..a5a1922b7f23 --- /dev/null +++ b/docs/user-guide/sql/create.md @@ -0,0 +1,28 @@ +# CREATE + +In Polars, the `SQLContext` provides a way to execute SQL statements against `LazyFrames` and `DataFrames` using SQL syntax. One of the SQL statements that can be executed using `SQLContext` is the `CREATE TABLE` statement, which is used to create a new table. + +The syntax for the `CREATE TABLE` statement in Polars is as follows: + +``` +CREATE TABLE table_name +AS +SELECT ... +``` + +In this syntax, `table_name` is the name of the new table that will be created, and `SELECT ...` is a SELECT statement that defines the data that will be inserted into the table. + +Here's an example of how to use the `CREATE TABLE` statement in Polars: + +{{code_block('user-guide/sql/create','create',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/create.py:setup" +--8<-- "python/user-guide/sql/create.py:create" +``` + +In this example, we use the `execute()` method of the `SQLContext` to execute a `CREATE TABLE` statement that creates a new table called `older_people` based on a SELECT statement that selects all rows from the `my_table` DataFrame where the `age` column is greater than 30. + +!!! note Result + + Note that the result of a `CREATE TABLE` statement is not the table itself. The table is registered in the `SQLContext`. In case you want to turn the table back to a `DataFrame` you can use a `SELECT * FROM ...` statement diff --git a/docs/user-guide/sql/cte.md b/docs/user-guide/sql/cte.md new file mode 100644 index 000000000000..1129f6d19230 --- /dev/null +++ b/docs/user-guide/sql/cte.md @@ -0,0 +1,27 @@ +# Common Table Expressions + +Common Table Expressions (CTEs) are a feature of SQL that allow you to define a temporary named result set that can be referenced within a SQL statement. CTEs provide a way to break down complex SQL queries into smaller, more manageable pieces, making them easier to read, write, and maintain. + +A CTE is defined using the `WITH` keyword followed by a comma-separated list of subqueries, each of which defines a named result set that can be used in subsequent queries. The syntax for a CTE is as follows: + +``` +WITH cte_name AS ( + subquery +) +SELECT ... +``` + +In this syntax, `cte_name` is the name of the CTE, and `subquery` is the subquery that defines the result set. The CTE can then be referenced in subsequent queries as if it were a table or view. + +CTEs are particularly useful when working with complex queries that involve multiple levels of subqueries, as they allow you to break down the query into smaller, more manageable pieces that are easier to understand and debug. Additionally, CTEs can help improve query performance by allowing the database to optimize and cache the results of subqueries, reducing the number of times they need to be executed. + +Polars supports Common Table Expressions (CTEs) using the WITH clause in SQL syntax. Below is an example + +{{code_block('user-guide/sql/cte','cte',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/cte" +--8<-- "python/user-guide/sql/cte.py:setup" +--8<-- "python/user-guide/sql/cte.py:cte" +``` + +In this example, we use the `execute()` method of the `SQLContext` to execute a SQL query that includes a CTE. The CTE selects all rows from the `my_table` LazyFrame where the `age` column is greater than 30 and gives it the alias `older_people`. We then execute a second SQL query that selects all rows from the `older_people` CTE where the `name` column starts with the letter 'C'. diff --git a/docs/user-guide/sql/intro.md b/docs/user-guide/sql/intro.md new file mode 100644 index 000000000000..815231e3d59c --- /dev/null +++ b/docs/user-guide/sql/intro.md @@ -0,0 +1,106 @@ +# Introduction + +While Polars does support writing queries in SQL, it's recommended that users familiarize themselves with the [expression syntax](../concepts/expressions.md) for more readable and expressive code. As a primarily DataFrame library, new features will typically be added to the expression API first. However, if you already have an existing SQL codebase or prefer to use SQL, Polars also offers support for SQL queries. + +!!! note Execution + + In Polars, there is no separate SQL engine because Polars translates SQL queries into [expressions](../concepts/expressions.md), which are then executed using its built-in execution engine. This approach ensures that Polars maintains its performance and scalability advantages as a native DataFrame library while still providing users with the ability to work with SQL queries. + +## Context + +Polars uses the `SQLContext` to manage SQL queries . The context contains a dictionary mapping `DataFrames` and `LazyFrames` names to their corresponding datasets[^1]. The example below starts a `SQLContext`: + +{{code_block('user-guide/sql/intro','context',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:setup" +--8<-- "python/user-guide/sql/intro.py:context" +``` + +## Register Dataframes + +There are 2 ways to register DataFrames in the `SQLContext`: + +- register all `LazyFrames` and `DataFrames` in the global namespace +- register them one by one + +{{code_block('user-guide/sql/intro','register_context',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:register_context" +``` + +We can also register Pandas DataFrames by converting them to Polars first. + +{{code_block('user-guide/sql/intro','register_pandas',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:register_pandas" +``` + +!!! note Pandas + + Converting a Pandas DataFrame backed by Numpy to Polars triggers a conversion to the Arrow format. This conversion has a computation cost. Converting a Pandas DataFrame backed by Arrow on the other hand will be free or almost free. + +Once the `SQLContext` is initialized, we can register additional Dataframes or unregister existing Dataframes with: + +- `register` +- `register_globals` +- `register_many` +- `unregister` + +## Execute queries and collect results + +SQL queries are always executed in lazy mode to benefit from lazy optimizations, so we have 2 options to collect the result: + +- Set the parameter `eager_execution` to True in `SQLContext`. With this parameter, Polars will automatically collect SQL results +- Set the parameter `eager` to True when executing a query with `execute`, or collect the result with `collect`. + +We execute SQL queries by calling `execute` on a `SQLContext`. + +{{code_block('user-guide/sql/intro','execute',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:execute" +``` + +## Execute queries from multiple sources + +SQL queries can be executed just as easily from multiple sources. +In the example below, we register : + +- a CSV file loaded lazily +- a NDJSON file loaded lazily +- a Pandas DataFrame + +And we join them together with SQL. +Lazy reading allows to only load the necessary rows and columns from the files. + +In the same way, it's possible to register cloud datalakes (S3, Azure Data Lake). A PyArrow dataset can point to the datalake, then Polars can read it with `scan_pyarrow_dataset`. + +{{code_block('user-guide/sql/intro','execute_multiple_sources',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:prepare_multiple_sources" +--8<-- "python/user-guide/sql/intro.py:execute_multiple_sources" +--8<-- "python/user-guide/sql/intro.py:clean_multiple_sources" +``` + +[^1]: Additionally it also tracks the [common table expressions](./cte.md) as well. + +## Compatibility + +Polars does not support the full SQL language, in Polars you are allowed to: + +- Write a `CREATE` statements `CREATE TABLE xxx AS ...` +- Write a `SELECT` statements with all generic elements (`GROUP BY`, `WHERE`,`ORDER`,`LIMIT`,`JOIN`, ...) +- Write Common Table Expressions (CTE's) (`WITH tablename AS`) +- Show an overview of all tables `SHOW TABLES` + +The following is not yet supported: + +- `INSERT`, `UPDATE` or `DELETE` statements +- Table aliasing (e.g. `SELECT p.Name from pokemon AS p`) +- Meta queries such as `ANALYZE`, `EXPLAIN` + +In the upcoming sections we will cover each of the statements in more details. diff --git a/docs/user-guide/sql/select.md b/docs/user-guide/sql/select.md new file mode 100644 index 000000000000..d994191c5068 --- /dev/null +++ b/docs/user-guide/sql/select.md @@ -0,0 +1,72 @@ +# SELECT + +In Polars SQL, the `SELECT` statement is used to retrieve data from a table into a `DataFrame`. The basic syntax of a `SELECT` statement in Polars SQL is as follows: + +```sql +SELECT column1, column2, ... +FROM table_name; +``` + +Here, `column1`, `column2`, etc. are the columns that you want to select from the table. You can also use the wildcard `*` to select all columns. `table_name` is the name of the table or that you want to retrieve data from. In the sections below we will cover some of the more common SELECT variants + +{{code_block('user-guide/sql/select','df',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:setup" +--8<-- "python/user-guide/sql/select.py:df" +``` + +### GROUP BY + +The `GROUP BY` statement is used to group rows in a table by one or more columns and compute aggregate functions on each group. + +{{code_block('user-guide/sql/select','group_by',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:group_by" +``` + +### ORDER BY + +The `ORDER BY` statement is used to sort the result set of a query by one or more columns in ascending or descending order. + +{{code_block('user-guide/sql/select','orderby',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:orderby" +``` + +### JOIN + +{{code_block('user-guide/sql/select','join',['SQLregister_many','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:join" +``` + +### Functions + +Polars provides a wide range of SQL functions, including: + +- Mathematical functions: `ABS`, `EXP`, `LOG`, `ASIN`, `ACOS`, `ATAN`, etc. +- String functions: `LOWER`, `UPPER`, `LTRIM`, `RTRIM`, `STARTS_WITH`,`ENDS_WITH`. +- Aggregation functions: `SUM`, `AVG`, `MIN`, `MAX`, `COUNT`, `STDDEV`, `FIRST` etc. +- Array functions: `EXPLODE`, `UNNEST`,`ARRAY_SUM`,`ARRAY_REVERSE`, etc. + +For a full list of supported functions go the [API documentation](https://docs.rs/polars-sql/latest/src/polars_sql/keywords.rs.html). The example below demonstrates how to use a function in a query + +{{code_block('user-guide/sql/select','functions',['SQLquery'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:functions" +``` + +### Table Functions + +In the examples earlier we first generated a DataFrame which we registered in the `SQLContext`. Polars also support directly reading from CSV, Parquet, JSON and IPC in your SQL query using table functions `read_xxx`. + +{{code_block('user-guide/sql/select','tablefunctions',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:tablefunctions" +``` diff --git a/docs/user-guide/sql/show.md b/docs/user-guide/sql/show.md new file mode 100644 index 000000000000..70453ebcb6dd --- /dev/null +++ b/docs/user-guide/sql/show.md @@ -0,0 +1,22 @@ +# SHOW TABLES + +In Polars, the `SHOW TABLES` statement is used to list all the tables that have been registered in the current `SQLContext`. When you register a DataFrame with the `SQLContext`, you give it a name that can be used to refer to the DataFrame in subsequent SQL statements. The `SHOW TABLES` statement allows you to see a list of all the registered tables, along with their names. + +The syntax for the `SHOW TABLES` statement in Polars is as follows: + +``` +SHOW TABLES +``` + +Here's an example of how to use the `SHOW TABLES` statement in Polars: + +{{code_block('user-guide/sql/show','show',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/show" +--8<-- "python/user-guide/sql/show.py:setup" +--8<-- "python/user-guide/sql/show.py:show" +``` + +In this example, we create two DataFrames and register them with the `SQLContext` using different names. We then execute a `SHOW TABLES` statement using the `execute()` method of the `SQLContext` object, which returns a DataFrame containing a list of all the registered tables and their names. The resulting DataFrame is then printed using the `print()` function. + +Note that the `SHOW TABLES` statement only lists tables that have been registered with the current `SQLContext`. If you register a DataFrame with a different `SQLContext` or in a different Python session, it will not appear in the list of tables returned by `SHOW TABLES`. diff --git a/docs/user-guide/transformations/concatenation.md b/docs/user-guide/transformations/concatenation.md new file mode 100644 index 000000000000..8deff923acee --- /dev/null +++ b/docs/user-guide/transformations/concatenation.md @@ -0,0 +1,51 @@ +# Concatenation + +There are a number of ways to concatenate data from separate DataFrames: + +- two dataframes with **the same columns** can be **vertically** concatenated to make a **longer** dataframe +- two dataframes with the **same number of rows** and **non-overlapping columns** can be **horizontally** concatenated to make a **wider** dataframe +- two dataframes with **different numbers of rows and columns** can be **diagonally** concatenated to make a dataframe which might be longer and/ or wider. Where column names overlap values will be vertically concatenated. Where column names do not overlap new rows and columns will be added. Missing values will be set as `null` + +## Vertical concatenation - getting longer + +In a vertical concatenation you combine all of the rows from a list of `DataFrames` into a single longer `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','vertical',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:setup" +--8<-- "python/user-guide/transformations/concatenation.py:vertical" +``` + +Vertical concatenation fails when the dataframes do not have the same column names. + +## Horizontal concatenation - getting wider + +In a horizontal concatenation you combine all of the columns from a list of `DataFrames` into a single wider `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','horizontal',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:horizontal" +``` + +Horizontal concatenation fails when dataframes have overlapping columns or a different number of rows. + +## Diagonal concatenation - getting longer, wider and `null`ier + +In a diagonal concatenation you combine all of the row and columns from a list of `DataFrames` into a single longer and/or wider `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','cross',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:cross" +``` + +Diagonal concatenation generates nulls when the column names do not overlap. + +When the dataframe shapes do not match and we have an overlapping semantic key then [we can join the dataframes](joins.md) instead of concatenating them. + +## Rechunking + +Before a concatenation we have two dataframes `df1` and `df2`. Each column in `df1` and `df2` is in one or more chunks in memory. By default, during concatenation the chunks in each column are copied to a single new chunk - this is known as **rechunking**. Rechunking is an expensive operation, but is often worth it because future operations will be faster. +If you do not want Polars to rechunk the concatenated `DataFrame` you specify `rechunk = False` when doing the concatenation. diff --git a/docs/user-guide/transformations/joins.md b/docs/user-guide/transformations/joins.md new file mode 100644 index 000000000000..ad233cf060fb --- /dev/null +++ b/docs/user-guide/transformations/joins.md @@ -0,0 +1,183 @@ +# Joins + +## Join strategies + +`Polars` supports the following join strategies by specifying the `strategy` argument: + +| Strategy | Description | +| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `inner` | Returns row with matching keys in _both_ frames. Non-matching rows in either the left or right frame are discarded. | +| `left` | Returns all rows in the left dataframe, whether or not a match in the right-frame is found. Non-matching rows have their right columns null-filled. | +| `outer` | Returns all rows from both the left and right dataframe. If no match is found in one frame, columns from the other frame are null-filled. | +| `cross` | Returns the Cartesian product of all rows from the left frame with all rows from the right frame. Duplicates rows are retained; the table length of `A` cross-joined with `B` is always `len(A) × len(B)`. | +| `asof` | A left-join in which the match is performed on the _nearest_ key rather than on equal keys. | +| `semi` | Returns all rows from the left frame in which the join key is also present in the right frame. | +| `anti` | Returns all rows from the left frame in which the join key is _not_ present in the right frame. | + +### Inner join + +An `inner` join produces a `DataFrame` that contains only the rows where the join key exists in both `DataFrames`. Let's take for example the following two `DataFrames`: + +{{code_block('user-guide/transformations/joins','innerdf',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:setup" +--8<-- "python/user-guide/transformations/joins.py:innerdf" +``` + +

+ +{{code_block('user-guide/transformations/joins','innerdf2',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:innerdf2" +``` + +To get a `DataFrame` with the orders and their associated customer we can do an `inner` join on the `customer_id` column: + +{{code_block('user-guide/transformations/joins','inner',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:inner" +``` + +### Left join + +The `left` join produces a `DataFrame` that contains all the rows from the left `DataFrame` and only the rows from the right `DataFrame` where the join key exists in the left `DataFrame`. If we now take the example from above and want to have a `DataFrame` with all the customers and their associated orders (regardless of whether they have placed an order or not) we can do a `left` join: + +{{code_block('user-guide/transformations/joins','left',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:left" +``` + +Notice, that the fields for the customer with the `customer_id` of `3` are null, as there are no orders for this customer. + +### Outer join + +The `outer` join produces a `DataFrame` that contains all the rows from both `DataFrames`. Columns are null, if the join key does not exist in the source `DataFrame`. Doing an `outer` join on the two `DataFrames` from above produces a similar `DataFrame` to the `left` join: + +{{code_block('user-guide/transformations/joins','outer',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:outer" +``` + +### Cross join + +A `cross` join is a cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. + +{{code_block('user-guide/transformations/joins','df3',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df3" +``` + +

+ +{{code_block('user-guide/transformations/joins','df4',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df4" +``` + +We can now create a `DataFrame` containing all possible combinations of the colors and sizes with a `cross` join: + +{{code_block('user-guide/transformations/joins','cross',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:cross" +``` + +
+ +The `inner`, `left`, `outer` and `cross` join strategies are standard amongst dataframe libraries. We provide more details on the less familiar `semi`, `anti` and `asof` join strategies below. + +### Semi join + +The `semi` join returns all rows from the left frame in which the join key is also present in the right frame. Consider the following scenario: a car rental company has a `DataFrame` showing the cars that it owns with each car having a unique `id`. + +{{code_block('user-guide/transformations/joins','df5',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df5" +``` + +The company has another `DataFrame` showing each repair job carried out on a vehicle. + +{{code_block('user-guide/transformations/joins','df6',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df6" +``` + +You want to answer this question: which of the cars have had repairs carried out? + +An inner join does not answer this question directly as it produces a `DataFrame` with multiple rows for each car that has had multiple repair jobs: + +{{code_block('user-guide/transformations/joins','inner2',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:inner2" +``` + +However, a semi join produces a single row for each car that has had a repair job carried out. + +{{code_block('user-guide/transformations/joins','semi',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:semi" +``` + +### Anti join + +Continuing this example, an alternative question might be: which of the cars have **not** had a repair job carried out? An anti join produces a `DataFrame` showing all the cars from `df_cars` where the `id` is not present in the `df_repairs` `DataFrame`. + +{{code_block('user-guide/transformations/joins','anti',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:anti" +``` + +### Asof join + +An `asof` join is like a left join except that we match on nearest key rather than equal keys. +In `Polars` we can do an asof join with the `join` method and specifying `strategy="asof"`. However, for more flexibility we can use the `join_asof` method. + +Consider the following scenario: a stock market broker has a `DataFrame` called `df_trades` showing transactions it has made for different stocks. + +{{code_block('user-guide/transformations/joins','df7',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df7" +``` + +The broker has another `DataFrame` called `df_quotes` showing prices it has quoted for these stocks. + +{{code_block('user-guide/transformations/joins','df8',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df8" +``` + +You want to produce a `DataFrame` showing for each trade the most recent quote provided _before_ the trade. You do this with `join_asof` (using the default `strategy = "backward"`). +To avoid joining between trades on one stock with a quote on another you must specify an exact preliminary join on the stock column with `by="stock"`. + +{{code_block('user-guide/transformations/joins','asof',['join_asof'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asofpre" +--8<-- "python/user-guide/transformations/joins.py:asof" +``` + +If you want to make sure that only quotes within a certain time range are joined to the trades you can specify the `tolerance` argument. In this case we want to make sure that the last preceding quote is within 1 minute of the trade so we set `tolerance = "1m"`. + +=== ":fontawesome-brands-python: Python" + +```python +--8<-- "python/user-guide/transformations/joins.py:asof2" +``` + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asof2" +``` diff --git a/docs/user-guide/transformations/melt.md b/docs/user-guide/transformations/melt.md new file mode 100644 index 000000000000..3e6efe35723e --- /dev/null +++ b/docs/user-guide/transformations/melt.md @@ -0,0 +1,21 @@ +# Melts + +Melt operations unpivot a DataFrame from wide format to long format + +## Dataset + +{{code_block('user-guide/transformations/melt','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/melt" +--8<-- "python/user-guide/transformations/melt.py:df" +``` + +## Eager + lazy + +`Eager` and `lazy` have the same API. + +{{code_block('user-guide/transformations/melt','melt',['melt'])}} + +```python exec="on" result="text" session="user-guide/transformations/melt" +--8<-- "python/user-guide/transformations/melt.py:melt" +``` diff --git a/docs/user-guide/transformations/pivot.md b/docs/user-guide/transformations/pivot.md new file mode 100644 index 000000000000..9850dbed0330 --- /dev/null +++ b/docs/user-guide/transformations/pivot.md @@ -0,0 +1,46 @@ +# Pivots + +Pivot a column in a `DataFrame` and perform one of the following aggregations: + +- first +- sum +- min +- max +- mean +- median + +The pivot operation consists of a group by one, or multiple columns (these will be the +new y-axis), the column that will be pivoted (this will be the new x-axis) and an +aggregation. + +## Dataset + +{{code_block('user-guide/transformations/pivot','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:setup" +--8<-- "python/user-guide/transformations/pivot.py:df" +``` + +## Eager + +{{code_block('user-guide/transformations/pivot','eager',['pivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:eager" +``` + +## Lazy + +A polars `LazyFrame` always need to know the schema of a computation statically (before collecting the query). +As a pivot's output schema depends on the data, and it is therefore impossible to determine the schema without +running the query. + +Polars could have abstracted this fact for you just like Spark does, but we don't want you to shoot yourself in the foot +with a shotgun. The cost should be clear upfront. + +{{code_block('user-guide/transformations/pivot','lazy',['pivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:lazy" +``` diff --git a/docs/user-guide/transformations/time-series/filter.md b/docs/user-guide/transformations/time-series/filter.md new file mode 100644 index 000000000000..1f57d8866fbd --- /dev/null +++ b/docs/user-guide/transformations/time-series/filter.md @@ -0,0 +1,48 @@ +# Filtering + +Filtering date columns works in the same way as with other types of columns using the `.filter` method. + +Polars uses Python's native `datetime`, `date` and `timedelta` for equality comparisons between the datatypes `pl.Datetime`, `pl.Date` and `pl.Duration`. + +In the following example we use a time series of Apple stock prices. + +{{code_block('user-guide/transformations/time-series/filter','df',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:df" +``` + +## Filtering by single dates + +We can filter by a single date by casting the desired date string to a `Date` object +in a filter expression: + +{{code_block('user-guide/transformations/time-series/filter','filter',['filter'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:filter" +``` + +Note we are using the lowercase `datetime` method rather than the uppercase `Datetime` data type. + +## Filtering by a date range + +We can filter by a range of dates using the `is_between` method in a filter expression with the start and end dates: + +{{code_block('user-guide/transformations/time-series/filter','range',['filter','is_between'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:range" +``` + +## Filtering with negative dates + +Say you are working with an archeologist and are dealing in negative dates. +Polars can parse and store them just fine, but the Python `datetime` library +does not. So for filtering, you should use attributes in the `.dt` namespace: + +{{code_block('user-guide/transformations/time-series/filter','negative',['str.to_date'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:negative" +``` diff --git a/docs/user-guide/transformations/time-series/parsing.md b/docs/user-guide/transformations/time-series/parsing.md new file mode 100644 index 000000000000..62bdb0a44b8f --- /dev/null +++ b/docs/user-guide/transformations/time-series/parsing.md @@ -0,0 +1,58 @@ +# Parsing + +Polars has native support for parsing time series data and doing more sophisticated operations such as temporal grouping and resampling. + +## Datatypes + +`Polars` has the following datetime datatypes: + +- `Date`: Date representation e.g. 2014-07-08. It is internally represented as days since UNIX epoch encoded by a 32-bit signed integer. +- `Datetime`: Datetime representation e.g. 2014-07-08 07:00:00. It is internally represented as a 64 bit integer since the Unix epoch and can have different units such as ns, us, ms. +- `Duration`: A time delta type that is created when subtracting `Date/Datetime`. Similar to `timedelta` in python. +- `Time`: Time representation, internally represented as nanoseconds since midnight. + +## Parsing dates from a file + +When loading from a CSV file `Polars` attempts to parse dates and times if the `try_parse_dates` flag is set to `True`: + +{{code_block('user-guide/transformations/time-series/parsing','df',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:setup" +--8<-- "python/user-guide/transformations/time-series/parsing.py:df" +``` + +On the other hand binary formats such as parquet have a schema that is respected by `Polars`. + +## Casting strings to dates + +You can also cast a column of datetimes encoded as strings to a datetime type. You do this by calling the string `str.to_date` method and passing the format of the date string: + +{{code_block('user-guide/transformations/time-series/parsing','cast',['read_csv','str.to_date'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:cast" +``` + +[The format string specification can be found here.](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). + +## Extracting date features from a date column + +You can extract data features such as the year or day from a date column using the `.dt` namespace on a date column: + +{{code_block('user-guide/transformations/time-series/parsing','extract',['dt.year'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:extract" +``` + +## Mixed offsets + +If you have mixed offsets (say, due to crossing daylight saving time), +then you can use `utc=True` and then convert to your time zone: + +{{code_block('user-guide/transformations/time-series/parsing','mixed',['str.to_datetime','dt.convert_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:mixed" +``` diff --git a/docs/user-guide/transformations/time-series/resampling.md b/docs/user-guide/transformations/time-series/resampling.md new file mode 100644 index 000000000000..63ad583a9bec --- /dev/null +++ b/docs/user-guide/transformations/time-series/resampling.md @@ -0,0 +1,42 @@ +# Resampling + +We can resample by either: + +- upsampling (moving data to a higher frequency) +- downsampling (moving data to a lower frequency) +- combinations of these e.g. first upsample and then downsample + +## Downsampling to a lower frequency + +`Polars` views downsampling as a special case of the **group_by** operation and you can do this with `group_by_dynamic` and `group_by_rolling` - [see the temporal group by page for examples](rolling.md). + +## Upsampling to a higher frequency + +Let's go through an example where we generate data at 30 minute intervals: + +{{code_block('user-guide/transformations/time-series/resampling','df',['DataFrame','date_range'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:setup" +--8<-- "python/user-guide/transformations/time-series/resampling.py:df" +``` + +Upsampling can be done by defining the new sampling interval. By upsampling we are adding in extra rows where we do not have data. As such upsampling by itself gives a DataFrame with nulls. These nulls can then be filled with a fill strategy or interpolation. + +### Upsampling strategies + +In this example we upsample from the original 30 minutes to 15 minutes and then use a `forward` strategy to replace the nulls with the previous non-null value: + +{{code_block('user-guide/transformations/time-series/resampling','upsample',['upsample'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:upsample" +``` + +In this example we instead fill the nulls by linear interpolation: + +{{code_block('user-guide/transformations/time-series/resampling','upsample2',['upsample','interpolate','fill_null'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:upsample2" +``` diff --git a/docs/user-guide/transformations/time-series/rolling.md b/docs/user-guide/transformations/time-series/rolling.md new file mode 100644 index 000000000000..f3e009f99b76 --- /dev/null +++ b/docs/user-guide/transformations/time-series/rolling.md @@ -0,0 +1,148 @@ +# Grouping + +## Grouping by fixed windows + +We can calculate temporal statistics using `group_by_dynamic` to group rows into days/months/years etc. + +### Annual average example + +In following simple example we calculate the annual average closing price of Apple stock prices. We first load the data from CSV: + +{{code_block('user-guide/transformations/time-series/rolling','df',['upsample'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:setup" +--8<-- "python/user-guide/transformations/time-series/rolling.py:df" +``` + +!!! info + + The dates are sorted in ascending order - if they are not sorted in this way the `group_by_dynamic` output will not be correct! + +To get the annual average closing price we tell `group_by_dynamic` that we want to: + +- group by the `Date` column on an annual (`1y`) basis +- take the mean values of the `Close` column for each year: + +{{code_block('user-guide/transformations/time-series/rolling','group_by',['group_by_dynamic'])}} + +The annual average closing price is then: + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by" +``` + +### Parameters for `group_by_dynamic` + +A dynamic window is defined by a: + +- **every**: indicates the interval of the window +- **period**: indicates the duration of the window +- **offset**: can be used to offset the start of the windows + +The value for `every` sets how often the groups start. The time period values are flexible - for example we could take: + +- the average over 2 year intervals by replacing `1y` with `2y` +- the average over 18 month periods by replacing `1y` with `1y6mo` + +We can also use the `period` parameter to set how long the time period for each group is. For example, if we set the `every` parameter to be `1y` and the `period` parameter to be `2y` then we would get groups at one year intervals where each groups spanned two years. + +If the `period` parameter is not specified then it is set equal to the `every` parameter so that if the `every` parameter is set to be `1y` then each group spans `1y` as well. + +Because _**every**_ does not have to be equal to _**period**_, we can create many groups in a very flexible way. They may overlap +or leave boundaries between them. + +Let's see how the windows for some parameter combinations would look. Let's start out boring. 🥱 + +- every: 1 day -> `"1d"` +- period: 1 day -> `"1d"` + +```text +this creates adjacent windows of the same size +|--| + |--| + |--| +``` + +- every: 1 day -> `"1d"` +- period: 2 days -> `"2d"` + +```text +these windows have an overlap of 1 day +|----| + |----| + |----| +``` + +- every: 2 days -> `"2d"` +- period: 1 day -> `"1d"` + +```text +this would leave gaps between the windows +data points that in these gaps will not be a member of any group +|--| + |--| + |--| +``` + +#### `truncate` + +The `truncate` parameter is a Boolean variable that determines what datetime value is associated with each group in the output. In the example above the first data point is on 23rd February 1981. If `truncate = True` (the default) then the date for the first year in the annual average is 1st January 1981. However, if `truncate = False` then the date for the first year in the annual average is the date of the first data point on 23rd February 1981. Note that `truncate` only affects what's shown in the +`Date` column and does not affect the window boundaries. + +### Using expressions in `group_by_dynamic` + +We aren't restricted to using simple aggregations like `mean` in a group by operation - we can use the full range of expressions available in Polars. + +In the snippet below we create a `date range` with every **day** (`"1d"`) in 2021 and turn this into a `DataFrame`. + +Then in the `group_by_dynamic` we create dynamic windows that start every **month** (`"1mo"`) and have a window length of `1` month. The values that match these dynamic windows are then assigned to that group and can be aggregated with the powerful expression API. + +Below we show an example where we use **group_by_dynamic** to compute: + +- the number of days until the end of the month +- the number of days in a month + +{{code_block('user-guide/transformations/time-series/rolling','group_by_dyn',['group_by_dynamic','DataFrame.explode','date_range'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_dyn" +``` + +## Grouping by rolling windows + +The rolling group by, `group_by_rolling`, is another entrance to the `group_by` context. But different from the `group_by_dynamic` the windows are +not fixed by a parameter `every` and `period`. In a rolling group by, the windows are not fixed at all! They are determined +by the values in the `index_column`. + +So imagine having a time column with the values `{2021-01-06, 2021-01-10}` and a `period="5d"` this would create the following +windows: + +```text +2021-01-01 2021-01-06 + |----------| + + 2021-01-05 2021-01-10 + |----------| +``` + +Because the windows of a rolling group by are always determined by the values in the `DataFrame` column, the number of +groups is always equal to the original `DataFrame`. + +## Combining group by operations + +Rolling and dynamic group by operations can be combined with normal group by operations. + +Below is an example with a dynamic group by. + +{{code_block('user-guide/transformations/time-series/rolling','group_by_roll',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_roll" +``` + +{{code_block('user-guide/transformations/time-series/rolling','group_by_dyn2',['group_by_dynamic'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_dyn2" +``` diff --git a/docs/user-guide/transformations/time-series/timezones.md b/docs/user-guide/transformations/time-series/timezones.md new file mode 100644 index 000000000000..a12b97c68dd9 --- /dev/null +++ b/docs/user-guide/transformations/time-series/timezones.md @@ -0,0 +1,46 @@ +--- +hide: + - toc +--- + +# Time zones + +!!! quote "Tom Scott" + + You really should never, ever deal with time zones if you can help it. + +The `Datetime` datatype can have a time zone associated with it. +Examples of valid time zones are: + +- `None`: no time zone, also known as "time zone naive"; +- `UTC`: Coordinated Universal Time; +- `Asia/Kathmandu`: time zone in "area/location" format. + See the [list of tz database time zones](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) + to see what's available; +- `+01:00`: fixed offsets. May be useful when parsing, but you almost certainly want the "Area/Location" + format above instead as it will deal with irregularities such as DST (Daylight Saving Time) for you. + +Note that, because a `Datetime` can only have a single time zone, it is +impossible to have a column with multiple time zones. If you are parsing data +with multiple offsets, you may want to pass `utc=True` to convert +them all to a common time zone (`UTC`), see [parsing dates and times](parsing.md). + +The main methods for setting and converting between time zones are: + +- `dt.convert_time_zone`: convert from one time zone to another; +- `dt.replace_time_zone`: set/unset/change time zone; + +Let's look at some examples of common operations: + +{{code_block('user-guide/transformations/time-series/timezones','example',['str.to_datetime','dt.replace_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/timezones" +--8<-- "python/user-guide/transformations/time-series/timezones.py:setup" +--8<-- "python/user-guide/transformations/time-series/timezones.py:example" +``` + +{{code_block('user-guide/transformations/time-series/timezones','example2',['dt.convert_time_zone','dt.replace_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/timezones" +--8<-- "python/user-guide/transformations/time-series/timezones.py:example2" +``` diff --git a/examples/read_csv/src/main.rs b/examples/read_csv/src/main.rs index 7d1f555c8397..ca4fbbd7730c 100644 --- a/examples/read_csv/src/main.rs +++ b/examples/read_csv/src/main.rs @@ -6,7 +6,7 @@ fn main() -> PolarsResult<()> { .unwrap(); let file = Box::new(file) as Box; let _df = CsvReader::new(file) - .with_delimiter(b'|') + .with_separator(b'|') .has_header(false) .with_chunk_size(10) .batched_mmap(None) diff --git a/examples/read_parquet/src/main.rs b/examples/read_parquet/src/main.rs index fa886a9f3262..6ea2fe22a9fa 100644 --- a/examples/read_parquet/src/main.rs +++ b/examples/read_parquet/src/main.rs @@ -10,6 +10,6 @@ fn main() -> PolarsResult<()> { ]) .collect()?; - dbg!(df); + println!("{}", df); Ok(()) } diff --git a/examples/read_parquet_cloud/Cargo.toml b/examples/read_parquet_cloud/Cargo.toml index 634311904846..f6f5b56eb430 100644 --- a/examples/read_parquet_cloud/Cargo.toml +++ b/examples/read_parquet_cloud/Cargo.toml @@ -4,6 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "aws"] } +polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet"] } aws-creds = "0.35.0" diff --git a/examples/read_parquet_cloud/src/main.rs b/examples/read_parquet_cloud/src/main.rs index e179266e1de3..367575bbdd30 100644 --- a/examples/read_parquet_cloud/src/main.rs +++ b/examples/read_parquet_cloud/src/main.rs @@ -25,6 +25,6 @@ fn main() -> PolarsResult<()> { ]) .collect()?; - dbg!(df); + println!("{}", df); Ok(()) } diff --git a/examples/write_parquet_cloud/Cargo.toml b/examples/write_parquet_cloud/Cargo.toml new file mode 100644 index 000000000000..7bf6a24e46d3 --- /dev/null +++ b/examples/write_parquet_cloud/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "write_parquet_cloud" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +aws-creds = "0.35.0" +polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet", "cloud_write"] } diff --git a/examples/write_parquet_cloud/src/main.rs b/examples/write_parquet_cloud/src/main.rs new file mode 100644 index 000000000000..8ed29bc402cf --- /dev/null +++ b/examples/write_parquet_cloud/src/main.rs @@ -0,0 +1,63 @@ +use awscreds::Credentials; +use cloud::AmazonS3ConfigKey as Key; +use polars::prelude::*; + +// Login to your aws account and then copy the ../datasets/foods1.parquet file to your own bucket. +// Adjust the link below. +const TEST_S3_LOCATION: &str = "s3://polarstesting/polars_write_example_cloud.parquet"; + +fn main() -> PolarsResult<()> { + sink_file(); + sink_cloud_local(); + sink_aws(); + + Ok(()) +} + +fn sink_file() { + let df = example_dataframe(); + + // Writing to a local file: + let path = "/tmp/polars_write_example.parquet".into(); + df.lazy().sink_parquet(path, Default::default()).unwrap(); +} + +fn sink_cloud_local() { + let df = example_dataframe(); + + // Writing to a location that might be in the cloud: + let uri = "file:///tmp/polars_write_example_cloud.parquet".to_string(); + df.lazy() + .sink_parquet_cloud(uri, None, Default::default()) + .unwrap(); +} + +fn sink_aws() { + let cred = Credentials::default().unwrap(); + + // Propagate the credentials and other cloud options. + let cloud_options = cloud::CloudOptions::default().with_aws([ + (Key::AccessKeyId, &cred.access_key.unwrap()), + (Key::SecretAccessKey, &cred.secret_key.unwrap()), + (Key::Region, &"eu-central-1".into()), + ]); + let cloud_options = Some(cloud_options); + + let df = example_dataframe(); + + df.lazy() + .sink_parquet_cloud( + TEST_S3_LOCATION.to_string(), + cloud_options, + Default::default(), + ) + .unwrap(); +} + +fn example_dataframe() -> DataFrame { + df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap() +} diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000000..501d047b35e5 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,162 @@ +# https://www.mkdocs.org/user-guide/configuration/ + +# Project information +site_name: Polars documentation +site_url: https://pola-rs.github.io/polars +repo_url: https://github.com/pola-rs/polars +repo_name: pola-rs/polars + +# Documentation layout +nav: + - Home: index.md + - Getting started: + - getting-started/intro.md + - getting-started/installation.md + - getting-started/series-dataframes.md + - getting-started/reading-writing.md + - getting-started/expressions.md + - getting-started/joins.md + - User guide: + - user-guide/index.md + - user-guide/installation.md + - Concepts: + - user-guide/concepts/data-types.md + - user-guide/concepts/data-structures.md + - user-guide/concepts/contexts.md + - user-guide/concepts/expressions.md + - user-guide/concepts/lazy-vs-eager.md + - user-guide/concepts/streaming.md + - Expressions: + - user-guide/expressions/operators.md + - user-guide/expressions/column-selections.md + - user-guide/expressions/functions.md + - user-guide/expressions/casting.md + - user-guide/expressions/strings.md + - user-guide/expressions/aggregation.md + - user-guide/expressions/null.md + - user-guide/expressions/window.md + - user-guide/expressions/folds.md + - user-guide/expressions/lists.md + - user-guide/expressions/user-defined-functions.md + - user-guide/expressions/structs.md + - user-guide/expressions/numpy.md + - Transformations: + - user-guide/transformations/joins.md + - user-guide/transformations/concatenation.md + - user-guide/transformations/pivot.md + - user-guide/transformations/melt.md + - Time series: + - user-guide/transformations/time-series/parsing.md + - user-guide/transformations/time-series/filter.md + - user-guide/transformations/time-series/rolling.md + - user-guide/transformations/time-series/resampling.md + - user-guide/transformations/time-series/timezones.md + - Lazy API: + - user-guide/lazy/using.md + - user-guide/lazy/optimizations.md + - user-guide/lazy/schemas.md + - user-guide/lazy/query-plan.md + - user-guide/lazy/execution.md + - user-guide/lazy/streaming.md + - IO: + - user-guide/io/csv.md + - user-guide/io/parquet.md + - user-guide/io/json.md + - user-guide/io/multiple.md + - user-guide/io/database.md + - user-guide/io/cloud-storage.md + - user-guide/io/bigquery.md + - SQL: + - user-guide/sql/intro.md + - user-guide/sql/show.md + - user-guide/sql/select.md + - user-guide/sql/create.md + - user-guide/sql/cte.md + - Migrating: + - user-guide/migration/pandas.md + - user-guide/migration/spark.md + - Misc: + - user-guide/misc/multiprocessing.md + - user-guide/misc/alternatives.md + - user-guide/misc/reference-guides.md + - user-guide/misc/contributing.md +not_in_nav: | + /_build/ + people.md +validation: + links: + # Allow an absolute link to the features page for our code snippets + absolute_links: ignore + +# Build directories +theme: + name: material + locale: en + custom_dir: docs/_build/overrides + palette: + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + toggle: + icon: material/brightness-7 + name: Switch to dark mode + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + toggle: + icon: material/brightness-4 + name: Switch to light mode + logo: _build/assets/logo.png + features: + - navigation.tracking + - navigation.instant + - navigation.tabs + - navigation.tabs.sticky + - navigation.footer + - content.tabs.link + icon: + repo: fontawesome/brands/github + +extra_css: + - _build/css/extra.css +extra: + consent: + title: Cookie consent + description: >- + We use cookies to recognize your repeated visits and preferences, as well + as to measure the effectiveness of our documentation and whether users + find what they're searching for. With your consent, you're helping us to + make our documentation better. + analytics: + provider: google + property: G-LKNVFWD3T5 + +# Preview controls +strict: true + +# Formatting options +markdown_extensions: + - admonition + - pymdownx.details + - attr_list + - pymdownx.emoji: + emoji_index: !!python/name:materialx.emoji.twemoji + emoji_generator: !!python/name:materialx.emoji.to_svg + - pymdownx.superfences + - pymdownx.tabbed: + alternate_style: true + - pymdownx.snippets: + base_path: ['.','docs/src/'] + check_paths: true + dedent_subsections: true + - footnotes + +hooks: + - docs/_build/scripts/people.py + +plugins: + - search: + lang: en + - markdown-exec + - macros: + module_name: docs/_build/scripts/macro diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 26f33e0d5c9b..32b759f2b46d 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -96,40 +105,6 @@ dependencies = [ "serde", ] -[[package]] -name = "arrow2" -version = "0.17.4" -source = "git+https://github.com/jorgecarleitao/arrow2?rev=ba6a882bc1542b0b899774b696ebea77482b5c31#ba6a882bc1542b0b899774b696ebea77482b5c31" -dependencies = [ - "ahash", - "arrow-format", - "avro-schema", - "base64", - "bytemuck", - "chrono", - "chrono-tz", - "dyn-clone", - "either", - "ethnum", - "fallible-streaming-iterator", - "foreign_vec", - "futures", - "getrandom", - "hash_hasher", - "lexical-core", - "lz4", - "multiversion", - "num-traits", - "parquet2", - "regex", - "regex-syntax", - "rustc_version", - "simdutf8", - "streaming-iterator", - "strength_reduce", - "zstd", -] - [[package]] name = "async-stream" version = "0.3.5" @@ -149,7 +124,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -160,7 +135,7 @@ checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -192,11 +167,26 @@ dependencies = [ "snap", ] +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" -version = "0.21.3" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "bitflags" @@ -247,35 +237,35 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "bytemuck" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdde5c9cd29ebd706ce1b35600920a33550e402fc998a2e53ad3b42c3c47a192" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cargo-lock" @@ -307,9 +297,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.27" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f56b4c72906975ca04becb8a30e102dfecddd0c06181e3e95ddc444be28881f8" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", "iana-time-zone", @@ -496,6 +486,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dyn-clone" version = "1.0.13" @@ -508,6 +504,15 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "encoding_rs" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum_dispatch" version = "0.3.12" @@ -517,7 +522,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -564,6 +569,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foreign_vec" version = "0.1.0" @@ -635,7 +646,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -681,6 +692,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" + [[package]] name = "git2" version = "0.17.2" @@ -700,6 +717,25 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "h2" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap 1.9.3", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "1.8.2" @@ -717,10 +753,10 @@ dependencies = [ ] [[package]] -name = "hash_hasher" -version = "2.0.3" +name = "hashbrown" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74721d007512d0cb3338cd20f0654ac913920061a4c4d0d8708edb3f2a698c0c" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" @@ -750,9 +786,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "hex" @@ -769,6 +805,84 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "hyper" +version = "0.14.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.4.9", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" +dependencies = [ + "futures-util", + "http", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + [[package]] name = "iana-time-zone" version = "0.1.57" @@ -802,6 +916,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + [[package]] name = "indexmap" version = "2.0.0" @@ -815,9 +939,9 @@ dependencies = [ [[package]] name = "indoc" -version = "1.0.9" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" [[package]] name = "inventory" @@ -825,6 +949,21 @@ version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1be380c410bf0595e94992a648ea89db4dd3f3354ba54af206fd2a68cf5ac8e" +[[package]] +name = "ipnet" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.9" @@ -960,9 +1099,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "libflate" @@ -996,6 +1135,16 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libloading" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d580318f95776505201b28cf98eb1fa5e4be3b689633ba6a3e6cd880ff22d8cb" +dependencies = [ + "cfg-if", + "windows-sys", +] + [[package]] name = "libm" version = "0.2.7" @@ -1004,9 +1153,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "libmimalloc-sys" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25d058a81af0d1c22d7a1c948576bee6d673f7af3c0f35564abd6c81122f513d" +checksum = "3979b5c37ece694f1f5e51e7ecc871fdb0f517ed04ee45f88d15d6d553cb9664" dependencies = [ "cc", "libc", @@ -1082,9 +1231,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.1" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f478948fd84d9f8e86967bf432640e46adfb5a4bd4f14ef7e864ab38220534ae" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memmap2" @@ -1106,13 +1255,19 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.38" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "972e5f23f6716f62665760b0f4cbf592576a80c7b879ba9beaafc0e558894127" +checksum = "fa01922b5ea280a911e323e4d2fd24b7fe5cc4042e0d2cda3c40775cdc4bdc9c" dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -1156,6 +1311,37 @@ dependencies = [ "target-features", ] +[[package]] +name = "nano-arrow" +version = "0.1.0" +dependencies = [ + "ahash", + "arrow-format", + "avro-schema", + "base64", + "bytemuck", + "chrono", + "chrono-tz", + "dyn-clone", + "either", + "ethnum", + "fallible-streaming-iterator", + "foreign_vec", + "futures", + "getrandom", + "hashbrown 0.14.0", + "lexical-core", + "lz4", + "multiversion", + "num-traits", + "parquet2", + "rustc_version", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "zstd", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -1228,9 +1414,9 @@ dependencies = [ [[package]] name = "numpy" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "437213adf41bbccf4aeae535fbfcdad0f6fed241e1ae182ebe97fa1f3ce19389" +checksum = "bef41cbb417ea83b30525259e30ccef6af39b31c240bda578889494c5392d331" dependencies = [ "libc", "ndarray", @@ -1241,6 +1427,45 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "object_store" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d359e231e5451f4f9fa889d56e3ce34f8724f1a61db2107739359717cf2bbf08" +dependencies = [ + "async-trait", + "base64", + "bytes", + "chrono", + "futures", + "humantime", + "hyper", + "itertools", + "parking_lot", + "percent-encoding", + "quick-xml", + "rand", + "reqwest", + "ring", + "rustls-pemfile", + "serde", + "serde_json", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.18.0" @@ -1380,7 +1605,7 @@ dependencies = [ [[package]] name = "polars" -version = "0.32.0" +version = "0.33.2" dependencies = [ "getrandom", "polars-core", @@ -1394,7 +1619,7 @@ dependencies = [ [[package]] name = "polars-algo" -version = "0.32.0" +version = "0.33.2" dependencies = [ "polars-core", "polars-lazy", @@ -1403,15 +1628,15 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", "atoi", "chrono", "chrono-tz", "ethnum", "hashbrown 0.14.0", "multiversion", + "nano-arrow", "num-traits", "polars-error", "serde", @@ -1421,18 +1646,19 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", "bitflags 2.4.0", + "bytemuck", "chrono", "chrono-tz", "comfy-table", "either", "hashbrown 0.14.0", - "indexmap", + "indexmap 2.0.0", "itoap", + "nano-arrow", "ndarray", "num-traits", "once_cell", @@ -1454,31 +1680,45 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", + "nano-arrow", + "object_store", "regex", "thiserror", ] +[[package]] +name = "polars-ffi" +version = "0.33.2" +dependencies = [ + "nano-arrow", + "polars-core", +] + [[package]] name = "polars-io" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", + "async-trait", "bytes", "chrono", "chrono-tz", "fast-float", "flate2", + "futures", "home", + "itoa", "lexical", "lexical-core", "memchr", "memmap2", + "nano-arrow", "num-traits", + "object_store", "once_cell", + "percent-encoding", "polars-arrow", "polars-core", "polars-error", @@ -1487,31 +1727,41 @@ dependencies = [ "polars-utils", "rayon", "regex", + "reqwest", + "ryu", "serde", "serde_json", "simd-json", "simdutf8", + "smartstring", + "tokio", + "tokio-util", + "url", ] [[package]] name = "polars-json" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", + "chrono", "fallible-streaming-iterator", "hashbrown 0.14.0", - "indexmap", + "indexmap 2.0.0", + "itoa", + "nano-arrow", "num-traits", "polars-arrow", "polars-error", "polars-utils", + "ryu", "simd-json", + "streaming-iterator", ] [[package]] name = "polars-lazy" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", "bitflags 2.4.0", @@ -1529,27 +1779,37 @@ dependencies = [ "pyo3", "rayon", "smartstring", + "tokio", "version_check", ] [[package]] name = "polars-ops" -version = "0.32.0" +version = "0.33.2" dependencies = [ + "ahash", "argminmax", - "arrow2", "base64", + "bytemuck", "chrono", "chrono-tz", "either", + "hashbrown 0.14.0", "hex", - "indexmap", + "indexmap 2.0.0", "jsonpath_lib", "memchr", + "nano-arrow", + "num-traits", "polars-arrow", "polars-core", + "polars-error", "polars-json", "polars-utils", + "rand", + "rand_distr", + "rayon", + "regex", "serde", "serde_json", "smartstring", @@ -1558,7 +1818,7 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.32.0" +version = "0.33.2" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -1574,21 +1834,25 @@ dependencies = [ "polars-utils", "rayon", "smartstring", + "tokio", "version_check", ] [[package]] name = "polars-plan" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", "chrono", "chrono-tz", "ciborium", + "libloading", + "nano-arrow", "once_cell", + "percent-encoding", "polars-arrow", "polars-core", + "polars-ffi", "polars-io", "polars-ops", "polars-time", @@ -1604,21 +1868,22 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", + "nano-arrow", "polars-error", "polars-utils", ] [[package]] name = "polars-sql" -version = "0.32.0" +version = "0.33.2" dependencies = [ "polars-arrow", "polars-core", "polars-lazy", "polars-plan", + "rand", "serde", "serde_json", "sqlparser", @@ -1626,12 +1891,12 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", "atoi", "chrono", "chrono-tz", + "nano-arrow", "now", "once_cell", "polars-arrow", @@ -1645,9 +1910,10 @@ dependencies = [ [[package]] name = "polars-utils" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", + "bytemuck", "hashbrown 0.14.0", "num-traits", "once_cell", @@ -1666,16 +1932,16 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" dependencies = [ "unicode-ident", ] [[package]] name = "py-polars" -version = "0.19.0" +version = "0.19.8" dependencies = [ "ahash", "built", @@ -1693,6 +1959,8 @@ dependencies = [ "polars-core", "polars-error", "polars-lazy", + "polars-ops", + "polars-plan", "pyo3", "pyo3-built", "serde_json", @@ -1702,9 +1970,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" dependencies = [ "cfg-if", "indoc", @@ -1720,9 +1988,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" dependencies = [ "once_cell", "target-lexicon", @@ -1736,9 +2004,9 @@ checksum = "be6d574e0f8cab2cdd1eeeb640cbf845c974519fa9e9b62fa9c08ecece0ca5de" [[package]] name = "pyo3-ffi" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" dependencies = [ "libc", "pyo3-build-config", @@ -1746,25 +2014,36 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 1.0.109", + "syn 2.0.36", ] [[package]] name = "pyo3-macros-backend" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" dependencies = [ + "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.36", +] + +[[package]] +name = "quick-xml" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce5e73202a820a31f8a0ee32ada5e21029c81fd9e3ebf668a40832e4219d9d1" +dependencies = [ + "memchr", + "serde", ] [[package]] @@ -1824,9 +2103,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -1834,14 +2113,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] [[package]] @@ -1855,9 +2132,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -1867,9 +2144,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -1882,12 +2159,74 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +[[package]] +name = "reqwest" +version = "0.11.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-rustls", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "webpki-roots", + "winreg", +] + +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin", + "untrusted", + "web-sys", + "winapi", +] + [[package]] name = "rle-decode-fast" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -1903,6 +2242,37 @@ dependencies = [ "semver", ] +[[package]] +name = "rustls" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +dependencies = [ + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45a27e3b59326c16e23d30aeb7a36a24cc0d29e71d68ff611cdfb4a01d013bed" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -1915,12 +2285,31 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "semver" version = "1.0.18" @@ -1953,16 +2342,16 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ - "indexmap", + "indexmap 2.0.0", "itoa", "ryu", "serde", @@ -1977,6 +2366,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "signal-hook" version = "0.3.17" @@ -2009,11 +2410,12 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.10.6" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de7f1293f0e4e11d52e588766fe9de8caa2857ff63809d40de83245452ca7c5c" +checksum = "474b451aaac1828ed12f6454a80fe58b940ae2998d10389d41533940a6f641bf" dependencies = [ "ahash", + "getrandom", "halfbrown", "lexical-core", "once_cell", @@ -2062,17 +2464,65 @@ dependencies = [ "version_check", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "snap" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e9f0ab6ef7eb7353d9119c170a436d1bf248eea575ac42d19d12f4e34130831" +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "socket2" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "sqlparser" -version = "0.36.1" +version = "0.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eaa1e88e78d2c2460d78b7dc3f0c08dbb606ab4222f9aff36f420d36e307d87" +checksum = "0272b7bb0a225320170c99901b4b5fb3a4384e255a7f2cc228f61e2ba3893e75" dependencies = [ "log", ] @@ -2133,7 +2583,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -2149,9 +2599,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.29" +version = "2.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +checksum = "91e02e55d62894af2a08aca894c6577281f76769ba47c94d5756bec8ac6e7373" dependencies = [ "proc-macro2", "quote", @@ -2160,9 +2610,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.9" +version = "0.29.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d0e9cc2273cc8d31377bdd638d72e3ac3e5607b18621062b169d02787f1bab" +checksum = "0a18d114d420ada3a891e6bc8e96a2023402203296a47cdd65083377dad18ba5" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2186,22 +2636,22 @@ checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" [[package]] name = "thiserror" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -2219,11 +2669,63 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokio" +version = "1.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "pin-project-lite", + "socket2 0.5.4", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.36", +] + +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + [[package]] name = "toml" -version = "0.7.6" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ "serde", "serde_spanned", @@ -2242,17 +2744,61 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap", + "indexmap 2.0.0", "serde", "serde_spanned", "toml_datetime", "winnow", ] +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +dependencies = [ + "cfg-if", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.36", +] + +[[package]] +name = "tracing-core" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -2261,9 +2807,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -2282,9 +2828,15 @@ checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" [[package]] name = "unindent" -version = "0.1.11" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "untrusted" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" @@ -2321,6 +2873,25 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2348,10 +2919,22 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.87" @@ -2370,7 +2953,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2381,6 +2964,35 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wasm-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "web-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" + [[package]] name = "winapi" version = "0.3.9" @@ -2397,6 +3009,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2487,11 +3108,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys", +] + [[package]] name = "xxhash-rust" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "735a71d46c4d68d71d4b24d03fdc2b98e38cea81730595801db779c04fe80d70" +checksum = "9828b178da53440fa9c766a3d2f73f7cf5d0ac1fe3980c1e5018d899fd19e07b" [[package]] name = "zstd" diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index b0d2e8ed3df0..66e0c9db1854 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.19.0" +version = "0.19.8" edition = "2021" [lib] @@ -12,6 +12,8 @@ polars-algo = { path = "../crates/polars-algo", default-features = false } polars-core = { path = "../crates/polars-core", default-features = false, features = ["python"] } polars-error = { path = "../crates/polars-error" } polars-lazy = { path = "../crates/polars-lazy", default-features = false, features = ["python"] } +polars-ops = { path = "../crates/polars-ops", default-features = false, features = ["convert_index"] } +polars-plan = { path = "../crates/polars-plan", default-features = false } ahash = "0.8" ciborium = "0.2" @@ -19,9 +21,9 @@ either = "1.8" lexical-core = "0.8" libc = "0.2" ndarray = "0.15" -numpy = "0.19" +numpy = "0.20" once_cell = "1" -pyo3 = { version = "0.19", features = ["abi3-py38", "extension-module", "multiple-pymethods"] } +pyo3 = { version = "0.20", features = ["abi3-py38", "extension-module", "multiple-pymethods"] } pyo3-built = { version = "0.4", optional = true } serde_json = { version = "1", optional = true } smartstring = "1" @@ -50,7 +52,8 @@ features = [ "fmt", "horizontal_concat", "interpolate", - "is_first", + "is_first_distinct", + "is_last_distinct", "is_unique", "lazy", "list_eval", @@ -134,54 +137,78 @@ list_count = ["polars/list_count"] binary_encoding = ["polars/binary_encoding"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] +list_drop_nulls = ["polars/list_drop_nulls"] cutqcut = ["polars/cutqcut"] rle = ["polars/rle"] extract_groups = ["polars/extract_groups"] +ffi_plugin = ["polars-plan/ffi_plugin"] +cloud = ["polars/cloud", "polars/aws", "polars/gcp", "polars/azure"] +peaks = ["polars/peaks"] -all = [ +dtypes = [ "dtype-i8", "dtype-i16", "dtype-u8", "dtype-u16", - "json", - "parquet", - "ipc", - "ipc_streaming", - "avro", + "polars/group_by_list", + "object", +] + +operations = [ "is_in", "repeat_by", "trigonometry", "sign", + "performant", + "list_take", + "list_count", + "list_sets", + "list_any_all", + "list_drop_nulls", + "cutqcut", + "rle", + "extract_groups", + "pivot", + "extract_jsonpath", "asof_join", "cross_join", "pct_change", "search_sorted", "merge_sorted", + "top_k", + "propagate_nans", + "timezones", + "peaks", +] + +io = [ + "json", + "parquet", + "ipc", + "ipc_streaming", + "avro", + "csv", + "cloud", +] + +optimizations = [ + "cse", + "polars/fused", + "streaming", +] + +all = [ + "optimizations", + "io", + "operations", + "dtypes", "meta", "decompress", "lazy_regex", - "csv", - "extract_jsonpath", - "timezones", - "object", - "pivot", - "top_k", "build_info", - "cse", - "propagate_nans", - "polars/group_by_list", - "polars/fused", "sql", "binary_encoding", - "streaming", - "performant", - "list_take", - "list_count", - "list_sets", - "list_any_all", - "cutqcut", - "rle", - "extract_groups", + "ffi_plugin", ] # we cannot conditionally activate simd diff --git a/py-polars/Makefile b/py-polars/Makefile index f5fd7c4404cd..be3f1f8da5be 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -23,12 +23,15 @@ build: .venv ## Compile and install Polars for development @unset CONDA_PREFIX && source $(VENV_BIN)/activate && maturin develop .PHONY: build-debug-opt -build-debug-opt: .venv ## Compile and install Polars for development, with minimal optimizations turned on +build-debug-opt: .venv ## Compile and install Polars with minimal optimizations turned on @unset CONDA_PREFIX && source $(VENV_BIN)/activate && maturin develop --profile opt-dev +.PHONY: build-debug-opt-subset +build-debug-opt-subset: .venv ## Compile and install Polars with minimal optimizations turned on and no default features + @unset CONDA_PREFIX && source $(VENV_BIN)/activate && maturin develop --no-default-features --profile opt-dev + .PHONY: build-opt -build-opt: .venv ## Compile and install Polars for performance-sensitive development, with nearly \ - full optimization on and debug assertions turned off, but with debug symbols on +build-opt: .venv ## Compile and install Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on @unset CONDA_PREFIX && source $(VENV_BIN)/activate && maturin develop --profile debug-release .PHONY: build-release @@ -76,6 +79,7 @@ test: .venv build ## Run fast unittests .PHONY: doctest doctest: .venv build ## Run doctests $(VENV_BIN)/python tests/docs/run_doctest.py + $(VENV_BIN)/pytest tests/docs/test_user_guide.py -m docs .PHONY: test-all test-all: .venv build ## Run all tests @@ -104,4 +108,4 @@ clean: ## Clean up caches and build artifacts .PHONY: help help: ## Display this help screen @echo -e "\033[1mAvailable commands:\033[0m" - @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt index 1ce2c7c19ba9..b755ae5dc02f 100644 --- a/py-polars/docs/requirements-docs.txt +++ b/py-polars/docs/requirements-docs.txt @@ -4,14 +4,14 @@ numpy pandas pyarrow -hypothesis==6.82.6 +hypothesis==6.87.1 sphinx==7.2.4 # Third-party Sphinx extensions autodocsumm==0.2.11 -numpydoc==1.5.0 -pydata-sphinx-theme==0.13.3 +numpydoc==1.6.0 +pydata-sphinx-theme==0.14.1 sphinx-autosummary-accessors==2023.4.0 sphinx-copybutton==0.5.2 sphinx-design==0.5.0 diff --git a/py-polars/docs/source/_static/css/custom.css b/py-polars/docs/source/_static/css/custom.css index 029703f6ec91..9cdd3b3591d8 100644 --- a/py-polars/docs/source/_static/css/custom.css +++ b/py-polars/docs/source/_static/css/custom.css @@ -43,3 +43,7 @@ div.bd-sidebar-secondary { label.sidebar-toggle.secondary-toggle { display: none !important; } + +a:visited { + color: var(--pst-color-link); +} diff --git a/py-polars/docs/source/_static/version_switcher.json b/py-polars/docs/source/_static/version_switcher.json index 9e59d14848fc..8f4d300d81de 100644 --- a/py-polars/docs/source/_static/version_switcher.json +++ b/py-polars/docs/source/_static/version_switcher.json @@ -7,7 +7,8 @@ { "name": "0.19 (stable)", "version": "0.19", - "url": "https://pola-rs.github.io/polars/py-polars/html/" + "url": "https://pola-rs.github.io/polars/py-polars/html/", + "preferred": true }, { "name": "0.18", diff --git a/py-polars/docs/source/conf.py b/py-polars/docs/source/conf.py index dc81c5fca524..643e7b4f07df 100644 --- a/py-polars/docs/source/conf.py +++ b/py-polars/docs/source/conf.py @@ -1,12 +1,15 @@ # Configuration file for the Sphinx documentation builder. # https://www.sphinx-doc.org/en/master/usage/configuration.html +from __future__ import annotations + import inspect import os import re import sys import warnings from pathlib import Path +from typing import Any import sphinx_autosummary_accessors @@ -103,7 +106,7 @@ "external_links": [ { "name": "User Guide", - "url": f"{web_root}/polars-book/user-guide/index.html", + "url": f"{web_root}/polars/user-guide/index.html", }, { "name": "Powered by Xomnia", @@ -135,6 +138,7 @@ "json_url": f"{web_root}/polars/docs/python/dev/_static/version_switcher.json", "version_match": switcher_version, }, + "show_version_warning_banner": False, "navbar_end": ["version-switcher", "navbar-icon-links"], "check_switcher": False, } @@ -157,7 +161,7 @@ # sphinx-ext-linkcode - Add external links to source code # https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html -def linkcode_resolve(domain, info): +def linkcode_resolve(domain: str, info: dict[str, Any]) -> str | None: """ Determine the URL corresponding to Python object. diff --git a/py-polars/docs/source/reference/config.rst b/py-polars/docs/source/reference/config.rst index 42c9ef3b4653..21266d911d6c 100644 --- a/py-polars/docs/source/reference/config.rst +++ b/py-polars/docs/source/reference/config.rst @@ -33,10 +33,21 @@ Config load, save, and current state :toctree: api/ Config.load + Config.load_from_file Config.save + Config.save_to_file Config.state Config.restore_defaults +While it is easy to restore *all* configuration options to their default +value using ``restore_defaults``, it can also be useful to reset *individual* +options. This can be done by setting the related value to ``None``, eg: + +.. code-block:: python + + pl.Config.set_tbl_rows(None) + + Use as a context manager ------------------------ diff --git a/py-polars/docs/source/reference/dataframe/modify_select.rst b/py-polars/docs/source/reference/dataframe/modify_select.rst index 7feb84ffbc6c..c3d1f1b91c8c 100644 --- a/py-polars/docs/source/reference/dataframe/modify_select.rst +++ b/py-polars/docs/source/reference/dataframe/modify_select.rst @@ -7,6 +7,7 @@ Manipulation/selection :toctree: api/ DataFrame.bottom_k + DataFrame.cast DataFrame.clear DataFrame.clone DataFrame.drop diff --git a/py-polars/docs/source/reference/expressions/boolean.rst b/py-polars/docs/source/reference/expressions/boolean.rst index 590758b75ef2..73c68917d515 100644 --- a/py-polars/docs/source/reference/expressions/boolean.rst +++ b/py-polars/docs/source/reference/expressions/boolean.rst @@ -12,11 +12,15 @@ Boolean Expr.is_duplicated Expr.is_finite Expr.is_first + Expr.is_first_distinct Expr.is_in Expr.is_infinite + Expr.is_last + Expr.is_last_distinct Expr.is_nan Expr.is_not Expr.is_not_nan Expr.is_not_null Expr.is_null Expr.is_unique + Expr.not_ diff --git a/py-polars/docs/source/reference/expressions/col.rst b/py-polars/docs/source/reference/expressions/col.rst new file mode 100644 index 000000000000..584d6461b953 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/col.rst @@ -0,0 +1,18 @@ +========== +polars.col +========== + +Create an expression representing column(s) in a dataframe. + +``col`` is technically not a function, but it can be used like one. + +See the class documentation below for examples and further documentation. + +----- + +.. currentmodule:: polars.functions.col +.. autoclass:: ColumnFactory + :members: __call__, __getattr__ + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/expressions/columns.rst b/py-polars/docs/source/reference/expressions/columns.rst index 8bb5c24d0bb7..c97e3cf13f25 100644 --- a/py-polars/docs/source/reference/expressions/columns.rst +++ b/py-polars/docs/source/reference/expressions/columns.rst @@ -6,10 +6,15 @@ Columns / names .. autosummary:: :toctree: api/ - col Expr.alias Expr.exclude Expr.keep_name Expr.map_alias Expr.prefix Expr.suffix + +.. toctree:: + :maxdepth: 2 + :hidden: + + col diff --git a/py-polars/docs/source/reference/expressions/computation.rst b/py-polars/docs/source/reference/expressions/computation.rst index 2964212bfffa..298631c7dfdf 100644 --- a/py-polars/docs/source/reference/expressions/computation.rst +++ b/py-polars/docs/source/reference/expressions/computation.rst @@ -41,6 +41,8 @@ Computation Expr.n_unique Expr.null_count Expr.pct_change + Expr.peak_max + Expr.peak_min Expr.radians Expr.rank Expr.rolling_apply diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index 66ac9ff9b662..bc813547f63f 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -38,6 +38,8 @@ These functions are available from the polars module root and can be used as exp datetime date_range date_ranges + datetime_range + datetime_ranges duration element exclude diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index 2710b8d56c80..d56b44abcc30 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -11,11 +11,13 @@ The following methods are available under the `expr.list` attribute. Expr.list.all Expr.list.any + Expr.list.drop_nulls Expr.list.arg_max Expr.list.arg_min Expr.list.concat Expr.list.contains Expr.list.count_match + Expr.list.count_matches Expr.list.diff Expr.list.difference Expr.list.eval @@ -26,6 +28,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.intersection Expr.list.join Expr.list.last + Expr.list.len Expr.list.lengths Expr.list.max Expr.list.mean diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index f2e5ad2e9945..2412f2444499 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -12,6 +12,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.concat Expr.str.contains Expr.str.count_match + Expr.str.count_matches Expr.str.decode Expr.str.encode Expr.str.ends_with @@ -21,6 +22,8 @@ The following methods are available under the `expr.str` attribute. Expr.str.extract_groups Expr.str.json_extract Expr.str.json_path_match + Expr.str.len_bytes + Expr.str.len_chars Expr.str.lengths Expr.str.ljust Expr.str.lstrip @@ -35,6 +38,11 @@ The following methods are available under the `expr.str` attribute. Expr.str.splitn Expr.str.starts_with Expr.str.strip + Expr.str.strip_chars + Expr.str.strip_chars_start + Expr.str.strip_chars_end + Expr.str.strip_prefix + Expr.str.strip_suffix Expr.str.strptime Expr.str.to_date Expr.str.to_datetime diff --git a/py-polars/docs/source/reference/expressions/window.rst b/py-polars/docs/source/reference/expressions/window.rst index 24ff8f2dbf3f..7c6c045b9e27 100644 --- a/py-polars/docs/source/reference/expressions/window.rst +++ b/py-polars/docs/source/reference/expressions/window.rst @@ -7,3 +7,4 @@ Window :toctree: api/ Expr.over + Expr.rolling diff --git a/py-polars/docs/source/reference/functions.rst b/py-polars/docs/source/reference/functions.rst index 83d72fcf0b89..1200c2d94d74 100644 --- a/py-polars/docs/source/reference/functions.rst +++ b/py-polars/docs/source/reference/functions.rst @@ -31,6 +31,7 @@ Parallelization :toctree: api/ collect_all + collect_all_async Random ~~~~~~ @@ -50,4 +51,5 @@ and a decorator, in order to explicitly scope cache lifetime. StringCache enable_string_cache + disable_string_cache using_string_cache diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index a5bf2f91916a..9b0b91335c09 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -66,14 +66,22 @@ AVRO read_avro DataFrame.write_avro -Excel -~~~~~ +Spreadsheet +~~~~~~~~~~~ .. autosummary:: :toctree: api/ read_excel + read_ods DataFrame.write_excel +Apache Iceberg +~~~~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + scan_iceberg + Delta Lake ~~~~~~~~~~ .. autosummary:: diff --git a/py-polars/docs/source/reference/lazyframe/miscellaneous.rst b/py-polars/docs/source/reference/lazyframe/miscellaneous.rst index d677b34a43b1..7dba1aa7d61d 100644 --- a/py-polars/docs/source/reference/lazyframe/miscellaneous.rst +++ b/py-polars/docs/source/reference/lazyframe/miscellaneous.rst @@ -8,6 +8,7 @@ Miscellaneous LazyFrame.cache LazyFrame.collect + LazyFrame.collect_async LazyFrame.fetch LazyFrame.lazy LazyFrame.map diff --git a/py-polars/docs/source/reference/lazyframe/modify_select.rst b/py-polars/docs/source/reference/lazyframe/modify_select.rst index 5fa34dea8ad0..1a1482ec4623 100644 --- a/py-polars/docs/source/reference/lazyframe/modify_select.rst +++ b/py-polars/docs/source/reference/lazyframe/modify_select.rst @@ -9,6 +9,7 @@ Manipulation/selection LazyFrame.approx_n_unique LazyFrame.approx_unique LazyFrame.bottom_k + LazyFrame.cast LazyFrame.clear LazyFrame.clone LazyFrame.drop diff --git a/py-polars/docs/source/reference/series/boolean.rst b/py-polars/docs/source/reference/series/boolean.rst index 21dd38cdaeba..33da829f1e77 100644 --- a/py-polars/docs/source/reference/series/boolean.rst +++ b/py-polars/docs/source/reference/series/boolean.rst @@ -8,3 +8,4 @@ Boolean Series.all Series.any + Series.not_ diff --git a/py-polars/docs/source/reference/series/descriptive.rst b/py-polars/docs/source/reference/series/descriptive.rst index 6f1558ec363d..6ec39e326b9f 100644 --- a/py-polars/docs/source/reference/series/descriptive.rst +++ b/py-polars/docs/source/reference/series/descriptive.rst @@ -15,10 +15,13 @@ Descriptive Series.is_empty Series.is_finite Series.is_first + Series.is_first_distinct Series.is_float Series.is_in Series.is_infinite Series.is_integer + Series.is_last + Series.is_last_distinct Series.is_nan Series.is_not_nan Series.is_not_null diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index 46942ab076b9..7f3b709e80db 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -11,11 +11,13 @@ The following methods are available under the `Series.list` attribute. Series.list.all Series.list.any + Series.list.drop_nulls Series.list.arg_max Series.list.arg_min Series.list.concat Series.list.contains Series.list.count_match + Series.list.count_matches Series.list.diff Series.list.difference Series.list.eval @@ -26,6 +28,7 @@ The following methods are available under the `Series.list` attribute. Series.list.join Series.list.intersection Series.list.last + Series.list.len Series.list.lengths Series.list.max Series.list.mean diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst index 59489085588c..5fa2efc643a8 100644 --- a/py-polars/docs/source/reference/series/string.rst +++ b/py-polars/docs/source/reference/series/string.rst @@ -12,6 +12,7 @@ The following methods are available under the `Series.str` attribute. Series.str.concat Series.str.contains Series.str.count_match + Series.str.count_matches Series.str.decode Series.str.encode Series.str.ends_with @@ -21,6 +22,8 @@ The following methods are available under the `Series.str` attribute. Series.str.extract_groups Series.str.json_extract Series.str.json_path_match + Series.str.len_bytes + Series.str.len_chars Series.str.lengths Series.str.ljust Series.str.lstrip @@ -35,6 +38,11 @@ The following methods are available under the `Series.str` attribute. Series.str.splitn Series.str.starts_with Series.str.strip + Series.str.strip_chars + Series.str.strip_chars_start + Series.str.strip_chars_end + Series.str.strip_prefix + Series.str.strip_suffix Series.str.strptime Series.str.to_date Series.str.to_datetime diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 825823c6e43f..9abe028cbe25 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -67,6 +67,7 @@ DuplicateError, InvalidOperationError, NoDataError, + OutOfBoundsError, PolarsPanicError, SchemaError, SchemaFieldNotFoundError, @@ -106,6 +107,8 @@ date_range, date_ranges, datetime, + datetime_range, + datetime_ranges, duration, element, exclude, @@ -165,10 +168,12 @@ read_ipc_stream, read_json, read_ndjson, + read_ods, read_parquet, read_parquet_schema, scan_csv, scan_delta, + scan_iceberg, scan_ipc, scan_ndjson, scan_parquet, @@ -177,7 +182,12 @@ from polars.lazyframe import LazyFrame from polars.series import Series from polars.sql import SQLContext -from polars.string_cache import StringCache, enable_string_cache, using_string_cache +from polars.string_cache import ( + StringCache, + disable_string_cache, + enable_string_cache, + using_string_cache, +) from polars.type_aliases import PolarsDataType from polars.utils import build_info, get_index_type, show_versions, threadpool_size @@ -199,6 +209,7 @@ "DuplicateError", "InvalidOperationError", "NoDataError", + "OutOfBoundsError", "PolarsPanicError", "SchemaError", "SchemaFieldNotFoundError", @@ -259,16 +270,19 @@ "read_ipc_stream", "read_json", "read_ndjson", + "read_ods", "read_parquet", "read_parquet_schema", "scan_csv", "scan_delta", + "scan_iceberg", "scan_ipc", "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", # polars.stringcache "StringCache", + "disable_string_cache", "enable_string_cache", "using_string_cache", # polars.config @@ -281,6 +295,8 @@ "concat", "date_range", "date_ranges", + "datetime_range", + "datetime_ranges", "element", "ones", "repeat", diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index 6ada4300c22d..2ca55f2bd123 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -2,11 +2,13 @@ import contextlib import os +import sys from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, get_args from polars.dependencies import json -from polars.utils.various import normalise_filepath +from polars.utils.deprecation import deprecate_nonkeyword_arguments +from polars.utils.various import normalize_filepath # dummy func required (so docs build) @@ -19,12 +21,35 @@ def _get_float_fmt() -> str: # pragma: no cover from polars.polars import get_float_fmt as _get_float_fmt # type: ignore[no-redef] from polars.polars import set_float_fmt as _set_float_fmt + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + pass + if TYPE_CHECKING: from types import TracebackType - from typing import Literal + + from typing_extensions import TypeAlias from polars.type_aliases import FloatFmt +TableFormatNames: TypeAlias = Literal[ + "ASCII_FULL", + "ASCII_FULL_CONDENSED", + "ASCII_NO_BORDERS", + "ASCII_BORDERS_ONLY", + "ASCII_BORDERS_ONLY_CONDENSED", + "ASCII_HORIZONTAL_ONLY", + "ASCII_MARKDOWN", + "UTF8_FULL", + "UTF8_FULL_CONDENSED", + "UTF8_NO_BORDERS", + "UTF8_BORDERS_ONLY", + "UTF8_HORIZONTAL_ONLY", + "NOTHING", +] + # note: register all Config-specific environment variable names here; need to constrain # which 'POLARS_' environment variables are recognised, as there are other lower-level @@ -93,26 +118,35 @@ def __init__(self, *, restore_defaults: bool = False, **options: Any) -> None: """ Initialise a Config object instance for context manager usage. - Any `options` kwargs should correspond to the available named "set_" - methods, but can optionally to omit the "set_" prefix for brevity. + Any ``options`` kwargs should correspond to the available named "set_*" + methods, but are allowed to omit the "set_" prefix for brevity. Parameters ---------- restore_defaults set all options to their default values (this is applied before setting any other options). - options + **options keyword args that will set the option; equivalent to calling the named "set_