diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 95545adb4c4e..a99a2770491c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,6 +4,7 @@ /crates/polars-sql/ @ritchie46 @orlp @c-peters @alexander-beedie /crates/polars-parquet/ @ritchie46 @orlp @c-peters @coastalwhite /crates/polars-time/ @ritchie46 @orlp @c-peters @MarcoGorelli +/crates/polars-python/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa +/crates/polars-python/src/lazyframe/visit.rs @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- +/crates/polars-python/src/lazyframe/visitor/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- /py-polars/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa -/py-polars/src/lazyframe/visit.rs @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- -/py-polars/src/lazyframe/visitor/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- diff --git a/.github/workflows/benchmark-remote.yml b/.github/workflows/benchmark-remote.yml index f21898f9689f..22a1a1c12a61 100644 --- a/.github/workflows/benchmark-remote.yml +++ b/.github/workflows/benchmark-remote.yml @@ -2,19 +2,22 @@ name: Remote Benchmark on: workflow_dispatch: + push: + branches: + - 'main' pull_request: types: [ labeled ] concurrency: group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + cancel-in-progress: ${{ github.event.label.name == 'needs-bench' }} env: SCALE_FACTOR: '10.0' jobs: main: - if: ${{ github.event.label.name == 'needs-bench' }} + if: ${{ github.ref == 'refs/heads/main' || github.event.label.name == 'needs-bench' }} runs-on: self-hosted steps: - uses: actions/checkout@v4 @@ -62,4 +65,11 @@ jobs: - name: Run benchmark working-directory: polars-benchmark run: | - make run-polars-no-env + make run-polars-no-env | tee ../py-polars/benchmark-results + + - name: Cache the Polars build + if: ${{ github.ref == 'refs/heads/main' }} + working-directory: py-polars + run: | + "$HOME/py-polars-cache/save_benchmark_data.py" "$PWD/polars" < ./benchmark-results + "$HOME/py-polars-cache/cache-build.sh" "$PWD/polars" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4da3d3bad8e3..afa9219231a3 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -71,7 +71,7 @@ jobs: env: RUSTFLAGS: -C embed-bitcode -D warnings working-directory: py-polars - run: maturin develop --features new_streaming --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native + run: maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native - name: Run benchmark tests uses: CodSpeedHQ/action@v3 diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index c44242a2a374..4981710d4773 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -137,7 +137,6 @@ jobs: if: matrix.architecture == 'x86-64' env: IS_LTS_CPU: ${{ matrix.package == 'polars-lts-cpu' }} - IS_MACOS: ${{ matrix.os == 'macos-13' }} # IMPORTANT: All features enabled here should also be included in py-polars/polars/_cpu_check.py run: | if [[ "$IS_LTS_CPU" = true ]]; then @@ -175,7 +174,7 @@ jobs: - name: Set variables in CPU check module - LTS_CPU if: matrix.package == 'polars-lts-cpu' run: | - sed $SED_INPLACE 's/^_LTS_CPU = False$/_LTS_CPU = True/g' $CPU_CHECK_MODULE + sed $SED_INPLACE 's/^_POLARS_LTS_CPU = False$/_POLARS_LTS_CPU = True/g' $CPU_CHECK_MODULE - name: Set Rust target for aarch64 if: matrix.architecture == 'aarch64' @@ -195,7 +194,7 @@ jobs: command: build target: ${{ steps.target.outputs.target }} args: > - --release + --profile dist-release --manifest-path py-polars/Cargo.toml --out dist manylinux: ${{ matrix.architecture == 'aarch64' && '2_24' || 'auto' }} diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 6185b905f7c7..a117f5b7fe96 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -82,7 +82,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Install Polars - run: maturin develop --features new_streaming + run: maturin develop - name: Run doctests if: github.ref_name != 'main' && matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' diff --git a/Cargo.lock b/Cargo.lock index b889fde7503f..967c4f3562cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 4 [[package]] name = "addr2line" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -30,7 +30,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "const-random", "getrandom", "once_cell", "version_check", @@ -90,15 +89,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "apache-avro" @@ -168,51 +167,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "arrow-array" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd2bf348cf9f02a5975c5962c7fa6dee107a2009a7b41ac5fb1a027e12dc033f" -dependencies = [ - "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "hashbrown 0.14.5", - "num", -] - -[[package]] -name = "arrow-buffer" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3092e37715f168976012ce52273c3989b5793b0db5f06cbaa246be25e5f0924d" -dependencies = [ - "bytes", - "half", - "num", -] - -[[package]] -name = "arrow-data" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4ac0c4ee79150afe067dc4857154b3ee9c1cd52b5f40d59a77306d0ed18d65" -dependencies = [ - "arrow-buffer", - "arrow-schema", - "half", - "num", -] - -[[package]] -name = "arrow-schema" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85320a3a2facf2b2822b57aa9d6d9d55edb8aee0b6b5d3b8df158e503d10858" - [[package]] name = "arrow2" version = "0.17.4" @@ -252,7 +206,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -263,7 +217,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -311,9 +265,9 @@ dependencies = [ [[package]] name = "aws-config" -version = "1.5.7" +version = "1.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8191fb3091fa0561d1379ef80333c3c7191c6f0435d986e85821bcf7acbd1126" +checksum = "2d6448cfb224dd6a9b9ac734f58622dd0d4751f3589f3b777345745f46b2eb14" dependencies = [ "aws-credential-types", "aws-runtime", @@ -414,9 +368,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.44.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b90cfe6504115e13c41d3ea90286ede5aa14da294f3fe077027a6e83850843c" +checksum = "a8776850becacbd3a82a4737a9375ddb5c6832a51379f24443a98e61513f852c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -436,9 +390,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.45.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "167c0fad1f212952084137308359e8e4c4724d1c643038ce163f06de9662c1d0" +checksum = "0007b5b8004547133319b6c4e87193eee2a0bcb3e4c18c75d09febe9dab7b383" dependencies = [ "aws-credential-types", "aws-runtime", @@ -458,9 +412,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.44.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb5f98188ec1435b68097daa2a37d74b9d17c9caa799466338a8d1544e71b9d" +checksum = "9fffaa356e7f1c725908b75136d53207fa714e348f365671df14e95a60530ad3" dependencies = [ "aws-credential-types", "aws-runtime", @@ -481,9 +435,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -593,9 +547,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -608,7 +562,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -637,9 +591,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.7" +version = "1.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" +checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" dependencies = [ "base64-simd", "bytes", @@ -735,9 +689,9 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bigdecimal" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" +checksum = "8f850665a0385e070b64c38d2354e6c104c8479c59868d1e48a0c13ee2c7a1c1" dependencies = [ "autocfg", "libm", @@ -825,22 +779,22 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -851,9 +805,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "bytes-utils" @@ -891,9 +845,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.24" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812acba72f0a070b003d3697490d2b55b837230ae7c6c6497f05cc2ddbb8d938" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "jobserver", "libc", @@ -970,18 +924,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstyle", "clap_lex", @@ -1044,26 +998,6 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" -[[package]] -name = "const-random" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" -dependencies = [ - "const-random-macro", -] - -[[package]] -name = "const-random-macro" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" -dependencies = [ - "getrandom", - "once_cell", - "tiny-keccak", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -1278,9 +1212,9 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "der" @@ -1374,7 +1308,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -1456,9 +1390,9 @@ dependencies = [ [[package]] name = "float-cmp" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" dependencies = [ "num-traits", ] @@ -1502,9 +1436,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1517,9 +1451,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1527,15 +1461,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1544,38 +1478,38 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1624,9 +1558,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1691,7 +1625,6 @@ checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ "cfg-if", "crunchy", - "num-traits", ] [[package]] @@ -1859,9 +1792,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -1883,9 +1816,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", @@ -1909,7 +1842,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -1925,9 +1858,9 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", - "rustls 0.23.13", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -1946,7 +1879,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.0", "pin-project-lite", "socket2", "tokio", @@ -2012,9 +1945,9 @@ checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" [[package]] name = "ipnet" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is-terminal" @@ -2088,9 +2021,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2114,9 +2047,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libflate" @@ -2174,9 +2107,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "a00419de735aac21d53b0de5ce2c03bd3627277cf471300f27ebc89f7d828047" [[package]] name = "libmimalloc-sys" @@ -2190,9 +2123,9 @@ dependencies = [ [[package]] name = "libz-ng-sys" -version = "1.1.16" +version = "1.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4436751a01da56f1277f323c80d584ffad94a3d14aecd959dd0dff75aa73a438" +checksum = "8f0f7295a34685977acb2e8cc8b08ee4a8dffd6cf278eeccddbe1ed55ba815d5" dependencies = [ "cmake", "libc", @@ -2222,11 +2155,11 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lru" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.15.0", ] [[package]] @@ -2390,20 +2323,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - [[package]] name = "num-bigint" version = "0.4.6" @@ -2439,28 +2358,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -2586,9 +2483,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.4" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] @@ -2605,7 +2502,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.0", "itertools 0.13.0", "md-5", "parking_lot", @@ -2626,12 +2523,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.1" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" -dependencies = [ - "portable-atomic", -] +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" @@ -2685,16 +2579,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "parquet-format-safe" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" -dependencies = [ - "async-trait", - "futures", -] - [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -2750,9 +2634,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -2815,11 +2699,10 @@ dependencies = [ [[package]] name = "polars" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "apache-avro", - "arrow-buffer", "avro-schema", "either", "ethnum", @@ -2845,13 +2728,9 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", "async-stream", "atoi", "atoi_simd", @@ -2894,6 +2773,7 @@ dependencies = [ "simdutf8", "streaming-iterator", "strength_reduce", + "strum_macros", "tokio", "tokio-util", "version_check", @@ -2914,7 +2794,7 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.43.1" +version = "0.44.1" dependencies = [ "bytemuck", "either", @@ -2929,10 +2809,9 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", - "arrow-array", "bincode", "bitflags", "bytemuck", @@ -2958,6 +2837,7 @@ dependencies = [ "regex", "serde", "serde_json", + "strum_macros", "thiserror", "version_check", "xxhash-rust", @@ -2965,7 +2845,7 @@ dependencies = [ [[package]] name = "polars-doc-examples" -version = "0.43.1" +version = "0.44.1" dependencies = [ "aws-config", "aws-sdk-s3", @@ -2979,7 +2859,7 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.43.1" +version = "0.44.1" dependencies = [ "avro-schema", "object_store", @@ -2991,10 +2871,12 @@ dependencies = [ [[package]] name = "polars-expr" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bitflags", + "hashbrown 0.15.0", + "num-traits", "once_cell", "polars-arrow", "polars-compute", @@ -3003,14 +2885,16 @@ dependencies = [ "polars-json", "polars-ops", "polars-plan", + "polars-row", "polars-time", "polars-utils", + "rand", "rayon", ] [[package]] name = "polars-ffi" -version = "0.43.1" +version = "0.44.1" dependencies = [ "polars-arrow", "polars-core", @@ -3018,7 +2902,7 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "async-trait", @@ -3049,6 +2933,7 @@ dependencies = [ "polars-schema", "polars-time", "polars-utils", + "pyo3", "rayon", "regex", "reqwest", @@ -3066,7 +2951,7 @@ dependencies = [ [[package]] name = "polars-json" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "chrono", @@ -3086,7 +2971,7 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bitflags", @@ -3114,7 +2999,7 @@ dependencies = [ [[package]] name = "polars-mem-engine" -version = "0.43.1" +version = "0.44.1" dependencies = [ "futures", "memmap2", @@ -3135,7 +3020,7 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "aho-corasick", @@ -3162,15 +3047,17 @@ dependencies = [ "rand_distr", "rayon", "regex", + "regex-syntax 0.8.5", "serde", "serde_json", + "strum_macros", "unicode-reverse", "version_check", ] [[package]] name = "polars-parquet" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "async-stream", @@ -3185,10 +3072,10 @@ dependencies = [ "lz4", "lz4_flex", "num-traits", - "parquet-format-safe", "polars-arrow", "polars-compute", "polars-error", + "polars-parquet-format", "polars-utils", "rand", "serde", @@ -3199,9 +3086,19 @@ dependencies = [ "zstd", ] +[[package]] +name = "polars-parquet-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" +dependencies = [ + "async-trait", + "futures", +] + [[package]] name = "polars-pipe" -version = "0.43.1" +version = "0.44.1" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -3226,7 +3123,7 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bitflags", @@ -3240,6 +3137,7 @@ dependencies = [ "hashbrown 0.15.0", "libloading", "memmap2", + "num-traits", "once_cell", "percent-encoding", "polars-arrow", @@ -3263,7 +3161,7 @@ dependencies = [ [[package]] name = "polars-python" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "arboard", @@ -3280,8 +3178,10 @@ dependencies = [ "polars", "polars-core", "polars-error", + "polars-expr", "polars-io", "polars-lazy", + "polars-mem-engine", "polars-ops", "polars-parquet", "polars-plan", @@ -3297,7 +3197,7 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.43.1" +version = "0.44.1" dependencies = [ "bytemuck", "polars-arrow", @@ -3307,7 +3207,7 @@ dependencies = [ [[package]] name = "polars-schema" -version = "0.43.1" +version = "0.44.1" dependencies = [ "indexmap", "polars-error", @@ -3318,7 +3218,7 @@ dependencies = [ [[package]] name = "polars-sql" -version = "0.43.1" +version = "0.44.1" dependencies = [ "hex", "once_cell", @@ -3338,7 +3238,7 @@ dependencies = [ [[package]] name = "polars-stream" -version = "0.43.1" +version = "0.44.1" dependencies = [ "atomic-waker", "crossbeam-deque", @@ -3365,7 +3265,7 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.43.1" +version = "0.44.1" dependencies = [ "atoi", "bytemuck", @@ -3380,11 +3280,12 @@ dependencies = [ "polars-utils", "regex", "serde", + "strum_macros", ] [[package]] name = "polars-utils" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bytemuck", @@ -3397,6 +3298,7 @@ dependencies = [ "num-traits", "once_cell", "polars-error", + "pyo3", "rand", "raw-cpuid", "rayon", @@ -3429,9 +3331,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -3486,7 +3388,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "1.9.0" +version = "1.12.0" dependencies = [ "jemallocator", "libc", @@ -3545,7 +3447,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -3558,7 +3460,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -3599,7 +3501,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.13", + "rustls 0.23.15", "socket2", "thiserror", "tokio", @@ -3616,7 +3518,7 @@ dependencies = [ "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.13", + "rustls 0.23.15", "slab", "thiserror", "tinyvec", @@ -3756,7 +3658,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -3785,14 +3687,14 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -3844,7 +3746,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3855,7 +3757,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.13", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3936,9 +3838,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" dependencies = [ "bitflags", "errno", @@ -3961,9 +3863,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.13" +version = "0.23.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" dependencies = [ "once_cell", "ring", @@ -4018,9 +3920,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -4045,9 +3947,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -4112,9 +4014,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ "windows-sys 0.59.0", ] @@ -4180,9 +4082,9 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] @@ -4198,20 +4100,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "indexmap", "itoa", @@ -4281,9 +4183,9 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.14.0" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f0b376aada35f30a0012f5790e50aed62f91804a0682669aefdbe81c7fcb91" +checksum = "b1df0290e9bfe79ddd5ff8798ca887cd107b75353d2957efe9777296e17f26b5" dependencies = [ "ahash", "getrandom", @@ -4451,7 +4353,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4473,9 +4375,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -4531,22 +4433,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4579,15 +4481,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinytemplate" version = "1.2.1" @@ -4615,9 +4508,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", @@ -4638,7 +4531,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4657,7 +4550,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.13", + "rustls 0.23.15", "rustls-pki-types", "tokio", ] @@ -4701,7 +4594,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4746,7 +4639,7 @@ checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4834,9 +4727,9 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -4844,9 +4737,9 @@ dependencies = [ [[package]] name = "value-trait" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcaa56177466248ba59d693a048c0959ddb67f1151b963f904306312548cf392" +checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" dependencies = [ "float-cmp", "halfbrown", @@ -4893,9 +4786,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -4904,24 +4797,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4931,9 +4824,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4941,28 +4834,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -4973,9 +4866,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -5051,7 +4944,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -5062,7 +4955,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -5290,7 +5183,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 39204515e5af..a57add1faf0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,18 +3,14 @@ resolver = "2" members = [ "crates/*", "docs/source/src/rust", - # "examples/*", "py-polars", ] default-members = [ "crates/*", ] -# exclude = [ -# "examples/datasets", -# ] [workspace.package] -version = "0.43.1" +version = "0.44.1" authors = ["Ritchie Vink "] edition = "2021" homepage = "https://www.pola.rs/" @@ -25,10 +21,6 @@ repository = "https://github.com/pola-rs/polars" ahash = ">=0.8.5" aho-corasick = "1.1" arboard = { version = "3.4.0", default-features = false } -arrow-array = { version = ">=41", default-features = false } -arrow-buffer = { version = ">=41", default-features = false } -arrow-data = { version = ">=41", default-features = false } -arrow-schema = { version = ">=41", default-features = false } atoi = "2" atoi_simd = "0.15.5" atomic-waker = "1" @@ -76,6 +68,7 @@ raw-cpuid = "11" rayon = "1.9" recursive = "0.1" regex = "1.9" +regex-syntax = "0.8.5" reqwest = { version = "0.12", default-features = false } ryu = "1.0.13" serde = { version = "1.0.188", features = ["derive", "rc"] } @@ -98,27 +91,27 @@ version_check = "0.9.4" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } zstd = "0.13" -polars = { version = "0.43.1", path = "crates/polars", default-features = false } -polars-compute = { version = "0.43.1", path = "crates/polars-compute", default-features = false } -polars-core = { version = "0.43.1", path = "crates/polars-core", default-features = false } -polars-error = { version = "0.43.1", path = "crates/polars-error", default-features = false } -polars-expr = { version = "0.43.1", path = "crates/polars-expr", default-features = false } -polars-ffi = { version = "0.43.1", path = "crates/polars-ffi", default-features = false } -polars-io = { version = "0.43.1", path = "crates/polars-io", default-features = false } -polars-json = { version = "0.43.1", path = "crates/polars-json", default-features = false } -polars-lazy = { version = "0.43.1", path = "crates/polars-lazy", default-features = false } -polars-mem-engine = { version = "0.43.1", path = "crates/polars-mem-engine", default-features = false } -polars-ops = { version = "0.43.1", path = "crates/polars-ops", default-features = false } -polars-parquet = { version = "0.43.1", path = "crates/polars-parquet", default-features = false } -polars-pipe = { version = "0.43.1", path = "crates/polars-pipe", default-features = false } -polars-plan = { version = "0.43.1", path = "crates/polars-plan", default-features = false } -polars-python = { version = "0.43.1", path = "crates/polars-python", default-features = false } -polars-row = { version = "0.43.1", path = "crates/polars-row", default-features = false } -polars-schema = { version = "0.43.1", path = "crates/polars-schema", default-features = false } -polars-sql = { version = "0.43.1", path = "crates/polars-sql", default-features = false } -polars-stream = { version = "0.43.1", path = "crates/polars-stream", default-features = false } -polars-time = { version = "0.43.1", path = "crates/polars-time", default-features = false } -polars-utils = { version = "0.43.1", path = "crates/polars-utils", default-features = false } +polars = { version = "0.44.1", path = "crates/polars", default-features = false } +polars-compute = { version = "0.44.1", path = "crates/polars-compute", default-features = false } +polars-core = { version = "0.44.1", path = "crates/polars-core", default-features = false } +polars-error = { version = "0.44.1", path = "crates/polars-error", default-features = false } +polars-expr = { version = "0.44.1", path = "crates/polars-expr", default-features = false } +polars-ffi = { version = "0.44.1", path = "crates/polars-ffi", default-features = false } +polars-io = { version = "0.44.1", path = "crates/polars-io", default-features = false } +polars-json = { version = "0.44.1", path = "crates/polars-json", default-features = false } +polars-lazy = { version = "0.44.1", path = "crates/polars-lazy", default-features = false } +polars-mem-engine = { version = "0.44.1", path = "crates/polars-mem-engine", default-features = false } +polars-ops = { version = "0.44.1", path = "crates/polars-ops", default-features = false } +polars-parquet = { version = "0.44.1", path = "crates/polars-parquet", default-features = false } +polars-pipe = { version = "0.44.1", path = "crates/polars-pipe", default-features = false } +polars-plan = { version = "0.44.1", path = "crates/polars-plan", default-features = false } +polars-python = { version = "0.44.1", path = "crates/polars-python", default-features = false } +polars-row = { version = "0.44.1", path = "crates/polars-row", default-features = false } +polars-schema = { version = "0.44.1", path = "crates/polars-schema", default-features = false } +polars-sql = { version = "0.44.1", path = "crates/polars-sql", default-features = false } +polars-stream = { version = "0.44.1", path = "crates/polars-stream", default-features = false } +polars-time = { version = "0.44.1", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.44.1", path = "crates/polars-utils", default-features = false } [workspace.dependencies.arrow-format] package = "polars-arrow-format" @@ -126,7 +119,7 @@ version = "0.1.0" [workspace.dependencies.arrow] package = "polars-arrow" -version = "0.43.1" +version = "0.44.1" path = "crates/polars-arrow" default-features = false features = [ @@ -143,17 +136,24 @@ features = [ # packed_simd_2 = { git = "https://github.com/rust-lang/packed_simd", rev = "e57c7ba11386147e6d2cbad7c88f376aab4bdc86" } # simd-json = { git = "https://github.com/ritchie46/simd-json", branch = "alignment" } -[profile.opt-dev] +[profile.mindebug-dev] inherits = "dev" -opt-level = 1 +debug = "line-tables-only" + +[profile.release] +lto = "thin" +debug = "line-tables-only" + +[profile.nodebug-release] +inherits = "release" +debug = false [profile.debug-release] inherits = "release" debug = true -incremental = true -codegen-units = 16 -lto = "thin" -[profile.release] +[profile.dist-release] +inherits = "release" codegen-units = 1 +debug = false lto = "fat" diff --git a/Makefile b/Makefile index 5dd746aa5b7a..534e14076b73 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,56 @@ else VENV_BIN=$(VENV)/bin endif +# Detect CPU architecture. +ifeq ($(OS),Windows_NT) + ifeq ($(PROCESSOR_ARCHITECTURE),AMD64) + ARCH := amd64 + else ifeq ($(PROCESSOR_ARCHITECTURE),x86) + ARCH := x86 + else ifeq ($(PROCESSOR_ARCHITECTURE),ARM64) + ARCH := arm64 + else + ARCH := unknown + endif +else + UNAME_P := $(shell uname -p) + ifeq ($(UNAME_P),x86_64) + ARCH := amd64 + else ifneq ($(filter %86,$(UNAME_P)),) + ARCH := x86 + else ifneq ($(filter arm%,$(UNAME_P)),) + ARCH := arm64 + else + ARCH := unknown + endif +endif + +# Ensure boolean arguments are normalized to 1/0 to prevent surprises. +ifdef LTS_CPU + ifeq ($(LTS_CPU),0) + else ifeq ($(LTS_CPU),1) + else +$(error LTS_CPU must be 0 or 1 (or undefined, default to 0)) + endif +endif + +# Define RUSTFLAGS and CFLAGS appropriate for the architecture. +# Keep synchronized with .github/workflows/release-python.yml. +ifeq ($(ARCH),amd64) + ifeq ($(LTS_CPU),1) + FEAT_RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b + FEAT_CFLAGS=-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 + else + FEAT_RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt,+pclmulqdq,+movbe -Z tune-cpu=skylake + FEAT_CFLAGS=-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 -mavx -mavx2 -mfma -mbmi -mbmi2 -mlzcnt -mpclmul -mmovbe -mtune=skylake + endif +endif + +override RUSTFLAGS+=$(FEAT_RUSTFLAGS) +override CFLAGS+=$(FEAT_CFLAGS) +export RUSTFLAGS +export CFLAGS + # Define command to filter pip warnings when running maturin FILTER_PIP_WARNINGS=| grep -v "don't match your environment"; test $${PIPESTATUS[0]} -eq 0 @@ -35,55 +85,37 @@ requirements-all: .venv ## Install/refresh all Python requirements (including t .PHONY: build build: .venv ## Compile and install Python Polars for development @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-debug-opt -build-debug-opt: .venv ## Compile and install Python Polars with minimal optimizations turned on +.PHONY: build-mindebug +build-mindebug: .venv ## Same as build, but don't include full debug information @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile opt-dev \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile mindebug-dev $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-debug-opt-subset -build-debug-opt-subset: .venv ## Compile and install Python Polars with minimal optimizations turned on and no default features +.PHONY: build-release +build-release: .venv ## Compile and install Python Polars binary with optimizations, with minimal debug symbols @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --no-default-features --profile opt-dev \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-opt -build-opt: .venv ## Compile and install Python Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on +.PHONY: build-nodebug-release +build-nodebug-release: .venv ## Same as build-release, but without any debug symbols at all (a bit faster to build) @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile nodebug-release $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-release -build-release: .venv ## Compile and install a faster Python Polars binary with full optimizations +.PHONY: build-debug-release +build-debug-release: .venv ## Same as build-release, but with full debug symbols turned on (a bit slower to build) @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-native -build-native: .venv ## Same as build, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-debug-opt-native -build-debug-opt-native: .venv ## Same as build-debug-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile opt-dev \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-opt-native -build-opt-native: .venv ## Same as build-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-release-native -build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release \ +.PHONY: build-dist-release +build-dist-release: .venv ## Compile and install Python Polars binary with super slow extra optimization turned on, for distribution + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile dist-release $(ARGS) \ $(FILTER_PIP_WARNINGS) .PHONY: check @@ -121,3 +153,6 @@ clean: ## Clean up caches, build artifacts, and the venv help: ## Display this help screen @echo -e "\033[1mAvailable commands:\033[0m" @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort + @echo + @echo The build commands support LTS_CPU=1 for building for older CPUs, and ARGS which is passed through to maturin. + @echo 'For example to build without default features use: make build ARGS="--no-default-features".' diff --git a/README.md b/README.md index d1885794ee55..43ac43596813 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@

- Polars logo -
+ + Polars logo +

@@ -232,13 +233,14 @@ This can be done by going through the following steps in sequence: 1. Install the latest [Rust compiler](https://www.rust-lang.org/tools/install) 2. Install [maturin](https://maturin.rs/): `pip install maturin` 3. `cd py-polars` and choose one of the following: - - `make build-release`, fastest binary, very long compile times - - `make build-opt`, fast binary with debug symbols, long compile times - - `make build-debug-opt`, medium-speed binary with debug assertions and symbols, medium compile times - `make build`, slow binary with debug assertions and symbols, fast compile times + - `make build-release`, fast binary without debug assertions, minimal debug symbols, long compile times + - `make build-nodebug-release`, same as build-release but without any debug symbols, slightly faster to compile + - `make build-debug-release`, same as build-release but with full debug symbols, slightly slower to compile + - `make build-dist-release`, fastest binary, extreme compile times - Append `-native` (e.g. `make build-release-native`) to enable further optimizations specific to - your CPU. This produces a non-portable binary/wheel however. +By default the binary is compiled with optimizations turned on for a modern CPU. Specify `LTS_CPU=1` +with the command if your CPU is older and does not support e.g. AVX2. Note that the Rust crate implementing the Python bindings is called `py-polars` to distinguish from the wrapped Rust crate `polars` itself. However, both the Python package and the Python module are named `polars`, so you diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 5e7e1eebff0a..5ce6d5deb9e7 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -14,7 +14,7 @@ description = "Minimal implementation of the Arrow specification forked from arr [dependencies] atoi = { workspace = true, optional = true } -bytemuck = { workspace = true } +bytemuck = { workspace = true, features = ["must_cast"] } chrono = { workspace = true } # for timezone support chrono-tz = { workspace = true, optional = true } @@ -55,9 +55,6 @@ zstd = { workspace = true, optional = true } # to write to parquet as a stream futures = { workspace = true, optional = true } -# to read IPC as a stream -async-stream = { version = "0.3.2", optional = true } - # avro support avro-schema = { workspace = true, optional = true } @@ -70,11 +67,11 @@ multiversion = { workspace = true, optional = true } # Faster hashing ahash = { workspace = true } -# Support conversion to/from arrow-rs -arrow-array = { workspace = true, optional = true } -arrow-buffer = { workspace = true, optional = true } -arrow-data = { workspace = true, optional = true } -arrow-schema = { workspace = true, optional = true } +# For async arrow flight conversion +async-stream = { version = "0.3", optional = true } +tokio = { workspace = true, optional = true, features = ["io-util"] } + +strum_macros = { workspace = true } [dev-dependencies] criterion = "0.5" @@ -102,11 +99,8 @@ getrandom = { version = "0.2", features = ["js"] } [features] default = [] full = [ - "arrow_rs", "io_ipc", "io_flight", - "io_ipc_write_async", - "io_ipc_read_async", "io_ipc_compression", "io_avro", "io_avro_compression", @@ -117,12 +111,9 @@ full = [ # parses timezones used in timestamp conversions "chrono-tz", ] -arrow_rs = ["arrow-buffer", "arrow-schema", "arrow-data", "arrow-array"] io_ipc = ["arrow-format", "polars-error/arrow-format"] -io_ipc_write_async = ["io_ipc", "futures"] -io_ipc_read_async = ["io_ipc", "futures", "async-stream"] io_ipc_compression = ["lz4", "zstd", "io_ipc"] -io_flight = ["io_ipc", "arrow-format/flight-data"] +io_flight = ["io_ipc", "arrow-format/flight-data", "async-stream", "futures", "tokio"] io_avro = ["avro-schema", "polars-error/avro-schema"] io_avro_compression = [ @@ -163,7 +154,7 @@ timezones = [ ] dtype-array = [] dtype-decimal = ["atoi", "itoap"] -bigidx = [] +bigidx = ["polars-utils/bigidx"] nightly = [] performant = [] strings = [] diff --git a/crates/polars-arrow/src/array/binary/data.rs b/crates/polars-arrow/src/array/binary/data.rs deleted file mode 100644 index 2c08d94eb1b0..000000000000 --- a/crates/polars-arrow/src/array/binary/data.rs +++ /dev/null @@ -1,43 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, BinaryArray}; -use crate::bitmap::Bitmap; -use crate::offset::{Offset, OffsetsBuffer}; - -impl Arrow2Arrow for BinaryArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.offsets().len_proxy()) - .buffers(vec![ - self.offsets.clone().into_inner().into(), - self.values.clone().into(), - ]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let buffers = data.buffers(); - - // SAFETY: ArrayData is valid - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype, - offsets, - values: buffers[1].clone().into(), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/binary/mod.rs b/crates/polars-arrow/src/array/binary/mod.rs index 9cd06adaaabf..b590a4554597 100644 --- a/crates/polars-arrow/src/array/binary/mod.rs +++ b/crates/polars-arrow/src/array/binary/mod.rs @@ -21,9 +21,6 @@ mod mutable; pub use mutable::*; use polars_error::{polars_bail, PolarsResult}; -#[cfg(feature = "arrow_rs")] -mod data; - /// A [`BinaryArray`] is Arrow's semantically equivalent of an immutable `Vec>>`. /// It implements [`Array`]. /// diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 67334a53aa17..15a744f804e9 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -11,7 +11,7 @@ use polars_utils::total_ord::{TotalEq, TotalOrd}; use crate::buffer::Buffer; use crate::datatypes::PrimitiveType; -use crate::types::NativeType; +use crate::types::{Bytes16Alignment4, NativeType}; // We use this instead of u128 because we want alignment of <= 8 bytes. /// A reference to a set of bytes. @@ -346,7 +346,9 @@ impl MinMax for View { impl NativeType for View { const PRIMITIVE: PrimitiveType = PrimitiveType::UInt128; + type Bytes = [u8; 16]; + type AlignedBytes = Bytes16Alignment4; #[inline] fn to_le_bytes(&self) -> Self::Bytes { diff --git a/crates/polars-arrow/src/array/boolean/data.rs b/crates/polars-arrow/src/array/boolean/data.rs deleted file mode 100644 index 6c497896775c..000000000000 --- a/crates/polars-arrow/src/array/boolean/data.rs +++ /dev/null @@ -1,36 +0,0 @@ -use arrow_buffer::{BooleanBuffer, NullBuffer}; -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, BooleanArray}; -use crate::bitmap::Bitmap; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for BooleanArray { - fn to_data(&self) -> ArrayData { - let buffer = NullBuffer::from(self.values.clone()); - - let builder = ArrayDataBuilder::new(arrow_schema::DataType::Boolean) - .len(buffer.len()) - .offset(buffer.offset()) - .buffers(vec![buffer.into_inner().into_inner()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - assert_eq!(data.data_type(), &arrow_schema::DataType::Boolean); - - let buffers = data.buffers(); - let buffer = BooleanBuffer::new(buffers[0].clone(), data.offset(), data.len()); - // Use NullBuffer to compute set count - let values = Bitmap::from_null_buffer(NullBuffer::new(buffer)); - - Self { - dtype: ArrowDataType::Boolean, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/boolean/mod.rs b/crates/polars-arrow/src/array/boolean/mod.rs index 5cd9870fdbf4..c1a17c0f27f3 100644 --- a/crates/polars-arrow/src/array/boolean/mod.rs +++ b/crates/polars-arrow/src/array/boolean/mod.rs @@ -7,8 +7,6 @@ use crate::bitmap::{Bitmap, MutableBitmap}; use crate::datatypes::{ArrowDataType, PhysicalType}; use crate::trusted_len::TrustedLen; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod from; diff --git a/crates/polars-arrow/src/array/dictionary/data.rs b/crates/polars-arrow/src/array/dictionary/data.rs deleted file mode 100644 index a5eda5a0fd73..000000000000 --- a/crates/polars-arrow/src/array/dictionary/data.rs +++ /dev/null @@ -1,49 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{ - from_data, to_data, Arrow2Arrow, DictionaryArray, DictionaryKey, PrimitiveArray, -}; -use crate::datatypes::{ArrowDataType, PhysicalType}; - -impl Arrow2Arrow for DictionaryArray { - fn to_data(&self) -> ArrayData { - let keys = self.keys.to_data(); - let builder = keys - .into_builder() - .data_type(self.dtype.clone().into()) - .child_data(vec![to_data(self.values.as_ref())]); - - // SAFETY: Dictionary is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let key = match data.data_type() { - arrow_schema::DataType::Dictionary(k, _) => k.as_ref(), - d => panic!("unsupported dictionary type {d}"), - }; - - let dtype = ArrowDataType::from(data.data_type().clone()); - assert_eq!( - dtype.to_physical_type(), - PhysicalType::Dictionary(K::KEY_TYPE) - ); - - let key_builder = ArrayDataBuilder::new(key.clone()) - .buffers(vec![data.buffers()[0].clone()]) - .offset(data.offset()) - .len(data.len()) - .nulls(data.nulls().cloned()); - - // SAFETY: Dictionary is valid - let key_data = unsafe { key_builder.build_unchecked() }; - let keys = PrimitiveArray::from_data(&key_data); - let values = from_data(&data.child_data()[0]); - - Self { - dtype, - keys, - values, - } - } -} diff --git a/crates/polars-arrow/src/array/dictionary/iterator.rs b/crates/polars-arrow/src/array/dictionary/iterator.rs index 68e95ca86fed..af6ef539572d 100644 --- a/crates/polars-arrow/src/array/dictionary/iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/iterator.rs @@ -21,7 +21,7 @@ impl<'a, K: DictionaryKey> DictionaryValuesIter<'a, K> { } } -impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> { +impl Iterator for DictionaryValuesIter<'_, K> { type Item = Box; #[inline] @@ -40,9 +40,9 @@ impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> { } } -unsafe impl<'a, K: DictionaryKey> TrustedLen for DictionaryValuesIter<'a, K> {} +unsafe impl TrustedLen for DictionaryValuesIter<'_, K> {} -impl<'a, K: DictionaryKey> DoubleEndedIterator for DictionaryValuesIter<'a, K> { +impl DoubleEndedIterator for DictionaryValuesIter<'_, K> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs index d53970dacd98..f23c409c48a9 100644 --- a/crates/polars-arrow/src/array/dictionary/mod.rs +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -8,8 +8,6 @@ use crate::scalar::{new_scalar, Scalar}; use crate::trusted_len::TrustedLen; use crate::types::NativeType; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs index 5257bde2cae0..d7e7637bf28d 100644 --- a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs @@ -117,11 +117,9 @@ impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped< } } -unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryValuesIterTyped<'a, K, V> {} +unsafe impl TrustedLen for DictionaryValuesIterTyped<'_, K, V> {} -impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator - for DictionaryValuesIterTyped<'a, K, V> -{ +impl DoubleEndedIterator for DictionaryValuesIterTyped<'_, K, V> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { @@ -181,9 +179,9 @@ impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryIterTyped<'a, K, } } -unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryIterTyped<'a, K, V> {} +unsafe impl TrustedLen for DictionaryIterTyped<'_, K, V> {} -impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator for DictionaryIterTyped<'a, K, V> { +impl DoubleEndedIterator for DictionaryIterTyped<'_, K, V> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/fixed_size_binary/data.rs b/crates/polars-arrow/src/array/fixed_size_binary/data.rs deleted file mode 100644 index f04be9883f64..000000000000 --- a/crates/polars-arrow/src/array/fixed_size_binary/data.rs +++ /dev/null @@ -1,37 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, FixedSizeBinaryArray}; -use crate::bitmap::Bitmap; -use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for FixedSizeBinaryArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.values.clone().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype: ArrowDataType = data.data_type().clone().into(); - let size = match dtype { - ArrowDataType::FixedSizeBinary(size) => size, - _ => unreachable!("must be FixedSizeBinary"), - }; - - let mut values: Buffer = data.buffers()[0].clone().into(); - values.slice(data.offset() * size, data.len() * size); - - Self { - size, - dtype, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs index ec3f96626c14..f8f5a1760d45 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs @@ -3,8 +3,6 @@ use crate::bitmap::Bitmap; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/fixed_size_list/data.rs b/crates/polars-arrow/src/array/fixed_size_list/data.rs deleted file mode 100644 index c1f353db691a..000000000000 --- a/crates/polars-arrow/src/array/fixed_size_list/data.rs +++ /dev/null @@ -1,38 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, FixedSizeListArray}; -use crate::bitmap::Bitmap; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for FixedSizeListArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(vec![to_data(self.values.as_ref())]); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype: ArrowDataType = data.data_type().clone().into(); - let length = data.len() - data.offset(); - let size = match dtype { - ArrowDataType::FixedSizeList(_, size) => size, - _ => unreachable!("must be FixedSizeList type"), - }; - - let mut values = from_data(&data.child_data()[0]); - values.slice(data.offset() * size, data.len() * size); - - Self { - size, - length, - dtype, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 4f1622819813..32267cc5a4b7 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -1,9 +1,7 @@ -use super::{new_empty_array, new_null_array, Array, Splitable}; +use super::{new_empty_array, new_null_array, Array, ArrayRef, Splitable}; use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; @@ -11,8 +9,11 @@ mod iterator; mod mutable; pub use mutable::*; use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_utils::format_tuple; use polars_utils::pl_str::PlSmallStr; +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + /// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. /// Cloning and slicing this struct is `O(1)`. #[derive(Clone)] @@ -122,6 +123,108 @@ impl FixedSizeListArray { let values = new_null_array(field.dtype().clone(), length * size); Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) } + + pub fn from_shape( + leaf_array: ArrayRef, + dimensions: &[ReshapeDimension], + ) -> PolarsResult { + polars_ensure!( + !dimensions.is_empty(), + InvalidOperation: "at least one dimension must be specified" + ); + let size = leaf_array.len(); + + let mut total_dim_size = 1; + let mut num_infers = 0; + for &dim in dimensions { + match dim { + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, + } + } + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + if size == 0 { + polars_ensure!( + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(prev_array); + } + + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // We pop the outer dimension as that is the height of the series. + for dim in dimensions[1..].iter().rev() { + // Infer dimension if needed + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = FixedSizeListArray::new( + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, + prev_array, + None, + ) + .boxed(); + } + Ok(prev_array) + } + + pub fn get_dims(&self) -> Vec { + let mut dims = vec![ + Dimension::new(self.length as _), + Dimension::new(self.size as _), + ]; + + let mut prev_array = &self.values; + + while let Some(a) = prev_array.as_any().downcast_ref::() { + dims.push(Dimension::new(a.size as _)); + prev_array = &a.values; + } + dims + } } // must use @@ -146,6 +249,7 @@ impl FixedSizeListArray { /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); self.validity = self .validity .take() diff --git a/crates/polars-arrow/src/array/growable/list.rs b/crates/polars-arrow/src/array/growable/list.rs index 90e4f15020a6..095f39522da4 100644 --- a/crates/polars-arrow/src/array/growable/list.rs +++ b/crates/polars-arrow/src/array/growable/list.rs @@ -14,7 +14,7 @@ unsafe fn extend_offset_values( start: usize, len: usize, ) { - let array = growable.arrays[index]; + let array = growable.arrays.get_unchecked_release(index); let offsets = array.offsets(); growable diff --git a/crates/polars-arrow/src/array/growable/null.rs b/crates/polars-arrow/src/array/growable/null.rs index c0b92e132819..e663fc31b8b4 100644 --- a/crates/polars-arrow/src/array/growable/null.rs +++ b/crates/polars-arrow/src/array/growable/null.rs @@ -23,7 +23,7 @@ impl GrowableNull { } } -impl<'a> Growable<'a> for GrowableNull { +impl Growable<'_> for GrowableNull { unsafe fn extend(&mut self, _: usize, _: usize, len: usize) { self.length += len; } diff --git a/crates/polars-arrow/src/array/growable/structure.rs b/crates/polars-arrow/src/array/growable/structure.rs index 5f3d0c107c62..79f922d318fa 100644 --- a/crates/polars-arrow/src/array/growable/structure.rs +++ b/crates/polars-arrow/src/array/growable/structure.rs @@ -10,6 +10,7 @@ use crate::bitmap::MutableBitmap; /// Concrete [`Growable`] for the [`StructArray`]. pub struct GrowableStruct<'a> { arrays: Vec<&'a StructArray>, + length: usize, validity: Option, values: Vec + 'a>>, } @@ -48,6 +49,7 @@ impl<'a> GrowableStruct<'a> { Self { arrays, + length: 0, values, validity: prepare_validity(use_validity, capacity), } @@ -60,6 +62,7 @@ impl<'a> GrowableStruct<'a> { StructArray::new( self.arrays[0].dtype().clone(), + self.length, values, validity.map(|v| v.into()), ) @@ -71,6 +74,8 @@ impl<'a> Growable<'a> for GrowableStruct<'a> { let array = *self.arrays.get_unchecked_release(index); extend_validity(&mut self.validity, array, start, len); + self.length += len; + if array.null_count() == 0 { self.values .iter_mut() @@ -97,6 +102,7 @@ impl<'a> Growable<'a> for GrowableStruct<'a> { if let Some(validity) = &mut self.validity { validity.extend_constant(additional, false); } + self.length += additional; } #[inline] @@ -123,6 +129,7 @@ impl<'a> From> for StructArray { StructArray::new( val.arrays[0].dtype().clone(), + val.length, values, val.validity.map(|v| v.into()), ) diff --git a/crates/polars-arrow/src/array/iterator.rs b/crates/polars-arrow/src/array/iterator.rs index 46b585ef2a36..5009442d5718 100644 --- a/crates/polars-arrow/src/array/iterator.rs +++ b/crates/polars-arrow/src/array/iterator.rs @@ -117,3 +117,12 @@ impl<'a, A: ArrayAccessor<'a> + ?Sized> Iterator for NonNullValuesIter<'a, A> { } unsafe impl<'a, A: ArrayAccessor<'a> + ?Sized> TrustedLen for NonNullValuesIter<'a, A> {} + +impl Clone for NonNullValuesIter<'_, A> { + fn clone(&self) -> Self { + Self { + accessor: self.accessor, + idxs: self.idxs.clone(), + } + } +} diff --git a/crates/polars-arrow/src/array/list/data.rs b/crates/polars-arrow/src/array/list/data.rs deleted file mode 100644 index 0d28583df125..000000000000 --- a/crates/polars-arrow/src/array/list/data.rs +++ /dev/null @@ -1,38 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, ListArray}; -use crate::bitmap::Bitmap; -use crate::offset::{Offset, OffsetsBuffer}; - -impl Arrow2Arrow for ListArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.offsets.clone().into_inner().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(vec![to_data(self.values.as_ref())]); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype, - offsets, - values: from_data(&data.child_data()[0]), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/list/mod.rs b/crates/polars-arrow/src/array/list/mod.rs index 3c2bb6b41f98..87f7b709f14b 100644 --- a/crates/polars-arrow/src/array/list/mod.rs +++ b/crates/polars-arrow/src/array/list/mod.rs @@ -4,8 +4,6 @@ use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; use crate::offset::{Offset, Offsets, OffsetsBuffer}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; @@ -130,6 +128,49 @@ impl ListArray { impl_sliced!(); impl_mut_validity!(); impl_into_array!(); + + pub fn trim_to_normalized_offsets_recursive(&self) -> Self { + let offsets = self.offsets(); + let values = self.values(); + + let first_idx = *offsets.first(); + let len = offsets.range().to_usize(); + + if first_idx.to_usize() == 0 && values.len() == len { + return self.clone(); + } + + let offsets = if first_idx.to_usize() == 0 { + offsets.clone() + } else { + let v = offsets.iter().map(|x| *x - first_idx).collect::>(); + unsafe { OffsetsBuffer::::new_unchecked(v.into()) } + }; + + let values = values.sliced(first_idx.to_usize(), len); + + let values = match values.dtype() { + ArrowDataType::List(_) => { + let inner: &ListArray = values.as_ref().as_any().downcast_ref().unwrap(); + Box::new(inner.trim_to_normalized_offsets_recursive()) as Box + }, + ArrowDataType::LargeList(_) => { + let inner: &ListArray = values.as_ref().as_any().downcast_ref().unwrap(); + Box::new(inner.trim_to_normalized_offsets_recursive()) as Box + }, + _ => values, + }; + + assert_eq!(offsets.first().to_usize(), 0); + assert_eq!(values.len(), offsets.range().to_usize()); + + Self::new( + self.dtype().clone(), + offsets, + values, + self.validity().cloned(), + ) + } } // Accessors diff --git a/crates/polars-arrow/src/array/map/data.rs b/crates/polars-arrow/src/array/map/data.rs deleted file mode 100644 index b5530886d817..000000000000 --- a/crates/polars-arrow/src/array/map/data.rs +++ /dev/null @@ -1,38 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, MapArray}; -use crate::bitmap::Bitmap; -use crate::offset::OffsetsBuffer; - -impl Arrow2Arrow for MapArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.offsets.clone().into_inner().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(vec![to_data(self.field.as_ref())]); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype: data.data_type().clone().into(), - offsets, - field: from_data(&data.child_data()[0]), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/map/iterator.rs b/crates/polars-arrow/src/array/map/iterator.rs index 558405ddc8de..79fc630cc520 100644 --- a/crates/polars-arrow/src/array/map/iterator.rs +++ b/crates/polars-arrow/src/array/map/iterator.rs @@ -22,7 +22,7 @@ impl<'a> MapValuesIter<'a> { } } -impl<'a> Iterator for MapValuesIter<'a> { +impl Iterator for MapValuesIter<'_> { type Item = Box; #[inline] @@ -43,9 +43,9 @@ impl<'a> Iterator for MapValuesIter<'a> { } } -unsafe impl<'a> TrustedLen for MapValuesIter<'a> {} +unsafe impl TrustedLen for MapValuesIter<'_> {} -impl<'a> DoubleEndedIterator for MapValuesIter<'a> { +impl DoubleEndedIterator for MapValuesIter<'_> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/map/mod.rs b/crates/polars-arrow/src/array/map/mod.rs index 5497c1d7342b..1018c21c830a 100644 --- a/crates/polars-arrow/src/array/map/mod.rs +++ b/crates/polars-arrow/src/array/map/mod.rs @@ -4,8 +4,6 @@ use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; use crate::offset::OffsetsBuffer; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs index 08702e8021d3..a2acd7164f6a 100644 --- a/crates/polars-arrow/src/array/mod.rs +++ b/crates/polars-arrow/src/array/mod.rs @@ -189,7 +189,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { new } - /// Clones this [`Array`] with a new new assigned bitmap. + /// Clones this [`Array`] with a new assigned bitmap. /// # Panic /// This function panics iff `validity.len() != self.len()`. fn with_validity(&self, validity: Option) -> Box; @@ -409,115 +409,6 @@ pub fn new_null_array(dtype: ArrowDataType, length: usize) -> Box { } } -/// Trait providing bi-directional conversion between polars_arrow [`Array`] and arrow-rs [`ArrayData`] -/// -/// [`ArrayData`]: arrow_data::ArrayData -#[cfg(feature = "arrow_rs")] -pub trait Arrow2Arrow: Array { - /// Convert this [`Array`] into [`ArrayData`] - fn to_data(&self) -> arrow_data::ArrayData; - - /// Create this [`Array`] from [`ArrayData`] - fn from_data(data: &arrow_data::ArrayData) -> Self; -} - -#[cfg(feature = "arrow_rs")] -macro_rules! to_data_dyn { - ($array:expr, $ty:ty) => {{ - let f = |x: &$ty| x.to_data(); - general_dyn!($array, $ty, f) - }}; -} - -#[cfg(feature = "arrow_rs")] -impl From> for arrow_array::ArrayRef { - fn from(value: Box) -> Self { - value.as_ref().into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&dyn Array> for arrow_array::ArrayRef { - fn from(value: &dyn Array) -> Self { - arrow_array::make_array(to_data(value)) - } -} - -#[cfg(feature = "arrow_rs")] -impl From for Box { - fn from(value: arrow_array::ArrayRef) -> Self { - value.as_ref().into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&dyn arrow_array::Array> for Box { - fn from(value: &dyn arrow_array::Array) -> Self { - from_data(&value.to_data()) - } -} - -/// Convert an polars_arrow [`Array`] to [`arrow_data::ArrayData`] -#[cfg(feature = "arrow_rs")] -pub fn to_data(array: &dyn Array) -> arrow_data::ArrayData { - use crate::datatypes::PhysicalType::*; - match array.dtype().to_physical_type() { - Null => to_data_dyn!(array, NullArray), - Boolean => to_data_dyn!(array, BooleanArray), - Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - to_data_dyn!(array, PrimitiveArray<$T>) - }), - Binary => to_data_dyn!(array, BinaryArray), - LargeBinary => to_data_dyn!(array, BinaryArray), - FixedSizeBinary => to_data_dyn!(array, FixedSizeBinaryArray), - Utf8 => to_data_dyn!(array, Utf8Array::), - LargeUtf8 => to_data_dyn!(array, Utf8Array::), - List => to_data_dyn!(array, ListArray::), - LargeList => to_data_dyn!(array, ListArray::), - FixedSizeList => to_data_dyn!(array, FixedSizeListArray), - Struct => to_data_dyn!(array, StructArray), - Union => to_data_dyn!(array, UnionArray), - Dictionary(key_type) => { - match_integer_type!(key_type, |$T| { - to_data_dyn!(array, DictionaryArray::<$T>) - }) - }, - Map => to_data_dyn!(array, MapArray), - BinaryView | Utf8View => todo!(), - } -} - -/// Convert an [`arrow_data::ArrayData`] to polars_arrow [`Array`] -#[cfg(feature = "arrow_rs")] -pub fn from_data(data: &arrow_data::ArrayData) -> Box { - use crate::datatypes::PhysicalType::*; - let dtype: ArrowDataType = data.data_type().clone().into(); - match dtype.to_physical_type() { - Null => Box::new(NullArray::from_data(data)), - Boolean => Box::new(BooleanArray::from_data(data)), - Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - Box::new(PrimitiveArray::<$T>::from_data(data)) - }), - Binary => Box::new(BinaryArray::::from_data(data)), - LargeBinary => Box::new(BinaryArray::::from_data(data)), - FixedSizeBinary => Box::new(FixedSizeBinaryArray::from_data(data)), - Utf8 => Box::new(Utf8Array::::from_data(data)), - LargeUtf8 => Box::new(Utf8Array::::from_data(data)), - List => Box::new(ListArray::::from_data(data)), - LargeList => Box::new(ListArray::::from_data(data)), - FixedSizeList => Box::new(FixedSizeListArray::from_data(data)), - Struct => Box::new(StructArray::from_data(data)), - Union => Box::new(UnionArray::from_data(data)), - Dictionary(key_type) => { - match_integer_type!(key_type, |$T| { - Box::new(DictionaryArray::<$T>::from_data(data)) - }) - }, - Map => Box::new(MapArray::from_data(data)), - BinaryView | Utf8View => todo!(), - } -} - macro_rules! clone_dyn { ($array:expr, $ty:ty) => {{ let f = |x: &$ty| Box::new(x.clone()); diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 4960b263667c..e6e840d86860 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -213,24 +213,3 @@ impl FromFfi for NullArray { Self::try_new(dtype, array.array().len()) } } - -#[cfg(feature = "arrow_rs")] -mod arrow { - use arrow_data::{ArrayData, ArrayDataBuilder}; - - use super::*; - impl NullArray { - /// Convert this array into [`arrow_data::ArrayData`] - pub fn to_data(&self) -> ArrayData { - let builder = ArrayDataBuilder::new(arrow_schema::DataType::Null).len(self.len()); - - // SAFETY: safe by construction - unsafe { builder.build_unchecked() } - } - - /// Create this array from [`ArrayData`] - pub fn from_data(data: &ArrayData) -> Self { - Self::new(ArrowDataType::Null, data.len()) - } - } -} diff --git a/crates/polars-arrow/src/array/primitive/data.rs b/crates/polars-arrow/src/array/primitive/data.rs deleted file mode 100644 index 56a94107cb89..000000000000 --- a/crates/polars-arrow/src/array/primitive/data.rs +++ /dev/null @@ -1,33 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, PrimitiveArray}; -use crate::bitmap::Bitmap; -use crate::buffer::Buffer; -use crate::types::NativeType; - -impl Arrow2Arrow for PrimitiveArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.values.clone().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - - let mut values: Buffer = data.buffers()[0].clone().into(); - values.slice(data.offset(), data.len()); - - Self { - dtype, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/primitive/mod.rs b/crates/polars-arrow/src/array/primitive/mod.rs index 831a10c372cc..8accc161faf2 100644 --- a/crates/polars-arrow/src/array/primitive/mod.rs +++ b/crates/polars-arrow/src/array/primitive/mod.rs @@ -11,8 +11,6 @@ use crate::datatypes::*; use crate::trusted_len::TrustedLen; use crate::types::{days_ms, f16, i256, months_days_ns, NativeType}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod from_natural; @@ -459,8 +457,8 @@ impl PrimitiveArray { // SAFETY: this is fine, we checked size and alignment, and NativeType // is always Pod. - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - assert_eq!(std::mem::align_of::(), std::mem::align_of::()); + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); let new_values = unsafe { std::mem::transmute::, Buffer>(values) }; PrimitiveArray::new(U::PRIMITIVE.into(), new_values, validity) } diff --git a/crates/polars-arrow/src/array/static_array_collect.rs b/crates/polars-arrow/src/array/static_array_collect.rs index 296d93502abe..9ff5ceb39361 100644 --- a/crates/polars-arrow/src/array/static_array_collect.rs +++ b/crates/polars-arrow/src/array/static_array_collect.rs @@ -417,10 +417,10 @@ impl IntoBytes for T { } } impl TrivialIntoBytes for Vec {} -impl<'a> TrivialIntoBytes for Cow<'a, [u8]> {} -impl<'a> TrivialIntoBytes for &'a [u8] {} +impl TrivialIntoBytes for Cow<'_, [u8]> {} +impl TrivialIntoBytes for &[u8] {} impl TrivialIntoBytes for String {} -impl<'a> TrivialIntoBytes for &'a str {} +impl TrivialIntoBytes for &str {} impl<'a> IntoBytes for Cow<'a, str> { type AsRefT = Cow<'a, [u8]>; fn into_bytes(self) -> Cow<'a, [u8]> { @@ -590,8 +590,8 @@ unsafe fn into_utf8array(arr: BinaryArray) -> Utf8Array { trait StrIntoBytes: IntoBytes {} impl StrIntoBytes for String {} -impl<'a> StrIntoBytes for &'a str {} -impl<'a> StrIntoBytes for Cow<'a, str> {} +impl StrIntoBytes for &str {} +impl StrIntoBytes for Cow<'_, str> {} impl ArrayFromIter for Utf8ViewArray { #[inline] diff --git a/crates/polars-arrow/src/array/struct_/data.rs b/crates/polars-arrow/src/array/struct_/data.rs deleted file mode 100644 index ca8c5b0c6ec3..000000000000 --- a/crates/polars-arrow/src/array/struct_/data.rs +++ /dev/null @@ -1,28 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, StructArray}; -use crate::bitmap::Bitmap; - -impl Arrow2Arrow for StructArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(self.values.iter().map(|x| to_data(x.as_ref())).collect()); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - - Self { - dtype, - values: data.child_data().iter().map(from_data).collect(), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/struct_/ffi.rs b/crates/polars-arrow/src/array/struct_/ffi.rs index 3bfb9a1a7d7f..cc56f0f12cf3 100644 --- a/crates/polars-arrow/src/array/struct_/ffi.rs +++ b/crates/polars-arrow/src/array/struct_/ffi.rs @@ -68,6 +68,6 @@ impl FromFfi for StructArray { }) .collect::>>>()?; - Self::try_new(dtype, values, validity) + Self::try_new(dtype, len, values, validity) } } diff --git a/crates/polars-arrow/src/array/struct_/iterator.rs b/crates/polars-arrow/src/array/struct_/iterator.rs index 4e89af3a6a7f..38a49f274cde 100644 --- a/crates/polars-arrow/src/array/struct_/iterator.rs +++ b/crates/polars-arrow/src/array/struct_/iterator.rs @@ -20,7 +20,7 @@ impl<'a> StructValueIter<'a> { } } -impl<'a> Iterator for StructValueIter<'a> { +impl Iterator for StructValueIter<'_> { type Item = Vec>; #[inline] @@ -48,9 +48,9 @@ impl<'a> Iterator for StructValueIter<'a> { } } -unsafe impl<'a> TrustedLen for StructValueIter<'a> {} +unsafe impl TrustedLen for StructValueIter<'_> {} -impl<'a> DoubleEndedIterator for StructValueIter<'a> { +impl DoubleEndedIterator for StructValueIter<'_> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs index decc95a2627a..eeaac519bb0d 100644 --- a/crates/polars-arrow/src/array/struct_/mod.rs +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -2,14 +2,12 @@ use super::{new_empty_array, new_null_array, Array, Splitable}; use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; mod mutable; pub use mutable::*; -use polars_error::{polars_bail, PolarsResult}; +use polars_error::{polars_bail, polars_ensure, PolarsResult}; use crate::compute::utils::combine_validities_and; @@ -27,13 +25,15 @@ use crate::compute::utils::combine_validities_and; /// Field::new("c".into(), ArrowDataType::Int32, false), /// ]; /// -/// let array = StructArray::new(ArrowDataType::Struct(fields), vec![boolean, int], None); +/// let array = StructArray::new(ArrowDataType::Struct(fields), 4, vec![boolean, int], None); /// ``` #[derive(Clone)] pub struct StructArray { dtype: ArrowDataType, // invariant: each array has the same length values: Vec>, + // invariant: for each v in values: length == v.len() + length: usize, validity: Option, } @@ -49,22 +49,17 @@ impl StructArray { /// * the validity's length is not equal to the length of the first element pub fn try_new( dtype: ArrowDataType, + length: usize, values: Vec>, validity: Option, ) -> PolarsResult { let fields = Self::try_get_fields(&dtype)?; - if fields.is_empty() { - assert!(values.is_empty(), "invalid struct"); - assert_eq!(validity.map(|v| v.len()).unwrap_or(0), 0, "invalid struct"); - return Ok(Self { - dtype, - values, - validity: None, - }); - } - if fields.len() != values.len() { - polars_bail!(ComputeError:"a StructArray must have a number of fields in its DataType equal to the number of child values") - } + + polars_ensure!( + fields.len() == values.len(), + ComputeError: + "a StructArray must have a number of fields in its DataType equal to the number of child values" + ); fields .iter().map(|a| &a.dtype) @@ -81,15 +76,14 @@ impl StructArray { } })?; - let len = values[0].len(); values .iter() - .map(|a| a.len()) + .map(|f| f.len()) .enumerate() - .try_for_each(|(index, a_len)| { - if a_len != len { - polars_bail!(ComputeError: "The children must have an equal number of values. - However, the values at index {index} have a length of {a_len}, which is different from values at index 0, {len}.") + .try_for_each(|(index, f_length)| { + if f_length != length { + polars_bail!(ComputeError: "The children must have the given number of values. + However, the values at index {index} have a length of {f_length}, which is different from given length {length}.") } else { Ok(()) } @@ -97,13 +91,14 @@ impl StructArray { if validity .as_ref() - .map_or(false, |validity| validity.len() != len) + .map_or(false, |validity| validity.len() != length) { polars_bail!(ComputeError:"The validity length of a StructArray must match its number of elements") } Ok(Self { dtype, + length, values, validity, }) @@ -120,10 +115,11 @@ impl StructArray { /// * the validity's length is not equal to the length of the first element pub fn new( dtype: ArrowDataType, + length: usize, values: Vec>, validity: Option, ) -> Self { - Self::try_new(dtype, values, validity).unwrap() + Self::try_new(dtype, length, values, validity).unwrap() } /// Creates an empty [`StructArray`]. @@ -133,7 +129,7 @@ impl StructArray { .iter() .map(|field| new_empty_array(field.dtype().clone())) .collect(); - Self::new(dtype, values, None) + Self::new(dtype, 0, values, None) } else { panic!("StructArray must be initialized with DataType::Struct"); } @@ -146,7 +142,7 @@ impl StructArray { .iter() .map(|field| new_null_array(field.dtype().clone(), length)) .collect(); - Self::new(dtype, values, Some(Bitmap::new_zeroed(length))) + Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) } else { panic!("StructArray must be initialized with DataType::Struct"); } @@ -157,9 +153,10 @@ impl StructArray { impl StructArray { /// Deconstructs the [`StructArray`] into its individual components. #[must_use] - pub fn into_data(self) -> (Vec, Vec>, Option) { + pub fn into_data(self) -> (Vec, usize, Vec>, Option) { let Self { dtype, + length, values, validity, } = self; @@ -168,7 +165,7 @@ impl StructArray { } else { unreachable!() }; - (fields, values, validity) + (fields, length, values, validity) } /// Slices this [`StructArray`]. @@ -199,6 +196,7 @@ impl StructArray { self.values .iter_mut() .for_each(|x| x.slice_unchecked(offset, length)); + self.length = length; } /// Set the outer nulls into the inner arrays. @@ -227,18 +225,17 @@ impl StructArray { impl StructArray { #[inline] fn len(&self) -> usize { - #[cfg(debug_assertions)] - if let Some(fst) = self.values.first() { - for arr in self.values.iter().skip(1) { + if cfg!(debug_assertions) { + for arr in self.values.iter() { assert_eq!( arr.len(), - fst.len(), + self.length, "StructArray invariant: each array has same length" ); } } - self.values.first().map(|arr| arr.len()).unwrap_or(0) + self.length } /// The optional validity. @@ -310,11 +307,13 @@ impl Splitable for StructArray { ( Self { dtype: self.dtype.clone(), + length: offset, values: lhs_values, validity: lhs_validity, }, Self { dtype: self.dtype.clone(), + length: self.length - offset, values: rhs_values, validity: rhs_validity, }, diff --git a/crates/polars-arrow/src/array/struct_/mutable.rs b/crates/polars-arrow/src/array/struct_/mutable.rs index 286db07e2f97..e066d7b6aef2 100644 --- a/crates/polars-arrow/src/array/struct_/mutable.rs +++ b/crates/polars-arrow/src/array/struct_/mutable.rs @@ -11,19 +11,19 @@ use crate::datatypes::ArrowDataType; #[derive(Debug)] pub struct MutableStructArray { dtype: ArrowDataType, + length: usize, values: Vec>, validity: Option, } fn check( dtype: &ArrowDataType, + length: usize, values: &[Box], validity: Option, ) -> PolarsResult<()> { let fields = StructArray::try_get_fields(dtype)?; - if fields.is_empty() { - polars_bail!(ComputeError: "a StructArray must contain at least one field") - } + if fields.len() != values.len() { polars_bail!(ComputeError: "a StructArray must have a number of fields in its DataType equal to the number of child values") } @@ -34,32 +34,25 @@ fn check( .enumerate() .try_for_each(|(index, (dtype, child))| { if dtype != child { - polars_bail!(ComputeError: - "The children DataTypes of a StructArray must equal the children data types. - However, the field {index} has data type {dtype:?} but the value has data type {child:?}" - ) + polars_bail!(ComputeError: "The children DataTypes of a StructArray must equal the children data types.\nHowever, the field {index} has data type {dtype:?} but the value has data type {child:?}") } else { Ok(()) } })?; - let len = values[0].len(); values .iter() - .map(|a| a.len()) + .map(|f| f.len()) .enumerate() - .try_for_each(|(index, a_len)| { - if a_len != len { - polars_bail!(ComputeError: - "The children must have an equal number of values. - However, the values at index {index} have a length of {a_len}, which is different from values at index 0, {len}." - ) + .try_for_each(|(index, f_length)| { + if f_length != length { + polars_bail!(ComputeError: "The children must have the given number of values.\nHowever, the values at index {index} have a length of {f_length}, which is different from given length {length}.") } else { Ok(()) } })?; - if validity.map_or(false, |validity| validity != len) { + if validity.map_or(false, |validity| validity != length) { polars_bail!(ComputeError: "the validity length of a StructArray must match its number of elements", ) @@ -77,6 +70,7 @@ impl From for StructArray { StructArray::new( other.dtype, + other.length, other.values.into_iter().map(|mut v| v.as_box()).collect(), validity, ) @@ -85,8 +79,8 @@ impl From for StructArray { impl MutableStructArray { /// Creates a new [`MutableStructArray`]. - pub fn new(dtype: ArrowDataType, values: Vec>) -> Self { - Self::try_new(dtype, values, None).unwrap() + pub fn new(dtype: ArrowDataType, length: usize, values: Vec>) -> Self { + Self::try_new(dtype, length, values, None).unwrap() } /// Create a [`MutableStructArray`] out of low-end APIs. @@ -97,12 +91,14 @@ impl MutableStructArray { /// * `validity` is not `None` and its length is different from the `values`'s length pub fn try_new( dtype: ArrowDataType, + length: usize, values: Vec>, validity: Option, ) -> PolarsResult { - check(&dtype, &values, validity.as_ref().map(|x| x.len()))?; + check(&dtype, length, &values, validity.as_ref().map(|x| x.len()))?; Ok(Self { dtype, + length, values, validity, }) @@ -113,26 +109,17 @@ impl MutableStructArray { self, ) -> ( ArrowDataType, + usize, Vec>, Option, ) { - (self.dtype, self.values, self.validity) - } - - /// The mutable values - pub fn mut_values(&mut self) -> &mut Vec> { - &mut self.values + (self.dtype, self.length, self.values, self.validity) } /// The values pub fn values(&self) -> &Vec> { &self.values } - - /// Return the `i`th child array. - pub fn value(&mut self, i: usize) -> Option<&mut A> { - self.values[i].as_mut_any().downcast_mut::() - } } impl MutableStructArray { @@ -155,6 +142,7 @@ impl MutableStructArray { false => self.init_validity(), }, }; + self.length += 1; } fn push_null(&mut self) { @@ -193,7 +181,7 @@ impl MutableStructArray { impl MutableArray for MutableStructArray { fn len(&self) -> usize { - self.values.first().map(|v| v.len()).unwrap_or(0) + self.length } fn validity(&self) -> Option<&MutableBitmap> { @@ -203,6 +191,7 @@ impl MutableArray for MutableStructArray { fn as_box(&mut self) -> Box { StructArray::new( self.dtype.clone(), + self.length, std::mem::take(&mut self.values) .into_iter() .map(|mut v| v.as_box()) @@ -215,6 +204,7 @@ impl MutableArray for MutableStructArray { fn as_arc(&mut self) -> Arc { StructArray::new( self.dtype.clone(), + self.length, std::mem::take(&mut self.values) .into_iter() .map(|mut v| v.as_box()) diff --git a/crates/polars-arrow/src/array/union/data.rs b/crates/polars-arrow/src/array/union/data.rs deleted file mode 100644 index 869fdcfc248d..000000000000 --- a/crates/polars-arrow/src/array/union/data.rs +++ /dev/null @@ -1,70 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, UnionArray}; -use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for UnionArray { - fn to_data(&self) -> ArrayData { - let dtype = arrow_schema::DataType::from(self.dtype.clone()); - let len = self.len(); - - let builder = match self.offsets.clone() { - Some(offsets) => ArrayDataBuilder::new(dtype) - .len(len) - .buffers(vec![self.types.clone().into(), offsets.into()]) - .child_data(self.fields.iter().map(|x| to_data(x.as_ref())).collect()), - None => ArrayDataBuilder::new(dtype) - .len(len) - .buffers(vec![self.types.clone().into()]) - .child_data( - self.fields - .iter() - .map(|x| to_data(x.as_ref()).slice(self.offset, len)) - .collect(), - ), - }; - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype: ArrowDataType = data.data_type().clone().into(); - - let fields = data.child_data().iter().map(from_data).collect(); - let buffers = data.buffers(); - let mut types: Buffer = buffers[0].clone().into(); - types.slice(data.offset(), data.len()); - let offsets = match buffers.len() == 2 { - true => { - let mut offsets: Buffer = buffers[1].clone().into(); - offsets.slice(data.offset(), data.len()); - Some(offsets) - }, - false => None, - }; - - // Map from type id to array index - let map = match &dtype { - ArrowDataType::Union(_, Some(ids), _) => { - let mut map = [0; 127]; - for (pos, &id) in ids.iter().enumerate() { - map[id as usize] = pos; - } - Some(map) - }, - ArrowDataType::Union(_, None, _) => None, - _ => unreachable!("must be Union type"), - }; - - Self { - types, - map, - fields, - offsets, - dtype, - offset: data.offset(), - } - } -} diff --git a/crates/polars-arrow/src/array/union/iterator.rs b/crates/polars-arrow/src/array/union/iterator.rs index bdcf5825af6c..e93223e46c43 100644 --- a/crates/polars-arrow/src/array/union/iterator.rs +++ b/crates/polars-arrow/src/array/union/iterator.rs @@ -15,7 +15,7 @@ impl<'a> UnionIter<'a> { } } -impl<'a> Iterator for UnionIter<'a> { +impl Iterator for UnionIter<'_> { type Item = Box; #[inline] @@ -54,6 +54,6 @@ impl<'a> UnionArray { } } -impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {} +impl std::iter::ExactSizeIterator for UnionIter<'_> {} -unsafe impl<'a> TrustedLen for UnionIter<'a> {} +unsafe impl TrustedLen for UnionIter<'_> {} diff --git a/crates/polars-arrow/src/array/union/mod.rs b/crates/polars-arrow/src/array/union/mod.rs index e42d268f5c06..f8007a485ed5 100644 --- a/crates/polars-arrow/src/array/union/mod.rs +++ b/crates/polars-arrow/src/array/union/mod.rs @@ -6,8 +6,6 @@ use crate::buffer::Buffer; use crate::datatypes::{ArrowDataType, Field, UnionMode}; use crate::scalar::{new_scalar, Scalar}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/utf8/data.rs b/crates/polars-arrow/src/array/utf8/data.rs deleted file mode 100644 index 37f73a089aa6..000000000000 --- a/crates/polars-arrow/src/array/utf8/data.rs +++ /dev/null @@ -1,42 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, Utf8Array}; -use crate::bitmap::Bitmap; -use crate::offset::{Offset, OffsetsBuffer}; - -impl Arrow2Arrow for Utf8Array { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype().clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.offsets().len_proxy()) - .buffers(vec![ - self.offsets.clone().into_inner().into(), - self.values.clone().into(), - ]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let buffers = data.buffers(); - - // SAFETY: ArrayData is valid - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype, - offsets, - values: buffers[1].clone().into(), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/utf8/mod.rs b/crates/polars-arrow/src/array/utf8/mod.rs index ebec52b78d28..fffa36ba2f8f 100644 --- a/crates/polars-arrow/src/array/utf8/mod.rs +++ b/crates/polars-arrow/src/array/utf8/mod.rs @@ -11,8 +11,6 @@ use crate::datatypes::ArrowDataType; use crate::offset::{Offset, Offsets, OffsetsBuffer}; use crate::trusted_len::TrustedLen; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod from; diff --git a/crates/polars-arrow/src/bitmap/aligned.rs b/crates/polars-arrow/src/bitmap/aligned.rs index b5eab49432ca..ad1baf06631e 100644 --- a/crates/polars-arrow/src/bitmap/aligned.rs +++ b/crates/polars-arrow/src/bitmap/aligned.rs @@ -55,7 +55,7 @@ impl<'a, T: BitChunk> AlignedBitmapSlice<'a, T> { /// The length (in bits) of the portion of the bitmap found in bulk. #[inline(always)] pub fn bulk_bitlen(&self) -> usize { - 8 * std::mem::size_of::() * self.bulk.len() + 8 * size_of::() * self.bulk.len() } /// The length (in bits) of the portion of the bitmap found in suffix. @@ -77,7 +77,7 @@ impl<'a, T: BitChunk> AlignedBitmapSlice<'a, T> { offset %= 8; // Fast-path: fits entirely in one chunk. - let chunk_len = std::mem::size_of::(); + let chunk_len = size_of::(); let chunk_len_bits = 8 * chunk_len; if offset + len <= chunk_len_bits { let mut prefix = load_chunk_le::(bytes) >> offset; diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs index a3edb658be4e..ec9814e4fad3 100644 --- a/crates/polars-arrow/src/bitmap/bitmap_ops.rs +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -12,7 +12,7 @@ pub(crate) fn push_bitchunk(buffer: &mut Vec, value: T) { /// Creates a [`Vec`] from a [`TrustedLen`] of [`BitChunk`]. pub fn chunk_iter_to_vec>(iter: I) -> Vec { - let cap = iter.size_hint().0 * std::mem::size_of::(); + let cap = iter.size_hint().0 * size_of::(); let mut buffer = Vec::with_capacity(cap); for v in iter { push_bitchunk(&mut buffer, v) @@ -24,7 +24,7 @@ fn chunk_iter_to_vec_and_remainder>( iter: I, remainder: T, ) -> Vec { - let cap = (iter.size_hint().0 + 1) * std::mem::size_of::(); + let cap = (iter.size_hint().0 + 1) * size_of::(); let mut buffer = Vec::with_capacity(cap); for v in iter { push_bitchunk(&mut buffer, v) @@ -338,7 +338,7 @@ impl PartialEq for Bitmap { } } -impl<'a, 'b> BitOr<&'b Bitmap> for &'a Bitmap { +impl<'b> BitOr<&'b Bitmap> for &Bitmap { type Output = Bitmap; fn bitor(self, rhs: &'b Bitmap) -> Bitmap { @@ -346,7 +346,7 @@ impl<'a, 'b> BitOr<&'b Bitmap> for &'a Bitmap { } } -impl<'a, 'b> BitAnd<&'b Bitmap> for &'a Bitmap { +impl<'b> BitAnd<&'b Bitmap> for &Bitmap { type Output = Bitmap; fn bitand(self, rhs: &'b Bitmap) -> Bitmap { @@ -354,7 +354,7 @@ impl<'a, 'b> BitAnd<&'b Bitmap> for &'a Bitmap { } } -impl<'a, 'b> BitXor<&'b Bitmap> for &'a Bitmap { +impl<'b> BitXor<&'b Bitmap> for &Bitmap { type Output = Bitmap; fn bitxor(self, rhs: &'b Bitmap) -> Bitmap { diff --git a/crates/polars-arrow/src/bitmap/builder.rs b/crates/polars-arrow/src/bitmap/builder.rs new file mode 100644 index 000000000000..c507df97c5ba --- /dev/null +++ b/crates/polars-arrow/src/bitmap/builder.rs @@ -0,0 +1,100 @@ +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::storage::SharedStorage; + +/// Used to build bitmaps bool-by-bool in sequential order. +#[derive(Default, Clone)] +pub struct BitmapBuilder { + buf: u64, + len: usize, + cap: usize, + set_bits: usize, + bytes: Vec, +} + +impl BitmapBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn capacity(&self) -> usize { + self.cap + } + + pub fn with_capacity(bits: usize) -> Self { + let bytes = Vec::with_capacity(bits.div_ceil(64) * 8); + let words_available = bytes.capacity() / 8; + Self { + buf: 0, + len: 0, + cap: words_available * 64, + set_bits: 0, + bytes, + } + } + + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + if self.len + additional > self.cap { + self.reserve_slow(additional) + } + } + + #[cold] + #[inline(never)] + fn reserve_slow(&mut self, additional: usize) { + let bytes_needed = (self.len + additional).div_ceil(64) * 8; + self.bytes.reserve(bytes_needed - self.bytes.capacity()); + let words_available = self.bytes.capacity() / 8; + self.cap = words_available * 64; + } + + #[inline(always)] + pub fn push(&mut self, x: bool) { + self.reserve(1); + unsafe { self.push_unchecked(x) } + } + + /// # Safety + /// self.len() < self.capacity() must hold. + #[inline(always)] + pub unsafe fn push_unchecked(&mut self, x: bool) { + debug_assert!(self.len < self.cap); + self.buf |= (x as u64) << (self.len % 64); + self.len += 1; + if self.len % 64 == 0 { + let p = self.bytes.as_mut_ptr().add(self.bytes.len()).cast::(); + p.write_unaligned(self.buf.to_le()); + self.bytes.set_len(self.bytes.len() + 8); + self.set_bits += self.buf.count_ones() as usize; + self.buf = 0; + } + } + + /// # Safety + /// May only be called once at the end. + unsafe fn finish(&mut self) { + if self.len % 64 != 0 { + self.bytes.extend_from_slice(&self.buf.to_le_bytes()); + self.set_bits += self.buf.count_ones() as usize; + } + } + + pub fn into_mut(mut self) -> MutableBitmap { + unsafe { + self.finish(); + MutableBitmap::from_vec(self.bytes, self.len) + } + } + + pub fn freeze(mut self) -> Bitmap { + unsafe { + self.finish(); + let storage = SharedStorage::from_vec(self.bytes); + Bitmap::from_inner_unchecked(storage, 0, self.len, Some(self.len - self.set_bits)) + } + } +} diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index c9aa0b681b4a..a896651467d2 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -5,7 +5,7 @@ use std::sync::LazyLock; use either::Either; use polars_error::{polars_bail, PolarsResult}; -use super::utils::{count_zeros, fmt, get_bit, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; +use super::utils::{count_zeros, fmt, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; use super::{chunk_iter_to_vec, intersects_with, num_intersections_with, IntoIter, MutableBitmap}; use crate::array::Splitable; use crate::bitmap::aligned::AlignedBitmapSlice; @@ -334,7 +334,8 @@ impl Bitmap { /// Panics iff `i >= self.len()`. #[inline] pub fn get_bit(&self, i: usize) -> bool { - get_bit(&self.storage, self.offset + i) + assert!(i < self.len()); + unsafe { self.get_bit_unchecked(i) } } /// Unsafely returns whether the bit at position `i` is set. @@ -343,6 +344,7 @@ impl Bitmap { /// Unsound iff `i >= self.len()`. #[inline] pub unsafe fn get_bit_unchecked(&self, i: usize) -> bool { + debug_assert!(i < self.len()); get_bit_unchecked(&self.storage, self.offset + i) } @@ -593,22 +595,6 @@ impl Bitmap { ) -> std::result::Result { Ok(MutableBitmap::try_from_trusted_len_iter_unchecked(iterator)?.into()) } - - /// Create a new [`Bitmap`] from an arrow [`NullBuffer`] - /// - /// [`NullBuffer`]: arrow_buffer::buffer::NullBuffer - #[cfg(feature = "arrow_rs")] - pub fn from_null_buffer(value: arrow_buffer::buffer::NullBuffer) -> Self { - let offset = value.offset(); - let length = value.len(); - let unset_bits = value.null_count(); - Self { - storage: SharedStorage::from_arrow_buffer(value.buffer().clone()), - offset, - length, - unset_bit_count_cache: AtomicU64::new(unset_bits as u64), - } - } } impl<'a> IntoIterator for &'a Bitmap { @@ -629,17 +615,6 @@ impl IntoIterator for Bitmap { } } -#[cfg(feature = "arrow_rs")] -impl From for arrow_buffer::buffer::NullBuffer { - fn from(value: Bitmap) -> Self { - let null_count = value.unset_bits(); - let buffer = value.storage.into_arrow_buffer(); - let buffer = arrow_buffer::buffer::BooleanBuffer::new(buffer, value.offset, value.length); - // SAFETY: null count is accurate - unsafe { arrow_buffer::buffer::NullBuffer::new_unchecked(buffer, null_count) } - } -} - impl Splitable for Bitmap { #[inline(always)] fn check_bound(&self, offset: usize) -> bool { diff --git a/crates/polars-arrow/src/bitmap/iterator.rs b/crates/polars-arrow/src/bitmap/iterator.rs index 63c61afd83f7..84e0a2d7a985 100644 --- a/crates/polars-arrow/src/bitmap/iterator.rs +++ b/crates/polars-arrow/src/bitmap/iterator.rs @@ -25,6 +25,7 @@ fn calc_iters_remaining(length: usize, min_length_for_iter: usize, consume: usiz 1 + obvious_iters // Thus always exactly 1 more iter. } +#[derive(Clone)] pub struct TrueIdxIter<'a> { mask: BitMask<'a>, first_unknown: usize, @@ -57,7 +58,7 @@ impl<'a> TrueIdxIter<'a> { } } -impl<'a> Iterator for TrueIdxIter<'a> { +impl Iterator for TrueIdxIter<'_> { type Item = usize; #[inline] @@ -92,7 +93,7 @@ impl<'a> Iterator for TrueIdxIter<'a> { } } -unsafe impl<'a> TrustedLen for TrueIdxIter<'a> {} +unsafe impl TrustedLen for TrueIdxIter<'_> {} pub struct FastU32BitmapIter<'a> { bytes: &'a [u8], @@ -142,7 +143,7 @@ impl<'a> FastU32BitmapIter<'a> { } } -impl<'a> Iterator for FastU32BitmapIter<'a> { +impl Iterator for FastU32BitmapIter<'_> { type Item = u32; #[inline] @@ -170,7 +171,7 @@ impl<'a> Iterator for FastU32BitmapIter<'a> { } } -unsafe impl<'a> TrustedLen for FastU32BitmapIter<'a> {} +unsafe impl TrustedLen for FastU32BitmapIter<'_> {} pub struct FastU56BitmapIter<'a> { bytes: &'a [u8], @@ -221,7 +222,7 @@ impl<'a> FastU56BitmapIter<'a> { } } -impl<'a> Iterator for FastU56BitmapIter<'a> { +impl Iterator for FastU56BitmapIter<'_> { type Item = u64; #[inline] @@ -251,7 +252,7 @@ impl<'a> Iterator for FastU56BitmapIter<'a> { } } -unsafe impl<'a> TrustedLen for FastU56BitmapIter<'a> {} +unsafe impl TrustedLen for FastU56BitmapIter<'_> {} pub struct FastU64BitmapIter<'a> { bytes: &'a [u8], @@ -316,7 +317,7 @@ impl<'a> FastU64BitmapIter<'a> { } } -impl<'a> Iterator for FastU64BitmapIter<'a> { +impl Iterator for FastU64BitmapIter<'_> { type Item = u64; #[inline] @@ -348,7 +349,7 @@ impl<'a> Iterator for FastU64BitmapIter<'a> { } } -unsafe impl<'a> TrustedLen for FastU64BitmapIter<'a> {} +unsafe impl TrustedLen for FastU64BitmapIter<'_> {} /// This crates' equivalent of [`std::vec::IntoIter`] for [`Bitmap`]. #[derive(Debug, Clone)] diff --git a/crates/polars-arrow/src/bitmap/mod.rs b/crates/polars-arrow/src/bitmap/mod.rs index e7ed5fa363e8..6d518bf596b4 100644 --- a/crates/polars-arrow/src/bitmap/mod.rs +++ b/crates/polars-arrow/src/bitmap/mod.rs @@ -19,3 +19,6 @@ pub use assign_ops::*; pub mod utils; pub mod bitmask; + +mod builder; +pub use builder::*; diff --git a/crates/polars-arrow/src/bitmap/mutable.rs b/crates/polars-arrow/src/bitmap/mutable.rs index d030682a63a7..05bbbe5dd976 100644 --- a/crates/polars-arrow/src/bitmap/mutable.rs +++ b/crates/polars-arrow/src/bitmap/mutable.rs @@ -1,12 +1,11 @@ use std::hint::unreachable_unchecked; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::vec::PushUnchecked; -use super::utils::{ - count_zeros, fmt, get_bit, set, set_bit, BitChunk, BitChunks, BitChunksExactMut, BitmapIter, -}; +use super::utils::{count_zeros, fmt, BitChunk, BitChunks, BitChunksExactMut, BitmapIter}; use super::{intersects_with_mut, Bitmap}; -use crate::bitmap::utils::{get_bit_unchecked, merge_reversed, set_bit_unchecked}; +use crate::bitmap::utils::{get_bit_unchecked, merge_reversed, set_bit_in_byte}; use crate::storage::SharedStorage; use crate::trusted_len::TrustedLen; @@ -118,8 +117,8 @@ impl MutableBitmap { if self.length % 8 == 0 { self.buffer.push(0); } - let byte = unsafe { self.buffer.as_mut_slice().last_mut().unwrap_unchecked() }; - *byte = set(*byte, self.length % 8, value); + let byte = unsafe { self.buffer.last_mut().unwrap_unchecked() }; + *byte = set_bit_in_byte(*byte, self.length % 8, value); self.length += 1; } @@ -144,7 +143,8 @@ impl MutableBitmap { /// Panics iff `index >= self.len()`. #[inline] pub fn get(&self, index: usize) -> bool { - get_bit(&self.buffer, index) + assert!(index < self.len()); + unsafe { self.get_unchecked(index) } } /// Returns whether the position `index` is set. @@ -161,7 +161,28 @@ impl MutableBitmap { /// Panics iff `index >= self.len()`. #[inline] pub fn set(&mut self, index: usize, value: bool) { - set_bit(self.buffer.as_mut_slice(), index, value) + assert!(index < self.len()); + unsafe { + self.set_unchecked(index, value); + } + } + + /// Sets the position `index` to the OR of its original value and `value`. + /// + /// # Safety + /// It's undefined behavior if index >= self.len(). + #[inline] + pub unsafe fn or_pos_unchecked(&mut self, index: usize, value: bool) { + *self.buffer.get_unchecked_mut(index / 8) |= (value as u8) << (index % 8); + } + + /// Sets the position `index` to the AND of its original value and `value`. + /// + /// # Safety + /// It's undefined behavior if index >= self.len(). + #[inline] + pub unsafe fn and_pos_unchecked(&mut self, index: usize, value: bool) { + *self.buffer.get_unchecked_mut(index / 8) &= (value as u8) << (index % 8); } /// constructs a new iterator over the bits of [`MutableBitmap`]. @@ -192,6 +213,17 @@ impl MutableBitmap { } } + /// Resizes the [`MutableBitmap`] to the specified length, inserting value + /// if the length is bigger than the current length. + pub fn resize(&mut self, length: usize, value: bool) { + if let Some(additional) = length.checked_sub(self.len()) { + self.extend_constant(additional, value); + } else { + self.buffer.truncate(length.saturating_add(7) / 8); + self.length = length; + } + } + /// Initializes a zeroed [`MutableBitmap`]. #[inline] pub fn from_len_zeroed(length: usize) -> Self { @@ -230,10 +262,10 @@ impl MutableBitmap { #[inline] pub unsafe fn push_unchecked(&mut self, value: bool) { if self.length % 8 == 0 { - self.buffer.push(0); + self.buffer.push_unchecked(0); } - let byte = self.buffer.as_mut_slice().last_mut().unwrap(); - *byte = set(*byte, self.length % 8, value); + let byte = self.buffer.last_mut().unwrap_unchecked(); + *byte = set_bit_in_byte(*byte, self.length % 8, value); self.length += 1; } @@ -330,7 +362,9 @@ impl MutableBitmap { /// Caller must ensure that `index < self.len()` #[inline] pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { - set_bit_unchecked(self.buffer.as_mut_slice(), index, value) + debug_assert!(index < self.len()); + let byte = self.buffer.get_unchecked_mut(index / 8); + *byte = set_bit_in_byte(*byte, index % 8, value); } /// Shrinks the capacity of the [`MutableBitmap`] to fit its current length. @@ -566,10 +600,10 @@ impl MutableBitmap { self.buffer.push(0); } // the iterator will not fill the last byte - let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + let byte = self.buffer.last_mut().unwrap(); let mut i = bit_offset; for value in iterator { - *byte = set(*byte, i, value); + *byte = set_bit_in_byte(*byte, i, value); i += 1; } self.length += length; @@ -581,9 +615,9 @@ impl MutableBitmap { if bit_offset != 0 { // we are in the middle of a byte; lets finish it - let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + let byte = self.buffer.last_mut().unwrap(); (bit_offset..8).for_each(|i| { - *byte = set(*byte, i, iterator.next().unwrap()); + *byte = set_bit_in_byte(*byte, i, iterator.next().unwrap()); }); self.length += 8 - bit_offset; length -= 8 - bit_offset; @@ -650,7 +684,7 @@ impl MutableBitmap { let data = buffer.as_mut_slice(); data[..chunks].iter_mut().try_for_each(|byte| { (0..8).try_for_each(|i| { - *byte = set(*byte, i, iterator.next().unwrap()?); + *byte = set_bit_in_byte(*byte, i, iterator.next().unwrap()?); Ok(()) }) })?; @@ -658,7 +692,7 @@ impl MutableBitmap { if reminder != 0 { let last = &mut data[chunks]; iterator.enumerate().try_for_each(|(i, value)| { - *last = set(*last, i, value?); + *last = set_bit_in_byte(*last, i, value?); Ok(()) })?; } diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs index 7bc12e22898e..50f4e023ac1f 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -17,7 +17,7 @@ impl<'a, T: BitChunk> BitChunksExact<'a, T> { #[inline] pub fn new(bitmap: &'a [u8], length: usize) -> Self { assert!(length <= bitmap.len() * 8); - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let bitmap = &bitmap[..length.saturating_add(7) / 8]; diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs index 81e08df0059e..680d3bf96fa4 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs @@ -23,7 +23,7 @@ where // expected = [n5, n6, n7, c0, c1, c2, c3, c4] // 1. unset most significants of `next` up to `offset` - let inverse_offset = std::mem::size_of::() * 8 - offset; + let inverse_offset = size_of::() * 8 - offset; next <<= inverse_offset; // next = [n5, n6, n7, 0 , 0 , 0 , 0 , 0 ] diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs index 8a1668a37d1f..7c387d65b10d 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -44,7 +44,7 @@ fn copy_with_merge(dst: &mut T::Bytes, bytes: &[u8], bit_offset: us bytes .windows(2) .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref())) - .take(std::mem::size_of::()) + .take(size_of::()) .enumerate() .for_each(|(i, w)| { let val = merge_reversed(w[0], w[1], bit_offset); @@ -59,7 +59,7 @@ impl<'a, T: BitChunk> BitChunks<'a, T> { let slice = &slice[offset / 8..]; let bit_offset = offset % 8; - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let bytes_len = len / 8; let bytes_upper_len = (len + bit_offset + 7) / 8; @@ -120,7 +120,7 @@ impl<'a, T: BitChunk> BitChunks<'a, T> { // all remaining bytes self.remainder_bytes .iter() - .take(std::mem::size_of::()) + .take(size_of::()) .enumerate() .for_each(|(i, val)| remainder[i] = *val); @@ -137,7 +137,7 @@ impl<'a, T: BitChunk> BitChunks<'a, T> { /// Returns the remainder bits in [`BitChunks::remainder`]. pub fn remainder_len(&self) -> usize { - self.len - (std::mem::size_of::() * ((self.len / 8) / std::mem::size_of::()) * 8) + self.len - (size_of::() * ((self.len / 8) / size_of::()) * 8) } } diff --git a/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs b/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs index 7a5a91a12805..04a9f8b661a7 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs @@ -4,7 +4,7 @@ use super::BitChunk; /// /// # Safety /// The slices returned by this iterator are guaranteed to have length equal to -/// `std::mem::size_of::()`. +/// `size_of::()`. #[derive(Debug)] pub struct BitChunksExactMut<'a, T: BitChunk> { chunks: std::slice::ChunksExactMut<'a, u8>, @@ -18,7 +18,7 @@ impl<'a, T: BitChunk> BitChunksExactMut<'a, T> { #[inline] pub fn new(bitmap: &'a mut [u8], length: usize) -> Self { assert!(length <= bitmap.len() * 8); - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let bitmap = &mut bitmap[..length.saturating_add(7) / 8]; diff --git a/crates/polars-arrow/src/bitmap/utils/iterator.rs b/crates/polars-arrow/src/bitmap/utils/iterator.rs index c386b2d02db4..243372599687 100644 --- a/crates/polars-arrow/src/bitmap/utils/iterator.rs +++ b/crates/polars-arrow/src/bitmap/utils/iterator.rs @@ -176,7 +176,7 @@ impl<'a> BitmapIter<'a> { let num_words = n / 64; if num_words > 0 { - assert!(self.bytes.len() >= num_words * std::mem::size_of::()); + assert!(self.bytes.len() >= num_words * size_of::()); bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize); @@ -189,7 +189,7 @@ impl<'a> BitmapIter<'a> { return; } - assert!(self.bytes.len() >= std::mem::size_of::()); + assert!(self.bytes.len() >= size_of::()); self.word_len = usize::min(self.rest_len, 64); self.rest_len -= self.word_len; @@ -205,7 +205,7 @@ impl<'a> BitmapIter<'a> { } } -impl<'a> Iterator for BitmapIter<'a> { +impl Iterator for BitmapIter<'_> { type Item = bool; #[inline] @@ -238,7 +238,7 @@ impl<'a> Iterator for BitmapIter<'a> { } } -impl<'a> DoubleEndedIterator for BitmapIter<'a> { +impl DoubleEndedIterator for BitmapIter<'_> { #[inline] fn next_back(&mut self) -> Option { if self.rest_len > 0 { diff --git a/crates/polars-arrow/src/bitmap/utils/mod.rs b/crates/polars-arrow/src/bitmap/utils/mod.rs index ebd47c2530dc..01cccd81bd68 100644 --- a/crates/polars-arrow/src/bitmap/utils/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/mod.rs @@ -25,53 +25,34 @@ pub fn is_set(byte: u8, i: usize) -> bool { } /// Sets bit at position `i` in `byte`. -#[inline] -pub fn set(byte: u8, i: usize, value: bool) -> u8 { +#[inline(always)] +pub fn set_bit_in_byte(byte: u8, i: usize, value: bool) -> u8 { debug_assert!(i < 8); - let mask = !(1 << i); let insert = (value as u8) << i; (byte & mask) | insert } -/// Sets bit at position `i` in `bytes`. -/// # Panics -/// This function panics iff `i >= bytes.len() * 8`. -#[inline] -pub fn set_bit(bytes: &mut [u8], i: usize, value: bool) { - bytes[i / 8] = set(bytes[i / 8], i % 8, value); -} - -/// Sets bit at position `i` in `bytes` without doing bound checks -/// # Safety -/// `i >= bytes.len() * 8` results in undefined behavior. -#[inline] -pub unsafe fn set_bit_unchecked(bytes: &mut [u8], i: usize, value: bool) { - let byte = bytes.get_unchecked_mut(i / 8); - *byte = set(*byte, i % 8, value); -} - -/// Returns whether bit at position `i` in `bytes` is set. -/// # Panic -/// This function panics iff `i >= bytes.len() * 8`. -#[inline] -pub fn get_bit(bytes: &[u8], i: usize) -> bool { - let byte = bytes[i / 8]; - let bit = (byte >> (i % 8)) & 1; - bit != 0 -} - /// Returns whether bit at position `i` in `bytes` is set or not. /// /// # Safety /// `i >= bytes.len() * 8` results in undefined behavior. -#[inline] +#[inline(always)] pub unsafe fn get_bit_unchecked(bytes: &[u8], i: usize) -> bool { let byte = *bytes.get_unchecked_release(i / 8); let bit = (byte >> (i % 8)) & 1; bit != 0 } +/// Sets bit at position `i` in `bytes` without doing bound checks. +/// # Safety +/// `i >= bytes.len() * 8` results in undefined behavior. +#[inline(always)] +pub unsafe fn set_bit_unchecked(bytes: &mut [u8], i: usize, value: bool) { + let byte = bytes.get_unchecked_mut(i / 8); + *byte = set_bit_in_byte(*byte, i % 8, value); +} + /// Returns the number of bytes required to hold `bits` bits. #[inline] pub fn bytes_for(bits: usize) -> usize { diff --git a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs index f3083ad0b141..9f43a3dfe89a 100644 --- a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs +++ b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs @@ -74,7 +74,7 @@ impl<'a> SlicesIterator<'a> { } } -impl<'a> Iterator for SlicesIterator<'a> { +impl Iterator for SlicesIterator<'_> { type Item = (usize, usize); #[inline] diff --git a/crates/polars-arrow/src/buffer/immutable.rs b/crates/polars-arrow/src/buffer/immutable.rs index 1dfe805ffc57..1c6e5b5aa4ff 100644 --- a/crates/polars-arrow/src/buffer/immutable.rs +++ b/crates/polars-arrow/src/buffer/immutable.rs @@ -288,24 +288,6 @@ impl IntoIterator for Buffer { } } -#[cfg(feature = "arrow_rs")] -impl From for Buffer { - fn from(value: arrow_buffer::Buffer) -> Self { - Self::from_storage(SharedStorage::from_arrow_buffer(value)) - } -} - -#[cfg(feature = "arrow_rs")] -impl From> for arrow_buffer::Buffer { - fn from(value: Buffer) -> Self { - let offset = value.offset(); - value.storage.into_arrow_buffer().slice_with_length( - offset * std::mem::size_of::(), - value.length * std::mem::size_of::(), - ) - } -} - unsafe impl<'a, T: 'a> ArrayAccessor<'a> for Buffer { type Item = &'a T; diff --git a/crates/polars-arrow/src/compute/aggregate/memory.rs b/crates/polars-arrow/src/compute/aggregate/memory.rs index bd4ba7ab6384..1cf7512bbbef 100644 --- a/crates/polars-arrow/src/compute/aggregate/memory.rs +++ b/crates/polars-arrow/src/compute/aggregate/memory.rs @@ -18,7 +18,7 @@ macro_rules! dyn_binary { let values_end = offsets[offsets.len() - 1] as usize; values_end - values_start - + offsets.len() * std::mem::size_of::<$o>() + + offsets.len() * size_of::<$o>() + validity_size(array.validity()) }}; } @@ -50,7 +50,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { }, Primitive(PrimitiveType::DaysMs) => { let array = array.as_any().downcast_ref::().unwrap(); - array.values().len() * std::mem::size_of::() * 2 + validity_size(array.validity()) + array.values().len() * size_of::() * 2 + validity_size(array.validity()) }, Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let array = array @@ -58,7 +58,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { .downcast_ref::>() .unwrap(); - array.values().len() * std::mem::size_of::<$T>() + validity_size(array.validity()) + array.values().len() * size_of::<$T>() + validity_size(array.validity()) }), Binary => dyn_binary!(array, BinaryArray, i32), FixedSizeBinary => { @@ -74,7 +74,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { List => { let array = array.as_any().downcast_ref::>().unwrap(); estimated_bytes_size(array.values().as_ref()) - + array.offsets().len_proxy() * std::mem::size_of::() + + array.offsets().len_proxy() * size_of::() + validity_size(array.validity()) }, FixedSizeList => { @@ -84,7 +84,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { LargeList => { let array = array.as_any().downcast_ref::>().unwrap(); estimated_bytes_size(array.values().as_ref()) - + array.offsets().len_proxy() * std::mem::size_of::() + + array.offsets().len_proxy() * size_of::() + validity_size(array.validity()) }, Struct => { @@ -99,11 +99,11 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { }, Union => { let array = array.as_any().downcast_ref::().unwrap(); - let types = array.types().len() * std::mem::size_of::(); + let types = array.types().len() * size_of::(); let offsets = array .offsets() .as_ref() - .map(|x| x.len() * std::mem::size_of::()) + .map(|x| x.len() * size_of::()) .unwrap_or_default(); let fields = array .fields() @@ -124,7 +124,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { BinaryView => binview_size::<[u8]>(array.as_any().downcast_ref().unwrap()), Map => { let array = array.as_any().downcast_ref::().unwrap(); - let offsets = array.offsets().len_proxy() * std::mem::size_of::(); + let offsets = array.offsets().len_proxy() * size_of::(); offsets + estimated_bytes_size(array.field().as_ref()) + validity_size(array.validity()) }, } diff --git a/crates/polars-arrow/src/compute/aggregate/mod.rs b/crates/polars-arrow/src/compute/aggregate/mod.rs index 9528f833a67e..481194c1551c 100644 --- a/crates/polars-arrow/src/compute/aggregate/mod.rs +++ b/crates/polars-arrow/src/compute/aggregate/mod.rs @@ -1,4 +1,4 @@ -//! Contains different aggregation functions +/// ! Contains different aggregation functions #[cfg(feature = "compute_aggregate")] mod sum; #[cfg(feature = "compute_aggregate")] diff --git a/crates/polars-arrow/src/compute/aggregate/sum.rs b/crates/polars-arrow/src/compute/aggregate/sum.rs index 9fbed5f8b1b6..e2098d969e03 100644 --- a/crates/polars-arrow/src/compute/aggregate/sum.rs +++ b/crates/polars-arrow/src/compute/aggregate/sum.rs @@ -6,7 +6,7 @@ use polars_error::PolarsResult; use crate::array::{Array, PrimitiveArray}; use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; use crate::bitmap::Bitmap; -use crate::datatypes::{ArrowDataType, PhysicalType, PrimitiveType}; +use crate::datatypes::PhysicalType; use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; @@ -102,19 +102,6 @@ where } } -/// Whether [`sum`] supports `dtype` -pub fn can_sum(dtype: &ArrowDataType) -> bool { - if let PhysicalType::Primitive(primitive) = dtype.to_physical_type() { - use PrimitiveType::*; - matches!( - primitive, - Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 - ) - } else { - false - } -} - /// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical /// and logical types as `array`. /// # Error diff --git a/crates/polars-arrow/src/compute/arity.rs b/crates/polars-arrow/src/compute/arity.rs index c5b397f6faac..2670cfb4d031 100644 --- a/crates/polars-arrow/src/compute/arity.rs +++ b/crates/polars-arrow/src/compute/arity.rs @@ -1,10 +1,7 @@ //! Defines kernels suitable to perform operations to primitive arrays. -use polars_error::PolarsResult; - use super::utils::{check_same_len, combine_validities_and}; use crate::array::PrimitiveArray; -use crate::bitmap::{Bitmap, MutableBitmap}; use crate::datatypes::ArrowDataType; use crate::types::NativeType; @@ -29,104 +26,6 @@ where PrimitiveArray::::new(dtype, values.into(), array.validity().cloned()) } -/// Version of unary that checks for errors in the closure used to create the -/// buffer -pub fn try_unary( - array: &PrimitiveArray, - op: F, - dtype: ArrowDataType, -) -> PolarsResult> -where - I: NativeType, - O: NativeType, - F: Fn(I) -> PolarsResult, -{ - let values = array - .values() - .iter() - .map(|v| op(*v)) - .collect::>>()? - .into(); - - Ok(PrimitiveArray::::new( - dtype, - values, - array.validity().cloned(), - )) -} - -/// Version of unary that returns an array and bitmap. Used when working with -/// overflowing operations -pub fn unary_with_bitmap( - array: &PrimitiveArray, - op: F, - dtype: ArrowDataType, -) -> (PrimitiveArray, Bitmap) -where - I: NativeType, - O: NativeType, - F: Fn(I) -> (O, bool), -{ - let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); - - let values = array - .values() - .iter() - .map(|v| { - let (res, over) = op(*v); - mut_bitmap.push(over); - res - }) - .collect::>() - .into(); - - ( - PrimitiveArray::::new(dtype, values, array.validity().cloned()), - mut_bitmap.into(), - ) -} - -/// Version of unary that creates a mutable bitmap that is used to keep track -/// of checked operations. The resulting bitmap is compared with the array -/// bitmap to create the final validity array. -pub fn unary_checked( - array: &PrimitiveArray, - op: F, - dtype: ArrowDataType, -) -> PrimitiveArray -where - I: NativeType, - O: NativeType, - F: Fn(I) -> Option, -{ - let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); - - let values = array - .values() - .iter() - .map(|v| match op(*v) { - Some(val) => { - mut_bitmap.push(true); - val - }, - None => { - mut_bitmap.push(false); - O::default() - }, - }) - .collect::>() - .into(); - - // The validity has to be checked against the bitmap created during the - // creation of the values with the iterator. If an error was found during - // the iteration, then the validity is changed to None to mark the value - // as Null - let bitmap: Bitmap = mut_bitmap.into(); - let validity = combine_validities_and(array.validity(), Some(&bitmap)); - - PrimitiveArray::::new(dtype, values, validity) -} - /// Applies a binary operations to two primitive arrays. /// /// This is the fastest way to perform an operation on two primitive array when the benefits of a @@ -169,115 +68,3 @@ where PrimitiveArray::::new(dtype, values, validity) } - -/// Version of binary that checks for errors in the closure used to create the -/// buffer -pub fn try_binary( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - dtype: ArrowDataType, - op: F, -) -> PolarsResult> -where - T: NativeType, - D: NativeType, - F: Fn(T, D) -> PolarsResult, -{ - check_same_len(lhs, rhs)?; - - let validity = combine_validities_and(lhs.validity(), rhs.validity()); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| op(*l, *r)) - .collect::>>()? - .into(); - - Ok(PrimitiveArray::::new(dtype, values, validity)) -} - -/// Version of binary that returns an array and bitmap. Used when working with -/// overflowing operations -pub fn binary_with_bitmap( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - dtype: ArrowDataType, - op: F, -) -> (PrimitiveArray, Bitmap) -where - T: NativeType, - D: NativeType, - F: Fn(T, D) -> (T, bool), -{ - check_same_len(lhs, rhs).unwrap(); - - let validity = combine_validities_and(lhs.validity(), rhs.validity()); - - let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| { - let (res, over) = op(*l, *r); - mut_bitmap.push(over); - res - }) - .collect::>() - .into(); - - ( - PrimitiveArray::::new(dtype, values, validity), - mut_bitmap.into(), - ) -} - -/// Version of binary that creates a mutable bitmap that is used to keep track -/// of checked operations. The resulting bitmap is compared with the array -/// bitmap to create the final validity array. -pub fn binary_checked( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - dtype: ArrowDataType, - op: F, -) -> PrimitiveArray -where - T: NativeType, - D: NativeType, - F: Fn(T, D) -> Option, -{ - check_same_len(lhs, rhs).unwrap(); - - let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| match op(*l, *r) { - Some(val) => { - mut_bitmap.push(true); - val - }, - None => { - mut_bitmap.push(false); - T::default() - }, - }) - .collect::>() - .into(); - - let bitmap: Bitmap = mut_bitmap.into(); - let validity = combine_validities_and(lhs.validity(), rhs.validity()); - - // The validity has to be checked against the bitmap created during the - // creation of the values with the iterator. If an error was found during - // the iteration, then the validity is changed to None to mark the value - // as Null - let validity = combine_validities_and(validity.as_ref(), Some(&bitmap)); - - PrimitiveArray::::new(dtype, values, validity) -} diff --git a/crates/polars-arrow/src/compute/cast/binary_to.rs b/crates/polars-arrow/src/compute/cast/binary_to.rs index e14a03040522..5d2bd3e0b6d9 100644 --- a/crates/polars-arrow/src/compute/cast/binary_to.rs +++ b/crates/polars-arrow/src/compute/cast/binary_to.rs @@ -92,19 +92,6 @@ pub fn binary_to_utf8( ) } -/// Conversion to utf8 -/// # Errors -/// This function errors if the values are not valid utf8 -pub fn binary_to_large_utf8( - from: &BinaryArray, - to_dtype: ArrowDataType, -) -> PolarsResult> { - let values = from.values().clone(); - let offsets = from.offsets().into(); - - Utf8Array::::try_new(to_dtype, offsets, values, from.validity().cloned()) -} - /// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. pub(super) fn binary_to_primitive( from: &BinaryArray, @@ -212,7 +199,7 @@ pub fn fixed_size_binary_to_binview(from: &FixedSizeBinaryArray) -> BinaryViewAr // This is NOT equal to MAX_BYTES_PER_BUFFER because of integer division let split_point = num_elements_per_buffer * size; - // This is zero-copy for the buffer since split just increases the the data since + // This is zero-copy for the buffer since split just increases the data since let mut buffer = from.values().clone(); let mut buffers = Vec::with_capacity(num_buffers); for _ in 0..num_buffers - 1 { diff --git a/crates/polars-arrow/src/compute/cast/dictionary_to.rs b/crates/polars-arrow/src/compute/cast/dictionary_to.rs index d67a116ca0de..3f6211ad544e 100644 --- a/crates/polars-arrow/src/compute/cast/dictionary_to.rs +++ b/crates/polars-arrow/src/compute/cast/dictionary_to.rs @@ -1,6 +1,6 @@ use polars_error::{polars_bail, PolarsResult}; -use super::{primitive_as_primitive, primitive_to_primitive, CastOptionsImpl}; +use super::{primitive_to_primitive, CastOptionsImpl}; use crate::array::{Array, DictionaryArray, DictionaryKey}; use crate::compute::cast::cast; use crate::datatypes::ArrowDataType; @@ -23,101 +23,6 @@ macro_rules! key_cast { }}; } -/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] by keeping the -/// keys and casting the values to `values_type`. -/// # Errors -/// This function errors if the values are not castable to `values_type` -pub fn dictionary_to_dictionary_values( - from: &DictionaryArray, - values_type: &ArrowDataType, -) -> PolarsResult> { - let keys = from.keys(); - let values = from.values(); - let length = values.len(); - - let values = cast(values.as_ref(), values_type, CastOptionsImpl::default())?; - - assert_eq!(values.len(), length); // this is guaranteed by `cast` - unsafe { - DictionaryArray::try_new_unchecked(from.dtype().clone(), keys.clone(), values.clone()) - } -} - -/// Similar to dictionary_to_dictionary_values, but overflowing cast is wrapped -pub fn wrapping_dictionary_to_dictionary_values( - from: &DictionaryArray, - values_type: &ArrowDataType, -) -> PolarsResult> { - let keys = from.keys(); - let values = from.values(); - let length = values.len(); - - let values = cast( - values.as_ref(), - values_type, - CastOptionsImpl { - wrapped: true, - partial: false, - }, - )?; - assert_eq!(values.len(), length); // this is guaranteed by `cast` - unsafe { - DictionaryArray::try_new_unchecked(from.dtype().clone(), keys.clone(), values.clone()) - } -} - -/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] backed by a -/// different physical type of the keys, while keeping the values equal. -/// # Errors -/// Errors if any of the old keys' values is larger than the maximum value -/// supported by the new physical type. -pub fn dictionary_to_dictionary_keys( - from: &DictionaryArray, -) -> PolarsResult> -where - K1: DictionaryKey + num_traits::NumCast, - K2: DictionaryKey + num_traits::NumCast, -{ - let keys = from.keys(); - let values = from.values(); - let is_ordered = from.is_ordered(); - - let casted_keys = primitive_to_primitive::(keys, &K2::PRIMITIVE.into()); - - if casted_keys.null_count() > keys.null_count() { - polars_bail!(ComputeError: "overflow") - } else { - let dtype = - ArrowDataType::Dictionary(K2::KEY_TYPE, Box::new(values.dtype().clone()), is_ordered); - // SAFETY: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` - unsafe { DictionaryArray::try_new_unchecked(dtype, casted_keys, values.clone()) } - } -} - -/// Similar to dictionary_to_dictionary_keys, but overflowing cast is wrapped -pub fn wrapping_dictionary_to_dictionary_keys( - from: &DictionaryArray, -) -> PolarsResult> -where - K1: DictionaryKey + num_traits::AsPrimitive, - K2: DictionaryKey, -{ - let keys = from.keys(); - let values = from.values(); - let is_ordered = from.is_ordered(); - - let casted_keys = primitive_as_primitive::(keys, &K2::PRIMITIVE.into()); - - if casted_keys.null_count() > keys.null_count() { - polars_bail!(ComputeError: "overflow") - } else { - let dtype = - ArrowDataType::Dictionary(K2::KEY_TYPE, Box::new(values.dtype().clone()), is_ordered); - // some of the values may not fit in `usize` and thus this needs to be checked - DictionaryArray::try_new(dtype, casted_keys, values.clone()) - } -} - pub(super) fn dictionary_cast_dyn( array: &dyn Array, to_type: &ArrowDataType, diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 27f93eb07356..f34d9ebba2a5 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -15,7 +15,7 @@ use binview_to::binview_to_primitive_dyn; pub use binview_to::utf8view_to_utf8; pub use boolean_to::*; pub use decimal_to::*; -pub use dictionary_to::*; +use dictionary_to::*; use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult}; use polars_utils::IdxSize; pub use primitive_to::*; @@ -94,6 +94,7 @@ fn cast_struct( Ok(StructArray::new( to_type.clone(), + array.len(), new_values, array.validity().cloned(), )) @@ -337,16 +338,6 @@ pub fn cast( array.as_any().downcast_ref().unwrap(), ) .boxed()), - UInt8 => binview_to_primitive_dyn::(array, to_type, options), - UInt16 => binview_to_primitive_dyn::(array, to_type, options), - UInt32 => binview_to_primitive_dyn::(array, to_type, options), - UInt64 => binview_to_primitive_dyn::(array, to_type, options), - Int8 => binview_to_primitive_dyn::(array, to_type, options), - Int16 => binview_to_primitive_dyn::(array, to_type, options), - Int32 => binview_to_primitive_dyn::(array, to_type, options), - Int64 => binview_to_primitive_dyn::(array, to_type, options), - Float32 => binview_to_primitive_dyn::(array, to_type, options), - Float64 => binview_to_primitive_dyn::(array, to_type, options), LargeList(inner) if matches!(inner.dtype, ArrowDataType::UInt8) => { let bin_array = view_to_binary::(array.as_any().downcast_ref().unwrap()); Ok(binary_to_list(&bin_array, to_type.clone()).boxed()) @@ -403,17 +394,16 @@ pub fn cast( match to_type { BinaryView => Ok(arr.to_binview().boxed()), LargeUtf8 => Ok(binview_to::utf8view_to_utf8::(arr).boxed()), - UInt8 - | UInt16 - | UInt32 - | UInt64 - | Int8 - | Int16 - | Int32 - | Int64 - | Float32 - | Float64 - | Decimal(_, _) => cast(&arr.to_binview(), to_type, options), + UInt8 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt16 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int8 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int16 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Float32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Float64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), Timestamp(time_unit, None) => { utf8view_to_naive_timestamp_dyn(array, time_unit.to_owned()) }, diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index d017b0a8e212..b0d84d81a3dd 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -8,10 +8,10 @@ use super::CastOptionsImpl; use crate::array::*; use crate::bitmap::Bitmap; use crate::compute::arity::unary; -use crate::datatypes::{ArrowDataType, IntervalUnit, TimeUnit}; +use crate::datatypes::{ArrowDataType, TimeUnit}; use crate::offset::{Offset, Offsets}; use crate::temporal_conversions::*; -use crate::types::{days_ms, f16, months_days_ns, NativeType}; +use crate::types::{f16, NativeType}; pub trait SerPrimitive { fn write(f: &mut Vec, val: Self) -> usize @@ -525,169 +525,6 @@ pub fn timestamp_to_timestamp( } } -fn timestamp_to_utf8_impl( - from: &PrimitiveArray, - time_unit: TimeUnit, - timezone: T, -) -> Utf8Array -where - T::Offset: std::fmt::Display, -{ - match time_unit { - TimeUnit::Nanosecond => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_ns_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Microsecond => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_us_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Millisecond => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_ms_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Second => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_s_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - } -} - -#[cfg(feature = "chrono-tz")] -#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] -fn chrono_tz_timestamp_to_utf8( - from: &PrimitiveArray, - time_unit: TimeUnit, - timezone_str: &str, -) -> PolarsResult> { - let timezone = parse_offset_tz(timezone_str)?; - Ok(timestamp_to_utf8_impl::( - from, time_unit, timezone, - )) -} - -#[cfg(not(feature = "chrono-tz"))] -fn chrono_tz_timestamp_to_utf8( - _: &PrimitiveArray, - _: TimeUnit, - timezone_str: &str, -) -> PolarsResult> { - panic!( - "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", - timezone_str - ) -} - -/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. -pub fn timestamp_to_utf8( - from: &PrimitiveArray, - time_unit: TimeUnit, - timezone_str: &str, -) -> PolarsResult> { - let timezone = parse_offset(timezone_str); - - if let Ok(timezone) = timezone { - Ok(timestamp_to_utf8_impl::( - from, time_unit, timezone, - )) - } else { - chrono_tz_timestamp_to_utf8(from, time_unit, timezone_str) - } -} - -/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. -pub fn naive_timestamp_to_utf8( - from: &PrimitiveArray, - time_unit: TimeUnit, -) -> Utf8Array { - match time_unit { - TimeUnit::Nanosecond => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_ns_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Microsecond => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_us_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Millisecond => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_ms_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Second => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_s_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - } -} - -#[inline] -fn days_ms_to_months_days_ns_scalar(from: days_ms) -> months_days_ns { - months_days_ns::new(0, from.days(), from.milliseconds() as i64 * 1000) -} - -/// Casts [`days_ms`]s to [`months_days_ns`]. This operation is infalible and lossless. -pub fn days_ms_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { - unary( - from, - days_ms_to_months_days_ns_scalar, - ArrowDataType::Interval(IntervalUnit::MonthDayNano), - ) -} - -#[inline] -fn months_to_months_days_ns_scalar(from: i32) -> months_days_ns { - months_days_ns::new(from, 0, 0) -} - -/// Casts months represented as [`i32`]s to [`months_days_ns`]. This operation is infalible and lossless. -pub fn months_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { - unary( - from, - months_to_months_days_ns_scalar, - ArrowDataType::Interval(IntervalUnit::MonthDayNano), - ) -} - /// Casts f16 into f32 pub fn f16_to_f32(from: &PrimitiveArray) -> PrimitiveArray { unary(from, |x| x.to_f32(), ArrowDataType::Float32) diff --git a/crates/polars-arrow/src/compute/take/fixed_size_list.rs b/crates/polars-arrow/src/compute/take/fixed_size_list.rs index 9eccd4bc043b..2a52a1ae3fd1 100644 --- a/crates/polars-arrow/src/compute/take/fixed_size_list.rs +++ b/crates/polars-arrow/src/compute/take/fixed_size_list.rs @@ -15,22 +15,34 @@ // specific language governing permissions and limitations // under the License. +use std::mem::ManuallyDrop; + +use polars_utils::itertools::Itertools; +use polars_utils::IdxSize; + use super::Index; use crate::array::growable::{Growable, GrowableFixedSizeList}; -use crate::array::{FixedSizeListArray, PrimitiveArray}; +use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray}; +use crate::bitmap::MutableBitmap; +use crate::compute::take::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked}; +use crate::compute::utils::combine_validities_and; +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::legacy::prelude::FromData; +use crate::with_match_primitive_type; -/// `take` implementation for FixedSizeListArrays -pub(super) unsafe fn take_unchecked( +pub(super) unsafe fn take_unchecked_slow( values: &FixedSizeListArray, indices: &PrimitiveArray, ) -> FixedSizeListArray { + let take_len = std::cmp::min(values.len(), 1); let mut capacity = 0; let arrays = indices .values() .iter() .map(|index| { let index = index.to_usize(); - let slice = values.clone().sliced(index, 1); + let slice = values.clone().sliced_unchecked(index, take_len); capacity += slice.len(); slice }) @@ -61,3 +73,261 @@ pub(super) unsafe fn take_unchecked( growable.into() } } + +fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) { + if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype { + get_stride_and_leaf_type(inner.dtype(), *size_inner * size) + } else { + (size, dtype) + } +} + +fn get_leaves(array: &FixedSizeListArray) -> &dyn Array { + if let Some(array) = array.values().as_any().downcast_ref::() { + get_leaves(array) + } else { + &**array.values() + } +} + +fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) { + match array.dtype().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let arr = array.as_any().downcast_ref::>().unwrap(); + let values = arr.values(); + (bytemuck::cast_slice(values), size_of::<$T>()) + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn from_buffer(mut buf: ManuallyDrop>, dtype: &ArrowDataType) -> ArrayRef { + match dtype.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let ptr = buf.as_mut_ptr(); + let len_units = buf.len(); + let cap_units = buf.capacity(); + + let buf = Vec::from_raw_parts( + ptr as *mut $T, + len_units / size_of::<$T>(), + cap_units / size_of::<$T>(), + ); + + PrimitiveArray::from_data_default(buf.into(), None).boxed() + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn aligned_vec(dt: &ArrowDataType, n_bytes: usize) -> Vec { + match dt.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let n_units = (n_bytes / size_of::<$T>()) + 1; + + let mut aligned: Vec<$T> = Vec::with_capacity(n_units); + + let ptr = aligned.as_mut_ptr(); + let len_units = aligned.len(); + let cap_units = aligned.capacity(); + + std::mem::forget(aligned); + + Vec::from_raw_parts( + ptr as *mut u8, + len_units * size_of::<$T>(), + cap_units * size_of::<$T>(), + ) + + }), + _ => { + unimplemented!() + }, + } +} + +fn arr_no_validities_recursive(arr: &dyn Array) -> bool { + arr.validity().is_none() + && arr + .as_any() + .downcast_ref::() + .map_or(true, |x| arr_no_validities_recursive(x.values().as_ref())) +} + +/// `take` implementation for FixedSizeListArrays +pub(super) unsafe fn take_unchecked( + values: &FixedSizeListArray, + indices: &PrimitiveArray, +) -> ArrayRef { + let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1); + if leaf_type.to_physical_type().is_primitive() + && arr_no_validities_recursive(values.values().as_ref()) + { + let leaves = get_leaves(values); + + let (leaves_buf, leave_size) = get_buffer_and_size(leaves); + let bytes_per_element = leave_size * stride; + + let n_idx = indices.len(); + let total_bytes = bytes_per_element * n_idx; + + let mut buf = ManuallyDrop::new(aligned_vec(leaves.dtype(), total_bytes)); + let dst = buf.spare_capacity_mut(); + + let mut count = 0; + let outer_validity = if indices.null_count() == 0 { + for i in indices.values().iter() { + let i = i.to_usize(); + + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); + count += 1; + } + None + } else { + let mut new_validity = MutableBitmap::with_capacity(indices.len()); + new_validity.extend_constant(indices.len(), true); + for i in indices.iter() { + if let Some(i) = i { + let i = i.to_usize(); + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); + } else { + new_validity.set_unchecked(count, false); + std::ptr::write_bytes( + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + 0, + bytes_per_element, + ); + } + + count += 1; + } + Some(new_validity.freeze()) + }; + + assert_eq!(count * bytes_per_element, total_bytes); + buf.set_len(total_bytes); + + let outer_validity = combine_validities_and( + outer_validity.as_ref(), + values + .validity() + .map(|x| { + if indices.has_nulls() { + take_bitmap_nulls_unchecked(x, indices) + } else { + take_bitmap_unchecked(x, indices.as_slice().unwrap()) + } + }) + .as_ref(), + ); + + let leaves = from_buffer(buf, leaves.dtype()); + let mut shape = values.get_dims(); + shape[0] = Dimension::new(indices.len() as _); + let shape = shape + .into_iter() + .map(ReshapeDimension::Specified) + .collect_vec(); + + FixedSizeListArray::from_shape(leaves.clone(), &shape) + .unwrap() + .with_validity(outer_validity) + } else { + take_unchecked_slow(values, indices).boxed() + } +} + +#[cfg(test)] +mod tests { + use crate::array::StaticArray; + use crate::datatypes::ArrowDataType; + + /// Test gather for FixedSizeListArray with outer validity but no inner validities. + #[test] + fn test_arr_gather_nulls_outer_validity_19482() { + use polars_utils::IdxSize; + + use super::take_unchecked; + use crate::array::{FixedSizeListArray, Int64Array, PrimitiveArray}; + use crate::bitmap::Bitmap; + use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + + unsafe { + let dyn_arr = FixedSizeListArray::from_shape( + Box::new(Int64Array::from_slice([1, 2, 3, 4])), + &[ + ReshapeDimension::Specified(Dimension::new(2)), + ReshapeDimension::Specified(Dimension::new(2)), + ], + ) + .unwrap() + .with_validity(Some(Bitmap::from_iter([true, false]))); // FixedSizeListArray[[1, 2], None] + + let arr = dyn_arr + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + [arr.validity().is_some(), arr.values().validity().is_some()], + [true, false] + ); + + assert_eq!( + take_unchecked(arr, &PrimitiveArray::::from_slice([0, 1])), + dyn_arr + ) + } + } + + #[test] + fn test_arr_gather_nulls_inner_validity() { + use polars_utils::IdxSize; + + use super::take_unchecked; + use crate::array::{FixedSizeListArray, Int64Array, PrimitiveArray}; + use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + + unsafe { + let dyn_arr = FixedSizeListArray::from_shape( + Box::new(Int64Array::full_null(4, ArrowDataType::Int64)), + &[ + ReshapeDimension::Specified(Dimension::new(2)), + ReshapeDimension::Specified(Dimension::new(2)), + ], + ) + .unwrap(); // FixedSizeListArray[[None, None], [None, None]] + + let arr = dyn_arr + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + [arr.validity().is_some(), arr.values().validity().is_some()], + [false, true] + ); + + assert_eq!( + take_unchecked(arr, &PrimitiveArray::::from_slice([0, 1])), + dyn_arr + ) + } + } +} diff --git a/crates/polars-arrow/src/compute/take/mod.rs b/crates/polars-arrow/src/compute/take/mod.rs index aed14823af1e..bdd782a1d609 100644 --- a/crates/polars-arrow/src/compute/take/mod.rs +++ b/crates/polars-arrow/src/compute/take/mod.rs @@ -68,7 +68,7 @@ pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box { let array = values.as_any().downcast_ref().unwrap(); - Box::new(fixed_size_list::take_unchecked(array, indices)) + fixed_size_list::take_unchecked(array, indices) }, BinaryView => { take_binview_unchecked(values.as_any().downcast_ref().unwrap(), indices).boxed() diff --git a/crates/polars-arrow/src/compute/take/structure.rs b/crates/polars-arrow/src/compute/take/structure.rs index caad9f4ee0a4..472535fe126e 100644 --- a/crates/polars-arrow/src/compute/take/structure.rs +++ b/crates/polars-arrow/src/compute/take/structure.rs @@ -30,5 +30,5 @@ pub(super) unsafe fn take_unchecked(array: &StructArray, indices: &IdxArr) -> St .validity() .map(|b| super::bitmap::take_bitmap_nulls_unchecked(b, indices)); let validity = combine_validities_and(validity.as_ref(), indices.validity()); - StructArray::new(array.dtype().clone(), values, validity) + StructArray::new(array.dtype().clone(), indices.len(), values, validity) } diff --git a/crates/polars-arrow/src/compute/temporal.rs b/crates/polars-arrow/src/compute/temporal.rs index 309493fbbbdb..98bc920caeb3 100644 --- a/crates/polars-arrow/src/compute/temporal.rs +++ b/crates/polars-arrow/src/compute/temporal.rs @@ -75,8 +75,6 @@ macro_rules! date_like { } /// Extracts the years of a temporal array as [`PrimitiveArray`]. -/// -/// Use [`can_year`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn year(array: &dyn Array) -> PolarsResult> { date_like!(year, array, ArrowDataType::Int32) } @@ -84,7 +82,6 @@ pub fn year(array: &dyn Array) -> PolarsResult> { /// Extracts the months of a temporal array as [`PrimitiveArray`]. /// /// Value ranges from 1 to 12. -/// Use [`can_month`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn month(array: &dyn Array) -> PolarsResult> { date_like!(month, array, ArrowDataType::Int8) } @@ -92,7 +89,6 @@ pub fn month(array: &dyn Array) -> PolarsResult> { /// Extracts the days of a temporal array as [`PrimitiveArray`]. /// /// Value ranges from 1 to 32 (Last day depends on month). -/// Use [`can_day`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn day(array: &dyn Array) -> PolarsResult> { date_like!(day, array, ArrowDataType::Int8) } @@ -100,7 +96,6 @@ pub fn day(array: &dyn Array) -> PolarsResult> { /// Extracts weekday of a temporal array as [`PrimitiveArray`]. /// /// Monday is 1, Tuesday is 2, ..., Sunday is 7. -/// Use [`can_weekday`] to check if this operation is supported for the target [`ArrowDataType`] pub fn weekday(array: &dyn Array) -> PolarsResult> { date_like!(i8_weekday, array, ArrowDataType::Int8) } @@ -108,7 +103,6 @@ pub fn weekday(array: &dyn Array) -> PolarsResult> { /// Extracts ISO week of a temporal array as [`PrimitiveArray`]. /// /// Value ranges from 1 to 53 (Last week depends on the year). -/// Use [`can_iso_week`] to check if this operation is supported for the target [`ArrowDataType`] pub fn iso_week(array: &dyn Array) -> PolarsResult> { date_like!(i8_iso_week, array, ArrowDataType::Int8) } @@ -345,86 +339,3 @@ where }, } } - -/// Checks if an array of type `datatype` can perform year operation -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::temporal::can_year; -/// use polars_arrow::datatypes::{ArrowDataType}; -/// -/// assert_eq!(can_year(&ArrowDataType::Date32), true); -/// assert_eq!(can_year(&ArrowDataType::Int8), false); -/// ``` -pub fn can_year(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `datatype` can perform month operation -pub fn can_month(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `datatype` can perform day operation -pub fn can_day(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `dtype` can perform weekday operation -pub fn can_weekday(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `dtype` can perform ISO week operation -pub fn can_iso_week(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -fn can_date(dtype: &ArrowDataType) -> bool { - matches!( - dtype, - ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, _) - ) -} - -/// Checks if an array of type `datatype` can perform hour operation -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::temporal::can_hour; -/// use polars_arrow::datatypes::{ArrowDataType, TimeUnit}; -/// -/// assert_eq!(can_hour(&ArrowDataType::Time32(TimeUnit::Second)), true); -/// assert_eq!(can_hour(&ArrowDataType::Int8), false); -/// ``` -pub fn can_hour(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -/// Checks if an array of type `datatype` can perform minute operation -pub fn can_minute(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -/// Checks if an array of type `datatype` can perform second operation -pub fn can_second(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -/// Checks if an array of type `datatype` can perform nanosecond operation -pub fn can_nanosecond(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -fn can_time(dtype: &ArrowDataType) -> bool { - matches!( - dtype, - ArrowDataType::Time32(TimeUnit::Second) - | ArrowDataType::Time32(TimeUnit::Millisecond) - | ArrowDataType::Time64(TimeUnit::Microsecond) - | ArrowDataType::Time64(TimeUnit::Nanosecond) - | ArrowDataType::Date32 - | ArrowDataType::Date64 - | ArrowDataType::Timestamp(_, _) - ) -} diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs index 0b8e1ecd69f4..2402283aa4ef 100644 --- a/crates/polars-arrow/src/compute/utils.rs +++ b/crates/polars-arrow/src/compute/utils.rs @@ -38,6 +38,7 @@ pub fn combine_validities_or(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> _ => None, } } + pub fn combine_validities_and_not( opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>, diff --git a/crates/polars-arrow/src/datatypes/field.rs b/crates/polars-arrow/src/datatypes/field.rs index 8bf18af82f46..b1a5baf5c0ee 100644 --- a/crates/polars-arrow/src/datatypes/field.rs +++ b/crates/polars-arrow/src/datatypes/field.rs @@ -60,60 +60,3 @@ impl Field { &self.dtype } } - -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::Field { - fn from(value: Field) -> Self { - Self::new( - value.name.to_string(), - value.dtype.into(), - value.is_nullable, - ) - .with_metadata( - value - .metadata - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(), - ) - } -} - -#[cfg(feature = "arrow_rs")] -impl From for Field { - fn from(value: arrow_schema::Field) -> Self { - (&value).into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&arrow_schema::Field> for Field { - fn from(value: &arrow_schema::Field) -> Self { - let dtype = value.data_type().clone().into(); - let metadata = value - .metadata() - .iter() - .map(|(k, v)| (PlSmallStr::from_str(k), PlSmallStr::from_str(v))) - .collect(); - Self::new( - PlSmallStr::from_str(value.name().as_str()), - dtype, - value.is_nullable(), - ) - .with_metadata(metadata) - } -} - -#[cfg(feature = "arrow_rs")] -impl From for Field { - fn from(value: arrow_schema::FieldRef) -> Self { - value.as_ref().into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&arrow_schema::FieldRef> for Field { - fn from(value: &arrow_schema::FieldRef) -> Self { - value.as_ref().into() - } -} diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 8f2226c709e6..0c0b7024bc71 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -2,6 +2,7 @@ mod field; mod physical_type; +pub mod reshape; mod schema; use std::collections::BTreeMap; @@ -175,143 +176,6 @@ pub enum ArrowDataType { Unknown, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::DataType { - fn from(value: ArrowDataType) -> Self { - use arrow_schema::{Field as ArrowField, UnionFields}; - - match value { - ArrowDataType::Null => Self::Null, - ArrowDataType::Boolean => Self::Boolean, - ArrowDataType::Int8 => Self::Int8, - ArrowDataType::Int16 => Self::Int16, - ArrowDataType::Int32 => Self::Int32, - ArrowDataType::Int64 => Self::Int64, - ArrowDataType::UInt8 => Self::UInt8, - ArrowDataType::UInt16 => Self::UInt16, - ArrowDataType::UInt32 => Self::UInt32, - ArrowDataType::UInt64 => Self::UInt64, - ArrowDataType::Float16 => Self::Float16, - ArrowDataType::Float32 => Self::Float32, - ArrowDataType::Float64 => Self::Float64, - ArrowDataType::Timestamp(unit, tz) => { - Self::Timestamp(unit.into(), tz.map(|x| Arc::::from(x.as_str()))) - }, - ArrowDataType::Date32 => Self::Date32, - ArrowDataType::Date64 => Self::Date64, - ArrowDataType::Time32(unit) => Self::Time32(unit.into()), - ArrowDataType::Time64(unit) => Self::Time64(unit.into()), - ArrowDataType::Duration(unit) => Self::Duration(unit.into()), - ArrowDataType::Interval(unit) => Self::Interval(unit.into()), - ArrowDataType::Binary => Self::Binary, - ArrowDataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), - ArrowDataType::LargeBinary => Self::LargeBinary, - ArrowDataType::Utf8 => Self::Utf8, - ArrowDataType::LargeUtf8 => Self::LargeUtf8, - ArrowDataType::List(f) => Self::List(Arc::new((*f).into())), - ArrowDataType::FixedSizeList(f, size) => { - Self::FixedSizeList(Arc::new((*f).into()), size as _) - }, - ArrowDataType::LargeList(f) => Self::LargeList(Arc::new((*f).into())), - ArrowDataType::Struct(f) => Self::Struct(f.into_iter().map(ArrowField::from).collect()), - ArrowDataType::Union(fields, Some(ids), mode) => { - let ids = ids.into_iter().map(|x| x as _); - let fields = fields.into_iter().map(ArrowField::from); - Self::Union(UnionFields::new(ids, fields), mode.into()) - }, - ArrowDataType::Union(fields, None, mode) => { - let ids = 0..fields.len() as i8; - let fields = fields.into_iter().map(ArrowField::from); - Self::Union(UnionFields::new(ids, fields), mode.into()) - }, - ArrowDataType::Map(f, ordered) => Self::Map(Arc::new((*f).into()), ordered), - ArrowDataType::Dictionary(key, value, _) => Self::Dictionary( - Box::new(ArrowDataType::from(key).into()), - Box::new((*value).into()), - ), - ArrowDataType::Decimal(precision, scale) => { - Self::Decimal128(precision as _, scale as _) - }, - ArrowDataType::Decimal256(precision, scale) => { - Self::Decimal256(precision as _, scale as _) - }, - ArrowDataType::Extension(_, d, _) => (*d).into(), - ArrowDataType::BinaryView | ArrowDataType::Utf8View => { - panic!("view datatypes not supported by arrow-rs") - }, - ArrowDataType::Unknown => unimplemented!(), - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for ArrowDataType { - fn from(value: arrow_schema::DataType) -> Self { - use arrow_schema::DataType; - match value { - DataType::Null => Self::Null, - DataType::Boolean => Self::Boolean, - DataType::Int8 => Self::Int8, - DataType::Int16 => Self::Int16, - DataType::Int32 => Self::Int32, - DataType::Int64 => Self::Int64, - DataType::UInt8 => Self::UInt8, - DataType::UInt16 => Self::UInt16, - DataType::UInt32 => Self::UInt32, - DataType::UInt64 => Self::UInt64, - DataType::Float16 => Self::Float16, - DataType::Float32 => Self::Float32, - DataType::Float64 => Self::Float64, - DataType::Timestamp(unit, tz) => { - Self::Timestamp(unit.into(), tz.map(|x| PlSmallStr::from_str(x.as_ref()))) - }, - DataType::Date32 => Self::Date32, - DataType::Date64 => Self::Date64, - DataType::Time32(unit) => Self::Time32(unit.into()), - DataType::Time64(unit) => Self::Time64(unit.into()), - DataType::Duration(unit) => Self::Duration(unit.into()), - DataType::Interval(unit) => Self::Interval(unit.into()), - DataType::Binary => Self::Binary, - DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), - DataType::LargeBinary => Self::LargeBinary, - DataType::Utf8 => Self::Utf8, - DataType::LargeUtf8 => Self::LargeUtf8, - DataType::List(f) => Self::List(Box::new(f.into())), - DataType::FixedSizeList(f, size) => Self::FixedSizeList(Box::new(f.into()), size as _), - DataType::LargeList(f) => Self::LargeList(Box::new(f.into())), - DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()), - DataType::Union(fields, mode) => { - let ids = fields.iter().map(|(x, _)| x as _).collect(); - let fields = fields.iter().map(|(_, f)| f.into()).collect(); - Self::Union(fields, Some(ids), mode.into()) - }, - DataType::Map(f, ordered) => Self::Map(Box::new(f.into()), ordered), - DataType::Dictionary(key, value) => { - let key = match *key { - DataType::Int8 => IntegerType::Int8, - DataType::Int16 => IntegerType::Int16, - DataType::Int32 => IntegerType::Int32, - DataType::Int64 => IntegerType::Int64, - DataType::UInt8 => IntegerType::UInt8, - DataType::UInt16 => IntegerType::UInt16, - DataType::UInt32 => IntegerType::UInt32, - DataType::UInt64 => IntegerType::UInt64, - d => panic!("illegal dictionary key type: {d}"), - }; - Self::Dictionary(key, Box::new((*value).into()), false) - }, - DataType::Decimal128(precision, scale) => Self::Decimal(precision as _, scale as _), - DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), - DataType::RunEndEncoded(_, _) => { - panic!("Run-end encoding not supported by polars_arrow") - }, - // This ensures that it doesn't fail to compile when new variants are added to Arrow - #[allow(unreachable_patterns)] - dtype => unimplemented!("unsupported datatype: {dtype}"), - } - } -} - /// Mode of [`ArrowDataType::Union`] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -322,26 +186,6 @@ pub enum UnionMode { Sparse, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::UnionMode { - fn from(value: UnionMode) -> Self { - match value { - UnionMode::Dense => Self::Dense, - UnionMode::Sparse => Self::Sparse, - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for UnionMode { - fn from(value: arrow_schema::UnionMode) -> Self { - match value { - arrow_schema::UnionMode::Dense => Self::Dense, - arrow_schema::UnionMode::Sparse => Self::Sparse, - } - } -} - impl UnionMode { /// Constructs a [`UnionMode::Sparse`] if the input bool is true, /// or otherwise constructs a [`UnionMode::Dense`] @@ -378,30 +222,6 @@ pub enum TimeUnit { Nanosecond, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::TimeUnit { - fn from(value: TimeUnit) -> Self { - match value { - TimeUnit::Nanosecond => Self::Nanosecond, - TimeUnit::Millisecond => Self::Millisecond, - TimeUnit::Microsecond => Self::Microsecond, - TimeUnit::Second => Self::Second, - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for TimeUnit { - fn from(value: arrow_schema::TimeUnit) -> Self { - match value { - arrow_schema::TimeUnit::Nanosecond => Self::Nanosecond, - arrow_schema::TimeUnit::Millisecond => Self::Millisecond, - arrow_schema::TimeUnit::Microsecond => Self::Microsecond, - arrow_schema::TimeUnit::Second => Self::Second, - } - } -} - /// Interval units defined in Arrow #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -415,28 +235,6 @@ pub enum IntervalUnit { MonthDayNano, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::IntervalUnit { - fn from(value: IntervalUnit) -> Self { - match value { - IntervalUnit::YearMonth => Self::YearMonth, - IntervalUnit::DayTime => Self::DayTime, - IntervalUnit::MonthDayNano => Self::MonthDayNano, - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for IntervalUnit { - fn from(value: arrow_schema::IntervalUnit) -> Self { - match value { - arrow_schema::IntervalUnit::YearMonth => Self::YearMonth, - arrow_schema::IntervalUnit::DayTime => Self::DayTime, - arrow_schema::IntervalUnit::MonthDayNano => Self::MonthDayNano, - } - } -} - impl ArrowDataType { /// the [`PhysicalType`] of this [`ArrowDataType`]. pub fn to_physical_type(&self) -> PhysicalType { @@ -568,6 +366,25 @@ impl ArrowDataType { matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) } + pub fn is_numeric(&self) -> bool { + use ArrowDataType as D; + matches!( + self, + D::Int8 + | D::Int16 + | D::Int32 + | D::Int64 + | D::UInt8 + | D::UInt16 + | D::UInt32 + | D::UInt64 + | D::Float32 + | D::Float64 + | D::Decimal(_, _) + | D::Decimal256(_, _) + ) + } + pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType { ArrowDataType::FixedSizeList( Box::new(Field::new( diff --git a/crates/polars-arrow/src/datatypes/physical_type.rs b/crates/polars-arrow/src/datatypes/physical_type.rs index 174c0401ca3f..732a129055a6 100644 --- a/crates/polars-arrow/src/datatypes/physical_type.rs +++ b/crates/polars-arrow/src/datatypes/physical_type.rs @@ -57,6 +57,10 @@ impl PhysicalType { false } } + + pub fn is_primitive(&self) -> bool { + matches!(self, Self::Primitive(_)) + } } /// the set of valid indices types of a dictionary-encoded Array. diff --git a/crates/polars-core/src/datatypes/reshape.rs b/crates/polars-arrow/src/datatypes/reshape.rs similarity index 100% rename from crates/polars-core/src/datatypes/reshape.rs rename to crates/polars-arrow/src/datatypes/reshape.rs diff --git a/crates/polars-arrow/src/ffi/array.rs b/crates/polars-arrow/src/ffi/array.rs index 7f7e9c409782..5e179be6cca7 100644 --- a/crates/polars-arrow/src/ffi/array.rs +++ b/crates/polars-arrow/src/ffi/array.rs @@ -104,6 +104,7 @@ impl ArrowArray { ArrowDataType::BinaryView | ArrowDataType::Utf8View ); + #[allow(unused_mut)] let (offset, mut buffers, children, dictionary) = offset_buffers_children_dictionary(array.as_ref()); @@ -224,11 +225,7 @@ unsafe fn get_buffer_ptr( ); } - if array - .buffers - .align_offset(std::mem::align_of::<*mut *const u8>()) - != 0 - { + if array.buffers.align_offset(align_of::<*mut *const u8>()) != 0 { polars_bail!( ComputeError: "an ArrowArray of type {dtype:?} must have buffer {index} aligned to type {}", @@ -293,7 +290,7 @@ unsafe fn create_buffer( // We have to check alignment. // This is the zero-copy path. - if ptr.align_offset(std::mem::align_of::()) == 0 { + if ptr.align_offset(align_of::()) == 0 { let storage = SharedStorage::from_internal_arrow_array(ptr, len, owner); Ok(Buffer::from_storage(storage).sliced(offset, len - offset)) } @@ -623,7 +620,7 @@ pub struct ArrowArrayChild<'a> { parent: InternalArrowArray, } -impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { +impl ArrowArrayRef for ArrowArrayChild<'_> { /// the dtype as declared in the schema fn dtype(&self) -> &ArrowDataType { &self.dtype diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs index f9423f83305a..f2f8af90c167 100644 --- a/crates/polars-arrow/src/io/avro/read/deserialize.rs +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -195,9 +195,8 @@ fn deserialize_value<'a>( array.push(Some(value)) }, PrimitiveType::Float32 => { - let value = - f32::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); - block = &block[std::mem::size_of::()..]; + let value = f32::from_le_bytes(block[..size_of::()].try_into().unwrap()); + block = &block[size_of::()..]; let array = array .as_mut_any() .downcast_mut::>() @@ -205,9 +204,8 @@ fn deserialize_value<'a>( array.push(Some(value)) }, PrimitiveType::Float64 => { - let value = - f64::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); - block = &block[std::mem::size_of::()..]; + let value = f64::from_le_bytes(block[..size_of::()].try_into().unwrap()); + block = &block[size_of::()..]; let array = array .as_mut_any() .downcast_mut::>() @@ -404,10 +402,10 @@ fn skip_item<'a>( let _ = util::zigzag_i64(&mut block)?; }, PrimitiveType::Float32 => { - block = &block[std::mem::size_of::()..]; + block = &block[size_of::()..]; }, PrimitiveType::Float64 => { - block = &block[std::mem::size_of::()..]; + block = &block[size_of::()..]; }, PrimitiveType::MonthDayNano => { block = &block[12..]; @@ -507,7 +505,9 @@ pub fn deserialize( }? } } + RecordBatchT::try_new( + rows, arrays .iter_mut() .zip(projection.iter()) diff --git a/crates/polars-arrow/src/io/avro/read/nested.rs b/crates/polars-arrow/src/io/avro/read/nested.rs index fc7e07487d83..e197b827d732 100644 --- a/crates/polars-arrow/src/io/avro/read/nested.rs +++ b/crates/polars-arrow/src/io/avro/read/nested.rs @@ -211,6 +211,7 @@ impl MutableArray for FixedItemsUtf8Dictionary { #[derive(Debug)] pub struct DynMutableStructArray { dtype: ArrowDataType, + length: usize, values: Vec>, validity: Option, } @@ -219,6 +220,7 @@ impl DynMutableStructArray { pub fn new(values: Vec>, dtype: ArrowDataType) -> Self { Self { dtype, + length: 0, values, validity: None, } @@ -234,12 +236,14 @@ impl DynMutableStructArray { if let Some(validity) = &mut self.validity { validity.push(true) } + self.length += 1; Ok(()) } #[inline] fn push_null(&mut self) { self.values.iter_mut().for_each(|x| x.push_null()); + self.length += 1; match &mut self.validity { Some(validity) => validity.push(false), None => self.init_validity(), @@ -258,7 +262,7 @@ impl DynMutableStructArray { impl MutableArray for DynMutableStructArray { fn len(&self) -> usize { - self.values[0].len() + self.length } fn validity(&self) -> Option<&MutableBitmap> { @@ -270,6 +274,7 @@ impl MutableArray for DynMutableStructArray { Box::new(StructArray::new( self.dtype.clone(), + self.length, values, std::mem::take(&mut self.validity).map(|x| x.into()), )) @@ -280,6 +285,7 @@ impl MutableArray for DynMutableStructArray { std::sync::Arc::new(StructArray::new( self.dtype.clone(), + self.length, values, std::mem::take(&mut self.validity).map(|x| x.into()), )) diff --git a/crates/polars-arrow/src/io/flight/mod.rs b/crates/polars-arrow/src/io/flight/mod.rs deleted file mode 100644 index c02a4889f7bb..000000000000 --- a/crates/polars-arrow/src/io/flight/mod.rs +++ /dev/null @@ -1,241 +0,0 @@ -//! Serialization and deserialization to Arrow's flight protocol - -use arrow_format::flight::data::{FlightData, SchemaResult}; -use arrow_format::ipc; -use arrow_format::ipc::planus::ReadAsRoot; -use polars_error::{polars_bail, polars_err, PolarsResult}; - -use super::ipc::read::Dictionaries; -pub use super::ipc::write::default_ipc_fields; -use super::ipc::{IpcField, IpcSchema}; -use crate::array::Array; -use crate::datatypes::*; -pub use crate::io::ipc::write::common::WriteOptions; -use crate::io::ipc::write::common::{encode_chunk, DictionaryTracker, EncodedData}; -use crate::io::ipc::{read, write}; -use crate::record_batch::RecordBatchT; - -/// Serializes [`RecordBatchT`] to a vector of [`FlightData`] representing the serialized dictionaries -/// and a [`FlightData`] representing the batch. -/// # Errors -/// This function errors iff `fields` is not consistent with `columns` -pub fn serialize_batch( - chunk: &RecordBatchT>, - fields: &[IpcField], - options: &WriteOptions, -) -> PolarsResult<(Vec, FlightData)> { - if fields.len() != chunk.arrays().len() { - polars_bail!(oos = "The argument `fields` must be consistent with the columns' schema. Use e.g. &polars_arrow::io::flight::default_ipc_fields(&schema.fields)"); - } - - let mut dictionary_tracker = DictionaryTracker { - dictionaries: Default::default(), - cannot_replace: false, - }; - - let (encoded_dictionaries, encoded_batch) = - encode_chunk(chunk, fields, &mut dictionary_tracker, options) - .expect("DictionaryTracker configured above to not error on replacement"); - - let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); - let flight_batch = encoded_batch.into(); - - Ok((flight_dictionaries, flight_batch)) -} - -impl From for FlightData { - fn from(data: EncodedData) -> Self { - FlightData { - data_header: data.ipc_message, - data_body: data.arrow_data, - ..Default::default() - } - } -} - -/// Serializes a [`ArrowSchema`] to [`SchemaResult`]. -pub fn serialize_schema_to_result( - schema: &ArrowSchema, - ipc_fields: Option<&[IpcField]>, -) -> SchemaResult { - SchemaResult { - schema: _serialize_schema(schema, ipc_fields), - } -} - -/// Serializes a [`ArrowSchema`] to [`FlightData`]. -pub fn serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> FlightData { - FlightData { - data_header: _serialize_schema(schema, ipc_fields), - ..Default::default() - } -} - -/// Convert a [`ArrowSchema`] to bytes in the format expected in [`arrow_format::flight::data::FlightInfo`]. -pub fn serialize_schema_to_info( - schema: &ArrowSchema, - ipc_fields: Option<&[IpcField]>, -) -> PolarsResult> { - let encoded_data = if let Some(ipc_fields) = ipc_fields { - schema_as_encoded_data(schema, ipc_fields) - } else { - let ipc_fields = default_ipc_fields(schema.iter_values()); - schema_as_encoded_data(schema, &ipc_fields) - }; - - let mut schema = vec![]; - write::common_sync::write_message(&mut schema, &encoded_data)?; - Ok(schema) -} - -fn _serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> Vec { - if let Some(ipc_fields) = ipc_fields { - write::schema_to_bytes(schema, ipc_fields) - } else { - let ipc_fields = default_ipc_fields(schema.iter_values()); - write::schema_to_bytes(schema, &ipc_fields) - } -} - -fn schema_as_encoded_data(schema: &ArrowSchema, ipc_fields: &[IpcField]) -> EncodedData { - EncodedData { - ipc_message: write::schema_to_bytes(schema, ipc_fields), - arrow_data: vec![], - } -} - -/// Deserialize an IPC message into [`ArrowSchema`], [`IpcSchema`]. -/// Use to deserialize [`FlightData::data_header`] and [`SchemaResult::schema`]. -pub fn deserialize_schemas(bytes: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchema)> { - read::deserialize_schema(bytes) -} - -/// Deserializes [`FlightData`] representing a record batch message to [`RecordBatchT`]. -pub fn deserialize_batch( - data: &FlightData, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - dictionaries: &read::Dictionaries, -) -> PolarsResult>> { - // check that the data_header is a record batch message - let message = arrow_format::ipc::MessageRef::read_as_root(&data.data_header) - .map_err(|_err| polars_err!(oos = "Unable to get root as message: {err:?}"))?; - - let length = data.data_body.len(); - let mut reader = std::io::Cursor::new(&data.data_body); - - match message.header()?.ok_or_else(|| { - polars_err!(oos = "Unable to convert flight data header to a record batch".to_string()) - })? { - ipc::MessageHeaderRef::RecordBatch(batch) => read::read_record_batch( - batch, - fields, - ipc_schema, - None, - None, - dictionaries, - message.version()?, - &mut reader, - 0, - length as u64, - &mut Default::default(), - ), - _ => polars_bail!(oos = "flight currently only supports reading RecordBatch messages"), - } -} - -/// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`. -pub fn deserialize_dictionary( - data: &FlightData, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - dictionaries: &mut read::Dictionaries, -) -> PolarsResult<()> { - let message = ipc::MessageRef::read_as_root(&data.data_header)?; - - let chunk = if let ipc::MessageHeaderRef::DictionaryBatch(chunk) = message - .header()? - .ok_or_else(|| polars_err!(oos = "Header is required"))? - { - chunk - } else { - return Ok(()); - }; - - let length = data.data_body.len(); - let mut reader = std::io::Cursor::new(&data.data_body); - read::read_dictionary( - chunk, - fields, - ipc_schema, - dictionaries, - &mut reader, - 0, - length as u64, - &mut Default::default(), - )?; - - Ok(()) -} - -/// Deserializes [`FlightData`] into either a [`RecordBatchT`] (when the message is a record batch) -/// or by upserting into `dictionaries` (when the message is a dictionary) -pub fn deserialize_message( - data: &FlightData, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - dictionaries: &mut Dictionaries, -) -> PolarsResult>>> { - let FlightData { - data_header, - data_body, - .. - } = data; - - let message = arrow_format::ipc::MessageRef::read_as_root(data_header)?; - let header = message - .header()? - .ok_or_else(|| polars_err!(oos = "IPC Message must contain a header"))?; - - match header { - ipc::MessageHeaderRef::RecordBatch(batch) => { - let length = data_body.len(); - let mut reader = std::io::Cursor::new(data_body); - - let chunk = read::read_record_batch( - batch, - fields, - ipc_schema, - None, - None, - dictionaries, - arrow_format::ipc::MetadataVersion::V5, - &mut reader, - 0, - length as u64, - &mut Default::default(), - )?; - - Ok(chunk.into()) - }, - ipc::MessageHeaderRef::DictionaryBatch(dict_batch) => { - let length = data_body.len(); - let mut reader = std::io::Cursor::new(data_body); - - read::read_dictionary( - dict_batch, - fields, - ipc_schema, - dictionaries, - &mut reader, - 0, - length as u64, - &mut Default::default(), - )?; - Ok(None) - }, - t => polars_bail!(ComputeError: - "Reading types other than record batches not yet supported, unable to read {t:?}" - ), - } -} diff --git a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs index 5cf68f1d1d95..a4b72af817e9 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs @@ -7,6 +7,7 @@ use super::super::super::IpcField; use super::super::deserialize::{read, skip}; use super::super::read_basic::*; use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use super::try_get_array_length; use crate::array::StructArray; use crate::datatypes::ArrowDataType; use crate::io::ipc::read::array::try_get_field_node; @@ -28,6 +29,7 @@ pub fn read_struct( scratch: &mut Vec, ) -> PolarsResult { let field_node = try_get_field_node(field_nodes, &dtype)?; + let length = try_get_array_length(field_node, limit)?; let validity = read_validity( buffers, @@ -64,7 +66,7 @@ pub fn read_struct( }) .collect::>>()?; - StructArray::try_new(dtype, values, validity) + StructArray::try_new(dtype, length, values, validity) } pub fn skip_struct( diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index 2458cba702b9..6b893c0e8ce3 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -42,7 +42,7 @@ impl<'a, A, I: Iterator> ProjectionIter<'a, A, I> { } } -impl<'a, A, I: Iterator> Iterator for ProjectionIter<'a, A, I> { +impl> Iterator for ProjectionIter<'_, A, I> { type Item = ProjectionResult; fn next(&mut self) -> Option { @@ -188,7 +188,16 @@ pub fn read_record_batch( }) .collect::>>()? }; - RecordBatchT::try_new(columns) + + let length = batch + .length() + .map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData)) + .unwrap() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + RecordBatchT::try_new(length, columns) } fn find_first_dict_field_d<'a>( @@ -353,6 +362,8 @@ pub fn apply_projection( chunk: RecordBatchT>, map: &PlHashMap, ) -> RecordBatchT> { + let length = chunk.len(); + // re-order according to projection let arrays = chunk.into_arrays(); let mut new_arrays = arrays.clone(); @@ -360,7 +371,7 @@ pub fn apply_projection( map.iter() .for_each(|(old, new)| new_arrays[*new] = arrays[*old].clone()); - RecordBatchT::new(new_arrays) + RecordBatchT::new(length, new_arrays) } #[cfg(test)] diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index 6c831064d5a1..a83e1b758d80 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -9,7 +9,7 @@ use polars_utils::aliases::{InitHashMaps, PlHashMap}; use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER}; use super::common::*; use super::schema::fb_to_schema; -use super::{Dictionaries, OutOfSpecKind}; +use super::{Dictionaries, OutOfSpecKind, SendableIterator}; use crate::array::Array; use crate::datatypes::ArrowSchemaRef; use crate::io::ipc::IpcSchema; @@ -129,14 +129,7 @@ pub fn read_file_dictionaries( Ok(dictionaries) } -/// Reads the footer's length and magic number in footer -fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> { - // read footer length and magic number in footer - let end = reader.seek(SeekFrom::End(-10))? + 10; - - let mut footer: [u8; 10] = [0; 10]; - - reader.read_exact(&mut footer)?; +pub(super) fn decode_footer_len(footer: [u8; 10], end: u64) -> PolarsResult<(u64, usize)> { let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); if footer[4..] != ARROW_MAGIC_V2 { @@ -152,6 +145,17 @@ fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> Ok((end, footer_len)) } +/// Reads the footer's length and magic number in footer +fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10))? + 10; + + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer)?; + decode_footer_len(footer, end) +} + fn read_footer(reader: &mut R, footer_len: usize) -> PolarsResult> { // read footer reader.seek(SeekFrom::End(-10 - footer_len as i64))?; @@ -187,29 +191,61 @@ fn deserialize_footer_blocks( Ok((footer, blocks)) } -pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { - let (footer, blocks) = deserialize_footer_blocks(footer_data)?; +pub(super) fn deserialize_footer_ref(footer_data: &[u8]) -> PolarsResult { + arrow_format::ipc::FooterRef::read_as_root(footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err))) +} - let ipc_schema = footer +pub(super) fn deserialize_schema_ref_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult { + footer .schema() .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferSchema(err)))? - .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingSchema))?; - let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingSchema)) +} + +/// Get the IPC blocks from the footer containing record batches +pub(super) fn iter_recordbatch_blocks_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult> + '_> { + let blocks = footer + .record_batches() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingRecordBatches))?; + + Ok(blocks.iter().map(|block| { + block + .try_into() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err))) + })) +} +pub(super) fn iter_dictionary_blocks_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult> + '_>> +{ let dictionaries = footer .dictionaries() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))? - .map(|dictionaries| { - dictionaries - .into_iter() - .map(|block| { - block.try_into().map_err(|err| { - polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) - }) - }) - .collect::>>() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))?; + + Ok(dictionaries.map(|dicts| { + dicts.into_iter().map(|block| { + block.try_into().map_err(|err| { + polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) }) + })) +} + +pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { + let footer = deserialize_footer_ref(footer_data)?; + let blocks = iter_recordbatch_blocks_from_footer(footer)?.collect::>>()?; + let dictionaries = iter_dictionary_blocks_from_footer(footer)? + .map(|dicts| dicts.collect::>>()) .transpose()?; + let ipc_schema = deserialize_schema_ref_from_footer(footer)?; + let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; Ok(FileMetadata { schema: Arc::new(schema), diff --git a/crates/polars-arrow/src/io/ipc/read/file_async.rs b/crates/polars-arrow/src/io/ipc/read/file_async.rs deleted file mode 100644 index 567a58c1a1fb..000000000000 --- a/crates/polars-arrow/src/io/ipc/read/file_async.rs +++ /dev/null @@ -1,350 +0,0 @@ -//! Async reader for Arrow IPC files -use std::io::SeekFrom; - -use arrow_format::ipc::planus::ReadAsRoot; -use arrow_format::ipc::{Block, MessageHeaderRef}; -use futures::stream::BoxStream; -use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt}; -use polars_error::{polars_bail, polars_err, PolarsResult}; -use polars_utils::aliases::PlHashMap; - -use super::common::{apply_projection, prepare_projection, read_dictionary, read_record_batch}; -use super::file::{deserialize_footer, get_record_batch}; -use super::{Dictionaries, FileMetadata, OutOfSpecKind}; -use crate::array::*; -use crate::datatypes::ArrowSchema; -use crate::io::ipc::{IpcSchema, ARROW_MAGIC_V2, CONTINUATION_MARKER}; -use crate::record_batch::RecordBatchT; - -/// Async reader for Arrow IPC files -pub struct FileStream<'a> { - stream: BoxStream<'a, PolarsResult>>>, - schema: Option, - metadata: FileMetadata, -} - -impl<'a> FileStream<'a> { - /// Create a new IPC file reader. - /// - /// # Examples - /// See [`FileSink`](crate::io::ipc::write::file_async::FileSink). - pub fn new( - reader: R, - metadata: FileMetadata, - projection: Option>, - limit: Option, - ) -> Self - where - R: AsyncRead + AsyncSeek + Unpin + Send + 'a, - { - let (projection, schema) = if let Some(projection) = projection { - let (p, h, schema) = prepare_projection(&metadata.schema, projection); - (Some((p, h)), Some(schema)) - } else { - (None, None) - }; - - let stream = Self::stream(reader, None, metadata.clone(), projection, limit); - Self { - stream, - metadata, - schema, - } - } - - /// Get the metadata from the IPC file. - pub fn metadata(&self) -> &FileMetadata { - &self.metadata - } - - /// Get the projected schema from the IPC file. - pub fn schema(&self) -> &ArrowSchema { - self.schema.as_ref().unwrap_or(&self.metadata.schema) - } - - fn stream( - mut reader: R, - mut dictionaries: Option, - metadata: FileMetadata, - projection: Option<(Vec, PlHashMap)>, - limit: Option, - ) -> BoxStream<'a, PolarsResult>>> - where - R: AsyncRead + AsyncSeek + Unpin + Send + 'a, - { - async_stream::try_stream! { - // read dictionaries - cached_read_dictionaries(&mut reader, &metadata, &mut dictionaries).await?; - - let mut meta_buffer = Default::default(); - let mut block_buffer = Default::default(); - let mut scratch = Default::default(); - let mut remaining = limit.unwrap_or(usize::MAX); - for block in 0..metadata.blocks.len() { - let chunk = read_batch( - &mut reader, - dictionaries.as_mut().unwrap(), - &metadata, - projection.as_ref().map(|x| x.0.as_ref()), - Some(remaining), - block, - &mut meta_buffer, - &mut block_buffer, - &mut scratch - ).await?; - remaining -= chunk.len(); - - let chunk = if let Some((_, map)) = &projection { - // re-order according to projection - apply_projection(chunk, map) - } else { - chunk - }; - - yield chunk; - } - } - .boxed() - } -} - -impl<'a> Stream for FileStream<'a> { - type Item = PolarsResult>>; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.get_mut().stream.poll_next_unpin(cx) - } -} - -/// Reads the footer's length and magic number in footer -async fn read_footer_len(reader: &mut R) -> PolarsResult { - // read footer length and magic number in footer - reader.seek(SeekFrom::End(-10)).await?; - let mut footer: [u8; 10] = [0; 10]; - - reader.read_exact(&mut footer).await?; - let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); - - if footer[4..] != ARROW_MAGIC_V2 { - polars_bail!(oos = OutOfSpecKind::InvalidFooter) - } - footer_len - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength)) -} - -/// Read the metadata from an IPC file. -pub async fn read_file_metadata_async(reader: &mut R) -> PolarsResult -where - R: AsyncRead + AsyncSeek + Unpin, -{ - let footer_size = read_footer_len(reader).await?; - // Read footer - reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; - - let mut footer = vec![]; - footer.try_reserve(footer_size)?; - reader - .take(footer_size as u64) - .read_to_end(&mut footer) - .await?; - - deserialize_footer(&footer, u64::MAX) -} - -#[allow(clippy::too_many_arguments)] -async fn read_batch( - mut reader: R, - dictionaries: &mut Dictionaries, - metadata: &FileMetadata, - projection: Option<&[usize]>, - limit: Option, - block: usize, - meta_buffer: &mut Vec, - block_buffer: &mut Vec, - scratch: &mut Vec, -) -> PolarsResult>> -where - R: AsyncRead + AsyncSeek + Unpin, -{ - let block = metadata.blocks[block]; - - let offset: u64 = block - .offset - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - - reader.seek(SeekFrom::Start(offset)).await?; - let mut meta_buf = [0; 4]; - reader.read_exact(&mut meta_buf).await?; - if meta_buf == CONTINUATION_MARKER { - reader.read_exact(&mut meta_buf).await?; - } - - let meta_len = i32::from_le_bytes(meta_buf) - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; - - meta_buffer.clear(); - meta_buffer.try_reserve(meta_len)?; - (&mut reader) - .take(meta_len as u64) - .read_to_end(meta_buffer) - .await?; - - let message = arrow_format::ipc::MessageRef::read_as_root(meta_buffer) - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; - - let batch = get_record_batch(message)?; - - let block_length: usize = message - .body_length() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; - - block_buffer.clear(); - block_buffer.try_reserve(block_length)?; - reader - .take(block_length as u64) - .read_to_end(block_buffer) - .await?; - - let mut cursor = std::io::Cursor::new(&block_buffer); - - read_record_batch( - batch, - &metadata.schema, - &metadata.ipc_schema, - projection, - limit, - dictionaries, - message - .version() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferVersion(err)))?, - &mut cursor, - 0, - metadata.size, - scratch, - ) -} - -async fn read_dictionaries( - mut reader: R, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - blocks: &[Block], - scratch: &mut Vec, -) -> PolarsResult -where - R: AsyncRead + AsyncSeek + Unpin, -{ - let mut dictionaries = Default::default(); - let mut data: Vec = vec![]; - let mut buffer: Vec = vec![]; - - for block in blocks { - let offset: u64 = block - .offset - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - - let length: usize = block - .body_length - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - - read_dictionary_message(&mut reader, offset, &mut data).await?; - - let message = arrow_format::ipc::MessageRef::read_as_root(data.as_ref()) - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; - - let header = message - .header() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? - .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; - - match header { - MessageHeaderRef::DictionaryBatch(batch) => { - buffer.clear(); - buffer.try_reserve(length)?; - (&mut reader) - .take(length as u64) - .read_to_end(&mut buffer) - .await?; - let mut cursor = std::io::Cursor::new(&buffer); - read_dictionary( - batch, - fields, - ipc_schema, - &mut dictionaries, - &mut cursor, - 0, - u64::MAX, - scratch, - )?; - }, - _ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType), - } - } - Ok(dictionaries) -} - -async fn read_dictionary_message( - mut reader: R, - offset: u64, - data: &mut Vec, -) -> PolarsResult<()> -where - R: AsyncRead + AsyncSeek + Unpin, -{ - let mut message_size = [0; 4]; - reader.seek(SeekFrom::Start(offset)).await?; - reader.read_exact(&mut message_size).await?; - if message_size == CONTINUATION_MARKER { - reader.read_exact(&mut message_size).await?; - } - let footer_size = i32::from_le_bytes(message_size); - - let footer_size: usize = footer_size - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - - data.clear(); - data.try_reserve(footer_size)?; - (&mut reader) - .take(footer_size as u64) - .read_to_end(data) - .await?; - - Ok(()) -} - -async fn cached_read_dictionaries( - reader: &mut R, - metadata: &FileMetadata, - dictionaries: &mut Option, -) -> PolarsResult<()> { - match (&dictionaries, metadata.dictionaries.as_deref()) { - (None, Some(blocks)) => { - let new_dictionaries: hashbrown::HashMap, ahash::RandomState> = - read_dictionaries( - reader, - &metadata.schema, - &metadata.ipc_schema, - blocks, - &mut Default::default(), - ) - .await?; - *dictionaries = Some(new_dictionaries); - }, - (None, None) => { - *dictionaries = Some(Default::default()); - }, - _ => {}, - }; - Ok(()) -} diff --git a/crates/polars-arrow/src/io/ipc/read/flight.rs b/crates/polars-arrow/src/io/ipc/read/flight.rs new file mode 100644 index 000000000000..8c35fa1cfd60 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/flight.rs @@ -0,0 +1,457 @@ +use std::io::SeekFrom; +use std::pin::Pin; +use std::sync::Arc; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef}; +use futures::{Stream, StreamExt}; +use polars_error::{polars_bail, polars_err, PolarsResult}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; + +use crate::datatypes::ArrowSchema; +use crate::io::ipc::read::common::read_record_batch; +use crate::io::ipc::read::file::{ + decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer, + iter_recordbatch_blocks_from_footer, +}; +use crate::io::ipc::read::schema::deserialize_stream_metadata; +use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata}; +use crate::io::ipc::write::common::EncodedData; +use crate::mmap::{mmap_dictionary_from_batch, mmap_record}; +use crate::record_batch::RecordBatch; + +async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>( + reader: &mut R, + block: &arrow_format::ipc::Block, + scratch: &'a mut Vec, +) -> PolarsResult> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + reader.seek(SeekFrom::Start(offset)).await?; + read_ipc_message(reader, scratch).await +} + +/// Read an encapsulated IPC Message from the reader +async fn read_ipc_message<'a, R: AsyncRead + Unpin>( + reader: &mut R, + scratch: &'a mut Vec, +) -> PolarsResult> { + let mut message_size: [u8; 4] = [0; 4]; + + reader.read_exact(&mut message_size).await?; + if message_size == crate::io::ipc::CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + }; + let message_length = i32::from_le_bytes(message_size); + + let message_length: usize = message_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + scratch.clear(); + scratch.try_reserve(message_length)?; + reader + .take(message_length as u64) + .read_to_end(scratch) + .await?; + + arrow_format::ipc::MessageRef::read_as_root(scratch) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err))) +} + +async fn read_footer_len( + reader: &mut R, +) -> PolarsResult<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10)).await? + 10; + + let mut footer: [u8; 10] = [0; 10]; + reader.read_exact(&mut footer).await?; + + decode_footer_len(footer, end) +} + +async fn read_footer( + reader: &mut R, + footer_len: usize, +) -> PolarsResult> { + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + + reader + .take(footer_len as u64) + .read_to_end(&mut serialized_footer) + .await?; + Ok(serialized_footer) +} + +fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData { + // Turn the IPC schema into an encapsulated message + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + // Assumed the conversion is infallible. + header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new( + schema.try_into().unwrap(), + ))), + body_length: 0, + custom_metadata: None, // todo: allow writing custom metadata + }; + let mut builder = arrow_format::ipc::planus::Builder::new(); + let header = builder.finish(&message, None).to_vec(); + + // Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use + // `data_header` and `data_body`. + EncodedData { + ipc_message: header, + arrow_data: vec![], + } +} + +async fn block_to_raw_message<'a, R>( + reader: &mut R, + block: &arrow_format::ipc::Block, + encoded_data: &mut EncodedData, +) -> PolarsResult<()> +where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, +{ + debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty()); + let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?; + + let block_length: u64 = message + .body_length() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + reader + .take(block_length) + .read_to_end(&mut encoded_data.arrow_data) + .await?; + + Ok(()) +} + +pub async fn into_flight_stream( + reader: &mut R, +) -> PolarsResult> + '_> { + Ok(async_stream::try_stream! { + let (_end, len) = read_footer_len(reader).await?; + let footer_data = read_footer(reader, len).await?; + let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + let data_blocks = iter_recordbatch_blocks_from_footer(footer)?; + let dict_blocks = iter_dictionary_blocks_from_footer(footer)?; + + let schema_ref = deserialize_schema_ref_from_footer(footer)?; + let schema = schema_to_raw_message(schema_ref); + + yield schema; + + if let Some(dict_blocks_iter) = dict_blocks { + for d in dict_blocks_iter { + let mut ed: EncodedData = Default::default(); + block_to_raw_message(reader, &d?, &mut ed).await?; + yield ed + } + }; + + for d in data_blocks { + let mut ed: EncodedData = Default::default(); + block_to_raw_message(reader, &d?, &mut ed).await?; + yield ed + } + }) +} + +pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> { + footer: Option<*const FooterRef<'static>>, + footer_data: Vec, + dict_blocks: Option>>>, + data_blocks: Option>>>, + reader: &'a mut R, +} + +impl Drop for FlightStreamProducer<'_, R> { + fn drop(&mut self) { + if let Some(p) = self.footer { + unsafe { + let _ = Box::from_raw(p as *mut FooterRef<'static>); + } + } + } +} + +unsafe impl Send for FlightStreamProducer<'_, R> {} + +impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { + pub async fn new(reader: &'a mut R) -> PolarsResult>> { + let (_end, len) = read_footer_len(reader).await?; + let footer_data = read_footer(reader, len).await?; + + Ok(Box::pin(Self { + footer: None, + footer_data, + dict_blocks: None, + data_blocks: None, + reader, + })) + } + + pub fn init(self: &mut Pin>) -> PolarsResult<()> { + let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + + let footer = Box::new(footer); + + #[allow(clippy::unnecessary_cast)] + let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>; + + self.footer = Some(ptr); + let footer = &unsafe { **self.footer.as_ref().unwrap() }; + + self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?) + as Box>); + self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)? + .map(|i| Box::new(i) as Box>); + + Ok(()) + } + + pub fn get_schema(self: &Pin>) -> PolarsResult { + let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") }; + + let schema_ref = deserialize_schema_ref_from_footer(*footer)?; + let schema = schema_to_raw_message(schema_ref); + + Ok(schema) + } + + pub async fn next_dict( + self: &mut Pin>, + encoded_data: &mut EncodedData, + ) -> PolarsResult> { + assert!(self.data_blocks.is_some(), "init must be called first"); + encoded_data.ipc_message.clear(); + encoded_data.arrow_data.clear(); + + if let Some(iter) = &mut self.dict_blocks { + let Some(value) = iter.next() else { + return Ok(None); + }; + let block = value?; + + block_to_raw_message(&mut self.reader, &block, encoded_data).await?; + Ok(Some(())) + } else { + Ok(None) + } + } + + pub async fn next_data( + self: &mut Pin>, + encoded_data: &mut EncodedData, + ) -> PolarsResult> { + encoded_data.ipc_message.clear(); + encoded_data.arrow_data.clear(); + + let iter = self + .data_blocks + .as_mut() + .expect("init must be called first"); + let Some(value) = iter.next() else { + return Ok(None); + }; + let block = value?; + + block_to_raw_message(&mut self.reader, &block, encoded_data).await?; + Ok(Some(())) + } +} + +pub struct FlightConsumer { + dictionaries: Dictionaries, + md: StreamMetadata, + scratch: Vec, +} + +impl FlightConsumer { + pub fn new(first: EncodedData) -> PolarsResult { + let md = deserialize_stream_metadata(&first.ipc_message)?; + Ok(Self { + dictionaries: Default::default(), + md, + scratch: vec![], + }) + } + + pub fn schema(&self) -> &ArrowSchema { + &self.md.schema + } + + pub fn consume(&mut self, msg: EncodedData) -> PolarsResult> { + // Parse the header + let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + + // Either append to the dictionaries and return None or return Some(ArrowChunk) + match header { + MessageHeaderRef::Schema(_) => { + polars_bail!(ComputeError: "Unexpected schema message while parsing Stream"); + }, + // Add to dictionary state and continue iteration + MessageHeaderRef::DictionaryBatch(batch) => unsafe { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + mmap_dictionary_from_batch( + &self.md.schema, + &self.md.ipc_schema.fields, + &arrow_data, + batch, + &mut self.dictionaries, + 0, + ) + .map(|_| None) + }, + // Return Batch + MessageHeaderRef::RecordBatch(batch) => { + if batch.compression()?.is_some() { + let data_size = msg.arrow_data.len() as u64; + let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice()); + read_record_batch( + batch, + &self.md.schema, + &self.md.ipc_schema, + None, + None, + &self.dictionaries, + self.md.version, + &mut reader, + 0, + data_size, + &mut self.scratch, + ) + .map(Some) + } else { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + unsafe { + mmap_record( + &self.md.schema, + &self.md.ipc_schema.fields, + arrow_data.clone(), + batch, + 0, + &self.dictionaries, + ) + .map(Some) + } + } + }, + _ => unimplemented!(), + } + } +} + +pub struct FlightstreamConsumer> + Unpin> { + inner: FlightConsumer, + stream: S, +} + +impl> + Unpin> FlightstreamConsumer { + pub async fn new(mut stream: S) -> PolarsResult { + let Some(first) = stream.next().await else { + polars_bail!(ComputeError: "expected the schema") + }; + let first = first?; + + Ok(FlightstreamConsumer { + inner: FlightConsumer::new(first)?, + stream, + }) + } + + pub async fn next_batch(&mut self) -> PolarsResult> { + while let Some(msg) = self.stream.next().await { + let msg = msg?; + let option_recordbatch = self.inner.consume(msg)?; + if option_recordbatch.is_some() { + return Ok(option_recordbatch); + } + } + Ok(None) + } +} + +#[cfg(test)] +mod test { + use std::path::{Path, PathBuf}; + + use tokio::fs::File; + + use super::*; + use crate::record_batch::RecordBatch; + + fn get_file_path() -> PathBuf { + let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc") + } + + fn read_file(path: &Path) -> RecordBatch { + let mut file = std::fs::File::open(path).unwrap(); + let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap(); + let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None); + ipc_reader.next().unwrap().unwrap() + } + + #[tokio::test] + async fn test_file_flight_simple() { + let path = &get_file_path(); + let mut file = tokio::fs::File::open(path).await.unwrap(); + let stream = into_flight_stream(&mut file).await.unwrap(); + + let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap(); + let b = c.next_batch().await.unwrap().unwrap(); + + assert_eq!(b, read_file(path)); + } + + #[tokio::test] + async fn test_file_flight_amortized() { + let path = &get_file_path(); + let mut file = File::open(path).await.unwrap(); + let mut p = FlightStreamProducer::new(&mut file).await.unwrap(); + p.init().unwrap(); + + let mut batches = vec![]; + + let schema = p.get_schema().unwrap(); + batches.push(schema); + + let mut ed = EncodedData::default(); + if p.next_dict(&mut ed).await.unwrap().is_some() { + batches.push(ed); + } + + let mut ed = EncodedData::default(); + p.next_data(&mut ed).await.unwrap(); + batches.push(ed); + + let mut c = + FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok)))) + .await + .unwrap(); + let b = c.next_batch().await.unwrap().unwrap(); + + assert_eq!(b, read_file(path)); + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/mod.rs b/crates/polars-arrow/src/io/ipc/read/mod.rs index 74d9a93a9309..88411f9b905f 100644 --- a/crates/polars-arrow/src/io/ipc/read/mod.rs +++ b/crates/polars-arrow/src/io/ipc/read/mod.rs @@ -11,27 +11,18 @@ mod common; mod deserialize; mod error; pub(crate) mod file; +#[cfg(feature = "io_flight")] +mod flight; mod read_basic; mod reader; mod schema; mod stream; -pub use error::OutOfSpecKind; -pub use file::get_row_count; - -#[cfg(feature = "io_ipc_read_async")] -#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] -pub mod stream_async; - -#[cfg(feature = "io_ipc_read_async")] -#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] -pub mod file_async; - pub(crate) use common::first_dict_field; -#[cfg(feature = "io_flight")] -pub(crate) use common::{read_dictionary, read_record_batch}; +pub use error::OutOfSpecKind; pub use file::{ - deserialize_footer, read_batch, read_file_dictionaries, read_file_metadata, FileMetadata, + deserialize_footer, get_row_count, read_batch, read_file_dictionaries, read_file_metadata, + FileMetadata, }; use polars_utils::aliases::PlHashMap; pub use reader::FileReader; @@ -45,3 +36,10 @@ pub(crate) type Node<'a> = arrow_format::ipc::FieldNodeRef<'a>; pub(crate) type IpcBuffer<'a> = arrow_format::ipc::BufferRef<'a>; pub(crate) type Compression<'a> = arrow_format::ipc::BodyCompressionRef<'a>; pub(crate) type Version = arrow_format::ipc::MetadataVersion; + +#[cfg(feature = "io_flight")] +pub use flight::*; + +pub trait SendableIterator: Send + Iterator {} + +impl SendableIterator for T {} diff --git a/crates/polars-arrow/src/io/ipc/read/read_basic.rs b/crates/polars-arrow/src/io/ipc/read/read_basic.rs index 09005ea4222e..9721137adf5c 100644 --- a/crates/polars-arrow/src/io/ipc/read/read_basic.rs +++ b/crates/polars-arrow/src/io/ipc/read/read_basic.rs @@ -17,10 +17,10 @@ fn read_swapped( is_little_endian: bool, ) -> PolarsResult<()> { // slow case where we must reverse bits - let mut slice = vec![0u8; length * std::mem::size_of::()]; + let mut slice = vec![0u8; length * size_of::()]; reader.read_exact(&mut slice)?; - let chunks = slice.chunks_exact(std::mem::size_of::()); + let chunks = slice.chunks_exact(size_of::()); if !is_little_endian { // machine is little endian, file is big endian buffer @@ -67,7 +67,7 @@ fn read_uncompressed_buffer( length: usize, is_little_endian: bool, ) -> PolarsResult> { - let required_number_of_bytes = length.saturating_mul(std::mem::size_of::()); + let required_number_of_bytes = length.saturating_mul(size_of::()); if required_number_of_bytes > buffer_length { polars_bail!( oos = OutOfSpecKind::InvalidBuffer { diff --git a/crates/polars-arrow/src/io/ipc/read/schema.rs b/crates/polars-arrow/src/io/ipc/read/schema.rs index 7fe6141e9b14..3e665a884eed 100644 --- a/crates/polars-arrow/src/io/ipc/read/schema.rs +++ b/crates/polars-arrow/src/io/ipc/read/schema.rs @@ -157,9 +157,6 @@ fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField) let fields = field .children()? .ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?; - if fields.is_empty() { - polars_bail!(oos = "IPC: Struct must contain at least one child"); - } let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { let (field, fields) = deserialize_field(field?)?; Ok((field, fields)) diff --git a/crates/polars-arrow/src/io/ipc/read/stream_async.rs b/crates/polars-arrow/src/io/ipc/read/stream_async.rs deleted file mode 100644 index ab29550d8a14..000000000000 --- a/crates/polars-arrow/src/io/ipc/read/stream_async.rs +++ /dev/null @@ -1,238 +0,0 @@ -//! APIs to read Arrow streams asynchronously - -use arrow_format::ipc::planus::ReadAsRoot; -use futures::future::BoxFuture; -use futures::{AsyncRead, AsyncReadExt, FutureExt, Stream}; -use polars_error::*; - -use super::super::CONTINUATION_MARKER; -use super::common::{read_dictionary, read_record_batch}; -use super::schema::deserialize_stream_metadata; -use super::{Dictionaries, OutOfSpecKind, StreamMetadata}; -use crate::array::*; -use crate::record_batch::RecordBatchT; - -/// A (private) state of stream messages -struct ReadState { - pub reader: R, - pub metadata: StreamMetadata, - pub dictionaries: Dictionaries, - /// The internal buffer to read data inside the messages (records and dictionaries) to - pub data_buffer: Vec, - /// The internal buffer to read messages to - pub message_buffer: Vec, -} - -/// The state of an Arrow stream -enum StreamState { - /// The stream does not contain new chunks (and it has not been closed) - Waiting(ReadState), - /// The stream contain a new chunk - Some((ReadState, RecordBatchT>)), -} - -/// Reads the [`StreamMetadata`] of the Arrow stream asynchronously -pub async fn read_stream_metadata_async( - reader: &mut R, -) -> PolarsResult { - // determine metadata length - let mut meta_size: [u8; 4] = [0; 4]; - reader.read_exact(&mut meta_size).await?; - let meta_len = { - // If a continuation marker is encountered, skip over it and read - // the size from the next four bytes. - if meta_size == CONTINUATION_MARKER { - reader.read_exact(&mut meta_size).await?; - } - i32::from_le_bytes(meta_size) - }; - - let meta_len: usize = meta_len.try_into().map_err( - |_| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), - )?; - - let mut meta_buffer = vec![]; - meta_buffer.try_reserve(meta_len)?; - reader - .take(meta_len as u64) - .read_to_end(&mut meta_buffer) - .await?; - - deserialize_stream_metadata(&meta_buffer) -} - -/// Reads the next item, yielding `None` if the stream has been closed, -/// or a [`StreamState`] otherwise. -async fn maybe_next( - mut state: ReadState, -) -> PolarsResult>> { - let mut scratch = Default::default(); - // determine metadata length - let mut meta_length: [u8; 4] = [0; 4]; - - match state.reader.read_exact(&mut meta_length).await { - Ok(()) => (), - Err(e) => { - return if e.kind() == std::io::ErrorKind::UnexpectedEof { - // Handle EOF without the "0xFFFFFFFF 0x00000000" - // valid according to: - // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format - Ok(Some(StreamState::Waiting(state))) - } else { - Err(PolarsError::from(e)) - }; - }, - } - - let meta_length = { - // If a continuation marker is encountered, skip over it and read - // the size from the next four bytes. - if meta_length == CONTINUATION_MARKER { - state.reader.read_exact(&mut meta_length).await?; - } - i32::from_le_bytes(meta_length) - }; - - let meta_length: usize = meta_length.try_into().map_err( - |_| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), - )?; - - if meta_length == 0 { - // the stream has ended, mark the reader as finished - return Ok(None); - } - - state.message_buffer.clear(); - state.message_buffer.try_reserve(meta_length)?; - (&mut state.reader) - .take(meta_length as u64) - .read_to_end(&mut state.message_buffer) - .await?; - - let message = arrow_format::ipc::MessageRef::read_as_root(state.message_buffer.as_ref()) - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; - - let header = message - .header() - .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferHeader(err)))? - .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; - - let block_length: usize = message - .body_length() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? - .try_into() - .map_err(|_err| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; - - match header { - arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { - state.data_buffer.clear(); - state.data_buffer.try_reserve(block_length)?; - (&mut state.reader) - .take(block_length as u64) - .read_to_end(&mut state.data_buffer) - .await?; - - let chunk = read_record_batch( - batch, - &state.metadata.schema, - &state.metadata.ipc_schema, - None, - None, - &state.dictionaries, - state.metadata.version, - &mut std::io::Cursor::new(&state.data_buffer), - 0, - state.data_buffer.len() as u64, - &mut scratch, - )?; - - Ok(Some(StreamState::Some((state, chunk)))) - }, - arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { - state.data_buffer.clear(); - state.data_buffer.try_reserve(block_length)?; - (&mut state.reader) - .take(block_length as u64) - .read_to_end(&mut state.data_buffer) - .await?; - - let file_size = state.data_buffer.len() as u64; - - let mut dict_reader = std::io::Cursor::new(&state.data_buffer); - - read_dictionary( - batch, - &state.metadata.schema, - &state.metadata.ipc_schema, - &mut state.dictionaries, - &mut dict_reader, - 0, - file_size, - &mut scratch, - )?; - - // read the next message until we encounter a Chunk> message - Ok(Some(StreamState::Waiting(state))) - }, - _ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType), - } -} - -/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`RecordBatchT`]s. -pub struct AsyncStreamReader<'a, R: AsyncRead + Unpin + Send + 'a> { - metadata: StreamMetadata, - future: Option>>>>, -} - -impl<'a, R: AsyncRead + Unpin + Send + 'a> AsyncStreamReader<'a, R> { - /// Creates a new [`AsyncStreamReader`] - pub fn new(reader: R, metadata: StreamMetadata) -> Self { - let state = ReadState { - reader, - metadata: metadata.clone(), - dictionaries: Default::default(), - data_buffer: Default::default(), - message_buffer: Default::default(), - }; - let future = Some(maybe_next(state).boxed()); - Self { metadata, future } - } - - /// Return the schema of the stream - pub fn metadata(&self) -> &StreamMetadata { - &self.metadata - } -} - -impl<'a, R: AsyncRead + Unpin + Send> Stream for AsyncStreamReader<'a, R> { - type Item = PolarsResult>>; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - use std::pin::Pin; - use std::task::Poll; - let me = Pin::into_inner(self); - - match &mut me.future { - Some(fut) => match fut.as_mut().poll(cx) { - Poll::Ready(Ok(None)) => { - me.future = None; - Poll::Ready(None) - }, - Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => { - me.future = Some(Box::pin(maybe_next(state))); - Poll::Ready(Some(Ok(batch))) - }, - Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending, - Poll::Ready(Err(err)) => { - me.future = None; - Poll::Ready(Some(Err(err))) - }, - Poll::Pending => Poll::Pending, - }, - None => Poll::Ready(None), - } - } -} diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs index a49c7fdcd790..6fa1b5f0d8c4 100644 --- a/crates/polars-arrow/src/io/ipc/write/common.rs +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -482,7 +482,7 @@ pub struct Record<'a> { fields: Option>, } -impl<'a> Record<'a> { +impl Record<'_> { /// Get the IPC fields for this record. pub fn fields(&self) -> Option<&[IpcField]> { self.fields.as_deref() diff --git a/crates/polars-arrow/src/io/ipc/write/common_async.rs b/crates/polars-arrow/src/io/ipc/write/common_async.rs deleted file mode 100644 index daadfcb5e25e..000000000000 --- a/crates/polars-arrow/src/io/ipc/write/common_async.rs +++ /dev/null @@ -1,66 +0,0 @@ -use futures::{AsyncWrite, AsyncWriteExt}; -use polars_error::PolarsResult; - -use super::super::CONTINUATION_MARKER; -use super::common::{pad_to_64, EncodedData}; - -/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written -pub async fn write_message( - mut writer: W, - encoded: EncodedData, -) -> PolarsResult<(usize, usize)> { - let arrow_data_len = encoded.arrow_data.len(); - - let a = 64 - 1; - let buffer = encoded.ipc_message; - let flatbuf_size = buffer.len(); - let prefix_size = 8; // the message length - let aligned_size = (flatbuf_size + prefix_size + a) & !a; - let padding_bytes = aligned_size - flatbuf_size - prefix_size; - - write_continuation(&mut writer, (aligned_size - prefix_size) as i32).await?; - - // write the flatbuf - if flatbuf_size > 0 { - writer.write_all(&buffer).await?; - } - // write padding - writer.write_all(&vec![0; padding_bytes]).await?; - - // write arrow data - let body_len = if arrow_data_len > 0 { - write_body_buffers(writer, &encoded.arrow_data).await? - } else { - 0 - }; - - Ok((aligned_size, body_len)) -} - -/// Write a record batch to the writer, writing the message size before the message -/// if the record batch is being written to a stream -pub async fn write_continuation( - mut writer: W, - total_len: i32, -) -> PolarsResult { - writer.write_all(&CONTINUATION_MARKER).await?; - writer.write_all(&total_len.to_le_bytes()[..]).await?; - Ok(8) -} - -async fn write_body_buffers( - mut writer: W, - data: &[u8], -) -> PolarsResult { - let len = data.len(); - let pad_len = pad_to_64(data.len()); - let total_len = len + pad_len; - - // write body buffer - writer.write_all(data).await?; - if pad_len > 0 { - writer.write_all(&vec![0u8; pad_len][..]).await?; - } - - Ok(total_len) -} diff --git a/crates/polars-arrow/src/io/ipc/write/file_async.rs b/crates/polars-arrow/src/io/ipc/write/file_async.rs deleted file mode 100644 index aaae101785bc..000000000000 --- a/crates/polars-arrow/src/io/ipc/write/file_async.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! Async writer for IPC files. - -use std::task::Poll; - -use arrow_format::ipc::planus::Builder; -use arrow_format::ipc::{Block, Footer, MetadataVersion}; -use futures::future::BoxFuture; -use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink}; -use polars_error::{PolarsError, PolarsResult}; - -use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; -use super::common_async::{write_continuation, write_message}; -use super::schema::serialize_schema; -use super::{default_ipc_fields, schema_to_bytes, Record}; -use crate::datatypes::*; -use crate::io::ipc::{IpcField, ARROW_MAGIC_V2}; - -type WriteOutput = (usize, Option, Vec, Option); - -/// Sink that writes array [`chunks`](crate::record_batch::RecordBatchT) as an IPC file. -/// -/// The file header is automatically written before writing the first chunk, and the file footer is -/// automatically written when the sink is closed. -pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { - writer: Option, - task: Option>>>, - options: WriteOptions, - dictionary_tracker: DictionaryTracker, - offset: usize, - fields: Vec, - record_blocks: Vec, - dictionary_blocks: Vec, - schema: ArrowSchema, -} - -impl<'a, W> FileSink<'a, W> -where - W: AsyncWrite + Unpin + Send + 'a, -{ - /// Create a new file writer. - pub fn new( - writer: W, - schema: ArrowSchema, - ipc_fields: Option>, - options: WriteOptions, - ) -> Self { - let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(schema.iter_values())); - let encoded = EncodedData { - ipc_message: schema_to_bytes(&schema, &fields), - arrow_data: vec![], - }; - let task = Some(Self::start(writer, encoded).boxed()); - Self { - writer: None, - task, - options, - fields, - offset: 0, - schema, - dictionary_tracker: DictionaryTracker { - dictionaries: Default::default(), - cannot_replace: true, - }, - record_blocks: vec![], - dictionary_blocks: vec![], - } - } - - async fn start(mut writer: W, encoded: EncodedData) -> PolarsResult> { - writer.write_all(&ARROW_MAGIC_V2[..]).await?; - writer.write_all(&[0, 0]).await?; - let (meta, data) = write_message(&mut writer, encoded).await?; - - Ok((meta + data + 8, None, vec![], Some(writer))) - } - - async fn write( - mut writer: W, - mut offset: usize, - record: EncodedData, - dictionaries: Vec, - ) -> PolarsResult> { - let mut dict_blocks = vec![]; - for dict in dictionaries { - let (meta, data) = write_message(&mut writer, dict).await?; - let block = Block { - offset: offset as i64, - meta_data_length: meta as i32, - body_length: data as i64, - }; - dict_blocks.push(block); - offset += meta + data; - } - let (meta, data) = write_message(&mut writer, record).await?; - let block = Block { - offset: offset as i64, - meta_data_length: meta as i32, - body_length: data as i64, - }; - offset += meta + data; - Ok((offset, Some(block), dict_blocks, Some(writer))) - } - - async fn finish(mut writer: W, footer: Footer) -> PolarsResult> { - write_continuation(&mut writer, 0).await?; - let footer = { - let mut builder = Builder::new(); - builder.finish(&footer, None).to_owned() - }; - writer.write_all(&footer[..]).await?; - writer - .write_all(&(footer.len() as i32).to_le_bytes()) - .await?; - writer.write_all(&ARROW_MAGIC_V2).await?; - writer.close().await?; - - Ok((0, None, vec![], None)) - } - - fn poll_write(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - if let Some(task) = &mut self.task { - match futures::ready!(task.poll_unpin(cx)) { - Ok((offset, record, mut dictionaries, writer)) => { - self.task = None; - self.writer = writer; - self.offset = offset; - if let Some(block) = record { - self.record_blocks.push(block); - } - self.dictionary_blocks.append(&mut dictionaries); - Poll::Ready(Ok(())) - }, - Err(error) => { - self.task = None; - Poll::Ready(Err(error)) - }, - } - } else { - Poll::Ready(Ok(())) - } - } -} - -impl<'a, W> Sink> for FileSink<'a, W> -where - W: AsyncWrite + Unpin + Send + 'a, -{ - type Error = PolarsError; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.get_mut().poll_write(cx) - } - - fn start_send(self: std::pin::Pin<&mut Self>, item: Record<'_>) -> PolarsResult<()> { - let this = self.get_mut(); - - if let Some(writer) = this.writer.take() { - let fields = item.fields().unwrap_or_else(|| &this.fields[..]); - - let (dictionaries, record) = encode_chunk( - item.columns(), - fields, - &mut this.dictionary_tracker, - &this.options, - )?; - - this.task = Some(Self::write(writer, this.offset, record, dictionaries).boxed()); - Ok(()) - } else { - let io_err = std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "writer is closed"); - Err(PolarsError::from(io_err)) - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.get_mut().poll_write(cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - match futures::ready!(this.poll_write(cx)) { - Ok(()) => { - if let Some(writer) = this.writer.take() { - let schema = serialize_schema(&this.schema, &this.fields); - let footer = Footer { - version: MetadataVersion::V5, - schema: Some(Box::new(schema)), - dictionaries: Some(std::mem::take(&mut this.dictionary_blocks)), - record_batches: Some(std::mem::take(&mut this.record_blocks)), - custom_metadata: None, - }; - this.task = Some(Self::finish(writer, footer).boxed()); - this.poll_write(cx) - } else { - Poll::Ready(Ok(())) - } - }, - Err(error) => Poll::Ready(Err(error)), - } - } -} diff --git a/crates/polars-arrow/src/io/ipc/write/mod.rs b/crates/polars-arrow/src/io/ipc/write/mod.rs index d8afc1571721..2291448d3012 100644 --- a/crates/polars-arrow/src/io/ipc/write/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/mod.rs @@ -5,7 +5,7 @@ mod serialize; mod stream; pub(crate) mod writer; -pub use common::{Compression, Record, WriteOptions}; +pub use common::{Compression, EncodedData, Record, WriteOptions}; pub use schema::schema_to_bytes; pub use serialize::write; use serialize::write_dictionary; @@ -14,16 +14,6 @@ pub use writer::FileWriter; pub(crate) mod common_sync; -#[cfg(feature = "io_ipc_write_async")] -mod common_async; -#[cfg(feature = "io_ipc_write_async")] -#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] -pub mod stream_async; - -#[cfg(feature = "io_ipc_write_async")] -#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] -pub mod file_async; - use super::IpcField; use crate::datatypes::{ArrowDataType, Field}; diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs index f13098477d4d..43af5add63a6 100644 --- a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs @@ -283,7 +283,7 @@ fn _write_buffer_from_iter>( is_little_endian: bool, ) { let len = buffer.size_hint().0; - arrow_data.reserve(len * std::mem::size_of::()); + arrow_data.reserve(len * size_of::()); if is_little_endian { buffer .map(|x| T::to_le_bytes(&x)) @@ -303,7 +303,7 @@ fn _write_compressed_buffer_from_iter>( compression: Compression, ) { let len = buffer.size_hint().0; - let mut swapped = Vec::with_capacity(len * std::mem::size_of::()); + let mut swapped = Vec::with_capacity(len * size_of::()); if is_little_endian { buffer .map(|x| T::to_le_bytes(&x)) diff --git a/crates/polars-arrow/src/io/ipc/write/stream_async.rs b/crates/polars-arrow/src/io/ipc/write/stream_async.rs deleted file mode 100644 index 3718d6f82b29..000000000000 --- a/crates/polars-arrow/src/io/ipc/write/stream_async.rs +++ /dev/null @@ -1,158 +0,0 @@ -//! `async` writing of arrow streams - -use std::pin::Pin; -use std::task::Poll; - -use futures::future::BoxFuture; -use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink}; -use polars_error::{PolarsError, PolarsResult}; - -use super::super::IpcField; -pub use super::common::WriteOptions; -use super::common::{encode_chunk, DictionaryTracker, EncodedData}; -use super::common_async::{write_continuation, write_message}; -use super::{default_ipc_fields, schema_to_bytes, Record}; -use crate::datatypes::*; - -/// A sink that writes array [`chunks`](crate::record_batch::RecordBatchT) as an IPC stream. -/// -/// The stream header is automatically written before writing the first chunk. -pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { - writer: Option, - task: Option>>>, - options: WriteOptions, - dictionary_tracker: DictionaryTracker, - fields: Vec, -} - -impl<'a, W> StreamSink<'a, W> -where - W: AsyncWrite + Unpin + Send + 'a, -{ - /// Create a new [`StreamSink`]. - pub fn new( - writer: W, - schema: &ArrowSchema, - ipc_fields: Option>, - write_options: WriteOptions, - ) -> Self { - let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(schema.iter_values())); - let task = Some(Self::start(writer, schema, &fields[..])); - Self { - writer: None, - task, - fields, - dictionary_tracker: DictionaryTracker { - dictionaries: Default::default(), - cannot_replace: false, - }, - options: write_options, - } - } - - fn start( - mut writer: W, - schema: &ArrowSchema, - ipc_fields: &[IpcField], - ) -> BoxFuture<'a, PolarsResult>> { - let message = EncodedData { - ipc_message: schema_to_bytes(schema, ipc_fields), - arrow_data: vec![], - }; - async move { - write_message(&mut writer, message).await?; - Ok(Some(writer)) - } - .boxed() - } - - fn write(&mut self, record: Record<'_>) -> PolarsResult<()> { - let fields = record.fields().unwrap_or(&self.fields[..]); - let (dictionaries, message) = encode_chunk( - record.columns(), - fields, - &mut self.dictionary_tracker, - &self.options, - )?; - - if let Some(mut writer) = self.writer.take() { - self.task = Some( - async move { - for d in dictionaries { - write_message(&mut writer, d).await?; - } - write_message(&mut writer, message).await?; - Ok(Some(writer)) - } - .boxed(), - ); - Ok(()) - } else { - let io_err = std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "writer closed".to_string(), - ); - Err(PolarsError::from(io_err)) - } - } - - fn poll_complete(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - if let Some(task) = &mut self.task { - match futures::ready!(task.poll_unpin(cx)) { - Ok(writer) => { - self.writer = writer; - self.task = None; - Poll::Ready(Ok(())) - }, - Err(error) => { - self.task = None; - Poll::Ready(Err(error)) - }, - } - } else { - Poll::Ready(Ok(())) - } - } -} - -impl<'a, W> Sink> for StreamSink<'a, W> -where - W: AsyncWrite + Unpin + Send, -{ - type Error = PolarsError; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - self.get_mut().poll_complete(cx) - } - - fn start_send(self: Pin<&mut Self>, item: Record<'_>) -> PolarsResult<()> { - self.get_mut().write(item) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - self.get_mut().poll_complete(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { - let this = self.get_mut(); - match this.poll_complete(cx) { - Poll::Ready(Ok(())) => { - if let Some(mut writer) = this.writer.take() { - this.task = Some( - async move { - write_continuation(&mut writer, 0).await?; - writer.flush().await?; - writer.close().await?; - Ok(None) - } - .boxed(), - ); - this.poll_complete(cx) - } else { - Poll::Ready(Ok(())) - } - }, - res => res, - } - } -} diff --git a/crates/polars-arrow/src/io/mod.rs b/crates/polars-arrow/src/io/mod.rs index a2b178304df0..1ae39a8ef766 100644 --- a/crates/polars-arrow/src/io/mod.rs +++ b/crates/polars-arrow/src/io/mod.rs @@ -1,15 +1,7 @@ -#![forbid(unsafe_code)] -//! Contains modules to interface with other formats such as [`csv`], -//! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`]. - #[cfg(feature = "io_ipc")] #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] pub mod ipc; -#[cfg(feature = "io_flight")] -#[cfg_attr(docsrs, doc(cfg(feature = "io_flight")))] -pub mod flight; - #[cfg(feature = "io_avro")] #[cfg_attr(docsrs, doc(cfg(feature = "io_avro")))] pub mod avro; diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs index f15ac1811f96..ed53797db9b9 100644 --- a/crates/polars-arrow/src/legacy/array/mod.rs +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -238,7 +238,13 @@ pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box>(); - StructArray::new(dtype.clone(), new_values, array.validity().cloned()).boxed() + StructArray::new( + dtype.clone(), + array.len(), + new_values, + array.validity().cloned(), + ) + .boxed() }, _ => new_null_array(dtype.clone(), array.len()), } diff --git a/crates/polars-arrow/src/legacy/conversion.rs b/crates/polars-arrow/src/legacy/conversion.rs index e2cb028036a0..9504e45d878c 100644 --- a/crates/polars-arrow/src/legacy/conversion.rs +++ b/crates/polars-arrow/src/legacy/conversion.rs @@ -5,7 +5,7 @@ use crate::types::NativeType; pub fn chunk_to_struct(chunk: RecordBatchT, fields: Vec) -> StructArray { let dtype = ArrowDataType::Struct(fields); - StructArray::new(dtype, chunk.into_arrays(), None) + StructArray::new(dtype, chunk.len(), chunk.into_arrays(), None) } /// Returns its underlying [`Vec`], if possible. diff --git a/crates/polars-arrow/src/legacy/kernels/atan2.rs b/crates/polars-arrow/src/legacy/kernels/atan2.rs deleted file mode 100644 index 40d3d527b24a..000000000000 --- a/crates/polars-arrow/src/legacy/kernels/atan2.rs +++ /dev/null @@ -1,12 +0,0 @@ -use num_traits::Float; - -use crate::array::PrimitiveArray; -use crate::compute::arity::binary; -use crate::types::NativeType; - -pub fn atan2(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> PrimitiveArray -where - T: Float + NativeType, -{ - binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| a.atan2(b)) -} diff --git a/crates/polars-arrow/src/legacy/kernels/float.rs b/crates/polars-arrow/src/legacy/kernels/float.rs deleted file mode 100644 index 22413fd0b4c5..000000000000 --- a/crates/polars-arrow/src/legacy/kernels/float.rs +++ /dev/null @@ -1,54 +0,0 @@ -use num_traits::Float; - -use crate::array::{ArrayRef, BooleanArray, PrimitiveArray}; -use crate::bitmap::Bitmap; -use crate::legacy::array::default_arrays::FromData; -use crate::types::NativeType; - -pub fn is_nan(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| v.is_nan())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} - -pub fn is_not_nan(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| !v.is_nan())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} - -pub fn is_finite(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| v.is_finite())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} - -pub fn is_infinite(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| v.is_infinite())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} diff --git a/crates/polars-arrow/src/legacy/kernels/mod.rs b/crates/polars-arrow/src/legacy/kernels/mod.rs index cdb7cb9c2057..89b31684beed 100644 --- a/crates/polars-arrow/src/legacy/kernels/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/mod.rs @@ -2,15 +2,12 @@ use std::iter::Enumerate; use crate::array::BooleanArray; use crate::bitmap::utils::BitChunks; -pub mod atan2; pub mod concatenate; pub mod ewm; #[cfg(feature = "compute_take")] pub mod fixed_size_list; -pub mod float; #[cfg(feature = "compute_take")] pub mod list; -pub mod pow; pub mod rolling; pub mod set; pub mod sort_partition; @@ -62,8 +59,7 @@ impl<'a> MaskedSlicesIterator<'a> { pub(crate) fn new(mask: &'a BooleanArray) -> Self { let chunks = mask.values().chunks::(); - let chunk_bits = 8 * std::mem::size_of::(); - let chunk_len = mask.len() / chunk_bits; + let chunk_len = mask.len() / 64; let remainder_len = chunks.remainder_len(); let remainder_mask = chunks.remainder(); @@ -141,7 +137,7 @@ impl<'a> MaskedSlicesIterator<'a> { } } -impl<'a> Iterator for MaskedSlicesIterator<'a> { +impl Iterator for MaskedSlicesIterator<'_> { type Item = (usize, usize); fn next(&mut self) -> Option { @@ -213,7 +209,7 @@ impl<'a> BinaryMaskedSliceIterator<'a> { } } -impl<'a> Iterator for BinaryMaskedSliceIterator<'a> { +impl Iterator for BinaryMaskedSliceIterator<'_> { type Item = (usize, usize, bool); fn next(&mut self) -> Option { diff --git a/crates/polars-arrow/src/legacy/kernels/pow.rs b/crates/polars-arrow/src/legacy/kernels/pow.rs deleted file mode 100644 index a790e6193129..000000000000 --- a/crates/polars-arrow/src/legacy/kernels/pow.rs +++ /dev/null @@ -1,13 +0,0 @@ -use num_traits::pow::Pow; - -use crate::array::PrimitiveArray; -use crate::compute::arity::binary; -use crate::types::NativeType; - -pub fn pow(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> PrimitiveArray -where - T: Pow + NativeType, - F: NativeType, -{ - binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| Pow::pow(a, b)) -} diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs index 8f8004c570e8..51f3c95d2a56 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs @@ -93,5 +93,5 @@ pub struct RollingVarParams { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingQuantileParams { pub prob: f64, - pub interpol: QuantileInterpolOptions, + pub method: QuantileMethod, } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 1b9695358dcb..7abe2455e61f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -11,6 +11,7 @@ use num_traits::{Float, Num, NumCast}; pub use quantile::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; pub use sum::*; pub use variance::*; @@ -69,17 +70,22 @@ where Ok(Box::new(arr)) } -#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum QuantileInterpolOptions { +#[strum(serialize_all = "snake_case")] +pub enum QuantileMethod { #[default] Nearest, Lower, Higher, Midpoint, Linear, + Equiprobable, } +#[deprecated(note = "use QuantileMethod instead")] +pub type QuantileInterpolOptions = QuantileMethod; + pub(super) fn rolling_apply_weights( values: &[T], window_size: usize, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index ab3919b9aaaa..bf0ad01e79c3 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -2,13 +2,13 @@ use num_traits::ToPrimitive; use polars_error::polars_ensure; use polars_utils::slice::GetSaferUnchecked; -use super::QuantileInterpolOptions::*; +use super::QuantileMethod::*; use super::*; pub struct QuantileWindow<'a, T: NativeType> { sorted: SortedBuf<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -34,7 +34,7 @@ impl< Self { sorted: SortedBuf::new(slice, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -42,7 +42,7 @@ impl< let vals = self.sorted.update(start, end); let length = vals.len(); - let idx = match self.interpol { + let idx = match self.method { Linear => { // Maybe add a fast path for median case? They could branch depending on odd/even. let length_f = length as f64; @@ -92,6 +92,7 @@ impl< let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; std::cmp::min(idx, length - 1) }, + Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize, }; // SAFETY: @@ -134,7 +135,7 @@ where unreachable!("expected Quantile params"); }; let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>( - params.interpol, + params.method, min_periods, window_size, values, @@ -170,7 +171,7 @@ where Ok(rolling_apply_weighted_quantile( values, params.prob, - params.interpol, + params.method, window_size, min_periods, offset_fn, @@ -182,7 +183,7 @@ where } #[inline] -fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, interp: QuantileInterpolOptions) -> T +fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T where T: Debug + NativeType + Mul + Sub + NumCast + ToPrimitive + Zero, { @@ -201,7 +202,7 @@ where (s_old, v_old, vk) = (s, vk, v); s += w; } - match (h == s_old, interp) { + match (h == s_old, method) { (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter (_, Lower) => v_old, (_, Higher) => vk, @@ -212,6 +213,14 @@ where vk } }, + (_, Equiprobable) => { + let threshold = (wsum * p).ceil() - 1.0; + if s > threshold { + vk + } else { + v_old + } + }, (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(), // This is seemingly the canonical way to do it. (_, Linear) => { @@ -224,7 +233,7 @@ where fn rolling_apply_weighted_quantile( values: &[T], p: f64, - interpolation: QuantileInterpolOptions, + method: QuantileMethod, window_size: usize, min_periods: usize, det_offsets_fn: Fo, @@ -252,7 +261,7 @@ where .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w)); } buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0)); - compute_wq(&buf, p, wsum, interpolation) + compute_wq(&buf, p, wsum, method) }) .collect_trusted::>(); @@ -273,7 +282,7 @@ mod test { let values = &[1.0, 2.0, 3.0, 4.0]; let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: Linear, + method: Linear, })); let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap(); let out = out.as_any().downcast_ref::>().unwrap(); @@ -305,18 +314,19 @@ mod test { fn test_rolling_quantile_limits() { let values = &[1.0f64, 2.0, 3.0, 4.0]; - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -328,7 +338,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs index 52039fe77572..7f0c65d42bd7 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs @@ -25,7 +25,7 @@ pub struct SortedMinMax<'a, T: NativeType> { null_count: usize, } -impl<'a, T: NativeType> SortedMinMax<'a, T> { +impl SortedMinMax<'_, T> { fn count_nulls(&self, start: usize, end: usize) -> usize { let (bytes, offset, _) = self.validity.as_slice(); count_zeros(bytes, offset + start, end - start) diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 259316513fe5..3d5dd664bd34 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -6,7 +6,7 @@ use crate::array::MutablePrimitiveArray; pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> { sorted: SortedBufNulls<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -39,7 +39,7 @@ impl< Self { sorted: SortedBufNulls::new(slice, validity, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -53,21 +53,22 @@ impl< let values = &values[null_count..]; let length = values.len(); - let mut idx = match self.interpol { - QuantileInterpolOptions::Nearest => ((length as f64) * self.prob) as usize, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => { + let mut idx = match self.method { + QuantileMethod::Nearest => ((length as f64) * self.prob) as usize, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { ((length as f64 - 1.0) * self.prob).floor() as usize }, - QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Equiprobable => { + ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize + }, }; idx = std::cmp::min(idx, length - 1); // we can unwrap because we sliced of the nulls - match self.interpol { - QuantileInterpolOptions::Midpoint => { + match self.method { + QuantileMethod::Midpoint => { let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; Some( (values.get_unchecked_release(idx).unwrap() @@ -75,7 +76,7 @@ impl< / T::from::(2.0f64).unwrap(), ) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let float_idx = (length as f64 - 1.0) * self.prob; let top_idx = f64::ceil(float_idx) as usize; @@ -136,7 +137,7 @@ where }; let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>( - params.interpol, + params.method, min_periods, window_size, arr.clone(), @@ -171,7 +172,7 @@ mod test { ); let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: QuantileInterpolOptions::Linear, + method: QuantileMethod::Linear, })); let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone()); @@ -210,18 +211,19 @@ mod test { Some(Bitmap::from(&[true, false, false, true, true])), ); - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -233,7 +235,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs index d1392ef7cb50..599bdb241b07 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs @@ -9,7 +9,7 @@ pub struct SumWindow<'a, T> { pub(super) null_count: usize, } -impl<'a, T: NativeType + IsFloat + Add + Sub> SumWindow<'a, T> { +impl + Sub> SumWindow<'_, T> { // compute sum from the entire window unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { let mut sum = None; diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs index ee97d4cb15a3..8252c8931c4f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs @@ -9,8 +9,8 @@ pub(super) struct SumSquaredWindow<'a, T> { null_count: usize, } -impl<'a, T: NativeType + IsFloat + Add + Sub + Mul> - SumSquaredWindow<'a, T> +impl + Sub + Mul> + SumSquaredWindow<'_, T> { // compute sum from the entire window unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs index 40a464e6f5bc..c616310ee568 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs @@ -11,7 +11,7 @@ use polars_utils::slice::{GetSaferUnchecked, SliceAble}; use polars_utils::sort::arg_sort_ascending; use polars_utils::total_ord::TotalOrd; -use crate::legacy::prelude::QuantileInterpolOptions; +use crate::legacy::prelude::QuantileMethod; use crate::pushable::Pushable; use crate::types::NativeType; @@ -32,7 +32,7 @@ struct Block<'a, A> { nulls_in_window: usize, } -impl<'a, A> Debug for Block<'a, A> +impl Debug for Block<'_, A> where A: Indexable, A::Item: Debug + Copy, @@ -443,7 +443,7 @@ where } } -impl<'a, A> LenGet for BlockUnion<'a, A> +impl LenGet for BlockUnion<'_, A> where A: Indexable + Bounded + NullCount + Clone, ::Item: TotalOrd + Copy + Debug, @@ -573,7 +573,7 @@ struct QuantileUpdate { inner: M, quantile: f64, min_periods: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl QuantileUpdate @@ -581,12 +581,12 @@ where M: LenGet, ::Item: Default + IsNull + Copy + FinishLinear + Debug, { - fn new(interpol: QuantileInterpolOptions, min_periods: usize, quantile: f64, inner: M) -> Self { + fn new(method: QuantileMethod, min_periods: usize, quantile: f64, inner: M) -> Self { Self { min_periods, quantile, inner, - interpol, + method, } } @@ -602,8 +602,8 @@ where let valid_length_f = valid_length as f64; - use QuantileInterpolOptions::*; - match self.interpol { + use QuantileMethod::*; + match self.method { Linear => { let float_idx_top = (valid_length_f - 1.0) * self.quantile; let idx = float_idx_top.floor() as usize; @@ -623,6 +623,10 @@ where let idx = std::cmp::min(idx, valid_length - 1); self.inner.get(idx + null_count) }, + Equiprobable => { + let idx = ((valid_length_f * self.quantile).ceil() - 1.0).max(0.0) as usize; + self.inner.get(idx + null_count) + }, Midpoint => { let idx = (valid_length_f * self.quantile) as usize; let idx = std::cmp::min(idx, valid_length - 1); @@ -651,7 +655,7 @@ where } pub(super) fn rolling_quantile::Item>>( - interpol: QuantileInterpolOptions, + method: QuantileMethod, min_periods: usize, k: usize, values: A, @@ -709,7 +713,7 @@ where // SAFETY: bounded by capacity unsafe { block_left.undelete(i) }; - let mut mu = QuantileUpdate::new(interpol, min_periods, quantile, &mut block_left); + let mut mu = QuantileUpdate::new(method, min_periods, quantile, &mut block_left); out.push(mu.quantile()); } for i in 1..n_blocks + 1 { @@ -747,7 +751,7 @@ where let mut union = BlockUnion::new(&mut *ptr_left, &mut *ptr_right); union.set_state(j); let q: ::Item = - QuantileUpdate::new(interpol, min_periods, quantile, union).quantile(); + QuantileUpdate::new(method, min_periods, quantile, union).quantile(); out.push(q); } } @@ -1062,22 +1066,22 @@ mod test { 2.0, 8.0, 5.0, 9.0, 1.0, 2.0, 4.0, 2.0, 4.0, 8.1, -1.0, 2.9, 1.2, 23.0, ] .as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 8.0, 5.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 2.9, 1.2, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 5, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 5, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 5.0, 4.0, 2.0, 2.0, 4.0, 4.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 7, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 7, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 3.5, 4.0, 4.0, 4.0, 4.0, 2.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 4, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 4, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 6.5, 3.5, 3.0, 2.0, 3.0, 4.0, 3.0, 3.45, 2.05, 2.05, ]; @@ -1087,7 +1091,7 @@ mod test { #[test] fn test_median_2() { let values = [10, 10, 15, 13, 9, 5, 3, 13, 19, 15, 19].as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [10, 10, 10, 13, 13, 9, 5, 5, 13, 15, 19]; assert_eq!(out, expected); } diff --git a/crates/polars-arrow/src/legacy/kernels/time.rs b/crates/polars-arrow/src/legacy/kernels/time.rs index 08bc285c7ffe..73caabb7587b 100644 --- a/crates/polars-arrow/src/legacy/kernels/time.rs +++ b/crates/polars-arrow/src/legacy/kernels/time.rs @@ -9,6 +9,7 @@ use polars_error::PolarsResult; use polars_error::{polars_bail, PolarsError}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; pub enum Ambiguous { Earliest, @@ -32,8 +33,9 @@ impl FromStr for Ambiguous { } } -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum NonExistent { Null, Raise, diff --git a/crates/polars-arrow/src/legacy/prelude.rs b/crates/polars-arrow/src/legacy/prelude.rs index 88b2dd48bbea..6afeb0c6c9be 100644 --- a/crates/polars-arrow/src/legacy/prelude.rs +++ b/crates/polars-arrow/src/legacy/prelude.rs @@ -2,7 +2,7 @@ use crate::array::{BinaryArray, ListArray, Utf8Array}; pub use crate::legacy::array::default_arrays::*; pub use crate::legacy::array::*; pub use crate::legacy::index::*; -pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; +pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod; pub use crate::legacy::kernels::rolling::{ RollingFnParams, RollingQuantileParams, RollingVarParams, }; @@ -11,3 +11,6 @@ pub use crate::legacy::kernels::{Ambiguous, NonExistent}; pub type LargeStringArray = Utf8Array; pub type LargeBinaryArray = BinaryArray; pub type LargeListArray = ListArray; + +#[allow(deprecated)] +pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; diff --git a/crates/polars-arrow/src/mmap/array.rs b/crates/polars-arrow/src/mmap/array.rs index 8822824858c6..060ab745fcc0 100644 --- a/crates/polars-arrow/src/mmap/array.rs +++ b/crates/polars-arrow/src/mmap/array.rs @@ -35,7 +35,7 @@ fn check_bytes_len_and_is_aligned( bytes: &[u8], expected_len: usize, ) -> PolarsResult { - if bytes.len() < std::mem::size_of::() * expected_len { + if bytes.len() < size_of::() * expected_len { polars_bail!(ComputeError: "buffer's length is too small in mmap") }; @@ -281,7 +281,7 @@ fn mmap_primitive>( let bytes = get_bytes(data_ref, block_offset, buffers)?; let is_aligned = check_bytes_len_and_is_aligned::

(bytes, num_rows)?; - let out = if is_aligned || std::mem::size_of::() <= 8 { + let out = if is_aligned || size_of::() <= 8 { assert!( is_aligned, "primitive type with size <= 8 bytes should have been aligned" diff --git a/crates/polars-arrow/src/mmap/mod.rs b/crates/polars-arrow/src/mmap/mod.rs index b934c31de563..6ad0ca776d7d 100644 --- a/crates/polars-arrow/src/mmap/mod.rs +++ b/crates/polars-arrow/src/mmap/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; mod array; use arrow_format::ipc::planus::ReadAsRoot; -use arrow_format::ipc::{Block, MessageRef, RecordBatchRef}; +use arrow_format::ipc::{Block, DictionaryBatchRef, MessageRef, RecordBatchRef}; use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; use polars_utils::pl_str::PlSmallStr; @@ -71,7 +71,7 @@ fn get_buffers_nodes(batch: RecordBatchRef) -> PolarsResult<(VecDeque Ok((buffers, field_nodes)) } -unsafe fn _mmap_record>( +pub(crate) unsafe fn mmap_record>( fields: &ArrowSchema, ipc_fields: &[IpcField], data: Arc, @@ -86,6 +86,13 @@ unsafe fn _mmap_record>( .map(|v| v.iter().map(|v| v as usize).collect::>()) .unwrap_or_else(VecDeque::new); + let length = batch + .length() + .map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData)) + .unwrap() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + fields .iter_values() .map(|f| &f.dtype) @@ -104,26 +111,7 @@ unsafe fn _mmap_record>( ) }) .collect::>() - .and_then(RecordBatchT::try_new) -} - -unsafe fn _mmap_unchecked>( - fields: &ArrowSchema, - ipc_fields: &[IpcField], - data: Arc, - block: Block, - dictionaries: &Dictionaries, -) -> PolarsResult>> { - let (message, offset) = read_message(data.as_ref().as_ref(), block)?; - let batch = get_record_batch(message)?; - _mmap_record( - fields, - ipc_fields, - data.clone(), - batch, - offset, - dictionaries, - ) + .and_then(|arr| RecordBatchT::try_new(length, arr)) } /// Memory maps an record batch from an IPC file into a [`RecordBatchT`]. @@ -147,7 +135,7 @@ pub unsafe fn mmap_unchecked>( let (message, offset) = read_message(data.as_ref().as_ref(), block)?; let batch = get_record_batch(message)?; - _mmap_record( + mmap_record( &metadata.schema, &metadata.ipc_schema.fields, data.clone(), @@ -158,19 +146,29 @@ pub unsafe fn mmap_unchecked>( } unsafe fn mmap_dictionary>( - metadata: &FileMetadata, + schema: &ArrowSchema, + ipc_fields: &[IpcField], data: Arc, block: Block, dictionaries: &mut Dictionaries, ) -> PolarsResult<()> { let (message, offset) = read_message(data.as_ref().as_ref(), block)?; let batch = get_dictionary_batch(&message)?; + mmap_dictionary_from_batch(schema, ipc_fields, &data, batch, dictionaries, offset) +} +pub(crate) unsafe fn mmap_dictionary_from_batch>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + data: &Arc, + batch: DictionaryBatchRef, + dictionaries: &mut Dictionaries, + offset: usize, +) -> PolarsResult<()> { let id = batch .id() .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferId(err)))?; - let (first_field, first_ipc_field) = - first_dict_field(id, &metadata.schema, &metadata.ipc_schema.fields)?; + let (first_field, first_ipc_field) = first_dict_field(id, schema, ipc_fields)?; let batch = batch .data() @@ -188,7 +186,7 @@ unsafe fn mmap_dictionary>( // Make a fake schema for the dictionary batch. let field = Field::new(PlSmallStr::EMPTY, value_type.clone(), false); - let chunk = _mmap_record( + let chunk = mmap_record( &std::iter::once((field.name.clone(), field)).collect(), &[first_ipc_field.clone()], data.clone(), @@ -211,7 +209,21 @@ pub unsafe fn mmap_dictionaries_unchecked>( metadata: &FileMetadata, data: Arc, ) -> PolarsResult { - let blocks = if let Some(blocks) = &metadata.dictionaries { + mmap_dictionaries_unchecked2( + metadata.schema.as_ref(), + &metadata.ipc_schema.fields, + metadata.dictionaries.as_ref(), + data, + ) +} + +pub(crate) unsafe fn mmap_dictionaries_unchecked2>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + dictionaries: Option<&Vec>, + data: Arc, +) -> PolarsResult { + let blocks = if let Some(blocks) = &dictionaries { blocks } else { return Ok(Default::default()); @@ -219,9 +231,8 @@ pub unsafe fn mmap_dictionaries_unchecked>( let mut dictionaries = Default::default(); - blocks - .iter() - .cloned() - .try_for_each(|block| mmap_dictionary(metadata, data.clone(), block, &mut dictionaries))?; + blocks.iter().cloned().try_for_each(|block| { + mmap_dictionary(schema, ipc_fields, data.clone(), block, &mut dictionaries) + })?; Ok(dictionaries) } diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs index ae4583dfe6f4..ca148655c2e3 100644 --- a/crates/polars-arrow/src/offset.rs +++ b/crates/polars-arrow/src/offset.rs @@ -415,7 +415,7 @@ impl OffsetsBuffer { &self.0 } - /// Returns the length an array with these offsets would be. + /// Returns what the length an array with these offsets would be. #[inline] pub fn len_proxy(&self) -> usize { self.0.len() - 1 @@ -513,6 +513,53 @@ impl OffsetsBuffer { self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) } + /// Returns `(offset, len)` pairs. + #[inline] + pub fn offset_and_length_iter(&self) -> impl Iterator + '_ { + self.windows(2).map(|x| { + let [l, r] = x else { unreachable!() }; + let l = l.to_usize(); + let r = r.to_usize(); + (l, r - l) + }) + } + + /// Offset and length of the primitive (leaf) array for a double+ nested list for every outer + /// row. + pub fn leaf_ranges_iter( + offsets: &[Self], + ) -> impl Iterator> + '_ { + let others = &offsets[1..]; + + offsets[0].windows(2).map(move |x| { + let [l, r] = x else { unreachable!() }; + let mut l = l.to_usize(); + let mut r = r.to_usize(); + + for o in others { + let slc = o.as_slice(); + l = slc[l].to_usize(); + r = slc[r].to_usize(); + } + + l..r + }) + } + + /// Return the full range of the leaf array used by the list. + pub fn leaf_full_start_end(offsets: &[Self]) -> core::ops::Range { + let mut l = offsets[0].first().to_usize(); + let mut r = offsets[0].last().to_usize(); + + for o in &offsets[1..] { + let slc = o.as_slice(); + l = slc[l].to_usize(); + r = slc[r].to_usize(); + } + + l..r + } + /// Returns the inner [`Buffer`]. #[inline] pub fn into_inner(self) -> Buffer { diff --git a/crates/polars-arrow/src/record_batch.rs b/crates/polars-arrow/src/record_batch.rs index b11fd6c899a2..f58d129831f1 100644 --- a/crates/polars-arrow/src/record_batch.rs +++ b/crates/polars-arrow/src/record_batch.rs @@ -1,7 +1,7 @@ //! Contains [`RecordBatchT`], a container of [`Array`] where every array has the //! same length. -use polars_error::{polars_bail, PolarsResult}; +use polars_error::{polars_ensure, PolarsResult}; use crate::array::{Array, ArrayRef}; @@ -9,6 +9,7 @@ use crate::array::{Array, ArrayRef}; /// the same length, [`RecordBatchT::len`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct RecordBatchT> { + length: usize, arrays: Vec, } @@ -16,29 +17,26 @@ pub type RecordBatch = RecordBatchT; impl> RecordBatchT { /// Creates a new [`RecordBatchT`]. - /// # Panic - /// Iff the arrays do not have the same length - pub fn new(arrays: Vec) -> Self { - Self::try_new(arrays).unwrap() + /// + /// # Panics + /// + /// I.f.f. the length does not match the length of any of the arrays + pub fn new(length: usize, arrays: Vec) -> Self { + Self::try_new(length, arrays).unwrap() } /// Creates a new [`RecordBatchT`]. + /// /// # Error - /// Iff the arrays do not have the same length - pub fn try_new(arrays: Vec) -> PolarsResult { - if !arrays.is_empty() { - let len = arrays.first().unwrap().as_ref().len(); - if arrays - .iter() - .map(|array| array.as_ref()) - .any(|array| array.len() != len) - { - polars_bail!(ComputeError: - "RecordBatch requires all its arrays to have an equal number of rows".to_string(), - ); - } - } - Ok(Self { arrays }) + /// + /// I.f.f. the length does not match the length of any of the arrays + pub fn try_new(length: usize, arrays: Vec) -> PolarsResult { + polars_ensure!( + arrays.iter().all(|arr| arr.as_ref().len() == length), + ComputeError: "RecordBatch requires all its arrays to have an equal number of rows", + ); + + Ok(Self { length, arrays }) } /// returns the [`Array`]s in [`RecordBatchT`] @@ -53,10 +51,7 @@ impl> RecordBatchT { /// returns the number of rows of every array pub fn len(&self) -> usize { - self.arrays - .first() - .map(|x| x.as_ref().len()) - .unwrap_or_default() + self.length } /// returns whether the columns have any rows diff --git a/crates/polars-arrow/src/storage.rs b/crates/polars-arrow/src/storage.rs index 7cab3235b25a..ddde815b5b10 100644 --- a/crates/polars-arrow/src/storage.rs +++ b/crates/polars-arrow/src/storage.rs @@ -1,37 +1,93 @@ use std::marker::PhantomData; use std::mem::ManuallyDrop; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::ptr::NonNull; use std::sync::atomic::{AtomicU64, Ordering}; +use bytemuck::Pod; + +// Allows us to transmute between types while also keeping the original +// stats and drop method of the Vec around. +struct VecVTable { + size: usize, + align: usize, + drop_buffer: unsafe fn(*mut (), usize), +} + +impl VecVTable { + const fn new() -> Self { + unsafe fn drop_buffer(ptr: *mut (), cap: usize) { + unsafe { drop(Vec::from_raw_parts(ptr.cast::(), 0, cap)) } + } + + Self { + size: size_of::(), + align: align_of::(), + drop_buffer: drop_buffer::, + } + } + + fn new_static() -> &'static Self { + const { &Self::new::() } + } +} + use crate::ffi::InternalArrowArray; enum BackingStorage { Vec { - capacity: usize, + original_capacity: usize, // Elements, not bytes. + vtable: &'static VecVTable, }, InternalArrowArray(InternalArrowArray), - #[cfg(feature = "arrow_rs")] - ArrowBuffer(arrow_buffer::Buffer), } struct SharedStorageInner { ref_count: AtomicU64, ptr: *mut T, - length: usize, + length_in_bytes: usize, backing: Option, // https://github.com/rust-lang/rfcs/blob/master/text/0769-sound-generic-drop.md#phantom-data phantom: PhantomData, } +impl SharedStorageInner { + pub fn from_vec(mut v: Vec) -> Self { + let length_in_bytes = v.len() * size_of::(); + let original_capacity = v.capacity(); + let ptr = v.as_mut_ptr(); + core::mem::forget(v); + Self { + ref_count: AtomicU64::new(1), + ptr, + length_in_bytes, + backing: Some(BackingStorage::Vec { + original_capacity, + vtable: VecVTable::new_static::(), + }), + phantom: PhantomData, + } + } +} + impl Drop for SharedStorageInner { fn drop(&mut self) { match self.backing.take() { Some(BackingStorage::InternalArrowArray(a)) => drop(a), - #[cfg(feature = "arrow_rs")] - Some(BackingStorage::ArrowBuffer(b)) => drop(b), - Some(BackingStorage::Vec { capacity }) => unsafe { - drop(Vec::from_raw_parts(self.ptr, self.length, capacity)) + Some(BackingStorage::Vec { + original_capacity, + vtable, + }) => unsafe { + // Drop the elements in our slice. + if std::mem::needs_drop::() { + core::ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut( + self.ptr, + self.length_in_bytes / size_of::(), + )); + } + + // Free the buffer. + (vtable.drop_buffer)(self.ptr.cast(), original_capacity); }, None => {}, } @@ -48,12 +104,13 @@ unsafe impl Sync for SharedStorage {} impl SharedStorage { pub fn from_static(slice: &'static [T]) -> Self { - let length = slice.len(); + #[expect(clippy::manual_slice_size_calculation)] + let length_in_bytes = slice.len() * size_of::(); let ptr = slice.as_ptr().cast_mut(); let inner = SharedStorageInner { ref_count: AtomicU64::new(2), // Never used, but 2 so it won't pass exclusivity tests. ptr, - length, + length_in_bytes, backing: None, phantom: PhantomData, }; @@ -63,20 +120,9 @@ impl SharedStorage { } } - pub fn from_vec(mut v: Vec) -> Self { - let length = v.len(); - let capacity = v.capacity(); - let ptr = v.as_mut_ptr(); - core::mem::forget(v); - let inner = SharedStorageInner { - ref_count: AtomicU64::new(1), - ptr, - length, - backing: Some(BackingStorage::Vec { capacity }), - phantom: PhantomData, - }; + pub fn from_vec(v: Vec) -> Self { Self { - inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(), + inner: NonNull::new(Box::into_raw(Box::new(SharedStorageInner::from_vec(v)))).unwrap(), phantom: PhantomData, } } @@ -85,7 +131,7 @@ impl SharedStorage { let inner = SharedStorageInner { ref_count: AtomicU64::new(1), ptr: ptr.cast_mut(), - length: len, + length_in_bytes: len * size_of::(), backing: Some(BackingStorage::InternalArrowArray(arr)), phantom: PhantomData, }; @@ -96,39 +142,40 @@ impl SharedStorage { } } -#[cfg(feature = "arrow_rs")] -impl SharedStorage { - pub fn from_arrow_buffer(buffer: arrow_buffer::Buffer) -> Self { - let ptr = buffer.as_ptr(); - let align_offset = ptr.align_offset(std::mem::align_of::()); - assert_eq!(align_offset, 0, "arrow_buffer::Buffer misaligned"); - let length = buffer.len() / std::mem::size_of::(); +pub struct SharedStorageAsVecMut<'a, T> { + ss: &'a mut SharedStorage, + vec: ManuallyDrop>, +} - let inner = SharedStorageInner { - ref_count: AtomicU64::new(1), - ptr: ptr as *mut T, - length, - backing: Some(BackingStorage::ArrowBuffer(buffer)), - phantom: PhantomData, - }; - Self { - inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(), - phantom: PhantomData, - } +impl Deref for SharedStorageAsVecMut<'_, T> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.vec + } +} + +impl DerefMut for SharedStorageAsVecMut<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.vec } +} - pub fn into_arrow_buffer(self) -> arrow_buffer::Buffer { - let ptr = NonNull::new(self.as_ptr() as *mut u8).unwrap(); - let len = self.len() * std::mem::size_of::(); - let arc = std::sync::Arc::new(self); - unsafe { arrow_buffer::Buffer::from_custom_allocation(ptr, len, arc) } +impl Drop for SharedStorageAsVecMut<'_, T> { + fn drop(&mut self) { + unsafe { + // Restore the SharedStorage. + let vec = ManuallyDrop::take(&mut self.vec); + let inner = self.ss.inner.as_ptr(); + inner.write(SharedStorageInner::from_vec(vec)); + } } } impl SharedStorage { #[inline(always)] pub fn len(&self) -> usize { - self.inner().length + self.inner().length_in_bytes / size_of::() } #[inline(always)] @@ -156,21 +203,55 @@ impl SharedStorage { pub fn try_as_mut_slice(&mut self) -> Option<&mut [T]> { self.is_exclusive().then(|| { let inner = self.inner(); - unsafe { core::slice::from_raw_parts_mut(inner.ptr, inner.length) } + let len = inner.length_in_bytes / size_of::(); + unsafe { core::slice::from_raw_parts_mut(inner.ptr, len) } }) } - pub fn try_into_vec(mut self) -> Result, Self> { - let Some(BackingStorage::Vec { capacity }) = self.inner().backing else { - return Err(self); + /// Try to take the vec backing this SharedStorage, leaving this as an empty slice. + pub fn try_take_vec(&mut self) -> Option> { + // We may only go back to a Vec if we originally came from a Vec + // where the desired size/align matches the original. + let Some(BackingStorage::Vec { + original_capacity, + vtable, + }) = self.inner().backing + else { + return None; }; - if self.is_exclusive() { - let slf = ManuallyDrop::new(self); - let inner = slf.inner(); - Ok(unsafe { Vec::from_raw_parts(inner.ptr, inner.length, capacity) }) - } else { - Err(self) + + if vtable.size != size_of::() || vtable.align != align_of::() { + return None; + } + + // If there are other references we can't get an exclusive reference. + if !self.is_exclusive() { + return None; } + + let ret; + unsafe { + let inner = &mut *self.inner.as_ptr(); + let len = inner.length_in_bytes / size_of::(); + ret = Vec::from_raw_parts(inner.ptr, len, original_capacity); + inner.length_in_bytes = 0; + inner.backing = None; + } + Some(ret) + } + + /// Attempts to call the given function with this SharedStorage as a + /// reference to a mutable Vec. If this SharedStorage can't be converted to + /// a Vec the function is not called and instead returned as an error. + pub fn try_as_mut_vec(&mut self) -> Option> { + Some(SharedStorageAsVecMut { + vec: ManuallyDrop::new(self.try_take_vec()?), + ss: self, + }) + } + + pub fn try_into_vec(mut self) -> Result, Self> { + self.try_take_vec().ok_or(self) } #[inline(always)] @@ -186,6 +267,39 @@ impl SharedStorage { } } +impl SharedStorage { + fn try_transmute(self) -> Result, Self> { + let inner = self.inner(); + + // The length of the array in bytes must be a multiple of the target size. + // We can skip this check if the size of U divides the size of T. + if size_of::() % size_of::() != 0 && inner.length_in_bytes % size_of::() != 0 { + return Err(self); + } + + // The pointer must be properly aligned for U. + // We can skip this check if the alignment of U divides the alignment of T. + if align_of::() % align_of::() != 0 && !inner.ptr.cast::().is_aligned() { + return Err(self); + } + + Ok(SharedStorage { + inner: self.inner.cast(), + phantom: PhantomData, + }) + } +} + +impl SharedStorage { + /// Create a [`SharedStorage`][SharedStorage] from a [`Vec`] of [`Pod`]. + pub fn bytes_from_pod_vec(v: Vec) -> Self { + // This can't fail, bytes is compatible with everything. + SharedStorage::from_vec(v) + .try_transmute::() + .unwrap_or_else(|_| unreachable!()) + } +} + impl Deref for SharedStorage { type Target = [T]; @@ -193,7 +307,8 @@ impl Deref for SharedStorage { fn deref(&self) -> &Self::Target { unsafe { let inner = self.inner(); - core::slice::from_raw_parts(inner.ptr, inner.length) + let len = inner.length_in_bytes / size_of::(); + core::slice::from_raw_parts(inner.ptr, len) } } } diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs index b5672f6dd626..1cf8dc10846a 100644 --- a/crates/polars-arrow/src/temporal_conversions.rs +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -3,10 +3,14 @@ use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta}; use polars_error::{polars_err, PolarsResult}; +#[cfg(feature = "compute_cast")] use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "compute_cast")] use crate::array::{PrimitiveArray, Utf8ViewArray}; -use crate::datatypes::{ArrowDataType, TimeUnit}; +#[cfg(feature = "compute_cast")] +use crate::datatypes::ArrowDataType; +use crate::datatypes::TimeUnit; /// Number of seconds in a day pub const SECONDS_IN_DAY: i64 = 86_400; @@ -316,6 +320,7 @@ pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> .ok() } +#[cfg(feature = "compute_cast")] fn utf8view_to_timestamp_impl( array: &Utf8ViewArray, fmt: &str, @@ -365,6 +370,7 @@ fn chrono_tz_utf_to_timestamp( } #[cfg(not(feature = "chrono-tz"))] +#[cfg(feature = "compute_cast")] fn chrono_tz_utf_to_timestamp( _: &Utf8ViewArray, _: &str, @@ -387,6 +393,7 @@ fn chrono_tz_utf_to_timestamp( /// # Error /// /// This function errors iff `timezone` is not parsable to an offset. +#[cfg(feature = "compute_cast")] pub(crate) fn utf8view_to_timestamp( array: &Utf8ViewArray, fmt: &str, @@ -408,6 +415,7 @@ pub(crate) fn utf8view_to_timestamp( /// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. /// Timezones are ignored. /// Null elements remain null; non-parsable elements are set to null. +#[cfg(feature = "compute_cast")] pub(crate) fn utf8view_to_naive_timestamp( array: &Utf8ViewArray, fmt: &str, diff --git a/crates/polars-arrow/src/types/aligned_bytes.rs b/crates/polars-arrow/src/types/aligned_bytes.rs new file mode 100644 index 000000000000..2c9bf9aed977 --- /dev/null +++ b/crates/polars-arrow/src/types/aligned_bytes.rs @@ -0,0 +1,112 @@ +use bytemuck::{Pod, Zeroable}; + +use super::{days_ms, f16, i256, months_days_ns}; +use crate::array::View; + +/// Define that a type has the same byte alignment and size as `B`. +/// +/// # Safety +/// +/// This is safe to implement if both types have the same alignment and size. +pub unsafe trait AlignedBytesCast: Pod {} + +/// A representation of a type as raw bytes with the same alignment as the original type. +pub trait AlignedBytes: Pod + Zeroable + Copy + Default + Eq { + const ALIGNMENT: usize; + const SIZE: usize; + + type Unaligned: AsRef<[u8]> + + AsMut<[u8]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> + + std::fmt::Debug + + Default + + IntoIterator + + Pod; + + fn to_unaligned(&self) -> Self::Unaligned; + fn from_unaligned(unaligned: Self::Unaligned) -> Self; + + /// Safely cast a mutable reference to a [`Vec`] of `T` to a mutable reference of `Self`. + fn cast_vec_ref_mut>(vec: &mut Vec) -> &mut Vec { + if cfg!(debug_assertions) { + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + } + + // SAFETY: SameBytes guarantees that T: + // 1. has the same size + // 2. has the same alignment + // 3. is Pod (therefore has no life-time issues) + unsafe { std::mem::transmute(vec) } + } +} + +macro_rules! impl_aligned_bytes { + ( + $(($name:ident, $size:literal, $alignment:literal, [$($eq_type:ty),*]),)+ + ) => { + $( + /// Bytes with a size and alignment. + /// + /// This is used to reduce the monomorphizations for routines that solely rely on the size + /// and alignment of types. + #[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Pod, Zeroable)] + #[repr(C, align($alignment))] + pub struct $name([u8; $size]); + + impl AlignedBytes for $name { + const ALIGNMENT: usize = $alignment; + const SIZE: usize = $size; + + type Unaligned = [u8; $size]; + + #[inline(always)] + fn to_unaligned(&self) -> Self::Unaligned { + self.0 + } + #[inline(always)] + fn from_unaligned(unaligned: Self::Unaligned) -> Self { + Self(unaligned) + } + } + + impl AsRef<[u8; $size]> for $name { + #[inline(always)] + fn as_ref(&self) -> &[u8; $size] { + &self.0 + } + } + + $( + impl From<$eq_type> for $name { + #[inline(always)] + fn from(value: $eq_type) -> Self { + bytemuck::must_cast(value) + } + } + impl From<$name> for $eq_type { + #[inline(always)] + fn from(value: $name) -> Self { + bytemuck::must_cast(value) + } + } + unsafe impl AlignedBytesCast<$name> for $eq_type {} + )* + )+ + } +} + +impl_aligned_bytes! { + (Bytes1Alignment1, 1, 1, [u8, i8]), + (Bytes2Alignment2, 2, 2, [u16, i16, f16]), + (Bytes4Alignment4, 4, 4, [u32, i32, f32]), + (Bytes8Alignment8, 8, 8, [u64, i64, f64]), + (Bytes8Alignment4, 8, 4, [days_ms]), + (Bytes12Alignment4, 12, 4, [[u32; 3]]), + (Bytes16Alignment4, 16, 4, [View]), + (Bytes16Alignment8, 16, 8, [months_days_ns]), + (Bytes16Alignment16, 16, 16, [u128, i128]), + (Bytes32Alignment16, 32, 16, [i256]), +} diff --git a/crates/polars-arrow/src/types/bit_chunk.rs b/crates/polars-arrow/src/types/bit_chunk.rs index be4445a5d77a..60892689060f 100644 --- a/crates/polars-arrow/src/types/bit_chunk.rs +++ b/crates/polars-arrow/src/types/bit_chunk.rs @@ -72,7 +72,7 @@ impl BitChunkIter { /// Creates a new [`BitChunkIter`] with `len` bits. #[inline] pub fn new(value: T, len: usize) -> Self { - assert!(len <= std::mem::size_of::() * 8); + assert!(len <= size_of::() * 8); Self { value, remaining: len, diff --git a/crates/polars-arrow/src/types/mod.rs b/crates/polars-arrow/src/types/mod.rs index 49b4d315408e..c6f653a32311 100644 --- a/crates/polars-arrow/src/types/mod.rs +++ b/crates/polars-arrow/src/types/mod.rs @@ -20,6 +20,8 @@ //! Finally, this module contains traits used to compile code based on [`NativeType`] optimized //! for SIMD, at [`mod@simd`]. +mod aligned_bytes; +pub use aligned_bytes::*; mod bit_chunk; pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes}; mod index; diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs index 6f869df32602..9cf9a6b56c46 100644 --- a/crates/polars-arrow/src/types/native.rs +++ b/crates/polars-arrow/src/types/native.rs @@ -1,11 +1,13 @@ +use std::hash::{Hash, Hasher}; use std::ops::Neg; use std::panic::RefUnwindSafe; use bytemuck::{Pod, Zeroable}; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; -use polars_utils::total_ord::{TotalEq, TotalOrd}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrd, TotalOrdWrap}; +use super::aligned_bytes::*; use super::PrimitiveType; /// Sealed trait implemented by all physical types that can be allocated, @@ -26,6 +28,7 @@ pub trait NativeType: + TotalOrd + IsNull + MinMax + + AlignedBytesCast { /// The corresponding variant of [`PrimitiveType`]. const PRIMITIVE: PrimitiveType; @@ -41,6 +44,11 @@ pub trait NativeType: + Default + IntoIterator; + /// Type denoting its representation as aligned bytes. + /// + /// This is `[u8; N]` where `N = size_of::` and has alignment `align_of::`. + type AlignedBytes: AlignedBytes + From + Into; + /// To bytes in little endian fn to_le_bytes(&self) -> Self::Bytes; @@ -55,11 +63,13 @@ pub trait NativeType: } macro_rules! native_type { - ($type:ty, $primitive_type:expr) => { + ($type:ty, $aligned:ty, $primitive_type:expr) => { impl NativeType for $type { const PRIMITIVE: PrimitiveType = $primitive_type; type Bytes = [u8; std::mem::size_of::()]; + type AlignedBytes = $aligned; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { Self::to_le_bytes(*self) @@ -83,18 +93,18 @@ macro_rules! native_type { }; } -native_type!(u8, PrimitiveType::UInt8); -native_type!(u16, PrimitiveType::UInt16); -native_type!(u32, PrimitiveType::UInt32); -native_type!(u64, PrimitiveType::UInt64); -native_type!(i8, PrimitiveType::Int8); -native_type!(i16, PrimitiveType::Int16); -native_type!(i32, PrimitiveType::Int32); -native_type!(i64, PrimitiveType::Int64); -native_type!(f32, PrimitiveType::Float32); -native_type!(f64, PrimitiveType::Float64); -native_type!(i128, PrimitiveType::Int128); -native_type!(u128, PrimitiveType::UInt128); +native_type!(u8, Bytes1Alignment1, PrimitiveType::UInt8); +native_type!(u16, Bytes2Alignment2, PrimitiveType::UInt16); +native_type!(u32, Bytes4Alignment4, PrimitiveType::UInt32); +native_type!(u64, Bytes8Alignment8, PrimitiveType::UInt64); +native_type!(i8, Bytes1Alignment1, PrimitiveType::Int8); +native_type!(i16, Bytes2Alignment2, PrimitiveType::Int16); +native_type!(i32, Bytes4Alignment4, PrimitiveType::Int32); +native_type!(i64, Bytes8Alignment8, PrimitiveType::Int64); +native_type!(f32, Bytes4Alignment4, PrimitiveType::Float32); +native_type!(f64, Bytes8Alignment8, PrimitiveType::Float64); +native_type!(i128, Bytes16Alignment16, PrimitiveType::Int128); +native_type!(u128, Bytes16Alignment16, PrimitiveType::UInt128); /// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type. #[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, Pod)] @@ -150,7 +160,10 @@ impl MinMax for days_ms { impl NativeType for days_ms { const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs; + type Bytes = [u8; 8]; + type AlignedBytes = Bytes8Alignment4; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { let days = self.0.to_le_bytes(); @@ -288,7 +301,10 @@ impl MinMax for months_days_ns { impl NativeType for months_days_ns { const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano; + type Bytes = [u8; 16]; + type AlignedBytes = Bytes16Alignment8; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { let months = self.months().to_le_bytes(); @@ -434,6 +450,44 @@ impl PartialEq for f16 { } } +/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +#[inline] +pub fn canonical_f16(x: f16) -> f16 { + // zero out the sign bit if the f16 is zero. + let convert_zero = f16(x.0 & (0x7FFF | (u16::from(x.0 & 0x7FFF == 0) << 15))); + if convert_zero.is_nan() { + f16::from_bits(0x7c00) // Canonical quiet NaN. + } else { + convert_zero + } +} + +impl TotalHash for f16 { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f16(*self).to_bits().hash(state) + } +} + +impl ToTotalOrd for f16 { + type TotalOrdItem = TotalOrdWrap; + type SourceItem = f16; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + impl IsNull for f16 { const HAS_NULLS: bool = false; type Inner = f16; @@ -619,7 +673,10 @@ impl MinMax for f16 { impl NativeType for f16 { const PRIMITIVE: PrimitiveType = PrimitiveType::Float16; + type Bytes = [u8; 2]; + type AlignedBytes = Bytes2Alignment2; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { self.0.to_le_bytes() @@ -719,6 +776,7 @@ impl NativeType for i256 { const PRIMITIVE: PrimitiveType = PrimitiveType::Int256; type Bytes = [u8; 32]; + type AlignedBytes = Bytes32Alignment16; #[inline] fn to_le_bytes(&self) -> Self::Bytes { diff --git a/crates/polars-arrow/src/util/macros.rs b/crates/polars-arrow/src/util/macros.rs index b09a9d5d5473..fb5bd61ebba0 100644 --- a/crates/polars-arrow/src/util/macros.rs +++ b/crates/polars-arrow/src/util/macros.rs @@ -13,6 +13,7 @@ macro_rules! with_match_primitive_type {( UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, + Int128 => __with_ty__! { i128 }, Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, _ => panic!("operator does not support primitive `{:?}`", diff --git a/crates/polars-compute/src/arithmetic/mod.rs b/crates/polars-compute/src/arithmetic/mod.rs index 83471f219d68..cb74881ed4a5 100644 --- a/crates/polars-compute/src/arithmetic/mod.rs +++ b/crates/polars-compute/src/arithmetic/mod.rs @@ -141,5 +141,6 @@ impl ArithmeticKernel for PrimitiveArray { } mod float; +pub mod pl_num; mod signed; mod unsigned; diff --git a/crates/polars-compute/src/arithmetic/pl_num.rs b/crates/polars-compute/src/arithmetic/pl_num.rs new file mode 100644 index 000000000000..c792deacfd52 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/pl_num.rs @@ -0,0 +1,229 @@ +use core::any::TypeId; + +use arrow::types::NativeType; +use polars_utils::floor_divmod::FloorDivMod; + +/// Implements basic arithmetic between scalars with the same behavior as `ArithmeticKernel`. +/// +/// Note, however, that the user is responsible for setting the validity of +/// results for e.g. div/mod operations with 0 in the denominator. +/// +/// This is intended as a low-level utility for custom arithmetic loops +/// (e.g. in list arithmetic). In most cases prefer using `ArithmeticKernel` or +/// `ArithmeticChunked` instead. +pub trait PlNumArithmetic: Sized + Copy + 'static { + type TrueDivT: NativeType; + + fn wrapping_abs(self) -> Self; + fn wrapping_neg(self) -> Self; + fn wrapping_add(self, rhs: Self) -> Self; + fn wrapping_sub(self, rhs: Self) -> Self; + fn wrapping_mul(self, rhs: Self) -> Self; + fn wrapping_floor_div(self, rhs: Self) -> Self; + fn wrapping_trunc_div(self, rhs: Self) -> Self; + fn wrapping_mod(self, rhs: Self) -> Self; + + fn true_div(self, rhs: Self) -> Self::TrueDivT; + + #[inline(always)] + fn legacy_div(self, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::() { + let ret = self.true_div(rhs); + unsafe { core::mem::transmute_copy(&ret) } + } else { + self.wrapping_floor_div(rhs) + } + } +} + +macro_rules! impl_signed_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = f64; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self.wrapping_abs() + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + self.wrapping_floor_div_mod(rhs).0 + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + if rhs != 0 { + self.wrapping_div(rhs) + } else { + 0 + } + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + self.wrapping_floor_div_mod(rhs).1 + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self as f64 / rhs as f64 + } + } + }; +} + +impl_signed_pl_num_arith!(i8); +impl_signed_pl_num_arith!(i16); +impl_signed_pl_num_arith!(i32); +impl_signed_pl_num_arith!(i64); +impl_signed_pl_num_arith!(i128); + +macro_rules! impl_unsigned_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = f64; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + if rhs != 0 { + self / rhs + } else { + 0 + } + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + self.wrapping_floor_div(rhs) + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + if rhs != 0 { + self % rhs + } else { + 0 + } + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self as f64 / rhs as f64 + } + } + }; +} + +impl_unsigned_pl_num_arith!(u8); +impl_unsigned_pl_num_arith!(u16); +impl_unsigned_pl_num_arith!(u32); +impl_unsigned_pl_num_arith!(u64); +impl_unsigned_pl_num_arith!(u128); + +macro_rules! impl_float_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = $T; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self.abs() + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + -self + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self + rhs + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self - rhs + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self * rhs + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + (l / r).floor() + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + (l / r).trunc() + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + l - r * (l / r).floor() + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self / rhs + } + } + }; +} + +impl_float_pl_num_arith!(f32); +impl_float_pl_num_arith!(f64); diff --git a/crates/polars-compute/src/arithmetic/signed.rs b/crates/polars-compute/src/arithmetic/signed.rs index 968c82ef96e2..de253a0aeca7 100644 --- a/crates/polars-compute/src/arithmetic/signed.rs +++ b/crates/polars-compute/src/arithmetic/signed.rs @@ -133,13 +133,13 @@ macro_rules! impl_signed_arith_kernel { } fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { - if lhs == 0 { - return rhs.fill_with(0); - } - let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0) + }; ret.with_validity(valid) } @@ -165,13 +165,13 @@ macro_rules! impl_signed_arith_kernel { } fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { - if lhs == 0 { - return rhs.fill_with(0); - } - let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 }); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 }) + }; ret.with_validity(valid) } @@ -205,13 +205,13 @@ macro_rules! impl_signed_arith_kernel { } fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { - if lhs == 0 { - return rhs.fill_with(0); - } - let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1) + }; ret.with_validity(valid) } diff --git a/crates/polars-compute/src/arithmetic/unsigned.rs b/crates/polars-compute/src/arithmetic/unsigned.rs index 82406a2f94c9..46fc0037597d 100644 --- a/crates/polars-compute/src/arithmetic/unsigned.rs +++ b/crates/polars-compute/src/arithmetic/unsigned.rs @@ -95,13 +95,13 @@ macro_rules! impl_unsigned_arith_kernel { } fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { - if lhs == 0 { - return rhs.fill_with(0); - } - let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| if x != 0 { lhs / x } else { 0 }); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| if x != 0 { lhs / x } else { 0 }) + }; ret.with_validity(valid) } @@ -125,13 +125,13 @@ macro_rules! impl_unsigned_arith_kernel { } fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { - if lhs == 0 { - return rhs.fill_with(0); - } - let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| if x != 0 { lhs % x } else { 0 }); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| if x != 0 { lhs % x } else { 0 }) + }; ret.with_validity(valid) } diff --git a/crates/polars-compute/src/arity.rs b/crates/polars-compute/src/arity.rs index 33c8b0eb0584..7f99752e7afe 100644 --- a/crates/polars-compute/src/arity.rs +++ b/crates/polars-compute/src/arity.rs @@ -52,9 +52,7 @@ where let len = arr.len(); // Reuse memory if possible. - if std::mem::size_of::() == std::mem::size_of::() - && std::mem::align_of::() == std::mem::align_of::() - { + if size_of::() == size_of::() && align_of::() == align_of::() { if let Some(values) = arr.get_mut_values() { let ptr = values.as_mut_ptr(); // SAFETY: checked same size & alignment I/O, NativeType is always Pod. @@ -93,9 +91,7 @@ where let validity = combine_validities_and(lhs.validity(), rhs.validity()); // Reuse memory if possible. - if std::mem::size_of::() == std::mem::size_of::() - && std::mem::align_of::() == std::mem::align_of::() - { + if size_of::() == size_of::() && align_of::() == align_of::() { if let Some(lv) = lhs.get_mut_values() { let lp = lv.as_mut_ptr(); let rp = rhs.values().as_ptr(); @@ -106,9 +102,7 @@ where return lhs.transmute::().with_validity(validity); } } - if std::mem::size_of::() == std::mem::size_of::() - && std::mem::align_of::() == std::mem::align_of::() - { + if size_of::() == size_of::() && align_of::() == align_of::() { if let Some(rv) = rhs.get_mut_values() { let lp = lhs.values().as_ptr(); let rp = rv.as_mut_ptr(); diff --git a/crates/polars-compute/src/cardinality.rs b/crates/polars-compute/src/cardinality.rs new file mode 100644 index 000000000000..d28efa9d051e --- /dev/null +++ b/crates/polars-compute/src/cardinality.rs @@ -0,0 +1,159 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeBinaryArray, PrimitiveArray, + Utf8Array, Utf8ViewArray, +}; +use arrow::datatypes::PhysicalType; +use arrow::types::Offset; +use arrow::with_match_primitive_type_full; +use polars_utils::total_ord::ToTotalOrd; + +use crate::hyperloglogplus::HyperLogLog; + +/// Get an estimate for the *cardinality* of the array (i.e. the number of unique values) +/// +/// This is not currently implemented for nested types. +pub fn estimate_cardinality(array: &dyn Array) -> usize { + if array.is_empty() { + return 0; + } + + if array.null_count() == array.len() { + return 1; + } + + // Estimate the cardinality with HyperLogLog + use PhysicalType as PT; + match array.dtype().to_physical_type() { + PT::Null => 1, + + PT::Boolean => { + let mut cardinality = 0; + + let array = array.as_any().downcast_ref::().unwrap(); + + cardinality += usize::from(array.has_nulls()); + + if let Some(unset_bits) = array.values().lazy_unset_bits() { + cardinality += 1 + usize::from(unset_bits != array.len()); + } else { + cardinality += 2; + } + + cardinality + }, + + PT::Primitive(primitive_type) => with_match_primitive_type_full!(primitive_type, |$T| { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.copied().unwrap_or_default(); + hll.add(&v.to_total_ord()); + } + } else { + for v in array.values_iter() { + hll.add(&v.to_total_ord()); + } + } + + hll.count() + }), + PT::FixedSizeBinary => { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() + }, + PT::Binary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::LargeBinary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::Utf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::LargeUtf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::BinaryView => { + binary_view_array_estimate(array.as_any().downcast_ref::().unwrap()) + }, + PT::Utf8View => binary_view_array_estimate( + &array + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(), + ), + PT::List => unimplemented!(), + PT::FixedSizeList => unimplemented!(), + PT::LargeList => unimplemented!(), + PT::Struct => unimplemented!(), + PT::Union => unimplemented!(), + PT::Map => unimplemented!(), + PT::Dictionary(_) => unimplemented!(), + } +} + +fn binary_offset_array_estimate(array: &BinaryArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} + +fn binary_view_array_estimate(array: &BinaryViewArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} diff --git a/crates/polars-compute/src/comparisons/simd.rs b/crates/polars-compute/src/comparisons/simd.rs index f855ed4ad1c0..95ea9707744a 100644 --- a/crates/polars-compute/src/comparisons/simd.rs +++ b/crates/polars-compute/src/comparisons/simd.rs @@ -17,7 +17,7 @@ where T: NativeType, F: FnMut(&[T; N], &[T; N]) -> M, { - assert_eq!(N, std::mem::size_of::() * 8); + assert_eq!(N, size_of::() * 8); assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -29,7 +29,7 @@ where let rhs_rest = rhs_chunks.remainder(); let num_masks = n.div_ceil(N); - let mut v: Vec = Vec::with_capacity(num_masks * std::mem::size_of::()); + let mut v: Vec = Vec::with_capacity(num_masks * size_of::()); let mut p = v.as_mut_ptr() as *mut M; for (l, r) in lhs_chunks.zip(rhs_chunks) { unsafe { @@ -53,7 +53,7 @@ where } unsafe { - v.set_len(num_masks * std::mem::size_of::()); + v.set_len(num_masks * size_of::()); } Bitmap::from_u8_vec(v, n) @@ -64,7 +64,7 @@ where T: NativeType, F: FnMut(&[T; N]) -> M, { - assert_eq!(N, std::mem::size_of::() * 8); + assert_eq!(N, size_of::() * 8); let n = arg.len(); let arg_buf = arg.values().as_slice(); @@ -72,7 +72,7 @@ where let arg_rest = arg_chunks.remainder(); let num_masks = n.div_ceil(N); - let mut v: Vec = Vec::with_capacity(num_masks * std::mem::size_of::()); + let mut v: Vec = Vec::with_capacity(num_masks * size_of::()); let mut p = v.as_mut_ptr() as *mut M; for a in arg_chunks { unsafe { @@ -91,7 +91,7 @@ where } unsafe { - v.set_len(num_masks * std::mem::size_of::()); + v.set_len(num_masks * size_of::()); } Bitmap::from_u8_vec(v, n) diff --git a/crates/polars-compute/src/filter/avx512.rs b/crates/polars-compute/src/filter/avx512.rs index 11466b137be8..237aed8ed483 100644 --- a/crates/polars-compute/src/filter/avx512.rs +++ b/crates/polars-compute/src/filter/avx512.rs @@ -5,7 +5,7 @@ use core::arch::x86_64::*; // structured functions. macro_rules! simd_filter { ($values: ident, $mask_bytes: ident, $out: ident, |$subchunk: ident, $m: ident: $MaskT: ty| $body:block) => {{ - const MASK_BITS: usize = std::mem::size_of::<$MaskT>() * 8; + const MASK_BITS: usize = <$MaskT>::BITS as usize; // Do a 64-element loop for sparse fast path. let chunks = $values.chunks_exact(64); diff --git a/crates/polars-compute/src/filter/primitive.rs b/crates/polars-compute/src/filter/primitive.rs index 9cc542b60978..4671fa4ff592 100644 --- a/crates/polars-compute/src/filter/primitive.rs +++ b/crates/polars-compute/src/filter/primitive.rs @@ -19,7 +19,7 @@ fn nop_filter<'a, T: Pod>( } pub fn filter_values(values: &[T], mask: &Bitmap) -> Vec { - match (std::mem::size_of::(), std::mem::align_of::()) { + match (size_of::(), align_of::()) { (1, 1) => cast_vec(filter_values_u8(cast_slice(values), mask)), (2, 2) => cast_vec(filter_values_u16(cast_slice(values), mask)), (4, 4) => cast_vec(filter_values_u32(cast_slice(values), mask)), diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index da56c65983db..a0957daeafcc 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -1,4 +1,6 @@ #![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "simd", feature(core_intrinsics))] // For fadd_algebraic. +#![cfg_attr(feature = "simd", allow(internal_features))] #![cfg_attr(feature = "simd", feature(avx512_target_feature))] #![cfg_attr( all(feature = "simd", target_arch = "x86_64"), @@ -10,6 +12,8 @@ use arrow::types::NativeType; pub mod arithmetic; pub mod arity; pub mod bitwise; +#[cfg(feature = "approx_unique")] +pub mod cardinality; pub mod comparisons; pub mod filter; pub mod float_sum; @@ -19,6 +23,7 @@ pub mod if_then_else; pub mod min_max; pub mod size; pub mod unique; +pub mod var_cov; // Trait to enable the scalar blanket implementation. pub trait NotSimdPrimitive: NativeType {} diff --git a/crates/polars-compute/src/var_cov.rs b/crates/polars-compute/src/var_cov.rs new file mode 100644 index 000000000000..d6c0267faec6 --- /dev/null +++ b/crates/polars-compute/src/var_cov.rs @@ -0,0 +1,327 @@ +// Some formulae: +// mean_x = sum(weight[i] * x[i]) / sum(weight) +// dp_xy = weighted sum of deviation products of variables x, y, written in +// the paper as simply XY. +// dp_xy = sum(weight[i] * (x[i] - mean_x) * (y[i] - mean_y)) +// +// cov(x, y) = dp_xy / sum(weight) +// var(x) = cov(x, x) +// +// Algorithms from: +// Numerically stable parallel computation of (co-)variance. +// Schubert, E., & Gertz, M. (2018). +// +// Key equations from the paper: +// (17) for mean update, (23) for dp update (and also Table 1). + +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType; +use num_traits::AsPrimitive; + +const CHUNK_SIZE: usize = 128; + +#[inline(always)] +fn alg_add(a: f64, b: f64) -> f64 { + #[cfg(feature = "simd")] + { + std::intrinsics::fadd_algebraic(a, b) + } + #[cfg(not(feature = "simd"))] + { + a + b + } +} + +fn alg_sum(it: impl IntoIterator) -> f64 { + it.into_iter().fold(0.0, alg_add) +} + +#[derive(Default, Clone)] +pub struct VarState { + weight: f64, + mean: f64, + dp: f64, +} + +#[derive(Default, Clone)] +pub struct CovState { + weight: f64, + mean_x: f64, + mean_y: f64, + dp_xy: f64, +} + +#[derive(Default, Clone)] +pub struct PearsonState { + weight: f64, + mean_x: f64, + mean_y: f64, + dp_xx: f64, + dp_xy: f64, + dp_yy: f64, +} + +impl VarState { + fn new(x: &[f64]) -> Self { + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let mean = alg_sum(x.iter().copied()) / weight; + Self { + weight, + mean, + dp: alg_sum(x.iter().map(|&xi| (xi - mean) * (xi - mean))), + } + } + + pub fn add_one(&mut self, x: f64) { + // Just a specialized version of + // self.combine(&Self { weight: 1.0, mean: x, dp: 0.0 }) + let new_weight = self.weight + 1.0; + let delta_mean = self.mean - x; + let new_mean = self.mean - delta_mean / new_weight; + self.dp += (new_mean - x) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean = self.mean - other.mean; + let new_mean = self.mean - delta_mean * other_weight_frac; + self.dp += other.dp + other.weight * (new_mean - other.mean) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + } + + pub fn finalize(&self, ddof: u8) -> Option { + if self.weight <= ddof as f64 { + None + } else { + Some(self.dp / (self.weight - ddof as f64)) + } + } +} + +impl CovState { + fn new(x: &[f64], y: &[f64]) -> Self { + assert!(x.len() == y.len()); + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let inv_weight = 1.0 / weight; + let mean_x = alg_sum(x.iter().copied()) * inv_weight; + let mean_y = alg_sum(y.iter().copied()) * inv_weight; + Self { + weight, + mean_x, + mean_y, + dp_xy: alg_sum( + x.iter() + .zip(y) + .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y)), + ), + } + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean_x = self.mean_x - other.mean_x; + let delta_mean_y = self.mean_y - other.mean_y; + let new_mean_x = self.mean_x - delta_mean_x * other_weight_frac; + let new_mean_y = self.mean_y - delta_mean_y * other_weight_frac; + self.dp_xy += other.dp_xy + other.weight * (new_mean_x - other.mean_x) * delta_mean_y; + self.weight = new_weight; + self.mean_x = new_mean_x; + self.mean_y = new_mean_y; + } + + pub fn finalize(&self, ddof: u8) -> Option { + if self.weight <= ddof as f64 { + None + } else { + Some(self.dp_xy / (self.weight - ddof as f64)) + } + } +} + +impl PearsonState { + fn new(x: &[f64], y: &[f64]) -> Self { + assert!(x.len() == y.len()); + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let inv_weight = 1.0 / weight; + let mean_x = alg_sum(x.iter().copied()) * inv_weight; + let mean_y = alg_sum(y.iter().copied()) * inv_weight; + let mut dp_xx = 0.0; + let mut dp_xy = 0.0; + let mut dp_yy = 0.0; + for (xi, yi) in x.iter().zip(y.iter()) { + dp_xx = alg_add(dp_xx, (xi - mean_x) * (xi - mean_x)); + dp_xy = alg_add(dp_xy, (xi - mean_x) * (yi - mean_y)); + dp_yy = alg_add(dp_yy, (yi - mean_y) * (yi - mean_y)); + } + Self { + weight, + mean_x, + mean_y, + dp_xx, + dp_xy, + dp_yy, + } + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean_x = self.mean_x - other.mean_x; + let delta_mean_y = self.mean_y - other.mean_y; + let new_mean_x = self.mean_x - delta_mean_x * other_weight_frac; + let new_mean_y = self.mean_y - delta_mean_y * other_weight_frac; + self.dp_xx += other.dp_xx + other.weight * (new_mean_x - other.mean_x) * delta_mean_x; + self.dp_xy += other.dp_xy + other.weight * (new_mean_x - other.mean_x) * delta_mean_y; + self.dp_yy += other.dp_yy + other.weight * (new_mean_y - other.mean_y) * delta_mean_y; + self.weight = new_weight; + self.mean_x = new_mean_x; + self.mean_y = new_mean_y; + } + + pub fn finalize(&self, _ddof: u8) -> f64 { + // The division by sample_weight - ddof on both sides cancels out. + let denom = (self.dp_xx * self.dp_yy).sqrt(); + if denom == 0.0 { + f64::NAN + } else { + self.dp_xy / denom + } + } +} + +fn chunk_as_float(it: I, mut f: F) +where + T: NativeType + AsPrimitive, + I: IntoIterator, + F: FnMut(&[f64]), +{ + let mut chunk = [0.0; CHUNK_SIZE]; + let mut i = 0; + for val in it { + if i >= CHUNK_SIZE { + f(&chunk); + i = 0; + } + chunk[i] = val.as_(); + i += 1; + } + if i > 0 { + f(&chunk[..i]); + } +} + +fn chunk_as_float_binary(it: I, mut f: F) +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, + I: IntoIterator, + F: FnMut(&[f64], &[f64]), +{ + let mut left_chunk = [0.0; CHUNK_SIZE]; + let mut right_chunk = [0.0; CHUNK_SIZE]; + let mut i = 0; + for (l, r) in it { + if i >= CHUNK_SIZE { + f(&left_chunk, &right_chunk); + i = 0; + } + left_chunk[i] = l.as_(); + right_chunk[i] = r.as_(); + i += 1; + } + if i > 0 { + f(&left_chunk[..i], &right_chunk[..i]); + } +} + +pub fn var(arr: &PrimitiveArray) -> VarState +where + T: NativeType + AsPrimitive, +{ + let mut out = VarState::default(); + if arr.has_nulls() { + chunk_as_float(arr.non_null_values_iter(), |chunk| { + out.combine(&VarState::new(chunk)) + }); + } else { + chunk_as_float(arr.values().iter().copied(), |chunk| { + out.combine(&VarState::new(chunk)) + }); + } + out +} + +pub fn cov(x: &PrimitiveArray, y: &PrimitiveArray) -> CovState +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, +{ + assert!(x.len() == y.len()); + let mut out = CovState::default(); + if x.has_nulls() || y.has_nulls() { + chunk_as_float_binary( + x.iter() + .zip(y.iter()) + .filter_map(|(l, r)| l.copied().zip(r.copied())), + |l, r| out.combine(&CovState::new(l, r)), + ); + } else { + chunk_as_float_binary( + x.values().iter().copied().zip(y.values().iter().copied()), + |l, r| out.combine(&CovState::new(l, r)), + ); + } + out +} + +pub fn pearson_corr(x: &PrimitiveArray, y: &PrimitiveArray) -> PearsonState +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, +{ + assert!(x.len() == y.len()); + let mut out = PearsonState::default(); + if x.has_nulls() || y.has_nulls() { + chunk_as_float_binary( + x.iter() + .zip(y.iter()) + .filter_map(|(l, r)| l.copied().zip(r.copied())), + |l, r| out.combine(&PearsonState::new(l, r)), + ); + } else { + chunk_as_float_binary( + x.values().iter().copied().zip(y.values().iter().copied()), + |l, r| out.combine(&PearsonState::new(l, r)), + ); + } + out +} diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index c9d68f0ee173..bb5cdc85cdac 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -17,7 +17,6 @@ polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true, optional = true } bitflags = { workspace = true } bytemuck = { workspace = true } chrono = { workspace = true, optional = true } @@ -37,6 +36,7 @@ regex = { workspace = true, optional = true } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } thiserror = { workspace = true } xxhash-rust = { workspace = true } @@ -100,7 +100,7 @@ partition_by = ["algorithm_group_by"] describe = [] timezones = ["temporal", "chrono", "chrono-tz", "arrow/chrono-tz", "arrow/timezones"] dynamic_group_by = ["dtype-datetime", "dtype-date"] -arrow_rs = ["arrow-array", "arrow/arrow_rs"] +list_arithmetic = [] # opt-in datatypes for Series dtype-date = ["temporal"] @@ -148,6 +148,7 @@ docs-selection = [ "describe", "partition_by", "algorithm_group_by", + "list_arithmetic", ] [package.metadata.docs.rs] diff --git a/crates/polars-core/src/chunked_array/builder/list/mod.rs b/crates/polars-core/src/chunked_array/builder/list/mod.rs index 645a2a168e90..c568f55b5675 100644 --- a/crates/polars-core/src/chunked_array/builder/list/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/list/mod.rs @@ -19,6 +19,8 @@ pub use null::*; pub use primitive::*; use super::*; +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_builder; pub trait ListBuilderTrait { fn append_opt_series(&mut self, opt_s: Option<&Series>) -> PolarsResult<()> { @@ -85,17 +87,17 @@ pub fn get_list_builder( value_capacity: usize, list_capacity: usize, name: PlSmallStr, -) -> PolarsResult> { +) -> Box { match inner_type_logical { #[cfg(feature = "dtype-categorical")] DataType::Categorical(Some(rev_map), ordering) => { - return Ok(create_categorical_chunked_listbuilder( + return create_categorical_chunked_listbuilder( name, *ordering, list_capacity, value_capacity, rev_map.clone(), - )) + ) }, #[cfg(feature = "dtype-categorical")] DataType::Enum(Some(rev_map), ordering) => { @@ -106,7 +108,7 @@ pub fn get_list_builder( value_capacity, (**rev_map).clone(), ); - return Ok(Box::new(list_builder)); + return Box::new(list_builder); }, _ => {}, } @@ -115,27 +117,34 @@ pub fn get_list_builder( match &physical_type { #[cfg(feature = "object")] - DataType::Object(_, _) => polars_bail!(opq = list_builder, &physical_type), + DataType::Object(_, _) => { + let builder = get_object_builder(PlSmallStr::EMPTY, 0).get_list_builder( + name, + value_capacity, + list_capacity, + ); + Box::new(builder) + }, #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => Ok(Box::new(AnonymousOwnedListBuilder::new( + DataType::Struct(_) => Box::new(AnonymousOwnedListBuilder::new( name, list_capacity, Some(inner_type_logical.clone()), - ))), - DataType::Null => Ok(Box::new(ListNullChunkedBuilder::new(name, list_capacity))), - DataType::List(_) => Ok(Box::new(AnonymousOwnedListBuilder::new( + )), + DataType::Null => Box::new(ListNullChunkedBuilder::new(name, list_capacity)), + DataType::List(_) => Box::new(AnonymousOwnedListBuilder::new( name, list_capacity, Some(inner_type_logical.clone()), - ))), + )), #[cfg(feature = "dtype-array")] - DataType::Array(..) => Ok(Box::new(AnonymousOwnedListBuilder::new( + DataType::Array(..) => Box::new(AnonymousOwnedListBuilder::new( name, list_capacity, Some(inner_type_logical.clone()), - ))), + )), #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => Ok(Box::new( + DataType::Decimal(_, _) => Box::new( ListPrimitiveChunkedBuilder::::new_with_values_type( name, list_capacity, @@ -143,7 +152,7 @@ pub fn get_list_builder( physical_type, inner_type_logical.clone(), ), - )), + ), _ => { macro_rules! get_primitive_builder { ($type:ty) => {{ @@ -177,13 +186,13 @@ pub fn get_list_builder( Box::new(builder) }}; } - Ok(match_dtype_to_logical_apply_macro!( + match_dtype_to_logical_apply_macro!( physical_type, get_primitive_builder, get_string_builder, get_binary_builder, get_bool_builder - )) + ) }, } } diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 1b8228d4ea69..fc8f993400d1 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -125,7 +125,7 @@ fn cast_single_to_struct( new_fields.push(Series::full_null(fld.name.clone(), length, &fld.dtype)); } - StructChunked::from_series(name, new_fields.iter()).map(|ca| ca.into_series()) + StructChunked::from_series(name, length, new_fields.iter()).map(|ca| ca.into_series()) } impl ChunkedArray @@ -206,10 +206,9 @@ where // - remain signed // - unsigned -> signed // this may still fail with overflow? - let dtype = self.dtype(); - let to_signed = dtype.is_signed_integer(); - let unsigned2unsigned = dtype.is_unsigned_integer() && dtype.is_unsigned_integer(); + let unsigned2unsigned = + self.dtype().is_unsigned_integer() && dtype.is_unsigned_integer(); let allowed = to_signed || unsigned2unsigned; if (allowed) diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 0a58d5192d41..e9b9efd2cae0 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -674,7 +674,11 @@ where right.offsets().range().try_into().unwrap(), ); - arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into()) + if missing { + arity::unary_mut_with_options(lhs, |a| broadcast_op(a, &values).into()) + } else { + arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into()) + } }, (1, _) => { let left = lhs.chunks()[0] @@ -699,9 +703,19 @@ where left.offsets().range().try_into().unwrap(), ); - arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into()) + if missing { + arity::unary_mut_with_options(rhs, |a| broadcast_op(a, &values).into()) + } else { + arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into()) + } + }, + _ => { + if missing { + arity::binary_mut_with_options(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } else { + arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } }, - _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), } } @@ -758,7 +772,7 @@ fn struct_helper( b: &StructChunked, op: F, reduce: R, - value: bool, + op_is_ne: bool, is_missing: bool, ) -> BooleanChunked where @@ -769,16 +783,43 @@ where let len_b = b.len(); let broadcasts = len_a == 1 || len_b == 1; if (a.len() != b.len() && !broadcasts) || a.struct_fields().len() != b.struct_fields().len() { - BooleanChunked::full(PlSmallStr::EMPTY, value, a.len()) + BooleanChunked::full(PlSmallStr::EMPTY, op_is_ne, a.len()) } else { let (a, b) = align_chunks_binary(a, b); + let mut out = a .fields_as_series() .iter() .zip(b.fields_as_series().iter()) .map(|(l, r)| op(l, r)) - .reduce(reduce) - .unwrap(); + .reduce(&reduce) + .unwrap_or_else(|| BooleanChunked::full(PlSmallStr::EMPTY, !op_is_ne, a.len())); + + if is_missing && (a.has_nulls() || b.has_nulls()) { + // Do some allocations so that we can use the Series dispatch, it otherwise + // gets complicated dealing with combinations of ==, != and broadcasting. + let default = || { + BooleanChunked::with_chunk(PlSmallStr::EMPTY, BooleanArray::from_slice([true])) + .into_series() + }; + let validity_to_series = |x| unsafe { + BooleanChunked::with_chunk( + PlSmallStr::EMPTY, + BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None), + ) + .into_series() + }; + + out = reduce( + out, + op( + &a.rechunk_validity() + .map_or_else(default, validity_to_series), + &b.rechunk_validity() + .map_or_else(default, validity_to_series), + ), + ) + } if !is_missing && (a.null_count() > 0 || b.null_count() > 0) { let mut a = a.into_owned(); @@ -874,7 +915,11 @@ where } } - arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into()) + if missing { + arity::unary_mut_with_options(lhs, |a| broadcast_op(a, right.values()).into()) + } else { + arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into()) + } }, (1, _) => { let left = lhs.chunks()[0] @@ -894,9 +939,19 @@ where } } - arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into()) + if missing { + arity::unary_mut_with_options(rhs, |a| broadcast_op(a, left.values()).into()) + } else { + arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into()) + } + }, + _ => { + if missing { + arity::binary_mut_with_options(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } else { + arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } }, - _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), } } diff --git a/crates/polars-core/src/chunked_array/float.rs b/crates/polars-core/src/chunked_array/float.rs index 8376629cc403..5d9bae240062 100644 --- a/crates/polars-core/src/chunked_array/float.rs +++ b/crates/polars-core/src/chunked_array/float.rs @@ -1,4 +1,3 @@ -use arrow::legacy::kernels::float::*; use arrow::legacy::kernels::set::set_at_nulls; use num_traits::Float; use polars_utils::total_ord::{canonical_f32, canonical_f64}; @@ -12,16 +11,16 @@ where T::Native: Float, { pub fn is_nan(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_nan::) + unary_elementwise_values(self, |x| x.is_nan()) } pub fn is_not_nan(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_not_nan::) + unary_elementwise_values(self, |x| !x.is_nan()) } pub fn is_finite(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_finite) + unary_elementwise_values(self, |x| x.is_finite()) } pub fn is_infinite(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_infinite) + unary_elementwise_values(self, |x| x.is_infinite()) } #[must_use] diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 33e984b94e0f..c84dc20a5d63 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -211,6 +211,7 @@ where /// Create a new [`ChunkedArray`] from existing chunks. /// /// # Safety + /// /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. pub unsafe fn from_chunks_and_dtype( name: PlSmallStr, @@ -225,10 +226,15 @@ where assert_eq!(chunks[0].dtype(), &dtype.to_arrow(CompatLevel::newest())) } } - let field = Arc::new(Field::new(name, dtype)); - ChunkedArray::new_with_compute_len(field, chunks) + + Self::from_chunks_and_dtype_unchecked(name, chunks, dtype) } + /// Create a new [`ChunkedArray`] from existing chunks. + /// + /// # Safety + /// + /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. pub(crate) unsafe fn from_chunks_and_dtype_unchecked( name: PlSmallStr, chunks: Vec, diff --git a/crates/polars-core/src/chunked_array/from_iterator.rs b/crates/polars-core/src/chunked_array/from_iterator.rs index 72f2bc8c60cb..ba9e8d1e6ccc 100644 --- a/crates/polars-core/src/chunked_array/from_iterator.rs +++ b/crates/polars-core/src/chunked_array/from_iterator.rs @@ -81,12 +81,12 @@ impl PolarsAsRef for &str {} // &["foo", "bar"] impl PolarsAsRef for &&str {} -impl<'a> PolarsAsRef for Cow<'a, str> {} +impl PolarsAsRef for Cow<'_, str> {} impl PolarsAsRef<[u8]> for Vec {} impl PolarsAsRef<[u8]> for &[u8] {} // TODO: remove! impl PolarsAsRef<[u8]> for &&[u8] {} -impl<'a> PolarsAsRef<[u8]> for Cow<'a, [u8]> {} +impl PolarsAsRef<[u8]> for Cow<'_, [u8]> {} impl FromIterator for StringChunked where @@ -142,8 +142,7 @@ where capacity * 5, capacity, PlSmallStr::EMPTY, - ) - .unwrap(); + ); builder.append_series(v.borrow()).unwrap(); for s in it { @@ -199,42 +198,23 @@ impl FromIterator> for ListChunked { } builder.finish() } else { - match first_s.dtype() { - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let mut builder = - first_s.get_list_builder(PlSmallStr::EMPTY, capacity * 5, capacity); - for _ in 0..init_null_count { - builder.append_null(); - } - builder.append_series(first_s).unwrap(); - - for opt_s in it { - builder.append_opt_series(opt_s.as_ref()).unwrap(); - } - builder.finish() - }, - _ => { - // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. - let mut builder = get_list_builder( - first_s.dtype(), - capacity * 5, - capacity, - PlSmallStr::EMPTY, - ) - .unwrap(); + // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. + let mut builder = get_list_builder( + first_s.dtype(), + capacity * 5, + capacity, + PlSmallStr::EMPTY, + ); - for _ in 0..init_null_count { - builder.append_null(); - } - builder.append_series(first_s).unwrap(); + for _ in 0..init_null_count { + builder.append_null(); + } + builder.append_series(first_s).unwrap(); - for opt_s in it { - builder.append_opt_series(opt_s.as_ref()).unwrap(); - } - builder.finish() - }, + for opt_s in it { + builder.append_opt_series(opt_s.as_ref()).unwrap(); } + builder.finish() } }, } diff --git a/crates/polars-core/src/chunked_array/from_iterator_par.rs b/crates/polars-core/src/chunked_array/from_iterator_par.rs index 5c9abf4620af..a90b27da5722 100644 --- a/crates/polars-core/src/chunked_array/from_iterator_par.rs +++ b/crates/polars-core/src/chunked_array/from_iterator_par.rs @@ -177,33 +177,13 @@ fn materialize_list( value_capacity: usize, list_capacity: usize, ) -> ListChunked { - match &dtype { - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let s = vectors - .iter() - .flatten() - .find_map(|opt_s| opt_s.as_ref()) - .unwrap(); - let mut builder = s.get_list_builder(name, value_capacity, list_capacity); - - for v in vectors { - for val in v { - builder.append_opt_series(val.as_ref()).unwrap(); - } - } - builder.finish() - }, - dtype => { - let mut builder = get_list_builder(dtype, value_capacity, list_capacity, name).unwrap(); - for v in vectors { - for val in v { - builder.append_opt_series(val.as_ref()).unwrap(); - } - } - builder.finish() - }, + let mut builder = get_list_builder(&dtype, value_capacity, list_capacity, name); + for v in vectors { + for val in v { + builder.append_opt_series(val.as_ref()).unwrap(); + } } + builder.finish() } impl FromParallelIterator> for ListChunked { diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index 728ffc5a8cff..a87d888968f3 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -25,7 +25,7 @@ pub trait PolarsIterator: ExactSizeIterator + DoubleEndedIterator + Send + Sync + TrustedLen { } -unsafe impl<'a, I> TrustedLen for Box + 'a> {} +unsafe impl TrustedLen for Box + '_> {} /// Implement [`PolarsIterator`] for every iterator that implements the needed traits. impl PolarsIterator for T where @@ -79,7 +79,7 @@ impl<'a> BoolIterNoNull<'a> { } } -impl<'a> Iterator for BoolIterNoNull<'a> { +impl Iterator for BoolIterNoNull<'_> { type Item = bool; fn next(&mut self) -> Option { @@ -100,7 +100,7 @@ impl<'a> Iterator for BoolIterNoNull<'a> { } } -impl<'a> DoubleEndedIterator for BoolIterNoNull<'a> { +impl DoubleEndedIterator for BoolIterNoNull<'_> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -112,7 +112,7 @@ impl<'a> DoubleEndedIterator for BoolIterNoNull<'a> { } /// all arrays have known size. -impl<'a> ExactSizeIterator for BoolIterNoNull<'a> {} +impl ExactSizeIterator for BoolIterNoNull<'_> {} impl BooleanChunked { #[allow(clippy::wrong_self_convention)] @@ -339,7 +339,7 @@ impl<'a> FixedSizeListIterNoNull<'a> { } #[cfg(feature = "dtype-array")] -impl<'a> Iterator for FixedSizeListIterNoNull<'a> { +impl Iterator for FixedSizeListIterNoNull<'_> { type Item = Series; fn next(&mut self) -> Option { @@ -367,7 +367,7 @@ impl<'a> Iterator for FixedSizeListIterNoNull<'a> { } #[cfg(feature = "dtype-array")] -impl<'a> DoubleEndedIterator for FixedSizeListIterNoNull<'a> { +impl DoubleEndedIterator for FixedSizeListIterNoNull<'_> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -388,7 +388,7 @@ impl<'a> DoubleEndedIterator for FixedSizeListIterNoNull<'a> { /// all arrays have known size. #[cfg(feature = "dtype-array")] -impl<'a> ExactSizeIterator for FixedSizeListIterNoNull<'a> {} +impl ExactSizeIterator for FixedSizeListIterNoNull<'_> {} #[cfg(feature = "dtype-array")] impl ArrayChunked { diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index f84c7d751874..2c48da805171 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -18,7 +18,7 @@ pub struct AmortizedListIter<'a, I: Iterator>> { inner_dtype: DataType, } -impl<'a, I: Iterator>> AmortizedListIter<'a, I> { +impl>> AmortizedListIter<'_, I> { pub(crate) unsafe fn new( len: usize, series_container: Series, @@ -37,7 +37,7 @@ impl<'a, I: Iterator>> AmortizedListIter<'a, I> { } } -impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a, I> { +impl>> Iterator for AmortizedListIter<'_, I> { type Item = Option; fn next(&mut self) -> Option { @@ -106,8 +106,8 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a // # Safety // we correctly implemented size_hint -unsafe impl<'a, I: Iterator>> TrustedLen for AmortizedListIter<'a, I> {} -impl<'a, I: Iterator>> ExactSizeIterator for AmortizedListIter<'a, I> {} +unsafe impl>> TrustedLen for AmortizedListIter<'_, I> {} +impl>> ExactSizeIterator for AmortizedListIter<'_, I> {} impl ListChunked { /// This is an iterator over a [`ListChunked`] that saves allocations. @@ -152,7 +152,7 @@ impl ListChunked { let (s, ptr) = unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) }; - // SAFETY: ptr belongs the the Series.. + // SAFETY: ptr belongs the Series.. unsafe { AmortizedListIter::new( self.len(), @@ -393,7 +393,7 @@ mod test { #[test] fn test_iter_list() { - let mut builder = get_list_builder(&DataType::Int32, 10, 10, PlSmallStr::EMPTY).unwrap(); + let mut builder = get_list_builder(&DataType::Int32, 10, 10, PlSmallStr::EMPTY); builder .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 2, 3])) .unwrap(); diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index 8b730966b1bc..7aff61172f04 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -46,25 +46,6 @@ impl ListChunked { } } - /// Returns an iterator over the offsets of this chunked array. - /// - /// The offsets are returned as though the array consisted of a single chunk. - pub fn iter_offsets(&self) -> impl Iterator + '_ { - let mut offsets = self.downcast_iter().map(|arr| arr.offsets().iter()); - let first_iter = offsets.next().unwrap(); - - // The first offset doesn't have to be 0, it can be sliced to `n` in the array. - // So we must correct for this. - let correction = first_iter.clone().next().unwrap(); - - OffsetsIterator { - current_offsets_iter: first_iter, - current_adjusted_offset: 0, - offset_adjustment: -correction, - offsets_iters: offsets, - } - } - /// Ignore the list indices and apply `func` to the inner type as [`Series`]. pub fn apply_to_inner( &self, @@ -110,33 +91,14 @@ impl ListChunked { ) }) } -} - -pub struct OffsetsIterator<'a, N> -where - N: Iterator>, -{ - offsets_iters: N, - current_offsets_iter: std::slice::Iter<'a, i64>, - current_adjusted_offset: i64, - offset_adjustment: i64, -} -impl<'a, N> Iterator for OffsetsIterator<'a, N> -where - N: Iterator>, -{ - type Item = i64; - - fn next(&mut self) -> Option { - if let Some(offset) = self.current_offsets_iter.next() { - self.current_adjusted_offset = offset + self.offset_adjustment; - Some(self.current_adjusted_offset) - } else { - self.current_offsets_iter = self.offsets_iters.next()?; - let first = self.current_offsets_iter.next().unwrap(); - self.offset_adjustment = self.current_adjusted_offset - first; - self.next() - } + pub fn rechunk_and_trim_to_normalized_offsets(&self) -> Self { + Self::with_chunk( + self.name().clone(), + self.rechunk() + .downcast_get(0) + .unwrap() + .trim_to_normalized_offsets_recursive(), + ) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index a59ff68e40d9..8ccd455e4bd0 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -203,6 +203,24 @@ impl CategoricalChunked { } } + /// Create a [`CategoricalChunked`] from a physical array and dtype. + /// + /// # Safety + /// It's not checked that the indices are in-bounds or that the dtype is + /// correct. + pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self { + debug_assert!(matches!( + dtype, + DataType::Enum { .. } | DataType::Categorical { .. } + )); + let mut logical = Logical::::new_logical::(idx); + logical.2 = Some(dtype); + Self { + physical: logical, + bit_settings: Default::default(), + } + } + /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`. /// /// # Safety @@ -425,7 +443,7 @@ pub struct CatIter<'a> { iter: Box> + 'a>, } -unsafe impl<'a> TrustedLen for CatIter<'a> {} +unsafe impl TrustedLen for CatIter<'_> {} impl<'a> Iterator for CatIter<'a> { type Item = Option<&'a str>; @@ -445,7 +463,7 @@ impl<'a> Iterator for CatIter<'a> { } } -impl<'a> ExactSizeIterator for CatIter<'a> {} +impl ExactSizeIterator for CatIter<'_> {} #[cfg(test)] mod test { diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index 7b851c5def54..707fc79f0364 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -66,8 +66,9 @@ impl CategoricalChunked { let mut counts = groups.group_count(); counts.rename(PlSmallStr::from_static("counts")); + let height = counts.len(); let cols = vec![values.into_series().into(), counts.into_series().into()]; - let df = unsafe { DataFrame::new_no_checks(cols) }; + let df = unsafe { DataFrame::new_no_checks(height, cols) }; df.sort( ["counts"], SortMultipleOptions::default().with_order_descending(true), diff --git a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs index a0bd2687af63..d5c6d6e857b8 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs @@ -2,12 +2,12 @@ use std::hash::{Hash, Hasher}; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use hashbrown::hash_map::RawEntryMut; +use hashbrown::hash_table::Entry; +use hashbrown::HashTable; use once_cell::sync::Lazy; use polars_utils::aliases::PlRandomState; use polars_utils::pl_str::PlSmallStr; -use crate::datatypes::{InitHashMaps2, PlIdHashMap}; use crate::hashing::_HASHMAP_INIT_SIZE; /// We use atomic reference counting to determine how many threads use the @@ -131,7 +131,7 @@ impl Hash for Key { } pub(crate) struct SCacheInner { - map: PlIdHashMap, + map: HashTable, pub(crate) uuid: u32, payloads: Vec, } @@ -149,26 +149,23 @@ impl SCacheInner { #[inline] pub(crate) fn insert_from_hash(&mut self, h: u64, s: &str) -> u32 { let mut global_idx = self.payloads.len() as u32; - // Note that we don't create the PlSmallStr to search the key in the hashmap - // as PlSmallStr may allocate a string - let entry = self.map.raw_entry_mut().from_hash(h, |key| { - (key.hash == h) && { - let pos = key.idx as usize; - let value = unsafe { self.payloads.get_unchecked(pos) }; + let entry = self.map.entry( + h, + |k| { + let value = unsafe { self.payloads.get_unchecked(k.idx as usize) }; s == value.as_str() - } - }); + }, + |k| k.hash, + ); match entry { - RawEntryMut::Occupied(entry) => { - global_idx = entry.key().idx; + Entry::Occupied(entry) => { + global_idx = entry.get().idx; }, - RawEntryMut::Vacant(entry) => { + Entry::Vacant(entry) => { let idx = self.payloads.len() as u32; let key = Key::new(h, idx); - entry.insert_hashed_nocheck(h, key, ()); - - // only just now we allocate the string + entry.insert(key); self.payloads.push(PlSmallStr::from_str(s)); }, } @@ -179,15 +176,11 @@ impl SCacheInner { pub(crate) fn get_cat(&self, s: &str) -> Option { let h = StringCache::get_hash_builder().hash_one(s); self.map - .raw_entry() - .from_hash(h, |key| { - (key.hash == h) && { - let pos = key.idx as usize; - let value = unsafe { self.payloads.get_unchecked(pos) }; - s == value.as_str() - } + .find(h, |k| { + let value = unsafe { self.payloads.get_unchecked(k.idx as usize) }; + s == value.as_str() }) - .map(|(k, _)| k.idx) + .map(|k| k.idx) } #[inline] @@ -200,7 +193,7 @@ impl SCacheInner { impl Default for SCacheInner { fn default() -> Self { Self { - map: PlIdHashMap::with_capacity(_HASHMAP_INIT_SIZE), + map: HashTable::with_capacity(_HASHMAP_INIT_SIZE), uuid: STRING_CACHE_UUID_CTR.fetch_add(1, Ordering::AcqRel), payloads: Vec::with_capacity(_HASHMAP_INIT_SIZE), } diff --git a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs index e143a59a7f7b..5279099f1cd7 100644 --- a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs @@ -85,16 +85,12 @@ impl EnumChunkedBuilder { let length = arr.len() as IdxSize; let ca = unsafe { UInt32Chunked::new_with_dims( - Arc::new(Field::new( - self.name, - DataType::Enum(Some(self.rev.clone()), self.ordering), - )), + Arc::new(Field::new(self.name, DataType::UInt32)), vec![Box::new(arr)], length, null_count, ) }; - // SAFETY: keys and values are in bounds unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering) diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 45f63847e97f..9383634cdcac 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -80,7 +80,7 @@ pub(crate) fn get_object_type() -> DataType { Box::new(ObjectChunkedBuilder::::new(name, capacity)) as Box }); - let object_size = std::mem::size_of::(); + let object_size = size_of::(); let physical_dtype = ArrowDataType::FixedSizeBinary(object_size); let registry = ObjectRegistry::new(object_builder, physical_dtype); diff --git a/crates/polars-core/src/chunked_array/object/extension/list.rs b/crates/polars-core/src/chunked_array/object/extension/list.rs index 1918039d647e..2d34315c378d 100644 --- a/crates/polars-core/src/chunked_array/object/extension/list.rs +++ b/crates/polars-core/src/chunked_array/object/extension/list.rs @@ -18,7 +18,7 @@ impl ObjectChunked { } } -struct ExtensionListBuilder { +pub(crate) struct ExtensionListBuilder { values_builder: ObjectChunkedBuilder, offsets: Vec, fast_explode: bool, diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 5a049da4a01f..f9167b200211 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -1,5 +1,5 @@ pub(crate) mod drop; -mod list; +pub(super) mod list; pub(crate) mod polars_extension; use std::mem; @@ -29,7 +29,7 @@ pub fn set_polars_allow_extension(toggle: bool) { /// `n_t_vals` must represent the correct number of `T` values in that allocation unsafe fn create_drop(mut ptr: *const u8, n_t_vals: usize) -> Box { Box::new(move || { - let t_size = std::mem::size_of::() as isize; + let t_size = size_of::() as isize; for _ in 0..n_t_vals { let _ = std::ptr::read_unaligned(ptr as *const T); ptr = ptr.offset(t_size) @@ -55,7 +55,7 @@ impl Drop for ExtensionSentinel { // https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8d // not entirely sure if padding bytes in T are initialized or not. unsafe fn any_as_u8_slice(p: &T) -> &[u8] { - std::slice::from_raw_parts((p as *const T) as *const u8, std::mem::size_of::()) + std::slice::from_raw_parts((p as *const T) as *const u8, size_of::()) } /// Create an extension Array that can be sent to arrow and (once wrapped in `[PolarsExtension]` will @@ -67,8 +67,8 @@ pub(crate) fn create_extension> + TrustedLen, T: Si if !(POLARS_ALLOW_EXTENSION.load(Ordering::Relaxed) || std::env::var(env).is_ok()) { panic!("creating extension types not allowed - try setting the environment variable {env}") } - let t_size = std::mem::size_of::(); - let t_alignment = std::mem::align_of::(); + let t_size = size_of::(); + let t_alignment = align_of::(); let n_t_vals = iter.size_hint().1.unwrap(); let mut buf = Vec::with_capacity(n_t_vals * t_size); diff --git a/crates/polars-core/src/chunked_array/object/iterator.rs b/crates/polars-core/src/chunked_array/object/iterator.rs index 7a5c6e00b590..7abb9c46f4ee 100644 --- a/crates/polars-core/src/chunked_array/object/iterator.rs +++ b/crates/polars-core/src/chunked_array/object/iterator.rs @@ -54,7 +54,7 @@ impl<'a, T: PolarsObject> std::iter::Iterator for ObjectIter<'a, T> { } } -impl<'a, T: PolarsObject> std::iter::DoubleEndedIterator for ObjectIter<'a, T> { +impl std::iter::DoubleEndedIterator for ObjectIter<'_, T> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -75,7 +75,7 @@ impl<'a, T: PolarsObject> std::iter::DoubleEndedIterator for ObjectIter<'a, T> { } /// all arrays have known size. -impl<'a, T: PolarsObject> std::iter::ExactSizeIterator for ObjectIter<'a, T> {} +impl std::iter::ExactSizeIterator for ObjectIter<'_, T> {} impl<'a, T: PolarsObject> IntoIterator for &'a ObjectArray { type Item = Option<&'a T>; diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index a7e3d2f9952d..8f4711976856 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -164,7 +164,7 @@ where } fn dtype(&self) -> &ArrowDataType { - &ArrowDataType::FixedSizeBinary(std::mem::size_of::()) + &ArrowDataType::FixedSizeBinary(size_of::()) } fn slice(&mut self, offset: usize, length: usize) { @@ -275,7 +275,7 @@ impl StaticArray for ObjectArray { impl ParameterFreeDtypeStaticArray for ObjectArray { fn get_dtype() -> ArrowDataType { - ArrowDataType::FixedSizeBinary(std::mem::size_of::()) + ArrowDataType::FixedSizeBinary(size_of::()) } } diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index e84c7ab69ba5..4bda1162bb94 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -13,7 +13,7 @@ use polars_utils::pl_str::PlSmallStr; use crate::chunked_array::object::builder::ObjectChunkedBuilder; use crate::datatypes::AnyValue; -use crate::prelude::PolarsObject; +use crate::prelude::{ListBuilderTrait, PolarsObject}; use crate::series::{IntoSeries, Series}; /// Takes a `name` and `capacity` and constructs a new builder. @@ -71,6 +71,13 @@ pub trait AnonymousObjectBuilder { /// Take the current state and materialize as a [`Series`] /// the builder should not be used after that. fn to_series(&mut self) -> Series; + + fn get_list_builder( + &self, + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box; } impl AnonymousObjectBuilder for ObjectChunkedBuilder { @@ -87,6 +94,18 @@ impl AnonymousObjectBuilder for ObjectChunkedBuilder { let builder = std::mem::take(self); builder.finish().into_series() } + fn get_list_builder( + &self, + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box { + Box::new(super::extension::list::ExtensionListBuilder::::new( + name, + values_capacity, + list_capacity, + )) + } } pub fn register_object_builder( diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index cf79b0acb473..0a059eb54274 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -11,6 +11,7 @@ use num_traits::{Float, One, ToPrimitive, Zero}; use polars_compute::float_sum; use polars_compute::min_max::MinMaxKernel; use polars_utils::min_max::MinMax; +use polars_utils::sync::SyncPtr; pub use quantile::*; pub use var::*; @@ -369,12 +370,8 @@ where ::Simd: Add::Simd> + compute::aggregate::Sum, { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -385,12 +382,8 @@ where } impl QuantileAggSeries for Float32Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float32, v.into())) } @@ -401,12 +394,8 @@ impl QuantileAggSeries for Float32Chunked { } impl QuantileAggSeries for Float64Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -553,17 +542,59 @@ impl CategoricalChunked { #[cfg(feature = "dtype-categorical")] impl ChunkAggSeries for CategoricalChunked { fn min_reduce(&self) -> Scalar { - let av: AnyValue = self.min_categorical().into(); - Scalar::new(DataType::String, av.into_static()) + match self.dtype() { + DataType::Enum(r, _) => match self.physical().min() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(_, _) => { + let av: AnyValue = self.min_categorical().into(); + Scalar::new(DataType::String, av.into_static()) + }, + _ => unreachable!(), + } } fn max_reduce(&self) -> Scalar { - let av: AnyValue = self.max_categorical().into(); - Scalar::new(DataType::String, av.into_static()) + match self.dtype() { + DataType::Enum(r, _) => match self.physical().max() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(_, _) => { + let av: AnyValue = self.max_categorical().into(); + Scalar::new(DataType::String, av.into_static()) + }, + _ => unreachable!(), + } } } impl BinaryChunked { - pub(crate) fn max_binary(&self) -> Option<&[u8]> { + pub fn max_binary(&self) -> Option<&[u8]> { if self.is_empty() { return None; } @@ -587,7 +618,7 @@ impl BinaryChunked { } } - pub(crate) fn min_binary(&self) -> Option<&[u8]> { + pub fn min_binary(&self) -> Option<&[u8]> { if self.is_empty() { return None; } @@ -735,19 +766,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_f64.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i64.quantile(0.9, interpol).unwrap(), None); + for method in methods { + assert_eq!(test_f32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_f64.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i64.quantile(0.9, method).unwrap(), None); } } @@ -758,19 +790,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i64.quantile(0.5, interpol).unwrap(), Some(1.0)); + for method in methods { + assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0)); } } @@ -793,37 +826,38 @@ mod test { &[None, Some(1i64), Some(5i64), Some(1i64)], ); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.0, interpol).unwrap(), test_f32.min()); - assert_eq!(test_f32.quantile(1.0, interpol).unwrap(), test_f32.max()); + for method in methods { + assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min()); + assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max()); assert_eq!( - test_i32.quantile(0.0, interpol).unwrap().unwrap(), + test_i32.quantile(0.0, method).unwrap().unwrap(), test_i32.min().unwrap() as f64 ); assert_eq!( - test_i32.quantile(1.0, interpol).unwrap().unwrap(), + test_i32.quantile(1.0, method).unwrap().unwrap(), test_i32.max().unwrap() as f64 ); - assert_eq!(test_f64.quantile(0.0, interpol).unwrap(), test_f64.min()); - assert_eq!(test_f64.quantile(1.0, interpol).unwrap(), test_f64.max()); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), test_f64.median()); + assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min()); + assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max()); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median()); assert_eq!( - test_i64.quantile(0.0, interpol).unwrap().unwrap(), + test_i64.quantile(0.0, method).unwrap().unwrap(), test_i64.min().unwrap() as f64 ); assert_eq!( - test_i64.quantile(1.0, interpol).unwrap().unwrap(), + test_i64.quantile(1.0, method).unwrap().unwrap(), test_i64.max().unwrap() as f64 ); } @@ -837,72 +871,56 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(5.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(3.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(3.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(3.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert!( + (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001 + ); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); - assert!( - (ca.quantile(0.6, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - - 3.4) - .abs() - < 0.0000001 + assert_eq!( + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(3.0) ); let ca = UInt32Chunked::new( @@ -922,68 +940,54 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(2.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(6.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(5.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(6.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(7.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(6.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.6) + ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(6.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(5.0) ); } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index d6218e81d463..f7716c864559 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -4,11 +4,7 @@ pub trait QuantileAggSeries { /// Get the median of the [`ChunkedArray`] as a new [`Series`] of length 1. fn median_reduce(&self) -> Scalar; /// Get the quantile of the [`ChunkedArray`] as a new [`Series`] of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult; + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult; } /// helper @@ -16,18 +12,23 @@ fn quantile_idx( quantile: f64, length: usize, null_count: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> (usize, f64, usize) { - let float_idx = ((length - null_count) as f64 - 1.0) * quantile + null_count as f64; - let mut base_idx = match interpol { - QuantileInterpolOptions::Nearest => { + let nonnull_count = (length - null_count) as f64; + let float_idx = (nonnull_count - 1.0) * quantile + null_count as f64; + let mut base_idx = match method { + QuantileMethod::Nearest => { let idx = float_idx.round() as usize; - return (float_idx.round() as usize, 0.0, idx); + return (idx, 0.0, idx); + }, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { + float_idx as usize + }, + QuantileMethod::Higher => float_idx.ceil() as usize, + QuantileMethod::Equiprobable => { + let idx = ((nonnull_count * quantile).ceil() - 1.0).max(0.0) as usize + null_count; + return (idx, 0.0, idx); }, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => float_idx as usize, - QuantileInterpolOptions::Higher => float_idx.ceil() as usize, }; base_idx = base_idx.clamp(0, length - 1); @@ -57,7 +58,7 @@ fn midpoint_interpol(lower: T, upper: T) -> T { fn quantile_slice( vals: &mut [T], quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { polars_ensure!((0.0..=1.0).contains(&quantile), ComputeError: "quantile should be between 0.0 and 1.0", @@ -68,21 +69,21 @@ fn quantile_slice( if vals.len() == 1 { return Ok(vals[0].to_f64()); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, method); let (_lhs, lower, rhs) = vals.select_nth_unstable_by(idx, TotalOrd::tot_cmp); if idx == top_idx { Ok(lower.to_f64()) } else { - match interpol { - QuantileInterpolOptions::Midpoint => { + match method { + QuantileMethod::Midpoint => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(Some(midpoint_interpol( lower.to_f64().unwrap(), upper.to_f64().unwrap(), ))) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(linear_interpol( lower.to_f64().unwrap(), @@ -100,7 +101,7 @@ fn quantile_slice( fn generic_quantile( ca: ChunkedArray, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> where T: PolarsNumericType, @@ -117,12 +118,12 @@ where return Ok(None); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, method); let sorted = ca.sort(false); let lower = sorted.get(idx).map(|v| v.to_f64().unwrap()); - let opt = match interpol { - QuantileInterpolOptions::Midpoint => { + let opt = match method { + QuantileMethod::Midpoint => { if top_idx == idx { lower } else { @@ -130,7 +131,7 @@ where midpoint_interpol(lower.unwrap(), upper.unwrap()).to_f64() } }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { if top_idx == idx { lower } else { @@ -149,22 +150,18 @@ where T: PolarsIntegerType, T::Native: TotalOrd, { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -177,61 +174,52 @@ where pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } impl ChunkQuantile for Float32Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let out = if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) }; out.map(|v| v.map(|v| v as f32)) } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } impl ChunkQuantile for Float64Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -239,20 +227,19 @@ impl Float64Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } @@ -260,20 +247,19 @@ impl Float32Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol).map(|v| v.map(|v| v as f32)) + quantile_slice(slice, quantile, method).map(|v| v.map(|v| v as f32)) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs index 1ca04cc2b30e..ea332f0cc432 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs @@ -1,4 +1,4 @@ -use arity::unary_elementwise_values; +use polars_compute::var_cov::VarState; use super::*; @@ -15,20 +15,11 @@ where ChunkedArray: ChunkAgg, { fn var(&self, ddof: u8) -> Option { - let n_values = self.len() - self.null_count(); - if n_values <= ddof as usize { - return None; + let mut out = VarState::default(); + for arr in self.downcast_iter() { + out.combine(&polars_compute::var_cov::var(arr)) } - - let mean = self.mean()?; - let squared: Float64Chunked = unary_elementwise_values(self, |value| { - let tmp = value.to_f64().unwrap() - mean; - tmp * tmp - }); - - squared - .sum() - .map(|sum| sum / (n_values as f64 - ddof as f64)) + out.finalize(ddof) } fn std(&self, ddof: u8) -> Option { diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 383c76d63600..1cbf0da390e7 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -132,7 +132,7 @@ where impl ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, for<'a> T::Physical<'a>: TotalOrd, { /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs index 7b20d77e2444..7236fde9993c 100644 --- a/crates/polars-core/src/chunked_array/ops/bit_repr.rs +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -8,8 +8,8 @@ use crate::series::BitRepr; fn reinterpret_chunked_array( ca: &ChunkedArray, ) -> ChunkedArray { - assert!(std::mem::size_of::() == std::mem::size_of::()); - assert!(std::mem::align_of::() == std::mem::align_of::()); + assert!(size_of::() == size_of::()); + assert!(align_of::() == align_of::()); let chunks = ca.downcast_iter().map(|array| { let buf = array.values().clone(); @@ -29,8 +29,8 @@ fn reinterpret_chunked_array( fn reinterpret_list_chunked( ca: &ListChunked, ) -> ListChunked { - assert!(std::mem::size_of::() == std::mem::size_of::()); - assert!(std::mem::align_of::() == std::mem::align_of::()); + assert!(size_of::() == size_of::()); + assert!(align_of::() == align_of::()); let chunks = ca.downcast_iter().map(|array| { let inner_arr = array @@ -105,7 +105,7 @@ where T: PolarsNumericType, { fn to_bit_repr(&self) -> BitRepr { - let is_large = std::mem::size_of::() == 8; + let is_large = size_of::() == 8; if is_large { if matches!(self.dtype(), DataType::UInt64) { @@ -118,7 +118,7 @@ where BitRepr::Large(reinterpret_chunked_array(self)) } else { - BitRepr::Small(if std::mem::size_of::() == 4 { + BitRepr::Small(if size_of::() == 4 { if matches!(self.dtype(), DataType::UInt32) { let ca = self.clone(); // Convince the compiler we are this type. This preserves flags. diff --git a/crates/polars-core/src/chunked_array/ops/bits.rs b/crates/polars-core/src/chunked_array/ops/bits.rs new file mode 100644 index 000000000000..178df3a73d09 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/bits.rs @@ -0,0 +1,21 @@ +use super::BooleanChunked; + +impl BooleanChunked { + pub fn num_trues(&self) -> usize { + self.downcast_iter() + .map(|arr| match arr.validity() { + None => arr.values().set_bits(), + Some(validity) => arr.values().num_intersections_with(validity), + }) + .sum() + } + + pub fn num_falses(&self) -> usize { + self.downcast_iter() + .map(|arr| match arr.validity() { + None => arr.values().unset_bits(), + Some(validity) => (!arr.values()).num_intersections_with(validity), + }) + .sum() + } +} diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index 2d1bddb2f4e3..e050fc6022ca 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -269,7 +269,7 @@ mod test { #[test] fn test_explode_list() -> PolarsResult<()> { - let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a"))?; + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a")); builder .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 3])) @@ -300,7 +300,7 @@ mod test { #[test] fn test_explode_empty_list_slot() -> PolarsResult<()> { // primitive - let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a"))?; + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a")); builder .append_series(&Series::new(PlSmallStr::EMPTY, &[1i32, 2])) .unwrap(); @@ -319,7 +319,7 @@ mod test { ); // more primitive - let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a"))?; + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a")); builder .append_series(&Series::new(PlSmallStr::EMPTY, &[1i32])) .unwrap(); @@ -344,7 +344,7 @@ mod test { ); // string - let mut builder = get_list_builder(&DataType::String, 5, 5, PlSmallStr::from_static("a"))?; + let mut builder = get_list_builder(&DataType::String, 5, 5, PlSmallStr::from_static("a")); builder .append_series(&Series::new(PlSmallStr::EMPTY, &["abc"])) .unwrap(); @@ -390,7 +390,7 @@ mod test { ); // boolean - let mut builder = get_list_builder(&DataType::Boolean, 5, 5, PlSmallStr::from_static("a"))?; + let mut builder = get_list_builder(&DataType::Boolean, 5, 5, PlSmallStr::from_static("a")); builder .append_series(&Series::new(PlSmallStr::EMPTY, &[true])) .unwrap(); diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index e33d38118891..8d34b80519dd 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -101,8 +101,7 @@ impl ChunkFullNull for BinaryOffsetChunked { impl ChunkFull<&Series> for ListChunked { fn full(name: PlSmallStr, value: &Series, length: usize) -> ListChunked { - let mut builder = - get_list_builder(value.dtype(), value.len() * length, length, name).unwrap(); + let mut builder = get_list_builder(value.dtype(), value.len() * length, length, name); for _ in 0..length { builder.append_series(value).unwrap(); } @@ -207,8 +206,7 @@ impl ListChunked { #[cfg(feature = "dtype-struct")] impl ChunkFullNull for StructChunked { fn full_null(name: PlSmallStr, length: usize) -> StructChunked { - let s = [Series::new_null(PlSmallStr::EMPTY, length)]; - StructChunked::from_series(name, s.iter()) + StructChunked::from_series(name, length, [].iter()) .unwrap() .with_outer_validity(Some(Bitmap::new_zeroed(length))) } diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index cb24305f75f6..fc162626bc27 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -143,7 +143,7 @@ unsafe fn gather_idx_array_unchecked( impl + ?Sized> ChunkTakeUnchecked for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &I) -> Self { @@ -178,7 +178,7 @@ pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> impl ChunkTakeUnchecked for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { @@ -312,3 +312,47 @@ impl IdxCa { f(&ca) } } + +#[cfg(feature = "dtype-array")] +impl ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +#[cfg(feature = "dtype-array")] +impl + ?Sized> ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} + +impl ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +impl + ?Sized> ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 061278a22cc9..c0daaa72bdf6 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -11,6 +11,7 @@ mod apply; mod approx_n_unique; pub mod arity; mod bit_repr; +mod bits; #[cfg(feature = "bitwise")] mod bitwise_reduce; pub(crate) mod chunkops; @@ -33,6 +34,7 @@ pub(crate) mod nulls; mod reverse; #[cfg(feature = "rolling_window")] pub(crate) mod rolling_window; +pub mod row_encode; pub mod search_sorted; mod set; mod shift; @@ -277,11 +279,7 @@ pub trait ChunkQuantile { } /// Aggregate a given quantile of the ChunkedArray. /// Returns `None` if the array is empty or only contains null values. - fn quantile( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult> { Ok(None) } } @@ -576,7 +574,7 @@ impl ChunkExpandAtIndex for StructChunked { }) .collect::>(); - StructArray::new(chunk.dtype().clone(), values, None).boxed() + StructArray::new(chunk.dtype().clone(), length, values, None).boxed() }; // SAFETY: chunks are from self. diff --git a/crates/polars-core/src/chunked_array/ops/row_encode.rs b/crates/polars-core/src/chunked_array/ops/row_encode.rs new file mode 100644 index 000000000000..5ac627327389 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/row_encode.rs @@ -0,0 +1,220 @@ +use arrow::compute::utils::combine_validities_and_many; +use polars_row::{convert_columns, EncodingField, RowsEncoded}; +use rayon::prelude::*; + +use crate::prelude::*; +use crate::utils::_split_offsets; +use crate::POOL; + +pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => s.rechunk(), + Binary | Boolean => s.clone(), + BinaryOffset => s.clone(), + String => s.str().unwrap().as_binary().into_series(), + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = s.struct_().unwrap(); + let new_fields = ca + .fields_as_series() + .iter() + .map(convert_series_for_row_encoding) + .collect::>>()?; + let mut out = + StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?; + out.zip_outer_validity(ca); + out.into_series() + }, + // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => s.clone(), + List(inner) if !inner.is_nested() => s.clone(), + Null => s.clone(), + _ => { + let phys = s.to_physical_repr().into_owned(); + polars_ensure!( + phys.dtype().is_numeric(), + InvalidOperation: "cannot sort column of dtype `{}`", s.dtype() + ); + phys + }, + }; + Ok(out) +} + +pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { + let by = convert_series_for_row_encoding(by)?; + let by = by.rechunk(); + + let out = match by.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + let ca = by.categorical().unwrap(); + if ca.uses_lexical_ordering() { + by.to_arrow(0, CompatLevel::newest()) + } else { + ca.physical().chunks[0].clone() + } + }, + // Take physical + _ => by.chunks()[0].clone(), + }; + Ok(out) +} + +pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + Ok(rows.into_array()) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +// Almost the same but broadcast nulls to the row-encoded array. +pub fn encode_rows_vertical_par_unordered_broadcast_nulls( + by: &[Series], +) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + + let validities = sliced + .iter() + .flat_map(|s| { + let s = s.rechunk(); + #[allow(clippy::unnecessary_to_owned)] + s.chunks() + .to_vec() + .into_iter() + .map(|arr| arr.validity().cloned()) + }) + .collect::>(); + + let validity = combine_validities_and_many(&validities); + Ok(rows.into_array().with_validity_typed(validity)) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +pub fn encode_rows_unordered(by: &[Series]) -> PolarsResult { + let rows = _get_rows_encoded_unordered(by)?; + Ok(BinaryOffsetChunked::with_chunk( + PlSmallStr::EMPTY, + rows.into_array(), + )) +} + +pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + for by in by { + let arr = _get_rows_encoded_compat_array(by)?; + let field = EncodingField::new_unsorted(); + match arr.dtype() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + for arr in arr.values() { + cols.push(arr.clone() as ArrayRef); + fields.push(field) + } + }, + _ => { + cols.push(arr); + fields.push(field) + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + +pub fn _get_rows_encoded( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + debug_assert_eq!(by.len(), descending.len()); + debug_assert_eq!(by.len(), nulls_last.len()); + + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + + for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { + let by = by.as_materialized_series(); + let arr = _get_rows_encoded_compat_array(by)?; + let sort_field = EncodingField { + descending: *desc, + nulls_last: *null_last, + no_order: false, + }; + match arr.dtype() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + let arr = arr.propagate_nulls(); + for value_arr in arr.values() { + cols.push(value_arr.clone() as ArrayRef); + fields.push(sort_field); + } + }, + _ => { + cols.push(arr); + fields.push(sort_field); + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + +pub fn _get_rows_encoded_ca( + name: PlSmallStr, + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + _get_rows_encoded(by, descending, nulls_last) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} + +pub fn _get_rows_encoded_arr( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult> { + _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) +} + +pub fn _get_rows_encoded_ca_unordered( + name: PlSmallStr, + by: &[Series], +) -> PolarsResult { + _get_rows_encoded_unordered(by) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs index cad95d6b1d10..7f257f23f59e 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -1,6 +1,7 @@ use polars_utils::itertools::Itertools; use super::*; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; #[derive(Eq)] struct CompareRow<'a> { diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 5653039ff02e..2291cc2306e1 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,10 +1,8 @@ -use arrow::compute::utils::combine_validities_and_many; use compare_inner::NullOrderCmp; -use polars_row::{convert_columns, EncodingField, RowsEncoded}; use polars_utils::itertools::Itertools; use super::*; -use crate::utils::_split_offsets; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; pub(crate) fn args_validate( ca: &ChunkedArray, @@ -86,181 +84,6 @@ pub(crate) fn arg_sort_multiple_impl( Ok(ca.into_inner()) } -pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { - let by = convert_sort_column_multi_sort(by)?; - let by = by.rechunk(); - - let out = match by.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - let ca = by.categorical().unwrap(); - if ca.uses_lexical_ordering() { - by.to_arrow(0, CompatLevel::newest()) - } else { - ca.physical().chunks[0].clone() - } - }, - // Take physical - _ => by.chunks()[0].clone(), - }; - Ok(out) -} - -pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult { - let n_threads = POOL.current_num_threads(); - let len = by[0].len(); - let splits = _split_offsets(len, n_threads); - - let chunks = splits.into_par_iter().map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded_unordered(&sliced)?; - Ok(rows.into_array()) - }); - let chunks = POOL.install(|| chunks.collect::>>()); - - Ok(BinaryOffsetChunked::from_chunk_iter( - PlSmallStr::EMPTY, - chunks?, - )) -} - -// Almost the same but broadcast nulls to the row-encoded array. -pub fn encode_rows_vertical_par_unordered_broadcast_nulls( - by: &[Series], -) -> PolarsResult { - let n_threads = POOL.current_num_threads(); - let len = by[0].len(); - let splits = _split_offsets(len, n_threads); - - let chunks = splits.into_par_iter().map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded_unordered(&sliced)?; - - let validities = sliced - .iter() - .flat_map(|s| { - let s = s.rechunk(); - #[allow(clippy::unnecessary_to_owned)] - s.chunks() - .to_vec() - .into_iter() - .map(|arr| arr.validity().cloned()) - }) - .collect::>(); - - let validity = combine_validities_and_many(&validities); - Ok(rows.into_array().with_validity_typed(validity)) - }); - let chunks = POOL.install(|| chunks.collect::>>()); - - Ok(BinaryOffsetChunked::from_chunk_iter( - PlSmallStr::EMPTY, - chunks?, - )) -} - -pub(crate) fn encode_rows_unordered(by: &[Series]) -> PolarsResult { - let rows = _get_rows_encoded_unordered(by)?; - Ok(BinaryOffsetChunked::with_chunk( - PlSmallStr::EMPTY, - rows.into_array(), - )) -} - -pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { - let mut cols = Vec::with_capacity(by.len()); - let mut fields = Vec::with_capacity(by.len()); - for by in by { - let arr = _get_rows_encoded_compat_array(by)?; - let field = EncodingField::new_unsorted(); - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for arr in arr.values() { - cols.push(arr.clone() as ArrayRef); - fields.push(field) - } - }, - _ => { - cols.push(arr); - fields.push(field) - }, - } - } - Ok(convert_columns(&cols, &fields)) -} - -pub fn _get_rows_encoded( - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult { - debug_assert_eq!(by.len(), descending.len()); - debug_assert_eq!(by.len(), nulls_last.len()); - - let mut cols = Vec::with_capacity(by.len()); - let mut fields = Vec::with_capacity(by.len()); - - for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { - let by = by.as_materialized_series(); - let arr = _get_rows_encoded_compat_array(by)?; - let sort_field = EncodingField { - descending: *desc, - nulls_last: *null_last, - no_order: false, - }; - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - let arr = arr.propagate_nulls(); - for value_arr in arr.values() { - cols.push(value_arr.clone() as ArrayRef); - fields.push(sort_field); - } - }, - _ => { - cols.push(arr); - fields.push(sort_field); - }, - } - } - Ok(convert_columns(&cols, &fields)) -} - -pub fn _get_rows_encoded_ca( - name: PlSmallStr, - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult { - _get_rows_encoded(by, descending, nulls_last) - .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) -} - -pub fn _get_rows_encoded_arr( - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult> { - _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) -} - -pub fn _get_rows_encoded_ca_unordered( - name: PlSmallStr, - by: &[Series], -) -> PolarsResult { - _get_rows_encoded_unordered(by) - .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) -} - pub(crate) fn argsort_multiple_row_fmt( by: &[Column], mut descending: Vec, diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 0aa70dae1c83..727f2ace15a8 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -18,6 +18,9 @@ use compare_inner::NonNull; use rayon::prelude::*; pub use slice::*; +use crate::chunked_array::ops::row_encode::{ + _get_rows_encoded_ca, convert_series_for_row_encoding, +}; use crate::prelude::compare_inner::TotalOrdInner; use crate::prelude::sort::arg_sort_multiple::*; use crate::prelude::*; @@ -708,43 +711,6 @@ impl ChunkSort for BooleanChunked { } } -pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult { - use DataType::*; - let out = match s.dtype() { - #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => s.rechunk(), - Binary | Boolean => s.clone(), - BinaryOffset => s.clone(), - String => s.str().unwrap().as_binary().into_series(), - #[cfg(feature = "dtype-struct")] - Struct(_) => { - let ca = s.struct_().unwrap(); - let new_fields = ca - .fields_as_series() - .iter() - .map(convert_sort_column_multi_sort) - .collect::>>()?; - let mut out = StructChunked::from_series(ca.name().clone(), new_fields.iter())?; - out.zip_outer_validity(ca); - out.into_series() - }, - // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here - #[cfg(feature = "dtype-decimal")] - Decimal(_, _) => s.clone(), - List(inner) if !inner.is_nested() => s.clone(), - Null => s.clone(), - _ => { - let phys = s.to_physical_repr().into_owned(); - polars_ensure!( - phys.dtype().is_numeric(), - InvalidOperation: "cannot sort column of dtype `{}`", s.dtype() - ); - phys - }, - }; - Ok(out) -} - pub fn _broadcast_bools(n_cols: usize, values: &mut Vec) { if n_cols > values.len() && values.len() == 1 { while n_cols != values.len() { @@ -762,7 +728,7 @@ pub(crate) fn prepare_arg_sort( let mut columns = columns .iter() .map(Column::as_materialized_series) - .map(convert_sort_column_multi_sort) + .map(convert_series_for_row_encoding) .map(|s| s.map(Column::from)) .collect::>>()?; diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 61f530280324..2fb29ac1ef46 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -271,7 +271,7 @@ impl ChunkZip for StructChunked { let if_true = if_true.as_ref(); let if_false = if_false.as_ref(); - let (l, r, mask) = align_chunks_ternary(if_true, if_false, mask); + let (if_true, if_false, mask) = align_chunks_ternary(if_true, if_false, mask); // Prepare the boolean arrays such that Null maps to false. // This prevents every field doing that. @@ -287,14 +287,14 @@ impl ChunkZip for StructChunked { } // Zip all the fields. - let fields = l + let fields = if_true .fields_as_series() .iter() - .zip(r.fields_as_series()) + .zip(if_false.fields_as_series()) .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs)) .collect::>>()?; - let mut out = StructChunked::from_series(self.name().clone(), fields.iter())?; + let mut out = StructChunked::from_series(self.name().clone(), length, fields.iter())?; fn rechunk_bitmaps( total_length: usize, @@ -330,138 +330,145 @@ impl ChunkZip for StructChunked { // We need to take two things into account: // 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`. // 2. `l` and `r` might still need to be broadcasted. - if (l.null_count + r.null_count) > 0 { + if (if_true.null_count + if_false.null_count) > 0 { // Create one validity mask that spans the entirety of out. - let rechunked_validity = match (l.len(), r.len()) { - (1, 1) if length != 1 => match (l.null_count() == 0, r.null_count() == 0) { - (true, true) => None, - (false, true) => { - if mask.chunks().len() == 1 { - let m = mask.chunks()[0] - .as_any() - .downcast_ref::() - .unwrap() - .values(); - Some(!m) - } else { - rechunk_bitmaps( - length, - mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))), - ) - } - }, - (true, false) => { - if mask.chunks().len() == 1 { - let m = mask.chunks()[0] - .as_any() - .downcast_ref::() - .unwrap() - .values(); - Some(m.clone()) - } else { - rechunk_bitmaps( - length, - mask.downcast_iter() - .map(|m| (m.len(), Some(m.values().clone()))), - ) - } - }, - (false, false) => Some(Bitmap::new_zeroed(length)), + let rechunked_validity = match (if_true.len(), if_false.len()) { + (1, 1) if length != 1 => { + match (if_true.null_count() == 0, if_false.null_count() == 0) { + (true, true) => None, + (false, true) => { + if mask.chunks().len() == 1 { + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + Some(!m) + } else { + rechunk_bitmaps( + length, + mask.downcast_iter() + .map(|m| (m.len(), Some(m.values().clone()))), + ) + } + }, + (true, false) => { + if mask.chunks().len() == 1 { + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + Some(m.clone()) + } else { + rechunk_bitmaps( + length, + mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))), + ) + } + }, + (false, false) => Some(Bitmap::new_zeroed(length)), + } }, (1, _) if length != 1 => { - debug_assert!(r + debug_assert!(if_false .chunk_lengths() .zip(mask.chunk_lengths()) .all(|(r, m)| r == m)); - let combine = if l.null_count() == 0 { - |r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m)) + let combine = if if_true.null_count() == 0 { + |if_false: Option<&Bitmap>, m: &Bitmap| { + if_false.map(|v| arrow::bitmap::or(v, m)) + } } else { - |r: Option<&Bitmap>, m: &Bitmap| { - Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m))) + |if_false: Option<&Bitmap>, m: &Bitmap| { + Some(if_false.map_or_else(|| !m, |v| arrow::bitmap::and_not(v, m))) } }; - if r.chunks().len() == 1 { - let r = r.chunks()[0].validity(); + if if_false.chunks().len() == 1 { + let if_false = if_false.chunks()[0].validity(); let m = mask.chunks()[0] .as_any() .downcast_ref::() .unwrap() .values(); - let validity = combine(r, m); - validity.and_then(|v| (v.unset_bits() > 0).then_some(v)) + let validity = combine(if_false, m); + validity.filter(|v| v.unset_bits() > 0) } else { rechunk_bitmaps( length, - r.chunks() - .iter() - .zip(mask.downcast_iter()) - .map(|(chunk, mask)| { + if_false.chunks().iter().zip(mask.downcast_iter()).map( + |(chunk, mask)| { (mask.len(), combine(chunk.validity(), mask.values())) - }), + }, + ), ) } }, (_, 1) if length != 1 => { - debug_assert!(l + debug_assert!(if_true .chunk_lengths() .zip(mask.chunk_lengths()) .all(|(l, m)| l == m)); - let combine = if r.null_count() == 0 { - |l: Option<&Bitmap>, m: &Bitmap| l.map(|l| arrow::bitmap::or_not(l, m)) + let combine = if if_false.null_count() == 0 { + |if_true: Option<&Bitmap>, m: &Bitmap| { + if_true.map(|v| arrow::bitmap::or_not(v, m)) + } } else { - |l: Option<&Bitmap>, m: &Bitmap| { - Some(l.map_or_else(|| m.clone(), |l| arrow::bitmap::and(l, m))) + |if_true: Option<&Bitmap>, m: &Bitmap| { + Some(if_true.map_or_else(|| m.clone(), |v| arrow::bitmap::and(v, m))) } }; - if l.chunks().len() == 1 { - let l = l.chunks()[0].validity(); + if if_true.chunks().len() == 1 { + let if_true = if_true.chunks()[0].validity(); let m = mask.chunks()[0] .as_any() .downcast_ref::() .unwrap() .values(); - let validity = combine(l, m); - validity.and_then(|v| (v.unset_bits() > 0).then_some(v)) + let validity = combine(if_true, m); + validity.filter(|v| v.unset_bits() > 0) } else { rechunk_bitmaps( length, - l.chunks() - .iter() - .zip(mask.downcast_iter()) - .map(|(chunk, mask)| { + if_true.chunks().iter().zip(mask.downcast_iter()).map( + |(chunk, mask)| { (mask.len(), combine(chunk.validity(), mask.values())) - }), + }, + ), ) } }, (_, _) => { - debug_assert!(l + debug_assert!(if_true .chunk_lengths() - .zip(r.chunk_lengths()) + .zip(if_false.chunk_lengths()) .all(|(l, r)| l == r)); - debug_assert!(l + debug_assert!(if_true .chunk_lengths() .zip(mask.chunk_lengths()) .all(|(l, r)| l == r)); - let validities = l + let validities = if_true .chunks() .iter() - .zip(r.chunks()) + .zip(if_false.chunks()) .map(|(l, r)| (l.validity(), r.validity())); rechunk_bitmaps( length, validities .zip(mask.downcast_iter()) - .map(|((lv, rv), mask)| { - (mask.len(), if_then_else_validity(mask.values(), lv, rv)) + .map(|((if_true, if_false), mask)| { + ( + mask.len(), + if_then_else_validity(mask.values(), if_true, if_false), + ) }), ) }, diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 1ad3d2b7abd7..b927cdc8f9ee 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -192,10 +192,7 @@ impl DataFrame { match n.get(0) { Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), - None => { - let new_cols = self.columns.iter().map(Column::clear).collect_trusted(); - Ok(unsafe { DataFrame::new_no_checks(new_cols) }) - }, + None => Ok(self.clear()), } } @@ -237,10 +234,7 @@ impl DataFrame { let n = (self.height() as f64 * frac) as usize; self.sample_n_literal(n, with_replacement, shuffle, seed) }, - None => { - let new_cols = self.columns.iter().map(Column::clear).collect_trusted(); - Ok(unsafe { DataFrame::new_no_checks(new_cols) }) - }, + None => Ok(self.clear()), } } } diff --git a/crates/polars-core/src/chunked_array/struct_/frame.rs b/crates/polars-core/src/chunked_array/struct_/frame.rs index 83f0f1299667..b175b3a04832 100644 --- a/crates/polars-core/src/chunked_array/struct_/frame.rs +++ b/crates/polars-core/src/chunked_array/struct_/frame.rs @@ -5,6 +5,6 @@ use crate::prelude::StructChunked; impl DataFrame { pub fn into_struct(self, name: PlSmallStr) -> StructChunked { - StructChunked::from_columns(name, &self.columns).expect("same invariants") + StructChunked::from_columns(name, self.height(), &self.columns).expect("same invariants") } } diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs index 0c4eb50ddc58..625da8881117 100644 --- a/crates/polars-core/src/chunked_array/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -10,8 +10,8 @@ use polars_utils::aliases::PlHashMap; use polars_utils::itertools::Itertools; use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::chunked_array::ChunkedArray; -use crate::prelude::sort::arg_sort_multiple::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::prelude::*; use crate::series::Series; use crate::utils::Container; @@ -20,12 +20,22 @@ pub type StructChunked = ChunkedArray; fn constructor<'a, I: ExactSizeIterator + Clone>( name: PlSmallStr, + length: usize, fields: I, ) -> PolarsResult { + if fields.len() == 0 { + let dtype = DataType::Struct(Vec::new()); + let arrow_dtype = dtype.to_physical().to_arrow(CompatLevel::newest()); + let chunks = vec![StructArray::new(arrow_dtype, length, Vec::new(), None).boxed()]; + + // SAFETY: We construct each chunk above to have the `Struct` data type. + return Ok(unsafe { StructChunked::from_chunks_and_dtype(name, chunks, dtype) }); + } + // Different chunk lengths: rechunk and recurse. if !fields.clone().map(|s| s.n_chunks()).all_equal() { let fields = fields.map(|s| s.rechunk()).collect::>(); - return constructor(name, fields.iter()); + return constructor(name, length, fields.iter()); } let n_chunks = fields.clone().next().unwrap().n_chunks(); @@ -39,11 +49,11 @@ fn constructor<'a, I: ExactSizeIterator + Clone>( .map(|field| field.chunks()[c_i].clone()) .collect::>(); - if !fields.iter().map(|arr| arr.len()).all_equal() { + if !fields.iter().all(|arr| length == arr.len()) { return Err(()); } - Ok(StructArray::new(arrow_dtype.clone(), fields, None).boxed()) + Ok(StructArray::new(arrow_dtype.clone(), length, fields, None).boxed()) }) .collect::, ()>>(); @@ -59,40 +69,44 @@ fn constructor<'a, I: ExactSizeIterator + Clone>( // Different chunk lengths: rechunk and recurse. Err(_) => { let fields = fields.map(|s| s.rechunk()).collect::>(); - constructor(name, fields.iter()) + constructor(name, length, fields.iter()) }, } } impl StructChunked { - pub fn from_columns(name: PlSmallStr, fields: &[Column]) -> PolarsResult { - Self::from_series(name, fields.iter().map(|c| c.as_materialized_series())) + pub fn from_columns(name: PlSmallStr, length: usize, fields: &[Column]) -> PolarsResult { + Self::from_series( + name, + length, + fields.iter().map(|c| c.as_materialized_series()), + ) } pub fn from_series<'a, I: ExactSizeIterator + Clone>( name: PlSmallStr, + length: usize, fields: I, ) -> PolarsResult { let mut names = PlHashSet::with_capacity(fields.len()); - let first_len = fields.clone().next().map(|s| s.len()).unwrap_or(0); - let mut max_len = first_len; - let mut all_equal_len = true; - let mut is_empty = false; + let mut needs_to_broadcast = false; for s in fields.clone() { let s_len = s.len(); - max_len = std::cmp::max(max_len, s_len); - if s_len != first_len { - all_equal_len = false; - } - if s_len == 0 { - is_empty = true; + if s_len != length && s_len != 1 { + polars_bail!( + ShapeMismatch: "expected struct fields to have given length. given = {length}, field length = {s_len}." + ); } + + needs_to_broadcast |= length != 1 && s_len == 1; + polars_ensure!( names.insert(s.name()), Duplicate: "multiple fields with name '{}' found", s.name() ); + match s.dtype() { #[cfg(feature = "object")] DataType::Object(_, _) => { @@ -102,29 +116,27 @@ impl StructChunked { } } - if !all_equal_len { - let mut new_fields = Vec::with_capacity(fields.len()); - for s in fields { - let s_len = s.len(); - if is_empty { - new_fields.push(s.clear()) - } else if s_len == max_len { - new_fields.push(s.clone()) - } else if s_len == 1 { - new_fields.push(s.new_from_index(0, max_len)) + if !needs_to_broadcast { + return constructor(name, length, fields); + } + + if length == 0 { + // @NOTE: There are columns that are being broadcasted so we need to clear those. + let new_fields = fields.map(|s| s.clear()).collect::>(); + + return constructor(name, length, new_fields.iter()); + } + + let new_fields = fields + .map(|s| { + if s.len() == length { + s.clone() } else { - polars_bail!( - ShapeMismatch: "expected all fields to have equal length" - ); + s.new_from_index(0, length) } - } - constructor(name, new_fields.iter()) - } else if fields.len() == 0 { - let fields = [Series::new_null(PlSmallStr::EMPTY, 0)]; - constructor(name, fields.iter()) - } else { - constructor(name, fields) - } + }) + .collect::>(); + constructor(name, length, new_fields.iter()) } pub fn struct_fields(&self) -> &[Field] { @@ -185,7 +197,8 @@ impl StructChunked { }) .collect::>>()?; - let mut out = Self::from_series(self.name().clone(), new_fields.iter())?; + let mut out = + Self::from_series(self.name().clone(), struct_len, new_fields.iter())?; if self.null_count > 0 { out.zip_outer_validity(self); } @@ -241,7 +254,7 @@ impl StructChunked { } }) .collect::>>()?; - let mut out = Self::from_series(self.name().clone(), fields.iter())?; + let mut out = Self::from_series(self.name().clone(), self.len(), fields.iter())?; if self.null_count > 0 { out.zip_outer_validity(self); } @@ -286,7 +299,7 @@ impl StructChunked { .iter() .map(func) .collect::>>()?; - Self::from_series(self.name().clone(), fields.iter()).map(|mut ca| { + Self::from_series(self.name().clone(), self.len(), fields.iter()).map(|mut ca| { if self.null_count > 0 { // SAFETY: we don't change types/ lengths. unsafe { @@ -361,10 +374,10 @@ impl StructChunked { .fields_as_series() .into_iter() .map(Column::from) - .collect(); + .collect::>(); // SAFETY: invariants for struct are the same - unsafe { DataFrame::new_no_checks(columns) } + unsafe { DataFrame::new_no_checks(self.len(), columns) } } /// Get access to one of this `[StructChunked]`'s fields diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 4155a9bf14e9..a1a87e3001bd 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -528,6 +528,8 @@ impl<'a> AnyValue<'a> { match self { AnyValue::Null => true, AnyValue::List(s) => s.null_count() == s.len(), + #[cfg(feature = "dtype-array")] + AnyValue::Array(s, _) => s.null_count() == s.len(), #[cfg(feature = "dtype-struct")] AnyValue::Struct(_, _, _) => self._iter_struct_av().all(|av| av.is_nested_null()), _ => false, @@ -852,13 +854,13 @@ impl AnyValue<'_> { } } -impl<'a> Hash for AnyValue<'a> { +impl Hash for AnyValue<'_> { fn hash(&self, state: &mut H) { self.hash_impl(state, false) } } -impl<'a> Eq for AnyValue<'a> {} +impl Eq for AnyValue<'_> {} impl<'a, T> From> for AnyValue<'a> where diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index cd79349bfcd8..30d96649762f 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; #[cfg(feature = "dtype-array")] use polars_utils::format_tuple; +use polars_utils::itertools::Itertools; use super::*; #[cfg(feature = "object")] @@ -115,9 +116,10 @@ impl PartialEq for DataType { use DataType::*; { match (self, other) { - // Don't include rev maps in comparisons #[cfg(feature = "dtype-categorical")] - (Categorical(_, _), Categorical(_, _)) => true, + // Don't include rev maps in comparisons + // TODO: include ordering in comparison + (Categorical(_, _ordering_l), Categorical(_, _ordering_r)) => true, #[cfg(feature = "dtype-categorical")] // None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)` (Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true, @@ -183,6 +185,8 @@ impl DataType { pub fn is_known(&self) -> bool { match self { DataType::List(inner) => inner.is_known(), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) => inner.is_known(), #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => fields.iter().all(|fld| fld.dtype.is_known()), DataType::Unknown(_) => false, @@ -190,6 +194,35 @@ impl DataType { } } + /// Materialize this datatype if it is unknown. All other datatypes + /// are left unchanged. + pub fn materialize_unknown(&self) -> PolarsResult { + match self { + DataType::Unknown(u) => u + .materialize() + .ok_or_else(|| polars_err!(SchemaMismatch: "failed to materialize unknown type")), + DataType::List(inner) => Ok(DataType::List(Box::new(inner.materialize_unknown()?))), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, size) => Ok(DataType::Array( + Box::new(inner.materialize_unknown()?), + *size, + )), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => Ok(DataType::Struct( + fields + .iter() + .map(|f| { + PolarsResult::Ok(Field::new( + f.name().clone(), + f.dtype().materialize_unknown()?, + )) + }) + .try_collect_vec()?, + )), + _ => Ok(self.clone()), + } + } + #[cfg(feature = "dtype-array")] /// Get the full shape of a multidimensional array. pub fn get_shape(&self) -> Option> { @@ -648,6 +681,8 @@ impl DataType { match self { Null => true, List(field) => field.is_nested_null(), + #[cfg(feature = "dtype-array")] + Array(field, _) => field.is_nested_null(), #[cfg(feature = "dtype-struct")] Struct(fields) => fields.iter().all(|fld| fld.dtype.is_nested_null()), _ => false, @@ -663,6 +698,10 @@ impl DataType { pub fn matches_schema_type(&self, schema_type: &DataType) -> PolarsResult { match (self, schema_type) { (DataType::List(l), DataType::List(r)) => l.matches_schema_type(r), + #[cfg(feature = "dtype-array")] + (DataType::Array(l, sl), DataType::Array(r, sr)) => { + Ok(l.matches_schema_type(r)? && sl == sr) + }, #[cfg(feature = "dtype-struct")] (DataType::Struct(l), DataType::Struct(r)) => { let mut must_cast = false; diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index f3bc3571505c..b85caeec0a2e 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -165,7 +165,7 @@ impl DataType { ArrowDataType::Extension(name, _, _) if name.as_str() == "POLARS_EXTENSION_TYPE" => { #[cfg(feature = "object")] { - DataType::Object("extension", None) + DataType::Object("object", None) } #[cfg(not(feature = "object"))] { diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 64266a0066db..8d84d47be978 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -13,7 +13,6 @@ mod any_value; mod dtype; mod field; mod into_scalar; -mod reshape; #[cfg(feature = "object")] mod static_array_collect; mod time_unit; @@ -26,6 +25,7 @@ use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign}; pub use aliases::*; pub use any_value::*; pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray}; +pub use arrow::datatypes::reshape::*; #[cfg(feature = "dtype-categorical")] use arrow::datatypes::IntegerType; pub use arrow::datatypes::{ArrowDataType, TimeUnit as ArrowTimeUnit}; @@ -35,14 +35,13 @@ use bytemuck::Zeroable; pub use dtype::*; pub use field::*; pub use into_scalar::*; -use num_traits::{Bounded, FromPrimitive, Num, NumCast, One, Zero}; +use num_traits::{AsPrimitive, Bounded, FromPrimitive, Num, NumCast, One, Zero}; use polars_compute::arithmetic::HasPrimitiveArithmeticKernel; use polars_compute::float_sum::FloatSum; use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; -pub use reshape::*; #[cfg(feature = "serde")] use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] @@ -300,7 +299,7 @@ unsafe impl PolarsDataType for ObjectType { type OwnedPhysical = T; type ZeroablePhysical<'a> = Option<&'a T>; type Array = ObjectArray; - type IsNested = TrueT; + type IsNested = FalseT; type HasViews = FalseT; type IsStruct = FalseT; type IsObject = TrueT; @@ -356,6 +355,7 @@ pub trait NumericNative: + IsFloat + HasPrimitiveArithmeticKernel::Native> + FloatSum + + AsPrimitive + MinMax + IsNull { diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index c930b9e94da7..88fbfae96701 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -125,6 +125,26 @@ fn get_str_len_limit() -> usize { fn get_list_len_limit() -> usize { parse_env_var_limit(FMT_TABLE_CELL_LIST_LEN, DEFAULT_LIST_LEN_LIMIT) } +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn get_ellipsis() -> &'static str { + match std::env::var(FMT_TABLE_FORMATTING).as_deref().unwrap_or("") { + preset if preset.starts_with("ASCII") => "...", + _ => "…", + } +} + +fn estimate_string_width(s: &str) -> usize { + // get a slightly more accurate estimate of a string's screen + // width, accounting (very roughly) for multibyte characters + let n_chars = s.chars().count(); + let n_bytes = s.len(); + if n_bytes == n_chars { + n_chars + } else { + let adjust = n_bytes as f64 / n_chars as f64; + std::cmp::min(n_chars * 2, (n_chars as f64 * adjust).ceil() as usize) + } +} macro_rules! format_array { ($f:ident, $a:expr, $dtype:expr, $name:expr, $array_type:expr) => {{ @@ -424,7 +444,7 @@ impl Debug for DataFrame { } } #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] -fn make_str_val(v: &str, truncate: usize) -> String { +fn make_str_val(v: &str, truncate: usize, ellipsis: &String) -> String { let v_trunc = &v[..v .char_indices() .take(truncate) @@ -434,14 +454,19 @@ fn make_str_val(v: &str, truncate: usize) -> String { if v == v_trunc { v.to_string() } else { - format!("{v_trunc}…") + format!("{v_trunc}{ellipsis}") } } #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] -fn field_to_str(f: &Field, str_truncate: usize) -> (String, usize) { - let name = make_str_val(f.name(), str_truncate); - let name_length = name.len(); +fn field_to_str( + f: &Field, + str_truncate: usize, + ellipsis: &String, + padding: usize, +) -> (String, usize) { + let name = make_str_val(f.name(), str_truncate, ellipsis); + let name_length = estimate_string_width(name.as_str()); let mut column_name = name; if env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) { column_name = "".to_string(); @@ -473,11 +498,11 @@ fn field_to_str(f: &Field, str_truncate: usize) -> (String, usize) { format!("{column_name}{separator}{column_dtype}") }; let mut s_len = std::cmp::max(name_length, dtype_length); - let separator_length = separator.trim().len(); + let separator_length = estimate_string_width(separator.trim()); if s_len < separator_length { s_len = separator_length; } - (s, s_len + 2) + (s, s_len + padding) } #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] @@ -487,27 +512,29 @@ fn prepare_row( n_last: usize, str_truncate: usize, max_elem_lengths: &mut [usize], + ellipsis: &String, + padding: usize, ) -> Vec { let reduce_columns = n_first + n_last < row.len(); let n_elems = n_first + n_last + reduce_columns as usize; let mut row_strings = Vec::with_capacity(n_elems); for (idx, v) in row[0..n_first].iter().enumerate() { - let elem_str = make_str_val(v, str_truncate); - let elem_len = elem_str.len() + 2; + let elem_str = make_str_val(v, str_truncate, ellipsis); + let elem_len = estimate_string_width(elem_str.as_str()) + padding; if max_elem_lengths[idx] < elem_len { max_elem_lengths[idx] = elem_len; }; row_strings.push(elem_str); } if reduce_columns { - row_strings.push("…".to_string()); - max_elem_lengths[n_first] = 3; + row_strings.push(ellipsis.to_string()); + max_elem_lengths[n_first] = ellipsis.chars().count() + padding; } let elem_offset = n_first + reduce_columns as usize; for (idx, v) in row[row.len() - n_last..].iter().enumerate() { - let elem_str = make_str_val(v, str_truncate); - let elem_len = elem_str.len() + 2; + let elem_str = make_str_val(v, str_truncate, ellipsis); + let elem_len = estimate_string_width(elem_str.as_str()) + padding; let elem_idx = elem_offset + idx; if max_elem_lengths[elem_idx] < elem_len { max_elem_lengths[elem_idx] = elem_len; @@ -542,16 +569,36 @@ impl Display for DataFrame { "The column lengths in the DataFrame are not equal." ); + let table_style = std::env::var(FMT_TABLE_FORMATTING).unwrap_or("DEFAULT".to_string()); + let is_utf8 = !table_style.starts_with("ASCII"); + let preset = match table_style.as_str() { + "ASCII_FULL" => ASCII_FULL, + "ASCII_FULL_CONDENSED" => ASCII_FULL_CONDENSED, + "ASCII_NO_BORDERS" => ASCII_NO_BORDERS, + "ASCII_BORDERS_ONLY" => ASCII_BORDERS_ONLY, + "ASCII_BORDERS_ONLY_CONDENSED" => ASCII_BORDERS_ONLY_CONDENSED, + "ASCII_HORIZONTAL_ONLY" => ASCII_HORIZONTAL_ONLY, + "ASCII_MARKDOWN" | "MARKDOWN" => ASCII_MARKDOWN, + "UTF8_FULL" => UTF8_FULL, + "UTF8_FULL_CONDENSED" => UTF8_FULL_CONDENSED, + "UTF8_NO_BORDERS" => UTF8_NO_BORDERS, + "UTF8_BORDERS_ONLY" => UTF8_BORDERS_ONLY, + "UTF8_HORIZONTAL_ONLY" => UTF8_HORIZONTAL_ONLY, + "NOTHING" => NOTHING, + _ => UTF8_FULL_CONDENSED, + }; + let ellipsis = get_ellipsis().to_string(); + let ellipsis_len = ellipsis.chars().count(); let max_n_cols = get_col_limit(); let max_n_rows = get_row_limit(); let str_truncate = get_str_len_limit(); + let padding = 2; // eg: one char either side of the value let (n_first, n_last) = if self.width() > max_n_cols { ((max_n_cols + 1) / 2, max_n_cols / 2) } else { (self.width(), 0) }; - let reduce_columns = n_first + n_last < self.width(); let n_tbl_cols = n_first + n_last + reduce_columns as usize; let mut names = Vec::with_capacity(n_tbl_cols); @@ -559,39 +606,19 @@ impl Display for DataFrame { let fields = self.fields(); for field in fields[0..n_first].iter() { - let (s, l) = field_to_str(field, str_truncate); + let (s, l) = field_to_str(field, str_truncate, &ellipsis, padding); names.push(s); name_lengths.push(l); } if reduce_columns { - names.push("…".into()); - name_lengths.push(3); + names.push(ellipsis.clone()); + name_lengths.push(ellipsis_len); } for field in fields[self.width() - n_last..].iter() { - let (s, l) = field_to_str(field, str_truncate); + let (s, l) = field_to_str(field, str_truncate, &ellipsis, padding); names.push(s); name_lengths.push(l); } - let (preset, is_utf8) = match std::env::var(FMT_TABLE_FORMATTING) - .as_deref() - .unwrap_or("DEFAULT") - { - "ASCII_FULL" => (ASCII_FULL, false), - "ASCII_FULL_CONDENSED" => (ASCII_FULL_CONDENSED, false), - "ASCII_NO_BORDERS" => (ASCII_NO_BORDERS, false), - "ASCII_BORDERS_ONLY" => (ASCII_BORDERS_ONLY, false), - "ASCII_BORDERS_ONLY_CONDENSED" => (ASCII_BORDERS_ONLY_CONDENSED, false), - "ASCII_HORIZONTAL_ONLY" => (ASCII_HORIZONTAL_ONLY, false), - "ASCII_MARKDOWN" => (ASCII_MARKDOWN, false), - "UTF8_FULL" => (UTF8_FULL, true), - "UTF8_FULL_CONDENSED" => (UTF8_FULL_CONDENSED, true), - "UTF8_NO_BORDERS" => (UTF8_NO_BORDERS, true), - "UTF8_BORDERS_ONLY" => (UTF8_BORDERS_ONLY, true), - "UTF8_HORIZONTAL_ONLY" => (UTF8_HORIZONTAL_ONLY, true), - "NOTHING" => (NOTHING, false), - "DEFAULT" => (UTF8_FULL_CONDENSED, true), - _ => (UTF8_FULL_CONDENSED, true), - }; let mut table = Table::new(); table @@ -601,7 +628,6 @@ impl Display for DataFrame { if is_utf8 && env_is_true(FMT_TABLE_ROUNDED_CORNERS) { table.apply_modifier(UTF8_ROUND_CORNERS); } - let mut constraints = Vec::with_capacity(n_tbl_cols); let mut max_elem_lengths: Vec = vec![0; n_tbl_cols]; @@ -610,7 +636,6 @@ impl Display for DataFrame { // Truncate the table if we have more rows than the // configured maximum number of rows let mut rows = Vec::with_capacity(std::cmp::max(max_n_rows, 2)); - let half = max_n_rows / 2; let rest = max_n_rows % 2; @@ -621,13 +646,20 @@ impl Display for DataFrame { .map(|c| c.str_value(i).unwrap()) .collect(); - let row_strings = - prepare_row(row, n_first, n_last, str_truncate, &mut max_elem_lengths); - + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); rows.push(row_strings); } - let dots = rows[0].iter().map(|_| "…".to_string()).collect(); + let dots = vec![ellipsis.clone(); rows[0].len()]; rows.push(dots); + for i in (height - half)..height { let row = self .get_columns() @@ -635,8 +667,15 @@ impl Display for DataFrame { .map(|c| c.str_value(i).unwrap()) .collect(); - let row_strings = - prepare_row(row, n_first, n_last, str_truncate, &mut max_elem_lengths); + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); rows.push(row_strings); } table.add_rows(rows); @@ -654,6 +693,8 @@ impl Display for DataFrame { n_last, str_truncate, &mut max_elem_lengths, + &ellipsis, + padding, ); table.add_row(row_strings); } else { @@ -662,10 +703,9 @@ impl Display for DataFrame { } } } else if height > 0 { - let dots: Vec = self.columns.iter().map(|_| "…".to_string()).collect(); + let dots: Vec = vec![ellipsis.clone(); self.columns.len()]; table.add_row(dots); } - let tbl_fallback_width = 100; let tbl_width = std::env::var("POLARS_TABLE_WIDTH") .map(|s| { @@ -683,10 +723,10 @@ impl Display for DataFrame { lower: Width::Fixed(l as u16), upper: Width::Fixed(u as u16), }; - let min_col_width = 5; + let min_col_width = std::cmp::max(5, 3 + padding); for (idx, elem_len) in max_elem_lengths.iter().enumerate() { let mx = std::cmp::min( - str_truncate + 3, // (3 = 2 space chars of padding + ellipsis char) + str_truncate + ellipsis_len + padding, std::cmp::max(name_lengths[idx], *elem_len), ); if mx <= min_col_width { @@ -1011,7 +1051,7 @@ fn format_blob(f: &mut Formatter<'_>, bytes: &[u8]) -> fmt::Result { } } if bytes.len() > width { - write!(f, "\"...")?; + write!(f, "\"…")?; } else { write!(f, "\"")?; } @@ -1138,9 +1178,7 @@ impl Series { if self.is_empty() { return "[]".to_owned(); } - let max_items = get_list_len_limit(); - match max_items { 0 => "[…]".to_owned(), _ if max_items >= self.len() => { diff --git a/crates/polars-core/src/frame/arithmetic.rs b/crates/polars-core/src/frame/arithmetic.rs index 6d184b2960c9..cf6683d0a57c 100644 --- a/crates/polars-core/src/frame/arithmetic.rs +++ b/crates/polars-core/src/frame/arithmetic.rs @@ -25,7 +25,7 @@ macro_rules! impl_arithmetic { .map(|s| s.map(Column::from)) .collect::>() })?; - Ok(unsafe { DataFrame::new_no_checks(cols) }) + Ok(unsafe { DataFrame::new_no_checks($self.height(), cols) }) }}; } diff --git a/crates/polars-core/src/frame/chunks.rs b/crates/polars-core/src/frame/chunks.rs index 16fa8f7c1ff9..801c0d11b9c8 100644 --- a/crates/polars-core/src/frame/chunks.rs +++ b/crates/polars-core/src/frame/chunks.rs @@ -33,7 +33,8 @@ impl DataFrame { .map(Column::from) .collect::>(); - DataFrame::new_no_checks(columns) + let height = Self::infer_height(&columns); + DataFrame::new_no_checks(height, columns) }) } diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 313a4ec89df2..e66c1ad12875 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -372,15 +372,18 @@ impl Column { #[inline] pub fn new_from_index(&self, index: usize, length: usize) -> Self { + if index >= self.len() { + return Self::full_null(self.name().clone(), length, self.dtype()); + } + match self { - Column::Series(s) => s.new_from_index(index, length).into(), - Column::Scalar(s) => { - if index >= s.len() { - Self::full_null(s.name().clone(), length, s.dtype()) - } else { - s.resize(length).into() - } + Column::Series(s) => { + // SAFETY: Bounds check done before. + let av = unsafe { s.get_unchecked(index) }; + let scalar = Scalar::new(self.dtype().clone(), av.into_static()); + Self::new_scalar(self.name().clone(), scalar, length) }, + Column::Scalar(s) => s.resize(length).into(), } } @@ -557,12 +560,12 @@ impl Column { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { // @scalar-opt unsafe { self.as_materialized_series() - .agg_quantile(groups, quantile, interpol) + .agg_quantile(groups, quantile, method) } .into() } @@ -677,25 +680,6 @@ impl Column { .vec_hash_combine(build_hasher, hashes) } - /// # Safety - /// - /// Indexes need to be in bounds. - pub(crate) unsafe fn equal_element( - &self, - idx_self: usize, - idx_other: usize, - other: &Column, - ) -> bool { - // @scalar-opt - unsafe { - self.as_materialized_series().equal_element( - idx_self, - idx_other, - other.as_materialized_series(), - ) - } - } - pub fn append(&mut self, other: &Column) -> PolarsResult<&mut Self> { // @scalar-opt self.into_materialized_series() @@ -715,7 +699,7 @@ impl Column { pub fn into_frame(self) -> DataFrame { // SAFETY: A single-column dataframe cannot have length mismatches or duplicate names - unsafe { DataFrame::new_no_checks(vec![self]) } + unsafe { DataFrame::new_no_checks(self.len(), vec![self]) } } pub fn unique_stable(&self) -> PolarsResult { diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs index 5ec5d98a1597..e0282f4275fb 100644 --- a/crates/polars-core/src/frame/from.rs +++ b/crates/polars-core/src/frame/from.rs @@ -4,7 +4,7 @@ impl TryFrom for DataFrame { type Error = PolarsError; fn try_from(arr: StructArray) -> PolarsResult { - let (fld, arrs, nulls) = arr.into_data(); + let (fld, _length, arrs, nulls) = arr.into_data(); polars_ensure!( nulls.is_none(), ComputeError: "cannot deserialize struct with nulls into a DataFrame" diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index fe71148cd49b..aaf24a470969 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -236,7 +236,7 @@ impl Series { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { @@ -247,13 +247,12 @@ impl Series { use DataType::*; match s.dtype() { - Float32 => s.f32().unwrap().agg_quantile(groups, quantile, interpol), - Float64 => s.f64().unwrap().agg_quantile(groups, quantile, interpol), + Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method), + Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method), dt if dt.is_numeric() || dt.is_temporal() => { let ca = s.to_physical_repr(); let physical_type = ca.dtype(); - let s = - apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol); + let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, method); if dt.is_logical() { // back to physical and then // back to logical type diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 092d660fb4d2..19b8d5c2d061 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -13,7 +13,7 @@ use arrow::legacy::kernels::rolling::no_nulls::{ }; use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls; use arrow::legacy::kernels::take_agg::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; use arrow::legacy::trusted_len::TrustedLenPush; use arrow::types::NativeType; use num_traits::pow::Pow; @@ -295,8 +295,7 @@ impl_take_extremum!(float: f64); /// This trait will ensure the specific dispatch works without complicating /// the trait bounds. trait QuantileDispatcher { - fn _quantile(self, quantile: f64, interpol: QuantileInterpolOptions) - -> PolarsResult>; + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult>; fn _median(self) -> Option; } @@ -307,12 +306,8 @@ where T::Native: Ord, ChunkedArray: IntoSeries, { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -320,24 +315,16 @@ where } impl QuantileDispatcher for Float32Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() } } impl QuantileDispatcher for Float64Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -348,7 +335,7 @@ unsafe fn agg_quantile_generic( ca: &ChunkedArray, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series where T: PolarsNumericType, @@ -371,7 +358,7 @@ where } let take = { ca.take_unchecked(idx) }; // checked with invalid quantile check - take._quantile(quantile, interpol).unwrap_unchecked() + take._quantile(quantile, method).unwrap_unchecked() }) }, GroupsProxy::Slice { groups, .. } => { @@ -390,7 +377,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ), Some(validity) => { @@ -400,7 +387,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ) }, @@ -418,7 +405,7 @@ where let arr_group = _slice_from_offsets(ca, first, len); // unwrap checked with invalid quantile check arr_group - ._quantile(quantile, interpol) + ._quantile(quantile, method) .unwrap_unchecked() .map(|flt| NumCast::from(flt).unwrap_unchecked()) }, @@ -450,7 +437,7 @@ where }) }, GroupsProxy::Slice { .. } => { - agg_quantile_generic::(ca, groups, 0.5, QuantileInterpolOptions::Linear) + agg_quantile_generic::(ca, groups, 0.5, QuantileMethod::Linear) }, } } @@ -977,9 +964,9 @@ impl Float32Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float32Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float32Type>(self, groups) @@ -990,9 +977,9 @@ impl Float64Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) @@ -1184,9 +1171,9 @@ where &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index bdaa439a1232..12b4b27de7e2 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -3,8 +3,8 @@ use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca_unordered; use crate::config::verbose; -use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered; use crate::series::BitRepr; use crate::utils::flatten::flatten_par; diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 89c72f5a0eac..9dee1e1f411a 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -21,7 +21,7 @@ mod proxy; pub use into_groups::*; pub use proxy::*; -use crate::prelude::sort::arg_sort_multiple::{ +use crate::chunked_array::ops::row_encode::{ encode_rows_unordered, encode_rows_vertical_par_unordered, }; @@ -594,18 +594,14 @@ impl<'df> GroupBy<'df> { /// /// ```rust /// # use polars_core::prelude::*; - /// # use arrow::legacy::prelude::QuantileInterpolOptions; + /// # use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> PolarsResult { - /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileInterpolOptions::default()) + /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default()) /// } /// ``` #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] - pub fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { polars_ensure!( (0.0..=1.0).contains(&quantile), ComputeError: "`quantile` should be within 0.0 and 1.0" @@ -614,9 +610,9 @@ impl<'df> GroupBy<'df> { for agg_col in agg_cols { let new_name = fmt_group_by_column( agg_col.name().as_str(), - GroupByMethod::Quantile(quantile, interpol), + GroupByMethod::Quantile(quantile, method), ); - let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, interpol) }; + let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) }; agg.rename(new_name); cols.push(agg); } @@ -795,7 +791,7 @@ impl<'df> GroupBy<'df> { new_cols.extend_from_slice(&self.selected_keys); let cols = self.df.select_columns_impl(agg.as_slice())?; new_cols.extend(cols); - Ok(unsafe { DataFrame::new_no_checks(new_cols) }) + Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) }) } } else { Ok(self.df.clone()) @@ -868,7 +864,7 @@ pub enum GroupByMethod { Sum, Groups, NUnique, - Quantile(f64, QuantileInterpolOptions), + Quantile(f64, QuantileMethod), Count { include_nulls: bool, }, diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 86d952828c85..ef228981f7e5 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::mem::MaybeUninit; use num_traits::{FromPrimitive, ToPrimitive}; use polars_utils::idx_vec::IdxVec; @@ -17,162 +18,136 @@ where T: PolarsIntegerType, T::Native: ToPrimitive + FromPrimitive + Debug, { - // Use the indexes as perfect groups - pub fn group_tuples_perfect( + /// Use the indexes as perfect groups. + /// + /// # Safety + /// This ChunkedArray must contain each value in [0..num_groups) at least + /// once, and nothing outside this range. + pub unsafe fn group_tuples_perfect( &self, - max: usize, + num_groups: usize, mut multithreaded: bool, group_capacity: usize, ) -> GroupsProxy { multithreaded &= POOL.current_num_threads() > 1; + // The latest index will be used for the null sentinel. let len = if self.null_count() > 0 { - // we add one to store the null sentinel group - max + 2 + // We add one to store the null sentinel group. + num_groups + 1 } else { - max + 1 + num_groups }; - - // the latest index will be used for the null sentinel let null_idx = len.saturating_sub(1); - let n_threads = POOL.current_num_threads(); + let n_threads = POOL.current_num_threads(); let chunk_size = len / n_threads; let (groups, first) = if multithreaded && chunk_size > 1 { - let mut groups: Vec = unsafe { aligned_vec(len) }; + let mut groups: Vec = Vec::new(); groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); - let mut first: Vec = unsafe { aligned_vec(len) }; - - // ensure we keep aligned to cache lines - let chunk_size = (chunk_size * std::mem::size_of::()).next_multiple_of(64); - let chunk_size = chunk_size / std::mem::size_of::(); - - let mut cache_line_offsets = Vec::with_capacity(n_threads + 1); - cache_line_offsets.push(0); - let mut current_offset = chunk_size; - - while current_offset <= len { - cache_line_offsets.push(current_offset); - current_offset += chunk_size; + let mut first: Vec = Vec::with_capacity(len); + + // Round up offsets to nearest cache line for groups to reduce false sharing. + let groups_start = groups.as_ptr(); + let mut per_thread_offsets = Vec::with_capacity(n_threads + 1); + per_thread_offsets.push(0); + for t in 0..n_threads { + let ideal_offset = (t + 1) * chunk_size; + let cache_aligned_offset = + ideal_offset + groups_start.wrapping_add(ideal_offset).align_offset(128); + if t == n_threads - 1 { + per_thread_offsets.push(len); + } else { + per_thread_offsets.push(std::cmp::min(cache_aligned_offset, len)); + } } - cache_line_offsets.push(current_offset); let groups_ptr = unsafe { SyncPtr::new(groups.as_mut_ptr()) }; let first_ptr = unsafe { SyncPtr::new(first.as_mut_ptr()) }; - - // The number of threads is dependent on the number of categoricals/ unique values - // as every at least writes to a single cache line - // lower bound per thread: - // 32bit: 16 - // 64bit: 8 POOL.install(|| { - (0..cache_line_offsets.len() - 1) - .into_par_iter() - .for_each(|thread_no| { - let mut row_nr = 0 as IdxSize; - let start = cache_line_offsets[thread_no]; - let start = T::Native::from_usize(start).unwrap(); - let end = cache_line_offsets[thread_no + 1]; - let end = T::Native::from_usize(end).unwrap(); - - // SAFETY: we don't alias - let groups = - unsafe { std::slice::from_raw_parts_mut(groups_ptr.get(), len) }; - let first = unsafe { std::slice::from_raw_parts_mut(first_ptr.get(), len) }; - - for arr in self.downcast_iter() { - if arr.null_count() == 0 { - for &cat in arr.values().as_slice() { + (0..n_threads).into_par_iter().for_each(|thread_no| { + // We use raw pointers because the slices would overlap. + // However, each thread has its own range it is responsible for. + let groups = groups_ptr.get(); + let first = first_ptr.get(); + let start = per_thread_offsets[thread_no]; + let start = T::Native::from_usize(start).unwrap(); + let end = per_thread_offsets[thread_no + 1]; + let end = T::Native::from_usize(end).unwrap(); + + if start == end && thread_no != n_threads - 1 { + return; + }; + + let push_to_group = |cat, row_nr| unsafe { + debug_assert!(cat < len); + let buf = &mut *groups.add(cat); + buf.push(row_nr); + if buf.len() == 1 { + *first.add(cat) = row_nr; + } + }; + + let mut row_nr = 0 as IdxSize; + for arr in self.downcast_iter() { + if arr.null_count() == 0 { + for &cat in arr.values().as_slice() { + if cat >= start && cat < end { + push_to_group(cat.to_usize().unwrap(), row_nr); + } + + row_nr += 1; + } + } else { + for opt_cat in arr.iter() { + if let Some(&cat) = opt_cat { if cat >= start && cat < end { - let cat = cat.to_usize().unwrap(); - let buf = unsafe { groups.get_unchecked_release_mut(cat) }; - buf.push(row_nr); - - unsafe { - if buf.len() == 1 { - // SAFETY: we just pushed - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(cat) = *first_value - } - } + push_to_group(cat.to_usize().unwrap(), row_nr); } - row_nr += 1; + } else if thread_no == n_threads - 1 { + // Last thread handles null values. + push_to_group(null_idx, row_nr); } - } else { - for opt_cat in arr.iter() { - if let Some(&cat) = opt_cat { - // cannot factor out due to bchk - if cat >= start && cat < end { - let cat = cat.to_usize().unwrap(); - let buf = - unsafe { groups.get_unchecked_release_mut(cat) }; - buf.push(row_nr); - - unsafe { - if buf.len() == 1 { - // SAFETY: we just pushed - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(cat) = - *first_value - } - } - } - } - // last thread handles null values - else if thread_no == cache_line_offsets.len() - 2 { - let buf = - unsafe { groups.get_unchecked_release_mut(null_idx) }; - buf.push(row_nr); - unsafe { - if buf.len() == 1 { - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(null_idx) = - *first_value - } - } - } - row_nr += 1; - } + row_nr += 1; } } - }); + } + }); }); unsafe { - groups.set_len(len); first.set_len(len); } (groups, first) } else { let mut groups = Vec::with_capacity(len); - let mut first = vec![IdxSize::MAX; len]; + let mut first = Vec::with_capacity(len); + let first_out = first.spare_capacity_mut(); groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); + let mut push_to_group = |cat, row_nr| unsafe { + let buf: &mut IdxVec = groups.get_unchecked_release_mut(cat); + buf.push(row_nr); + if buf.len() == 1 { + *first_out.get_unchecked_release_mut(cat) = MaybeUninit::new(row_nr); + } + }; + let mut row_nr = 0 as IdxSize; for arr in self.downcast_iter() { for opt_cat in arr.iter() { if let Some(cat) = opt_cat { - let group_id = cat.to_usize().unwrap(); - let buf = unsafe { groups.get_unchecked_release_mut(group_id) }; - buf.push(row_nr); - - unsafe { - if buf.len() == 1 { - *first.get_unchecked_release_mut(group_id) = row_nr; - } - } + push_to_group(cat.to_usize().unwrap(), row_nr); } else { - let buf = unsafe { groups.get_unchecked_release_mut(null_idx) }; - buf.push(row_nr); - unsafe { - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(null_idx) = *first_value - } + push_to_group(null_idx, row_nr); } row_nr += 1; } } + unsafe { + first.set_len(len); + } (groups, first) }; @@ -201,7 +176,7 @@ impl CategoricalChunked { } // on relative small tables this isn't much faster than the default strategy // but on huge tables, this can be > 2x faster - cats.group_tuples_perfect(cached.len() - 1, multithreaded, 0) + unsafe { cats.group_tuples_perfect(cached.len(), multithreaded, 0) } } else { self.physical().group_tuples(multithreaded, sorted).unwrap() } @@ -220,26 +195,3 @@ impl CategoricalChunked { out } } - -#[repr(C, align(64))] -struct AlignTo64([u8; 64]); - -/// There are no guarantees that the [`Vec`] will remain aligned if you reallocate the data. -/// This means that you cannot reallocate so you will need to know how big to allocate up front. -unsafe fn aligned_vec(n: usize) -> Vec { - assert!(std::mem::align_of::() <= 64); - let n_units = (n * std::mem::size_of::() / std::mem::size_of::()) + 1; - - let mut aligned: Vec = Vec::with_capacity(n_units); - - let ptr = aligned.as_mut_ptr(); - let cap_units = aligned.capacity(); - - std::mem::forget(aligned); - - Vec::from_raw_parts( - ptr as *mut T, - 0, - cap_units * std::mem::size_of::() / std::mem::size_of::(), - ) -} diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index d1c04162b7b9..63b1a8022108 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -546,7 +546,7 @@ pub enum GroupsIndicator<'a> { Slice([IdxSize; 2]), } -impl<'a> GroupsIndicator<'a> { +impl GroupsIndicator<'_> { pub fn len(&self) -> usize { match self { GroupsIndicator::Idx(g) => g.1.len(), diff --git a/crates/polars-core/src/frame/horizontal.rs b/crates/polars-core/src/frame/horizontal.rs index 31c072991d87..0886c6b3f958 100644 --- a/crates/polars-core/src/frame/horizontal.rs +++ b/crates/polars-core/src/frame/horizontal.rs @@ -32,6 +32,15 @@ impl DataFrame { /// - the length of all [`Column`] is equal to the height of this [`DataFrame`] /// - the columns names are unique pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Column]) -> &mut Self { + // If we don't have any columns yet, copy the height from the given columns. + if let Some(fst) = columns.first() { + if self.width() == 0 { + // SAFETY: The functions invariants asks for all columns to be the same length so + // that makes that a valid height. + unsafe { self.set_height(fst.len()) }; + } + } + self.columns.extend_from_slice(columns); self } @@ -68,7 +77,7 @@ impl DataFrame { /// Concat [`DataFrame`]s horizontally. /// Concat horizontally and extend with null values if lengths don't match pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult { - let max_len = dfs + let output_height = dfs .iter() .map(|df| df.height()) .max() @@ -77,18 +86,22 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars let owned_df; // if not all equal length, extend the DataFrame with nulls - let dfs = if !dfs.iter().all(|df| df.height() == max_len) { + let dfs = if !dfs.iter().all(|df| df.height() == output_height) { owned_df = dfs .iter() .cloned() .map(|mut df| { - if df.height() != max_len { - let diff = max_len - df.height(); - df.columns.iter_mut().for_each(|s| { - // @scalar-opt - let s = s.into_materialized_series(); - *s = s.extend_constant(AnyValue::Null, diff).unwrap() + if df.height() != output_height { + let diff = output_height - df.height(); + + // SAFETY: We extend each column with nulls to the point of being of length + // `output_height`. Then, we set the height of the resulting dataframe. + unsafe { df.get_columns_mut() }.iter_mut().for_each(|c| { + *c = c.extend_constant(AnyValue::Null, diff).unwrap(); }); + unsafe { + df.set_height(output_height); + } } df }) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a06c5856a5f1..0d4230ff3f91 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -13,6 +13,7 @@ use crate::prelude::*; #[cfg(feature = "row_hash")] use crate::utils::split_df; use crate::utils::{slice_offsets, try_get_supertype, Container, NoNull}; +use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; #[cfg(feature = "dataframe_arithmetic")] mod arithmetic; @@ -32,6 +33,7 @@ use arrow::record_batch::RecordBatch; use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] @@ -48,8 +50,9 @@ pub enum NullStrategy { Propagate, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum UniqueKeepStrategy { /// Keep the first unique row. First, @@ -170,7 +173,8 @@ where /// ``` #[derive(Clone)] pub struct DataFrame { - // invariant: Column.len() is the same for each column + height: usize, + // invariant: columns[i].len() == height for each 0 >= i > columns.len() pub(crate) columns: Vec, } @@ -286,33 +290,25 @@ impl DataFrame { pub fn new(columns: Vec) -> PolarsResult { ensure_names_unique(&columns, |s| s.name().as_str())?; - if columns.len() > 1 { - let first_len = columns[0].len(); - for col in &columns { - polars_ensure!( - col.len() == first_len, - ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", - columns[0].len(), first_len, col.name(), col.len() - ); - } + let Some(fst) = columns.first() else { + return Ok(DataFrame { height: 0, columns }); + }; + + let height = fst.len(); + for col in &columns[1..] { + polars_ensure!( + col.len() == height, + ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", + columns[0].len(), height, col.name(), col.len() + ); } - Ok(DataFrame { columns }) + Ok(DataFrame { height, columns }) } /// Converts a sequence of columns into a DataFrame, broadcasting length-1 /// columns to match the other columns. pub fn new_with_broadcast(columns: Vec) -> PolarsResult { - ensure_names_unique(&columns, |s| s.name().as_str())?; - unsafe { Self::new_with_broadcast_no_checks(columns) } - } - - /// Converts a sequence of columns into a DataFrame, broadcasting length-1 - /// columns to match the other columns. - /// - /// # Safety - /// Does not check that the column names are unique (which they must be). - pub unsafe fn new_with_broadcast_no_checks(mut columns: Vec) -> PolarsResult { // The length of the longest non-unit length column determines the // broadcast length. If all columns are unit-length the broadcast length // is one. @@ -322,23 +318,51 @@ impl DataFrame { .filter(|l| *l != 1) .max() .unwrap_or(1); + Self::new_with_broadcast_len(columns, broadcast_len) + } + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to broadcast_len. + pub fn new_with_broadcast_len( + columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { + ensure_names_unique(&columns, |s| s.name().as_str())?; + unsafe { Self::new_with_broadcast_no_namecheck(columns, broadcast_len) } + } + + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + /// + /// # Safety + /// Does not check that the column names are unique (which they must be). + pub unsafe fn new_with_broadcast_no_namecheck( + mut columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { for col in &mut columns { // Length not equal to the broadcast len, needs broadcast or is an error. let len = col.len(); if len != broadcast_len { if len != 1 { let name = col.name().to_owned(); - let longest_column = columns.iter().max_by_key(|c| c.len()).unwrap().name(); + let extra_info = + if let Some(c) = columns.iter().find(|c| c.len() == broadcast_len) { + format!(" (matching column '{}')", c.name()) + } else { + String::new() + }; polars_bail!( - ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", - name, len, longest_column, broadcast_len + ShapeMismatch: "could not create a new DataFrame: series {name:?} has length {len} while trying to broadcast to length {broadcast_len}{extra_info}", ); } *col = col.new_from_index(0, broadcast_len); } } - Ok(unsafe { DataFrame::new_no_checks(columns) }) + + let length = if columns.is_empty() { 0 } else { broadcast_len }; + + Ok(unsafe { DataFrame::new_no_checks(length, columns) }) } /// Creates an empty `DataFrame` usable in a compile time context (such as static initializers). @@ -350,7 +374,10 @@ impl DataFrame { /// static EMPTY: DataFrame = DataFrame::empty(); /// ``` pub const fn empty() -> Self { - DataFrame { columns: vec![] } + DataFrame { + height: 0, + columns: vec![], + } } /// Create an empty `DataFrame` with empty columns as per the `schema`. @@ -359,7 +386,7 @@ impl DataFrame { .iter() .map(|(name, dtype)| Column::from(Series::new_empty(name.clone(), dtype))) .collect(); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(0, cols) } } /// Create an empty `DataFrame` with empty columns as per the `schema`. @@ -368,7 +395,7 @@ impl DataFrame { .iter_values() .map(|fld| Column::from(Series::new_empty(fld.name.clone(), &(fld.dtype().into())))) .collect(); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(0, cols) } } /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. @@ -453,7 +480,22 @@ impl DataFrame { self } - /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the `Series`. + /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the + /// `Series`. + /// + /// Calculates the height from the first column or `0` if no columns are given. + /// + /// # Safety + /// + /// It is the callers responsibility to uphold the contract of all `Series` + /// having an equal length and a unique name, if not this may panic down the line. + pub unsafe fn new_no_checks_height_from_first(columns: Vec) -> DataFrame { + let height = columns.first().map_or(0, Column::len); + unsafe { Self::new_no_checks(height, columns) } + } + + /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the + /// `Series`. /// /// It is advised to use [DataFrame::new] in favor of this method. /// @@ -461,23 +503,24 @@ impl DataFrame { /// /// It is the callers responsibility to uphold the contract of all `Series` /// having an equal length and a unique name, if not this may panic down the line. - pub unsafe fn new_no_checks(columns: Vec) -> DataFrame { - #[cfg(debug_assertions)] - { - Self::new(columns).unwrap() - } - #[cfg(not(debug_assertions))] - { - Self::_new_no_checks_impl(columns) + pub unsafe fn new_no_checks(height: usize, columns: Vec) -> DataFrame { + if cfg!(debug_assertions) { + ensure_names_unique(&columns, |s| s.name().as_str()).unwrap(); + + for col in &columns { + assert_eq!(col.len(), height); + } } + + unsafe { Self::_new_no_checks_impl(height, columns) } } /// This will not panic even in debug mode - there are some (rare) use cases where a DataFrame /// is temporarily constructed containing duplicates for dispatching to functions. A DataFrame /// constructed with this method is generally highly unsafe and should not be long-lived. #[allow(clippy::missing_safety_doc)] - pub const unsafe fn _new_no_checks_impl(columns: Vec) -> DataFrame { - DataFrame { columns } + pub const unsafe fn _new_no_checks_impl(height: usize, columns: Vec) -> DataFrame { + DataFrame { height, columns } } /// Create a new `DataFrame` but does not check the length of the `Series`, @@ -492,15 +535,11 @@ impl DataFrame { pub unsafe fn new_no_length_checks(columns: Vec) -> PolarsResult { ensure_names_unique(&columns, |s| s.name().as_str())?; - Ok({ - #[cfg(debug_assertions)] - { - Self::new(columns).unwrap() - } - #[cfg(not(debug_assertions))] - { - DataFrame { columns } - } + Ok(if cfg!(debug_assertions) { + Self::new(columns).unwrap() + } else { + let height = Self::infer_height(&columns); + DataFrame { height, columns } }) } @@ -642,11 +681,33 @@ impl DataFrame { /// Get mutable access to the underlying columns. /// /// # Safety - /// The caller must ensure the length of all [`Series`] remains equal. + /// + /// The caller must ensure the length of all [`Series`] remains equal to `height` or + /// [`DataFrame::set_height`] is called afterwards with the appropriate `height`. pub unsafe fn get_columns_mut(&mut self) -> &mut Vec { &mut self.columns } + #[inline] + /// Remove all the columns in the [`DataFrame`] but keep the `height`. + pub fn clear_columns(&mut self) { + unsafe { self.get_columns_mut() }.clear() + } + + #[inline] + /// Extend the columns without checking for name collisions or height. + /// + /// # Safety + /// + /// The caller needs to ensure that: + /// - Column names are unique within the resulting [`DataFrame`]. + /// - The length of each appended column matches the height of the [`DataFrame`]. For + /// `DataFrame`]s with no columns (ZCDFs), it is important that the height is set afterwards + /// with [`DataFrame::set_height`]. + pub unsafe fn column_extend_unchecked(&mut self, iter: impl Iterator) { + unsafe { self.get_columns_mut() }.extend(iter) + } + /// Take ownership of the underlying columns vec. pub fn take_columns(self) -> Vec { self.columns @@ -806,10 +867,7 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn shape(&self) -> (usize, usize) { - match self.columns.as_slice() { - &[] => (0, 0), - v => (v[0].len(), v.len()), - } + (self.height, self.columns.len()) } /// Get the width of the [`DataFrame`] which is the number of columns. @@ -872,7 +930,16 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn is_empty(&self) -> bool { - self.height() == 0 + matches!(self.shape(), (0, _) | (_, 0)) + } + + /// Set the height (i.e. number of rows) of this [`DataFrame`]. + /// + /// # Safety + /// + /// This needs to be equal to the length of all the columns. + pub unsafe fn set_height(&mut self, height: usize) { + self.height = height; } /// Add multiple [`Series`] to a [`DataFrame`]. @@ -1020,6 +1087,7 @@ impl DataFrame { left.append(right)?; Ok(()) })?; + self.height += other.height; Ok(self) } @@ -1036,6 +1104,7 @@ impl DataFrame { .for_each(|(left, right)| { left.append(right).expect("should not fail"); }); + self.height += other.height; } /// Extend the memory backed by this [`DataFrame`] with the values from `other`. @@ -1059,6 +1128,7 @@ impl DataFrame { "unable to extend a DataFrame of width {} with a DataFrame of width {}", self.width(), other.width(), ); + self.columns .iter_mut() .zip(other.columns.iter()) @@ -1066,7 +1136,9 @@ impl DataFrame { ensure_can_extend(&*left, right)?; left.extend(right)?; Ok(()) - }) + })?; + self.height += other.height; + Ok(()) } /// Remove a column by name and return the column removed. @@ -1173,7 +1245,7 @@ impl DataFrame { } }); - Ok(unsafe { DataFrame::new_no_checks(new_cols) }) + Ok(unsafe { DataFrame::new_no_checks(self.height(), new_cols) }) } /// Drop columns that are in `names`. @@ -1198,7 +1270,7 @@ impl DataFrame { } }); - unsafe { DataFrame::new_no_checks(new_cols) } + unsafe { DataFrame::new_no_checks(self.height(), new_cols) } } /// Insert a new column at a given index without checking for duplicates. @@ -1213,6 +1285,11 @@ impl DataFrame { ShapeMismatch: "unable to add a column of length {} to a DataFrame of height {}", column.len(), self.height(), ); + + if self.width() == 0 { + self.height = column.len(); + } + self.columns.insert(index, column); Ok(self) } @@ -1232,6 +1309,10 @@ impl DataFrame { if let Some(idx) = self.get_column_index(column.name().as_str()) { self.replace_column(idx, column)?; } else { + if self.width() == 0 { + self.height = column.len(); + } + self.columns.push(column); } Ok(()) @@ -1274,7 +1355,13 @@ impl DataFrame { debug_assert!(self.width() == 0 || self.height() == column.len()); debug_assert!(self.get_column_index(column.name().as_str()).is_none()); + // SAFETY: Invariant of function guarantees for case `width` > 0. We set the height + // properly for `width` == 0. + if self.width() == 0 { + unsafe { self.set_height(column.len()) }; + } unsafe { self.get_columns_mut() }.push(column); + self } @@ -1288,6 +1375,10 @@ impl DataFrame { self.replace_column(idx, c)?; } } else { + if self.width() == 0 { + self.height = c.len(); + } + self.columns.push(c); } Ok(()) @@ -1393,14 +1484,6 @@ impl DataFrame { self.columns.get(idx) } - /// Select a mutable series by index. - /// - /// *Note: the length of the Series should remain the same otherwise the DataFrame is invalid.* - /// For this reason the method is not public - fn select_at_idx_mut(&mut self, idx: usize) -> Option<&mut Column> { - self.columns.get_mut(idx) - } - /// Select column(s) from this [`DataFrame`] by range and return a new [`DataFrame`] /// /// # Examples @@ -1559,7 +1642,7 @@ impl DataFrame { pub fn _select_impl_unchecked(&self, cols: &[PlSmallStr]) -> PolarsResult { let selected = self.select_columns_impl(cols)?; - Ok(unsafe { DataFrame::new_no_checks(selected) }) + Ok(unsafe { DataFrame::new_no_checks(self.height(), selected) }) } /// Select with a known schema. @@ -1596,7 +1679,7 @@ impl DataFrame { ensure_names_unique(cols, |s| s.as_str())?; } let selected = self.select_columns_impl_with_schema(cols, schema)?; - Ok(unsafe { DataFrame::new_no_checks(selected) }) + Ok(unsafe { DataFrame::new_no_checks(self.height(), selected) }) } /// A non generic implementation to reduce compiler bloat. @@ -1625,7 +1708,7 @@ impl DataFrame { fn select_physical_impl(&self, cols: &[PlSmallStr]) -> PolarsResult { ensure_names_unique(cols, |s| s.as_str())?; let selected = self.select_columns_physical_impl(cols)?; - Ok(unsafe { DataFrame::new_no_checks(selected) }) + Ok(unsafe { DataFrame::new_no_checks(self.height(), selected) }) } /// Select column(s) from this [`DataFrame`] and return them into a [`Vec`]. @@ -1701,13 +1784,21 @@ impl DataFrame { Ok(selected) } - /// Select a mutable series by name. - /// *Note: the length of the Series should remain the same otherwise the DataFrame is invalid.* - /// For this reason the method is not public - fn select_mut(&mut self, name: &str) -> Option<&mut Column> { - let opt_idx = self.get_column_index(name); + fn filter_height(&self, filtered: &[Column], mask: &BooleanChunked) -> usize { + // If there is a filtered column just see how many columns there are left. + if let Some(fst) = filtered.first() { + return fst.len(); + } - opt_idx.and_then(|idx| self.select_at_idx_mut(idx)) + // Otherwise, count the number of values that would be filtered and return that height. + let num_trues = mask.num_trues(); + if mask.len() == self.height() { + num_trues + } else { + // This is for broadcasting masks + debug_assert!(num_trues == 0 || num_trues == 1); + self.height() * num_trues + } } /// Take the [`DataFrame`] rows by a boolean mask. @@ -1723,13 +1814,17 @@ impl DataFrame { /// ``` pub fn filter(&self, mask: &BooleanChunked) -> PolarsResult { let new_col = self.try_apply_columns_par(&|s| s.filter(mask))?; - Ok(unsafe { DataFrame::new_no_checks(new_col) }) + let height = self.filter_height(&new_col, mask); + + Ok(unsafe { DataFrame::new_no_checks(height, new_col) }) } /// Same as `filter` but does not parallelize. pub fn _filter_seq(&self, mask: &BooleanChunked) -> PolarsResult { let new_col = self.try_apply_columns(&|s| s.filter(mask))?; - Ok(unsafe { DataFrame::new_no_checks(new_col) }) + let height = self.filter_height(&new_col, mask); + + Ok(unsafe { DataFrame::new_no_checks(height, new_col) }) } /// Take [`DataFrame`] rows by index values. @@ -1746,7 +1841,7 @@ impl DataFrame { pub fn take(&self, indices: &IdxCa) -> PolarsResult { let new_col = POOL.install(|| self.try_apply_columns_par(&|s| s.take(indices)))?; - Ok(unsafe { DataFrame::new_no_checks(new_col) }) + Ok(unsafe { DataFrame::new_no_checks(indices.len(), new_col) }) } /// # Safety @@ -1766,7 +1861,7 @@ impl DataFrame { .map(Column::from) .collect() }; - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(idx.len(), cols) } } pub(crate) unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { @@ -1782,7 +1877,7 @@ impl DataFrame { .map(Column::from) .collect() }; - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(idx.len(), cols) } } /// Rename a column in the [`DataFrame`]. @@ -1805,7 +1900,9 @@ impl DataFrame { self.columns.iter().all(|c| c.name() != &name), Duplicate: "column rename attempted with already existing name \"{name}\"" ); - self.select_mut(column) + + self.get_column_index(column) + .and_then(|idx| self.columns.get_mut(idx)) .ok_or_else(|| polars_err!(col_not_found = column)) .map(|c| c.rename(name))?; Ok(self) @@ -1987,20 +2084,23 @@ impl DataFrame { } unsafe { - DataFrame::new_no_checks(vec![ - column_names.finish().into_column(), - repr_ca.finish().into_column(), - sorted_asc_ca.finish().into_column(), - sorted_dsc_ca.finish().into_column(), - fast_explode_list_ca.finish().into_column(), - min_value_ca.finish().into_column(), - max_value_ca.finish().into_column(), - IdxCa::from_slice_options( - PlSmallStr::from_static("distinct_count"), - &distinct_count_ca[..], - ) - .into_column(), - ]) + DataFrame::new_no_checks( + self.width(), + vec![ + column_names.finish().into_column(), + repr_ca.finish().into_column(), + sorted_asc_ca.finish().into_column(), + sorted_dsc_ca.finish().into_column(), + fast_explode_list_ca.finish().into_column(), + min_value_ca.finish().into_column(), + max_value_ca.finish().into_column(), + IdxCa::from_slice_options( + PlSmallStr::from_static("distinct_count"), + &distinct_count_ca[..], + ) + .into_column(), + ], + ) } } @@ -2392,20 +2492,31 @@ impl DataFrame { .iter() .map(|s| s.slice(offset, length)) .collect::>(); - unsafe { DataFrame::new_no_checks(col) } + + let height = if let Some(fst) = col.first() { + fst.len() + } else { + let (_, length) = slice_offsets(offset, length, self.height()); + length + }; + + unsafe { DataFrame::new_no_checks(height, col) } } /// Split [`DataFrame`] at the given `offset`. pub fn split_at(&self, offset: i64) -> (Self, Self) { let (a, b) = self.columns.iter().map(|s| s.split_at(offset)).unzip(); - let a = unsafe { DataFrame::new_no_checks(a) }; - let b = unsafe { DataFrame::new_no_checks(b) }; + + let (idx, _) = slice_offsets(offset, 0, self.height()); + + let a = unsafe { DataFrame::new_no_checks(idx, a) }; + let b = unsafe { DataFrame::new_no_checks(self.height() - idx, b) }; (a, b) } pub fn clear(&self) -> Self { let col = self.columns.iter().map(|s| s.clear()).collect::>(); - unsafe { DataFrame::new_no_checks(col) } + unsafe { DataFrame::new_no_checks(0, col) } } #[must_use] @@ -2415,7 +2526,7 @@ impl DataFrame { } // @scalar-opt let columns = self._apply_columns_par(&|s| s.slice(offset, length)); - unsafe { DataFrame::new_no_checks(columns) } + unsafe { DataFrame::new_no_checks(length, columns) } } #[must_use] @@ -2429,7 +2540,7 @@ impl DataFrame { out.shrink_to_fit(); out }); - unsafe { DataFrame::new_no_checks(columns) } + unsafe { DataFrame::new_no_checks(length, columns) } } /// Get the head of the [`DataFrame`]. @@ -2472,7 +2583,10 @@ impl DataFrame { .iter() .map(|c| c.head(length)) .collect::>(); - unsafe { DataFrame::new_no_checks(col) } + + let height = length.unwrap_or(HEAD_DEFAULT_LENGTH); + let height = usize::min(height, self.height()); + unsafe { DataFrame::new_no_checks(height, col) } } /// Get the tail of the [`DataFrame`]. @@ -2512,7 +2626,10 @@ impl DataFrame { .iter() .map(|c| c.tail(length)) .collect::>(); - unsafe { DataFrame::new_no_checks(col) } + + let height = length.unwrap_or(TAIL_DEFAULT_LENGTH); + let height = usize::min(height, self.height()); + unsafe { DataFrame::new_no_checks(height, col) } } /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches. @@ -2568,7 +2685,7 @@ impl DataFrame { #[must_use] pub fn reverse(&self) -> Self { let col = self.columns.iter().map(|s| s.reverse()).collect::>(); - unsafe { DataFrame::new_no_checks(col) } + unsafe { DataFrame::new_no_checks(self.height(), col) } } /// Shift the values by a given period and fill the parts that will be empty due to this operation @@ -2578,7 +2695,7 @@ impl DataFrame { #[must_use] pub fn shift(&self, periods: i64) -> Self { let col = self._apply_columns_par(&|s| s.shift(periods)); - unsafe { DataFrame::new_no_checks(col) } + unsafe { DataFrame::new_no_checks(self.height(), col) } } /// Replace None values with one of the following strategies: @@ -2592,7 +2709,7 @@ impl DataFrame { pub fn fill_null(&self, strategy: FillNullStrategy) -> PolarsResult { let col = self.try_apply_columns_par(&|s| s.fill_null(strategy))?; - Ok(unsafe { DataFrame::new_no_checks(col) }) + Ok(unsafe { DataFrame::new_no_checks(self.height(), col) }) } /// Aggregate the column horizontally to their min values. @@ -2734,8 +2851,9 @@ impl DataFrame { dtype.is_numeric() || matches!(dtype, DataType::Boolean) }) .cloned() - .collect(); - let numeric_df = unsafe { DataFrame::_new_no_checks_impl(columns) }; + .collect::>(); + polars_ensure!(!columns.is_empty(), InvalidOperation: "'horizontal_mean' expected at least 1 numerical column"); + let numeric_df = unsafe { DataFrame::_new_no_checks_impl(self.height(), columns) }; let sum = || numeric_df.sum_horizontal(null_strategy); @@ -2933,7 +3051,9 @@ impl DataFrame { return df.filter(&mask); }, }; - Ok(unsafe { DataFrame::new_no_checks(columns) }) + + let height = Self::infer_height(&columns); + Ok(unsafe { DataFrame::new_no_checks(height, columns) }) } /// Get a mask of all the unique rows in the [`DataFrame`]. @@ -2994,7 +3114,7 @@ impl DataFrame { .iter() .map(|c| Column::new(c.name().clone(), [c.null_count() as IdxSize])) .collect(); - unsafe { Self::new_no_checks(cols) } + unsafe { Self::new_no_checks(1, cols) } } /// Hash and combine the row values @@ -3176,6 +3296,10 @@ impl DataFrame { } DataFrame::new(new_cols) } + + pub(crate) fn infer_height(cols: &[Column]) -> usize { + cols.first().map_or(0, Column::len) + } } pub struct RecordBatchIter<'a> { @@ -3186,32 +3310,34 @@ pub struct RecordBatchIter<'a> { parallel: bool, } -impl<'a> Iterator for RecordBatchIter<'a> { +impl Iterator for RecordBatchIter<'_> { type Item = RecordBatch; fn next(&mut self) -> Option { if self.idx >= self.n_chunks { - None + return None; + } + + // Create a batch of the columns with the same chunk no. + let batch_cols: Vec = if self.parallel { + let iter = self + .columns + .par_iter() + .map(Column::as_materialized_series) + .map(|s| s.to_arrow(self.idx, self.compat_level)); + POOL.install(|| iter.collect()) } else { - // Create a batch of the columns with the same chunk no. - let batch_cols = if self.parallel { - let iter = self - .columns - .par_iter() - .map(Column::as_materialized_series) - .map(|s| s.to_arrow(self.idx, self.compat_level)); - POOL.install(|| iter.collect()) - } else { - self.columns - .iter() - .map(Column::as_materialized_series) - .map(|s| s.to_arrow(self.idx, self.compat_level)) - .collect() - }; - self.idx += 1; + self.columns + .iter() + .map(Column::as_materialized_series) + .map(|s| s.to_arrow(self.idx, self.compat_level)) + .collect() + }; + self.idx += 1; - Some(RecordBatch::new(batch_cols)) - } + let length = batch_cols.first().map_or(0, |arr| arr.len()); + + Some(RecordBatch::new(length, batch_cols)) } fn size_hint(&self) -> (usize, Option) { @@ -3232,7 +3358,10 @@ impl Iterator for PhysRecordBatchIter<'_> { .iter_mut() .map(|phys_iter| phys_iter.next().cloned()) .collect::>>() - .map(RecordBatch::new) + .map(|arrs| { + let length = arrs.first().map_or(0, |arr| arr.len()); + RecordBatch::new(length, arrs) + }) } fn size_hint(&self) -> (usize, Option) { @@ -3481,25 +3610,4 @@ mod test { assert_eq!(df.get_column_names(), &["a", "b", "c"]); Ok(()) } - - #[test] - fn test_empty_df_hstack() -> PolarsResult<()> { - let mut base = df!( - "a" => [1, 2, 3], - "b" => [1, 2, 3] - )?; - - // has got columns, but no rows - let mut df = base.clear(); - let out = df.with_column(Series::new("c".into(), [1]))?; - assert_eq!(out.shape(), (0, 3)); - assert!(out.iter().all(|s| s.len() == 0)); - - // no columns - base.columns = vec![]; - let out = base.with_column(Series::new("c".into(), [1]))?; - assert_eq!(out.shape(), (1, 1)); - - Ok(()) - } } diff --git a/crates/polars-core/src/frame/row/av_buffer.rs b/crates/polars-core/src/frame/row/av_buffer.rs index 5d8da9c55666..03cac11538d5 100644 --- a/crates/polars-core/src/frame/row/av_buffer.rs +++ b/crates/polars-core/src/frame/row/av_buffer.rs @@ -547,6 +547,7 @@ impl<'a> AnyValueBufferTrusted<'a> { } } + /// Clear `self` and give `capacity`, returning the old contents as a [`Series`]. pub fn reset(&mut self, capacity: usize) -> Series { use AnyValueBufferTrusted::*; match self { @@ -616,15 +617,33 @@ impl<'a> AnyValueBufferTrusted<'a> { }, #[cfg(feature = "dtype-struct")] Struct(b) => { + // @Q? Maybe we need to add a length parameter here for ZFS's. I am not very happy + // with just setting the length to zero for that case. + if b.is_empty() { + return StructChunked::from_series(PlSmallStr::EMPTY, 0, [].iter()) + .unwrap() + .into_series(); + } + + let mut min_len = usize::MAX; + let mut max_len = usize::MIN; + let v = b .iter_mut() .map(|(b, name)| { let mut s = b.reset(capacity); + + min_len = min_len.min(s.len()); + max_len = max_len.max(s.len()); + s.rename(name.clone()); s }) .collect::>(); - StructChunked::from_series(PlSmallStr::EMPTY, v.iter()) + + let length = if min_len == 0 { 0 } else { max_len }; + + StructChunked::from_series(PlSmallStr::EMPTY, length, v.iter()) .unwrap() .into_series() }, diff --git a/crates/polars-core/src/frame/row/dataframe.rs b/crates/polars-core/src/frame/row/dataframe.rs index 1d11dcd9ecc0..97891c76d478 100644 --- a/crates/polars-core/src/frame/row/dataframe.rs +++ b/crates/polars-core/src/frame/row/dataframe.rs @@ -51,6 +51,12 @@ impl DataFrame { where I: Iterator>, { + if schema.is_empty() { + let height = rows.count(); + let columns = Vec::new(); + return Ok(unsafe { DataFrame::new_no_checks(height, columns) }); + } + let capacity = rows.size_hint().0; let mut buffers: Vec<_> = schema diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs index 87904e6f98cb..ad8831ebda54 100644 --- a/crates/polars-core/src/frame/row/mod.rs +++ b/crates/polars-core/src/frame/row/mod.rs @@ -116,10 +116,6 @@ pub fn infer_schema( } fn add_or_insert(values: &mut Tracker, key: PlSmallStr, dtype: DataType) { - if dtype == DataType::Null { - return; - } - if values.contains_key(&key) { let x = values.get_mut(&key).unwrap(); x.insert(dtype); @@ -210,7 +206,7 @@ pub fn rows_to_schema_first_non_null( .iter_values() .enumerate() .filter_map(|(i, dtype)| { - // double check struct and list types types + // double check struct and list types // nested null values can be wrongly inferred by front ends match dtype { DataType::Null | DataType::List(_) => Some(i), diff --git a/crates/polars-core/src/frame/row/transpose.rs b/crates/polars-core/src/frame/row/transpose.rs index 0f41bb2749d5..6910dfc7800c 100644 --- a/crates/polars-core/src/frame/row/transpose.rs +++ b/crates/polars-core/src/frame/row/transpose.rs @@ -84,7 +84,7 @@ impl DataFrame { })); }, }; - Ok(unsafe { DataFrame::new_no_checks(cols_t) }) + Ok(unsafe { DataFrame::new_no_checks(new_height, cols_t) }) } pub fn transpose( diff --git a/crates/polars-core/src/functions.rs b/crates/polars-core/src/functions.rs index 50ce5d14e491..ebc5006e0493 100644 --- a/crates/polars-core/src/functions.rs +++ b/crates/polars-core/src/functions.rs @@ -38,7 +38,7 @@ pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { None => columns.push(Column::full_null(name.clone(), height, dtype)), } } - unsafe { DataFrame::new_no_checks(columns) } + unsafe { DataFrame::new_no_checks(height, columns) } }) .collect::>(); diff --git a/crates/polars-core/src/hashing/identity.rs b/crates/polars-core/src/hashing/identity.rs index e917291f1586..a1ae697106f9 100644 --- a/crates/polars-core/src/hashing/identity.rs +++ b/crates/polars-core/src/hashing/identity.rs @@ -33,28 +33,3 @@ impl Hasher for IdHasher { } pub type IdBuildHasher = BuildHasherDefault; - -#[derive(Debug)] -/// Contains an idx of a row in a DataFrame and the precomputed hash of that row. -/// -/// That hash still needs to be used to create another hash to be able to resize hashmaps without -/// accidental quadratic behavior. So do not use an Identity function! -pub struct IdxHash { - // idx in row of Series, DataFrame - pub idx: IdxSize, - // precomputed hash of T - pub hash: u64, -} - -impl Hash for IdxHash { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash) - } -} - -impl IdxHash { - #[inline] - pub(crate) fn new(idx: IdxSize, hash: u64) -> Self { - IdxHash { idx, hash } - } -} diff --git a/crates/polars-core/src/hashing/mod.rs b/crates/polars-core/src/hashing/mod.rs index 8f966eb2f317..370401eb6e3e 100644 --- a/crates/polars-core/src/hashing/mod.rs +++ b/crates/polars-core/src/hashing/mod.rs @@ -1,15 +1,11 @@ mod identity; pub(crate) mod vector_hasher; -use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; +use std::hash::{BuildHasherDefault, Hash, Hasher}; -use hashbrown::hash_map::RawEntryMut; -use hashbrown::HashMap; pub use identity::*; pub use vector_hasher::*; -use crate::prelude::*; - // hash combine from c++' boost lib #[inline] pub fn _boost_hash_combine(l: u64, r: u64) -> u64 { @@ -19,73 +15,3 @@ pub fn _boost_hash_combine(l: u64, r: u64) -> u64 { // We must strike a balance between cache // Overallocation seems a lot more expensive than resizing so we start reasonable small. pub const _HASHMAP_INIT_SIZE: usize = 512; - -/// Utility function used as comparison function in the hashmap. -/// The rationale is that equality is an AND operation and therefore its probability of success -/// declines rapidly with the number of keys. Instead of first copying an entire row from both -/// sides and then do the comparison, we do the comparison value by value catching early failures -/// eagerly. -/// -/// # Safety -/// Doesn't check any bounds -#[inline] -pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usize) -> bool { - for s in keys.get_columns() { - if !s.equal_element(idx_a, idx_b, s) { - return false; - } - } - true -} - -/// Populate a multiple key hashmap with row indexes. -/// -/// Instead of the keys (which could be very large), the row indexes are stored. -/// To check if a row is equal the original DataFrame is also passed as ref. -/// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared -/// on equality. -pub fn populate_multiple_key_hashmap( - hash_tbl: &mut HashMap, - // row index - idx: IdxSize, - // hash - original_h: u64, - // keys of the hash table (will not be inserted, the indexes will be used) - // the keys are needed for the equality check - keys: &DataFrame, - // value to insert - vacant_fn: G, - // function that gets a mutable ref to the occupied value in the hash table - mut occupied_fn: F, -) where - G: Fn() -> V, - F: FnMut(&mut V), - H: BuildHasher, -{ - let entry = hash_tbl - .raw_entry_mut() - // uses the idx to probe rows in the original DataFrame with keys - // to check equality to find an entry - // this does not invalidate the hashmap as this equality function is not used - // during rehashing/resize (then the keys are already known to be unique). - // Only during insertion and probing an equality function is needed - .from_hash(original_h, |idx_hash| { - // first check the hash values - // before we incur a cache miss - idx_hash.hash == original_h && { - let key_idx = idx_hash.idx; - // SAFETY: - // indices in a group_by operation are always in bounds. - unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) } - } - }); - match entry { - RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn()); - }, - RawEntryMut::Occupied(mut entry) => { - let (_k, v) = entry.get_key_value_mut(); - occupied_fn(v); - }, - } -} diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 7dfb07c64d58..e00e45f1ede8 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -1,4 +1,5 @@ use arrow::bitmap::utils::get_bit_unchecked; +use polars_utils::hashing::folded_multiply; use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use rayon::prelude::*; use xxhash_rust::xxh3::xxh3_64_with_seed; @@ -30,11 +31,6 @@ pub trait VecHash { } } -pub(crate) const fn folded_multiply(s: u64, by: u64) -> u64 { - let result = (s as u128).wrapping_mul(by as u128); - ((result & 0xffff_ffff_ffff_ffff) as u64) ^ ((result >> 64) as u64) -} - pub(crate) fn get_null_hash_value(random_state: &PlRandomState) -> u64 { // we just start with a large prime number and hash that twice // to get a constant hash value for null/None diff --git a/crates/polars-core/src/named_from.rs b/crates/polars-core/src/named_from.rs index 4d5714e4e517..8eee78d5ec6e 100644 --- a/crates/polars-core/src/named_from.rs +++ b/crates/polars-core/src/named_from.rs @@ -9,7 +9,7 @@ use chrono::NaiveDateTime; #[cfg(feature = "dtype-time")] use chrono::NaiveTime; -use crate::chunked_array::builder::{get_list_builder, AnonymousListBuilder}; +use crate::chunked_array::builder::get_list_builder; use crate::prelude::*; pub trait NamedFrom { @@ -135,22 +135,13 @@ impl> NamedFrom for Series { let dt = series_slice[0].dtype(); - // inner type is also list so we need the anonymous builder - if let DataType::List(_) = dt { - let mut builder = AnonymousListBuilder::new(name, list_cap, Some(dt.clone())); - for s in series_slice { - builder.append_series(s).unwrap(); - } - builder.finish().into_series() - } else { - let values_cap = series_slice.iter().fold(0, |acc, s| acc + s.len()); + let values_cap = series_slice.iter().fold(0, |acc, s| acc + s.len()); - let mut builder = get_list_builder(dt, values_cap, list_cap, name).unwrap(); - for series in series_slice { - builder.append_series(series).unwrap(); - } - builder.finish().into_series() + let mut builder = get_list_builder(dt, values_cap, list_cap, name); + for series in series_slice { + builder.append_series(series).unwrap(); } + builder.finish().into_series() } } @@ -165,7 +156,7 @@ impl]>> NamedFrom]> for Series { None => &DataType::Null, }; - let mut builder = get_list_builder(dt, values_cap, series_slice.len(), name).unwrap(); + let mut builder = get_list_builder(dt, values_cap, series_slice.len(), name); for series in series_slice { builder.append_opt_series(series.as_ref()).unwrap(); } diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index d100cf91172f..38fe0377f9c4 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -20,6 +20,8 @@ pub trait SchemaExt { fn iter_fields(&self) -> impl ExactSizeIterator + '_; fn to_supertype(&mut self, other: &Schema) -> PolarsResult; + + fn materialize_unknown_dtypes(&self) -> PolarsResult; } impl SchemaExt for Schema { @@ -88,6 +90,13 @@ impl SchemaExt for Schema { } Ok(changed) } + + /// Materialize all unknown dtypes in this schema. + fn materialize_unknown_dtypes(&self) -> PolarsResult { + self.iter() + .map(|(name, dtype)| Ok((name.clone(), dtype.materialize_unknown()?))) + .collect() + } } pub trait SchemaNamesAndDtypes { diff --git a/crates/polars-core/src/serde/chunked_array.rs b/crates/polars-core/src/serde/chunked_array.rs index 145f05c9af38..7a643d17185e 100644 --- a/crates/polars-core/src/serde/chunked_array.rs +++ b/crates/polars-core/src/serde/chunked_array.rs @@ -173,9 +173,10 @@ impl Serialize for StructChunked { )); } - let mut state = serializer.serialize_map(Some(3))?; + let mut state = serializer.serialize_map(Some(4))?; state.serialize_entry("name", self.name())?; state.serialize_entry("datatype", self.dtype())?; + state.serialize_entry("length", &self.len())?; state.serialize_entry("values", &self.fields_as_series())?; state.end() } diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 0ef07e702374..0fb9d9f05f18 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -87,8 +87,8 @@ impl Serialize for Series { )), dt => { with_match_physical_numeric_polars_type!(dt, |$T| { - let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); - ca.serialize(serializer) + let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); + ca.serialize(serializer) }) }, } @@ -100,7 +100,7 @@ impl<'de> Deserialize<'de> for Series { where D: Deserializer<'de>, { - const FIELDS: &[&str] = &["name", "datatype", "bit_settings", "values"]; + const FIELDS: &[&str] = &["name", "datatype", "bit_settings", "length", "values"]; struct SeriesVisitor; @@ -109,7 +109,7 @@ impl<'de> Deserialize<'de> for Series { fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { formatter - .write_str("struct {name: , datatype: , bit_settings?: , values: }") + .write_str("struct {name: , datatype: , bit_settings?: , length?: , values: }") } fn visit_map(self, mut map: A) -> std::result::Result @@ -118,6 +118,7 @@ impl<'de> Deserialize<'de> for Series { { let mut name: Option> = None; let mut dtype = None; + let mut length = None; let mut bit_settings: Option = None; let mut values_set = false; while let Some(key) = map.next_key::>().unwrap() { @@ -134,6 +135,8 @@ impl<'de> Deserialize<'de> for Series { "bit_settings" => { bit_settings = Some(map.next_value()?); }, + // length is only used for struct at the moment + "length" => length = Some(map.next_value()?), "values" => { // we delay calling next_value until we know the dtype values_set = true; @@ -275,9 +278,29 @@ impl<'de> Deserialize<'de> for Series { Ok(Series::new(name, values)) }, #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { + DataType::Struct(fields) => { + let length = length.ok_or_else(|| de::Error::missing_field("length"))?; let values: Vec = map.next_value()?; - let ca = StructChunked::from_series(name.clone(), values.iter()).unwrap(); + + if fields.len() != values.len() { + let expected = format!("expected {} value series", fields.len()); + let expected = expected.as_str(); + return Err(de::Error::invalid_length(values.len(), &expected)); + } + + for (f, v) in fields.iter().zip(values.iter()) { + if f.dtype() != v.dtype() { + let err = format!( + "type mismatch for struct. expected: {}, given: {}", + f.dtype(), + v.dtype() + ); + return Err(de::Error::custom(err)); + } + } + + let ca = StructChunked::from_series(name.clone(), length, values.iter()) + .unwrap(); let mut s = ca.into_series(); s.rename(name); Ok(s) diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 0ca52862b3b9..6149690c85c5 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -657,18 +657,11 @@ fn any_values_to_list( DataType::Categorical(Some(Arc::new(RevMapping::default())), *ordering) }, - // Structs don't support empty fields yet. - // We must ensure the data-types match what we do physical - #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) if fields.is_empty() => { - DataType::Struct(vec![Field::new(PlSmallStr::EMPTY, DataType::Null)]) - }, - _ => inner_type.clone(), }; let mut builder = - get_list_builder(&list_inner_type, capacity * 5, capacity, PlSmallStr::EMPTY)?; + get_list_builder(&list_inner_type, capacity * 5, capacity, PlSmallStr::EMPTY); for av in avs { match av { @@ -832,7 +825,9 @@ fn any_values_to_struct( ) -> PolarsResult { // Fast path for structs with no fields. if fields.is_empty() { - return Ok(StructChunked::full_null(PlSmallStr::EMPTY, values.len()).into_series()); + return Ok( + StructChunked::from_series(PlSmallStr::EMPTY, values.len(), [].iter())?.into_series(), + ); } // The physical series fields of the struct. @@ -873,7 +868,8 @@ fn any_values_to_struct( series_fields.push(s) } - let mut out = StructChunked::from_series(PlSmallStr::EMPTY, series_fields.iter())?; + let mut out = + StructChunked::from_series(PlSmallStr::EMPTY, values.len(), series_fields.iter())?; if has_outer_validity { let mut validity = MutableBitmap::new(); validity.extend_constant(values.len(), true); diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 2e613ea7e1a0..f9e5ff42139b 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -146,7 +146,11 @@ fn broadcast_array(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<(ArrayChunk }, (a, b) if a == b => (lhs.clone(), rhs.clone()), _ => { - polars_bail!(InvalidOperation: "can only do arithmetic of arrays of the same type and shape; got {} and {}", lhs.dtype(), rhs.dtype()) + polars_bail!( + InvalidOperation: + "can only do arithmetic of arrays of the same type and shape; got {} and {}", + lhs.dtype(), rhs.dtype() + ) }, }; Ok(out) @@ -392,23 +396,35 @@ pub(crate) fn coerce_lhs_rhs<'a>( if let Some(result) = coerce_time_units(lhs, rhs) { return Ok(result); } - let dtype = match (lhs.dtype(), rhs.dtype()) { + let (left_dtype, right_dtype) = (lhs.dtype(), rhs.dtype()); + let leaf_super_dtype = match (left_dtype, right_dtype) { #[cfg(feature = "dtype-struct")] (DataType::Struct(_), DataType::Struct(_)) => { return Ok((Cow::Borrowed(lhs), Cow::Borrowed(rhs))) }, - _ => try_get_supertype(lhs.dtype(), rhs.dtype())?, + _ => try_get_supertype(left_dtype.leaf_dtype(), right_dtype.leaf_dtype())?, }; - let left = if lhs.dtype() == &dtype { + let mut new_left_dtype = left_dtype.cast_leaf(leaf_super_dtype.clone()); + let mut new_right_dtype = right_dtype.cast_leaf(leaf_super_dtype); + + // Cast List<->Array to List + if (left_dtype.is_list() && right_dtype.is_array()) + || (left_dtype.is_array() && right_dtype.is_list()) + { + new_left_dtype = try_get_supertype(&new_left_dtype, &new_right_dtype)?; + new_right_dtype = new_left_dtype.clone(); + } + + let left = if lhs.dtype() == &new_left_dtype { Cow::Borrowed(lhs) } else { - Cow::Owned(lhs.cast(&dtype)?) + Cow::Owned(lhs.cast(&new_left_dtype)?) }; - let right = if rhs.dtype() == &dtype { + let right = if rhs.dtype() == &new_right_dtype { Cow::Borrowed(rhs) } else { - Cow::Owned(rhs.cast(&dtype)?) + Cow::Owned(rhs.cast(&new_right_dtype)?) }; Ok((left, right)) } @@ -522,6 +538,9 @@ impl Add for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.add(b)) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Add.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.add_to(rhs.as_ref()) @@ -540,6 +559,9 @@ impl Sub for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.sub(b)) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Sub.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.subtract(rhs.as_ref()) @@ -574,6 +596,9 @@ impl Mul for &Series { let out = rhs.multiply(self)?; Ok(out.with_name(self.name().clone())) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Mul.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.multiply(rhs.as_ref()) @@ -595,19 +620,18 @@ impl Div for &Series { use DataType::*; match (self.dtype(), rhs.dtype()) { #[cfg(feature = "dtype-struct")] - (Struct(_), Struct(_)) => { - _struct_arithmetic(self, rhs, |a, b| a.div(b)) - }, + (Struct(_), Struct(_)) => _struct_arithmetic(self, rhs, |a, b| a.div(b)), (Duration(_), _) => self.divide(rhs), - | (Date, _) + (Date, _) | (Datetime(_, _), _) | (Time, _) - // temporal rhs - | (_ , Duration(_)) - | (_ , Time) - | (_ , Date) - | (_ , Datetime(_, _)) - => polars_bail!(opq = div, self.dtype(), rhs.dtype()), + | (_, Duration(_)) + | (_, Time) + | (_, Date) + | (_, Datetime(_, _)) => polars_bail!(opq = div, self.dtype(), rhs.dtype()), + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Div.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.divide(rhs.as_ref()) @@ -631,6 +655,9 @@ impl Rem for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.rem(b)) }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list_borrowed::NumericListOp::Rem.execute(self, rhs) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.remainder(rhs.as_ref()) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 1628780d7b0e..b96b78687c13 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -1,177 +1,1058 @@ //! Allow arithmetic operations for ListChunked. +//! use polars_error::{feature_gated, PolarsResult}; -use super::*; -use crate::chunked_array::builder::AnonymousListBuilder; - -/// Given an ArrayRef with some primitive values, wrap it in list(s) until it -/// matches the requested shape. -fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef { - if let Some(list_chunk) = shape.as_any().downcast_ref::() { - let result = LargeListArray::new( - list_chunk.dtype().clone(), - list_chunk.offsets().clone(), - reshape_list_based_on(data, list_chunk.values()), - list_chunk.validity().cloned(), - ); - Box::new(result) - } else { - data.clone() +use polars_error::{feature_gated, PolarsResult}; + +use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series}; + +impl NumOpsDispatchInner for ListType { + fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Add.execute(&lhs.clone().into_series(), rhs) } -} -/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or -/// more nulls. -fn does_list_have_nulls(data: &ArrayRef) -> bool { - if let Some(list_chunk) = data.as_any().downcast_ref::() { - if list_chunk - .validity() - .map(|bitmap| bitmap.unset_bits() > 0) - .unwrap_or(false) - { - true - } else { - does_list_have_nulls(list_chunk.values()) - } - } else { - false + fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Sub.execute(&lhs.clone().into_series(), rhs) + } + + fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Mul.execute(&lhs.clone().into_series(), rhs) } + + fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Div.execute(&lhs.clone().into_series(), rhs) + } + + fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::Rem.execute(&lhs.clone().into_series(), rhs) + } +} + +#[derive(Debug, Clone)] +pub enum NumericListOp { + Add, + Sub, + Mul, + Div, + Rem, + FloorDiv, } -/// Return whether the left and right have the same shape. We assume neither has -/// any nulls, recursively. -fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { - debug_assert!(!does_list_have_nulls(left)); - debug_assert!(!does_list_have_nulls(right)); - let left_as_list = left.as_any().downcast_ref::(); - let right_as_list = right.as_any().downcast_ref::(); - match (left_as_list, right_as_list) { - (Some(left), Some(right)) => { - left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values()) - }, - (None, None) => left.len() == right.len(), - _ => false, +impl NumericListOp { + #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))] + pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult { + feature_gated!("list_arithmetic", { + use either::Either; + + // `trim_to_normalized_offsets` ensures we don't perform excessive + // memory allocation / compute on memory regions that have been + // sliced out. + let lhs = lhs.list_rechunk_and_trim_to_normalized_offsets(); + let rhs = rhs.list_rechunk_and_trim_to_normalized_offsets(); + + let binary_op_exec = match BinaryListNumericOpHelper::try_new( + self.clone(), + lhs.name().clone(), + lhs.dtype(), + rhs.dtype(), + lhs.len(), + rhs.len(), + { + let (a, b) = lhs.list_offsets_and_validities_recursive(); + debug_assert!(a.iter().all(|x| *x.first() as usize == 0)); + (a, b, lhs.clone()) + }, + { + let (a, b) = rhs.list_offsets_and_validities_recursive(); + debug_assert!(a.iter().all(|x| *x.first() as usize == 0)); + (a, b, rhs.clone()) + }, + lhs.rechunk_validity(), + rhs.rechunk_validity(), + )? { + Either::Left(v) => v, + Either::Right(ca) => return Ok(ca.into_series()), + }; + + Ok(binary_op_exec.finish()?.into_series()) + }) } } -impl ListChunked { - /// Helper function for NumOpsDispatchInner implementation for ListChunked. - /// - /// Run the given `op` on `self` and `rhs`. - fn arithm_helper( - &self, - rhs: &Series, - op: &dyn Fn(&Series, &Series) -> PolarsResult, - has_nulls: Option, - ) -> PolarsResult { - polars_ensure!( - self.len() == rhs.len(), - InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", - self.len(), - rhs.len() - ); - - let mut has_nulls = has_nulls.unwrap_or(false); - if !has_nulls { - for chunk in self.chunks().iter() { - if does_list_have_nulls(chunk) { - has_nulls = true; - break; - } +#[cfg(feature = "list_arithmetic")] +use inner::BinaryListNumericOpHelper; + +#[cfg(feature = "list_arithmetic")] +mod inner { + use arrow::bitmap::Bitmap; + use arrow::compute::utils::combine_validities_and; + use arrow::offset::OffsetsBuffer; + use either::Either; + use num_traits::Zero; + use polars_compute::arithmetic::pl_num::PlNumArithmetic; + use polars_compute::arithmetic::ArithmeticKernel; + use polars_compute::comparisons::TotalEqKernel; + use polars_utils::float::IsFloat; + + use super::super::*; + + impl NumericListOp { + fn name(&self) -> &'static str { + match self { + Self::Add => "add", + Self::Sub => "sub", + Self::Mul => "mul", + Self::Div => "div", + Self::Rem => "rem", + Self::FloorDiv => "floor_div", } } - if !has_nulls { - for chunk in rhs.chunks().iter() { - if does_list_have_nulls(chunk) { - has_nulls = true; - break; + + fn try_get_leaf_supertype( + &self, + prim_dtype_lhs: &DataType, + prim_dtype_rhs: &DataType, + ) -> PolarsResult { + let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + Ok(if matches!(self, Self::Div) { + if dtype.is_float() { + dtype + } else { + DataType::Float64 } - } + } else { + dtype + }) } - if has_nulls { - // A slower implementation since we can't just add the underlying - // values Arrow arrays. Given nulls, the two values arrays might not - // line up the way we expect. - let mut result = AnonymousListBuilder::new( - self.name().clone(), - self.len(), - Some(self.inner_dtype().clone()), - ); - let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { - let (Some(a_owner), Some(b_owner)) = (a, b) else { - // Operations with nulls always result in nulls: - return Ok(None); - }; - let a = a_owner.as_ref(); - let b = b_owner.as_ref(); - polars_ensure!( - a.len() == b.len(), - InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", - a.len(), - b.len() - ); - let chunk_result = if let Ok(a_listchunked) = a.list() { - // If `a` contains more lists, we're going to reach this - // function recursively, and again have to decide whether to - // use the fast path (no nulls) or slow path (there were - // nulls). Since we know there were nulls, that means we - // have to stick to the slow path, so pass that information - // along. - a_listchunked.arithm_helper(b, op, Some(true)) - } else { - op(a, b) - }; - chunk_result.map(Some) - }).collect::>>>()?; - for s in combined.iter() { - if let Some(s) = s { - result.append_series(s)?; - } else { - result.append_null(); + + /// For operations that perform divisions on integers, sets the validity to NULL on rows where + /// the denominator is 0. + fn prepare_numeric_op_side_validities( + &self, + lhs: &mut PrimitiveArray, + rhs: &mut PrimitiveArray, + swapped: bool, + ) where + PrimitiveArray: + polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + if !T::Native::is_float() { + match self { + Self::Div | Self::Rem | Self::FloorDiv => { + let target = if swapped { lhs } else { rhs }; + let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero()); + let validity = combine_validities_and(target.validity(), Some(&ne_0)); + target.set_validity(validity); + }, + _ => {}, } } - return Ok(result.finish().into()); } - let l_rechunked = self.clone().rechunk().into_series(); - let l_leaf_array = l_rechunked.get_leaf_array(); - let r_leaf_array = rhs.rechunk().get_leaf_array(); - polars_ensure!( - lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), - InvalidOperation: "can only do arithmetic operations on lists of the same size" - ); - - let result = op(&l_leaf_array, &r_leaf_array)?; - - // We now need to wrap the Arrow arrays with the metadata that turns - // them into lists: - // TODO is there a way to do this without cloning the underlying data? - let result_chunks = result.chunks(); - assert_eq!(result_chunks.len(), 1); - let left_chunk = &l_rechunked.chunks()[0]; - let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk); - - unsafe { - let mut result = - ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); - result.compute_len(); - Ok(result.into()) + + /// For list<->primitive where the primitive is broadcasted, we can dispatch to + /// `ArithmeticKernel`, which can have optimized codepaths for when one side is + /// a scalar. + fn apply_array_to_scalar( + &self, + arr_lhs: PrimitiveArray, + r: T::Native, + swapped: bool, + ) -> PrimitiveArray { + match self { + Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r), + Self::Sub => { + if swapped { + ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r) + } + }, + Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r), + Self::Div => { + if swapped { + ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::legacy_div_scalar(arr_lhs, r) + } + }, + Self::Rem => { + if swapped { + ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r) + } + }, + Self::FloorDiv => { + if swapped { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r) + } + }, + } } } + + macro_rules! with_match_numeric_list_op { + ($op:expr, $swapped:expr, | $_:tt $OP:tt | $($body:tt)* ) => ({ + macro_rules! __with_func__ {( $_ $OP:tt ) => ( $($body)* )} + + match $op { + NumericListOp::Add => __with_func__! { (PlNumArithmetic::wrapping_add) }, + NumericListOp::Sub => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_sub(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_sub) } + } + }, + NumericListOp::Mul => __with_func__! { (PlNumArithmetic::wrapping_mul) }, + NumericListOp::Div => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::legacy_div(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::legacy_div) } + } + }, + NumericListOp::Rem => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_mod(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_mod) } + } + }, + NumericListOp::FloorDiv => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_floor_div(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_floor_div) } + } + }, + } + }) } -impl NumOpsDispatchInner for ListType { - fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) + #[derive(Debug)] + enum BinaryOpApplyType { + ListToList, + ListToPrimitive, + PrimitiveToList, } - fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) + + #[derive(Debug)] + enum Broadcast { + Left, + Right, + #[allow(clippy::enum_variant_names)] + NoBroadcast, } - fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) + + /// Utility to perform a binary operation between the primitive values of + /// 2 columns, where at least one of the columns is a `ListChunked` type. + pub(super) struct BinaryListNumericOpHelper { + op: NumericListOp, + output_name: PlSmallStr, + op_apply_type: BinaryOpApplyType, + broadcast: Broadcast, + output_dtype: DataType, + output_primitive_dtype: DataType, + output_len: usize, + /// Outer validity of the result, we always materialize this to reduce the + /// amount of code paths we need. + outer_validity: Bitmap, + // The series are stored as they are used for list broadcasting. + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + list_to_prim_lhs: Option<(Box, usize)>, + swapped: bool, } - fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) + + /// This lets us separate some logic into `new()` to reduce the amount of + /// monomorphized code. + impl BinaryListNumericOpHelper { + /// Checks that: + /// * Dtypes are compatible: + /// * list<->primitive | primitive<->list + /// * list<->list both contain primitives (e.g. List) + /// * Primitive dtypes match + /// * Lengths are compatible: + /// * 1<->n | n<->1 + /// * n<->n + /// * Both sides have at least 1 non-NULL outer row. + /// + /// Does not check: + /// * Whether the offsets are aligned for list<->list, this will be checked during execution. + /// + /// This returns an `Either` which may contain the final result to simplify + /// the implementation. + #[allow(clippy::too_many_arguments)] + pub(super) fn try_new( + op: NumericListOp, + output_name: PlSmallStr, + dtype_lhs: &DataType, + dtype_rhs: &DataType, + len_lhs: usize, + len_rhs: usize, + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + validity_lhs: Option, + validity_rhs: Option, + ) -> PolarsResult> { + let prim_dtype_lhs = dtype_lhs.leaf_dtype(); + let prim_dtype_rhs = dtype_rhs.leaf_dtype(); + + let output_primitive_dtype = + op.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) { + (l @ DataType::List(a), r @ DataType::List(b)) => { + // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user + // directly adds 2 series together it bypasses the DSL. + // This is currently duplicated code and should be replaced one day with an assert after Series ops get + // checked properly. + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op.name(), l, r, + ); + } + (BinaryOpApplyType::ListToList, l) + }, + (list_dtype @ DataType::List(_), x) + if x.is_numeric() || x.is_bool() || x.is_null() => + { + (BinaryOpApplyType::ListToPrimitive, list_dtype) + }, + (x, list_dtype @ DataType::List(_)) + if x.is_numeric() || x.is_bool() || x.is_null() => + { + (BinaryOpApplyType::PrimitiveToList, list_dtype) + }, + (l, r) => polars_bail!( + InvalidOperation: + "{} operation not supported for dtypes: {} != {}", + op.name(), l, r, + ), + }; + + let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone()); + + let (broadcast, output_len) = match (len_lhs, len_rhs) { + (l, r) if l == r => (Broadcast::NoBroadcast, l), + (1, v) => (Broadcast::Left, v), + (v, 1) => (Broadcast::Right, v), + (l, r) => polars_bail!( + ShapeMismatch: + "cannot {} two columns of differing lengths: {} != {}", + op.name(), l, r + ), + }; + + let DataType::List(output_inner_dtype) = &output_dtype else { + unreachable!() + }; + + // # NULL semantics + // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]] + // * Essentially as if the NULL primitive was added to every primitive in the row of the list column. + // * NULL (List[Int64]) + 1 (Int64) => NULL + // * NULL (List[Int64]) + [1] (List[Int64]) => NULL + + if output_len == 0 + || (len_lhs == 1 + && matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive + ) + && validity_lhs.as_ref().map_or(false, |x| { + !x.get_bit(0) // is not valid + })) + || (len_rhs == 1 + && matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList + ) + && validity_rhs.as_ref().map_or(false, |x| { + !x.get_bit(0) // is not valid + })) + { + return Ok(Either::Right(ListChunked::full_null_with_dtype( + output_name.clone(), + output_len, + output_inner_dtype.as_ref(), + ))); + } + + // At this point: + // * All unit length list columns have a valid outer value. + + // The outer validity is just the validity of any non-broadcasting lists. + let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) { + // Both lists with same length, we combine the validity. + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => { + combine_validities_and(l.as_ref(), r.as_ref()) + }, + // Match all other combinations that have non-broadcasting lists. + ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive, + Broadcast::NoBroadcast | Broadcast::Right, + v, + _, + ) + | ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList, + Broadcast::NoBroadcast | Broadcast::Left, + _, + v, + ) => v, + _ => None, + } + .unwrap_or_else(|| Bitmap::new_with_value(true, output_len)); + + Ok(Either::Left(Self { + op, + output_name, + op_apply_type, + broadcast, + output_dtype: output_dtype.clone(), + output_primitive_dtype, + output_len, + outer_validity, + data_lhs, + data_rhs, + list_to_prim_lhs: None, + swapped: false, + })) + } + + pub(super) fn finish(mut self) -> PolarsResult { + // We have physical codepaths for a subset of the possible combinations of broadcasting and + // column types. The remaining combinations are handled by dispatching to the physical + // codepaths after operand swapping and/or materialized broadcasting. + // + // # Physical impl table + // Legend + // * | N | // impl "N" + // * | [N] | // dispatches to impl "N" + // + // | L | N | R | // Broadcast (L)eft, (N)oBroadcast, (R)ight + // ListToList | [1] | 0 | 1 | + // ListToPrimitive | [2] | 2 | 3 | // list broadcasting just materializes and dispatches to NoBroadcast + // PrimitiveToList | [3] | [2] | [2] | + + self.swapped = true; + + match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToList, Broadcast::Right) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => { + // We materialize the list columns with `new_from_index`, as otherwise we'd have to + // implement logic that broadcasts the offsets and validities across multiple levels + // of nesting. But we will re-use the materialized memory to store the result. + + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_rhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::NoBroadcast; + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToList, Broadcast::Left) => { + self.broadcast = Broadcast::Right; + + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => { + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_lhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.broadcast = Broadcast::NoBroadcast; + + // This does not swap! We are just dispatching to `NoBroadcast` + // after materializing the broadcasted list array. + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::Right; + + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + } + } + + fn _finish_impl_dispatch(&mut self) -> PolarsResult { + let output_dtype = self.output_dtype.clone(); + let output_len = self.output_len; + + let prim_lhs = self + .data_lhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + let prim_rhs = self + .data_rhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + + debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype()); + let prim_dtype = prim_lhs.dtype(); + debug_assert_eq!(prim_dtype, &self.output_primitive_dtype); + + // Safety: Leaf dtypes have been checked to be numeric by `try_new()` + let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| { + self._finish_impl::<$T>(prim_lhs, prim_rhs) + })?; + + debug_assert_eq!(out.dtype(), &output_dtype); + assert_eq!(out.len(), output_len); + + Ok(out) + } + + /// Internal use only - contains physical impls. + fn _finish_impl( + &mut self, + prim_s_lhs: Series, + prim_s_rhs: Series, + ) -> PolarsResult + where + T::Native: PlNumArithmetic, + PrimitiveArray: + polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + #[inline(never)] + fn check_mismatch_pos( + mismatch_pos: usize, + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + ) -> PolarsResult<()> { + if mismatch_pos < offsets_lhs.len_proxy() { + // RHS could be broadcasted + let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 { + 0 + } else { + mismatch_pos + }); + polars_bail!( + ShapeMismatch: + "list lengths differed at index {}: {} != {}", + mismatch_pos, + offsets_lhs.length_at(mismatch_pos), len_r + ) + } + Ok(()) + } + + let mut arr_lhs = { + let ca: &ChunkedArray = prim_s_lhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + let mut arr_rhs = { + let ca: &ChunkedArray = prim_s_rhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + match (&self.op_apply_type, &self.broadcast) { + // We skip for this because it dispatches to `ArithmeticKernel`, which handles the + // validities for us. + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {}, + _ if self.list_to_prim_lhs.is_none() => { + self.op.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {}, + _ => unreachable!(), + } + + // + // General notes + // * Lists can be: + // * Sliced, in which case the primitive/leaf array needs to be indexed starting from an + // offset instead of 0. + // * Masked, in which case the masked rows are permitted to have non-matching widths. + // + + let out = match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; + + assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy()); + + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + + // Counter that stops being incremented at the first row position with mismatching + // list lengths. + let mut mismatch_pos = 0; + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + .enumerate() + { + if + (mismatch_pos == i) + & ( + (lhs_len == rhs_len) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + ) + { + mismatch_pos += 1; + } + + // Both sides are lists, we restrict the index to the min length to avoid + // OOB memory access. + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); + + unsafe { out_ptr.add(l_idx).write(v) }; + } + } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + /// Reduce monomorphization + #[inline(never)] + fn combine_validities_list_to_list_no_broadcast( + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + { + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } + + let leaf_validity = combine_validities_list_to_list_no_broadcast( + offsets_lhs, + offsets_rhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToList, Broadcast::Right) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; + + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + + assert_eq!(offsets_rhs.len_proxy(), 1); + let rhs_start = *offsets_rhs.first() as usize; + let width = offsets_rhs.range() as usize; + + let mut mismatch_pos = 0; + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() { + if ((lhs_len == width) & (mismatch_pos == i)) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + { + mismatch_pos += 1; + } + + let len: usize = lhs_len.min(width); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); + + unsafe { + out_ptr.add(l_idx).write(v); + } + } + } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + #[inline(never)] + fn combine_validities_list_to_list_broadcast_right( + offsets_lhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + width: usize, + rhs_start: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() { + let len: usize = lhs_len.min(width); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } + + let leaf_validity = combine_validities_list_to_list_broadcast_right( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + width, + rhs_start, + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + if self.list_to_prim_lhs.is_none() => + { + let offsets_lhs = self.data_lhs.0.as_slice(); + + // Notes + // * Primitive indexing starts from 0 + // * Output is aligned to LHS array + + let n_values = arr_lhs.len(); + let mut out_vec = Vec::::with_capacity(n_values); + let out_ptr = out_vec.as_mut_ptr(); + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs.value_unchecked(l_idx); + let v = $OP(l, r); + out_ptr.add(l_idx).write(v); + } + } + } + }); + + unsafe { out_vec.set_len(n_values) } + + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + // If we are dispatched here, it means that the LHS array is a unique allocation created + // after a unit-length list column was broadcasted, so this codepath mutably stores the + // results back into the LHS array to save memory. + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { + let offsets_lhs = self.data_lhs.0.as_slice(); + + let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap(); + let arr = arr + .as_any_mut() + .downcast_mut::>() + .unwrap(); + let mut arr_lhs = core::mem::take(arr); + + self.op.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ); + + let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap(); + assert_eq!(arr_lhs_mut_slice.len(), n_values); + + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx); + *l = $OP(*l, r); + } + } + } + }); + + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = arr_lhs.with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + assert_eq!(arr_rhs.len(), 1); + + let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else { + // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL. + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + return self.finish_offsets_and_validities( + Box::new( + arr_lhs.clone().with_validity(Some(Bitmap::new_with_value( + false, + arr_lhs.len(), + ))), + ), + offsets, + validities, + ); + }; + + let arr = self.op.apply_array_to_scalar::(arr_lhs, r, self.swapped); + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) + | v @ (BinaryOpApplyType::ListToList, Broadcast::Left) + | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + if cfg!(debug_assertions) { + panic!("operation was not re-written: {:?}", v) + } else { + unreachable!() + } + }, + }?; + + Ok(out) + } + + /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every + /// level. + fn finish_offsets_and_validities( + &mut self, + leaf_array: Box, + offsets: Vec>, + validities: Vec>, + ) -> PolarsResult { + assert!(!offsets.is_empty()); + assert_eq!(offsets.len(), validities.len()); + let mut results = leaf_array; + + let mut iter = offsets.into_iter().zip(validities).rev(); + + while iter.len() > 1 { + let (offsets, validity) = iter.next().unwrap(); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + results = Box::new(LargeListArray::new(dtype, offsets, results, validity)); + } + + // The combined outer validity is pre-computed during `try_new()` + let (offsets, _) = iter.next().unwrap(); + let validity = core::mem::take(&mut self.outer_validity); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + let results = LargeListArray::new(dtype, offsets, results, Some(validity)); + + Ok(ListChunked::with_chunk( + core::mem::take(&mut self.output_name), + results, + )) + } + + fn materialize_broadcasted_list( + side_data: &mut (Vec>, Vec>, Series), + output_len: usize, + output_primitive_dtype: &DataType, + ) -> (Box, usize) { + let s = &side_data.2; + assert_eq!(s.len(), 1); + + let expected_n_values = { + let offsets = s.list_offsets_and_validities_recursive().0; + output_len * OffsetsBuffer::::leaf_full_start_end(&offsets).len() + }; + + let ca = s.list().unwrap(); + // Remember to cast the leaf primitives to the supertype. + let ca = ca + .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone())) + .unwrap(); + assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data. + let ca = ca.new_from_index(0, output_len).rechunk(); + + let s = ca.into_series(); + + *side_data = { + let (a, b) = s.list_offsets_and_validities_recursive(); + // `Series::default()`: This field in the tuple is no longer used. + (a, b, Series::default()) + }; + + let n_values = OffsetsBuffer::::leaf_full_start_end(&side_data.0).len(); + assert_eq!(n_values, expected_n_values); + + let mut s = s.get_leaf_array(); + let v = unsafe { s.chunks_mut() }; + + assert_eq!(v.len(), 1); + (v.swap_remove(0), n_values) + } } - fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) + + /// Used in 2 places, so it's outside here. + #[inline(never)] + fn combine_validities_list_to_primitive_no_broadcast( + offsets_lhs: &[OffsetsBuffer], + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + // Materialize a full-true validity to re-use the codepath, as we still + // need to spread the bits from the RHS to the correct positions. + (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)), + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() { + let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) }; + for l_idx in l_range { + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) } } diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs index d7d7dbdb8a0e..0a5550b7b0f3 100644 --- a/crates/polars-core/src/series/arithmetic/mod.rs +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -6,6 +6,7 @@ use std::borrow::Cow; use std::ops::{Add, Div, Mul, Rem, Sub}; pub use borrowed::*; +pub use list_borrowed::NumericListOp; use num_traits::{Num, NumCast}; use crate::prelude::*; diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 7f61f99895f4..f532a0e61d6f 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -10,6 +10,7 @@ use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked; ))] use arrow::temporal_conversions::*; use polars_error::feature_gated; +use polars_utils::itertools::Itertools; use crate::chunked_array::cast::{cast_chunks, CastOptions}; #[cfg(feature = "object")] @@ -575,38 +576,53 @@ unsafe fn to_physical_and_dtype( }, ArrowDataType::Struct(_fields) => { feature_gated!("dtype-struct", { - debug_assert_eq!(arrays.len(), 1); - let arr = arrays[0].clone(); - let arr = arr.as_any().downcast_ref::().unwrap(); - let (values, dtypes): (Vec<_>, Vec<_>) = arr - .values() + let mut pl_fields = None; + let arrays = arrays .iter() - .zip(_fields.iter()) - .map(|(value, field)| { - let mut out = - to_physical_and_dtype(vec![value.clone()], Some(&field.metadata)); - (out.0.pop().unwrap(), out.1) + .map(|arr| { + let arr = arr.as_any().downcast_ref::().unwrap(); + let (values, dtypes): (Vec<_>, Vec<_>) = arr + .values() + .iter() + .zip(_fields.iter()) + .map(|(value, field)| { + let mut out = to_physical_and_dtype( + vec![value.clone()], + Some(&field.metadata), + ); + (out.0.pop().unwrap(), out.1) + }) + .unzip(); + + let arrow_fields = values + .iter() + .zip(_fields.iter()) + .map(|(arr, field)| { + ArrowField::new(field.name.clone(), arr.dtype().clone(), true) + }) + .collect(); + let arrow_array = Box::new(StructArray::new( + ArrowDataType::Struct(arrow_fields), + arr.len(), + values, + arr.validity().cloned(), + )) as ArrayRef; + + if pl_fields.is_none() { + pl_fields = Some( + _fields + .iter() + .zip(dtypes) + .map(|(field, dtype)| Field::new(field.name.clone(), dtype)) + .collect_vec(), + ) + } + + arrow_array }) - .unzip(); + .collect_vec(); - let arrow_fields = values - .iter() - .zip(_fields.iter()) - .map(|(arr, field)| { - ArrowField::new(field.name.clone(), arr.dtype().clone(), true) - }) - .collect(); - let arrow_array = Box::new(StructArray::new( - ArrowDataType::Struct(arrow_fields), - values, - arr.validity().cloned(), - )) as ArrayRef; - let polars_fields = _fields - .iter() - .zip(dtypes) - .map(|(field, dtype)| Field::new(field.name.clone(), dtype)) - .collect(); - (vec![arrow_array], DataType::Struct(polars_fields)) + (arrays, DataType::Struct(pl_fields.unwrap())) }) }, // Use Series architecture to convert nested logical types to physical. diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index b91df29a0a38..ace52993b8a1 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -358,11 +358,7 @@ impl SeriesTrait for SeriesWrap { Ok(Scalar::new(self.dtype().clone(), av)) } - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { Ok(Scalar::new(self.dtype().clone(), AnyValue::Null)) } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 30125ccc15b6..612505057eca 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -404,13 +404,9 @@ impl SeriesTrait for SeriesWrap { Ok(self.apply_scale(self.0.std_reduce(ddof))) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { self.0 - .quantile_reduce(quantile, interpol) + .quantile_reduce(quantile, method) .map(|v| self.apply_scale(v)) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 13b121aee0ca..803ca813aa1c 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -501,12 +501,8 @@ impl SeriesTrait for SeriesWrap { v.as_duration(self.0.time_unit()), )) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.0.quantile_reduce(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.0.quantile_reduce(quantile, method)?; let to = self.dtype().to_physical(); let v = v.value().cast(&to); Ok(Scalar::new( diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 24be56671d69..846e326d35b2 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -365,9 +365,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] fn and_reduce(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 9d8357a905bc..b2cb97e39b69 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -468,9 +468,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 805f06d86bac..d40e53d1a01e 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -232,7 +232,14 @@ impl SeriesTrait for SeriesWrap { } fn reverse(&self) -> Series { - self.0._apply_fields(|s| s.reverse()).unwrap().into_series() + let validity = self + .rechunk_validity() + .map(|x| x.into_iter().rev().collect::()); + self.0 + ._apply_fields(|s| s.reverse()) + .unwrap() + .with_outer_validity(validity) + .into_series() } fn shift(&self, periods: i64) -> Series { diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs index 1213c3346525..aa703fb533a1 100644 --- a/crates/polars-core/src/series/into.rs +++ b/crates/polars-core/src/series/into.rs @@ -44,7 +44,13 @@ impl Series { s.to_arrow(0, compat_level) }) .collect::>(); - StructArray::new(dt.to_arrow(compat_level), values, arr.validity().cloned()).boxed() + StructArray::new( + dt.to_arrow(compat_level), + arr.len(), + values, + arr.validity().cloned(), + ) + .boxed() }, // special list branch to // make sure that we recursively apply all logical types. @@ -79,6 +85,34 @@ impl Series { ); Box::new(arr) }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, width) => { + let ca = self.array().unwrap(); + let arr = ca.chunks[chunk_idx].clone(); + let arr = arr.as_any().downcast_ref::().unwrap(); + + let new_values = if let DataType::Null = &**inner { + arr.values().clone() + } else { + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + &inner.to_physical(), + ) + .cast_unchecked(inner) + .unwrap() + }; + + s.to_arrow(0, compat_level) + }; + + let dtype = + FixedSizeListArray::default_datatype(inner.to_arrow(compat_level), *width); + let arr = + FixedSizeListArray::new(dtype, arr.len(), new_values, arr.validity().cloned()); + Box::new(arr) + }, #[cfg(feature = "dtype-categorical")] dt @ (DataType::Categorical(_, ordering) | DataType::Enum(_, ordering)) => { let ca = self.categorical().unwrap(); diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 72cb3b67dc41..fadd2b4f570f 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -23,6 +23,7 @@ use arrow::offset::Offsets; pub use from::*; pub use iterator::{SeriesIter, SeriesPhysIter}; use num_traits::NumCast; +use polars_utils::itertools::Itertools; pub use series_trait::{IsSorted, *}; use crate::chunked_array::cast::CastOptions; @@ -257,7 +258,7 @@ impl Series { pub fn into_frame(self) -> DataFrame { // SAFETY: A single-column dataframe cannot have length mismatches or duplicate names - unsafe { DataFrame::new_no_checks(vec![self.into()]) } + unsafe { DataFrame::new_no_checks(self.len(), vec![self.into()]) } } /// Rename series. @@ -298,11 +299,6 @@ impl Series { Self::try_from((name, array)) } - #[cfg(feature = "arrow_rs")] - pub fn from_arrow_rs(name: PlSmallStr, array: &dyn arrow_array::Array) -> PolarsResult { - Self::from_arrow(name, array.into()) - } - /// Shrink the capacity of this array to fit its length. pub fn shrink_to_fit(&mut self) { self._get_inner_mut().shrink_to_fit() @@ -595,15 +591,17 @@ impl Series { lhs.zip_with_same_type(mask, rhs.as_ref()) } - /// Cast a datelike Series to their physical representation. - /// Primitives remain unchanged + /// Converts a Series to their physical representation, if they have one, + /// otherwise the series is left unchanged. /// /// * Date -> Int32 - /// * Datetime-> Int64 + /// * Datetime -> Int64 + /// * Duration -> Int64 /// * Time -> Int64 /// * Categorical -> UInt32 /// * List(inner) -> List(physical of inner) - /// + /// * Array(inner) -> Array(physical of inner) + /// * Struct -> Struct with physical repr of each struct column pub fn to_physical_repr(&self) -> Cow { use DataType::*; match self.dtype() { @@ -623,6 +621,11 @@ impl Series { Cow::Owned(ca.physical().clone().into_series()) }, List(inner) => Cow::Owned(self.cast(&List(Box::new(inner.to_physical()))).unwrap()), + #[cfg(feature = "dtype-array")] + Array(inner, size) => Cow::Owned( + self.cast(&Array(Box::new(inner.to_physical()), *size)) + .unwrap(), + ), #[cfg(feature = "dtype-struct")] Struct(_) => { let arr = self.struct_().unwrap(); @@ -632,7 +635,8 @@ impl Series { .map(|s| s.to_physical_repr().into_owned()) .collect(); let mut ca = - StructChunked::from_series(self.name().clone(), fields.iter()).unwrap(); + StructChunked::from_series(self.name().clone(), arr.len(), fields.iter()) + .unwrap(); if arr.null_count() > 0 { ca.zip_outer_validity(arr); @@ -643,6 +647,75 @@ impl Series { } } + /// Attempts to convert a Series to dtype, only allowing conversions from + /// physical to logical dtypes--the inverse of to_physical_repr(). + /// + /// # Safety + /// When converting from UInt32 to Categorical it is not checked that the + /// values are in-bound for the categorical mapping. + pub unsafe fn to_logical_repr_unchecked(&self, dtype: &DataType) -> PolarsResult { + use DataType::*; + + let err = || { + Err( + polars_err!(ComputeError: "can't cast from {} to {} in to_logical_repr_unchecked", self.dtype(), dtype), + ) + }; + + match dtype { + dt if self.dtype() == dt => Ok(self.clone()), + #[cfg(feature = "dtype-date")] + Date => Ok(self.i32()?.clone().into_date().into_series()), + #[cfg(feature = "dtype-datetime")] + Datetime(u, z) => Ok(self + .i64()? + .clone() + .into_datetime(*u, z.clone()) + .into_series()), + #[cfg(feature = "dtype-duration")] + Duration(u) => Ok(self.i64()?.clone().into_duration(*u).into_series()), + #[cfg(feature = "dtype-time")] + Time => Ok(self.i64()?.clone().into_time().into_series()), + #[cfg(feature = "dtype-categorical")] + Categorical { .. } | Enum { .. } => { + Ok(CategoricalChunked::from_cats_and_dtype_unchecked( + self.u32()?.clone(), + dtype.clone(), + ) + .into_series()) + }, + List(inner) => { + if let List(self_inner) = self.dtype() { + if inner.to_physical() == **self_inner { + return self.cast(dtype); + } + } + err() + }, + #[cfg(feature = "dtype-struct")] + Struct(target_fields) => { + let ca = self.struct_().unwrap(); + if ca.struct_fields().len() != target_fields.len() { + return err(); + } + let fields = ca + .fields_as_series() + .iter() + .zip(target_fields) + .map(|(s, tf)| s.to_logical_repr_unchecked(tf.dtype())) + .try_collect_vec()?; + let mut result = + StructChunked::from_series(self.name().clone(), ca.len(), fields.iter())?; + if ca.null_count() > 0 { + result.zip_outer_validity(ca); + } + Ok(result.into_series()) + }, + + _ => err(), + } + } + /// Take by index if ChunkedArray contains a single chunk. /// /// # Safety @@ -872,8 +945,7 @@ impl Series { DataType::Categorical(Some(rv), _) | DataType::Enum(Some(rv), _) => match &**rv { RevMapping::Local(arr, _) => size += estimated_bytes_size(arr), RevMapping::Global(map, arr, _) => { - size += - map.capacity() * std::mem::size_of::() * 2 + estimated_bytes_size(arr); + size += map.capacity() * size_of::() * 2 + estimated_bytes_size(arr); }, }, _ => {}, @@ -931,7 +1003,7 @@ fn equal_outer_type(dtype: &DataType) -> bool { } } -impl<'a, T> AsRef> for dyn SeriesTrait + 'a +impl AsRef> for dyn SeriesTrait + '_ where T: 'static + PolarsDataType, { @@ -948,7 +1020,7 @@ where } } -impl<'a, T> AsMut> for dyn SeriesTrait + 'a +impl AsMut> for dyn SeriesTrait + '_ where T: 'static + PolarsDataType, { diff --git a/crates/polars-core/src/series/ops/null.rs b/crates/polars-core/src/series/ops/null.rs index 78e8fb795f27..3a708a404a82 100644 --- a/crates/polars-core/src/series/ops/null.rs +++ b/crates/polars-core/src/series/ops/null.rs @@ -57,7 +57,7 @@ impl Series { .iter() .map(|fld| Series::full_null(fld.name().clone(), size, fld.dtype())) .collect::>(); - let ca = StructChunked::from_series(name, fields.iter()).unwrap(); + let ca = StructChunked::from_series(name, size, fields.iter()).unwrap(); if !fields.is_empty() { ca.with_outer_validity(Some(Bitmap::new_zeroed(size))) diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index 544754755e6e..85998aa54de3 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -1,8 +1,9 @@ use std::borrow::Cow; use arrow::array::*; +use arrow::bitmap::Bitmap; use arrow::legacy::kernels::list::array_to_unit_list; -use arrow::offset::Offsets; +use arrow::offset::{Offsets, OffsetsBuffer}; use polars_error::{polars_bail, polars_ensure, PolarsResult}; use polars_utils::format_tuple; @@ -11,16 +12,11 @@ use crate::datatypes::{DataType, ListChunked}; use crate::prelude::{IntoSeries, Series, *}; fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series { - let mut ca = match s.dtype() { - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { - ListChunked::with_chunk(name, array_to_unit_list(s.array_ref(0).clone())) - }, - _ => ListChunked::from_chunk_iter( - name, - s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())), - ), - }; + let mut ca = ListChunked::from_chunk_iter( + name, + s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())), + ); + ca.set_inner_dtype(s.dtype().clone()); ca.set_fast_explode(); ca.into_series() @@ -56,6 +52,35 @@ impl Series { } } + /// TODO: Move this somewhere else? + pub fn list_offsets_and_validities_recursive( + &self, + ) -> (Vec>, Vec>) { + let mut offsets = vec![]; + let mut validities = vec![]; + + let mut s = self.rechunk(); + + while let DataType::List(_) = s.dtype() { + let ca = s.list().unwrap(); + offsets.push(ca.offsets().unwrap()); + validities.push(ca.rechunk_validity()); + s = ca.get_inner(); + } + + (offsets, validities) + } + + /// For ListArrays, recursively normalizes the offsets to begin from 0, and + /// slices excess length from the values array. + pub fn list_rechunk_and_trim_to_normalized_offsets(&self) -> Self { + if let Some(ca) = self.try_list() { + ca.rechunk_and_trim_to_normalized_offsets().into_series() + } else { + self.rechunk() + } + } + /// Convert the values of this Series to a ListChunked with a length of 1, /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`. pub fn implode(&self) -> PolarsResult { @@ -91,7 +116,7 @@ impl Series { InvalidOperation: "at least one dimension must be specified" ); - let leaf_array = self.get_leaf_array(); + let leaf_array = self.get_leaf_array().rechunk(); let size = leaf_array.len(); let mut total_dim_size = 1; @@ -259,7 +284,7 @@ impl Series { ); let mut builder = - get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone())?; + get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone()); let mut offset = 0u64; for _ in 0..rows { @@ -285,7 +310,7 @@ mod test { fn test_to_list() -> PolarsResult<()> { let s = Series::new("a".into(), &[1, 2, 3]); - let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone())?; + let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone()); builder.append_series(&s).unwrap(); let expected = builder.finish(); diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 2dc8de00dcd7..0352343baa82 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::borrow::Cow; use std::sync::RwLockReadGuard; +use arrow::bitmap::{Bitmap, MutableBitmap}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -327,6 +328,26 @@ pub trait SeriesTrait: /// Aggregate all chunks to a contiguous array of memory. fn rechunk(&self) -> Series; + fn rechunk_validity(&self) -> Option { + if self.chunks().len() == 1 { + return self.chunks()[0].validity().cloned(); + } + + if !self.has_nulls() || self.is_empty() { + return None; + } + + let mut bm = MutableBitmap::with_capacity(self.len()); + for arr in self.chunks() { + if let Some(v) = arr.validity() { + bm.extend_from_bitmap(v); + } else { + bm.extend_constant(arr.len(), true); + } + } + Some(bm.into()) + } + /// Drop all null values and return a new Series. fn drop_nulls(&self) -> Series { if self.null_count() == 0 { @@ -498,11 +519,7 @@ pub trait SeriesTrait: polars_bail!(opq = std, self._dtype()); } /// Get the quantile of the ChunkedArray as a new Series of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { polars_bail!(opq = quantile, self._dtype()); } /// Get the bitwise AND of the Series as a new Series of length 1, @@ -596,7 +613,7 @@ pub trait SeriesTrait: } } -impl<'a> (dyn SeriesTrait + 'a) { +impl (dyn SeriesTrait + '_) { pub fn unpack(&self) -> PolarsResult<&ChunkedArray> where N: 'static + PolarsDataType, diff --git a/crates/polars-core/src/utils/flatten.rs b/crates/polars-core/src/utils/flatten.rs index b96ce61dab82..733ffcd60c08 100644 --- a/crates/polars-core/src/utils/flatten.rs +++ b/crates/polars-core/src/utils/flatten.rs @@ -17,8 +17,10 @@ pub fn flatten_df_iter(df: &DataFrame) -> impl Iterator + '_ { out.set_sorted_flag(s.is_sorted_flag()); Column::from(out) }) - .collect(); - let df = unsafe { DataFrame::new_no_checks(columns) }; + .collect::>(); + + let height = DataFrame::infer_height(&columns); + let df = unsafe { DataFrame::new_no_checks(height, columns) }; if df.is_empty() { None } else { diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 169c30fd0498..05171fe35cfe 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -1159,7 +1159,7 @@ pub fn coalesce_nulls_columns(a: &Column, b: &Column) -> (Column, Column) { } pub fn operation_exceeded_idxsize_msg(operation: &str) -> String { - if core::mem::size_of::() == core::mem::size_of::() { + if size_of::() == size_of::() { format!( "{} exceeded the maximum supported limit of {} rows. Consider installing 'polars-u64-idx'.", operation, diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 027e85886793..18b8c9ddd00a 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -498,3 +498,56 @@ fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> { }, } } + +pub fn merge_dtypes_many + Clone, D: AsRef>( + into_iter: I, +) -> PolarsResult { + let mut iter = into_iter.clone().into_iter(); + + let mut st = iter + .next() + .ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype")) + .map(|d| d.as_ref().clone())?; + + for d in iter { + st = try_get_supertype(d.as_ref(), &st)?; + } + + match st { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(_), ordering) => { + // This merges the global rev maps with linear complexity. + // If we do a binary reduce, it would be quadratic. + let mut iter = into_iter.into_iter(); + let first_dt = iter.next().unwrap(); + let first_dt = first_dt.as_ref(); + let DataType::Categorical(Some(rm), _) = first_dt else { + unreachable!() + }; + + let mut merger = GlobalRevMapMerger::new(rm.clone()); + + for d in iter { + if let DataType::Categorical(Some(rm), _) = d.as_ref() { + merger.merge_map(rm)? + } + } + let rev_map = merger.finish(); + + Ok(DataType::Categorical(Some(rev_map), ordering)) + }, + // This would be quadratic if we do this with the binary `merge_dtypes`. + DataType::List(inner) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) if fields.iter().any(|f| f.dtype().contains_categoricals()) => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + _ => Ok(st), + } +} diff --git a/crates/polars-error/src/constants.rs b/crates/polars-error/src/constants.rs index 910c1e62a499..b6367e3abb2e 100644 --- a/crates/polars-error/src/constants.rs +++ b/crates/polars-error/src/constants.rs @@ -11,7 +11,7 @@ pub static FALSE: &str = "false"; #[cfg(not(feature = "python"))] pub static LENGTH_LIMIT_MSG: &str = - "polars' maximum length reached. Consider compiling with 'bigidx' feature."; + "Polars' maximum length reached. Consider compiling with 'bigidx' feature."; #[cfg(feature = "python")] pub static LENGTH_LIMIT_MSG: &str = - "polars' maximum length reached. Consider installing 'polars-u64-idx'."; + "Polars' maximum length reached. Consider installing 'polars-u64-idx'."; diff --git a/crates/polars-expr/Cargo.toml b/crates/polars-expr/Cargo.toml index 1b2b6063de9b..29aa34652146 100644 --- a/crates/polars-expr/Cargo.toml +++ b/crates/polars-expr/Cargo.toml @@ -12,6 +12,8 @@ description = "Physical expression implementation of the Polars project." ahash = { workspace = true } arrow = { workspace = true } bitflags = { workspace = true } +hashbrown = { workspace = true } +num-traits = { workspace = true } once_cell = { workspace = true } polars-compute = { workspace = true } polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } @@ -19,8 +21,10 @@ polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } polars-ops = { workspace = true, features = ["chunked_ids"] } polars-plan = { workspace = true } +polars-row = { workspace = true } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } +rand = { workspace = true } rayon = { workspace = true } [features] @@ -72,5 +76,5 @@ bitwise = ["polars-core/bitwise", "polars-plan/bitwise"] round_series = ["polars-plan/round_series", "polars-ops/round_series"] is_between = ["polars-plan/is_between"] dynamic_group_by = ["polars-plan/dynamic_group_by", "polars-time", "temporal"] -propagate_nans = ["polars-plan/propagate_nans"] +propagate_nans = ["polars-plan/propagate_nans", "polars-ops/propagate_nans"] panic_on_schema = ["polars-plan/panic_on_schema"] diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index e1d2a1e716ab..f1cfa5251899 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -535,11 +535,13 @@ impl PartitionedAggregation for AggregationExpr { }; let mut count_s = series.agg_valid_count(groups); count_s.rename(PlSmallStr::from_static("__POLARS_COUNT")); - Ok( - StructChunked::from_series(new_name, [agg_s, count_s].iter()) - .unwrap() - .into_series(), + Ok(StructChunked::from_series( + new_name, + agg_s.len(), + [agg_s, count_s].iter(), ) + .unwrap() + .into_series()) } }, GroupByMethod::Implode => { @@ -713,19 +715,19 @@ impl PartitionedAggregation for AggregationExpr { pub struct AggQuantileExpr { pub(crate) input: Arc, pub(crate) quantile: Arc, - pub(crate) interpol: QuantileInterpolOptions, + pub(crate) method: QuantileMethod, } impl AggQuantileExpr { pub fn new( input: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { Self { input, quantile, - interpol, + method, } } @@ -748,7 +750,7 @@ impl PhysicalExpr for AggQuantileExpr { let input = self.input.evaluate(df, state)?; let quantile = self.get_quantile(df, state)?; input - .quantile_reduce(quantile, self.interpol) + .quantile_reduce(quantile, self.method) .map(|sc| sc.into_series(input.name().clone())) } #[allow(clippy::ptr_arg)] @@ -769,7 +771,7 @@ impl PhysicalExpr for AggQuantileExpr { let mut agg = unsafe { ac.flat_naive() .into_owned() - .agg_quantile(ac.groups(), quantile, self.interpol) + .agg_quantile(ac.groups(), quantile, self.method) }; agg.rename(keep_name); Ok(AggregationContext::from_agg_state( diff --git a/crates/polars-expr/src/expressions/alias.rs b/crates/polars-expr/src/expressions/alias.rs index 6b38d8dc8270..8d321263a3f5 100644 --- a/crates/polars-expr/src/expressions/alias.rs +++ b/crates/polars-expr/src/expressions/alias.rs @@ -59,6 +59,10 @@ impl PhysicalExpr for AliasExpr { )) } + fn is_literal(&self) -> bool { + self.physical_expr.is_literal() + } + fn is_scalar(&self) -> bool { self.physical_expr.is_scalar() } diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index a52cff4ca2f5..53579b763033 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; +use polars_core::chunked_array::builder::get_list_builder; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "parquet")] @@ -22,11 +23,11 @@ pub struct ApplyExpr { function_operates_on_scalar: bool, allow_rename: bool, pass_name_to_apply: bool, - input_schema: Option, + input_schema: SchemaRef, allow_threading: bool, check_lengths: bool, allow_group_aware: bool, - output_dtype: Option, + output_field: Field, } impl ApplyExpr { @@ -37,8 +38,8 @@ impl ApplyExpr { expr: Expr, options: FunctionOptions, allow_threading: bool, - input_schema: Option, - output_dtype: Option, + input_schema: SchemaRef, + output_field: Field, returns_scalar: bool, ) -> Self { #[cfg(debug_assertions)] @@ -61,30 +62,7 @@ impl ApplyExpr { allow_threading, check_lengths: options.check_lengths(), allow_group_aware: options.flags.contains(FunctionFlags::ALLOW_GROUP_AWARE), - output_dtype, - } - } - - pub(crate) fn new_minimal( - inputs: Vec>, - function: SpecialEq>, - expr: Expr, - collect_groups: ApplyOptions, - ) -> Self { - Self { - inputs, - function, - expr, - collect_groups, - function_returns_scalar: false, - function_operates_on_scalar: false, - allow_rename: false, - pass_name_to_apply: false, - input_schema: None, - allow_threading: true, - check_lengths: true, - allow_group_aware: true, - output_dtype: None, + output_field, } } @@ -122,11 +100,8 @@ impl ApplyExpr { Ok(ac) } - fn get_input_schema(&self, df: &DataFrame) -> Cow { - match &self.input_schema { - Some(schema) => Cow::Borrowed(schema.as_ref()), - None => Cow::Owned(df.schema()), - } + fn get_input_schema(&self, _df: &DataFrame) -> Cow { + Cow::Borrowed(self.input_schema.as_ref()) } /// Evaluates and flattens `Option` to `Column`. @@ -134,7 +109,7 @@ impl ApplyExpr { if let Some(out) = self.function.call_udf(inputs)? { Ok(out) } else { - let field = self.to_field(self.input_schema.as_ref().unwrap()).unwrap(); + let field = self.to_field(self.input_schema.as_ref()).unwrap(); Ok(Column::full_null(field.name().clone(), 1, field.dtype())) } } @@ -178,9 +153,11 @@ impl ApplyExpr { }; let ca: ListChunked = if self.allow_threading { - let dtype = match &self.output_dtype { - Some(dtype) if dtype.is_known() && !dtype.is_null() => Some(dtype.clone()), - _ => None, + let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null() + { + Some(self.output_field.dtype.clone()) + } else { + None }; let lst = agg.list().unwrap(); @@ -265,46 +242,51 @@ impl ApplyExpr { // Length of the items to iterate over. let len = iters[0].size_hint().0; - if len == 0 { - drop(iters); - - // Take the first aggregation context that as that is the input series. - let mut ac = acs.swap_remove(0); - ac.with_update_groups(UpdateGroups::No); - - let agg_state = if self.function_returns_scalar { - AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype)) - } else { - match self.collect_groups { - ApplyOptions::ElementWise | ApplyOptions::ApplyList => ac - .agg_state() - .map(|_| Series::new_empty(field.name().clone(), &field.dtype)), - ApplyOptions::GroupWise => AggState::AggregatedList(Series::new_empty( - field.name().clone(), - &DataType::List(Box::new(field.dtype.clone())), - )), - } - }; - - ac.with_agg_state(agg_state); - return Ok(ac); - } - - let ca = (0..len) - .map(|_| { + let ca = if len == 0 { + let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name); + for _ in 0..len { container.clear(); for iter in &mut iters { match iter.next().unwrap() { - None => return Ok(None), + None => { + builder.append_null(); + }, Some(s) => container.push(s.deep_clone().into()), } } - self.function + let out = self + .function .call_udf(&mut container) - .map(|r| r.map(|c| c.as_materialized_series().clone())) - }) - .collect::>()? - .with_name(field.name.clone()); + .map(|r| r.map(|c| c.as_materialized_series().clone()))?; + + builder.append_opt_series(out.as_ref())? + } + builder.finish() + } else { + // We still need this branch to materialize unknown/ data dependent types in eager. :( + (0..len) + .map(|_| { + container.clear(); + for iter in &mut iters { + match iter.next().unwrap() { + None => return Ok(None), + Some(s) => container.push(s.deep_clone().into()), + } + } + self.function + .call_udf(&mut container) + .map(|r| r.map(|c| c.as_materialized_series().clone())) + }) + .collect::>()? + .with_name(field.name.clone()) + }; + #[cfg(debug_assertions)] + { + let inner = ca.dtype().inner_dtype().unwrap(); + if field.dtype.is_known() { + assert_eq!(inner, &field.dtype); + } + } drop(iters); @@ -443,7 +425,7 @@ impl PhysicalExpr for ApplyExpr { self.expr.to_field(input_schema, Context::Default) } #[cfg(feature = "parquet")] - fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> { + fn as_stats_evaluator(&self) -> Option<&dyn StatsEvaluator> { let function = match &self.expr { Expr::Function { function, .. } => function, _ => return None, @@ -543,14 +525,6 @@ fn apply_multiple_elementwise<'a>( impl StatsEvaluator for ApplyExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { let read = self.should_read_impl(stats)?; - if ExecutionState::new().verbose() { - if read { - eprintln!("parquet file must be read, statistics not sufficient for predicate.") - } else { - eprintln!("parquet file can be skipped, the statistics were sufficient to apply the predicate.") - } - } - Ok(read) } } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index d0b00bf2ddac..23f50af45273 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -75,11 +75,8 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu let right_dt = right.dtype().cast_leaf(Float64); left.cast(&left_dt)? / right.cast(&right_dt)? }, - dt @ List(_) => { - let left_dt = dt.cast_leaf(Float64); - let right_dt = right.dtype().cast_leaf(Float64); - left.cast(&left_dt)? / right.cast(&right_dt)? - }, + List(_) => left / right, + _ if right.dtype().is_list() => left / right, _ => { if right.dtype().is_temporal() { return left / right; @@ -354,7 +351,7 @@ mod stats { use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), - Operator::NotEq => apply_operator_stats_eq(min_max, literal), + Operator::NotEq => apply_operator_stats_neq(min_max, literal), Operator::Gt => { // Literal is bigger than max value, selection needs all rows. C::gt(literal, min_max).map(|ca| ca.any()).unwrap_or(false) @@ -457,10 +454,6 @@ mod stats { impl StatsEvaluator for BinaryExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { - if std::env::var("POLARS_NO_PARQUET_STATISTICS").is_ok() { - return Ok(true); - } - use Operator::*; match ( self.left.as_stats_evaluator(), diff --git a/crates/polars-expr/src/expressions/column.rs b/crates/polars-expr/src/expressions/column.rs index 6bac214f140c..8a59d6c25ddb 100644 --- a/crates/polars-expr/src/expressions/column.rs +++ b/crates/polars-expr/src/expressions/column.rs @@ -9,11 +9,11 @@ use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExp pub struct ColumnExpr { name: PlSmallStr, expr: Expr, - schema: Option, + schema: SchemaRef, } impl ColumnExpr { - pub fn new(name: PlSmallStr, expr: Expr, schema: Option) -> Self { + pub fn new(name: PlSmallStr, expr: Expr, schema: SchemaRef) -> Self { Self { name, expr, schema } } } @@ -141,42 +141,37 @@ impl PhysicalExpr for ColumnExpr { Some(&self.expr) } fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { - let out = match &self.schema { - None => self.process_by_linear_search(df, state, false), - Some(schema) => { - match schema.get_full(&self.name) { - Some((idx, _, _)) => { - // check if the schema was correct - // if not do O(n) search - match df.get_columns().get(idx) { - Some(out) => self.process_by_idx( - out.as_materialized_series(), - state, - schema, - df, - true, - ), - None => { - // partitioned group_by special case - if let Some(schema) = state.get_schema() { - self.process_from_state_schema(df, state, &schema) - } else { - self.process_by_linear_search(df, state, true) - } - }, - } - }, - // in the future we will throw an error here - // now we do a linear search first as the lazy reported schema may still be incorrect - // in debug builds we panic so that it can be fixed when occurring + let out = match self.schema.get_full(&self.name) { + Some((idx, _, _)) => { + // check if the schema was correct + // if not do O(n) search + match df.get_columns().get(idx) { + Some(out) => self.process_by_idx( + out.as_materialized_series(), + state, + &self.schema, + df, + true, + ), None => { - if self.name.starts_with(CSE_REPLACED) { - return self.process_cse(df, schema); + // partitioned group_by special case + if let Some(schema) = state.get_schema() { + self.process_from_state_schema(df, state, &schema) + } else { + self.process_by_linear_search(df, state, true) } - self.process_by_linear_search(df, state, true) }, } }, + // in the future we will throw an error here + // now we do a linear search first as the lazy reported schema may still be incorrect + // in debug builds we panic so that it can be fixed when occurring + None => { + if self.name.starts_with(CSE_REPLACED) { + return self.process_cse(df, &self.schema); + } + self.process_by_linear_search(df, state, true) + }, }; self.check_external_context(out, state) } diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs index 5c9fc86a9c27..c9b1c26b6a94 100644 --- a/crates/polars-expr/src/expressions/gather.rs +++ b/crates/polars-expr/src/expressions/gather.rs @@ -127,7 +127,7 @@ impl GatherExpr { let idx: IdxCa = match groups.as_ref() { GroupsProxy::Idx(groups) => { if groups.all().iter().zip(idx).any(|(g, idx)| match idx { - None => true, + None => false, Some(idx) => idx >= g.len() as IdxSize, }) { self.oob_err()?; @@ -148,7 +148,7 @@ impl GatherExpr { }, GroupsProxy::Slice { groups, .. } => { if groups.iter().zip(idx).any(|(g, idx)| match idx { - None => true, + None => false, Some(idx) => idx >= g[1], }) { self.oob_err()?; @@ -255,7 +255,7 @@ impl GatherExpr { idx.series().len(), groups.len(), ac.series().name().clone(), - )?; + ); let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); for (s, idx) in iter { diff --git a/crates/polars-expr/src/expressions/group_iter.rs b/crates/polars-expr/src/expressions/group_iter.rs index 6b1d54d0ac13..b42851e49d2a 100644 --- a/crates/polars-expr/src/expressions/group_iter.rs +++ b/crates/polars-expr/src/expressions/group_iter.rs @@ -4,7 +4,7 @@ use polars_core::series::amortized_iter::AmortSeries; use super::*; -impl<'a> AggregationContext<'a> { +impl AggregationContext<'_> { pub(super) fn iter_groups( &mut self, keep_names: bool, diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 8a74033953dc..15550c517fe7 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -72,13 +72,6 @@ impl AggState { AggState::NotAggregated(s) => AggState::NotAggregated(func(s)?), }) } - - fn map(&self, func: F) -> Self - where - F: FnOnce(&Series) -> Series, - { - self.try_map(|s| Ok(func(s))).unwrap() - } } // lazy update strategy diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index c776e4b951dd..37600c71f06a 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -230,7 +230,7 @@ impl PhysicalExpr for TernaryExpr { // * `zip_with` can be called directly with the series // * mix of unit literals and AggregatedList // * `zip_with` can be called with the flat values after the offsets - // have been been checked for alignment + // have been checked for alignment let ac_target = non_literal_acs.first().unwrap(); let agg_state_out = match ac_target.agg_state() { diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index b47d1744f662..f843c0e83d95 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -588,8 +588,11 @@ impl PhysicalExpr for WindowExpr { .1, ) } else { - let df_right = unsafe { DataFrame::new_no_checks(keys) }; - let df_left = unsafe { DataFrame::new_no_checks(group_by_columns) }; + let df_right = + unsafe { DataFrame::new_no_checks_height_from_first(keys) }; + let df_left = unsafe { + DataFrame::new_no_checks_height_from_first(group_by_columns) + }; Ok(private_left_join_multiple_keys(&df_left, &df_right, true)?.1) } }; @@ -751,7 +754,7 @@ where unsafe { values.set_len(len) } ChunkedArray::new_vec(ca.name().clone(), values).into_series() } else { - // We don't use a mutable bitmap as bits will have have race conditions! + // We don't use a mutable bitmap as bits will have race conditions! // A single byte might alias if we write from single threads. let mut validity: Vec = vec![false; len]; let validity_ptr = validity.as_mut_ptr(); diff --git a/crates/polars-expr/src/groups/mod.rs b/crates/polars-expr/src/groups/mod.rs new file mode 100644 index 000000000000..43091244c661 --- /dev/null +++ b/crates/polars-expr/src/groups/mod.rs @@ -0,0 +1,67 @@ +use std::any::Any; +use std::path::Path; + +use polars_core::prelude::*; +use polars_utils::aliases::PlRandomState; +use polars_utils::IdxSize; + +mod row_encoded; + +/// A Grouper maps keys to groups, such that duplicate keys map to the same group. +pub trait Grouper: Any + Send { + /// Creates a new empty Grouper similar to this one. + fn new_empty(&self) -> Box; + + /// Returns the number of groups in this Grouper. + fn num_groups(&self) -> IdxSize; + + /// Inserts the given keys into this Grouper, mutating groups_idxs such + /// that group_idxs[i] is the group index of keys[..][i]. + fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec); + + /// Adds the given Grouper into this one, mutating groups_idxs such that + /// the ith group of other now has group index group_idxs[i] in self. + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec); + + /// Partitions this Grouper into the given number of partitions. + /// + /// Updates partition_idxs such that the ith group of self moves to partition + /// partition_idxs[i]. + /// + /// It is guaranteed that two equal keys in two independent partition_into + /// calls map to the same partition index if the seed and the number of + /// partitions is equal. + fn partition( + &self, + seed: u64, + num_partitions: usize, + partition_idxs: &mut Vec, + ) -> Vec>; + + /// Returns the keys in this Grouper in group order, that is the key for + /// group i is returned in row i. + fn get_keys_in_group_order(&self) -> DataFrame; + + /// Returns the keys in this Grouper, mutating group_idxs such that the ith + /// key returned corresponds to group group_idxs[i]. + fn get_keys_groups(&self, group_idxs: &mut Vec) -> DataFrame; + + /// Stores this Grouper at the given path. + fn store_ooc(&self, _path: &Path) { + unimplemented!(); + } + + /// Loads this Grouper from the given path. + fn load_ooc(&mut self, _path: &Path) { + unimplemented!(); + } + + fn as_any(&self) -> &dyn Any; +} + +pub fn new_hash_grouper(key_schema: Arc, random_state: PlRandomState) -> Box { + Box::new(row_encoded::RowEncodedHashGrouper::new( + key_schema, + random_state, + )) +} diff --git a/crates/polars-expr/src/groups/row_encoded.rs b/crates/polars-expr/src/groups/row_encoded.rs new file mode 100644 index 000000000000..1a2fd5209436 --- /dev/null +++ b/crates/polars-expr/src/groups/row_encoded.rs @@ -0,0 +1,251 @@ +use std::mem::MaybeUninit; + +use hashbrown::hash_table::{Entry, HashTable}; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_unordered; +use polars_row::EncodingField; +use polars_utils::aliases::PlRandomState; +use polars_utils::hashing::{folded_multiply, hash_to_partition}; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; +use rand::Rng; + +use super::*; + +struct Group { + key_hash: u64, + key_offset: usize, + key_length: u32, + group_idx: IdxSize, +} + +impl Group { + unsafe fn key<'k>(&self, key_data: &'k [u8]) -> &'k [u8] { + key_data.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) + } +} + +#[derive(Default)] +pub struct RowEncodedHashGrouper { + key_schema: Arc, + table: HashTable, + key_data: Vec, + + // Used for computing canonical hashes. + random_state: PlRandomState, + + // Internal random seed used to keep hash iteration order decorrelated. + // We simply store a random odd number and multiply the canonical hash by it. + seed: u64, +} + +impl RowEncodedHashGrouper { + pub fn new(key_schema: Arc, random_state: PlRandomState) -> Self { + Self { + key_schema, + random_state, + seed: rand::random::() | 1, + ..Default::default() + } + } + + fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize { + let num_groups = self.table.len(); + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |g| unsafe { hash == g.key_hash && key == g.key(&self.key_data) }, + |g| g.key_hash.wrapping_mul(self.seed), + ); + + match entry { + Entry::Occupied(e) => e.get().group_idx, + Entry::Vacant(e) => { + let group_idx: IdxSize = num_groups.try_into().unwrap(); + let group = Group { + key_hash: hash, + key_offset: self.key_data.len(), + key_length: key.len().try_into().unwrap(), + group_idx, + }; + self.key_data.extend(key); + e.insert(group); + group_idx + }, + } + } + + /// Insert a key, without checking that it is unique. + fn insert_key_unique(&mut self, hash: u64, key: &[u8]) -> IdxSize { + let group_idx = self.table.len().try_into().unwrap(); + let group = Group { + key_hash: hash, + key_offset: self.key_data.len(), + key_length: key.len().try_into().unwrap(), + group_idx, + }; + self.key_data.extend(key); + self.table + .insert_unique(hash.wrapping_mul(self.seed), group, |g| { + g.key_hash.wrapping_mul(self.seed) + }); + group_idx + } + + fn finalize_keys(&self, mut key_rows: Vec<&[u8]>) -> DataFrame { + let key_dtypes = self + .key_schema + .iter() + .map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest())) + .collect::>(); + let fields = vec![EncodingField::new_unsorted(); key_dtypes.len()]; + let key_columns = + unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &key_dtypes) }; + + let cols = self + .key_schema + .iter() + .zip(key_columns) + .map(|((name, dt), col)| { + let s = Series::try_from((name.clone(), col)).unwrap(); + unsafe { s.to_logical_repr_unchecked(dt) } + .unwrap() + .into_column() + }) + .collect(); + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +impl Grouper for RowEncodedHashGrouper { + fn new_empty(&self) -> Box { + Box::new(Self::new( + self.key_schema.clone(), + self.random_state.clone(), + )) + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec) { + let series = keys + .get_columns() + .iter() + .map(|c| c.as_materialized_series().clone()) + .collect_vec(); + let keys_encoded = _get_rows_encoded_unordered(&series[..]) + .unwrap() + .into_array(); + assert!(keys_encoded.len() == keys[0].len()); + + group_idxs.clear(); + group_idxs.reserve(keys_encoded.len()); + for key in keys_encoded.values_iter() { + let hash = self.random_state.hash_one(key); + unsafe { + group_idxs.push_unchecked(self.insert_key(hash, key)); + } + } + } + + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec) { + let other = other.as_any().downcast_ref::().unwrap(); + + // TODO: cardinality estimation. + self.table + .reserve(other.table.len(), |g| g.key_hash.wrapping_mul(self.seed)); + + unsafe { + group_idxs.clear(); + group_idxs.reserve(other.table.len()); + let idx_out = group_idxs.spare_capacity_mut(); + for group in other.table.iter() { + let group_key = group.key(&other.key_data); + let new_idx = self.insert_key(group.key_hash, group_key); + *idx_out.get_unchecked_mut(group.group_idx as usize) = MaybeUninit::new(new_idx); + } + group_idxs.set_len(other.table.len()); + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.table.len()); + unsafe { + let out = key_rows.spare_capacity_mut(); + for group in &self.table { + *out.get_unchecked_mut(group.group_idx as usize) = + MaybeUninit::new(group.key(&self.key_data)); + } + key_rows.set_len(self.table.len()); + } + self.finalize_keys(key_rows) + } + + fn get_keys_groups(&self, group_idxs: &mut Vec) -> DataFrame { + group_idxs.clear(); + group_idxs.reserve(self.table.len()); + self.finalize_keys( + self.table + .iter() + .map(|group| unsafe { + group_idxs.push(group.group_idx); + group.key(&self.key_data) + }) + .collect(), + ) + } + + fn partition( + &self, + seed: u64, + num_partitions: usize, + partition_idxs: &mut Vec, + ) -> Vec> { + assert!(num_partitions > 0); + + // Two-pass algorithm to prevent reallocations. + let mut partition_size = vec![(0, 0); num_partitions]; // (keys, bytes) + unsafe { + for group in self.table.iter() { + let ph = folded_multiply(group.key_hash, seed | 1); + let p_idx = hash_to_partition(ph, num_partitions); + let (p_keys, p_bytes) = partition_size.get_unchecked_mut(p_idx as usize); + *p_keys += 1; + *p_bytes += group.key_length as usize; + } + } + + let mut rng = rand::thread_rng(); + let mut partitions = partition_size + .into_iter() + .map(|(keys, bytes)| Self { + key_schema: self.key_schema.clone(), + table: HashTable::with_capacity(keys), + key_data: Vec::with_capacity(bytes), + random_state: self.random_state.clone(), + seed: rng.gen::() | 1, + }) + .collect_vec(); + + unsafe { + partition_idxs.clear(); + partition_idxs.reserve(self.table.len()); + let partition_idxs_out = partition_idxs.spare_capacity_mut(); + for group in self.table.iter() { + let ph = folded_multiply(group.key_hash, seed | 1); + let p_idx = hash_to_partition(ph, num_partitions); + let p = partitions.get_unchecked_mut(p_idx); + p.insert_key_unique(group.key_hash, group.key(&self.key_data)); + *partition_idxs_out.get_unchecked_mut(group.group_idx as usize) = + MaybeUninit::new(p_idx as IdxSize); + } + partition_idxs.set_len(self.table.len()); + } + + partitions.into_iter().map(|p| Box::new(p) as _).collect() + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 9981e47f1451..2778f4621dc2 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -1,4 +1,5 @@ mod expressions; +pub mod groups; pub mod planner; pub mod prelude; pub mod reduce; diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 8878b420af02..c4006de0c8ec 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -25,7 +25,7 @@ pub fn create_physical_expressions_from_irs( exprs: &[ExprIR], context: Context, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, state: &mut ExpressionConversionState, ) -> PolarsResult>> { create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker) @@ -35,7 +35,7 @@ pub(crate) fn create_physical_expressions_check_state( exprs: &[ExprIR], context: Context, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, state: &mut ExpressionConversionState, checker: F, ) -> PolarsResult>> @@ -57,7 +57,7 @@ pub(crate) fn create_physical_expressions_from_nodes( exprs: &[Node], context: Context, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, state: &mut ExpressionConversionState, ) -> PolarsResult>> { create_physical_expressions_from_nodes_check_state( @@ -69,7 +69,7 @@ pub(crate) fn create_physical_expressions_from_nodes_check_state( exprs: &[Node], context: Context, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, state: &mut ExpressionConversionState, checker: F, ) -> PolarsResult>> @@ -165,7 +165,7 @@ pub fn create_physical_expr( expr_ir: &ExprIR, ctxt: Context, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, state: &mut ExpressionConversionState, ) -> PolarsResult> { let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?; @@ -185,7 +185,7 @@ fn create_physical_expr_inner( expression: Node, ctxt: Context, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, state: &mut ExpressionConversionState, ) -> PolarsResult> { use AExpr::*; @@ -309,7 +309,7 @@ fn create_physical_expr_inner( Column(column) => Ok(Arc::new(ColumnExpr::new( column.clone(), node_to_expr(expression, expr_arena), - schema.cloned(), + schema.clone(), ))), Sort { expr, options } => { let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; @@ -402,7 +402,9 @@ fn create_physical_expr_inner( }, _ => { if let IRAggExpr::Quantile { - quantile, interpol, .. + quantile, + method: interpol, + .. } = agg { let quantile = @@ -410,22 +412,18 @@ fn create_physical_expr_inner( return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol))); } - let field = schema - .map(|schema| { - expr_arena.get(expression).to_field( - schema, - Context::Aggregation, - expr_arena, - ) - }) - .transpose()?; + let field = expr_arena.get(expression).to_field( + schema, + Context::Aggregation, + expr_arena, + )?; let groupby = GroupByMethod::from(agg.clone()); let agg_type = AggregationType { groupby, allow_threading: false, }; - Ok(Arc::new(AggregationExpr::new(input, agg_type, field))) + Ok(Arc::new(AggregationExpr::new(input, agg_type, Some(field)))) }, } }, @@ -475,12 +473,10 @@ fn create_physical_expr_inner( options, } => { let is_scalar = is_scalar_ae(expression, expr_arena); - let output_dtype = schema.and_then(|schema| { + let output_dtype = expr_arena .get(expression) - .to_dtype(schema, Context::Default, expr_arena) - .ok() - }); + .to_field(schema, Context::Default, expr_arena)?; let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise); @@ -504,7 +500,7 @@ fn create_physical_expr_inner( node_to_expr(expression, expr_arena), *options, state.allow_threading, - schema.cloned(), + schema.clone(), output_dtype, is_scalar, ))) @@ -516,12 +512,10 @@ fn create_physical_expr_inner( .. } => { let is_scalar = is_scalar_ae(expression, expr_arena); - let output_dtype = schema.and_then(|schema| { + let output_field = expr_arena .get(expression) - .to_dtype(schema, Context::Default, expr_arena) - .ok() - }); + .to_field(schema, Context::Default, expr_arena)?; let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise); // Will be reset in the function so get that here. @@ -544,8 +538,8 @@ fn create_physical_expr_inner( node_to_expr(expression, expr_arena), *options, state.allow_threading, - schema.cloned(), - output_dtype, + schema.clone(), + output_field, is_scalar, ))) }, @@ -570,11 +564,25 @@ fn create_physical_expr_inner( let function = SpecialEq::new(Arc::new( move |c: &mut [polars_core::frame::column::Column]| c[0].explode().map(Some), ) as Arc); - Ok(Arc::new(ApplyExpr::new_minimal( + + let field = expr_arena + .get(expression) + .to_field(schema, ctxt, expr_arena)?; + Ok(Arc::new(ApplyExpr::new( vec![input], function, node_to_expr(expression, expr_arena), - ApplyOptions::GroupWise, + FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "", + cast_to_supertypes: None, + check_lengths: Default::default(), + flags: Default::default(), + }, + state.allow_threading, + schema.clone(), + field, + false, ))) }, Alias(input, name) => { diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs index af3f72733efd..55a4b325bda1 100644 --- a/crates/polars-expr/src/reduce/convert.rs +++ b/crates/polars-expr/src/reduce/convert.rs @@ -1,81 +1,50 @@ -use polars_core::error::feature_gated; +// use polars_core::error::feature_gated; use polars_plan::prelude::*; use polars_utils::arena::{Arena, Node}; -use super::len::LenReduce; -use super::mean::MeanReduce; -use super::min_max::{MaxReduce, MinReduce}; -#[cfg(feature = "propagate_nans")] -use super::nan_min_max::{NanMaxReduce, NanMinReduce}; -use super::sum::SumReduce; use super::*; +use crate::reduce::len::LenReduce; +use crate::reduce::mean::new_mean_reduction; +use crate::reduce::min_max::{new_max_reduction, new_min_reduction}; +use crate::reduce::sum::new_sum_reduction; +use crate::reduce::var_std::new_var_std_reduction; /// Converts a node into a reduction + its associated selector expression. pub fn into_reduction( node: Node, expr_arena: &mut Arena, schema: &Schema, -) -> PolarsResult<(Box, Node)> { +) -> PolarsResult<(Box, Node)> { let get_dt = |node| { expr_arena .get(node) - .to_dtype(schema, Context::Default, expr_arena) + .to_dtype(schema, Context::Default, expr_arena)? + .materialize_unknown() }; let out = match expr_arena.get(node) { AExpr::Agg(agg) => match agg { - IRAggExpr::Sum(input) => ( - Box::new(SumReduce::new(get_dt(*input)?)) as Box, - *input, - ), + IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?), *input), + IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?), *input), IRAggExpr::Min { propagate_nans, input, - } => { - let dt = get_dt(*input)?; - if *propagate_nans && dt.is_float() { - feature_gated!("propagate_nans", { - let out: Box = match dt { - DataType::Float32 => Box::new(NanMinReduce::::new()), - DataType::Float64 => Box::new(NanMinReduce::::new()), - _ => unreachable!(), - }; - (out, *input) - }) - } else { - ( - Box::new(MinReduce::new(dt.clone())) as Box, - *input, - ) - } - }, + } => (new_min_reduction(get_dt(*input)?, *propagate_nans), *input), IRAggExpr::Max { propagate_nans, input, - } => { - let dt = get_dt(*input)?; - if *propagate_nans && dt.is_float() { - feature_gated!("propagate_nans", { - let out: Box = match dt { - DataType::Float32 => Box::new(NanMaxReduce::::new()), - DataType::Float64 => Box::new(NanMaxReduce::::new()), - _ => unreachable!(), - }; - (out, *input) - }) - } else { - (Box::new(MaxReduce::new(dt.clone())) as _, *input) - } + } => (new_max_reduction(get_dt(*input)?, *propagate_nans), *input), + IRAggExpr::Var(input, ddof) => { + (new_var_std_reduction(get_dt(*input)?, false, *ddof), *input) }, - IRAggExpr::Mean(input) => { - let out: Box = Box::new(MeanReduce::new(get_dt(*input)?)); - (out, *input) + IRAggExpr::Std(input, ddof) => { + (new_var_std_reduction(get_dt(*input)?, true, *ddof), *input) }, - _ => unreachable!(), + _ => todo!(), }, AExpr::Len => { // Compute length on the first column, or if none exist we'll use // a zero-length dummy series. - let out: Box = Box::new(LenReduce::new()); + let out: Box = Box::new(LenReduce::default()); let expr = if let Some(first_column) = schema.iter_names().next() { expr_arena.add(AExpr::Column(first_column.as_str().into())) } else { diff --git a/crates/polars-expr/src/reduce/len.rs b/crates/polars-expr/src/reduce/len.rs index 1e11a505410d..57641b1a02b6 100644 --- a/crates/polars-expr/src/reduce/len.rs +++ b/crates/polars-expr/src/reduce/len.rs @@ -1,42 +1,76 @@ use polars_core::error::constants::LENGTH_LIMIT_MSG; use super::*; +use crate::reduce::partition::partition_vec; -#[derive(Clone)] -pub struct LenReduce {} +#[derive(Default)] +pub struct LenReduce { + groups: Vec, +} -impl LenReduce { - pub fn new() -> Self { - Self {} +impl GroupedReduction for LenReduce { + fn new_empty(&self) -> Box { + Box::new(Self::default()) } -} -impl Reduction for LenReduce { - fn new_reducer(&self) -> Box { - Box::new(LenReduceState { len: 0 }) + fn resize(&mut self, num_groups: IdxSize) { + self.groups.resize(num_groups as usize, 0); } -} -pub struct LenReduceState { - len: u64, -} + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { + self.groups[group_idx as usize] += values.len() as u64; + Ok(()) + } -impl ReductionState for LenReduceState { - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - self.len += batch.len() as u64; + unsafe fn update_groups( + &mut self, + values: &Series, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + assert!(values.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for g in group_idxs.iter() { + *self.groups.get_unchecked_mut(*g as usize) += 1; + } + } Ok(()) } - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - self.len += other.len; + assert!(other.groups.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(other.groups.iter()) { + *self.groups.get_unchecked_mut(*g as usize) += v; + } + } Ok(()) } - fn finalize(&self) -> PolarsResult { - #[allow(clippy::useless_conversion)] - let as_idx: IdxSize = self.len.try_into().expect(LENGTH_LIMIT_MSG); - Ok(Scalar::new(IDX_DTYPE, as_idx.into())) + fn finalize(&mut self) -> PolarsResult { + let ca: IdxCa = self + .groups + .drain(..) + .map(|l| IdxSize::try_from(l).expect(LENGTH_LIMIT_MSG)) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } + + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition_vec(self.groups, partition_sizes, partition_idxs) + .into_iter() + .map(|groups| Box::new(Self { groups }) as _) + .collect() } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs index e8b19b342de6..4a8ec962f237 100644 --- a/crates/polars-expr/src/reduce/mean.rs +++ b/crates/polars-expr/src/reduce/mean.rs @@ -1,56 +1,164 @@ +use std::marker::PhantomData; + +use num_traits::{AsPrimitive, Zero}; +use polars_core::with_match_physical_numeric_polars_type; + use super::*; -#[derive(Clone)] -pub struct MeanReduce { - dtype: DataType, +pub fn new_mean_reduction(dtype: DataType) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VGR::new(dtype, BoolMeanReducer)), + _ if dtype.is_numeric() || dtype.is_temporal() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, NumMeanReducer::<$T>(PhantomData))) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VGR::new(dtype, NumMeanReducer::(PhantomData))), + _ => unimplemented!(), + } } -impl MeanReduce { - pub fn new(dtype: DataType) -> Self { - Self { dtype } +fn finish_output(values: Vec<(f64, usize)>, dtype: &DataType) -> Series { + match dtype { + DataType::Float32 => { + let ca: Float32Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| (s / c as f64) as f32)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series() + }, + dt if dt.is_numeric() => { + let ca: Float64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| s / c as f64)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series() + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_prec, scale) => { + let inv_scale_factor = 1.0 / 10u128.pow(scale.unwrap() as u32) as f64; + let ca: Float64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| s / c as f64 * inv_scale_factor)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series() + }, + #[cfg(feature = "dtype-datetime")] + DataType::Date => { + const MS_IN_DAY: i64 = 86_400_000; + let ca: Int64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| (s / c as f64 * MS_IN_DAY as f64) as i64)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_datetime(TimeUnit::Milliseconds, None).into_series() + }, + DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time => { + let ca: Int64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| (s / c as f64) as i64)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype).unwrap() + }, + _ => unimplemented!(), } } -impl Reduction for MeanReduce { - fn new_reducer(&self) -> Box { - Box::new(MeanReduceState { - dtype: self.dtype.clone(), - sum: 0.0, - count: 0, - }) +struct NumMeanReducer(PhantomData); +impl Clone for NumMeanReducer { + fn clone(&self) -> Self { + Self(PhantomData) } } -pub struct MeanReduceState { - dtype: DataType, - sum: f64, - count: u64, +impl Reducer for NumMeanReducer +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg + IntoSeries, +{ + type Dtype = T; + type Value = (f64, usize); + + #[inline(always)] + fn init(&self) -> Self::Value { + (0.0, 0) + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + s.to_physical_repr() + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option) { + a.0 += b.unwrap_or(T::Native::zero()).as_(); + a.1 += b.is_some() as usize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + v.0 += ChunkAgg::_sum_as_f64(ca); + v.1 += ca.len() - ca.null_count(); + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + Ok(finish_output(v, dtype)) + } } -impl ReductionState for MeanReduceState { - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let count = batch.len() as u64 - batch.null_count() as u64; - self.count += count; - self.sum += batch._sum_as_f64(); - Ok(()) +#[derive(Clone)] +struct BoolMeanReducer; + +impl Reducer for BoolMeanReducer { + type Dtype = BooleanType; + type Value = (usize, usize); + + #[inline(always)] + fn init(&self) -> Self::Value { + (0, 0) + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; } - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - self.sum += other.sum; - self.count += other.count; - Ok(()) + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option) { + a.0 += b.unwrap_or(false) as usize; + a.1 += b.is_some() as usize; } - fn finalize(&self) -> PolarsResult { - let val = (self.count > 0).then(|| self.sum / self.count as f64); - Ok(polars_core::scalar::reduce::mean_reduce( - val, - self.dtype.clone(), - )) + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + v.0 += ca.sum().unwrap_or(0) as usize; + v.1 += ca.len() - ca.null_count(); } - fn as_any(&self) -> &dyn Any { - self + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + assert!(dtype == &DataType::Boolean); + let ca: Float64Chunked = v + .into_iter() + .map(|(s, c)| (c != 0).then(|| s as f64 / c as f64)) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) } } diff --git a/crates/polars-expr/src/reduce/min_max.rs b/crates/polars-expr/src/reduce/min_max.rs index 27cf3d5b5727..de25d3efc927 100644 --- a/crates/polars-expr/src/reduce/min_max.rs +++ b/crates/polars-expr/src/reduce/min_max.rs @@ -1,115 +1,492 @@ +use std::borrow::Cow; +use std::marker::PhantomData; + +use arrow::array::BooleanArray; +use arrow::bitmap::Bitmap; +use num_traits::Bounded; +use polars_core::with_match_physical_integer_polars_type; +#[cfg(feature = "propagate_nans")] +use polars_ops::prelude::nan_propagating_aggregate::ca_nan_agg; +use polars_utils::float::IsFloat; +use polars_utils::min_max::MinMax; + use super::*; +use crate::reduce::partition::partition_mask; -#[derive(Clone)] -pub struct MinReduce { - dtype: DataType, +pub fn new_min_reduction(dtype: DataType, propagate_nans: bool) -> Box { + use DataType::*; + use VecMaskGroupedReduction as VMGR; + match dtype { + Boolean => Box::new(BoolMinGroupedReduction::default()), + #[cfg(feature = "propagate_nans")] + Float32 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + #[cfg(feature = "propagate_nans")] + Float64 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + Float32 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + Float64 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + String | Binary => Box::new(VecGroupedReduction::new(dtype, BinaryMinReducer)), + _ if dtype.is_integer() || dtype.is_temporal() => { + with_match_physical_integer_polars_type!(dtype.to_physical(), |$T| { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VMGR::new(dtype, NumReducer::>::new())), + _ => unimplemented!(), + } } -impl MinReduce { - pub fn new(dtype: DataType) -> Self { - Self { dtype } +pub fn new_max_reduction(dtype: DataType, propagate_nans: bool) -> Box { + use DataType::*; + use VecMaskGroupedReduction as VMGR; + match dtype { + Boolean => Box::new(BoolMaxGroupedReduction::default()), + #[cfg(feature = "propagate_nans")] + Float32 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + #[cfg(feature = "propagate_nans")] + Float64 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + Float32 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + Float64 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + String | Binary => Box::new(VecGroupedReduction::new(dtype, BinaryMaxReducer)), + _ if dtype.is_integer() || dtype.is_temporal() => { + with_match_physical_integer_polars_type!(dtype.to_physical(), |$T| { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VMGR::new(dtype, NumReducer::>::new())), + _ => unimplemented!(), } } -impl Reduction for MinReduce { - fn new_reducer(&self) -> Box { - Box::new(MinReduceState { - value: Scalar::new(self.dtype.clone(), AnyValue::Null), - }) +// These two variants ignore nans. +struct Min(PhantomData); +struct Max(PhantomData); + +// These two variants propagate nans. +#[cfg(feature = "propagate_nans")] +struct NanMin(PhantomData); +#[cfg(feature = "propagate_nans")] +struct NanMax(PhantomData); + +impl NumericReduction for Min +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + if T::Native::is_float() { + T::Native::nan_value() + } else { + T::Native::max_value() + } + } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::min_ignore_nan(a, b) } -} -struct MinReduceState { - value: Scalar, + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ChunkAgg::min(ca) + } } -impl MinReduceState { - fn update_with_value(&mut self, other: &AnyValue<'static>) { - // AnyValue uses total ordering, so NaN is greater than any value. - // This means other < self.value.value() already ignores incoming NaNs. - // We still must check if self is NaN and if so replace. - if self.value.is_null() - || !other.is_null() && (other < self.value.value() || self.value.is_nan()) - { - self.value.update(other.clone()); +impl NumericReduction for Max +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + if T::Native::is_float() { + T::Native::nan_value() + } else { + T::Native::min_value() } } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::max_ignore_nan(a, b) + } + + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ChunkAgg::max(ca) + } } -impl ReductionState for MinReduceState { - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.min_reduce()?; - self.update_with_value(sc.value()); - Ok(()) +#[cfg(feature = "propagate_nans")] +impl NumericReduction for NanMin { + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + T::Native::max_value() } - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - self.update_with_value(other.value.value()); - Ok(()) + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::min_propagate_nan(a, b) } - fn finalize(&self) -> PolarsResult { - Ok(self.value.clone()) + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ca_nan_agg(ca, MinMax::min_propagate_nan) } +} - fn as_any(&self) -> &dyn Any { - self +#[cfg(feature = "propagate_nans")] +impl NumericReduction for NanMax { + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + T::Native::min_value() + } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::max_propagate_nan(a, b) + } + + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ca_nan_agg(ca, MinMax::max_propagate_nan) } } #[derive(Clone)] -pub struct MaxReduce { - dtype: DataType, +struct BinaryMinReducer; +#[derive(Clone)] +struct BinaryMaxReducer; + +impl Reducer for BinaryMinReducer { + type Dtype = BinaryType; + type Value = Option>; // TODO: evaluate SmallVec. + + fn init(&self) -> Self::Value { + None + } + + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Owned(s.cast(&DataType::Binary).unwrap()) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + self.reduce_one(a, b.as_deref()) + } + + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>) { + match (a, b) { + (_, None) => {}, + (l @ None, Some(r)) => *l = Some(r.to_owned()), + (Some(l), Some(r)) => { + if l.as_slice() > r { + l.clear(); + l.extend_from_slice(r); + } + }, + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &BinaryChunked) { + self.reduce_one(v, ca.min_binary()) + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: BinaryChunked = v.into_iter().collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype) + } } -impl MaxReduce { - pub fn new(dtype: DataType) -> Self { - Self { dtype } +impl Reducer for BinaryMaxReducer { + type Dtype = BinaryType; + type Value = Option>; // TODO: evaluate SmallVec. + + #[inline(always)] + fn init(&self) -> Self::Value { + None + } + + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Owned(s.cast(&DataType::Binary).unwrap()) + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + self.reduce_one(a, b.as_deref()) + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>) { + match (a, b) { + (_, None) => {}, + (l @ None, Some(r)) => *l = Some(r.to_owned()), + (Some(l), Some(r)) => { + if l.as_slice() < r { + l.clear(); + l.extend_from_slice(r); + } + }, + } + } + + #[inline(always)] + fn reduce_ca(&self, v: &mut Self::Value, ca: &BinaryChunked) { + self.reduce_one(v, ca.max_binary()) + } + + #[inline(always)] + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: BinaryChunked = v.into_iter().collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype) } } -impl Reduction for MaxReduce { - fn new_reducer(&self) -> Box { - Box::new(MaxReduceState { - value: Scalar::new(self.dtype.clone(), AnyValue::Null), +#[derive(Default)] +pub struct BoolMinGroupedReduction { + values: MutableBitmap, + mask: MutableBitmap, +} + +impl GroupedReduction for BoolMinGroupedReduction { + fn new_empty(&self) -> Box { + Box::new(Self::default()) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, true); + self.mask.resize(num_groups as usize, false); + } + + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + let ca: &BooleanChunked = values.as_ref().as_ref(); + if !ca.all() { + self.values.set(group_idx as usize, false); + } + if ca.len() != ca.null_count() { + self.mask.set(group_idx as usize, true); + } + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Series, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + assert!(values.len() == group_idxs.len()); + let ca: &BooleanChunked = values.as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + self.values + .and_pos_unchecked(*g as usize, ov.unwrap_or(true)); + self.mask.or_pos_unchecked(*g as usize, ov.is_some()); + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.values.len() == other.values.len()); + assert!(self.mask.len() == other.mask.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, (v, o)) in group_idxs + .iter() + .zip(other.values.iter().zip(other.mask.iter())) + { + self.values.and_pos_unchecked(*g as usize, v); + self.mask.or_pos_unchecked(*g as usize, o); + } + } + Ok(()) + } + + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + let p_values = partition_mask(&self.values.freeze(), partition_sizes, partition_idxs); + let p_mask = partition_mask(&self.mask.freeze(), partition_sizes, partition_idxs); + p_values + .into_iter() + .zip(p_mask) + .map(|(values, mask)| { + Box::new(Self { + values: values.into_mut(), + mask: mask.into_mut(), + }) as _ + }) + .collect() + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + let m = core::mem::take(&mut self.mask); + let arr = BooleanArray::from(v.freeze()) + .with_validity(Some(m.freeze())) + .boxed(); + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + &DataType::Boolean, + ) }) } + + fn as_any(&self) -> &dyn Any { + self + } } -struct MaxReduceState { - value: Scalar, +#[derive(Default)] +pub struct BoolMaxGroupedReduction { + values: MutableBitmap, + mask: MutableBitmap, } -impl MaxReduceState { - fn update_with_value(&mut self, other: &AnyValue<'static>) { - // AnyValue uses total ordering, so NaN is greater than any value. - // This means other > self.value.value() might have false positives. - // We also must check if self is NaN and if so replace. - if self.value.is_null() - || !other.is_null() - && (other > self.value.value() && !other.is_nan() || self.value.is_nan()) - { - self.value.update(other.clone()); +impl GroupedReduction for BoolMaxGroupedReduction { + fn new_empty(&self) -> Box { + Box::new(Self::default()) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, false); + self.mask.resize(num_groups as usize, false); + } + + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + let ca: &BooleanChunked = values.as_ref().as_ref(); + if ca.any() { + self.values.set(group_idx as usize, true); } + if ca.len() != ca.null_count() { + self.mask.set(group_idx as usize, true); + } + Ok(()) } -} -impl ReductionState for MaxReduceState { - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.max_reduce()?; - self.update_with_value(sc.value()); + unsafe fn update_groups( + &mut self, + values: &Series, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + assert!(values.len() == group_idxs.len()); + let ca: &BooleanChunked = values.as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + self.values + .or_pos_unchecked(*g as usize, ov.unwrap_or(false)); + self.mask.or_pos_unchecked(*g as usize, ov.is_some()); + } + } Ok(()) } - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - self.update_with_value(other.value.value()); + assert!(other.values.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, (v, o)) in group_idxs + .iter() + .zip(other.values.iter().zip(other.mask.iter())) + { + self.values.or_pos_unchecked(*g as usize, v); + self.mask.or_pos_unchecked(*g as usize, o); + } + } Ok(()) } - fn finalize(&self) -> PolarsResult { - Ok(self.value.clone()) + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + let m = core::mem::take(&mut self.mask); + let arr = BooleanArray::from(v.freeze()) + .with_validity(Some(m.freeze())) + .boxed(); + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + &DataType::Boolean, + ) + }) + } + + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + let p_values = partition_mask(&self.values.freeze(), partition_sizes, partition_idxs); + let p_mask = partition_mask(&self.mask.freeze(), partition_sizes, partition_idxs); + p_values + .into_iter() + .zip(p_mask) + .map(|(values, mask)| { + Box::new(Self { + values: values.into_mut(), + mask: mask.into_mut(), + }) as _ + }) + .collect() } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs index 26f9749b4479..bfe4cb56417b 100644 --- a/crates/polars-expr/src/reduce/mod.rs +++ b/crates/polars-expr/src/reduce/mod.rs @@ -2,39 +2,421 @@ mod convert; mod len; mod mean; mod min_max; -#[cfg(feature = "propagate_nans")] -mod nan_min_max; +mod partition; mod sum; +mod var_std; use std::any::Any; +use std::borrow::Cow; +use std::marker::PhantomData; +use arrow::array::{Array, PrimitiveArray, StaticArray}; +use arrow::bitmap::{Bitmap, MutableBitmap}; pub use convert::into_reduction; use polars_core::prelude::*; -pub trait Reduction: Send { - /// Create a new reducer for this Reduction. - fn new_reducer(&self) -> Box; -} +/// A reduction with groups. +/// +/// Each group has its own reduction state that values can be aggregated into. +pub trait GroupedReduction: Any + Send { + /// Returns a new empty reduction. + fn new_empty(&self) -> Box; + + /// Resizes this GroupedReduction to the given number of groups. + /// + /// While not an actual member of the trait, the safety preconditions below + /// refer to self.num_groups() as given by the last call of this function. + fn resize(&mut self, num_groups: IdxSize); -pub trait ReductionState: Any + Send { - /// Adds the given series into the reduction. - fn update(&mut self, batch: &Series) -> PolarsResult<()>; + /// Updates the specified group with the given values. + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()>; - /// Adds the elements of the given series at the given indices into the reduction. + /// Updates this GroupedReduction with new values. values[i] should + /// be added to reduction self[group_idxs[i]]. /// /// # Safety - /// Implementations may elide bound checks. - unsafe fn update_gathered(&mut self, batch: &Series, idx: &[IdxSize]) -> PolarsResult<()> { - let batch = batch.take_unchecked_from_slice(idx); - self.update(&batch) - } + /// group_idxs[i] < self.num_groups() for all i. + unsafe fn update_groups(&mut self, values: &Series, group_idxs: &[IdxSize]) + -> PolarsResult<()>; + + /// Combines this GroupedReduction with another. Group other[i] + /// should be combined into group self[group_idxs[i]]. + /// + /// # Safety + /// group_idxs[i] < self.num_groups() for all i. + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()>; - /// Combines this reduction with another. - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()>; + /// Partitions this GroupedReduction into several partitions. + /// + /// The ith group of this GroupedReduction should becomes the group_idxs[i] + /// group in partition partition_idxs[i]. + /// + /// # Safety + /// partitions_idxs[i] < partition_sizes.len() for all i. + /// group_idxs[i] < partition_sizes[partition_idxs[i]] for all i. + /// Each partition p has an associated set of group_idxs, this set contains + /// 0..partition_size[p] exactly once. + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec>; - /// Returns a final result from the reduction. - fn finalize(&self) -> PolarsResult; + /// Returns the finalized value per group as a Series. + /// + /// After this operation the number of groups is reset to 0. + fn finalize(&mut self) -> PolarsResult; - /// Returns this ReductionState as a dyn Any. + /// Returns this GroupedReduction as a dyn Any. fn as_any(&self) -> &dyn Any; } + +// Helper traits used in the VecGroupedReduction and VecMaskGroupedReduction to +// reduce code duplication. +pub trait Reducer: Send + Sync + Clone + 'static { + type Dtype: PolarsDataType; + type Value: Clone + Send + Sync + 'static; + fn init(&self) -> Self::Value; + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Borrowed(s) + } + fn combine(&self, a: &mut Self::Value, b: &Self::Value); + fn reduce_one( + &self, + a: &mut Self::Value, + b: Option<::Physical<'_>>, + ); + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray); + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult; +} + +pub trait NumericReduction: Send + Sync + 'static { + type Dtype: PolarsNumericType; + fn init() -> ::Native; + fn combine( + a: ::Native, + b: ::Native, + ) -> ::Native; + fn reduce_ca( + ca: &ChunkedArray, + ) -> Option<::Native>; +} + +struct NumReducer(PhantomData); +impl NumReducer { + fn new() -> Self { + Self(PhantomData) + } +} +impl Clone for NumReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Reducer for NumReducer { + type Dtype = ::Dtype; + type Value = <::Dtype as PolarsNumericType>::Native; + + #[inline(always)] + fn init(&self) -> Self::Value { + ::init() + } + + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + s.to_physical_repr() + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + *a = ::combine(*a, *b); + } + + #[inline(always)] + fn reduce_one( + &self, + a: &mut Self::Value, + b: Option<::Physical<'_>>, + ) { + if let Some(b) = b { + *a = ::combine(*a, b); + } + } + + #[inline(always)] + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + if let Some(r) = ::reduce_ca(ca) { + *v = ::combine(*v, r); + } + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + let arr = Box::new(PrimitiveArray::::from_vec(v).with_validity(m)); + Ok(unsafe { Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![arr], dtype) }) + } +} + +pub struct VecGroupedReduction { + values: Vec, + in_dtype: DataType, + reducer: R, +} + +impl VecGroupedReduction { + fn new(in_dtype: DataType, reducer: R) -> Self { + Self { + values: Vec::new(), + in_dtype, + reducer, + } + } +} + +impl GroupedReduction for VecGroupedReduction +where + R: Reducer, +{ + fn new_empty(&self) -> Box { + Box::new(Self { + values: Vec::new(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, self.reducer.init()); + } + + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + let values = self.reducer.cast_series(values); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + self.reducer + .reduce_ca(&mut self.values[group_idx as usize], ca); + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Series, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + assert!(values.len() == group_idxs.len()); + let values = self.reducer.cast_series(values); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + if values.has_nulls() { + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, ov); + } + } else { + let mut offset = 0; + for arr in ca.downcast_iter() { + let subgroup = &group_idxs[offset..offset + arr.len()]; + for (g, v) in subgroup.iter().zip(arr.values_iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, Some(v)); + } + offset += arr.len(); + } + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.in_dtype == other.in_dtype); + assert!(group_idxs.len() == other.values.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(other.values.iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.combine(grp, v); + } + } + Ok(()) + } + + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition::partition_vec(self.values, partition_sizes, partition_idxs) + .into_iter() + .map(|values| { + Box::new(Self { + values, + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) as _ + }) + .collect() + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + self.reducer.finish(v, None, &self.in_dtype) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +pub struct VecMaskGroupedReduction { + values: Vec, + mask: MutableBitmap, + in_dtype: DataType, + reducer: R, +} + +impl VecMaskGroupedReduction { + fn new(in_dtype: DataType, reducer: R) -> Self { + Self { + values: Vec::new(), + mask: MutableBitmap::new(), + in_dtype, + reducer, + } + } +} + +impl GroupedReduction for VecMaskGroupedReduction +where + R: Reducer, +{ + fn new_empty(&self) -> Box { + Box::new(Self { + values: Vec::new(), + mask: MutableBitmap::new(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, self.reducer.init()); + self.mask.resize(num_groups as usize, false); + } + + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &self.in_dtype); + let values = values.to_physical_repr(); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + self.reducer + .reduce_ca(&mut self.values[group_idx as usize], ca); + if ca.len() != ca.null_count() { + self.mask.set(group_idx as usize, true); + } + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Series, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &self.in_dtype); + assert!(values.len() == group_idxs.len()); + let values = values.to_physical_repr(); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + if let Some(v) = ov { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, Some(v)); + self.mask.set_unchecked(*g as usize, true); + } + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.in_dtype == other.in_dtype); + assert!(group_idxs.len() == other.values.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, (v, o)) in group_idxs + .iter() + .zip(other.values.iter().zip(other.mask.iter())) + { + if o { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.combine(grp, v); + self.mask.set_unchecked(*g as usize, true); + } + } + } + Ok(()) + } + + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition::partition_vec_mask( + self.values, + &self.mask.freeze(), + partition_sizes, + partition_idxs, + ) + .into_iter() + .map(|(values, mask)| { + Box::new(Self { + values, + mask: mask.into_mut(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) as _ + }) + .collect() + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + let m = core::mem::take(&mut self.mask); + self.reducer.finish(v, Some(m.freeze()), &self.in_dtype) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/nan_min_max.rs b/crates/polars-expr/src/reduce/nan_min_max.rs deleted file mode 100644 index 4a42ce37d3a5..000000000000 --- a/crates/polars-expr/src/reduce/nan_min_max.rs +++ /dev/null @@ -1,141 +0,0 @@ -use std::marker::PhantomData; - -use polars_compute::min_max::MinMaxKernel; -use polars_core::datatypes::PolarsFloatType; -use polars_utils::min_max::MinMax; - -use super::*; - -#[derive(Clone)] -pub struct NanMinReduce { - _phantom: PhantomData, -} - -impl NanMinReduce { - pub fn new() -> Self { - Self { - _phantom: PhantomData, - } - } -} - -impl Reduction for NanMinReduce -where - F::Array: for<'a> MinMaxKernel = F::Native>, -{ - fn new_reducer(&self) -> Box { - Box::new(NanMinReduceState:: { value: None }) - } -} - -struct NanMinReduceState { - value: Option, -} - -impl NanMinReduceState { - fn update_with_value(&mut self, other: Option) { - if let Some(other) = other { - if let Some(value) = self.value { - self.value = Some(MinMax::min_propagate_nan(value, other)); - } else { - self.value = Some(other); - } - } - } -} - -impl ReductionState for NanMinReduceState -where - F::Array: for<'a> MinMaxKernel = F::Native>, -{ - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let ca = batch.unpack::().unwrap(); - let reduced = ca - .downcast_iter() - .filter_map(MinMaxKernel::min_propagate_nan_kernel) - .reduce(MinMax::min_propagate_nan); - self.update_with_value(reduced); - Ok(()) - } - - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - self.update_with_value(other.value); - Ok(()) - } - - fn finalize(&self) -> PolarsResult { - Ok(Scalar::new(F::get_dtype(), AnyValue::from(self.value))) - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -#[derive(Clone)] -pub struct NanMaxReduce { - _phantom: PhantomData, -} - -impl NanMaxReduce { - pub fn new() -> Self { - Self { - _phantom: PhantomData, - } - } -} - -impl Reduction for NanMaxReduce -where - F::Array: for<'a> MinMaxKernel = F::Native>, -{ - fn new_reducer(&self) -> Box { - Box::new(NanMaxReduceState:: { value: None }) - } -} - -struct NanMaxReduceState { - value: Option, -} - -impl NanMaxReduceState { - fn update_with_value(&mut self, other: Option) { - if let Some(other) = other { - if let Some(value) = self.value { - self.value = Some(MinMax::max_propagate_nan(value, other)); - } else { - self.value = Some(other); - } - } - } -} - -impl ReductionState for NanMaxReduceState -where - F::Array: for<'a> MinMaxKernel = F::Native>, -{ - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let ca = batch.unpack::().unwrap(); - let reduced = ca - .downcast_iter() - .filter_map(MinMaxKernel::max_propagate_nan_kernel) - .reduce(MinMax::max_propagate_nan); - self.update_with_value(reduced); - Ok(()) - } - - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - self.update_with_value(other.value); - Ok(()) - } - - fn finalize(&self) -> PolarsResult { - Ok(Scalar::new(F::get_dtype(), AnyValue::from(self.value))) - } - - fn as_any(&self) -> &dyn Any { - self - } -} diff --git a/crates/polars-expr/src/reduce/partition.rs b/crates/polars-expr/src/reduce/partition.rs new file mode 100644 index 000000000000..0152035879bd --- /dev/null +++ b/crates/polars-expr/src/reduce/partition.rs @@ -0,0 +1,105 @@ +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; +use polars_utils::IdxSize; + +/// Partitions this Vec into multiple Vecs. +/// +/// # Safety +/// partitions_idxs[i] < partition_sizes.len() for all i. +/// idx_in_partition[i] < partition_sizes[partition_idxs[i]] for all i. +/// Each partition p has an associated set of idx_in_partition, this set +/// contains 0..partition_size[p] exactly once. +pub unsafe fn partition_vec( + v: Vec, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], +) -> Vec> { + assert!(partition_idxs.len() == v.len()); + + let mut partitions = partition_sizes + .iter() + .map(|sz| Vec::::with_capacity(*sz as usize)) + .collect_vec(); + + unsafe { + // Scatter into each partition. + for (i, val) in v.into_iter().enumerate() { + let p_idx = *partition_idxs.get_unchecked(i) as usize; + debug_assert!(p_idx < partitions.len()); + let p = partitions.get_unchecked_mut(p_idx); + p.push_unchecked(val); + } + + for (p, sz) in partitions.iter_mut().zip(partition_sizes) { + p.set_len(*sz as usize); + } + } + + partitions +} + +/// # Safety +/// Same as partition_vec. +pub unsafe fn partition_mask( + m: &Bitmap, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], +) -> Vec { + assert!(partition_idxs.len() == m.len()); + + let mut partitions = partition_sizes + .iter() + .map(|sz| BitmapBuilder::with_capacity(*sz as usize)) + .collect_vec(); + + unsafe { + // Scatter into each partition. + for i in 0..m.len() { + let p_idx = *partition_idxs.get_unchecked(i) as usize; + let p = partitions.get_unchecked_mut(p_idx); + p.push_unchecked(m.get_bit_unchecked(i)); + } + } + + partitions +} + +/// A fused loop of partition_vec and partition_mask. +/// # Safety +/// Same as partition_vec. +pub unsafe fn partition_vec_mask( + v: Vec, + m: &Bitmap, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], +) -> Vec<(Vec, BitmapBuilder)> { + assert!(partition_idxs.len() == v.len()); + assert!(m.len() == v.len()); + + let mut partitions = partition_sizes + .iter() + .map(|sz| { + ( + Vec::::with_capacity(*sz as usize), + BitmapBuilder::with_capacity(*sz as usize), + ) + }) + .collect_vec(); + + unsafe { + // Scatter into each partition. + for (i, val) in v.into_iter().enumerate() { + let p_idx = *partition_idxs.get_unchecked(i) as usize; + let (pv, pm) = partitions.get_unchecked_mut(p_idx); + pv.push_unchecked(val); + pm.push_unchecked(m.get_bit_unchecked(i)); + } + + for (p, sz) in partitions.iter_mut().zip(partition_sizes) { + p.0.set_len(*sz as usize); + } + } + + partitions +} diff --git a/crates/polars-expr/src/reduce/sum.rs b/crates/polars-expr/src/reduce/sum.rs index 0f1d094ded3f..466d5ffb9f9d 100644 --- a/crates/polars-expr/src/reduce/sum.rs +++ b/crates/polars-expr/src/reduce/sum.rs @@ -1,58 +1,157 @@ -use polars_core::prelude::{AnyValue, DataType}; +use std::borrow::Cow; + +use arrow::array::PrimitiveArray; +use num_traits::Zero; use super::*; -#[derive(Clone)] -pub struct SumReduce { - dtype: DataType, +pub struct SumReduce { + sums: Vec, + in_dtype: DataType, } -impl SumReduce { - pub fn new(dtype: DataType) -> Self { - // We cast small dtypes up in the sum, we must also do this when - // returning the empty sum to be consistent. - use DataType::*; - let dtype = match dtype { - Boolean => IDX_DTYPE, - Int8 | UInt8 | Int16 | UInt16 => Int64, - dt => dt, - }; - Self { dtype } +impl SumReduce { + fn new(in_dtype: DataType) -> Self { + SumReduce { + sums: Vec::new(), + in_dtype, + } } } -impl Reduction for SumReduce { - fn new_reducer(&self) -> Box { - let value = Scalar::new(self.dtype.clone(), AnyValue::zero_sum(&self.dtype)); - Box::new(SumReduceState { value }) +pub fn new_sum_reduction(dtype: DataType) -> Box { + use DataType::*; + match dtype { + Boolean => Box::new(SumReduce::::new(dtype)), + Int8 | UInt8 | Int16 | UInt16 => Box::new(SumReduce::::new(dtype)), + UInt32 => Box::new(SumReduce::::new(dtype)), + UInt64 => Box::new(SumReduce::::new(dtype)), + Int32 => Box::new(SumReduce::::new(dtype)), + Int64 => Box::new(SumReduce::::new(dtype)), + Float32 => Box::new(SumReduce::::new(dtype)), + Float64 => Box::new(SumReduce::::new(dtype)), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(SumReduce::::new(dtype)), + Duration(_) => Box::new(SumReduce::::new(dtype)), + _ => unimplemented!(), } } -struct SumReduceState { - value: Scalar, +fn cast_sum_input<'a>(s: &'a Series, dt: &DataType) -> PolarsResult> { + use DataType::*; + match dt { + Boolean => Ok(Cow::Owned(s.cast(&IDX_DTYPE)?)), + Int8 | UInt8 | Int16 | UInt16 => Ok(Cow::Owned(s.cast(&Int64)?)), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Ok(Cow::Owned( + s.decimal().unwrap().physical().clone().into_series(), + )), + #[cfg(feature = "dtype-duration")] + Duration(_) => Ok(Cow::Owned( + s.duration().unwrap().physical().clone().into_series(), + )), + _ => Ok(Cow::Borrowed(s)), + } } -impl SumReduceState { - fn add_value(&mut self, other: &AnyValue<'_>) { - self.value.update(self.value.value().add(other)); +fn out_dtype(in_dtype: &DataType) -> DataType { + use DataType::*; + match in_dtype { + Boolean => IDX_DTYPE, + Int8 | UInt8 | Int16 | UInt16 => Int64, + dt => dt.clone(), } } -impl ReductionState for SumReduceState { - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let reduced = batch.sum_reduce()?; - self.add_value(reduced.value()); +impl GroupedReduction for SumReduce +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg + IntoSeries, +{ + fn new_empty(&self) -> Box { + Box::new(Self { + sums: Vec::new(), + in_dtype: self.in_dtype.clone(), + }) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.sums.resize(num_groups as usize, T::Native::zero()); + } + + fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &self.in_dtype); + let values = cast_sum_input(values, &self.in_dtype)?; + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + self.sums[group_idx as usize] += ChunkAgg::sum(ca).unwrap_or(T::Native::zero()); Ok(()) } - fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + unsafe fn update_groups( + &mut self, + values: &Series, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &self.in_dtype); + let values = cast_sum_input(values, &self.in_dtype)?; + assert!(values.len() == group_idxs.len()); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(ca.iter()) { + *self.sums.get_unchecked_mut(*g as usize) += v.unwrap_or(T::Native::zero()); + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - self.add_value(other.value.value()); + assert!(self.in_dtype == other.in_dtype); + assert!(other.sums.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(other.sums.iter()) { + *self.sums.get_unchecked_mut(*g as usize) += *v; + } + } Ok(()) } - fn finalize(&self) -> PolarsResult { - Ok(self.value.clone()) + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition::partition_vec(self.sums, partition_sizes, partition_idxs) + .into_iter() + .map(|sums| { + Box::new(Self { + sums, + in_dtype: self.in_dtype.clone(), + }) as _ + }) + .collect() + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.sums); + let arr = Box::new(PrimitiveArray::::from_vec(v)); + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + &out_dtype(&self.in_dtype), + ) + }) } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-expr/src/reduce/var_std.rs b/crates/polars-expr/src/reduce/var_std.rs new file mode 100644 index 000000000000..7993c06e6a4e --- /dev/null +++ b/crates/polars-expr/src/reduce/var_std.rs @@ -0,0 +1,168 @@ +use std::marker::PhantomData; + +use num_traits::AsPrimitive; +use polars_compute::var_cov::VarState; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub fn new_var_std_reduction(dtype: DataType, is_std: bool, ddof: u8) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VGR::new(dtype, BoolVarStdReducer { is_std, ddof })), + _ if dtype.is_numeric() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, VarStdReducer::<$T> { + is_std, + ddof, + needs_cast: false, + _phantom: PhantomData, + })) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VGR::new( + dtype, + VarStdReducer:: { + is_std, + ddof, + needs_cast: true, + _phantom: PhantomData, + }, + )), + Duration(..) => todo!(), + _ => unimplemented!(), + } +} + +struct VarStdReducer { + is_std: bool, + ddof: u8, + needs_cast: bool, + _phantom: PhantomData, +} + +impl Clone for VarStdReducer { + fn clone(&self) -> Self { + Self { + is_std: self.is_std, + ddof: self.ddof, + needs_cast: self.needs_cast, + _phantom: PhantomData, + } + } +} + +impl Reducer for VarStdReducer { + type Dtype = T; + type Value = VarState; + + fn init(&self) -> Self::Value { + VarState::default() + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + if self.needs_cast { + Cow::Owned(s.cast(&DataType::Float64).unwrap()) + } else { + Cow::Borrowed(s) + } + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.combine(b) + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option) { + if let Some(x) = b { + a.add_one(x.as_()); + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + for arr in ca.downcast_iter() { + v.combine(&polars_compute::var_cov::var(arr)) + } + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let ca: Float64Chunked = v + .into_iter() + .map(|s| { + let var = s.finalize(self.ddof); + if self.is_std { + var.map(f64::sqrt) + } else { + var + } + }) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} + +#[derive(Clone)] +struct BoolVarStdReducer { + is_std: bool, + ddof: u8, +} + +impl Reducer for BoolVarStdReducer { + type Dtype = BooleanType; + type Value = (usize, usize); + + fn init(&self) -> Self::Value { + (0, 0) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option) { + a.0 += b.unwrap_or(false) as usize; + a.1 += b.is_some() as usize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + v.0 += ca.sum().unwrap_or(0) as usize; + v.1 += ca.len() - ca.null_count(); + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let ca: Float64Chunked = v + .into_iter() + .map(|v| { + if v.1 <= self.ddof as usize { + return None; + } + + let sum = v.0 as f64; // Both the sum and sum-of-squares, letting us simplify. + let n = v.1; + let var = sum * (1.0 - sum / n as f64) / ((n - self.ddof as usize) as f64); + if self.is_std { + Some(var.sqrt()) + } else { + Some(var) + } + }) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} diff --git a/crates/polars-expr/src/state/node_timer.rs b/crates/polars-expr/src/state/node_timer.rs index 48aa65e12c17..c3114d3029cd 100644 --- a/crates/polars-expr/src/state/node_timer.rs +++ b/crates/polars-expr/src/state/node_timer.rs @@ -57,8 +57,9 @@ impl NodeTimer { let mut end = end.into_inner(); end.rename(PlSmallStr::from_static("end")); + let height = nodes_s.len(); let columns = vec![nodes_s, start.into_column(), end.into_column()]; - let df = unsafe { DataFrame::new_no_checks(columns) }; + let df = unsafe { DataFrame::new_no_checks(height, columns) }; df.sort(vec!["start"], SortMultipleOptions::default()) } } diff --git a/crates/polars-ffi/src/version_0.rs b/crates/polars-ffi/src/version_0.rs index 3cffd4425045..504f6cc126d1 100644 --- a/crates/polars-ffi/src/version_0.rs +++ b/crates/polars-ffi/src/version_0.rs @@ -132,7 +132,7 @@ impl CallerContext { self.bitflags |= 1 << k } - /// Parallelism is done by polars' main engine, the plugin should not run run its own parallelism. + /// Parallelism is done by polars' main engine, the plugin should not run its own parallelism. /// If this is `false`, the plugin could use parallelism without (much) contention with polars /// parallelism strategies. pub fn parallel(&self) -> bool { diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index ca3d313e08ae..2e0a51acc9e4 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -37,6 +37,7 @@ num-traits = { workspace = true } object_store = { workspace = true, optional = true } once_cell = { workspace = true } percent-encoding = { workspace = true } +pyo3 = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true } reqwest = { workspace = true, optional = true } @@ -108,7 +109,6 @@ async = [ "futures", "tokio", "tokio-util", - "arrow/io_ipc_write_async", "polars-error/regex", "polars-parquet?/async", ] @@ -130,7 +130,7 @@ gcp = ["object_store/gcp", "cloud"] http = ["object_store/http", "cloud"] temporal = ["dtype-datetime", "dtype-date", "dtype-time"] simd = [] -python = ["polars-error/python"] +python = ["pyo3", "polars-error/python", "polars-utils/python"] [package.metadata.docs.rs] all-features = true diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs new file mode 100644 index 000000000000..e6de837488c1 --- /dev/null +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -0,0 +1,738 @@ +use std::fmt::Debug; +use std::future::Future; +use std::hash::Hash; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +#[cfg(feature = "aws")] +pub use object_store::aws::AwsCredential; +#[cfg(feature = "azure")] +pub use object_store::azure::AzureCredential; +#[cfg(feature = "gcp")] +pub use object_store::gcp::GcpCredential; +use polars_core::config; +use polars_error::{polars_bail, PolarsResult}; +#[cfg(feature = "python")] +use polars_utils::python_function::PythonFunction; +#[cfg(feature = "python")] +use python_impl::PythonCredentialProvider; + +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum PlCredentialProvider { + /// Prefer using [`PlCredentialProvider::from_func`] instead of constructing this directly + Function(CredentialProviderFunction), + #[cfg(feature = "python")] + Python(python_impl::PythonCredentialProvider), +} + +impl PlCredentialProvider { + /// Accepts a function that returns (credential, expiry time as seconds since UNIX_EPOCH) + /// + /// This functionality is unstable. + pub fn from_func( + // Internal notes + // * This function is exposed as the Rust API for `PlCredentialProvider` + func: impl Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync + + 'static, + ) -> Self { + Self::Function(CredentialProviderFunction(Arc::new(func))) + } + + #[cfg(feature = "python")] + pub fn from_python_func(func: PythonFunction) -> Self { + Self::Python(python_impl::PythonCredentialProvider(Arc::new(func))) + } + + #[cfg(feature = "python")] + pub fn from_python_func_object(func: pyo3::PyObject) -> Self { + Self::Python(python_impl::PythonCredentialProvider(Arc::new( + PythonFunction(func), + ))) + } + + pub(super) fn func_addr(&self) -> usize { + match self { + Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize, + #[cfg(feature = "python")] + Self::Python(PythonCredentialProvider(v)) => Arc::as_ptr(v) as *const () as usize, + } + } +} + +pub enum ObjectStoreCredential { + #[cfg(feature = "aws")] + Aws(Arc), + #[cfg(feature = "azure")] + Azure(Arc), + #[cfg(feature = "gcp")] + Gcp(Arc), + /// For testing purposes + None, +} + +impl ObjectStoreCredential { + fn variant_name(&self) -> &'static str { + match self { + #[cfg(feature = "aws")] + Self::Aws(_) => "Aws", + #[cfg(feature = "azure")] + Self::Azure(_) => "Azure", + #[cfg(feature = "gcp")] + Self::Gcp(_) => "Gcp", + Self::None => "None", + } + } + + fn panic_type_mismatch(&self, expected: &str) { + panic!( + "impl error: credential type mismatch: expected {}, got {} instead", + expected, + self.variant_name() + ) + } + + #[cfg(feature = "aws")] + fn unwrap_aws(self) -> Arc { + let Self::Aws(v) = self else { + self.panic_type_mismatch("aws"); + unreachable!() + }; + v + } + + #[cfg(feature = "azure")] + fn unwrap_azure(self) -> Arc { + let Self::Azure(v) = self else { + self.panic_type_mismatch("azure"); + unreachable!() + }; + v + } + + #[cfg(feature = "gcp")] + fn unwrap_gcp(self) -> Arc { + let Self::Gcp(v) = self else { + self.panic_type_mismatch("gcp"); + unreachable!() + }; + v + } +} + +pub trait IntoCredentialProvider: Sized { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + unimplemented!() + } +} + +impl IntoCredentialProvider for PlCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + match self { + Self::Function(v) => v.into_aws_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_aws_provider(), + } + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + match self { + Self::Function(v) => v.into_azure_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_azure_provider(), + } + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + match self { + Self::Function(v) => v.into_gcp_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_gcp_provider(), + } + } +} + +type CredentialProviderFunctionImpl = Arc< + dyn Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync, +>; + +/// Wrapper that implements [`IntoCredentialProvider`], [`Debug`], [`PartialEq`], [`Hash`] etc. +#[derive(Clone)] +pub struct CredentialProviderFunction(CredentialProviderFunctionImpl); + +macro_rules! build_to_object_store_err { + ($s:expr) => {{ + fn to_object_store_err( + e: impl std::error::Error + Send + Sync + 'static, + ) -> object_store::Error { + object_store::Error::Generic { + store: $s, + source: Box::new(e), + } + } + + to_object_store_err + }}; +} + +impl IntoCredentialProvider for CredentialProviderFunction { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::aws::AwsCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_aws(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-aws")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + })), + )) + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::azure::AzureCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_azure(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-azure")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))), + )) + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::gcp::GcpCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_gcp(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-gcp")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(GcpCredential { + bearer: String::new(), + })), + )) + } +} + +impl Debug for CredentialProviderFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "credential provider function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for CredentialProviderFunction {} + +impl PartialEq for CredentialProviderFunction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Hash for CredentialProviderFunction { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for PlCredentialProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cfg(feature = "python")] + { + Ok(Self::Python(PythonCredentialProvider::deserialize( + deserializer, + )?)) + } + #[cfg(not(feature = "python"))] + { + use serde::de::Error; + Err(D::Error::custom("cannot deserialize PlCredentialProvider")) + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PlCredentialProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + + #[cfg(feature = "python")] + if let PlCredentialProvider::Python(v) = self { + return v.serialize(serializer); + } + + Err(S::Error::custom(format!("cannot serialize {:?}", self))) + } +} + +/// Avoids calling the credential provider function if we have not yet passed the expiry time. +#[derive(Debug)] +struct FetchedCredentialsCache(tokio::sync::Mutex<(C, u64)>); + +impl FetchedCredentialsCache { + fn new(init_creds: C) -> Self { + Self(tokio::sync::Mutex::new((init_creds, 0))) + } + + async fn get_maybe_update( + &self, + // Taking an `impl Future` here allows us to potentially avoid a `Box::pin` allocation from + // a `Fn() -> Pin>` by having it wrapped in an `async { f() }` block. We + // will not poll that block if the credentials have not yet expired. + update_func: impl Future>, + ) -> PolarsResult { + let verbose = config::verbose(); + + fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String { + if last_fetched_expiry == u64::MAX { + "expiry = (never expires)".into() + } else { + format!( + "expiry = {} (in {} seconds)", + last_fetched_expiry, + last_fetched_expiry.saturating_sub(now) + ) + } + } + + let mut inner = self.0.lock().await; + let (last_fetched_credentials, last_fetched_expiry) = &mut *inner; + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Ensure the credential is valid for at least this many seconds to + // accommodate for latency. + const REQUEST_TIME_BUFFER: u64 = 7; + + if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER { + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Call update_func: current_time = {}\ + , last_fetched_expiry = {}", + current_time, *last_fetched_expiry + ) + } + let (credentials, expiry) = update_func.await?; + + *last_fetched_credentials = credentials; + *last_fetched_expiry = expiry; + + if expiry < current_time && expiry != 0 { + polars_bail!( + ComputeError: + "credential expiry time {} is older than system time {} \ + by {} seconds", + expiry, + current_time, + current_time - expiry + ) + } + + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Finish update_func: new {}", + expiry_msg( + *last_fetched_expiry, + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + ) + ) + } + } else if verbose { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + eprintln!( + "[FetchedCredentialsCache]: Using cached credentials: \ + current_time = {}, {}", + now, + expiry_msg(*last_fetched_expiry, now) + ) + } + + Ok(last_fetched_credentials.clone()) + } +} + +#[cfg(feature = "python")] +mod python_impl { + use std::hash::Hash; + use std::sync::Arc; + + use polars_error::PolarsError; + use polars_utils::python_function::PythonFunction; + use pyo3::exceptions::PyValueError; + use pyo3::pybacked::PyBackedStr; + use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods}; + use pyo3::Python; + + use super::IntoCredentialProvider; + + #[derive(Clone, Debug)] + pub struct PythonCredentialProvider(pub(super) Arc); + + impl From for PythonCredentialProvider { + fn from(value: PythonFunction) -> Self { + Self(Arc::new(value)) + } + } + + impl IntoCredentialProvider for PythonCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = object_store::aws::AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::>()?; + + match k.as_ref() { + "aws_access_key_id" => { + credentials.key_id = v.ok_or_else(|| { + PyValueError::new_err("aws_access_key_id was None") + })?; + }, + "aws_secret_access_key" => { + credentials.secret_key = v.ok_or_else(|| { + PyValueError::new_err("aws_secret_access_key was None") + })? + }, + "aws_session_token" => credentials.token = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for aws: {}, \ + valid configuration keys are: \ + {}, {}, {}", + v, + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token" + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.key_id.is_empty() { + return Err(PolarsError::ComputeError( + "aws_access_key_id was empty or not given".into(), + )); + } + + if credentials.secret_key.is_empty() { + return Err(PolarsError::ComputeError( + "aws_secret_access_key was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry)) + }) + })) + .into_aws_provider() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = + object_store::azure::AzureCredential::BearerToken(String::new()); + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + // We only support bearer for now + match k.as_ref() { + "bearer_token" => { + credentials = + object_store::azure::AzureCredential::BearerToken(v) + }, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for azure: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + let object_store::azure::AzureCredential::BearerToken(bearer) = &credentials + else { + unreachable!() + }; + + if bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry)) + }) + })) + .into_azure_provider() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = object_store::gcp::GcpCredential { + bearer: String::new(), + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + match k.as_ref() { + "bearer_token" => credentials.bearer = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for gcp: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry)) + }) + })) + .into_gcp_provider() + } + } + + impl Eq for PythonCredentialProvider {} + + impl PartialEq for PythonCredentialProvider { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } + } + + impl Hash for PythonCredentialProvider { + fn hash(&self, state: &mut H) { + // # Safety + // * Inner is an `Arc` + // * Visibility is limited to super + // * No code in `mod python_impl` or `super` mutates the Arc inner. + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } + } + + #[cfg(feature = "serde")] + mod _serde_impl { + use polars_utils::python_function::PySerializeWrap; + + use super::PythonCredentialProvider; + + impl serde::Serialize for PythonCredentialProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + PySerializeWrap(self.0.as_ref()).serialize(serializer) + } + } + + impl<'a> serde::Deserialize<'a> for PythonCredentialProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + PySerializeWrap::::deserialize(deserializer) + .map(|x| x.0.into()) + } + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "serde")] + #[allow(clippy::redundant_pattern_matching)] + #[test] + fn test_serde() { + use super::*; + + assert!(matches!( + serde_json::to_string(&Some(PlCredentialProvider::from_func(|| { + Box::pin(core::future::ready(PolarsResult::Ok(( + ObjectStoreCredential::None, + 0, + )))) + }))), + Err(_) + )); + + assert!(matches!( + serde_json::to_string(&Option::::None), + Ok(String { .. }) + )); + + assert!(matches!( + serde_json::from_str::>( + serde_json::to_string(&Option::::None) + .unwrap() + .as_str() + ), + Ok(None) + )); + } +} diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs index b41f7d45cf21..7ae2d99444a7 100644 --- a/crates/polars-io/src/cloud/mod.rs +++ b/crates/polars-io/src/cloud/mod.rs @@ -19,3 +19,6 @@ pub use object_store_setup::*; pub use options::*; #[cfg(feature = "cloud")] pub use polars_object_store::*; + +#[cfg(feature = "cloud")] +pub mod credential_provider; diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index b6464b109535..0bb33b333c18 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -3,12 +3,14 @@ use std::sync::Arc; use object_store::local::LocalFileSystem; use object_store::ObjectStore; use once_cell::sync::Lazy; +use polars_core::config; use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; use polars_utils::aliases::PlHashMap; use tokio::sync::RwLock; use url::Url; use super::{parse_url, CloudLocation, CloudOptions, CloudType}; +use crate::cloud::CloudConfig; /// Object stores must be cached. Every object-store will do DNS lookups and /// get rate limited when querying the DNS (can take up to 5s). @@ -28,10 +30,40 @@ fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { } /// Get the key of a url for object store registration. -/// The credential info will be removed fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String { + #[derive(Clone, Debug, PartialEq, Hash, Eq)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + struct S { + max_retries: usize, + #[cfg(feature = "file_cache")] + file_cache_ttl: u64, + config: Option, + #[cfg(feature = "cloud")] + credential_provider: usize, + } + // We include credentials as they can expire, so users will send new credentials for the same url. - let creds = serde_json::to_string(&options).unwrap_or_else(|_| "".into()); + let creds = serde_json::to_string(&options.map( + |CloudOptions { + // Destructure to ensure this breaks if anything changes. + max_retries, + #[cfg(feature = "file_cache")] + file_cache_ttl, + config, + #[cfg(feature = "cloud")] + credential_provider, + }| { + S { + max_retries: *max_retries, + #[cfg(feature = "file_cache")] + file_cache_ttl: *file_cache_ttl, + config: config.clone(), + #[cfg(feature = "cloud")] + credential_provider: credential_provider.as_ref().map_or(0, |x| x.func_addr()), + } + }, + )) + .unwrap(); format!( "{}://{}<\\creds\\>{}", url.scheme(), @@ -58,6 +90,8 @@ pub async fn build_object_store( let parsed = parse_url(url).map_err(to_compute_err)?; let cloud_location = CloudLocation::from_url(&parsed, glob)?; + // FIXME: `credential_provider` is currently serializing the entire Python function here + // into a string with pickle for this cache key because we are using `serde_json::to_string` let key = url_and_creds_to_key(&parsed, options); let mut allow_cache = true; @@ -124,6 +158,12 @@ pub async fn build_object_store( let mut cache = OBJECT_STORE_CACHE.write().await; // Clear the cache if we surpass a certain amount of buckets. if cache.len() > 8 { + if config::verbose() { + eprintln!( + "build_object_store: clearing store cache (cache.len(): {})", + cache.len() + ); + } cache.clear() } cache.insert(key, store.clone()); diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index efaab673f634..9549b837f06d 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -36,6 +36,8 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "cloud")] use url::Url; +#[cfg(feature = "cloud")] +use super::credential_provider::PlCredentialProvider; #[cfg(feature = "file_cache")] use crate::file_cache::get_env_file_cache_ttl; #[cfg(feature = "aws")] @@ -75,6 +77,8 @@ pub struct CloudOptions { #[cfg(feature = "file_cache")] pub file_cache_ttl: u64, pub(crate) config: Option, + #[cfg(feature = "cloud")] + pub(crate) credential_provider: Option, } impl Default for CloudOptions { @@ -84,6 +88,8 @@ impl Default for CloudOptions { #[cfg(feature = "file_cache")] file_cache_ttl: get_env_file_cache_ttl(), config: None, + #[cfg(feature = "cloud")] + credential_provider: Default::default(), } } } @@ -248,6 +254,15 @@ impl CloudOptions { self } + #[cfg(feature = "cloud")] + pub fn with_credential_provider( + mut self, + credential_provider: Option, + ) -> Self { + self.credential_provider = credential_provider; + self + } + /// Set the configuration for AWS connections. This is the preferred API from rust. #[cfg(feature = "aws")] pub fn with_aws)>>( @@ -263,6 +278,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for AWS. #[cfg(feature = "aws")] pub async fn build_aws(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = AmazonS3Builder::from_env().with_url(url); if let Some(options) = &self.config { let CloudConfig::Aws(options) = options else { @@ -346,11 +363,17 @@ impl CloudOptions { }; }; - builder + let builder = builder .with_client_options(get_client_options()) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_aws_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } /// Set the configuration for Azure connections. This is the preferred API from rust. @@ -368,6 +391,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for Azure. #[cfg(feature = "azure")] pub fn build_azure(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = MicrosoftAzureBuilder::from_env(); if let Some(options) = &self.config { let CloudConfig::Azure(options) = options else { @@ -378,12 +403,18 @@ impl CloudOptions { } } - builder + let builder = builder .with_client_options(get_client_options()) .with_url(url) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_azure_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } /// Set the configuration for GCP connections. This is the preferred API from rust. @@ -401,6 +432,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for GCP. #[cfg(feature = "gcp")] pub fn build_gcp(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = GoogleCloudStorageBuilder::from_env(); if let Some(options) = &self.config { let CloudConfig::Gcp(options) = options else { @@ -411,12 +444,18 @@ impl CloudOptions { } } - builder + let builder = builder .with_client_options(get_client_options()) .with_url(url) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_gcp_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } #[cfg(feature = "http")] diff --git a/crates/polars-io/src/csv/read/buffer.rs b/crates/polars-io/src/csv/read/buffer.rs index 712201ceaca6..1d69e0c7132e 100644 --- a/crates/polars-io/src/csv/read/buffer.rs +++ b/crates/polars-io/src/csv/read/buffer.rs @@ -148,7 +148,7 @@ where pub struct Utf8Field { name: PlSmallStr, - mutable: MutableBinaryViewArray, + mutable: MutableBinaryViewArray<[u8]>, scratch: Vec, quote_char: u8, encoding: CsvEncoding, @@ -172,7 +172,7 @@ impl Utf8Field { } #[inline] -fn validate_utf8(bytes: &[u8]) -> bool { +pub(super) fn validate_utf8(bytes: &[u8]) -> bool { simdutf8::basic::from_utf8(bytes).is_ok() } @@ -190,7 +190,7 @@ impl ParsedBuffer for Utf8Field { if missing_is_null { self.mutable.push_null() } else { - self.mutable.push(Some("")) + self.mutable.push(Some([])) } return Ok(()); } @@ -199,7 +199,7 @@ impl ParsedBuffer for Utf8Field { let escaped_bytes = if needs_escaping { self.scratch.clear(); self.scratch.reserve(bytes.len()); - polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); + polars_ensure!(bytes.len() > 1 && bytes.last() == Some(&self.quote_char), ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); // SAFETY: // we just allocated enough capacity and data_len is correct. @@ -208,36 +208,41 @@ impl ParsedBuffer for Utf8Field { escape_field(bytes, self.quote_char, self.scratch.spare_capacity_mut()); self.scratch.set_len(n_written); } + self.scratch.as_slice() } else { bytes }; - // It is important that this happens after escaping, as invalid escaped string can produce - // invalid utf8. - let parse_result = validate_utf8(escaped_bytes); + if matches!(self.encoding, CsvEncoding::LossyUtf8) | ignore_errors { + // It is important that this happens after escaping, as invalid escaped string can produce + // invalid utf8. + let parse_result = validate_utf8(escaped_bytes); - match parse_result { - true => { - let value = unsafe { std::str::from_utf8_unchecked(escaped_bytes) }; - self.mutable.push_value(value) - }, - false => { - if matches!(self.encoding, CsvEncoding::LossyUtf8) { - // TODO! do this without allocating - let s = String::from_utf8_lossy(escaped_bytes); - self.mutable.push_value(s.as_ref()) - } else if ignore_errors { - self.mutable.push_null() - } else { - // If field before escaping is valid utf8, the escaping is incorrect. - if needs_escaping && validate_utf8(bytes) { - polars_bail!(ComputeError: "string field is not properly escaped"); + match parse_result { + true => { + let value = escaped_bytes; + self.mutable.push_value(value) + }, + false => { + if matches!(self.encoding, CsvEncoding::LossyUtf8) { + // TODO! do this without allocating + let s = String::from_utf8_lossy(escaped_bytes); + self.mutable.push_value(s.as_ref().as_bytes()) + } else if ignore_errors { + self.mutable.push_null() } else { - polars_bail!(ComputeError: "invalid utf-8 sequence"); + // If field before escaping is valid utf8, the escaping is incorrect. + if needs_escaping && validate_utf8(bytes) { + polars_bail!(ComputeError: "string field is not properly escaped"); + } else { + polars_bail!(ComputeError: "invalid utf-8 sequence"); + } } - } - }, + }, + } + } else { + self.mutable.push_value(escaped_bytes) } Ok(()) @@ -631,7 +636,8 @@ impl Buffer { Buffer::Utf8(v) => { let arr = v.mutable.freeze(); - StringChunked::with_chunk(v.name.clone(), arr).into_series() + StringChunked::with_chunk(v.name.clone(), unsafe { arr.to_utf8view_unchecked() }) + .into_series() }, #[allow(unused_variables)] Buffer::Categorical(buf) => { diff --git a/crates/polars-io/src/csv/read/parser.rs b/crates/polars-io/src/csv/read/parser.rs index 2b11013be951..9272dc6d65b8 100644 --- a/crates/polars-io/src/csv/read/parser.rs +++ b/crates/polars-io/src/csv/read/parser.rs @@ -835,7 +835,7 @@ pub(super) fn parse_lines( \n\ You might want to try:\n\ - increasing `infer_schema_length` (e.g. `infer_schema_length=10000`),\n\ - - specifying correct dtype with the `dtypes` argument\n\ + - specifying correct dtype with the `schema_overrides` argument\n\ - setting `ignore_errors` to `True`,\n\ - adding `{}` to the `null_values` list.\n\n\ Original error: ```{}```", diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs index 47fa19f2e95d..52f29ee0a128 100644 --- a/crates/polars-io/src/csv/read/read_impl.rs +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -82,7 +82,7 @@ pub(crate) fn cast_columns( }) .collect::>>() })?; - *df = unsafe { DataFrame::new_no_checks(cols) } + *df = unsafe { DataFrame::new_no_checks(df.height(), cols) } } else { // cast to the original dtypes in the schema for fld in to_cast { @@ -126,7 +126,7 @@ pub(crate) struct CoreReader<'a> { truncate_ragged_lines: bool, } -impl<'a> fmt::Debug for CoreReader<'a> { +impl fmt::Debug for CoreReader<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Reader") .field("schema", &self.schema) @@ -191,7 +191,7 @@ impl<'a> CoreReader<'a> { if let Some(b) = decompress(&reader_bytes, total_n_rows, separator, quote_char, eol_char) { - reader_bytes = ReaderBytes::Owned(b); + reader_bytes = ReaderBytes::Owned(b.into()); } } @@ -467,24 +467,29 @@ impl<'a> CoreReader<'a> { continue; } - let b = unsafe { - bytes.get_unchecked_release(total_offset..total_offset + position) - }; - // The parsers will not create a null row if we end on a new line. - if b.last() == Some(&self.eol_char) { - chunk_size *= 2; - continue; - } + let end = total_offset + position + 1; + let b = unsafe { bytes.get_unchecked_release(total_offset..end) }; - total_offset += position + 1; + total_offset = end; (b, count) }; + let check_utf8 = matches!(self.encoding, CsvEncoding::Utf8) + && self.schema.iter_fields().any(|f| f.dtype().is_string()); if !b.is_empty() { let results = results.clone(); let projection = projection.as_ref(); let slf = &(*self); s.spawn(move |_| { + if check_utf8 && !super::buffer::validate_utf8(b) { + let mut results = results.lock().unwrap(); + results.push(( + b.as_ptr() as usize, + Err(polars_err!(ComputeError: "invalid utf-8 sequence")), + )); + return; + } + let result = slf .read_chunk(b, projection, 0, count, starting_point_offset, b.len()) .and_then(|mut df| { @@ -625,6 +630,6 @@ fn read_chunk( let columns = buffers .into_iter() .map(|buf| buf.into_series().map(Column::from)) - .collect::>()?; - Ok(unsafe { DataFrame::new_no_checks(columns) }) + .collect::>>()?; + Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) }) } diff --git a/crates/polars-io/src/csv/read/read_impl/batched.rs b/crates/polars-io/src/csv/read/read_impl/batched.rs index 3bf6e2dd4e32..90e0b4e4e37c 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched.rs @@ -66,7 +66,7 @@ struct ChunkOffsetIter<'a> { eol_char: u8, } -impl<'a> Iterator for ChunkOffsetIter<'a> { +impl Iterator for ChunkOffsetIter<'_> { type Item = (usize, usize); fn next(&mut self) -> Option { @@ -209,7 +209,7 @@ pub struct BatchedCsvReader<'a> { decimal_comma: bool, } -impl<'a> BatchedCsvReader<'a> { +impl BatchedCsvReader<'_> { pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { if n == 0 || self.remaining == 0 { return Ok(None); diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs index e31948c85d26..098f27d82c0d 100644 --- a/crates/polars-io/src/csv/read/reader.rs +++ b/crates/polars-io/src/csv/read/reader.rs @@ -318,30 +318,29 @@ where } #[cfg(feature = "temporal")] -fn parse_dates(mut df: DataFrame, fixed_schema: &Schema) -> DataFrame { +fn parse_dates(df: DataFrame, fixed_schema: &Schema) -> DataFrame { use polars_core::POOL; - let cols = unsafe { std::mem::take(df.get_columns_mut()) } - .into_par_iter() - .map(|c| { - match c.dtype() { - DataType::String => { - let ca = c.str().unwrap(); - // don't change columns that are in the fixed schema. - if fixed_schema.index_of(c.name()).is_some() { - return c; - } - - #[cfg(feature = "dtype-time")] - if let Ok(ca) = ca.as_time(None, false) { - return ca.into_column(); - } - c - }, - _ => c, - } - }); + let height = df.height(); + let cols = df.take_columns().into_par_iter().map(|c| { + match c.dtype() { + DataType::String => { + let ca = c.str().unwrap(); + // don't change columns that are in the fixed schema. + if fixed_schema.index_of(c.name()).is_some() { + return c; + } + + #[cfg(feature = "dtype-time")] + if let Ok(ca) = ca.as_time(None, false) { + return ca.into_column(); + } + c + }, + _ => c, + } + }); let cols = POOL.install(|| cols.collect::>()); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(height, cols) } } diff --git a/crates/polars-io/src/csv/read/schema_inference.rs b/crates/polars-io/src/csv/read/schema_inference.rs index 3684df1a2bac..a01ec5ddef3f 100644 --- a/crates/polars-io/src/csv/read/schema_inference.rs +++ b/crates/polars-io/src/csv/read/schema_inference.rs @@ -296,7 +296,7 @@ fn infer_file_schema_inner( buf.push(eol_char); return infer_file_schema_inner( - &ReaderBytes::Owned(buf), + &ReaderBytes::Owned(buf.into()), separator, max_read_rows, has_header, @@ -481,7 +481,7 @@ fn infer_file_schema_inner( rb.extend_from_slice(reader_bytes); rb.push(eol_char); return infer_file_schema_inner( - &ReaderBytes::Owned(rb), + &ReaderBytes::Owned(rb.into()), separator, max_read_rows, has_header, diff --git a/crates/polars-io/src/csv/read/splitfields.rs b/crates/polars-io/src/csv/read/splitfields.rs index 68714c4fddd7..ac06b185a37a 100644 --- a/crates/polars-io/src/csv/read/splitfields.rs +++ b/crates/polars-io/src/csv/read/splitfields.rs @@ -265,10 +265,10 @@ mod inner { .unwrap_unchecked_release() }; let simd_bytes = SimdVec::from(lane); - let eol_mask = simd_bytes.simd_eq(self.simd_eol_char).to_bitmask(); - let sep_mask = simd_bytes.simd_eq(self.simd_separator).to_bitmask(); + let has_eol = simd_bytes.simd_eq(self.simd_eol_char); + let has_sep = simd_bytes.simd_eq(self.simd_separator); let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask(); - let mut end_mask = sep_mask | eol_mask; + let mut end_mask = (has_sep | has_eol).to_bitmask(); let mut not_in_quote_field = prefix_xorsum_inclusive(quote_mask); @@ -360,12 +360,12 @@ mod inner { .unwrap_unchecked_release() }; let simd_bytes = SimdVec::from(lane); - let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char).to_bitmask(); - let has_separator = simd_bytes.simd_eq(self.simd_separator).to_bitmask(); - let has_any = has_separator | has_eol_char; + let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char); + let has_separator = simd_bytes.simd_eq(self.simd_separator); + let has_any_mask = (has_separator | has_eol_char).to_bitmask(); - if has_any != 0 { - total_idx += has_any.trailing_zeros() as usize; + if has_any_mask != 0 { + total_idx += has_any_mask.trailing_zeros() as usize; break; } else { total_idx += SIMD_SIZE; diff --git a/crates/polars-io/src/csv/write/write_impl/serializer.rs b/crates/polars-io/src/csv/write/write_impl/serializer.rs index e9a3e055278a..6a4f964d88b3 100644 --- a/crates/polars-io/src/csv/write/write_impl/serializer.rs +++ b/crates/polars-io/src/csv/write/write_impl/serializer.rs @@ -686,17 +686,11 @@ pub(super) fn serializer_for<'a>( }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, scale) => { - let array = array.as_any().downcast_ref().unwrap(); - match options.quote_style { - QuoteStyle::Never => Box::new(decimal_serializer(array, scale.unwrap_or(0))) - as Box, - _ => Box::new(quote_serializer(decimal_serializer( - array, - scale.unwrap_or(0), - ))), - } + quote_if_always!(decimal_serializer, scale.unwrap_or(0)) + }, + _ => { + polars_bail!(ComputeError: "datatype {dtype} cannot be written to CSV\n\nConsider using JSON or a binary format.") }, - _ => polars_bail!(ComputeError: "datatype {dtype} cannot be written to csv"), }; Ok(serializer) } diff --git a/crates/polars-io/src/hive.rs b/crates/polars-io/src/hive.rs index 77e65647fa56..df755eab56f3 100644 --- a/crates/polars-io/src/hive.rs +++ b/crates/polars-io/src/hive.rs @@ -1,4 +1,6 @@ +use polars_core::frame::column::ScalarColumn; use polars_core::frame::DataFrame; +use polars_core::prelude::Column; use polars_core::series::Series; /// Materializes hive partitions. @@ -22,13 +24,17 @@ pub(crate) fn materialize_hive_partitions( return; } - let hive_columns_iter = hive_columns + let hive_columns_sc = hive_columns .iter() - .map(|s| s.new_from_index(0, num_rows).into()); + .map(|s| ScalarColumn::new(s.name().clone(), s.first(), num_rows).into()) + .collect::>(); if reader_schema.index_of(hive_columns[0].name()).is_none() || df.width() == 0 { // Fast-path - all hive columns are at the end - unsafe { df.get_columns_mut() }.extend(hive_columns_iter); + if df.width() == 0 { + unsafe { df.set_height(num_rows) }; + } + unsafe { df.hstack_mut_unchecked(&hive_columns_sc) }; return; } @@ -39,9 +45,8 @@ pub(crate) fn materialize_hive_partitions( // We have a slightly involved algorithm here because `reader_schema` may contain extra // columns that were excluded from a projection pushdown. - let hive_columns = hive_columns_iter.collect::>(); // Safety: These are both non-empty at the start - let mut series_arr = [df_columns, hive_columns.as_slice()]; + let mut series_arr = [df_columns, hive_columns_sc.as_slice()]; let mut schema_idx_arr = [ reader_schema.index_of(series_arr[0][0].name()).unwrap(), reader_schema.index_of(series_arr[1][0].name()).unwrap(), @@ -71,6 +76,6 @@ pub(crate) fn materialize_hive_partitions( out_columns.extend_from_slice(series_arr[0]); out_columns.extend_from_slice(series_arr[1]); - *unsafe { df.get_columns_mut() } = out_columns; + *df = unsafe { DataFrame::new_no_checks(num_rows, out_columns) }; } } diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index ab6805d18967..100d37b2c941 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -1,6 +1,6 @@ //! # (De)serializing Arrows IPC format. //! -//! Arrow IPC is a [binary format format](https://arrow.apache.org/docs/python/ipc.html). +//! Arrow IPC is a [binary format](https://arrow.apache.org/docs/python/ipc.html). //! It is the recommended way to serialize and deserialize Polars DataFrames as this is most true //! to the data schema. //! diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index 6b16579ac93d..6393c639cf35 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -1,6 +1,6 @@ //! # (De)serializing Arrows Streaming IPC format. //! -//! Arrow Streaming IPC is a [binary format format](https://arrow.apache.org/docs/python/ipc.html). +//! Arrow Streaming IPC is a [binary format](https://arrow.apache.org/docs/python/ipc.html). //! It used for sending an arbitrary length sequence of record batches. //! The format must be processed from start to end, and does not support random access. //! It is different than IPC, if you can't deserialize a file with `IpcReader::new`, it's probably an IPC Stream File. diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs index f0343642482e..b8749c5a7392 100644 --- a/crates/polars-io/src/ipc/mmap.rs +++ b/crates/polars-io/src/ipc/mmap.rs @@ -96,9 +96,10 @@ impl ArrowReader for MMapChunkIter<'_> { let chunk = match &self.projection { None => chunk, Some(proj) => { + let length = chunk.len(); let cols = chunk.into_arrays(); let arrays = proj.iter().map(|i| cols[*i].clone()).collect(); - RecordBatch::new(arrays) + RecordBatch::new(length, arrays) }, }; Ok(Some(chunk)) diff --git a/crates/polars-io/src/ipc/mod.rs b/crates/polars-io/src/ipc/mod.rs index d78362f5555f..1e341de98c56 100644 --- a/crates/polars-io/src/ipc/mod.rs +++ b/crates/polars-io/src/ipc/mod.rs @@ -7,9 +7,6 @@ mod ipc_stream; #[cfg(feature = "ipc")] mod mmap; mod write; -#[cfg(all(feature = "async", feature = "ipc"))] -mod write_async; - #[cfg(feature = "ipc")] pub use ipc_file::{IpcReader, IpcScanOptions}; #[cfg(feature = "cloud")] diff --git a/crates/polars-io/src/ipc/write_async.rs b/crates/polars-io/src/ipc/write_async.rs deleted file mode 100644 index 5ed459a715d2..000000000000 --- a/crates/polars-io/src/ipc/write_async.rs +++ /dev/null @@ -1,59 +0,0 @@ -use arrow::io::ipc::write::file_async::FileSink; -use arrow::io::ipc::write::WriteOptions; -use futures::{AsyncWrite, SinkExt}; -use polars_core::prelude::*; - -use crate::ipc::IpcWriter; - -impl IpcWriter { - pub fn new_async(writer: W) -> Self { - IpcWriter { - writer, - compression: None, - compat_level: CompatLevel::oldest(), - } - } - - pub fn batched_async(self, schema: &Schema) -> PolarsResult> { - let writer = FileSink::new( - self.writer, - schema.to_arrow(CompatLevel::oldest()), - None, - WriteOptions { - compression: self.compression.map(|c| c.into()), - }, - ); - - Ok(BatchedWriterAsync { writer }) - } -} - -pub struct BatchedWriterAsync<'a, W> -where - W: AsyncWrite + Unpin + Send + 'a, -{ - writer: FileSink<'a, W>, -} - -impl<'a, W> BatchedWriterAsync<'a, W> -where - W: AsyncWrite + Unpin + Send + 'a, -{ - /// Write a batch to the parquet writer. - /// - /// # Panics - /// The caller must ensure the chunks in the given [`DataFrame`] are aligned. - pub async fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { - let iter = df.iter_chunks(CompatLevel::oldest(), true); - for batch in iter { - self.writer.feed(batch.into()).await?; - } - Ok(()) - } - - /// Writes the footer of the IPC file. - pub async fn finish(&mut self) -> PolarsResult<()> { - self.writer.close().await?; - Ok(()) - } -} diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index 208b1ba8befa..df2ff8bb9c28 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -236,7 +236,7 @@ pub fn remove_bom(bytes: &[u8]) -> PolarsResult<&[u8]> { Ok(bytes) } } -impl<'a, R> SerReader for JsonReader<'a, R> +impl SerReader for JsonReader<'_, R> where R: MmapBytesReader, { @@ -287,6 +287,8 @@ where } } + let allow_extra_fields_in_struct = self.schema.is_some(); + // struct type let dtype = if let Some(mut schema) = self.schema { if let Some(overwrite) = self.schema_overwrite { @@ -338,7 +340,11 @@ where dtype }; - let arr = polars_json::json::deserialize(&json_value, dtype)?; + let arr = polars_json::json::deserialize( + &json_value, + dtype, + allow_extra_fields_in_struct, + )?; let arr = arr.as_any().downcast_ref::().ok_or_else( || polars_err!(ComputeError: "can only deserialize json objects"), )?; diff --git a/crates/polars-io/src/mmap.rs b/crates/polars-io/src/mmap.rs index df91f32942f9..2373257469e7 100644 --- a/crates/polars-io/src/mmap.rs +++ b/crates/polars-io/src/mmap.rs @@ -1,9 +1,8 @@ use std::fs::File; use std::io::{BufReader, Cursor, Read, Seek}; -use std::sync::Arc; use polars_core::config::verbose; -use polars_utils::mmap::{MMapSemaphore, MemSlice}; +use polars_utils::mmap::MemSlice; /// Trait used to get a hold to file handler or to the underlying bytes /// without performing a Read. @@ -67,8 +66,7 @@ impl MmapBytesReader for &mut T { // Handle various forms of input bytes pub enum ReaderBytes<'a> { Borrowed(&'a [u8]), - Owned(Vec), - Mapped(MMapSemaphore, &'a File), + Owned(MemSlice), } impl std::ops::Deref for ReaderBytes<'_> { @@ -77,19 +75,21 @@ impl std::ops::Deref for ReaderBytes<'_> { match self { Self::Borrowed(ref_bytes) => ref_bytes, Self::Owned(vec) => vec, - Self::Mapped(mmap, _) => mmap.as_ref(), } } } -/// Require 'static to force the caller to do any transmute as it's usually much -/// clearer to see there whether it's sound. +/// There are some places that perform manual lifetime management after transmuting `ReaderBytes` +/// to have a `'static` inner lifetime. The advantage to doing this is that it lets you construct a +/// `MemSlice` from the `ReaderBytes` in a zero-copy manner regardless of the underlying enum +/// variant. impl ReaderBytes<'static> { - pub fn into_mem_slice(self) -> MemSlice { + /// Construct a `MemSlice` in a zero-copy manner from the underlying bytes, with the assumption + /// that the underlying bytes have a `'static` lifetime. + pub fn to_memslice(&self) -> MemSlice { match self { ReaderBytes::Borrowed(v) => MemSlice::from_static(v), - ReaderBytes::Owned(v) => MemSlice::from_vec(v), - ReaderBytes::Mapped(v, _) => MemSlice::from_mmap(Arc::new(v)), + ReaderBytes::Owned(v) => v.clone(), } } } @@ -104,16 +104,14 @@ impl<'a, T: 'a + MmapBytesReader> From<&'a mut T> for ReaderBytes<'a> { }, None => { if let Some(f) = m.to_file() { - let f = unsafe { std::mem::transmute::<&File, &'a File>(f) }; - let mmap = MMapSemaphore::new_from_file(f).unwrap(); - ReaderBytes::Mapped(mmap, f) + ReaderBytes::Owned(MemSlice::from_file(f).unwrap()) } else { if verbose() { eprintln!("could not memory map file; read to buffer.") } let mut buf = vec![]; m.read_to_end(&mut buf).expect("could not read"); - ReaderBytes::Owned(buf) + ReaderBytes::Owned(MemSlice::from_vec(buf)) } }, } diff --git a/crates/polars-io/src/ndjson/buffer.rs b/crates/polars-io/src/ndjson/buffer.rs index 2bb2a028f1ca..1c9938979af5 100644 --- a/crates/polars-io/src/ndjson/buffer.rs +++ b/crates/polars-io/src/ndjson/buffer.rs @@ -12,9 +12,9 @@ use simd_json::{BorrowedValue as Value, KnownKey, StaticNode}; #[derive(Debug, Clone, PartialEq)] pub(crate) struct BufferKey<'a>(pub(crate) KnownKey<'a>); -impl<'a> Eq for BufferKey<'a> {} +impl Eq for BufferKey<'_> {} -impl<'a> Hash for BufferKey<'a> { +impl Hash for BufferKey<'_> { fn hash(&self, state: &mut H) { self.0.key().hash(state) } diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs index a72b4ccf7038..2fabd7dcd589 100644 --- a/crates/polars-io/src/ndjson/core.rs +++ b/crates/polars-io/src/ndjson/core.rs @@ -133,7 +133,7 @@ where } } -impl<'a> JsonLineReader<'a, File> { +impl JsonLineReader<'_, File> { /// This is the recommended way to create a json reader as this allows for fastest parsing. pub fn from_path>(path: P) -> PolarsResult { let path = crate::resolve_homedir(&path.into()); @@ -141,7 +141,7 @@ impl<'a> JsonLineReader<'a, File> { Ok(Self::new(f).with_path(Some(path))) } } -impl<'a, R> SerReader for JsonLineReader<'a, R> +impl SerReader for JsonLineReader<'_, R> where R: MmapBytesReader, { diff --git a/crates/polars-io/src/parquet/read/mod.rs b/crates/polars-io/src/parquet/read/mod.rs index 1fec749af5ce..cc0020cc7857 100644 --- a/crates/polars-io/src/parquet/read/mod.rs +++ b/crates/polars-io/src/parquet/read/mod.rs @@ -33,6 +33,7 @@ or set 'streaming'", pub use options::{ParallelStrategy, ParquetOptions}; use polars_error::{ErrString, PolarsError}; +pub use read_impl::{create_sorting_map, try_set_sorted_flag}; #[cfg(feature = "cloud")] pub use reader::ParquetAsyncReader; pub use reader::{BatchedParquetReader, ParquetReader}; diff --git a/crates/polars-io/src/parquet/read/predicates.rs b/crates/polars-io/src/parquet/read/predicates.rs index eb8f7747f078..a3269341c1a3 100644 --- a/crates/polars-io/src/parquet/read/predicates.rs +++ b/crates/polars-io/src/parquet/read/predicates.rs @@ -1,3 +1,4 @@ +use polars_core::config; use polars_core::prelude::*; use polars_parquet::read::statistics::{deserialize, Statistics}; use polars_parquet::read::RowGroupMetadata; @@ -50,18 +51,38 @@ pub fn read_this_row_group( md: &RowGroupMetadata, schema: &ArrowSchema, ) -> PolarsResult { + if std::env::var("POLARS_NO_PARQUET_STATISTICS").is_ok() { + return Ok(true); + } + + let mut should_read = true; + if let Some(pred) = predicate { if let Some(pred) = pred.as_stats_evaluator() { if let Some(stats) = collect_statistics(md, schema)? { - let should_read = pred.should_read(&stats); + let pred_result = pred.should_read(&stats); + // a parquet file may not have statistics of all columns - if matches!(should_read, Ok(false)) { - return Ok(false); - } else if !matches!(should_read, Err(PolarsError::ColumnNotFound(_))) { - let _ = should_read?; + match pred_result { + Err(PolarsError::ColumnNotFound(errstr)) => { + return Err(PolarsError::ColumnNotFound(errstr)) + }, + Ok(false) => should_read = false, + _ => {}, } } } + + if config::verbose() { + if should_read { + eprintln!( + "parquet row group must be read, statistics not sufficient for predicate." + ); + } else { + eprintln!("parquet row group can be skipped, the statistics were sufficient to apply the predicate."); + } + } } - Ok(true) + + Ok(should_read) } diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 45aa2260de30..de22b639bf8b 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -7,14 +7,14 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowSchemaRef; use polars_core::chunked_array::builder::NullChunkedBuilder; use polars_core::prelude::*; +use polars_core::series::IsSorted; use polars_core::utils::{accumulate_dataframes_vertical, split_df}; -use polars_core::POOL; +use polars_core::{config, POOL}; use polars_parquet::parquet::error::ParquetResult; use polars_parquet::parquet::statistics::Statistics; use polars_parquet::read::{ self, ColumnChunkMetadata, FileMetadata, Filter, PhysicalType, RowGroupMetadata, }; -use polars_utils::mmap::MemSlice; use rayon::prelude::*; #[cfg(feature = "cloud")] @@ -43,6 +43,9 @@ fn assert_dtypes(dtype: &ArrowDataType) { // These should all be casted to the BinaryView / Utf8View variants D::Utf8 | D::Binary | D::LargeUtf8 | D::LargeBinary => unreachable!(), + // These should be casted to Float32 + D::Float16 => unreachable!(), + // This should have been converted to a LargeList D::List(_) => unreachable!(), @@ -60,6 +63,57 @@ fn assert_dtypes(dtype: &ArrowDataType) { } } +fn should_copy_sortedness(dtype: &DataType) -> bool { + // @NOTE: For now, we are a bit conservative with this. + use DataType as D; + + matches!( + dtype, + D::Int8 | D::Int16 | D::Int32 | D::Int64 | D::UInt8 | D::UInt16 | D::UInt32 | D::UInt64 + ) +} + +pub fn try_set_sorted_flag( + series: &mut Series, + col_idx: usize, + sorting_map: &PlHashMap, +) { + if let Some(is_sorted) = sorting_map.get(&col_idx) { + if should_copy_sortedness(series.dtype()) { + if config::verbose() { + eprintln!( + "Parquet conserved SortingColumn for column chunk of '{}' to {is_sorted:?}", + series.name() + ); + } + + series.set_sorted_flag(*is_sorted); + } + } +} + +pub fn create_sorting_map(md: &RowGroupMetadata) -> PlHashMap { + let capacity = md.sorting_columns().map_or(0, |s| s.len()); + let mut sorting_map = PlHashMap::with_capacity(capacity); + + if let Some(sorting_columns) = md.sorting_columns() { + for sorting in sorting_columns { + let prev_value = sorting_map.insert( + sorting.column_idx as usize, + if sorting.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }, + ); + + debug_assert!(prev_value.is_none()); + } + } + + sorting_map +} + fn column_idx_to_series( column_i: usize, // The metadata belonging to this column @@ -68,6 +122,8 @@ fn column_idx_to_series( file_schema: &ArrowSchema, store: &mmap::ColumnStore, ) -> PolarsResult { + let did_filter = filter.is_some(); + let field = file_schema.get_at_index(column_i).unwrap().1; #[cfg(debug_assertions)] @@ -91,6 +147,11 @@ fn column_idx_to_series( _ => {}, } + // We cannot trust the statistics if we filtered the parquet already. + if did_filter { + return Ok(series); + } + // See if we can find some statistics for this series. If we cannot find anything just return // the series as is. let Ok(Some(stats)) = stats.map(|mut s| s.pop().flatten()) else { @@ -320,6 +381,8 @@ fn rg_to_dfs_prefiltered( } } + let sorting_map = create_sorting_map(md); + // Collect the data for the live columns let live_columns = (0..num_live_columns) .into_par_iter() @@ -338,14 +401,18 @@ fn rg_to_dfs_prefiltered( let part = iter.collect::>(); - column_idx_to_series(col_idx, part.as_slice(), None, schema, store) - .map(Column::from) + let mut series = + column_idx_to_series(col_idx, part.as_slice(), None, schema, store)?; + + try_set_sorted_flag(&mut series, col_idx, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; // Apply the predicate to the live columns and save the dataframe and the bitmask let md = &file_metadata.row_groups[rg_idx]; - let mut df = unsafe { DataFrame::new_no_checks(live_columns) }; + let mut df = unsafe { DataFrame::new_no_checks(md.num_rows(), live_columns) }; materialize_hive_partitions( &mut df, @@ -445,7 +512,7 @@ fn rg_to_dfs_prefiltered( array.filter(&mask_arr) }; - let array = if mask_setting.should_prefilter( + let mut series = if mask_setting.should_prefilter( prefilter_cost, &schema.get_at_index(col_idx).unwrap().1.dtype, ) { @@ -454,14 +521,17 @@ fn rg_to_dfs_prefiltered( post()? }; - debug_assert_eq!(array.len(), filter_mask.set_bits()); + debug_assert_eq!(series.len(), filter_mask.set_bits()); - Ok(array.into_column()) + try_set_sorted_flag(&mut series, col_idx, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; debug_assert!(dead_columns.iter().all(|v| v.len() == df.height())); + let height = df.height(); let mut live_columns = df.take_columns(); assert_eq!( @@ -507,7 +577,7 @@ fn rg_to_dfs_prefiltered( // SAFETY: This is completely based on the schema so all column names are unique // and the length is given by the parquet file which should always be the same. - let df = unsafe { DataFrame::new_no_checks(columns) }; + let df = unsafe { DataFrame::new_no_checks(height, columns) }; PolarsResult::Ok(Some(df)) }) @@ -568,6 +638,8 @@ fn rg_to_dfs_optionally_par_over_columns( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let sorting_map = create_sorting_map(md); + let columns = if let ParallelStrategy::Columns = parallel { POOL.install(|| { projection @@ -585,14 +657,17 @@ fn rg_to_dfs_optionally_par_over_columns( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>() })? @@ -612,19 +687,22 @@ fn rg_to_dfs_optionally_par_over_columns( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()? }; - let mut df = unsafe { DataFrame::new_no_checks(columns) }; + let mut df = unsafe { DataFrame::new_no_checks(rg_slice.1, columns) }; if let Some(rc) = &row_index { df.with_row_index_mut(rc.name.clone(), Some(*previous_row_count + rc.offset)); } @@ -704,6 +782,8 @@ fn rg_to_dfs_par_over_rg( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let sorting_map = create_sorting_map(md); + let columns = projection .iter() .map(|column_i| { @@ -719,18 +799,21 @@ fn rg_to_dfs_par_over_rg( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(slice.0, slice.0 + slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; - let mut df = unsafe { DataFrame::new_no_checks(columns) }; + let mut df = unsafe { DataFrame::new_no_checks(slice.1, columns) }; if let Some(rc) = &row_index { df.with_row_index_mut( @@ -824,10 +907,9 @@ pub fn read_parquet( } let reader = ReaderBytes::from(&mut reader); - let store = mmap::ColumnStore::Local( - unsafe { std::mem::transmute::, ReaderBytes<'static>>(reader) } - .into_mem_slice(), - ); + let store = mmap::ColumnStore::Local(unsafe { + std::mem::transmute::, ReaderBytes<'static>>(reader).to_memslice() + }); let dfs = rg_to_dfs( &store, @@ -875,9 +957,7 @@ impl FetchRowGroupsFromMmapReader { fn fetch_row_groups(&mut self, _row_groups: Range) -> PolarsResult { // @TODO: we can something smarter here with mmap - Ok(mmap::ColumnStore::Local(MemSlice::from_vec( - self.0.deref().to_vec(), - ))) + Ok(mmap::ColumnStore::Local(self.0.to_memslice())) } } diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 2a70ef2c5046..25d1f51b098b 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -89,9 +89,15 @@ impl ParquetReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema()?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema()?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema()?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -104,7 +110,7 @@ impl ParquetReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" @@ -328,9 +334,15 @@ impl ParquetAsyncReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema().await?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema().await?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema().await?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -343,7 +355,7 @@ impl ParquetAsyncReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" diff --git a/crates/polars-io/src/path_utils/mod.rs b/crates/polars-io/src/path_utils/mod.rs index 1795cda6ebd0..71c59fecb31d 100644 --- a/crates/polars-io/src/path_utils/mod.rs +++ b/crates/polars-io/src/path_utils/mod.rs @@ -99,7 +99,7 @@ struct HiveIdxTracker<'a> { check_directory_level: bool, } -impl<'a> HiveIdxTracker<'a> { +impl HiveIdxTracker<'_> { fn update(&mut self, i: usize, path_idx: usize) -> PolarsResult<()> { let check_directory_level = self.check_directory_level; let paths = self.paths; diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index 023d61fe525b..dda13bf1ea51 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -6,14 +6,14 @@ use once_cell::sync::Lazy; use polars_core::prelude::*; #[cfg(any(feature = "ipc_streaming", feature = "parquet"))] use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df_as_ref}; -use polars_utils::mmap::MMapSemaphore; +use polars_utils::mmap::{MMapSemaphore, MemSlice}; use regex::{Regex, RegexBuilder}; use crate::mmap::{MmapBytesReader, ReaderBytes}; -pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( - reader: &'a mut R, -) -> PolarsResult> { +pub fn get_reader_bytes( + reader: &mut R, +) -> PolarsResult> { // we have a file so we can mmap // only seekable files are mmap-able if let Some((file, offset)) = reader @@ -23,14 +23,8 @@ pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( { let mut options = memmap::MmapOptions::new(); options.offset(offset); - - // somehow bck thinks borrows alias - // this is sound as file was already bound to 'a - use std::fs::File; - - let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; let mmap = MMapSemaphore::new_from_file_with_options(file, options)?; - Ok(ReaderBytes::Mapped(mmap, file)) + Ok(ReaderBytes::Owned(MemSlice::from_mmap(Arc::new(mmap)))) } else { // we can get the bytes for free if reader.to_bytes().is_some() { @@ -40,7 +34,7 @@ pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( // we have to read to an owned buffer to get the bytes. let mut bytes = Vec::with_capacity(1024 * 128); reader.read_to_end(&mut bytes)?; - Ok(ReaderBytes::Owned(bytes)) + Ok(ReaderBytes::Owned(bytes.into())) } } } @@ -79,20 +73,43 @@ pub(crate) fn columns_to_projection( Ok(prj) } +#[cfg(debug_assertions)] +fn check_offsets(dfs: &[DataFrame]) { + dfs.windows(2).for_each(|s| { + let a = &s[0].get_columns()[0]; + let b = &s[1].get_columns()[0]; + + let prev = a.get(a.len() - 1).unwrap().extract::().unwrap(); + let next = b.get(0).unwrap().extract::().unwrap(); + assert_eq!(prev + 1, next); + }) +} + /// Because of threading every row starts from `0` or from `offset`. /// We must correct that so that they are monotonically increasing. #[cfg(any(feature = "csv", feature = "json"))] pub(crate) fn update_row_counts2(dfs: &mut [DataFrame], offset: IdxSize) { if !dfs.is_empty() { - let mut previous = dfs[0].height() as IdxSize + offset; - for df in &mut dfs[1..] { + let mut previous = offset; + for df in &mut *dfs { + if df.is_empty() { + continue; + } let n_read = df.height() as IdxSize; if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = &*s + previous; + if let Ok(v) = s.get(0) { + if v.extract::().unwrap() != previous as usize { + *s = &*s + previous; + } + } } previous += n_read; } } + #[cfg(debug_assertions)] + { + check_offsets(dfs) + } } /// Because of threading every row starts from `0` or from `offset`. @@ -101,15 +118,21 @@ pub(crate) fn update_row_counts2(dfs: &mut [DataFrame], offset: IdxSize) { pub(crate) fn update_row_counts3(dfs: &mut [DataFrame], heights: &[IdxSize], offset: IdxSize) { assert_eq!(dfs.len(), heights.len()); if !dfs.is_empty() { - let mut previous = heights[0] + offset; - for i in 1..dfs.len() { + let mut previous = offset; + for i in 0..dfs.len() { let df = &mut dfs[i]; - let n_read = heights[i]; + if df.is_empty() { + continue; + } if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = &*s + previous; + if let Ok(v) = s.get(0) { + if v.extract::().unwrap() != previous as usize { + *s = &*s + previous; + } + } } - + let n_read = heights[i]; previous += n_read; } } diff --git a/crates/polars-json/src/json/deserialize.rs b/crates/polars-json/src/json/deserialize.rs index eb4c12954a8d..5e6977eff3e2 100644 --- a/crates/polars-json/src/json/deserialize.rs +++ b/crates/polars-json/src/json/deserialize.rs @@ -17,169 +17,245 @@ const JSON_NULL_VALUE: BorrowedValue = BorrowedValue::Static(StaticNode::Null); fn deserialize_boolean_into<'a, A: Borrow>>( target: &mut MutableBooleanArray, rows: &[A], -) { - let iter = rows.iter().map(|row| match row.borrow() { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::Static(StaticNode::Bool(v)) => Some(v), - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); target.extend_trusted_len(iter); + check_err_idx(rows, err_idx, "boolean") } fn deserialize_primitive_into<'a, T: NativeType + NumCast, A: Borrow>>( target: &mut MutablePrimitiveArray, rows: &[A], -) { - let iter = rows.iter().map(|row| match row.borrow() { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::Static(StaticNode::I64(v)) => T::from(*v), BorrowedValue::Static(StaticNode::U64(v)) => T::from(*v), BorrowedValue::Static(StaticNode::F64(v)) => T::from(*v), BorrowedValue::Static(StaticNode::Bool(v)) => T::from(*v as u8), - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); target.extend_trusted_len(iter); + check_err_idx(rows, err_idx, "numeric") } -fn deserialize_binary<'a, A: Borrow>>(rows: &[A]) -> BinaryArray { - let iter = rows.iter().map(|row| match row.borrow() { +fn deserialize_binary<'a, A: Borrow>>( + rows: &[A], +) -> PolarsResult> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::String(v) => Some(v.as_bytes()), - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); - BinaryArray::from_trusted_len_iter(iter) + let out = BinaryArray::from_trusted_len_iter(iter); + check_err_idx(rows, err_idx, "binary")?; + Ok(out) } fn deserialize_utf8_into<'a, O: Offset, A: Borrow>>( target: &mut MutableUtf8Array, rows: &[A], -) { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); let mut scratch = String::new(); - for row in rows { + for (i, row) in rows.iter().enumerate() { match row.borrow() { BorrowedValue::String(v) => target.push(Some(v.as_ref())), BorrowedValue::Static(StaticNode::Bool(v)) => { target.push(Some(if *v { "true" } else { "false" })) }, - BorrowedValue::Static(node) if !matches!(node, StaticNode::Null) => { + BorrowedValue::Static(StaticNode::Null) => target.push_null(), + BorrowedValue::Static(node) => { write!(scratch, "{node}").unwrap(); target.push(Some(scratch.as_str())); scratch.clear(); }, - _ => target.push_null(), + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, } } + check_err_idx(rows, err_idx, "string") } fn deserialize_utf8view_into<'a, A: Borrow>>( target: &mut MutableBinaryViewArray, rows: &[A], -) { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); let mut scratch = String::new(); - for row in rows { + for (i, row) in rows.iter().enumerate() { match row.borrow() { BorrowedValue::String(v) => target.push_value(v.as_ref()), BorrowedValue::Static(StaticNode::Bool(v)) => { target.push_value(if *v { "true" } else { "false" }) }, - BorrowedValue::Static(node) if !matches!(node, StaticNode::Null) => { + BorrowedValue::Static(StaticNode::Null) => target.push_null(), + BorrowedValue::Static(node) => { write!(scratch, "{node}").unwrap(); target.push_value(scratch.as_str()); scratch.clear(); }, - _ => target.push_null(), + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, } } + check_err_idx(rows, err_idx, "string") } fn deserialize_list<'a, A: Borrow>>( rows: &[A], dtype: ArrowDataType, -) -> ListArray { + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { + let mut err_idx = rows.len(); let child = ListArray::::get_child_type(&dtype); let mut validity = MutableBitmap::with_capacity(rows.len()); let mut offsets = Offsets::::with_capacity(rows.len()); let mut inner = vec![]; - rows.iter().for_each(|row| match row.borrow() { - BorrowedValue::Array(value) => { - inner.extend(value.iter()); - validity.push(true); - offsets - .try_push(value.len()) - .expect("List offset is too large :/"); - }, - BorrowedValue::Static(StaticNode::Null) => { - validity.push(false); - offsets.extend_constant(1) - }, - value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { - inner.push(value); - validity.push(true); - offsets.try_push(1).expect("List offset is too large :/"); - }, - _ => { - validity.push(false); - offsets.extend_constant(1); - }, - }); + rows.iter() + .enumerate() + .for_each(|(i, row)| match row.borrow() { + BorrowedValue::Array(value) => { + inner.extend(value.iter()); + validity.push(true); + offsets + .try_push(value.len()) + .expect("List offset is too large :/"); + }, + BorrowedValue::Static(StaticNode::Null) => { + validity.push(false); + offsets.extend_constant(1) + }, + value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { + inner.push(value); + validity.push(true); + offsets.try_push(1).expect("List offset is too large :/"); + }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, + }); + + check_err_idx(rows, err_idx, "list")?; - let values = _deserialize(&inner, child.clone()); + let values = _deserialize(&inner, child.clone(), allow_extra_fields_in_struct)?; - ListArray::::new(dtype, offsets.into(), values, validity.into()) + Ok(ListArray::::new( + dtype, + offsets.into(), + values, + validity.into(), + )) } fn deserialize_struct<'a, A: Borrow>>( rows: &[A], dtype: ArrowDataType, -) -> StructArray { + allow_extra_fields_in_struct: bool, +) -> PolarsResult { + let mut err_idx = rows.len(); let fields = StructArray::get_fields(&dtype); - let mut values = fields + let mut out_values = fields .iter() .map(|f| (f.name.as_str(), (f.dtype(), vec![]))) .collect::>(); let mut validity = MutableBitmap::with_capacity(rows.len()); + // Custom error tracker + let mut extra_field = None; - rows.iter().for_each(|row| { + rows.iter().enumerate().for_each(|(i, row)| { match row.borrow() { - BorrowedValue::Object(value) => { - values.iter_mut().for_each(|(s, (_, inner))| { - inner.push(value.get(*s).unwrap_or(&JSON_NULL_VALUE)) - }); + BorrowedValue::Object(values) => { + let mut n_matched = 0usize; + for (&key, &mut (_, ref mut inner)) in out_values.iter_mut() { + if let Some(v) = values.get(key) { + n_matched += 1; + inner.push(v) + } else { + inner.push(&JSON_NULL_VALUE) + } + } + validity.push(true); + + if n_matched < values.len() && extra_field.is_none() { + for k in values.keys() { + if !out_values.contains_key(k.as_ref()) { + extra_field = Some(k.as_ref()) + } + } + } }, - _ => { - values + BorrowedValue::Static(StaticNode::Null) => { + out_values .iter_mut() .for_each(|(_, (_, inner))| inner.push(&JSON_NULL_VALUE)); validity.push(false); }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, }; }); + if let Some(v) = extra_field { + if !allow_extra_fields_in_struct { + polars_bail!(ComputeError: "extra key in struct data: {}", v) + } + } + + check_err_idx(rows, err_idx, "struct")?; + // ensure we collect in the proper order let values = fields .iter() .map(|fld| { - let (dtype, vals) = values.get(fld.name.as_str()).unwrap(); - _deserialize(vals, (*dtype).clone()) + let (dtype, vals) = out_values.get(fld.name.as_str()).unwrap(); + _deserialize(vals, (*dtype).clone(), allow_extra_fields_in_struct) }) - .collect::>(); + .collect::>>()?; - StructArray::new(dtype.clone(), values, validity.into()) + Ok(StructArray::new( + dtype.clone(), + rows.len(), + values, + validity.into(), + )) } fn fill_array_from( - f: fn(&mut MutablePrimitiveArray, &[B]), + f: fn(&mut MutablePrimitiveArray, &[B]) -> PolarsResult<()>, dtype: ArrowDataType, rows: &[B], -) -> Box +) -> PolarsResult> where T: NativeType, A: From> + Array, { let mut array = MutablePrimitiveArray::::with_capacity(rows.len()).to(dtype); - f(&mut array, rows); - Box::new(A::from(array)) + f(&mut array, rows)?; + Ok(Box::new(A::from(array))) } /// A trait describing an array with a backing store that can be preallocated to @@ -236,22 +312,34 @@ impl Container for MutableUtf8Array { } } -fn fill_generic_array_from(f: fn(&mut M, &[B]), rows: &[B]) -> Box +fn fill_generic_array_from( + f: fn(&mut M, &[B]) -> PolarsResult<()>, + rows: &[B], +) -> PolarsResult> where M: Container, A: From + Array, { let mut array = M::with_capacity(rows.len()); - f(&mut array, rows); - Box::new(A::from(array)) + f(&mut array, rows)?; + Ok(Box::new(A::from(array))) } pub(crate) fn _deserialize<'a, A: Borrow>>( rows: &[A], dtype: ArrowDataType, -) -> Box { + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { match &dtype { - ArrowDataType::Null => Box::new(NullArray::new(dtype, rows.len())), + ArrowDataType::Null => { + if let Some(err_idx) = (0..rows.len()) + .find(|i| !matches!(rows[*i].borrow(), BorrowedValue::Static(StaticNode::Null))) + { + check_err_idx(rows, err_idx, "null")?; + } + + Ok(Box::new(NullArray::new(dtype, rows.len()))) + }, ArrowDataType::Boolean => { fill_generic_array_from::<_, _, BooleanArray>(deserialize_boolean_into, rows) }, @@ -277,7 +365,8 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) }, ArrowDataType::Timestamp(tu, tz) => { - let iter = rows.iter().map(|row| match row.borrow() { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::Static(StaticNode::I64(v)) => Some(*v), BorrowedValue::String(v) => match (tu, tz) { (_, None) => temporal_conversions::utf8_to_naive_timestamp_scalar(v, "%+", tu), @@ -286,9 +375,15 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( temporal_conversions::utf8_to_timestamp_scalar(v, "%+", &tz, tu) }, }, - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); - Box::new(Int64Array::from_iter(iter).to(dtype)) + let out = Box::new(Int64Array::from_iter(iter).to(dtype)); + check_err_idx(rows, err_idx, "timestamp")?; + Ok(out) }, ArrowDataType::UInt8 => { fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) @@ -315,19 +410,51 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( ArrowDataType::Utf8View => { fill_generic_array_from::<_, _, Utf8ViewArray>(deserialize_utf8view_into, rows) }, - ArrowDataType::LargeList(_) => Box::new(deserialize_list(rows, dtype)), - ArrowDataType::LargeBinary => Box::new(deserialize_binary(rows)), - ArrowDataType::Struct(_) => Box::new(deserialize_struct(rows, dtype)), + ArrowDataType::LargeList(_) => Ok(Box::new(deserialize_list( + rows, + dtype, + allow_extra_fields_in_struct, + )?)), + ArrowDataType::LargeBinary => Ok(Box::new(deserialize_binary(rows)?)), + ArrowDataType::Struct(_) => Ok(Box::new(deserialize_struct( + rows, + dtype, + allow_extra_fields_in_struct, + )?)), _ => todo!(), } } -pub fn deserialize(json: &BorrowedValue, dtype: ArrowDataType) -> PolarsResult> { +pub fn deserialize( + json: &BorrowedValue, + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { match json { BorrowedValue::Array(rows) => match dtype { - ArrowDataType::LargeList(inner) => Ok(_deserialize(rows, inner.dtype)), + ArrowDataType::LargeList(inner) => { + _deserialize(rows, inner.dtype, allow_extra_fields_in_struct) + }, _ => todo!("read an Array from a non-Array data type"), }, - _ => Ok(_deserialize(&[json], dtype)), + _ => _deserialize(&[json], dtype, allow_extra_fields_in_struct), } } + +fn check_err_idx<'a>( + rows: &[impl Borrow>], + err_idx: usize, + type_name: &'static str, +) -> PolarsResult<()> { + if err_idx != rows.len() { + polars_bail!( + ComputeError: + r#"error deserializing value "{:?}" as {}. \ + Try increasing `infer_schema_length` or specifying a schema. + "#, + rows[err_idx].borrow(), type_name, + ) + } + + Ok(()) +} diff --git a/crates/polars-json/src/json/write/mod.rs b/crates/polars-json/src/json/write/mod.rs index a23b245b68b2..6796ef7436bb 100644 --- a/crates/polars-json/src/json/write/mod.rs +++ b/crates/polars-json/src/json/write/mod.rs @@ -101,7 +101,7 @@ impl<'a> RecordSerializer<'a> { } } -impl<'a> FallibleStreamingIterator for RecordSerializer<'a> { +impl FallibleStreamingIterator for RecordSerializer<'_> { type Item = [u8]; type Error = PolarsError; diff --git a/crates/polars-json/src/ndjson/deserialize.rs b/crates/polars-json/src/ndjson/deserialize.rs index 94a482b7b275..4441691cf034 100644 --- a/crates/polars-json/src/ndjson/deserialize.rs +++ b/crates/polars-json/src/ndjson/deserialize.rs @@ -18,19 +18,28 @@ pub fn deserialize_iter<'a>( dtype: ArrowDataType, buf_size: usize, count: usize, + allow_extra_fields_in_struct: bool, ) -> PolarsResult { let mut arr: Vec> = Vec::new(); let mut buf = Vec::with_capacity(std::cmp::min(buf_size + count + 2, u32::MAX as usize)); buf.push(b'['); - fn _deserializer(s: &mut [u8], dtype: ArrowDataType) -> PolarsResult> { + fn _deserializer( + s: &mut [u8], + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, + ) -> PolarsResult> { let out = simd_json::to_borrowed_value(s) .map_err(|e| PolarsError::ComputeError(format!("json parsing error: '{e}'").into()))?; - Ok(if let BorrowedValue::Array(rows) = out { - super::super::json::deserialize::_deserialize(&rows, dtype.clone()) + if let BorrowedValue::Array(rows) = out { + super::super::json::deserialize::_deserialize( + &rows, + dtype.clone(), + allow_extra_fields_in_struct, + ) } else { unreachable!() - }) + } } let mut row_iter = rows.peekable(); @@ -42,7 +51,11 @@ pub fn deserialize_iter<'a>( if buf.len() + next_row_length >= u32::MAX as usize { let _ = buf.pop(); buf.push(b']'); - arr.push(_deserializer(&mut buf, dtype.clone())?); + arr.push(_deserializer( + &mut buf, + dtype.clone(), + allow_extra_fields_in_struct, + )?); buf.clear(); buf.push(b'['); } @@ -53,9 +66,13 @@ pub fn deserialize_iter<'a>( buf.push(b']'); if arr.is_empty() { - _deserializer(&mut buf, dtype.clone()) + _deserializer(&mut buf, dtype.clone(), allow_extra_fields_in_struct) } else { - arr.push(_deserializer(&mut buf, dtype.clone())?); + arr.push(_deserializer( + &mut buf, + dtype.clone(), + allow_extra_fields_in_struct, + )?); concatenate_owned_unchecked(&arr) } } diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 2dfd642cde1f..78f8274fb079 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -71,7 +71,7 @@ temporal = [ ] # debugging purposes fmt = ["polars-core/fmt", "polars-plan/fmt"] -strings = ["polars-plan/strings"] +strings = ["polars-plan/strings", "polars-stream?/strings"] future = [] dtype-full = [ @@ -163,7 +163,7 @@ bitwise = [ "polars-plan/bitwise", "polars-expr/bitwise", "polars-core/bitwise", - "polars-stream/bitwise", + "polars-stream?/bitwise", "polars-ops/bitwise", ] approx_unique = ["polars-plan/approx_unique"] @@ -203,6 +203,7 @@ dynamic_group_by = [ "temporal", "polars-expr/dynamic_group_by", "polars-mem-engine/dynamic_group_by", + "polars-stream?/dynamic_group_by", ] ewma = ["polars-plan/ewma"] ewma_by = ["polars-plan/ewma_by"] @@ -257,7 +258,7 @@ replace = ["polars-plan/replace"] binary_encoding = ["polars-plan/binary_encoding"] string_encoding = ["polars-plan/string_encoding"] -bigidx = ["polars-plan/bigidx"] +bigidx = ["polars-plan/bigidx", "polars-utils/bigidx"] polars_cloud = ["polars-plan/polars_cloud"] panic_on_schema = ["polars-plan/panic_on_schema", "polars-expr/panic_on_schema"] diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs index 92783aa874ee..b94295ac12e4 100644 --- a/crates/polars-lazy/src/dsl/eval.rs +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -52,6 +52,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { // Ensure we get the new schema. let output_field = eval_field_to_dtype(c.field().as_ref(), &expr, false); + let schema = Arc::new(Schema::from_iter(std::iter::once(output_field.clone()))); let expr = expr.clone(); let mut arena = Arena::with_capacity(10); @@ -60,7 +61,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { &aexpr, Context::Default, &arena, - None, + &schema, &mut ExpressionConversionState::new(true, 0), )?; @@ -99,9 +100,9 @@ pub trait ExprEvalExtension: IntoExpr + Sized { let c = c.slice(0, len); if (len - c.null_count()) >= min_periods { unsafe { - df_container.get_columns_mut().push(c.into_column()); + df_container.with_column_unchecked(c.into_column()); let out = phys_expr.evaluate(&df_container, &state)?.into_column(); - df_container.get_columns_mut().clear(); + df_container.clear_columns(); finish(out) } } else { diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs index 4dae2529bc14..d73e4be5d13e 100644 --- a/crates/polars-lazy/src/dsl/list.rs +++ b/crates/polars-lazy/src/dsl/list.rs @@ -86,9 +86,9 @@ fn run_per_sublist( lst.into_iter() .map(|s| { s.and_then(|s| unsafe { - df_container.get_columns_mut().push(s.into_column()); + df_container.with_column_unchecked(s.into_column()); let out = phys_expr.evaluate(&df_container, &state); - df_container.get_columns_mut().clear(); + df_container.clear_columns(); match out { Ok(s) => Some(s), Err(e) => { diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index d7b11f7ab7fa..cf90c5232450 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -252,7 +252,7 @@ impl LazyFrame { /// Return a String describing the logical plan. /// - /// If `optimized` is `true`, explains the optimized plan. If `optimized` is `false, + /// If `optimized` is `true`, explains the optimized plan. If `optimized` is `false`, /// explains the naive, un-optimized plan. pub fn explain(&self, optimized: bool) -> PolarsResult { if optimized { @@ -612,12 +612,12 @@ impl LazyFrame { lp_arena, expr_arena, scratch, - Some(&|expr, expr_arena| { + Some(&|expr, expr_arena, schema| { let phys_expr = create_physical_expr( expr, Context::Default, expr_arena, - None, + schema, &mut ExpressionConversionState::new(true, 0), ) .ok()?; @@ -714,48 +714,12 @@ impl LazyFrame { pub fn collect(self) -> PolarsResult { #[cfg(feature = "new_streaming")] { - let auto_new_streaming = - std::env::var("POLARS_AUTO_NEW_STREAMING").as_deref() == Ok("1"); - if self.opt_state.contains(OptFlags::NEW_STREAMING) || auto_new_streaming { - // Try to run using the new streaming engine, falling back - // if it fails in a todo!() error if auto_new_streaming is set. - let mut new_stream_lazy = self.clone(); - new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING; - new_stream_lazy.opt_state &= !OptFlags::STREAMING; - let mut alp_plan = new_stream_lazy.to_alp_optimized()?; - let stream_lp_top = alp_plan.lp_arena.add(IR::Sink { - input: alp_plan.lp_top, - payload: SinkType::Memory, - }); - - let f = || { - polars_stream::run_query( - stream_lp_top, - alp_plan.lp_arena, - &mut alp_plan.expr_arena, - ) - }; - match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { - Ok(r) => return r, - Err(e) => { - // Fallback to normal engine if error is due to not being implemented - // and auto_new_streaming is set, otherwise propagate error. - if auto_new_streaming - && e.downcast_ref::<&str>() - .map(|s| s.starts_with("not yet implemented")) - .unwrap_or(false) - { - if polars_core::config::verbose() { - eprintln!("caught unimplemented error in new streaming engine, falling back to normal engine"); - } - } else { - std::panic::resume_unwind(e); - } - }, - } + let mut slf = self; + if let Some(df) = slf.try_new_streaming_if_requested(SinkType::Memory) { + return Ok(df?.unwrap()); } - let mut alp_plan = self.to_alp_optimized()?; + let mut alp_plan = slf.to_alp_optimized()?; let mut physical_plan = create_physical_plan( alp_plan.lp_top, &mut alp_plan.lp_arena, @@ -895,6 +859,54 @@ impl LazyFrame { ) } + #[cfg(feature = "new_streaming")] + pub fn try_new_streaming_if_requested( + &mut self, + payload: SinkType, + ) -> Option>> { + let auto_new_streaming = std::env::var("POLARS_AUTO_NEW_STREAMING").as_deref() == Ok("1"); + + if self.opt_state.contains(OptFlags::NEW_STREAMING) || auto_new_streaming { + // Try to run using the new streaming engine, falling back + // if it fails in a todo!() error if auto_new_streaming is set. + let mut new_stream_lazy = self.clone(); + new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING; + new_stream_lazy.opt_state &= !OptFlags::STREAMING; + let mut alp_plan = match new_stream_lazy.to_alp_optimized() { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }; + let stream_lp_top = alp_plan.lp_arena.add(IR::Sink { + input: alp_plan.lp_top, + payload, + }); + + let f = || { + polars_stream::run_query(stream_lp_top, alp_plan.lp_arena, &mut alp_plan.expr_arena) + }; + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { + Ok(v) => return Some(v), + Err(e) => { + // Fallback to normal engine if error is due to not being implemented + // and auto_new_streaming is set, otherwise propagate error. + if auto_new_streaming + && e.downcast_ref::<&str>() + .map(|s| s.starts_with("not yet implemented")) + .unwrap_or(false) + { + if polars_core::config::verbose() { + eprintln!("caught unimplemented error in new streaming engine, falling back to normal engine"); + } + } else { + std::panic::resume_unwind(e); + } + }, + } + } + + None + } + #[cfg(any( feature = "ipc", feature = "parquet", @@ -903,11 +915,21 @@ impl LazyFrame { feature = "json", ))] fn sink(mut self, payload: SinkType, msg_alternative: &str) -> Result<(), PolarsError> { - self.opt_state |= OptFlags::STREAMING; + #[cfg(feature = "new_streaming")] + { + if self + .try_new_streaming_if_requested(payload.clone()) + .is_some() + { + return Ok(()); + } + } + self.logical_plan = DslPlan::Sink { input: Arc::new(self.logical_plan), payload, }; + self.opt_state |= OptFlags::STREAMING; let (mut state, mut physical_plan, is_streaming) = self.prepare_collect(true)?; polars_ensure!( is_streaming, @@ -1004,7 +1026,7 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1012,7 +1034,7 @@ impl LazyFrame { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` @@ -1327,7 +1349,7 @@ impl LazyFrame { right_on: E, args: JoinArgs, ) -> LazyFrame { - // if any of the nodes reads from files we must activate this this plan as well. + // if any of the nodes reads from files we must activate this plan as well. if other.opt_state.contains(OptFlags::FILE_CACHING) { self.opt_state |= OptFlags::FILE_CACHING; } @@ -1495,10 +1517,10 @@ impl LazyFrame { } /// Aggregate all the columns as their quantile values. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { self.map_private(DslFunction::Stats(StatsFunction::Quantile { quantile, - interpol, + method, })) } @@ -1885,7 +1907,7 @@ impl LazyGroupBy { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1893,7 +1915,7 @@ impl LazyGroupBy { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index 3059384a1c8c..f3dff5710170 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -104,7 +104,7 @@ //! use polars_core::prelude::*; //! use polars_core::df; //! use polars_lazy::prelude::*; -//! use arrow::legacy::prelude::QuantileInterpolOptions; +//! use arrow::legacy::prelude::QuantileMethod; //! //! fn example() -> PolarsResult { //! let df = df!( @@ -118,7 +118,7 @@ //! .agg([ //! col("rain").min().alias("min_rain"), //! col("rain").sum().alias("sum_rain"), -//! col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), +//! col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), //! ]) //! .sort(["date"], Default::default()) //! .collect() diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index 453337e616f8..08673ca1f032 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -24,21 +24,25 @@ pub(crate) fn prepare_expression_for_context( // create a dummy lazyframe and run a very simple optimization run so that // type coercion and simplify expression optimizations run. let column = Series::full_null(name, 0, dtype); - let lf = column - .into_frame() + let df = column.into_frame(); + let input_schema = Arc::new(df.schema()); + let lf = df .lazy() .without_optimizations() .with_simplify_expr(true) .select([expr.clone()]); let optimized = lf.optimize(&mut lp_arena, &mut expr_arena)?; let lp = lp_arena.get(optimized); - let aexpr = lp.get_exprs().pop().unwrap(); + let aexpr = lp + .get_exprs() + .pop() + .ok_or_else(|| polars_err!(ComputeError: "expected expressions in the context"))?; create_physical_expr( &aexpr, ctxt, &expr_arena, - None, + &input_schema, &mut ExpressionConversionState::new(true, 0), ) } diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index 777f769866d0..ad4d8cd1fb48 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -50,7 +50,7 @@ impl PhysicalPipedExpr for Wrap { fn to_physical_piped_expr( expr: &ExprIR, expr_arena: &Arena, - schema: Option<&SchemaRef>, + schema: &SchemaRef, ) -> PolarsResult> { // this is a double Arc explore if we can create a single of it. create_physical_expr( diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 7100c083bd47..6c84af4510b5 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -163,13 +163,19 @@ pub(crate) fn insert_streaming_nodes( execution_id += 1; match lp_arena.get(root) { Filter { input, predicate } - if is_streamable(predicate.node(), expr_arena, Context::Default) => + if is_streamable( + predicate.node(), + expr_arena, + IsStreamableContext::new(Default::default()), + ) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, - HStack { input, exprs, .. } if all_streamable(exprs, expr_arena, Context::Default) => { + HStack { input, exprs, .. } + if all_streamable(exprs, expr_arena, Default::default()) => + { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) @@ -194,7 +200,13 @@ pub(crate) fn insert_streaming_nodes( state.operators_sinks.push(PipelineNode::Sink(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, - Select { input, expr, .. } if all_streamable(expr, expr_arena, Context::Default) => { + Select { input, expr, .. } + if all_streamable( + expr, + expr_arena, + IsStreamableContext::new(Default::default()), + ) => + { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index 6ed8e1cc67c8..615b74e7738f 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -204,7 +204,7 @@ fn test_cse_joins_4954() -> PolarsResult<()> { let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = c.optimize(&mut lp_arena, &mut expr_arena).unwrap(); - // Ensure we get only one cache and the it is not above the join + // Ensure we get only one cache and it is not above the join // and ensure that every cache only has 1 hit. let cache_ids = (&lp_arena) .iter(lp) @@ -218,7 +218,7 @@ fn test_cse_joins_4954() -> PolarsResult<()> { .. } => { assert_eq!(*cache_hits, 1); - assert!(matches!(lp_arena.get(*input), IR::DataFrameScan { .. })); + assert!(matches!(lp_arena.get(*input), IR::SimpleProjection { .. })); Some(*id) }, diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 4d482202cd67..95cbf586be67 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1486,7 +1486,7 @@ fn test_singleton_broadcast() -> PolarsResult<()> { #[test] fn test_list_in_select_context() -> PolarsResult<()> { let s = Column::new("a".into(), &[1, 2, 3]); - let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone()).unwrap(); + let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone()); builder.append_series(s.as_materialized_series()).unwrap(); let expected = builder.finish().into_column(); diff --git a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs index 83c6ec2e5bda..ad41378b3086 100644 --- a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs +++ b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs @@ -144,7 +144,7 @@ fn estimate_unique_count(keys: &[Column], mut sample_size: usize) -> PolarsResul if keys.len() == 1 { // we sample as that will work also with sorted data. - // not that sampling without replacement is very very expensive. don't do that. + // not that sampling without replacement is *very* expensive. don't do that. let s = keys[0].sample_n(sample_size, true, false, None).unwrap(); // fast multi-threaded way to get unique. let groups = s.as_materialized_series().group_tuples(true, false)?; @@ -156,7 +156,7 @@ fn estimate_unique_count(keys: &[Column], mut sample_size: usize) -> PolarsResul .map(|s| s.slice(offset, sample_size)) .map(Column::from) .collect::>(); - let df = unsafe { DataFrame::new_no_checks(keys) }; + let df = unsafe { DataFrame::new_no_checks_height_from_first(keys) }; let names = df.get_column_names().into_iter().cloned(); let gb = df.group_by(names).unwrap(); Ok(finish(gb.get_groups())) diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index 477cdd79a162..47464849582e 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -340,9 +340,12 @@ pub(super) fn check_expand_literals( } // @scalar-opt - let selected_columns = selected_columns.into_iter().map(Column::from).collect(); + let selected_columns = selected_columns + .into_iter() + .map(Column::from) + .collect::>(); - let df = unsafe { DataFrame::new_no_checks(selected_columns) }; + let df = unsafe { DataFrame::new_no_checks_height_from_first(selected_columns) }; // a literal could be projected to a zero length dataframe. // This prevents a panic. diff --git a/crates/polars-mem-engine/src/executors/scan/ndjson.rs b/crates/polars-mem-engine/src/executors/scan/ndjson.rs index 58862bd71f9e..1f90e07a72c1 100644 --- a/crates/polars-mem-engine/src/executors/scan/ndjson.rs +++ b/crates/polars-mem-engine/src/executors/scan/ndjson.rs @@ -1,5 +1,6 @@ use polars_core::config; use polars_core::utils::accumulate_dataframes_vertical; +use polars_io::prelude::{JsonLineReader, SerReader}; use polars_io::utils::compression::maybe_decompress_bytes; use super::*; diff --git a/crates/polars-mem-engine/src/executors/scan/python_scan.rs b/crates/polars-mem-engine/src/executors/scan/python_scan.rs index 067895ed593f..f74da737d0ac 100644 --- a/crates/polars-mem-engine/src/executors/scan/python_scan.rs +++ b/crates/polars-mem-engine/src/executors/scan/python_scan.rs @@ -23,11 +23,9 @@ fn python_df_to_rust(py: Python, df: Bound) -> PolarsResult { let (ptr, len, cap) = raw_parts; unsafe { - Ok(DataFrame::new_no_checks(Vec::from_raw_parts( - ptr as *mut Column, - len, - cap, - ))) + Ok(DataFrame::new_no_checks_height_from_first( + Vec::from_raw_parts(ptr as *mut Column, len, cap), + )) } } diff --git a/crates/polars-mem-engine/src/executors/stack.rs b/crates/polars-mem-engine/src/executors/stack.rs index a5bb2f78ad89..e325e2982a00 100644 --- a/crates/polars-mem-engine/src/executors/stack.rs +++ b/crates/polars-mem-engine/src/executors/stack.rs @@ -64,7 +64,7 @@ impl StackExec { // new, unique column names. It is immediately // followed by a projection which pulls out the // possibly mismatching column lengths. - unsafe { df.get_columns_mut() }.extend(res.into_iter().map(Column::from)); + unsafe { df.column_extend_unchecked(res.into_iter().map(Column::from)) }; } else { let (df_height, df_width) = df.shape(); diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index e1b53bea2151..3a5e525867fb 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -168,7 +168,7 @@ fn create_physical_plan_impl( e, Context::Default, expr_arena, - Some(&options.schema), + &options.schema, &mut state, ) }; @@ -210,7 +210,7 @@ fn create_physical_plan_impl( }, SinkType::File { file_type, .. } => { polars_bail!(InvalidOperation: - "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_parquet()'" + "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_{file_type:?}()'" ) }, #[cfg(feature = "cloud")] @@ -239,7 +239,11 @@ fn create_physical_plan_impl( Ok(Box::new(executors::SliceExec { input, offset, len })) }, Filter { input, predicate } => { - let mut streamable = is_streamable(predicate.node(), expr_arena, Context::Default); + let mut streamable = is_streamable( + predicate.node(), + expr_arena, + IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), + ); let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); if streamable { // This can cause problems with string caches @@ -264,7 +268,7 @@ fn create_physical_plan_impl( &predicate, Context::Default, expr_arena, - Some(&input_schema), + &input_schema, &mut state, )?; Ok(Box::new(executors::FilterExec::new( @@ -297,7 +301,7 @@ fn create_physical_plan_impl( &pred, Context::Default, expr_arena, - output_schema.as_ref(), + output_schema.as_ref().unwrap_or(&file_info.schema), &mut state, ) }) @@ -374,16 +378,20 @@ fn create_physical_plan_impl( POOL.current_num_threads() > expr.len(), state.expr_depth, ); - - let streamable = - options.should_broadcast && all_streamable(&expr, expr_arena, Context::Default); let phys_expr = create_physical_expressions_from_irs( &expr, Context::Default, expr_arena, - Some(&input_schema), + &input_schema, &mut state, )?; + + let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false)) + // If all columns are literal we would get a 1 row per thread. + && !phys_expr.iter().all(|p| { + p.is_literal() + }); + Ok(Box::new(executors::ProjectionExec { input, expr: phys_expr, @@ -419,13 +427,7 @@ fn create_physical_plan_impl( let mut state = ExpressionConversionState::new(true, state.expr_depth); let selection = predicate .map(|pred| { - create_physical_expr( - &pred, - Context::Default, - expr_arena, - Some(&schema), - &mut state, - ) + create_physical_expr(&pred, Context::Default, expr_arena, &schema, &mut state) }) .transpose()?; Ok(Box::new(executors::DataFrameExec { @@ -446,7 +448,7 @@ fn create_physical_plan_impl( &by_column, Context::Default, expr_arena, - Some(input_schema.as_ref()), + input_schema.as_ref(), &mut ExpressionConversionState::new(true, state.expr_depth), )?; let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; @@ -488,14 +490,14 @@ fn create_physical_plan_impl( &keys, Context::Default, expr_arena, - Some(&input_schema), + &input_schema, &mut ExpressionConversionState::new(true, state.expr_depth), )?; let phys_aggs = create_physical_expressions_from_irs( &aggs, Context::Aggregation, expr_arena, - Some(&input_schema), + &input_schema, &mut ExpressionConversionState::new(true, state.expr_depth), )?; @@ -594,21 +596,24 @@ fn create_physical_plan_impl( } else { false }; + let schema_left = lp_arena.get(input_left).schema(lp_arena).into_owned(); + let schema_right = lp_arena.get(input_right).schema(lp_arena).into_owned(); let input_left = create_physical_plan_impl(input_left, lp_arena, expr_arena, state)?; let input_right = create_physical_plan_impl(input_right, lp_arena, expr_arena, state)?; + let left_on = create_physical_expressions_from_irs( &left_on, Context::Default, expr_arena, - None, + &schema_left, &mut ExpressionConversionState::new(true, state.expr_depth), )?; let right_on = create_physical_expressions_from_irs( &right_on, Context::Default, expr_arena, - None, + &schema_right, &mut ExpressionConversionState::new(true, state.expr_depth), )?; let options = Arc::try_unwrap(options).unwrap_or_else(|options| (*options).clone()); @@ -630,8 +635,12 @@ fn create_physical_plan_impl( let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; - let streamable = - options.should_broadcast && all_streamable(&exprs, expr_arena, Context::Default); + let streamable = options.should_broadcast + && all_streamable( + &exprs, + expr_arena, + IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), + ); let mut state = ExpressionConversionState::new( POOL.current_num_threads() > exprs.len(), @@ -642,7 +651,7 @@ fn create_physical_plan_impl( &exprs, Context::Default, expr_arena, - Some(&input_schema), + &input_schema, &mut state, )?; Ok(Box::new(executors::StackExec { diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 027d846b485e..63e52cffc1e4 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -34,8 +34,10 @@ rand = { workspace = true, optional = true, features = ["small_rng", "std"] } rand_distr = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true } +regex-syntax = { workspace = true } serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } unicode-reverse = { workspace = true, optional = true } [dependencies.jsonpath_lib] diff --git a/crates/polars-ops/src/chunked_array/array/to_struct.rs b/crates/polars-ops/src/chunked_array/array/to_struct.rs index b00dbbf4d43b..2342c4909159 100644 --- a/crates/polars-ops/src/chunked_array/array/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/array/to_struct.rs @@ -23,7 +23,6 @@ pub trait ToStruct: AsArray { .as_deref() .unwrap_or(&arr_default_struct_name_gen); - polars_ensure!(n_fields != 0, ComputeError: "cannot create a struct with 0 fields"); let fields = POOL.install(|| { (0..n_fields) .into_par_iter() @@ -40,7 +39,7 @@ pub trait ToStruct: AsArray { .collect::>>() })?; - StructChunked::from_series(ca.name().clone(), fields.iter()) + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) } } diff --git a/crates/polars-ops/src/chunked_array/cov.rs b/crates/polars-ops/src/chunked_array/cov.rs index 5a9b952097b8..dbfa6b48f4fb 100644 --- a/crates/polars-ops/src/chunked_array/cov.rs +++ b/crates/polars-ops/src/chunked_array/cov.rs @@ -1,196 +1,34 @@ -use num_traits::{ToPrimitive, Zero}; -use polars_compute::float_sum::FloatSum; +use num_traits::AsPrimitive; +use polars_compute::var_cov::{CovState, PearsonState}; use polars_core::prelude::*; use polars_core::utils::align_chunks_binary; -const COV_BUF_SIZE: usize = 64; - -/// Calculates the sum of x[i] * y[i] from 0..k. -fn multiply_sum(x: &[f64; COV_BUF_SIZE], y: &[f64; COV_BUF_SIZE], k: usize) -> f64 { - assert!(k <= COV_BUF_SIZE); - let tmp: [f64; COV_BUF_SIZE] = std::array::from_fn(|i| x[i] * y[i]); - FloatSum::sum(&tmp[..k]) -} - /// Compute the covariance between two columns. pub fn cov(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option where T: PolarsNumericType, - T::Native: ToPrimitive, -{ - if a.len() != b.len() { - None - } else { - let (a, b) = align_chunks_binary(a, b); - - let out = if a.null_count() > 0 || b.null_count() > 0 { - let iters = a.downcast_iter().zip(b.downcast_iter()).map(|(a, b)| { - a.into_iter().zip(b).filter_map(|(a, b)| match (a, b) { - (Some(a), Some(b)) => Some((*a, *b)), - _ => None, - }) - }); - online_cov(iters, ddof) - } else { - let iters = a - .downcast_iter() - .zip(b.downcast_iter()) - .map(|(a, b)| a.values_iter().copied().zip(b.values_iter().copied())); - online_cov(iters, ddof) - }; - Some(out) - } -} - -/// # Arguments -/// `iter` - Iterator over `T` tuple where any `Option` would skip the tuple. -fn online_cov(iters: I, ddof: u8) -> f64 -where - I: Iterator, - J: IntoIterator + Clone, - T: ToPrimitive, + T::Native: AsPrimitive, + ChunkedArray: ChunkVar, { - // The algorithm is derived from - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version - // We simply set the weights to 1.0. This allows us to simplify the expressions - // a lot, and move out subtractions out of sums. - let mut mean_x = 0.0; - let mut mean_y = 0.0; - let mut cxy = 0.0; - let mut n = 0.0; - - let mut x_tmp = [0.0; COV_BUF_SIZE]; - let mut y_tmp = [0.0; COV_BUF_SIZE]; - - for iter in iters { - let mut iter = iter.clone().into_iter(); - - loop { - let mut k = 0; - for (x, y) in iter.by_ref().take(COV_BUF_SIZE) { - let x = x.to_f64().unwrap(); - let y = y.to_f64().unwrap(); - - x_tmp[k] = x; - y_tmp[k] = y; - k += 1; - } - if k == 0 { - break; - } - - // TODO: combine these all in one SIMD'ized pass. - let xsum: f64 = FloatSum::sum(&x_tmp[..k]); - let ysum: f64 = FloatSum::sum(&y_tmp[..k]); - let xysum = multiply_sum(&x_tmp, &y_tmp, k); - - let old_mean_x = mean_x; - let old_mean_y = mean_y; - n += k as f64; - mean_x += (xsum - k as f64 * old_mean_x) / n; - mean_y += (ysum - k as f64 * old_mean_y) / n; - - cxy += xysum - xsum * old_mean_y - ysum * mean_x + mean_x * old_mean_y * (k as f64); - } + let (a, b) = align_chunks_binary(a, b); + let mut out = CovState::default(); + for (a, b) in a.downcast_iter().zip(b.downcast_iter()) { + out.combine(&polars_compute::var_cov::cov(a, b)) } - - cxy / (n - ddof as f64) + out.finalize(ddof) } /// Compute the pearson correlation between two columns. pub fn pearson_corr(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option where T: PolarsNumericType, - T::Native: ToPrimitive, + T::Native: AsPrimitive, ChunkedArray: ChunkVar, { let (a, b) = align_chunks_binary(a, b); - - let out = if a.null_count() > 0 || b.null_count() > 0 { - let iters = a.downcast_iter().zip(b.downcast_iter()).map(|(a, b)| { - a.into_iter().zip(b).filter_map(|(a, b)| match (a, b) { - (Some(a), Some(b)) => Some((*a, *b)), - _ => None, - }) - }); - online_pearson_corr(iters, ddof) - } else { - let iters = a - .downcast_iter() - .zip(b.downcast_iter()) - .map(|(a, b)| a.values_iter().copied().zip(b.values_iter().copied())); - online_pearson_corr(iters, ddof) - }; - Some(out) -} - -/// # Arguments -/// `iter` - Iterator over `T` tuple where any `Option` would skip the tuple. -fn online_pearson_corr(iters: I, ddof: u8) -> f64 -where - I: Iterator, - J: IntoIterator + Clone, - T: ToPrimitive, -{ - // Algorithm is same as cov, we just maintain cov(X, X), cov(X, Y), and - // cov(Y, Y), noting that var(X) = cov(X, X). - // Then corr(X, Y) = cov(X, Y)/(std(X) * std(Y)). - let mut mean_x = 0.0; - let mut mean_y = 0.0; - let mut cxy = 0.0; - let mut cxx = 0.0; - let mut cyy = 0.0; - let mut n = 0.0; - - let mut x_tmp = [0.0; COV_BUF_SIZE]; - let mut y_tmp = [0.0; COV_BUF_SIZE]; - - for iter in iters { - let mut iter = iter.clone().into_iter(); - - loop { - let mut k = 0; - for (x, y) in iter.by_ref().take(COV_BUF_SIZE) { - let x = x.to_f64().unwrap(); - let y = y.to_f64().unwrap(); - - x_tmp[k] = x; - y_tmp[k] = y; - k += 1; - } - if k == 0 { - break; - } - - // TODO: combine these all in one SIMD'ized pass. - let xsum: f64 = FloatSum::sum(&x_tmp[..k]); - let ysum: f64 = FloatSum::sum(&y_tmp[..k]); - let xxsum = multiply_sum(&x_tmp, &x_tmp, k); - let xysum = multiply_sum(&x_tmp, &y_tmp, k); - let yysum = multiply_sum(&y_tmp, &y_tmp, k); - - let old_mean_x = mean_x; - let old_mean_y = mean_y; - n += k as f64; - mean_x += (xsum - k as f64 * old_mean_x) / n; - mean_y += (ysum - k as f64 * old_mean_y) / n; - - cxx += xxsum - xsum * old_mean_x - xsum * mean_x + mean_x * old_mean_x * (k as f64); - cxy += xysum - xsum * old_mean_y - ysum * mean_x + mean_x * old_mean_y * (k as f64); - cyy += yysum - ysum * old_mean_y - ysum * mean_y + mean_y * old_mean_y * (k as f64); - } - } - - let sample_n = n - ddof as f64; - let sample_cov = cxy / sample_n; - let sample_std_x = (cxx / sample_n).sqrt(); - let sample_std_y = (cyy / sample_n).sqrt(); - - let denom = sample_std_x * sample_std_y; - let result = sample_cov / denom; - if denom.is_zero() { - f64::NAN - } else { - result + let mut out = PearsonState::default(); + for (a, b) in a.downcast_iter().zip(b.downcast_iter()) { + out.combine(&polars_compute::var_cov::pearson_corr(a, b)) } + Some(out.finalize(ddof)) } diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index 345f3689984c..b31c77e8e365 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -24,7 +24,7 @@ pub trait DfTake: IntoDf { .to_df() ._apply_columns(&|s| s.take_chunked_unchecked(idx, sorted)); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks_height_from_first(cols) } } /// Take elements by a slice of optional [`ChunkId`]s. /// @@ -35,7 +35,7 @@ pub trait DfTake: IntoDf { .to_df() ._apply_columns(&|s| s.take_opt_chunked_unchecked(idx)); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks_height_from_first(cols) } } /// # Safety @@ -45,7 +45,7 @@ pub trait DfTake: IntoDf { .to_df() ._apply_columns_par(&|s| s.take_chunked_unchecked(idx, sorted)); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks_height_from_first(cols) } } /// # Safety @@ -57,7 +57,7 @@ pub trait DfTake: IntoDf { .to_df() ._apply_columns_par(&|s| s.take_opt_chunked_unchecked(idx)); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks_height_from_first(cols) } } } diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index ca906d12851c..dbf8337ae637 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -136,7 +136,7 @@ where let out = fields.pop().unwrap(); out.with_name(ca.name().clone()) } else { - StructChunked::from_series(ca.name().clone(), fields.iter()) + StructChunked::from_series(ca.name().clone(), fields[0].len(), fields.iter()) .unwrap() .into_series() } diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 3584fa792d07..fc498c25ce44 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -653,7 +653,7 @@ pub trait ListNameSpaceImpl: AsList { ca.get_values_size() + vals_size_other + 1, length, ca.name().clone(), - )?; + ); ca.into_iter().for_each(|opt_s| { let opt_s = opt_s.map(|mut s| { for append in &to_append { @@ -690,7 +690,7 @@ pub trait ListNameSpaceImpl: AsList { ca.get_values_size() + vals_size_other + 1, length, ca.name().clone(), - )?; + ); for _ in 0..ca.len() { let mut acc = match first_iter.next().unwrap() { diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index cdd245bce8b7..fad1bcebb9a1 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -5,83 +5,220 @@ use polars_utils::pl_str::PlSmallStr; use super::*; -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ListToStructArgs { + FixedWidth(Arc<[PlSmallStr]>), + InferWidth { + infer_field_strategy: ListToStructWidthStrategy, + get_index_name: Option, + /// If this is 0, it means unbounded. + max_fields: usize, + }, +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum ListToStructWidthStrategy { FirstNonNull, MaxWidth, } -fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize { - match n_fields { - ListToStructWidthStrategy::MaxWidth => { - let mut max = 0; - - ca.downcast_iter().for_each(|arr| { - let offsets = arr.offsets().as_slice(); - let mut last = offsets[0]; - for o in &offsets[1..] { - let len = (*o - last) as usize; - max = std::cmp::max(max, len); - last = *o; +impl ListToStructArgs { + pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult { + let DataType::List(inner_dtype) = input_dtype else { + polars_bail!( + InvalidOperation: + "attempted list to_struct on non-list dtype: {}", + input_dtype + ); + }; + let inner_dtype = inner_dtype.as_ref(); + + match self { + Self::FixedWidth(names) => Ok(DataType::Struct( + names + .iter() + .map(|x| Field::new(x.clone(), inner_dtype.clone())) + .collect::>(), + )), + Self::InferWidth { + get_index_name, + max_fields, + .. + } if *max_fields > 0 => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + Ok(DataType::Struct( + (0..*max_fields) + .map(|i| Field::new(get_index_name_func(i), inner_dtype.clone())) + .collect::>(), + )) + }, + Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)), + } + } + + fn det_n_fields(&self, ca: &ListChunked) -> usize { + match self { + Self::FixedWidth(v) => v.len(), + Self::InferWidth { + infer_field_strategy, + max_fields, + .. + } => { + let inferred = match infer_field_strategy { + ListToStructWidthStrategy::MaxWidth => { + let mut max = 0; + + ca.downcast_iter().for_each(|arr| { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + let len = (*o - last) as usize; + max = std::cmp::max(max, len); + last = *o; + } + }); + max + }, + ListToStructWidthStrategy::FirstNonNull => { + let mut len = 0; + for arr in ca.downcast_iter() { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + len = (*o - last) as usize; + if len > 0 { + break; + } + last = *o; + } + if len > 0 { + break; + } + } + len + }, + }; + + if *max_fields > 0 { + inferred.min(*max_fields) + } else { + inferred } - }); - max - }, - ListToStructWidthStrategy::FirstNonNull => { - let mut len = 0; - for arr in ca.downcast_iter() { - let offsets = arr.offsets().as_slice(); - let mut last = offsets[0]; - for o in &offsets[1..] { - len = (*o - last) as usize; - if len > 0 { - break; - } - last = *o; + }, + } + } + + fn set_output_names(&self, columns: &mut [Series]) { + match self { + Self::FixedWidth(v) => { + assert_eq!(columns.len(), v.len()); + + for (c, name) in columns.iter_mut().zip(v.iter()) { + c.rename(name.clone()); } - if len > 0 { - break; + }, + Self::InferWidth { get_index_name, .. } => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + + for (i, c) in columns.iter_mut().enumerate() { + c.rename(get_index_name_func(i)); } - } - len - }, + }, + } + } +} + +#[derive(Clone)] +pub struct NameGenerator(pub Arc PlSmallStr + Send + Sync>); + +impl NameGenerator { + pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self { + Self(Arc::new(func)) + } +} + +impl std::fmt::Debug for NameGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "list::to_struct::NameGenerator function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for NameGenerator {} + +impl PartialEq for NameGenerator { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) } } -pub type NameGenerator = Arc PlSmallStr + Send + Sync>; +impl std::hash::Hash for NameGenerator { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr { format_pl_smallstr!("field_{idx}") } pub trait ToStruct: AsList { - fn to_struct( - &self, - n_fields: ListToStructWidthStrategy, - name_generator: Option, - ) -> PolarsResult { + fn to_struct(&self, args: &ListToStructArgs) -> PolarsResult { let ca = self.as_list(); - let n_fields = det_n_fields(ca, n_fields); + let n_fields = args.det_n_fields(ca); - let name_generator = name_generator - .as_deref() - .unwrap_or(&_default_struct_name_gen); - - polars_ensure!(n_fields != 0, ComputeError: "cannot create a struct with 0 fields"); - let fields = POOL.install(|| { + let mut fields = POOL.install(|| { (0..n_fields) .into_par_iter() - .map(|i| { - ca.lst_get(i as i64, true).map(|mut s| { - s.rename(name_generator(i)); - s - }) - }) + .map(|i| ca.lst_get(i as i64, true)) .collect::>>() })?; - StructChunked::from_series(ca.name().clone(), fields.iter()) + args.set_output_names(&mut fields); + + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) } } impl ToStruct for ListChunked {} + +#[cfg(feature = "serde")] +mod _serde_impl { + use super::*; + + impl serde::Serialize for NameGenerator { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + Err(S::Error::custom( + "cannot serialize name generator function for to_struct, \ + consider passing a list of field names instead.", + )) + } + } + + impl<'de> serde::Deserialize<'de> for NameGenerator { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + Err(D::Error::custom( + "invalid data: attempted to deserialize list::to_struct::NameGenerator", + )) + } + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/case.rs b/crates/polars-ops/src/chunked_array/strings/case.rs index 7bb348e28803..4004f122143f 100644 --- a/crates/polars-ops/src/chunked_array/strings/case.rs +++ b/crates/polars-ops/src/chunked_array/strings/case.rs @@ -5,7 +5,7 @@ fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec) { out.clear(); out.reserve(b.len()); - const USIZE_SIZE: usize = std::mem::size_of::(); + const USIZE_SIZE: usize = size_of::(); const MAGIC_UNROLL: usize = 2; const N: usize = USIZE_SIZE * MAGIC_UNROLL; const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]); diff --git a/crates/polars-ops/src/chunked_array/strings/escape_regex.rs b/crates/polars-ops/src/chunked_array/strings/escape_regex.rs new file mode 100644 index 000000000000..1edb9146e9f4 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/escape_regex.rs @@ -0,0 +1,21 @@ +use polars_core::prelude::{StringChunked, StringChunkedBuilder}; + +#[inline] +pub fn escape_regex_str(s: &str) -> String { + regex_syntax::escape(s) +} + +pub fn escape_regex(ca: &StringChunked) -> StringChunked { + let mut buffer = String::new(); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + for opt_s in ca.iter() { + if let Some(s) = opt_s { + buffer.clear(); + regex_syntax::escape_into(s, &mut buffer); + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } + builder.finish() +} diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index cb26d66f7aff..e7068e96b614 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -36,7 +36,7 @@ fn extract_groups_array( } let values = builders.into_iter().map(|a| a.freeze().boxed()).collect(); - Ok(StructArray::new(dtype.clone(), values, arr.validity().cloned()).boxed()) + Ok(StructArray::new(dtype.clone(), arr.len(), values, arr.validity().cloned()).boxed()) } #[cfg(feature = "extract_groups")] @@ -50,6 +50,7 @@ pub(super) fn extract_groups( if n_fields == 1 { return StructChunked::from_series( ca.name().clone(), + ca.len(), [Series::new_null(ca.name().clone(), ca.len())].iter(), ) .map(|ca| ca.into_series()); diff --git a/crates/polars-ops/src/chunked_array/strings/json_path.rs b/crates/polars-ops/src/chunked_array/strings/json_path.rs index fe8783530d6a..7ef813d63027 100644 --- a/crates/polars-ops/src/chunked_array/strings/json_path.rs +++ b/crates/polars-ops/src/chunked_array/strings/json_path.rs @@ -98,6 +98,8 @@ pub trait Utf8JsonPathImpl: AsString { infer_schema_len: Option, ) -> PolarsResult { let ca = self.as_string(); + // Ignore extra fields instead of erroring if the dtype was explicitly given. + let allow_extra_fields_in_struct = dtype.is_some(); let dtype = match dtype { Some(dt) => dt, None => ca.json_infer(infer_schema_len)?, @@ -110,6 +112,7 @@ pub trait Utf8JsonPathImpl: AsString { dtype.to_arrow(CompatLevel::newest()), buf_size, ca.len(), + allow_extra_fields_in_struct, ) .map_err(|e| polars_err!(ComputeError: "error deserializing JSON: {}", e))?; Series::try_from((PlSmallStr::EMPTY, array)) @@ -204,6 +207,7 @@ mod tests { let expected_series = StructChunked::from_series( "".into(), + 4, [ Series::new("a".into(), &[None, Some(1), Some(2), None]), Series::new("b".into(), &[None, Some("hello"), Some("goodbye"), None]), diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index b9149983307b..326349c36815 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -3,6 +3,8 @@ mod case; #[cfg(feature = "strings")] mod concat; #[cfg(feature = "strings")] +mod escape_regex; +#[cfg(feature = "strings")] mod extract; #[cfg(feature = "find_many")] mod find_many; @@ -20,12 +22,13 @@ mod split; mod strip; #[cfg(feature = "strings")] mod substring; - #[cfg(all(not(feature = "nightly"), feature = "strings"))] mod unicode_internals; #[cfg(feature = "strings")] pub use concat::*; +#[cfg(feature = "strings")] +pub use escape_regex::*; #[cfg(feature = "find_many")] pub use find_many::*; #[cfg(feature = "extract_jsonpath")] diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 812dfbfcba91..93574e5f3080 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -640,6 +640,12 @@ pub trait StringNameSpaceImpl: AsString { substring::tail(ca, n.i64()?) } + #[cfg(feature = "strings")] + /// Escapes all regular expression meta characters in the string. + fn str_escape_regex(&self) -> StringChunked { + let ca = self.as_string(); + escape_regex::escape_regex(ca) + } } impl StringNameSpaceImpl for StringChunked {} diff --git a/crates/polars-ops/src/chunked_array/strings/split.rs b/crates/polars-ops/src/chunked_array/strings/split.rs index 31c15a70cb08..fce94fa1ad09 100644 --- a/crates/polars-ops/src/chunked_array/strings/split.rs +++ b/crates/polars-ops/src/chunked_array/strings/split.rs @@ -149,7 +149,7 @@ where }) .collect::>(); - StructChunked::from_series(ca.name().clone(), fields.iter()) + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) } pub fn split_helper<'a, F, I>(ca: &'a StringChunked, by: &'a StringChunked, op: F) -> ListChunked diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index b4d347170cbb..4c845a2ba541 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -18,6 +18,7 @@ pub type ChunkJoinIds = Vec; use polars_core::export::once_cell::sync::Lazy; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; #[derive(Clone, PartialEq, Eq, Debug, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -108,8 +109,9 @@ impl JoinArgs { } } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum JoinType { Inner, Left, diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 9332b10e392b..e4f2e7e8fa62 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -1,104 +1,20 @@ use std::hash::Hash; -use hashbrown::HashMap; use num_traits::Zero; -use polars_core::hashing::{ - IdxHash, _df_rows_to_hashes_threaded_vertical, populate_multiple_key_hashmap, - _HASHMAP_INIT_SIZE, -}; +use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; use polars_core::series::BitRepr; use polars_core::utils::flatten::flatten_nullable; -use polars_core::utils::{_set_partition_size, split_and_flatten}; -use polars_core::{with_match_physical_float_polars_type, IdBuildHasher, POOL}; +use polars_core::utils::split_and_flatten; +use polars_core::{with_match_physical_float_polars_type, POOL}; use polars_utils::abs_diff::AbsDiff; -use polars_utils::aliases::PlRandomState; use polars_utils::hashing::{hash_to_partition, DirtyHash}; -use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; -use polars_utils::pl_str::PlSmallStr; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; -use polars_utils::unitvec; use rayon::prelude::*; use super::*; - -/// Compare the rows of two [`DataFrame`]s -pub(crate) unsafe fn compare_df_rows2( - left: &DataFrame, - right: &DataFrame, - left_idx: usize, - right_idx: usize, - join_nulls: bool, -) -> bool { - for (l, r) in left.get_columns().iter().zip(right.get_columns()) { - let l = l.get_unchecked(left_idx); - let r = r.get_unchecked(right_idx); - if !l.eq_missing(&r, join_nulls) { - return false; - } - } - true -} - -pub(crate) fn create_probe_table( - hashes: &[UInt64Chunked], - keys: &DataFrame, -) -> Vec> { - let n_partitions = _set_partition_size(); - - // We will create a hashtable in every thread. - // We use the hash to partition the keys to the matching hashtable. - // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions) - .into_par_iter() - .map(|part_no| { - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || unitvec![idx], - |v| v.push(idx), - ) - } - idx += 1; - }); - - offset += len as IdxSize; - } - } - hash_tbl - }) - .collect() - }) -} - -pub(crate) fn get_offsets(probe_hashes: &[UInt64Chunked]) -> Vec { - probe_hashes - .iter() - .map(|ph| ph.len()) - .scan(0, |state, val| { - let out = *state; - *state += val; - Some(out) - }) - .collect() -} +use crate::frame::join::{prepare_binary, prepare_keys_multiple}; fn compute_len_offsets>(iter: I) -> Vec { let mut cumlen = 0; @@ -238,14 +154,16 @@ where Ok(flatten_nullable(&bufs)) } -fn asof_join_by_binary( - by_left: &BinaryChunked, - by_right: &BinaryChunked, +fn asof_join_by_binary( + by_left: &ChunkedArray, + by_right: &ChunkedArray, left_asof: &ChunkedArray, right_asof: &ChunkedArray, filter: F, ) -> IdxArr where + B: PolarsDataType, + for<'b> ::ValueT<'b>: AsRef<[u8]>, T: PolarsDataType, A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, @@ -254,14 +172,8 @@ where let left_val_arr = left_asof.downcast_iter().next().unwrap(); let right_val_arr = right_asof.downcast_iter().next().unwrap(); - let n_threads = POOL.current_num_threads(); - let split_by_left = split_and_flatten(by_left, n_threads); - let split_by_right = split_and_flatten(by_right, n_threads); - let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len())); - - let hb = PlRandomState::default(); - let prep_by_left = prepare_bytes(&split_by_left, &hb); - let prep_by_right = prepare_bytes(&split_by_right, &hb); + let (prep_by_left, prep_by_right, _, _) = prepare_binary::(by_left, by_right, false); + let offsets = compute_len_offsets(prep_by_left.iter().map(|s| s.len())); let hash_tbls = build_tables(prep_by_right, false); let n_tables = hash_tbls.len(); @@ -303,87 +215,6 @@ where flatten_nullable(&bufs) } -fn asof_join_by_multiple( - by_left: &mut DataFrame, - by_right: &mut DataFrame, - left_asof: &ChunkedArray, - right_asof: &ChunkedArray, - filter: F, -) -> IdxArr -where - T: PolarsDataType, - A: for<'a> AsofJoinState>, - F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, -{ - let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk()); - let left_val_arr = left_asof.downcast_iter().next().unwrap(); - let right_val_arr = right_asof.downcast_iter().next().unwrap(); - - let n_threads = POOL.current_num_threads(); - let split_by_left = split_and_flatten(by_left, n_threads); - let split_by_right = split_and_flatten(by_right, n_threads); - - let (build_hashes, random_state) = - _df_rows_to_hashes_threaded_vertical(&split_by_right, None).unwrap(); - let (probe_hashes, _) = - _df_rows_to_hashes_threaded_vertical(&split_by_left, Some(random_state)).unwrap(); - - let hash_tbls = create_probe_table(&build_hashes, by_right); - drop(build_hashes); // Early drop to reduce memory pressure. - let offsets = get_offsets(&probe_hashes); - let n_tables = hash_tbls.len(); - - // Now we probe the right hand side for each left hand side. - let iter = probe_hashes - .into_par_iter() - .zip(offsets) - .map(|(hash_by_left, offset)| { - let mut results = Vec::with_capacity(hash_by_left.len()); - let mut group_states: PlHashMap<_, A> = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); - - let mut ctr = 0; - for by_left_view in hash_by_left.data_views() { - for h_left in by_left_view.iter().copied() { - let idx_left = offset + ctr; - ctr += 1; - let opt_left_val = left_val_arr.get(idx_left); - - let Some(left_val) = opt_left_val else { - results.push(NullableIdxSize::null()); - continue; - }; - - let group_probe_table = - unsafe { hash_tbls.get_unchecked(hash_to_partition(h_left, n_tables)) }; - - let entry = group_probe_table.raw_entry().from_hash(h_left, |idx_hash| { - let idx_right = idx_hash.idx; - // SAFETY: indices in a join operation are always in bounds. - unsafe { - compare_df_rows2(by_left, by_right, idx_left, idx_right as usize, false) - } - }); - let Some((_, right_grp_idxs)) = entry else { - results.push(NullableIdxSize::null()); - continue; - }; - let id = asof_in_group::( - left_val, - right_val_arr, - &right_grp_idxs[..], - &mut group_states, - &filter, - ); - - results.push(materialize_nullable(id)); - } - } - results - }); - let bufs = POOL.install(|| iter.collect::>()); - flatten_nullable(&bufs) -} - #[allow(clippy::too_many_arguments)] fn dispatch_join_by_type( left_asof: &ChunkedArray, @@ -409,12 +240,16 @@ where DataType::String => { let left_by = &left_by_s.str().unwrap().as_binary(); let right_by = right_by_s.str().unwrap().as_binary(); - asof_join_by_binary::(left_by, &right_by, left_asof, right_asof, filter) + asof_join_by_binary::( + left_by, &right_by, left_asof, right_asof, filter, + ) }, DataType::Binary => { let left_by = &left_by_s.binary().unwrap(); let right_by = right_by_s.binary().unwrap(); - asof_join_by_binary::(left_by, right_by, left_asof, right_asof, filter) + asof_join_by_binary::( + left_by, right_by, left_asof, right_asof, filter, + ) }, x if x.is_float() => { with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { @@ -458,7 +293,15 @@ where #[cfg(feature = "dtype-categorical")] _check_categorical_src(lhs.dtype(), rhs.dtype())?; } - asof_join_by_multiple::(left_by, right_by, left_asof, right_asof, filter) + + // TODO: @scalar-opt. + let left_by_series: Vec<_> = left_by.materialized_column_iter().cloned().collect(); + let right_by_series: Vec<_> = right_by.materialized_column_iter().cloned().collect(); + let lhs_keys = prepare_keys_multiple(&left_by_series, false)?; + let rhs_keys = prepare_keys_multiple(&right_by_series, false)?; + asof_join_by_binary::( + &lhs_keys, &rhs_keys, left_asof, right_asof, filter, + ) }; Ok(out) } @@ -673,7 +516,7 @@ pub trait AsofJoinBy: IntoDf { .filter(|s| !drop_these.contains(&s.name())) .cloned() .collect(); - let proj_other_df = unsafe { DataFrame::new_no_checks(cols) }; + let proj_other_df = unsafe { DataFrame::new_no_checks(other_df.height(), cols) }; let left = self_df.clone(); diff --git a/crates/polars-ops/src/frame/join/asof/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs index 71e813cdac39..cb122a649c4f 100644 --- a/crates/polars-ops/src/frame/join/asof/mod.rs +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "dtype-categorical")] use super::_check_categorical_src; -use super::{_finish_join, build_tables, prepare_bytes}; +use super::{_finish_join, build_tables}; use crate::frame::IntoDf; use crate::series::SeriesMethods; diff --git a/crates/polars-ops/src/frame/join/dispatch_left_right.rs b/crates/polars-ops/src/frame/join/dispatch_left_right.rs index f5c91de88a74..b3193ce76628 100644 --- a/crates/polars-ops/src/frame/join/dispatch_left_right.rs +++ b/crates/polars-ops/src/frame/join/dispatch_left_right.rs @@ -90,14 +90,14 @@ fn materialize_left_join( if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - left._create_left_df_from_slice(left_idx, true, true) + left._create_left_df_from_slice(left_idx, true, args.slice.is_some(), true) }, ChunkJoinIds::Right(left_idx) => unsafe { let mut left_idx = &*left_idx; if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - left.create_left_df_chunked(left_idx, true) + left.create_left_df_chunked(left_idx, true, args.slice.is_some()) }, }; @@ -133,7 +133,8 @@ fn materialize_left_join( if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - let materialize_left = || unsafe { left._create_left_df_from_slice(&left_idx, true, true) }; + let materialize_left = + || unsafe { left._create_left_df_from_slice(&left_idx, true, args.slice.is_some(), true) }; let mut right_idx = &*right_idx; if let Some((offset, len)) = args.slice { diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs index 1420d7b66062..0bf0a86cd972 100644 --- a/crates/polars-ops/src/frame/join/general.rs +++ b/crates/polars-ops/src/frame/join/general.rs @@ -56,7 +56,7 @@ pub fn _coalesce_full_join( df_left: &DataFrame, ) -> DataFrame { // No need to allocate the schema because we already - // know for certain that the column name for left left is `name` + // know for certain that the column name for left is `name` // and for right is `name + suffix` let schema_left = if keys_left == keys_right { Schema::default() diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 35e4ea9403af..8d5533c41e5b 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -12,7 +12,7 @@ use polars_core::POOL; use polars_utils::index::ChunkId; pub(super) use single_keys::*; #[cfg(feature = "asof_join")] -pub(super) use single_keys_dispatch::prepare_bytes; +pub(super) use single_keys_dispatch::prepare_binary; pub use single_keys_dispatch::SeriesJoin; use single_keys_inner::*; use single_keys_left::*; @@ -55,9 +55,18 @@ pub trait JoinDispatch: IntoDf { /// # Safety /// Join tuples must be in bounds #[cfg(feature = "chunked_ids")] - unsafe fn create_left_df_chunked(&self, chunk_ids: &[ChunkId], left_join: bool) -> DataFrame { + unsafe fn create_left_df_chunked( + &self, + chunk_ids: &[ChunkId], + left_join: bool, + was_sliced: bool, + ) -> DataFrame { let df_self = self.to_df(); - if left_join && chunk_ids.len() == df_self.height() { + + let left_join_no_duplicate_matches = + left_join && !was_sliced && chunk_ids.len() == df_self.height(); + + if left_join_no_duplicate_matches { df_self.clone() } else { // left join keys are in ascending order @@ -76,10 +85,15 @@ pub trait JoinDispatch: IntoDf { &self, join_tuples: &[IdxSize], left_join: bool, + was_sliced: bool, sorted_tuple_idx: bool, ) -> DataFrame { let df_self = self.to_df(); - if left_join && join_tuples.len() == df_self.height() { + + let left_join_no_duplicate_matches = + left_join && !was_sliced && join_tuples.len() == df_self.height(); + + if left_join_no_duplicate_matches { df_self.clone() } else { // left join tuples are always in ascending order diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index a8093873ea51..f79e8759d9e8 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -507,27 +507,7 @@ where } } -#[cfg(feature = "asof_join")] -pub fn prepare_bytes<'a>( - been_split: &'a [BinaryChunked], - hb: &PlRandomState, -) -> Vec>> { - POOL.install(|| { - been_split - .par_iter() - .map(|ca| { - ca.iter() - .map(|opt_b| { - let hash = hb.hash_one(opt_b); - BytesHash::new(opt_b, hash) - }) - .collect::>() - }) - .collect() - }) -} - -fn prepare_binary<'a, T>( +pub(crate) fn prepare_binary<'a, T>( ca: &'a ChunkedArray, other: &'a ChunkedArray, // In inner join and outer join, the shortest relation will be used to create a hash table. diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs index a180b293ca0f..d018c0e7f756 100644 --- a/crates/polars-ops/src/frame/join/merge_sorted.rs +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -47,7 +47,7 @@ pub fn _merge_sorted_dfs( }) .collect::>()?; - Ok(unsafe { DataFrame::new_no_checks(new_columns) }) + Ok(unsafe { DataFrame::new_no_checks(left.height() + right.height(), new_columns) }) } fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> PolarsResult { @@ -85,7 +85,7 @@ fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> PolarsR .zip(rhs.fields_as_series()) .map(|(lhs, rhs)| merge_series(lhs, &rhs, merge_indicator)) .collect::>>()?; - StructChunked::from_series(PlSmallStr::EMPTY, new_fields.iter()) + StructChunked::from_series(PlSmallStr::EMPTY, new_fields[0].len(), new_fields.iter()) .unwrap() .into_series() }, diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 81f4fe54e7e4..fd59ef7f4a1c 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -34,11 +34,11 @@ use hashbrown::hash_map::{Entry, RawEntryMut}; pub use iejoin::{IEJoinOptions, InequalityOperator}; #[cfg(feature = "merge_sorted")] pub use merge_sorted::_merge_sorted_dfs; -use polars_core::hashing::_HASHMAP_INIT_SIZE; #[allow(unused_imports)] -use polars_core::prelude::sort::arg_sort_multiple::{ +use polars_core::chunked_array::ops::row_encode::{ encode_rows_vertical_par_unordered, encode_rows_vertical_par_unordered_broadcast_nulls, }; +use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; pub(super) use polars_core::series::IsSorted; use polars_core::utils::slice_offsets; @@ -506,7 +506,14 @@ trait DataFrameJoinOpsPrivate: IntoDf { let (df_left, df_right) = POOL.join( // SAFETY: join indices are known to be in bounds - || unsafe { left_df._create_left_df_from_slice(join_tuples_left, false, sorted) }, + || unsafe { + left_df._create_left_df_from_slice( + join_tuples_left, + false, + args.slice.is_some(), + sorted, + ) + }, || unsafe { if let Some(drop_names) = drop_names { other.drop_many(drop_names) diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index 4604920351eb..d72ff4488251 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -11,9 +11,6 @@ use polars_core::utils::accumulate_dataframes_horizontal; #[cfg(feature = "to_dummies")] use polars_core::POOL; -#[allow(unused_imports)] -use crate::prelude::*; - pub trait IntoDf { fn to_df(&self) -> &DataFrame; } @@ -94,6 +91,8 @@ pub trait DataFrameOps: IntoDf { separator: Option<&str>, drop_first: bool, ) -> PolarsResult { + use crate::series::ToDummies; + let df = self.to_df(); let set: PlHashSet<&str> = if let Some(columns) = columns { diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 5e9ece7a3878..0c5478e7b8b6 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -233,7 +233,7 @@ fn pivot_impl( already exists in the DataFrame. Please rename it prior to calling `pivot`.") } // @scalar-opt - let columns_struct = StructChunked::from_columns(column.clone(), fields) + let columns_struct = StructChunked::from_columns(column.clone(), fields[0].len(), fields) .unwrap() .into_column(); let mut binding = pivot_df.clone(); diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index 0e0de1083c5b..b7058de0a05e 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -240,13 +240,13 @@ pub(super) fn compute_col_idx( let col_locations = match column_agg_physical.dtype() { T::Int32 | T::UInt32 => { let Some(BitRepr::Small(ca)) = column_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 32-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 32-bit representation to be available; this should never happen"); }; compute_col_idx_numeric(&ca) }, T::Int64 | T::UInt64 => { let Some(BitRepr::Large(ca)) = column_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 64-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 64-bit representation to be available; this should never happen"); }; compute_col_idx_numeric(&ca) }, @@ -413,13 +413,13 @@ pub(super) fn compute_row_idx( match index_agg_physical.dtype() { T::Int32 | T::UInt32 => { let Some(BitRepr::Small(ca)) = index_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 32-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 32-bit representation to be available; this should never happen"); }; compute_row_index(index, &ca, count, index_s.dtype()) }, T::Int64 | T::UInt64 => { let Some(BitRepr::Large(ca)) = index_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 64-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 64-bit representation to be available; this should never happen"); }; compute_row_index(index, &ca, count, index_s.dtype()) }, @@ -485,9 +485,12 @@ pub(super) fn compute_row_idx( } else { let binding = pivot_df.select(index.iter().cloned())?; let fields = binding.get_columns(); - let index_struct_series = - StructChunked::from_columns(PlSmallStr::from_static("placeholder"), fields)? - .into_series(); + let index_struct_series = StructChunked::from_columns( + PlSmallStr::from_static("placeholder"), + fields[0].len(), + fields, + )? + .into_series(); let index_agg = unsafe { index_struct_series.agg_first(groups) }; let index_agg_physical = index_agg.to_physical_repr(); let ca = index_agg_physical.struct_()?; diff --git a/crates/polars-ops/src/frame/pivot/unpivot.rs b/crates/polars-ops/src/frame/pivot/unpivot.rs index 89c38e88c37b..49eeaeba4498 100644 --- a/crates/polars-ops/src/frame/pivot/unpivot.rs +++ b/crates/polars-ops/src/frame/pivot/unpivot.rs @@ -4,7 +4,7 @@ use polars_core::datatypes::{DataType, PlSmallStr}; use polars_core::frame::column::Column; use polars_core::frame::DataFrame; use polars_core::prelude::{IntoVec, Series, UnpivotArgsIR}; -use polars_core::utils::try_get_supertype; +use polars_core::utils::merge_dtypes_many; use polars_error::{polars_err, PolarsResult}; use polars_utils::aliases::PlHashSet; @@ -104,18 +104,19 @@ pub trait UnpivotDF: IntoDf { let len = self_.height(); - // if value vars is empty we take all columns that are not in id_vars. + // If value vars is empty we take all columns that are not in id_vars. if on.is_empty() { - // return empty frame if there are no columns available to use as value vars + // Return empty frame if there are no columns available to use as value vars. if index.len() == self_.width() { let variable_col = Column::new_empty(variable_name, &DataType::String); let value_col = Column::new_empty(value_name, &DataType::Null); let mut out = self_.select(index).unwrap().clear().take_columns(); + out.push(variable_col); out.push(value_col); - return Ok(unsafe { DataFrame::new_no_checks(out) }); + return Ok(unsafe { DataFrame::new_no_checks(0, out) }); } let index_set = PlHashSet::from_iter(index.iter().cloned()); @@ -132,15 +133,14 @@ pub trait UnpivotDF: IntoDf { .collect(); } - // values will all be placed in single column, so we must find their supertype + // Values will all be placed in single column, so we must find their supertype let schema = self_.schema(); - let mut iter = on + let dtypes = on .iter() - .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))); - let mut st = iter.next().unwrap()?.clone(); - for dt in iter { - st = try_get_supertype(&st, dt?)?; - } + .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))) + .collect::>>()?; + + let st = merge_dtypes_many(dtypes.iter())?; // The column name of the variable that is unpivoted let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); @@ -166,7 +166,7 @@ pub trait UnpivotDF: IntoDf { let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?; let col = &columns[pos]; let value_col = col.cast(&st).map_err( - |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}", col.dtype()), + |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}\n\nConsider casting to String.", col.dtype()), )?; values.extend_from_slice(value_col.as_materialized_series().chunks()) } diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index 52cc2ee5a67a..b7b8d3e9f179 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -63,7 +63,7 @@ fn map_cats( ._with_fast_unique(label_has_value.iter().all(bool::clone)) .into_series(), ]; - Ok(StructChunked::from_series(out_name, outvals.iter())?.into_series()) + Ok(StructChunked::from_series(out_name, outvals[0].len(), outvals.iter())?.into_series()) } else { Ok(bld .drain_iter_and_finish(s_iter.map(|opt| { @@ -144,11 +144,7 @@ pub fn qcut( let s2 = s.sort(SortOptions::default())?; let ca = s2.f64()?; - let f = |&p| { - ca.quantile(p, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - }; + let f = |&p| ca.quantile(p, QuantileMethod::Linear).unwrap().unwrap(); let mut qbreaks: Vec<_> = probs.iter().map(f).collect(); qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); diff --git a/crates/polars-ops/src/series/ops/floor_divide.rs b/crates/polars-ops/src/series/ops/floor_divide.rs index 4c5075ecad42..b8aa78c4ec01 100644 --- a/crates/polars-ops/src/series/ops/floor_divide.rs +++ b/crates/polars-ops/src/series/ops/floor_divide.rs @@ -1,6 +1,7 @@ use polars_compute::arithmetic::ArithmeticKernel; use polars_core::chunked_array::ops::arity::apply_binary_kernel_broadcast; use polars_core::prelude::*; +use polars_core::series::arithmetic::NumericListOp; #[cfg(feature = "dtype-struct")] use polars_core::series::arithmetic::_struct_arithmetic; use polars_core::with_match_physical_numeric_polars_type; @@ -24,6 +25,9 @@ pub fn floor_div_series(a: &Series, b: &Series) -> PolarsResult { (DataType::Struct(_), DataType::Struct(_)) => { return _struct_arithmetic(a, b, floor_div_series); }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + return NumericListOp::FloorDiv.execute(a, b); + }, _ => {}, } diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 6e07e212bedd..663ac3664c8e 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -2,28 +2,32 @@ use polars_core::frame::NullStrategy; use polars_core::prelude::*; pub fn max_horizontal(s: &[Column]) -> PolarsResult> { - let df = unsafe { DataFrame::_new_no_checks_impl(Vec::from(s)) }; + let df = + unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; df.max_horizontal() .map(|s| s.map(Column::from)) .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn min_horizontal(s: &[Column]) -> PolarsResult> { - let df = unsafe { DataFrame::_new_no_checks_impl(Vec::from(s)) }; + let df = + unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; df.min_horizontal() .map(|s| s.map(Column::from)) .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn sum_horizontal(s: &[Column]) -> PolarsResult> { - let df = unsafe { DataFrame::_new_no_checks_impl(Vec::from(s)) }; + let df = + unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; df.sum_horizontal(NullStrategy::Ignore) .map(|s| s.map(Column::from)) .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn mean_horizontal(s: &[Column]) -> PolarsResult> { - let df = unsafe { DataFrame::_new_no_checks_impl(Vec::from(s)) }; + let df = + unsafe { DataFrame::_new_no_checks_impl(s.first().map_or(0, Column::len), Vec::from(s)) }; df.mean_horizontal(NullStrategy::Ignore) .map(|s| s.map(Column::from)) .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) diff --git a/crates/polars-ops/src/series/ops/is_between.rs b/crates/polars-ops/src/series/ops/is_between.rs index 053493d552f6..96b1074b4d82 100644 --- a/crates/polars-ops/src/series/ops/is_between.rs +++ b/crates/polars-ops/src/series/ops/is_between.rs @@ -3,9 +3,11 @@ use std::ops::BitAnd; use polars_core::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum ClosedInterval { #[default] Both, diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index f2d8f8128777..7c5697429372 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -237,6 +237,7 @@ fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsRe old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD")); new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW")); + let len = old.len(); let cols = if add_mask { // @scalar-opt let mask = Column::new(PlSmallStr::from_static("__POLARS_REPLACE_MASK"), &[true]) @@ -245,7 +246,7 @@ fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsRe } else { vec![old.into(), new.into()] }; - let out = unsafe { DataFrame::new_no_checks(cols) }; + let out = unsafe { DataFrame::new_no_checks(len, cols) }; Ok(out) } diff --git a/crates/polars-ops/src/series/ops/rle.rs b/crates/polars-ops/src/series/ops/rle.rs index 9277913558a2..ed88648fee72 100644 --- a/crates/polars-ops/src/series/ops/rle.rs +++ b/crates/polars-ops/src/series/ops/rle.rs @@ -31,7 +31,7 @@ pub fn rle(s: &Column) -> PolarsResult { Series::from_vec(PlSmallStr::from_static("len"), lengths).into(), vals.to_owned(), ]; - Ok(StructChunked::from_columns(s.name().clone(), &outvals)?.into_column()) + Ok(StructChunked::from_columns(s.name().clone(), vals.len(), &outvals)?.into_column()) } /// Similar to `rle`, but maps values to run IDs. diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index 437f49dad480..eb2cf3a228c1 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -44,9 +44,10 @@ impl ToDummies for Series { }; ca.into_column() }) - .collect(); + .collect::>(); - Ok(unsafe { DataFrame::new_no_checks(sort_columns(columns)) }) + // SAFETY: `dummies_helper` functions preserve `self.len()` length + unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) } } } diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index c29fcc431c98..47d467f7f7ba 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -1,7 +1,7 @@ use num_traits::Bounded; -use polars_core::prelude::arity::unary_elementwise_values; #[cfg(feature = "dtype-struct")] -use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_ca; +use polars_core::prelude::arity::unary_elementwise_values; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::with_match_physical_numeric_polars_type; @@ -39,8 +39,9 @@ pub trait SeriesMethods: SeriesSealed { counts.into_column() }; + let height = counts.len(); let cols = vec![values, counts]; - let df = unsafe { DataFrame::new_no_checks(cols) }; + let df = unsafe { DataFrame::new_no_checks(height, cols) }; if sort { df.sort( [name], diff --git a/crates/polars-parquet/Cargo.toml b/crates/polars-parquet/Cargo.toml index 26a57b22e713..8ae9108a1fa7 100644 --- a/crates/polars-parquet/Cargo.toml +++ b/crates/polars-parquet/Cargo.toml @@ -22,12 +22,12 @@ fallible-streaming-iterator = { workspace = true, optional = true } futures = { workspace = true, optional = true } hashbrown = { workspace = true } num-traits = { workspace = true } -polars-compute = { workspace = true } +polars-compute = { workspace = true, features = ["approx_unique"] } polars-error = { workspace = true } +polars-parquet-format = "0.1" polars-utils = { workspace = true, features = ["mmap"] } simdutf8 = { workspace = true } -parquet-format-safe = "0.2" streaming-decompression = "0.1" async-stream = { version = "0.3.3", optional = true } @@ -61,6 +61,6 @@ gzip_zlib_ng = ["flate2/zlib-ng"] lz4 = ["dep:lz4"] lz4_flex = ["dep:lz4_flex"] -async = ["async-stream", "futures", "parquet-format-safe/async"] +async = ["async-stream", "futures", "polars-parquet-format/async"] bloom_filter = ["xxhash-rust"] serde_types = ["serde"] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs index 6777f7e639c9..86e46756788a 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs @@ -362,7 +362,7 @@ impl DeltaGatherer for StatGatherer { } } -impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut DeltaCollector<'a, 'b> { +impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut DeltaCollector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } @@ -394,7 +394,7 @@ impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut Delta } } -impl<'a, 'b> DeltaCollector<'a, 'b> { +impl DeltaCollector<'_, '_> { pub fn flush(&mut self, target: &mut MutableBinaryViewArray<[u8]>) { if !self.pushed_lengths.is_empty() { let start_bytes_len = target.total_bytes_len(); @@ -428,7 +428,7 @@ impl<'a, 'b> DeltaCollector<'a, 'b> { } } -impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for DeltaBytesCollector<'a, 'b> { +impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for DeltaBytesCollector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } @@ -621,7 +621,7 @@ impl utils::Decoder for BinViewDecoder { max_length: &'b mut usize, } - impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'a, 'b> { + impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } @@ -709,7 +709,7 @@ impl utils::Decoder for BinViewDecoder { ) -> ParquetResult<()> { struct DictionaryTranslator<'a>(&'a [View]); - impl<'a> HybridRleGatherer for DictionaryTranslator<'a> { + impl HybridRleGatherer for DictionaryTranslator<'_> { type Target = MutableBinaryViewArray<[u8]>; fn target_reserve(&self, target: &mut Self::Target, n: usize) { @@ -803,7 +803,7 @@ impl utils::Decoder for BinViewDecoder { translator: DictionaryTranslator<'b>, } - impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'a, 'b> { + impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs index af2e504d2646..51026f483bd7 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs @@ -161,7 +161,7 @@ impl HybridRleGatherer for BitmapGatherer { // @TODO: The slice impl here can speed some stuff up } struct BitmapCollector<'a, 'b>(&'b mut HybridRleDecoder<'a>); -impl<'a, 'b> BatchableCollector for BitmapCollector<'a, 'b> { +impl BatchableCollector for BitmapCollector<'_, '_> { fn reserve(target: &mut MutableBitmap, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs index de2bfe2e47f3..478c7cca0f2e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs @@ -184,7 +184,7 @@ pub(crate) struct DictArrayTranslator { dict_size: usize, } -impl<'a, 'b, K: DictionaryKey> BatchableCollector<(), Vec> for DictArrayCollector<'a, 'b> { +impl BatchableCollector<(), Vec> for DictArrayCollector<'_, '_> { fn reserve(target: &mut Vec, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs index 3825d528c8f5..5657a20dd151 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -154,7 +154,7 @@ impl Decoder for BinaryDecoder { size: usize, } - impl<'a, 'b> BatchableCollector<(), Vec> for FixedSizeBinaryCollector<'a, 'b> { + impl BatchableCollector<(), Vec> for FixedSizeBinaryCollector<'_, '_> { fn reserve(target: &mut Vec, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index b5b083f8b882..415517207a87 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -403,7 +403,7 @@ pub fn columns_to_iter_recursive( let (mut nested, last_array) = field_to_nested_array(init.clone(), &mut columns, &mut types, last_field)?; debug_assert!(matches!(nested.last().unwrap(), NestedContent::Struct)); - let (_, _, struct_validity) = nested.pop().unwrap(); + let (length, _, struct_validity) = nested.pop().unwrap(); let mut field_arrays = Vec::>::with_capacity(fields.len()); field_arrays.push(last_array); @@ -431,6 +431,7 @@ pub fn columns_to_iter_recursive( nested, Box::new(StructArray::new( ArrowDataType::Struct(fields.clone()), + length, field_arrays, struct_validity, )), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index ab769848ca92..d37a6d4bf3b1 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -156,9 +156,35 @@ impl Nested { if is_valid && self.num_invalids != 0 { debug_assert!(!is_primitive); - let validity = self.validity.as_mut().unwrap(); - validity.extend_constant(self.num_valids, true); - validity.extend_constant(self.num_invalids, false); + // @NOTE: Having invalid items might not necessarily mean that we have a validity mask. + // + // For instance, if we have a optional struct with a required list in it, that struct + // will have a validity mask and the list will not. In the arrow representation of this + // array, however, the list will still have invalid items where the struct is null. + // + // Array: + // [ + // { 'x': [1] }, + // None, + // { 'x': [1, 2] }, + // ] + // + // Arrow: + // struct = [ list[0] None list[2] ] + // list = { + // values = [ 1, 1, 2 ], + // offsets = [ 0, 1, 1, 3 ], + // } + // + // Parquet: + // [ 1, 1, 2 ] + definition + repetition levels + // + // As you can see we need to insert an invalid item into the list even though it does + // not have a validity mask. + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(self.num_valids, true); + validity.extend_constant(self.num_invalids, false); + } self.num_valids = 0; self.num_invalids = 0; @@ -174,8 +200,6 @@ impl Nested { } fn push_default(&mut self, length: i64) { - debug_assert!(self.validity.is_some()); - let is_primitive = matches!(self.content, NestedContent::Primitive); self.num_invalids += usize::from(!is_primitive); @@ -191,8 +215,8 @@ pub struct BatchedNestedDecoder<'a, 'b, 'c, D: utils::NestedDecoder> { decoder: &'c mut D, } -impl<'a, 'b, 'c, D: utils::NestedDecoder> BatchableCollector<(), D::DecodedState> - for BatchedNestedDecoder<'a, 'b, 'c, D> +impl BatchableCollector<(), D::DecodedState> + for BatchedNestedDecoder<'_, '_, '_, D> { fn reserve(_target: &mut D::DecodedState, _n: usize) { unreachable!() diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs index 0a43141abd06..eb4815b6fbfe 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs @@ -55,7 +55,7 @@ where let values = split_buffer(page)?.values; Ok(Self::ByteStreamSplit(byte_stream_split::Decoder::try_new( values, - std::mem::size_of::

(), + size_of::

(), )?)) }, _ => Err(utils::not_implemented(page)), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs index ed8e0a541a68..ff9c0b014b08 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs @@ -58,7 +58,7 @@ where let values = split_buffer(page)?.values; Ok(Self::ByteStreamSplit(byte_stream_split::Decoder::try_new( values, - std::mem::size_of::

(), + size_of::

(), )?)) }, (Encoding::DeltaBinaryPacked, _) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs index 1a9d50a66d31..5539595fda48 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs @@ -125,8 +125,8 @@ where pub(crate) _pd: std::marker::PhantomData, } -impl<'a, 'b, P, T, D: DecoderFunction> BatchableCollector<(), Vec> - for PlainDecoderFnCollector<'a, 'b, P, T, D> +impl> BatchableCollector<(), Vec> + for PlainDecoderFnCollector<'_, '_, P, T, D> where T: NativeType, P: ParquetNativeType, @@ -167,7 +167,7 @@ where D: DecoderFunction, { values - .chunks_exact(std::mem::size_of::

()) + .chunks_exact(size_of::

()) .map(decode) .map(|v| decoder.decode(v)) .collect::>() @@ -239,7 +239,7 @@ where } } -impl<'a, 'b, P, T, D> BatchableCollector<(), Vec> for DeltaCollector<'a, 'b, P, T, D> +impl BatchableCollector<(), Vec> for DeltaCollector<'_, '_, P, T, D> where T: NativeType, P: ParquetNativeType, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs index 56912934e100..f86523b74d9b 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs @@ -1,7 +1,7 @@ use arrow::array::{Array, DictionaryArray, DictionaryKey, FixedSizeBinaryArray, PrimitiveArray}; use arrow::datatypes::{ArrowDataType, IntervalUnit, TimeUnit}; use arrow::match_integer_type; -use arrow::types::{days_ms, i256}; +use arrow::types::{days_ms, i256, NativeType}; use ethnum::I256; use polars_error::{polars_bail, PolarsResult}; @@ -275,6 +275,34 @@ pub fn page_iter_to_array( primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), + + // Float16 + (PhysicalType::FixedLenByteArray(2), Float32) => { + // @NOTE: To reduce code bloat, we just use the FixedSizeBinary decoder. + + let mut fsb_array = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(2), + fixed_size_binary::BinaryDecoder { size: 2 }, + )?.collect_n(filter)?; + + + let validity = fsb_array.take_validity(); + let values = fsb_array.values().as_slice(); + assert_eq!(values.len() % 2, 0); + let values = values.chunks_exact(2); + let values = values.map(|v| { + // SAFETY: We know that `v` is always of size two. + let le_bytes: [u8; 2] = unsafe { v.try_into().unwrap_unchecked() }; + let v = arrow::types::f16::from_le_bytes(le_bytes); + v.to_f32() + }).collect(); + + let array = PrimitiveArray::::new(dtype, values, validity); + + Box::new(array) + }, + (PhysicalType::Float, Float32) => Box::new(PageDecoder::new( pages, dtype, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs index 330ad77a7c44..f66544ee3183 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs @@ -16,7 +16,7 @@ impl<'a, P: ParquetNativeType> ArrayChunks<'a, P> { /// /// This returns null if the `bytes` slice's length is not a multiple of the size of `P::Bytes`. pub(crate) fn new(bytes: &'a [u8]) -> Option { - if bytes.len() % std::mem::size_of::() != 0 { + if bytes.len() % size_of::() != 0 { return None; } @@ -47,4 +47,4 @@ impl<'a, P: ParquetNativeType> Iterator for ArrayChunks<'a, P> { } } -impl<'a, P: ParquetNativeType> ExactSizeIterator for ArrayChunks<'a, P> {} +impl ExactSizeIterator for ArrayChunks<'_, P> {} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs index dba00fc97930..7c6cf840bdce 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -433,7 +433,7 @@ where } } -impl<'a, 'b, 'c, O, T> BatchableCollector> for TranslatedHybridRle<'a, 'b, 'c, O, T> +impl BatchableCollector> for TranslatedHybridRle<'_, '_, '_, O, T> where O: Clone + Default, T: Translator, @@ -487,7 +487,7 @@ where } } -impl<'a, 'b, 'c, O, G> BatchableCollector> for GatheredHybridRle<'a, 'b, 'c, O, G> +impl BatchableCollector> for GatheredHybridRle<'_, '_, '_, O, G> where O: Clone, G: HybridRleGatherer>, @@ -516,8 +516,8 @@ where } } -impl<'a, 'b, 'c, T> BatchableCollector> - for TranslatedHybridRle<'a, 'b, 'c, View, T> +impl BatchableCollector> + for TranslatedHybridRle<'_, '_, '_, View, T> where T: Translator, { diff --git a/crates/polars-parquet/src/arrow/read/schema/metadata.rs b/crates/polars-parquet/src/arrow/read/schema/metadata.rs index 915936c81296..f73c401bd72b 100644 --- a/crates/polars-parquet/src/arrow/read/schema/metadata.rs +++ b/crates/polars-parquet/src/arrow/read/schema/metadata.rs @@ -35,6 +35,7 @@ fn convert_dtype(mut dtype: ArrowDataType) -> ArrowDataType { convert_field(field); } }, + Float16 => dtype = Float32, Binary | LargeBinary => dtype = BinaryView, Utf8 | LargeUtf8 => dtype = Utf8View, Dictionary(_, ref mut dtype, _) | Extension(_, ref mut dtype, _) => { diff --git a/crates/polars-parquet/src/arrow/read/schema/mod.rs b/crates/polars-parquet/src/arrow/read/schema/mod.rs index 347cd49faefd..ea27aa03d46d 100644 --- a/crates/polars-parquet/src/arrow/read/schema/mod.rs +++ b/crates/polars-parquet/src/arrow/read/schema/mod.rs @@ -40,7 +40,7 @@ impl Default for SchemaInferenceOptions { /// /// # Error /// This function errors iff the key `"ARROW:schema"` exists but is not correctly encoded, -/// indicating that that the file's arrow metadata was incorrectly written. +/// indicating that the file's arrow metadata was incorrectly written. pub fn infer_schema(file_metadata: &FileMetadata) -> PolarsResult { infer_schema_with_options(file_metadata, &None) } diff --git a/crates/polars-parquet/src/arrow/read/statistics/mod.rs b/crates/polars-parquet/src/arrow/read/statistics/mod.rs index cb65827de9a6..22ba71caba17 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/mod.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/mod.rs @@ -3,7 +3,7 @@ use std::collections::VecDeque; use arrow::array::*; use arrow::datatypes::{ArrowDataType, Field, IntervalUnit, PhysicalType}; -use arrow::types::i256; +use arrow::types::{f16, i256, NativeType}; use arrow::with_match_primitive_type_full; use ethnum::I256; use polars_error::{polars_bail, PolarsResult}; @@ -28,6 +28,7 @@ mod struct_; mod utf8; use self::list::DynMutableListArray; +use super::PrimitiveLogicalType; /// Arrow-deserialized parquet Statistics of a file #[derive(Debug, PartialEq)] @@ -319,7 +320,11 @@ fn push( null_count, ); }, - Struct(_) => { + Struct(fields) => { + if fields.is_empty() { + return Ok(()); + } + let min = min .as_mut_any() .downcast_mut::() @@ -338,11 +343,11 @@ fn push( .unwrap(); return min - .inner + .inner_mut() .iter_mut() - .zip(max.inner.iter_mut()) - .zip(distinct_count.inner.iter_mut()) - .zip(null_count.inner.iter_mut()) + .zip(max.inner_mut()) + .zip(distinct_count.inner_mut()) + .zip(null_count.inner_mut()) .try_for_each(|(((min, max), distinct_count), null_count)| { push( stats, @@ -545,6 +550,37 @@ fn push( } } +pub fn cast_statistics( + statistics: ParquetStatistics, + primitive_type: &ParquetPrimitiveType, + output_type: &ArrowDataType, +) -> ParquetStatistics { + use {ArrowDataType as DT, PrimitiveLogicalType as PT}; + + match (primitive_type.logical_type, output_type) { + (Some(PT::Float16), DT::Float32) => { + let statistics = statistics.expect_fixedlen(); + + let primitive_type = primitive_type.clone(); + + ParquetStatistics::Float(PrimitiveStatistics:: { + primitive_type, + null_count: statistics.null_count, + distinct_count: statistics.distinct_count, + min_value: statistics + .min_value + .as_ref() + .map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()), + max_value: statistics + .max_value + .as_ref() + .map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()), + }) + }, + _ => statistics, + } +} + /// Deserializes the statistics in the column chunks from a single `row_group` /// into [`Statistics`] associated from `field`'s name. /// @@ -558,9 +594,13 @@ pub fn deserialize<'a>( let mut stats = field_md .map(|column| { + let primitive_type = &column.descriptor().descriptor.primitive_type; Ok(( - column.statistics().transpose()?, - column.descriptor().descriptor.primitive_type.clone(), + column + .statistics() + .transpose()? + .map(|stats| cast_statistics(stats, primitive_type, &field.dtype)), + primitive_type.clone(), )) }) .collect::, ParquetPrimitiveType)>>>()?; diff --git a/crates/polars-parquet/src/arrow/read/statistics/struct_.rs b/crates/polars-parquet/src/arrow/read/statistics/struct_.rs index bb6889d569f8..8adfaed734c4 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/struct_.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/struct_.rs @@ -7,7 +7,7 @@ use super::make_mutable; #[derive(Debug)] pub struct DynMutableStructArray { dtype: ArrowDataType, - pub inner: Vec>, + inner: Vec>, } impl DynMutableStructArray { @@ -16,6 +16,9 @@ impl DynMutableStructArray { ArrowDataType::Struct(inner) => inner, _ => unreachable!(), }; + + assert!(!inners.is_empty()); + let inner = inners .iter() .map(|f| make_mutable(f.dtype(), capacity)) @@ -23,7 +26,12 @@ impl DynMutableStructArray { Ok(Self { dtype, inner }) } + + pub fn inner_mut(&mut self) -> &mut [Box] { + &mut self.inner + } } + impl MutableArray for DynMutableStructArray { fn dtype(&self) -> &ArrowDataType { &self.dtype @@ -38,9 +46,9 @@ impl MutableArray for DynMutableStructArray { } fn as_box(&mut self) -> Box { + let len = self.len(); let inner = self.inner.iter_mut().map(|x| x.as_box()).collect(); - - Box::new(StructArray::new(self.dtype.clone(), inner, None)) + Box::new(StructArray::new(self.dtype.clone(), len, inner, None)) } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-parquet/src/arrow/write/binary/basic.rs b/crates/polars-parquet/src/arrow/write/binary/basic.rs index 0032a5d1e29d..7af6ab5f353d 100644 --- a/crates/polars-parquet/src/arrow/write/binary/basic.rs +++ b/crates/polars-parquet/src/arrow/write/binary/basic.rs @@ -30,15 +30,15 @@ pub(crate) fn encode_plain( ) { if options.is_optional() && array.validity().is_some() { let len_before = buffer.len(); - let capacity = array.get_values_size() - + (array.len() - array.null_count()) * std::mem::size_of::(); + let capacity = + array.get_values_size() + (array.len() - array.null_count()) * size_of::(); buffer.reserve(capacity); encode_non_null_values(array.non_null_values_iter(), buffer); // Ensure we allocated properly. debug_assert_eq!(buffer.len() - len_before, capacity); } else { let len_before = buffer.len(); - let capacity = array.get_values_size() + array.len() * std::mem::size_of::(); + let capacity = array.get_values_size() + array.len() * size_of::(); buffer.reserve(capacity); encode_non_null_values(array.values_iter(), buffer); // Ensure we allocated properly. diff --git a/crates/polars-parquet/src/arrow/write/binview/basic.rs b/crates/polars-parquet/src/arrow/write/binview/basic.rs index 8251818e722f..018e1627d25f 100644 --- a/crates/polars-parquet/src/arrow/write/binview/basic.rs +++ b/crates/polars-parquet/src/arrow/write/binview/basic.rs @@ -16,8 +16,8 @@ pub(crate) fn encode_plain( buffer: &mut Vec, ) { if options.is_optional() && array.validity().is_some() { - let capacity = array.total_bytes_len() - + (array.len() - array.null_count()) * std::mem::size_of::(); + let capacity = + array.total_bytes_len() + (array.len() - array.null_count()) * size_of::(); let len_before = buffer.len(); buffer.reserve(capacity); @@ -26,7 +26,7 @@ pub(crate) fn encode_plain( // Append the non-null values. debug_assert_eq!(buffer.len() - len_before, capacity); } else { - let capacity = array.total_bytes_len() + array.len() * std::mem::size_of::(); + let capacity = array.total_bytes_len() + array.len() * size_of::(); let len_before = buffer.len(); buffer.reserve(capacity); diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index 4e0d57302314..fc97c268c0fd 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -3,16 +3,17 @@ use arrow::array::{ }; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::buffer::Buffer; -use arrow::datatypes::{ArrowDataType, IntegerType}; +use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType}; +use arrow::legacy::utils::CustomIterTools; +use arrow::trusted_len::TrustMyLength; use arrow::types::NativeType; use polars_compute::min_max::MinMaxKernel; use polars_error::{polars_bail, PolarsResult}; -use polars_utils::unwrap::UnwrapUncheckedRelease; use super::binary::{ build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, }; -use super::fixed_len_bytes::{ +use super::fixed_size_binary::{ build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain, }; use super::pages::PrimitiveNested; @@ -31,33 +32,51 @@ use crate::parquet::CowBuffer; use crate::write::DynIter; trait MinMaxThreshold { - const DELTA_THRESHOLD: Self; + const DELTA_THRESHOLD: usize; + const BITMASK_THRESHOLD: usize; + + fn from_start_and_offset(start: Self, offset: usize) -> Self; } macro_rules! minmaxthreshold_impls { - ($($t:ty => $threshold:literal,)+) => { + ($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => { $( - impl MinMaxThreshold for $t { - const DELTA_THRESHOLD: Self = $threshold; + impl MinMaxThreshold for $signed { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + ((offset as $unsigned) as $signed) + } + } + impl MinMaxThreshold for $unsigned { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + (offset as $unsigned) + } } )+ }; } minmaxthreshold_impls! { - i8 => 16, - i16 => 256, - i32 => 512, - i64 => 2048, - u8 => 16, - u16 => 256, - u32 => 512, - u64 => 2048, + i8, u8 => 16, u8::MAX as usize, + i16, u16 => 256, u16::MAX as usize, + i32, u32 => 512, u16::MAX as usize, + i64, u64 => 2048, u16::MAX as usize, +} + +enum DictionaryDecision { + NotWorth, + TryAgain, + Found(DictionaryArray), } fn min_max_integer_encode_as_dictionary_optional<'a, E, T>( array: &'a dyn Array, -) -> Option> +) -> DictionaryDecision where E: std::fmt::Debug, T: NativeType @@ -65,26 +84,82 @@ where + std::cmp::Ord + TryInto + std::ops::Sub - + num_traits::CheckedSub, + + num_traits::CheckedSub + + num_traits::cast::AsPrimitive, std::ops::RangeInclusive: Iterator, PrimitiveArray: MinMaxKernel = T>, { - use ArrowDataType as DT; - let (min, max): (T, T) = as MinMaxKernel>::min_max_ignore_nan_kernel( + let min_max = as MinMaxKernel>::min_max_ignore_nan_kernel( array.as_any().downcast_ref().unwrap(), - )?; + ); + + let Some((min, max)) = min_max else { + return DictionaryDecision::TryAgain; + }; debug_assert!(max >= min, "{max} >= {min}"); - if !max - .checked_sub(&min) - .is_some_and(|v| v <= T::DELTA_THRESHOLD) - { - return None; + let Some(diff) = max.checked_sub(&min) else { + return DictionaryDecision::TryAgain; + }; + + let diff = diff.as_(); + + if diff > T::BITMASK_THRESHOLD { + return DictionaryDecision::TryAgain; + } + + let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1); + + let array = array.as_any().downcast_ref::>().unwrap(); + + if array.has_nulls() { + for v in array.non_null_values_iter() { + let offset = (v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } + } else { + for v in array.values_iter() { + let offset = (*v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } } - // @TODO: This currently overestimates the values, it might be interesting to use the unique - // kernel here. - let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), (min..=max).collect(), None); + let cardinality = seen_mask.set_bits(); + + let mut is_worth_it = false; + + is_worth_it |= cardinality <= T::DELTA_THRESHOLD; + is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75; + + if !is_worth_it { + return DictionaryDecision::NotWorth; + } + + let seen_mask = seen_mask.freeze(); + + // SAFETY: We just did the calculation for this. + let indexes = seen_mask + .true_idx_iter() + .map(|idx| T::from_start_and_offset(min, idx)); + let indexes = unsafe { TrustMyLength::new(indexes, cardinality) }; + let indexes = indexes.collect_trusted::>(); + + let mut lookup = vec![0u16; diff + 1]; + + for (i, &idx) in indexes.iter().enumerate() { + lookup[(idx - min).as_()] = i as u16; + } + + use ArrowDataType as DT; + let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None); let values = Box::new(values); let keys: Buffer = array @@ -93,20 +168,19 @@ where .unwrap() .values() .iter() - .map(|v| unsafe { + .map(|v| { // @NOTE: // Since the values might contain nulls which have a undefined value. We just // clamp the values to between the min and max value. This way, they will still - // be valid dictionary keys. This is mostly to make the - // unwrap_unchecked_release not produce any unsafety. - (*v.clamp(&min, &max) - min) - .try_into() - .unwrap_unchecked_release() + // be valid dictionary keys. + let idx = *v.clamp(&min, &max) - min; + let value = unsafe { lookup.get_unchecked(idx.as_()) }; + (*value).into() }) .collect(); let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned()); - Some( + DictionaryDecision::Found( DictionaryArray::::try_new( ArrowDataType::Dictionary( IntegerType::UInt32, @@ -126,26 +200,15 @@ pub(crate) fn encode_as_dictionary_optional( type_: PrimitiveType, options: WriteOptions, ) -> Option>>> { - use ArrowDataType as DT; - let fast_dictionary = match array.dtype() { - DT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), - DT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), - DT::Int32 | DT::Date32 | DT::Time32(_) => { - min_max_integer_encode_as_dictionary_optional::<_, i32>(array) - }, - DT::Int64 | DT::Date64 | DT::Time64(_) | DT::Timestamp(_, _) | DT::Duration(_) => { - min_max_integer_encode_as_dictionary_optional::<_, i64>(array) - }, - DT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), - DT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), - DT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), - DT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), - _ => None, - }; + if array.is_empty() { + let array = DictionaryArray::::new_empty(ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(array.dtype().clone()), + false, // @TODO: This might be able to be set to true? + )); - if let Some(fast_dictionary) = fast_dictionary { return Some(array_to_pages( - &fast_dictionary, + &array, type_, nested, options, @@ -153,9 +216,44 @@ pub(crate) fn encode_as_dictionary_optional( )); } + use arrow::types::PrimitiveType as PT; + let fast_dictionary = match array.dtype().to_physical_type() { + PhysicalType::Primitive(pt) => match pt { + PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), + PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), + PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array), + PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array), + PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), + PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), + PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), + PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), + _ => DictionaryDecision::TryAgain, + }, + _ => DictionaryDecision::TryAgain, + }; + + match fast_dictionary { + DictionaryDecision::NotWorth => return None, + DictionaryDecision::Found(dictionary_array) => { + return Some(array_to_pages( + &dictionary_array, + type_, + nested, + options, + Encoding::RleDictionary, + )) + }, + DictionaryDecision::TryAgain => {}, + } + let dtype = Box::new(array.dtype().clone()); - let len_before = array.len(); + let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array); + + if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 { + return None; + } + // This does the group by. let array = arrow::compute::cast::cast( array, @@ -169,10 +267,6 @@ pub(crate) fn encode_as_dictionary_optional( .downcast_ref::>() .unwrap(); - if (array.values().len() as f64) / (len_before as f64) > 0.75 { - return None; - } - Some(array_to_pages( array, type_, diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs new file mode 100644 index 000000000000..27151ce51f70 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs @@ -0,0 +1,47 @@ +use arrow::array::{Array, FixedSizeBinaryArray}; +use polars_error::PolarsResult; + +use super::encode_plain; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; +use crate::read::schema::is_nullable; +use crate::write::{utils, EncodeNullability, Encoding, WriteOptions}; + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + statistics: Option, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, encode_options, &mut buffer); + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics.map(|x| x.serialize()), + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/fixed_len_bytes.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs similarity index 79% rename from crates/polars-parquet/src/arrow/write/fixed_len_bytes.rs rename to crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs index 9277b9c78a98..58f11adfa491 100644 --- a/crates/polars-parquet/src/arrow/write/fixed_len_bytes.rs +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs @@ -1,12 +1,13 @@ +mod basic; +mod nested; + use arrow::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; use arrow::types::i256; -use polars_error::PolarsResult; +pub use basic::array_to_page; +pub use nested::array_to_page as nested_array_to_page; use super::binary::ord_binary; -use super::{utils, EncodeNullability, StatisticsOptions, WriteOptions}; -use crate::arrow::read::schema::is_nullable; -use crate::parquet::encoding::Encoding; -use crate::parquet::page::DataPage; +use super::{EncodeNullability, StatisticsOptions}; use crate::parquet::schema::types::PrimitiveType; use crate::parquet::statistics::FixedLenStatistics; @@ -27,44 +28,6 @@ pub(crate) fn encode_plain( } } -pub fn array_to_page( - array: &FixedSizeBinaryArray, - options: WriteOptions, - type_: PrimitiveType, - statistics: Option, -) -> PolarsResult { - let is_optional = is_nullable(&type_.field_info); - let encode_options = EncodeNullability::new(is_optional); - - let validity = array.validity(); - - let mut buffer = vec![]; - utils::write_def_levels( - &mut buffer, - is_optional, - validity, - array.len(), - options.version, - )?; - - let definition_levels_byte_length = buffer.len(); - - encode_plain(array, encode_options, &mut buffer); - - utils::build_plain_page( - buffer, - array.len(), - array.len(), - array.null_count(), - 0, - definition_levels_byte_length, - statistics.map(|x| x.serialize()), - type_, - options, - Encoding::Plain, - ) -} - pub(super) fn build_statistics( array: &FixedSizeBinaryArray, primitive_type: PrimitiveType, diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs new file mode 100644 index 000000000000..81175cf5db18 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs @@ -0,0 +1,39 @@ +use arrow::array::{Array, FixedSizeBinaryArray}; +use polars_error::PolarsResult; + +use super::encode_plain; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; +use crate::read::schema::is_nullable; +use crate::write::{nested, utils, EncodeNullability, Encoding, Nested, WriteOptions}; + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], + statistics: Option, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, encode_options, &mut buffer); + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics.map(|x| x.serialize()), + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/mod.rs b/crates/polars-parquet/src/arrow/write/mod.rs index 02f0165d04c7..17a342ac9d67 100644 --- a/crates/polars-parquet/src/arrow/write/mod.rs +++ b/crates/polars-parquet/src/arrow/write/mod.rs @@ -17,7 +17,7 @@ mod binview; mod boolean; mod dictionary; mod file; -mod fixed_len_bytes; +mod fixed_size_binary; mod nested; mod pages; mod primitive; @@ -528,7 +528,7 @@ pub fn array_to_page_simple( array.validity().cloned(), ); let statistics = if options.has_statistics() { - Some(fixed_len_bytes::build_statistics( + Some(fixed_size_binary::build_statistics( &array, type_.clone(), &options.statistics, @@ -536,7 +536,7 @@ pub fn array_to_page_simple( } else { None }; - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) }, ArrowDataType::Interval(IntervalUnit::DayTime) => { let array = array @@ -555,7 +555,7 @@ pub fn array_to_page_simple( array.validity().cloned(), ); let statistics = if options.has_statistics() { - Some(fixed_len_bytes::build_statistics( + Some(fixed_size_binary::build_statistics( &array, type_.clone(), &options.statistics, @@ -563,12 +563,12 @@ pub fn array_to_page_simple( } else { None }; - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) }, ArrowDataType::FixedSizeBinary(_) => { let array = array.as_any().downcast_ref().unwrap(); let statistics = if options.has_statistics() { - Some(fixed_len_bytes::build_statistics( + Some(fixed_size_binary::build_statistics( array, type_.clone(), &options.statistics, @@ -577,7 +577,7 @@ pub fn array_to_page_simple( None }; - fixed_len_bytes::array_to_page(array, options, type_, statistics) + fixed_size_binary::array_to_page(array, options, type_, statistics) }, ArrowDataType::Decimal256(precision, _) => { let precision = *precision; @@ -620,7 +620,7 @@ pub fn array_to_page_simple( } else if precision <= 38 { let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + let stats = fixed_size_binary::build_statistics_decimal256_with_i128( array, type_.clone(), size, @@ -641,7 +641,7 @@ pub fn array_to_page_simple( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) } else { let size = 32; let array = array @@ -649,7 +649,7 @@ pub fn array_to_page_simple( .downcast_ref::>() .unwrap(); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256( + let stats = fixed_size_binary::build_statistics_decimal256( array, type_.clone(), size, @@ -670,7 +670,7 @@ pub fn array_to_page_simple( array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) } }, ArrowDataType::Decimal(precision, _) => { @@ -715,7 +715,7 @@ pub fn array_to_page_simple( let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal( + let stats = fixed_size_binary::build_statistics_decimal( array, type_.clone(), size, @@ -736,7 +736,7 @@ pub fn array_to_page_simple( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) } }, other => polars_bail!(nyi = "Writing parquet pages for data type {other:?}"), @@ -858,7 +858,7 @@ fn array_to_page_nested( let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal( + let stats = fixed_size_binary::build_statistics_decimal( array, type_.clone(), size, @@ -879,7 +879,7 @@ fn array_to_page_nested( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) } }, Decimal256(precision, _) => { @@ -919,7 +919,7 @@ fn array_to_page_nested( } else if precision <= 38 { let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + let stats = fixed_size_binary::build_statistics_decimal256_with_i128( array, type_.clone(), size, @@ -940,7 +940,7 @@ fn array_to_page_nested( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) } else { let size = 32; let array = array @@ -948,7 +948,7 @@ fn array_to_page_nested( .downcast_ref::>() .unwrap(); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256( + let stats = fixed_size_binary::build_statistics_decimal256( array, type_.clone(), size, @@ -969,7 +969,7 @@ fn array_to_page_nested( array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) } }, other => polars_bail!(nyi = "Writing nested parquet pages for data type {other:?}"), diff --git a/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs b/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs index 546efd034a9c..961393bf4ed2 100644 --- a/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs +++ b/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs @@ -79,7 +79,7 @@ pub fn num_values(nested: &[Nested]) -> usize { BufferedDremelIter::new(nested).count() } -impl<'a> Level<'a> { +impl Level<'_> { /// Fetch the number of elements given on the next level at `offset` on this level fn next_level_length(&self, offset: usize, is_valid: bool) -> usize { match self.lengths { @@ -407,7 +407,7 @@ impl<'a> BufferedDremelIter<'a> { } } -impl<'a> Iterator for BufferedDremelIter<'a> { +impl Iterator for BufferedDremelIter<'_> { type Item = DremelValue; fn next(&mut self) -> Option { diff --git a/crates/polars-parquet/src/arrow/write/pages.rs b/crates/polars-parquet/src/arrow/write/pages.rs index ce2cb0c1cc29..f943ebfbc671 100644 --- a/crates/polars-parquet/src/arrow/write/pages.rs +++ b/crates/polars-parquet/src/arrow/write/pages.rs @@ -600,6 +600,7 @@ mod tests { let array = StructArray::new( ArrowDataType::Struct(fields), + 4, vec![boolean.clone(), int.clone()], Some(Bitmap::from([true, true, false, true])), ); @@ -664,6 +665,7 @@ mod tests { let array = StructArray::new( ArrowDataType::Struct(fields), + 4, vec![boolean.clone(), int.clone()], Some(Bitmap::from([true, true, false, true])), ); @@ -675,6 +677,7 @@ mod tests { let array = StructArray::new( ArrowDataType::Struct(fields), + 4, vec![Box::new(array.clone()), Box::new(array)], None, ); @@ -767,6 +770,7 @@ mod tests { let array = StructArray::new( ArrowDataType::Struct(fields), + 4, vec![boolean.clone(), int.clone()], Some(Bitmap::from([true, true, false, true])), ); @@ -872,7 +876,7 @@ mod tests { let key_array = Utf8Array::::from_slice(["k1", "k2", "k3", "k4", "k5", "k6"]).boxed(); let val_array = Int32Array::from_slice([42, 28, 19, 31, 21, 17]).boxed(); - let kv_array = StructArray::try_new(kv_type, vec![key_array, val_array], None) + let kv_array = StructArray::try_new(kv_type, 6, vec![key_array, val_array], None) .unwrap() .boxed(); let offsets = OffsetsBuffer::try_from(vec![0, 2, 3, 4, 6]).unwrap(); diff --git a/crates/polars-parquet/src/arrow/write/primitive/basic.rs b/crates/polars-parquet/src/arrow/write/primitive/basic.rs index d970e3659dcb..88bab6ca5ea6 100644 --- a/crates/polars-parquet/src/arrow/write/primitive/basic.rs +++ b/crates/polars-parquet/src/arrow/write/primitive/basic.rs @@ -38,7 +38,7 @@ where let mut iter = validity.iter(); let values = array.values().as_slice(); - buffer.reserve(std::mem::size_of::() * (array.len() - null_count)); + buffer.reserve(size_of::() * (array.len() - null_count)); let mut offset = 0; let mut remaining_valid = array.len() - null_count; @@ -61,7 +61,7 @@ where } } - buffer.reserve(std::mem::size_of::

() * array.len()); + buffer.reserve(size_of::

() * array.len()); buffer.extend( array .values() diff --git a/crates/polars-parquet/src/arrow/write/schema.rs b/crates/polars-parquet/src/arrow/write/schema.rs index 1403a7f4eeec..61c7f9fad218 100644 --- a/crates/polars-parquet/src/arrow/write/schema.rs +++ b/crates/polars-parquet/src/arrow/write/schema.rs @@ -290,7 +290,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { ArrowDataType::Struct(fields) => { if fields.is_empty() { polars_bail!(InvalidOperation: - "Parquet does not support writing empty structs".to_string(), + "Unable to write struct type with no child field to Parquet. Consider adding a dummy child field.".to_string(), ) } // recursively convert children to types/nodes diff --git a/crates/polars-parquet/src/parquet/bloom_filter/read.rs b/crates/polars-parquet/src/parquet/bloom_filter/read.rs index deda00b36272..fe4ee718cb7e 100644 --- a/crates/polars-parquet/src/parquet/bloom_filter/read.rs +++ b/crates/polars-parquet/src/parquet/bloom_filter/read.rs @@ -1,7 +1,7 @@ use std::io::{Read, Seek, SeekFrom}; -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; -use parquet_format_safe::{ +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::{ BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHeader, SplitBlockAlgorithm, Uncompressed, }; diff --git a/crates/polars-parquet/src/parquet/compression.rs b/crates/polars-parquet/src/parquet/compression.rs index 41bfb5f557bf..f8e90f65e3ee 100644 --- a/crates/polars-parquet/src/parquet/compression.rs +++ b/crates/polars-parquet/src/parquet/compression.rs @@ -245,7 +245,7 @@ fn try_decompress_hadoop(input_buf: &[u8], output_buf: &mut [u8]) -> ParquetResu // The Hadoop Lz4Codec source code can be found here: // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/Lz4Codec.cc - const SIZE_U32: usize = std::mem::size_of::(); + const SIZE_U32: usize = size_of::(); const PREFIX_LEN: usize = SIZE_U32 * 2; let mut input_len = input_buf.len(); let mut input = input_buf; diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs index 6e37507d137f..b5ea9b815dc1 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs @@ -13,7 +13,7 @@ pub struct Decoder<'a, T: Unpackable> { _pd: std::marker::PhantomData, } -impl<'a, T: Unpackable> Default for Decoder<'a, T> { +impl Default for Decoder<'_, T> { fn default() -> Self { Self { packed: [].chunks(1), @@ -56,7 +56,7 @@ impl<'a, T: Unpackable> Decoder<'a, T> { num_bits: usize, length: usize, ) -> ParquetResult { - let block_size = std::mem::size_of::() * num_bits; + let block_size = size_of::() * num_bits; if packed.len() * 8 < length * num_bits { return Err(ParquetError::oos(format!( @@ -78,7 +78,7 @@ impl<'a, T: Unpackable> Decoder<'a, T> { /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. pub fn try_new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { - let block_size = std::mem::size_of::() * num_bits; + let block_size = size_of::() * num_bits; if num_bits == 0 { return Err(ParquetError::oos("Bitpacking requires num_bits > 0")); @@ -114,7 +114,7 @@ pub struct ChunkedDecoder<'a, 'b, T: Unpackable> { pub(crate) decoder: &'b mut Decoder<'a, T>, } -impl<'a, 'b, T: Unpackable> Iterator for ChunkedDecoder<'a, 'b, T> { +impl Iterator for ChunkedDecoder<'_, '_, T> { type Item = T::Unpacked; #[inline] @@ -136,9 +136,9 @@ impl<'a, 'b, T: Unpackable> Iterator for ChunkedDecoder<'a, 'b, T> { } } -impl<'a, 'b, T: Unpackable> ExactSizeIterator for ChunkedDecoder<'a, 'b, T> {} +impl ExactSizeIterator for ChunkedDecoder<'_, '_, T> {} -impl<'a, 'b, T: Unpackable> ChunkedDecoder<'a, 'b, T> { +impl ChunkedDecoder<'_, '_, T> { /// Get and consume the remainder chunk if it exists pub fn remainder(&mut self) -> Option<(T::Unpacked, usize)> { let remainder_len = self.decoder.len() % T::Unpacked::LENGTH; @@ -181,7 +181,7 @@ impl<'a, T: Unpackable> Decoder<'a, T> { } pub fn take(&mut self) -> Self { - let block_size = std::mem::size_of::() * self.num_bits; + let block_size = size_of::() * self.num_bits; let packed = std::mem::replace(&mut self.packed, [].chunks(block_size)); let length = self.length; self.length = 0; @@ -262,7 +262,7 @@ mod tests { use super::super::tests::case1; use super::*; - impl<'a, T: Unpackable> Decoder<'a, T> { + impl Decoder<'_, T> { pub fn collect(self) -> Vec { let mut vec = Vec::new(); self.collect_into(&mut vec); diff --git a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs index 793fa6f111d7..1b383e9522f1 100644 --- a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs @@ -96,7 +96,7 @@ where converter: F, } -impl<'a, 'b, T, F> Iterator for DecoderIterator<'a, 'b, T, F> +impl Iterator for DecoderIterator<'_, '_, T, F> where F: Copy + Fn(&[u8]) -> T, { diff --git a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs index 1ef6a99a2128..555954ff8069 100644 --- a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs @@ -14,7 +14,7 @@ mod tests { let mut buffer = vec![]; encode(&data, &mut buffer); - let mut decoder = Decoder::try_new(&buffer, std::mem::size_of::())?; + let mut decoder = Decoder::try_new(&buffer, size_of::())?; let values = decoder .iter_converted(|bytes| f32::from_le_bytes(bytes.try_into().unwrap())) .collect::>(); @@ -30,7 +30,7 @@ mod tests { let mut buffer = vec![]; encode(&data, &mut buffer); - let mut decoder = Decoder::try_new(&buffer, std::mem::size_of::())?; + let mut decoder = Decoder::try_new(&buffer, size_of::())?; let values = decoder .iter_converted(|bytes| f64::from_le_bytes(bytes.try_into().unwrap())) .collect::>(); @@ -61,9 +61,9 @@ mod tests { } fn encode(data: &[T], buffer: &mut Vec) { - let element_size = std::mem::size_of::(); + let element_size = size_of::(); let num_elements = data.len(); - let total_length = std::mem::size_of_val(data); + let total_length = size_of_val(data); buffer.resize(total_length, 0); for (i, v) in data.iter().enumerate() { diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs index 03889e0aa5d3..deb95f1dd3a2 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs @@ -56,7 +56,7 @@ impl<'a> Decoder<'a> { mod tests { use super::*; - impl<'a> Iterator for Decoder<'a> { + impl Iterator for Decoder<'_> { type Item = ParquetResult>; fn next(&mut self) -> Option { diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs index f46f22f84adb..0d67dc935857 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs @@ -39,7 +39,7 @@ impl<'a> BitmapIter<'a> { } } -impl<'a> Iterator for BitmapIter<'a> { +impl Iterator for BitmapIter<'_> { type Item = bool; #[inline] diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs index 824638d253ad..95d53b2769e4 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs @@ -44,7 +44,7 @@ impl Iterator for BufferedRle { impl ExactSizeIterator for BufferedRle {} -impl<'a> Iterator for BufferedBitpacked<'a> { +impl Iterator for BufferedBitpacked<'_> { type Item = u32; fn next(&mut self) -> Option { @@ -74,9 +74,9 @@ impl<'a> Iterator for BufferedBitpacked<'a> { } } -impl<'a> ExactSizeIterator for BufferedBitpacked<'a> {} +impl ExactSizeIterator for BufferedBitpacked<'_> {} -impl<'a> Iterator for HybridRleBuffered<'a> { +impl Iterator for HybridRleBuffered<'_> { type Item = u32; fn next(&mut self) -> Option { @@ -94,9 +94,9 @@ impl<'a> Iterator for HybridRleBuffered<'a> { } } -impl<'a> ExactSizeIterator for HybridRleBuffered<'a> {} +impl ExactSizeIterator for HybridRleBuffered<'_> {} -impl<'a> BufferedBitpacked<'a> { +impl BufferedBitpacked<'_> { fn gather_limited_into>( &mut self, target: &mut G::Target, @@ -212,7 +212,7 @@ impl BufferedRle { } } -impl<'a> HybridRleBuffered<'a> { +impl HybridRleBuffered<'_> { pub fn gather_limited_into>( &mut self, target: &mut G::Target, diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs index 1548f6e50a02..c66ef5873439 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs @@ -432,7 +432,7 @@ impl Translator for UnitTranslator { /// [`HybridRleDecoder`]: super::HybridRleDecoder pub struct DictionaryTranslator<'a, T>(pub &'a [T]); -impl<'a, T: Copy> Translator for DictionaryTranslator<'a, T> { +impl Translator for DictionaryTranslator<'_, T> { fn translate(&self, value: u32) -> ParquetResult { self.0 .get(value as usize) diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index c71b6455ddf4..f721274d00cd 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -133,7 +133,7 @@ impl<'a> HybridRleDecoder<'a> { if run_length == 0 { 0 } else { - let mut bytes = [0u8; std::mem::size_of::()]; + let mut bytes = [0u8; size_of::()]; pack.iter().zip(bytes.iter_mut()).for_each(|(src, dst)| { *dst = *src; }); @@ -380,7 +380,7 @@ impl<'a> HybridRleDecoder<'a> { if run_length <= n { run_length } else { - let mut bytes = [0u8; std::mem::size_of::()]; + let mut bytes = [0u8; size_of::()]; pack.iter().zip(bytes.iter_mut()).for_each(|(src, dst)| { *dst = *src; }); diff --git a/crates/polars-parquet/src/parquet/error.rs b/crates/polars-parquet/src/parquet/error.rs index fdd7f8fcadfd..3d00cf3c647f 100644 --- a/crates/polars-parquet/src/parquet/error.rs +++ b/crates/polars-parquet/src/parquet/error.rs @@ -94,8 +94,8 @@ impl From for ParquetError { } } -impl From for ParquetError { - fn from(e: parquet_format_safe::thrift::Error) -> ParquetError { +impl From for ParquetError { + fn from(e: polars_parquet_format::thrift::Error) -> ParquetError { ParquetError::OutOfSpec(format!("Invalid thrift: {}", e)) } } diff --git a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs index 30a606d6108a..dba897a8eeea 100644 --- a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::{ColumnChunk, ColumnMetaData, Encoding}; +use polars_parquet_format::{ColumnChunk, ColumnMetaData, Encoding}; use super::column_descriptor::ColumnDescriptor; use crate::parquet::compression::Compression; @@ -10,7 +10,7 @@ use crate::parquet::statistics::Statistics; mod serde_types { pub use std::io::Cursor; - pub use parquet_format_safe::thrift::protocol::{ + pub use polars_parquet_format::thrift::protocol::{ TCompactInputProtocol, TCompactOutputProtocol, }; pub use serde::de::Error as DeserializeError; diff --git a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs index 2705c2a7b70d..ed14a1e130d6 100644 --- a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::ColumnOrder as TColumnOrder; +use polars_parquet_format::ColumnOrder as TColumnOrder; use super::column_order::ColumnOrder; use super::schema_descriptor::SchemaDescriptor; @@ -8,7 +8,7 @@ use crate::parquet::metadata::get_sort_order; pub use crate::parquet::thrift_format::KeyValue; /// Metadata for a Parquet file. -// This is almost equal to [`parquet_format_safe::FileMetaData`] but contains the descriptors, +// This is almost equal to [`polars_parquet_format::FileMetaData`] but contains the descriptors, // which are crucial to deserialize pages. #[derive(Debug, Clone)] pub struct FileMetadata { @@ -65,7 +65,7 @@ impl FileMetadata { /// Deserializes [`crate::parquet::thrift_format::FileMetadata`] into this struct pub fn try_from_thrift( - metadata: parquet_format_safe::FileMetaData, + metadata: polars_parquet_format::FileMetaData, ) -> Result { let schema_descr = SchemaDescriptor::try_from_thrift(&metadata.schema)?; diff --git a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs index 9cca27553415..b02983a760ed 100644 --- a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use hashbrown::hash_map::RawEntryMut; -use parquet_format_safe::RowGroup; +use polars_parquet_format::{RowGroup, SortingColumn}; use polars_utils::aliases::{InitHashMaps, PlHashMap}; use polars_utils::idx_vec::UnitVec; use polars_utils::pl_str::PlSmallStr; @@ -41,6 +41,7 @@ pub struct RowGroupMetadata { num_rows: usize, total_byte_size: usize, full_byte_range: core::ops::Range, + sorting_columns: Option>, } impl RowGroupMetadata { @@ -59,6 +60,11 @@ impl RowGroupMetadata { .map(|x| x.iter().map(|&x| &self.columns[x])) } + /// Fetch all columns under this root name if it exists. + pub fn columns_idxs_under_root_iter<'a>(&'a self, root_name: &str) -> Option<&'a [usize]> { + self.column_lookup.get(root_name).map(|x| x.as_slice()) + } + /// Number of rows in this row group. pub fn num_rows(&self) -> usize { self.num_rows @@ -85,6 +91,10 @@ impl RowGroupMetadata { self.columns.iter().map(|x| x.byte_range()) } + pub fn sorting_columns(&self) -> Option<&[SortingColumn]> { + self.sorting_columns.as_deref() + } + /// Method to convert from Thrift. pub(crate) fn try_from_thrift( schema_descr: &SchemaDescriptor, @@ -106,6 +116,8 @@ impl RowGroupMetadata { 0..0 }; + let sorting_columns = rg.sorting_columns.clone(); + let columns = rg .columns .into_iter() @@ -131,6 +143,7 @@ impl RowGroupMetadata { num_rows, total_byte_size, full_byte_range, + sorting_columns, }) } } diff --git a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs index 7c29f983ee1d..c40fcdd1309b 100644 --- a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs +++ b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::SchemaElement; +use polars_parquet_format::SchemaElement; use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/metadata/sort.rs b/crates/polars-parquet/src/parquet/metadata/sort.rs index 93aac06605b6..d75d77134103 100644 --- a/crates/polars-parquet/src/parquet/metadata/sort.rs +++ b/crates/polars-parquet/src/parquet/metadata/sort.rs @@ -56,6 +56,7 @@ fn get_logical_sort_order(logical_type: &PrimitiveLogicalType) -> SortOrder { Timestamp { .. } => SortOrder::Signed, Unknown => SortOrder::Undefined, Uuid => SortOrder::Unsigned, + Float16 => SortOrder::Unsigned, } } diff --git a/crates/polars-parquet/src/parquet/mod.rs b/crates/polars-parquet/src/parquet/mod.rs index ea6b5b2c8357..1926e641fd04 100644 --- a/crates/polars-parquet/src/parquet/mod.rs +++ b/crates/polars-parquet/src/parquet/mod.rs @@ -15,7 +15,7 @@ pub mod write; use std::ops::Deref; -use parquet_format_safe as thrift_format; +use polars_parquet_format as thrift_format; use polars_utils::mmap::MemSlice; pub use streaming_decompression::{fallible_streaming_iterator, FallibleStreamingIterator}; diff --git a/crates/polars-parquet/src/parquet/parquet_bridge.rs b/crates/polars-parquet/src/parquet/parquet_bridge.rs index 21261f7ca011..523ea0e3e12e 100644 --- a/crates/polars-parquet/src/parquet/parquet_bridge.rs +++ b/crates/polars-parquet/src/parquet/parquet_bridge.rs @@ -497,6 +497,7 @@ pub enum PrimitiveLogicalType { Json, Bson, Uuid, + Float16, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -575,6 +576,7 @@ impl TryFrom for PrimitiveLogicalType { ParquetLogicalType::JSON(_) => PrimitiveLogicalType::Json, ParquetLogicalType::BSON(_) => PrimitiveLogicalType::Bson, ParquetLogicalType::UUID(_) => PrimitiveLogicalType::Uuid, + ParquetLogicalType::FLOAT16(_) => PrimitiveLogicalType::Float16, _ => return Err(ParquetError::oos("LogicalType value out of range")), }) } @@ -629,6 +631,7 @@ impl From for ParquetLogicalType { PrimitiveLogicalType::Json => ParquetLogicalType::JSON(Default::default()), PrimitiveLogicalType::Bson => ParquetLogicalType::BSON(Default::default()), PrimitiveLogicalType::Uuid => ParquetLogicalType::UUID(Default::default()), + PrimitiveLogicalType::Float16 => ParquetLogicalType::FLOAT16(Default::default()), } } } diff --git a/crates/polars-parquet/src/parquet/read/compression.rs b/crates/polars-parquet/src/parquet/read/compression.rs index a79989c39e26..1bd457474e06 100644 --- a/crates/polars-parquet/src/parquet/read/compression.rs +++ b/crates/polars-parquet/src/parquet/read/compression.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::DataPageHeaderV2; +use polars_parquet_format::DataPageHeaderV2; use super::PageReader; use crate::parquet::compression::{self, Compression}; diff --git a/crates/polars-parquet/src/parquet/read/metadata.rs b/crates/polars-parquet/src/parquet/read/metadata.rs index e14a2a60e997..a260fe71ff06 100644 --- a/crates/polars-parquet/src/parquet/read/metadata.rs +++ b/crates/polars-parquet/src/parquet/read/metadata.rs @@ -1,8 +1,8 @@ use std::cmp::min; use std::io::{Read, Seek, SeekFrom}; -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; -use parquet_format_safe::FileMetaData as TFileMetadata; +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::FileMetaData as TFileMetadata; use super::super::metadata::FileMetadata; use super::super::{DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, HEADER_SIZE, PARQUET_MAGIC}; diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs index cd23af0499d7..7dfa2e144d8d 100644 --- a/crates/polars-parquet/src/parquet/read/page/reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -1,7 +1,7 @@ use std::io::Seek; use std::sync::OnceLock; -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; use polars_utils::mmap::{MemReader, MemSlice}; use super::PageIterator; @@ -13,6 +13,7 @@ use crate::parquet::page::{ ParquetPageHeader, }; use crate::parquet::CowBuffer; +use crate::write::Encoding; /// This meta is a small part of [`ColumnChunkMetadata`]. #[derive(Debug, Clone, PartialEq, Eq)] @@ -96,7 +97,7 @@ impl PageReader { Self::new_with_page_meta(reader, column.into(), scratch, max_page_size) } - /// Create a a new [`PageReader`] with [`PageMetaData`]. + /// Create a new [`PageReader`] with [`PageMetaData`]. /// /// It assumes that the reader has been `sought` (`seek`) to the beginning of `column`. pub fn new_with_page_meta( @@ -251,7 +252,10 @@ pub(super) fn finish_page( })?; if do_verbose { - println!("DictPage ( )"); + eprintln!( + "Parquet DictPage ( num_values: {}, datatype: {:?} )", + dict_header.num_values, descriptor.primitive_type + ); } let is_sorted = dict_header.is_sorted.unwrap_or(false); @@ -275,9 +279,11 @@ pub(super) fn finish_page( })?; if do_verbose { - println!( - "DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", - header.num_values, descriptor.primitive_type, header.encoding + eprintln!( + "Parquet DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() ); } @@ -298,8 +304,10 @@ pub(super) fn finish_page( if do_verbose { println!( - "DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", - header.num_values, descriptor.primitive_type, header.encoding + "Parquet DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() ); } diff --git a/crates/polars-parquet/src/parquet/read/page/stream.rs b/crates/polars-parquet/src/parquet/read/page/stream.rs index fbd36b3ccfe1..7145689493fc 100644 --- a/crates/polars-parquet/src/parquet/read/page/stream.rs +++ b/crates/polars-parquet/src/parquet/read/page/stream.rs @@ -2,7 +2,7 @@ use std::io::SeekFrom; use async_stream::try_stream; use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream}; -use parquet_format_safe::thrift::protocol::TCompactInputStreamProtocol; +use polars_parquet_format::thrift::protocol::TCompactInputStreamProtocol; use polars_utils::mmap::MemSlice; use super::reader::{finish_page, PageMetaData}; diff --git a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs index d4f2c692e95d..ccf293e48b08 100644 --- a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs +++ b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs @@ -42,7 +42,7 @@ //! println!("{:?}", schema); //! ``` -use parquet_format_safe::Type; +use polars_parquet_format::Type; use polars_utils::pl_str::PlSmallStr; use types::PrimitiveLogicalType; @@ -303,7 +303,7 @@ fn parse_timeunit( }) } -impl<'a> Parser<'a> { +impl Parser<'_> { // Entry function to parse message type, uses internal tokenizer. fn parse_message_type(&mut self) -> ParquetResult { // Check that message type starts with "message". diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs index b0bbe20999bc..7a874fb59e46 100644 --- a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::SchemaElement; +use polars_parquet_format::SchemaElement; use polars_utils::pl_str::PlSmallStr; use super::super::types::ParquetType; diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs index 3aef1fe792fa..db372b733593 100644 --- a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::{ConvertedType, SchemaElement}; +use polars_parquet_format::{ConvertedType, SchemaElement}; use super::super::types::ParquetType; use crate::parquet::schema::types::PrimitiveType; diff --git a/crates/polars-parquet/src/parquet/schema/types/converted_type.rs b/crates/polars-parquet/src/parquet/schema/types/converted_type.rs index 8432167fcd3b..91b4b5ac78f9 100644 --- a/crates/polars-parquet/src/parquet/schema/types/converted_type.rs +++ b/crates/polars-parquet/src/parquet/schema/types/converted_type.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::ConvertedType; +use polars_parquet_format::ConvertedType; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/schema/types/physical_type.rs b/crates/polars-parquet/src/parquet/schema/types/physical_type.rs index 01595134c6b3..ed7242adac71 100644 --- a/crates/polars-parquet/src/parquet/schema/types/physical_type.rs +++ b/crates/polars-parquet/src/parquet/schema/types/physical_type.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Type; +use polars_parquet_format::Type; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/schema/types/spec.rs b/crates/polars-parquet/src/parquet/schema/types/spec.rs index f21cdbb9d611..f18f7ae9015b 100644 --- a/crates/polars-parquet/src/parquet/schema/types/spec.rs +++ b/crates/polars-parquet/src/parquet/schema/types/spec.rs @@ -170,6 +170,7 @@ pub fn check_logical_invariants( (String | Json | Bson, PhysicalType::ByteArray) => {}, // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#uuid (Uuid, PhysicalType::FixedLenByteArray(16)) => {}, + (Float16, PhysicalType::FixedLenByteArray(2)) => {}, (a, b) => { return Err(ParquetError::oos(format!( "Cannot annotate {:?} from {:?} fields", diff --git a/crates/polars-parquet/src/parquet/statistics/binary.rs b/crates/polars-parquet/src/parquet/statistics/binary.rs index e9506c375a71..7f1dabf21ec8 100644 --- a/crates/polars-parquet/src/parquet/statistics/binary.rs +++ b/crates/polars-parquet/src/parquet/statistics/binary.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::ParquetResult; use crate::parquet::schema::types::PrimitiveType; @@ -32,8 +32,10 @@ impl BinaryStatistics { distinct_count: self.distinct_count, max_value: self.max_value.clone(), min_value: self.min_value.clone(), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/statistics/boolean.rs b/crates/polars-parquet/src/parquet/statistics/boolean.rs index 607897bdddf0..55e478d5b957 100644 --- a/crates/polars-parquet/src/parquet/statistics/boolean.rs +++ b/crates/polars-parquet/src/parquet/statistics/boolean.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::{ParquetError, ParquetResult}; @@ -13,14 +13,14 @@ pub struct BooleanStatistics { impl BooleanStatistics { pub fn deserialize(v: &ParquetStatistics) -> ParquetResult { if let Some(ref v) = v.max_value { - if v.len() != std::mem::size_of::() { + if v.len() != size_of::() { return Err(ParquetError::oos( "The max_value of statistics MUST be plain encoded", )); } }; if let Some(ref v) = v.min_value { - if v.len() != std::mem::size_of::() { + if v.len() != size_of::() { return Err(ParquetError::oos( "The min_value of statistics MUST be plain encoded", )); @@ -49,8 +49,10 @@ impl BooleanStatistics { distinct_count: self.distinct_count, max_value: self.max_value.map(|x| vec![x as u8]), min_value: self.min_value.map(|x| vec![x as u8]), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs b/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs index 8de2aef0a508..87642246907d 100644 --- a/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs +++ b/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::schema::types::PrimitiveType; @@ -54,8 +54,10 @@ impl FixedLenStatistics { distinct_count: self.distinct_count, max_value: self.max_value.clone(), min_value: self.min_value.clone(), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/statistics/mod.rs b/crates/polars-parquet/src/parquet/statistics/mod.rs index 7501df2a3886..03335c27817b 100644 --- a/crates/polars-parquet/src/parquet/statistics/mod.rs +++ b/crates/polars-parquet/src/parquet/statistics/mod.rs @@ -41,6 +41,34 @@ impl Statistics { } } + pub fn clear_min(&mut self) { + use Statistics as S; + match self { + S::Binary(s) => _ = s.min_value.take(), + S::Boolean(s) => _ = s.min_value.take(), + S::FixedLen(s) => _ = s.min_value.take(), + S::Int32(s) => _ = s.min_value.take(), + S::Int64(s) => _ = s.min_value.take(), + S::Int96(s) => _ = s.min_value.take(), + S::Float(s) => _ = s.min_value.take(), + S::Double(s) => _ = s.min_value.take(), + }; + } + + pub fn clear_max(&mut self) { + use Statistics as S; + match self { + S::Binary(s) => _ = s.max_value.take(), + S::Boolean(s) => _ = s.max_value.take(), + S::FixedLen(s) => _ = s.max_value.take(), + S::Int32(s) => _ = s.max_value.take(), + S::Int64(s) => _ = s.max_value.take(), + S::Int96(s) => _ = s.max_value.take(), + S::Float(s) => _ = s.max_value.take(), + S::Double(s) => _ = s.max_value.take(), + }; + } + /// Deserializes a raw parquet statistics into [`Statistics`]. /// # Error /// This function errors if it is not possible to read the statistics to the @@ -51,7 +79,7 @@ impl Statistics { primitive_type: PrimitiveType, ) -> ParquetResult { use {PhysicalType as T, PrimitiveStatistics as PrimStat}; - Ok(match primitive_type.physical_type { + let mut stats: Self = match primitive_type.physical_type { T::ByteArray => BinaryStatistics::deserialize(statistics, primitive_type)?.into(), T::Boolean => BooleanStatistics::deserialize(statistics)?.into(), T::Int32 => PrimStat::::deserialize(statistics, primitive_type)?.into(), @@ -62,7 +90,16 @@ impl Statistics { T::FixedLenByteArray(size) => { FixedLenStatistics::deserialize(statistics, size, primitive_type)?.into() }, - }) + }; + + if statistics.is_min_value_exact.is_some_and(|v| !v) { + stats.clear_min(); + } + if statistics.is_max_value_exact.is_some_and(|v| !v) { + stats.clear_max(); + } + + Ok(stats) } } diff --git a/crates/polars-parquet/src/parquet/statistics/primitive.rs b/crates/polars-parquet/src/parquet/statistics/primitive.rs index ed5ae71515b0..7bd8d227faa2 100644 --- a/crates/polars-parquet/src/parquet/statistics/primitive.rs +++ b/crates/polars-parquet/src/parquet/statistics/primitive.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::schema::types::PrimitiveType; @@ -20,7 +20,7 @@ impl PrimitiveStatistics { ) -> ParquetResult { if v.max_value .as_ref() - .is_some_and(|v| v.len() != std::mem::size_of::()) + .is_some_and(|v| v.len() != size_of::()) { return Err(ParquetError::oos( "The max_value of statistics MUST be plain encoded", @@ -28,7 +28,7 @@ impl PrimitiveStatistics { }; if v.min_value .as_ref() - .is_some_and(|v| v.len() != std::mem::size_of::()) + .is_some_and(|v| v.len() != size_of::()) { return Err(ParquetError::oos( "The min_value of statistics MUST be plain encoded", @@ -50,8 +50,10 @@ impl PrimitiveStatistics { distinct_count: self.distinct_count, max_value: self.max_value.map(|x| x.to_le_bytes().as_ref().to_vec()), min_value: self.min_value.map(|x| x.to_le_bytes().as_ref().to_vec()), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/types.rs b/crates/polars-parquet/src/parquet/types.rs index 7591f6ba0bd7..1dd65d0ff622 100644 --- a/crates/polars-parquet/src/parquet/types.rs +++ b/crates/polars-parquet/src/parquet/types.rs @@ -22,7 +22,7 @@ pub trait NativeType: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { macro_rules! native { ($type:ty, $physical_type:expr) => { impl NativeType for $type { - type Bytes = [u8; std::mem::size_of::()]; + type Bytes = [u8; size_of::()]; #[inline] fn to_le_bytes(&self) -> Self::Bytes { Self::to_le_bytes(*self) @@ -51,7 +51,7 @@ native!(f64, PhysicalType::Double); impl NativeType for [u32; 3] { const TYPE: PhysicalType = PhysicalType::Int96; - type Bytes = [u8; std::mem::size_of::()]; + type Bytes = [u8; size_of::()]; #[inline] fn to_le_bytes(&self) -> Self::Bytes { let mut bytes = [0; 12]; @@ -137,7 +137,7 @@ pub fn ord_binary<'a>(a: &'a [u8], b: &'a [u8]) -> std::cmp::Ordering { #[inline] pub fn decode(chunk: &[u8]) -> T { - assert!(chunk.len() >= std::mem::size_of::<::Bytes>()); + assert!(chunk.len() >= size_of::<::Bytes>()); unsafe { decode_unchecked(chunk) } } diff --git a/crates/polars-parquet/src/parquet/write/column_chunk.rs b/crates/polars-parquet/src/parquet/write/column_chunk.rs index 6ae51a191dc5..8728f289eb65 100644 --- a/crates/polars-parquet/src/parquet/write/column_chunk.rs +++ b/crates/polars-parquet/src/parquet/write/column_chunk.rs @@ -2,10 +2,10 @@ use std::io::Write; #[cfg(feature = "async")] use futures::AsyncWrite; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; #[cfg(feature = "async")] -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; -use parquet_format_safe::{ColumnChunk, ColumnMetaData, Type}; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::{ColumnChunk, ColumnMetaData, Type}; use polars_utils::aliases::PlHashSet; #[cfg(feature = "async")] @@ -195,6 +195,8 @@ fn build_column_chunk( statistics, encoding_stats: None, bloom_filter_offset: None, + bloom_filter_length: None, + size_statistics: None, }; Ok(ColumnChunk { diff --git a/crates/polars-parquet/src/parquet/write/dyn_iter.rs b/crates/polars-parquet/src/parquet/write/dyn_iter.rs index f47710b56b22..a232c06375e8 100644 --- a/crates/polars-parquet/src/parquet/write/dyn_iter.rs +++ b/crates/polars-parquet/src/parquet/write/dyn_iter.rs @@ -7,7 +7,7 @@ pub struct DynIter<'a, V> { iter: Box + 'a + Send + Sync>, } -impl<'a, V> Iterator for DynIter<'a, V> { +impl Iterator for DynIter<'_, V> { type Item = V; fn next(&mut self) -> Option { self.iter.next() @@ -35,7 +35,7 @@ pub struct DynStreamingIterator<'a, V, E> { iter: Box + 'a + Send + Sync>, } -impl<'a, V, E> FallibleStreamingIterator for DynStreamingIterator<'a, V, E> { +impl FallibleStreamingIterator for DynStreamingIterator<'_, V, E> { type Item = V; type Error = E; diff --git a/crates/polars-parquet/src/parquet/write/file.rs b/crates/polars-parquet/src/parquet/write/file.rs index 8dd3212bb76a..d46f85dd3138 100644 --- a/crates/polars-parquet/src/parquet/write/file.rs +++ b/crates/polars-parquet/src/parquet/write/file.rs @@ -1,7 +1,7 @@ use std::io::Write; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; -use parquet_format_safe::RowGroup; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::RowGroup; use super::indexes::{write_column_index, write_offset_index}; use super::page::PageWriteSpec; @@ -39,7 +39,7 @@ pub(super) fn end_file( Ok(metadata_len as u64 + FOOTER_SIZE) } -fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec { +fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec { // We only include ColumnOrder for leaf nodes. // Currently only supported ColumnOrder is TypeDefinedOrder so we set this // for all leaf nodes. @@ -47,7 +47,9 @@ fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec ParquetResult ParquetResult>>()?; - Ok(OffsetIndex { page_locations }) + Ok(OffsetIndex { + page_locations, + unencoded_byte_array_data_bytes: None, + }) } diff --git a/crates/polars-parquet/src/parquet/write/indexes/write.rs b/crates/polars-parquet/src/parquet/write/indexes/write.rs index 7c82b1dcc9ae..73325654e518 100644 --- a/crates/polars-parquet/src/parquet/write/indexes/write.rs +++ b/crates/polars-parquet/src/parquet/write/indexes/write.rs @@ -2,9 +2,9 @@ use std::io::Write; #[cfg(feature = "async")] use futures::AsyncWrite; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; #[cfg(feature = "async")] -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; use super::serialize::{serialize_column_index, serialize_offset_index}; use crate::parquet::error::ParquetResult; diff --git a/crates/polars-parquet/src/parquet/write/page.rs b/crates/polars-parquet/src/parquet/write/page.rs index f9e527d5a9db..8fb65c3daf12 100644 --- a/crates/polars-parquet/src/parquet/write/page.rs +++ b/crates/polars-parquet/src/parquet/write/page.rs @@ -2,10 +2,10 @@ use std::io::Write; #[cfg(feature = "async")] use futures::{AsyncWrite, AsyncWriteExt}; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; #[cfg(feature = "async")] -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; -use parquet_format_safe::{DictionaryPageHeader, Encoding, PageType}; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::{DictionaryPageHeader, Encoding, PageType}; use crate::parquet::compression::Compression; use crate::parquet::error::{ParquetError, ParquetResult}; diff --git a/crates/polars-parquet/src/parquet/write/row_group.rs b/crates/polars-parquet/src/parquet/write/row_group.rs index 43404dc32a89..dfca3d27f948 100644 --- a/crates/polars-parquet/src/parquet/write/row_group.rs +++ b/crates/polars-parquet/src/parquet/write/row_group.rs @@ -2,7 +2,7 @@ use std::io::Write; #[cfg(feature = "async")] use futures::AsyncWrite; -use parquet_format_safe::{ColumnChunk, RowGroup}; +use polars_parquet_format::{ColumnChunk, RowGroup}; use super::column_chunk::write_column_chunk; #[cfg(feature = "async")] diff --git a/crates/polars-parquet/src/parquet/write/stream.rs b/crates/polars-parquet/src/parquet/write/stream.rs index eca712db65dc..05c50e6e3a2c 100644 --- a/crates/polars-parquet/src/parquet/write/stream.rs +++ b/crates/polars-parquet/src/parquet/write/stream.rs @@ -1,8 +1,8 @@ use std::io::Write; use futures::{AsyncWrite, AsyncWriteExt}; -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; -use parquet_format_safe::RowGroup; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::RowGroup; use super::row_group::write_row_group_async; use super::{RowGroupIterColumns, WriteOptions}; @@ -20,7 +20,7 @@ async fn start_file(writer: &mut W) -> ParquetResult async fn end_file( mut writer: &mut W, - metadata: parquet_format_safe::FileMetaData, + metadata: polars_parquet_format::FileMetaData, ) -> ParquetResult { // Write file metadata let mut protocol = TCompactOutputStreamProtocol::new(&mut writer); @@ -169,7 +169,7 @@ impl FileStreamer { } } - let metadata = parquet_format_safe::FileMetaData::new( + let metadata = polars_parquet_format::FileMetaData::new( self.options.version.into(), self.schema.clone().into_thrift(), num_rows, diff --git a/crates/polars-pipe/src/executors/operators/projection.rs b/crates/polars-pipe/src/executors/operators/projection.rs index 9ae6dbc5299d..de8c9d424b44 100644 --- a/crates/polars-pipe/src/executors/operators/projection.rs +++ b/crates/polars-pipe/src/executors/operators/projection.rs @@ -89,7 +89,8 @@ impl Operator for ProjectionOperator { } } - let chunk = chunk.with_data(unsafe { DataFrame::new_no_checks(projected) }); + let chunk = + chunk.with_data(unsafe { DataFrame::new_no_checks_height_from_first(projected) }); Ok(OperatorResult::Finished(chunk)) } fn split(&self, _thread_no: usize) -> Box { @@ -125,7 +126,7 @@ impl Operator for HstackOperator { .collect::>>()?; let columns = chunk.data.get_columns()[..width].to_vec(); - let mut df = unsafe { DataFrame::new_no_checks(columns) }; + let mut df = unsafe { DataFrame::new_no_checks_height_from_first(columns) }; let schema = &*self.input_schema; if self.options.should_broadcast { diff --git a/crates/polars-pipe/src/executors/operators/reproject.rs b/crates/polars-pipe/src/executors/operators/reproject.rs index a4f6010bef79..d037937896d3 100644 --- a/crates/polars-pipe/src/executors/operators/reproject.rs +++ b/crates/polars-pipe/src/executors/operators/reproject.rs @@ -27,7 +27,7 @@ pub(crate) fn reproject_chunk( } else { let columns = chunk.data.get_columns(); let cols = positions.iter().map(|i| columns[*i].clone()).collect(); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(chunk.data.height(), cols) } }; *chunk = chunk.with_data(out); Ok(()) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs index 4e81e8531bac..14ce4836096c 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs @@ -135,7 +135,7 @@ pub(crate) fn convert_to_hash_agg( to_physical: &F, ) -> (DataType, Arc, AggregateFunction) where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { match expr_arena.get(node) { AExpr::Alias(input, _) => convert_to_hash_agg(*input, expr_arena, schema, to_physical), @@ -146,12 +146,9 @@ where ), AExpr::Agg(agg) => match agg { IRAggExpr::Min { input, .. } => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; let agg_fn = match logical_dtype.to_physical() { @@ -170,12 +167,9 @@ where (logical_dtype, phys_expr, agg_fn) }, IRAggExpr::Max { input, .. } => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; let agg_fn = match logical_dtype.to_physical() { @@ -194,12 +188,9 @@ where (logical_dtype, phys_expr, agg_fn) }, IRAggExpr::Sum(input) => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; #[cfg(feature = "dtype-categorical")] @@ -217,7 +208,7 @@ where let agg_fn = match logical_dtype.to_physical() { // Boolean is aggregated as the IDX type. DataType::Boolean => { - if std::mem::size_of::() == 4 { + if size_of::() == 4 { AggregateFunction::SumU32(SumAgg::::new()) } else { AggregateFunction::SumU64(SumAgg::::new()) @@ -240,12 +231,9 @@ where (logical_dtype, phys_expr, agg_fn) }, IRAggExpr::Mean(input) => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; #[cfg(feature = "dtype-categorical")] @@ -270,12 +258,9 @@ where (logical_dtype, phys_expr, agg_fn) }, IRAggExpr::First(input) => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; ( logical_dtype.clone(), @@ -284,12 +269,9 @@ where ) }, IRAggExpr::Last(input) => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; ( logical_dtype.clone(), @@ -298,12 +280,9 @@ where ) }, IRAggExpr::Count(input, _) => { - let phys_expr = to_physical( - &ExprIR::from_node(*input, expr_arena), - expr_arena, - Some(schema), - ) - .unwrap(); + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); let logical_dtype = phys_expr.field(schema).unwrap().dtype; ( logical_dtype, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs index 05947baae209..3162c57828b3 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs @@ -279,7 +279,7 @@ impl AggHashTable { .map(|buf| buf.into_series().into_column()), ); physical_agg_to_logical(&mut cols, &self.output_schema); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks_height_from_first(cols) } } } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs index e9fa7ba495cd..98425f66bee0 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs @@ -74,6 +74,8 @@ impl SpillPayload { debug_assert_eq!(self.hashes.len(), self.chunk_idx.len()); debug_assert_eq!(self.hashes.len(), self.keys.len()); + let height = self.hashes.len(); + let hashes = UInt64Chunked::from_vec(PlSmallStr::from_static(HASH_COL), self.hashes).into_column(); let chunk_idx = @@ -87,7 +89,7 @@ impl SpillPayload { cols.push(keys); // @scalar-opt cols.extend(self.aggs.into_iter().map(Column::from)); - unsafe { DataFrame::new_no_checks(cols) } + unsafe { DataFrame::new_no_checks(height, cols) } } fn spilled_to_columns( diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs index 77a939c64290..280cd236afa6 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs @@ -8,7 +8,7 @@ use crate::pipeline::{morsels_per_sink, FORCE_OOC}; pub(super) struct OocState { // OOC // Stores available memory in the system at the start of this sink. - // and stores the memory used by this this sink. + // and stores the memory used by this sink. mem_track: MemTracker, // sort in-memory or out-of-core pub(super) ooc: bool, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs index 1f79e20bcdab..f2c664087daf 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs @@ -13,7 +13,7 @@ use crate::pipeline::morsels_per_sink; pub(super) struct OocState { // OOC // Stores available memory in the system at the start of this sink. - // and stores the memory used by this this sink. + // and stores the memory used by this sink. _mem_track: MemTracker, // sort in-memory or out-of-core pub(super) ooc: bool, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs index 8715dd6f3fa9..dce130c29187 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs @@ -212,7 +212,7 @@ where .map(|buf| buf.into_series().into_column()), ); physical_agg_to_logical(&mut cols, &self.output_schema); - Some(unsafe { DataFrame::new_no_checks(cols) }) + Some(unsafe { DataFrame::new_no_checks_height_from_first(cols) }) }) .collect::>(); Ok(dfs) @@ -458,10 +458,11 @@ where fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { let dfs = self.pre_finalize()?; let payload = if self.ooc_state.ooc { - let mut iot = self.ooc_state.io_thread.lock().unwrap(); - // make sure that we reset the shared states - // the OOC group_by will call split as well and it should - // not send continue spilling to disk + let mut guard = self.ooc_state.io_thread.lock().unwrap(); + // Type hint fixes rust-analyzer thinking .take() is an iterator method. + let iot: &mut Option<_> = &mut *guard; + // Make sure that we reset the shared states. The OOC group_by will + // call split as well and it should not send continue spilling to disk. let iot = iot.take().unwrap(); self.ooc_state.ooc = false; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/string.rs b/crates/polars-pipe/src/executors/sinks/group_by/string.rs index 0855e4cbf42d..a6254ba7fdaf 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/string.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/string.rs @@ -216,7 +216,7 @@ impl StringGroupbySink { .map(|buf| buf.into_series().into_column()), ); physical_agg_to_logical(&mut cols, &self.output_schema); - Some(unsafe { DataFrame::new_no_checks(cols) }) + Some(unsafe { DataFrame::new_no_checks_height_from_first(cols) }) }) .collect::>(); diff --git a/crates/polars-pipe/src/executors/sinks/group_by/utils.rs b/crates/polars-pipe/src/executors/sinks/group_by/utils.rs index dd0d8ce4aea0..ab9247e7a17b 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/utils.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/utils.rs @@ -59,9 +59,9 @@ pub(super) fn finalize_group_by( let df = if dfs.is_empty() { DataFrame::empty_with_schema(output_schema) } else { - let mut df = accumulate_dataframes_vertical_unchecked(dfs); + let df = accumulate_dataframes_vertical_unchecked(dfs); // re init to check duplicates - unsafe { DataFrame::new(std::mem::take(df.get_columns_mut())) }? + DataFrame::new(df.take_columns())? }; match ooc_payload { diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs index d6014c344978..77466578d3e9 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/cross.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -111,7 +111,7 @@ impl Operator for CrossJoinProbe { _context: &PExecutionContext, chunk: &DataChunk, ) -> PolarsResult { - // Expected output is size**2, so this needs to be a a small number. + // Expected output is size**2, so this needs to be a small number. // However, if one of the DataFrames is much smaller than 250, we want // to take rather more from the other DataFrame so we don't end up with // overly small chunks. diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs index 2ab417ad2096..0b7cbfb2c534 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs @@ -121,6 +121,9 @@ impl GenericFullOuterJoinProbe { left_df .get_columns_mut() .extend_from_slice(right_df.get_columns()); + + // @TODO: Is this actually the case? + // SAFETY: output_names should be unique. left_df .get_columns_mut() .iter_mut() @@ -265,6 +268,7 @@ impl GenericFullOuterJoinProbe { let right_df = unsafe { DataFrame::new_no_checks( + size, right_df .get_columns() .iter() diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 43589c9783a1..49d51cc2e2fb 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -20,7 +20,7 @@ pub struct SortSink { schema: SchemaRef, chunks: Vec, // Stores available memory in the system at the start of this sink. - // and stores the memory used by this this sink. + // and stores the memory used by this sink. mem_track: MemTracker, // sort in-memory or out-of-core ooc: bool, diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 7d32d10961c3..7c0a35db38a1 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,8 +1,8 @@ use std::any::Any; use arrow::array::BinaryArray; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_compat_array; use polars_core::prelude::sort::_broadcast_bools; -use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_row::decode::decode_rows_from_binary; @@ -65,55 +65,56 @@ fn finalize_dataframe( sort_fields: &[EncodingField], schema: &Schema, ) { - unsafe { - let cols = df.get_columns_mut(); - // pop the encoded sort column - let encoded = cols.pop().unwrap(); - - // we decode the row-encoded binary column - // this will be decoded into multiple columns - // these are the columns we sorted by - // those need to be inserted at the `sort_idx` position - // in the `DataFrame`. - if can_decode { - let sort_dtypes = sort_dtypes.expect("should be set if 'can_decode'"); - - let encoded = encoded.binary_offset().unwrap(); - assert_eq!(encoded.chunks().len(), 1); - let arr = encoded.downcast_iter().next().unwrap(); - - // SAFETY: - // temporary extend lifetime - // this is safe as the lifetime in rows stays bound to this scope - let arrays = { - let arr = - std::mem::transmute::<&'_ BinaryArray, &'static BinaryArray>(arr); - decode_rows_from_binary(arr, sort_fields, sort_dtypes, rows) - }; - rows.clear(); - - let arrays = sort_by_idx(&arrays, sort_idx); - let mut sort_idx = sort_idx.to_vec(); - sort_idx.sort_unstable(); - - for (sort_idx, arr) in sort_idx.into_iter().zip(arrays) { - let (name, logical_dtype) = schema.get_at_index(sort_idx).unwrap(); - assert_eq!(logical_dtype.to_physical(), DataType::from(arr.dtype())); - let col = - Series::from_chunks_and_dtype_unchecked(name.clone(), vec![arr], logical_dtype) - .into_column(); - cols.insert(sort_idx, col); + // pop the encoded sort column + // SAFETY: We only pop a value + let encoded = unsafe { df.get_columns_mut() }.pop().unwrap(); + + // we decode the row-encoded binary column + // this will be decoded into multiple columns + // these are the columns we sorted by + // those need to be inserted at the `sort_idx` position + // in the `DataFrame`. + if can_decode { + let sort_dtypes = sort_dtypes.expect("should be set if 'can_decode'"); + + let encoded = encoded.binary_offset().unwrap(); + assert_eq!(encoded.chunks().len(), 1); + let arr = encoded.downcast_iter().next().unwrap(); + + // SAFETY: + // temporary extend lifetime + // this is safe as the lifetime in rows stays bound to this scope + let arrays = unsafe { + let arr = std::mem::transmute::<&'_ BinaryArray, &'static BinaryArray>(arr); + decode_rows_from_binary(arr, sort_fields, sort_dtypes, rows) + }; + rows.clear(); + + let arrays = sort_by_idx(&arrays, sort_idx); + let mut sort_idx = sort_idx.to_vec(); + sort_idx.sort_unstable(); + + for (sort_idx, arr) in sort_idx.into_iter().zip(arrays) { + let (name, logical_dtype) = schema.get_at_index(sort_idx).unwrap(); + assert_eq!(logical_dtype.to_physical(), DataType::from(arr.dtype())); + let col = unsafe { + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![arr], logical_dtype) } - } + .into_column(); - let first_sort_col = &mut cols[sort_idx[0]]; - let flag = if sort_options.descending[0] { - IsSorted::Descending - } else { - IsSorted::Ascending - }; - first_sort_col.set_sorted_flag(flag) + // SAFETY: col has the same length as the df height because it was popped from df. + unsafe { df.get_columns_mut() }.insert(sort_idx, col); + } } + + // SAFETY: We just change the sorted flag. + let first_sort_col = &mut unsafe { df.get_columns_mut() }[sort_idx[0]]; + let flag = if sort_options.descending[0] { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + first_sort_col.set_sorted_flag(flag) } /// This struct will dispatch all sorting to `SortSink` @@ -200,12 +201,11 @@ impl SortSinkMultiple { fn encode(&mut self, chunk: &mut DataChunk) -> PolarsResult<()> { let df = &mut chunk.data; - let cols = unsafe { df.get_columns_mut() }; self.sort_column.clear(); for i in self.sort_idx.iter() { - let s = &cols[*i]; + let s = &df.get_columns()[*i]; let arr = _get_rows_encoded_compat_array(s.as_materialized_series())?; self.sort_column.push(arr); } @@ -216,6 +216,9 @@ impl SortSinkMultiple { let mut sorted_sort_idx = self.sort_idx.to_vec(); sorted_sort_idx.sort_unstable(); + // SAFETY: We do not adjust the names or lengths or columns. + let cols = unsafe { df.get_columns_mut() }; + sorted_sort_idx .into_iter() .enumerate() diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index 323776deb976..900be25256b4 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -216,14 +216,11 @@ impl Source for CsvSource { }; for data_chunk in &mut out { - // The batched reader creates the column containing all nulls because the schema it - // gets passed contains the column. - for s in unsafe { data_chunk.data.get_columns_mut() } { - if s.name() == ca.name() { - *s = ca.slice(0, s.len()).into_column(); - break; - } - } + let n = data_chunk.data.height(); + // SAFETY: Columns are only replaced with columns + // 1. of the same name, and + // 2. of the same length. + unsafe { data_chunk.data.get_columns_mut() }.push(ca.slice(0, n).into_column()) } } diff --git a/crates/polars-pipe/src/operators/chunks.rs b/crates/polars-pipe/src/operators/chunks.rs index c1f63019a611..d237510b4636 100644 --- a/crates/polars-pipe/src/operators/chunks.rs +++ b/crates/polars-pipe/src/operators/chunks.rs @@ -39,7 +39,7 @@ pub(crate) fn chunks_to_df_unchecked(chunks: Vec) -> DataFrame { /// /// The benefit of having a series of `DataFrame` that are e.g. 4MB each that /// are then made contiguous is that you're not using a lot of memory (an extra -/// 4MB), but you're still doing better than if you had a series of of 2KB +/// 4MB), but you're still doing better than if you had a series of 2KB /// `DataFrame`s. /// /// Changing the `DataFrame` into contiguous chunks is the caller's diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 0a6a8946feba..b0b19aa26708 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -26,10 +26,10 @@ fn exprs_to_physical( exprs: &[ExprIR], expr_arena: &Arena, to_physical: &F, - schema: Option<&SchemaRef>, + schema: &SchemaRef, ) -> PolarsResult>> where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { exprs .iter() @@ -47,7 +47,7 @@ fn get_source( verbose: bool, ) -> PolarsResult> where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { use IR::*; match source { @@ -58,9 +58,12 @@ where .. } => { let mut df = (*df).clone(); + let schema = output_schema + .clone() + .unwrap_or_else(|| Arc::new(df.schema())); if push_predicate { if let Some(predicate) = selection { - let predicate = to_physical(&predicate, expr_arena, output_schema.as_ref())?; + let predicate = to_physical(&predicate, expr_arena, &schema)?; let op = operators::FilterOperator { predicate }; let op = Box::new(op) as Box; operator_objects.push(op) @@ -83,6 +86,7 @@ where scan_type, } => { let paths = sources.into_paths(); + let schema = output_schema.as_ref().unwrap_or(&file_info.schema); // Add predicate to operators. // Except for parquet, as that format can use statistics to prune file/row-groups. @@ -95,7 +99,7 @@ where { #[cfg(feature = "parquet")] debug_assert!(!matches!(scan_type, FileScan::Parquet { .. })); - let predicate = to_physical(&predicate, expr_arena, output_schema.as_ref())?; + let predicate = to_physical(&predicate, expr_arena, schema)?; let op = operators::FilterOperator { predicate }; let op = Box::new(op) as Box; operator_objects.push(op) @@ -105,7 +109,7 @@ where FileScan::Csv { options, .. } => { let src = sources::CsvSource::new( sources, - file_info.schema, + file_info.reader_schema.clone().unwrap().unwrap_right(), options, file_options, verbose, @@ -121,7 +125,7 @@ where let predicate = predicate .as_ref() .map(|predicate| { - let p = to_physical(predicate, expr_arena, output_schema.as_ref())?; + let p = to_physical(predicate, expr_arena, schema)?; // Arc's all the way down. :( // Temporarily until: https://github.com/rust-lang/rust/issues/65991 // stabilizes @@ -173,7 +177,7 @@ pub fn get_sink( callbacks: &mut CallBacks, ) -> PolarsResult> where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { use IR::*; let out = match lp_arena.get(node) { @@ -272,14 +276,14 @@ where left_on, expr_arena, to_physical, - Some(input_schema_left.as_ref()), + input_schema_left.as_ref(), )?); let input_schema_right = lp_arena.get(*input_right).schema(lp_arena); let join_columns_right = Arc::new(exprs_to_physical( right_on, expr_arena, to_physical, - Some(input_schema_right.as_ref()), + input_schema_right.as_ref(), )?); let swap_eval = || { @@ -448,7 +452,7 @@ where &keys, expr_arena, to_physical, - Some(&input_schema), + &input_schema, )?); let mut aggregation_columns = Vec::with_capacity(aggs.len()); @@ -488,7 +492,7 @@ where keys, expr_arena, to_physical, - Some(&input_schema), + &input_schema, )?); let mut aggregation_columns = Vec::with_capacity(aggs.len()); @@ -568,10 +572,10 @@ fn get_hstack( options: ProjectionOptions, ) -> PolarsResult where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { - Ok(operators::HstackOperator { - exprs: exprs_to_physical(exprs, expr_arena, &to_physical, Some(&input_schema))?, + Ok(HstackOperator { + exprs: exprs_to_physical(exprs, expr_arena, &to_physical, &input_schema)?, input_schema, options, }) @@ -584,7 +588,7 @@ pub fn get_operator( to_physical: &F, ) -> PolarsResult> where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { use IR::*; let op = match lp_arena.get(node) { @@ -602,7 +606,7 @@ where } => { let input_schema = lp_arena.get(*input).schema(lp_arena); let op = operators::ProjectionOperator { - exprs: exprs_to_physical(expr, expr_arena, &to_physical, Some(&input_schema))?, + exprs: exprs_to_physical(expr, expr_arena, &to_physical, input_schema.as_ref())?, options: *options, }; Box::new(op) as Box @@ -627,7 +631,7 @@ where }, Filter { predicate, input } => { let input_schema = lp_arena.get(*input).schema(lp_arena); - let predicate = to_physical(predicate, expr_arena, Some(input_schema.as_ref()))?; + let predicate = to_physical(predicate, expr_arena, input_schema.as_ref())?; let op = operators::FilterOperator { predicate }; Box::new(op) as Box }, @@ -662,7 +666,7 @@ pub fn create_pipeline( callbacks: &mut CallBacks, ) -> PolarsResult where - F: Fn(&ExprIR, &Arena, Option<&SchemaRef>) -> PolarsResult>, + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, { use IR::*; diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 499bb115396a..9c5a3fe32913 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -34,6 +34,7 @@ either = { workspace = true } futures = { workspace = true, optional = true } hashbrown = { workspace = true } memmap = { workspace = true } +num-traits = { workspace = true } once_cell = { workspace = true } percent-encoding = { workspace = true } pyo3 = { workspace = true, optional = true } @@ -50,7 +51,7 @@ version_check = { workspace = true } [features] # debugging utility debugging = [] -python = ["dep:pyo3", "ciborium"] +python = ["dep:pyo3", "ciborium", "polars-utils/python"] serde = [ "dep:serde", "polars-core/serde-lazy", @@ -185,7 +186,7 @@ month_start = ["polars-time/month_start"] month_end = ["polars-time/month_end"] offset_by = ["polars-time/offset_by"] -bigidx = ["polars-core/bigidx"] +bigidx = ["polars-core/bigidx", "polars-utils/bigidx"] polars_cloud = ["serde", "ciborium"] ir_serde = ["serde", "polars-utils/ir_serde"] diff --git a/crates/polars-plan/src/client/check.rs b/crates/polars-plan/src/client/check.rs index f76f508643f0..97b1ed23da2e 100644 --- a/crates/polars-plan/src/client/check.rs +++ b/crates/polars-plan/src/client/check.rs @@ -6,6 +6,9 @@ use crate::plans::{DslPlan, FileScan, ScanSources}; /// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud. pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> { + if std::env::var("POLARS_SKIP_CLIENT_CHECK").as_deref() == Ok("1") { + return Ok(()); + } for plan_node in dsl.into_iter() { match plan_node { #[cfg(feature = "python")] diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 33eb20e86da6..32fa45528d3a 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -33,7 +33,7 @@ pub enum AggExpr { Quantile { expr: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Arc), AggGroups(Arc), diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 2c0acfeeb13b..483dafcc83f1 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -71,7 +71,7 @@ impl<'a> Deserialize<'a> for SpecialEq> { { let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(udf)) @@ -399,7 +399,7 @@ impl<'a> Deserialize<'a> for GetOutput { { let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(get_output)) diff --git a/crates/polars-plan/src/dsl/function_expr/bitwise.rs b/crates/polars-plan/src/dsl/function_expr/bitwise.rs index 2d4dd779cff0..1f0be9247993 100644 --- a/crates/polars-plan/src/dsl/function_expr/bitwise.rs +++ b/crates/polars-plan/src/dsl/function_expr/bitwise.rs @@ -2,6 +2,7 @@ use std::fmt; use std::sync::Arc; use polars_core::prelude::*; +use strum_macros::IntoStaticStr; use super::{ColumnsUdf, SpecialEq}; use crate::dsl::FieldsMapper; @@ -21,7 +22,8 @@ pub enum BitwiseFunction { } #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)] +#[strum(serialize_all = "snake_case")] pub enum BitwiseAggFunction { And, Or, diff --git a/crates/polars-plan/src/dsl/function_expr/coerce.rs b/crates/polars-plan/src/dsl/function_expr/coerce.rs index bd03ede32c84..2ade8737d077 100644 --- a/crates/polars-plan/src/dsl/function_expr/coerce.rs +++ b/crates/polars-plan/src/dsl/function_expr/coerce.rs @@ -1,5 +1,22 @@ use polars_core::prelude::*; -pub fn as_struct(s: &[Column]) -> PolarsResult { - Ok(StructChunked::from_columns(s[0].name().clone(), s)?.into_column()) +pub fn as_struct(cols: &[Column]) -> PolarsResult { + let Some(fst) = cols.first() else { + polars_bail!(nyi = "turning no columns as_struct"); + }; + + let mut min_length = usize::MAX; + let mut max_length = usize::MIN; + + for col in cols { + let len = col.len(); + + min_length = min_length.min(len); + max_length = max_length.max(len); + } + + // @NOTE: Any additional errors should be handled by the StructChunked::from_columns + let length = if min_length == 0 { 0 } else { max_length }; + + Ok(StructChunked::from_columns(fst.name().clone(), length, cols)?.into_column()) } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 9225dfa4cae0..ddf8fb1fff20 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -4,7 +4,7 @@ use polars_ops::chunked_array::list::*; use super::*; use crate::{map, map_as_slice, wrap}; -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ListFunction { Concat, @@ -56,6 +56,8 @@ pub enum ListFunction { Join(bool), #[cfg(feature = "dtype-array")] ToArray(usize), + #[cfg(feature = "list_to_struct")] + ToStruct(ListToStructArgs), } impl ListFunction { @@ -103,6 +105,8 @@ impl ListFunction { #[cfg(feature = "dtype-array")] ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), NUnique => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), } } } @@ -174,6 +178,8 @@ impl Display for ListFunction { Join(_) => "join", #[cfg(feature = "dtype-array")] ToArray(_) => "to_array", + #[cfg(feature = "list_to_struct")] + ToStruct(_) => "to_struct", }; write!(f, "list.{name}") } @@ -235,6 +241,8 @@ impl From for SpecialEq> { #[cfg(feature = "dtype-array")] ToArray(width) => map!(to_array, width), NUnique => map!(n_unique), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => map!(to_struct, &args), } } } @@ -503,7 +511,7 @@ pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult let idx = &args[1]; let ca = ca.list()?; - if idx.len() == 1 && null_on_oob { + if idx.len() == 1 && idx.dtype().is_numeric() && null_on_oob { // fast path let idx = idx.get(0)?.try_extract::()?; let out = ca.lst_get(idx, null_on_oob).map(Column::from)?; @@ -650,6 +658,11 @@ pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult { s.cast(&array_dtype) } +#[cfg(feature = "list_to_struct")] +pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult { + Ok(s.list()?.to_struct(args)?.into_series().into()) +} + pub(super) fn n_unique(s: &Column) -> PolarsResult { Ok(s.list()?.lst_n_unique()?.into_column()) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 0458b2b4a1d0..17dc82e23e8c 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -47,6 +47,7 @@ pub mod pow; mod random; #[cfg(feature = "range")] mod range; +mod repeat; #[cfg(feature = "rolling_window")] pub mod rolling; #[cfg(feature = "rolling_window_by")] @@ -189,6 +190,7 @@ pub enum FunctionExpr { options: RankOptions, seed: Option, }, + Repeat, #[cfg(feature = "round_series")] Clip { has_min: bool, @@ -452,6 +454,7 @@ impl Hash for FunctionExpr { a.hash(state); b.hash(state); }, + Repeat => {}, #[cfg(feature = "rank")] Rank { options, seed } => { options.hash(state); @@ -651,6 +654,7 @@ impl Display for FunctionExpr { #[cfg(feature = "moment")] Kurtosis(..) => "kurtosis", ArgUnique => "arg_unique", + Repeat => "repeat", #[cfg(feature = "rank")] Rank { .. } => "rank", #[cfg(feature = "round_series")] @@ -950,6 +954,19 @@ impl From for SpecialEq> { Std(options) => map!(rolling::rolling_std, options.clone()), #[cfg(feature = "moment")] Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias), + #[cfg(feature = "cov")] + CorrCov { + rolling_options, + corr_cov_options, + is_corr, + } => { + map_as_slice!( + rolling::rolling_corr_cov, + rolling_options.clone(), + corr_cov_options, + is_corr + ) + }, } }, #[cfg(feature = "rolling_window_by")] @@ -996,6 +1013,7 @@ impl From for SpecialEq> { #[cfg(feature = "moment")] Kurtosis(fisher, bias) => map!(dispatch::kurtosis, fisher, bias), ArgUnique => map!(dispatch::arg_unique), + Repeat => map_as_slice!(repeat::repeat), #[cfg(feature = "rank")] Rank { options, seed } => map!(dispatch::rank, options, seed), #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-plan/src/dsl/function_expr/pow.rs b/crates/polars-plan/src/dsl/function_expr/pow.rs index 44394e9ae10a..912aafc762d9 100644 --- a/crates/polars-plan/src/dsl/function_expr/pow.rs +++ b/crates/polars-plan/src/dsl/function_expr/pow.rs @@ -1,8 +1,8 @@ -use arrow::legacy::kernels::pow::pow as pow_kernel; use num::pow::Pow; +use num_traits::{One, Zero}; use polars_core::export::num; use polars_core::export::num::{Float, ToPrimitive}; -use polars_core::prelude::arity::unary_elementwise_values; +use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values}; use polars_core::with_match_physical_integer_type; use super::*; @@ -29,30 +29,27 @@ impl Display for PowFunction { fn pow_on_chunked_arrays( base: &ChunkedArray, exponent: &ChunkedArray, -) -> PolarsResult> +) -> ChunkedArray where T: PolarsNumericType, F: PolarsNumericType, T::Native: num::pow::Pow + ToPrimitive, - ChunkedArray: IntoColumn, { - if (base.len() == 1) && (exponent.len() != 1) { - let name = base.name(); - let base = base - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "base is null"))?; - - Ok(Some( - unary_elementwise_values(exponent, |exp| Pow::pow(base, exp)) - .into_column() - .with_name(name.clone()), - )) - } else { - Ok(Some( - polars_core::chunked_array::ops::arity::binary(base, exponent, pow_kernel) - .into_column(), - )) + if exponent.len() == 1 { + if let Some(e) = exponent.get(0) { + if e == F::Native::zero() { + return unary_elementwise_values(base, |_| T::Native::one()); + } + if e == F::Native::one() { + return base.clone(); + } + if e == F::Native::one() + F::Native::one() { + return base * base; + } + } } + + broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?))) } fn pow_on_floats( @@ -93,7 +90,7 @@ where }; Ok(Some(s)) } else { - pow_on_chunked_arrays(base, exponent) + Ok(Some(pow_on_chunked_arrays(base, exponent).into_column())) } } @@ -133,7 +130,7 @@ where }; Ok(Some(s)) } else { - pow_on_chunked_arrays(base, exponent) + Ok(Some(pow_on_chunked_arrays(base, exponent).into_column())) } } diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs index a61264ce7aca..e220a7107435 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -223,7 +223,7 @@ pub(super) fn datetime_ranges( out.cast(&to_type).map(Column::from) } -impl<'a> FieldsMapper<'a> { +impl FieldsMapper<'_> { pub(super) fn map_to_datetime_range_dtype( &self, time_unit: Option<&TimeUnit>, diff --git a/crates/polars-plan/src/dsl/function_expr/repeat.rs b/crates/polars-plan/src/dsl/function_expr/repeat.rs new file mode 100644 index 000000000000..cebc2ce792c4 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/repeat.rs @@ -0,0 +1,18 @@ +use polars_core::prelude::{polars_ensure, polars_err, Column, PolarsResult}; + +pub fn repeat(args: &[Column]) -> PolarsResult { + let c = &args[0]; + let n = &args[1]; + + polars_ensure!( + n.dtype().is_integer(), + SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() + ); + + let first_value = n.get(0)?; + let n = first_value.extract::().ok_or_else( + || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), + )?; + + Ok(c.new_from_index(0, n)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index c108c92b571a..44af4f1b1510 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -1,6 +1,12 @@ +#[cfg(feature = "cov")] +use std::ops::BitAnd; + +use polars_core::utils::Container; use polars_time::chunkedarray::*; use super::*; +#[cfg(feature = "cov")] +use crate::dsl::pow::pow; #[derive(Clone, PartialEq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -14,6 +20,13 @@ pub enum RollingFunction { Std(RollingOptionsFixedWindow), #[cfg(feature = "moment")] Skew(usize, bool), + #[cfg(feature = "cov")] + CorrCov { + rolling_options: RollingOptionsFixedWindow, + corr_cov_options: RollingCovOptions, + // Whether is Corr or Cov + is_corr: bool, + }, } impl Display for RollingFunction { @@ -30,6 +43,14 @@ impl Display for RollingFunction { Std(_) => "rolling_std", #[cfg(feature = "moment")] Skew(..) => "rolling_skew", + #[cfg(feature = "cov")] + CorrCov { is_corr, .. } => { + if *is_corr { + "rolling_corr" + } else { + "rolling_cov" + } + }, }; write!(f, "{name}") @@ -47,6 +68,10 @@ impl Hash for RollingFunction { window_size.hash(state); bias.hash(state) }, + #[cfg(feature = "cov")] + CorrCov { is_corr, .. } => { + is_corr.hash(state); + }, _ => {}, } } @@ -111,3 +136,100 @@ pub(super) fn rolling_skew(s: &Column, window_size: usize, bias: bool) -> Polars .rolling_skew(window_size, bias) .map(Column::from) } + +#[cfg(feature = "cov")] +fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series { + match dtype { + DataType::Float64 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f64) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + DataType::Float32 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f32) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + _ => unreachable!(), + } +} + +#[cfg(feature = "cov")] +pub(super) fn rolling_corr_cov( + s: &[Column], + rolling_options: RollingOptionsFixedWindow, + cov_options: RollingCovOptions, + is_corr: bool, +) -> PolarsResult { + let mut x = s[0].as_materialized_series().rechunk(); + let mut y = s[1].as_materialized_series().rechunk(); + + if !x.dtype().is_float() { + x = x.cast(&DataType::Float64)?; + } + if !y.dtype().is_float() { + y = y.cast(&DataType::Float64)?; + } + let dtype = x.dtype().clone(); + + let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?; + let rolling_options_count = RollingOptionsFixedWindow { + window_size: rolling_options.window_size, + min_periods: 0, + ..Default::default() + }; + + let count_x_y = if (x.null_count() + y.null_count()) > 0 { + // mask out nulls on both sides before compute mean/var + let valids = x.is_not_null().bitand(y.is_not_null()); + let valids_arr = valids.clone().downcast_into_array(); + let valids_bitmap = valids_arr.values(); + + unsafe { + let xarr = &mut x.chunks_mut()[0]; + *xarr = xarr.with_validity(Some(valids_bitmap.clone())); + let yarr = &mut y.chunks_mut()[0]; + *yarr = yarr.with_validity(Some(valids_bitmap.clone())); + x.compute_len(); + y.compute_len(); + } + valids + .cast(&dtype) + .unwrap() + .rolling_sum(rolling_options_count)? + } else { + det_count_x_y(rolling_options.window_size, x.len(), &dtype) + }; + + let mean_x = x.rolling_mean(rolling_options.clone())?; + let mean_y = y.rolling_mean(rolling_options.clone())?; + let ddof = Series::new( + PlSmallStr::EMPTY, + &[AnyValue::from(cov_options.ddof).cast(&dtype)], + ); + + let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap() + * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap()) + .unwrap(); + + if is_corr { + let var_x = x.rolling_var(rolling_options.clone())?; + let var_y = y.rolling_var(rolling_options.clone())?; + + let base = (var_x * var_y).unwrap(); + let sc = Scalar::new( + base.dtype().clone(), + AnyValue::Float64(0.5).cast(&dtype).into_static(), + ); + let denominator = pow(&mut [base.into_column(), sc.into_column("".into())]) + .unwrap() + .unwrap() + .take_materialized_series(); + + Ok((numerator / denominator)?.into_column()) + } else { + Ok(numerator.into_column()) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 7cc5b8c5c7ad..606ab81207c4 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -66,6 +66,8 @@ impl FunctionExpr { match rolling_func { Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(), Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(), + #[cfg(feature = "cov")] + CorrCov {..} => mapper.map_to_float_dtype(), #[cfg(feature = "moment")] Skew(..) => mapper.map_to_float_dtype(), } @@ -90,6 +92,7 @@ impl FunctionExpr { #[cfg(feature = "moment")] Kurtosis(..) => mapper.with_dtype(DataType::Float64), ArgUnique => mapper.with_dtype(IDX_DTYPE), + Repeat => mapper.with_same_dtype(), #[cfg(feature = "rank")] Rank { options, .. } => mapper.with_dtype(match options.method { RankMethod::Average => DataType::Float64, diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index ba06dc00e67c..039c995557be 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -9,7 +9,7 @@ use polars_core::utils::handle_casting_failures; #[cfg(feature = "dtype-struct")] use polars_utils::format_pl_smallstr; #[cfg(feature = "regex")] -use regex::{escape, Regex}; +use regex::{escape, NoExpand, Regex}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -130,6 +130,8 @@ pub enum StringFunction { ascii_case_insensitive: bool, overlapping: bool, }, + #[cfg(feature = "regex")] + EscapeRegex, } impl StringFunction { @@ -197,6 +199,8 @@ impl StringFunction { ReplaceMany { .. } => mapper.with_same_dtype(), #[cfg(feature = "find_many")] ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))), + #[cfg(feature = "regex")] + EscapeRegex => mapper.with_same_dtype(), } } } @@ -285,6 +289,8 @@ impl Display for StringFunction { ReplaceMany { .. } => "replace_many", #[cfg(feature = "find_many")] ExtractMany { .. } => "extract_many", + #[cfg(feature = "regex")] + EscapeRegex => "escape_regex", }; write!(f, "str.{s}") } @@ -400,6 +406,8 @@ impl From for SpecialEq> { } => { map_as_slice!(extract_many, ascii_case_insensitive, overlapping) }, + #[cfg(feature = "regex")] + EscapeRegex => map!(escape_regex), } } } @@ -836,20 +844,26 @@ fn replace_n<'a>( "replacement value length ({}) does not match string column length ({})", len_val, ca.len(), ); - let literal = literal || is_literal_pat(&pat); + let lit = is_literal_pat(&pat); + let literal_pat = literal || lit; - if literal { + if literal_pat { pat = escape(&pat) } let reg = Regex::new(&pat)?; - let lit = pat.chars().all(|c| !c.is_ascii_punctuation()); let f = |s: &'a str, val: &'a str| { if lit && (s.len() <= 32) { Cow::Owned(s.replacen(&pat, val, 1)) } else { - reg.replace(s, val) + // According to the docs for replace + // when literal = True then capture groups are ignored. + if literal { + reg.replace(s, NoExpand(val)) + } else { + reg.replace(s, val) + } } }; Ok(iter_and_replace(ca, val, f)) @@ -888,15 +902,25 @@ fn replace_all<'a>( "replacement value length ({}) does not match string column length ({})", len_val, ca.len(), ); - let literal = literal || is_literal_pat(&pat); - if literal { + let literal_pat = literal || is_literal_pat(&pat); + + if literal_pat { pat = escape(&pat) } let reg = Regex::new(&pat)?; - let f = |s: &'a str, val: &'a str| reg.replace_all(s, val); + let f = |s: &'a str, val: &'a str| { + // According to the docs for replace_all + // when literal = True then capture groups are ignored. + if literal { + reg.replace_all(s, NoExpand(val)) + } else { + reg.replace_all(s, val) + } + }; + Ok(iter_and_replace(ca, val, f)) }, _ => polars_bail!( @@ -1023,3 +1047,9 @@ pub(super) fn json_path_match(s: &[Column]) -> PolarsResult { let pat = s[1].str()?; Ok(ca.json_path_match(pat)?.into_column()) } + +#[cfg(feature = "regex")] +pub(super) fn escape_regex(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_escape_regex().into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/struct_.rs b/crates/polars-plan/src/dsl/function_expr/struct_.rs index acc8020b8e7e..23c8c961f9b2 100644 --- a/crates/polars-plan/src/dsl/function_expr/struct_.rs +++ b/crates/polars-plan/src/dsl/function_expr/struct_.rs @@ -176,7 +176,7 @@ pub(super) fn rename_fields(s: &Column, names: Arc<[PlSmallStr]>) -> PolarsResul s }) .collect::>(); - let mut out = StructChunked::from_series(ca.name().clone(), fields.iter())?; + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?; out.zip_outer_validity(ca); Ok(out.into_column()) } @@ -193,7 +193,7 @@ pub(super) fn prefix_fields(s: &Column, prefix: &str) -> PolarsResult { s }) .collect::>(); - let mut out = StructChunked::from_series(ca.name().clone(), fields.iter())?; + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?; out.zip_outer_validity(ca); Ok(out.into_column()) } @@ -210,7 +210,7 @@ pub(super) fn suffix_fields(s: &Column, suffix: &str) -> PolarsResult { s }) .collect::>(); - let mut out = StructChunked::from_series(ca.name().clone(), fields.iter())?; + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?; out.zip_outer_validity(ca); Ok(out.into_column()) } @@ -245,7 +245,8 @@ pub(super) fn with_fields(args: &[Column]) -> PolarsResult { } let new_fields = fields.into_values().cloned().collect::>(); - let mut out = StructChunked::from_series(ca.name().clone(), new_fields.iter())?; + let mut out = + StructChunked::from_series(ca.name().clone(), new_fields[0].len(), new_fields.iter())?; out.zip_outer_validity(ca); Ok(out.into_column()) } diff --git a/crates/polars-plan/src/dsl/function_expr/trigonometry.rs b/crates/polars-plan/src/dsl/function_expr/trigonometry.rs index c0d83822aef9..5398f43f4323 100644 --- a/crates/polars-plan/src/dsl/function_expr/trigonometry.rs +++ b/crates/polars-plan/src/dsl/function_expr/trigonometry.rs @@ -1,5 +1,5 @@ -use arrow::legacy::kernels::atan2::atan2 as atan2_kernel; use num::Float; +use polars_core::chunked_array::ops::arity::broadcast_binary_elementwise; use polars_core::export::num; use super::*; @@ -117,23 +117,9 @@ where .unpack_series_matching_type(x.as_materialized_series()) .unwrap(); - if x.len() == 1 { - let x_value = x - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "arctan2 x value is null"))?; - - Ok(Some(y.apply_values(|v| v.atan2(x_value)).into_column())) - } else if y.len() == 1 { - let y_value = y - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "arctan2 y value is null"))?; - - Ok(Some(x.apply_values(|v| y_value.atan2(v)).into_column())) - } else { - Ok(Some( - polars_core::prelude::arity::binary(y, x, atan2_kernel).into_column(), - )) - } + Ok(Some( + broadcast_binary_elementwise(y, x, |yv, xv| Some(yv?.atan2(xv?))).into_column(), + )) } fn apply_trigonometric_function_to_float( diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index dd7521ad20a9..97e14f5df2f8 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -70,8 +70,8 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E } } -#[cfg(feature = "rolling_window")] -pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { +#[cfg(all(feature = "rolling_window", feature = "cov"))] +fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804 let rolling_options = RollingOptionsFixedWindow { window_size: options.window_size as usize, @@ -79,59 +79,23 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; - let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) - .then(lit(1.0)) - .otherwise(lit(Null {})); - - let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let var_x = (x.clone() * non_null_mask.clone()).rolling_var(rolling_options.clone()); - let var_y = (y.clone() * non_null_mask.clone()).rolling_var(rolling_options); - - let rolling_options_count = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: 0, - ..Default::default() - }; - let ddof = options.ddof as f64; - let count_x_y = (x + y) - .is_not_null() - .cast(DataType::Float64) - .rolling_sum(rolling_options_count); - let numerator = (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof))); - let denominator = (var_x * var_y).pow(lit(0.5)); + Expr::Function { + input: vec![x, y], + function: FunctionExpr::RollingExpr(RollingFunction::CorrCov { + rolling_options, + corr_cov_options: options, + is_corr, + }), + options: Default::default(), + } +} - numerator / denominator +#[cfg(all(feature = "rolling_window", feature = "cov"))] +pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { + dispatch_corr_cov(x, y, options, true) } -#[cfg(feature = "rolling_window")] +#[cfg(all(feature = "rolling_window", feature = "cov"))] pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { - // see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700 - let rolling_options = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: options.min_periods as usize, - ..Default::default() - }; - - let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) - .then(lit(1.0)) - .otherwise(lit(Null {})); - - let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options); - let rolling_options_count = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: 0, - ..Default::default() - }; - let count_x_y = (x + y) - .is_not_null() - .cast(DataType::Float64) - .rolling_sum(rolling_options_count); - - let ddof = options.ddof as f64; - - (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof))) + dispatch_corr_cov(x, y, options, false) } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 542212f8de82..26b6209a720e 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -22,7 +22,7 @@ fn cum_fold_dtype() -> GetOutput { /// Accumulate over multiple columns horizontally / row wise. pub fn fold_exprs(acc: Expr, f: F, exprs: E) -> Expr where - F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync + Clone, + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, E: AsRef<[Expr]>, { let mut exprs = exprs.as_ref().to_vec(); @@ -62,7 +62,7 @@ where /// `collect` is called. pub fn reduce_exprs(f: F, exprs: E) -> Expr where - F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync + Clone, + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, E: AsRef<[Expr]>, { let exprs = exprs.as_ref().to_vec(); @@ -104,7 +104,7 @@ where #[cfg(feature = "dtype-struct")] pub fn cum_reduce_exprs(f: F, exprs: E) -> Expr where - F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync + Clone, + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, E: AsRef<[Expr]>, { let exprs = exprs.as_ref().to_vec(); @@ -126,7 +126,7 @@ where result.push(acc.clone()); } - StructChunked::from_columns(acc.name().clone(), &result) + StructChunked::from_columns(acc.name().clone(), result[0].len(), &result) .map(|ca| Some(ca.into_column())) }, None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")), @@ -152,7 +152,7 @@ where #[cfg(feature = "dtype-struct")] pub fn cum_fold_exprs(acc: Expr, f: F, exprs: E, include_init: bool) -> Expr where - F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync + Clone, + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, E: AsRef<[Expr]>, { let mut exprs = exprs.as_ref().to_vec(); @@ -176,7 +176,8 @@ where } } - StructChunked::from_columns(acc.name().clone(), &result).map(|ca| Some(ca.into_column())) + StructChunked::from_columns(acc.name().clone(), result[0].len(), &result) + .map(|ca| Some(ca.into_column())) }); Expr::AnonymousFunction { diff --git a/crates/polars-plan/src/dsl/functions/repeat.rs b/crates/polars-plan/src/dsl/functions/repeat.rs index 21d27a542e99..ea80e2598186 100644 --- a/crates/polars-plan/src/dsl/functions/repeat.rs +++ b/crates/polars-plan/src/dsl/functions/repeat.rs @@ -5,17 +5,20 @@ use super::*; /// Generally you won't need this function, as `lit(value)` already represents a column containing /// only `value` whose length is automatically set to the correct number of rows. pub fn repeat>(value: E, n: Expr) -> Expr { - let function = |s: Column, n: Column| { - polars_ensure!( - n.dtype().is_integer(), - SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() - ); - let first_value = n.get(0)?; - let n = first_value.extract::().ok_or_else( - || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), - )?; - Ok(Some(s.new_from_index(0, n))) + let input = vec![value.into(), n]; + + let expr = Expr::Function { + input, + function: FunctionExpr::Repeat, + options: FunctionOptions { + flags: FunctionFlags::default() + | FunctionFlags::ALLOW_RENAME + | FunctionFlags::CHANGES_LENGTH, + ..Default::default() + }, }; - apply_binary(value.into(), n, function, GetOutput::same_type()) - .alias(PlSmallStr::from_static("repeat")) + + // @NOTE: This alias should probably not be here for consistency, but it is here for backwards + // compatibility until 2.0. + expr.alias(PlSmallStr::from_static("repeat")) } diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index e1ef64ee02ec..4d0e4c105014 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -33,8 +33,8 @@ pub fn median(name: &str) -> Expr { } /// Find a specific quantile of all the values in the column named `name`. -pub fn quantile(name: &str, quantile: Expr, interpol: QuantileInterpolOptions) -> Expr { - col(name).quantile(quantile, interpol) +pub fn quantile(name: &str, quantile: Expr, method: QuantileMethod) -> Expr { + col(name).quantile(quantile, method) } /// Negates a boolean column. @@ -55,7 +55,7 @@ pub fn is_not_null(expr: Expr) -> Expr { /// Casts the column given by `Expr` to a different type. /// /// Follows the rules of Rust casting, with the exception that integers and floats can be cast to `DataType::Date` and -/// `DataType::DateTime(_, _)`. A column consisting entirely of of `Null` can be cast to any type, regardless of the +/// `DataType::DateTime(_, _)`. A column consisting entirely of `Null` can be cast to any type, regardless of the /// nominal type of the column. pub fn cast(expr: Expr, dtype: DataType) -> Expr { Expr::Cast { diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index 48508dba40d8..9a18abda55fd 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -427,7 +427,7 @@ pub fn duration(args: DurationArgs) -> Expr { function: FunctionExpr::TemporalExpr(TemporalFunction::Duration(args.time_unit)), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, + flags: FunctionFlags::default(), ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index fb0c7a83b463..3a1a37c9f393 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "list_to_struct")] -use std::sync::RwLock; - use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; @@ -281,50 +278,9 @@ impl ListNameSpace { /// an `upper_bound` of struct fields that will be set. /// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression /// will look in the current schema to determine which columns to select. - pub fn to_struct( - self, - n_fields: ListToStructWidthStrategy, - name_generator: Option, - upper_bound: usize, - ) -> Expr { - // heap allocate the output type and fill it later - let out_dtype = Arc::new(RwLock::new(None::)); - + pub fn to_struct(self, args: ListToStructArgs) -> Expr { self.0 - .map( - move |s| { - s.list()? - .to_struct(n_fields, name_generator.clone()) - .map(|s| Some(s.into_column())) - }, - // we don't yet know the fields - GetOutput::map_dtype(move |dt: &DataType| { - polars_ensure!(matches!(dt, DataType::List(_)), SchemaMismatch: "expected 'List' as input to 'list.to_struct' got {}", dt); - let out = out_dtype.read().unwrap(); - match out.as_ref() { - // dtype already set - Some(dt) => Ok(dt.clone()), - // dtype still unknown, set it - None => { - drop(out); - let mut lock = out_dtype.write().unwrap(); - - let inner = dt.inner_dtype().unwrap(); - let fields = (0..upper_bound) - .map(|i| { - let name = _default_struct_name_gen(i); - Field::new(name, inner.clone()) - }) - .collect(); - let dt = DataType::Struct(fields); - - *lock = Some(dt.clone()); - Ok(dt) - }, - } - }), - ) - .with_fmt("list.to_struct") + .map_private(FunctionExpr::ListExpr(ListFunction::ToStruct(args))) } #[cfg(feature = "is_in")] @@ -332,34 +288,24 @@ impl ListNameSpace { pub fn contains>(self, other: E) -> Expr { let other = other.into(); - self.0 - .map_many_private( - FunctionExpr::ListExpr(ListFunction::Contains), - &[other], - false, - None, - ) - .with_function_options(|mut options| { - options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; - options - }) + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Contains), + &[other], + false, + None, + ) } #[cfg(feature = "list_count")] /// Count how often the value produced by ``element`` occurs. pub fn count_matches>(self, element: E) -> Expr { let other = element.into(); - self.0 - .map_many_private( - FunctionExpr::ListExpr(ListFunction::CountMatches), - &[other], - false, - None, - ) - .with_function_options(|mut options| { - options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; - options - }) + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::CountMatches), + &[other], + false, + None, + ) } #[cfg(feature = "list_sets")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 9dd20bc813f5..a88ff858e6ee 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -45,7 +45,7 @@ use std::sync::Arc; pub use arity::*; #[cfg(feature = "dtype-array")] pub use array::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; pub use expr::*; pub use function_expr::schema::FieldsMapper; pub use function_expr::*; @@ -227,11 +227,11 @@ impl Expr { } /// Compute the quantile per group. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { AggExpr::Quantile { expr: Arc::new(self), quantile: Arc::new(quantile), - interpol, + method, } .into() } @@ -1126,27 +1126,27 @@ impl Expr { }) } - /// "and" operation. + /// Bitwise "and" operation. pub fn and>(self, expr: E) -> Self { binary_expr(self, Operator::And, expr.into()) } - /// "xor" operation. + /// Bitwise "xor" operation. pub fn xor>(self, expr: E) -> Self { binary_expr(self, Operator::Xor, expr.into()) } - /// "or" operation. + /// Bitwise "or" operation. pub fn or>(self, expr: E) -> Self { binary_expr(self, Operator::Or, expr.into()) } - /// "or" operation. + /// Logical "or" operation. pub fn logical_or>(self, expr: E) -> Self { binary_expr(self, Operator::LogicalOr, expr.into()) } - /// "or" operation. + /// Logical "and" operation. pub fn logical_and>(self, expr: E) -> Self { binary_expr(self, Operator::LogicalAnd, expr.into()) } @@ -1358,13 +1358,13 @@ impl Expr { pub fn rolling_quantile_by( self, by: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsDynamicWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) @@ -1385,7 +1385,7 @@ impl Expr { /// Apply a rolling median based on another column. #[cfg(feature = "rolling_window_by")] pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { - self.rolling_quantile_by(by, QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile_by(by, QuantileMethod::Linear, 0.5, options) } /// Apply a rolling minimum. @@ -1425,7 +1425,7 @@ impl Expr { /// See: [`RollingAgg::rolling_median`] #[cfg(feature = "rolling_window")] pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { - self.rolling_quantile(QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile(QuantileMethod::Linear, 0.5, options) } /// Apply a rolling quantile. @@ -1434,13 +1434,13 @@ impl Expr { #[cfg(feature = "rolling_window")] pub fn rolling_quantile( self, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsFixedWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling(options, RollingFunction::Quantile) diff --git a/crates/polars-plan/src/dsl/name.rs b/crates/polars-plan/src/dsl/name.rs index 1261b4430bec..5eeb67a77ba1 100644 --- a/crates/polars-plan/src/dsl/name.rs +++ b/crates/polars-plan/src/dsl/name.rs @@ -76,7 +76,7 @@ impl ExprNameNameSpace { fd }) .collect::>(); - let mut out = StructChunked::from_series(s.name().clone(), fields.iter())?; + let mut out = StructChunked::from_series(s.name().clone(), s.len(), fields.iter())?; out.zip_outer_validity(s); Ok(Some(out.into_column())) }, diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index 73481796d3e0..259d66af95ae 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -5,6 +5,7 @@ use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::dsl::Selector; @@ -87,8 +88,9 @@ impl Default for WindowType { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum WindowMapping { /// Map the group values to the position #[default] diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index a813dbf64e87..cd133ceb646e 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -9,10 +9,6 @@ use polars_core::schema::Schema; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; -#[cfg(feature = "serde")] -use serde::ser::Error; -#[cfg(feature = "serde")] -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::expr_dyn_fn::*; use crate::constants::MAP_LIST_NAME; @@ -25,81 +21,10 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option< pub static mut CALL_DF_UDF_PYTHON: Option< fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, > = None; -pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); -#[derive(Clone, Debug)] -pub struct PythonFunction(pub PyObject); - -impl From for PythonFunction { - fn from(value: PyObject) -> Self { - Self(value) - } -} - -impl Eq for PythonFunction {} - -impl PartialEq for PythonFunction { - fn eq(&self, other: &Self) -> bool { - Python::with_gil(|py| { - let eq = self.0.getattr(py, "__eq__").unwrap(); - eq.call1(py, (other.0.clone(),)) - .unwrap() - .extract::(py) - // equality can be not implemented, so default to false - .unwrap_or(false) - }) - } -} - -#[cfg(feature = "serde")] -impl Serialize for PythonFunction { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'cloudpickle' or 'pickle'") - .getattr("dumps") - .unwrap(); - - let python_function = self.0.clone(); - - let dumped = pickle - .call1((python_function,)) - .map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?; - let dumped = dumped.extract::().unwrap(); - - serializer.serialize_bytes(&dumped) - }) - } -} - -#[cfg(feature = "serde")] -impl<'a> Deserialize<'a> for PythonFunction { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'a>, - { - use serde::de::Error; - let bytes = Vec::::deserialize(deserializer)?; - - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'pickle'") - .getattr("loads") - .unwrap(); - let arg = (PyBytes::new_bound(py, &bytes),); - let python_function = pickle - .call1(arg) - .map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?; - - Ok(Self(python_function.into())) - }) - } -} +pub use polars_utils::python_function::{ + PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK, +}; pub struct PythonUdfExpression { python_function: PyObject, @@ -125,19 +50,36 @@ impl PythonUdfExpression { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { - debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - // skip header - let buf = &buf[MAGIC_BYTE_MARK.len()..]; + // Handle byte mark + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; + + // Handle pickle metadata + let use_cloudpickle = buf[0]; + if use_cloudpickle != 0 { + let ser_py_version = &buf[1..3]; + let cur_py_version = *PYTHON3_VERSION; + polars_ensure!( + ser_py_version == cur_py_version, + InvalidOperation: + "current Python version {:?} does not match the Python version used to serialize the UDF {:?}", + (3, cur_py_version[0], cur_py_version[1]), + (3, ser_py_version[0], ser_py_version[1] ) + ); + } + let buf = &buf[3..]; + + // Load UDF metadata let mut reader = Cursor::new(buf); let (output_type, is_elementwise, returns_scalar): (Option, bool, bool) = ciborium::de::from_reader(&mut reader).map_err(map_err)?; let remainder = &buf[reader.position() as usize..]; + // Load UDF Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'pickle'") + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") .getattr("loads") .unwrap(); let arg = (PyBytes::new_bound(py, remainder),); @@ -156,7 +98,7 @@ fn from_pyerr(e: PyErr) -> PolarsError { PolarsError::ComputeError(format!("error raised in python: {e}").into()) } -impl DataFrameUdf for PythonFunction { +impl DataFrameUdf for polars_utils::python_function::PythonFunction { fn call_udf(&self, df: DataFrame) -> PolarsResult { let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() }; func(df, &self.0) @@ -189,26 +131,46 @@ impl ColumnsUdf for PythonUdfExpression { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { - buf.extend_from_slice(MAGIC_BYTE_MARK); - ciborium::ser::into_writer( - &( - self.output_type.clone(), - self.is_elementwise, - self.returns_scalar, - ), - &mut *buf, - ) - .unwrap(); + // Write byte marks + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'pickle'") + // Try pickle to serialize the UDF, otherwise fall back to cloudpickle. + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") .getattr("dumps") .unwrap(); - let dumped = pickle - .call1((self.python_function.clone(),)) - .map_err(from_pyerr)?; + let pickle_result = pickle.call1((self.python_function.clone_ref(py),)); + let (dumped, use_cloudpickle) = match pickle_result { + Ok(dumped) => (dumped, false), + Err(_) => { + let cloudpickle = PyModule::import_bound(py, "cloudpickle") + .map_err(from_pyerr)? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle + .call1((self.python_function.clone_ref(py),)) + .map_err(from_pyerr)?; + (dumped, true) + }, + }; + + // Write pickle metadata + buf.push(use_cloudpickle as u8); + buf.extend_from_slice(&*PYTHON3_VERSION); + + // Write UDF metadata + ciborium::ser::into_writer( + &( + self.output_type.clone(), + self.is_elementwise, + self.returns_scalar, + ), + &mut *buf, + ) + .unwrap(); + + // Write UDF let dumped = dumped.extract::().unwrap(); buf.extend_from_slice(&dumped); Ok(()) @@ -229,8 +191,8 @@ impl PythonGetOutput { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { // Skip header. - debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - let buf = &buf[MAGIC_BYTE_MARK.len()..]; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; let mut reader = Cursor::new(buf); let return_dtype: Option = @@ -258,7 +220,7 @@ impl FunctionOutputField for PythonGetOutput { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { - buf.extend_from_slice(MAGIC_BYTE_MARK); + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap(); Ok(()) } diff --git a/crates/polars-plan/src/dsl/selector.rs b/crates/polars-plan/src/dsl/selector.rs index 16e7d7b374e0..7877edb152df 100644 --- a/crates/polars-plan/src/dsl/selector.rs +++ b/crates/polars-plan/src/dsl/selector.rs @@ -11,7 +11,7 @@ pub enum Selector { Add(Box, Box), Sub(Box, Box), ExclusiveOr(Box, Box), - InterSect(Box, Box), + Intersect(Box, Box), Root(Box), } @@ -34,7 +34,7 @@ impl BitAnd for Selector { #[allow(clippy::suspicious_arithmetic_impl)] fn bitand(self, rhs: Self) -> Self::Output { - Selector::InterSect(Box::new(self), Box::new(rhs)) + Selector::Intersect(Box::new(self), Box::new(rhs)) } } diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index d392d403d1b6..2514d1a5f6a4 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -592,4 +592,14 @@ impl StringNameSpace { None, ) } + + #[cfg(feature = "regex")] + pub fn escape_regex(self) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::EscapeRegex), + &[], + false, + None, + ) + } } diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 565710c0dbaf..286ea86ac968 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -44,7 +44,7 @@ pub enum IRAggExpr { Quantile { expr: Node, quantile: Node, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Node), Count(Node, bool), @@ -62,7 +62,9 @@ impl Hash for IRAggExpr { Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => { propagate_nans.hash(state) }, - Self::Quantile { interpol, .. } => interpol.hash(state), + Self::Quantile { + method: interpol, .. + } => interpol.hash(state), Self::Std(_, v) | Self::Var(_, v) => v.hash(state), #[cfg(feature = "bitwise")] Self::Bitwise(_, f) => f.hash(state), @@ -92,7 +94,7 @@ impl IRAggExpr { propagate_nans: r, .. }, ) => l == r, - (Quantile { interpol: l, .. }, Quantile { interpol: r, .. }) => l == r, + (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r, (Std(_, l), Std(_, r)) => l == r, (Var(_, l), Var(_, r)) => l == r, #[cfg(feature = "bitwise")] diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 88c44233175a..8cb4b8cc2387 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -19,17 +19,17 @@ impl AExpr { pub fn to_dtype( &self, schema: &Schema, - ctxt: Context, + ctx: Context, arena: &Arena, ) -> PolarsResult { - self.to_field(schema, ctxt, arena).map(|f| f.dtype) + self.to_field(schema, ctx, arena).map(|f| f.dtype) } /// Get Field result of the expression. The schema is the input data. pub fn to_field( &self, schema: &Schema, - ctxt: Context, + ctx: Context, arena: &Arena, ) -> PolarsResult { // During aggregation a column that isn't aggregated gets an extra nesting level @@ -37,7 +37,7 @@ impl AExpr { // But not if we do an aggregation: // col(foo: i64).sum() -> i64 // The `nested` keeps track of the nesting we need to add. - let mut nested = matches!(ctxt, Context::Aggregation) as u8; + let mut nested = matches!(ctx, Context::Aggregation) as u8; let mut field = self.to_field_impl(schema, arena, &mut nested)?; if nested >= 1 { @@ -72,6 +72,7 @@ impl AExpr { }, Explode(expr) => { let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + *nested = nested.saturating_sub(1); if let List(inner) = field.dtype() { Ok(Field::new(field.name().clone(), *inner.clone())) @@ -369,6 +370,28 @@ fn get_arithmetic_field( (_, Time) | (Time, _) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "sub", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + // FIXME: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block. + // Otherwise we will silently permit addition operations between logical types (see above). + // This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors + // if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema + // may be incorrect. + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -394,6 +417,23 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, (Boolean, Boolean) => IDX_DTYPE, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "add", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -426,6 +466,27 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, }, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op, l, r, + ) + }, + // List<->primitive operations can be done directly after casting the to the primitive + // supertype for the primitive values on both sides. + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + let dtype = list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?); + left_field.coerce(dtype); + return Ok(left_field); + }, _ => { // Avoid needlessly type casting numeric columns during arithmetic // with literals. @@ -465,32 +526,51 @@ fn get_truediv_field( nested: &mut u8, ) -> PolarsResult { let mut left_field = arena.get(left).to_field_impl(schema, arena, nested)?; + let right_field = arena.get(right).to_field_impl(schema, arena, nested)?; use DataType::*; - let out_type = match left_field.dtype() { - Float32 => Float32, - dt if dt.is_numeric() => Float64, - #[cfg(feature = "dtype-duration")] - Duration(_) => match arena - .get(right) - .to_field_impl(schema, arena, nested)? - .dtype() + + // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code + // originally (mostly) only looked at the LHS dtype. + let out_type = match (left_field.dtype(), right_field.dtype()) { + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) => { - Duration(_) => Float64, - dt if dt.is_numeric() => return Ok(left_field), - dt => { - polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.dtype(), dt) - }, + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "div", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + list_dtype.cast_leaf(match (list_dtype.leaf_dtype(), other_dtype.leaf_dtype()) { + (Float32, Float32) => Float32, + (Float32, Float64) | (Float64, Float32) => Float64, + // FIXME: We should properly recurse on the enclosing match block here. + (dt, _) => dt.clone(), + }) + }, + (Float32, _) => Float32, + (dt, _) if dt.is_numeric() => Float64, + #[cfg(feature = "dtype-duration")] + (Duration(_), Duration(_)) => Float64, + #[cfg(feature = "dtype-duration")] + (Duration(_), dt) if dt.is_numeric() => return Ok(left_field), + #[cfg(feature = "dtype-duration")] + (Duration(_), dt) => { + polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.dtype(), dt) }, #[cfg(feature = "dtype-datetime")] - Datetime(_, _) => { + (Datetime(_, _), _) => { polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed") }, #[cfg(feature = "dtype-time")] - Time => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), + (Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), #[cfg(feature = "dtype-date")] - Date => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), + (Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), // we don't know what to do here, best return the dtype - dt => dt.clone(), + (dt, _) => dt.clone(), }; left_field.coerce(out_type); diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 7163e18de165..1697a5571d4e 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -85,7 +85,7 @@ impl AExpr { } } - pub(crate) fn replace_inputs(mut self, inputs: &[Node]) -> Self { + pub fn replace_inputs(mut self, inputs: &[Node]) -> Self { use AExpr::*; let input = match &mut self { Column(_) | Literal(_) | Len => return self, diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index aef7cd157334..6520cc476178 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -1,3 +1,5 @@ +use bitflags::bitflags; + use super::*; fn has_series_or_range(ae: &AExpr) -> bool { @@ -7,7 +9,46 @@ fn has_series_or_range(ae: &AExpr) -> bool { ) } -pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> bool { +bitflags! { + #[derive(Default, Copy, Clone)] + struct StreamableFlags: u8 { + const ALLOW_CAST_CATEGORICAL = 1; + } +} + +#[derive(Copy, Clone)] +pub struct IsStreamableContext { + flags: StreamableFlags, + context: Context, +} + +impl Default for IsStreamableContext { + fn default() -> Self { + Self { + flags: StreamableFlags::all(), + context: Default::default(), + } + } +} + +impl IsStreamableContext { + pub fn new(ctx: Context) -> Self { + Self { + flags: StreamableFlags::all(), + context: ctx, + } + } + + pub fn with_allow_cast_categorical(mut self, allow_cast_categorical: bool) -> Self { + self.flags.set( + StreamableFlags::ALLOW_CAST_CATEGORICAL, + allow_cast_categorical, + ); + self + } +} + +pub fn is_streamable(node: Node, expr_arena: &Arena, ctx: IsStreamableContext) -> bool { // check whether leaf column is Col or Lit let mut seen_column = false; let mut seen_lit_range = false; @@ -16,13 +57,14 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> function: FunctionExpr::SetSortedFlag(_), .. } => true, - AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => match context - { - Context::Default => matches!( - options.collect_groups, - ApplyOptions::ElementWise | ApplyOptions::ApplyList - ), - Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), + AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { + match ctx.context { + Context::Default => matches!( + options.collect_groups, + ApplyOptions::ElementWise | ApplyOptions::ApplyList + ), + Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), + } }, AExpr::Column(_) => { seen_column = true; @@ -41,6 +83,10 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> && !has_aexpr(*falsy, expr_arena, has_series_or_range) && !has_aexpr(*predicate, expr_arena, has_series_or_range) }, + #[cfg(feature = "dtype-categorical")] + AExpr::Cast { dtype, .. } if matches!(dtype, DataType::Categorical(_, _)) => { + ctx.flags.contains(StreamableFlags::ALLOW_CAST_CATEGORICAL) + }, AExpr::Alias(_, _) | AExpr::Cast { .. } => true, AExpr::Literal(lv) => match lv { LiteralValue::Series(_) | LiteralValue::Range { .. } => { @@ -64,8 +110,12 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> false } -pub fn all_streamable(exprs: &[ExprIR], expr_arena: &Arena, context: Context) -> bool { +pub fn all_streamable( + exprs: &[ExprIR], + expr_arena: &Arena, + ctx: IsStreamableContext, +) -> bool { exprs .iter() - .all(|e| is_streamable(e.node(), expr_arena, context)) + .all(|e| is_streamable(e.node(), expr_arena, ctx)) } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index c0178f5b383c..793ab63194d7 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -369,7 +369,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; - return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Context::Default) { + return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Default::default()) { // Split expression that are ANDed into multiple Filter nodes as the optimizer can then // push them down independently. Especially if they refer columns from different tables // this will be more performant. @@ -738,9 +738,9 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult |name| col(name.clone()).std(ddof), &input_schema, ), - StatsFunction::Quantile { quantile, interpol } => stats_helper( + StatsFunction::Quantile { quantile, method } => stats_helper( |dt| dt.is_numeric(), - |name| col(name.clone()).quantile(quantile.clone(), interpol), + |name| col(name.clone()).quantile(quantile.clone(), method), &input_schema, ), StatsFunction::Mean => stats_helper( diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index 4d1aa76caff5..4709641662f9 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -1,5 +1,4 @@ //! this contains code used for rewriting projections, expanding wildcards, regex selection etc. -use std::ops::BitXor; use super::*; @@ -45,7 +44,7 @@ fn rewrite_special_aliases(expr: Expr) -> PolarsResult { Ok(Expr::Alias(expr, name.clone())) }, Expr::RenameAlias { expr, function } => { - let name = get_single_leaf(&expr).unwrap(); + let name = get_single_leaf(&expr)?; let name = function.call(&name)?; Ok(Expr::Alias(expr, name)) }, @@ -176,26 +175,28 @@ fn expand_columns( schema: &Schema, exclude: &PlHashSet, ) -> PolarsResult<()> { - let mut is_valid = true; + if !expr.into_iter().all(|e| match e { + // check for invalid expansions such as `col([a, b]) + col([c, d])` + Expr::Columns(ref members) => members.as_ref() == names, + _ => true, + }) { + polars_bail!(ComputeError: "expanding more than one `col` is not allowed"); + } for name in names { if !exclude.contains(name) { - let new_expr = expr.clone(); - let (new_expr, new_expr_valid) = replace_columns_with_column(new_expr, names, name); - is_valid &= new_expr_valid; - // we may have regex col in columns. - #[allow(clippy::collapsible_else_if)] + let new_expr = expr.clone().map_expr(|e| match e { + Expr::Columns(_) => Expr::Column((*name).clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }); + #[cfg(feature = "regex")] - { - replace_regex(&new_expr, result, schema, exclude)?; - } + replace_regex(&new_expr, result, schema, exclude)?; + #[cfg(not(feature = "regex"))] - { - let new_expr = rewrite_special_aliases(new_expr)?; - result.push(new_expr) - } + result.push(rewrite_special_aliases(new_expr)?); } } - polars_ensure!(is_valid, ComputeError: "expanding more than one `col` is not allowed"); Ok(()) } @@ -246,30 +247,6 @@ fn replace_dtype_or_index_with_column( }) } -/// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the -/// expression chain. -pub(super) fn replace_columns_with_column( - mut expr: Expr, - names: &[PlSmallStr], - column_name: &PlSmallStr, -) -> (Expr, bool) { - let mut is_valid = true; - expr = expr.map_expr(|e| match e { - Expr::Columns(members) => { - // `col([a, b]) + col([c, d])` - if members.as_ref() == names { - Expr::Column(column_name.clone()) - } else { - is_valid = false; - Expr::Columns(members) - } - }, - Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), - e => e, - }); - (expr, is_valid) -} - fn dtypes_match(d1: &DataType, d2: &DataType) -> bool { match (d1, d2) { // note: allow Datetime "*" wildcard for timezones... @@ -550,7 +527,7 @@ fn expand_function_inputs( .flags .contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) => { - *input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags).unwrap(); + *input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags)?; if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) { // Needed to visualize the error *input = vec![Expr::Literal(LiteralValue::Null)]; @@ -562,7 +539,7 @@ fn expand_function_inputs( }) } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] struct ExpansionFlags { multiple_columns: bool, has_nth: bool, @@ -819,42 +796,31 @@ fn replace_selector_inner( members.extend(scratch.drain(..)) }, Selector::Add(lhs, rhs) => { + let mut tmp_members: PlIndexSet = Default::default(); replace_selector_inner(*lhs, members, scratch, schema, keys)?; - let mut rhs_members: PlIndexSet = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - members.extend(rhs_members) + replace_selector_inner(*rhs, &mut tmp_members, scratch, schema, keys)?; + members.extend(tmp_members) }, Selector::ExclusiveOr(lhs, rhs) => { - let mut lhs_members = Default::default(); - replace_selector_inner(*lhs, &mut lhs_members, scratch, schema, keys)?; + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - let xor_members = lhs_members.bitxor(&rhs_members); - *members = xor_members; + *members = tmp_members.symmetric_difference(members).cloned().collect(); }, - Selector::InterSect(lhs, rhs) => { - replace_selector_inner(*lhs, members, scratch, schema, keys)?; + Selector::Intersect(lhs, rhs) => { + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - *members = members.intersection(&rhs_members).cloned().collect() + *members = tmp_members.intersection(members).cloned().collect(); }, Selector::Sub(lhs, rhs) => { - replace_selector_inner(*lhs, members, scratch, schema, keys)?; + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - let mut new_members = PlIndexSet::with_capacity(members.len()); - for e in members.drain(..) { - if !rhs_members.contains(&e) { - new_members.insert(e); - } - } - *members = new_members; + *members = tmp_members.difference(members).cloned().collect(); }, } Ok(()) diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index 95eca45a9bf6..d3e0c17f8098 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -96,7 +96,7 @@ fn to_aexpr_impl_materialized_lit( let e = match expr { Expr::Literal(lv @ LiteralValue::Int(_) | lv @ LiteralValue::Float(_)) => { let av = lv.to_any_value().unwrap(); - Expr::Literal(LiteralValue::try_from(av).unwrap()) + Expr::Literal(LiteralValue::from(av)) }, Expr::Alias(inner, name) if matches!( @@ -109,10 +109,7 @@ fn to_aexpr_impl_materialized_lit( unreachable!() }; let av = lv.to_any_value().unwrap(); - Expr::Alias( - Arc::new(Expr::Literal(LiteralValue::try_from(av).unwrap())), - name, - ) + Expr::Alias(Arc::new(Expr::Literal(LiteralValue::from(av))), name) }, e => e, }; @@ -240,11 +237,11 @@ pub(super) fn to_aexpr_impl( AggExpr::Quantile { expr, quantile, - interpol, + method, } => IRAggExpr::Quantile { expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?, - interpol, + method, }, AggExpr::Sum(expr) => { IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index 5d2e4c373b30..160b70951962 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -129,14 +129,14 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { IRAggExpr::Quantile { expr, quantile, - interpol, + method, } => { let expr = node_to_expr(expr, expr_arena); let quantile = node_to_expr(quantile, expr_arena); AggExpr::Quantile { expr: Arc::new(expr), quantile: Arc::new(quantile), - interpol, + method, } .into() }, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 60f7fc20f57e..9d63d18b0a46 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -115,7 +115,7 @@ pub fn resolve_join( // Every expression must be elementwise so that we are // guaranteed the keys for a join are all the same length. let all_elementwise = - |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Context::Default); + |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Default::default()); polars_ensure!( all_elementwise(&left_on) && all_elementwise(&right_on), InvalidOperation: "All join key expressions must be elementwise." @@ -163,7 +163,6 @@ fn resolve_join_where( .get(input_right) .schema(ctxt.lp_arena) .into_owned(); - for e in &predicates { let no_binary_comparisons = e .into_iter() @@ -174,16 +173,23 @@ fn resolve_join_where( .count(); polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition"); - fn all_in_schema(schema: &Schema, left: &Expr, right: &Expr) -> bool { + fn all_in_schema( + schema: &Schema, + other: Option<&Schema>, + left: &Expr, + right: &Expr, + ) -> bool { let mut iter = expr_to_leaf_column_names_iter(left).chain(expr_to_leaf_column_names_iter(right)); - iter.all(|name| schema.contains(name.as_str())) + iter.all(|name| { + schema.contains(name.as_str()) && other.map_or(true, |s| !s.contains(name.as_str())) + }) } let valid = e.into_iter().all(|e| match e { Expr::BinaryExpr { left, op, right } if op.is_comparison() => { - !(all_in_schema(&schema_left, left, right) - || all_in_schema(&schema_right, left, right)) + !(all_in_schema(&schema_left, None, left, right) + || all_in_schema(&schema_right, Some(&schema_left), left, right)) }, _ => true, }); diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 37d58e004ab1..24f65b3465f5 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -47,55 +47,6 @@ fn is_cat_str_binary(type_left: &DataType, type_right: &DataType) -> bool { } } -fn process_list_arithmetic( - type_left: DataType, - type_right: DataType, - node_left: Node, - node_right: Node, - op: Operator, - expr_arena: &mut Arena, -) -> PolarsResult> { - match (&type_left, &type_right) { - (DataType::List(_), _) => { - let leaf = type_left.leaf_dtype(); - if type_right != *leaf { - let new_node_right = expr_arena.add(AExpr::Cast { - expr: node_right, - dtype: type_left.cast_leaf(leaf.clone()), - options: CastOptions::NonStrict, - }); - - Ok(Some(AExpr::BinaryExpr { - left: node_left, - op, - right: new_node_right, - })) - } else { - Ok(None) - } - }, - (_, DataType::List(_)) => { - let leaf = type_right.leaf_dtype(); - if type_left != *leaf { - let new_node_left = expr_arena.add(AExpr::Cast { - expr: node_left, - dtype: type_right.cast_leaf(leaf.clone()), - options: CastOptions::NonStrict, - }); - - Ok(Some(AExpr::BinaryExpr { - left: new_node_left, - op, - right: node_right, - })) - } else { - Ok(None) - } - }, - _ => unreachable!(), - } -} - #[cfg(feature = "dtype-struct")] // Ensure we don't cast to supertype // otherwise we will fill a struct with null fields @@ -265,11 +216,6 @@ pub(super) fn process_binary( (String, a) | (a, String) if a.is_numeric() => { polars_bail!(InvalidOperation: "arithmetic on string and numeric not allowed, try an explicit cast first") }, - (List(_), _) | (_, List(_)) => { - return process_list_arithmetic( - type_left, type_right, node_left, node_right, op, expr_arena, - ) - }, (Datetime(_, _), _) | (_, Datetime(_, _)) | (Date, _) @@ -277,7 +223,9 @@ pub(super) fn process_binary( | (Duration(_), _) | (_, Duration(_)) | (Time, _) - | (_, Time) => return Ok(None), + | (_, Time) + | (List(_), _) + | (_, List(_)) => return Ok(None), #[cfg(feature = "dtype-struct")] (Struct(_), a) | (a, Struct(_)) if a.is_numeric() => { return process_struct_numeric_arithmetic( diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index fc8f520e86ea..fd0c06a39550 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -419,7 +419,7 @@ fn inline_or_prune_cast( }, LiteralValue::StrCat(s) => { let av = AnyValue::String(s).strict_cast(dtype); - return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap()))); + return Ok(av.map(|av| AExpr::Literal(av.into()))); }, // We generate casted literal datetimes, so ensure we cast upon conversion // to create simpler expr trees. @@ -431,7 +431,7 @@ fn inline_or_prune_cast( lv @ (LiteralValue::Int(_) | LiteralValue::Float(_)) => { let av = lv.to_any_value().ok_or_else(|| polars_err!(InvalidOperation: "literal value: {:?} too large for Polars", lv))?; let av = av.strict_cast(dtype); - return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap()))); + return Ok(av.map(|av| AExpr::Literal(av.into()))); }, LiteralValue::Null => match dtype { DataType::Unknown(UnknownKind::Float | UnknownKind::Int(_) | UnknownKind::Str) => { @@ -469,7 +469,7 @@ fn inline_or_prune_cast( None => return Ok(None), } }; - out.try_into()? + out.into() }, } }, diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index e470bd3044bc..f1aa33a7e7dd 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -72,7 +72,7 @@ pub enum StatsFunction { }, Quantile { quantile: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Median, Mean, diff --git a/crates/polars-plan/src/plans/functions/merge_sorted.rs b/crates/polars-plan/src/plans/functions/merge_sorted.rs index 605a628c3c88..6397a8374933 100644 --- a/crates/polars-plan/src/plans/functions/merge_sorted.rs +++ b/crates/polars-plan/src/plans/functions/merge_sorted.rs @@ -31,8 +31,8 @@ pub(super) fn merge_sorted(df: &DataFrame, column: &str) -> PolarsResult char { - if polars_io::path_utils::is_cloud_url(url) { - '/' +fn separator(url: &Path) -> &[char] { + if cfg!(target_family = "windows") { + if polars_io::path_utils::is_cloud_url(url) { + &['/'] + } else { + &['/', '\\'] + } } else { - '\\' + &['/'] } } -/// Determine the path separator for identifying Hive partitions. -#[cfg(not(target_os = "windows"))] -fn separator(_url: &Path) -> char { - '/' -} - /// Parse a Hive partition string (e.g. "column=1.5") into a name and value part. /// /// Returns `None` if the string is not a Hive partition string. diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs index 51050f2fa877..76d8559a052d 100644 --- a/crates/polars-plan/src/plans/ir/dot.rs +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -420,7 +420,7 @@ impl fmt::Display for OptionExprIRDisplay<'_> { /// Utility structure to write to a [`fmt::Formatter`] whilst escaping the output as a label name pub struct EscapeLabel<'a>(pub &'a mut dyn fmt::Write); -impl<'a> fmt::Write for EscapeLabel<'a> { +impl fmt::Write for EscapeLabel<'_> { fn write_str(&mut self, mut s: &str) -> fmt::Result { loop { let mut char_indices = s.char_indices(); diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 4ccb74f66238..c4ff7dfffb45 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -413,13 +413,13 @@ impl<'a> ExprIRDisplay<'a> { } } -impl<'a> Display for IRDisplay<'a> { +impl Display for IRDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self._format(f, 0) } } -impl<'a, T: AsExpr> Display for ExprIRSliceDisplay<'a, T> { +impl Display for ExprIRSliceDisplay<'_, T> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { // Display items in slice delimited by a comma @@ -452,13 +452,13 @@ impl<'a, T: AsExpr> Display for ExprIRSliceDisplay<'a, T> { } } -impl<'a, T: AsExpr> fmt::Debug for ExprIRSliceDisplay<'a, T> { +impl fmt::Debug for ExprIRSliceDisplay<'_, T> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(self, f) } } -impl<'a> Display for ExprIRDisplay<'a> { +impl Display for ExprIRDisplay<'_> { #[recursive] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let root = self.expr_arena.get(self.node); @@ -673,7 +673,7 @@ impl<'a> Display for ExprIRDisplay<'a> { } } -impl<'a> fmt::Debug for ExprIRDisplay<'a> { +impl fmt::Debug for ExprIRDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(self, f) } diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index a9eb45b6406f..14b658bea22f 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -272,6 +272,6 @@ mod test { #[ignore] #[test] fn test_alp_size() { - assert!(std::mem::size_of::() <= 152); + assert!(size_of::() <= 152); } } diff --git a/crates/polars-plan/src/plans/ir/scan_sources.rs b/crates/polars-plan/src/plans/ir/scan_sources.rs index 789a5c4f4811..bcb80bd6140d 100644 --- a/crates/polars-plan/src/plans/ir/scan_sources.rs +++ b/crates/polars-plan/src/plans/ir/scan_sources.rs @@ -330,4 +330,4 @@ impl<'a> Iterator for ScanSourceIter<'a> { } } -impl<'a> ExactSizeIterator for ScanSourceIter<'a> {} +impl ExactSizeIterator for ScanSourceIter<'_> {} diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index 6d95d7c443ca..74feffd60da0 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -92,7 +92,7 @@ impl LiteralValue { match self { LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_) => { let av = self.to_any_value().unwrap(); - av.try_into().unwrap() + av.into() }, lv => lv, } @@ -266,7 +266,7 @@ impl Literal for String { } } -impl<'a> Literal for &'a str { +impl Literal for &str { fn lit(self) -> Expr { Expr::Literal(LiteralValue::String(PlSmallStr::from_str(self))) } @@ -278,61 +278,58 @@ impl Literal for Vec { } } -impl<'a> Literal for &'a [u8] { +impl Literal for &[u8] { fn lit(self) -> Expr { Expr::Literal(LiteralValue::Binary(self.to_vec())) } } -impl TryFrom> for LiteralValue { - type Error = PolarsError; - fn try_from(value: AnyValue) -> PolarsResult { +impl From> for LiteralValue { + fn from(value: AnyValue) -> Self { match value { - AnyValue::Null => Ok(Self::Null), - AnyValue::Boolean(b) => Ok(Self::Boolean(b)), - AnyValue::String(s) => Ok(Self::String(PlSmallStr::from_str(s))), - AnyValue::Binary(b) => Ok(Self::Binary(b.to_vec())), + AnyValue::Null => Self::Null, + AnyValue::Boolean(b) => Self::Boolean(b), + AnyValue::String(s) => Self::String(PlSmallStr::from_str(s)), + AnyValue::Binary(b) => Self::Binary(b.to_vec()), #[cfg(feature = "dtype-u8")] - AnyValue::UInt8(u) => Ok(Self::UInt8(u)), + AnyValue::UInt8(u) => Self::UInt8(u), #[cfg(feature = "dtype-u16")] - AnyValue::UInt16(u) => Ok(Self::UInt16(u)), - AnyValue::UInt32(u) => Ok(Self::UInt32(u)), - AnyValue::UInt64(u) => Ok(Self::UInt64(u)), + AnyValue::UInt16(u) => Self::UInt16(u), + AnyValue::UInt32(u) => Self::UInt32(u), + AnyValue::UInt64(u) => Self::UInt64(u), #[cfg(feature = "dtype-i8")] - AnyValue::Int8(i) => Ok(Self::Int8(i)), + AnyValue::Int8(i) => Self::Int8(i), #[cfg(feature = "dtype-i16")] - AnyValue::Int16(i) => Ok(Self::Int16(i)), - AnyValue::Int32(i) => Ok(Self::Int32(i)), - AnyValue::Int64(i) => Ok(Self::Int64(i)), - AnyValue::Float32(f) => Ok(Self::Float32(f)), - AnyValue::Float64(f) => Ok(Self::Float64(f)), + AnyValue::Int16(i) => Self::Int16(i), + AnyValue::Int32(i) => Self::Int32(i), + AnyValue::Int64(i) => Self::Int64(i), + AnyValue::Float32(f) => Self::Float32(f), + AnyValue::Float64(f) => Self::Float64(f), #[cfg(feature = "dtype-decimal")] - AnyValue::Decimal(v, scale) => Ok(Self::Decimal(v, scale)), + AnyValue::Decimal(v, scale) => Self::Decimal(v, scale), #[cfg(feature = "dtype-date")] - AnyValue::Date(v) => Ok(LiteralValue::Date(v)), + AnyValue::Date(v) => LiteralValue::Date(v), #[cfg(feature = "dtype-datetime")] - AnyValue::Datetime(value, tu, tz) => Ok(LiteralValue::DateTime(value, tu, tz.cloned())), + AnyValue::Datetime(value, tu, tz) => LiteralValue::DateTime(value, tu, tz.cloned()), #[cfg(feature = "dtype-duration")] - AnyValue::Duration(value, tu) => Ok(LiteralValue::Duration(value, tu)), + AnyValue::Duration(value, tu) => LiteralValue::Duration(value, tu), #[cfg(feature = "dtype-time")] - AnyValue::Time(v) => Ok(LiteralValue::Time(v)), - AnyValue::List(l) => Ok(Self::Series(SpecialEq::new(l))), - AnyValue::StringOwned(o) => Ok(Self::String(o)), + AnyValue::Time(v) => LiteralValue::Time(v), + AnyValue::List(l) => Self::Series(SpecialEq::new(l)), + AnyValue::StringOwned(o) => Self::String(o), #[cfg(feature = "dtype-categorical")] AnyValue::Categorical(c, rev_mapping, arr) | AnyValue::Enum(c, rev_mapping, arr) => { if arr.is_null() { - Ok(Self::String(PlSmallStr::from_str(rev_mapping.get(c)))) + Self::String(PlSmallStr::from_str(rev_mapping.get(c))) } else { unsafe { - Ok(Self::String(PlSmallStr::from_str( + Self::String(PlSmallStr::from_str( arr.deref_unchecked().value(c as usize), - ))) + )) } } }, - v => polars_bail!( - ComputeError: "cannot convert any-value {:?} to literal", v - ), + _ => LiteralValue::OtherScalar(Scalar::new(value.dtype(), value.into_static())), } } } diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 03eb06387cc6..314ca8bb0cb2 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -49,11 +49,12 @@ pub use schema::*; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Default)] pub enum Context { /// Any operation that is done on groups Aggregation, /// Any operation that is done while projection/ selection of data + #[default] Default, } diff --git a/crates/polars-plan/src/plans/optimizer/cache_states.rs b/crates/polars-plan/src/plans/optimizer/cache_states.rs index da13d047d43f..f6968cc6f7d2 100644 --- a/crates/polars-plan/src/plans/optimizer/cache_states.rs +++ b/crates/polars-plan/src/plans/optimizer/cache_states.rs @@ -348,17 +348,23 @@ pub(super) fn set_cache_states( let lp = IRBuilder::new(new_child, expr_arena, lp_arena) .project_simple(projection) - .unwrap() + .expect("unique names") .build(); let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; - // Remove the projection added by the optimization. - let lp = - if let IR::Select { input, .. } | IR::SimpleProjection { input, .. } = lp { - lp_arena.take(input) + // Optimization can lead to a double projection. Only take the last. + let lp = if let IR::SimpleProjection { input, columns } = lp { + let input = if let IR::SimpleProjection { input: input2, .. } = + lp_arena.get(input) + { + *input2 } else { - lp + input }; + IR::SimpleProjection { input, columns } + } else { + lp + }; lp_arena.replace(child, lp); } } else { diff --git a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs index b3f52c6e30a9..4e109903fdce 100644 --- a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs +++ b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs @@ -141,7 +141,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) // @NOTE: Pruning of re-assigned columns // // We checked if this expression output is also assigned by the input and - // that that assignment is not used in the current WITH_COLUMNS. + // that this assignment is not used in the current WITH_COLUMNS. // Consequently, we are free to prune the input's assignment to the output. // // We immediately prune here to simplify the later code. diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index 608c7122f2ec..778efee6aa9b 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -123,7 +123,7 @@ struct MintermIter<'a> { expr_arena: &'a Arena, } -impl<'a> Iterator for MintermIter<'a> { +impl Iterator for MintermIter<'_> { type Item = Node; fn next(&mut self) -> Option { diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs index 6b7763760fa1..700a82720eb4 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs @@ -181,6 +181,11 @@ enum VisitRecord { fn skip_pre_visit(ae: &AExpr, is_groupby: bool) -> bool { match ae { AExpr::Window { .. } => true, + #[cfg(feature = "dtype-struct")] + AExpr::Function { + function: FunctionExpr::AsStruct, + .. + } => true, AExpr::Ternary { .. } => is_groupby, _ => false, } diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs index 075414597edf..f4522ee3a3ca 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs @@ -184,7 +184,7 @@ fn skip_children(lp: &IR) -> bool { } } -impl<'a> Visitor for LpIdentifierVisitor<'a> { +impl Visitor for LpIdentifierVisitor<'_> { type Node = IRNode; type Arena = IRNodeArena; @@ -265,7 +265,7 @@ impl<'a> CommonSubPlanRewriter<'a> { } } -impl<'a> RewritingVisitor for CommonSubPlanRewriter<'a> { +impl RewritingVisitor for CommonSubPlanRewriter<'_> { type Node = IRNode; type Arena = IRNodeArena; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index 7cb0753e5a6d..ff5f2f89ff0d 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -17,7 +17,7 @@ use crate::prelude::optimizer::predicate_pushdown::rename::process_rename; use crate::utils::{check_input_node, has_aexpr}; pub type ExprEval<'a> = - Option<&'a dyn Fn(&ExprIR, &Arena) -> Option>>; + Option<&'a dyn Fn(&ExprIR, &Arena, &SchemaRef) -> Option>>; pub struct PredicatePushDown<'a> { expr_eval: ExprEval<'a>, @@ -364,7 +364,9 @@ impl<'a> PredicatePushDown<'a> { let predicate = predicate_at_scan(acc_predicates, predicate.clone(), expr_arena); if let (Some(hive_parts), Some(predicate)) = (&scan_hive_parts, &predicate) { - if let Some(io_expr) = self.expr_eval.unwrap()(predicate, expr_arena) { + if let Some(io_expr) = + self.expr_eval.unwrap()(predicate, expr_arena, &file_info.schema) + { if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { let paths = sources.as_paths().ok_or_else(|| { polars_err!(nyi = "Hive partitioning of in-memory buffers") @@ -670,7 +672,7 @@ impl<'a> PredicatePushDown<'a> { if let Some(predicate) = predicate { // For IO plugins we only accept streamable expressions as // we want to apply the predicates to the batches. - if !is_streamable(predicate.node(), expr_arena, Context::Default) + if !is_streamable(predicate.node(), expr_arena, Default::default()) && matches!(options.python_source, PythonScanSource::IOPlugin) { let lp = PythonScan { options }; diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs index 8096b5bde3d8..81328fe208e6 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs @@ -63,7 +63,7 @@ pub(super) fn process_hstack( acc_projections, &lp_arena.get(input).schema(lp_arena), expr_arena, - false, + true, // expands_schema ); proj_pd.pushdown_and_assign( diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs index 61c86e789d95..55d5501dd44e 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -344,6 +344,7 @@ impl ProjectionPushDown { projections_seen, lp_arena, expr_arena, + false, ), SimpleProjection { columns, input, .. } => { let exprs = names_to_expr_irs(columns.iter_names_cloned(), expr_arena); @@ -356,6 +357,7 @@ impl ProjectionPushDown { projections_seen, lp_arena, expr_arena, + true, ) }, DataFrameScan { @@ -509,6 +511,21 @@ impl ProjectionPushDown { file_options.row_index = None; } }; + + if let Some(col_name) = &file_options.include_file_paths { + if output_schema + .as_ref() + .map_or(false, |schema| !schema.contains(col_name)) + { + // Need to remove it from the input schema so + // that projection indices are correct. + let mut file_schema = Arc::unwrap_or_clone(file_info.schema); + file_schema.shift_remove(col_name); + file_info.schema = Arc::new(file_schema); + file_options.include_file_paths = None; + } + }; + let lp = Scan { sources, file_info, diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs index 6b1106a7ca19..854cc17d6fbe 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs @@ -54,6 +54,8 @@ pub(super) fn process_projection( projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, + // Whether is SimpleProjection. + simple: bool, ) -> PolarsResult { let mut local_projection = Vec::with_capacity(exprs.len()); @@ -130,7 +132,14 @@ pub(super) fn process_projection( )?; let builder = IRBuilder::new(input, expr_arena, lp_arena); - let lp = proj_pd.finish_node(local_projection, builder); + + let lp = if !local_projection.is_empty() && simple { + builder + .project_simple_nodes(local_projection.into_iter().map(|e| e.node()))? + .build() + } else { + proj_pd.finish_node(local_projection, builder) + }; Ok(lp) } diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs index 1df68a0adcfa..db123a5bd09d 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs @@ -152,7 +152,7 @@ impl OptimizationRule for SimplifyBooleanRule { AExpr::Literal(LiteralValue::Boolean(true)) ) && in_filter => { - // Only in filter as we we might change the name from "literal" + // Only in filter as we might change the name from "literal" // to whatever lhs columns is. return Ok(Some(expr_arena.get(*right).clone())); }, @@ -210,7 +210,7 @@ impl OptimizationRule for SimplifyBooleanRule { AExpr::Literal(LiteralValue::Boolean(false)) ) && in_filter => { - // Only in filter as we we might change the name from "literal" + // Only in filter as we might change the name from "literal" // to whatever lhs columns is. return Ok(Some(expr_arena.get(*right).clone())); }, diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index b656795f53d2..9c2f8497fac8 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -31,7 +31,7 @@ fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena) - // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, // `str.contains`, `str.contains_many` etc. - observe a column node is present // but the output height is not dependent on it. - let is_elementwise = is_streamable(expr_ir.node(), arena, Context::Default); + let is_elementwise = is_streamable(expr_ir.node(), arena, Default::default()); let (has_column, literals_all_scalar) = arena.iter(expr_ir.node()).fold( (false, true), |(has_column, lit_scalar), (_node, ae)| { diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index 71b287d03b85..62a64319ae2e 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -67,7 +67,7 @@ impl TreeWalker for Expr { Mean(x) => Mean(am(x, f)?), Implode(x) => Implode(am(x, f)?), Count(x, nulls) => Count(am(x, f)?, nulls), - Quantile { expr, quantile, interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, interpol }, + Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol }, Sum(x) => Sum(am(x, f)?), AggGroups(x) => AggGroups(am(x, f)?), Std(x, ddf) => Std(am(x, f)?, ddf), diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index 943515e79f5c..16af7a3071df 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -11,8 +11,10 @@ description = "Enable running Polars workloads in Python" [dependencies] polars-core = { workspace = true, features = ["python"] } polars-error = { workspace = true } +polars-expr = { workspace = true } polars-io = { workspace = true } polars-lazy = { workspace = true, features = ["python"] } +polars-mem-engine = { workspace = true } polars-ops = { workspace = true, features = ["bitwise"] } polars-parquet = { workspace = true, optional = true } polars-plan = { workspace = true } @@ -36,7 +38,7 @@ num-traits = { workspace = true } # https://github.com/PyO3/rust-numpy/issues/409 numpy = { git = "https://github.com/stinodego/rust-numpy.git", rev = "9ba9962ae57ba26e35babdce6f179edf5fe5b9c8", default-features = false } once_cell = { workspace = true } -pyo3 = { workspace = true, features = ["abi3-py38", "chrono", "multiple-pymethods"] } +pyo3 = { workspace = true, features = ["abi3-py39", "chrono", "multiple-pymethods"] } recursive = { workspace = true } serde_json = { workspace = true, optional = true } thiserror = { workspace = true } @@ -77,6 +79,7 @@ features = [ "lazy", "list_eval", "list_to_struct", + "list_arithmetic", "array_to_struct", "log", "mode", @@ -232,7 +235,7 @@ optimizations = [ "streaming", ] -polars_cloud = ["polars/polars_cloud"] +polars_cloud = ["polars/polars_cloud", "polars/ir_serde"] # also includes simd nightly = ["polars/nightly"] @@ -252,7 +255,7 @@ all = [ "binary_encoding", "ffi_plugin", "polars_cloud", - # "new_streaming", + "new_streaming", ] # we cannot conditionally activate simd diff --git a/crates/polars-python/src/cloud.rs b/crates/polars-python/src/cloud.rs index dacca675c551..39410a6fa7a1 100644 --- a/crates/polars-python/src/cloud.rs +++ b/crates/polars-python/src/cloud.rs @@ -1,8 +1,17 @@ -use pyo3::prelude::*; -use pyo3::types::PyBytes; +use std::io::Cursor; + +use polars_core::error::{polars_err, to_compute_err, PolarsResult}; +use polars_expr::state::ExecutionState; +use polars_mem_engine::create_physical_plan; +use polars_plan::plans::{AExpr, IRPlan, IR}; +use polars_plan::prelude::{Arena, Node}; +use pyo3::intern; +use pyo3::prelude::{PyAnyMethods, PyModule, Python, *}; +use pyo3::types::{IntoPyDict, PyBytes}; use crate::error::PyPolarsErr; -use crate::PyLazyFrame; +use crate::lazyframe::visit::NodeTraverser; +use crate::{PyDataFrame, PyLazyFrame}; #[pyfunction] pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python) -> PyResult { @@ -11,3 +20,76 @@ pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python) -> PyResult { Ok(PyBytes::new_bound(py, &bytes).to_object(py)) } + +/// Take a serialized `IRPlan` and execute it on the GPU engine. +/// +/// This is done as a Python function because the `NodeTraverser` class created for this purpose +/// must exactly match the one expected by the `cudf_polars` package. +#[pyfunction] +pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec, py: Python) -> PyResult { + // Deserialize into IRPlan. + let reader = Cursor::new(ir_plan_ser); + let mut ir_plan = ciborium::from_reader::(reader) + .map_err(to_compute_err) + .map_err(PyPolarsErr::from)?; + + // Edit for use with GPU engine. + gpu_post_opt( + py, + ir_plan.lp_top, + &mut ir_plan.lp_arena, + &mut ir_plan.expr_arena, + ) + .map_err(PyPolarsErr::from)?; + + // Convert to physical plan. + let mut physical_plan = + create_physical_plan(ir_plan.lp_top, &mut ir_plan.lp_arena, &ir_plan.expr_arena) + .map_err(PyPolarsErr::from)?; + + // Execute the plan. + let mut state = ExecutionState::new(); + let df = physical_plan + .execute(&mut state) + .map_err(PyPolarsErr::from)?; + + Ok(df.into()) +} + +/// Prepare the IR for execution by the Polars GPU engine. +fn gpu_post_opt( + py: Python, + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult<()> { + // Get cuDF Python function. + let cudf = PyModule::import_bound(py, intern!(py, "cudf_polars")).unwrap(); + let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap(); + + // Define cuDF config. + let polars = PyModule::import_bound(py, intern!(py, "polars")).unwrap(); + let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap(); + let kwargs = [("raise_on_fail", true)].into_py_dict_bound(py); + let engine = engine.call((), Some(&kwargs)).unwrap(); + + // Define node traverser. + let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena)); + + // Get a copy of the arenas. + let arenas = nt.get_arenas(); + + // Pass the node visitor which allows the Python callback to replace parts of the query plan. + // Remove "cuda" or specify better once we have multiple post-opt callbacks. + let kwargs = [("config", engine)].into_py_dict_bound(py); + lambda + .call((nt,), Some(&kwargs)) + .map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?; + + // Unpack the arena's. + // At this point the `nt` is useless. + std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap()); + std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap()); + + Ok(()) +} diff --git a/crates/polars-python/src/conversion/any_value.rs b/crates/polars-python/src/conversion/any_value.rs index c9d6f6f17e8b..eb4835ada90f 100644 --- a/crates/polars-python/src/conversion/any_value.rs +++ b/crates/polars-python/src/conversion/any_value.rs @@ -98,12 +98,12 @@ pub(crate) fn any_value_into_py_object(av: AnyValue, py: Python) -> PyObject { #[cfg(feature = "object")] AnyValue::Object(v) => { let object = v.as_any().downcast_ref::().unwrap(); - object.inner.clone() + object.inner.clone_ref(py) }, #[cfg(feature = "object")] AnyValue::ObjectOwned(v) => { let object = v.0.as_any().downcast_ref::().unwrap(); - object.inner.clone() + object.inner.clone_ref(py) }, AnyValue::Binary(v) => PyBytes::new_bound(py, v).into_py(py), AnyValue::BinaryOwned(v) => PyBytes::new_bound(py, &v).into_py(py), @@ -115,7 +115,7 @@ pub(crate) fn any_value_into_py_object(av: AnyValue, py: Python) -> PyObject { let buf = unsafe { std::slice::from_raw_parts( buf.as_slice().as_ptr() as *const u8, - N * std::mem::size_of::(), + N * size_of::(), ) }; let digits = PyTuple::new_bound(py, buf.iter().take(n_digits)); @@ -425,8 +425,8 @@ pub(crate) fn py_object_to_any_value<'py>( Ok(get_struct) } else { let ob_type = ob.get_type(); - let type_name = ob_type.qualname().unwrap(); - match &*type_name { + let type_name = ob_type.qualname().unwrap().to_string(); + match type_name.as_str() { // Can't use pyo3::types::PyDateTime with abi3-py37 feature, // so need this workaround instead of `isinstance(ob, datetime)`. "date" => Ok(get_date as InitFn), diff --git a/crates/polars-python/src/conversion/chunked_array.rs b/crates/polars-python/src/conversion/chunked_array.rs index 3a69d61f7dd1..404fe68ce8ef 100644 --- a/crates/polars-python/src/conversion/chunked_array.rs +++ b/crates/polars-python/src/conversion/chunked_array.rs @@ -127,7 +127,7 @@ pub(crate) fn decimal_to_pyobject_iter<'a>( let buf = unsafe { std::slice::from_raw_parts( buf.as_slice().as_ptr() as *const u8, - N * std::mem::size_of::(), + N * size_of::(), ) }; let digits = PyTuple::new_bound(py, buf.iter().take(n_digits)); diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs index 02f8e5008b88..16abea471d7f 100644 --- a/crates/polars-python/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -59,8 +59,8 @@ unsafe impl Transparent for Option { } pub(crate) fn reinterpret_vec(input: Vec) -> Vec { - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - assert_eq!(std::mem::align_of::(), std::mem::align_of::()); + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); let len = input.len(); let cap = input.capacity(); let mut manual_drop_vec = std::mem::ManuallyDrop::new(input); @@ -336,7 +336,7 @@ impl<'py> FromPyObject<'py> for Wrap { impl<'py> FromPyObject<'py> for Wrap { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let py = ob.py(); - let type_name = ob.get_type().qualname()?; + let type_name = ob.get_type().qualname()?.to_string(); let dtype = match &*type_name { "DataTypeClass" => { @@ -689,8 +689,8 @@ impl From<&dyn PolarsObjectSafe> for &ObjectValue { } impl ToPyObject for ObjectValue { - fn to_object(&self, _py: Python) -> PyObject { - self.inner.clone() + fn to_object(&self, py: Python) -> PyObject { + self.inner.clone_ref(py) } } @@ -978,17 +978,18 @@ impl<'py> FromPyObject<'py> for Wrap { } } -impl<'py> FromPyObject<'py> for Wrap { +impl<'py> FromPyObject<'py> for Wrap { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let parsed = match &*ob.extract::()? { - "lower" => QuantileInterpolOptions::Lower, - "higher" => QuantileInterpolOptions::Higher, - "nearest" => QuantileInterpolOptions::Nearest, - "linear" => QuantileInterpolOptions::Linear, - "midpoint" => QuantileInterpolOptions::Midpoint, + "lower" => QuantileMethod::Lower, + "higher" => QuantileMethod::Higher, + "nearest" => QuantileMethod::Nearest, + "linear" => QuantileMethod::Linear, + "midpoint" => QuantileMethod::Midpoint, + "equiprobable" => QuantileMethod::Equiprobable, v => { return Err(PyValueError::new_err(format!( - "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint'}}, got {v}", + "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint', 'equiprobable'}}, got {v}", ))) } }; diff --git a/crates/polars-python/src/dataframe/construction.rs b/crates/polars-python/src/dataframe/construction.rs index 36d16e57187b..1c753ac88e0e 100644 --- a/crates/polars-python/src/dataframe/construction.rs +++ b/crates/polars-python/src/dataframe/construction.rs @@ -12,6 +12,7 @@ use crate::interop; #[pymethods] impl PyDataFrame { #[staticmethod] + #[pyo3(signature = (data, schema=None, infer_schema_length=None))] pub fn from_rows( py: Python, data: Vec>, diff --git a/crates/polars-python/src/dataframe/export.rs b/crates/polars-python/src/dataframe/export.rs index 29ba4edb5371..a7e9394ab4fd 100644 --- a/crates/polars-python/src/dataframe/export.rs +++ b/crates/polars-python/src/dataframe/export.rs @@ -113,6 +113,7 @@ impl PyDataFrame { .df .iter_chunks(CompatLevel::oldest(), true) .map(|rb| { + let length = rb.len(); let mut rb = rb.into_arrays(); for i in &cat_columns { let arr = rb.get_mut(*i).unwrap(); @@ -128,7 +129,7 @@ impl PyDataFrame { .unwrap(); *arr = out; } - let rb = RecordBatch::new(rb); + let rb = RecordBatch::new(length, rb); interop::arrow::to_py::to_py_rb(&rb, &names, py, &pyarrow) }) diff --git a/crates/polars-python/src/dataframe/general.rs b/crates/polars-python/src/dataframe/general.rs index 5df89ed423c7..ac4febced0f6 100644 --- a/crates/polars-python/src/dataframe/general.rs +++ b/crates/polars-python/src/dataframe/general.rs @@ -95,6 +95,7 @@ impl PyDataFrame { Ok(df.into()) } + #[pyo3(signature = (n, with_replacement, shuffle, seed=None))] pub fn sample_n( &self, n: &PySeries, @@ -109,6 +110,7 @@ impl PyDataFrame { Ok(df.into()) } + #[pyo3(signature = (frac, with_replacement, shuffle, seed=None))] pub fn sample_frac( &self, frac: &PySeries, @@ -260,16 +262,16 @@ impl PyDataFrame { Ok(PyDataFrame::new(df)) } - pub fn gather(&self, indices: Wrap>) -> PyResult { + pub fn gather(&self, py: Python, indices: Wrap>) -> PyResult { let indices = indices.0; let indices = IdxCa::from_vec("".into(), indices); - let df = self.df.take(&indices).map_err(PyPolarsErr::from)?; + let df = Python::allow_threads(py, || self.df.take(&indices).map_err(PyPolarsErr::from))?; Ok(PyDataFrame::new(df)) } - pub fn gather_with_series(&self, indices: &PySeries) -> PyResult { + pub fn gather_with_series(&self, py: Python, indices: &PySeries) -> PyResult { let indices = indices.series.idx().map_err(PyPolarsErr::from)?; - let df = self.df.take(indices).map_err(PyPolarsErr::from)?; + let df = Python::allow_threads(py, || self.df.take(indices).map_err(PyPolarsErr::from))?; Ok(PyDataFrame::new(df)) } @@ -294,6 +296,7 @@ impl PyDataFrame { Ok(()) } + #[pyo3(signature = (offset, length=None))] pub fn slice(&self, offset: i64, length: Option) -> Self { let df = self .df @@ -329,6 +332,7 @@ impl PyDataFrame { } } + #[pyo3(signature = (name, offset=None))] pub fn with_row_index(&self, name: &str, offset: Option) -> PyResult { let df = self .df @@ -391,6 +395,7 @@ impl PyDataFrame { } #[cfg(feature = "pivot")] + #[pyo3(signature = (on, index, value_name=None, variable_name=None))] pub fn unpivot( &self, on: Vec, @@ -586,11 +591,11 @@ impl PyDataFrame { every: &str, stable: bool, ) -> PyResult { + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let out = if stable { - self.df - .upsample_stable(by, index_column, Duration::parse(every)) + self.df.upsample_stable(by, index_column, every) } else { - self.df.upsample(by, index_column, Duration::parse(every)) + self.df.upsample(by, index_column, every) }; let out = out.map_err(PyPolarsErr::from)?; Ok(out.into()) @@ -627,8 +632,8 @@ impl PyDataFrame { pub fn into_raw_parts(&mut self) -> (usize, usize, usize) { // Used for polars-lazy python node. This takes the dataframe from // underneath of you, so don't use this anywhere else. - let mut df = std::mem::take(&mut self.df); - let cols = unsafe { std::mem::take(df.get_columns_mut()) }; + let df = std::mem::take(&mut self.df); + let cols = df.take_columns(); let mut md_cols = ManuallyDrop::new(cols); let ptr = md_cols.as_mut_ptr(); let len = md_cols.len(); diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index 852ccff23a1c..9b34eb7e8ae9 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -188,6 +188,7 @@ impl PyDataFrame { #[staticmethod] #[cfg(feature = "json")] + #[pyo3(signature = (py_f, infer_schema_length=None, schema=None, schema_overrides=None))] pub fn read_json( py: Python, mut py_f: Bound, @@ -220,6 +221,7 @@ impl PyDataFrame { #[staticmethod] #[cfg(feature = "json")] + #[pyo3(signature = (py_f, ignore_errors, schema=None, schema_overrides=None))] pub fn read_ndjson( py: Python, mut py_f: Bound, @@ -337,6 +339,7 @@ impl PyDataFrame { } #[cfg(feature = "csv")] + #[pyo3(signature = (py_f, include_bom, include_header, separator, line_terminator, quote_char, batch_size, datetime_format=None, date_format=None, time_format=None, float_scientific=None, float_precision=None, null_value=None, quote_style=None))] pub fn write_csv( &mut self, py: Python, diff --git a/crates/polars-python/src/dataframe/serde.rs b/crates/polars-python/src/dataframe/serde.rs index 5bd54d5114af..b08d2bd5ed85 100644 --- a/crates/polars-python/src/dataframe/serde.rs +++ b/crates/polars-python/src/dataframe/serde.rs @@ -4,6 +4,7 @@ use std::ops::Deref; use polars::prelude::*; use polars_io::mmap::ReaderBytes; use pyo3::prelude::*; +use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; use super::PyDataFrame; @@ -25,11 +26,11 @@ impl PyDataFrame { } #[cfg(feature = "ipc_streaming")] - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, state: &Bound) -> PyResult<()> { // Used in pickle/pickling - match state.extract::<&PyBytes>(py) { + match state.extract::() { Ok(s) => { - let c = Cursor::new(s.as_bytes()); + let c = Cursor::new(&*s); let reader = IpcStreamReader::new(c); reader diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index cf0c50cd2210..7125388e88cd 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -9,6 +9,7 @@ use pyo3::class::basic::CompareOp; use pyo3::prelude::*; use crate::conversion::{parse_fill_null_strategy, vec_extract_wrapped, Wrap}; +use crate::error::PyPolarsErr; use crate::map::lazy::map_single; use crate::PyExpr; @@ -149,7 +150,7 @@ impl PyExpr { fn implode(&self) -> Self { self.inner.clone().implode().into() } - fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { self.inner .clone() .quantile(quantile.inner, interpolation.0) @@ -361,6 +362,7 @@ impl PyExpr { self.inner.clone().forward_fill(limit).into() } + #[pyo3(signature = (n, fill_value=None))] fn shift(&self, n: Self, fill_value: Option) -> Self { let expr = self.inner.clone(); let out = match fill_value { @@ -470,6 +472,7 @@ impl PyExpr { self.inner.clone().ceil().into() } + #[pyo3(signature = (min=None, max=None))] fn clip(&self, min: Option, max: Option) -> Self { let expr = self.inner.clone(); let out = match (min, max) { @@ -612,15 +615,15 @@ impl PyExpr { period: &str, offset: &str, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingGroupOptions { index_column: index_column.into(), - period: Duration::parse(period), - offset: Duration::parse(offset), + period: Duration::try_parse(period).map_err(PyPolarsErr::from)?, + offset: Duration::try_parse(offset).map_err(PyPolarsErr::from)?, closed_window: closed.0, }; - self.inner.clone().rolling(options).into() + Ok(self.inner.clone().rolling(options).into()) } fn and_(&self, expr: Self) -> Self { @@ -739,6 +742,7 @@ impl PyExpr { self.inner.clone().upper_bound().into() } + #[pyo3(signature = (method, descending, seed=None))] fn rank(&self, method: Wrap, descending: bool, seed: Option) -> Self { let options = RankOptions { method: method.0, @@ -809,12 +813,13 @@ impl PyExpr { }; self.inner.clone().ewm_mean(options).into() } - fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> Self { - let half_life = Duration::parse(half_life); - self.inner + fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> PyResult { + let half_life = Duration::try_parse(half_life).map_err(PyPolarsErr::from)?; + Ok(self + .inner .clone() .ewm_mean_by(times.inner, half_life) - .into() + .into()) } fn ewm_std( @@ -896,6 +901,7 @@ impl PyExpr { self.inner.clone().replace(old.inner, new.inner).into() } + #[pyo3(signature = (old, new, default=None, return_dtype=None))] fn replace_strict( &self, old: PyExpr, diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index cb179eb0e859..af3be10449b1 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -4,6 +4,7 @@ use polars::prelude::*; use polars::series::ops::NullBehavior; use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; +use pyo3::types::PySequence; use crate::conversion::Wrap; use crate::PyExpr; @@ -118,6 +119,7 @@ impl PyExpr { self.inner.clone().list().shift(periods.inner).into() } + #[pyo3(signature = (offset, length=None))] fn list_slice(&self, offset: PyExpr, length: Option) -> Self { let length = match length { Some(i) => i.inner, @@ -152,6 +154,7 @@ impl PyExpr { } #[cfg(feature = "list_sample")] + #[pyo3(signature = (n, with_replacement, shuffle, seed=None))] fn list_sample_n( &self, n: PyExpr, @@ -167,6 +170,7 @@ impl PyExpr { } #[cfg(feature = "list_sample")] + #[pyo3(signature = (fraction, with_replacement, shuffle, seed=None))] fn list_sample_fraction( &self, fraction: PyExpr, @@ -211,20 +215,39 @@ impl PyExpr { upper_bound: usize, ) -> PyResult { let name_gen = name_gen.map(|lambda| { - Arc::new(move |idx: usize| { + NameGenerator::from_func(move |idx: usize| { Python::with_gil(|py| { let out = lambda.call1(py, (idx,)).unwrap(); let out: PlSmallStr = out.extract::>(py).unwrap().as_ref().into(); out }) - }) as NameGenerator + }) }); Ok(self .inner .clone() .list() - .to_struct(width_strat.0, name_gen, upper_bound) + .to_struct(ListToStructArgs::InferWidth { + infer_field_strategy: width_strat.0, + get_index_name: name_gen, + max_fields: upper_bound, + }) + .into()) + } + + #[pyo3(signature = (names))] + fn list_to_struct_fixed_width(&self, names: Bound<'_, PySequence>) -> PyResult { + Ok(self + .inner + .clone() + .list() + .to_struct(ListToStructArgs::FixedWidth( + names + .iter()? + .map(|x| Ok(x?.extract::>()?.0)) + .collect::>>()?, + )) .into()) } diff --git a/crates/polars-python/src/expr/rolling.rs b/crates/polars-python/src/expr/rolling.rs index 629f1eab391d..5ef511902613 100644 --- a/crates/polars-python/src/expr/rolling.rs +++ b/crates/polars-python/src/expr/rolling.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyFloat; use crate::conversion::Wrap; +use crate::error::PyPolarsErr; use crate::map::lazy::call_lambda_with_series; use crate::{PyExpr, PySeries}; @@ -34,14 +35,14 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_sum_by(by.inner, options).into() + Ok(self.inner.clone().rolling_sum_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -70,14 +71,14 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_min_by(by.inner, options).into() + Ok(self.inner.clone().rolling_min_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -105,14 +106,14 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_max_by(by.inner, options).into() + Ok(self.inner.clone().rolling_max_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -142,15 +143,15 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_mean_by(by.inner, options).into() + Ok(self.inner.clone().rolling_mean_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] @@ -182,15 +183,15 @@ impl PyExpr { min_periods: usize, closed: Wrap, ddof: u8, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })), }; - self.inner.clone().rolling_std_by(by.inner, options).into() + Ok(self.inner.clone().rolling_std_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] @@ -222,15 +223,15 @@ impl PyExpr { min_periods: usize, closed: Wrap, ddof: u8, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })), }; - self.inner.clone().rolling_var_by(by.inner, options).into() + Ok(self.inner.clone().rolling_var_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -259,24 +260,25 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner + Ok(self + .inner .clone() .rolling_median_by(by.inner, options) - .into() + .into()) } #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center))] fn rolling_quantile( &self, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: usize, weights: Option>, min_periods: Option, @@ -302,22 +304,23 @@ impl PyExpr { &self, by: PyExpr, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner + Ok(self + .inner .clone() .rolling_quantile_by(by.inner, interpolation.0, quantile, options) - .into() + .into()) } fn rolling_skew(&self, window_size: usize, bias: bool) -> Self { diff --git a/crates/polars-python/src/expr/serde.rs b/crates/polars-python/src/expr/serde.rs index 8045a1076d39..9933e9a979b1 100644 --- a/crates/polars-python/src/expr/serde.rs +++ b/crates/polars-python/src/expr/serde.rs @@ -21,9 +21,9 @@ impl PyExpr { Ok(PyBytes::new_bound(py, &writer).to_object(py)) } - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, state: &Bound) -> PyResult<()> { // Used in pickle/pickling - match state.extract::(py) { + match state.extract::() { Ok(s) => { let cursor = Cursor::new(&*s); self.inner = ciborium::de::from_reader(cursor) diff --git a/crates/polars-python/src/expr/string.rs b/crates/polars-python/src/expr/string.rs index e238e412dc02..87521a2b7aa1 100644 --- a/crates/polars-python/src/expr/string.rs +++ b/crates/polars-python/src/expr/string.rs @@ -226,6 +226,7 @@ impl PyExpr { } #[cfg(feature = "extract_jsonpath")] + #[pyo3(signature = (dtype=None, infer_schema_len=None))] fn str_json_decode( &self, dtype: Option>, @@ -338,4 +339,9 @@ impl PyExpr { .extract_many(patterns.inner, ascii_case_insensitive, overlapping) .into() } + + #[cfg(feature = "regex")] + fn str_escape_regex(&self) -> Self { + self.inner.clone().str().escape_regex().into() + } } diff --git a/crates/polars-python/src/file.rs b/crates/polars-python/src/file.rs index 33d084c5130c..efbcbff3fc18 100644 --- a/crates/polars-python/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -9,7 +9,7 @@ use std::os::fd::{FromRawFd, RawFd}; use std::path::PathBuf; use polars::io::mmap::MmapBytesReader; -use polars_error::{polars_err, polars_warn}; +use polars_error::polars_err; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyString, PyStringMethods}; @@ -25,7 +25,7 @@ pub struct PyFileLikeObject { /// Wraps a `PyObject`, and implements read, seek, and write for it. impl PyFileLikeObject { /// Creates an instance of a `PyFileLikeObject` from a `PyObject`. - /// To assert the object has the required methods methods, + /// To assert the object has the required methods, /// instantiate it with `PyFileLikeObject::require` pub fn new(object: PyObject) -> Self { PyFileLikeObject { inner: object } @@ -284,14 +284,6 @@ pub fn get_python_scan_source_input( })); } - // BytesIO / StringIO is relatively fast, and some code relies on it. - if !py_f.is_exact_instance(&io.getattr("BytesIO").unwrap()) - && !py_f.is_exact_instance(&io.getattr("StringIO").unwrap()) - { - polars_warn!("Polars found a filename. \ - Ensure you pass a path to the file instead of a python file object when possible for best \ - performance."); - } // Unwrap TextIOWrapper // Allow subclasses to allow things like pytest.capture.CaptureIO let py_f = if py_f @@ -397,14 +389,6 @@ fn get_either_buffer_or_path( )); } - // BytesIO / StringIO is relatively fast, and some code relies on it. - if !py_f.is_exact_instance(&io.getattr("BytesIO").unwrap()) - && !py_f.is_exact_instance(&io.getattr("StringIO").unwrap()) - { - polars_warn!("Polars found a filename. \ - Ensure you pass a path to the file instead of a python file object when possible for best \ - performance."); - } // Unwrap TextIOWrapper // Allow subclasses to allow things like pytest.capture.CaptureIO let py_f = if py_f diff --git a/crates/polars-python/src/functions/lazy.rs b/crates/polars-python/src/functions/lazy.rs index cb826a699551..24db48144508 100644 --- a/crates/polars-python/src/functions/lazy.rs +++ b/crates/polars-python/src/functions/lazy.rs @@ -252,7 +252,7 @@ pub fn cum_reduce(lambda: PyObject, exprs: Vec) -> PyExpr { } #[pyfunction] -#[pyo3(signature = (year, month, day, hour=None, minute=None, second=None, microsecond=None, time_unit=Wrap(TimeUnit::Microseconds), time_zone=None, ambiguous=None))] +#[pyo3(signature = (year, month, day, hour=None, minute=None, second=None, microsecond=None, time_unit=Wrap(TimeUnit::Microseconds), time_zone=None, ambiguous=PyExpr::from(dsl::lit(String::from("raise")))))] pub fn datetime( year: PyExpr, month: PyExpr, @@ -263,15 +263,13 @@ pub fn datetime( microsecond: Option, time_unit: Wrap, time_zone: Option>, - ambiguous: Option, + ambiguous: PyExpr, ) -> PyExpr { let year = year.inner; let month = month.inner; let day = day.inner; set_unwrapped_or_0!(hour, minute, second, microsecond); - let ambiguous = ambiguous - .map(|e| e.inner) - .unwrap_or(dsl::lit(String::from("raise"))); + let ambiguous = ambiguous.inner; let time_unit = time_unit.0; let time_zone = time_zone.map(|x| x.0); let args = DatetimeArgs { @@ -466,7 +464,7 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool, is_scalar: bool) -> PyR format!( "cannot create expression literal for value of type {}.\ \n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.", - value.get_type().qualname().unwrap_or("unknown".to_owned()), + value.get_type().qualname().map(|s|s.to_string()).unwrap_or("unknown".to_owned()), ) ) })?; @@ -478,7 +476,7 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool, is_scalar: bool) -> PyR }); Ok(dsl::lit(s).into()) }, - _ => Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()), + _ => Ok(Expr::Literal(LiteralValue::from(av)).into()), } } } @@ -517,6 +515,7 @@ pub fn reduce(lambda: PyObject, exprs: Vec) -> PyExpr { } #[pyfunction] +#[pyo3(signature = (value, n, dtype=None))] pub fn repeat(value: PyExpr, n: PyExpr, dtype: Option>) -> PyResult { let mut value = value.inner; let n = n.inner; diff --git a/crates/polars-python/src/functions/meta.rs b/crates/polars-python/src/functions/meta.rs index bc43657e1b12..ba6af7f2c669 100644 --- a/crates/polars-python/src/functions/meta.rs +++ b/crates/polars-python/src/functions/meta.rs @@ -41,6 +41,7 @@ pub fn get_float_fmt() -> PyResult { } #[pyfunction] +#[pyo3(signature = (precision=None))] pub fn set_float_precision(precision: Option) -> PyResult<()> { use polars_core::fmt::set_float_precision; set_float_precision(precision); @@ -54,6 +55,7 @@ pub fn get_float_precision() -> PyResult> { } #[pyfunction] +#[pyo3(signature = (sep=None))] pub fn set_thousands_separator(sep: Option) -> PyResult<()> { use polars_core::fmt::set_thousands_separator; set_thousands_separator(sep); @@ -67,6 +69,7 @@ pub fn get_thousands_separator() -> PyResult> { } #[pyfunction] +#[pyo3(signature = (sep=None))] pub fn set_decimal_separator(sep: Option) -> PyResult<()> { use polars_core::fmt::set_decimal_separator; set_decimal_separator(sep); @@ -80,6 +83,7 @@ pub fn get_decimal_separator() -> PyResult> { } #[pyfunction] +#[pyo3(signature = (trim=None))] pub fn set_trim_decimal_zeros(trim: Option) -> PyResult<()> { use polars_core::fmt::set_trim_decimal_zeros; set_trim_decimal_zeros(trim); diff --git a/crates/polars-python/src/functions/mod.rs b/crates/polars-python/src/functions/mod.rs index 0bb5e55ea23c..ddf58c7acde6 100644 --- a/crates/polars-python/src/functions/mod.rs +++ b/crates/polars-python/src/functions/mod.rs @@ -8,6 +8,7 @@ mod misc; mod random; mod range; mod string_cache; +mod strings; mod whenthen; pub use aggregation::*; @@ -20,4 +21,5 @@ pub use misc::*; pub use random::*; pub use range::*; pub use string_cache::*; +pub use strings::*; pub use whenthen::*; diff --git a/crates/polars-python/src/functions/range.rs b/crates/polars-python/src/functions/range.rs index b07522650de3..b6eae4400dd8 100644 --- a/crates/polars-python/src/functions/range.rs +++ b/crates/polars-python/src/functions/range.rs @@ -71,12 +71,12 @@ pub fn date_range( end: PyExpr, interval: &str, closed: Wrap, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let interval = Duration::parse(interval); + let interval = Duration::try_parse(interval).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::date_range(start, end, interval, closed).into() + Ok(dsl::date_range(start, end, interval, closed).into()) } #[pyfunction] @@ -85,15 +85,16 @@ pub fn date_ranges( end: PyExpr, interval: &str, closed: Wrap, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let interval = Duration::parse(interval); + let interval = Duration::try_parse(interval).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::date_ranges(start, end, interval, closed).into() + Ok(dsl::date_ranges(start, end, interval, closed).into()) } #[pyfunction] +#[pyo3(signature = (start, end, every, closed, time_unit=None, time_zone=None))] pub fn datetime_range( start: PyExpr, end: PyExpr, @@ -101,17 +102,18 @@ pub fn datetime_range( closed: Wrap, time_unit: Option>, time_zone: Option>, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; let time_unit = time_unit.map(|x| x.0); let time_zone = time_zone.map(|x| x.0); - dsl::datetime_range(start, end, every, closed, time_unit, time_zone).into() + Ok(dsl::datetime_range(start, end, every, closed, time_unit, time_zone).into()) } #[pyfunction] +#[pyo3(signature = (start, end, every, closed, time_unit=None, time_zone=None))] pub fn datetime_ranges( start: PyExpr, end: PyExpr, @@ -119,30 +121,40 @@ pub fn datetime_ranges( closed: Wrap, time_unit: Option>, time_zone: Option>, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; let time_unit = time_unit.map(|x| x.0); let time_zone = time_zone.map(|x| x.0); - dsl::datetime_ranges(start, end, every, closed, time_unit, time_zone).into() + Ok(dsl::datetime_ranges(start, end, every, closed, time_unit, time_zone).into()) } #[pyfunction] -pub fn time_range(start: PyExpr, end: PyExpr, every: &str, closed: Wrap) -> PyExpr { +pub fn time_range( + start: PyExpr, + end: PyExpr, + every: &str, + closed: Wrap, +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::time_range(start, end, every, closed).into() + Ok(dsl::time_range(start, end, every, closed).into()) } #[pyfunction] -pub fn time_ranges(start: PyExpr, end: PyExpr, every: &str, closed: Wrap) -> PyExpr { +pub fn time_ranges( + start: PyExpr, + end: PyExpr, + every: &str, + closed: Wrap, +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::time_ranges(start, end, every, closed).into() + Ok(dsl::time_ranges(start, end, every, closed).into()) } diff --git a/crates/polars-python/src/functions/strings.rs b/crates/polars-python/src/functions/strings.rs new file mode 100644 index 000000000000..d75666ecf367 --- /dev/null +++ b/crates/polars-python/src/functions/strings.rs @@ -0,0 +1,7 @@ +use pyo3::prelude::*; + +#[pyfunction] +pub fn escape_regex(s: &str) -> PyResult { + let escaped_s = polars_ops::chunked_array::strings::escape_regex_str(s); + Ok(escaped_s) +} diff --git a/crates/polars-python/src/interop/arrow/to_py.rs b/crates/polars-python/src/interop/arrow/to_py.rs index 017771bb1567..1a90e9eb680a 100644 --- a/crates/polars-python/src/interop/arrow/to_py.rs +++ b/crates/polars-python/src/interop/arrow/to_py.rs @@ -123,10 +123,15 @@ impl Iterator for DataFrameStreamIterator { .columns .iter() .map(|s| s.to_arrow(self.idx, CompatLevel::newest())) - .collect(); + .collect::>(); self.idx += 1; - let array = arrow::array::StructArray::new(self.dtype.clone(), batch_cols, None); + let array = arrow::array::StructArray::new( + self.dtype.clone(), + batch_cols[0].len(), + batch_cols, + None, + ); Some(Ok(Box::new(array))) } } diff --git a/crates/polars-python/src/interop/arrow/to_rust.rs b/crates/polars-python/src/interop/arrow/to_rust.rs index 809bd527a492..1add88c96fd8 100644 --- a/crates/polars-python/src/interop/arrow/to_rust.rs +++ b/crates/polars-python/src/interop/arrow/to_rust.rs @@ -105,7 +105,7 @@ pub fn to_rust_df(rb: &[Bound]) -> PyResult { }?; // no need to check as a record batch has the same guarantees - Ok(unsafe { DataFrame::new_no_checks(columns) }) + Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) }) }) .collect::>>()?; diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index b0307f23e6f8..da0b597418eb 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -41,7 +41,7 @@ impl PyLazyFrame { #[allow(clippy::too_many_arguments)] #[pyo3(signature = ( source, sources, infer_schema_length, schema, schema_overrides, batch_size, n_rows, low_memory, rechunk, - row_index, ignore_errors, include_file_paths, cloud_options, retries, file_cache_ttl + row_index, ignore_errors, include_file_paths, cloud_options, credential_provider, retries, file_cache_ttl ))] fn new_from_ndjson( source: Option, @@ -57,9 +57,11 @@ impl PyLazyFrame { ignore_errors: bool, include_file_paths: Option, cloud_options: Option>, + credential_provider: Option, retries: usize, file_cache_ttl: Option, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; let row_index = row_index.map(|(name, offset)| RowIndex { name: name.into(), offset, @@ -79,7 +81,11 @@ impl PyLazyFrame { let mut cloud_options = parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; - cloud_options = cloud_options.with_max_retries(retries); + cloud_options = cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ); if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; @@ -111,7 +117,7 @@ impl PyLazyFrame { low_memory, comment_prefix, quote_char, null_values, missing_utf8_is_empty_string, infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header, encoding, row_index, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, glob, schema, - cloud_options, retries, file_cache_ttl, include_file_paths + cloud_options, credential_provider, retries, file_cache_ttl, include_file_paths ) )] fn new_from_csv( @@ -143,10 +149,13 @@ impl PyLazyFrame { glob: bool, schema: Option>, cloud_options: Option>, + credential_provider: Option, retries: usize, file_cache_ttl: Option, include_file_paths: Option, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; + let null_values = null_values.map(|w| w.0); let quote_char = quote_char .map(|s| { @@ -198,7 +207,11 @@ impl PyLazyFrame { if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; } - cloud_options = cloud_options.with_max_retries(retries); + cloud_options = cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ); r = r.with_cloud_options(Some(cloud_options)); } @@ -257,9 +270,11 @@ impl PyLazyFrame { #[cfg(feature = "parquet")] #[staticmethod] - #[pyo3(signature = (source, sources, n_rows, cache, parallel, rechunk, row_index, - low_memory, cloud_options, use_statistics, hive_partitioning, schema, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns) - )] + #[pyo3(signature = ( + source, sources, n_rows, cache, parallel, rechunk, row_index, low_memory, cloud_options, + credential_provider, use_statistics, hive_partitioning, schema, hive_schema, + try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns, + ))] fn new_from_parquet( source: Option, sources: Wrap, @@ -270,6 +285,7 @@ impl PyLazyFrame { row_index: Option<(String, IdxSize)>, low_memory: bool, cloud_options: Option>, + credential_provider: Option, use_statistics: bool, hive_partitioning: Option, schema: Option>, @@ -280,6 +296,8 @@ impl PyLazyFrame { include_file_paths: Option, allow_missing_columns: bool, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; + let parallel = parallel.0; let hive_schema = hive_schema.map(|s| Arc::new(s.0)); @@ -322,7 +340,13 @@ impl PyLazyFrame { let first_path_url = first_path.to_string_lossy(); let cloud_options = parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; - args.cloud_options = Some(cloud_options.with_max_retries(retries)); + args.cloud_options = Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ); } let lf = LazyFrame::scan_parquet_sources(sources, args).map_err(PyPolarsErr::from)?; @@ -332,7 +356,11 @@ impl PyLazyFrame { #[cfg(feature = "ipc")] #[staticmethod] - #[pyo3(signature = (source, sources, n_rows, cache, rechunk, row_index, cloud_options, hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, include_file_paths))] + #[pyo3(signature = ( + source, sources, n_rows, cache, rechunk, row_index, cloud_options,credential_provider, + hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, + include_file_paths + ))] fn new_from_ipc( source: Option, sources: Wrap, @@ -341,6 +369,7 @@ impl PyLazyFrame { rechunk: bool, row_index: Option<(String, IdxSize)>, cloud_options: Option>, + credential_provider: Option, hive_partitioning: Option, hive_schema: Option>, try_parse_hive_dates: bool, @@ -348,6 +377,7 @@ impl PyLazyFrame { file_cache_ttl: Option, include_file_paths: Option, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; let row_index = row_index.map(|(name, offset)| RowIndex { name: name.into(), offset, @@ -386,7 +416,13 @@ impl PyLazyFrame { if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; } - args.cloud_options = Some(cloud_options.with_max_retries(retries)); + args.cloud_options = Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ); } let lf = LazyFrame::scan_ipc_sources(sources, args).map_err(PyPolarsErr::from)?; @@ -575,6 +611,7 @@ impl PyLazyFrame { Ok((df.into(), time_df.into())) } + #[pyo3(signature = (lambda_post_opt=None))] fn collect(&self, py: Python, lambda_post_opt: Option) -> PyResult { // if we don't allow threads and we have udfs trying to acquire the gil from different // threads we deadlock. @@ -805,7 +842,7 @@ impl PyLazyFrame { offset: &str, closed: Wrap, by: Vec, - ) -> PyLazyGroupBy { + ) -> PyResult { let closed_window = closed.0; let ldf = self.ldf.clone(); let by = by @@ -817,13 +854,13 @@ impl PyLazyFrame { by, RollingGroupOptions { index_column: "".into(), - period: Duration::parse(period), - offset: Duration::parse(offset), + period: Duration::try_parse(period).map_err(PyPolarsErr::from)?, + offset: Duration::try_parse(offset).map_err(PyPolarsErr::from)?, closed_window, }, ); - PyLazyGroupBy { lgb: Some(lazy_gb) } + Ok(PyLazyGroupBy { lgb: Some(lazy_gb) }) } fn group_by_dynamic( @@ -837,7 +874,7 @@ impl PyLazyFrame { closed: Wrap, group_by: Vec, start_by: Wrap, - ) -> PyLazyGroupBy { + ) -> PyResult { let closed_window = closed.0; let group_by = group_by .into_iter() @@ -848,9 +885,9 @@ impl PyLazyFrame { index_column.inner, group_by, DynamicGroupOptions { - every: Duration::parse(every), - period: Duration::parse(period), - offset: Duration::parse(offset), + every: Duration::try_parse(every).map_err(PyPolarsErr::from)?, + period: Duration::try_parse(period).map_err(PyPolarsErr::from)?, + offset: Duration::try_parse(offset).map_err(PyPolarsErr::from)?, label: label.0, include_boundaries, closed_window, @@ -859,7 +896,7 @@ impl PyLazyFrame { }, ); - PyLazyGroupBy { lgb: Some(lazy_gb) } + Ok(PyLazyGroupBy { lgb: Some(lazy_gb) }) } fn with_context(&self, contexts: Vec) -> Self { @@ -913,6 +950,7 @@ impl PyLazyFrame { .into()) } + #[pyo3(signature = (other, left_on, right_on, allow_parallel, force_parallel, join_nulls, how, suffix, validate, coalesce=None))] fn join( &self, other: Self, @@ -992,6 +1030,7 @@ impl PyLazyFrame { ldf.reverse().into() } + #[pyo3(signature = (n, fill_value=None))] fn shift(&self, n: PyExpr, fill_value: Option) -> Self { let lf = self.ldf.clone(); let out = match fill_value { @@ -1048,7 +1087,7 @@ impl PyLazyFrame { out.into() } - fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { let ldf = self.ldf.clone(); let out = ldf.quantile(quantile.inner, interpolation.0); out.into() @@ -1081,12 +1120,14 @@ impl PyLazyFrame { .into() } + #[pyo3(signature = (subset=None))] fn drop_nulls(&self, subset: Option>) -> Self { let ldf = self.ldf.clone(); let subset = subset.map(|e| e.to_exprs()); ldf.drop_nulls(subset).into() } + #[pyo3(signature = (offset, len=None))] fn slice(&self, offset: i64, len: Option) -> Self { let ldf = self.ldf.clone(); ldf.slice(offset, len.unwrap_or(IdxSize::MAX)).into() @@ -1116,6 +1157,7 @@ impl PyLazyFrame { ldf.unpivot(args).into() } + #[pyo3(signature = (name, offset=None))] fn with_row_index(&self, name: &str, offset: Option) -> Self { let ldf = self.ldf.clone(); ldf.with_row_index(name, offset).into() diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 32c8d3d23b7d..bc4cebb360a2 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (2, 2); + const VERSION: Version = (3, 1); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index e8832e9b5488..06a98e3fe970 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -2,9 +2,7 @@ use polars::datatypes::TimeUnit; #[cfg(feature = "iejoin")] use polars::prelude::InequalityOperator; use polars::series::ops::NullBehavior; -use polars_core::prelude::{NonExistent, QuantileInterpolOptions}; use polars_core::series::IsSorted; -use polars_ops::prelude::ClosedInterval; use polars_ops::series::InterpolationMethod; #[cfg(feature = "search_sorted")] use polars_ops::series::SearchSortedSide; @@ -16,6 +14,7 @@ use polars_plan::prelude::{ WindowMapping, WindowType, }; use polars_time::prelude::RollingGroupOptions; +use polars_time::{Duration, DynamicGroupOptions}; use pyo3::exceptions::PyNotImplementedError; use pyo3::prelude::*; @@ -44,18 +43,6 @@ pub struct Literal { dtype: PyObject, } -impl IntoPy for Wrap { - fn into_py(self, py: Python<'_>) -> PyObject { - match self.0 { - ClosedInterval::Both => "both", - ClosedInterval::Left => "left", - ClosedInterval::Right => "right", - ClosedInterval::None => "none", - } - .into_py(py) - } -} - #[pyclass(name = "Operator")] #[derive(Copy, Clone)] pub enum PyOperator { @@ -174,6 +161,7 @@ pub enum PyStringFunction { ZFill, ContainsMany, ReplaceMany, + EscapeRegex, } #[pymethods] @@ -403,15 +391,25 @@ pub struct PyWindowMapping { impl PyWindowMapping { #[getter] fn kind(&self, py: Python<'_>) -> PyResult { - let result = match self.inner { - WindowMapping::GroupsToRows => "groups_to_rows".to_object(py), - WindowMapping::Explode => "explode".to_object(py), - WindowMapping::Join => "join".to_object(py), - }; + let result: &str = self.inner.into(); Ok(result.into_py(py)) } } +impl IntoPy for Wrap { + fn into_py(self, py: Python<'_>) -> PyObject { + ( + self.0.months(), + self.0.weeks(), + self.0.days(), + self.0.nanoseconds(), + self.0.parsed_int, + self.0.negative(), + ) + .into_py(py) + } +} + #[pyclass(name = "RollingGroupOptions")] pub struct PyRollingGroupOptions { inner: RollingGroupOptions, @@ -426,41 +424,68 @@ impl PyRollingGroupOptions { #[getter] fn period(&self, py: Python<'_>) -> PyResult { - let result = vec![ - self.inner.period.months().to_object(py), - self.inner.period.weeks().to_object(py), - self.inner.period.days().to_object(py), - self.inner.period.nanoseconds().to_object(py), - self.inner.period.parsed_int.to_object(py), - self.inner.period.negative().to_object(py), - ] - .into_py(py); - Ok(result) + Ok(Wrap(self.inner.period).into_py(py)) } #[getter] fn offset(&self, py: Python<'_>) -> PyResult { - let result = vec![ - self.inner.offset.months().to_object(py), - self.inner.offset.weeks().to_object(py), - self.inner.offset.days().to_object(py), - self.inner.offset.nanoseconds().to_object(py), - self.inner.offset.parsed_int.to_object(py), - self.inner.offset.negative().to_object(py), - ] - .into_py(py); - Ok(result) + Ok(Wrap(self.inner.offset).into_py(py)) } #[getter] fn closed_window(&self, py: Python<'_>) -> PyResult { - let result = match self.inner.closed_window { - polars::time::ClosedWindow::Left => "left".to_object(py), - polars::time::ClosedWindow::Right => "right".to_object(py), - polars::time::ClosedWindow::Both => "both".to_object(py), - polars::time::ClosedWindow::None => "none".to_object(py), - }; - Ok(result.into_py(py)) + let result: &str = self.inner.closed_window.into(); + Ok(result.to_object(py)) + } +} + +#[pyclass(name = "DynamicGroupOptions")] +pub struct PyDynamicGroupOptions { + inner: DynamicGroupOptions, +} + +#[pymethods] +impl PyDynamicGroupOptions { + #[getter] + fn index_column(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.index_column.to_object(py)) + } + + #[getter] + fn every(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.every).into_py(py)) + } + + #[getter] + fn period(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.period).into_py(py)) + } + + #[getter] + fn offset(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.offset).into_py(py)) + } + + #[getter] + fn label(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.label.into(); + Ok(result.to_object(py)) + } + + #[getter] + fn include_boundaries(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.include_boundaries.into_py(py)) + } + + #[getter] + fn closed_window(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.closed_window.into(); + Ok(result.to_object(py)) + } + #[getter] + fn start_by(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.start_by.into(); + Ok(result.to_object(py)) } } @@ -485,6 +510,14 @@ impl PyGroupbyOptions { .map_or_else(|| py.None(), |f| f.to_object(py))) } + #[getter] + fn dynamic(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.dynamic.as_ref().map_or_else( + || py.None(), + |f| PyDynamicGroupOptions { inner: f.clone() }.into_py(py), + )) + } + #[getter] fn rolling(&self, py: Python<'_>) -> PyResult { Ok(self.inner.rolling.as_ref().map_or_else( @@ -700,18 +733,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Quantile { expr, quantile, - interpol, + method: interpol, } => Agg { name: "quantile".to_object(py), arguments: vec![expr.0, quantile.0], - options: match interpol { - QuantileInterpolOptions::Nearest => "nearest", - QuantileInterpolOptions::Lower => "lower", - QuantileInterpolOptions::Higher => "higher", - QuantileInterpolOptions::Midpoint => "midpoint", - QuantileInterpolOptions::Linear => "linear", - } - .to_object(py), + options: Into::<&str>::into(interpol).to_object(py), }, IRAggExpr::Sum(n) => Agg { name: "sum".to_object(py), @@ -741,12 +767,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Bitwise(n, f) => Agg { name: "bitwise".to_object(py), arguments: vec![n.0], - options: match f { - polars::prelude::BitwiseAggFunction::And => "and", - polars::prelude::BitwiseAggFunction::Or => "or", - polars::prelude::BitwiseAggFunction::Xor => "xor", - } - .to_object(py), + options: Into::<&str>::into(f).to_object(py), }, } .into_py(py), @@ -952,6 +973,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::ExtractMany { .. } => { return Err(PyNotImplementedError::new_err("extract_many")) }, + StringFunction::EscapeRegex => { + (PyStringFunction::EscapeRegex.into_py(py),).to_object(py) + }, }, FunctionExpr::StructExpr(_) => { return Err(PyNotImplementedError::new_err("struct expr")) @@ -1030,10 +1054,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { time_zone .as_ref() .map_or_else(|| py.None(), |s| s.to_object(py)), - match non_existent { - NonExistent::Null => "nullify", - NonExistent::Raise => "raise", - }, + Into::<&str>::into(non_existent), ) .into_py(py), TemporalFunction::Combine(time_unit) => { @@ -1073,7 +1094,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { BooleanFunction::IsUnique => (PyBooleanFunction::IsUnique,).into_py(py), BooleanFunction::IsDuplicated => (PyBooleanFunction::IsDuplicated,).into_py(py), BooleanFunction::IsBetween { closed } => { - (PyBooleanFunction::IsBetween, Wrap(*closed)).into_py(py) + (PyBooleanFunction::IsBetween, Into::<&str>::into(closed)).into_py(py) }, #[cfg(feature = "is_in")] BooleanFunction::IsIn => (PyBooleanFunction::IsIn,).into_py(py), @@ -1166,6 +1187,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { RollingFunction::Skew(_, _) => { return Err(PyNotImplementedError::new_err("rolling skew")) }, + RollingFunction::CorrCov { .. } => { + return Err(PyNotImplementedError::new_err("rolling cor_cov")) + }, }, FunctionExpr::RollingExprBy(rolling) => match rolling { RollingFunctionBy::MinBy(_) => { @@ -1202,6 +1226,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { #[cfg(feature = "repeat_by")] FunctionExpr::RepeatBy => ("repeat_by",).to_object(py), FunctionExpr::ArgUnique => ("arg_unique",).to_object(py), + FunctionExpr::Repeat => ("repeat",).to_object(py), FunctionExpr::Rank { options: _, seed: _, diff --git a/crates/polars-python/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs index ae805e7d0ff0..28c5e459b1e5 100644 --- a/crates/polars-python/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -1,4 +1,4 @@ -use polars_core::prelude::{IdxSize, UniqueKeepStrategy}; +use polars_core::prelude::IdxSize; use polars_ops::prelude::JoinType; use polars_plan::plans::IR; use polars_plan::prelude::{ @@ -273,7 +273,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { options .scan_fn .as_ref() - .map_or_else(|| py.None(), |s| s.0.clone()), + .map_or_else(|| py.None(), |s| s.0.clone_ref(py)), options.with_columns.as_ref().map_or_else( || py.None(), |cols| { @@ -454,7 +454,6 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { ))) })?, maintain_order: *maintain_order, - // TODO: dynamic options options: PyGroupbyOptions::new(options.as_ref().clone()).into_py(py), } .into_py(py), @@ -472,23 +471,16 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { right_on: right_on.iter().map(|e| e.into()).collect(), options: { let how = &options.args.how; - + let name = Into::<&str>::into(how).to_object(py); ( match how { - JoinType::Left => "left".to_object(py), - JoinType::Right => "right".to_object(py), - JoinType::Inner => "inner".to_object(py), - JoinType::Full => "full".to_object(py), #[cfg(feature = "asof_join")] JoinType::AsOf(_) => { return Err(PyNotImplementedError::new_err("asof join")) }, - JoinType::Cross => "cross".to_object(py), - JoinType::Semi => "leftsemi".to_object(py), - JoinType::Anti => "leftanti".to_object(py), #[cfg(feature = "iejoin")] JoinType::IEJoin(ie_options) => ( - "inequality".to_object(py), + name, crate::Wrap(ie_options.operator1).into_py(py), ie_options .operator2 @@ -496,10 +488,11 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { .map_or_else(|| py.None(), |op| crate::Wrap(*op).into_py(py)), ) .into_py(py), + _ => name, }, options.args.join_nulls, options.args.slice, - options.args.suffix.as_deref(), + options.args.suffix().as_str(), options.args.coalesce.coalesce(how), ) .to_object(py) @@ -529,12 +522,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { IR::Distinct { input, options } => Distinct { input: input.0, options: ( - match options.keep_strategy { - UniqueKeepStrategy::First => "first", - UniqueKeepStrategy::Last => "last", - UniqueKeepStrategy::None => "none", - UniqueKeepStrategy::Any => "any", - }, + Into::<&str>::into(options.keep_strategy), options.subset.as_ref().map_or_else( || py.None(), |f| { diff --git a/crates/polars-python/src/lazygroupby.rs b/crates/polars-python/src/lazygroupby.rs index 52df635efb53..d2ed68f3d568 100644 --- a/crates/polars-python/src/lazygroupby.rs +++ b/crates/polars-python/src/lazygroupby.rs @@ -34,6 +34,7 @@ impl PyLazyGroupBy { lgb.tail(Some(n)).into() } + #[pyo3(signature = (lambda, schema=None))] fn map_groups( &mut self, lambda: PyObject, diff --git a/crates/polars-python/src/map/mod.rs b/crates/polars-python/src/map/mod.rs index ef1bb4e34507..3bf96f91e631 100644 --- a/crates/polars-python/src/map/mod.rs +++ b/crates/polars-python/src/map/mod.rs @@ -122,10 +122,12 @@ fn iterator_to_struct<'a>( .collect::>() }); - Ok(StructChunked::from_series(name, fields.iter()) - .unwrap() - .into_series() - .into()) + Ok( + StructChunked::from_series(name, fields[0].len(), fields.iter()) + .unwrap() + .into_series() + .into(), + ) } fn iterator_to_primitive( @@ -255,8 +257,7 @@ fn iterator_to_list( name: PlSmallStr, capacity: usize, ) -> PyResult { - let mut builder = - get_list_builder(dt, capacity * 5, capacity, name).map_err(PyPolarsErr::from)?; + let mut builder = get_list_builder(dt, capacity * 5, capacity, name); for _ in 0..init_null_count { builder.append_null() } diff --git a/crates/polars-python/src/on_startup.rs b/crates/polars-python/src/on_startup.rs index 9b6f17d46f72..8c0b4275b4ba 100644 --- a/crates/polars-python/src/on_startup.rs +++ b/crates/polars-python/src/on_startup.rs @@ -88,7 +88,7 @@ pub fn register_startup_deps() { Box::new(object) as Box }); - let object_size = std::mem::size_of::(); + let object_size = size_of::(); let physical_dtype = ArrowDataType::FixedSizeBinary(object_size); registry::register_object_builder(object_builder, object_converter, physical_dtype); // register SERIES UDF diff --git a/crates/polars-python/src/series/aggregation.rs b/crates/polars-python/src/series/aggregation.rs index dbcbad59ddac..5aa8ee16639e 100644 --- a/crates/polars-python/src/series/aggregation.rs +++ b/crates/polars-python/src/series/aggregation.rs @@ -105,11 +105,7 @@ impl PySeries { .into_py(py)) } - fn quantile( - &self, - quantile: f64, - interpolation: Wrap, - ) -> PyResult { + fn quantile(&self, quantile: f64, interpolation: Wrap) -> PyResult { let bind = self.series.quantile_reduce(quantile, interpolation.0); let sc = bind.map_err(PyPolarsErr::from)?; diff --git a/crates/polars-python/src/series/buffers.rs b/crates/polars-python/src/series/buffers.rs index 30eddce08e39..939159220277 100644 --- a/crates/polars-python/src/series/buffers.rs +++ b/crates/polars-python/src/series/buffers.rs @@ -251,6 +251,7 @@ fn get_boolean_buffer_length_in_bytes(length: usize, offset: usize) -> usize { impl PySeries { /// Construct a PySeries from information about its underlying buffers. #[staticmethod] + #[pyo3(signature = (dtype, data, validity=None))] unsafe fn _from_buffers( dtype: Wrap, data: Vec, diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index 1398aa0c3dd1..f65822146d2c 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -497,6 +497,7 @@ impl PySeries { Ok(out.into()) } + #[pyo3(signature = (offset, length=None))] fn slice(&self, offset: i64, length: Option) -> Self { let length = length.unwrap_or_else(|| self.series.len()); self.series.slice(offset, length).into() @@ -523,6 +524,7 @@ macro_rules! impl_set_with_mask { #[pymethods] impl PySeries { + #[pyo3(signature = (filter, value))] fn $name(&self, filter: &PySeries, value: Option<$native>) -> PyResult { let series = $name(&self.series, filter, value).map_err(PyPolarsErr::from)?; Ok(Self::new(series)) diff --git a/crates/polars-python/src/series/map.rs b/crates/polars-python/src/series/map.rs index 9e3d9795a758..f003096f7dc9 100644 --- a/crates/polars-python/src/series/map.rs +++ b/crates/polars-python/src/series/map.rs @@ -232,7 +232,7 @@ impl PySeries { PyCFunction::new_closure_bound(py, None, None, move |args, _kwargs| { Python::with_gil(|py| { let out = function_owned.call1(py, args)?; - SERIES.call1(py, ("", out, dtype_py.clone())) + SERIES.call1(py, ("", out, dtype_py.clone_ref(py))) }) })? .to_object(py); diff --git a/crates/polars-python/src/series/numpy_ufunc.rs b/crates/polars-python/src/series/numpy_ufunc.rs index 10d765c3fc25..5df438686447 100644 --- a/crates/polars-python/src/series/numpy_ufunc.rs +++ b/crates/polars-python/src/series/numpy_ufunc.rs @@ -1,4 +1,4 @@ -use std::{mem, ptr}; +use std::ptr; use ndarray::IntoDimension; use numpy::npyffi::types::npy_intp; @@ -30,7 +30,7 @@ unsafe fn aligned_array( let buffer_ptr = buf.as_mut_ptr(); let mut dims = [len].into_dimension(); - let strides = [mem::size_of::() as npy_intp]; + let strides = [size_of::() as npy_intp]; let ptr = PY_ARRAY_API.PyArray_NewFromDescr( py, diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index 858ce3f55fcf..04c320e33ec3 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -64,7 +64,10 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataTyp .iter() .map(|struct_fld| decode(rows, field, struct_fld.dtype())) .collect(); - StructArray::new(dtype.clone(), values, None).to_boxed() + StructArray::new(dtype.clone(), rows.len(), values, None).to_boxed() + }, + ArrowDataType::List { .. } | ArrowDataType::LargeList { .. } => { + todo!("list decoding is not yet supported in polars' row encoding") }, dt => { with_match_arrow_primitive_type!(dt, |$T| { diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 363a43b6a9e8..fa9699c72f1b 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -230,7 +230,7 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE match encoder { Encoder::List { .. } => { let iter = encoder.list_iter(); - crate::variable::encode_iter(iter, out, &EncodingField::new_unsorted()) + crate::variable::encode_iter(iter, out, field) }, Encoder::Leaf(array) => { match array.dtype() { @@ -260,6 +260,7 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE .map(|opt_s| opt_s.map(|s| s.as_bytes())); crate::variable::encode_iter(iter, out, field) }, + ArrowDataType::Null => {}, // No output needed. dt => { with_match_arrow_primitive_type!(dt, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); @@ -286,6 +287,7 @@ pub fn encoded_size(dtype: &ArrowDataType) -> usize { Float32 => f32::ENCODED_LEN, Float64 => f64::ENCODED_LEN, Boolean => bool::ENCODED_LEN, + Null => 0, dt => unimplemented!("{dt:?}"), } } @@ -371,20 +373,13 @@ fn allocate_rows_buf( for opt_val in iter { unsafe { lengths.push_unchecked( - row_size_fixed - + crate::variable::encoded_len( - opt_val, - &EncodingField::new_unsorted(), - ), + row_size_fixed + crate::variable::encoded_len(opt_val, &field), ); } } } else { for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len( - opt_val, - &EncodingField::new_unsorted(), - ) + *row_length += crate::variable::encoded_len(opt_val, &field) } } processed_count += 1; @@ -637,7 +632,7 @@ mod test { let out = out.into_array(); assert_eq!( out.values().iter().map(|v| *v as usize).sum::(), - 82411 + 42774 ); } } diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index 7932420d1577..89f904e54296 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -25,7 +25,7 @@ impl FromSlice for [u8; N] { pub trait FixedLengthEncoding: Copy + Debug { // 1 is validity 0 or 1 // bit repr of encoding - const ENCODED_LEN: usize = 1 + std::mem::size_of::(); + const ENCODED_LEN: usize = 1 + size_of::(); type Encoded: Sized + Copy + AsRef<[u8]> + AsMut<[u8]>; diff --git a/crates/polars-row/src/lib.rs b/crates/polars-row/src/lib.rs index 823e5c6e4566..ddfbae9ea52b 100644 --- a/crates/polars-row/src/lib.rs +++ b/crates/polars-row/src/lib.rs @@ -120,6 +120,10 @@ //! This approach is loosely inspired by [COBS] encoding, and chosen over more traditional //! [byte stuffing] as it is more amenable to vectorisation, in particular AVX-256. //! +//! For the unordered row encoding we use a simpler scheme, we prepend the length +//! encoded as 4 bytes followed by the raw data, with nulls being marked with a +//! length of u32::MAX. +//! //! ## Dictionary Encoding //! //! [`RowsEncoded`] needs to support converting dictionary encoded arrays with unsorted, and diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index d48f6f51c205..1aa50e8b0e43 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -40,8 +40,8 @@ pub struct RowsEncoded { fn checks(offsets: &[usize]) { assert_eq!( - std::mem::size_of::(), - std::mem::size_of::(), + size_of::(), + size_of::(), "only supported on 64bit arch" ); assert!( diff --git a/crates/polars-row/src/variable.rs b/crates/polars-row/src/variable.rs index 5032e41085a8..2ccdf0f686f3 100644 --- a/crates/polars-row/src/variable.rs +++ b/crates/polars-row/src/variable.rs @@ -13,6 +13,7 @@ use std::mem::MaybeUninit; use arrow::array::{BinaryArray, BinaryViewArray, MutableBinaryViewArray}; +use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; use arrow::offset::Offsets; use polars_utils::slice::{GetSaferUnchecked, Slice2Uninit}; @@ -46,30 +47,12 @@ fn padded_length(a: usize) -> usize { 1 + ceil(a, BLOCK_SIZE) * (BLOCK_SIZE + 1) } -#[inline] -fn padded_length_opt(a: Option) -> usize { - if let Some(a) = a { - padded_length(a) - } else { - 1 - } -} - -#[inline] -fn length_opt(a: Option) -> usize { - if let Some(a) = a { - 1 + a - } else { - 1 - } -} - #[inline] pub fn encoded_len(a: Option<&[u8]>, field: &EncodingField) -> usize { if field.no_order { - length_opt(a.map(|v| v.len())) + 4 + a.map(|v| v.len()).unwrap_or(0) } else { - padded_length_opt(a.map(|v| v.len())) + a.map(|v| padded_length(v.len())).unwrap_or(1) } } @@ -78,30 +61,19 @@ unsafe fn encode_one_no_order( val: Option<&[MaybeUninit]>, field: &EncodingField, ) -> usize { + debug_assert!(field.no_order); match val { - Some([]) => { - let byte = if field.descending { - !EMPTY_SENTINEL - } else { - EMPTY_SENTINEL - }; - *out.get_unchecked_release_mut(0) = MaybeUninit::new(byte); - 1 - }, Some(val) => { - let end_offset = 1 + val.len(); - - // Write `2_u8` to demarcate as non-empty, non-null string - *out.get_unchecked_release_mut(0) = MaybeUninit::new(NON_EMPTY_SENTINEL); - std::ptr::copy_nonoverlapping(val.as_ptr(), out.as_mut_ptr().add(1), val.len()); - - end_offset + assert!(val.len() < u32::MAX as usize); + let encoded_len = (val.len() as u32).to_le_bytes().map(MaybeUninit::new); + std::ptr::copy_nonoverlapping(encoded_len.as_ptr(), out.as_mut_ptr(), 4); + std::ptr::copy_nonoverlapping(val.as_ptr(), out.as_mut_ptr().add(4), val.len()); + 4 + val.len() }, None => { - *out.get_unchecked_release_mut(0) = MaybeUninit::new(get_null_sentinel(field)); - // // write remainder as zeros - // out.get_unchecked_release_mut(1..).fill(MaybeUninit::new(0)); - 1 + let sentinel = u32::MAX.to_le_bytes().map(MaybeUninit::new); + std::ptr::copy_nonoverlapping(sentinel.as_ptr(), out.as_mut_ptr(), 4); + 4 }, } } @@ -258,7 +230,63 @@ unsafe fn decoded_len( } } +unsafe fn decoded_len_unordered(row: &[u8]) -> Option { + let len = u32::from_le_bytes(row.get_unchecked(0..4).try_into().unwrap()); + Some(len).filter(|l| *l < u32::MAX) +} + +unsafe fn decode_binary_unordered(rows: &mut [&[u8]]) -> BinaryArray { + let mut has_nulls = false; + let mut total_len = 0; + for row in rows.iter() { + if let Some(len) = decoded_len_unordered(row) { + total_len += len as usize; + } else { + has_nulls = true; + } + } + + let validity = has_nulls.then(|| { + Bitmap::from_trusted_len_iter_unchecked( + rows.iter().map(|row| decoded_len_unordered(row).is_none()), + ) + }); + + let mut values = Vec::with_capacity(total_len); + let mut offsets = Vec::with_capacity(rows.len() + 1); + offsets.push(0); + for row in rows.iter_mut() { + let len = decoded_len_unordered(row).unwrap_or(0) as usize; + values.extend_from_slice(row.get_unchecked(4..4 + len)); + *row = row.get_unchecked(4 + len..); + offsets.push(values.len() as i64); + } + BinaryArray::new( + ArrowDataType::LargeBinary, + Offsets::new_unchecked(offsets).into(), + values.into(), + validity, + ) +} + +unsafe fn decode_binview_unordered(rows: &mut [&[u8]]) -> BinaryViewArray { + let mut mutable = MutableBinaryViewArray::with_capacity(rows.len()); + for row in rows.iter_mut() { + if let Some(len) = decoded_len_unordered(row) { + mutable.push_value(row.get_unchecked(4..4 + len as usize)); + *row = row.get_unchecked(4 + len as usize..); + } else { + mutable.push_null(); + } + } + mutable.freeze() +} + pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &EncodingField) -> BinaryArray { + if field.no_order { + return decode_binary_unordered(rows); + } + let (non_empty_sentinel, continuation_token) = if field.descending { (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION_TOKEN) } else { @@ -330,6 +358,10 @@ pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &EncodingField) -> } pub(super) unsafe fn decode_binview(rows: &mut [&[u8]], field: &EncodingField) -> BinaryViewArray { + if field.no_order { + return decode_binview_unordered(rows); + } + let (non_empty_sentinel, continuation_token) = if field.descending { (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION_TOKEN) } else { diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 9db54d1c3333..b5b875403b6d 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -24,7 +24,6 @@ rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlparser = { workspace = true } -# sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs.git", rev = "ae3b5844c839072c235965fe0d1bddc473dced87" } [dev-dependencies] # to display dataframes in case of test failures @@ -34,6 +33,7 @@ polars-core = { workspace = true, features = ["fmt"] } default = [] nightly = [] binary_encoding = ["polars-lazy/binary_encoding"] +bitwise = ["polars-lazy/bitwise"] csv = ["polars-lazy/csv"] diagonal_concat = ["polars-lazy/diagonal_concat"] dtype-decimal = ["polars-lazy/dtype-decimal"] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 342a5e0883d2..1a060545439c 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -709,6 +709,11 @@ impl SQLContext { }; lf = if group_by_keys.is_empty() { + // The 'having' clause is only valid inside 'group by' + if select_stmt.having.is_some() { + polars_bail!(SQLSyntax: "HAVING clause not valid outside of GROUP BY; found:\n{:?}", select_stmt.having); + }; + // Final/selected cols, accounting for 'SELECT *' modifiers let mut retained_cols = Vec::with_capacity(projections.len()); let have_order_by = query.order_by.is_some(); diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 87b0656d171d..ff1a926f27e3 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -2,7 +2,9 @@ use std::ops::Sub; use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::export::regex; -use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit}; +use polars_core::prelude::{ + polars_bail, polars_err, DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, +}; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] use polars_lazy::dsl::ListNameSpaceExtension; @@ -28,11 +30,40 @@ pub(crate) struct SQLFunctionVisitor<'a> { /// SQL functions that are supported by Polars pub(crate) enum PolarsSQLFunctions { + // ---- + // Bitwise functions + // ---- + /// SQL 'bit_and' function. + /// Returns the bitwise AND of the input expressions. + /// ```sql + /// SELECT BIT_AND(column_1, column_2) FROM df; + /// ``` + BitAnd, + /// SQL 'bit_count' function. + /// Returns the number of set bits in the input expression. + /// ```sql + /// SELECT BIT_COUNT(column_1) FROM df; + /// ``` + #[cfg(feature = "bitwise")] + BitCount, + /// SQL 'bit_or' function. + /// Returns the bitwise OR of the input expressions. + /// ```sql + /// SELECT BIT_OR(column_1, column_2) FROM df; + /// ``` + BitOr, + /// SQL 'bit_xor' function. + /// Returns the bitwise XOR of the input expressions. + /// ```sql + /// SELECT BIT_XOR(column_1, column_2) FROM df; + /// ``` + BitXor, + // ---- // Math functions // ---- /// SQL 'abs' function - /// Returns the absolute value of the input column. + /// Returns the absolute value of the input expression. /// ```sql /// SELECT ABS(column_1) FROM df; /// ``` @@ -140,97 +171,97 @@ pub(crate) enum PolarsSQLFunctions { // Trig functions // ---- /// SQL 'cos' function - /// Compute the cosine sine of the input column (in radians). + /// Compute the cosine sine of the input expression (in radians). /// ```sql /// SELECT COS(column_1) FROM df; /// ``` Cos, /// SQL 'cot' function - /// Compute the cotangent of the input column (in radians). + /// Compute the cotangent of the input expression (in radians). /// ```sql /// SELECT COT(column_1) FROM df; /// ``` Cot, /// SQL 'sin' function - /// Compute the sine of the input column (in radians). + /// Compute the sine of the input expression (in radians). /// ```sql /// SELECT SIN(column_1) FROM df; /// ``` Sin, /// SQL 'tan' function - /// Compute the tangent of the input column (in radians). + /// Compute the tangent of the input expression (in radians). /// ```sql /// SELECT TAN(column_1) FROM df; /// ``` Tan, /// SQL 'cosd' function - /// Compute the cosine sine of the input column (in degrees). + /// Compute the cosine sine of the input expression (in degrees). /// ```sql /// SELECT COSD(column_1) FROM df; /// ``` CosD, /// SQL 'cotd' function - /// Compute cotangent of the input column (in degrees). + /// Compute cotangent of the input expression (in degrees). /// ```sql /// SELECT COTD(column_1) FROM df; /// ``` CotD, /// SQL 'sind' function - /// Compute the sine of the input column (in degrees). + /// Compute the sine of the input expression (in degrees). /// ```sql /// SELECT SIND(column_1) FROM df; /// ``` SinD, /// SQL 'tand' function - /// Compute the tangent of the input column (in degrees). + /// Compute the tangent of the input expression (in degrees). /// ```sql /// SELECT TAND(column_1) FROM df; /// ``` TanD, /// SQL 'acos' function - /// Compute inverse cosinus of the input column (in radians). + /// Compute inverse cosinus of the input expression (in radians). /// ```sql /// SELECT ACOS(column_1) FROM df; /// ``` Acos, /// SQL 'asin' function - /// Compute inverse sine of the input column (in radians). + /// Compute inverse sine of the input expression (in radians). /// ```sql /// SELECT ASIN(column_1) FROM df; /// ``` Asin, /// SQL 'atan' function - /// Compute inverse tangent of the input column (in radians). + /// Compute inverse tangent of the input expression (in radians). /// ```sql /// SELECT ATAN(column_1) FROM df; /// ``` Atan, /// SQL 'atan2' function - /// Compute the inverse tangent of column_2/column_1 (in radians). + /// Compute the inverse tangent of column_1/column_2 (in radians). /// ```sql /// SELECT ATAN2(column_1, column_2) FROM df; /// ``` Atan2, /// SQL 'acosd' function - /// Compute inverse cosinus of the input column (in degrees). + /// Compute inverse cosinus of the input expression (in degrees). /// ```sql /// SELECT ACOSD(column_1) FROM df; /// ``` AcosD, /// SQL 'asind' function - /// Compute inverse sine of the input column (in degrees). + /// Compute inverse sine of the input expression (in degrees). /// ```sql /// SELECT ASIND(column_1) FROM df; /// ``` AsinD, /// SQL 'atand' function - /// Compute inverse tangent of the input column (in degrees). + /// Compute inverse tangent of the input expression (in degrees). /// ```sql /// SELECT ATAND(column_1) FROM df; /// ``` AtanD, /// SQL 'atan2d' function - /// Compute the inverse tangent of column_2/column_1 (in degrees). + /// Compute the inverse tangent of column_1/column_2 (in degrees). /// ```sql /// SELECT ATAN2D(column_1) FROM df; /// ``` @@ -504,6 +535,20 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT MEDIAN(column_1) FROM df; /// ``` Median, + /// SQL 'quantile_cont' function + /// Returns the continuous quantile element from the grouping + /// (interpolated value between two closest values). + /// ```sql + /// SELECT QUANTILE_CONT(column_1) FROM df; + /// ``` + QuantileCont, + /// SQL 'quantile_disc' function + /// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, + /// and returns the value associated with the subinterval where the quantile value falls. + /// ```sql + /// SELECT QUANTILE_DISC(column_1) FROM df; + /// ``` + QuantileDisc, /// SQL 'min' function /// Returns the smallest (minimum) of all the elements in the grouping. /// ```sql @@ -640,7 +685,11 @@ impl PolarsSQLFunctions { "atan2d", "atand", "avg", + "bit_and", + "bit_count", "bit_length", + "bit_or", + "bit_xor", "cbrt", "ceil", "ceiling", @@ -679,6 +728,7 @@ impl PolarsSQLFunctions { "ltrim", "max", "median", + "quantile_disc", "min", "mod", "nullif", @@ -686,6 +736,8 @@ impl PolarsSQLFunctions { "pi", "pow", "power", + "quantile_cont", + "quantile_disc", "radians", "regexp_like", "replace", @@ -722,6 +774,15 @@ impl PolarsSQLFunctions { fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult { let function_name = function.name.0[0].value.to_lowercase(); Ok(match function_name.as_str() { + // ---- + // Bitwise functions + // ---- + "bit_and" | "bitand" => Self::BitAnd, + #[cfg(feature = "bitwise")] + "bit_count" | "bitcount" => Self::BitCount, + "bit_or" | "bitor" => Self::BitOr, + "bit_xor" | "bitxor" | "xor" => Self::BitXor, + // ---- // Math functions // ---- @@ -818,6 +879,8 @@ impl PolarsSQLFunctions { "last" => Self::Last, "max" => Self::Max, "median" => Self::Median, + "quantile_cont" => Self::QuantileCont, + "quantile_disc" => Self::QuantileDisc, "min" => Self::Min, "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev, "sum" => Self::Sum, @@ -873,6 +936,15 @@ impl SQLFunctionVisitor<'_> { } match function_name { + // ---- + // Bitwise functions + // ---- + BitAnd => self.visit_binary::(Expr::and), + #[cfg(feature = "bitwise")] + BitCount => self.visit_unary(Expr::bitwise_count_ones), + BitOr => self.visit_binary::(Expr::or), + BitXor => self.visit_binary::(Expr::xor), + // ---- // Math functions // ---- @@ -1243,6 +1315,58 @@ impl SQLFunctionVisitor<'_> { Last => self.visit_unary(Expr::last), Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max), Median => self.visit_unary(Expr::median), + QuantileCont => { + let args = extract_args(function)?; + match args.len() { + 2 => self.try_visit_binary(|e, q| { + let value = match q { + Expr::Literal(LiteralValue::Float(f)) => { + if (0.0..=1.0).contains(&f) { + Expr::from(f) + } else { + polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1]) + } + }, + Expr::Literal(LiteralValue::Int(n)) => { + if (0..=1).contains(&n) { + Expr::from(n as f64) + } else { + polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1]) + }; + Ok(e.quantile(value, QuantileMethod::Linear)) + }), + _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()), + } + }, + QuantileDisc => { + let args = extract_args(function)?; + match args.len() { + 2 => self.try_visit_binary(|e, q| { + let value = match q { + Expr::Literal(LiteralValue::Float(f)) => { + if (0.0..=1.0).contains(&f) { + Expr::from(f) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + Expr::Literal(LiteralValue::Int(n)) => { + if (0..=1).contains(&n) { + Expr::from(n as f64) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1]) + }; + Ok(e.quantile(value, QuantileMethod::Equiprobable)) + }), + _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()), + } + }, Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min), StdDev => self.visit_unary(|e| e.std(1)), Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum), diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index f9caa288cb82..5eb2bdd843b4 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -374,10 +374,9 @@ impl SQLExprVisitor<'_> { }, // identify "CAST(expr AS type) string" and/or "expr::type string" expressions (Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => { - if let Expr::Column(name) = &**expr { - (Some(name.clone()), Some(s), Some(dtype)) - } else { - (None, Some(s), Some(dtype)) + match &**expr { + Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)), + _ => (None, Some(s), Some(dtype)), } }, _ => (None, None, None), @@ -385,23 +384,25 @@ impl SQLExprVisitor<'_> { if expr_dtype.is_none() && self.active_schema.is_none() { right.clone() } else { - let left_dtype = expr_dtype - .unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap()); - + let left_dtype = expr_dtype.or_else(|| { + self.active_schema + .as_ref() + .and_then(|schema| schema.get(&name)) + }); match left_dtype { - DataType::Time if is_iso_time(s) => { + Some(DataType::Time) if is_iso_time(s) => { right.clone().str().to_time(StrptimeOptions { strict: true, ..Default::default() }) }, - DataType::Date if is_iso_date(s) => { + Some(DataType::Date) if is_iso_date(s) => { right.clone().str().to_date(StrptimeOptions { strict: true, ..Default::default() }) }, - DataType::Datetime(tu, tz) if is_iso_datetime(s) || is_iso_date(s) => { + Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => { if s.len() == 10 { // handle upcast from ISO date string (10 chars) to datetime lit(format!("{}T00:00:00", s)) @@ -469,48 +470,53 @@ impl SQLExprVisitor<'_> { rhs = self.convert_temporal_strings(&lhs, &rhs); Ok(match op { - SQLBinaryOperator::And => lhs.and(rhs), - SQLBinaryOperator::Divide => lhs / rhs, - SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), - SQLBinaryOperator::Eq => lhs.eq(rhs), - SQLBinaryOperator::Gt => lhs.gt(rhs), - SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), - SQLBinaryOperator::Lt => lhs.lt(rhs), - SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), - SQLBinaryOperator::Minus => lhs - rhs, - SQLBinaryOperator::Modulo => lhs % rhs, - SQLBinaryOperator::Multiply => lhs * rhs, - SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), - SQLBinaryOperator::Or => lhs.or(rhs), - SQLBinaryOperator::Plus => lhs + rhs, - SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), - SQLBinaryOperator::StringConcat => { + // ---- + // Bitwise operators + // ---- + SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), // "x & y" + SQLBinaryOperator::BitwiseOr => lhs.or(rhs), // "x | y" + SQLBinaryOperator::Xor => lhs.xor(rhs), // "x XOR y" + + // ---- + // General operators + // ---- + SQLBinaryOperator::And => lhs.and(rhs), // "x AND y" + SQLBinaryOperator::Divide => lhs / rhs, // "x / y" + SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), // "x // y" + SQLBinaryOperator::Eq => lhs.eq(rhs), // "x = y" + SQLBinaryOperator::Gt => lhs.gt(rhs), // "x > y" + SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), // "x >= y" + SQLBinaryOperator::Lt => lhs.lt(rhs), // "x < y" + SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), // "x <= y" + SQLBinaryOperator::Minus => lhs - rhs, // "x - y" + SQLBinaryOperator::Modulo => lhs % rhs, // "x % y" + SQLBinaryOperator::Multiply => lhs * rhs, // "x * y" + SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), // "x != y" + SQLBinaryOperator::Or => lhs.or(rhs), // "x OR y" + SQLBinaryOperator::Plus => lhs + rhs, // "x + y" + SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), // "x <=> y" + SQLBinaryOperator::StringConcat => { // "x || y" lhs.cast(DataType::String) + rhs.cast(DataType::String) }, - SQLBinaryOperator::Xor => lhs.xor(rhs), - SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), + SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), // "x ^@ y" // ---- // Regular expression operators // ---- - // "a ~ b" - SQLBinaryOperator::PGRegexMatch => match rhs { + SQLBinaryOperator::PGRegexMatch => match rhs { // "x ~ y" Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true), _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs), }, - // "a !~ b" - SQLBinaryOperator::PGRegexNotMatch => match rhs { + SQLBinaryOperator::PGRegexNotMatch => match rhs { // "x !~ y" Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(), _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs), }, - // "a ~* b" - SQLBinaryOperator::PGRegexIMatch => match rhs { + SQLBinaryOperator::PGRegexIMatch => match rhs { // "x ~* y" Expr::Literal(LiteralValue::String(pat)) => { lhs.str().contains(lit(format!("(?i){}", pat)), true) }, _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs), }, - // "a !~* b" - SQLBinaryOperator::PGRegexNotIMatch => match rhs { + SQLBinaryOperator::PGRegexNotIMatch => match rhs { // "x !~* y" Expr::Literal(LiteralValue::String(pat)) => { lhs.str().contains(lit(format!("(?i){}", pat)), true).not() }, @@ -521,10 +527,10 @@ impl SQLExprVisitor<'_> { // ---- // LIKE/ILIKE operators // ---- - SQLBinaryOperator::PGLikeMatch - | SQLBinaryOperator::PGNotLikeMatch - | SQLBinaryOperator::PGILikeMatch - | SQLBinaryOperator::PGNotILikeMatch => { + SQLBinaryOperator::PGLikeMatch // "x ~~ y" + | SQLBinaryOperator::PGNotLikeMatch // "x !~~ y" + | SQLBinaryOperator::PGILikeMatch // "x ~~* y" + | SQLBinaryOperator::PGNotILikeMatch => { // "x !~~* y" let expr = if matches!( op, SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch @@ -548,7 +554,7 @@ impl SQLExprVisitor<'_> { // ---- // JSON/Struct field access operators // ---- - SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { + SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { // "x -> y", "x ->> y" Expr::Literal(LiteralValue::String(path)) => { let mut expr = self.struct_field_access_expr(&lhs, &path, false)?; if let SQLBinaryOperator::LongArrow = op { @@ -567,7 +573,7 @@ impl SQLExprVisitor<'_> { polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right) }, }, - SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { + SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { // "x #> y", "x #>> y" if let Expr::Literal(LiteralValue::String(path)) = rhs { let mut expr = self.struct_field_access_expr(&lhs, &path, true)?; if let SQLBinaryOperator::HashLongArrow = op { diff --git a/crates/polars-sql/tests/functions_aggregate.rs b/crates/polars-sql/tests/functions_aggregate.rs new file mode 100644 index 000000000000..092a340f5f18 --- /dev/null +++ b/crates/polars-sql/tests/functions_aggregate.rs @@ -0,0 +1,121 @@ +use polars_core::prelude::*; +use polars_lazy::prelude::*; +use polars_plan::dsl::Expr; +use polars_sql::*; + +fn create_df() -> LazyFrame { + df! { + "Data" => [1000, 2000, 3000, 4000, 5000, 6000] + } + .unwrap() + .lazy() +} + +fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { + let df = create_df(); + let alias = "TEST"; + + let query = format!( + r#" + SELECT + {sql} as {alias} + FROM + df + "# + ); + + let expected = df + .clone() + .select(&[expr.alias(alias)]) + .sort([alias], Default::default()) + .collect() + .unwrap(); + let mut ctx = SQLContext::new(); + ctx.register("df", df); + + let actual = ctx.execute(&query).unwrap().collect().unwrap(); + (expected, actual) +} + +#[test] +fn test_median() { + let expr = col("Data").median(); + + let sql_expr = "MEDIAN(Data)"; + let (expected, actual) = create_expected(expr, sql_expr); + + assert!(expected.equals(&actual)) +} + +#[test] +fn test_quantile_cont() { + for &q in &[0.25, 0.5, 0.75] { + let expr = col("Data").quantile(lit(q), QuantileMethod::Linear); + + let sql_expr = format!("QUANTILE_CONT(Data, {})", q); + let (expected, actual) = create_expected(expr, &sql_expr); + + assert!( + expected.equals(&actual), + "q: {q}: expected {expected:?}, got {actual:?}" + ) + } +} + +#[test] +fn test_quantile_disc() { + for &q in &[0.25, 0.5, 0.75] { + let expr = col("Data").quantile(lit(q), QuantileMethod::Equiprobable); + + let sql_expr = format!("QUANTILE_DISC(Data, {})", q); + let (expected, actual) = create_expected(expr, &sql_expr); + + assert!(expected.equals(&actual)) + } +} + +#[test] +fn test_quantile_out_of_range() { + for &q in &["-1", "2", "-0.01", "1.01"] { + for &func in &["QUANTILE_CONT", "QUANTILE_DISC"] { + let query = format!("SELECT {func}(Data, {q})"); + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + let actual = ctx.execute(&query); + assert!(actual.is_err()) + } + } +} + +#[test] +fn test_quantile_disc_conformance() { + let expected = df![ + "q" => [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], + "Data" => [1000, 1000, 2000, 2000, 3000, 3000, 4000, 5000, 5000, 6000, 6000], + ] + .unwrap(); + + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + + let mut actual: Option = None; + for &q in &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] { + let res = ctx + .execute(&format!( + "SELECT {q}::float as q, QUANTILE_DISC(Data, {q}) as Data FROM df" + )) + .unwrap() + .collect() + .unwrap(); + actual = if let Some(df) = actual { + Some(df.vstack(&res).unwrap()) + } else { + Some(res) + }; + } + + assert!( + expected.equals(actual.as_ref().unwrap()), + "expected {expected:?}, got {actual:?}" + ) +} diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index 78cdbc9115d0..fc130a035140 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -37,4 +37,11 @@ version_check = { workspace = true } [features] nightly = [] -bitwise = ["polars-core/bitwise", "polars-plan/bitwise"] +bitwise = ["polars-core/bitwise", "polars-plan/bitwise", "polars-expr/bitwise"] +merge_sorted = ["polars-plan/merge_sorted"] +dynamic_group_by = [] +strings = [] + +# We need to specify default features here to match workspace defaults. +# Otherwise we get warnings with cargo check/clippy. +default = ["bitwise"] diff --git a/crates/polars-stream/src/async_executor/mod.rs b/crates/polars-stream/src/async_executor/mod.rs index dec560845b09..23789e5a20df 100644 --- a/crates/polars-stream/src/async_executor/mod.rs +++ b/crates/polars-stream/src/async_executor/mod.rs @@ -1,12 +1,16 @@ +#![allow(clippy::disallowed_types)] + mod park_group; mod task; use std::cell::{Cell, UnsafeCell}; +use std::collections::HashMap; use std::future::Future; use std::marker::PhantomData; -use std::panic::AssertUnwindSafe; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::{Arc, OnceLock, Weak}; +use std::panic::{AssertUnwindSafe, Location}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, LazyLock, OnceLock, Weak}; +use std::time::Duration; use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue}; use crossbeam_utils::CachePadded; @@ -30,6 +34,27 @@ thread_local!( static TLS_THREAD_ID: Cell = const { Cell::new(usize::MAX) }; ); +static NS_SPENT_BLOCKED: LazyLock, u64>>> = + LazyLock::new(Mutex::default); + +static TRACK_WAIT_STATISTICS: AtomicBool = AtomicBool::new(false); + +pub fn track_task_wait_statistics(should_track: bool) { + TRACK_WAIT_STATISTICS.store(should_track, Ordering::Relaxed); +} + +pub fn get_task_wait_statistics() -> Vec<(&'static Location<'static>, Duration)> { + NS_SPENT_BLOCKED + .lock() + .iter() + .map(|(l, ns)| (*l, Duration::from_nanos(*ns))) + .collect() +} + +pub fn clear_task_wait_statistics() { + NS_SPENT_BLOCKED.lock().clear() +} + slotmap::new_key_type! { struct TaskKey; } @@ -48,6 +73,8 @@ struct ScopedTaskMetadata { } struct TaskMetadata { + spawn_location: &'static Location<'static>, + ns_spent_blocked: AtomicU64, priority: TaskPriority, freshly_spawned: AtomicBool, scoped: Option, @@ -55,6 +82,10 @@ struct TaskMetadata { impl Drop for TaskMetadata { fn drop(&mut self) { + *NS_SPENT_BLOCKED + .lock() + .entry(self.spawn_location) + .or_default() += self.ns_spent_blocked.load(Ordering::Relaxed); if let Some(scoped) = &self.scoped { if let Some(completed_tasks) = scoped.completed_tasks.upgrade() { completed_tasks.lock().push(scoped.task_key); @@ -182,6 +213,7 @@ impl Executor { let mut rng = SmallRng::from_rng(&mut rand::thread_rng()).unwrap(); let mut worker = self.park_group.new_worker(); + let mut last_block_start = None; loop { let ttl = &self.thread_task_lists[thread]; @@ -206,11 +238,23 @@ impl Executor { if let Some(task) = self.try_steal_task(thread, &mut rng) { return Some(task); } + + if last_block_start.is_none() && TRACK_WAIT_STATISTICS.load(Ordering::Relaxed) { + last_block_start = Some(std::time::Instant::now()); + } park.park(); None })(); if let Some(task) = task { + if let Some(t) = last_block_start.take() { + if TRACK_WAIT_STATISTICS.load(Ordering::Relaxed) { + let ns: u64 = t.elapsed().as_nanos().try_into().unwrap(); + task.metadata() + .ns_spent_blocked + .fetch_add(ns, Ordering::Relaxed); + } + } worker.recruit_next(); task.run(); } @@ -264,7 +308,7 @@ pub struct TaskScope<'scope, 'env: 'scope> { env: PhantomData<&'env mut &'env ()>, } -impl<'scope, 'env> TaskScope<'scope, 'env> { +impl<'scope> TaskScope<'scope, '_> { // Not Drop because that extends lifetimes. fn destroy(&self) { // Make sure all tasks are cancelled. @@ -280,6 +324,7 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { } } + #[track_caller] pub fn spawn_task( &self, priority: TaskPriority, @@ -288,6 +333,7 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { where ::Output: Send + 'static, { + let spawn_location = Location::caller(); self.clear_completed_tasks(); let mut runnable = None; @@ -301,6 +347,8 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { fut, on_wake, TaskMetadata { + spawn_location, + ns_spent_blocked: AtomicU64::new(0), priority, freshly_spawned: AtomicBool::new(true), scoped: Some(ScopedTaskMetadata { @@ -345,16 +393,20 @@ where } } +#[track_caller] pub fn spawn(priority: TaskPriority, fut: F) -> JoinHandle where ::Output: Send + 'static, { + let spawn_location = Location::caller(); let executor = Executor::global(); let on_wake = move |task| executor.schedule_task(task); let (runnable, join_handle) = task::spawn( fut, on_wake, TaskMetadata { + spawn_location, + ns_spent_blocked: AtomicU64::new(0), priority, freshly_spawned: AtomicBool::new(true), scoped: None, diff --git a/crates/polars-stream/src/async_executor/park_group.rs b/crates/polars-stream/src/async_executor/park_group.rs index d9da30ce7f3e..d72a474da1e4 100644 --- a/crates/polars-stream/src/async_executor/park_group.rs +++ b/crates/polars-stream/src/async_executor/park_group.rs @@ -149,7 +149,7 @@ pub struct ParkAttempt<'a> { worker: &'a mut ParkGroupWorker, } -impl<'a> ParkAttempt<'a> { +impl ParkAttempt<'_> { /// Actually park this worker. /// /// If there were calls to unpark between calling prepare_park() and park(), diff --git a/crates/polars-stream/src/async_executor/task.rs b/crates/polars-stream/src/async_executor/task.rs index 9991377eb718..1383da2edde8 100644 --- a/crates/polars-stream/src/async_executor/task.rs +++ b/crates/polars-stream/src/async_executor/task.rs @@ -118,9 +118,9 @@ where } } -impl<'a, F, S, M> Wake for Task +impl Wake for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Fn(Runnable) + Send + Sync + Copy + 'static, M: Send + Sync + 'static, @@ -143,9 +143,9 @@ pub trait DynTask: Send + Sync { fn schedule(self: Arc); } -impl<'a, F, S, M> DynTask for Task +impl DynTask for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Fn(Runnable) + Send + Sync + Copy + 'static, M: Send + Sync + 'static, @@ -202,9 +202,9 @@ trait Joinable: Send + Sync { fn poll_join(&self, ctx: &mut Context<'_>) -> Poll; } -impl<'a, F, S, M> Joinable for Task +impl Joinable for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Fn(Runnable) + Send + Sync + Copy + 'static, M: Send + Sync + 'static, @@ -233,9 +233,9 @@ trait Cancellable: Send + Sync { fn cancel(&self); } -impl<'a, F, S, M> Cancellable for Task +impl Cancellable for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Send + Sync + 'static, M: Send + Sync + 'static, diff --git a/crates/polars-stream/src/async_primitives/connector.rs b/crates/polars-stream/src/async_primitives/connector.rs index 94999fff4e7a..8b53193b95f1 100644 --- a/crates/polars-stream/src/async_primitives/connector.rs +++ b/crates/polars-stream/src/async_primitives/connector.rs @@ -217,7 +217,7 @@ pin_project! { } } -unsafe impl<'a, T: Send> Send for SendFuture<'a, T> {} +unsafe impl Send for SendFuture<'_, T> {} impl Sender { /// Returns a future that when awaited will send the value to the [`Receiver`]. @@ -255,7 +255,7 @@ pin_project! { } } -unsafe impl<'a, T: Send> Send for RecvFuture<'a, T> {} +unsafe impl Send for RecvFuture<'_, T> {} impl Receiver { /// Returns a future that when awaited will return `Ok(value)` once the diff --git a/crates/polars-stream/src/async_primitives/task_parker.rs b/crates/polars-stream/src/async_primitives/task_parker.rs index 9e48b79e468b..d6cde679980b 100644 --- a/crates/polars-stream/src/async_primitives/task_parker.rs +++ b/crates/polars-stream/src/async_primitives/task_parker.rs @@ -43,7 +43,7 @@ pub struct TaskParkFuture<'a> { parker: &'a TaskParker, } -impl<'a> Future for TaskParkFuture<'a> { +impl Future for TaskParkFuture<'_> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/crates/polars-stream/src/async_primitives/wait_group.rs b/crates/polars-stream/src/async_primitives/wait_group.rs index 716363528505..e08f556d3b95 100644 --- a/crates/polars-stream/src/async_primitives/wait_group.rs +++ b/crates/polars-stream/src/async_primitives/wait_group.rs @@ -62,7 +62,7 @@ impl Future for WaitGroupFuture<'_> { } } -impl<'a> Drop for WaitGroupFuture<'a> { +impl Drop for WaitGroupFuture<'_> { fn drop(&mut self) { self.inner.is_waiting.store(false, Ordering::Relaxed); } diff --git a/crates/polars-stream/src/execute.rs b/crates/polars-stream/src/execute.rs index b199c0044e92..2d68cae2c90e 100644 --- a/crates/polars-stream/src/execute.rs +++ b/crates/polars-stream/src/execute.rs @@ -220,12 +220,21 @@ fn run_subgraph( } // Wait until all tasks are done. - polars_io::pl_async::get_runtime().block_on(async move { + // Only now do we turn on/off wait statistics tracking to reduce noise + // from task startup. + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { + async_executor::track_task_wait_statistics(true); + } + let ret = polars_io::pl_async::get_runtime().block_on(async move { for handle in join_handles { handle.await?; } PolarsResult::Ok(()) - }) + }); + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { + async_executor::track_task_wait_statistics(false); + } + ret })?; Ok(()) diff --git a/crates/polars-stream/src/nodes/filter.rs b/crates/polars-stream/src/nodes/filter.rs index 9f0b0301ef91..f89a53adbf23 100644 --- a/crates/polars-stream/src/nodes/filter.rs +++ b/crates/polars-stream/src/nodes/filter.rs @@ -27,14 +27,14 @@ impl ComputeNode for FilterNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/group_by.rs b/crates/polars-stream/src/nodes/group_by.rs new file mode 100644 index 000000000000..c534924d1433 --- /dev/null +++ b/crates/polars-stream/src/nodes/group_by.rs @@ -0,0 +1,290 @@ +use std::mem::ManuallyDrop; +use std::sync::Arc; + +use polars_core::prelude::IntoColumn; +use polars_core::schema::Schema; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_expr::groups::Grouper; +use polars_expr::reduce::GroupedReduction; +use polars_utils::itertools::Itertools; +use polars_utils::sync::SyncPtr; +use rayon::prelude::*; + +use super::compute_node_prelude::*; +use crate::async_primitives::connector::Receiver; +use crate::expression::StreamExpr; +use crate::nodes::in_memory_source::InMemorySourceNode; + +struct LocalGroupBySinkState { + grouper: Box, + grouped_reductions: Vec>, +} + +struct GroupBySinkState { + key_selectors: Vec, + grouped_reduction_selectors: Vec, + grouper: Box, + grouped_reductions: Vec>, + local: Vec, +} + +impl GroupBySinkState { + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + receivers: Vec>, + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(receivers.len() >= self.local.len()); + self.local + .resize_with(receivers.len(), || LocalGroupBySinkState { + grouper: self.grouper.new_empty(), + grouped_reductions: self + .grouped_reductions + .iter() + .map(|r| r.new_empty()) + .collect(), + }); + for (mut recv, local) in receivers.into_iter().zip(&mut self.local) { + let key_selectors = &self.key_selectors; + let grouped_reduction_selectors = &self.grouped_reduction_selectors; + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + let mut group_idxs = Vec::new(); + while let Ok(morsel) = recv.recv().await { + // Compute group indices from key. + let df = morsel.into_df(); + let mut key_columns = Vec::new(); + for selector in key_selectors { + let s = selector.evaluate(&df, state).await?; + key_columns.push(s.into_column()); + } + let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; + local.grouper.insert_keys(&keys, &mut group_idxs); + + // Update reductions. + for (selector, reduction) in grouped_reduction_selectors + .iter() + .zip(&mut local.grouped_reductions) + { + unsafe { + // SAFETY: we resize the reduction to the number of groups beforehand. + reduction.resize(local.grouper.num_groups()); + reduction.update_groups( + &selector.evaluate(&df, state).await?, + &group_idxs, + )?; + } + } + } + Ok(()) + })); + } + } + + fn combine_locals( + output_schema: &Schema, + mut locals: Vec, + ) -> PolarsResult { + let mut group_idxs = Vec::new(); + let mut combined = locals.pop().unwrap(); + for local in locals { + combined.grouper.combine(&*local.grouper, &mut group_idxs); + for (l, r) in combined + .grouped_reductions + .iter_mut() + .zip(&local.grouped_reductions) + { + unsafe { + l.resize(combined.grouper.num_groups()); + l.combine(&**r, &group_idxs)?; + } + } + } + let mut out = combined.grouper.get_keys_in_group_order(); + let out_names = output_schema.iter_names().skip(out.width()); + for (mut r, name) in combined.grouped_reductions.into_iter().zip(out_names) { + unsafe { + out.with_column_unchecked(r.finalize()?.with_name(name.clone()).into_column()); + } + } + Ok(out) + } + + fn into_source_parallel(self, output_schema: &Schema) -> PolarsResult { + let num_partitions = self.local.len(); + let seed = 0xdeadbeef; + let partitioned_locals: Vec<_> = self + .local + .into_par_iter() + .with_max_len(1) + .map(|local| { + let mut partition_idxs = Vec::new(); + let p_groupers = local + .grouper + .partition(seed, num_partitions, &mut partition_idxs); + let partition_sizes = p_groupers.iter().map(|g| g.num_groups()).collect_vec(); + let grouped_reductions_p = local + .grouped_reductions + .into_iter() + .map(|r| unsafe { r.partition(&partition_sizes, &partition_idxs) }) + .collect_vec(); + (p_groupers, grouped_reductions_p) + }) + .collect(); + + let frames = unsafe { + let mut partitioned_locals = ManuallyDrop::new(partitioned_locals); + let partitioned_locals_ptr = SyncPtr::new(partitioned_locals.as_mut_ptr()); + (0..num_partitions) + .into_par_iter() + .with_max_len(1) + .map(|p| { + let locals_in_p = (0..num_partitions) + .map(|l| { + let partitioned_local = &*partitioned_locals_ptr.get().add(l); + let (p_groupers, grouped_reductions_p) = partitioned_local; + LocalGroupBySinkState { + grouper: p_groupers.as_ptr().add(p).read(), + grouped_reductions: grouped_reductions_p + .iter() + .map(|r| r.as_ptr().add(p).read()) + .collect(), + } + }) + .collect(); + Self::combine_locals(output_schema, locals_in_p) + }) + .collect::>>() + }; + + let df = accumulate_dataframes_vertical_unchecked(frames?); + let mut source_node = InMemorySourceNode::new(Arc::new(df)); + source_node.initialize(num_partitions); + Ok(source_node) + } + + fn into_source(self, output_schema: &Schema) -> PolarsResult { + if std::env::var("POLARS_PARALLEL_GROUPBY_FINALIZE").as_deref() == Ok("1") { + self.into_source_parallel(output_schema) + } else { + let num_pipelines = self.local.len(); + let df = Self::combine_locals(output_schema, self.local); + let mut source_node = InMemorySourceNode::new(Arc::new(df?)); + source_node.initialize(num_pipelines); + Ok(source_node) + } + } +} + +enum GroupByState { + Sink(GroupBySinkState), + Source(InMemorySourceNode), + Done, +} + +pub struct GroupByNode { + state: GroupByState, + output_schema: Arc, +} + +impl GroupByNode { + pub fn new( + key_selectors: Vec, + grouped_reduction_selectors: Vec, + grouped_reductions: Vec>, + grouper: Box, + output_schema: Arc, + ) -> Self { + Self { + state: GroupByState::Sink(GroupBySinkState { + key_selectors, + grouped_reduction_selectors, + grouped_reductions, + grouper, + local: Vec::new(), + }), + output_schema, + } + } +} + +impl ComputeNode for GroupByNode { + fn name(&self) -> &str { + "group_by" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(recv.len() == 1 && send.len() == 1); + + // State transitions. + match &mut self.state { + // If the output doesn't want any more data, transition to being done. + _ if send[0] == PortState::Done => { + self.state = GroupByState::Done; + }, + // Input is done, transition to being a source. + GroupByState::Sink(_) if matches!(recv[0], PortState::Done) => { + let GroupByState::Sink(sink) = + core::mem::replace(&mut self.state, GroupByState::Done) + else { + unreachable!() + }; + self.state = GroupByState::Source(sink.into_source(&self.output_schema)?); + }, + // Defer to source node implementation. + GroupByState::Source(src) => { + src.update_state(&mut [], send)?; + if send[0] == PortState::Done { + self.state = GroupByState::Done; + } + }, + // Nothing to change. + GroupByState::Done | GroupByState::Sink(_) => {}, + } + + // Communicate our state. + match &self.state { + GroupByState::Sink { .. } => { + send[0] = PortState::Blocked; + recv[0] = PortState::Ready; + }, + GroupByState::Source(..) => { + recv[0] = PortState::Done; + send[0] = PortState::Ready; + }, + GroupByState::Done => { + recv[0] = PortState::Done; + send[0] = PortState::Done; + }, + } + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(send_ports.len() == 1 && recv_ports.len() == 1); + match &mut self.state { + GroupByState::Sink(sink) => { + assert!(send_ports[0].is_none()); + sink.spawn( + scope, + recv_ports[0].take().unwrap().parallel(), + state, + join_handles, + ) + }, + GroupByState::Source(source) => { + assert!(recv_ports[0].is_none()); + source.spawn(scope, &mut [], send_ports, state, join_handles); + }, + GroupByState::Done => unreachable!(), + } + } +} diff --git a/crates/polars-stream/src/nodes/in_memory_map.rs b/crates/polars-stream/src/nodes/in_memory_map.rs index 3a8bff496a18..27af6be9aa87 100644 --- a/crates/polars-stream/src/nodes/in_memory_map.rs +++ b/crates/polars-stream/src/nodes/in_memory_map.rs @@ -86,16 +86,16 @@ impl ComputeNode for InMemoryMapNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { match self { Self::Sink { sink_node, .. } => { - sink_node.spawn(scope, recv, &mut [], state, join_handles) + sink_node.spawn(scope, recv_ports, &mut [], state, join_handles) }, - Self::Source(source) => source.spawn(scope, &mut [], send, state, join_handles), + Self::Source(source) => source.spawn(scope, &mut [], send_ports, state, join_handles), Self::Done => unreachable!(), } } diff --git a/crates/polars-stream/src/nodes/in_memory_sink.rs b/crates/polars-stream/src/nodes/in_memory_sink.rs index afd6ccfd95cc..58d2f9e8ffe6 100644 --- a/crates/polars-stream/src/nodes/in_memory_sink.rs +++ b/crates/polars-stream/src/nodes/in_memory_sink.rs @@ -45,13 +45,13 @@ impl ComputeNode for InMemorySinkNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.is_empty()); - let receivers = recv[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.is_empty()); + let receivers = recv_ports[0].take().unwrap().parallel(); for mut recv in receivers { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/in_memory_source.rs b/crates/polars-stream/src/nodes/in_memory_source.rs index 5ab6b0f75d50..c8dfec9d0032 100644 --- a/crates/polars-stream/src/nodes/in_memory_source.rs +++ b/crates/polars-stream/src/nodes/in_memory_source.rs @@ -60,13 +60,13 @@ impl ComputeNode for InMemorySourceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.is_empty() && send.len() == 1); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.is_empty() && send_ports.len() == 1); + let senders = send_ports[0].take().unwrap().parallel(); let source = self.source.as_ref().unwrap(); // TODO: can this just be serial, using the work distributor? diff --git a/crates/polars-stream/src/nodes/input_independent_select.rs b/crates/polars-stream/src/nodes/input_independent_select.rs index f1a9113d05d4..9df4c1ab5281 100644 --- a/crates/polars-stream/src/nodes/input_independent_select.rs +++ b/crates/polars-stream/src/nodes/input_independent_select.rs @@ -36,13 +36,13 @@ impl ComputeNode for InputIndependentSelectNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.is_empty() && send.len() == 1); - let mut sender = send[0].take().unwrap().serial(); + assert!(recv_ports.is_empty() && send_ports.len() == 1); + let mut sender = send_ports[0].take().unwrap().serial(); join_handles.push(scope.spawn_task(TaskPriority::Low, async move { let empty_df = DataFrame::empty(); diff --git a/crates/polars-stream/src/nodes/io_sinks/ipc.rs b/crates/polars-stream/src/nodes/io_sinks/ipc.rs new file mode 100644 index 000000000000..5587221d894c --- /dev/null +++ b/crates/polars-stream/src/nodes/io_sinks/ipc.rs @@ -0,0 +1,81 @@ +use std::fs::{File, OpenOptions}; +use std::path::Path; + +use polars_core::schema::SchemaRef; +use polars_error::PolarsResult; +use polars_expr::state::ExecutionState; +use polars_io::ipc::{BatchedWriter, IpcWriter, IpcWriterOptions}; +use polars_io::SerWriter; + +use crate::nodes::{ComputeNode, JoinHandle, PortState, TaskPriority, TaskScope}; +use crate::pipe::{RecvPort, SendPort}; + +pub struct IpcSinkNode { + is_finished: bool, + writer: BatchedWriter, +} + +impl IpcSinkNode { + pub fn new( + input_schema: SchemaRef, + path: &Path, + write_options: &IpcWriterOptions, + ) -> PolarsResult { + let file = OpenOptions::new().write(true).open(path)?; + let writer = IpcWriter::new(file) + .with_compression(write_options.compression) + .batched(&input_schema)?; + + Ok(Self { + is_finished: false, + writer, + }) + } +} + +impl ComputeNode for IpcSinkNode { + fn name(&self) -> &str { + "ipc_sink" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(send.is_empty()); + assert!(recv.len() == 1); + + if recv[0] == PortState::Done && !self.is_finished { + // @NOTE: This function can be called afterwards multiple times. So make sure to only + // finish the writer once. + self.is_finished = true; + self.writer.finish()?; + } + + // We are always ready to receive, unless the sender is done, then we're + // also done. + if recv[0] != PortState::Done { + recv[0] = PortState::Ready; + } + + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + _state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(send_ports.is_empty()); + assert!(recv_ports.len() == 1); + let mut receiver = recv_ports[0].take().unwrap().serial(); + + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + while let Ok(morsel) = receiver.recv().await { + self.writer.write_batch(&morsel.into_df())?; + } + + Ok(()) + })); + } +} diff --git a/crates/polars-stream/src/nodes/io_sinks/mod.rs b/crates/polars-stream/src/nodes/io_sinks/mod.rs new file mode 100644 index 000000000000..ce14ad3b0f7a --- /dev/null +++ b/crates/polars-stream/src/nodes/io_sinks/mod.rs @@ -0,0 +1 @@ +pub mod ipc; diff --git a/crates/polars-stream/src/nodes/map.rs b/crates/polars-stream/src/nodes/map.rs index 007dfa921672..c1994d1e4a9a 100644 --- a/crates/polars-stream/src/nodes/map.rs +++ b/crates/polars-stream/src/nodes/map.rs @@ -29,14 +29,14 @@ impl ComputeNode for MapNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index 82ad0f8293e9..559e4717c4e9 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -1,8 +1,10 @@ pub mod filter; +pub mod group_by; pub mod in_memory_map; pub mod in_memory_sink; pub mod in_memory_source; pub mod input_independent_select; +pub mod io_sinks; pub mod map; pub mod multiplexer; pub mod ordered_union; @@ -61,8 +63,8 @@ pub trait ComputeNode: Send { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ); diff --git a/crates/polars-stream/src/nodes/multiplexer.rs b/crates/polars-stream/src/nodes/multiplexer.rs index 65f2e752d28d..d4e4ac62cf01 100644 --- a/crates/polars-stream/src/nodes/multiplexer.rs +++ b/crates/polars-stream/src/nodes/multiplexer.rs @@ -92,13 +92,13 @@ impl ComputeNode for MultiplexerNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && !send.is_empty()); - assert!(self.buffers.len() == send.len()); + assert!(recv_ports.len() == 1 && !send_ports.is_empty()); + assert!(self.buffers.len() == send_ports.len()); enum Listener<'a> { Active(UnboundedSender), @@ -114,7 +114,7 @@ impl ComputeNode for MultiplexerNode { .enumerate() .map(|(port_idx, buffer)| { if let BufferedStream::Open(buf) = buffer { - if send[port_idx].is_some() { + if send_ports[port_idx].is_some() { // TODO: replace with a bounded channel and store data // out-of-core beyond a certain size. let (rx, tx) = unbounded_channel(); @@ -129,7 +129,7 @@ impl ComputeNode for MultiplexerNode { .unzip(); // TODO: parallel multiplexing. - if let Some(mut receiver) = recv[0].take().map(|r| r.serial()) { + if let Some(mut receiver) = recv_ports[0].take().map(|r| r.serial()) { let buffered_source_token = buffered_source_token.clone(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { loop { @@ -176,7 +176,7 @@ impl ComputeNode for MultiplexerNode { })); } - for (send_port, opt_buf_recv) in send.iter_mut().zip(buf_receivers) { + for (send_port, opt_buf_recv) in send_ports.iter_mut().zip(buf_receivers) { if let Some((buf, mut rx)) = opt_buf_recv { let mut sender = send_port.take().unwrap().serial(); diff --git a/crates/polars-stream/src/nodes/ordered_union.rs b/crates/polars-stream/src/nodes/ordered_union.rs index 3c72d9cc6e15..cb65175292e2 100644 --- a/crates/polars-stream/src/nodes/ordered_union.rs +++ b/crates/polars-stream/src/nodes/ordered_union.rs @@ -52,15 +52,15 @@ impl ComputeNode for OrderedUnionNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - let ready_count = recv.iter().filter(|r| r.is_some()).count(); - assert!(ready_count == 1 && send.len() == 1); - let receivers = recv[self.cur_input_idx].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + let ready_count = recv_ports.iter().filter(|r| r.is_some()).count(); + assert!(ready_count == 1 && send_ports.len() == 1); + let receivers = recv_ports[self.cur_input_idx].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); let mut inner_handles = Vec::new(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { diff --git a/crates/polars-stream/src/nodes/parquet_source/init.rs b/crates/polars-stream/src/nodes/parquet_source/init.rs index 3187bbe797e4..a722186ff497 100644 --- a/crates/polars-stream/src/nodes/parquet_source/init.rs +++ b/crates/polars-stream/src/nodes/parquet_source/init.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::future::Future; use std::sync::Arc; @@ -14,7 +15,6 @@ use super::{AsyncTaskData, ParquetSourceNode}; use crate::async_executor; use crate::async_primitives::connector::connector; use crate::async_primitives::wait_group::{WaitGroup, WaitToken}; -use crate::morsel::get_ideal_morsel_size; use crate::nodes::{MorselSeq, TaskPriority}; use crate::utils::task_handles_ext; @@ -118,6 +118,8 @@ impl ParquetSourceNode { let row_group_decoder = self.init_row_group_decoder(); let row_group_decoder = Arc::new(row_group_decoder); + let ideal_morsel_size = self.config.ideal_morsel_size; + // Distributes morsels across pipelines. This does not perform any CPU or I/O bound work - // it is purely a dispatch loop. let raw_morsel_distributor_task_handle = io_runtime.spawn(async move { @@ -191,25 +193,31 @@ impl ParquetSourceNode { ); let morsel_seq_ref = &mut MorselSeq::default(); - let mut dfs = vec![].into_iter(); + let mut dfs = VecDeque::with_capacity(1); 'main: loop { let Some(mut indexed_wait_group) = wait_groups.next().await else { break; }; - if dfs.len() == 0 { + while dfs.is_empty() { let Some(v) = df_stream.next().await else { - break; + break 'main; }; - let v = v?; - assert!(!v.is_empty()); + let df = v?; + + if df.is_empty() { + continue; + } - dfs = v.into_iter(); + let (iter, n) = split_to_morsels(&df, ideal_morsel_size); + + dfs.reserve(n); + dfs.extend(iter); } - let mut df = dfs.next().unwrap(); + let mut df = dfs.pop_front().unwrap(); let morsel_seq = *morsel_seq_ref; *morsel_seq_ref = morsel_seq.successor(); @@ -270,7 +278,6 @@ impl ParquetSourceNode { let projected_arrow_schema = self.projected_arrow_schema.clone().unwrap(); let row_index = self.file_options.row_index.clone(); let physical_predicate = self.physical_predicate.clone(); - let ideal_morsel_size = get_ideal_morsel_size(); let min_values_per_thread = self.config.min_values_per_thread; let mut use_prefiltered = physical_predicate.is_some() @@ -348,7 +355,6 @@ impl ParquetSourceNode { predicate_arrow_field_indices, non_predicate_arrow_field_indices, predicate_arrow_field_mask, - ideal_morsel_size, min_values_per_thread, } } @@ -402,6 +408,28 @@ fn filtered_range(exclude: &[usize], len: usize) -> Vec { .collect() } +/// Note: The 2nd return is an upper bound on the number of morsels rather than an exact count. +fn split_to_morsels( + df: &DataFrame, + ideal_morsel_size: usize, +) -> (impl Iterator + '_, usize) { + let n_morsels = if df.height() > 3 * ideal_morsel_size / 2 { + // num_rows > (1.5 * ideal_morsel_size) + (df.height() / ideal_morsel_size).max(2) + } else { + 1 + }; + + let rows_per_morsel = 1 + df.height() / n_morsels; + + ( + (0..i64::try_from(df.height()).unwrap()) + .step_by(rows_per_morsel) + .map(move |offset| df.slice(offset, rows_per_morsel)), + n_morsels, + ) +} + mod tests { #[test] diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index 746c517ce744..e3377036b908 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -141,7 +141,7 @@ impl ParquetSourceNode { } if allow_missing_columns { - ensure_matching_dtypes_if_found(&first_schema, &schema)?; + ensure_matching_dtypes_if_found(projected_arrow_schema.as_ref(), &schema)?; } else { ensure_schema_has_projected_fields( &schema, diff --git a/crates/polars-stream/src/nodes/parquet_source/mod.rs b/crates/polars-stream/src/nodes/parquet_source/mod.rs index 44fc4e1f1239..a5efa4cb3b89 100644 --- a/crates/polars-stream/src/nodes/parquet_source/mod.rs +++ b/crates/polars-stream/src/nodes/parquet_source/mod.rs @@ -18,7 +18,7 @@ use polars_plan::prelude::FileScanOptions; use super::compute_node_prelude::*; use super::{MorselSeq, TaskPriority}; use crate::async_primitives::wait_group::WaitToken; -use crate::morsel::SourceToken; +use crate::morsel::{get_ideal_morsel_size, SourceToken}; use crate::utils::task_handles_ext; mod init; @@ -70,6 +70,7 @@ struct Config { /// Minimum number of values for a parallel spawned task to process to amortize /// parallelism overhead. min_values_per_thread: usize, + ideal_morsel_size: usize, } #[allow(clippy::too_many_arguments)] @@ -110,6 +111,7 @@ impl ParquetSourceNode { metadata_decode_ahead_size: 0, row_group_prefetch_size: 0, min_values_per_thread: 0, + ideal_morsel_size: 0, }, verbose, physical_predicate: None, @@ -142,6 +144,7 @@ impl ComputeNode for ParquetSourceNode { let min_values_per_thread = std::env::var("POLARS_MIN_VALUES_PER_THREAD") .map(|x| x.parse::().expect("integer").max(1)) .unwrap_or(16_777_216); + let ideal_morsel_size = get_ideal_morsel_size(); Config { num_pipelines, @@ -149,6 +152,7 @@ impl ComputeNode for ParquetSourceNode { metadata_decode_ahead_size, row_group_prefetch_size, min_values_per_thread, + ideal_morsel_size, } }; @@ -198,18 +202,18 @@ impl ComputeNode for ParquetSourceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { use std::sync::atomic::Ordering; - assert!(recv.is_empty()); - assert_eq!(send.len(), 1); + assert!(recv_ports.is_empty()); + assert_eq!(send_ports.len(), 1); assert!(!self.is_finished.load(Ordering::Relaxed)); - let morsel_senders = send[0].take().unwrap().parallel(); + let morsel_senders = send_ports[0].take().unwrap().parallel(); let mut async_task_data_guard = self.async_task_data.try_lock().unwrap(); let (raw_morsel_receivers, _) = async_task_data_guard.as_mut().unwrap(); diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs index dfa4b11e3b02..52d3003de7ea 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs @@ -2,11 +2,12 @@ use std::future::Future; use std::sync::Arc; use polars_core::prelude::{ArrowSchema, InitHashMaps, PlHashMap}; +use polars_core::series::IsSorted; use polars_core::utils::operation_exceeded_idxsize_msg; use polars_error::{polars_err, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; -use polars_io::prelude::FileMetadata; use polars_io::prelude::_internal::read_this_row_group; +use polars_io::prelude::{create_sorting_map, FileMetadata}; use polars_io::utils::byte_source::{ByteSource, DynByteSource}; use polars_io::utils::slice::SplitSlicePosition; use polars_parquet::read::RowGroupMetadata; @@ -27,6 +28,7 @@ pub(super) struct RowGroupData { pub(super) slice: Option<(usize, usize)>, pub(super) file_max_row_group_height: usize, pub(super) row_group_metadata: RowGroupMetadata, + pub(super) sorting_map: PlHashMap, pub(super) shared_file_state: Arc>, } @@ -86,6 +88,7 @@ impl RowGroupDataFetcher { let current_row_group_idx = self.current_row_group_idx; let num_rows = row_group_metadata.num_rows(); + let sorting_map = create_sorting_map(&row_group_metadata); self.current_row_offset = current_row_offset.saturating_add(num_rows); self.current_row_group_idx += 1; @@ -246,6 +249,7 @@ impl RowGroupDataFetcher { slice, file_max_row_group_height: current_max_row_group_height, row_group_metadata, + sorting_map, shared_file_state: current_shared_file_state.clone(), }) }); diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs index dc8fe611eafd..d31f1e51f71e 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs @@ -11,6 +11,7 @@ use polars_error::{polars_bail, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; use polars_io::prelude::_internal::calc_prefilter_cost; pub use polars_io::prelude::_internal::PrefilterMaskSetting; +use polars_io::prelude::try_set_sorted_flag; use polars_io::RowIndex; use polars_plan::plans::hive::HivePartitions; use polars_plan::plans::ScanSources; @@ -37,7 +38,6 @@ pub(super) struct RowGroupDecoder { pub(super) non_predicate_arrow_field_indices: Vec, /// The nth bit is set to `true` if the field at that index is used in the predicate. pub(super) predicate_arrow_field_mask: Vec, - pub(super) ideal_morsel_size: usize, pub(super) min_values_per_thread: usize, } @@ -45,7 +45,7 @@ impl RowGroupDecoder { pub(super) async fn row_group_data_to_df( &self, row_group_data: RowGroupData, - ) -> PolarsResult> { + ) -> PolarsResult { if self.use_prefiltered.is_some() { self.row_group_data_to_df_prefiltered(row_group_data).await } else { @@ -56,7 +56,7 @@ impl RowGroupDecoder { async fn row_group_data_to_df_impl( &self, row_group_data: RowGroupData, - ) -> PolarsResult> { + ) -> PolarsResult { let row_group_data = Arc::new(row_group_data); let out_width = self.row_index.is_some() as usize @@ -108,24 +108,29 @@ impl RowGroupDecoder { out_columns.push(file_path_series.slice(0, projection_height)); } - let df = unsafe { DataFrame::new_no_checks(out_columns) }; + let df = unsafe { DataFrame::new_no_checks(projection_height, out_columns) }; let df = if let Some(predicate) = self.physical_predicate.as_deref() { let mask = predicate.evaluate_io(&df)?; let mask = mask.bool().unwrap(); - unsafe { - DataFrame::new_no_checks( - filter_cols(df.take_columns(), mask, self.min_values_per_thread).await?, - ) - } + let filtered = + unsafe { filter_cols(df.take_columns(), mask, self.min_values_per_thread) }.await?; + + let height = if let Some(fst) = filtered.first() { + fst.len() + } else { + mask.num_trues() + }; + + unsafe { DataFrame::new_no_checks(height, filtered) } } else { df }; assert_eq!(df.width(), out_width); // `out_width` should have been calculated correctly - Ok(self.split_to_morsels(df)) + Ok(df) } async fn shared_file_state_init_func(&self, row_group_data: &RowGroupData) -> SharedFileState { @@ -302,26 +307,6 @@ impl RowGroupDecoder { Ok(()) } - - fn split_to_morsels(&self, df: DataFrame) -> Vec { - let n_morsels = if df.height() > 3 * self.ideal_morsel_size / 2 { - // num_rows > (1.5 * ideal_morsel_size) - (df.height() / self.ideal_morsel_size).max(2) - } else { - 1 - } as u64; - - if n_morsels == 1 { - return vec![df]; - } - - let rows_per_morsel = 1 + df.height() / n_morsels as usize; - - (0..i64::try_from(df.height()).unwrap()) - .step_by(rows_per_morsel) - .map(|offset| df.slice(offset, rows_per_morsel)) - .collect::>() - } } fn decode_column( @@ -362,11 +347,20 @@ fn decode_column( assert_eq!(array.len(), expected_num_rows); - let series = Series::try_from((arrow_field, array))?; + let mut series = Series::try_from((arrow_field, array))?; + + if let Some(col_idxs) = row_group_data + .row_group_metadata + .columns_idxs_under_root_iter(&arrow_field.name) + { + if col_idxs.len() == 1 { + try_set_sorted_flag(&mut series, col_idxs[0], &row_group_data.sorting_map); + } + } // TODO: Also load in the metadata. - Ok(series.into()) + Ok(series.into_column()) } /// # Safety @@ -463,7 +457,7 @@ impl RowGroupDecoder { async fn row_group_data_to_df_prefiltered( &self, row_group_data: RowGroupData, - ) -> PolarsResult> { + ) -> PolarsResult { debug_assert!(row_group_data.slice.is_none()); // Invariant of the optimizer. assert!(self.predicate_arrow_field_indices.len() <= self.projected_arrow_schema.len()); @@ -515,7 +509,9 @@ impl RowGroupDecoder { live_columns.push(s?); } - let live_df = unsafe { DataFrame::new_no_checks(live_columns) }; + let live_df = unsafe { + DataFrame::new_no_checks(row_group_data.row_group_metadata.num_rows(), live_columns) + }; let mask = self .physical_predicate .as_deref() @@ -523,12 +519,18 @@ impl RowGroupDecoder { .evaluate_io(&live_df)?; let mask = mask.bool().unwrap(); - let live_df_filtered = unsafe { - DataFrame::new_no_checks( - filter_cols(live_df.take_columns(), mask, self.min_values_per_thread).await?, - ) + let filtered = + unsafe { filter_cols(live_df.take_columns(), mask, self.min_values_per_thread) } + .await?; + + let height = if let Some(fst) = filtered.first() { + fst.len() + } else { + mask.num_trues() }; + let live_df_filtered = unsafe { DataFrame::new_no_checks(height, filtered) }; + let mask_bitmap = { let mut mask_bitmap = MutableBitmap::with_capacity(mask.len()); @@ -590,8 +592,8 @@ impl RowGroupDecoder { out_columns.extend(live_rem); // optional hive cols, file path col assert_eq!(dead_rem.len(), 0); - let df = unsafe { DataFrame::new_no_checks(out_columns) }; - Ok(self.split_to_morsels(df)) + let df = unsafe { DataFrame::new_no_checks(expected_num_rows, out_columns) }; + Ok(df) } } @@ -639,17 +641,26 @@ fn decode_column_prefiltered( deserialize_filter, )?; - let column = Series::try_from((arrow_field, array))?.into_column(); + let mut series = Series::try_from((arrow_field, array))?; + + if let Some(col_idxs) = row_group_data + .row_group_metadata + .columns_idxs_under_root_iter(&arrow_field.name) + { + if col_idxs.len() == 1 { + try_set_sorted_flag(&mut series, col_idxs[0], &row_group_data.sorting_map); + } + } - let column = if !prefilter { - column.filter(mask)? + let series = if !prefilter { + series.filter(mask)? } else { - column + series }; - assert_eq!(column.len(), expected_num_rows); + assert_eq!(series.len(), expected_num_rows); - Ok(column) + Ok(series.into_column()) } mod tests { diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs index 15048daba4f8..565854e97b81 100644 --- a/crates/polars-stream/src/nodes/reduce.rs +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -1,7 +1,9 @@ use std::sync::Arc; +use polars_core::frame::column::ScalarColumn; +use polars_core::prelude::Column; use polars_core::schema::{Schema, SchemaExt}; -use polars_expr::reduce::{Reduction, ReductionState}; +use polars_expr::reduce::GroupedReduction; use polars_utils::itertools::Itertools; use super::compute_node_prelude::*; @@ -11,8 +13,7 @@ use crate::morsel::SourceToken; enum ReduceState { Sink { selectors: Vec, - reductions: Vec>, - reduction_states: Vec>, + reductions: Vec>, }, Source(Option), Done, @@ -26,15 +27,13 @@ pub struct ReduceNode { impl ReduceNode { pub fn new( selectors: Vec, - reductions: Vec>, + reductions: Vec>, output_schema: Arc, ) -> Self { - let reduction_states = reductions.iter().map(|r| r.new_reducer()).collect(); Self { state: ReduceState::Sink { selectors, reductions, - reduction_states, }, output_schema, } @@ -42,8 +41,7 @@ impl ReduceNode { fn spawn_sink<'env, 's>( selectors: &'env [StreamExpr], - reductions: &'env mut [Box], - reduction_states: &'env mut [Box], + reductions: &'env mut [Box], scope: &'s TaskScope<'s, 'env>, recv: RecvPort<'_>, state: &'s ExecutionState, @@ -53,14 +51,20 @@ impl ReduceNode { .parallel() .into_iter() .map(|mut recv| { - let mut local_reducers: Vec<_> = - reductions.iter().map(|d| d.new_reducer()).collect(); + let mut local_reducers: Vec<_> = reductions + .iter() + .map(|d| { + let mut r = d.new_empty(); + r.resize(1); + r + }) + .collect(); scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = recv.recv().await { for (reducer, selector) in local_reducers.iter_mut().zip(selectors) { let input = selector.evaluate(morsel.df(), state).await?; - reducer.update(&input)?; + reducer.update_group(&input, 0)?; } } @@ -72,8 +76,11 @@ impl ReduceNode { join_handles.push(scope.spawn_task(TaskPriority::High, async move { for task in parallel_tasks { let local_reducers = task.await?; - for (r1, r2) in reduction_states.iter_mut().zip(local_reducers) { - r1.combine(&*r2)?; + for (r1, r2) in reductions.iter_mut().zip(local_reducers) { + r1.resize(1); + unsafe { + r1.combine(&*r2, &[0])?; + } } } @@ -111,18 +118,14 @@ impl ComputeNode for ReduceNode { self.state = ReduceState::Done; }, // Input is done, transition to being a source. - ReduceState::Sink { - reduction_states, .. - } if matches!(recv[0], PortState::Done) => { - let columns = reduction_states + ReduceState::Sink { reductions, .. } if matches!(recv[0], PortState::Done) => { + let columns = reductions .iter_mut() .zip(self.output_schema.iter_fields()) .map(|(r, field)| { - r.finalize().map(|scalar| { - scalar - .into_column(field.name.clone()) - .cast(&field.dtype) - .unwrap() + r.finalize().map(|s| { + let s = s.with_name(field.name.clone()).cast(&field.dtype).unwrap(); + Column::Scalar(ScalarColumn::unit_scalar_from_series(s)) }) }) .try_collect_vec()?; @@ -159,33 +162,24 @@ impl ComputeNode for ReduceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(send.len() == 1 && recv.len() == 1); + assert!(send_ports.len() == 1 && recv_ports.len() == 1); match &mut self.state { ReduceState::Sink { selectors, reductions, - reduction_states, } => { - assert!(send[0].is_none()); - let recv_port = recv[0].take().unwrap(); - Self::spawn_sink( - selectors, - reductions, - reduction_states, - scope, - recv_port, - state, - join_handles, - ) + assert!(send_ports[0].is_none()); + let recv_port = recv_ports[0].take().unwrap(); + Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles) }, ReduceState::Source(df) => { - assert!(recv[0].is_none()); - let send_port = send[0].take().unwrap(); + assert!(recv_ports[0].is_none()); + let send_port = send_ports[0].take().unwrap(); Self::spawn_source(df, scope, send_port, join_handles) }, ReduceState::Done => unreachable!(), diff --git a/crates/polars-stream/src/nodes/select.rs b/crates/polars-stream/src/nodes/select.rs index 3b060e78e654..bf12904ff12c 100644 --- a/crates/polars-stream/src/nodes/select.rs +++ b/crates/polars-stream/src/nodes/select.rs @@ -36,14 +36,14 @@ impl ComputeNode for SelectNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/simple_projection.rs b/crates/polars-stream/src/nodes/simple_projection.rs index 95f002df2889..00cd8ed55ad0 100644 --- a/crates/polars-stream/src/nodes/simple_projection.rs +++ b/crates/polars-stream/src/nodes/simple_projection.rs @@ -33,14 +33,14 @@ impl ComputeNode for SimpleProjectionNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/streaming_slice.rs b/crates/polars-stream/src/nodes/streaming_slice.rs index 950b39331588..5d9f5a003340 100644 --- a/crates/polars-stream/src/nodes/streaming_slice.rs +++ b/crates/polars-stream/src/nodes/streaming_slice.rs @@ -43,14 +43,14 @@ impl ComputeNode for StreamingSliceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let mut recv = recv[0].take().unwrap().serial(); - let mut send = send[0].take().unwrap().serial(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let mut recv = recv_ports[0].take().unwrap().serial(); + let mut send = send_ports[0].take().unwrap().serial(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { let stop_offset = self.start_offset + self.length; diff --git a/crates/polars-stream/src/nodes/with_row_index.rs b/crates/polars-stream/src/nodes/with_row_index.rs index 942d23219fec..fe075120963d 100644 --- a/crates/polars-stream/src/nodes/with_row_index.rs +++ b/crates/polars-stream/src/nodes/with_row_index.rs @@ -35,14 +35,14 @@ impl ComputeNode for WithRowIndexNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let mut receiver = recv[0].take().unwrap().serial(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let mut receiver = recv_ports[0].take().unwrap().serial(); + let senders = send_ports[0].take().unwrap().parallel(); let (mut distributor, distr_receivers) = distributor_channel(senders.len(), DEFAULT_DISTRIBUTOR_BUFFER_SIZE); diff --git a/crates/polars-stream/src/nodes/zip.rs b/crates/polars-stream/src/nodes/zip.rs index cd72a3567442..614c7b506128 100644 --- a/crates/polars-stream/src/nodes/zip.rs +++ b/crates/polars-stream/src/nodes/zip.rs @@ -205,20 +205,20 @@ impl ComputeNode for ZipNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(send.len() == 1); - assert!(!recv.is_empty()); - let mut sender = send[0].take().unwrap().serial(); + assert!(send_ports.len() == 1); + assert!(!recv_ports.is_empty()); + let mut sender = send_ports[0].take().unwrap().serial(); - let mut receivers = recv + let mut receivers = recv_ports .iter_mut() - .map(|r| { + .map(|recv_port| { // Add buffering to each receiver to reduce contention between input heads. - let mut serial_recv = r.take()?.serial(); + let mut serial_recv = recv_port.take()?.serial(); let (buf_send, buf_recv) = tokio::sync::mpsc::channel(DEFAULT_ZIP_HEAD_BUFFER_SIZE); join_handles.push(scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = serial_recv.recv().await { diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 21dd8e9dd634..ed0f08a0d48f 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -2,6 +2,7 @@ use std::fmt::Write; use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{AExpr, EscapeLabel, FileScan, ScanSourcesDisplay}; +use polars_plan::prelude::FileType; use polars_utils::arena::Arena; use polars_utils::itertools::Itertools; use slotmap::{Key, SecondaryMap, SlotMap}; @@ -95,6 +96,14 @@ fn visualize_plan_rec( from_ref(input), ), PhysNodeKind::InMemorySink { input } => ("in-memory-sink".to_string(), from_ref(input)), + PhysNodeKind::FileSink { + input, file_type, .. + } => match file_type { + FileType::Parquet(_) => ("parquet-sink".to_string(), from_ref(input)), + FileType::Ipc(_) => ("ipc-sink".to_string(), from_ref(input)), + FileType::Csv(_) => ("csv-sink".to_string(), from_ref(input)), + FileType::Json(_) => ("json-sink".to_string(), from_ref(input)), + }, PhysNodeKind::InMemoryMap { input, map: _ } => { ("in-memory-map".to_string(), from_ref(input)) }, @@ -183,6 +192,17 @@ fn visualize_plan_rec( (out, &[][..]) }, + PhysNodeKind::GroupBy { input, key, aggs } => { + let label = "group-by"; + ( + format!( + "{label}\\nkey:\\n{}\\naggs:\\n{}", + fmt_exprs(key, expr_arena), + fmt_exprs(aggs, expr_arena) + ), + from_ref(input), + ) + }, }; out.push(format!( diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 618ec358f209..3af80df16f9f 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -98,6 +98,7 @@ pub(crate) fn is_elementwise( match function { // Non-strict strptime must be done in-memory to ensure the format // is consistent across the entire dataframe. + #[cfg(feature = "strings")] FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)) => opts.strict, _ => { options.is_elementwise() @@ -348,7 +349,7 @@ fn build_fallback_node_with_ctx( expr, Context::Default, ctx.expr_arena, - Some(&ctx.phys_sm[input_node].output_schema), + &ctx.phys_sm[input_node].output_schema, &mut conv_state, ) }) @@ -573,7 +574,9 @@ fn lower_exprs_with_ctx( .. } | IRAggExpr::Sum(ref mut inner) - | IRAggExpr::Mean(ref mut inner) => { + | IRAggExpr::Mean(ref mut inner) + | IRAggExpr::Var(ref mut inner, _ /* ddof */) + | IRAggExpr::Std(ref mut inner, _ /* ddof */) => { let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[*inner], ctx)?; *inner = trans_exprs[0]; @@ -596,8 +599,6 @@ fn lower_exprs_with_ctx( | IRAggExpr::Implode(_) | IRAggExpr::Quantile { .. } | IRAggExpr::Count(_, _) - | IRAggExpr::Std(_, _) - | IRAggExpr::Var(_, _) | IRAggExpr::AggGroups(_) => { let out_name = unique_column_name(); fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone()))); @@ -664,7 +665,7 @@ fn lower_exprs_with_ctx( /// Computes the schema that selecting the given expressions on the input schema /// would result in. -fn compute_output_schema( +pub fn compute_output_schema( input_schema: &Schema, exprs: &[ExprIR], expr_arena: &Arena, diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index f2b532ca1ca3..485bbf03a7fe 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -2,16 +2,37 @@ use std::sync::Arc; use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; use polars_core::schema::Schema; -use polars_error::PolarsResult; +use polars_error::{polars_ensure, PolarsResult}; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; -use polars_plan::plans::{AExpr, FunctionIR, IR}; -use polars_plan::prelude::SinkType; +use polars_plan::plans::{AExpr, FunctionIR, IRAggExpr, IR}; +use polars_plan::prelude::{FileType, SinkType}; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use slotmap::SlotMap; use super::{PhysNode, PhysNodeKey, PhysNodeKind}; -use crate::physical_plan::lower_expr::{is_elementwise, ExprCache}; +use crate::physical_plan::lower_expr::{build_select_node, is_elementwise, lower_exprs, ExprCache}; + +fn build_slice_node( + input: PhysNodeKey, + offset: i64, + length: usize, + phys_sm: &mut SlotMap, +) -> PhysNodeKey { + if offset >= 0 { + let offset = offset as usize; + phys_sm.insert(PhysNode::new( + phys_sm[input].output_schema.clone(), + PhysNodeKind::StreamingSlice { + input, + offset, + length, + }, + )) + } else { + todo!() + } +} #[recursive::recursive] pub fn lower_ir( @@ -22,19 +43,26 @@ pub fn lower_ir( schema_cache: &mut PlHashMap>, expr_cache: &mut ExprCache, ) -> PolarsResult { - let ir_node = ir_arena.get(node); - let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); - let node_kind = match ir_node { - IR::SimpleProjection { input, columns } => { - let columns = columns.iter_names_cloned().collect::>(); - let phys_input = lower_ir( - *input, + // Helper macro to simplify recursive calls. + macro_rules! lower_ir { + ($input:expr) => { + lower_ir( + $input, ir_arena, expr_arena, phys_sm, schema_cache, expr_cache, - )?; + ) + }; + } + + let ir_node = ir_arena.get(node); + let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); + let node_kind = match ir_node { + IR::SimpleProjection { input, columns } => { + let columns = columns.iter_names_cloned().collect::>(); + let phys_input = lower_ir!(*input)?; PhysNodeKind::SimpleProjection { input: phys_input, columns, @@ -43,17 +71,8 @@ pub fn lower_ir( IR::Select { input, expr, .. } => { let selectors = expr.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; - return super::lower_expr::build_select_node( - phys_input, &selectors, expr_arena, phys_sm, expr_cache, - ); + let phys_input = lower_ir!(*input)?; + return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::HStack { input, exprs, .. } @@ -63,14 +82,7 @@ pub fn lower_ir( { // FIXME: constant literal columns should be broadcasted with hstack. let selectors = exprs.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; PhysNodeKind::Select { input: phys_input, selectors, @@ -84,14 +96,7 @@ pub fn lower_ir( // // FIXME: constant literal columns should be broadcasted with hstack. let exprs = exprs.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; let input_schema = &phys_sm[phys_input].output_schema; let mut selectors = PlIndexMap::with_capacity(input_schema.len() + exprs.len()); for name in input_schema.iter_names() { @@ -106,43 +111,19 @@ pub fn lower_ir( selectors.insert(expr.output_name().clone(), expr); } let selectors = selectors.into_values().collect_vec(); - return super::lower_expr::build_select_node( - phys_input, &selectors, expr_arena, phys_sm, expr_cache, - ); + return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::Slice { input, offset, len } => { - if *offset >= 0 { - let offset = *offset as usize; - let length = *len as usize; - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; - PhysNodeKind::StreamingSlice { - input: phys_input, - offset, - length, - } - } else { - todo!() - } + let offset = *offset; + let len = *len as usize; + let phys_input = lower_ir!(*input)?; + return Ok(build_slice_node(phys_input, offset, len, phys_sm)); }, IR::Filter { input, predicate } => { let predicate = predicate.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; let cols_and_predicate = output_schema .iter_names() .cloned() @@ -154,7 +135,7 @@ pub fn lower_ir( }) .chain([predicate]) .collect_vec(); - let (trans_input, mut trans_cols_and_predicate) = super::lower_expr::lower_exprs( + let (trans_input, mut trans_cols_and_predicate) = lower_exprs( phys_input, &cols_and_predicate, expr_arena, @@ -170,7 +151,7 @@ pub fn lower_ir( let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); trans_cols_and_predicate.pop(); // Remove predicate. - return super::lower_expr::build_select_node( + return build_select_node( post_filter, &trans_cols_and_predicate, expr_arena, @@ -221,38 +202,40 @@ pub fn lower_ir( node_kind }, - IR::Sink { input, payload } => { - if *payload == SinkType::Memory { - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + IR::Sink { input, payload } => match payload { + SinkType::Memory => { + let phys_input = lower_ir!(*input)?; PhysNodeKind::InMemorySink { input: phys_input } - } else { - todo!() - } + }, + SinkType::File { path, file_type } => { + let path = path.clone(); + let file_type = file_type.clone(); + + match file_type { + FileType::Ipc(_) => { + let phys_input = lower_ir!(*input)?; + PhysNodeKind::FileSink { + path, + file_type, + input: phys_input, + } + }, + _ => todo!(), + } + }, + SinkType::Cloud { .. } => todo!(), }, IR::MapFunction { input, function } => { // MergeSorted uses a rechunk hack incompatible with the // streaming engine. + #[cfg(feature = "merge_sorted")] if let FunctionIR::MergeSorted { .. } = function { todo!() } let function = function.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; match function { FunctionIR::RowIndex { @@ -292,14 +275,7 @@ pub fn lower_ir( by_column: by_column.clone(), slice: *slice, sort_options: sort_options.clone(), - input: lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?, + input: lower_ir!(*input)?, }, IR::Union { inputs, options } => { @@ -310,16 +286,7 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| { - lower_ir( - input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - ) - }) + .map(|input| lower_ir!(input)) .collect::>()?; PhysNodeKind::OrderedUnion { inputs } }, @@ -332,16 +299,7 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| { - lower_ir( - input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - ) - }) + .map(|input| lower_ir!(input)) .collect::>()?; PhysNodeKind::Zip { inputs, @@ -377,7 +335,84 @@ pub fn lower_ir( IR::PythonScan { .. } => todo!(), IR::Reduce { .. } => todo!(), IR::Cache { .. } => todo!(), - IR::GroupBy { .. } => todo!(), + IR::GroupBy { + input, + keys, + aggs, + schema: _, + apply, + maintain_order, + options, + } => { + if apply.is_some() || *maintain_order { + todo!() + } + + #[cfg(feature = "dynamic_group_by")] + if options.dynamic.is_some() || options.rolling.is_some() { + todo!() + } + + let key = keys.clone(); + let mut aggs = aggs.clone(); + let options = options.clone(); + + polars_ensure!(!keys.is_empty(), ComputeError: "at least one key is required in a group_by operation"); + + // TODO: allow all aggregates. + let mut input_exprs = key.clone(); + for agg in &aggs { + match expr_arena.get(agg.node()) { + AExpr::Agg(expr) => match expr { + IRAggExpr::Min { input, .. } + | IRAggExpr::Max { input, .. } + | IRAggExpr::Mean(input) + | IRAggExpr::Sum(input) + | IRAggExpr::Var(input, ..) + | IRAggExpr::Std(input, ..) => { + if is_elementwise(*input, expr_arena, expr_cache) { + input_exprs.push(ExprIR::from_node(*input, expr_arena)); + } else { + todo!() + } + }, + _ => todo!(), + }, + AExpr::Len => input_exprs.push(key[0].clone()), // Hack, use the first key column for the length. + _ => todo!(), + } + } + + let phys_input = lower_ir!(*input)?; + let (trans_input, trans_exprs) = + lower_exprs(phys_input, &input_exprs, expr_arena, phys_sm, expr_cache)?; + let trans_key = trans_exprs[..key.len()].to_vec(); + let trans_aggs = aggs + .iter_mut() + .zip(trans_exprs.iter().skip(key.len())) + .map(|(agg, trans_expr)| { + let old_expr = expr_arena.get(agg.node()).clone(); + let new_expr = old_expr.replace_inputs(&[trans_expr.node()]); + ExprIR::new(expr_arena.add(new_expr), agg.output_name_inner().clone()) + }) + .collect(); + + let mut node = phys_sm.insert(PhysNode::new( + output_schema, + PhysNodeKind::GroupBy { + input: trans_input, + key: trans_key, + aggs: trans_aggs, + }, + )); + + // TODO: actually limit number of groups instead of computing full + // result and then slicing. + if let Some((offset, len)) = options.slice { + node = build_slice_node(node, offset, len, phys_sm); + } + return Ok(node); + }, IR::Join { .. } => todo!(), IR::Distinct { .. } => todo!(), IR::ExtContext { .. } => todo!(), diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index eddbc87bda99..3b4643100249 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::sync::Arc; use polars_core::frame::DataFrame; @@ -14,7 +15,7 @@ mod lower_ir; mod to_graph; pub use fmt::visualize_plan; -use polars_plan::prelude::FileScanOptions; +use polars_plan::prelude::{FileScanOptions, FileType}; use polars_utils::arena::{Arena, Node}; use polars_utils::pl_str::PlSmallStr; use slotmap::{Key, SecondaryMap, SlotMap}; @@ -93,6 +94,12 @@ pub enum PhysNodeKind { input: PhysNodeKey, }, + FileSink { + path: Arc, + file_type: FileType, + input: PhysNodeKey, + }, + InMemoryMap { input: PhysNodeKey, map: Arc, @@ -136,6 +143,12 @@ pub enum PhysNodeKind { scan_type: FileScan, file_options: FileScanOptions, }, + + GroupBy { + input: PhysNodeKey, + key: Vec, + aggs: Vec, + }, } #[recursive::recursive] @@ -176,10 +189,12 @@ fn insert_multiplexers( | PhysNodeKind::Filter { input, .. } | PhysNodeKind::SimpleProjection { input, .. } | PhysNodeKind::InMemorySink { input } + | PhysNodeKind::FileSink { input, .. } | PhysNodeKind::InMemoryMap { input, .. } | PhysNodeKind::Map { input, .. } | PhysNodeKind::Sort { input, .. } - | PhysNodeKind::Multiplexer { input } => { + | PhysNodeKind::Multiplexer { input } + | PhysNodeKind::GroupBy { input, .. } => { insert_multiplexers(*input, phys_sm, referenced); }, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index a4d58847033e..d9253e48dfa5 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -1,7 +1,9 @@ use std::sync::Arc; use parking_lot::Mutex; +use polars_core::schema::{Schema, SchemaExt}; use polars_error::PolarsResult; +use polars_expr::groups::new_hash_grouper; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; @@ -9,7 +11,7 @@ use polars_mem_engine::create_physical_plan; use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{AExpr, ArenaExprIter, Context, IR}; -use polars_plan::prelude::FunctionFlags; +use polars_plan::prelude::{FileType, FunctionFlags}; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use recursive::recursive; @@ -19,6 +21,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; +use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; fn has_potential_recurring_entrance(node: Node, arena: &Arena) -> bool { @@ -33,13 +36,14 @@ fn has_potential_recurring_entrance(node: Node, arena: &Arena) -> bool { fn create_stream_expr( expr_ir: &ExprIR, ctx: &mut GraphConversionContext<'_>, + schema: &Arc, ) -> PolarsResult { let reentrant = has_potential_recurring_entrance(expr_ir.node(), ctx.expr_arena); let phys = create_physical_expr( expr_ir, Context::Default, ctx.expr_arena, - None, + schema, &mut ctx.expr_conversion_state, )?; Ok(StreamExpr::new(phys, reentrant)) @@ -103,7 +107,8 @@ fn to_graph_rec<'a>( }, Filter { predicate, input } => { - let phys_predicate_expr = create_stream_expr(predicate, ctx)?; + let input_schema = &ctx.phys_sm[*input].output_schema; + let phys_predicate_expr = create_stream_expr(predicate, ctx, input_schema)?; let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( nodes::filter::FilterNode::new(phys_predicate_expr), @@ -116,9 +121,10 @@ fn to_graph_rec<'a>( input, extend_original, } => { + let input_schema = &ctx.phys_sm[*input].output_schema; let phys_selectors = selectors .iter() - .map(|selector| create_stream_expr(selector, ctx)) + .map(|selector| create_stream_expr(selector, ctx, input_schema)) .collect::>()?; let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( @@ -144,9 +150,10 @@ fn to_graph_rec<'a>( }, InputIndependentSelect { selectors } => { + let empty_schema = Default::default(); let phys_selectors = selectors .iter() - .map(|selector| create_stream_expr(selector, ctx)) + .map(|selector| create_stream_expr(selector, ctx, &empty_schema)) .collect::>()?; ctx.graph.add_node( nodes::input_independent_select::InputIndependentSelectNode::new(phys_selectors), @@ -165,8 +172,11 @@ fn to_graph_rec<'a>( let (red, input_node) = into_reduction(e.node(), ctx.expr_arena, input_schema)?; reductions.push(red); - let input_phys = - create_stream_expr(&ExprIR::from_node(input_node, ctx.expr_arena), ctx)?; + let input_phys = create_stream_expr( + &ExprIR::from_node(input_node, ctx.expr_arena), + ctx, + input_schema, + )?; inputs.push(input_phys) } @@ -194,6 +204,23 @@ fn to_graph_rec<'a>( ) }, + FileSink { + path, + file_type, + input, + } => { + let input_schema = ctx.phys_sm[*input].output_schema.clone(); + let input_key = to_graph_rec(*input, ctx)?; + + match file_type { + FileType::Ipc(ipc_writer_options) => ctx.graph.add_node( + nodes::io_sinks::ipc::IpcSinkNode::new(input_schema, path, ipc_writer_options)?, + [input_key], + ), + _ => todo!(), + } + }, + InMemoryMap { input, map } => { let input_schema = ctx.phys_sm[*input].output_schema.clone(); let input_key = to_graph_rec(*input, ctx)?; @@ -304,7 +331,7 @@ fn to_graph_rec<'a>( &pred, Context::Default, ctx.expr_arena, - output_schema.as_ref(), + output_schema.as_ref().unwrap_or(&file_info.schema), &mut ctx.expr_conversion_state, ) }) @@ -341,6 +368,46 @@ fn to_graph_rec<'a>( } } }, + + GroupBy { input, key, aggs } => { + let input_key = to_graph_rec(*input, ctx)?; + + let input_schema = &ctx.phys_sm[*input].output_schema; + let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)? + .materialize_unknown_dtypes()?; + let random_state = Default::default(); + let grouper = new_hash_grouper(Arc::new(key_schema), random_state); + + let key_selectors = key + .iter() + .map(|e| create_stream_expr(e, ctx, input_schema)) + .try_collect_vec()?; + + let mut grouped_reductions = Vec::new(); + let mut grouped_reduction_selectors = Vec::new(); + for agg in aggs { + let (reduction, input_node) = + into_reduction(agg.node(), ctx.expr_arena, input_schema)?; + let selector = create_stream_expr( + &ExprIR::from_node(input_node, ctx.expr_arena), + ctx, + input_schema, + )?; + grouped_reductions.push(reduction); + grouped_reduction_selectors.push(selector); + } + + ctx.graph.add_node( + nodes::group_by::GroupByNode::new( + key_selectors, + grouped_reduction_selectors, + grouped_reductions, + grouper, + node.output_schema.clone(), + ), + [input_key], + ) + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key); diff --git a/crates/polars-stream/src/pipe.rs b/crates/polars-stream/src/pipe.rs index 019d8779d18a..21b6a5672618 100644 --- a/crates/polars-stream/src/pipe.rs +++ b/crates/polars-stream/src/pipe.rs @@ -20,7 +20,7 @@ pub enum PhysicalPipe { pub struct SendPort<'a>(&'a mut PhysicalPipe); pub struct RecvPort<'a>(&'a mut PhysicalPipe); -impl<'a> RecvPort<'a> { +impl RecvPort<'_> { pub fn serial(self) -> Receiver { let PhysicalPipe::Uninit(num_pipelines) = self.0 else { unreachable!() @@ -41,7 +41,7 @@ impl<'a> RecvPort<'a> { } } -impl<'a> SendPort<'a> { +impl SendPort<'_> { #[allow(unused)] pub fn is_receiver_serial(&self) -> bool { matches!(self.0, PhysicalPipe::SerialReceiver(..)) diff --git a/crates/polars-stream/src/skeleton.rs b/crates/polars-stream/src/skeleton.rs index 20ca189de9e0..9516be3b902a 100644 --- a/crates/polars-stream/src/skeleton.rs +++ b/crates/polars-stream/src/skeleton.rs @@ -1,22 +1,24 @@ #![allow(unused)] // TODO: remove me +use std::cmp::Reverse; + use polars_core::prelude::*; use polars_core::POOL; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; -use polars_plan::plans::{Context, IRPlan, IR}; +use polars_plan::plans::{Context, IRPlan, IsStreamableContext, IR}; use polars_plan::prelude::expr_ir::ExprIR; use polars_plan::prelude::AExpr; use polars_utils::arena::{Arena, Node}; use slotmap::{SecondaryMap, SlotMap}; fn is_streamable(node: Node, arena: &Arena) -> bool { - polars_plan::plans::is_streamable(node, arena, Context::Default) + polars_plan::plans::is_streamable(node, arena, IsStreamableContext::new(Context::Default)) } pub fn run_query( node: Node, mut ir_arena: Arena, expr_arena: &mut Arena, -) -> PolarsResult { +) -> PolarsResult> { if let Ok(visual_path) = std::env::var("POLARS_VISUALIZE_IR") { let plan = IRPlan { lp_top: node, @@ -35,6 +37,16 @@ pub fn run_query( } let (mut graph, phys_to_graph) = crate::physical_plan::physical_plan_to_graph(root, &phys_sm, expr_arena)?; + crate::async_executor::clear_task_wait_statistics(); let mut results = crate::execute::execute_graph(&mut graph)?; - Ok(results.remove(phys_to_graph[root]).unwrap()) + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { + let mut stats = crate::async_executor::get_task_wait_statistics(); + stats.sort_by_key(|(_l, w)| Reverse(*w)); + eprintln!("Time spent waiting for async tasks:"); + for (loc, wait_time) in stats { + eprintln!("{}:{} - {:?}", loc.file(), loc.line(), wait_time); + } + } + + Ok(results.remove(phys_to_graph[root])) } diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index d75d634d213d..6d878b0ba27f 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -23,6 +23,7 @@ now = { version = "0.1" } once_cell = { workspace = true } regex = { workspace = true } serde = { workspace = true, optional = true } +strum_macros = { workspace = true } [dev-dependencies] polars-ops = { workspace = true, features = ["abs"] } diff --git a/crates/polars-time/src/date_range.rs b/crates/polars-time/src/date_range.rs index 8f01d687fd83..57ad3eb2870c 100644 --- a/crates/polars-time/src/date_range.rs +++ b/crates/polars-time/src/date_range.rs @@ -111,25 +111,36 @@ pub(crate) fn datetime_range_i64( ComputeError: "`interval` must be positive" ); - let size: usize; - let offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult; - - match tu { - TimeUnit::Nanoseconds => { - size = ((end - start) / interval.duration_ns() + 1) as usize; - offset_fn = Duration::add_ns; - }, - TimeUnit::Microseconds => { - size = ((end - start) / interval.duration_us() + 1) as usize; - offset_fn = Duration::add_us; - }, - TimeUnit::Milliseconds => { - size = ((end - start) / interval.duration_ms() + 1) as usize; - offset_fn = Duration::add_ms; - }, + let duration = match tu { + TimeUnit::Nanoseconds => interval.duration_ns(), + TimeUnit::Microseconds => interval.duration_us(), + TimeUnit::Milliseconds => interval.duration_ms(), + }; + let time_zone_opt_string: Option = match tz { + #[cfg(feature = "timezones")] + Some(tz) => Some(tz.to_string()), + _ => None, + }; + if interval.is_constant_duration(time_zone_opt_string.as_deref()) { + // Fast path! + let step: usize = duration.try_into().map_err( + |_err| polars_err!(ComputeError: "Could not convert {:?} to usize", duration), + )?; + return match closed { + ClosedWindow::Both => Ok((start..=end).step_by(step).collect::>()), + ClosedWindow::None => Ok((start + duration..end).step_by(step).collect::>()), + ClosedWindow::Left => Ok((start..end).step_by(step).collect::>()), + ClosedWindow::Right => Ok((start + duration..=end).step_by(step).collect::>()), + }; } - let mut ts = Vec::with_capacity(size); + let size = ((end - start) / duration + 1) as usize; + let offset_fn = match tu { + TimeUnit::Nanoseconds => Duration::add_ns, + TimeUnit::Microseconds => Duration::add_us, + TimeUnit::Milliseconds => Duration::add_ms, + }; + let mut ts = Vec::with_capacity(size); let mut i = match closed { ClosedWindow::Both | ClosedWindow::Left => 0, ClosedWindow::Right | ClosedWindow::None => 1, diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 8a8d2312d580..3ff08ee4d308 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -789,12 +789,12 @@ mod test { let quantile = unsafe { a.as_materialized_series() - .agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) + .agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]); assert_eq!(quantile, expected); - let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) }; + let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]); assert_eq!(quantile, expected); diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index 4f300f733100..56ce3e4bdcd5 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -153,7 +153,7 @@ impl Duration { /// # Panics /// If the given str is invalid for any reason. pub fn parse(duration: &str) -> Self { - Self::_parse(duration, false) + Self::try_parse(duration).unwrap() } #[doc(hidden)] @@ -161,23 +161,31 @@ impl Duration { /// units (such as 'year', 'minutes', etc.) and whitespace, as /// well as being case-insensitive. pub fn parse_interval(interval: &str) -> Self { + Self::try_parse_interval(interval).unwrap() + } + + pub fn try_parse(duration: &str) -> PolarsResult { + Self::_parse(duration, false) + } + + pub fn try_parse_interval(interval: &str) -> PolarsResult { Self::_parse(&interval.to_ascii_lowercase(), true) } - fn _parse(s: &str, as_interval: bool) -> Self { + fn _parse(s: &str, as_interval: bool) -> PolarsResult { let s = if as_interval { s.trim_start() } else { s }; let parse_type = if as_interval { "interval" } else { "duration" }; let num_minus_signs = s.matches('-').count(); if num_minus_signs > 1 { - panic!("{} string can only have a single minus sign", parse_type) + polars_bail!(InvalidOperation: "{} string can only have a single minus sign", parse_type); } if num_minus_signs > 0 { if as_interval { // TODO: intervals need to support per-element minus signs - panic!("minus signs are not currently supported in interval strings") + polars_bail!(InvalidOperation: "minus signs are not currently supported in interval strings"); } else if !s.starts_with('-') { - panic!("only a single minus sign is allowed, at the front of the string") + polars_bail!(InvalidOperation: "only a single minus sign is allowed, at the front of the string"); } } let mut months = 0; @@ -211,12 +219,12 @@ impl Duration { while let Some((i, mut ch)) = iter.next() { if !ch.is_ascii_digit() { - let n = s[start..i].parse::().unwrap_or_else(|_| { - panic!( + let Ok(n) = s[start..i].parse::() else { + polars_bail!(InvalidOperation: "expected leading integer in the {} string, found {}", parse_type, ch - ) - }); + ); + }; loop { match ch { @@ -233,10 +241,10 @@ impl Duration { } } if unit.is_empty() { - panic!( + polars_bail!(InvalidOperation: "expected a unit to follow integer in the {} string '{}'", parse_type, s - ) + ); } match &*unit { // matches that are allowed for both duration/interval @@ -270,24 +278,25 @@ impl Duration { "year" | "years" => months += n * 12, _ => { let valid_units = "'year', 'month', 'quarter', 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond', 'nanosecond'"; - panic!("unit: '{unit}' not supported; available units include: {} (and their plurals)", valid_units) + polars_bail!(InvalidOperation: "unit: '{unit}' not supported; available units include: {} (and their plurals)", valid_units); }, }, _ => { - panic!("unit: '{unit}' not supported; available units are: 'y', 'mo', 'q', 'w', 'd', 'h', 'm', 's', 'ms', 'us', 'ns'") + polars_bail!(InvalidOperation: "unit: '{unit}' not supported; available units are: 'y', 'mo', 'q', 'w', 'd', 'h', 'm', 's', 'ms', 'us', 'ns'"); }, } unit.clear(); } } - Duration { + + Ok(Duration { nsecs: nsecs.abs(), days: days.abs(), weeks: weeks.abs(), months: months.abs(), negative, parsed_int, - } + }) } fn to_positive(v: i64) -> (bool, i64) { diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 9ba3a2d3dbc2..0a40b9af6fbc 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -8,11 +8,13 @@ use polars_core::POOL; use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::prelude::*; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum ClosedWindow { Left, Right, @@ -20,16 +22,18 @@ pub enum ClosedWindow { None, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum Label { Left, Right, DataPoint, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum StartBy { WindowBound, DataPoint, diff --git a/crates/polars-time/src/windows/window.rs b/crates/polars-time/src/windows/window.rs index 90afe791e4d2..c7a29b846c58 100644 --- a/crates/polars-time/src/windows/window.rs +++ b/crates/polars-time/src/windows/window.rs @@ -316,7 +316,7 @@ impl<'a> BoundsIter<'a> { } } -impl<'a> Iterator for BoundsIter<'a> { +impl Iterator for BoundsIter<'_> { type Item = Bounds; fn next(&mut self) -> Option { diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index 442d319b7753..ef968918d4e8 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -21,6 +21,7 @@ libc = { workspace = true } memmap = { workspace = true, optional = true } num-traits = { workspace = true } once_cell = { workspace = true } +pyo3 = { workspace = true, optional = true } raw-cpuid = { workspace = true } rayon = { workspace = true } serde = { workspace = true, optional = true } @@ -39,3 +40,4 @@ bigidx = [] nightly = [] ir_serde = ["serde"] serde = ["dep:serde", "serde/derive"] +python = ["pyo3"] diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index c8b5823d695e..270a3fb9b835 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -7,11 +7,11 @@ use crate::error::*; use crate::slice::GetSaferUnchecked; unsafe fn index_of_unchecked(slice: &[T], item: &T) -> usize { - (item as *const _ as usize - slice.as_ptr() as usize) / std::mem::size_of::() + (item as *const _ as usize - slice.as_ptr() as usize) / size_of::() } fn index_of(slice: &[T], item: &T) -> Option { - debug_assert!(std::mem::size_of::() > 0); + debug_assert!(size_of::() > 0); let ptr = item as *const T; unsafe { if slice.as_ptr() < ptr && slice.as_ptr().add(slice.len()) > ptr { diff --git a/crates/polars-utils/src/clmul.rs b/crates/polars-utils/src/clmul.rs index c3467e6152e4..ac6fc541c400 100644 --- a/crates/polars-utils/src/clmul.rs +++ b/crates/polars-utils/src/clmul.rs @@ -57,7 +57,7 @@ pub fn portable_prefix_xorsum(x: u64) -> u64 { portable_prefix_xorsum_inclusive(x << 1) } -// Computes for each bit i the XOR of all less significant bits. +// Computes for each bit i the XOR of bits[0..i]. #[inline] pub fn prefix_xorsum(x: u64) -> u64 { #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] @@ -82,7 +82,7 @@ pub fn portable_prefix_xorsum_inclusive(mut x: u64) -> u64 { x } -// Computes for each bit i the XOR of all less significant bits. +// Computes for each bit i the XOR of bits[0..=i]. #[inline] pub fn prefix_xorsum_inclusive(x: u64) -> u64 { #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] diff --git a/crates/polars-utils/src/float.rs b/crates/polars-utils/src/float.rs index 30d084985782..30d47397c28e 100644 --- a/crates/polars-utils/src/float.rs +++ b/crates/polars-utils/src/float.rs @@ -1,6 +1,6 @@ /// # Safety /// unsafe code downstream relies on the correct is_float call -pub unsafe trait IsFloat: private::Sealed { +pub unsafe trait IsFloat: private::Sealed + Sized { fn is_float() -> bool { false } @@ -13,6 +13,10 @@ pub unsafe trait IsFloat: private::Sealed { false } + fn nan_value() -> Self { + unimplemented!() + } + #[allow(clippy::wrong_self_convention)] fn is_nan(&self) -> bool where @@ -78,6 +82,10 @@ macro_rules! impl_is_float { $is_f64 } + fn nan_value() -> Self { + Self::NAN + } + #[inline] fn is_nan(&self) -> bool { <$tp>::is_nan(*self) diff --git a/crates/polars-utils/src/hashing.rs b/crates/polars-utils/src/hashing.rs index 12e59bf52f26..63f4c661a2c3 100644 --- a/crates/polars-utils/src/hashing.rs +++ b/crates/polars-utils/src/hashing.rs @@ -2,6 +2,11 @@ use std::hash::{Hash, Hasher}; use crate::nulls::IsNull; +pub const fn folded_multiply(a: u64, b: u64) -> u64 { + let full = (a as u128).wrapping_mul(b as u128); + (full as u64) ^ ((full >> 64) as u64) +} + /// Contains a byte slice and a precomputed hash for that string. /// During rehashes, we will rehash the hash instead of the string, that makes /// rehashing cheap and allows cache coherent small hash tables. @@ -33,13 +38,13 @@ impl<'a> IsNull for BytesHash<'a> { } } -impl<'a> Hash for BytesHash<'a> { +impl Hash for BytesHash<'_> { fn hash(&self, state: &mut H) { state.write_u64(self.hash) } } -impl<'a> PartialEq for BytesHash<'a> { +impl PartialEq for BytesHash<'_> { #[inline] fn eq(&self, other: &Self) -> bool { (self.hash == other.hash) && (self.payload == other.payload) @@ -94,7 +99,7 @@ impl DirtyHash for i128 { } } -impl<'a> DirtyHash for BytesHash<'a> { +impl DirtyHash for BytesHash<'_> { fn dirty_hash(&self) -> u64 { self.hash } diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs index 8bfdfafa2fd4..b3fbbed403f5 100644 --- a/crates/polars-utils/src/idx_vec.rs +++ b/crates/polars-utils/src/idx_vec.rs @@ -45,10 +45,7 @@ impl UnitVec { #[inline] pub fn new() -> Self { // This is optimized away, all const. - assert!( - std::mem::size_of::() <= std::mem::size_of::<*mut T>() - && std::mem::align_of::() <= std::mem::align_of::<*mut T>() - ); + assert!(size_of::() <= size_of::<*mut T>() && align_of::() <= align_of::<*mut T>()); Self { len: 0, capacity: NonZeroUsize::new(1).unwrap(), diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index eacd517d1254..5c302067e146 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -51,3 +51,6 @@ pub mod partitioned; pub use index::{IdxSize, NullableIdxSize}; pub use io::*; + +#[cfg(feature = "python")] +pub mod python_function; diff --git a/crates/polars-utils/src/mmap.rs b/crates/polars-utils/src/mmap.rs index 29651d5eb56a..0ac1a643d93d 100644 --- a/crates/polars-utils/src/mmap.rs +++ b/crates/polars-utils/src/mmap.rs @@ -61,6 +61,12 @@ mod private { } } + impl From> for MemSlice { + fn from(value: Vec) -> Self { + Self::from_vec(value) + } + } + impl MemSlice { pub const EMPTY: Self = Self::from_static(&[]); diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs new file mode 100644 index 000000000000..178f89aa2ca1 --- /dev/null +++ b/crates/polars-utils/src/python_function.rs @@ -0,0 +1,235 @@ +use polars_error::{polars_bail, PolarsError, PolarsResult}; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedBytes; +use pyo3::types::PyBytes; +#[cfg(feature = "serde")] +pub use serde_wrap::{ + PySerializeWrap, TrySerializeToBytes, PYTHON3_VERSION, + SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK, +}; + +use crate::flatten; + +#[derive(Clone, Debug)] +pub struct PythonFunction(pub PyObject); + +impl From for PythonFunction { + fn from(value: PyObject) -> Self { + Self(value) + } +} + +impl Eq for PythonFunction {} + +impl PartialEq for PythonFunction { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + let eq = self.0.getattr(py, "__eq__").unwrap(); + eq.call1(py, (other.0.clone_ref(py),)) + .unwrap() + .extract::(py) + // equality can be not implemented, so default to false + .unwrap_or(false) + }) + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PythonFunction { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + serializer.serialize_bytes( + self.try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))? + .as_slice(), + ) + } +} + +#[cfg(feature = "serde")] +impl<'a> serde::Deserialize<'a> for PythonFunction { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + let bytes = Vec::::deserialize(deserializer)?; + Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string())) + } +} + +#[cfg(feature = "serde")] +impl TrySerializeToBytes for PythonFunction { + fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult> { + serialize_pyobject_with_cloudpickle_fallback(&self.0) + } + + fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult { + deserialize_pyobject_bytes_maybe_cloudpickle(bytes) + } +} + +pub fn serialize_pyobject_with_cloudpickle_fallback(py_object: &PyObject) -> PolarsResult> { + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("dumps") + .unwrap(); + + let dumped = pickle.call1((py_object.clone_ref(py),)); + + let (dumped, used_cloudpickle) = if let Ok(v) = dumped { + (v, false) + } else { + let cloudpickle = PyModule::import_bound(py, "cloudpickle") + .map_err(from_pyerr)? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle + .call1((py_object.clone_ref(py),)) + .map_err(from_pyerr)?; + (dumped, true) + }; + + let py_bytes = dumped.extract::().map_err(from_pyerr)?; + + Ok(flatten( + &[&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()], + None, + )) + }) +} + +pub fn deserialize_pyobject_bytes_maybe_cloudpickle From>( + bytes: &[u8], +) -> PolarsResult { + // TODO: Actually deserialize with cloudpickle if it's set. + let [_used_cloudpickle @ 0 | _used_cloudpickle @ 1, b'C', rem @ ..] = bytes else { + polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes") + }; + + let bytes = rem; + + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("loads") + .unwrap(); + let arg = (PyBytes::new_bound(py, bytes),); + let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?; + Ok(PyObject::from(pyany_bound).into()) + }) +} + +#[cfg(feature = "serde")] +mod serde_wrap { + use once_cell::sync::Lazy; + use polars_error::PolarsResult; + + use crate::flatten; + + pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes(); + /// [minor, micro] + pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version); + + /// Serializes a Python object without additional system metadata. This is intended to be used + /// together with `PySerializeWrap`, which attaches e.g. Python version metadata. + pub trait TrySerializeToBytes: Sized { + fn try_serialize_to_bytes(&self) -> PolarsResult>; + fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult; + } + + /// Serialization wrapper for T: TrySerializeToBytes that attaches Python + /// version metadata. + pub struct PySerializeWrap(pub T); + + impl serde::Serialize for PySerializeWrap<&T> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + let dumped = self + .0 + .try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))?; + + serializer.serialize_bytes( + flatten( + &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()], + None, + ) + .as_slice(), + ) + } + } + + impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + let bytes = Vec::::deserialize(deserializer)?; + + let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject version", + )); + }; + + if magic != SERDE_MAGIC_BYTE_MARK { + return Err(D::Error::custom( + "serialized pyobject did not begin with magic byte mark", + )); + } + + let bytes = rem; + + let [a, b, rem @ ..] = bytes else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject metadata", + )); + }; + + let py3_version = [*a, *b]; + + if py3_version != *PYTHON3_VERSION { + return Err(D::Error::custom(format!( + "python version that pyobject was serialized with {:?} \ + differs from system python version {:?}", + (3, py3_version[0], py3_version[1]), + (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]), + ))); + } + + let bytes = rem; + + T::try_deserialize_bytes(bytes) + .map(Self) + .map_err(|e| D::Error::custom(e.to_string())) + } + } +} + +/// Get the [minor, micro] Python3 version from the `sys` module. +fn get_python3_version() -> [u8; 2] { + Python::with_gil(|py| { + let version_info = PyModule::import_bound(py, "sys") + .unwrap() + .getattr("version_info") + .unwrap(); + + [ + version_info.getattr("minor").unwrap().extract().unwrap(), + version_info.getattr("micro").unwrap().extract().unwrap(), + ] + }) +} + +fn from_pyerr(e: PyErr) -> PolarsError { + PolarsError::ComputeError(format!("error raised in python: {e}").into()) +} diff --git a/crates/polars-utils/src/sort.rs b/crates/polars-utils/src/sort.rs index 780dc39d1b9c..0c83f2becddd 100644 --- a/crates/polars-utils/src/sort.rs +++ b/crates/polars-utils/src/sort.rs @@ -90,8 +90,8 @@ where Idx: FromPrimitive + Copy, { // Needed to be able to write back to back in the same buffer. - debug_assert_eq!(std::mem::align_of::(), std::mem::align_of::<(T, Idx)>()); - let size = std::mem::size_of::<(T, Idx)>(); + debug_assert_eq!(align_of::(), align_of::<(T, Idx)>()); + let size = size_of::<(T, Idx)>(); let upper_bound = size * n + size; scratch.reserve(upper_bound); let scratch_slice = unsafe { diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs index cfaa05f0141d..982dc707e3de 100644 --- a/crates/polars-utils/src/total_ord.rs +++ b/crates/polars-utils/src/total_ord.rs @@ -453,7 +453,7 @@ impl TotalOrd for (T, U) { } } -impl<'a> TotalHash for BytesHash<'a> { +impl TotalHash for BytesHash<'_> { #[inline(always)] fn tot_hash(&self, state: &mut H) where @@ -463,7 +463,7 @@ impl<'a> TotalHash for BytesHash<'a> { } } -impl<'a> TotalEq for BytesHash<'a> { +impl TotalEq for BytesHash<'_> { #[inline(always)] fn tot_eq(&self, other: &Self) -> bool { self == other diff --git a/crates/polars-utils/src/vec.rs b/crates/polars-utils/src/vec.rs index 108e7d573d1c..9060a348230c 100644 --- a/crates/polars-utils/src/vec.rs +++ b/crates/polars-utils/src/vec.rs @@ -20,7 +20,7 @@ impl IntoRawParts for Vec { } } -/// Fill current allocation if if > 0 +/// Fill current allocation if > 0 /// otherwise realloc pub trait ResizeFaster { fn fill_or_alloc(&mut self, new_len: usize, value: T); diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 31bae99c6146..685ed71d8306 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -26,8 +26,7 @@ polars-utils = { workspace = true } [dev-dependencies] ahash = { workspace = true } apache-avro = { version = "0.17", features = ["snappy"] } -arrow = { workspace = true, features = ["arrow_rs"] } -arrow-buffer = { workspace = true } +arrow = { workspace = true } avro-schema = { workspace = true, features = ["async"] } either = { workspace = true } ethnum = "1" @@ -132,7 +131,13 @@ array_any_all = ["polars-lazy?/array_any_all", "dtype-array"] asof_join = ["polars-lazy?/asof_join", "polars-ops/asof_join"] iejoin = ["polars-lazy?/iejoin"] binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"] -bitwise = ["polars-core/bitwise", "polars-plan?/bitwise", "polars-ops/bitwise", "polars-lazy?/bitwise"] +bitwise = [ + "polars-core/bitwise", + "polars-plan?/bitwise", + "polars-ops/bitwise", + "polars-lazy?/bitwise", + "polars-sql?/bitwise", +] business = ["polars-lazy?/business", "polars-ops/business"] checked_arithmetic = ["polars-core/checked_arithmetic"] chunked_ids = ["polars-ops?/chunked_ids"] @@ -184,6 +189,7 @@ list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"] list_sample = ["polars-lazy?/list_sample"] list_sets = ["polars-lazy?/list_sets"] list_to_struct = ["polars-ops/list_to_struct", "polars-lazy?/list_to_struct"] +list_arithmetic = ["polars-core/list_arithmetic"] array_to_struct = ["polars-ops/array_to_struct", "polars-lazy?/array_to_struct"] log = ["polars-ops/log", "polars-lazy?/log"] merge_sorted = ["polars-lazy?/merge_sorted"] @@ -229,7 +235,7 @@ true_div = ["polars-lazy?/true_div"] unique_counts = ["polars-ops/unique_counts", "polars-lazy?/unique_counts"] zip_with = ["polars-core/zip_with"] -bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] +bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx", "polars-utils/bigidx"] polars_cloud = ["polars-lazy?/polars_cloud"] ir_serde = ["polars-plan/ir_serde"] @@ -376,6 +382,8 @@ docs-selection = [ "is_last_distinct", "asof_join", "cross_join", + "semi_anti_join", + "iejoin", "concat_str", "string_reverse", "string_to_integer", diff --git a/crates/polars/src/docs/lazy.rs b/crates/polars/src/docs/lazy.rs index c77bf58d5cac..bfaa6ebd2569 100644 --- a/crates/polars/src/docs/lazy.rs +++ b/crates/polars/src/docs/lazy.rs @@ -257,7 +257,7 @@ //! "b" => [3.0f32, 5.1, 0.3] //! ]? //! .lazy() -//! .select([as_struct(&[col("a"), col("b")]).map( +//! .select([as_struct(vec![col("a"), col("b")]).map( //! |s| { //! let ca = s.struct_()?; //! diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 5ecc28c94c34..dba9bb39d46d 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -373,6 +373,7 @@ //! * `ASCII_BORDERS_ONLY_CONDENSED` //! * `ASCII_HORIZONTAL_ONLY` //! * `ASCII_MARKDOWN` +//! * `MARKDOWN` //! * `UTF8_FULL` //! * `UTF8_FULL_CONDENSED` //! * `UTF8_NO_BORDERS` diff --git a/crates/polars/tests/it/arrow/array/growable/mod.rs b/crates/polars/tests/it/arrow/array/growable/mod.rs index 4510fd0749cd..648e5203263a 100644 --- a/crates/polars/tests/it/arrow/array/growable/mod.rs +++ b/crates/polars/tests/it/arrow/array/growable/mod.rs @@ -60,6 +60,7 @@ fn test_make_growable_extension() { ); let array = StructArray::new( dtype.clone(), + 2, vec![Int32Array::from_slice([1, 2]).boxed()], None, ); diff --git a/crates/polars/tests/it/arrow/array/growable/struct_.rs b/crates/polars/tests/it/arrow/array/growable/struct_.rs index 07f0403ee294..2749fa88bb1c 100644 --- a/crates/polars/tests/it/arrow/array/growable/struct_.rs +++ b/crates/polars/tests/it/arrow/array/growable/struct_.rs @@ -29,7 +29,7 @@ fn some_values() -> (ArrowDataType, Vec>) { fn basic() { let (fields, values) = some_values(); - let array = StructArray::new(fields.clone(), values.clone(), None); + let array = StructArray::new(fields.clone(), values[0].len(), values.clone(), None); let mut a = GrowableStruct::new(vec![&array], false, 0); @@ -41,6 +41,7 @@ fn basic() { let expected = StructArray::new( fields, + 2, vec![values[0].sliced(1, 2), values[1].sliced(1, 2)], None, ); @@ -51,7 +52,8 @@ fn basic() { fn offset() { let (fields, values) = some_values(); - let array = StructArray::new(fields.clone(), values.clone(), None).sliced(1, 3); + let array = + StructArray::new(fields.clone(), values[0].len(), values.clone(), None).sliced(1, 3); let mut a = GrowableStruct::new(vec![&array], false, 0); @@ -63,6 +65,7 @@ fn offset() { let expected = StructArray::new( fields, + 2, vec![values[0].sliced(2, 2), values[1].sliced(2, 2)], None, ); @@ -76,6 +79,7 @@ fn nulls() { let array = StructArray::new( fields.clone(), + values[0].len(), values.clone(), Some(Bitmap::from_u8_slice([0b00000010], 5)), ); @@ -90,6 +94,7 @@ fn nulls() { let expected = StructArray::new( fields, + 2, vec![values[0].sliced(1, 2), values[1].sliced(1, 2)], Some(Bitmap::from_u8_slice([0b00000010], 5).sliced(1, 2)), ); @@ -101,7 +106,7 @@ fn nulls() { fn many() { let (fields, values) = some_values(); - let array = StructArray::new(fields.clone(), values.clone(), None); + let array = StructArray::new(fields.clone(), values[0].len(), values.clone(), None); let mut mutable = GrowableStruct::new(vec![&array, &array], true, 0); @@ -132,6 +137,7 @@ fn many() { let expected = StructArray::new( fields, + expected_string.len(), vec![expected_string, expected_int], Some(Bitmap::from([true, true, true, true, false])), ); diff --git a/crates/polars/tests/it/arrow/array/map/mod.rs b/crates/polars/tests/it/arrow/array/map/mod.rs index 34e880578659..44702f118a89 100644 --- a/crates/polars/tests/it/arrow/array/map/mod.rs +++ b/crates/polars/tests/it/arrow/array/map/mod.rs @@ -13,6 +13,7 @@ fn array() -> MapArray { let field = StructArray::new( dt(), + 3, vec![ Box::new(Utf8Array::::from_slice(["a", "aa", "aaa"])) as _, Box::new(Utf8Array::::from_slice(["b", "bb", "bbb"])), @@ -36,6 +37,7 @@ fn basics() { array.value(0), Box::new(StructArray::new( dt(), + 1, vec![ Box::new(Utf8Array::::from_slice(["a"])) as _, Box::new(Utf8Array::::from_slice(["b"])), @@ -49,6 +51,7 @@ fn basics() { sliced.value(0), Box::new(StructArray::new( dt(), + 1, vec![ Box::new(Utf8Array::::from_slice(["aa"])) as _, Box::new(Utf8Array::::from_slice(["bb"])), @@ -66,6 +69,7 @@ fn split_at() { lhs.value(0), Box::new(StructArray::new( dt(), + 1, vec![ Box::new(Utf8Array::::from_slice(["a"])) as _, Box::new(Utf8Array::::from_slice(["b"])), @@ -77,6 +81,7 @@ fn split_at() { rhs.value(0), Box::new(StructArray::new( dt(), + 1, vec![ Box::new(Utf8Array::::from_slice(["aa"])) as _, Box::new(Utf8Array::::from_slice(["bb"])), @@ -88,6 +93,7 @@ fn split_at() { rhs.value(1), Box::new(StructArray::new( dt(), + 1, vec![ Box::new(Utf8Array::::from_slice(["aaa"])) as _, Box::new(Utf8Array::::from_slice(["bbb"])), diff --git a/crates/polars/tests/it/arrow/array/struct_/iterator.rs b/crates/polars/tests/it/arrow/array/struct_/iterator.rs index e4b6a7691ad0..7e3430e355d9 100644 --- a/crates/polars/tests/it/arrow/array/struct_/iterator.rs +++ b/crates/polars/tests/it/arrow/array/struct_/iterator.rs @@ -14,6 +14,7 @@ fn test_simple_iter() { let array = StructArray::new( ArrowDataType::Struct(fields), + boolean.len(), vec![boolean.clone(), int.clone()], None, ); diff --git a/crates/polars/tests/it/arrow/array/struct_/mod.rs b/crates/polars/tests/it/arrow/array/struct_/mod.rs index bd1a1c83086c..9492a1b6bdba 100644 --- a/crates/polars/tests/it/arrow/array/struct_/mod.rs +++ b/crates/polars/tests/it/arrow/array/struct_/mod.rs @@ -1,5 +1,4 @@ mod iterator; -mod mutable; use arrow::array::*; use arrow::bitmap::Bitmap; @@ -16,6 +15,7 @@ fn array() -> StructArray { StructArray::new( ArrowDataType::Struct(fields), + boolean.len(), vec![boolean.clone(), int.clone()], Some(Bitmap::from([true, true, false, true])), ) diff --git a/crates/polars/tests/it/arrow/array/struct_/mutable.rs b/crates/polars/tests/it/arrow/array/struct_/mutable.rs deleted file mode 100644 index 4a526a76391b..000000000000 --- a/crates/polars/tests/it/arrow/array/struct_/mutable.rs +++ /dev/null @@ -1,31 +0,0 @@ -use arrow::array::*; -use arrow::datatypes::{ArrowDataType, Field}; - -#[test] -fn push() { - let c1 = Box::new(MutablePrimitiveArray::::new()) as Box; - let values = vec![c1]; - let dtype = ArrowDataType::Struct(vec![Field::new("f1".into(), ArrowDataType::Int32, true)]); - let mut a = MutableStructArray::new(dtype, values); - - a.value::>(0) - .unwrap() - .push(Some(1)); - a.push(true); - a.value::>(0).unwrap().push(None); - a.push(false); - a.value::>(0) - .unwrap() - .push(Some(2)); - a.push(true); - - assert_eq!(a.len(), 3); - assert!(a.is_valid(0)); - assert!(!a.is_valid(1)); - assert!(a.is_valid(2)); - - assert_eq!( - a.value::>(0).unwrap().values(), - &Vec::from([1, 0, 2]) - ); -} diff --git a/crates/polars/tests/it/arrow/bitmap/immutable.rs b/crates/polars/tests/it/arrow/bitmap/immutable.rs index 4f2b3f3748b0..336322534de2 100644 --- a/crates/polars/tests/it/arrow/bitmap/immutable.rs +++ b/crates/polars/tests/it/arrow/bitmap/immutable.rs @@ -76,28 +76,3 @@ fn debug() { "Bitmap { len: 7, offset: 2, bytes: [0b111110__, 0b_______1] }" ); } - -#[test] -fn from_arrow() { - use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; - let buffer = arrow_buffer::Buffer::from_iter(vec![true, true, true, false, false, false, true]); - let bools = BooleanBuffer::new(buffer, 0, 7); - let nulls = NullBuffer::new(bools); - assert_eq!(nulls.null_count(), 3); - - let bitmap = Bitmap::from_null_buffer(nulls.clone()); - assert_eq!(nulls.null_count(), bitmap.unset_bits()); - assert_eq!(nulls.len(), bitmap.len()); - let back = NullBuffer::from(bitmap); - assert_eq!(nulls, back); - - let nulls = nulls.slice(1, 3); - assert_eq!(nulls.null_count(), 1); - assert_eq!(nulls.len(), 3); - - let bitmap = Bitmap::from_null_buffer(nulls.clone()); - assert_eq!(nulls.null_count(), bitmap.unset_bits()); - assert_eq!(nulls.len(), bitmap.len()); - let back = NullBuffer::from(bitmap); - assert_eq!(nulls, back); -} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs index fa138cd528c9..08cfbf31c62e 100644 --- a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs +++ b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs @@ -2,7 +2,7 @@ use arrow::bitmap::utils::fmt; struct A<'a>(&'a [u8], usize, usize); -impl<'a> std::fmt::Debug for A<'a> { +impl std::fmt::Debug for A<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt(self.0, self.1, self.2, f) } diff --git a/crates/polars/tests/it/arrow/bitmap/utils/mod.rs b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs index 12af43e4e949..ebd8d983dec0 100644 --- a/crates/polars/tests/it/arrow/bitmap/utils/mod.rs +++ b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs @@ -16,22 +16,24 @@ fn get_bit_basics() { 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, 0b01000000, 0b11111111, ]; - for i in 0..8 { - assert!(!get_bit(input, i)); + unsafe { + for i in 0..8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 8)); + for i in 8 + 1..2 * 8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 2 * 8 + 1)); + for i in 2 * 8 + 2..3 * 8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 3 * 8 + 2)); + for i in 3 * 8 + 3..4 * 8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 4 * 8 + 3)); } - assert!(get_bit(input, 8)); - for i in 8 + 1..2 * 8 { - assert!(!get_bit(input, i)); - } - assert!(get_bit(input, 2 * 8 + 1)); - for i in 2 * 8 + 2..3 * 8 { - assert!(!get_bit(input, i)); - } - assert!(get_bit(input, 3 * 8 + 2)); - for i in 3 * 8 + 3..4 * 8 { - assert!(!get_bit(input, i)); - } - assert!(get_bit(input, 4 * 8 + 3)); } #[test] diff --git a/crates/polars/tests/it/arrow/buffer/immutable.rs b/crates/polars/tests/it/arrow/buffer/immutable.rs index 9065b52fba35..cc8742ba73ae 100644 --- a/crates/polars/tests/it/arrow/buffer/immutable.rs +++ b/crates/polars/tests/it/arrow/buffer/immutable.rs @@ -43,73 +43,3 @@ fn from_vec() { assert_eq!(buffer.len(), 3); assert_eq!(buffer.as_slice(), &[0, 1, 2]); } - -#[test] -fn from_arrow() { - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 3); - assert_eq!(b.as_slice(), &[1, 2, 3]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); - - let buffer = buffer.slice(4); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 2); - assert_eq!(b.as_slice(), &[2, 3]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); - - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i64, 2_i64]); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 4); - assert_eq!(b.as_slice(), &[1, 0, 2, 0]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); - - let buffer = buffer.slice(4); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 3); - assert_eq!(b.as_slice(), &[0, 2, 0]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); -} - -#[test] -fn from_arrow_vec() { - // Zero-copy vec conversion in arrow-rs - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); - let back: Vec = buffer.into_vec().unwrap(); - - // Zero-copy vec conversion in arrow2 - let buffer = Buffer::::from(back); - let back: Vec = buffer.into_mut().unwrap_right(); - - let buffer = arrow_buffer::Buffer::from_vec(back); - let buffer = Buffer::::from(buffer); - - // But not possible after conversion between buffer representations - let _ = buffer.into_mut().unwrap_left(); - - let buffer = Buffer::::from(vec![1_i32]); - let buffer = arrow_buffer::Buffer::from(buffer); - - // But not possible after conversion between buffer representations - let _ = buffer.into_vec::().unwrap_err(); -} - -#[test] -#[should_panic(expected = "arrow_buffer::Buffer misaligned")] -fn from_arrow_misaligned() { - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]).slice(1); - let _ = Buffer::::from(buffer); -} - -#[test] -fn from_arrow_sliced() { - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); - let b = Buffer::::from(buffer); - let sliced = b.sliced(1, 2); - let back = arrow_buffer::Buffer::from(sliced); - assert_eq!(back.typed_data::(), &[2, 3]); -} diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs index 075e5179e1ca..942fd38cc4a7 100644 --- a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -5,7 +5,7 @@ use arrow::datatypes::{ArrowDataType, Field}; #[test] fn primitive() { let a = Int32Array::from_slice([1, 2, 3, 4, 5]); - assert_eq!(5 * std::mem::size_of::(), estimated_bytes_size(&a)); + assert_eq!(5 * size_of::(), estimated_bytes_size(&a)); } #[test] @@ -17,7 +17,7 @@ fn boolean() { #[test] fn utf8() { let a = Utf8Array::::from_slice(["aaa"]); - assert_eq!(3 + 2 * std::mem::size_of::(), estimated_bytes_size(&a)); + assert_eq!(3 + 2 * size_of::(), estimated_bytes_size(&a)); } #[test] @@ -28,5 +28,5 @@ fn fixed_size_list() { ); let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); let a = FixedSizeListArray::new(dtype, 2, values, None); - assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); + assert_eq!(6 * size_of::(), estimated_bytes_size(&a)); } diff --git a/crates/polars/tests/it/arrow/io/ipc/mod.rs b/crates/polars/tests/it/arrow/io/ipc/mod.rs index 3dfd0aedd276..8004f2fc8eea 100644 --- a/crates/polars/tests/it/arrow/io/ipc/mod.rs +++ b/crates/polars/tests/it/arrow/io/ipc/mod.rs @@ -62,7 +62,7 @@ fn prep_schema(array: &dyn Array) -> ArrowSchemaRef { fn write_boolean() -> PolarsResult<()> { let array = BooleanArray::from([Some(true), Some(false), None, Some(true)]).boxed(); let schema = prep_schema(array.as_ref()); - let columns = RecordBatchT::try_new(vec![array])?; + let columns = RecordBatchT::try_new(4, vec![array])?; round_trip(columns, schema, None, Some(Compression::ZSTD)) } @@ -72,7 +72,7 @@ fn write_sliced_utf8() -> PolarsResult<()> { .sliced(1, 1) .boxed(); let schema = prep_schema(array.as_ref()); - let columns = RecordBatchT::try_new(vec![array])?; + let columns = RecordBatchT::try_new(array.len(), vec![array])?; round_trip(columns, schema, None, Some(Compression::ZSTD)) } @@ -80,6 +80,6 @@ fn write_sliced_utf8() -> PolarsResult<()> { fn write_binview() -> PolarsResult<()> { let array = Utf8ViewArray::from_slice([Some("foo"), Some("bar"), None, Some("hamlet")]).boxed(); let schema = prep_schema(array.as_ref()); - let columns = RecordBatchT::try_new(vec![array])?; + let columns = RecordBatchT::try_new(array.len(), vec![array])?; round_trip(columns, schema, None, Some(Compression::ZSTD)) } diff --git a/crates/polars/tests/it/arrow/scalar/map.rs b/crates/polars/tests/it/arrow/scalar/map.rs index ee23cb47960f..3a12e8bffcd6 100644 --- a/crates/polars/tests/it/arrow/scalar/map.rs +++ b/crates/polars/tests/it/arrow/scalar/map.rs @@ -11,6 +11,7 @@ fn equal() { ]); let kv_array1 = StructArray::try_new( kv_dt.clone(), + 2, vec![ Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), BooleanArray::from_slice([true, false]).boxed(), @@ -20,6 +21,7 @@ fn equal() { .unwrap(); let kv_array2 = StructArray::try_new( kv_dt.clone(), + 2, vec![ Utf8Array::::from([Some("k1"), Some("k3")]).boxed(), BooleanArray::from_slice([true, true]).boxed(), @@ -47,6 +49,7 @@ fn basics() { ]); let kv_array = StructArray::try_new( kv_dt.clone(), + 2, vec![ Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), BooleanArray::from_slice([true, false]).boxed(), diff --git a/crates/polars/tests/it/io/avro/read.rs b/crates/polars/tests/it/io/avro/read.rs index dac9adbfc9d0..57adac991d1a 100644 --- a/crates/polars/tests/it/io/avro/read.rs +++ b/crates/polars/tests/it/io/avro/read.rs @@ -110,19 +110,21 @@ pub(super) fn data() -> RecordBatchT> { array.into_box(), StructArray::new( ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), + 2, vec![PrimitiveArray::::from_slice([1.0, 2.0]).boxed()], None, ) .boxed(), StructArray::new( ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), + 2, vec![PrimitiveArray::::from_slice([1.0, 0.0]).boxed()], Some([true, false].into()), ) .boxed(), ]; - RecordBatchT::try_new(columns).unwrap() + RecordBatchT::try_new(2, columns).unwrap() } pub(super) fn write_avro(codec: Codec) -> Result, apache_avro::Error> { @@ -256,6 +258,7 @@ fn test_projected() -> PolarsResult<()> { let mut projection = vec![false; expected_schema.len()]; projection[i] = true; + let length = expected.first().map_or(0, |arr| arr.len()); let expected = expected .clone() .into_arrays() @@ -263,7 +266,7 @@ fn test_projected() -> PolarsResult<()> { .zip(projection.iter()) .filter_map(|x| if *x.1 { Some(x.0) } else { None }) .collect(); - let expected = RecordBatchT::new(expected); + let expected = RecordBatchT::new(length, expected); let expected_schema = expected_schema .clone() @@ -324,9 +327,10 @@ pub(super) fn data_list() -> RecordBatchT> { ); array.try_extend(data).unwrap(); + let length = array.len(); let columns = vec![array.into_box()]; - RecordBatchT::try_new(columns).unwrap() + RecordBatchT::try_new(length, columns).unwrap() } pub(super) fn write_list(codec: Codec) -> Result, apache_avro::Error> { diff --git a/crates/polars/tests/it/io/avro/read_async.rs b/crates/polars/tests/it/io/avro/read_async.rs index 049910c0ce28..d50fd7595c58 100644 --- a/crates/polars/tests/it/io/avro/read_async.rs +++ b/crates/polars/tests/it/io/avro/read_async.rs @@ -26,25 +26,16 @@ async fn test(codec: Codec) -> PolarsResult<()> { Ok(()) } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn read_without_codec() -> PolarsResult<()> { test(Codec::Null).await } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn read_deflate() -> PolarsResult<()> { test(Codec::Deflate).await } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn read_snappy() -> PolarsResult<()> { test(Codec::Snappy).await diff --git a/crates/polars/tests/it/io/avro/write.rs b/crates/polars/tests/it/io/avro/write.rs index 43011eb7a2bf..48633f39ce94 100644 --- a/crates/polars/tests/it/io/avro/write.rs +++ b/crates/polars/tests/it/io/avro/write.rs @@ -102,7 +102,7 @@ pub(super) fn data() -> RecordBatchT> { )), ]; - RecordBatchT::new(columns) + RecordBatchT::new(2, columns) } pub(super) fn serialize_to_block>( @@ -197,7 +197,7 @@ fn large_format_data() -> RecordBatchT> { Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), Box::new(BinaryArray::::from([Some(b"foo"), None])), ]; - RecordBatchT::new(columns) + RecordBatchT::new(2, columns) } fn large_format_expected_schema() -> ArrowSchema { @@ -216,7 +216,7 @@ fn large_format_expected_data() -> RecordBatchT> { Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), Box::new(BinaryArray::::from([Some(b"foo"), None])), ]; - RecordBatchT::new(columns) + RecordBatchT::new(2, columns) } #[test] @@ -265,24 +265,29 @@ fn struct_data() -> RecordBatchT> { Field::new("item2".into(), ArrowDataType::Int32, true), ]); - RecordBatchT::new(vec![ - Box::new(StructArray::new( - struct_dt.clone(), - vec![ - Box::new(PrimitiveArray::::from_slice([1, 2])), - Box::new(PrimitiveArray::::from([None, Some(1)])), - ], - None, - )), - Box::new(StructArray::new( - struct_dt, - vec![ - Box::new(PrimitiveArray::::from_slice([1, 2])), - Box::new(PrimitiveArray::::from([None, Some(1)])), - ], - Some([true, false].into()), - )), - ]) + RecordBatchT::new( + 2, + vec![ + Box::new(StructArray::new( + struct_dt.clone(), + 2, + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + None, + )), + Box::new(StructArray::new( + struct_dt, + 2, + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + Some([true, false].into()), + )), + ], + ) } fn avro_record() -> Record { diff --git a/crates/polars/tests/it/io/avro/write_async.rs b/crates/polars/tests/it/io/avro/write_async.rs index 1be109e2733a..77cb212f89db 100644 --- a/crates/polars/tests/it/io/avro/write_async.rs +++ b/crates/polars/tests/it/io/avro/write_async.rs @@ -42,9 +42,6 @@ async fn roundtrip(compression: Option) -> PolarsResult<()> { Ok(()) } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn no_compression() -> PolarsResult<()> { roundtrip(None).await diff --git a/crates/polars/tests/it/io/parquet/arrow/mod.rs b/crates/polars/tests/it/io/parquet/arrow/mod.rs index a54b5fcacb1c..0a573eb4a186 100644 --- a/crates/polars/tests/it/io/parquet/arrow/mod.rs +++ b/crates/polars/tests/it/io/parquet/arrow/mod.rs @@ -19,6 +19,7 @@ use super::read::file::FileReader; fn new_struct( arrays: Vec>, + length: usize, names: Vec, validity: Option, ) -> StructArray { @@ -27,7 +28,7 @@ fn new_struct( .zip(arrays.iter()) .map(|(n, a)| Field::new(n.into(), a.dtype().clone(), true)) .collect(); - StructArray::new(ArrowDataType::Struct(fields), arrays, validity) + StructArray::new(ArrowDataType::Struct(fields), length, arrays, validity) } pub fn read_column(mut reader: R, column: &str) -> PolarsResult> { @@ -85,6 +86,7 @@ pub fn pyarrow_nested_edge(column: &str) -> Box { ); StructArray::new( ArrowDataType::Struct(vec![Field::new("f1".into(), a.dtype().clone(), true)]), + a.len(), vec![a.boxed()], None, ) @@ -260,8 +262,11 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { Some("e"), ]) .boxed(); + + let len = array.len(); new_struct( vec![array], + len, vec!["a".to_string()], Some( [ @@ -322,8 +327,10 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { ) .boxed(); + let len = array.len(); new_struct( vec![array], + len, vec!["a".to_string()], Some( [ @@ -416,7 +423,10 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { let array: ListArray = a.into(); Box::new(array) }, - "struct_list_nullable" => new_struct(vec![values], vec!["a".to_string()], None).boxed(), + "struct_list_nullable" => { + let len = values.len(); + new_struct(vec![values], len, vec!["a".to_string()], None).boxed() + }, _ => { let field = match column { "list_int64" => Field::new("item".into(), ArrowDataType::Int64, true), @@ -809,6 +819,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { distinct_count: new_list( new_struct( vec![UInt64Array::from([None]).boxed()], + 1, vec!["a".to_string()], None, ) @@ -819,6 +830,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { null_count: new_list( new_struct( vec![UInt64Array::from([Some(4)]).boxed()], + 1, vec!["a".to_string()], None, ) @@ -829,6 +841,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { min_value: new_list( new_struct( vec![Utf8ViewArray::from_slice([Some("a")]).boxed()], + 1, vec!["a".to_string()], None, ) @@ -839,6 +852,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { max_value: new_list( new_struct( vec![Utf8ViewArray::from_slice([Some("e")]).boxed()], + 1, vec!["a".to_string()], None, ) @@ -851,6 +865,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { distinct_count: new_list( new_struct( vec![new_list(UInt64Array::from([None]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) @@ -861,6 +876,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { null_count: new_list( new_struct( vec![new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) @@ -871,6 +887,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { min_value: new_list( new_struct( vec![new_list(Utf8ViewArray::from_slice([Some("a")]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) @@ -881,6 +898,7 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { max_value: new_list( new_struct( vec![new_list(Utf8ViewArray::from_slice([Some("d")]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) @@ -892,24 +910,28 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { "struct_list_nullable" => Statistics { distinct_count: new_struct( vec![new_list(UInt64Array::from([None]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) .boxed(), null_count: new_struct( vec![new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) .boxed(), min_value: new_struct( vec![new_list(Utf8ViewArray::from_slice([Some("")]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) .boxed(), max_value: new_struct( vec![new_list(Utf8ViewArray::from_slice([Some("ccc")]).boxed(), true).boxed()], + 1, vec!["a".to_string()], None, ) @@ -933,13 +955,13 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { ) }; - let new_struct = |arrays: Vec>, names: Vec| { + let new_struct = |arrays: Vec>, length: usize, names: Vec| { let fields = names .into_iter() .zip(arrays.iter()) .map(|(n, a)| Field::new(n.into(), a.dtype().clone(), true)) .collect(); - StructArray::new(ArrowDataType::Struct(fields), arrays, None) + StructArray::new(ArrowDataType::Struct(fields), length, arrays, None) }; let names = vec!["f1".to_string()]; @@ -960,20 +982,24 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { "struct_list_nullable" => Statistics { distinct_count: new_struct( vec![new_list(Box::new(UInt64Array::from([None]))).boxed()], + 1, names.clone(), ) .boxed(), null_count: new_struct( vec![new_list(Box::new(UInt64Array::from([Some(1)]))).boxed()], + 1, names.clone(), ) .boxed(), min_value: Box::new(new_struct( vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("a")]))).boxed()], + 1, names.clone(), )), max_value: Box::new(new_struct( vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("c")]))).boxed()], + 1, names, )), }, @@ -981,6 +1007,7 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { distinct_count: new_list( new_struct( vec![new_list(Box::new(UInt64Array::from([None]))).boxed()], + 1, names.clone(), ) .boxed(), @@ -989,6 +1016,7 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { null_count: new_list( new_struct( vec![new_list(Box::new(UInt64Array::from([Some(1)]))).boxed()], + 1, names.clone(), ) .boxed(), @@ -996,11 +1024,13 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { .boxed(), min_value: new_list(Box::new(new_struct( vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("a")]))).boxed()], + 1, names.clone(), ))) .boxed(), max_value: new_list(Box::new(new_struct( vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("c")]))).boxed()], + 1, names, ))) .boxed(), @@ -1045,12 +1075,23 @@ pub fn pyarrow_struct(column: &str) -> Box { Field::new("f2".into(), ArrowDataType::Boolean, true), ]; match column { - "struct" => { - StructArray::new(ArrowDataType::Struct(fields), vec![string, boolean], None).boxed() - }, + "struct" => StructArray::new( + ArrowDataType::Struct(fields), + string.len(), + vec![string, boolean], + None, + ) + .boxed(), "struct_nullable" => { + let len = string.len(); let values = vec![string, boolean]; - StructArray::new(ArrowDataType::Struct(fields), values, Some(mask.into())).boxed() + StructArray::new( + ArrowDataType::Struct(fields), + len, + values, + Some(mask.into()), + ) + .boxed() }, "struct_struct" => { let struct_ = pyarrow_struct("struct"); @@ -1059,6 +1100,7 @@ pub fn pyarrow_struct(column: &str) -> Box { Field::new("f1".into(), ArrowDataType::Struct(fields), true), Field::new("f2".into(), ArrowDataType::Boolean, true), ]), + struct_.len(), vec![struct_, boolean], None, )) @@ -1070,6 +1112,7 @@ pub fn pyarrow_struct(column: &str) -> Box { Field::new("f1".into(), ArrowDataType::Struct(fields), true), Field::new("f2".into(), ArrowDataType::Boolean, true), ]), + struct_.len(), vec![struct_, boolean], Some(mask.into()), )) @@ -1079,8 +1122,9 @@ pub fn pyarrow_struct(column: &str) -> Box { } pub fn pyarrow_struct_statistics(column: &str) -> Statistics { - let new_struct = - |arrays: Vec>, names: Vec| new_struct(arrays, names, None); + let new_struct = |arrays: Vec>, length: usize, names: Vec| { + new_struct(arrays, length, names, None) + }; let names = vec!["f1".to_string(), "f2".to_string()]; @@ -1091,6 +1135,7 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(UInt64Array::from([None])), Box::new(UInt64Array::from([Some(2)])), ], + 1, names.clone(), ) .boxed(), @@ -1099,6 +1144,7 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(UInt64Array::from([Some(4)])), Box::new(UInt64Array::from([Some(4)])), ], + 1, names.clone(), ) .boxed(), @@ -1107,6 +1153,7 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(Utf8ViewArray::from_slice([Some("")])), Box::new(BooleanArray::from_slice([false])), ], + 1, names.clone(), )), max_value: Box::new(new_struct( @@ -1114,6 +1161,7 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(Utf8ViewArray::from_slice([Some("def")])), Box::new(BooleanArray::from_slice([true])), ], + 1, names, )), }, @@ -1125,11 +1173,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(UInt64Array::from([None])), Box::new(UInt64Array::from([Some(2)])), ], + 1, names.clone(), ) .boxed(), UInt64Array::from([None]).boxed(), ], + 1, names.clone(), ) .boxed(), @@ -1140,11 +1190,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(UInt64Array::from([Some(4)])), Box::new(UInt64Array::from([Some(4)])), ], + 1, names.clone(), ) .boxed(), UInt64Array::from([Some(4)]).boxed(), ], + 1, names.clone(), ) .boxed(), @@ -1155,11 +1207,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Utf8ViewArray::from_slice([Some("")]).boxed(), BooleanArray::from_slice([false]).boxed(), ], + 1, names.clone(), ) .boxed(), BooleanArray::from_slice([false]).boxed(), ], + 1, names.clone(), ) .boxed(), @@ -1170,11 +1224,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Utf8ViewArray::from_slice([Some("def")]).boxed(), BooleanArray::from_slice([true]).boxed(), ], + 1, names.clone(), ) .boxed(), BooleanArray::from_slice([true]).boxed(), ], + 1, names, ) .boxed(), @@ -1187,11 +1243,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(UInt64Array::from([None])), Box::new(UInt64Array::from([None])), ], + 1, names.clone(), ) .boxed(), UInt64Array::from([None]).boxed(), ], + 1, names.clone(), ) .boxed(), @@ -1202,11 +1260,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Box::new(UInt64Array::from([Some(5)])), Box::new(UInt64Array::from([Some(5)])), ], + 1, names.clone(), ) .boxed(), UInt64Array::from([Some(5)]).boxed(), ], + 1, names.clone(), ) .boxed(), @@ -1217,11 +1277,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Utf8ViewArray::from_slice([Some("")]).boxed(), BooleanArray::from_slice([false]).boxed(), ], + 1, names.clone(), ) .boxed(), BooleanArray::from_slice([false]).boxed(), ], + 1, names.clone(), ) .boxed(), @@ -1232,11 +1294,13 @@ pub fn pyarrow_struct_statistics(column: &str) -> Statistics { Utf8ViewArray::from_slice([Some("def")]).boxed(), BooleanArray::from_slice([true]).boxed(), ], + 1, names.clone(), ) .boxed(), BooleanArray::from_slice([true]).boxed(), ], + 1, names, ) .boxed(), @@ -1357,20 +1421,23 @@ fn generic_data() -> PolarsResult<(ArrowSchema, RecordBatchT>)> { Field::new("a12".into(), array12.dtype().clone(), true), Field::new("a13".into(), array13.dtype().clone(), true), ]); - let chunk = RecordBatchT::try_new(vec![ - array1.boxed(), - array2.boxed(), - array3.boxed(), - array4.boxed(), - array6.boxed(), - array7.boxed(), - array8.boxed(), - array9.boxed(), - array10.boxed(), - array11.boxed(), - array12.boxed(), - array13.boxed(), - ])?; + let chunk = RecordBatchT::try_new( + array1.len(), + vec![ + array1.boxed(), + array2.boxed(), + array3.boxed(), + array4.boxed(), + array6.boxed(), + array7.boxed(), + array8.boxed(), + array9.boxed(), + array10.boxed(), + array11.boxed(), + array12.boxed(), + array13.boxed(), + ], + )?; Ok((schema, chunk)) } @@ -1385,12 +1452,13 @@ fn assert_roundtrip( let (new_schema, new_chunks) = integration_read(&r, limit)?; let expected = if let Some(limit) = limit { + let length = chunk.len().min(limit); let expected = chunk .into_arrays() .into_iter() .map(|x| x.sliced(0, limit)) .collect::>(); - RecordBatchT::new(expected) + RecordBatchT::new(length, expected) } else { chunk }; @@ -1451,7 +1519,7 @@ fn assert_array_roundtrip( ) -> PolarsResult<()> { let schema = ArrowSchema::from_iter([Field::new("a1".into(), array.dtype().clone(), is_nullable)]); - let chunk = RecordBatchT::try_new(vec![array])?; + let chunk = RecordBatchT::try_new(array.len(), vec![array])?; assert_roundtrip(schema, chunk, limit) } @@ -1580,7 +1648,7 @@ fn nested_dict_data( )?; let schema = ArrowSchema::from_iter([Field::new("c1".into(), values.dtype().clone(), true)]); - let chunk = RecordBatchT::try_new(vec![values.boxed()])?; + let chunk = RecordBatchT::try_new(values.len(), vec![values.boxed()])?; Ok((schema, chunk)) } @@ -1608,8 +1676,8 @@ fn nested_dict_limit() -> PolarsResult<()> { #[test] fn filter_chunk() -> PolarsResult<()> { - let chunk1 = RecordBatchT::new(vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); - let chunk2 = RecordBatchT::new(vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); + let chunk1 = RecordBatchT::new(2, vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); + let chunk2 = RecordBatchT::new(2, vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); let schema = ArrowSchema::from_iter([Field::new("c1".into(), ArrowDataType::Int16, true)]); let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; diff --git a/crates/polars/tests/it/io/parquet/arrow/write.rs b/crates/polars/tests/it/io/parquet/arrow/write.rs index 8863c068baff..9619a083ddcb 100644 --- a/crates/polars/tests/it/io/parquet/arrow/write.rs +++ b/crates/polars/tests/it/io/parquet/arrow/write.rs @@ -50,7 +50,7 @@ fn round_trip_opt_stats( data_page_size: None, }; - let iter = vec![RecordBatchT::try_new(vec![array.clone()])]; + let iter = vec![RecordBatchT::try_new(array.len(), vec![array.clone()])]; let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs index 9b2185656462..188adf8efeba 100644 --- a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs +++ b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs @@ -30,7 +30,7 @@ pub fn read( num_values: usize, _is_sorted: bool, ) -> ParquetResult> { - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let typed_size = num_values.wrapping_mul(size_of); diff --git a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs index 36fdb254420a..430df46d1239 100644 --- a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs +++ b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs @@ -10,7 +10,7 @@ use super::dictionary::PrimitivePageDict; use super::{hybrid_rle_iter, Array}; fn read_buffer(values: &[u8]) -> impl Iterator + '_ { - let chunks = values.chunks_exact(std::mem::size_of::()); + let chunks = values.chunks_exact(size_of::()); chunks.map(|chunk| { // unwrap is infalible due to the chunk size. let chunk: T::Bytes = match chunk.try_into() { @@ -179,7 +179,7 @@ pub struct DecoderIter<'a, T: Unpackable> { pub(crate) unpacked_end: usize, } -impl<'a, T: Unpackable> Iterator for DecoderIter<'a, T> { +impl Iterator for DecoderIter<'_, T> { type Item = T; fn next(&mut self) -> Option { @@ -203,7 +203,7 @@ impl<'a, T: Unpackable> Iterator for DecoderIter<'a, T> { } } -impl<'a, T: Unpackable> ExactSizeIterator for DecoderIter<'a, T> {} +impl ExactSizeIterator for DecoderIter<'_, T> {} impl<'a, T: Unpackable> DecoderIter<'a, T> { pub fn new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { diff --git a/crates/polars/tests/it/io/parquet/read/row_group.rs b/crates/polars/tests/it/io/parquet/read/row_group.rs index 80478a0da958..8008c594ed19 100644 --- a/crates/polars/tests/it/io/parquet/read/row_group.rs +++ b/crates/polars/tests/it/io/parquet/read/row_group.rs @@ -52,7 +52,8 @@ impl Iterator for RowGroupDeserializer { if self.remaining_rows == 0 { return None; } - let chunk = RecordBatchT::try_new(std::mem::take(&mut self.column_chunks)); + let length = self.column_chunks.first().map_or(0, |chunk| chunk.len()); + let chunk = RecordBatchT::try_new(length, std::mem::take(&mut self.column_chunks)); self.remaining_rows = self.remaining_rows.saturating_sub( chunk .as_ref() diff --git a/crates/polars/tests/it/io/parquet/read/utils.rs b/crates/polars/tests/it/io/parquet/read/utils.rs index 19feaee29534..14b664e0d962 100644 --- a/crates/polars/tests/it/io/parquet/read/utils.rs +++ b/crates/polars/tests/it/io/parquet/read/utils.rs @@ -197,13 +197,11 @@ pub fn native_cast(page: &DataPage) -> ParquetResult> { def: _, values, } = split_buffer(page)?; - if values.len() % std::mem::size_of::() != 0 { + if values.len() % size_of::() != 0 { panic!("A primitive page data's len must be a multiple of the type"); } - Ok(values - .chunks_exact(std::mem::size_of::()) - .map(decode::)) + Ok(values.chunks_exact(size_of::()).map(decode::)) } /// The deserialization state of a `DataPage` of `Primitive` parquet primitive type diff --git a/crates/polars/tests/it/io/parquet/roundtrip.rs b/crates/polars/tests/it/io/parquet/roundtrip.rs index 6e105002fa53..d20551432ec0 100644 --- a/crates/polars/tests/it/io/parquet/roundtrip.rs +++ b/crates/polars/tests/it/io/parquet/roundtrip.rs @@ -28,7 +28,7 @@ fn round_trip( data_page_size: None, }; - let iter = vec![RecordBatchT::try_new(vec![array.clone()])]; + let iter = vec![RecordBatchT::try_new(array.len(), vec![array.clone()])]; let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs index 85ded9c742d0..10c386037d17 100644 --- a/crates/polars/tests/it/lazy/aggregation.rs +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -26,7 +26,7 @@ fn test_lazy_agg() { col("rain").min().alias("min"), col("rain").sum().alias("sum"), col("rain") - .quantile(lit(0.5), QuantileInterpolOptions::default()) + .quantile(lit(0.5), QuantileMethod::default()) .alias("median_rain"), ]) .sort(["date"], Default::default()); @@ -35,33 +35,3 @@ fn test_lazy_agg() { let min = new.column("min").unwrap(); assert_eq!(min, &Column::new("min".into(), [0.1f64, 0.01, 0.1])); } - -#[test] -#[should_panic(expected = "hardcoded error")] -/// Test where apply_multiple returns an error -fn test_apply_multiple_error() { - fn issue() -> Expr { - apply_multiple( - move |_| polars_bail!(ComputeError: "hardcoded error"), - &[col("x"), col("y")], - GetOutput::from_type(DataType::Float64), - true, - ) - } - - let df = df![ - "rf" => ["App", "App", "Gg", "App"], - "x" => ["Hey", "There", "Ante", "R"], - "y" => [Some(-1.11), Some(2.),None, Some(3.4)], - "z" => [Some(-1.11), Some(2.),None, Some(3.4)], - ] - .unwrap(); - - let _res = df - .lazy() - .with_streaming(false) - .group_by_stable([col("rf")]) - .agg([issue()]) - .collect() - .unwrap(); -} diff --git a/crates/polars/tests/it/lazy/queries.rs b/crates/polars/tests/it/lazy/queries.rs index f140a0461639..50e5c70be047 100644 --- a/crates/polars/tests/it/lazy/queries.rs +++ b/crates/polars/tests/it/lazy/queries.rs @@ -203,7 +203,7 @@ fn test_apply_multiple_columns() -> PolarsResult<()> { .select([map_multiple( multiply, [col("A"), col("B")], - GetOutput::from_type(DataType::Float64), + GetOutput::from_type(DataType::Int32), )]) .collect()?; let out = out.column("A")?; @@ -219,7 +219,7 @@ fn test_apply_multiple_columns() -> PolarsResult<()> { .agg([apply_multiple( multiply, [col("A"), col("B")], - GetOutput::from_type(DataType::Float64), + GetOutput::from_type(DataType::Int32), true, )]) .collect()?; diff --git a/crates/polars/tests/it/schema.rs b/crates/polars/tests/it/schema.rs index c791367f7546..9a464e27384b 100644 --- a/crates/polars/tests/it/schema.rs +++ b/crates/polars/tests/it/schema.rs @@ -544,3 +544,31 @@ fn test_set_dtype() { ), ); } + +#[test] +fn test_infer_schema() { + use polars_core::frame::row::infer_schema; + use DataType::{Int32, Null, String}; + + // Sample data as a vector of tuples (column name, value) + let data: Vec> = vec![ + vec![(PlSmallStr::from("a"), DataType::String)], + vec![(PlSmallStr::from("b"), DataType::Int32)], + vec![(PlSmallStr::from("c"), DataType::Null)], + ]; + + // Create an iterator over the sample data + let iter = data.into_iter(); + + // Infer the schema + let schema = infer_schema(iter, 3); + + let exp_fields = vec![ + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Null), + ]; + + // Check the inferred schema + assert_eq!(Schema::from_iter(exp_fields.clone()), schema); +} diff --git a/docs/source/_build/API_REFERENCE_LINKS.yml b/docs/source/_build/API_REFERENCE_LINKS.yml index 2a9bc80237dc..1e301f592cb1 100644 --- a/docs/source/_build/API_REFERENCE_LINKS.yml +++ b/docs/source/_build/API_REFERENCE_LINKS.yml @@ -26,6 +26,7 @@ python: is_duplicated: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_duplicated.html sample: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.sample.html head: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.head.html + glimpse: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.glimpse.html tail: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.tail.html describe: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.describe.html col: https://docs.pola.rs/api/python/stable/reference/expressions/col.html @@ -101,6 +102,7 @@ python: name: execute link: https://docs.pola.rs/api/python/stable/reference/sql/api/polars.SQLContext.execute.html join_asof: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_asof.html + join_where: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_where.html concat: https://docs.pola.rs/api/python/stable/reference/api/polars.concat.html pivot: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pivot.html unpivot: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.unpivot.html @@ -179,6 +181,11 @@ rust: link: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.group_by_dynamic feature_flags: [dynamic_group_by] join: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.DataFrameJoinOps.html#method.join + join-semi_anti_join_flag: + name: join + link: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.DataFrameJoinOps.html#method.join + feature_flags: ["semi_anti_join"] + vstack: https://docs.pola.rs/api/rust/dev/polars_core/frame/struct.DataFrame.html#method.vstack concat: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/functions/fn.concat.html @@ -192,7 +199,18 @@ rust: pivot: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/pivot/fn.pivot.html unpivot: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.unpivot upsample: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.upsample - join_asof: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.AsofJoin.html#method.join_asof + join_asof_by: + name: join_asof_by + link: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.AsofJoinBy.html#method.join_asof_by + feature_flags: ['asof_join'] + join_where: + name: join_where + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.JoinBuilder.html#method.join_where + feature_flags: ["iejoin"] + cross_join: + name: cross_join + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.LazyFrame.html#method.cross_join + feature_flags: [cross_join] unnest: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.unnest read_csv: diff --git a/docs/source/_build/scripts/macro.py b/docs/source/_build/scripts/macro.py index 3b8055074d44..651786b0044b 100644 --- a/docs/source/_build/scripts/macro.py +++ b/docs/source/_build/scripts/macro.py @@ -1,10 +1,12 @@ from collections import OrderedDict import os -from typing import List, Optional, Set +from typing import Any, List, Optional, Set import yaml import logging +from mkdocs_macros.plugin import MacrosPlugin + # Supported Languages and their metadata LANGUAGES = OrderedDict( python={ @@ -130,7 +132,7 @@ def code_tab( """ -def define_env(env): +def define_env(env: MacrosPlugin) -> None: @env.macro def code_header( language: str, section: str = [], api_functions: List[str] = [] @@ -154,7 +156,11 @@ def code_header( @env.macro def code_block( - path: str, section: str = None, api_functions: List[str] = None + path: str, + section: str = None, + api_functions: List[str] = None, + python_api_functions: List[str] = None, + rust_api_functions: List[str] = None, ) -> str: """Dynamically generate a code block for the code located under {language}/path @@ -170,8 +176,14 @@ def code_block( for language, info in LANGUAGES.items(): base_path = f"{language}/{path}{info['extension']}" full_path = "docs/source/src/" + base_path + if language == "python": + extras = python_api_functions or [] + else: + extras = rust_api_functions or [] # Check if file exists for the language if os.path.exists(full_path): - result.append(code_tab(base_path, section, info, api_functions)) + result.append( + code_tab(base_path, section, info, api_functions + extras) + ) return "\n".join(result) diff --git a/docs/source/development/contributing/code-style.md b/docs/source/development/contributing/code-style.md index 00ad8a8f726e..c22b9f3ed7ac 100644 --- a/docs/source/development/contributing/code-style.md +++ b/docs/source/development/contributing/code-style.md @@ -30,7 +30,7 @@ use polars::export::arrow::array::*; use polars::export::arrow::compute::arity::binary; use polars::export::arrow::types::NativeType; use polars::prelude::*; -use polars_core::utils::{align_chunks_binary, combine_validities_or}; +use polars_core::utils::{align_chunks_binary, combine_validities_and}; use polars_core::with_match_physical_numeric_polars_type; // Prefer to do the compute closest to the arrow arrays. @@ -45,7 +45,7 @@ where let validity_1 = arr_1.validity(); let validity_2 = arr_2.validity(); - let validity = combine_validities_or(validity_1, validity_2); + let validity = combine_validities_and(validity_1, validity_2); // process the numerical data as if there were no validities let values_1: &[T] = arr_1.values().as_slice(); diff --git a/docs/source/development/contributing/index.md b/docs/source/development/contributing/index.md index 30fb6ddc0ac9..c3175df9f5b2 100644 --- a/docs/source/development/contributing/index.md +++ b/docs/source/development/contributing/index.md @@ -268,6 +268,13 @@ df = pl.read_parquet("file.parquet") The snippet is delimited by `--8<-- [start:]` and `--8<-- [end:]`. The snippet name must match the name given in the second argument to `code_block` above. +In some cases, you may need to add links to different functions for the Python and Rust APIs. +When that is the case, you can use the two extra optional arguments that `code_block` accepts, that can be used to pass Python-only and Rust-only links: + +``` +{{code_block('path', 'snippet_name', ['common_api_links'], ['python_only_links'], ['rust_only_links'])}} +``` + #### Linting Before committing, install `dprint` (see above) and run `dprint fmt` from the `docs` directory to lint the markdown files. diff --git a/docs/source/src/python/user-guide/concepts/data-structures.py b/docs/source/src/python/user-guide/concepts/data-structures.py deleted file mode 100644 index edc1a2a25c3c..000000000000 --- a/docs/source/src/python/user-guide/concepts/data-structures.py +++ /dev/null @@ -1,42 +0,0 @@ -# --8<-- [start:series] -import polars as pl - -s = pl.Series("a", [1, 2, 3, 4, 5]) -print(s) -# --8<-- [end:series] - -# --8<-- [start:dataframe] -from datetime import datetime - -df = pl.DataFrame( - { - "integer": [1, 2, 3, 4, 5], - "date": [ - datetime(2022, 1, 1), - datetime(2022, 1, 2), - datetime(2022, 1, 3), - datetime(2022, 1, 4), - datetime(2022, 1, 5), - ], - "float": [4.0, 5.0, 6.0, 7.0, 8.0], - } -) - -print(df) -# --8<-- [end:dataframe] - -# --8<-- [start:head] -print(df.head(3)) -# --8<-- [end:head] - -# --8<-- [start:tail] -print(df.tail(3)) -# --8<-- [end:tail] - -# --8<-- [start:sample] -print(df.sample(2)) -# --8<-- [end:sample] - -# --8<-- [start:describe] -print(df.describe()) -# --8<-- [end:describe] diff --git a/docs/source/src/python/user-guide/concepts/data-types-and-structures.py b/docs/source/src/python/user-guide/concepts/data-types-and-structures.py new file mode 100644 index 000000000000..3d08edcbcec9 --- /dev/null +++ b/docs/source/src/python/user-guide/concepts/data-types-and-structures.py @@ -0,0 +1,60 @@ +# --8<-- [start:series] +import polars as pl + +s = pl.Series("ints", [1, 2, 3, 4, 5]) +print(s) +# --8<-- [end:series] + +# --8<-- [start:series-dtype] +s1 = pl.Series("ints", [1, 2, 3, 4, 5]) +s2 = pl.Series("uints", [1, 2, 3, 4, 5], dtype=pl.UInt64) +print(s1.dtype, s2.dtype) +# --8<-- [end:series-dtype] + +# --8<-- [start:df] +from datetime import date + +df = pl.DataFrame( + { + "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate": [ + date(1997, 1, 10), + date(1985, 2, 15), + date(1983, 3, 22), + date(1981, 4, 30), + ], + "weight": [57.9, 72.5, 53.6, 83.1], # (kg) + "height": [1.56, 1.77, 1.65, 1.75], # (m) + } +) + +print(df) +# --8<-- [end:df] + +# --8<-- [start:schema] +print(df.schema) +# --8<-- [end:schema] + +# --8<-- [start:head] +print(df.head(3)) +# --8<-- [end:head] + +# --8<-- [start:glimpse] +print(df.glimpse(return_as_string=True)) +# --8<-- [end:glimpse] + +# --8<-- [start:tail] +print(df.tail(3)) +# --8<-- [end:tail] + +# --8<-- [start:sample] +import random + +random.seed(42) # For reproducibility. + +print(df.sample(2)) +# --8<-- [end:sample] + +# --8<-- [start:describe] +print(df.describe()) +# --8<-- [end:describe] diff --git a/docs/source/src/python/user-guide/concepts/expressions.py b/docs/source/src/python/user-guide/concepts/expressions.py index c6b477ec7692..7a4cb0637ba3 100644 --- a/docs/source/src/python/user-guide/concepts/expressions.py +++ b/docs/source/src/python/user-guide/concepts/expressions.py @@ -1,16 +1,105 @@ +# --8<-- [start:expression] import polars as pl +pl.col("weight") / (pl.col("height") ** 2) +# --8<-- [end:expression] + +# --8<-- [start:print-expr] +bmi_expr = pl.col("weight") / (pl.col("height") ** 2) +print(bmi_expr) +# --8<-- [end:print-expr] + +# --8<-- [start:df] +from datetime import date + df = pl.DataFrame( { - "foo": [1, 2, 3, None, 5], - "bar": [1.5, 0.9, 2.0, 0.0, None], + "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate": [ + date(1997, 1, 10), + date(1985, 2, 15), + date(1983, 3, 22), + date(1981, 4, 30), + ], + "weight": [57.9, 72.5, 53.6, 83.1], # (kg) + "height": [1.56, 1.77, 1.65, 1.75], # (m) } ) -# --8<-- [start:example1] -pl.col("foo").sort().head(2) -# --8<-- [end:example1] +print(df) +# --8<-- [end:df] + +# --8<-- [start:select-1] +result = df.select( + bmi=bmi_expr, + avg_bmi=bmi_expr.mean(), + ideal_max_bmi=25, +) +print(result) +# --8<-- [end:select-1] + +# --8<-- [start:select-2] +result = df.select(deviation=(bmi_expr - bmi_expr.mean()) / bmi_expr.std()) +print(result) +# --8<-- [end:select-2] + +# --8<-- [start:with_columns-1] +result = df.with_columns( + bmi=bmi_expr, + avg_bmi=bmi_expr.mean(), + ideal_max_bmi=25, +) +print(result) +# --8<-- [end:with_columns-1] + +# --8<-- [start:filter-1] +result = df.filter( + pl.col("birthdate").is_between(date(1982, 12, 31), date(1996, 1, 1)), + pl.col("height") > 1.7, +) +print(result) +# --8<-- [end:filter-1] + +# --8<-- [start:group_by-1] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), +).agg(pl.col("name")) +print(result) +# --8<-- [end:group_by-1] + +# --8<-- [start:group_by-2] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + (pl.col("height") < 1.7).alias("short?"), +).agg(pl.col("name")) +print(result) +# --8<-- [end:group_by-2] + +# --8<-- [start:group_by-3] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + (pl.col("height") < 1.7).alias("short?"), +).agg( + pl.len(), + pl.col("height").max().alias("tallest"), + pl.col("weight", "height").mean().name.prefix("avg_"), +) +print(result) +# --8<-- [end:group_by-3] -# --8<-- [start:example2] -df.select(pl.col("foo").sort().head(2), pl.col("bar").filter(pl.col("foo") == 1).sum()) -# --8<-- [end:example2] +# --8<-- [start:expression-expansion-1] +expr = (pl.col(pl.Float64) * 1.1).name.suffix("*1.1") +result = df.select(expr) +print(result) +# --8<-- [end:expression-expansion-1] + +# --8<-- [start:expression-expansion-2] +df2 = pl.DataFrame( + { + "ints": [1, 2, 3, 4], + "letters": ["A", "B", "C", "D"], + } +) +result = df2.select(expr) +print(result) +# --8<-- [end:expression-expansion-2] diff --git a/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py b/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py index ebd684cf1a1d..dd48c65b2378 100644 --- a/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py +++ b/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py @@ -1,5 +1,8 @@ +# --8<-- [start:import] import polars as pl +# --8<-- [end:import] + # --8<-- [start:eager] df = pl.read_csv("docs/assets/data/iris.csv") @@ -18,3 +21,25 @@ df = q.collect() # --8<-- [end:lazy] + +# --8<-- [start:explain] +print(q.explain()) +# --8<-- [end:explain] + +# --8<-- [start:explain-expression-expansion] +schema = pl.Schema( + { + "int_1": pl.Int16, + "int_2": pl.Int32, + "float_1": pl.Float64, + "float_2": pl.Float64, + "float_3": pl.Float64, + } +) + +print( + pl.LazyFrame(schema=schema) + .select((pl.col(pl.Float64) * 1.1).name.suffix("*1.1")) + .explain() +) +# --8<-- [end:explain-expression-expansion] diff --git a/docs/source/src/python/user-guide/io/cloud-storage.py b/docs/source/src/python/user-guide/io/cloud-storage.py index 73cf597ec84e..12b02df28e61 100644 --- a/docs/source/src/python/user-guide/io/cloud-storage.py +++ b/docs/source/src/python/user-guide/io/cloud-storage.py @@ -7,7 +7,16 @@ df = pl.read_parquet(source) # --8<-- [end:read_parquet] -# --8<-- [start:scan_parquet] +# --8<-- [start:scan_parquet_query] +import polars as pl + +source = "s3://bucket/*.parquet" + +df = pl.scan_parquet(source).filter(pl.col("id") < 100).select("id","value").collect() +# --8<-- [end:scan_parquet_query] + + +# --8<-- [start:scan_parquet_storage_options_aws] import polars as pl source = "s3://bucket/*.parquet" @@ -17,17 +26,42 @@ "aws_secret_access_key": "", "aws_region": "us-east-1", } -df = pl.scan_parquet(source, storage_options=storage_options) -# --8<-- [end:scan_parquet] +df = pl.scan_parquet(source, storage_options=storage_options).collect() +# --8<-- [end:scan_parquet_storage_options_aws] + +# --8<-- [start:credential_provider_class] +lf = pl.scan_parquet( + "s3://.../...", + credential_provider=pl.CredentialProviderAWS( + profile_name="..." + assume_role={ + "RoleArn": f"...", + "RoleSessionName": "...", + } + ), +) -# --8<-- [start:scan_parquet_query] -import polars as pl +df = lf.collect() +# --8<-- [end:credential_provider_class] -source = "s3://bucket/*.parquet" +# --8<-- [start:credential_provider_custom_func] +def get_credentials() -> pl.CredentialProviderFunctionReturn: + expiry = None + return { + "aws_access_key_id": "...", + "aws_secret_access_key": "...", + "aws_session_token": "...", + }, expiry -df = pl.scan_parquet(source).filter(pl.col("id") < 100).select("id","value").collect() -# --8<-- [end:scan_parquet_query] + +lf = pl.scan_parquet( + "s3://.../...", + credential_provider=get_credentials, +) + +df = lf.collect() +# --8<-- [end:credential_provider_custom_func] # --8<-- [start:scan_pyarrow_dataset] import polars as pl diff --git a/docs/source/src/python/user-guide/transformations/joins.py b/docs/source/src/python/user-guide/transformations/joins.py index a34ea310e614..e44cbdc560c1 100644 --- a/docs/source/src/python/user-guide/transformations/joins.py +++ b/docs/source/src/python/user-guide/transformations/joins.py @@ -1,117 +1,138 @@ -# --8<-- [start:setup] +# --8<-- [start:prep-data] +import pathlib +import requests + + +DATA = [ + ( + "https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/data/monopoly_props_groups.csv", + "docs/assets/data/monopoly_props_groups.csv", + ), + ( + "https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/data/monopoly_props_prices.csv", + "docs/assets/data/monopoly_props_prices.csv", + ), +] + + +for url, dest in DATA: + if pathlib.Path(dest).exists(): + continue + with open(dest, "wb") as f: + f.write(requests.get(url, timeout=10).content) +# --8<-- [end:prep-data] + +# --8<-- [start:props_groups] import polars as pl -from datetime import datetime - -# --8<-- [end:setup] - -# --8<-- [start:innerdf] -df_customers = pl.DataFrame( - { - "customer_id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - } -) -print(df_customers) -# --8<-- [end:innerdf] - -# --8<-- [start:innerdf2] -df_orders = pl.DataFrame( - { - "order_id": ["a", "b", "c"], - "customer_id": [1, 2, 2], - "amount": [100, 200, 300], - } -) -print(df_orders) -# --8<-- [end:innerdf2] - -# --8<-- [start:inner] -df_inner_customer_join = df_customers.join(df_orders, on="customer_id", how="inner") -print(df_inner_customer_join) -# --8<-- [end:inner] +props_groups = pl.read_csv("docs/assets/data/monopoly_props_groups.csv").head(5) +print(props_groups) +# --8<-- [end:props_groups] -# --8<-- [start:left] -df_left_join = df_customers.join(df_orders, on="customer_id", how="left") -print(df_left_join) -# --8<-- [end:left] +# --8<-- [start:props_prices] +props_prices = pl.read_csv("docs/assets/data/monopoly_props_prices.csv").head(5) +print(props_prices) +# --8<-- [end:props_prices] -# --8<-- [start:right] -df_right_join = df_orders.join(df_customers, on="customer_id", how="right") -print(df_right_join) -# --8<-- [end:right] +# --8<-- [start:equi-join] +result = props_groups.join(props_prices, on="property_name") +print(result) +# --8<-- [end:equi-join] -# --8<-- [start:full] -df_outer_join = df_customers.join(df_orders, on="customer_id", how="full") -print(df_outer_join) -# --8<-- [end:full] - -# --8<-- [start:full_coalesce] -df_outer_coalesce_join = df_customers.join( - df_orders, on="customer_id", how="full", coalesce=True +# --8<-- [start:props_groups2] +props_groups2 = props_groups.with_columns( + pl.col("property_name").str.to_lowercase(), ) -print(df_outer_coalesce_join) -# --8<-- [end:full_coalesce] +print(props_groups2) +# --8<-- [end:props_groups2] -# --8<-- [start:df3] -df_colors = pl.DataFrame( - { - "color": ["red", "blue", "green"], - } +# --8<-- [start:props_prices2] +props_prices2 = props_prices.select( + pl.col("property_name").alias("name"), pl.col("cost") ) -print(df_colors) -# --8<-- [end:df3] - -# --8<-- [start:df4] -df_sizes = pl.DataFrame( - { - "size": ["S", "M", "L"], - } +print(props_prices2) +# --8<-- [end:props_prices2] + +# --8<-- [start:join-key-expression] +result = props_groups2.join( + props_prices2, + left_on="property_name", + right_on=pl.col("name").str.to_lowercase(), +) +print(result) +# --8<-- [end:join-key-expression] + +# --8<-- [start:inner-join] +result = props_groups.join(props_prices, on="property_name", how="inner") +print(result) +# --8<-- [end:inner-join] + +# --8<-- [start:left-join] +result = props_groups.join(props_prices, on="property_name", how="left") +print(result) +# --8<-- [end:left-join] + +# --8<-- [start:right-join] +result = props_groups.join(props_prices, on="property_name", how="right") +print(result) +# --8<-- [end:right-join] + +# --8<-- [start:left-right-join-equals] +print( + result.equals( + props_prices.join( + props_groups, + on="property_name", + how="left", + # Reorder the columns to match the order from above. + ).select(pl.col("group"), pl.col("property_name"), pl.col("cost")) + ) ) -print(df_sizes) -# --8<-- [end:df4] +# --8<-- [end:left-right-join-equals] + +# --8<-- [start:full-join] +result = props_groups.join(props_prices, on="property_name", how="full") +print(result) +# --8<-- [end:full-join] + +# --8<-- [start:full-join-coalesce] +result = props_groups.join( + props_prices, + on="property_name", + how="full", + coalesce=True, +) +print(result) +# --8<-- [end:full-join-coalesce] -# --8<-- [start:cross] -df_cross_join = df_colors.join(df_sizes, how="cross") -print(df_cross_join) -# --8<-- [end:cross] +# --8<-- [start:semi-join] +result = props_groups.join(props_prices, on="property_name", how="semi") +print(result) +# --8<-- [end:semi-join] -# --8<-- [start:df5] -df_cars = pl.DataFrame( - { - "id": ["a", "b", "c"], - "make": ["ford", "toyota", "bmw"], - } -) -print(df_cars) -# --8<-- [end:df5] +# --8<-- [start:anti-join] +result = props_groups.join(props_prices, on="property_name", how="anti") +print(result) +# --8<-- [end:anti-join] -# --8<-- [start:df6] -df_repairs = pl.DataFrame( +# --8<-- [start:players] +players = pl.DataFrame( { - "id": ["c", "c"], - "cost": [100, 200], + "name": ["Alice", "Bob"], + "cash": [78, 135], } ) -print(df_repairs) -# --8<-- [end:df6] - -# --8<-- [start:inner2] -df_inner_join = df_cars.join(df_repairs, on="id", how="inner") -print(df_inner_join) -# --8<-- [end:inner2] +print(players) +# --8<-- [end:players] -# --8<-- [start:semi] -df_semi_join = df_cars.join(df_repairs, on="id", how="semi") -print(df_semi_join) -# --8<-- [end:semi] +# --8<-- [start:non-equi] +result = players.join_where(props_prices, pl.col("cash") > pl.col("cost")) +print(result) +# --8<-- [end:non-equi] -# --8<-- [start:anti] -df_anti_join = df_cars.join(df_repairs, on="id", how="anti") -print(df_anti_join) -# --8<-- [end:anti] +# --8<-- [start:df_trades] +from datetime import datetime -# --8<-- [start:df7] df_trades = pl.DataFrame( { "time": [ @@ -125,9 +146,9 @@ } ) print(df_trades) -# --8<-- [end:df7] +# --8<-- [end:df_trades] -# --8<-- [start:df8] +# --8<-- [start:df_quotes] df_quotes = pl.DataFrame( { "time": [ @@ -142,21 +163,23 @@ ) print(df_quotes) -# --8<-- [end:df8] - -# --8<-- [start:asofpre] -df_trades = df_trades.sort("time") -df_quotes = df_quotes.sort("time") # Set column as sorted -# --8<-- [end:asofpre] +# --8<-- [end:df_quotes] # --8<-- [start:asof] df_asof_join = df_trades.join_asof(df_quotes, on="time", by="stock") print(df_asof_join) # --8<-- [end:asof] -# --8<-- [start:asof2] +# --8<-- [start:asof-tolerance] df_asof_tolerance_join = df_trades.join_asof( df_quotes, on="time", by="stock", tolerance="1m" ) print(df_asof_tolerance_join) -# --8<-- [end:asof2] +# --8<-- [end:asof-tolerance] + +# --8<-- [start:cartesian-product] +tokens = pl.DataFrame({"monopoly_token": ["hat", "shoe", "boat"]}) + +result = players.select(pl.col("name")).join(tokens, how="cross") +print(result) +# --8<-- [end:cartesian-product] diff --git a/docs/source/src/rust/Cargo.toml b/docs/source/src/rust/Cargo.toml index 061c60d02948..8a6607d4aa84 100644 --- a/docs/source/src/rust/Cargo.toml +++ b/docs/source/src/rust/Cargo.toml @@ -31,8 +31,8 @@ path = "user-guide/getting-started.rs" required-features = ["polars/lazy", "polars/temporal", "polars/round_series", "polars/strings"] [[bin]] -name = "user-guide-concepts-data-structures" -path = "user-guide/concepts/data-structures.rs" +name = "user-guide-concepts-data-types-and-structures" +path = "user-guide/concepts/data-types-and-structures.rs" [[bin]] name = "user-guide-concepts-contexts" @@ -41,7 +41,7 @@ required-features = ["polars/lazy"] [[bin]] name = "user-guide-concepts-expressions" path = "user-guide/concepts/expressions.rs" -required-features = ["polars/lazy"] +required-features = ["polars/lazy", "polars/temporal", "polars/is_between"] [[bin]] name = "user-guide-concepts-lazy-vs-eager" path = "user-guide/concepts/lazy-vs-eager.rs" @@ -124,7 +124,7 @@ required-features = ["polars/lazy"] [[bin]] name = "user-guide-transformations-joins" path = "user-guide/transformations/joins.rs" -required-features = ["polars/lazy", "polars/asof_join"] +required-features = ["polars/lazy", "polars/strings", "polars/semi_anti_join", "polars/iejoin", "polars/cross_join"] [[bin]] name = "user-guide-transformations-unpivot" path = "user-guide/transformations/unpivot.rs" diff --git a/docs/source/src/rust/user-guide/concepts/data-structures.rs b/docs/source/src/rust/user-guide/concepts/data-structures.rs deleted file mode 100644 index d3cf7bd33d4f..000000000000 --- a/docs/source/src/rust/user-guide/concepts/data-structures.rs +++ /dev/null @@ -1,51 +0,0 @@ -fn main() { - // --8<-- [start:series] - use polars::prelude::*; - - let s = Series::new("a".into(), &[1, 2, 3, 4, 5]); - - println!("{}", s); - // --8<-- [end:series] - - // --8<-- [start:dataframe] - use chrono::NaiveDate; - - let df: DataFrame = df!( - "integer" => &[1, 2, 3, 4, 5], - "date" => &[ - NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), - NaiveDate::from_ymd_opt(2022, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), - NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), - NaiveDate::from_ymd_opt(2022, 1, 4).unwrap().and_hms_opt(0, 0, 0).unwrap(), - NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), - ], - "float" => &[4.0, 5.0, 6.0, 7.0, 8.0] - ) - .unwrap(); - - println!("{}", df); - // --8<-- [end:dataframe] - - // --8<-- [start:head] - let df_head = df.head(Some(3)); - - println!("{}", df_head); - // --8<-- [end:head] - - // --8<-- [start:tail] - let df_tail = df.tail(Some(3)); - - println!("{}", df_tail); - // --8<-- [end:tail] - - // --8<-- [start:sample] - let n = Series::new("".into(), &[2]); - let sampled_df = df.sample_n(&n, false, false, None).unwrap(); - - println!("{}", sampled_df); - // --8<-- [end:sample] - - // --8<-- [start:describe] - // Not available in Rust - // --8<-- [end:describe] -} diff --git a/docs/source/src/rust/user-guide/concepts/data-types-and-structures.rs b/docs/source/src/rust/user-guide/concepts/data-types-and-structures.rs new file mode 100644 index 000000000000..cc20f35db060 --- /dev/null +++ b/docs/source/src/rust/user-guide/concepts/data-types-and-structures.rs @@ -0,0 +1,62 @@ +fn main() { + // --8<-- [start:series] + use polars::prelude::*; + + let s = Series::new("ints".into(), &[1, 2, 3, 4, 5]); + + println!("{}", s); + // --8<-- [end:series] + + // --8<-- [start:series-dtype] + let s1 = Series::new("ints".into(), &[1, 2, 3, 4, 5]); + let s2 = Series::new("uints".into(), &[1, 2, 3, 4, 5]) + .cast(&DataType::UInt64) // Here, we actually cast after inference. + .unwrap(); + println!("{} {}", s1.dtype(), s2.dtype()); // i32 u64 + // --8<-- [end:series-dtype] + + // --8<-- [start:df] + use chrono::prelude::*; + + let df: DataFrame = df!( + "name" => ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate" => [ + NaiveDate::from_ymd_opt(1997, 1, 10).unwrap(), + NaiveDate::from_ymd_opt(1985, 2, 15).unwrap(), + NaiveDate::from_ymd_opt(1983, 3, 22).unwrap(), + NaiveDate::from_ymd_opt(1981, 4, 30).unwrap(), + ], + "weight" => [57.9, 72.5, 53.6, 83.1], // (kg) + "height" => [1.56, 1.77, 1.65, 1.75], // (m) + ) + .unwrap(); + println!("{}", df); + // --8<-- [end:df] + + // --8<-- [start:schema] + println!("{:?}", df.schema()); + // --8<-- [end:schema] + + // --8<-- [start:head] + let df_head = df.head(Some(3)); + + println!("{}", df_head); + // --8<-- [end:head] + + // --8<-- [start:tail] + let df_tail = df.tail(Some(3)); + + println!("{}", df_tail); + // --8<-- [end:tail] + + // --8<-- [start:sample] + let n = Series::new("".into(), &[2]); + let sampled_df = df.sample_n(&n, false, false, None).unwrap(); + + println!("{}", sampled_df); + // --8<-- [end:sample] + + // --8<-- [start:describe] + // Not available in Rust + // --8<-- [end:describe] +} diff --git a/docs/source/src/rust/user-guide/concepts/expressions.rs b/docs/source/src/rust/user-guide/concepts/expressions.rs index 2cd69f95b041..c74e9847e3fb 100644 --- a/docs/source/src/rust/user-guide/concepts/expressions.rs +++ b/docs/source/src/rust/user-guide/concepts/expressions.rs @@ -1,24 +1,135 @@ use polars::prelude::*; fn main() -> Result<(), Box> { - let df = df! ( - "foo" => &[Some(1), Some(2), Some(3), None, Some(5)], - "bar" => &[Some("foo"), Some("ham"), Some("spam"), Some("egg"), None], - )?; + // --8<-- [start:df] + use chrono::prelude::*; + use polars::prelude::*; - // --8<-- [start:example1] - let _ = col("foo").sort(Default::default()).head(Some(2)); - // --8<-- [end:example1] + let df: DataFrame = df!( + "name" => ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate" => [ + NaiveDate::from_ymd_opt(1997, 1, 10).unwrap(), + NaiveDate::from_ymd_opt(1985, 2, 15).unwrap(), + NaiveDate::from_ymd_opt(1983, 3, 22).unwrap(), + NaiveDate::from_ymd_opt(1981, 4, 30).unwrap(), + ], + "weight" => [57.9, 72.5, 53.6, 83.1], // (kg) + "height" => [1.56, 1.77, 1.65, 1.75], // (m) + ) + .unwrap(); + println!("{}", df); + // --8<-- [end:df] - // --8<-- [start:example2] - df.clone() + // --8<-- [start:select-1] + let bmi = col("weight") / col("height").pow(2); + let result = df + .clone() .lazy() .select([ - col("foo").sort(Default::default()).head(Some(2)), - col("bar").filter(col("foo").eq(lit(1))).sum(), + bmi.clone().alias("bmi"), + bmi.clone().mean().alias("avg_bmi"), + lit(25).alias("ideal_max_bmi"), ]) .collect()?; - // --8<-- [end:example2] + println!("{}", result); + // --8<-- [end:select-1] + + // --8<-- [start:select-2] + let result = df + .clone() + .lazy() + .select([((bmi.clone() - bmi.clone().mean()) / bmi.clone().std(1)).alias("deviation")]) + .collect()?; + println!("{}", result); + // --8<-- [end:select-2] + + // --8<-- [start:with_columns-1] + let result = df + .clone() + .lazy() + .with_columns([ + bmi.clone().alias("bmi"), + bmi.clone().mean().alias("avg_bmi"), + lit(25).alias("ideal_max_bmi"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:with_columns-1] + + // --8<-- [start:filter-1] + let result = df + .clone() + .lazy() + .filter( + col("birthdate") + .is_between( + lit(NaiveDate::from_ymd_opt(1982, 12, 31).unwrap()), + lit(NaiveDate::from_ymd_opt(1996, 1, 1).unwrap()), + ClosedInterval::Both, + ) + .and(col("height").gt(lit(1.7))), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:filter-1] + + // --8<-- [start:group_by-1] + let result = df + .clone() + .lazy() + .group_by([(col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade")]) + .agg([col("name")]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-1] + + // --8<-- [start:group_by-2] + let result = df + .clone() + .lazy() + .group_by([ + (col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade"), + (col("height").lt(lit(1.7)).alias("short?")), + ]) + .agg([col("name")]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-2] + + // --8<-- [start:group_by-3] + let result = df + .clone() + .lazy() + .group_by([ + (col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade"), + (col("height").lt(lit(1.7)).alias("short?")), + ]) + .agg([ + len(), + col("height").max().alias("tallest"), + cols(["weight", "height"]).mean().name().prefix("avg_"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-3] + + // --8<-- [start:expression-expansion-1] + let expr = (dtype_col(&DataType::Float64) * lit(1.1)) + .name() + .suffix("*1.1"); + let result = df.clone().lazy().select([expr.clone()]).collect()?; + println!("{}", result); + // --8<-- [end:expression-expansion-1] + + // --8<-- [start:expression-expansion-2] + let df2: DataFrame = df!( + "ints" => [1, 2, 3, 4], + "letters" => ["A", "B", "C", "D"], + ) + .unwrap(); + let result = df2.clone().lazy().select([expr.clone()]).collect()?; + println!("{}", result); + // --8<-- [end:expression-expansion-2] Ok(()) } diff --git a/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs b/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs index cbebb6a46a3f..955111ac2c11 100644 --- a/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs +++ b/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs @@ -28,5 +28,15 @@ fn main() -> Result<(), Box> { println!("{}", df); // --8<-- [end:lazy] + // --8<-- [start:explain] + let q = LazyCsvReader::new("docs/assets/data/iris.csv") + .with_has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + println!("{:?}", q.explain(true)); + // --8<-- [end:explain] + Ok(()) } diff --git a/docs/source/src/rust/user-guide/io/cloud-storage.rs b/docs/source/src/rust/user-guide/io/cloud-storage.rs index 5c297739eeee..2df882a39c00 100644 --- a/docs/source/src/rust/user-guide/io/cloud-storage.rs +++ b/docs/source/src/rust/user-guide/io/cloud-storage.rs @@ -1,7 +1,3 @@ -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#![allow(clippy::needless_return)] - // --8<-- [start:read_parquet] use aws_config::BehaviorVersion; use polars::prelude::*; @@ -31,12 +27,18 @@ async fn main() { } // --8<-- [end:read_parquet] -// --8<-- [start:scan_parquet] -// --8<-- [end:scan_parquet] - // --8<-- [start:scan_parquet_query] // --8<-- [end:scan_parquet_query] +// --8<-- [start:scan_parquet_storage_options_aws] +// --8<-- [end:scan_parquet_storage_options_aws] + +// --8<-- [start:credential_provider_class] +// --8<-- [end:credential_provider_class] + +// --8<-- [start:credential_provider_custom_func] +// --8<-- [end:credential_provider_custom_func] + // --8<-- [start:scan_pyarrow_dataset] // --8<-- [end:scan_pyarrow_dataset] diff --git a/docs/source/src/rust/user-guide/transformations/joins.rs b/docs/source/src/rust/user-guide/transformations/joins.rs index 5caa0cc4ac18..5d1c50f733b1 100644 --- a/docs/source/src/rust/user-guide/transformations/joins.rs +++ b/docs/source/src/rust/user-guide/transformations/joins.rs @@ -3,218 +3,252 @@ use polars::prelude::*; // --8<-- [end:setup] fn main() -> Result<(), Box> { - // --8<-- [start:innerdf] - let df_customers = df! ( + // NOTE: This assumes the data has been downloaded and is available. + // See the corresponding Python script for the remote location of the data. - "customer_id" => &[1, 2, 3], - "name" => &["Alice", "Bob", "Charlie"], - )?; - - println!("{}", &df_customers); - // --8<-- [end:innerdf] + // --8<-- [start:props_groups] + let props_groups = CsvReadOptions::default() + .with_has_header(true) + .try_into_reader_with_file_path(Some( + "../../../assets/data/monopoly_props_groups.csv".into(), + ))? + .finish()? + .head(Some(5)); + println!("{}", props_groups); + // --8<-- [end:props_groups] - // --8<-- [start:innerdf2] - let df_orders = df!( - "order_id"=> &["a", "b", "c"], - "customer_id"=> &[1, 2, 2], - "amount"=> &[100, 200, 300], - )?; - println!("{}", &df_orders); - // --8<-- [end:innerdf2] + // --8<-- [start:props_prices] + let props_prices = CsvReadOptions::default() + .with_has_header(true) + .try_into_reader_with_file_path(Some( + "../../../assets/data/monopoly_props_prices.csv".into(), + ))? + .finish()? + .head(Some(5)); + println!("{}", props_prices); + // --8<-- [end:props_prices] - // --8<-- [start:inner] - let df_inner_customer_join = df_customers + // --8<-- [start:equi-join] + // In Rust, we cannot use the shorthand of specifying a common + // column name just once. + let result = props_groups .clone() .lazy() .join( - df_orders.clone().lazy(), - [col("customer_id")], - [col("customer_id")], - JoinArgs::new(JoinType::Inner), + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::default(), ) .collect()?; - println!("{}", &df_inner_customer_join); - // --8<-- [end:inner] + println!("{}", result); + // --8<-- [end:equi-join] - // --8<-- [start:left] - let df_left_join = df_customers + // --8<-- [start:props_groups2] + let props_groups2 = props_groups + .clone() + .lazy() + .with_column(col("property_name").str().to_lowercase()) + .collect()?; + println!("{}", props_groups2); + // --8<-- [end:props_groups2] + + // --8<-- [start:props_prices2] + let props_prices2 = props_prices + .clone() + .lazy() + .select([col("property_name").alias("name"), col("cost")]) + .collect()?; + println!("{}", props_prices2); + // --8<-- [end:props_prices2] + + // --8<-- [start:join-key-expression] + let result = props_groups2 .clone() .lazy() .join( - df_orders.clone().lazy(), - [col("customer_id")], - [col("customer_id")], - JoinArgs::new(JoinType::Left), + props_prices2.clone().lazy(), + [col("property_name")], + [col("name").str().to_lowercase()], + JoinArgs::default(), ) .collect()?; - println!("{}", &df_left_join); - // --8<-- [end:left] + println!("{}", result); + // --8<-- [end:join-key-expression] - // --8<-- [start:right] - let df_right_join = df_orders + // --8<-- [start:inner-join] + let result = props_groups .clone() .lazy() .join( - df_customers.clone().lazy(), - [col("customer_id")], - [col("customer_id")], - JoinArgs::new(JoinType::Right), + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Inner), ) .collect()?; - println!("{}", &df_right_join); - // --8<-- [end:right] + println!("{}", result); + // --8<-- [end:inner-join] - // --8<-- [start:full] - let df_full_join = df_customers + // --8<-- [start:left-join] + let result = props_groups .clone() .lazy() .join( - df_orders.clone().lazy(), - [col("customer_id")], - [col("customer_id")], - JoinArgs::new(JoinType::Full), + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Left), ) .collect()?; - println!("{}", &df_full_join); - // --8<-- [end:full] + println!("{}", result); + // --8<-- [end:left-join] - // --8<-- [start:full_coalesce] - let df_full_join = df_customers + // --8<-- [start:right-join] + let result = props_groups .clone() .lazy() .join( - df_orders.clone().lazy(), - [col("customer_id")], - [col("customer_id")], - JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Right), ) .collect()?; - println!("{}", &df_full_join); - // --8<-- [end:full_coalesce] + println!("{}", result); + // --8<-- [end:right-join] - // --8<-- [start:df3] - let df_colors = df!( - "color"=> &["red", "blue", "green"], - )?; - println!("{}", &df_colors); - // --8<-- [end:df3] - - // --8<-- [start:df4] - let df_sizes = df!( - "size"=> &["S", "M", "L"], - )?; - println!("{}", &df_sizes); - // --8<-- [end:df4] + // --8<-- [start:left-right-join-equals] + // `equals_missing` is needed instead of `equals` + // so that missing values compare as equal. + let dfs_match = result.equals_missing( + &props_prices + .clone() + .lazy() + .join( + props_groups.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Left), + ) + .select([ + // Reorder the columns to match the order of `result`. + col("group"), + col("property_name"), + col("cost"), + ]) + .collect()?, + ); + println!("{}", dfs_match); + // --8<-- [end:left-right-join-equals] - // --8<-- [start:cross] - let df_cross_join = df_colors + // --8<-- [start:full-join] + let result = props_groups .clone() .lazy() - .cross_join(df_sizes.clone().lazy(), None) + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Full), + ) .collect()?; - println!("{}", &df_cross_join); - // --8<-- [end:cross] + println!("{}", result); + // --8<-- [end:full-join] - // --8<-- [start:df5] - let df_cars = df!( - "id"=> &["a", "b", "c"], - "make"=> &["ford", "toyota", "bmw"], - )?; - println!("{}", &df_cars); - // --8<-- [end:df5] - - // --8<-- [start:df6] - let df_repairs = df!( - "id"=> &["c", "c"], - "cost"=> &[100, 200], - )?; - println!("{}", &df_repairs); - // --8<-- [end:df6] - - // --8<-- [start:inner2] - let df_inner_join = df_cars + // --8<-- [start:full-join-coalesce] + let result = props_groups .clone() .lazy() - .inner_join(df_repairs.clone().lazy(), col("id"), col("id")) + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + ) .collect()?; - println!("{}", &df_inner_join); - // --8<-- [end:inner2] + println!("{}", result); + // --8<-- [end:full-join-coalesce] - // --8<-- [start:semi] - let df_semi_join = df_cars + // --8<-- [start:semi-join] + let result = props_groups .clone() .lazy() .join( - df_repairs.clone().lazy(), - [col("id")], - [col("id")], + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], JoinArgs::new(JoinType::Semi), ) .collect()?; - println!("{}", &df_semi_join); - // --8<-- [end:semi] + println!("{}", result); + // --8<-- [end:semi-join] - // --8<-- [start:anti] - let df_anti_join = df_cars + // --8<-- [start:anti-join] + let result = props_groups .clone() .lazy() .join( - df_repairs.clone().lazy(), - [col("id")], - [col("id")], + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], JoinArgs::new(JoinType::Anti), ) .collect()?; - println!("{}", &df_anti_join); - // --8<-- [end:anti] + println!("{}", result); + // --8<-- [end:anti-join] + + // --8<-- [start:players] + let players = df!( + "name" => ["Alice", "Bob"], + "cash" => [78, 135], + )?; + println!("{}", players); + // --8<-- [end:players] + + // --8<-- [start:non-equi] + let result = players + .clone() + .lazy() + .join_builder() + .with(props_prices.clone().lazy()) + .join_where(vec![col("cash").cast(DataType::Int64).gt(col("cost"))]) + .collect()?; + println!("{}", result); + // --8<-- [end:non-equi] - // --8<-- [start:df7] + // --8<-- [start:df_trades] use chrono::prelude::*; + let df_trades = df!( - "time"=> &[ - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 3, 0).unwrap(), - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), - ], - "stock"=> &["A", "B", "B", "C"], - "trade"=> &[101, 299, 301, 500], + "time" => [ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 3, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock" => ["A", "B", "B", "C"], + "trade" => [101, 299, 301, 500], )?; - println!("{}", &df_trades); - // --8<-- [end:df7] + println!("{}", df_trades); + // --8<-- [end:df_trades] - // --8<-- [start:df8] + // --8<-- [start:df_quotes] let df_quotes = df!( - "time"=> &[ - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 0, 0).unwrap(), - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 2, 0).unwrap(), - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 4, 0).unwrap(), - NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), - ], - "stock"=> &["A", "B", "C", "A"], - "quote"=> &[100, 300, 501, 102], + "time" => [ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 2, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 4, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock" => ["A", "B", "C", "A"], + "quote" => [100, 300, 501, 102], )?; - - println!("{}", &df_quotes); - // --8<-- [end:df8] - - // --8<-- [start:asofpre] - let df_trades = df_trades - .sort( - ["time"], - SortMultipleOptions::default().with_maintain_order(true), - ) - .unwrap(); - let df_quotes = df_quotes - .sort( - ["time"], - SortMultipleOptions::default().with_maintain_order(true), - ) - .unwrap(); - // --8<-- [end:asofpre] + println!("{}", df_quotes); + // --8<-- [end:df_quotes] // --8<-- [start:asof] - let df_asof_join = df_trades.join_asof_by( + let result = df_trades.join_asof_by( &df_quotes, "time", "time", @@ -223,11 +257,11 @@ fn main() -> Result<(), Box> { AsofStrategy::Backward, None, )?; - println!("{}", &df_asof_join); + println!("{}", result); // --8<-- [end:asof] - // --8<-- [start:asof2] - let df_asof_tolerance_join = df_trades.join_asof_by( + // --8<-- [start:asof-tolerance] + let result = df_trades.join_asof_by( &df_quotes, "time", "time", @@ -236,8 +270,22 @@ fn main() -> Result<(), Box> { AsofStrategy::Backward, Some(AnyValue::Duration(60000, TimeUnit::Milliseconds)), )?; - println!("{}", &df_asof_tolerance_join); - // --8<-- [end:asof2] + println!("{}", result); + // --8<-- [end:asof-tolerance] + + // --8<-- [start:cartesian-product] + let tokens = df!( + "monopoly_token" => ["hat", "shoe", "boat"], + )?; + + let result = players + .clone() + .lazy() + .select([col("name")]) + .cross_join(tokens.clone().lazy(), None) + .collect()?; + println!("{}", result); + // --8<-- [end:cartesian-product] Ok(()) } diff --git a/docs/source/user-guide/concepts/streaming.md b/docs/source/user-guide/concepts/_streaming.md similarity index 73% rename from docs/source/user-guide/concepts/streaming.md rename to docs/source/user-guide/concepts/_streaming.md index 0dbafec6ec7b..e4427c10481a 100644 --- a/docs/source/user-guide/concepts/streaming.md +++ b/docs/source/user-guide/concepts/_streaming.md @@ -1,6 +1,8 @@ -# Streaming API +# Streaming -One additional benefit of the lazy API is that it allows queries to be executed in a streaming manner. Instead of processing the data all-at-once Polars can execute the query in batches allowing you to process datasets that are larger-than-memory. + + +One additional benefit of the lazy API is that it allows queries to be executed in a streaming manner. Instead of processing all the data at once, Polars can execute the query in batches allowing you to process datasets that do not fit in memory. To tell Polars we want to execute a query in streaming mode we pass the `streaming=True` argument to `collect` @@ -8,18 +10,18 @@ To tell Polars we want to execute a query in streaming mode we pass the `streami ## When is streaming available? -Streaming is still in development. We can ask Polars to execute any lazy query in streaming mode. However, not all lazy operations support streaming. If there is an operation for which streaming is not supported Polars will run the query in non-streaming mode. +Streaming is still in development. We can ask Polars to execute any lazy query in streaming mode. However, not all lazy operations support streaming. If there is an operation for which streaming is not supported, Polars will run the query in non-streaming mode. Streaming is supported for many operations including: -- `filter`,`slice`,`head`,`tail` -- `with_columns`,`select` +- `filter`, `slice`, `head`, `tail` +- `with_columns`, `select` - `group_by` - `join` - `unique` - `sort` -- `explode`,`unpivot` -- `scan_csv`,`scan_parquet`,`scan_ipc` +- `explode`, `unpivot` +- `scan_csv`, `scan_parquet`, `scan_ipc` This list is not exhaustive. Polars is in active development, and more operations can be added without explicit notice. diff --git a/docs/source/user-guide/concepts/contexts.md b/docs/source/user-guide/concepts/contexts.md deleted file mode 100644 index 2b0e004837f3..000000000000 --- a/docs/source/user-guide/concepts/contexts.md +++ /dev/null @@ -1,67 +0,0 @@ -# Contexts - -Polars has developed its own Domain Specific Language (DSL) for transforming data. The language is very easy to use and allows for complex queries that remain human readable. The two core components of the language are Contexts and Expressions, the latter we will cover in the next section. - -A context, as implied by the name, refers to the context in which an expression needs to be evaluated. There are three main contexts [^1]: - -1. Selection: `df.select(...)`, `df.with_columns(...)` -1. Filtering: `df.filter()` -1. Group by / Aggregation: `df.group_by(...).agg(...)` - -The examples below are performed on the following `DataFrame`: - -{{code_block('user-guide/concepts/contexts','dataframe',['DataFrame'])}} - -```python exec="on" result="text" session="user-guide/contexts" ---8<-- "python/user-guide/concepts/contexts.py:setup" ---8<-- "python/user-guide/concepts/contexts.py:dataframe" -``` - -## Selection - -The selection context applies expressions over columns. A `select` may produce new columns that are aggregations, combinations of expressions, or literals. - -The expressions in a selection context must produce `Series` that are all the same length or have a length of 1. Literals are treated as length-1 `Series`. - -When some expressions produce length-1 `Series` and some do not, the length-1 `Series` will be broadcast to match the length of the remaining `Series`. -Note that broadcasting can also occur within expressions: for instance, in `pl.col.value() / pl.col.value.sum()`, each element of the `value` column is divided by the column's sum. - -{{code_block('user-guide/concepts/contexts','select',['select'])}} - -```python exec="on" result="text" session="user-guide/contexts" ---8<-- "python/user-guide/concepts/contexts.py:select" -``` - -As you can see from the query, the selection context is very powerful and allows you to evaluate arbitrary expressions independent of (and in parallel to) each other. - -Similar to the `select` statement, the `with_columns` statement also enters into the selection context. The main difference between `with_columns` and `select` is that `with_columns` retains the original columns and adds new ones, whereas `select` drops the original columns. - -{{code_block('user-guide/concepts/contexts','with_columns',['with_columns'])}} - -```python exec="on" result="text" session="user-guide/contexts" ---8<-- "python/user-guide/concepts/contexts.py:with_columns" -``` - -## Filtering - -The filtering context filters a `DataFrame` based on one or more expressions that evaluate to the `Boolean` data type. - -{{code_block('user-guide/concepts/contexts','filter',['filter'])}} - -```python exec="on" result="text" session="user-guide/contexts" ---8<-- "python/user-guide/concepts/contexts.py:filter" -``` - -## Group by / aggregation - -In the `group_by` context, expressions work on groups and thus may yield results of any length (a group may have many members). - -{{code_block('user-guide/concepts/contexts','group_by',['group_by'])}} - -```python exec="on" result="text" session="user-guide/contexts" ---8<-- "python/user-guide/concepts/contexts.py:group_by" -``` - -As you can see from the result all expressions are applied to the group defined by the `group_by` context. Besides the standard `group_by`, `group_by_dynamic`, and `group_by_rolling` are also entrances to the group by context. - -[^1]: There are additional List and SQL contexts which are covered later in this guide. But for simplicity, we leave them out of scope for now. diff --git a/docs/source/user-guide/concepts/data-structures.md b/docs/source/user-guide/concepts/data-structures.md deleted file mode 100644 index 860ac9da99bb..000000000000 --- a/docs/source/user-guide/concepts/data-structures.md +++ /dev/null @@ -1,68 +0,0 @@ -# Data structures - -The core base data structures provided by Polars are `Series` and `DataFrame`. - -## Series - -Series are a 1-dimensional data structure. Within a series all elements have the same [Data Type](data-types/overview.md) . -The snippet below shows how to create a simple named `Series` object. - -{{code_block('user-guide/concepts/data-structures','series',['Series'])}} - -```python exec="on" result="text" session="user-guide/data-structures" ---8<-- "python/user-guide/concepts/data-structures.py:series" -``` - -## DataFrame - -A `DataFrame` is a 2-dimensional data structure that is backed by a `Series`, and it can be seen as an abstraction of a collection (e.g. list) of `Series`. Operations that can be executed on a `DataFrame` are very similar to what is done in a `SQL` like query. You can `GROUP BY`, `JOIN`, `PIVOT`, but also define custom functions. - -{{code_block('user-guide/concepts/data-structures','dataframe',['DataFrame'])}} - -```python exec="on" result="text" session="user-guide/data-structures" ---8<-- "python/user-guide/concepts/data-structures.py:dataframe" -``` - -### Viewing data - -This part focuses on viewing data in a `DataFrame`. We will use the `DataFrame` from the previous example as a starting point. - -#### Head - -The `head` function shows by default the first 5 rows of a `DataFrame`. You can specify the number of rows you want to see (e.g. `df.head(10)`). - -{{code_block('user-guide/concepts/data-structures','head',['head'])}} - -```python exec="on" result="text" session="user-guide/data-structures" ---8<-- "python/user-guide/concepts/data-structures.py:head" -``` - -#### Tail - -The `tail` function shows the last 5 rows of a `DataFrame`. You can also specify the number of rows you want to see, similar to `head`. - -{{code_block('user-guide/concepts/data-structures','tail',['tail'])}} - -```python exec="on" result="text" session="user-guide/data-structures" ---8<-- "python/user-guide/concepts/data-structures.py:tail" -``` - -#### Sample - -If you want to get an impression of the data of your `DataFrame`, you can also use `sample`. With `sample` you get an _n_ number of random rows from the `DataFrame`. - -{{code_block('user-guide/concepts/data-structures','sample',['sample'])}} - -```python exec="on" result="text" session="user-guide/data-structures" ---8<-- "python/user-guide/concepts/data-structures.py:sample" -``` - -#### Describe - -`Describe` returns summary statistics of your `DataFrame`. It will provide several quick statistics if possible. - -{{code_block('user-guide/concepts/data-structures','describe',['describe'])}} - -```python exec="on" result="text" session="user-guide/data-structures" ---8<-- "python/user-guide/concepts/data-structures.py:describe" -``` diff --git a/docs/source/user-guide/concepts/data-types-and-structures.md b/docs/source/user-guide/concepts/data-types-and-structures.md new file mode 100644 index 000000000000..2de8120f05a3 --- /dev/null +++ b/docs/source/user-guide/concepts/data-types-and-structures.md @@ -0,0 +1,176 @@ +# Data types and structures + +## Data types + +Polars supports a variety of data types that fall broadly under the following categories: + +- Numeric data types: signed integers, unsigned integers, floating point numbers, and decimals. +- Nested data types: lists, structs, and arrays. +- Temporal: dates, datetimes, times, and time deltas. +- Miscellaneous: strings, binary data, Booleans, categoricals, enums, and objects. + +All types support missing values represented by the special value `null`. +This is not to be conflated with the special value `NaN` in floating number data types; see the [section about floating point numbers](#floating-point-numbers) for more information. + +You can also find a [full table with all data types supported in the appendix](#appendix-full-data-types-table) with notes on when to use each data type and with links to relevant parts of the documentation. + +## Series + +The core base data structures provided by Polars are series and dataframes. +A series is a 1-dimensional homogeneous data structure. +By “homogeneous” we mean that all elements inside a series have the same data type. +The snippet below shows how to create a named series: + +{{code_block('user-guide/concepts/data-types-and-structures','series',['Series'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:series" +``` + +When creating a series, Polars will infer the data type from the values you provide. +You can specify a concrete data type to override the inference mechanism: + +{{code_block('user-guide/concepts/data-types-and-structures','series-dtype',['Series'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:series-dtype" +``` + +## Dataframe + +A dataframe is a 2-dimensional heterogeneous data structure that contains uniquely named series. +By holding your data in a dataframe you will be able to use the Polars API to write queries that manipulate your data. +You will be able to do this by using the [contexts and expressions provided by Polars](expressions-and-contexts.md) that we will talk about next. + +The snippet below shows how to create a dataframe from a dictionary of lists: + +{{code_block('user-guide/concepts/data-types-and-structures','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:df" +``` + +### Inspecting a dataframe + +In this subsection we will show some useful methods to quickly inspect a dataframe. +We will use the dataframe we created earlier as a starting point. + +#### Head + +The function `head` shows the first rows of a dataframe. +By default, you get the first 5 rows but you can also specify the number of rows you want: + +{{code_block('user-guide/concepts/data-types-and-structures','head',['head'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:head" +``` + +#### Glimpse + +The function `glimpse` is another function that shows the values of the first few rows of a dataframe, but formats the output differently from `head`. +Here, each line of the output corresponds to a single column, making it easier to take inspect wider dataframes: + +=== ":fontawesome-brands-python: Python" +[:material-api: `glimpse`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.glimpse.html) + +```python +--8<-- "python/user-guide/concepts/data-types-and-structures.py:glimpse" +``` + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:glimpse" +``` + +!!! info +`glimpse` is only available for Python users. + +#### Tail + +The function `tail` shows the last rows of a dataframe. +By default, you get the last 5 rows but you can also specify the number of rows you want, similar to how `head` works: + +{{code_block('user-guide/concepts/data-types-and-structures','tail',['tail'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:tail" +``` + +#### Sample + +If you think the first or last rows of your dataframe are not representative of your data, you can use `sample` to get an arbitrary number of randomly selected rows from the DataFrame. +Note that the rows are not necessarily returned in the same order as they appear in the dataframe: + +{{code_block('user-guide/concepts/data-types-and-structures','sample',['sample'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:sample" +``` + +#### Describe + +You can also use `describe` to compute summary statistics for all columns of your dataframe: + +{{code_block('user-guide/concepts/data-types-and-structures','describe',['describe'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:describe" +``` + +## Schema + +When talking about data (in a dataframe or otherwise) we can refer to its schema. +The schema is a mapping of column or series names to the data types of those same columns or series. + +Much like with series, Polars will infer the schema of a dataframe when you create it but you can override the inference system if needed. +You can check the schema of a dataframe with `schema`: + +{{code_block('user-guide/concepts/data-types-and-structures','schema',[])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:schema" +``` + +## Data types internals + +Polars utilizes the [Arrow Columnar Format](https://arrow.apache.org/docs/format/Columnar.html) for its data orientation. +Following this specification allows Polars to transfer data to/from other tools that also use the Arrow specification with little to no overhead. + +Polars gets most of its performance from its query engine, the optimizations it performs on your query plans, and from the parallelization that it employs when running [your expressions](expressions-and-contexts.md#expressions). + +## Floating point numbers + +Polars generally follows the IEEE 754 floating point standard for `Float32` and `Float64`, with some exceptions: + +- Any `NaN` compares equal to any other `NaN`, and greater than any non-`NaN` value. +- Operations do not guarantee any particular behavior on the sign of zero or `NaN`, + nor on the payload of `NaN` values. This is not just limited to arithmetic operations, + e.g. a sort or group by operation may canonicalize all zeroes to +0 and all `NaN`s + to a positive `NaN` without payload for efficient equality checks. + +Polars always attempts to provide reasonably accurate results for floating point computations but does not provide guarantees +on the error unless mentioned otherwise. Generally speaking 100% accurate results are infeasibly expensive to achieve (requiring +much larger internal representations than 64-bit floats), and thus some error is always to be expected. + +## Appendix: full data types table + +| Type(s) | Details | +| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `Boolean` | Boolean type that is bit packed efficiently. | +| `Int8`, `Int16`, `Int32`, `Int64` | Varying-precision signed integer types. | +| `UInt8`, `UInt16`, `UInt32`, `UInt64` | Varying-precision unsigned integer types. | +| `Float32`, `Float64` | Varying-precision signed floating point numbers. | +| `Decimal` | Decimal 128-bit type with optional precision and non-negative scale. Use this if you need fine-grained control over the precision of your floats and the operations you make on them. See [Python's `decimal.Decimal`](https://docs.python.org/3/library/decimal.html) for documentation on what a decimal data type is. | +| `String` | Variable length UTF-8 encoded string data, typically Human-readable. | +| `Binary` | Stores arbitrary, varying length raw binary data. | +| `Date` | Represents a calendar date. | +| `Time` | Represents a time of day. | +| `Datetime` | Represents a calendar date and time of day. | +| `Duration` | Represents a time duration. | +| `Array` | Arrays with a known, fixed shape per series; akin to numpy arrays. [Learn more about how arrays and lists differ and how to work with both](../expressions/lists.md). | +| `List` | Homogeneous 1D container with variable length. [Learn more about how arrays and lists differ and how to work with both](../expressions/lists.md). | +| `Object` | Wraps arbitrary Python objects. | +| `Categorical` | Efficient encoding of string data where the categories are inferred at runtime. [Learn more about how categoricals and enums differ and how to work with both](../expressions/categorical-data-and-enums.md). | +| `Enum` | Efficient ordered encoding of a set of predetermined string categories. [Learn more about how categoricals and enums differ and how to work with both](../expressions/categorical-data-and-enums.md). | +| `Struct` | Composite product type that can store multiple fields. [Learn more about the data type `Struct` in its dedicated documentation section.](../expressions/structs.md). | +| `Null` | Represents null values. | diff --git a/docs/source/user-guide/concepts/data-types/overview.md b/docs/source/user-guide/concepts/data-types/overview.md deleted file mode 100644 index fe7b75f99246..000000000000 --- a/docs/source/user-guide/concepts/data-types/overview.md +++ /dev/null @@ -1,47 +0,0 @@ -# Overview - -Polars is entirely based on Arrow data types and backed by Arrow memory arrays. This makes data processing -cache-efficient and well-supported for Inter Process Communication. Most data types follow the exact implementation -from Arrow, with the exception of `String` (this is actually `LargeUtf8`), `Categorical`, and `Object` (support is limited). The data types are: - -| Group | Type | Details | -| -------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------ | -| Numeric | `Int8` | 8-bit signed integer. | -| | `Int16` | 16-bit signed integer. | -| | `Int32` | 32-bit signed integer. | -| | `Int64` | 64-bit signed integer. | -| | `UInt8` | 8-bit unsigned integer. | -| | `UInt16` | 16-bit unsigned integer. | -| | `UInt32` | 32-bit unsigned integer. | -| | `UInt64` | 64-bit unsigned integer. | -| | `Float32` | 32-bit floating point. | -| | `Float64` | 64-bit floating point. | -| Nested | `Struct` | A struct array is represented as a `Vec` and is useful to pack multiple/heterogeneous values in a single column. | -| | `List` | A list array contains a child array containing the list values and an offset array. (this is actually Arrow `LargeList` internally). | -| | `Array` | A fixed-size multidimensional array. | -| Temporal | `Date` | Date representation, internally represented as days since UNIX epoch encoded by a 32-bit signed integer. | -| | `Datetime` | Datetime representation, internally represented as microseconds since UNIX epoch encoded by a 64-bit signed integer. | -| | `Duration` | A timedelta type, internally represented as microseconds. Created when subtracting `Date/Datetime`. | -| | `Time` | Time representation, internally represented as nanoseconds since midnight. | -| Other | `Boolean` | Boolean type effectively bit packed. | -| | `String` | String data (this is actually Arrow `LargeUtf8` internally). | -| | `Binary` | Store data as bytes. | -| | `Object` | A limited supported data type that can be any value. | -| | `Categorical` | A categorical encoding of a set of strings. | -| | `Enum` | A fixed categorical encoding of a set of strings. | - -To learn more about the internal representation of these data types, check the [Arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html). - -## Floating Point - -Polars generally follows the IEEE 754 floating point standard for `Float32` and `Float64`, with some exceptions: - -- Any NaN compares equal to any other NaN, and greater than any non-NaN value. -- Operations do not guarantee any particular behavior on the sign of zero or NaN, - nor on the payload of NaN values. This is not just limited to arithmetic operations, - e.g. a sort or group by operation may canonicalize all zeroes to +0 and all NaNs - to a positive NaN without payload for efficient equality checks. - -Polars always attempts to provide reasonably accurate results for floating point computations but does not provide guarantees -on the error unless mentioned otherwise. Generally speaking 100% accurate results are infeasibly expensive to acquire (requiring -much larger internal representations than 64-bit floats), and thus some error is always to be expected. diff --git a/docs/source/user-guide/concepts/expressions-and-contexts.md b/docs/source/user-guide/concepts/expressions-and-contexts.md new file mode 100644 index 000000000000..4ec537b71fb9 --- /dev/null +++ b/docs/source/user-guide/concepts/expressions-and-contexts.md @@ -0,0 +1,204 @@ +# Expressions and contexts + +Polars has developed its own Domain Specific Language (DSL) for transforming data. +The language is very easy to use and allows for complex queries that remain human readable. +Expressions and contexts, which will be introduced here, are very important in achieving this readability while also allowing the Polars query engine to optimize your queries to make them run as fast as possible. + +## Expressions + +In Polars, an _expression_ is a lazy representation of a data transformation. +Expressions are modular and flexible, which means you can use them as building blocks to build more complex expressions. +Here is an example of a Polars expression: + +```python +--8<-- "python/user-guide/concepts/expressions.py:expression" +``` + +As you might be able to guess, this expression takes a column named “weight” and divides its values by the square of the values in a column “height”, computing a person's BMI. + +The code above expresses an abstract computation that we can save in a variable, manipulate further, or just print: + +```python +--8<-- "python/user-guide/concepts/expressions.py:print-expr" +``` + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:expression" +--8<-- "python/user-guide/concepts/expressions.py:print-expr" +``` + +Because expressions are lazy, no computations have taken place yet. +That's what we need contexts for. + +## Contexts + +Polars expressions need a _context_ in which they are executed to produce a result. +Depending on the context it is used in, the same Polars expression can produce different results. +In this section, we will learn about the four most common contexts that Polars provides[^1]: + +1. `select` +2. `with_columns` +3. `filter` +4. `group_by` + +We use the dataframe below to show how each of the contexts works. + +{{code_block('user-guide/concepts/expressions','df',[])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:df" +``` + +### `select` + +The selection context `select` applies expressions over columns. +The context `select` may produce new columns that are aggregations, combinations of other columns, or literals: + +{{code_block('user-guide/concepts/expressions','select-1',['select'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:select-1" +``` + +The expressions in a context `select` must produce series that are all the same length or they must produce a scalar. +Scalars will be broadcast to match the length of the remaining series. +Literals, like the number used above, are also broadcast. + +Note that broadcasting can also occur within expressions. +For instance, consider the expression below: + +{{code_block('user-guide/concepts/expressions','select-2',['select'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:select-2" +``` + +Both the subtraction and the division use broadcasting within the expression because the subexpressions that compute the mean and the standard deviation evaluate to single values. + +The context `select` is very flexible and powerful and allows you to evaluate arbitrary expressions independent of, and in parallel to, each other. +This is also true of the other contexts that we will see next. + +### `with_columns` + +The context `with_columns` is very similar to the context `select`. +The main difference between the two is that the context `with_columns` creates a new dataframe that contains the columns from the original dataframe and the new columns according to its input expressions, whereas the context `select` only includes the columns selected by its input expressions: + +{{code_block('user-guide/concepts/expressions','with_columns-1',['with_columns'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:with_columns-1" +``` + +Because of this difference between `select` and `with_columns`, the expressions used in a context `with_columns` must produce series that have the same length as the original columns in the dataframe, whereas it is enough for the expressions in the context `select` to produce series that have the same length among them. + +### `filter` + +The context `filter` filters the rows of a dataframe based on one or more expressions that evaluate to the Boolean data type. + +{{code_block('user-guide/concepts/expressions','filter-1',['filter'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:filter-1" +``` + +### `group_by` and aggregations + +In the context `group_by`, rows are grouped according to the unique values of the grouping expressions. +You can then apply expressions to the resulting groups, which may be of variable lengths. + +When using the context `group_by`, you can use an expression to compute the groupings dynamically: + +{{code_block('user-guide/concepts/expressions','group_by-1',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:group_by-1" +``` + +After using `group_by` we use `agg` to apply aggregating expressions to the groups. +Since in the example above we only specified the name of a column, we get the groups of that column as lists. + +We can specify as many grouping expressions as we'd like and the context `group_by` will group the rows according to the distinct values across the expressions specified. +Here, we group by a combination of decade of birth and whether the person is shorter than 1.7 metres: + +{{code_block('user-guide/concepts/expressions','group_by-2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:group_by-2" +``` + +The resulting dataframe, after applying aggregating expressions, contains one column per each grouping expression on the left and then as many columns as needed to represent the results of the aggregating expressions. +In turn, we can specify as many aggregating expressions as we want: + +{{code_block('user-guide/concepts/expressions','group_by-3',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:group_by-3" +``` + +See also `group_by_dynamic` and `group_by_rolling` for other grouping contexts. + +## Expression expansion + +The last example contained two grouping expressions and three aggregating expressions, and yet the resulting dataframe contained six columns instead of five. +If we look closely, the last aggregating expression mentioned two different columns: “weight” and “height”. + +Polars expressions support a feature called _expression expansion_. +Expression expansion is like a shorthand notation for when you want to apply the same transform to multiple columns. +As we have seen, the expression + +```python +pl.col("weight", "height").mean().name.prefix("avg_") +``` + +will compute the mean value of the columns “weight” and “height” and will rename them as “avg_weight” and “avg_height”, respectively. +In fact, the expression above is equivalent to using the two following expressions: + +```python +[ + pl.col("weight").mean().alias("avg_weight"), + pl.col("height").mean().alias("avg_height"), +] +``` + +In this case, this expression expands into two independent expressions that Polars can execute in parallel. +In other cases, we may not be able to know in advance how many independent expressions an expression will unfold into. + +Consider this simple but elucidative example: + +```python +(pl.col(pl.Float64) * 1.1).name.suffix("*1.1") +``` + +This expression will multiply all columns with data type `Float64` by `1.1`. +The number of columns this applies to depends on the schema of each dataframe. +In the case of the dataframe we have been using, it applies to two columns: + +{{code_block('user-guide/concepts/expressions','expression-expansion-1',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:expression-expansion-1" +``` + +In the case of the dataframe `df2` below, the same expression expands to 0 columns because no column has the data type `Float64`: + +{{code_block('user-guide/concepts/expressions','expression-expansion-2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:expression-expansion-2" +``` + +It is equally easy to imagine a scenario where the same expression would expand to dozens of columns. + +Next, you will learn about [the lazy API and the function `explain`](lazy-api.md#previewing-the-query-plan), which you can use to preview what an expression will expand to given a schema. + +## Conclusion + +Because expressions are lazy, when you use an expression inside a context Polars can try to simplify your expression before running the data transformation it expresses. +Separate expressions within a context are embarrassingly parallel and Polars will take advantage of that, while also parallelizing expression execution when using expression expansion. +Further performance gains can be obtained when using [the lazy API of Polars](lazy-api.md), which is introduced next. + +We have only scratched the surface of the capabilities of expressions. +There are a ton more expressions and they can be combined in a variety of ways. +See the [section on expressions](../expressions/index.md) for a deeper dive on the different types of expressions available. + +[^1]: There are additional List and SQL contexts which are covered later in this guide. But for simplicity, we leave them out of scope for now. diff --git a/docs/source/user-guide/concepts/expressions.md b/docs/source/user-guide/concepts/expressions.md deleted file mode 100644 index 9b857bb90aa0..000000000000 --- a/docs/source/user-guide/concepts/expressions.md +++ /dev/null @@ -1,55 +0,0 @@ -# Expressions - -Polars has a powerful concept called expressions that is central to its very fast performance. - -Expressions are at the core of many data science operations: - -- taking a sample of rows from a column -- multiplying values in a column -- extracting a column of years from dates -- convert a column of strings to lowercase -- and so on! - -However, expressions are also used within other operations: - -- taking the mean of a group in a `group_by` operation -- calculating the size of groups in a `group_by` operation -- taking the sum horizontally across columns - -Polars performs these core data transformations very quickly by: - -- automatic query optimization on each expression -- automatic parallelization of expressions on many columns - -An expression is a tree of operations that describe how to construct one or more -Series. As the outputs are Series, it is straightforward to -apply a sequence of expressions (similar to method chaining in pandas) each of which -transforms the output from the previous step. - -If this seems abstract and confusing - don't worry! People quickly develop an intuition for expressions -just by looking at a few examples. We'll do that next! - -## Examples - -The following is an expression: - -{{code_block('user-guide/concepts/expressions','example1',['col','sort','head'])}} - -The snippet above says: - -1. Select column "foo" -1. Then sort the column (not in reversed order) -1. Then take the first two values of the sorted output - -The power of expressions is that every expression produces a new expression, and that they -can be _piped_ together. You can run an expression by passing them to one of Polars execution contexts. - -Here we run two expressions by running `df.select`: - -{{code_block('user-guide/concepts/expressions','example2',['select'])}} - -All expressions are run in parallel, meaning that separate Polars expressions are **embarrassingly parallel**. Note that within an expression there may be more parallelization going on. - -## Conclusion - -This is the tip of the iceberg in terms of possible expressions. There are a ton more, and they can be combined in a variety of ways. This page is intended to get you familiar with the concept of expressions, in the section on [expressions](../expressions/operators.md) we will dive deeper. diff --git a/docs/source/user-guide/concepts/index.md b/docs/source/user-guide/concepts/index.md index 63a2ebeabe44..c4b28e50721f 100644 --- a/docs/source/user-guide/concepts/index.md +++ b/docs/source/user-guide/concepts/index.md @@ -1,11 +1,7 @@ # Concepts -The `Concepts` chapter describes the core concepts of the Polars API. Understanding these will help you optimise your queries on a daily basis. We will cover the following topics: +This chapter describes the core concepts of the Polars API. Understanding these will help you optimise your queries on a daily basis. We will cover the following topics: -- [Data Types: Overview](data-types/overview.md) -- [Data Types: Categoricals](data-types/categoricals.md) -- [Data structures](data-structures.md) -- [Contexts](contexts.md) -- [Expressions](expressions.md) -- [Lazy vs eager](lazy-vs-eager.md) -- [Streaming](streaming.md) +- [Data types and structures](data-types-and-structures.md) +- [Expressions and contexts](expressions-and-contexts.md) +- [Lazy API](lazy-api.md) diff --git a/docs/source/user-guide/concepts/lazy-api.md b/docs/source/user-guide/concepts/lazy-api.md new file mode 100644 index 000000000000..85b985e1a74c --- /dev/null +++ b/docs/source/user-guide/concepts/lazy-api.md @@ -0,0 +1,65 @@ +# Lazy API + +Polars supports two modes of operation: lazy and eager. The examples so far have used the eager API, in which the query is executed immediately. +In the lazy API, the query is only evaluated once it is _collected_. Deferring the execution to the last minute can have significant performance advantages and is why the lazy API is preferred in most cases. Let us demonstrate this with an example: + +{{code_block('user-guide/concepts/lazy-vs-eager','eager',['read_csv'])}} + +In this example we use the eager API to: + +1. Read the iris [dataset](https://archive.ics.uci.edu/dataset/53/iris). +1. Filter the dataset based on sepal length. +1. Calculate the mean of the sepal width per species. + +Every step is executed immediately returning the intermediate results. This can be very wasteful as we might do work or load extra data that is not being used. If we instead used the lazy API and waited on execution until all the steps are defined then the query planner could perform various optimizations. In this case: + +- Predicate pushdown: Apply filters as early as possible while reading the dataset, thus only reading rows with sepal length greater than 5. +- Projection pushdown: Select only the columns that are needed while reading the dataset, thus removing the need to load additional columns (e.g., petal length and petal width). + +{{code_block('user-guide/concepts/lazy-vs-eager','lazy',['scan_csv'])}} + +These will significantly lower the load on memory & CPU thus allowing you to fit bigger datasets in memory and process them faster. Once the query is defined you call `collect` to inform Polars that you want to execute it. You can [learn more about the lazy API in its dedicated chapter](../lazy/index.md). + +!!! info "Eager API" + + In many cases the eager API is actually calling the lazy API under the hood and immediately collecting the result. This has the benefit that within the query itself optimization(s) made by the query planner can still take place. + +## When to use which + +In general, the lazy API should be preferred unless you are either interested in the intermediate results or are doing exploratory work and don't know yet what your query is going to look like. + +## Previewing the query plan + +When using the lazy API you can use the function `explain` to ask Polars to create a description of the query plan that will be executed once you collect the results. +This can be useful if you want to see what types of optimizations Polars performs on your queries. +We can ask Polars to explain the query `q` we defined above: + +{{code_block('user-guide/concepts/lazy-vs-eager','explain',['explain'])}} + +```python exec="on" result="text" session="user-guide/concepts/lazy-api" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:import" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:lazy" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:explain" +``` + +Immediately, we can see in the explanation that Polars did apply predicate pushdown, as it is only reading rows where the sepal length is greater than 5, and it did apply projection pushdown, as it is only reading the columns that are needed by the query. + +The function `explain` can also be used to see how expression expansion will unfold in the context of a given schema. +Consider the example expression from the [section on expression expansion](expressions-and-contexts.md#expression-expansion): + +```python +(pl.col(pl.Float64) * 1.1).name.suffix("*1.1") +``` + +We can use `explain` to see how this expression would evaluate against an arbitrary schema: + +=== ":fontawesome-brands-python: Python" +[:material-api: `explain`](https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.explain.html) + +```python +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:explain-expression-expansion" +``` + +```python exec="on" result="text" session="user-guide/concepts/lazy-api" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:explain-expression-expansion" +``` diff --git a/docs/source/user-guide/concepts/lazy-vs-eager.md b/docs/source/user-guide/concepts/lazy-vs-eager.md deleted file mode 100644 index 4822f81a5d1d..000000000000 --- a/docs/source/user-guide/concepts/lazy-vs-eager.md +++ /dev/null @@ -1,28 +0,0 @@ -# Lazy / eager API - -Polars supports two modes of operation: lazy and eager. In the eager API the query is executed immediately while in the lazy API the query is only evaluated once it is 'needed'. Deferring the execution to the last minute can have significant performance advantages and is why the Lazy API is preferred in most cases. Let us demonstrate this with an example: - -{{code_block('user-guide/concepts/lazy-vs-eager','eager',['read_csv'])}} - -In this example we use the eager API to: - -1. Read the iris [dataset](https://archive.ics.uci.edu/dataset/53/iris). -1. Filter the dataset based on sepal length -1. Calculate the mean of the sepal width per species - -Every step is executed immediately returning the intermediate results. This can be very wasteful as we might do work or load extra data that is not being used. If we instead used the lazy API and waited on execution until all the steps are defined then the query planner could perform various optimizations. In this case: - -- Predicate pushdown: Apply filters as early as possible while reading the dataset, thus only reading rows with sepal length greater than 5. -- Projection pushdown: Select only the columns that are needed while reading the dataset, thus removing the need to load additional columns (e.g. petal length & petal width) - -{{code_block('user-guide/concepts/lazy-vs-eager','lazy',['scan_csv'])}} - -These will significantly lower the load on memory & CPU thus allowing you to fit bigger datasets in memory and process faster. Once the query is defined you call `collect` to inform Polars that you want to execute it. In the section on Lazy API we will go into more details on its implementation. - -!!! info "Eager API" - - In many cases the eager API is actually calling the lazy API under the hood and immediately collecting the result. This has the benefit that within the query itself optimization(s) made by the query planner can still take place. - -### When to use which - -In general the lazy API should be preferred unless you are either interested in the intermediate results or are doing exploratory work and don't know yet what your query is going to look like. diff --git a/docs/source/user-guide/ecosystem.md b/docs/source/user-guide/ecosystem.md index 21f1dbc2ba60..9a8f96c4f72f 100644 --- a/docs/source/user-guide/ecosystem.md +++ b/docs/source/user-guide/ecosystem.md @@ -20,25 +20,7 @@ On this page you can find a non-exhaustive list of libraries and tools that supp ### Data visualisation -#### hvPlot - -[hvPlot](https://hvplot.holoviz.org/) is available as the default plotting backend for Polars making it simple to create interactive and static visualisations. You can use hvPlot by using the feature flag `plot` during installing. - -```python -pip install 'polars[plot]' -``` - -#### Matplotlib - -[Matplotlib](https://matplotlib.org/) is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible. - -#### Plotly - -[Plotly](https://plotly.com/python/) is an interactive, open-source, and browser-based graphing library for Python. Built on top of plotly.js, it ships with over 30 chart types, including scientific charts, 3D graphs, statistical charts, SVG maps, financial charts, and more. - -#### [Seaborn](https://seaborn.pydata.org/) - -Seaborn is a Python data visualization library based on Matplotlib. It provides a high-level interface for drawing attractive and informative statistical graphics. +See the [dedicated visualization section](misc/visualization.md). ### IO @@ -71,3 +53,7 @@ With [Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html #### Mage [Mage](https://www.mage.ai) is an open-source data pipeline tool for transforming and integrating data. Learn about integration between Polars and Mage at [docs.mage.ai](https://docs.mage.ai/integrations/polars). + +#### marimo + +[marimo](https://marimo.io) is a reactive notebook for Python and SQL that models notebooks as dataflow graphs. It offers built-in support for Polars, allowing seamless integration of Polars dataframes in an interactive, reactive environment - such as displaying rich Polars tables, no-code transformations of Polars dataframes, or selecting points on a Polars-backed reactive chart. diff --git a/docs/source/user-guide/expressions/casting.md b/docs/source/user-guide/expressions/casting.md index 6deddaecb684..f0c625d19f28 100644 --- a/docs/source/user-guide/expressions/casting.md +++ b/docs/source/user-guide/expressions/casting.md @@ -1,6 +1,7 @@ # Casting -Casting converts the underlying [`DataType`](../concepts/data-types/overview.md) of a column to a new one. Polars uses Arrow to manage the data in memory and relies on the compute kernels in the [Rust implementation](https://github.com/jorgecarleitao/arrow2) to do the conversion. Casting is available with the `cast()` method. +Casting converts the [underlying `DataType` of a column](../concepts/data-types-and-structures.md) to a new one. +Casting is available with the `cast()` method. The `cast` method includes a `strict` parameter that determines how Polars behaves when it encounters a value that can't be converted from the source `DataType` to the target `DataType`. By default, `strict=True`, which means that Polars will throw an error to notify the user of the failed conversion and provide details on the values that couldn't be cast. On the other hand, if `strict=False`, any values that can't be converted to the target `DataType` will be quietly converted to `null`. diff --git a/docs/source/user-guide/concepts/data-types/categoricals.md b/docs/source/user-guide/expressions/categorical-data-and-enums.md similarity index 95% rename from docs/source/user-guide/concepts/data-types/categoricals.md rename to docs/source/user-guide/expressions/categorical-data-and-enums.md index e5b469f45ad6..9d1e5eee8905 100644 --- a/docs/source/user-guide/concepts/data-types/categoricals.md +++ b/docs/source/user-guide/expressions/categorical-data-and-enums.md @@ -1,4 +1,4 @@ -# Categorical data +# Categorical data and enums Categorical data represents string data where the values in the column have a finite set of values (usually way smaller than the length of the column). You can think about columns on gender, countries, currency pairings, etc. Storing these values as plain strings is a waste of memory and performance as we will be repeating the same string over and over again. Additionally, in the case of joins we are stuck with expensive string comparisons. @@ -106,7 +106,7 @@ In Polars a categorical is defined as a string column which is encoded by a dict The physical `0` in this case encodes (or maps) to the value 'Polar Bear', the value `1` encodes to 'Panda Bear' and the value `2` to 'Brown Bear'. This encoding has the benefit of only storing the string values once. Additionally, when we perform operations (e.g. sorting, counting) we can work directly on the physical representation which is much faster than the working with string data. -### `Enum` vs `Categorical` +## `Enum` vs `Categorical` Polars supports two different DataTypes for working with categorical data: `Enum` and `Categorical`. When the categories are known up front use `Enum`. When you don't know the categories or they are not fixed then you use `Categorical`. In case your requirements change along the way you can always cast from one to the other. @@ -114,7 +114,7 @@ Polars supports two different DataTypes for working with categorical data: `Enum From the code block above you can see that the `Enum` data type requires the upfront while the categorical data type infers the categories. -#### `Categorical` data type +### `Categorical` data type The `Categorical` data type is a flexible one. Polars will add categories on the fly if it sees them. This sounds like a strictly better version compared to the `Enum` data type as we can simply infer the categories, however inferring comes at a cost. The main cost here is we have no control over our encodings. @@ -240,7 +240,7 @@ Polars encodes the string values in order as they appear. So the series would lo Combining the `Series` becomes a non-trivial task which is expensive as the physical value of `0` represents something different in both `Series`. Polars does support these types of operations for convenience, however in general these should be avoided due to its slower performance as it requires making both encodings compatible first before doing any merge operations. -##### Using the global string cache +#### Using the global string cache One way to handle this problem is to enable a `StringCache`. When you enable the `StringCache` strings are no longer encoded in the order they appear on a per-column basis. Instead, the string cache ensures a single encoding for each string. The string `Polar` will always map the same physical for all categorical columns made under the string cache. Merge operations (e.g. appends, joins) are cheap as there is no need to make the encodings compatible first, solving the problem we had above. @@ -249,12 +249,14 @@ Merge operations (e.g. appends, joins) are cheap as there is no need to make the However, the string cache does come at a small performance hit during construction of the `Series` as we need to look up / insert the string value in the cache. Therefore, it is preferred to use the `Enum` Data Type if you know your categories in advance. -#### `Enum data type` +### `Enum data type` In the `Enum` data type we specify the categories in advance. This way we ensure categoricals from different columns or different datasets have the same encoding and there is no need for expensive re-encoding or cache lookups. {{code_block('user-guide/concepts/data-types/categoricals','enum_append',[])}} + + Polars will raise an `OutOfBounds` error when a value is encountered which is not specified in the `Enum`. {{code_block('user-guide/concepts/data-types/categoricals','enum_error',[])}} @@ -264,14 +266,16 @@ Polars will raise an `OutOfBounds` error when a value is encountered which is no --8<-- "python/user-guide/concepts/data-types/categoricals.py:enum_error" ``` -### Comparisons +## Comparisons + + The following types of comparisons operators are allowed for categorical data: - Categorical vs Categorical - Categorical vs String -#### `Categorical` Type +### `Categorical` Type For the `Categorical` type comparisons are valid if they have the same global cache set or if they have the same underlying categories in the same order. @@ -296,7 +300,7 @@ For `Categorical` vs `String` comparisons Polars uses lexical ordering to determ --8<-- "python/user-guide/concepts/data-types/categoricals.py:str_compare" ``` -#### `Enum` Type +### `Enum` Type For `Enum` type comparisons are valid if they have the same categories. diff --git a/docs/source/user-guide/expressions/plugins.md b/docs/source/user-guide/expressions/plugins.md index e679b09ee180..9ef5633cfcd0 100644 --- a/docs/source/user-guide/expressions/plugins.md +++ b/docs/source/user-guide/expressions/plugins.md @@ -37,7 +37,7 @@ crate-type = ["cdylib"] [dependencies] polars = { version = "*" } -pyo3 = { version = "*", features = ["extension-module", "abi-py38"] } +pyo3 = { version = "*", features = ["extension-module", "abi3-py38"] } pyo3-polars = { version = "*", features = ["derive"] } serde = { version = "*", features = ["derive"] } ``` diff --git a/docs/source/user-guide/getting-started.md b/docs/source/user-guide/getting-started.md index 69aa8a02d8da..e571ea71cca1 100644 --- a/docs/source/user-guide/getting-started.md +++ b/docs/source/user-guide/getting-started.md @@ -22,7 +22,7 @@ This chapter is here to help you get started with Polars. It covers all the fund ## Reading & writing -Polars supports reading and writing for common file formats (e.g. csv, json, parquet), cloud storage (S3, Azure Blob, BigQuery) and databases (e.g. postgres, mysql). Below, we create a small dataframe and show how to write it to disk and read it back. +Polars supports reading and writing for common file formats (e.g., csv, json, parquet), cloud storage (S3, Azure Blob, BigQuery) and databases (e.g., postgres, mysql). Below, we create a small dataframe and show how to write it to disk and read it back. {{code_block('user-guide/getting-started','df',['DataFrame'])}} @@ -60,7 +60,7 @@ Below, we will show examples of Polars expressions inside different contexts: - `filter` - `group_by` -For a more detailed exploration of contexts and expressions see the respective user guide sections: [Contexts](concepts/contexts.md) and [Expressions](concepts/expressions.md). +For a more [detailed exploration of expressions and contexts see the respective user guide section](concepts/expressions-and-contexts.md). ### `select` @@ -155,7 +155,7 @@ In the example below we combine some of the contexts we have seen so far to crea Polars provides a number of tools to combine two dataframes. In this section, we show an example of a join and an example of a concatenation. -### Joinining dataframes +### Joining dataframes Polars provides many different join algorithms. The example below shows how to use a left outer join to combine two dataframes when a column can be used as a unique identifier to establish a correspondence between rows across the dataframes: diff --git a/docs/source/user-guide/installation.md b/docs/source/user-guide/installation.md index 791362403456..fdfe83d49dee 100644 --- a/docs/source/user-guide/installation.md +++ b/docs/source/user-guide/installation.md @@ -21,9 +21,10 @@ Polars is a library and installation is as simple as invoking the package manage polars = { version = "x", features = ["lazy", ...]} ``` -### Big Index +## Big Index -By default, polars is limited to 2^32 (~4.2 billion rows). To increase this limit 2^64 (~18 quintillion) by enabling big index: +By default, Polars dataframes are limited to 232 rows (~4.3 billion). +Increase this limit to 264 (~18 quintillion) by enabling the big index extension: === ":fontawesome-brands-python: Python" @@ -41,9 +42,9 @@ By default, polars is limited to 2^32 (~4.2 billion rows). To increase this limi polars = { version = "x", features = ["bigidx", ...] } ``` -### Legacy CPU +## Legacy CPU -To install polars on an old CPU without [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support: +To install Polars for Python on an old CPU without [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, run: === ":fontawesome-brands-python: Python" @@ -53,7 +54,7 @@ To install polars on an old CPU without [AVX](https://en.wikipedia.org/wiki/Adva ## Importing -To use the library import it into your project +To use the library, simply import it into your project: === ":fontawesome-brands-python: Python" @@ -73,7 +74,7 @@ By using the above command you install the core of Polars onto your system. However, depending on your use case, you might want to install the optional dependencies as well. These are made optional to minimize the footprint. The flags are different depending on the programming language. -Throughout the user guide we will mention when a functionality is used that requires an additional dependency. +Throughout the user guide we will mention when a functionality used requires an additional dependency. ### Python @@ -84,15 +85,15 @@ pip install 'polars[numpy,fsspec]' #### All -| Tag | Description | -| --- | --------------------------------- | -| all | Install all optional dependencies | +| Tag | Description | +| --- | ---------------------------------- | +| all | Install all optional dependencies. | #### GPU -| Tag | Description | -| --- | -------------------------- | -| gpu | Run queries on NVIDIA GPUs | +| Tag | Description | +| --- | --------------------------- | +| gpu | Run queries on NVIDIA GPUs. | !!! note @@ -101,59 +102,59 @@ pip install 'polars[numpy,fsspec]' support](gpu-support.md) for more detailed instructions and prerequisites. -#### Interop +#### Interoperability -| Tag | Description | -| -------- | ------------------------------------------------- | -| pandas | Convert data to and from pandas DataFrames/Series | -| numpy | Convert data to and from NumPy arrays | -| pyarrow | Convert data to and from PyArrow tables/arrays | -| pydantic | Convert data from Pydantic models to Polars | +| Tag | Description | +| -------- | -------------------------------------------------- | +| pandas | Convert data to and from pandas dataframes/series. | +| numpy | Convert data to and from NumPy arrays. | +| pyarrow | Convert data to and from PyArrow tables/arrays. | +| pydantic | Convert data from Pydantic models to Polars. | #### Excel -| Tag | Description | -| ---------- | ----------------------------------------------- | -| calamine | Read from Excel files with the calamine engine | -| openpyxl | Read from Excel files with the openpyxl engine | -| xlsx2csv | Read from Excel files with the xlsx2csv engine | -| xlsxwriter | Write to Excel files with the XlsxWriter engine | -| excel | Install all supported Excel engines | +| Tag | Description | +| ---------- | ------------------------------------------------ | +| calamine | Read from Excel files with the calamine engine. | +| openpyxl | Read from Excel files with the openpyxl engine. | +| xlsx2csv | Read from Excel files with the xlsx2csv engine. | +| xlsxwriter | Write to Excel files with the XlsxWriter engine. | +| excel | Install all supported Excel engines. | #### Database -| Tag | Description | -| ---------- | ----------------------------------------------------------------------------------- | -| adbc | Read from and write to databases with the Arrow Database Connectivity (ADBC) engine | -| connectorx | Read from databases with the ConnectorX engine | -| sqlalchemy | Write to databases with the SQLAlchemy engine | -| database | Install all supported database engines | +| Tag | Description | +| ---------- | ------------------------------------------------------------------------------------ | +| adbc | Read from and write to databases with the Arrow Database Connectivity (ADBC) engine. | +| connectorx | Read from databases with the ConnectorX engine. | +| sqlalchemy | Write to databases with the SQLAlchemy engine. | +| database | Install all supported database engines. | #### Cloud -| Tag | Description | -| ------ | ------------------------------------------ | -| fsspec | Read from and write to remote file systems | +| Tag | Description | +| ------ | ------------------------------------------- | +| fsspec | Read from and write to remote file systems. | #### Other I/O -| Tag | Description | -| --------- | ----------------------------------- | -| deltalake | Read from and write to Delta tables | -| iceberg | Read from Apache Iceberg tables | +| Tag | Description | +| --------- | ------------------------------------ | +| deltalake | Read from and write to Delta tables. | +| iceberg | Read from Apache Iceberg tables. | #### Other -| Tag | Description | -| ----------- | ---------------------------------------------- | -| async | Collect LazyFrames asynchronously | -| cloudpickle | Serialize user-defined functions | -| graph | Visualize LazyFrames as a graph | -| plot | Plot DataFrames through the `plot` namespace | -| style | Style DataFrames through the `style` namespace | -| timezone | Timezone support* | +| Tag | Description | +| ----------- | ----------------------------------------------- | +| async | Collect LazyFrames asynchronously. | +| cloudpickle | Serialize user-defined functions. | +| graph | Visualize LazyFrames as a graph. | +| plot | Plot dataframes through the `plot` namespace. | +| style | Style dataframes through the `style` namespace. | +| timezone | Timezone support[^note]. | -_* Only needed if 1. you are on Python < 3.9 and/or 2. you are on Windows_ +[^note]: Only needed if you are on Python < 3.9 or you are on Windows. ### Rust @@ -178,94 +179,93 @@ The opt-in features are: - `dtype-u16` - `dtype-categorical` - `dtype-struct` -- `lazy` - Lazy API - - `regex` - Use regexes in [column selection](crate::lazy::dsl::col) +- `lazy` - Lazy API: + - `regex` - Use regexes in [column selection](crate::lazy::dsl::col). - `dot_diagram` - Create dot diagrams from lazy logical plans. -- `sql` - Pass SQL queries to polars. +- `sql` - Pass SQL queries to Polars. - `streaming` - Be able to process datasets that are larger than RAM. - `random` - Generate arrays with randomly sampled values - `ndarray`- Convert from `DataFrame` to `ndarray` - `temporal` - Conversions between [Chrono](https://docs.rs/chrono/) and Polars for temporal data types - `timezones` - Activate timezone support. -- `strings` - Extra string utilities for `StringChunked` - - `string_pad` - `pad_start`, `pad_end`, `zfill` - - `string_to_integer` - `parse_int` +- `strings` - Extra string utilities for `StringChunked`: + - `string_pad` - for `pad_start`, `pad_end`, `zfill`. + - `string_to_integer` - for `parse_int`. - `object` - Support for generic ChunkedArrays called `ObjectChunked` (generic over `T`). These are downcastable from Series through the [Any](https://doc.rust-lang.org/std/any/index.html) trait. - Performance related: - `nightly` - Several nightly only features such as SIMD and specialization. - `performant` - more fast paths, slower compile times. - - `bigidx` - Activate this feature if you expect >> 2^32 rows. This has not been needed by anyone. + - `bigidx` - Activate this feature if you expect >> 232 rows. This allows polars to scale up way beyond that by using `u64` as an index. Polars will be a bit slower with this feature activated as many data structures are less cache efficient. - - `cse` - Activate common subplan elimination optimization + - `cse` - Activate common subplan elimination optimization. - IO related: - `serde` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. Can be used for JSON and more serde supported serialization formats. - `serde-lazy` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. Can be used for JSON and more serde supported serialization formats. - - `parquet` - Read Apache Parquet format - - `json` - JSON serialization - - `ipc` - Arrow's IPC format serialization + - `parquet` - Read Apache Parquet format. + - `json` - JSON serialization. + - `ipc` - Arrow's IPC format serialization. - `decompress` - Automatically infer compression of csvs and decompress them. Supported compressions: - zip - gzip - -- `DataFrame` operations: +- Dataframe operations: - `dynamic_group_by` - Group by based on a time window instead of predefined keys. Also activates rolling window group by operations. - - `sort_multiple` - Allow sorting a `DataFrame` on multiple columns - - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`. - And activates `pivot` and `transpose` operations + - `sort_multiple` - Allow sorting a dataframe on multiple columns. + - `rows` - Create dataframe from rows and extract rows from `dataframes`. + Also activates `pivot` and `transpose` operations. - `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match. - - `cross_join` - Create the Cartesian product of two DataFrames. + - `cross_join` - Create the Cartesian product of two dataframes. - `semi_anti_join` - SEMI and ANTI joins. - - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked - - `diagonal_concat` - Concat diagonally thereby combining different schemas. - - `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series) - - `partition_by` - Split into multiple DataFrames partitioned by groups. -- `Series`/`Expression` operations: - - `is_in` - [Check for membership in `Series`](crate::chunked_array::ops::IsIn) - - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip) - - `round_series` - round underlying float types of `Series`. - - `repeat_by` - [Repeat element in an Array N times, where N is given by another array. + - `row_hash` - Utility to hash dataframe rows to `UInt64Chunked`. + - `diagonal_concat` - Diagonal concatenation thereby combining different schemas. + - `dataframe_arithmetic` - Arithmetic between dataframes and other dataframes or series. + - `partition_by` - Split into multiple dataframes partitioned by groups. +- Series/expression operations: + - `is_in` - [Check for membership in series](crate::chunked_array::ops::IsIn) + - `zip_with` - [Zip two `Series` / `ChunkedArray`s](crate::chunked_array::ops::ChunkZip) + - `round_series` - round underlying float types of series. + - `repeat_by` - Repeat element in an array a number of times specified by another array. - `is_first_distinct` - Check if element is first unique value. - `is_last_distinct` - Check if element is last unique value. - - `checked_arithmetic` - checked arithmetic/ returning `None` on invalid operations. - - `dot_product` - Dot/inner product on Series and Expressions. - - `concat_str` - Concat string data in linear time. - - `reinterpret` - Utility to reinterpret bits to signed/unsigned - - `take_opt_iter` - Take from a Series with `Iterator>` - - `mode` - [Return the most occurring value(s)](crate::chunked_array::ops::ChunkUnique::mode) - - `cum_agg` - cum_sum, cum_min, cum_max aggregation. - - `rolling_window` - rolling window functions, like rolling_mean - - `interpolate` [interpolate None values](crate::chunked_array::ops::Interpolate) - - `extract_jsonpath` - [Run jsonpath queries on StringChunked](https://goessner.net/articles/JsonPath/) - - `list` - List utils. - - `list_gather` take sublist by multiple indices + - `checked_arithmetic` - checked arithmetic returning `None` on invalid operations. + - `dot_product` - Dot/inner product on series and expressions. + - `concat_str` - Concatenate string data in linear time. + - `reinterpret` - Utility to reinterpret bits to signed/unsigned. + - `take_opt_iter` - Take from a series with `Iterator>`. + - `mode` - [Return the most frequently occurring value(s)](crate::chunked_array::ops::ChunkUnique::mode). + - `cum_agg` - `cum_sum`, `cum_min`, and `cum_max`, aggregations. + - `rolling_window` - rolling window functions, like `rolling_mean`. + - `interpolate` - [interpolate `None` values](crate::chunked_array::ops::Interpolate). + - `extract_jsonpath` - [Run `jsonpath` queries on `StringChunked`](https://goessner.net/articles/JsonPath/). + - `list` - List utils: + - `list_gather` - take sublist by multiple indices. - `rank` - Ranking algorithms. - - `moment` - kurtosis and skew statistics - - `ewma` - Exponential moving average windows - - `abs` - Get absolute values of Series - - `arange` - Range operation on Series - - `product` - Compute the product of a Series. + - `moment` - Kurtosis and skew statistics. + - `ewma` - Exponential moving average windows. + - `abs` - Get absolute values of series. + - `arange` - Range operation on series. + - `product` - Compute the product of a series. - `diff` - `diff` operation. - `pct_change` - Compute change percentages. - `unique_counts` - Count unique values in expressions. - - `log` - Logarithms for `Series`. - - `list_to_struct` - Convert `List` to `Struct` dtypes. + - `log` - Logarithms for series. + - `list_to_struct` - Convert `List` to `Struct` data types. - `list_count` - Count elements in lists. - `list_eval` - Apply expressions over list elements. - `cumulative_eval` - Apply expressions over cumulatively increasing windows. - `arg_where` - Get indices where condition holds. - `search_sorted` - Find indices where elements should be inserted to maintain order. - - `offset_by` Add an offset to dates that take months and leap years into account. - - `trigonometry` Trigonometric functions. - - `sign` Compute the element-wise sign of a Series. - - `propagate_nans` NaN propagating min/max aggregations. -- `DataFrame` pretty printing - - `fmt` - Activate DataFrame formatting + - `offset_by` - Add an offset to dates that take months and leap years into account. + - `trigonometry` - Trigonometric functions. + - `sign` - Compute the element-wise sign of a series. + - `propagate_nans` - `NaN`-propagating min/max aggregations. +- Dataframe pretty printing: + - `fmt` - Activate dataframe formatting. diff --git a/docs/source/user-guide/io/cloud-storage.md b/docs/source/user-guide/io/cloud-storage.md index ba686a5a0f11..f3b5d7a8fb09 100644 --- a/docs/source/user-guide/io/cloud-storage.md +++ b/docs/source/user-guide/io/cloud-storage.md @@ -18,23 +18,39 @@ To read from cloud storage, additional dependencies may be needed depending on t ## Reading from cloud storage -Polars can read a CSV, IPC or Parquet file in eager mode from cloud storage. +Polars supports reading Parquet, CSV, IPC and NDJSON files from cloud storage: {{code_block('user-guide/io/cloud-storage','read_parquet',['read_parquet','read_csv','read_ipc'])}} -This eager query downloads the file to a buffer in memory and creates a `DataFrame` from there. Polars uses `fsspec` to manage this download internally for all cloud storage providers. - ## Scanning from cloud storage with query optimisation -Polars can scan a Parquet file in lazy mode from cloud storage. We may need to provide further details beyond the source url such as authentication details or storage region. Polars looks for these as environment variables but we can also do this manually by passing a `dict` as the `storage_options` argument. +Using `pl.scan_*` functions to read from cloud storage can benefit from [predicate and projection pushdowns](../lazy/optimizations.md), where the query optimizer will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. -{{code_block('user-guide/io/cloud-storage','scan_parquet',['scan_parquet'])}} +{{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} -This query creates a `LazyFrame` without downloading the file. In the `LazyFrame` we have access to file metadata such as the schema. Polars uses the `object_store.rs` library internally to manage the interface with the cloud storage providers and so no extra dependencies are required in Python to scan a cloud Parquet file. +## Cloud authentication -If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimizer will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. +Polars is able to automatically load default credential configurations for some cloud providers. For +cases when this does not happen, it is possible to manually configure the credentials for Polars to +use for authentication. This can be done in a few ways: -{{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} +### Using `storage_options`: + +- Credentials can be passed as configuration keys in a dict with the `storage_options` parameter: + +{{code_block('user-guide/io/cloud-storage','scan_parquet_storage_options_aws',['scan_parquet'])}} + +### Using one of the available `CredentialProvider*` utility classes + +- There may be a utility class `pl.CredentialProvider*` that provides the required authentication functionality. For example, `pl.CredentialProviderAWS` supports selecting AWS profiles, as well as assuming an IAM role: + +{{code_block('user-guide/io/cloud-storage','credential_provider_class',['scan_parquet'])}} + +### Using a custom `credential_provider` function + +- Some environments may require custom authentication logic (e.g. AWS IAM role-chaining). For these cases a Python function can be provided for Polars to use to retrieve credentials: + +{{code_block('user-guide/io/cloud-storage','credential_provider_custom_func',['scan_parquet'])}} ## Scanning with PyArrow diff --git a/docs/source/user-guide/io/csv.md b/docs/source/user-guide/io/csv.md index 2d7772b45f1f..f654d970ac81 100644 --- a/docs/source/user-guide/io/csv.md +++ b/docs/source/user-guide/io/csv.md @@ -18,4 +18,4 @@ file and instead returns a lazy computation holder called a `LazyFrame`. {{code_block('user-guide/io/csv','scan',['scan_csv'])}} If you want to know why this is desirable, you can read more about these Polars -optimizations [here](../concepts/lazy-vs-eager.md). +optimizations [here](../concepts/lazy-api.md). diff --git a/docs/source/user-guide/io/parquet.md b/docs/source/user-guide/io/parquet.md index da35ee96f476..e04c2bdde2e7 100644 --- a/docs/source/user-guide/io/parquet.md +++ b/docs/source/user-guide/io/parquet.md @@ -20,6 +20,6 @@ Polars allows you to _scan_ a `Parquet` input. Scanning delays the actual parsin {{code_block('user-guide/io/parquet','scan',['scan_parquet'])}} -If you want to know why this is desirable, you can read more about those Polars optimizations [here](../concepts/lazy-vs-eager.md). +If you want to know why this is desirable, you can read more about those Polars optimizations [here](../concepts/lazy-api.md). When we scan a `Parquet` file stored in the cloud, we can also apply predicate and projection pushdowns. This can significantly reduce the amount of data that needs to be downloaded. For scanning a Parquet file in the cloud, see [Cloud storage](cloud-storage.md/#scanning-from-cloud-storage-with-query-optimisation). diff --git a/docs/source/user-guide/migration/pandas.md b/docs/source/user-guide/migration/pandas.md index 5fa435278949..3d1f0996bdad 100644 --- a/docs/source/user-guide/migration/pandas.md +++ b/docs/source/user-guide/migration/pandas.md @@ -50,8 +50,7 @@ eager evaluation. The lazy evaluation mode is powerful because Polars carries ou automatic query optimization when it examines the query plan and looks for ways to accelerate the query or reduce memory usage. -`Dask` also supports lazy evaluation when it generates a query plan. However, `Dask` -does not carry out query optimization on the query plan. +`Dask` also supports lazy evaluation when it generates a query plan. ## Key syntax differences diff --git a/docs/source/user-guide/misc/visualization.md b/docs/source/user-guide/misc/visualization.md index 3f7574c07a2e..5832fa0c9e5f 100644 --- a/docs/source/user-guide/misc/visualization.md +++ b/docs/source/user-guide/misc/visualization.md @@ -27,7 +27,7 @@ This is shorthand for: import altair as alt ( - alt.Chart(df).mark_point().encode( + alt.Chart(df).mark_point(tooltip=True).encode( x="sepal_length", y="sepal_width", color="species", diff --git a/docs/source/user-guide/sql/intro.md b/docs/source/user-guide/sql/intro.md index 08918e4e6404..0b762f16cdd9 100644 --- a/docs/source/user-guide/sql/intro.md +++ b/docs/source/user-guide/sql/intro.md @@ -1,13 +1,13 @@ # Introduction While Polars supports interaction with SQL, it's recommended that users familiarize themselves with -the [expression syntax](../concepts/expressions.md) to produce more readable and expressive code. As the DataFrame +the [expression syntax](../concepts/expressions-and-contexts.md#expressions) to produce more readable and expressive code. As the DataFrame interface is primary, new features are typically added to the expression API first. However, if you already have an existing SQL codebase or prefer the use of SQL, Polars does offers support for this. !!! note Execution - There is no separate SQL engine because Polars translates SQL queries into [expressions](../concepts/expressions.md), which are then executed using its own engine. This approach ensures that Polars maintains its performance and scalability advantages as a native DataFrame library, while still providing users with the ability to work with SQL. + There is no separate SQL engine because Polars translates SQL queries into [expressions](../concepts/expressions-and-contexts.md#expressions), which are then executed using its own engine. This approach ensures that Polars maintains its performance and scalability advantages as a native DataFrame library, while still providing users with the ability to work with SQL. ## Context diff --git a/docs/source/user-guide/transformations/joins.md b/docs/source/user-guide/transformations/joins.md index 7cf07e680503..b135a45f53d3 100644 --- a/docs/source/user-guide/transformations/joins.md +++ b/docs/source/user-guide/transformations/joins.md @@ -1,229 +1,273 @@ # Joins -## Join strategies +A join operation combines columns from one or more dataframes into a new dataframe. +The different “joining strategies” and matching criteria used by the different types of joins influence how columns are combined and also what rows are included in the result of the join operation. -Polars supports the following join strategies by specifying the `how` argument: +The most common type of join is an “equi join”, in which rows are matched by a key expression. +Polars supports several joining strategies for equi joins, which determine exactly how we handle the matching of rows. +Polars also supports “non-equi joins”, a type of join where the matching criterion is not an equality, and a type of join where rows are matched by key proximity, called “asof join”. -| Strategy | Description | -| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `inner` | Returns row with matching keys in _both_ frames. Non-matching rows in either the left or right frame are discarded. | -| `left` | Returns all rows in the left dataframe, whether or not a match in the right-frame is found. Non-matching rows have their right columns null-filled. | -| `right` | Returns all rows in the right dataframe, whether or not a match in the left-frame is found. Non-matching rows have their left columns null-filled. | -| `full` | Returns all rows from both the left and right dataframe. If no match is found in one frame, columns from the other frame are null-filled. | -| `cross` | Returns the Cartesian product of all rows from the left frame with all rows from the right frame. Duplicates rows are retained; the table length of `A` cross-joined with `B` is always `len(A) × len(B)`. | -| `semi` | Returns all rows from the left frame in which the join key is also present in the right frame. | -| `anti` | Returns all rows from the left frame in which the join key is _not_ present in the right frame. | +## Quick reference table -A separate `coalesce` parameter determines whether to merge key columns with the same name from the left and right -frames. +The table below acts as a quick reference for people who know what they are looking for. +If you want to learn about joins in general and how to work with them in Polars, feel free to skip the table and keep reading below. -### Inner join +=== ":fontawesome-brands-python: Python" + + [:material-api: `join`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html) + [:material-api: `join_where`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_asof.html) + [:material-api: `join_asof`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_where.html) + +=== ":fontawesome-brands-rust: Rust" -An `inner` join produces a `DataFrame` that contains only the rows where the join key exists in both `DataFrames`. Let's -take for example the following two `DataFrames`: + [:material-api: `join`](https://docs.pola.rs/api/rust/dev/polars/prelude/trait.DataFrameJoinOps.html#method.join) + ([:material-flag-plus: semi_anti_join](/user-guide/installation/#feature-flags "Enable the feature flag semi_anti_join for semi and for anti joins"){.feature-flag} needed for some options.) + [:material-api: `join_asof_by`](https://docs.pola.rs/api/rust/dev/polars/prelude/trait.AsofJoin.html#method.join_asof) + [:material-flag-plus: Available on feature asof_join](/user-guide/installation/#feature-flags "To use this functionality enable the feature flag asof_join"){.feature-flag} + [:material-api: `join_where`](https://docs.rs/polars/latest/polars/prelude/struct.JoinBuilder.html#method.join_where) + [:material-flag-plus: Available on feature iejoin](/user-guide/installation/#feature-flags "To use this functionality enable the feature flag iejoin"){.feature-flag} -{{code_block('user-guide/transformations/joins','innerdf',['DataFrame'])}} +| Type | Function | Brief description | +| --------------------- | -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Equi inner join | `join(..., how="inner")` | Keeps rows that matched both on the left and right. | +| Equi left outer join | `join(..., how="left")` | Keeps all rows from the left plus matching rows from the right. Non-matching rows from the left have their right columns filled with `null`. | +| Equi right outer join | `join(..., how="right")` | Keeps all rows from the right plus matching rows from the left. Non-matching rows from the right have their left columns filled with `null`. | +| Equi full join | `join(..., how="full")` | Keeps all rows from either dataframe, regardless of whether they match or not. Non-matching rows from one side have the columns from the other side filled with `null`. | +| Equi semi join | `join(..., how="semi")` | Keeps rows from the left that have a match on the right. | +| Equi anti join | `join(..., how="anti")` | Keeps rows from the left that do not have a match on the right. | +| Non-equi inner join | `join_where` | Finds all possible pairings of rows from the left and right that satisfy the given predicate(s). | +| Asof join | `join_asof`/`join_asof_by` | Like a left outer join, but matches on the nearest key instead of on exact key matches. | +| Cartesian product | `join(..., how="cross")` | Computes the [Cartesian product](https://en.wikipedia.org/wiki/Cartesian_product) of the two dataframes. | -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:setup" ---8<-- "python/user-guide/transformations/joins.py:innerdf" +## Equi joins + +In an equi join, rows are matched by checking equality of a key expression. +You can do an equi join with the function `join` by specifying the name of the column to be used as key. +For the examples, we will be loading some (modified) Monopoly property data. + +First, we load a dataframe that contains property names and their colour group in the game: + +{{code_block('user-guide/transformations/joins','props_groups',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:prep-data" +--8<-- "python/user-guide/transformations/joins.py:props_groups" ``` -

+Next, we load a dataframe that contains property names and their price in the game: -{{code_block('user-guide/transformations/joins','innerdf2',['DataFrame'])}} +{{code_block('user-guide/transformations/joins','props_prices',[])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:innerdf2" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:props_prices" ``` -To get a `DataFrame` with the orders and their associated customer we can do an `inner` join on the `customer_id` -column: +Now, we join both dataframes to create a dataframe that contains property names, colour groups, and prices: -{{code_block('user-guide/transformations/joins','inner',['join'])}} +{{code_block('user-guide/transformations/joins','equi-join',['join'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:inner" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:equi-join" ``` -### Left join +The result has four rows but both dataframes used in the operation had five rows. +Polars uses a joining strategy to determine what happens with rows that have multiple matches or with rows that have no match at all. +By default, Polars computes an “inner join” but there are [other join strategies that we show next](#join-strategies). -The `left` outer join produces a `DataFrame` that contains all the rows from the left `DataFrame` and only the rows from -the right `DataFrame` where the join key exists in the left `DataFrame`. If we now take the example from above and want -to have a `DataFrame` with all the customers and their associated orders (regardless of whether they have placed an -order or not) we can do a `left` join: +In the example above, the two dataframes conveniently had the column we wish to use as key with the same name and with the values in the exact same format. +Suppose, for the sake of argument, that one of the dataframes had a differently named column and the other had the property names in lower case: -{{code_block('user-guide/transformations/joins','left',['join'])}} +{{code_block('user-guide/transformations/joins','props_groups2',['Expr.str'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:left" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:props_groups2" ``` -Notice, that the fields for the customer with the `customer_id` of `3` are null, as there are no orders for this -customer. +{{code_block('user-guide/transformations/joins','props_prices2',[])}} -### Right join +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:props_prices2" +``` -The `right` outer join produces a `DataFrame` that contains all the rows from the right `DataFrame` and only the rows from -the left `DataFrame` where the join key exists in the right `DataFrame`. If we now take the example from above and want -to have a `DataFrame` with all the customers and their associated orders (regardless of whether they have placed an -order or not) we can do a `right` join: +In a situation like this, where we may want to perform the same join as before, we can leverage `join`'s flexibility and specify arbitrary expressions to compute the joining key on the left and on the right, allowing one to compute row keys dynamically: -{{code_block('user-guide/transformations/joins','right',['join'])}} +{{code_block('user-guide/transformations/joins', 'join-key-expression', ['join', 'Expr.str'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:right" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:join-key-expression" ``` -Notice, that the fields for the customer with the `customer_id` of `3` are null, as there are no orders for this -customer. +Because we are joining on the right with an expression, Polars preserves the column “property_name” from the left and the column “name” from the right so we can have access to the original values that the key expressions were applied to. -### Outer join +## Join strategies -The `full` outer join produces a `DataFrame` that contains all the rows from both `DataFrames`. Columns are null, if the -join key does not exist in the source `DataFrame`. Doing a `full` outer join on the two `DataFrames` from above produces -a similar `DataFrame` to the `left` join: +When computing a join with `df1.join(df2, ...)`, we can specify one of many different join strategies. +A join strategy specifies what rows to keep from each dataframe based on whether they match rows from the other dataframe. + +### Inner join -{{code_block('user-guide/transformations/joins','full',['join'])}} +In an inner join the resulting dataframe only contains the rows from the left and right dataframes that matched. +That is the default strategy used by `join` and above we can see an example of that. +We repeat the example here and explicitly specify the join strategy: -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:full" +{{code_block('user-guide/transformations/joins','inner-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:inner-join" ``` -{{code_block('user-guide/transformations/joins','full_coalesce',['join'])}} +The result does not include the row from `props_groups` that contains “The Shire” and the result also does not include the row from `props_prices` that contains “Sesame Street”. + +### Left join + +A left outer join is a join where the result contains all the rows from the left dataframe and the rows of the right dataframe that matched any rows from the left dataframe. -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:full_coalesce" +{{code_block('user-guide/transformations/joins','left-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:left-join" ``` -### Cross join +If there are any rows from the left dataframe that have no matching rows on the right dataframe, they get the value `null` on the new columns. + +### Right join -A `cross` join is a Cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is -joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible -combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. +Computationally speaking, a right outer join is exactly the same as a left outer join, but with the arguments swapped. +Here is an example: -{{code_block('user-guide/transformations/joins','df3',['DataFrame'])}} +{{code_block('user-guide/transformations/joins','right-join',['join'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:df3" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:right-join" ``` -

+We show that `df1.join(df2, how="right", ...)` is the same as `df2.join(df1, how="left", ...)`, up to the order of the columns of the result, with the computation below: -{{code_block('user-guide/transformations/joins','df4',['DataFrame'])}} +{{code_block('user-guide/transformations/joins','left-right-join-equals',['join'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:df4" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:left-right-join-equals" ``` -We can now create a `DataFrame` containing all possible combinations of the colors and sizes with a `cross` join: +### Full join -{{code_block('user-guide/transformations/joins','cross',['join'])}} +A full outer join will keep all of the rows from the left and right dataframes, even if they don't have matching rows in the other dataframe: -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:cross" +{{code_block('user-guide/transformations/joins','full-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:full-join" ``` -
+In this case, we see that we get two columns `property_name` and `property_name_right` to make up for the fact that we are matching on the column `property_name` of both dataframes and there are some names for which there are no matches. +The two columns help differentiate the source of each row data. +If we wanted to force `join` to coalesce the two columns `property_name` into a single column, we could set `coalesce=True` explicitly: -The `inner`, `left`, `right`, `full` and `cross` join strategies are standard amongst dataframe libraries. We provide more -details on the less familiar `semi`, `anti` and `asof` join strategies below. +{{code_block('user-guide/transformations/joins','full-join-coalesce',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:full-join-coalesce" +``` + +When not set, the parameter `coalesce` is determined automatically from the join strategy and the key(s) specified, which is why the inner, left, and right, joins acted as if `coalesce=True`, even though we didn't set it. ### Semi join -The `semi` join returns all rows from the left frame in which the join key is also present in the right frame. Consider -the following scenario: a car rental company has a `DataFrame` showing the cars that it owns with each car having a -unique `id`. +A semi join will return the rows of the left dataframe that have a match in the right dataframe, but we do not actually join the matching rows: -{{code_block('user-guide/transformations/joins','df5',['DataFrame'])}} +{{code_block('user-guide/transformations/joins', 'semi-join', [], ['join'], ['join-semi_anti_join_flag'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:df5" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:semi-join" ``` -The company has another `DataFrame` showing each repair job carried out on a vehicle. +A semi join acts as a sort of row filter based on a second dataframe. + +### Anti join + +Conversely, an anti join will return the rows of the left dataframe that do not have a match in the right dataframe: -{{code_block('user-guide/transformations/joins','df6',['DataFrame'])}} +{{code_block('user-guide/transformations/joins', 'anti-join', [], ['join'], ['join-semi_anti_join_flag'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:df6" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:anti-join" ``` -You want to answer this question: which of the cars have had repairs carried out? +## Non-equi joins + +In a non-equi join matches between the left and right dataframes are computed differently. +Instead of looking for matches on key expressions, we provide a single predicate that determines what rows of the left dataframe can be paired up with what rows of the right dataframe. -An inner join does not answer this question directly as it produces a `DataFrame` with multiple rows for each car that -has had multiple repair jobs: +For example, consider the following Monopoly players and their current cash: -{{code_block('user-guide/transformations/joins','inner2',['join'])}} +{{code_block('user-guide/transformations/joins','players',[])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:inner2" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:players" ``` -However, a semi join produces a single row for each car that has had a repair job carried out. +Using a non-equi join we can easily build a dataframe with all the possible properties that each player could be interested in buying. +We use the function `join_where` to compute a non-equi join: -{{code_block('user-guide/transformations/joins','semi',['join'])}} +{{code_block('user-guide/transformations/joins','non-equi',['join_where'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:semi" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:non-equi" ``` -### Anti join +You can provide multiple expressions as predicates but they all must use comparison operators that evaluate to a Boolean result and must refer to columns from both dataframes. -Continuing this example, an alternative question might be: which of the cars have **not** had a repair job carried out? -An anti join produces a `DataFrame` showing all the cars from `df_cars` where the `id` is not present in -the `df_repairs` `DataFrame`. +!!! note -{{code_block('user-guide/transformations/joins','anti',['join'])}} - -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:anti" -``` + `join_where` is still experimental and doesn't yet support arbitrary Boolean expressions as predicates. ## Asof join An `asof` join is like a left join except that we match on nearest key rather than equal keys. In Polars we can do an asof join with the `join_asof` method. -Consider the following scenario: a stock market broker has a `DataFrame` called `df_trades` showing transactions it has -made for different stocks. +For the asof join we will consider a scenario inspired by the stock market. +Suppose a stock market broker has a dataframe called `df_trades` showing transactions it has made for different stocks. -{{code_block('user-guide/transformations/joins','df7',['DataFrame'])}} +{{code_block('user-guide/transformations/joins','df_trades',[])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:df7" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df_trades" ``` -The broker has another `DataFrame` called `df_quotes` showing prices it has quoted for these stocks. +The broker has another dataframe called `df_quotes` showing prices it has quoted for these stocks: -{{code_block('user-guide/transformations/joins','df8',['DataFrame'])}} +{{code_block('user-guide/transformations/joins','df_quotes',[])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:df8" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df_quotes" ``` -You want to produce a `DataFrame` showing for each trade the most recent quote provided _before_ the trade. You do this -with `join_asof` (using the default `strategy = "backward"`). -To avoid joining between trades on one stock with a quote on another you must specify an exact preliminary join on the -stock column with `by="stock"`. +You want to produce a dataframe showing for each trade the most recent quote provided _before_ the trade. You do this with `join_asof` (using the default `strategy = "backward"`). +To avoid joining between trades on one stock with a quote on another you must specify an exact preliminary join on the stock column with `by="stock"`. -{{code_block('user-guide/transformations/joins','asof',['join_asof'])}} +{{code_block('user-guide/transformations/joins','asof', [], ['join_asof'], ['join_asof_by'])}} -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:asofpre" +```python exec="on" result="text" session="transformations/joins" --8<-- "python/user-guide/transformations/joins.py:asof" ``` -If you want to make sure that only quotes within a certain time range are joined to the trades you can specify -the `tolerance` argument. In this case we want to make sure that the last preceding quote is within 1 minute of the -trade so we set `tolerance = "1m"`. +If you want to make sure that only quotes within a certain time range are joined to the trades you can specify the `tolerance` argument. +In this case we want to make sure that the last preceding quote is within 1 minute of the trade so we set `tolerance = "1m"`. -=== ":fontawesome-brands-python: Python" +{{code_block('user-guide/transformations/joins','asof-tolerance', [], ['join_asof'], ['join_asof_by'])}} -```python ---8<-- "python/user-guide/transformations/joins.py:asof2" +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asof-tolerance" ``` -```python exec="on" result="text" session="user-guide/transformations/joins" ---8<-- "python/user-guide/transformations/joins.py:asof2" +## Cartesian product + +Polars allows you to compute the [Cartesian product](https://en.wikipedia.org/wiki/Cartesian_product) of two dataframes, producing a dataframe where all rows of the left dataframe are paired up with all the rows of the right dataframe. +To compute the Cartesian product of two dataframes, you can pass the strategy `how="cross"` to the function `join` without specifying any of `on`, `left_on`, and `right_on`: + +{{code_block('user-guide/transformations/joins','cartesian-product',[],['join'],['cross_join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:cartesian-product" ``` diff --git a/examples/python_rust_compiled_function/Cargo.toml b/examples/python_rust_compiled_function/Cargo.toml deleted file mode 100644 index 94982fe498ef..000000000000 --- a/examples/python_rust_compiled_function/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "python_rust_compiled_function" -version = "0.1.0" -edition = "2021" - -[lib] -name = "my_polars_functions" -crate-type = ["cdylib"] - -[dependencies] -arrow = { workspace = true } -polars = { path = "../../crates/polars" } - -pyo3 = { workspace = true, features = ["extension-module"] } - -[build-dependencies] -pyo3-build-config = "0.21" diff --git a/examples/python_rust_compiled_function/README.md b/examples/python_rust_compiled_function/README.md deleted file mode 100644 index e958fc2fc24f..000000000000 --- a/examples/python_rust_compiled_function/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Compile Custom Rust functions and use in Python Polars - -## Compile a development binary in your current environment - -```sh -pip install -U maturin && maturin develop -``` - -## Run - -```sh -python example.py -``` - -## Compile a **release** build - -```sh -maturin develop --release -``` diff --git a/examples/python_rust_compiled_function/build.rs b/examples/python_rust_compiled_function/build.rs deleted file mode 100644 index dace4a9ba9f8..000000000000 --- a/examples/python_rust_compiled_function/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - pyo3_build_config::add_extension_module_link_args(); -} diff --git a/examples/python_rust_compiled_function/example.py b/examples/python_rust_compiled_function/example.py deleted file mode 100644 index 3b26d63975b1..000000000000 --- a/examples/python_rust_compiled_function/example.py +++ /dev/null @@ -1,19 +0,0 @@ -import polars as pl -from my_polars_functions import hamming_distance - -a = pl.Series("a", ["foo", "bar"]) -b = pl.Series("b", ["fooy", "ham"]) - -dist = hamming_distance(a, b) -expected = pl.Series("", [None, 2], dtype=pl.UInt32) - -# run on 2 Series -print("hamming distance: ", hamming_distance(a, b)) -assert dist.series_equal(expected, null_equal=True) - -# or use in polars expressions -print( - pl.DataFrame([a, b]).select( - pl.map(["a", "b"], lambda series: hamming_distance(series[0], series[1])) - ) -) diff --git a/examples/python_rust_compiled_function/pyproject.toml b/examples/python_rust_compiled_function/pyproject.toml deleted file mode 100644 index c0e2db718795..000000000000 --- a/examples/python_rust_compiled_function/pyproject.toml +++ /dev/null @@ -1,7 +0,0 @@ -[build-system] -requires = ["maturin~=1.2.1"] -build-backend = "maturin" - -[project] -name = "my_polars_functions" -version = "0.1.0" diff --git a/examples/python_rust_compiled_function/src/ffi.rs b/examples/python_rust_compiled_function/src/ffi.rs deleted file mode 100644 index 3597d1f83a03..000000000000 --- a/examples/python_rust_compiled_function/src/ffi.rs +++ /dev/null @@ -1,84 +0,0 @@ -use arrow::ffi; -use polars::prelude::*; -use pyo3::exceptions::PyValueError; -use pyo3::ffi::Py_uintptr_t; -use pyo3::prelude::*; -use pyo3::{PyAny, PyObject, PyResult}; - -/// Take an arrow array from python and convert it to a rust arrow array. -/// This operation does not copy data. -fn array_to_rust(arrow_array: &Bound) -> PyResult { - // prepare a pointer to receive the Array struct - let array = Box::new(ffi::ArrowArray::empty()); - let schema = Box::new(ffi::ArrowSchema::empty()); - - let array_ptr = &*array as *const ffi::ArrowArray; - let schema_ptr = &*schema as *const ffi::ArrowSchema; - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - arrow_array.call_method1( - "_export_to_c", - (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), - )?; - - unsafe { - let field = ffi::import_field_from_c(schema.as_ref()).unwrap(); - let array = ffi::import_array_from_c(*array, field.dtype).unwrap(); - Ok(array) - } -} - -/// Arrow array to Python. -pub(crate) fn to_py_array(py: Python, pyarrow: &Bound, array: ArrayRef) -> PyResult { - let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( - "", - array.dtype().clone(), - true, - ))); - let array = Box::new(ffi::export_array_to_c(array)); - - let schema_ptr: *const ffi::ArrowSchema = &*schema; - let array_ptr: *const ffi::ArrowArray = &*array; - - let array = pyarrow.getattr("Array")?.call_method1( - "_import_from_c", - (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), - )?; - - Ok(array.to_object(py)) -} - -pub fn py_series_to_rust_series(series: &Bound) -> PyResult { - // rechunk series so that they have a single arrow array - let series = series.call_method0("rechunk")?; - - let name = series.getattr("name")?.extract::()?; - - // retrieve pyarrow array - let array = series.call_method0("to_arrow")?; - - // retrieve rust arrow array - let array = array_to_rust(&array)?; - - Series::try_from((name.as_str(), array)).map_err(|e| PyValueError::new_err(format!("{}", e))) -} - -pub fn rust_series_to_py_series(series: &Series) -> PyResult { - // ensure we have a single chunk - let series = series.rechunk(); - let array = series.to_arrow(0, false); - - Python::with_gil(|py| { - // import pyarrow - let pyarrow = py.import_bound("pyarrow")?; - - // pyarrow array - let pyarrow_array = to_py_array(py, &pyarrow, array)?; - - // import polars - let polars = py.import_bound("polars")?; - let out = polars.call_method1("from_arrow", (pyarrow_array,))?; - Ok(out.to_object(py)) - }) -} diff --git a/examples/python_rust_compiled_function/src/lib.rs b/examples/python_rust_compiled_function/src/lib.rs deleted file mode 100644 index f8c2caec2123..000000000000 --- a/examples/python_rust_compiled_function/src/lib.rs +++ /dev/null @@ -1,50 +0,0 @@ -mod ffi; - -use polars::prelude::*; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -#[pyfunction] -fn hamming_distance(series_a: &Bound, series_b: &Bound) -> PyResult { - let series_a = ffi::py_series_to_rust_series(series_a)?; - let series_b = ffi::py_series_to_rust_series(series_b)?; - - let out = hamming_distance_impl(&series_a, &series_b) - .map_err(|e| PyValueError::new_err(format!("Something went wrong: {:?}", e)))?; - ffi::rust_series_to_py_series(&out.into_series()) -} - -/// This function iterates over 2 `StringChunked` arrays and computes the hamming distance between the values . -fn hamming_distance_impl(a: &Series, b: &Series) -> PolarsResult { - Ok(a.str()? - .into_iter() - .zip(b.str()?) - .map(|(lhs, rhs)| hamming_distance_strs(lhs, rhs)) - .collect()) -} - -/// Compute the hamming distance between 2 string values. -fn hamming_distance_strs(a: Option<&str>, b: Option<&str>) -> Option { - match (a, b) { - (None, _) => None, - (_, None) => None, - (Some(a), Some(b)) => { - if a.len() != b.len() { - None - } else { - Some( - a.chars() - .zip(b.chars()) - .map(|(a_char, b_char)| (a_char != b_char) as u32) - .sum::(), - ) - } - }, - } -} - -#[pymodule] -fn my_polars_functions(_py: Python, m: &Bound) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(hamming_distance)).unwrap(); - Ok(()) -} diff --git a/examples/read_csv/Cargo.toml b/examples/read_csv/Cargo.toml deleted file mode 100644 index f30de21d1ae0..000000000000 --- a/examples/read_csv/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "read_csv" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "csv", "ipc"] } - -[features] -write_output = ["polars/ipc", "polars/parquet"] -default = ["write_output"] diff --git a/examples/read_csv/src/main.rs b/examples/read_csv/src/main.rs deleted file mode 100644 index 877fc6483635..000000000000 --- a/examples/read_csv/src/main.rs +++ /dev/null @@ -1,31 +0,0 @@ -use polars::io::mmap::MmapBytesReader; -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let file = std::fs::File::open("/home/ritchie46/Downloads/pdsh/tables_scale_100/lineitem.tbl") - .unwrap(); - let file = Box::new(file) as Box; - let _df = CsvReader::new(file) - .with_separator(b'|') - .has_header(false) - .with_chunk_size(10) - .batched(None) - .unwrap(); - - // write_other_formats(&mut df)?; - Ok(()) -} - -fn _write_other_formats(df: &mut DataFrame) -> PolarsResult<()> { - let parquet_out = "../datasets/foods1.parquet"; - if std::fs::metadata(parquet_out).is_err() { - let f = std::fs::File::create(parquet_out).unwrap(); - ParquetWriter::new(f).with_statistics(true).finish(df)?; - } - let ipc_out = "../datasets/foods1.ipc"; - if std::fs::metadata(ipc_out).is_err() { - let f = std::fs::File::create(ipc_out).unwrap(); - IpcWriter::new(f).finish(df)? - } - Ok(()) -} diff --git a/examples/read_json/Cargo.toml b/examples/read_json/Cargo.toml deleted file mode 100644 index d9e66858c2c0..000000000000 --- a/examples/read_json/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "read_json" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["json"] } diff --git a/examples/read_json/src/main.rs b/examples/read_json/src/main.rs deleted file mode 100644 index f16c2abfdc69..000000000000 --- a/examples/read_json/src/main.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::io::Cursor; - -use polars::prelude::*; - -fn main() { - let data = r#"[ - {"date": "1996-12-16T00:00:00.000", "open": 16.86, "close": 16.86, "high": 16.86, "low": 16.86, "volume": 62442.0, "turnover": 105277000.0}, - {"date": "1996-12-17T00:00:00.000", "open": 15.17, "close": 15.17, "high": 16.79, "low": 15.17, "volume": 463675.0, "turnover": 718902016.0}, - {"date": "1996-12-18T00:00:00.000", "open": 15.28, "close": 16.69, "high": 16.69, "low": 15.18, "volume": 445380.0, "turnover": 719400000.0}, - {"date": "1996-12-19T00:00:00.000", "open": 17.01, "close": 16.4, "high": 17.9, "low": 15.99, "volume": 572946.0, "turnover": 970124992.0} - ]"#; - - let res = JsonReader::new(Cursor::new(data)).finish(); - println!("{:?}", res); - assert!(res.is_ok()); - let df = res.unwrap(); - println!("{:?}", df); -} diff --git a/examples/read_parquet/Cargo.toml b/examples/read_parquet/Cargo.toml deleted file mode 100644 index 0e88fb7b62ca..000000000000 --- a/examples/read_parquet/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "read_parquet" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "parquet"] } diff --git a/examples/read_parquet/src/main.rs b/examples/read_parquet/src/main.rs deleted file mode 100644 index 5023a95c2368..000000000000 --- a/examples/read_parquet/src/main.rs +++ /dev/null @@ -1,15 +0,0 @@ -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let df = LazyFrame::scan_parquet("../datasets/foods1.parquet", ScanArgsParquet::default())? - .select([ - // select all columns - all(), - // and do some aggregations - cols(["fats_g", "sugars_g"]).sum().name().suffix("_summed"), - ]) - .collect()?; - - println!("{}", df); - Ok(()) -} diff --git a/examples/read_parquet_cloud/Cargo.toml b/examples/read_parquet_cloud/Cargo.toml deleted file mode 100644 index bbb43403bd95..000000000000 --- a/examples/read_parquet_cloud/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "read_parquet_cloud" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet"] } - -aws-creds = "0.36.0" diff --git a/examples/read_parquet_cloud/src/main.rs b/examples/read_parquet_cloud/src/main.rs deleted file mode 100644 index 367575bbdd30..000000000000 --- a/examples/read_parquet_cloud/src/main.rs +++ /dev/null @@ -1,30 +0,0 @@ -use awscreds::Credentials; -use cloud::AmazonS3ConfigKey as Key; -use polars::prelude::*; - -// Login to your aws account and then copy the ../datasets/foods1.parquet file to your own bucket. -// Adjust the link below. -const TEST_S3: &str = "s3://lov2test/polars/datasets/*.parquet"; - -fn main() -> PolarsResult<()> { - let cred = Credentials::default().unwrap(); - - // Propagate the credentials and other cloud options. - let mut args = ScanArgsParquet::default(); - let cloud_options = cloud::CloudOptions::default().with_aws([ - (Key::AccessKeyId, &cred.access_key.unwrap()), - (Key::SecretAccessKey, &cred.secret_key.unwrap()), - (Key::Region, &"us-west-2".into()), - ]); - args.cloud_options = Some(cloud_options); - let df = LazyFrame::scan_parquet(TEST_S3, args)? - .with_streaming(true) - .select([ - // select all columns - all(), - ]) - .collect()?; - - println!("{}", df); - Ok(()) -} diff --git a/examples/string_filter/Cargo.toml b/examples/string_filter/Cargo.toml deleted file mode 100644 index a7738dd2db7f..000000000000 --- a/examples/string_filter/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "string_filter" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["strings", "lazy"] } diff --git a/examples/string_filter/src/main.rs b/examples/string_filter/src/main.rs deleted file mode 100644 index 5364f9b5b4d2..000000000000 --- a/examples/string_filter/src/main.rs +++ /dev/null @@ -1,13 +0,0 @@ -use polars::lazy::prelude::*; -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let a = Series::new("a", [1, 2, 3, 4]); - let b = Series::new("b", ["one", "two", "three", "four"]); - let df = DataFrame::new(vec![a, b])? - .lazy() - .filter(col("b").str().starts_with(lit("t"))) - .collect()?; - println!("{df:?}"); - Ok(()) -} diff --git a/examples/write_ipc_cloud/Cargo.toml b/examples/write_ipc_cloud/Cargo.toml deleted file mode 100644 index 764f67ed4504..000000000000 --- a/examples/write_ipc_cloud/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "write_ipc_cloud" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -aws-creds = "0.36.0" -polars = { path = "../../crates/polars", features = ["lazy", "aws", "ipc", "cloud_write", "streaming"] } - -[workspace] diff --git a/examples/write_ipc_cloud/src/main.rs b/examples/write_ipc_cloud/src/main.rs deleted file mode 100644 index 3da5c95fc362..000000000000 --- a/examples/write_ipc_cloud/src/main.rs +++ /dev/null @@ -1,30 +0,0 @@ -use cloud::AmazonS3ConfigKey as Key; -use polars::prelude::*; - -const TEST_S3_LOCATION: &str = "s3://test-bucket/test-writes/polars_write_example_cloud.ipc"; - -fn main() -> PolarsResult<()> { - let cloud_options = cloud::CloudOptions::default().with_aws([ - (Key::AccessKeyId, "test".to_string()), - (Key::SecretAccessKey, "test".to_string()), - (Key::Endpoint, "http://localhost:4566".to_string()), - (Key::Region, "us-east-1".to_string()), - ]); - let cloud_options = Some(cloud_options); - - let df = df!( - "foo" => &[1, 2, 3], - "bar" => &[None, Some("bak"), Some("baz")], - ) - .unwrap(); - - df.lazy() - .sink_ipc_cloud( - TEST_S3_LOCATION.to_string(), - cloud_options, - Default::default(), - ) - .unwrap(); - - Ok(()) -} diff --git a/examples/write_parquet_cloud/Cargo.toml b/examples/write_parquet_cloud/Cargo.toml deleted file mode 100644 index 1e79b04739b2..000000000000 --- a/examples/write_parquet_cloud/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "write_parquet_cloud" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -aws-creds = "0.36.0" -polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet", "cloud_write", "streaming"] } - -[workspace] diff --git a/examples/write_parquet_cloud/src/main.rs b/examples/write_parquet_cloud/src/main.rs deleted file mode 100644 index c928ceb4c765..000000000000 --- a/examples/write_parquet_cloud/src/main.rs +++ /dev/null @@ -1,62 +0,0 @@ -use awscreds::Credentials; -use cloud::AmazonS3ConfigKey as Key; -use polars::prelude::*; - -// Adjust the link below. -const TEST_S3_LOCATION: &str = "s3://polarstesting/polars_write_example_cloud.parquet"; - -fn main() -> PolarsResult<()> { - sink_file(); - sink_cloud_local(); - sink_aws(); - - Ok(()) -} - -fn sink_file() { - let df = example_dataframe(); - - // Writing to a local file: - let path = "/tmp/polars_write_example.parquet".into(); - df.lazy().sink_parquet(path, Default::default()).unwrap(); -} - -fn sink_cloud_local() { - let df = example_dataframe(); - - // Writing to a location that might be in the cloud: - let uri = "file:///tmp/polars_write_example_cloud.parquet".to_string(); - df.lazy() - .sink_parquet_cloud(uri, None, Default::default()) - .unwrap(); -} - -fn sink_aws() { - let cred = Credentials::default().unwrap(); - - // Propagate the credentials and other cloud options. - let cloud_options = cloud::CloudOptions::default().with_aws([ - (Key::AccessKeyId, &cred.access_key.unwrap()), - (Key::SecretAccessKey, &cred.secret_key.unwrap()), - (Key::Region, &"eu-central-1".into()), - ]); - let cloud_options = Some(cloud_options); - - let df = example_dataframe(); - - df.lazy() - .sink_parquet_cloud( - TEST_S3_LOCATION.to_string(), - cloud_options, - Default::default(), - ) - .unwrap(); -} - -fn example_dataframe() -> DataFrame { - df!( - "foo" => &[1, 2, 3], - "bar" => &[None, Some("bak"), Some("baz")], - ) - .unwrap() -} diff --git a/mkdocs.yml b/mkdocs.yml index 37fbe465973f..c180bbfc6b8e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -16,14 +16,9 @@ nav: - user-guide/installation.md - Concepts: - user-guide/concepts/index.md - - Data types: - - user-guide/concepts/data-types/overview.md - - user-guide/concepts/data-types/categoricals.md - - user-guide/concepts/data-structures.md - - user-guide/concepts/contexts.md - - user-guide/concepts/expressions.md - - user-guide/concepts/lazy-vs-eager.md - - user-guide/concepts/streaming.md + - user-guide/concepts/data-types-and-structures.md + - user-guide/concepts/expressions-and-contexts.md + - user-guide/concepts/lazy-api.md - Expressions: - user-guide/expressions/index.md - user-guide/expressions/operators.md @@ -31,6 +26,7 @@ nav: - user-guide/expressions/functions.md - user-guide/expressions/casting.md - user-guide/expressions/strings.md + - user-guide/expressions/categorical-data-and-enums.md - user-guide/expressions/aggregation.md - user-guide/expressions/missing-data.md - user-guide/expressions/window.md diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 63863959b720..fc3e520e5ecc 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "1.9.0" +version = "1.12.0" edition = "2021" [lib] @@ -104,6 +104,7 @@ all = [ "parquet", "ipc", "polars-python/all", + "performant", ] default = ["all", "nightly"] diff --git a/py-polars/Makefile b/py-polars/Makefile index 3c98adab08cb..3d9e5d7ffddc 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -23,39 +23,23 @@ requirements-all: .venv ## Install/refresh all Python requirements (including t @$(MAKE) -s -C .. $@ .PHONY: build -build: .venv ## Compile and install Polars for development - @$(MAKE) -s -C .. $@ - -.PHONY: build-debug-opt -build-debug-opt: .venv ## Compile and install Polars with minimal optimizations turned on - @$(MAKE) -s -C .. $@ - -.PHONY: build-debug-opt-subset -build-debug-opt-subset: .venv ## Compile and install Polars with minimal optimizations turned on and no default features - @$(MAKE) -s -C .. $@ - -.PHONY: build-opt -build-opt: .venv ## Compile and install Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on +build: .venv ## Compile and install Python Polars for development @$(MAKE) -s -C .. $@ .PHONY: build-release -build-release: .venv ## Compile and install a faster Polars binary with full optimizations - @$(MAKE) -s -C .. $@ - -.PHONY: build-native -build-native: .venv ## Same as build, except with native CPU optimizations turned on +build-release: .venv ## Compile and install Python Polars binary with optimizations, with minimal debug symbols @$(MAKE) -s -C .. $@ -.PHONY: build-debug-opt-native -build-debug-opt-native: .venv ## Same as build-debug-opt, except with native CPU optimizations turned on +.PHONY: build-nodebug-release +build-nodebug-release: .venv ## Same as build-release, but without any debug symbols at all (a bit faster to build) @$(MAKE) -s -C .. $@ -.PHONY: build-opt-native -build-opt-native: .venv ## Same as build-opt, except with native CPU optimizations turned on +.PHONY: build-debug-release +build-debug-release: .venv ## Same as build-release, but with full debug symbols turned on (a bit slower to build) @$(MAKE) -s -C .. $@ -.PHONY: build-release-native -build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on +.PHONY: build-dist-release +build-dist-release: .venv ## Compile and install Python Polars binary with super slow extra optimization turned on, for distribution @$(MAKE) -s -C .. $@ .PHONY: lint diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index a0cde717f0da..7c1358b480f6 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -16,6 +16,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.decode Expr.str.encode Expr.str.ends_with + Expr.str.escape_regex Expr.str.explode Expr.str.extract Expr.str.extract_all diff --git a/py-polars/docs/source/reference/expressions/struct.rst b/py-polars/docs/source/reference/expressions/struct.rst index 958436e4066b..cd081477b23b 100644 --- a/py-polars/docs/source/reference/expressions/struct.rst +++ b/py-polars/docs/source/reference/expressions/struct.rst @@ -10,6 +10,7 @@ The following methods are available under the `expr.struct` attribute. :template: autosummary/accessor_method.rst Expr.struct.field + Expr.struct.unnest Expr.struct.json_encode Expr.struct.rename_fields Expr.struct.with_fields diff --git a/py-polars/docs/source/reference/functions.rst b/py-polars/docs/source/reference/functions.rst index c672aaa77eac..33ee296844db 100644 --- a/py-polars/docs/source/reference/functions.rst +++ b/py-polars/docs/source/reference/functions.rst @@ -25,6 +25,7 @@ Miscellaneous align_frames concat + escape_regex Parallelization ~~~~~~~~~~~~~~~ diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index 1f088958a3c0..bdfe93ee9ebe 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -117,3 +117,14 @@ Connect to pyarrow datasets. :toctree: api/ scan_pyarrow_dataset + +Cloud Credentials +~~~~~~~~~~~~~~~~~ +Configuration for cloud credential provisioning. + +.. autosummary:: + :toctree: api/ + + CredentialProvider + CredentialProviderAWS + CredentialProviderGCP diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index 5ab80238a54d..8cdb8fe152fa 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -15,6 +15,7 @@ Computation Series.arctanh Series.arg_true Series.arg_unique + Series.approx_n_unique Series.bitwise_count_ones Series.bitwise_count_zeros Series.bitwise_leading_ones @@ -42,10 +43,12 @@ Computation Series.ewm_std Series.ewm_var Series.exp + Series.first Series.hash Series.hist Series.is_between Series.kurtosis + Series.last Series.log Series.log10 Series.log1p diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst index 9e44b0d9b105..85dcf4b1b2d6 100644 --- a/py-polars/docs/source/reference/series/string.rst +++ b/py-polars/docs/source/reference/series/string.rst @@ -16,6 +16,7 @@ The following methods are available under the `Series.str` attribute. Series.str.decode Series.str.encode Series.str.ends_with + Series.str.escape_regex Series.str.explode Series.str.extract Series.str.extract_all diff --git a/py-polars/docs/source/reference/sql/functions/aggregate.rst b/py-polars/docs/source/reference/sql/functions/aggregate.rst index 0244af235175..07081c72454c 100644 --- a/py-polars/docs/source/reference/sql/functions/aggregate.rst +++ b/py-polars/docs/source/reference/sql/functions/aggregate.rst @@ -21,6 +21,11 @@ Aggregate - Returns the median element from the grouping. * - :ref:`MIN ` - Returns the smallest (minimum) of all the elements in the grouping. + * - :ref:`QUANTILE_CONT ` + - Returns the continuous quantile element from the grouping (interpolated value between two closest values). + * - :ref:`QUANTILE_DISC ` + - Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, and returns the + value associated with the subinterval where the quantile value falls. * - :ref:`STDDEV ` - Returns the standard deviation of all the elements in the grouping. * - :ref:`SUM ` @@ -198,6 +203,64 @@ Returns the smallest (minimum) of all the elements in the grouping. # │ 10 │ # └─────────┘ + +.. _quantile_cont: + +QUANTILE_CONT +------------- +Returns the continuous quantile element from the grouping (interpolated value between two closest values). + +**Example:** + +.. code-block:: python + + df = pl.DataFrame({"foo": [5, 20, 10, 30, 70, 40, 10, 90]}) + df.sql(""" + SELECT + QUANTILE_CONT(foo, 0.25) AS foo_q25, + QUANTILE_CONT(foo, 0.50) AS foo_q50, + QUANTILE_CONT(foo, 0.75) AS foo_q75, + FROM self + """) + # shape: (1, 3) + # ┌─────────┬─────────┬─────────┐ + # │ foo_q25 ┆ foo_q50 ┆ foo_q75 │ + # │ --- ┆ --- ┆ --- │ + # │ f64 ┆ f64 ┆ f64 │ + # ╞═════════╪═════════╪═════════╡ + # │ 10.0 ┆ 25.0 ┆ 47.5 │ + # └─────────┴─────────┴─────────┘ + + +.. _quantile_disc: + +QUANTILE_DISC +------------- +Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, and +returns the value associated with the subinterval where the quantile value falls. + +**Example:** + +.. code-block:: python + + df = pl.DataFrame({"foo": [5, 20, 10, 30, 70, 40, 10, 90]}) + df.sql(""" + SELECT + QUANTILE_DISC(foo, 0.25) AS foo_q25, + QUANTILE_DISC(foo, 0.50) AS foo_q50, + QUANTILE_DISC(foo, 0.75) AS foo_q75, + FROM self + """) + # shape: (1, 3) + # ┌─────────┬─────────┬─────────┐ + # │ foo_q25 ┆ foo_q50 ┆ foo_q75 │ + # │ --- ┆ --- ┆ --- │ + # │ f64 ┆ f64 ┆ f64 │ + # ╞═════════╪═════════╪═════════╡ + # │ 10.0 ┆ 20.0 ┆ 40.0 │ + # └─────────┴─────────┴─────────┘ + + .. _stddev: STDDEV diff --git a/py-polars/docs/source/reference/sql/functions/bitwise.rst b/py-polars/docs/source/reference/sql/functions/bitwise.rst new file mode 100644 index 000000000000..bd66c3810df1 --- /dev/null +++ b/py-polars/docs/source/reference/sql/functions/bitwise.rst @@ -0,0 +1,151 @@ +Temporal +======== + +.. list-table:: + :header-rows: 1 + :widths: 20 60 + + * - Function + - Description + + * - :ref:`BIT_AND ` + - Returns the bitwise AND of the given values. + * - :ref:`BIT_COUNT ` + - Returns the number of bits set to 1 in the binary representation of the given value. + * - :ref:`BIT_OR ` + - Returns the bitwise OR of the given values. + * - :ref:`BIT_XOR ` + - Returns the bitwise XOR of the given values. + + +.. _bit_and: + +BIT_AND +------- +Returns the bitwise AND of the given values. +Also available as the `&` binary operator. + +.. code-block:: python + + df = pl.DataFrame( + { + "i": [3, 10, 4, 8], + "j": [4, 7, 9, 10], + } + ) + df.sql(""" + SELECT + i, + j, + i & j AS i_bitand_op_j, + BIT_AND(i, j) AS i_bitand_j + FROM self + """) + # shape: (4, 4) + # ┌─────┬─────┬───────────────┬────────────┐ + # │ i ┆ j ┆ i_bitand_op_j ┆ i_bitand_j │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞═════╪═════╪═══════════════╪════════════╡ + # │ 3 ┆ 4 ┆ 0 ┆ 0 │ + # │ 10 ┆ 7 ┆ 2 ┆ 2 │ + # │ 4 ┆ 9 ┆ 0 ┆ 0 │ + # │ 8 ┆ 10 ┆ 8 ┆ 8 │ + # └─────┴─────┴───────────────┴────────────┘ + +.. _bit_count: + +BIT_COUNT +--------- +Returns the number of bits set to 1 in the binary representation of the given value. + +.. code-block:: python + + df = pl.DataFrame({"i": [16, 10, 55, 127]}) + df.sql(""" + SELECT + i, + BIT_COUNT(i) AS i_bitcount + FROM self + """) + # shape: (4, 2) + # ┌─────┬────────────┐ + # │ i ┆ i_bitcount │ + # │ --- ┆ --- │ + # │ i64 ┆ u32 │ + # ╞═════╪════════════╡ + # │ 16 ┆ 1 │ + # │ 10 ┆ 2 │ + # │ 55 ┆ 5 │ + # │ 127 ┆ 7 │ + # └─────┴────────────┘ + +.. _bit_or: + +BIT_OR +------ +Returns the bitwise OR of the given values. +Also available as the `|` binary operator. + +.. code-block:: python + + df = pl.DataFrame( + { + "i": [3, 10, 4, 8], + "j": [4, 7, 9, 10], + } + ) + df.sql(""" + SELECT + i, + j, + i | j AS i_bitor_op_j, + BIT_OR(i, j) AS i_bitor_j + FROM self + """) + # shape: (4, 4) + # ┌─────┬─────┬──────────────┬───────────┐ + # │ i ┆ j ┆ i_bitor_op_j ┆ i_bitor_j │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞═════╪═════╪══════════════╪═══════════╡ + # │ 3 ┆ 4 ┆ 7 ┆ 7 │ + # │ 10 ┆ 7 ┆ 15 ┆ 15 │ + # │ 4 ┆ 9 ┆ 13 ┆ 13 │ + # │ 8 ┆ 10 ┆ 10 ┆ 10 │ + # └─────┴─────┴──────────────┴───────────┘ + +.. _bit_xor: + +BIT_XOR +------- +Returns the bitwise XOR of the given values. +Also available as the `XOR` binary operator. + +.. code-block:: python + + df = pl.DataFrame( + { + "i": [3, 10, 4, 8], + "j": [4, 7, 9, 10], + } + ) + df.sql(""" + SELECT + i, + j, + i XOR j AS i_bitxor_op_j, + BIT_XOR(i, j) AS i_bitxor_j + FROM self + """) + # shape: (4, 4) + # ┌─────┬─────┬───────────────┬────────────┐ + # │ i ┆ j ┆ i_bitxor_op_j ┆ i_bitxor_j │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞═════╪═════╪═══════════════╪════════════╡ + # │ 3 ┆ 4 ┆ 7 ┆ 7 │ + # │ 10 ┆ 7 ┆ 13 ┆ 13 │ + # │ 4 ┆ 9 ┆ 13 ┆ 13 │ + # │ 8 ┆ 10 ┆ 2 ┆ 2 │ + # └─────┴─────┴───────────────┴────────────┘ diff --git a/py-polars/docs/source/reference/sql/functions/index.rst b/py-polars/docs/source/reference/sql/functions/index.rst index cf6247e41a10..3473a0741a91 100644 --- a/py-polars/docs/source/reference/sql/functions/index.rst +++ b/py-polars/docs/source/reference/sql/functions/index.rst @@ -32,6 +32,18 @@ SQL Functions array + .. grid-item-card:: + + **Bitwise** + ^^^^^^^^^^^ + + .. toctree:: + :maxdepth: 2 + + bitwise + +.. grid:: + .. grid-item-card:: **Conditional** @@ -52,8 +64,6 @@ SQL Functions math -.. grid:: - .. grid-item-card:: **String** @@ -64,6 +74,8 @@ SQL Functions string +.. grid:: + .. grid-item-card:: **Temporal** diff --git a/py-polars/docs/source/reference/sql/functions/trigonometry.rst b/py-polars/docs/source/reference/sql/functions/trigonometry.rst index a2ee47ad3d06..dcd7ba4fe18d 100644 --- a/py-polars/docs/source/reference/sql/functions/trigonometry.rst +++ b/py-polars/docs/source/reference/sql/functions/trigonometry.rst @@ -18,9 +18,9 @@ Trigonometry * - :ref:`ATAND ` - Compute inverse tangent of the input column (in degrees). * - :ref:`ATAN2 ` - - Compute the inverse tangent of column_2/column_1 (in radians). + - Compute the inverse tangent of column_1/column_2 (in radians). * - :ref:`ATAN2D ` - - Compute the inverse tangent of column_2/column_1 (in degrees). + - Compute the inverse tangent of column_1/column_2 (in degrees). * - :ref:`COT ` - Compute the cotangent of the input column (in radians). * - :ref:`COTD ` @@ -187,7 +187,7 @@ Compute inverse tangent of the input column (in degrees). ATAN2 ----- -Compute the inverse tangent of column_2/column_1 (in radians). +Compute the inverse tangent of column_1/column_2 (in radians). **Example:** @@ -216,7 +216,7 @@ Compute the inverse tangent of column_2/column_1 (in radians). ATAN2D ------ -Compute the inverse tangent of column_2/column_1 (in degrees). +Compute the inverse tangent of column_1/column_2 (in degrees). **Example:** @@ -224,22 +224,22 @@ Compute the inverse tangent of column_2/column_1 (in degrees). df = pl.DataFrame( { - "a": [0, 90, 180, 360], - "b": [360, 270, 180, 90], + "a": [-1.0, 0.0, 1.0, 1.0], + "b": [1.0, 1.0, 0.0, -1.0], } ) df.sql("SELECT a, b, ATAN2D(a, b) AS atan2d_ab FROM self") # shape: (4, 3) - # ┌─────┬─────┬───────────┐ - # │ a ┆ b ┆ atan2d_ab │ - # │ --- ┆ --- ┆ --- │ - # │ i64 ┆ i64 ┆ f64 │ - # ╞═════╪═════╪═══════════╡ - # │ 0 ┆ 360 ┆ 0.0 │ - # │ 90 ┆ 270 ┆ 18.434949 │ - # │ 180 ┆ 180 ┆ 45.0 │ - # │ 360 ┆ 90 ┆ 75.963757 │ - # └─────┴─────┴───────────┘ + # ┌──────┬──────┬───────────┐ + # │ a ┆ b ┆ atan2d_ab │ + # │ --- ┆ --- ┆ --- │ + # │ f64 ┆ f64 ┆ f64 │ + # ╞══════╪══════╪═══════════╡ + # │ -1 ┆ 1.0 ┆ 135.0 │ + # │ 0.0 ┆ 1.0 ┆ 90.0 │ + # │ 1.0 ┆ 0.0 ┆ 0.0 │ + # │ 1.0 ┆ -1.0 ┆ -45.0 │ + # └──────┴──────┴───────────┘ .. _cot: diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 10f0ee54228b..83ea52acc822 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -104,6 +104,7 @@ datetime_ranges, duration, element, + escape_regex, exclude, field, first, @@ -176,6 +177,13 @@ scan_parquet, scan_pyarrow_dataset, ) +from polars.io.cloud import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) from polars.lazyframe import GPUEngine, LazyFrame from polars.meta import ( build_info, @@ -266,6 +274,12 @@ "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", + # polars.io.cloud + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", # polars.stringcache "StringCache", "disable_string_cache", @@ -290,6 +304,7 @@ "time_range", "time_ranges", "zeros", + "escape_regex", # polars.functions.aggregation "all", "all_horizontal", diff --git a/py-polars/polars/_cpu_check.py b/py-polars/polars/_cpu_check.py index c71029c303f7..487b5f02590e 100644 --- a/py-polars/polars/_cpu_check.py +++ b/py-polars/polars/_cpu_check.py @@ -44,13 +44,16 @@ _POLARS_FEATURE_FLAGS = "" # Set to True during the build process if we are building a LTS CPU version. -# The risk of the CPU check failing is then higher than a CPU not being supported. _POLARS_LTS_CPU = False _IS_WINDOWS = os.name == "nt" _IS_64BIT = ctypes.sizeof(ctypes.c_void_p) == 8 +def get_lts_cpu() -> bool: + return _POLARS_LTS_CPU + + def _open_posix_libc() -> ctypes.CDLL: # Avoid importing ctypes.util if possible. try: @@ -234,11 +237,7 @@ def _read_cpu_flags() -> dict[str, bool]: def check_cpu_flags() -> None: - if ( - not _POLARS_FEATURE_FLAGS - or _POLARS_LTS_CPU - or os.environ.get("POLARS_SKIP_CPU_CHECK") - ): + if not _POLARS_FEATURE_FLAGS or os.environ.get("POLARS_SKIP_CPU_CHECK"): return expected_cpu_flags = [f.lstrip("+") for f in _POLARS_FEATURE_FLAGS.split(",")] diff --git a/py-polars/polars/_reexport.py b/py-polars/polars/_reexport.py index 408fead781de..10818f473166 100644 --- a/py-polars/polars/_reexport.py +++ b/py-polars/polars/_reexport.py @@ -3,12 +3,14 @@ from polars.dataframe import DataFrame from polars.expr import Expr, When from polars.lazyframe import LazyFrame +from polars.schema import Schema from polars.series import Series __all__ = [ "DataFrame", "Expr", "LazyFrame", + "Schema", "Series", "When", ] diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 1670b08aeb2f..67a06a2c689e 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -1,7 +1,9 @@ from __future__ import annotations from collections.abc import Collection, Iterable, Mapping, Sequence +from pathlib import Path from typing import ( + IO, TYPE_CHECKING, Any, Literal, @@ -158,9 +160,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: RollingInterpolationMethod: TypeAlias = Literal[ "nearest", "higher", "lower", "midpoint", "linear" ] # QuantileInterpolOptions -ToStructStrategy: TypeAlias = Literal[ - "first_non_null", "max_width" -] # ListToStructWidthStrategy +ListToStructWidthStrategy: TypeAlias = Literal["first_non_null", "max_width"] # The following have no equivalent on the Rust side ConcatMethod = Literal[ @@ -294,3 +294,14 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: # LazyFrame engine selection EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"] + +ScanSource: TypeAlias = Union[ + str, + Path, + IO[bytes], + bytes, + list[str], + list[Path], + list[IO[bytes]], + list[bytes], +] diff --git a/py-polars/polars/_utils/construction/series.py b/py-polars/polars/_utils/construction/series.py index 121378b6d873..f8b700badc20 100644 --- a/py-polars/polars/_utils/construction/series.py +++ b/py-polars/polars/_utils/construction/series.py @@ -36,6 +36,7 @@ Object, Struct, Time, + UInt32, Unknown, dtype_to_py_type, is_polars_dtype, @@ -57,9 +58,10 @@ from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa +from polars.functions.eager import concat with contextlib.suppress(ImportError): # Module not available when building docs - from polars.polars import PySeries + from polars.polars import PySeries, get_index_type if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -454,27 +456,48 @@ def numpy_to_pyseries( return constructor( name, values, nan_to_null if dtype in (np.float32, np.float64) else strict ) - elif sum(values.shape) == 0: - # Optimize by ingesting 1D and reshaping in Rust - original_shape = values.shape - values = values.reshape(-1) - py_s = numpy_to_pyseries( - name, - values, - strict=strict, - nan_to_null=nan_to_null, - ) - return wrap_s(py_s).reshape(original_shape)._s else: original_shape = values.shape - values = values.reshape(-1) - py_s = numpy_to_pyseries( - name, - values, - strict=strict, - nan_to_null=nan_to_null, - ) - return wrap_s(py_s).reshape(original_shape)._s + values_1d = values.reshape(-1) + + if get_index_type() == UInt32: + limit = 2**32 - 1 + else: + limit = 2**64 - 1 + + if values.size <= limit: + py_s = numpy_to_pyseries( + name, + values_1d, + strict=strict, + nan_to_null=nan_to_null, + ) + return wrap_s(py_s).reshape(original_shape)._s + else: + # Process in chunk, so we don't trigger ROWS_LIMIT + offset = 0 + chunks = [] + + # Tuples are immutable, so convert to list + original_shape_chunk = list(original_shape) + # Rows size is now changed, so infer + original_shape_chunk[0] = -1 + original_shape_chunk_t = tuple(original_shape_chunk) + while True: + chunk = values_1d[offset : offset + limit] + offset += limit + if chunk.shape[0] == 0: + break + + py_s = numpy_to_pyseries( + name, + chunk, + strict=strict, + nan_to_null=nan_to_null, + ) + chunks.append(wrap_s(py_s).reshape(original_shape_chunk_t)) + + return concat(chunks)._s def series_to_pyseries( diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index 0ad1dfb985f9..dc5060c4e4c3 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -3,7 +3,7 @@ import contextlib import os from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, get_args +from typing import TYPE_CHECKING, Literal, TypedDict, get_args from polars._utils.various import normalize_filepath from polars.dependencies import json @@ -20,9 +20,9 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 11): - from typing import Self + from typing import Self, Unpack else: - from typing_extensions import Self + from typing_extensions import Self, Unpack __all__ = ["Config"] @@ -35,6 +35,7 @@ "ASCII_BORDERS_ONLY_CONDENSED", "ASCII_HORIZONTAL_ONLY", "ASCII_MARKDOWN", + "MARKDOWN", "UTF8_FULL", "UTF8_FULL_CONDENSED", "UTF8_NO_BORDERS", @@ -87,6 +88,60 @@ } +class ConfigParameters(TypedDict, total=False): + """Parameters supported by the polars Config.""" + + ascii_tables: bool | None + auto_structify: bool | None + decimal_separator: str | None + thousands_separator: str | bool | None + float_precision: int | None + fmt_float: FloatFmt | None + fmt_str_lengths: int | None + fmt_table_cell_list_len: int | None + streaming_chunk_size: int | None + tbl_cell_alignment: Literal["LEFT", "CENTER", "RIGHT"] | None + tbl_cell_numeric_alignment: Literal["LEFT", "CENTER", "RIGHT"] | None + tbl_cols: int | None + tbl_column_data_type_inline: bool | None + tbl_dataframe_shape_below: bool | None + tbl_formatting: TableFormatNames | None + tbl_hide_column_data_types: bool | None + tbl_hide_column_names: bool | None + tbl_hide_dtype_separator: bool | None + tbl_hide_dataframe_shape: bool | None + tbl_rows: int | None + tbl_width_chars: int | None + trim_decimal_zeros: bool | None + verbose: bool | None + expr_depth_warning: int + + set_ascii_tables: bool | None + set_auto_structify: bool | None + set_decimal_separator: str | None + set_thousands_separator: str | bool | None + set_float_precision: int | None + set_fmt_float: FloatFmt | None + set_fmt_str_lengths: int | None + set_fmt_table_cell_list_len: int | None + set_streaming_chunk_size: int | None + set_tbl_cell_alignment: Literal["LEFT", "CENTER", "RIGHT"] | None + set_tbl_cell_numeric_alignment: Literal["LEFT", "CENTER", "RIGHT"] | None + set_tbl_cols: int | None + set_tbl_column_data_type_inline: bool | None + set_tbl_dataframe_shape_below: bool | None + set_tbl_formatting: TableFormatNames | None + set_tbl_hide_column_data_types: bool | None + set_tbl_hide_column_names: bool | None + set_tbl_hide_dtype_separator: bool | None + set_tbl_hide_dataframe_shape: bool | None + set_tbl_rows: int | None + set_tbl_width_chars: int | None + set_trim_decimal_zeros: bool | None + set_verbose: bool | None + set_expr_depth_warning: int + + class Config(contextlib.ContextDecorator): """ Configure polars; offers options for table formatting and more. @@ -118,7 +173,9 @@ class Config(contextlib.ContextDecorator): _original_state: str = "" - def __init__(self, *, restore_defaults: bool = False, **options: Any) -> None: + def __init__( + self, *, restore_defaults: bool = False, **options: Unpack[ConfigParameters] + ) -> None: """ Initialise a Config object instance for context manager usage. @@ -139,7 +196,7 @@ def __init__(self, *, restore_defaults: bool = False, **options: Any) -> None: >>> df = pl.DataFrame({"abc": [1.0, 2.5, 5.0], "xyz": [True, False, True]}) >>> with pl.Config( ... # these options will be set for scope duration - ... tbl_formatting="ASCII_MARKDOWN", + ... tbl_formatting="MARKDOWN", ... tbl_hide_dataframe_shape=True, ... tbl_rows=10, ... ): @@ -981,7 +1038,8 @@ def set_tbl_formatting( * "ASCII_BORDERS_ONLY": ASCII, borders only. * "ASCII_BORDERS_ONLY_CONDENSED": ASCII, borders only, dense row spacing. * "ASCII_HORIZONTAL_ONLY": ASCII, horizontal lines only. - * "ASCII_MARKDOWN": ASCII, Markdown compatible. + * "ASCII_MARKDOWN": Markdown format (ascii ellipses for truncated values). + * "MARKDOWN": Markdown format (utf8 ellipses for truncated values). * "UTF8_FULL": UTF8, with all borders and lines, including row dividers. * "UTF8_FULL_CONDENSED": Same as UTF8_FULL, but with dense row spacing. * "UTF8_NO_BORDERS": UTF8, no borders. @@ -1004,7 +1062,7 @@ def set_tbl_formatting( ... {"abc": [-2.5, 5.0], "mno": ["hello", "world"], "xyz": [True, False]} ... ) >>> with pl.Config( - ... tbl_formatting="ASCII_MARKDOWN", + ... tbl_formatting="MARKDOWN", ... tbl_hide_column_data_types=True, ... tbl_hide_dataframe_shape=True, ... ): diff --git a/py-polars/polars/convert/general.py b/py-polars/polars/convert/general.py index cee4b925e9e9..f80526e5a2c9 100644 --- a/py-polars/polars/convert/general.py +++ b/py-polars/polars/convert/general.py @@ -4,7 +4,7 @@ import itertools import re from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Literal, overload import polars._reexport as pl from polars import functions as F @@ -98,7 +98,7 @@ def from_dict( def from_dicts( - data: Sequence[dict[str, Any]], + data: Iterable[dict[str, Any]], schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, @@ -487,15 +487,26 @@ def from_pandas( @overload def from_pandas( - data: pd.Series[Any] | pd.Index[Any], + data: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, *, schema_overrides: SchemaDict | None = ..., rechunk: bool = ..., nan_to_null: bool = ..., - include_index: bool = ..., + include_index: Literal[False] = ..., ) -> Series: ... +@overload +def from_pandas( + data: pd.Series[Any], + *, + schema_overrides: SchemaDict | None = ..., + rechunk: bool = ..., + nan_to_null: bool = ..., + include_index: Literal[True] = ..., +) -> DataFrame: ... + + def from_pandas( data: pd.DataFrame | pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, *, @@ -525,8 +536,8 @@ def from_pandas( Load any non-default pandas indexes as columns. .. note:: - If the input is a pandas ``Series`` or ``DataFrame`` and has a nameless - index which just enumerates the rows, then it will not be included in the + If the input is a pandas ``DataFrame`` and has a nameless index + which just enumerates the rows, then it will not be included in the result, regardless of this parameter. If you want to be sure to include it, please call ``.reset_index()`` prior to calling this function. @@ -566,6 +577,9 @@ def from_pandas( 3 ] """ + if include_index and isinstance(data, pd.Series): + data = data.reset_index() + if isinstance(data, (pd.Series, pd.Index, pd.DatetimeIndex)): return wrap_s(pandas_to_pyseries("", data, nan_to_null=nan_to_null)) elif isinstance(data, pd.DataFrame): @@ -724,6 +738,7 @@ def _from_dataframe_repr(m: re.Match[str]) -> DataFrame: if schema and data and (n_extend_cols := (len(schema) - len(data))) > 0: empty_data = [None] * len(data[0]) data.extend((pl.Series(empty_data, dtype=String)) for _ in range(n_extend_cols)) + for dtype in set(schema.values()): if dtype in (List, Struct, Object): msg = ( diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index eca65a22c8c8..721a02c7fec6 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -73,6 +73,7 @@ Float64, Int32, Int64, + Null, Object, String, Struct, @@ -627,17 +628,17 @@ def plot(self) -> DataFramePlot: - `df.plot.line(**kwargs)` is shorthand for - `alt.Chart(df).mark_line().encode(**kwargs).interactive()` + `alt.Chart(df).mark_line(tooltip=True).encode(**kwargs).interactive()` - `df.plot.point(**kwargs)` is shorthand for - `alt.Chart(df).mark_point().encode(**kwargs).interactive()` (and + `alt.Chart(df).mark_point(tooltip=True).encode(**kwargs).interactive()` (and `plot.scatter` is provided as an alias) - `df.plot.bar(**kwargs)` is shorthand for - `alt.Chart(df).mark_bar().encode(**kwargs).interactive()` + `alt.Chart(df).mark_bar(tooltip=True).encode(**kwargs).interactive()` - for any other attribute `attr`, `df.plot.attr(**kwargs)` is shorthand for - `alt.Chart(df).mark_attr().encode(**kwargs).interactive()` + `alt.Chart(df).mark_attr(tooltip=True).encode(**kwargs).interactive()` Examples -------- @@ -911,7 +912,7 @@ def schema(self) -> Schema: >>> df.schema Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return Schema(zip(self.columns, self.dtypes)) + return Schema(zip(self.columns, self.dtypes), check_dtypes=False) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None @@ -1073,6 +1074,7 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame: other = DataFrame([s.alias(f"n{i}") for i in range(len(self.columns))]) orig_dtypes = other.dtypes + # TODO: Dispatch to a native floordiv other = self._cast_all_from_to(other, INTEGER_DTYPES, Float64) df = self._from_pydf(self._df.div_df(other._df)) @@ -1085,7 +1087,8 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame: int_casts = [ col(column).cast(tp) for i, (column, tp) in enumerate(self.schema.items()) - if tp.is_integer() and orig_dtypes[i].is_integer() + if tp.is_integer() + and (orig_dtypes[i].is_integer() or orig_dtypes[i] == Null) ] if int_casts: return df.with_columns(int_casts) @@ -1977,7 +1980,7 @@ def to_jax( Create the Array on a specific GPU device: - >>> gpu_device = jax.devices("gpu")[1]) # doctest: +SKIP + >>> gpu_device = jax.devices("gpu")[1] # doctest: +SKIP >>> a = df.to_jax(device=gpu_device) # doctest: +SKIP >>> a.device() # doctest: +SKIP GpuDevice(id=1, process_index=0) @@ -3121,8 +3124,9 @@ def write_excel( If the table has headers, provide autofilter capability. autofit : bool Calculate individual column widths from the data. - hidden_columns : list - A list or selector representing table columns to hide in the worksheet. + hidden_columns : str | list + A column name, list of column names, or a selector representing table + columns to mark as hidden in the output worksheet. hide_gridlines : bool Do not display any gridlines on the output worksheet. sheet_zoom : int @@ -3442,10 +3446,15 @@ def write_excel( include_header=include_header, format_cache=fmt_cache, ) + # additional column-level properties if hidden_columns is None: - hidden_columns = () - hidden_columns = _expand_selectors(df, hidden_columns) + hidden = set() + elif isinstance(hidden_columns, str): + hidden = {hidden_columns} + else: + hidden = set(_expand_selectors(df, hidden_columns)) + if isinstance(column_widths, int): column_widths = dict.fromkeys(df.columns, column_widths) else: @@ -3455,9 +3464,8 @@ def write_excel( column_widths = _unpack_multi_column_dict(column_widths or {}) # type: ignore[assignment] for column in df.columns: - col_idx, options = table_start[1] + df.get_column_index(column), {} - if column in hidden_columns: - options = {"hidden": True} + options = {"hidden": True} if column in hidden else {} + col_idx = table_start[1] + df.get_column_index(column) if column in column_widths: # type: ignore[operator] ws.set_column_pixels( col_idx, @@ -3466,6 +3474,8 @@ def write_excel( None, options, ) + elif options: + ws.set_column(col_idx, col_idx, None, None, options) # finally, inject any sparklines into the table for column, params in (sparklines or {}).items(): @@ -3883,11 +3893,14 @@ def write_database( Select the engine to use for writing frame data; only necessary when supplying a URI string (defaults to 'sqlalchemy' if unset) engine_options - Additional options to pass to the engine's associated insert method: - - * "sqlalchemy" - currently inserts using Pandas' `to_sql` method, though - this will eventually be phased out in favor of a native solution. - * "adbc" - inserts using the ADBC cursor's `adbc_ingest` method. + Additional options to pass to the insert method associated with the engine + specified by the option `engine`. + + * Setting `engine` to "sqlalchemy" currently inserts using Pandas' `to_sql` + method (though this will eventually be phased out in favor of a native + solution). + * Setting `engine` to "adbc" inserts using the ADBC cursor's `adbc_ingest` + method. Examples -------- diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py index 11828cd54ebc..2bf134e3c67e 100644 --- a/py-polars/polars/dataframe/plotting.py +++ b/py-polars/polars/dataframe/plotting.py @@ -30,21 +30,6 @@ Encodings: TypeAlias = dict[str, Encoding] -def _maybe_extract_shorthand(encoding: Encoding) -> Encoding: - if isinstance(encoding, alt.SchemaBase): - # e.g. for `alt.X('x:Q', axis=alt.Axis(labelAngle=30))`, return `'x:Q'` - return getattr(encoding, "shorthand", encoding) - return encoding - - -def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None: - if "tooltip" not in kwargs: - encodings["tooltip"] = [ - *[_maybe_extract_shorthand(x) for x in encodings.values()], - *[_maybe_extract_shorthand(x) for x in kwargs.values()], # type: ignore[arg-type] - ] # type: ignore[assignment] - - class DataFramePlot: """DataFrame.plot namespace.""" @@ -107,8 +92,11 @@ def bar( encodings["y"] = y if color is not None: encodings["color"] = color - _add_tooltip(encodings, **kwargs) - return self._chart.mark_bar().encode(**encodings, **kwargs).interactive() + return ( + self._chart.mark_bar(tooltip=True) + .encode(**encodings, **kwargs) + .interactive() + ) def line( self, @@ -169,8 +157,11 @@ def line( encodings["color"] = color if order is not None: encodings["order"] = order - _add_tooltip(encodings, **kwargs) - return self._chart.mark_line().encode(**encodings, **kwargs).interactive() + return ( + self._chart.mark_line(tooltip=True) + .encode(**encodings, **kwargs) + .interactive() + ) def point( self, @@ -231,9 +222,8 @@ def point( encodings["color"] = color if size is not None: encodings["size"] = size - _add_tooltip(encodings, **kwargs) return ( - self._chart.mark_point() + self._chart.mark_point(tooltip=True) .encode( **encodings, **kwargs, @@ -252,7 +242,6 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: encodings: Encodings = {} def func(**kwargs: EncodeKwds) -> alt.Chart: - _add_tooltip(encodings, **kwargs) - return method().encode(**encodings, **kwargs).interactive() + return method(tooltip=True).encode(**encodings, **kwargs).interactive() return func diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index d46d8c111581..f773cc28b6ca 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -65,10 +65,14 @@ def is_polars_dtype( - dtype: Any, *, include_unknown: bool = False + dtype: Any, + *, + include_unknown: bool = False, + require_instantiated: bool = False, ) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" - is_dtype = isinstance(dtype, (DataType, DataTypeClass)) + check_classes = DataType if require_instantiated else (DataType, DataTypeClass) + is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type] if not include_unknown: return is_dtype and dtype != Unknown diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index 9aaed1352d09..4628ee2a9c15 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -6,7 +6,7 @@ import polars._reexport as pl from polars import functions as F from polars._utils.convert import parse_as_duration_string -from polars._utils.deprecation import deprecate_function +from polars._utils.deprecation import deprecate_function, deprecate_nonkeyword_arguments from polars._utils.parse import parse_into_expression from polars._utils.unstable import unstable from polars._utils.wrap import wrap_expr @@ -35,6 +35,7 @@ class ExprDateTimeNameSpace: def __init__(self, expr: Expr) -> None: self._pyexpr = expr._pyexpr + @deprecate_nonkeyword_arguments(allowed_args=["self", "n"], version="1.12.0") def add_business_days( self, n: int | IntoExpr, @@ -97,7 +98,9 @@ def add_business_days( You can pass a custom weekend - for example, if you only take Sunday off: >>> week_mask = (True, True, True, True, True, True, False) - >>> df.with_columns(result=pl.col("start").dt.add_business_days(5, week_mask)) + >>> df.with_columns( + ... result=pl.col("start").dt.add_business_days(5, week_mask=week_mask) + ... ) shape: (2, 2) ┌────────────┬────────────┐ │ start ┆ result │ diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 98f638b0846a..4613aabd4beb 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -407,9 +407,16 @@ def to_physical(self) -> Expr: - :func:`polars.datatypes.Duration` -> :func:`polars.datatypes.Int64` - :func:`polars.datatypes.Categorical` -> :func:`polars.datatypes.UInt32` - `List(inner)` -> `List(physical of inner)` + - `Array(inner)` -> `Struct(physical of inner)` + - `Struct(fields)` -> `Array(physical of fields)` Other data types will be left unchanged. + Warning + ------- + The physical representations are an implementation detail + and not guaranteed to be stable. + Examples -------- Replicating the pandas @@ -2815,7 +2822,7 @@ def fill_nan(self, value: int | float | Expr | None) -> Expr: def forward_fill(self, limit: int | None = None) -> Expr: """ - Fill missing values with the latest seen values. + Fill missing values with the last non-null value. Parameters ---------- @@ -2851,7 +2858,7 @@ def forward_fill(self, limit: int | None = None) -> Expr: def backward_fill(self, limit: int | None = None) -> Expr: """ - Fill missing values with the next to be seen values. + Fill missing values with the next non-null value. Parameters ---------- @@ -4195,7 +4202,6 @@ def filter( Filter expressions can also take constraints as keyword arguments. - >>> import polars.selectors as cs >>> df = pl.DataFrame( ... { ... "key": ["a", "a", "a", "a", "b", "b", "b", "b", "b"], diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 48b4d1da9c49..4d239460d6b5 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -16,8 +16,8 @@ from polars._typing import ( IntoExpr, IntoExprColumn, + ListToStructWidthStrategy, NullBehavior, - ToStructStrategy, ) @@ -1092,7 +1092,7 @@ def to_array(self, width: int) -> Expr: def to_struct( self, - n_field_strategy: ToStructStrategy = "first_non_null", + n_field_strategy: ListToStructWidthStrategy = "first_non_null", fields: Sequence[str] | Callable[[int], str] | None = None, upper_bound: int = 0, ) -> Expr: @@ -1180,9 +1180,8 @@ def to_struct( [{'n': {'one': 0, 'two': 1}}, {'n': {'one': 2, 'two': 3}}] """ if isinstance(fields, Sequence): - field_names = list(fields) - pyexpr = self._pyexpr.list_to_struct(n_field_strategy, None, upper_bound) - return wrap_expr(pyexpr).struct.rename_fields(field_names) + pyexpr = self._pyexpr.list_to_struct_fixed_width(fields) + return wrap_expr(pyexpr) else: pyexpr = self._pyexpr.list_to_struct(n_field_strategy, fields, upper_bound) return wrap_expr(pyexpr) diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 7582758d5921..e94f995ee700 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -2781,6 +2781,28 @@ def concat( delimiter = "-" return self.join(delimiter, ignore_nulls=ignore_nulls) + def escape_regex(self) -> Expr: + r""" + Returns string values with all regular expression meta characters escaped. + + Examples + -------- + >>> df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + >>> df.with_columns(pl.col("text").str.escape_regex().alias("escaped")) + shape: (4, 2) + ┌──────────┬──────────────┐ + │ text ┆ escaped │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞══════════╪══════════════╡ + │ abc ┆ abc │ + │ def ┆ def │ + │ null ┆ null │ + │ abc(\w+) ┆ abc\(\\w\+\) │ + └──────────┴──────────────┘ + """ + return wrap_expr(self._pyexpr.str_escape_regex()) + def _validate_format_argument(format: str | None) -> None: if format is not None and ".%f" in format: diff --git a/py-polars/polars/expr/struct.py b/py-polars/polars/expr/struct.py index 57b8b6eddfb3..1df3a16f4411 100644 --- a/py-polars/polars/expr/struct.py +++ b/py-polars/polars/expr/struct.py @@ -152,6 +152,43 @@ def field(self, name: str | list[str], *more_names: str) -> Expr: return wrap_expr(self._pyexpr.struct_field_by_name(name)) + def unnest(self) -> Expr: + """ + Expand the struct into its individual fields. + + Alias for `Expr.struct.field("*")`. + + >>> df = pl.DataFrame( + ... { + ... "aaa": [1, 2], + ... "bbb": ["ab", "cd"], + ... "ccc": [True, None], + ... "ddd": [[1, 2], [3]], + ... } + ... ).select(pl.struct("aaa", "bbb", "ccc", "ddd").alias("struct_col")) + >>> df + shape: (2, 1) + ┌──────────────────────┐ + │ struct_col │ + │ --- │ + │ struct[4] │ + ╞══════════════════════╡ + │ {1,"ab",true,[1, 2]} │ + │ {2,"cd",null,[3]} │ + └──────────────────────┘ + >>> df.select(pl.col("struct_col").struct.unnest()) + shape: (2, 4) + ┌─────┬─────┬──────┬───────────┐ + │ aaa ┆ bbb ┆ ccc ┆ ddd │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ bool ┆ list[i64] │ + ╞═════╪═════╪══════╪═══════════╡ + │ 1 ┆ ab ┆ true ┆ [1, 2] │ + │ 2 ┆ cd ┆ null ┆ [3] │ + └─────┴─────┴──────┴───────────┘ + """ + return self.field("*") + def rename_fields(self, names: Sequence[str]) -> Expr: """ Rename the fields of the struct. diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index fedd0ac2bff0..32fbe4578059 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -26,6 +26,7 @@ from polars.functions.business import business_day_count from polars.functions.col import col from polars.functions.eager import align_frames, concat +from polars.functions.escape_regex import escape_regex from polars.functions.lazy import ( approx_n_unique, arctan2, @@ -170,4 +171,6 @@ # polars.functions.whenthen "when", "sql_expr", + # polars.functions.escape_regex + "escape_regex", ] diff --git a/py-polars/polars/functions/escape_regex.py b/py-polars/polars/functions/escape_regex.py new file mode 100644 index 000000000000..1c038347e8af --- /dev/null +++ b/py-polars/polars/functions/escape_regex.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import contextlib + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr +import polars._reexport as pl + + +def escape_regex(s: str) -> str: + r""" + Escapes string regex meta characters. + + Parameters + ---------- + s + The string that all of its meta characters will be escaped. + + """ + if isinstance(s, pl.Expr): + msg = "escape_regex function is unsupported for `Expr`, you may want use `Expr.str.escape_regex` instead" + raise TypeError(msg) + elif not isinstance(s, str): + msg = f"escape_regex function supports only `str` type, got `{type(s)}`" + raise TypeError(msg) + + return plr.escape_regex(s) diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 61cead9871d8..7a473048bc23 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1014,6 +1014,7 @@ def map_groups( ... function=lambda list_of_series: list_of_series[0] ... / list_of_series[0].sum() ... + list_of_series[1], + ... return_dtype=pl.Float64, ... ).alias("my_custom_aggregation") ... ) ... ).sort("group") diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index 68d4b604d6a6..527f1da01240 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -7,7 +7,11 @@ from pathlib import Path from typing import IO, TYPE_CHECKING, Any, overload -from polars._utils.various import is_int_sequence, is_str_sequence, normalize_filepath +from polars._utils.various import ( + is_int_sequence, + is_str_sequence, + normalize_filepath, +) from polars.dependencies import _FSSPEC_AVAILABLE, fsspec from polars.exceptions import NoDataError diff --git a/py-polars/polars/io/cloud/__init__.py b/py-polars/polars/io/cloud/__init__.py new file mode 100644 index 000000000000..f5ef9c5fd0bf --- /dev/null +++ b/py-polars/polars/io/cloud/__init__.py @@ -0,0 +1,15 @@ +from polars.io.cloud.credential_provider import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) + +__all__ = [ + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", +] diff --git a/py-polars/polars/io/cloud/_utils.py b/py-polars/polars/io/cloud/_utils.py new file mode 100644 index 000000000000..7279838aa005 --- /dev/null +++ b/py-polars/polars/io/cloud/_utils.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from pathlib import Path +from typing import IO + +from polars._utils.various import is_path_or_str_sequence + + +def _first_scan_path( + source: str + | Path + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | list[bytes], +) -> str | Path | None: + if isinstance(source, (str, Path)): + return source + elif is_path_or_str_sequence(source) and source: + return source[0] + + return None + + +def _get_path_scheme(path: str | Path) -> str | None: + splitted = str(path).split("://", maxsplit=1) + + return None if not splitted else splitted[0] + + +def _is_aws_cloud(scheme: str) -> bool: + return any(scheme == x for x in ["s3", "s3a"]) + + +def _is_gcp_cloud(scheme: str) -> bool: + return any(scheme == x for x in ["gs", "gcp", "gcs"]) diff --git a/py-polars/polars/io/cloud/credential_provider.py b/py-polars/polars/io/cloud/credential_provider.py new file mode 100644 index 000000000000..26e8ebf6826e --- /dev/null +++ b/py-polars/polars/io/cloud/credential_provider.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import abc +import importlib.util +import os +import sys +import zoneinfo +from typing import IO, TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict, Union + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + from pathlib import Path + +from polars._utils.unstable import issue_unstable_warning + +# These typedefs are here to avoid circular import issues, as +# `CredentialProviderFunction` specifies "CredentialProvider" +CredentialProviderFunctionReturn: TypeAlias = tuple[ + dict[str, Optional[str]], Optional[int] +] + +CredentialProviderFunction: TypeAlias = Union[ + Callable[[], CredentialProviderFunctionReturn], "CredentialProvider" +] + + +class AWSAssumeRoleKWArgs(TypedDict): + """Parameters for [STS.Client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role).""" + + RoleArn: str + RoleSessionName: str + PolicyArns: list[dict[str, str]] + Policy: str + DurationSeconds: int + Tags: list[dict[str, str]] + TransitiveTagKeys: list[str] + ExternalId: str + SerialNumber: str + TokenCode: str + SourceIdentity: str + ProvidedContexts: list[dict[str, str]] + + +class CredentialProvider(abc.ABC): + """ + Base class for credential providers. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + @abc.abstractmethod + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetches the credentials.""" + + +class CredentialProviderAWS(CredentialProvider): + """ + AWS Credential Provider. + + Using this requires the `boto3` Python package to be installed. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__( + self, + *, + profile_name: str | None = None, + assume_role: AWSAssumeRoleKWArgs | None = None, + ) -> None: + """ + Initialize a credential provider for AWS. + + Parameters + ---------- + profile_name : str + Profile name to use from credentials file. + assume_role : AWSAssumeRoleKWArgs | None + Configure a role to assume. These are passed as kwarg parameters to + [STS.client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role) + """ + msg = "`CredentialProviderAWS` functionality is considered unstable" + issue_unstable_warning(msg) + + self._check_module_availability() + self.profile_name = profile_name + self.assume_role = assume_role + + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetch the credentials for the configured profile name.""" + import boto3 + + session = boto3.Session(profile_name=self.profile_name) + + if self.assume_role is not None: + return self._finish_assume_role(session) + + creds = session.get_credentials() + + if creds is None: + msg = "unexpected None value returned from boto3.Session.get_credentials()" + raise ValueError(msg) + + return { + "aws_access_key_id": creds.access_key, + "aws_secret_access_key": creds.secret_key, + "aws_session_token": creds.token, + }, None + + def _finish_assume_role(self, session: Any) -> CredentialProviderFunctionReturn: + client = session.client("sts") + + sts_response = client.assume_role(**self.assume_role) + creds = sts_response["Credentials"] + + expiry = creds["Expiration"] + + if expiry.tzinfo is None: + msg = "expiration time in STS response did not contain timezone information" + raise ValueError(msg) + + return { + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds["SecretAccessKey"], + "aws_session_token": creds["SessionToken"], + }, int(expiry.timestamp()) + + @classmethod + def _check_module_availability(cls) -> None: + if importlib.util.find_spec("boto3") is None: + msg = "boto3 must be installed to use `CredentialProviderAWS`" + raise ImportError(msg) + + +class CredentialProviderGCP(CredentialProvider): + """ + GCP Credential Provider. + + Using this requires the `google-auth` Python package to be installed. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__(self) -> None: + """Initialize a credential provider for Google Cloud (GCP).""" + msg = "`CredentialProviderAWS` functionality is considered unstable" + issue_unstable_warning(msg) + + self._check_module_availability() + + import google.auth + import google.auth.credentials + + # CI runs with both `mypy` and `mypy --allow-untyped-calls` depending on + # Python version. If we add a `type: ignore[no-untyped-call]`, then the + # check that runs with `--allow-untyped-calls` will complain about an + # unused "type: ignore" comment. And if we don't add the ignore, then + # he check that runs `mypy` will complain. + # + # So we just bypass it with a __dict__[] (because ruff complains about + # getattr) :| + creds, _ = google.auth.__dict__["default"]() + self.creds = creds + + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetch the credentials for the configured profile name.""" + import google.auth.transport.requests + + self.creds.refresh(google.auth.transport.requests.__dict__["Request"]()) + + return {"bearer_token": self.creds.token}, ( + int( + ( + expiry.replace(tzinfo=zoneinfo.ZoneInfo("UTC")) + if expiry.tzinfo is None + else expiry + ).timestamp() + ) + if (expiry := self.creds.expiry) is not None + else None + ) + + @classmethod + def _check_module_availability(cls) -> None: + if importlib.util.find_spec("google.auth") is None: + msg = "google-auth must be installed to use `CredentialProviderGCP`" + raise ImportError(msg) + + +def _maybe_init_credential_provider( + credential_provider: CredentialProviderFunction | Literal["auto"] | None, + source: str + | Path + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | list[bytes], + storage_options: dict[str, Any] | None, + caller_name: str, +) -> CredentialProviderFunction | CredentialProvider | None: + from polars.io.cloud._utils import ( + _first_scan_path, + _get_path_scheme, + _is_aws_cloud, + _is_gcp_cloud, + ) + + if credential_provider is not None: + msg = f"The `credential_provider` parameter of `{caller_name}` is considered unstable." + issue_unstable_warning(msg) + + if credential_provider != "auto": + return credential_provider + + if storage_options is not None: + return None + + verbose = os.getenv("POLARS_VERBOSE") == "1" + + if (path := _first_scan_path(source)) is None: + return None + + if (scheme := _get_path_scheme(path)) is None: + return None + + provider = None + + try: + provider = ( + CredentialProviderAWS() + if _is_aws_cloud(scheme) + else CredentialProviderGCP() + if _is_gcp_cloud(scheme) + else None + ) + except ImportError as e: + if verbose: + msg = f"Unable to auto-select credential provider: {e}" + print(msg, file=sys.stderr) + + if provider is not None and verbose: + msg = f"Auto-selected credential provider: {type(provider).__name__}" + print(msg, file=sys.stderr) + + return provider diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 7ded05836f90..daebcf452c1e 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from io import BytesIO, StringIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable +from typing import IO, TYPE_CHECKING, Any, Callable, Literal import polars._reexport as pl import polars.functions as F @@ -24,6 +24,7 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _maybe_init_credential_provider from polars.io.csv._utils import _check_arg_is_1byte, _update_columns from polars.io.csv.batched_reader import BatchedCsvReader @@ -35,6 +36,7 @@ from polars import DataFrame, LazyFrame from polars._typing import CsvEncoding, PolarsDataType, SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31") @@ -1034,6 +1036,7 @@ def scan_csv( decimal_comma: bool = False, glob: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -1154,6 +1157,14 @@ def scan_csv( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. file_cache_ttl @@ -1259,6 +1270,10 @@ def with_column_names(cols: list[str]) -> list[str]: if not infer_schema: infer_schema_length = 0 + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_csv" + ) + return _scan_csv_impl( source, has_header=has_header, @@ -1289,6 +1304,7 @@ def with_column_names(cols: list[str]) -> list[str]: glob=glob, retries=retries, storage_options=storage_options, + credential_provider=credential_provider, file_cache_ttl=file_cache_ttl, include_file_paths=include_file_paths, ) @@ -1332,6 +1348,7 @@ def _scan_csv_impl( decimal_comma: bool = False, glob: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -1384,6 +1401,7 @@ def _scan_csv_impl( glob=glob, schema=schema, cloud_options=storage_options, + credential_provider=credential_provider, retries=retries, file_cache_ttl=file_cache_ttl, include_file_paths=include_file_paths, diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index 278e3e8e0738..cb6b6b92ff1b 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -338,7 +338,7 @@ def _inject_type_overrides( @staticmethod def _is_alchemy_async(conn: Any) -> bool: - """Check if the cursor/connection/session object is async.""" + """Check if the given connection is SQLALchemy async.""" try: from sqlalchemy.ext.asyncio import ( AsyncConnection, @@ -352,7 +352,7 @@ def _is_alchemy_async(conn: Any) -> bool: @staticmethod def _is_alchemy_engine(conn: Any) -> bool: - """Check if the cursor/connection/session object is async.""" + """Check if the given connection is a SQLAlchemy Engine.""" from sqlalchemy.engine import Engine if isinstance(conn, Engine): @@ -364,9 +364,14 @@ def _is_alchemy_engine(conn: Any) -> bool: except ImportError: return False + @staticmethod + def _is_alchemy_object(conn: Any) -> bool: + """Check if the given connection is a SQLAlchemy object (of any kind).""" + return type(conn).__module__.split(".", 1)[0] == "sqlalchemy" + @staticmethod def _is_alchemy_session(conn: Any) -> bool: - """Check if the cursor/connection/session object is async.""" + """Check if the given connection is a SQLAlchemy Session object.""" from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, sessionmaker @@ -392,7 +397,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor: return conn.engine.raw_connection().cursor() elif conn.engine.driver == "duckdb_engine": self.driver_name = "duckdb" - return conn.engine.raw_connection().driver_connection + return conn elif self._is_alchemy_engine(conn): # note: if we create it, we can close it self.can_close_cursor = True @@ -482,7 +487,7 @@ def execute( options = options or {} - if self.driver_name == "sqlalchemy": + if self._is_alchemy_object(self.cursor): cursor_execute, options, query = self._sqlalchemy_setup(query, options) else: cursor_execute = self.cursor.execute @@ -505,8 +510,11 @@ def execute( ) result = cursor_execute(query, *positional_options) - # note: some cursors execute in-place + # note: some cursors execute in-place, some access results via a property result = self.cursor if result is None else result + if self.driver_name == "duckdb": + result = result.cursor + self.result = result return self diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index ac5dcaaac3e7..21e436dc0557 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -25,10 +25,12 @@ except ImportError: Selectable: TypeAlias = Any # type: ignore[no-redef] + from sqlalchemy.sql.elements import TextClause + @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: Literal[False] = ..., @@ -41,7 +43,7 @@ def read_database( @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: Literal[True], @@ -54,7 +56,7 @@ def read_database( @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: bool, @@ -66,7 +68,7 @@ def read_database( def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: bool = False, @@ -263,6 +265,51 @@ def read_database( ) +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["adbc"], + schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: list[str] | str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["connectorx"] | None = None, + schema_overrides: SchemaDict | None = None, + execute_options: None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: DbReadEngine | None = None, + schema_overrides: None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + def read_database_uri( query: list[str] | str, uri: str, diff --git a/py-polars/polars/io/delta.py b/py-polars/polars/io/delta.py index 9b0219fa8dff..b2b27c74a20f 100644 --- a/py-polars/polars/io/delta.py +++ b/py-polars/polars/io/delta.py @@ -1,10 +1,12 @@ from __future__ import annotations +import warnings from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse +from polars import DataFrame from polars.convert import from_arrow from polars.datatypes import Null, Time from polars.datatypes.convert import unpack_dtypes @@ -12,11 +14,13 @@ from polars.io.pyarrow_dataset import scan_pyarrow_dataset if TYPE_CHECKING: - from polars import DataFrame, DataType, LazyFrame + from deltalake import DeltaTable + + from polars import DataType, LazyFrame def read_delta( - source: str, + source: str | DeltaTable, *, version: int | str | datetime | None = None, columns: list[str] | None = None, @@ -31,7 +35,7 @@ def read_delta( Parameters ---------- source - Path or URI to the root of the Delta lake table. + DeltaTable or a Path or URI to the root of the Delta lake table. Note: For Local filesystem, absolute and relative paths are supported but for the supported object storages - GCS, Azure and S3 full URI must be provided. @@ -138,22 +142,23 @@ def read_delta( if pyarrow_options is None: pyarrow_options = {} - resolved_uri = _resolve_delta_lake_uri(source) - dl_tbl = _get_delta_lake_table( - table_path=resolved_uri, + table_path=source, version=version, storage_options=storage_options, delta_table_options=delta_table_options, ) - return from_arrow( - dl_tbl.to_pyarrow_table(columns=columns, **pyarrow_options), rechunk=rechunk - ) # type: ignore[return-value] + return cast( + DataFrame, + from_arrow( + dl_tbl.to_pyarrow_table(columns=columns, **pyarrow_options), rechunk=rechunk + ), + ) def scan_delta( - source: str, + source: str | DeltaTable, *, version: int | str | datetime | None = None, storage_options: dict[str, Any] | None = None, @@ -166,7 +171,7 @@ def scan_delta( Parameters ---------- source - Path or URI to the root of the Delta lake table. + DeltaTable or a Path or URI to the root of the Delta lake table. Note: For Local filesystem, absolute and relative paths are supported but for the supported object storages - GCS, Azure and S3 full URI must be provided. @@ -274,9 +279,8 @@ def scan_delta( if pyarrow_options is None: pyarrow_options = {} - resolved_uri = _resolve_delta_lake_uri(source) dl_tbl = _get_delta_lake_table( - table_path=resolved_uri, + table_path=source, version=version, storage_options=storage_options, delta_table_options=delta_table_options, @@ -299,7 +303,7 @@ def _resolve_delta_lake_uri(table_uri: str, *, strict: bool = True) -> str: def _get_delta_lake_table( - table_path: str, + table_path: str | DeltaTable, version: int | str | datetime | None = None, storage_options: dict[str, Any] | None = None, delta_table_options: dict[str, Any] | None = None, @@ -314,12 +318,27 @@ def _get_delta_lake_table( """ _check_if_delta_available() + if isinstance(table_path, deltalake.DeltaTable): + if any( + [ + version is not None, + storage_options is not None, + delta_table_options is not None, + ] + ): + warnings.warn( + """When supplying a DeltaTable directly, `version`, `storage_options`, and `delta_table_options` are ignored. + To silence this warning, don't supply those parameters.""", + RuntimeWarning, + stacklevel=1, + ) + return table_path if delta_table_options is None: delta_table_options = {} - + resolved_uri = _resolve_delta_lake_uri(table_path) if not isinstance(version, (str, datetime)): dl_tbl = deltalake.DeltaTable( - table_path, + resolved_uri, version=version, storage_options=storage_options, **delta_table_options, diff --git a/py-polars/polars/io/iceberg.py b/py-polars/polars/io/iceberg.py index d816703f56ac..5d47eb2f389e 100644 --- a/py-polars/polars/io/iceberg.py +++ b/py-polars/polars/io/iceberg.py @@ -42,6 +42,7 @@ def scan_iceberg( source: str | Table, *, + snapshot_id: int | None = None, storage_options: dict[str, Any] | None = None, ) -> LazyFrame: """ @@ -54,6 +55,8 @@ def scan_iceberg( Note: For Local filesystem, absolute and relative paths are supported but for the supported object storages - GCS, Azure and S3 full URI must be provided. + snapshot_id + The snapshot ID to scan from. storage_options Extra options for the storage backends supported by `pyiceberg`. For cloud storages, this may include configurations for authentication etc. @@ -126,6 +129,12 @@ def scan_iceberg( >>> pl.scan_iceberg( ... table_path, storage_options=storage_options ... ).collect() # doctest: +SKIP + + Creates a scan for an Iceberg table using a specific snapshot ID. + + >>> table_path = "/path/to/iceberg-table/metadata.json" + >>> snapshot_id = 7051579356916758811 + >>> pl.scan_iceberg(table_path, snapshot_id=snapshot_id).collect() # doctest: +SKIP """ from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.table import StaticTable @@ -135,7 +144,13 @@ def scan_iceberg( metadata_location=source, properties=storage_options or {} ) - func = partial(_scan_pyarrow_dataset_impl, source) + if snapshot_id is not None: + snapshot = source.snapshot_by_id(snapshot_id) + if snapshot is None: + msg = f"Snapshot ID not found: {snapshot_id}" + raise ValueError(msg) + + func = partial(_scan_pyarrow_dataset_impl, source, snapshot_id=snapshot_id) arrow_schema = schema_to_pyarrow(source.schema()) return pl.LazyFrame._scan_python_function(arrow_schema, func, pyarrow=True) @@ -145,6 +160,7 @@ def _scan_pyarrow_dataset_impl( with_columns: list[str] | None = None, predicate: str = "", n_rows: int | None = None, + snapshot_id: int | None = None, **kwargs: Any, ) -> DataFrame | Series: """ @@ -160,6 +176,8 @@ def _scan_pyarrow_dataset_impl( pyarrow expression that can be evaluated with eval n_rows: Materialize only n rows from the arrow dataset. + snapshot_id: + The snapshot ID to scan from. batch_size The maximum row count for scanned pyarrow record batches. kwargs: @@ -171,7 +189,7 @@ def _scan_pyarrow_dataset_impl( """ from polars import from_arrow - scan = tbl.scan(limit=n_rows) + scan = tbl.scan(limit=n_rows, snapshot_id=snapshot_id) if with_columns is not None: scan = scan.select(*with_columns) diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 6d64a560d094..1348134fc0d6 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -3,7 +3,7 @@ import contextlib import os from pathlib import Path -from typing import IO, TYPE_CHECKING, Any +from typing import IO, TYPE_CHECKING, Any, Literal import polars._reexport as pl import polars.functions as F @@ -22,6 +22,7 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _maybe_init_credential_provider with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PyLazyFrame @@ -32,6 +33,7 @@ from polars import DataFrame, DataType, LazyFrame from polars._typing import SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @@ -362,6 +364,7 @@ def scan_ipc( row_index_name: str | None = None, row_index_offset: int = 0, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, memory_map: bool = True, retries: int = 2, file_cache_ttl: int | None = None, @@ -407,6 +410,15 @@ def scan_ipc( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + memory_map Try to memory map the file. This can greatly improve performance on repeated queries as the OS may cache pages. @@ -451,6 +463,16 @@ def scan_ipc( # Memory Mapping is now a no-op _ = memory_map + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_parquet" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + pylf = PyLazyFrame.new_from_ipc( source, sources, @@ -459,6 +481,7 @@ def scan_ipc( rechunk, parse_row_index_args(row_index_name, row_index_offset), cloud_options=storage_options, + credential_provider=credential_provider, retries=retries, file_cache_ttl=file_cache_ttl, hive_partitioning=hive_partitioning, diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 7a5fb2c0d1e6..983b8cddcfe1 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -4,13 +4,14 @@ from collections.abc import Sequence from io import BytesIO, StringIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any +from typing import IO, TYPE_CHECKING, Any, Literal from polars._utils.deprecation import deprecate_renamed_parameter from polars._utils.various import is_path_or_str_sequence, normalize_filepath from polars._utils.wrap import wrap_df, wrap_ldf from polars.datatypes import N_INFER_DEFAULT from polars.io._utils import parse_row_index_args +from polars.io.cloud.credential_provider import _maybe_init_credential_provider with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PyLazyFrame @@ -20,6 +21,7 @@ from polars import DataFrame, LazyFrame from polars._typing import SchemaDefinition + from polars.io.cloud import CredentialProviderFunction def read_ndjson( @@ -36,6 +38,7 @@ def read_ndjson( row_index_offset: int = 0, ignore_errors: bool = False, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -96,6 +99,14 @@ def read_ndjson( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. file_cache_ttl @@ -147,6 +158,10 @@ def read_ndjson( return df + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "read_ndjson" + ) + return scan_ndjson( source, schema=schema, @@ -162,6 +177,7 @@ def read_ndjson( include_file_paths=include_file_paths, retries=retries, storage_options=storage_options, + credential_provider=credential_provider, file_cache_ttl=file_cache_ttl, ).collect() @@ -190,6 +206,7 @@ def scan_ndjson( row_index_offset: int = 0, ignore_errors: bool = False, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -249,6 +266,14 @@ def scan_ndjson( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. file_cache_ttl @@ -276,6 +301,10 @@ def scan_ndjson( msg = "'infer_schema_length' should be positive" raise ValueError(msg) + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_ndjson" + ) + if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] else: @@ -297,6 +326,7 @@ def scan_ndjson( include_file_paths=include_file_paths, retries=retries, cloud_options=storage_options, + credential_provider=credential_provider, file_cache_ttl=file_cache_ttl, ) return wrap_ldf(pylf) diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index c27ee99ca94b..16ff7f614349 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -21,27 +21,24 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _maybe_init_credential_provider with contextlib.suppress(ImportError): from polars.polars import PyLazyFrame from polars.polars import read_parquet_schema as _read_parquet_schema if TYPE_CHECKING: + from typing import Literal + from polars import DataFrame, DataType, LazyFrame - from polars._typing import ParallelStrategy, SchemaDict + from polars._typing import ParallelStrategy, ScanSource, SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_parquet( - source: str - | Path - | IO[bytes] - | bytes - | list[str] - | list[Path] - | list[IO[bytes]] - | list[bytes], + source: ScanSource, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -57,6 +54,7 @@ def read_parquet( rechunk: bool = False, low_memory: bool = False, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, use_pyarrow: bool = False, pyarrow_options: dict[str, Any] | None = None, @@ -139,6 +137,14 @@ def read_parquet( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. use_pyarrow @@ -203,16 +209,9 @@ def read_parquet( rechunk=rechunk, ) - # Read file and bytes inputs using `read_parquet` - if isinstance(source, bytes): - source = io.BytesIO(source) - elif isinstance(source, list) and len(source) > 0 and isinstance(source[0], bytes): - assert all(isinstance(s, bytes) for s in source) - source = [io.BytesIO(s) for s in source] # type: ignore[arg-type, assignment] - # For other inputs, defer to `scan_parquet` lf = scan_parquet( - source, # type: ignore[arg-type] + source, n_rows=n_rows, row_index_name=row_index_name, row_index_offset=row_index_offset, @@ -226,6 +225,7 @@ def read_parquet( low_memory=low_memory, cache=False, storage_options=storage_options, + credential_provider=credential_provider, retries=retries, glob=glob, include_file_paths=include_file_paths, @@ -322,7 +322,7 @@ def read_parquet_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, Dat @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_parquet( - source: str | Path | IO[bytes] | list[str] | list[Path] | list[IO[bytes]], + source: ScanSource, *, n_rows: int | None = None, row_index_name: str | None = None, @@ -338,6 +338,7 @@ def scan_parquet( low_memory: bool = False, cache: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, include_file_paths: str | None = None, allow_missing_columns: bool = False, @@ -426,6 +427,14 @@ def scan_parquet( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. include_file_paths @@ -474,6 +483,10 @@ def scan_parquet( normalize_filepath(source, check_not_directory=False) for source in source ] + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_parquet" + ) + return _scan_parquet_impl( source, # type: ignore[arg-type] n_rows=n_rows, @@ -483,6 +496,7 @@ def scan_parquet( row_index_name=row_index_name, row_index_offset=row_index_offset, storage_options=storage_options, + credential_provider=credential_provider, low_memory=low_memory, use_statistics=use_statistics, hive_partitioning=hive_partitioning, @@ -506,6 +520,7 @@ def _scan_parquet_impl( row_index_name: str | None = None, row_index_offset: int = 0, storage_options: dict[str, object] | None = None, + credential_provider: CredentialProviderFunction | None = None, low_memory: bool = False, use_statistics: bool = True, hive_partitioning: bool | None = None, @@ -539,6 +554,7 @@ def _scan_parquet_impl( parse_row_index_args(row_index_name, row_index_offset), low_memory, cloud_options=storage_options, + credential_provider=credential_provider, use_statistics=use_statistics, hive_partitioning=hive_partitioning, schema=schema, diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 2b04fc1ac4dc..1d4bb5fe90b6 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -55,6 +55,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -72,6 +73,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -89,6 +91,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> NoReturn: ... @@ -108,6 +111,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -125,6 +129,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -142,6 +147,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -160,6 +166,7 @@ def read_excel( columns: Sequence[int] | Sequence[str] | None = None, schema_overrides: SchemaDict | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, + drop_empty_rows: bool = True, raise_if_empty: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: """ @@ -229,6 +236,8 @@ def read_excel( entire dataset is scanned to determine the dtypes, which can slow parsing for large workbooks. Note that only the "calamine" and "xlsx2csv" engines support this parameter. + drop_empty_rows + Indicate whether to omit empty rows when reading data into the DataFrame. raise_if_empty When there is no data in the sheet,`NoDataError` is raised. If this parameter is set to False, an empty DataFrame (with no columns) is returned instead. @@ -299,6 +308,7 @@ def read_excel( raise_if_empty=raise_if_empty, has_header=has_header, columns=columns, + drop_empty_rows=drop_empty_rows, ) @@ -312,6 +322,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -326,6 +337,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -340,6 +352,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> NoReturn: ... @@ -354,6 +367,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -368,6 +382,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -382,6 +397,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., + drop_empty_rows: bool = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -395,6 +411,7 @@ def read_ods( columns: Sequence[int] | Sequence[str] | None = None, schema_overrides: SchemaDict | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, + drop_empty_rows: bool = True, raise_if_empty: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: """ @@ -429,6 +446,8 @@ def read_ods( The maximum number of rows to scan for schema inference. If set to `None`, the entire dataset is scanned to determine the dtypes, which can slow parsing for large workbooks. + drop_empty_rows + Indicate whether to omit empty rows when reading data into the DataFrame. raise_if_empty When there is no data in the sheet,`NoDataError` is raised. If this parameter is set to False, an empty DataFrame (with no columns) is returned instead. @@ -470,6 +489,7 @@ def read_ods( schema_overrides=schema_overrides, infer_schema_length=infer_schema_length, raise_if_empty=raise_if_empty, + drop_empty_rows=drop_empty_rows, has_header=has_header, columns=columns, ) @@ -530,6 +550,7 @@ def _read_spreadsheet( columns: Sequence[int] | Sequence[str] | None = None, has_header: bool = True, raise_if_empty: bool = True, + drop_empty_rows: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: if isinstance(source, (str, Path)): source = normalize_filepath(source) @@ -544,7 +565,7 @@ def _read_spreadsheet( infer_schema_length=infer_schema_length, ) engine_options = (engine_options or {}).copy() - schema_overrides = dict(schema_overrides or {}) + schema_overrides = pl.Schema(schema_overrides or {}) # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( @@ -561,6 +582,7 @@ def _read_spreadsheet( read_options=read_options, raise_if_empty=raise_if_empty, columns=columns, + drop_empty_rows=drop_empty_rows, ) for name in sheet_names } @@ -749,6 +771,7 @@ def _csv_buffer_to_frame( read_options: dict[str, Any], schema_overrides: SchemaDict | None, raise_if_empty: bool, + drop_empty_rows: bool, ) -> pl.DataFrame: """Translate StringIO buffer containing delimited data as a DataFrame.""" # handle (completely) empty sheet data @@ -782,11 +805,19 @@ def _csv_buffer_to_frame( separator=separator, **read_options, ) - return _drop_null_data(df, raise_if_empty=raise_if_empty) + return _drop_null_data( + df, raise_if_empty=raise_if_empty, drop_empty_rows=drop_empty_rows + ) -def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: - """If DataFrame contains columns/rows that contain only nulls, drop them.""" +def _drop_null_data( + df: pl.DataFrame, *, raise_if_empty: bool, drop_empty_rows: bool = True +) -> pl.DataFrame: + """ + If DataFrame contains columns/rows that contain only nulls, drop them. + + If `drop_empty_rows` is set to `False`, empty rows are not dropped. + """ null_cols = [] for col_name in df.columns: # note that if multiple unnamed columns are found then all but the first one @@ -807,8 +838,9 @@ def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: if len(df) == 0 and len(df.columns) == 0: return _empty_frame(raise_if_empty) - - return df.filter(~F.all_horizontal(F.all().is_null())) + if drop_empty_rows: + return df.filter(~F.all_horizontal(F.all().is_null())) + return df def _empty_frame(raise_if_empty: bool) -> pl.DataFrame: # noqa: FBT001 @@ -840,6 +872,7 @@ def _read_spreadsheet_openpyxl( schema_overrides: SchemaDict | None, columns: Sequence[int] | Sequence[str] | None, raise_if_empty: bool, + drop_empty_rows: bool, ) -> pl.DataFrame: """Use the 'openpyxl' library to read data from the given worksheet.""" infer_schema_length = read_options.pop("infer_schema_length", None) @@ -896,7 +929,9 @@ def _read_spreadsheet_openpyxl( strict=False, ) - df = _drop_null_data(df, raise_if_empty=raise_if_empty) + df = _drop_null_data( + df, raise_if_empty=raise_if_empty, drop_empty_rows=drop_empty_rows + ) df = _reorder_columns(df, columns) return df @@ -909,6 +944,7 @@ def _read_spreadsheet_calamine( schema_overrides: SchemaDict | None, columns: Sequence[int] | Sequence[str] | None, raise_if_empty: bool, + drop_empty_rows: bool, ) -> pl.DataFrame: # if we have 'schema_overrides' and a more recent version of `fastexcel` # we can pass translated dtypes to the engine to refine the initial parse @@ -966,7 +1002,9 @@ def _read_spreadsheet_calamine( if schema_overrides: df = df.cast(dtypes=schema_overrides) - df = _drop_null_data(df, raise_if_empty=raise_if_empty) + df = _drop_null_data( + df, raise_if_empty=raise_if_empty, drop_empty_rows=drop_empty_rows + ) # standardise on string dtype for null columns in empty frame if df.is_empty(): @@ -1009,6 +1047,7 @@ def _read_spreadsheet_xlsx2csv( schema_overrides: SchemaDict | None, columns: Sequence[int] | Sequence[str] | None, raise_if_empty: bool, + drop_empty_rows: bool, ) -> pl.DataFrame: """Use the 'xlsx2csv' library to read data from the given worksheet.""" csv_buffer = StringIO() @@ -1031,6 +1070,7 @@ def _read_spreadsheet_xlsx2csv( read_options=read_options, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, + drop_empty_rows=drop_empty_rows, ) if cast_to_boolean: df = df.with_columns(*cast_to_boolean) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index d5c56c93ca0a..64608ae825fb 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1369,7 +1369,12 @@ def sort( └──────┴─────┴─────┘ """ # Fast path for sorting by a single existing column - if isinstance(by, str) and not more_by: + if ( + isinstance(by, str) + and not more_by + and isinstance(descending, bool) + and isinstance(nulls_last, bool) + ): return self._from_pyldf( self._ldf.sort( by, descending, nulls_last, maintain_order, multithreaded @@ -2258,7 +2263,7 @@ def collect_schema(self) -> Schema: >>> schema.len() 3 """ - return Schema(self._ldf.collect_schema()) + return Schema(self._ldf.collect_schema(), check_dtypes=False) @unstable() def sink_parquet( diff --git a/py-polars/polars/meta/versions.py b/py-polars/polars/meta/versions.py index 6788d25a68ea..425f01d91a85 100644 --- a/py-polars/polars/meta/versions.py +++ b/py-polars/polars/meta/versions.py @@ -2,6 +2,7 @@ import sys +from polars._cpu_check import get_lts_cpu from polars._utils.polars_version import get_polars_version from polars.meta.index_type import get_index_type @@ -18,6 +19,7 @@ def show_versions() -> None: Index type: UInt32 Platform: macOS-14.4.1-arm64-arm-64bit Python: 3.11.8 (main, Feb 6 2024, 21:21:21) [Clang 15.0.0 (clang-1500.1.0.2.5)] + LTS CPU: False ----Optional dependencies---- adbc_driver_manager: 0.11.0 altair: 5.4.0 @@ -45,7 +47,7 @@ def show_versions() -> None: import platform deps = _get_dependency_list() - core_properties = ("Polars", "Index type", "Platform", "Python") + core_properties = ("Polars", "Index type", "Platform", "Python", "LTS CPU") keylen = max(len(x) for x in [*core_properties, *deps]) + 1 print("--------Version info---------") @@ -53,6 +55,7 @@ def show_versions() -> None: print(f"{'Index type:':{keylen}s} {get_index_type()}") print(f"{'Platform:':{keylen}s} {platform.platform()}") print(f"{'Python:':{keylen}s} {sys.version}") + print(f"{'LTS CPU:':{keylen}s} {get_lts_cpu()}") print("\n----Optional dependencies----") for name in deps: diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 72eb8b86d25e..81ade5a6b206 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -1,23 +1,54 @@ from __future__ import annotations +import sys from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union -from polars.datatypes import DataType +from polars._typing import PythonDataType +from polars.datatypes import DataType, DataTypeClass, is_polars_dtype from polars.datatypes._parse import parse_into_dtype -BaseSchema = OrderedDict[str, DataType] - if TYPE_CHECKING: from collections.abc import Iterable - from polars._typing import PythonDataType + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + +if sys.version_info >= (3, 10): + + def _required_init_args(tp: DataTypeClass) -> bool: + # note: this check is ~20% faster than the check for a + # custom "__init__", below, but is not available on py39 + return bool(tp.__annotations__) +else: + + def _required_init_args(tp: DataTypeClass) -> bool: + # indicates override of the default __init__ + # (eg: this type requires specific args) + return "__init__" in tp.__dict__ + + +BaseSchema = OrderedDict[str, DataType] +SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType] __all__ = ["Schema"] +def _check_dtype(tp: DataType | DataTypeClass) -> DataType: + if not isinstance(tp, DataType): + # note: if nested/decimal, or has signature params, this implies required args + if tp.is_nested() or tp.is_decimal() or _required_init_args(tp): + msg = f"dtypes must be fully-specified, got: {tp!r}" + raise TypeError(msg) + tp = tp() + return tp # type: ignore[return-value] + + class Schema(BaseSchema): """ Ordered mapping of column names to their data type. @@ -54,18 +85,42 @@ class Schema(BaseSchema): def __init__( self, schema: ( - Mapping[str, DataType | PythonDataType] - | Iterable[tuple[str, DataType | PythonDataType]] + Mapping[str, SchemaInitDataType] + | Iterable[tuple[str, SchemaInitDataType]] | None ) = None, + *, + check_dtypes: bool = True, ) -> None: input = ( schema.items() if schema and isinstance(schema, Mapping) else (schema or {}) ) - super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc] - - def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None: - super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment] + for name, tp in input: # type: ignore[misc] + if not check_dtypes: + super().__setitem__(name, tp) # type: ignore[assignment] + elif is_polars_dtype(tp): + super().__setitem__(name, _check_dtype(tp)) + else: + self[name] = tp + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Mapping): + return False + if len(self) != len(other): + return False + for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()): + if nm1 != nm2 or not tp1.is_(tp2): + return False + return True + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __setitem__( + self, name: str, dtype: DataType | DataTypeClass | PythonDataType + ) -> None: + dtype = _check_dtype(parse_into_dtype(dtype)) + super().__setitem__(name, dtype) def names(self) -> list[str]: """Get the column names of the schema.""" @@ -81,7 +136,7 @@ def len(self) -> int: def to_python(self) -> dict[str, type]: """ - Return Schema as a dictionary of column names and their Python types. + Return a dictionary of column names and Python types. Examples -------- diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 2631f222612a..4cb11506b3f6 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -508,7 +508,7 @@ def as_expr(self) -> Expr: def _re_string(string: str | Collection[str], *, escape: bool = True) -> str: """Return escaped regex, potentially representing multiple string fragments.""" if isinstance(string, str): - rx = f"{re_escape(string)}" if escape else string + rx = re_escape(string) if escape else string else: strings: list[str] = [] for st in string: diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index dcf65ff15312..3b0e905b84fc 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -97,7 +97,7 @@ def add_business_days( You can pass a custom weekend - for example, if you only take Sunday off: >>> week_mask = (True, True, True, True, True, True, False) - >>> s.dt.add_business_days(5, week_mask) + >>> s.dt.add_business_days(5, week_mask=week_mask) shape: (2,) Series: 'start' [date] [ diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index cf70f5225f56..0c4b08982606 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -14,8 +14,8 @@ from polars._typing import ( IntoExpr, IntoExprColumn, + ListToStructWidthStrategy, NullBehavior, - ToStructStrategy, ) from polars.polars import PySeries @@ -855,7 +855,7 @@ def to_array(self, width: int) -> Series: def to_struct( self, - n_field_strategy: ToStructStrategy = "first_non_null", + n_field_strategy: ListToStructWidthStrategy = "first_non_null", fields: Callable[[int], str] | Sequence[str] | None = None, ) -> Series: """ diff --git a/py-polars/polars/series/plotting.py b/py-polars/polars/series/plotting.py index 5430d55c6ff3..cc08599b1720 100644 --- a/py-polars/polars/series/plotting.py +++ b/py-polars/polars/series/plotting.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Callable -from polars.dataframe.plotting import _add_tooltip from polars.dependencies import altair as alt if TYPE_CHECKING: @@ -42,7 +41,7 @@ def hist( `Altair `_. `s.plot.hist(**kwargs)` is shorthand for - `alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`, + `alt.Chart(s.to_frame()).mark_bar(tooltip=True).encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`, and is provided for convenience - for full customisatibility, use a plotting library directly. @@ -69,9 +68,11 @@ def hist( "x": alt.X(f"{self._series_name}:Q", bin=True), "y": "count()", } - _add_tooltip(encodings, **kwargs) return ( - alt.Chart(self._df).mark_bar().encode(**encodings, **kwargs).interactive() + alt.Chart(self._df) + .mark_bar(tooltip=True) + .encode(**encodings, **kwargs) + .interactive() ) def kde( @@ -86,7 +87,7 @@ def kde( `Altair `_. `s.plot.kde(**kwargs)` is shorthand for - `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()`, + `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area(tooltip=True).encode(x=s.name, y='density:Q', **kwargs).interactive()`, and is provided for convenience - for full customisatibility, use a plotting library directly. @@ -110,11 +111,10 @@ def kde( msg = "Cannot use `plot.kde` when Series name is `'density'`" raise ValueError(msg) encodings: Encodings = {"x": self._series_name, "y": "density:Q"} - _add_tooltip(encodings, **kwargs) return ( alt.Chart(self._df) .transform_density(self._series_name, as_=[self._series_name, "density"]) - .mark_area() + .mark_area(tooltip=True) .encode(**encodings, **kwargs) .interactive() ) @@ -131,7 +131,7 @@ def line( `Altair `_. `s.plot.line(**kwargs)` is shorthand for - `alt.Chart(s.to_frame().with_row_index()).mark_line().encode(x='index', y=s.name, **kwargs).interactive()`, + `alt.Chart(s.to_frame().with_row_index()).mark_line(tooltip=True).encode(x='index', y=s.name, **kwargs).interactive()`, and is provided for convenience - for full customisatibility, use a plotting library directly. @@ -155,10 +155,9 @@ def line( msg = "Cannot call `plot.line` when Series name is 'index'" raise ValueError(msg) encodings: Encodings = {"x": "index", "y": self._series_name} - _add_tooltip(encodings, **kwargs) return ( alt.Chart(self._df.with_row_index()) - .mark_line() + .mark_line(tooltip=True) .encode(**encodings, **kwargs) .interactive() ) @@ -177,7 +176,6 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: encodings: Encodings = {"x": "index", "y": self._series_name} def func(**kwargs: EncodeKwds) -> alt.Chart: - _add_tooltip(encodings, **kwargs) - return method().encode(**encodings, **kwargs).interactive() + return method(tooltip=True).encode(**encodings, **kwargs).interactive() return func diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index f57f976df70a..8e27b3470b16 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1084,13 +1084,18 @@ def __truediv__(self, other: Any) -> Series | Expr: msg = "first cast to integer before dividing datelike dtypes" raise TypeError(msg) - # this branch is exactly the floordiv function without rounding the floats - if self.dtype.is_float() or self.dtype == Decimal: - as_float = self - else: - as_float = self._recursive_cast_to_dtype(Float64()) + self = ( + self._recursive_cast_to_dtype(Float64()) + if not ( + self.dtype.is_float() + or self.dtype.is_decimal() + or isinstance(self.dtype, List) + or (isinstance(other, Series) and isinstance(other.dtype, List)) + ) + else self + ) - return as_float._arithmetic(other, "div", "div_<>") + return self._arithmetic(other, "div", "div_<>") @overload def __floordiv__(self, other: Expr) -> Expr: ... @@ -4042,8 +4047,15 @@ def to_physical(self) -> Series: - :func:`polars.datatypes.Duration` -> :func:`polars.datatypes.Int64` - :func:`polars.datatypes.Categorical` -> :func:`polars.datatypes.UInt32` - `List(inner)` -> `List(physical of inner)` + - `Array(inner)` -> `Array(physical of inner)` + - `Struct(fields)` -> `Struct(physical of fields)` - Other data types will be left unchanged. + Warning + ------- + The physical representations are an implementation detail + and not guaranteed to be stable. + Examples -------- Replicating the pandas @@ -7477,13 +7489,13 @@ def plot(self) -> SeriesPlot: - `s.plot.hist(**kwargs)` is shorthand for - `alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()` + `alt.Chart(s.to_frame()).mark_bar(tooltip=True).encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()` - `s.plot.kde(**kwargs)` is shorthand for - `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()` + `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area(tooltip=True).encode(x=s.name, y='density:Q', **kwargs).interactive()` - for any other attribute `attr`, `s.plot.attr(**kwargs)` is shorthand for - `alt.Chart(s.to_frame().with_row_index()).mark_attr().encode(x='index', y=s.name, **kwargs).interactive()` + `alt.Chart(s.to_frame().with_row_index()).mark_attr(tooltip=True).encode(x='index', y=s.name, **kwargs).interactive()` Examples -------- diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index af9ce66850c4..97f3e373e98f 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -2077,3 +2077,25 @@ def concat( null ] """ + + def escape_regex(self) -> Series: + r""" + Returns string values with all regular expression meta characters escaped. + + Returns + ------- + Series + Series of data type :class:`String`. + + Examples + -------- + >>> pl.Series(["abc", "def", None, "abc(\\w+)"]).str.escape_regex() + shape: (4,) + Series: '' [str] + [ + "abc" + "def" + null + "abc\(\\w\+\)" + ] + """ diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index e8137a23be32..a04d254808d3 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -107,7 +107,7 @@ def schema(self) -> Schema: return Schema({}) schema = self._s.dtype().to_schema() - return Schema(schema) + return Schema(schema, check_dtypes=False) def unnest(self) -> DataFrame: """ diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 9d0a2df292cb..e89a8e19c0b6 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -74,3 +74,4 @@ flask-cors # Stub files pandas-stubs boto3-stubs +google-auth-stubs diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 1c645738102a..859609828d19 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -275,6 +275,10 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::set_random_seed)) .unwrap(); + // Functions - escape_regex + m.add_wrapped(wrap_pyfunction!(functions::escape_regex)) + .unwrap(); + // Exceptions - Errors m.add( "PolarsError", @@ -377,6 +381,9 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { #[cfg(feature = "polars_cloud")] m.add_wrapped(wrap_pyfunction!(cloud::prepare_cloud_plan)) .unwrap(); + #[cfg(feature = "polars_cloud")] + m.add_wrapped(wrap_pyfunction!(cloud::_execute_ir_plan_with_gpu)) + .unwrap(); // Build info m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index d340433ddf10..3ab507f31fa8 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -150,7 +150,7 @@ def test_init_dict() -> None: data={"dt": dates, "dtm": datetimes}, schema=coldefs, ) - assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime} + assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")} assert df.rows() == list(zip(py_dates, py_datetimes)) # Overriding dict column names/types @@ -251,7 +251,7 @@ class TradeNT(NamedTuple): ) assert df.schema == { "ts": pl.Datetime("ms"), - "tk": pl.Categorical, + "tk": pl.Categorical(ordering="physical"), "pc": pl.Decimal(scale=1), "sz": pl.UInt16, } @@ -284,7 +284,6 @@ class PageView(BaseModel): models = adapter.validate_json(data_json) result = pl.DataFrame(models) - expected = pl.DataFrame( { "user_id": ["x"], diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 015b64ce6303..d8910cda4fb2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -303,10 +303,9 @@ def test_dataframe_membership_operator() -> None: def test_sort() -> None: df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3]}) - assert_frame_equal(df.sort("a"), pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]})) - assert_frame_equal( - df.sort(["a", "b"]), pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]}) - ) + expected = pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]}) + assert_frame_equal(df.sort("a"), expected) + assert_frame_equal(df.sort(["a", "b"]), expected) def test_sort_multi_output_exprs_01() -> None: @@ -761,7 +760,7 @@ def test_to_dummies() -> None: "i": [1, 2, 3], "category": ["dog", "cat", "cat"], }, - schema={"i": pl.Int32, "category": pl.Categorical}, + schema={"i": pl.Int32, "category": pl.Categorical("lexical")}, ) expected = pl.DataFrame( { diff --git a/py-polars/tests/unit/dataframe/test_getitem.py b/py-polars/tests/unit/dataframe/test_getitem.py index 5d112ad67528..ab618a944b80 100644 --- a/py-polars/tests/unit/dataframe/test_getitem.py +++ b/py-polars/tests/unit/dataframe/test_getitem.py @@ -479,3 +479,9 @@ def test_df_getitem_5343() -> None: assert df[4, 5] == 1024 assert_frame_equal(df[4, [2]], pl.DataFrame({"foo2": [16]})) assert_frame_equal(df[4, [5]], pl.DataFrame({"foo5": [1024]})) + + +def test_no_deadlock_19358() -> None: + s = pl.Series(["text"] * 100 + [1] * 100, dtype=pl.Object) + result = s.to_frame()[[0, -1]] + assert result[""].to_list() == ["text", 1] diff --git a/py-polars/tests/unit/dataframe/test_null_count.py b/py-polars/tests/unit/dataframe/test_null_count.py index a9b1141a2a67..507bf0269517 100644 --- a/py-polars/tests/unit/dataframe/test_null_count.py +++ b/py-polars/tests/unit/dataframe/test_null_count.py @@ -23,9 +23,6 @@ def test_null_count(df: pl.DataFrame) -> None: # note: the zero-row and zero-col cases are always passed as explicit examples null_count, ncols = df.null_count(), len(df.columns) - if ncols == 0: - assert null_count.shape == (0, 0) - else: - assert null_count.shape == (1, ncols) - for idx, count in enumerate(null_count.rows()[0]): - assert count == sum(v is None for v in df.to_series(idx).to_list()) + assert null_count.shape == (1, ncols) + for idx, count in enumerate(null_count.rows()[0]): + assert count == sum(v is None for v in df.to_series(idx).to_list()) diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index c5888abee67e..f8750f9e2a95 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -864,3 +864,15 @@ def test_nested_categorical_concat( with pytest.raises(pl.exceptions.StringCacheMismatchError): pl.concat([a, b]) + + +def test_perfect_group_by_19452() -> None: + n = 40 + df2 = pl.DataFrame( + { + "a": pl.int_range(n, eager=True).cast(pl.String).cast(pl.Categorical), + "b": pl.int_range(n, eager=True), + } + ) + + assert df2.with_columns(a=(pl.col("b")).over(pl.col("a")))["a"].is_sorted() diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 9ad384cd1bab..9bd4a49ddd1d 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -536,3 +536,21 @@ def test_integer_cast_to_enum_15738(dt: pl.DataType) -> None: assert s.to_list() == ["a", "b", "c"] expected_s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) assert_series_equal(s, expected_s) + + +def test_enum_19269() -> None: + en = pl.Enum(["X", "Z", "Y"]) + df = pl.DataFrame( + {"test": pl.Series(["X", "Y", "Z"], dtype=en), "group": [1, 2, 2]} + ) + out = ( + df.group_by("group", maintain_order=True) + .agg(pl.col("test").mode()) + .select( + a=pl.col("test").list.max(), + b=pl.col("test").list.min(), + ) + ) + + assert out.to_dict(as_series=False) == {"a": ["X", "Y"], "b": ["X", "Z"]} + assert out.dtypes == [en, en] diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 8c5502d698fd..f7774fd70191 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -49,7 +49,7 @@ def test_dtype() -> None: "u": pl.List(pl.UInt64), "tm": pl.List(pl.Time), "dt": pl.List(pl.Date), - "dtm": pl.List(pl.Datetime), + "dtm": pl.List(pl.Datetime("us")), } assert all(tp.is_nested() for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None: assert df.to_dict(as_series=False) == expected df = pl.DataFrame(schema=[("col", pl.List)]) - assert df.schema == {"col": pl.List} + assert df.schema == {"col": pl.List(pl.Null)} assert df.rows() == [] diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 5caa7c338e10..706ae904775b 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -1,8 +1,9 @@ from __future__ import annotations +import io from dataclasses import dataclass from datetime import datetime, time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import pandas as pd import pyarrow as pa @@ -210,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame: # struct column df = build_struct_df([{"struct_col": {"inner": 1}}]) assert df.columns == ["struct_col"] - assert df.schema == {"struct_col": pl.Struct} + assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})} assert df["struct_col"].struct.field("inner").to_list() == [1] # struct in struct @@ -619,7 +620,7 @@ def test_struct_categorical_5843() -> None: def test_empty_struct() -> None: # List df = pl.DataFrame({"a": [[{}]]}) - assert df.to_dict(as_series=False) == {"a": [[None]]} + assert df.to_dict(as_series=False) == {"a": [[{}]]} # Struct one not empty df = pl.DataFrame({"a": [[{}, {"a": 10}]]}) @@ -627,7 +628,7 @@ def test_empty_struct() -> None: # Empty struct df = pl.DataFrame({"a": [{}]}) - assert df.to_dict(as_series=False) == {"a": [None]} + assert df.to_dict(as_series=False) == {"a": [{}]} @pytest.mark.parametrize( @@ -1048,3 +1049,140 @@ def test_struct_null_zip() -> None: df.select(pl.when(pl.Series([True])).then(pl.col.int).otherwise(pl.col.int)), pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame(), ) + + +@pytest.mark.parametrize("size", [0, 1, 2, 5, 9, 13, 42]) +def test_zfs_construction(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])) + assert a.len() == size + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_unnest(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).struct.unnest() + assert a.height == size + assert a.width == 0 + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_equality(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])) + b = pl.Series("a", [{}] * size, pl.Struct([])) + + assert_series_equal(a, b) + + assert_frame_equal( + a.to_frame(), + b.to_frame(), + ) + + +def test_zfs_nullable_when_otherwise() -> None: + a = pl.Series("a", [{}, None, {}, {}, None], pl.Struct([])) + b = pl.Series("b", [None, {}, None, {}, None], pl.Struct([])) + + df = pl.DataFrame([a, b]) + + df = df.select( + x=pl.when(pl.col.a.is_not_null()).then(pl.col.a).otherwise(pl.col.b), + y=pl.when(pl.col.a.is_null()).then(pl.col.a).otherwise(pl.col.b), + ) + + assert_series_equal(df["x"], pl.Series("x", [{}, {}, {}, {}, None], pl.Struct([]))) + assert_series_equal( + df["y"], pl.Series("y", [None, None, None, {}, None], pl.Struct([])) + ) + + +def test_zfs_struct_fns() -> None: + a = pl.Series("a", [{}], pl.Struct([])) + + assert a.struct.fields == [] + + # @TODO: This should really throw an error as per #19132 + assert a.struct.rename_fields(["a"]).struct.unnest().shape == (1, 0) + assert a.struct.rename_fields([]).struct.unnest().shape == (1, 0) + + assert_series_equal(a.struct.json_encode(), pl.Series("a", ["{}"], pl.String)) + + +@pytest.mark.parametrize("format", ["binary", "json"]) +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_serialization_roundtrip(format: pl.SerializationFormat, size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + f = io.BytesIO() + a.serialize(f, format=format) + + f.seek(0) + assert_frame_equal( + a, + pl.DataFrame.deserialize(f, format=format), + ) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_row_encoding(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])) + + df = pl.DataFrame([a, pl.Series("x", list(range(size)), pl.Int8)]) + + gb = df.lazy().group_by(["a", "x"]).agg(pl.all().min()).collect(streaming=True) + + # We need to ignore the order because the group_by is non-deterministic + assert_frame_equal(gb, df, check_row_order=False) + + +@pytest.mark.may_fail_auto_streaming +def test_list_to_struct_19208() -> None: + df = pl.DataFrame( + { + "nested": [ + [{"a": 1}], + [], + [{"a": 3}], + ] + } + ) + assert pl.concat([df[0], df[1], df[2]]).select( + pl.col("nested").list.to_struct() + ).to_dict(as_series=False) == { + "nested": [{"field_0": {"a": 1}}, {"field_0": None}, {"field_0": {"a": 3}}] + } + + +def test_struct_reverse_outer_validity_19445() -> None: + assert_series_equal( + pl.Series([{"a": 1}, None]).reverse(), + pl.Series([None, {"a": 1}]), + ) + + +@pytest.mark.parametrize("maybe_swap", [lambda a, b: (a, b), lambda a, b: (b, a)]) +def test_struct_eq_missing_outer_validity_19156( + maybe_swap: Callable[[pl.Series, pl.Series], tuple[pl.Series, pl.Series]], +) -> None: + # Ensure that lit({'x': NULL}).eq_missing(lit(NULL)) => False + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([None, {"a": None, "b": None}]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([False, False])) + assert_series_equal(l.ne_missing(r), pl.Series([True, True])) + + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([None]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([False, True])) + assert_series_equal(l.ne_missing(r), pl.Series([True, False])) + + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([{"a": None, "b": None}]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([True, False])) + assert_series_equal(l.ne_missing(r), pl.Series([False, True])) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a925c6f18781..042a0fca786b 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -632,16 +632,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.072", "2016-05-25 13:30:00.075", ] - ticker = [ - "GOOG", - "MSFT", - "MSFT", - "MSFT", - "GOOG", - "AAPL", - "GOOG", - "MSFT", - ] + ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"] quotes = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -656,13 +647,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.048", "2016-05-25 13:30:00.048", ] - ticker = [ - "MSFT", - "MSFT", - "GOOG", - "GOOG", - "AAPL", - ] + ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"] trades = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -678,11 +663,11 @@ def test_asof_join() -> None: out = trades.join_asof(quotes, on="dates", strategy="backward") assert out.schema == { - "bid": pl.Float64, - "bid_right": pl.Float64, "dates": pl.Datetime("ms"), "ticker": pl.String, + "bid": pl.Float64, "ticker_right": pl.String, + "bid_right": pl.Float64, } assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"] assert (out["dates"].cast(int)).to_list() == [ diff --git a/py-polars/tests/unit/functions/as_datatype/test_duration.py b/py-polars/tests/unit/functions/as_datatype/test_duration.py index 8bef49e1f074..a3fc5dc658ef 100644 --- a/py-polars/tests/unit/functions/as_datatype/test_duration.py +++ b/py-polars/tests/unit/functions/as_datatype/test_duration.py @@ -181,3 +181,12 @@ def test_duration_time_unit_ms() -> None: result = pl.duration(milliseconds=4) expected = pl.duration(milliseconds=4, time_unit="us") assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in pl.duration + # https://github.com/pola-rs/polars/issues/19007 + df = df = pl.DataFrame({"a": [1], "b": [2]}) + assert df.select(pl.duration(hours=pl.all()).name.keep()).to_dict( + as_series=False + ) == {"a": [timedelta(seconds=3600)], "b": [timedelta(seconds=7200)]} diff --git a/py-polars/tests/unit/functions/range/test_date_range.py b/py-polars/tests/unit/functions/range/test_date_range.py index a881d30c1e41..a88287bedf41 100644 --- a/py-polars/tests/unit/functions/range/test_date_range.py +++ b/py-polars/tests/unit/functions/range/test_date_range.py @@ -7,7 +7,7 @@ import pytest import polars as pl -from polars.exceptions import ComputeError, PanicException +from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -21,7 +21,7 @@ def test_date_range() -> None: def test_date_range_invalid_time_unit() -> None: - with pytest.raises(PanicException, match="'x' not supported"): + with pytest.raises(InvalidOperationError, match="'x' not supported"): pl.date_range( start=date(2021, 12, 16), end=date(2021, 12, 18), @@ -312,6 +312,7 @@ def test_date_ranges_datetime_input() -> None: assert_series_equal(result, expected) +@pytest.mark.may_fail_auto_streaming def test_date_range_with_subclass_18470_18447() -> None: class MyAmazingDate(date): pass diff --git a/py-polars/tests/unit/functions/range/test_datetime_range.py b/py-polars/tests/unit/functions/range/test_datetime_range.py index 22a57c567396..ce99cd27b802 100644 --- a/py-polars/tests/unit/functions/range/test_datetime_range.py +++ b/py-polars/tests/unit/functions/range/test_datetime_range.py @@ -3,11 +3,13 @@ from datetime import date, datetime, timedelta from typing import TYPE_CHECKING +import hypothesis.strategies as st import pytest +from hypothesis import given, settings import polars as pl from polars.datatypes import DTYPE_TEMPORAL_UNITS -from polars.exceptions import ComputeError, PanicException, SchemaError +from polars.exceptions import ComputeError, InvalidOperationError, SchemaError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -94,7 +96,7 @@ def test_datetime_range_precision( def test_datetime_range_invalid_time_unit() -> None: - with pytest.raises(PanicException, match="'x' not supported"): + with pytest.raises(InvalidOperationError, match="'x' not supported"): pl.datetime_range( start=datetime(2021, 12, 16), end=datetime(2021, 12, 16, 3), @@ -579,3 +581,42 @@ def test_datetime_range_specifying_ambiguous_11713() -> None: "datetime", [datetime(2023, 10, 29, 2), datetime(2023, 10, 29, 3)] ).dt.replace_time_zone("Europe/Madrid", ambiguous=pl.Series(["latest", "raise"])) assert_series_equal(result, expected) + + +@given( + closed=st.sampled_from(["none", "left", "right", "both"]), + time_unit=st.sampled_from(["ms", "us", "ns"]), + n=st.integers(1, 10), + size=st.integers(8, 10), + unit=st.sampled_from(["s", "m", "h", "d", "mo"]), + start=st.datetimes(datetime(1965, 1, 1), datetime(2100, 1, 1)), +) +@settings(max_examples=50) +@pytest.mark.benchmark +def test_datetime_range_fast_slow_paths( + closed: ClosedInterval, + time_unit: TimeUnit, + n: int, + size: int, + unit: str, + start: datetime, +) -> None: + end = pl.select(pl.lit(start).dt.offset_by(f"{n*size}{unit}")).item() + result_slow = pl.datetime_range( + start, + end, + closed=closed, + time_unit=time_unit, + interval=f"{n}{unit}", + time_zone="Asia/Kathmandu", + eager=True, + ).dt.replace_time_zone(None) + result_fast = pl.datetime_range( + start, + end, + closed=closed, + time_unit=time_unit, + interval=f"{n}{unit}", + eager=True, + ) + assert_series_equal(result_slow, result_fast) diff --git a/py-polars/tests/unit/functions/test_functions.py b/py-polars/tests/unit/functions/test_functions.py index de7e49574393..05bd11976fd8 100644 --- a/py-polars/tests/unit/functions/test_functions.py +++ b/py-polars/tests/unit/functions/test_functions.py @@ -538,3 +538,22 @@ def test_head_tail(fruits_cars: pl.DataFrame) -> None: res_expr = fruits_cars.select(pl.tail("A", 2)) expected = pl.Series("A", [4, 5]) assert_series_equal(res_expr.to_series(), expected) + + +def test_escape_regex() -> None: + result = pl.escape_regex("abc(\\w+)") + expected = "abc\\(\\\\w\\+\\)" + assert result == expected + + df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + with pytest.raises( + TypeError, + match="escape_regex function is unsupported for `Expr`, you may want use `Expr.str.escape_regex` instead", + ): + df.with_columns(escaped=pl.escape_regex(pl.col("text"))) # type: ignore[arg-type] + + with pytest.raises( + TypeError, + match="escape_regex function supports only `str` type, got ``", + ): + pl.escape_regex(3) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/functions/test_lit.py b/py-polars/tests/unit/functions/test_lit.py index f48c6034f7fd..81a89ab342f2 100644 --- a/py-polars/tests/unit/functions/test_lit.py +++ b/py-polars/tests/unit/functions/test_lit.py @@ -195,3 +195,11 @@ def test_lit_decimal_parametric(s: pl.Series) -> None: assert df.dtypes[0] == pl.Decimal(None, scale) assert result == value + + +@pytest.mark.parametrize( + "item", + [{}, {"foo": 1}], +) +def test_lit_structs(item: Any) -> None: + assert pl.select(pl.lit(item)).to_dict(as_series=False) == {"literal": [item]} diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index dc458086c943..6e79c874a01b 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -690,3 +690,49 @@ def test_when_then_chunked_structs_18673() -> None: df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))), pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}), ) + + +some_scalar = pl.Series("a", [{"x": 2}], pl.Struct) +none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64})) +column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct) + + +@pytest.mark.parametrize( + "values", + [ + (some_scalar, some_scalar), + (some_scalar, pl.col.a), + (some_scalar, none_scalar), + (some_scalar, column), + (none_scalar, pl.col.a), + (none_scalar, none_scalar), + (none_scalar, column), + (pl.col.a, pl.col.a), + (pl.col.a, column), + (column, column), + ], +) +def test_struct_when_then_broadcasting_combinations_19122( + values: tuple[Any, Any], +) -> None: + lv, rv = values + + df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame() + + assert_frame_equal( + df.select( + pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a") + ), + df.select( + pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a") + ), + ) + + assert_frame_equal( + df.select( + pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a") + ), + df.select( + pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a") + ), + ) diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index c22d8abafd21..50ef4f2ac0ad 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -60,7 +60,7 @@ def test_from_pandas() -> None: "floats_nulls": pl.Float64, "strings": pl.String, "strings_nulls": pl.String, - "strings-cat": pl.Categorical, + "strings-cat": pl.Categorical(ordering="physical"), } assert out.rows() == [ (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), @@ -190,6 +190,18 @@ def test_from_pandas_include_indexes() -> None: assert df.to_dict(as_series=False) == data +def test_from_pandas_series_include_indexes() -> None: + # no default index + pd_series = pd.Series({"a": 1, "b": 2}, name="number").rename_axis(["letter"]) + df = pl.from_pandas(pd_series, include_index=True) + assert df.to_dict(as_series=False) == {"letter": ["a", "b"], "number": [1, 2]} + + # default index + pd_series = pd.Series(range(2)) + df = pl.from_pandas(pd_series, include_index=True) + assert df.to_dict(as_series=False) == {"index": [0, 1], "0": [0, 1]} + + def test_duplicate_cols_diff_types() -> None: df = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=["0", 0, "1", 1]) with pytest.raises( diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index 5768787c22c4..b69a10671ca7 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -79,6 +79,32 @@ def test_arrow_list_chunked_array() -> None: assert s.dtype == pl.List +# Test that polars convert Arrays of logical types correctly to arrow +def test_arrow_array_logical() -> None: + # cast to large string and uint32 indices because polars converts to those + pa_data1 = ( + pa.array(["a", "b", "c", "d"]) + .dictionary_encode() + .cast(pa.dictionary(pa.uint32(), pa.large_string())) + ) + pa_array_logical1 = pa.FixedSizeListArray.from_arrays(pa_data1, 2) + + s1 = pl.Series( + values=[["a", "b"], ["c", "d"]], + dtype=pl.Array(pl.Enum(["a", "b", "c", "d"]), shape=2), + ) + assert s1.to_arrow() == pa_array_logical1 + + pa_data2 = pa.array([date(2024, 1, 1), date(2024, 1, 2)]) + pa_array_logical2 = pa.FixedSizeListArray.from_arrays(pa_data2, 1) + + s2 = pl.Series( + values=[[date(2024, 1, 1)], [date(2024, 1, 2)]], + dtype=pl.Array(pl.Date, shape=1), + ) + assert s2.to_arrow() == pa_array_logical2 + + def test_from_dict() -> None: data = {"a": [1, 2], "b": [3, 4]} df = pl.from_dict(data) @@ -96,7 +122,7 @@ def test_from_dict_struct() -> None: assert df.shape == (2, 2) assert df["a"][0] == {"b": 1, "c": 2} assert df["a"][1] == {"b": 3, "c": 4} - assert df.schema == {"a": pl.Struct, "d": pl.Int64} + assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64} def test_from_dicts() -> None: @@ -371,7 +397,7 @@ def test_dataframe_from_repr() -> None: assert frame.schema == { "a": pl.Int64, "b": pl.Float64, - "c": pl.Categorical, + "c": pl.Categorical(ordering="physical"), "d": pl.Boolean, "e": pl.String, "f": pl.Date, diff --git a/py-polars/tests/unit/io/cloud/test_cloud.py b/py-polars/tests/unit/io/cloud/test_cloud.py index f943ab5e2c26..e9b56b6b9f15 100644 --- a/py-polars/tests/unit/io/cloud/test_cloud.py +++ b/py-polars/tests/unit/io/cloud/test_cloud.py @@ -1,3 +1,7 @@ +import io +import sys +from typing import Any + import pytest import polars as pl @@ -23,3 +27,89 @@ def test_scan_nonexistent_cloud_path_17444(format: str) -> None: # Upon collection, it should fail with pytest.raises(ComputeError): result.collect() + + +@pytest.mark.parametrize( + "io_func", + [ + *[pl.scan_parquet, pl.read_parquet], + pl.scan_csv, + *[pl.scan_ndjson, pl.read_ndjson], + pl.scan_ipc, + ], +) +def test_scan_credential_provider( + io_func: Any, monkeypatch: pytest.MonkeyPatch +) -> None: + err_magic = "err_magic_3" + + def raises(*_: None, **__: None) -> None: + raise AssertionError(err_magic) + + monkeypatch.setattr(pl.CredentialProviderAWS, "__init__", raises) + + with pytest.raises(AssertionError, match=err_magic): + io_func("s3://bucket/path", credential_provider="auto") + + # We can't test these with the `read_` functions as they end up executing + # the query + if io_func.__name__.startswith("scan_"): + # Passing `None` should disable the automatic instantiation of + # `CredentialProviderAWS` + io_func("s3://bucket/path", credential_provider=None) + # Passing `storage_options` should disable the automatic instantiation of + # `CredentialProviderAWS` + io_func("s3://bucket/path", credential_provider="auto", storage_options={}) + + err_magic = "err_magic_7" + + def raises_2() -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + # Note to reader: It is converted to a ComputeError as it is being called + # from Rust. + with pytest.raises(ComputeError, match=err_magic): + io_func("s3://bucket/path", credential_provider=raises_2).collect() + + +def test_scan_credential_provider_serialization() -> None: + err_magic = "err_magic_3" + + class ErrCredentialProvider(pl.CredentialProvider): + def __call__(self) -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=ErrCredentialProvider() + ) + + serialized = lf.serialize() + + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + + with pytest.raises(ComputeError, match=err_magic): + lf.collect() + + +def test_scan_credential_provider_serialization_pyversion() -> None: + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=pl.CredentialProviderAWS() + ) + + serialized = lf.serialize() + serialized = bytearray(serialized) + + # We can't monkeypatch sys.python_version so we just mutate the output + # instead. + + v = b"PLPYFN" + i = serialized.index(v) + len(v) + a, b = serialized[i:][:2] + serialized_pyver = (a, b) + assert serialized_pyver == (sys.version_info.minor, sys.version_info.micro) + # Note: These are loaded as u8's + serialized[i] = 255 + serialized[i + 1] = 254 + + with pytest.raises(ComputeError, match=r"python version.*(3, 255, 254).*differs.*"): + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index deb44a5a79f4..69e7853172a1 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -12,7 +12,7 @@ import pyarrow as pa import pytest import sqlalchemy -from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast @@ -292,17 +292,18 @@ def test_read_database( tmp_sqlite_db: Path, ) -> None: if read_method == "read_database_uri": + connect_using = cast("DbReadEngine", connect_using) # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) df_empty = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data WHERE name LIKE '%polars%'", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: @@ -382,6 +383,39 @@ def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: assert_frame_equal(batches[0], expected) +def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None: + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + # establish sqlalchemy "textclause" and validate usage + textclause_query = text(""" + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < 0 + """) + + expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(textclause_query, connection=conn), + expected, + ) + + batches = list( + pl.read_database( + textclause_query, + connection=conn, + iter_batches=True, + batch_size=1, + ) + ) + assert len(batches) == 1 + assert_frame_equal(batches[0], expected) + + def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") diff --git a/py-polars/tests/unit/io/files/test_empty_rows.xlsx b/py-polars/tests/unit/io/files/test_empty_rows.xlsx new file mode 100644 index 000000000000..1fc27ee4c3e4 Binary files /dev/null and b/py-polars/tests/unit/io/files/test_empty_rows.xlsx differ diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index bf2035b6d961..22df34dad668 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1833,9 +1833,9 @@ class TemporalFormats(TypedDict): ) assert df.write_csv(quote_style="necessary", **temporal_formats) == ( "float,string,int,bool,date,datetime,time,decimal\n" - '1.0,a,1,true,2077-07-05,,03:01:00,"1.0"\n' - '2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00,"2.0"\n' - ',"""hello",3,,2077-07-05,2077-07-05T03:01:00,,""\n' + "1.0,a,1,true,2077-07-05,,03:01:00,1.0\n" + '2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00,2.0\n' + ',"""hello",3,,2077-07-05,2077-07-05T03:01:00,,\n' ) assert df.write_csv(quote_style="never", **temporal_formats) == ( "float,string,int,bool,date,datetime,time,decimal\n" @@ -1847,9 +1847,9 @@ class TemporalFormats(TypedDict): quote_style="non_numeric", quote_char="8", **temporal_formats ) == ( "8float8,8string8,8int8,8bool8,8date8,8datetime8,8time8,8decimal8\n" - "1.0,8a8,1,8true8,82077-07-058,,803:01:008,81.08\n" - "2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008,82.08\n" - ',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,,88\n' + "1.0,8a8,1,8true8,82077-07-058,,803:01:008,1.0\n" + "2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008,2.0\n" + ',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,,\n' ) @@ -2064,10 +2064,15 @@ def test_read_csv_single_column(columns: list[str] | str) -> None: def test_csv_invalid_escape_utf8_14960() -> None: - with pytest.raises(ComputeError, match=r"field is not properly escaped"): + with pytest.raises(ComputeError, match=r"Field .* is not properly escaped"): pl.read_csv('col1\n""•'.encode()) +def test_csv_invalid_escape() -> None: + with pytest.raises(ComputeError): + pl.read_csv(b'col1,col2\n"a,b') + + @pytest.mark.slow @pytest.mark.write_disk def test_read_csv_only_loads_selected_columns( @@ -2294,3 +2299,11 @@ def test_read_csv_cast_unparsable_later( df.write_csv(f) f.seek(0) assert df.equals(pl.read_csv(f, schema={"x": dtype})) + + +def test_csv_double_new_line() -> None: + assert pl.read_csv(b"a,b,c\n\n", has_header=False).to_dict(as_series=False) == { + "column_1": ["a", None], + "column_2": ["b", None], + "column_3": ["c", None], + } diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py index 33c6b052ffdf..213988964de3 100644 --- a/py-polars/tests/unit/io/test_delta.py +++ b/py-polars/tests/unit/io/test_delta.py @@ -516,3 +516,11 @@ def test_read_parquet_respects_rechunk_16982( rechunk, expected_chunks = rechunk_and_expected_chunks result = pl.read_delta(str(tmp_path), rechunk=rechunk) assert result.n_chunks() == expected_chunks + + +def test_scan_delta_DT_input(delta_table_path: Path) -> None: + DT = DeltaTable(str(delta_table_path), version=0) + ldf = pl.scan_delta(DT) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index a01a2ef6e59d..9e9213ac9bd4 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -779,3 +779,45 @@ def test_hive_predicate_dates_14712( ) pl.scan_parquet(tmp_path).filter(pl.col("a") != datetime(2024, 1, 1)).collect() assert "hive partitioning: skipped 1 files" in capfd.readouterr().err + + +@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows paths") +@pytest.mark.write_disk +def test_hive_windows_splits_on_forward_slashes(tmp_path: Path) -> None: + # Note: This needs to be an absolute path. + tmp_path = tmp_path.resolve() + path = f"{tmp_path}/a=1/b=1/c=1/d=1/e=1" + Path(path).mkdir(exist_ok=True, parents=True) + + df = pl.DataFrame({"x": "x"}) + df.write_parquet(f"{path}/data.parquet") + + expect = pl.DataFrame( + [ + s.new_from_index(0, 5) + for s in pl.DataFrame( + { + "x": "x", + "a": 1, + "b": 1, + "c": 1, + "d": 1, + "e": 1, + } + ) + ] + ) + + assert_frame_equal( + pl.scan_parquet( + [ + f"{tmp_path}/a=1/b=1/c=1/d=1/e=1/data.parquet", + f"{tmp_path}\\a=1\\b=1\\c=1\\d=1\\e=1\\data.parquet", + f"{tmp_path}\\a=1/b=1/c=1/d=1/**/*", + f"{tmp_path}/a=1/b=1\\c=1/d=1/**/*", + f"{tmp_path}/a=1/b=1/c=1/d=1\\e=1/*", + ], + hive_partitioning=True, + ).collect(), + expect, + ) diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py index 5a5f6769e3a5..59ba92559459 100644 --- a/py-polars/tests/unit/io/test_iceberg.py +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -44,6 +44,19 @@ def test_scan_iceberg_plain(self, iceberg_path: str) -> None: "ts": pl.Datetime(time_unit="us", time_zone=None), } + def test_scan_iceberg_snapshot_id(self, iceberg_path: str) -> None: + df = pl.scan_iceberg(iceberg_path, snapshot_id=7051579356916758811) + assert len(df.collect()) == 3 + assert df.collect_schema() == { + "id": pl.Int32, + "str": pl.String, + "ts": pl.Datetime(time_unit="us", time_zone=None), + } + + def test_scan_iceberg_snapshot_id_not_found(self, iceberg_path: str) -> None: + with pytest.raises(ValueError, match="Snapshot ID not found"): + pl.scan_iceberg(iceberg_path, snapshot_id=1234567890) + def test_scan_iceberg_filter_on_partition(self, iceberg_path: str) -> None: ts1 = datetime(2023, 3, 1, 18, 15) ts2 = datetime(2023, 3, 1, 19, 25) diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index 37a583d929df..84e6436cb10e 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -354,3 +354,55 @@ def test_ipc_variadic_buffers_categorical_binview_18636() -> None: df.write_ipc(b) b.seek(0) assert_frame_equal(pl.read_ipc(b), df) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_ipc_chunked_roundtrip(size: int) -> None: + a = pl.Series("a", [{"x": 1}] * size, pl.Struct({"x": pl.Int8})).to_frame() + + c = pl.concat([a] * 2, how="vertical") + + f = io.BytesIO() + c.write_ipc(f) + + f.seek(0) + assert_frame_equal(c, pl.read_ipc(f)) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_ipc_roundtrip(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + f = io.BytesIO() + a.write_ipc(f) + + f.seek(0) + assert_frame_equal(a, pl.read_ipc(f)) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_ipc_chunked_roundtrip(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + c = pl.concat([a] * 2, how="vertical") + + f = io.BytesIO() + c.write_ipc(f) + + f.seek(0) + assert_frame_equal(c, pl.read_ipc(f)) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +@pytest.mark.parametrize("value", [{}, {"x": 1}]) +@pytest.mark.write_disk +def test_memmap_ipc_chunked_structs( + size: int, value: dict[str, int], tmp_path: Path +) -> None: + a = pl.Series("a", [value] * size, pl.Struct).to_frame() + + c = pl.concat([a] * 2, how="vertical") + + f = tmp_path / "f.ipc" + c.write_ipc(f) + assert_frame_equal(c, pl.read_ipc(f)) diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index 93780e79293d..3b003ace6b1e 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -18,6 +18,7 @@ import pytest import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal @@ -68,9 +69,10 @@ def test_write_json_decimal() -> None: def test_json_infer_schema_length_11148() -> None: response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1 - result = pl.read_json(json.dumps(response).encode(), infer_schema_length=2) - with pytest.raises(AssertionError): - assert set(result.columns) == {"col1", "col2"} + with pytest.raises( + pl.exceptions.ComputeError, match="extra key in struct data: col2" + ): + pl.read_json(json.dumps(response).encode(), infer_schema_length=2) response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1 result = pl.read_json(json.dumps(response).encode(), infer_schema_length=3) @@ -160,10 +162,8 @@ def test_ndjson_nested_null() -> None: # 'bar' represents an empty list of structs; check the schema is correct (eg: picks # up that it IS a list of structs), but confirm that list is empty (ref: #11301) # We don't support empty structs yet. So Null is closest. - assert df.schema == { - "foo": pl.Struct([pl.Field("bar", pl.List(pl.Struct({"": pl.Null})))]) - } - assert df.to_dict(as_series=False) == {"foo": [{"bar": []}]} + assert df.schema == {"foo": pl.Struct([pl.Field("bar", pl.List(pl.Struct({})))])} + assert df.to_dict(as_series=False) == {"foo": [{"bar": [{}]}]} def test_ndjson_nested_string_int() -> None: @@ -289,7 +289,7 @@ def test_ndjson_null_buffer() -> None: ("id", pl.Int64), ("zero_column", pl.Int64), ("empty_array_column", pl.List(pl.Null)), - ("empty_object_column", pl.Struct([pl.Field("", pl.Null)])), + ("empty_object_column", pl.Struct([])), ("null_column", pl.Null), ] ) @@ -388,7 +388,7 @@ def test_empty_json() -> None: df = pl.read_json(b'{"j":{}}') assert df.dtypes == [pl.Struct([])] - assert df.shape == (0, 1) + assert df.shape == (1, 1) def test_compressed_json() -> None: @@ -435,6 +435,77 @@ def test_empty_list_json() -> None: def test_json_infer_3_dtypes() -> None: # would SO before df = pl.DataFrame({"a": ["{}", "1", "[1, 2]"]}) - out = df.select(pl.col("a").str.json_decode()) + + with pytest.raises(pl.exceptions.ComputeError): + df.select(pl.col("a").str.json_decode()) + + df = pl.DataFrame({"a": [None, "1", "[1, 2]"]}) + out = df.select(pl.col("a").str.json_decode(dtype=pl.List(pl.String))) assert out["a"].to_list() == [None, ["1"], ["1", "2"]] assert out.dtypes[0] == pl.List(pl.String) + + +# NOTE: This doesn't work for 0, but that is normal +@pytest.mark.parametrize("size", [1, 2, 13]) +def test_zfs_json_roundtrip(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + f = io.StringIO() + a.write_json(f) + + f.seek(0) + assert_frame_equal(a, pl.read_json(f)) + + +def test_read_json_raise_on_data_type_mismatch() -> None: + with pytest.raises(ComputeError): + pl.read_json( + b"""\ +[ + {"a": null}, + {"a": 1} +] +""", + infer_schema_length=1, + ) + + +def test_read_json_struct_schema() -> None: + with pytest.raises(ComputeError, match="extra key in struct data: b"): + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + infer_schema_length=1, + ) + + assert_frame_equal( + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + infer_schema_length=2, + ), + pl.DataFrame({"a": [1, 2], "b": [None, 2]}), + ) + + # If the schema was explicitly given, then we ignore extra fields. + # TODO: There should be a `columns=` parameter to this. + assert_frame_equal( + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + schema={"a": pl.Int64}, + ), + pl.DataFrame({"a": [1, 2]}), + ) diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 792ddc42b02a..49a842386ea2 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -235,12 +235,12 @@ def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) captured = capfd.readouterr().err assert ( - "parquet file must be read, statistics not sufficient for predicate." + "parquet row group must be read, statistics not sufficient for predicate." in captured ) assert ( - "parquet file can be skipped, the statistics were sufficient" - " to apply the predicate." in captured + "parquet row group can be skipped, the statistics were sufficient to apply the predicate." + in captured ) @@ -710,10 +710,18 @@ def test_parquet_schema_arg( schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef] - lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + for allow_missing_columns in [True, False]: + lf = pl.scan_parquet( + paths, + parallel=parallel, + schema=schema, + allow_missing_columns=allow_missing_columns, + ) - with pytest.raises(pl.exceptions.SchemaError, match="file contained extra columns"): - lf.collect(streaming=streaming) + with pytest.raises( + pl.exceptions.SchemaError, match="file contained extra columns" + ): + lf.collect(streaming=streaming) lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a") @@ -731,3 +739,29 @@ def test_parquet_schema_arg( match="data type mismatch for column b: expected: i8, found: i64", ): lf.collect(streaming=streaming) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("allow_missing_columns", [True, False]) +@pytest.mark.write_disk +def test_scan_parquet_ignores_dtype_mismatch_for_non_projected_columns_19249( + tmp_path: Path, + allow_missing_columns: bool, + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + paths = [tmp_path / "1", tmp_path / "2"] + + pl.DataFrame({"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt8}).write_parquet( + paths[0] + ) + pl.DataFrame( + {"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt64} + ).write_parquet(paths[1]) + + assert_frame_equal( + pl.scan_parquet(paths, allow_missing_columns=allow_missing_columns) + .select("a") + .collect(streaming=streaming), + pl.DataFrame({"a": [1, 1]}, schema={"a": pl.Int32}), + ) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 8431c659cce0..bf9d8ac4fad8 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1,9 +1,10 @@ from __future__ import annotations +import decimal import io from datetime import datetime, time, timezone from decimal import Decimal -from typing import IO, TYPE_CHECKING, Any, Literal, cast +from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast import fsspec import numpy as np @@ -625,6 +626,10 @@ def test_parquet_rle_non_nullable_12814() -> None: pq.write_table(table, f, data_page_size=1) f.seek(0) + print(pq.read_table(f)) + + f.seek(0) + expect = pl.DataFrame(table).tail(10) actual = pl.read_parquet(f).tail(10) @@ -1088,7 +1093,7 @@ def test_hybrid_rle() -> None: pl.Boolean, ], min_size=1, - max_size=5000, + max_size=500, ) ) @pytest.mark.slow @@ -1961,3 +1966,150 @@ def test_allow_missing_columns( .collect(streaming=streaming), expected, ) + + +def test_nested_nonnullable_19158() -> None: + # Bug is based on the top-level struct being nullable and the inner list + # not being nullable. + tbl = pa.table( + { + "a": [{"x": [1]}, None, {"x": [1, 2]}, None], + }, + schema=pa.schema( + [ + pa.field( + "a", + pa.struct([pa.field("x", pa.list_(pa.int8()), nullable=False)]), + nullable=True, + ) + ] + ), + ) + + f = io.BytesIO() + pq.write_table(tbl, f) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), pl.DataFrame(tbl)) + + +D = Decimal + + +@pytest.mark.parametrize("precision", range(1, 37, 2)) +@pytest.mark.parametrize( + "nesting", + [ + # Struct + lambda t: ([{"x": None}, None], pl.Struct({"x": t})), + lambda t: ([None, {"x": None}], pl.Struct({"x": t})), + lambda t: ([{"x": D("1.5")}, None], pl.Struct({"x": t})), + lambda t: ([{"x": D("1.5")}, {"x": D("4.8")}], pl.Struct({"x": t})), + # Array + lambda t: ([[None, None, D("8.2")], None], pl.Array(t, 3)), + lambda t: ([None, [None, D("8.9"), None]], pl.Array(t, 3)), + lambda t: ([[D("1.5"), D("3.7"), D("4.1")], None], pl.Array(t, 3)), + lambda t: ( + [[D("1.5"), D("3.7"), D("4.1")], [D("2.8"), D("5.2"), D("8.9")]], + pl.Array(t, 3), + ), + # List + lambda t: ([[None, D("8.2")], None], pl.List(t)), + lambda t: ([None, [D("8.9"), None]], pl.List(t)), + lambda t: ([[D("1.5"), D("4.1")], None], pl.List(t)), + lambda t: ([[D("1.5"), D("3.7"), D("4.1")], [D("2.8"), D("8.9")]], pl.List(t)), + ], +) +def test_decimal_precision_nested_roundtrip( + nesting: Callable[[pl.DataType], tuple[list[Any], pl.DataType]], + precision: int, +) -> None: + # Limit the context as to not disturb any other tests + with decimal.localcontext() as ctx: + ctx.prec = precision + + decimal_dtype = pl.Decimal(precision=precision) + values, dtype = nesting(decimal_dtype) + + df = pl.Series("a", values, dtype).to_frame() + + test_round_trip(df) + + +@pytest.mark.parametrize("parallel", ["prefiltered", "columns", "row_groups", "auto"]) +def test_conserve_sortedness( + monkeypatch: Any, capfd: Any, parallel: pl.ParallelStrategy +) -> None: + f = io.BytesIO() + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, None], + "b": [1.0, 2.0, 3.0, 4.0, 5.0, None], + "c": [None, 5, 4, 3, 2, 1], + "d": [None, 5.0, 4.0, 3.0, 2.0, 1.0], + "a_nosort": [1, 2, 3, 4, 5, None], + "f": range(6), + } + ) + + pq.write_table( + df.to_arrow(), + f, + sorting_columns=[ + pq.SortingColumn(0, False, False), + pq.SortingColumn(1, False, False), + pq.SortingColumn(2, True, True), + pq.SortingColumn(3, True, True), + ], + ) + + f.seek(0) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.scan_parquet(f, parallel=parallel).filter(pl.col.f > 1).collect() + + captured = capfd.readouterr().err + + # @NOTE: We don't conserve sortedness for anything except integers at the + # moment. + assert captured.count("Parquet conserved SortingColumn for column chunk of") == 2 + assert ( + "Parquet conserved SortingColumn for column chunk of 'a' to Ascending" + in captured + ) + assert ( + "Parquet conserved SortingColumn for column chunk of 'c' to Descending" + in captured + ) + + +def test_f16() -> None: + values = [float("nan"), 0.0, 0.5, 1.0, 1.5] + + table = pa.Table.from_pydict( + { + "x": pa.array(np.array(values, dtype=np.float16), type=pa.float16()), + } + ) + + df = pl.Series("x", values, pl.Float32).to_frame() + + f = io.BytesIO() + pq.write_table(table, f) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).filter(pl.col.x > 0.5).collect(), + df.filter(pl.col.x > 0.5), + ) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(1, 3).collect(), + df.slice(1, 3), + ) diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 799c4953cbf6..30af7b830ff8 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -801,3 +801,37 @@ def test_scan_double_collect_row_index_invalidates_cached_ir_18892() -> None: schema={"index": pl.UInt32, "a": pl.Int64}, ), ) + + +def test_scan_include_file_paths_respects_projection_pushdown() -> None: + q = pl.scan_csv(b"a,b,c\na1,b1,c1", include_file_paths="path_name").select( + ["a", "b"] + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": "a1", "b": "b1"})) + + +def test_streaming_scan_csv_include_file_paths_18257(io_files_path: Path) -> None: + lf = pl.scan_csv( + io_files_path / "foods1.csv", + include_file_paths="path", + ).select("category", "path") + + assert lf.collect(streaming=True).columns == ["category", "path"] + + +def test_streaming_scan_csv_with_row_index_19172(io_files_path: Path) -> None: + lf = ( + pl.scan_csv(io_files_path / "foods1.csv", infer_schema=False) + .with_row_index() + .select("calories", "index") + .head(1) + ) + + assert_frame_equal( + lf.collect(streaming=True), + pl.DataFrame( + {"calories": "45", "index": 0}, + schema={"calories": pl.String, "index": pl.UInt32}, + ), + ) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 7483f371b51c..b7b03a0bd02e 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -21,7 +21,7 @@ from polars._typing import ExcelSpreadsheetEngine, SelectorType -pytestmark = pytest.mark.slow() +# pytestmark = pytest.mark.slow() @pytest.fixture @@ -83,6 +83,11 @@ def path_ods_mixed(io_files_path: Path) -> Path: return io_files_path / "mixed.ods" +@pytest.fixture +def path_empty_rows_excel(io_files_path: Path) -> Path: + return io_files_path / "test_empty_rows.xlsx" + + @pytest.mark.parametrize( ("read_spreadsheet", "source", "engine_params"), [ @@ -227,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None: xls = BytesIO() df.write_excel(xls, position="C5") - schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean} + schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean()} df_compare = df.with_columns( pl.col(nm).cast(tp) for nm, tp in schema_overrides.items() ) @@ -317,13 +322,12 @@ def test_read_mixed_dtype_columns( ) -> None: spreadsheet_path = request.getfixturevalue(source) schema_overrides = { - "Employee ID": pl.Utf8, - "Employee Name": pl.Utf8, - "Date": pl.Date, - "Details": pl.Categorical, - "Asset ID": pl.Utf8, + "Employee ID": pl.Utf8(), + "Employee Name": pl.Utf8(), + "Date": pl.Date(), + "Details": pl.Categorical("lexical"), + "Asset ID": pl.Utf8(), } - df = read_spreadsheet( spreadsheet_path, sheet_id=0, @@ -1060,3 +1064,38 @@ def test_identify_workbook( bytesio_data = BytesIO(f.read()) assert _identify_workbook(bytesio_data) == file_type assert isinstance(pl.read_excel(bytesio_data, engine="calamine"), pl.DataFrame) + + +def test_drop_empty_rows(path_empty_rows_excel: Path) -> None: + df1 = pl.read_excel(source=path_empty_rows_excel, engine="xlsx2csv") + assert df1.shape == (8, 4) + df2 = pl.read_excel( + source=path_empty_rows_excel, engine="xlsx2csv", drop_empty_rows=True + ) + assert df2.shape == (8, 4) + df3 = pl.read_excel( + source=path_empty_rows_excel, engine="xlsx2csv", drop_empty_rows=False + ) + assert df3.shape == (10, 4) + + df4 = pl.read_excel(source=path_empty_rows_excel, engine="openpyxl") + assert df4.shape == (8, 4) + df5 = pl.read_excel( + source=path_empty_rows_excel, engine="openpyxl", drop_empty_rows=True + ) + assert df5.shape == (8, 4) + df6 = pl.read_excel( + source=path_empty_rows_excel, engine="openpyxl", drop_empty_rows=False + ) + assert df6.shape == (10, 4) + + df7 = pl.read_excel(source=path_empty_rows_excel, engine="calamine") + assert df7.shape == (8, 4) + df8 = pl.read_excel( + source=path_empty_rows_excel, engine="calamine", drop_empty_rows=True + ) + assert df8.shape == (8, 4) + df9 = pl.read_excel( + source=path_empty_rows_excel, engine="calamine", drop_empty_rows=False + ) + assert df9.shape == (10, 4) diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index a82e389b4583..8ddcbfafd6f6 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -116,3 +116,33 @@ def test_lf_serde_scan(tmp_path: Path) -> None: result = pl.LazyFrame.deserialize(io.BytesIO(ser)) assert_frame_equal(result, lf) assert_frame_equal(result.collect(), df) + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).select( + pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64) + ) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected) + + +def custom_function(x: pl.Series) -> pl.Series: + return x + 1 + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_version_specific_named_function( + monkeypatch: pytest.MonkeyPatch, +) -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).select( + pl.col("a").map_batches(custom_function, return_dtype=pl.Int64) + ) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/meta/test_versions.py b/py-polars/tests/unit/meta/test_versions.py index 944f921ac1fb..36504c7fe4c5 100644 --- a/py-polars/tests/unit/meta/test_versions.py +++ b/py-polars/tests/unit/meta/test_versions.py @@ -12,4 +12,5 @@ def test_show_versions(capsys: Any) -> None: out, _ = capsys.readouterr() assert "Python" in out assert "Polars" in out + assert "LTS CPU" in out assert "Optional dependencies" in out diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 360def065ca1..f989bc4681e2 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -21,7 +21,7 @@ UInt32, UInt64, ) -from polars.exceptions import ColumnNotFoundError, InvalidOperationError, SchemaError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES @@ -284,8 +284,8 @@ def test_operator_arithmetic_with_nulls(op: Any, dtype: pl.DataType) -> None: df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr)) ) - assert_frame_equal(df_expected, op(df, None)) - assert_series_equal(s_expected, op(s, None)) + assert_frame_equal(op(df, None), df_expected) + assert_series_equal(op(s, None), s_expected) @pytest.mark.parametrize( @@ -598,7 +598,6 @@ def test_array_arithmetic_same_size( pl.Series("nested", np.array([[[1, 2]], [[3, 4]]], dtype=np.int64)), ] ) - print(df.select(expr(pl.col(column_names[0]), pl.col(column_names[1])))) # Expr-based arithmetic: assert_frame_equal( df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), @@ -611,110 +610,6 @@ def test_array_arithmetic_same_size( ) -@pytest.mark.parametrize( - ("expected", "expr", "column_names"), - [ - ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), - ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), - ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), - ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), - ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), - ( - [[3, 4], [7]], - lambda a, b: a + b, - ("a", "uint8"), - ), - ( - [[[2, 4]], [[6]]], - lambda a, b: a + b, - ("nested", "nested"), - ), - ( - [[[2, 4]], [[6]]], - lambda a, b: a + b, - ("nested", "nested_uint8"), - ), - ], -) -def test_list_arithmetic_same_size( - expected: Any, - expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], - column_names: tuple[str, str], -) -> None: - df = pl.DataFrame( - [ - pl.Series("a", [[1, 2], [3]]), - pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), - pl.Series("nested", [[[1, 2]], [[3]]]), - pl.Series( - "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) - ), - ] - ) - # Expr-based arithmetic: - assert_frame_equal( - df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), - pl.Series(column_names[0], expected).to_frame(), - ) - # Direct arithmetic on the Series: - assert_series_equal( - expr(df[column_names[0]], df[column_names[1]]), - pl.Series(column_names[0], expected), - ) - - -@pytest.mark.parametrize( - ("a", "b", "expected"), - [ - ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), - ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), - ([[[2]], [None], [[4]]], [[[3]], [[6]], [[8]]], [[[5]], [None], [[12]]]), - ], -) -def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: - series_a = pl.Series(a) - series_b = pl.Series(b) - series_expected = pl.Series(expected) - - # Same dtype: - assert_series_equal(series_a + series_b, series_expected) - - # Different dtype: - assert_series_equal( - series_a._recursive_cast_to_dtype(pl.Int32()) - + series_b._recursive_cast_to_dtype(pl.Int64()), - series_expected._recursive_cast_to_dtype(pl.Int64()), - ) - - -def test_list_arithmetic_error_cases() -> None: - # Different series length: - with pytest.raises( - InvalidOperationError, match="Series of the same size; got 1 and 2" - ): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) - with pytest.raises( - InvalidOperationError, match="Series of the same size; got 1 and 2" - ): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None]) - - # Different list length: - with pytest.raises(InvalidOperationError, match="lists of the same size"): - _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) - with pytest.raises( - InvalidOperationError, match="lists of the same size; got 2 and 1" - ): - _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) - - # Wrong types: - with pytest.raises(InvalidOperationError, match="cannot cast List type"): - _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) - - # Different nesting: - with pytest.raises(SchemaError, match="failed to determine supertype"): - _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) - - def test_schema_owned_arithmetic_5669() -> None: df = ( pl.LazyFrame({"A": [1, 2, 3]}) @@ -891,5 +786,10 @@ def test_date_datetime_sub() -> None: def test_raise_invalid_shape() -> None: - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises(InvalidOperationError): pl.DataFrame([[1, 2], [3, 4]]) * pl.DataFrame([1, 2, 3]) + + +def test_integer_divide_scalar_zero_lhs_19142() -> None: + assert_series_equal(pl.Series([0]) // pl.Series([1, 0]), pl.Series([0, None])) + assert_series_equal(pl.Series([0]) % pl.Series([1, 0]), pl.Series([0, None])) diff --git a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py new file mode 100644 index 000000000000..0e7af02d8290 --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py @@ -0,0 +1,1058 @@ +from __future__ import annotations + +import operator +from typing import Any, Callable + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError, ShapeError +from polars.testing import assert_frame_equal, assert_series_equal + + +def exec_op_with_series(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: + v: pl.Series = op(lhs, rhs) + return v + + +def build_expr_op_exec( + type_coercion: bool, +) -> Callable[[pl.Series, pl.Series, Any], pl.Series]: + def func(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: + return ( + pl.select(lhs) + .lazy() + .select(op(pl.first(), rhs)) + .collect(type_coercion=type_coercion) + .to_series() + ) + + return func + + +def build_series_broadcaster( + side: str, +) -> Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] +]: + length = 3 + + if side == "left": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return l.new_from_index(0, length), r, o.new_from_index(0, length) + elif side == "right": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return l, r.new_from_index(0, length), o.new_from_index(0, length) + elif side == "both": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return ( + l.new_from_index(0, length), + r.new_from_index(0, length), + o.new_from_index(0, length), + ) + elif side == "none": + + def func( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + return l, r, o + else: + raise ValueError(side) + + return func + + +BROADCAST_SERIES_COMBINATIONS = [ + build_series_broadcaster("left"), + build_series_broadcaster("right"), + build_series_broadcaster("both"), + build_series_broadcaster("none"), +] + +EXEC_OP_COMBINATIONS = [ + exec_op_with_series, + build_expr_op_exec(True), + build_expr_op_exec(False), +] + + +@pytest.mark.parametrize( + "list_side", ["left", "left3", "both", "right3", "right", "none"] +) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_list_arithmetic_values( + list_side: str, + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + """ + Tests value correctness. + + This test checks for output value correctness (a + b == c) across different + codepaths, by wrapping the values (a, b, c) in different combinations of + list / primitive columns. + """ + import operator as op + + dtypes: list[Any] = [pl.Null, pl.Null, pl.Null] + dtype: Any = pl.Null + + def materialize_list(v: Any) -> pl.Series: + return pl.Series( + [[None, v, None]], + dtype=pl.List(dtype), + ) + + def materialize_list3(v: Any) -> pl.Series: + return pl.Series( + [[[[None, v], None], None]], + dtype=pl.List(pl.List(pl.List(dtype))), + ) + + def materialize_primitive(v: Any) -> pl.Series: + return pl.Series([v], dtype=dtype) + + def materialize_series( + l: Any, # noqa: E741 + r: Any, + o: Any, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + nonlocal dtype + + dtype = dtypes[0] + l = { # noqa: E741 + "left": materialize_list, + "left3": materialize_list3, + "both": materialize_list, + "right": materialize_primitive, + "right3": materialize_primitive, + "none": materialize_primitive, + }[list_side](l) # fmt: skip + + dtype = dtypes[1] + r = { + "left": materialize_primitive, + "left3": materialize_primitive, + "both": materialize_list, + "right": materialize_list, + "right3": materialize_list3, + "none": materialize_primitive, + }[list_side](r) # fmt: skip + + dtype = dtypes[2] + o = { + "left": materialize_list, + "left3": materialize_list3, + "both": materialize_list, + "right": materialize_list, + "right3": materialize_list3, + "none": materialize_primitive, + }[list_side](o) # fmt: skip + + assert l.len() == 1 + assert r.len() == 1 + assert o.len() == 1 + + return broadcast_series(l, r, o) + + # Signed + dtypes = [pl.Int8, pl.Int8, pl.Int8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(-5, 127, 124) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(-5, 127, -123) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(-5, 3, -2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Unsigned + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(2, 3, 255) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(2, 128, 0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(5, 2, 2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(5, 2, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Floats. Note we pick Float32 to ensure there is no accidental upcasting + # to Float64. + dtypes = [pl.Float32, pl.Float32, pl.Float32] + l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for zero behavior + # + + # Integer + + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(1, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Float + + dtypes = [pl.Float32, pl.Float32, pl.Float32] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for NULL behavior + # + + for dtype, truediv_dtype in [ # type: ignore[misc] + [pl.Int8, pl.Float64], + [pl.Float32, pl.Float32], + ]: + for vals in [ + [None, None, None], + [0, None, None], + [None, 0, None], + [0, None, None], + [None, 0, None], + [3, None, None], + [None, 3, None], + ]: + dtypes = 3 * [dtype] + + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + dtypes[2] = truediv_dtype # type: ignore[has-type] + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Type upcasting for Boolean and Null + + # Check boolean upcasting + dtypes = [pl.Boolean, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(True, 3, 4) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(True, 3, 254) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(True, 3, 3) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(True, 3, 0) # noqa: E741 + if list_side != "none": + # TODO: FIXME: We get an error on non-lists with this: + # "floor_div operation not supported for dtype `bool`" + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(True, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Boolean, pl.UInt8, pl.Float64] + l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Check Null upcasting + dtypes = [pl.Null, pl.UInt8, pl.UInt8] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + if list_side != "none": + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Null, pl.UInt8, pl.Float64] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_supertype( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8)) + b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64)) + + assert_series_equal( + exec_op(a, b, op.add), + pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.slow +def test_list_numeric_op_validity_combination( + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64)) + # expected result + e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64)) + + assert_series_equal( + exec_op(a, b, op.add), + e, + ) + + a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [None], dtype=pl.Int64) + e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [1], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [0], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.floordiv), e) + + +def test_list_add_alignment() -> None: + a = pl.Series("a", [[1, 1], [1, 1, 1]]) + b = pl.Series("b", [[1, 1, 1], [1, 1]]) + + df = pl.DataFrame([a, b]) + + with pytest.raises(ShapeError): + df.select(x=pl.col("a") + pl.col("b")) + + # Test masking and slicing + a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]]) + b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]]) + c = pl.Series("c", [1, 1, 1, 1]) + p = pl.Series("p", [True, True, False, False]) + + df = pl.DataFrame([a, b, c, p]).filter("p").slice(1) + + for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]: + assert_series_equal( + df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]]) + ) + + df = df.vstack(df) + + for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]: + assert_series_equal( + df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]]) + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_list_add_empty_lists( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + l = pl.Series( # noqa: E741 + "x", + [[[[]], []], []], + ) + r = pl.Series([1]) + + assert_series_equal( + exec_op(l, r, operator.add), + pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))), + ) + + l = pl.Series( # noqa: E741 + "x", + [[[[]], None], []], + ) + r = pl.Series([1]) + + assert_series_equal( + exec_op(l, r, operator.add), + pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_to_list_arithmetic_double_nesting_raises_error( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + s = pl.Series(dtype=pl.List(pl.List(pl.Int32))) + + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): + exec_op(s, s, operator.add) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_height_mismatch( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32)) + + # TODO: Make the error type consistently a ShapeError + with pytest.raises( + (ShapeError, InvalidOperationError), + match="length", + ): + exec_op(s, pl.Series([1, 1]), operator.add) + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.mod, + operator.truediv, + ], +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_list_date_to_numeric_arithmetic_raises_error( + op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series] +) -> None: + l = pl.Series([1], dtype=pl.Date) # noqa: E741 + r = pl.Series([[1]], dtype=pl.List(pl.Int32)) + + exec_op(l.to_physical(), r, op) + + # TODO(_): Ideally this always raises InvalidOperationError. The TypeError + # is being raised by checks on the Python side that should be moved to Rust. + with pytest.raises((InvalidOperationError, TypeError)): + exec_op(l, r, op) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), + ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), + ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("a", "uint8"), + ), + ], +) +def test_list_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", [[1, 2], [3]]), + pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), + pl.Series("nested", [[[1, 2]], [[3]]]), + pl.Series( + "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) + ), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), + ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), + ], +) +def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None]) + + # Different list length: + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]]) + + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): + _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) + + # Different nesting: + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): + _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + # All 5 arithmetic operations: + ([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")), + ([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")), + ([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")), + ([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")), + ([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")), + # Different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("list", "uint8"), + ), + # Extra nesting + different types: + ( + [[[3, 4]], [[8]]], + lambda a, b: a + b, + ("nested", "int64"), + ), + # Primitive numeric on the left; only addition and multiplication are + # supported: + ([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")), + ([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")), + # Primitive numeric on the left with different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("uint8", "list"), + ), + ( + [[2, 4], [12]], + lambda a, b: a * b, + ("uint8", "list"), + ), + ], +) +def test_list_and_numeric_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("list", [[1, 2], [3]]), + pl.Series("int64", [2, 3], dtype=pl.Int64()), + pl.Series("uint8", [2, 4], dtype=pl.UInt8()), + pl.Series("nested", [[[1, 2]], [[5]]]), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + # Null on numeric on the right: + ([[1, 2], [3]], [1, None], [[2, 3], [None]]), + # Null on list on the left: + ([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]), + # Extra nesting: + ([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]), + ], +) +def test_list_and_numeric_arithmetic_nulls( + a: list[Any], b: list[Any], expected: list[Any] +) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected, dtype=series_a.dtype) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + # Swap sides: + assert_series_equal(series_b + series_a, series_expected) + assert_series_equal( + series_b._recursive_cast_to_dtype(pl.Int32()) + + series_a._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_and_numeric_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2]) + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): + _ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"]) + + +@pytest.mark.parametrize("broadcast", [True, False]) +@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()]) +def test_list_arithmetic_div_ops_zero_denominator( + broadcast: bool, dtype: pl.DataType +) -> None: + # Notes + # * truediv (/) on integers upcasts to Float64 + # * Otherwise, we test floordiv (//) and module/rem (%) + # * On integers, 0-denominator is expected to output NULL + # * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending + # on a few factors (e.g. whether the numerator is also 0). + + s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype)) + + n = 1 if broadcast else s.len() + + # list<->primitive + + # truediv + assert_series_equal( + pl.Series([1]).new_from_index(0, n) / s, + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + assert_series_equal( + s / pl.Series([1]).new_from_index(0, n), + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + # floordiv + assert_series_equal( + pl.Series([1]).new_from_index(0, n) // s, + ( + pl.Series([[None], [1], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s // pl.Series([0]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=s.dtype + ) + ), + ) + + # rem + assert_series_equal( + pl.Series([1]).new_from_index(0, n) % s, + ( + pl.Series([[None], [0], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s % pl.Series([0]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("nan")], [None], None], dtype=s.dtype + ) + ), + ) + + # list<->list + + # truediv + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) / s, + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + assert_series_equal( + s / pl.Series([[0]]).new_from_index(0, n), + pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64) + ), + ) + + # floordiv + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) // s, + ( + pl.Series([[None], [1], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s // pl.Series([[0]]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=s.dtype + ) + ), + ) + + # rem + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) % s, + ( + pl.Series([[None], [0], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s % pl.Series([[0]]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("nan")], [None], None], dtype=s.dtype + ) + ), + ) + + +def test_list_to_primitive_arithmetic() -> None: + # Input data + # * List type: List(List(List(Int16))) (triple-nested) + # * Numeric type: Int32 + # + # Tests run + # Broadcast Operation + # | L | R | + # * list<->primitive | | | floor_div + # * primitive<->list | | | floor_div + # * list<->primitive | | * | subtract + # * primitive<->list | * | | subtract + # * list<->primitive | * | | subtract + # * primitive<->list | | * | subtract + # + # Notes + # * In floor_div, we check that results from a 0 denominator are masked out + # * We choose floor_div and subtract as they emit different results when + # sides are swapped + + # Create some non-zero start offsets and masked out rows. + lhs = ( + pl.Series( + [ + [[[None, None, None, None, None]]], # sliced out + # Nulls at every level XO + [[[3, 7]], [[-3], [None], [], [], None], [], None], + [[[1, 2, 3, 4, 5]]], # masked out + [[[3, 7]], [[0], [None], [], [], None]], + [[[3, 7]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ) + .slice(1) + .to_frame() + .select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first())) + .to_series() + ) + + # Note to reader: This is what our LHS looks like + assert_series_equal( + lhs, + pl.Series( + [ + [[[3, 7]], [[-3], [None], [], [], None], [], None], + None, + [[[3, 7]], [[0], [None], [], [], None]], + [[[3, 7]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ), + ) + + class _: + # Floor div, no broadcasting + rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32) + + assert len(lhs) == len(rhs) + + expect = pl.Series( + [ + [[[0, 1]], [[-1], [None], [], [], None], [], None], + None, + [[[None, None]], [[None], [None], [], [], None]], + [[[None, None]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = ( + pl.select(l=lhs, r=rhs) + .select(pl.col("l") // pl.col("r")) + .to_series() + .alias("") + ) + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[1, 0]], [[-2], [None], [], [], None], [], None], + None, + [[[0, 0]], [[None], [None], [], [], None]], + [[[None, None]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = ( # noqa: PIE794 + pl.select(l=lhs, r=rhs) + .select(pl.col("r") // pl.col("l")) + .to_series() + .alias("") + ) + + assert_series_equal(out, expect) + + class _: # type: ignore[no-redef] + # Subtraction with broadcasting + rhs = pl.Series([1], dtype=pl.Int32) + + expect = pl.Series( + [ + [[[2, 6]], [[-4], [None], [], [], None], [], None], + None, + [[[2, 6]], [[-1], [None], [], [], None]], + [[[2, 6]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("") + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[-2, -6]], [[4], [None], [], [], None], [], None], + None, + [[[-2, -6]], [[1], [None], [], [], None]], + [[[-2, -6]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794 + + assert_series_equal(out, expect) + + # Test broadcasting of the list side + lhs = lhs.slice(2, 1) + # Note to reader: This is what our LHS looks like + assert_series_equal( + lhs, + pl.Series( + [ + [[[3, 7]], [[0], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ), + ) + + assert len(lhs) == 1 + + class _: # type: ignore[no-redef] + rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32) + + expect = pl.Series( + [ + [[[2, 6]], [[-1], [None], [], [], None]], + [[[1, 5]], [[-2], [None], [], [], None]], + [[[0, 4]], [[-3], [None], [], [], None]], + [[[None, None]], [[None], [None], [], [], None]], + [[[-2, 2]], [[-5], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("") + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[-2, -6]], [[1], [None], [], [], None]], + [[[-1, -5]], [[2], [None], [], [], None]], + [[[0, -4]], [[3], [None], [], [], None]], + [[[None, None]], [[None], [None], [], [], None]], + [[[2, -2]], [[5], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE794 + + assert_series_equal(out, expect) diff --git a/py-polars/tests/unit/operations/map/test_map_groups.py b/py-polars/tests/unit/operations/map/test_map_groups.py index 772f4d088249..cffc78c93ef1 100644 --- a/py-polars/tests/unit/operations/map/test_map_groups.py +++ b/py-polars/tests/unit/operations/map/test_map_groups.py @@ -86,6 +86,7 @@ def test_map_groups_none() -> None: pl.map_groups( exprs=["a", pl.col("b") ** 4, pl.col("a") / 4], function=lambda x: x[0] * x[1] + x[2].sum(), + return_dtype=pl.Float64, ).alias("multiple") ) )["multiple"] @@ -127,7 +128,9 @@ def __init__(self, payload: Any) -> None: result = df.group_by("groups").agg( pl.map_groups( - [pl.col("dates"), pl.col("names")], lambda s: Foo(dict(zip(s[0], s[1]))) + [pl.col("dates"), pl.col("names")], + lambda s: Foo(dict(zip(s[0], s[1]))), + return_dtype=pl.Object, ) ) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 966fee3ea5ac..1264a5ed8773 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -7,7 +7,11 @@ import pytest import polars as pl -from polars.exceptions import ComputeError, OutOfBoundsError, SchemaError +from polars.exceptions import ( + ComputeError, + OutOfBoundsError, + SchemaError, +) from polars.testing import assert_frame_equal, assert_series_equal @@ -246,6 +250,16 @@ def test_list_contains_invalid_datatype() -> None: df.select(pl.col("a").list.contains(2)) +def test_list_contains_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in list.contains + # https://github.com/pola-rs/polars/issues/18968 + df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]}) + assert df.select(pl.all().list.contains(3)).to_dict(as_series=False) == { + "a": [False], + "b": [True], + } + + def test_list_concat() -> None: df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]}) @@ -643,6 +657,26 @@ def test_list_to_struct() -> None: {"n": {"one": 0, "two": 1, "three": None}}, ] + q = df.lazy().select( + pl.col("n").list.to_struct(fields=["a", "b"]).struct.field("a") + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": [0, 0]})) + + # Check that: + # * Specifying an upper bound calls the field name getter function to + # retrieve the lazy schema + # * The upper bound is respected during execution + q = df.lazy().select( + pl.col("n").list.to_struct(fields=str, upper_bound=2).struct.unnest() + ) + assert q.collect_schema() == {"0": pl.Int64, "1": pl.Int64} + assert_frame_equal(q.collect(), pl.DataFrame({"0": [0, 0], "1": [1, 1]})) + + assert df.lazy().select(pl.col("n").list.to_struct()).collect_schema() == { + "n": pl.Unknown + } + def test_select_from_list_to_struct_11143() -> None: ldf = pl.LazyFrame({"some_col": [[1.0, 2.0], [1.5, 3.0]]}) @@ -686,6 +720,16 @@ def test_list_count_matches_boolean_nulls_9141() -> None: assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1] +def test_list_count_matches_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in list.count_match + # https://github.com/pola-rs/polars/issues/18968 + df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]}) + assert df.select(pl.all().list.count_matches(3)).to_dict(as_series=False) == { + "a": [0], + "b": [1], + } + + def test_list_gather_oob_10079() -> None: df = pl.DataFrame( { @@ -885,3 +929,29 @@ def test_list_get_with_null() -> None: def test_list_sum_bool_schema() -> None: q = pl.LazyFrame({"x": [[True, True, False]]}) assert q.select(pl.col("x").list.sum()).collect_schema()["x"] == pl.UInt32 + + +def test_list_concat_struct_19279() -> None: + df = pl.select( + pl.struct(s=pl.lit("abcd").str.split("").explode(), i=pl.int_range(0, 4)) + ) + df = pl.concat([df[:2], df[-2:]]) + assert df.select(pl.concat_list("s")).to_dict(as_series=False) == { + "s": [ + [{"s": "a", "i": 0}], + [{"s": "b", "i": 1}], + [{"s": "c", "i": 2}], + [{"s": "d", "i": 3}], + ] + } + + +def test_list_eval_element_schema_19345() -> None: + assert_frame_equal( + ( + pl.LazyFrame({"a": [[{"a": 1}]]}) + .select(pl.col("a").list.eval(pl.element().struct.field("a"))) + .collect() + ), + pl.DataFrame({"a": [[1]]}), + ) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py index 842b0fd141a5..fcca5c5987b1 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_string.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -1006,6 +1006,66 @@ def test_replace_all() -> None: ) +def test_replace_all_literal_no_caputures() -> None: + # When using literal = True, capture groups should be disabled + + # Single row code path in Rust + df = pl.DataFrame({"text": ["I found yesterday."], "amt": ["$1"]}) + df = df.with_columns( + pl.col("text") + .str.replace_all("", pl.col("amt"), literal=True) + .alias("text2") + ) + assert df.get_column("text2")[0] == "I found $1 yesterday." + + # Multi-row code path in Rust + df2 = pl.DataFrame( + { + "text": ["I found yesterday.", "I lost yesterday."], + "amt": ["$1", "$2"], + } + ) + df2 = df2.with_columns( + pl.col("text") + .str.replace_all("", pl.col("amt"), literal=True) + .alias("text2") + ) + assert df2.get_column("text2")[0] == "I found $1 yesterday." + assert df2.get_column("text2")[1] == "I lost $2 yesterday." + + +def test_replace_literal_no_caputures() -> None: + # When using literal = True, capture groups should be disabled + + # Single row code path in Rust + df = pl.DataFrame({"text": ["I found yesterday."], "amt": ["$1"]}) + df = df.with_columns( + pl.col("text").str.replace("", pl.col("amt"), literal=True).alias("text2") + ) + assert df.get_column("text2")[0] == "I found $1 yesterday." + + # Multi-row code path in Rust + # A string shorter than 32 chars, + # and one longer than 32 chars to test both sub-paths + df2 = pl.DataFrame( + { + "text": [ + "I found yesterday.", + "I lost yesterday and this string is longer than 32 characters.", + ], + "amt": ["$1", "$2"], + } + ) + df2 = df2.with_columns( + pl.col("text").str.replace("", pl.col("amt"), literal=True).alias("text2") + ) + assert df2.get_column("text2")[0] == "I found $1 yesterday." + assert ( + df2.get_column("text2")[1] + == "I lost $2 yesterday and this string is longer than 32 characters." + ) + + def test_replace_expressions() -> None: df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"], "value": ["A", "B"]}) out = df.select([pl.col("foo").str.replace(pl.col("foo").first(), pl.col("value"))]) @@ -1727,3 +1787,55 @@ def test_extract_many() -> None: assert df.select(pl.col("values").str.extract_many("patterns")).to_dict( as_series=False ) == {"values": [["disco"], ["rhap", "ody"]]} + + +def test_json_decode_raise_on_data_type_mismatch_13061() -> None: + assert_series_equal( + pl.Series(["null", "null"]).str.json_decode(infer_schema_length=1), + pl.Series([None, None]), + ) + + with pytest.raises(ComputeError): + pl.Series(["null", "1"]).str.json_decode(infer_schema_length=1) + + assert_series_equal( + pl.Series(["null", "1"]).str.json_decode(infer_schema_length=2), + pl.Series([None, 1]), + ) + + +def test_json_decode_struct_schema() -> None: + with pytest.raises(ComputeError, match="extra key in struct data: b"): + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + infer_schema_length=1 + ) + + assert_series_equal( + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + infer_schema_length=2 + ), + pl.Series([{"a": 1, "b": None}, {"a": 2, "b": 2}]), + ) + + # If the schema was explicitly given, then we ignore extra fields. + # TODO: There should be a `columns=` parameter to this. + assert_series_equal( + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + dtype=pl.Struct({"a": pl.Int64}) + ), + pl.Series([{"a": 1}, {"a": 2}]), + ) + + +def test_escape_regex() -> None: + df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + result_df = df.with_columns(pl.col("text").str.escape_regex().alias("escaped")) + expected_df = pl.DataFrame( + { + "text": ["abc", "def", None, "abc(\\w+)"], + "escaped": ["abc", "def", None, "abc\\(\\\\w\\+\\)"], + } + ) + + assert_frame_equal(result_df, expected_df) + assert_series_equal(result_df["escaped"], expected_df["escaped"]) diff --git a/py-polars/tests/unit/operations/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py index 789f2a0974d8..7ffda7dbed83 100644 --- a/py-polars/tests/unit/operations/namespaces/test_plot.py +++ b/py-polars/tests/unit/operations/namespaces/test_plot.py @@ -28,11 +28,7 @@ def test_dataframe_plot_tooltip() -> None: } ) result = df.plot.line(x="length", y="width", color="species").to_dict() - assert result["encoding"]["tooltip"] == [ - {"field": "length", "type": "quantitative"}, - {"field": "width", "type": "quantitative"}, - {"field": "species", "type": "nominal"}, - ] + assert result["mark"]["tooltip"] is True result = df.plot.line( x="length", y="width", color="species", tooltip=["length", "width"] ).to_dict() @@ -54,10 +50,7 @@ def test_series_plot() -> None: def test_series_plot_tooltip() -> None: s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) result = s.plot.line().to_dict() - assert result["encoding"]["tooltip"] == [ - {"field": "index", "type": "quantitative"}, - {"field": "a", "type": "quantitative"}, - ] + assert result["mark"]["tooltip"] is True result = s.plot.line(tooltip=["a"]).to_dict() assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}] @@ -73,4 +66,4 @@ def test_nameless_series() -> None: def test_x_with_axis_18830() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) result = df.plot.line(x=alt.X("a", axis=alt.Axis(labelAngle=-90))).to_dict() - assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}] + assert result["mark"]["tooltip"] is True diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index e01912237a19..dca9eeb3e767 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -461,12 +461,6 @@ def test_cast_temporal( "expected_value", ), [ - (str(2**7 - 1).encode(), pl.Binary, pl.Int8, 2**7 - 1), - (str(2**15 - 1).encode(), pl.Binary, pl.Int16, 2**15 - 1), - (str(2**31 - 1).encode(), pl.Binary, pl.Int32, 2**31 - 1), - (str(2**63 - 1).encode(), pl.Binary, pl.Int64, 2**63 - 1), - (b"1.0", pl.Binary, pl.Float32, 1.0), - (b"1.0", pl.Binary, pl.Float64, 1.0), (str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1), (str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1), (str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1), @@ -478,13 +472,9 @@ def test_cast_temporal( (str(2**15), pl.String, pl.Int16, None), (str(2**31), pl.String, pl.Int32, None), (str(2**63), pl.String, pl.Int64, None), - (str(2**7).encode(), pl.Binary, pl.Int8, None), - (str(2**15).encode(), pl.Binary, pl.Int16, None), - (str(2**31).encode(), pl.Binary, pl.Int32, None), - (str(2**63).encode(), pl.Binary, pl.Int64, None), ], ) -def test_cast_string_and_binary( +def test_cast_string( value: int, from_dtype: PolarsDataType, to_dtype: PolarsDataType, @@ -522,12 +512,6 @@ def test_cast_string_and_binary( "expected_value", ), [ - (str(2**7 - 1).encode(), pl.Binary, pl.Int8, True, 2**7 - 1), - (str(2**15 - 1).encode(), pl.Binary, pl.Int16, True, 2**15 - 1), - (str(2**31 - 1).encode(), pl.Binary, pl.Int32, True, 2**31 - 1), - (str(2**63 - 1).encode(), pl.Binary, pl.Int64, True, 2**63 - 1), - (b"1.0", pl.Binary, pl.Float32, True, 1.0), - (b"1.0", pl.Binary, pl.Float64, True, 1.0), (str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1), (str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1), (str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1), @@ -539,13 +523,9 @@ def test_cast_string_and_binary( (str(2**15), pl.String, pl.Int16, False, None), (str(2**31), pl.String, pl.Int32, False, None), (str(2**63), pl.String, pl.Int64, False, None), - (str(2**7).encode(), pl.Binary, pl.Int8, False, None), - (str(2**15).encode(), pl.Binary, pl.Int16, False, None), - (str(2**31).encode(), pl.Binary, pl.Int32, False, None), - (str(2**63).encode(), pl.Binary, pl.Int64, False, None), ], ) -def test_strict_cast_string_and_binary( +def test_strict_cast_string( value: int, from_dtype: PolarsDataType, to_dtype: PolarsDataType, @@ -692,3 +672,9 @@ def test_cast_consistency() -> None: assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns( b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String) ).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]} + + +def test_cast_int_to_string_unsets_sorted_flag_19424() -> None: + s = pl.Series([1, 2]).set_sorted() + assert s.flags["SORTED_ASC"] + assert not s.cast(pl.String).flags["SORTED_ASC"] diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 3807a6b29ef5..14aefa93c3c1 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -405,14 +405,14 @@ def test_fast_explode_merge_left_16923() -> None: @pytest.mark.parametrize( ("values", "exploded"), [ - (["foobar", None], ["f", "o", "o", "b", "a", "r", ""]), - ([None, "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), + (["foobar", None], ["f", "o", "o", "b", "a", "r", None]), + ([None, "foo", "bar"], [None, "f", "o", "o", "b", "a", "r"]), ( [None, "foo", "bar", None, "ham"], - ["", "f", "o", "o", "b", "a", "r", "", "h", "a", "m"], + [None, "f", "o", "o", "b", "a", "r", None, "h", "a", "m"], ), (["foo", "bar", "ham"], ["f", "o", "o", "b", "a", "r", "h", "a", "m"]), - (["", None, "foo", "bar"], ["", "", "f", "o", "o", "b", "a", "r"]), + (["", None, "foo", "bar"], ["", None, "f", "o", "o", "b", "a", "r"]), (["", "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), ], ) @@ -421,9 +421,6 @@ def test_series_str_explode_deprecated( ) -> None: with pytest.deprecated_call(): result = pl.Series(values).str.explode() - if result.to_list() != exploded: - print(result.to_list()) - print(exploded) assert result.to_list() == exploded diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py index d56e5d80858d..4c74f60a7ff9 100644 --- a/py-polars/tests/unit/operations/test_gather.py +++ b/py-polars/tests/unit/operations/test_gather.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import polars as pl @@ -156,3 +157,47 @@ def test_gather_str_col_18099() -> None: "foo": [1, 1, 2], "idx": [0, 0, 1], } + + +def test_gather_list_19243() -> None: + df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]}) + assert df.with_columns(pl.lit([0]).alias("c")).with_columns( + gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True) + ).to_dict(as_series=False) == { + "a": [[0.1, 0.2, 0.3]], + "c": [[0]], + "gather": [[0.1]], + } + + +def test_gather_array_list_null_19302() -> None: + data = pl.DataFrame( + {"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))} + ) + assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == { + "data": [None] + } + + +def test_gather_array() -> None: + a = np.arange(16).reshape(-1, 2, 2) + s = pl.Series(a) + + for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]: + assert (s.gather(idx).to_numpy() == a[idx]).all() + + v = s[[0, 1, None, 3]] # type: ignore[list-item] + assert v[2] is None + + +def test_gather_array_outer_validity_19482() -> None: + s = ( + pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1)) + .to_frame() + .select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first())) + .to_series() + ) + + expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1)) + assert_series_equal(s, expect) + assert_series_equal(s.gather([0, 1]), expect) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index af9dc9a180e2..cff43b43274c 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None: ) assert df.rows() == [] assert df.shape == (0, 2) - assert df.schema == {"key": pl.Categorical, "val": pl.Float64} + assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} def test_group_by_custom_agg_empty_list() -> None: @@ -1146,3 +1146,10 @@ def test_positional_by_with_list_or_tuple_17540() -> None: pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"]) with pytest.raises(TypeError, match="Hint: if you"): pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"]) + + +def test_group_by_agg_19173() -> None: + df = pl.DataFrame({"x": [1.0], "g": [0]}) + out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2) + assert out.to_dict(as_series=False) == {"g": [], "x": []} + assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))]) diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 872361197a8d..891ac32fa0ba 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -594,3 +594,17 @@ def test_join_on_strings() -> None: "a_right": ["a", "a", "b", "a", "b", "c"], "b_right": ["b", "b", "b", "b", "b", "b"], } + + +def test_join_partial_column_name_overlap_19119() -> None: + left = pl.LazyFrame({"a": [1], "b": [2]}) + right = pl.LazyFrame({"a": [2], "d": [0]}) + + q = left.join_where(right, pl.col("a") > pl.col("d")) + + assert q.collect().to_dict(as_series=False) == { + "a": [1], + "b": [2], + "a_right": [2], + "d": [0], + } diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 68220cf81551..87ad4dfc5a9b 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -1082,3 +1082,20 @@ def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None: msg = "cross join should not pass join keys" with pytest.raises(ValueError, match=msg): df1.join(df2, how="cross", **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize("set_sorted", [True, False]) +def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None: + left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]}) + right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]}) + + if set_sorted: + # The data isn't actually sorted on purpose to ensure we default to a + # hash join unless we set the sorted flag here, in case there is new + # code in the future that automatically identifies sortedness during + # Series construction from Python. + left = left.set_sorted("k") + right = right.set_sorted("k") + + q = left.join(right, on="k", how="left").head(5) + assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]})) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 57dbec1a13ee..b19d008556e1 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -413,6 +413,7 @@ def test_sort_by_in_over_5499() -> None: } +@pytest.mark.may_fail_auto_streaming def test_merge_sorted() -> None: df_a = ( pl.datetime_range( diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py index ada642c294ae..434c2fdc3af9 100644 --- a/py-polars/tests/unit/operations/test_unpivot.py +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -2,6 +2,7 @@ import polars as pl import polars.selectors as cs +from polars import StringCache from polars.testing import assert_frame_equal @@ -94,3 +95,21 @@ def test_unpivot_empty_18170() -> None: assert pl.DataFrame().unpivot().schema == pl.Schema( {"variable": pl.String(), "value": pl.Null()} ) + + +@StringCache() +def test_unpivot_categorical_global() -> None: + df = pl.DataFrame( + { + "index": [0, 1], + "1": pl.Series(["a", "b"], dtype=pl.Categorical), + "2": pl.Series(["b", "c"], dtype=pl.Categorical), + } + ) + out = df.unpivot(["1", "2"], index="index") + assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="physical")] + assert out.to_dict(as_series=False) == { + "index": [0, 1, 0, 1], + "variable": ["1", "1", "2", "2"], + "value": ["a", "b", "b", "c"], + } diff --git a/py-polars/tests/unit/operations/unique/test_unique.py b/py-polars/tests/unit/operations/unique/test_unique.py index 479a52ca2f9a..b50edd981e5f 100644 --- a/py-polars/tests/unit/operations/unique/test_unique.py +++ b/py-polars/tests/unit/operations/unique/test_unique.py @@ -154,3 +154,11 @@ def test_unique_with_null() -> None: {"a": [1, 2, 3, 4], "b": ["a", "b", "c", "c"], "c": [None, None, None, None]} ) assert_frame_equal(df.unique(maintain_order=True), expected_df) + + +def test_categorical_unique_19409() -> None: + df = pl.DataFrame({"x": [str(n % 50) for n in range(127)]}).cast(pl.Categorical) + uniq = df.unique() + assert uniq.height == 50 + assert uniq.null_count().item() == 0 + assert set(uniq["x"]) == set(df["x"]) diff --git a/py-polars/tests/unit/series/test_equals.py b/py-polars/tests/unit/series/test_equals.py index 989554656253..e60f9b9428bd 100644 --- a/py-polars/tests/unit/series/test_equals.py +++ b/py-polars/tests/unit/series/test_equals.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Callable import pytest @@ -105,3 +106,194 @@ def test_series_equals_strict_deprecated() -> None: s2 = pl.Series("a", [1, 2, None], pl.Int64) with pytest.deprecated_call(): assert not s1.equals(s2, strict=True) # type: ignore[call-arg] + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)]) +@pytest.mark.parametrize( + ("cmp_eq", "cmp_ne"), + [ + # We parametrize the comparison sides as the impl looks like this: + # match (left.len(), right.len()) { + # (1, _) => ..., + # (_, 1) => ..., + # (_, _) => ..., + # } + (pl.Series.eq, pl.Series.ne), + ( + lambda a, b: pl.Series.eq(b, a), + lambda a, b: pl.Series.ne(b, a), + ), + ], +) +def test_eq_lists_arrays( + dtype: pl.DataType, + cmp_eq: Callable[[pl.Series, pl.Series], pl.Series], + cmp_ne: Callable[[pl.Series, pl.Series], pl.Series], +) -> None: + # Broadcast NULL + assert_series_equal( + cmp_eq( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + # Non-broadcast full-NULL + assert_series_equal( + cmp_eq( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + # Broadcast valid + assert_series_equal( + cmp_eq( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, True, False], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, False, True], dtype=pl.Boolean), + ) + + # Non-broadcast mixed + assert_series_equal( + cmp_eq( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, False, True], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, True, False], dtype=pl.Boolean), + ) + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)]) +@pytest.mark.parametrize( + ("cmp_eq_missing", "cmp_ne_missing"), + [ + (pl.Series.eq_missing, pl.Series.ne_missing), + ( + lambda a, b: pl.Series.eq_missing(b, a), + lambda a, b: pl.Series.ne_missing(b, a), + ), + ], +) +def test_eq_missing_lists_arrays_19153( + dtype: pl.DataType, + cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series], + cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series], +) -> None: + def assert_series_equal( + left: pl.Series, + right: pl.Series, + *, + assert_series_equal_impl: Callable[[pl.Series, pl.Series], None] = globals()[ + "assert_series_equal" + ], + ) -> None: + # `assert_series_equal` also uses `ne_missing` underneath so we have + # some extra checks here to be sure. + assert_series_equal_impl(left, right) + assert left.to_list() == right.to_list() + assert left.null_count() == 0 + assert right.null_count() == 0 + + # Broadcast NULL + assert_series_equal( + cmp_eq_missing( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, False]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, True]), + ) + + # Non-broadcast full-NULL + assert_series_equal( + cmp_eq_missing( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, False]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, True]), + ) + + # Broadcast valid + assert_series_equal( + cmp_eq_missing( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, False]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, True]), + ) + + # Non-broadcast mixed + assert_series_equal( + cmp_eq_missing( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, True]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, False]), + ) diff --git a/py-polars/tests/unit/series/test_scatter.py b/py-polars/tests/unit/series/test_scatter.py index c3c0b38d6805..95e4aa5b31e3 100644 --- a/py-polars/tests/unit/series/test_scatter.py +++ b/py-polars/tests/unit/series/test_scatter.py @@ -43,7 +43,7 @@ def test_scatter() -> None: assert s.to_list() == ["a", "x", "x"] assert s.scatter([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"] - # set multiple values values + # set multiple values s = pl.Series(["z", "z", "z"]) assert s.scatter([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"] s = pl.Series([True, False, True]) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index f1a858f62add..0ed478b9aa83 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -23,6 +23,7 @@ Unknown, ) from polars.exceptions import ( + DuplicateError, InvalidOperationError, PolarsInefficientMapWarning, ShapeError, @@ -1356,6 +1357,13 @@ def test_to_dummies_drop_first() -> None: assert_frame_equal(result, expected) +def test_to_dummies_null_clash_19096() -> None: + with pytest.raises( + DuplicateError, match="column with name '_null' has more than one occurrence" + ): + pl.Series([None, "null"]).to_dummies() + + def test_chunk_lengths() -> None: s = pl.Series("a", [1, 2, 2, 3]) # this is a Series with one chunk, of length 4 diff --git a/py-polars/tests/unit/sql/test_bitwise.py b/py-polars/tests/unit/sql/test_bitwise.py new file mode 100644 index 000000000000..7bba8cfde5c5 --- /dev/null +++ b/py-polars/tests/unit/sql/test_bitwise.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [20, 32, 50, 88, 128], + "y": [-128, 0, 10, -1, None], + } + ) + + +def test_bitwise_and(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x & y AS x_bitand_op_y, + BITAND(y, x) AS y_bitand_x, + BIT_AND(x, y) AS x_bitand_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitand_op_y": [0, 0, 2, 88, None], + "y_bitand_x": [0, 0, 2, 88, None], + "x_bitand_y": [0, 0, 2, 88, None], + } + + +def test_bitwise_count(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + BITCOUNT(x) AS x_bits_set, + BIT_COUNT(y) AS y_bits_set, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bits_set": [2, 1, 3, 3, 1], + "y_bits_set": [57, 0, 2, 64, None], + } + + +def test_bitwise_or(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x | y AS x_bitor_op_y, + BITOR(y, x) AS y_bitor_x, + BIT_OR(x, y) AS x_bitor_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitor_op_y": [-108, 32, 58, -1, None], + "y_bitor_x": [-108, 32, 58, -1, None], + "x_bitor_y": [-108, 32, 58, -1, None], + } + + +def test_bitwise_xor(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x XOR y AS x_bitxor_op_y, + BITXOR(y, x) AS y_bitxor_x, + BIT_XOR(x, y) AS x_bitxor_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitxor_op_y": [-108, 32, 56, -89, None], + "y_bitxor_x": [-108, 32, 56, -89, None], + "x_bitxor_y": [-108, 32, 56, -89, None], + } diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 08e4b236c833..71fa1572831c 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -238,3 +238,9 @@ def test_group_by_errors() -> None: match=r"'a' should participate in the GROUP BY clause or an aggregate function", ): df.sql("SELECT a, SUM(b) FROM self GROUP BY b") + + with pytest.raises( + SQLSyntaxError, + match=r"HAVING clause not valid outside of GROUP BY", + ): + df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1") diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 95ba8461bebe..f7d0615e13c6 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -7,7 +7,7 @@ import pytest import polars as pl -from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal if TYPE_CHECKING: @@ -362,3 +362,26 @@ def test_global_variable_inference_17398() -> None: eager=True, ) assert_frame_equal(res, users) + + +@pytest.mark.parametrize( + "query", + [ + "SELECT invalid_column FROM self", + "SELECT key, invalid_column FROM self", + "SELECT invalid_column * 2 FROM self", + "SELECT * FROM self ORDER BY invalid_column", + "SELECT * FROM self WHERE invalid_column = 200", + "SELECT * FROM self WHERE invalid_column = '200'", + "SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column", + ], +) +def test_invalid_cols(query: str) -> None: + df = pl.DataFrame( + { + "key": ["xx", "xx", "yy"], + "n": ["100", "200", "300"], + } + ) + with pytest.raises(ColumnNotFoundError, match="invalid_column"): + df.sql(query) diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index 80730273672e..225c0b97553c 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -19,6 +19,7 @@ pytestmark = pytest.mark.xdist_group("streaming") +@pytest.mark.may_fail_auto_streaming def test_streaming_categoricals_5921() -> None: with pl.StringCache(): out_lazy = ( @@ -74,6 +75,7 @@ def test_streaming_streamable_functions(monkeypatch: Any, capfd: Any) -> None: @pytest.mark.slow +@pytest.mark.may_fail_auto_streaming def test_cross_join_stack() -> None: a = pl.Series(np.arange(100_000)).to_frame().lazy() t0 = time.time() diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py index cd0491b5435f..02397afcf9d4 100644 --- a/py-polars/tests/unit/test_config.py +++ b/py-polars/tests/unit/test_config.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from textwrap import dedent from typing import TYPE_CHECKING, Any import pytest @@ -497,6 +498,16 @@ def test_shape_format_for_big_numbers() -> None: "╰─────────┴───╯" ) + pl.Config.set_tbl_formatting("ASCII_FULL_CONDENSED") + assert ( + str(df) == "shape: (1, 1_000)\n" + "+---------+-----+\n" + "| 0 (i64) | ... |\n" + "+===============+\n" + "| 1 | ... |\n" + "+---------+-----+" + ) + def test_numeric_right_alignment() -> None: pl.Config.set_tbl_cell_numeric_alignment("RIGHT") @@ -739,7 +750,7 @@ def test_config_scope() -> None: def test_config_raise_error_if_not_exist() -> None: - with pytest.raises(AttributeError), pl.Config(i_do_not_exist=True): + with pytest.raises(AttributeError), pl.Config(i_do_not_exist=True): # type: ignore[call-arg] pass @@ -771,6 +782,79 @@ def test_set_fmt_str_lengths_invalid_length() -> None: cfg.set_fmt_str_lengths(-2) +def test_truncated_rows_cols_values_ascii() -> None: + df = pl.DataFrame({f"c{n}": list(range(-n, 100 - n)) for n in range(10)}) + + pl.Config.set_tbl_formatting("UTF8_BORDERS_ONLY", rounded_corners=True) + assert ( + str(df) == "shape: (100, 10)\n" + "╭───────────────────────────────────────────────────╮\n" + "│ c0 c1 c2 c3 … c6 c7 c8 c9 │\n" + "│ --- --- --- --- --- --- --- --- │\n" + "│ i64 i64 i64 i64 i64 i64 i64 i64 │\n" + "╞═══════════════════════════════════════════════════╡\n" + "│ 0 -1 -2 -3 … -6 -7 -8 -9 │\n" + "│ 1 0 -1 -2 … -5 -6 -7 -8 │\n" + "│ 2 1 0 -1 … -4 -5 -6 -7 │\n" + "│ 3 2 1 0 … -3 -4 -5 -6 │\n" + "│ 4 3 2 1 … -2 -3 -4 -5 │\n" + "│ … … … … … … … … … │\n" + "│ 95 94 93 92 … 89 88 87 86 │\n" + "│ 96 95 94 93 … 90 89 88 87 │\n" + "│ 97 96 95 94 … 91 90 89 88 │\n" + "│ 98 97 96 95 … 92 91 90 89 │\n" + "│ 99 98 97 96 … 93 92 91 90 │\n" + "╰───────────────────────────────────────────────────╯" + ) + with pl.Config(tbl_formatting="ASCII_FULL_CONDENSED"): + assert ( + str(df) == "shape: (100, 10)\n" + "+-----+-----+-----+-----+-----+-----+-----+-----+-----+\n" + "| c0 | c1 | c2 | c3 | ... | c6 | c7 | c8 | c9 |\n" + "| --- | --- | --- | --- | | --- | --- | --- | --- |\n" + "| i64 | i64 | i64 | i64 | | i64 | i64 | i64 | i64 |\n" + "+=====================================================+\n" + "| 0 | -1 | -2 | -3 | ... | -6 | -7 | -8 | -9 |\n" + "| 1 | 0 | -1 | -2 | ... | -5 | -6 | -7 | -8 |\n" + "| 2 | 1 | 0 | -1 | ... | -4 | -5 | -6 | -7 |\n" + "| 3 | 2 | 1 | 0 | ... | -3 | -4 | -5 | -6 |\n" + "| 4 | 3 | 2 | 1 | ... | -2 | -3 | -4 | -5 |\n" + "| ... | ... | ... | ... | ... | ... | ... | ... | ... |\n" + "| 95 | 94 | 93 | 92 | ... | 89 | 88 | 87 | 86 |\n" + "| 96 | 95 | 94 | 93 | ... | 90 | 89 | 88 | 87 |\n" + "| 97 | 96 | 95 | 94 | ... | 91 | 90 | 89 | 88 |\n" + "| 98 | 97 | 96 | 95 | ... | 92 | 91 | 90 | 89 |\n" + "| 99 | 98 | 97 | 96 | ... | 93 | 92 | 91 | 90 |\n" + "+-----+-----+-----+-----+-----+-----+-----+-----+-----+" + ) + + with pl.Config(tbl_formatting="MARKDOWN"): + df = pl.DataFrame({"b": [b"0tigohij1prisdfj1gs2io3fbjg0pfihodjgsnfbbmfgnd8j"]}) + assert ( + str(df) + == dedent(""" + shape: (1, 1) + | b | + | --- | + | binary | + |---------------------------------| + | b"0tigohij1prisdfj1gs2io3fbjg0… |""").lstrip() + ) + + with pl.Config(tbl_formatting="ASCII_MARKDOWN"): + df = pl.DataFrame({"b": [b"0tigohij1prisdfj1gs2io3fbjg0pfihodjgsnfbbmfgnd8j"]}) + assert ( + str(df) + == dedent(""" + shape: (1, 1) + | b | + | --- | + | binary | + |-----------------------------------| + | b"0tigohij1prisdfj1gs2io3fbjg0... |""").lstrip() + ) + + def test_warn_unstable(recwarn: pytest.WarningsRecorder) -> None: issue_unstable_warning() assert len(recwarn) == 0 diff --git a/py-polars/tests/unit/test_cpu_check.py b/py-polars/tests/unit/test_cpu_check.py index fdfa5965f6ff..efc72fe4bc0e 100644 --- a/py-polars/tests/unit/test_cpu_check.py +++ b/py-polars/tests/unit/test_cpu_check.py @@ -59,18 +59,6 @@ def test_check_cpu_flags_skipped_no_flags(monkeypatch: pytest.MonkeyPatch) -> No assert mock_read_cpu_flags.call_count == 0 -@pytest.mark.usefixtures("_feature_flags") -def test_check_cpu_flags_skipped_lts_cpu(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(_cpu_check, "_POLARS_LTS_CPU", True) - - mock_read_cpu_flags = Mock() - monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) - - check_cpu_flags() - - assert mock_read_cpu_flags.call_count == 0 - - @pytest.mark.usefixtures("_feature_flags") def test_check_cpu_flags_skipped_env_var(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("POLARS_SKIP_CPU_CHECK", "1") diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index d6fdb8976c5b..5ac102efe5cb 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -788,3 +788,34 @@ def test_eager_cse_during_struct_expansion_18411() -> None: df.select(pl.col("foo").replace(classes, counts)) == df.select(pl.col("foo").replace(classes, counts)) )["foo"].all() + + +def test_cse_skip_as_struct_19253() -> None: + df = pl.LazyFrame({"x": [1, 2], "y": [4, 5]}) + + assert ( + df.with_columns( + q1=pl.struct(pl.col.x - pl.col.y.mean()), + q2=pl.struct(pl.col.x - pl.col.y.mean().over("y")), + ).collect() + ).to_dict(as_series=False) == { + "x": [1, 2], + "y": [4, 5], + "q1": [{"x": -3.5}, {"x": -2.5}], + "q2": [{"x": -3.0}, {"x": -3.0}], + } + + +def test_cse_union_19227() -> None: + lf = pl.LazyFrame({"A": [1], "B": [2]}) + lf_1 = lf.select(C="A", B="B") + lf_2 = lf.select(C="A", A="B") + + direct = lf_2.join(lf, on=["A"]).select("C", "A", "B") + + indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B") + + out = pl.concat([direct, indirect]) + assert out.collect().schema == pl.Schema( + [("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)] + ) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index c730ee8d30a7..0ca504d20e20 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -16,7 +16,6 @@ ComputeError, InvalidOperationError, OutOfBoundsError, - PanicException, SchemaError, SchemaFieldNotFoundError, ShapeError, @@ -116,7 +115,7 @@ def test_string_numeric_comp_err() -> None: def test_panic_error() -> None: with pytest.raises( - PanicException, + InvalidOperationError, match="unit: 'k' not supported", ): pl.datetime_range( @@ -696,7 +695,7 @@ def test_no_panic_pandas_nat() -> None: def test_list_to_struct_invalid_type() -> None: - with pytest.raises(pl.exceptions.SchemaError): + with pytest.raises(pl.exceptions.InvalidOperationError): pl.DataFrame({"a": 1}).select(pl.col("a").list.to_struct()) @@ -708,3 +707,15 @@ def test_raise_invalid_agg() -> None: .group_by("index") .agg(pl.col("foo").filter(pl.col("i_do_not_exist"))) ).collect() + + +def test_err_mean_horizontal_lists() -> None: + df = pl.DataFrame( + { + "experiment_id": [1, 2], + "sensor1": [[1, 2, 3], [7, 8, 9]], + "sensor2": [[4, 5, 6], [10, 11, 12]], + } + ) + with pytest.raises(pl.exceptions.InvalidOperationError): + df.with_columns(pl.mean_horizontal("sensor1", "sensor2").alias("avg_sensor")) diff --git a/py-polars/tests/unit/test_expansion.py b/py-polars/tests/unit/test_expansion.py index 795d5511cc50..b01fc625f5cc 100644 --- a/py-polars/tests/unit/test_expansion.py +++ b/py-polars/tests/unit/test_expansion.py @@ -138,6 +138,20 @@ def test_struct_field_expand_star() -> None: assert_frame_equal(struct_df.select(pl.col("struct_col").struct.field("*")), df) +def test_struct_unnest() -> None: + """Same as test_struct_field_expand_star but using the unnest alias.""" + df = pl.DataFrame( + { + "aaa": [1, 2], + "bbb": ["ab", "cd"], + "ccc": [True, None], + "ddd": [[1, 2], [3]], + } + ) + struct_df = df.select(pl.struct(["aaa", "bbb", "ccc", "ddd"]).alias("struct_col")) + assert_frame_equal(struct_df.select(pl.col("struct_col").struct.unnest()), df) + + def test_struct_field_expand_rewrite() -> None: df = pl.DataFrame({"A": [1], "B": [2]}) assert df.select( diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 48b3077423e0..e075fdb6acc2 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -582,3 +582,33 @@ def test_projections_collapse_17781() -> None: else: lf = lf.join(lfj, on="index", how="left") assert "SELECT " not in lf.explain() # type: ignore[union-attr] + + +def test_with_columns_projection_pushdown() -> None: + # # Summary + # `process_hstack` in projection PD incorrectly took a fast-path meant for + # LP nodes that don't add new columns to the schema, which stops projection + # PD if it sees that the schema lengths on the upper node matches. + # + # To trigger this, we drop the same number of columns before and after + # the with_columns, and in the with_columns we also add the same number of + # columns. + lf = ( + pl.scan_csv( + b"""\ +a,b,c,d,e +1,1,1,1,1 +""", + include_file_paths="path", + ) + .drop("a", "b") + .with_columns(pl.lit(1).alias(x) for x in ["x", "y"]) + .drop("c", "d") + ) + + plan = lf.explain().strip() + + assert plan.startswith("WITH_COLUMNS:") + # [dyn int: 1.alias("x"), dyn int: 1.alias("y")] + # Csv SCAN [20 in-mem bytes] + assert plan.endswith("PROJECT 1/6 COLUMNS") diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 7b6a1583ef11..9c5848382f42 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,6 +1,8 @@ import pickle from datetime import datetime +import pytest + import polars as pl @@ -13,14 +15,40 @@ def test_schema() -> None: assert s.names() == ["foo", "bar"] assert s.dtypes() == [pl.Int8(), pl.String()] + with pytest.raises( + TypeError, + match="dtypes must be fully-specified, got: List", + ): + pl.Schema({"foo": pl.String, "bar": pl.List}) + + +def test_schema_equality() -> None: + s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) + s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) + + assert s1 == s1 + assert s2 == s2 + assert s3 == s3 + assert s1 != s2 + assert s1 != s3 + assert s2 != s3 + + s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")}) + s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")}) + s6 = {"foo": pl.Datetime, "bar": pl.Duration} + + assert s4 != s5 + assert s4 != s6 -def test_schema_parse_nonpolars_dtypes() -> None: + +def test_schema_parse_python_dtypes() -> None: cardinal_directions = pl.Enum(["north", "south", "east", "west"]) - s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] s["ham"] = datetime - assert s["foo"] == pl.List + assert s["foo"] == pl.List(pl.Int32) assert s["bar"] == pl.Int64 assert s["baz"] == cardinal_directions assert s["ham"] == pl.Datetime("us") @@ -33,19 +61,6 @@ def test_schema_parse_nonpolars_dtypes() -> None: assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] -def test_schema_equality() -> None: - s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) - s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) - s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) - - assert s1 == s1 - assert s2 == s2 - assert s3 == s3 - assert s1 != s2 - assert s1 != s3 - assert s2 != s3 - - def test_schema_picklable() -> None: s = pl.Schema( { @@ -88,7 +103,6 @@ def test_schema_in_map_elements_returns_scalar() -> None: "amounts": [100.0, -110.0] * 2, } ) - q = ldf.group_by("portfolio").agg( pl.col("amounts") .map_elements( @@ -112,7 +126,6 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None: .rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i") .agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2")) ) - assert q.collect_schema() == pl.Schema( [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] ) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index bf44ff87bac5..dd2c415c9a13 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -515,7 +515,7 @@ def test_selector_temporal(df: pl.DataFrame) -> None: assert df.select(cs.temporal()).schema == { "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), } all_columns = set(df.columns) @@ -611,7 +611,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: "eee": pl.Boolean, "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), "qqR": pl.String, } @@ -629,7 +629,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) @@ -639,7 +639,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: ).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 179cd9e58d86..90221c3e2edb 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-09-29" +channel = "nightly-2024-10-28"