diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 95545adb4c4e..a99a2770491c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,6 +4,7 @@ /crates/polars-sql/ @ritchie46 @orlp @c-peters @alexander-beedie /crates/polars-parquet/ @ritchie46 @orlp @c-peters @coastalwhite /crates/polars-time/ @ritchie46 @orlp @c-peters @MarcoGorelli +/crates/polars-python/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa +/crates/polars-python/src/lazyframe/visit.rs @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- +/crates/polars-python/src/lazyframe/visitor/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- /py-polars/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa -/py-polars/src/lazyframe/visit.rs @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- -/py-polars/src/lazyframe/visitor/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- diff --git a/.github/workflows/benchmark-remote.yml b/.github/workflows/benchmark-remote.yml index f21898f9689f..ab7d829897a2 100644 --- a/.github/workflows/benchmark-remote.yml +++ b/.github/workflows/benchmark-remote.yml @@ -2,19 +2,22 @@ name: Remote Benchmark on: workflow_dispatch: + push: + branches: + - 'main' pull_request: types: [ labeled ] concurrency: group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + cancel-in-progress: ${{ github.event.label.name == 'needs-bench' }} env: SCALE_FACTOR: '10.0' jobs: main: - if: ${{ github.event.label.name == 'needs-bench' }} + if: ${{ github.ref == 'refs/heads/main' || github.event.label.name == 'needs-bench' }} runs-on: self-hosted steps: - uses: actions/checkout@v4 @@ -63,3 +66,9 @@ jobs: working-directory: polars-benchmark run: | make run-polars-no-env + + - name: Cache the Polars build + if: ${{ github.ref == 'refs/heads/main' }} + working-directory: py-polars + run: | + "$HOME/py-polars-cache/cache-build.sh" "$PWD/polars" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4da3d3bad8e3..afa9219231a3 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -71,7 +71,7 @@ jobs: env: RUSTFLAGS: -C embed-bitcode -D warnings working-directory: py-polars - run: maturin develop --features new_streaming --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native + run: maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native - name: Run benchmark tests uses: CodSpeedHQ/action@v3 diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index 06b7f096524e..4981710d4773 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -194,7 +194,7 @@ jobs: command: build target: ${{ steps.target.outputs.target }} args: > - --release + --profile dist-release --manifest-path py-polars/Cargo.toml --out dist manylinux: ${{ matrix.architecture == 'aarch64' && '2_24' || 'auto' }} diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 6185b905f7c7..a117f5b7fe96 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -82,7 +82,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Install Polars - run: maturin develop --features new_streaming + run: maturin develop - name: Run doctests if: github.ref_name != 'main' && matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' diff --git a/Cargo.lock b/Cargo.lock index 7d308d3988ae..967c4f3562cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 4 [[package]] name = "addr2line" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -30,7 +30,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "const-random", "getrandom", "once_cell", "version_check", @@ -90,15 +89,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "apache-avro" @@ -168,51 +167,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "arrow-array" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd2bf348cf9f02a5975c5962c7fa6dee107a2009a7b41ac5fb1a027e12dc033f" -dependencies = [ - "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "hashbrown 0.14.5", - "num", -] - -[[package]] -name = "arrow-buffer" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3092e37715f168976012ce52273c3989b5793b0db5f06cbaa246be25e5f0924d" -dependencies = [ - "bytes", - "half", - "num", -] - -[[package]] -name = "arrow-data" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4ac0c4ee79150afe067dc4857154b3ee9c1cd52b5f40d59a77306d0ed18d65" -dependencies = [ - "arrow-buffer", - "arrow-schema", - "half", - "num", -] - -[[package]] -name = "arrow-schema" -version = "53.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85320a3a2facf2b2822b57aa9d6d9d55edb8aee0b6b5d3b8df158e503d10858" - [[package]] name = "arrow2" version = "0.17.4" @@ -252,7 +206,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -263,7 +217,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -311,9 +265,9 @@ dependencies = [ [[package]] name = "aws-config" -version = "1.5.7" +version = "1.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8191fb3091fa0561d1379ef80333c3c7191c6f0435d986e85821bcf7acbd1126" +checksum = "2d6448cfb224dd6a9b9ac734f58622dd0d4751f3589f3b777345745f46b2eb14" dependencies = [ "aws-credential-types", "aws-runtime", @@ -414,9 +368,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.44.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b90cfe6504115e13c41d3ea90286ede5aa14da294f3fe077027a6e83850843c" +checksum = "a8776850becacbd3a82a4737a9375ddb5c6832a51379f24443a98e61513f852c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -436,9 +390,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.45.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "167c0fad1f212952084137308359e8e4c4724d1c643038ce163f06de9662c1d0" +checksum = "0007b5b8004547133319b6c4e87193eee2a0bcb3e4c18c75d09febe9dab7b383" dependencies = [ "aws-credential-types", "aws-runtime", @@ -458,9 +412,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.44.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb5f98188ec1435b68097daa2a37d74b9d17c9caa799466338a8d1544e71b9d" +checksum = "9fffaa356e7f1c725908b75136d53207fa714e348f365671df14e95a60530ad3" dependencies = [ "aws-credential-types", "aws-runtime", @@ -481,9 +435,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -593,9 +547,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -608,7 +562,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -637,9 +591,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.7" +version = "1.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" +checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" dependencies = [ "base64-simd", "bytes", @@ -735,9 +689,9 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bigdecimal" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" +checksum = "8f850665a0385e070b64c38d2354e6c104c8479c59868d1e48a0c13ee2c7a1c1" dependencies = [ "autocfg", "libm", @@ -825,22 +779,22 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -851,9 +805,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "bytes-utils" @@ -891,9 +845,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.24" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812acba72f0a070b003d3697490d2b55b837230ae7c6c6497f05cc2ddbb8d938" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "jobserver", "libc", @@ -970,18 +924,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstyle", "clap_lex", @@ -1044,26 +998,6 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" -[[package]] -name = "const-random" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" -dependencies = [ - "const-random-macro", -] - -[[package]] -name = "const-random-macro" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" -dependencies = [ - "getrandom", - "once_cell", - "tiny-keccak", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -1278,9 +1212,9 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "der" @@ -1374,7 +1308,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -1456,9 +1390,9 @@ dependencies = [ [[package]] name = "float-cmp" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" dependencies = [ "num-traits", ] @@ -1502,9 +1436,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1517,9 +1451,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1527,15 +1461,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1544,38 +1478,38 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1624,9 +1558,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1691,7 +1625,6 @@ checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ "cfg-if", "crunchy", - "num-traits", ] [[package]] @@ -1859,9 +1792,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -1883,9 +1816,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", @@ -1909,7 +1842,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -1925,9 +1858,9 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", - "rustls 0.23.13", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -1946,7 +1879,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.0", "pin-project-lite", "socket2", "tokio", @@ -2012,9 +1945,9 @@ checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" [[package]] name = "ipnet" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is-terminal" @@ -2088,9 +2021,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2114,9 +2047,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libflate" @@ -2174,9 +2107,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "a00419de735aac21d53b0de5ce2c03bd3627277cf471300f27ebc89f7d828047" [[package]] name = "libmimalloc-sys" @@ -2190,9 +2123,9 @@ dependencies = [ [[package]] name = "libz-ng-sys" -version = "1.1.16" +version = "1.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4436751a01da56f1277f323c80d584ffad94a3d14aecd959dd0dff75aa73a438" +checksum = "8f0f7295a34685977acb2e8cc8b08ee4a8dffd6cf278eeccddbe1ed55ba815d5" dependencies = [ "cmake", "libc", @@ -2222,11 +2155,11 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lru" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.15.0", ] [[package]] @@ -2390,20 +2323,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - [[package]] name = "num-bigint" version = "0.4.6" @@ -2439,28 +2358,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -2473,9 +2370,8 @@ dependencies = [ [[package]] name = "numpy" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf314fca279e6e6ac2126a4ff98f26d88aa4ad06bc68fb6ae5cf4bd706758311" +version = "0.21.0" +source = "git+https://github.com/stinodego/rust-numpy.git?rev=9ba9962ae57ba26e35babdce6f179edf5fe5b9c8#9ba9962ae57ba26e35babdce6f179edf5fe5b9c8" dependencies = [ "libc", "ndarray", @@ -2587,9 +2483,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.4" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] @@ -2606,7 +2502,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.0", "itertools 0.13.0", "md-5", "parking_lot", @@ -2627,12 +2523,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.1" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" -dependencies = [ - "portable-atomic", -] +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" @@ -2686,16 +2579,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "parquet-format-safe" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" -dependencies = [ - "async-trait", - "futures", -] - [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -2751,9 +2634,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -2816,11 +2699,10 @@ dependencies = [ [[package]] name = "polars" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "apache-avro", - "arrow-buffer", "avro-schema", "either", "ethnum", @@ -2846,13 +2728,10 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "async-stream", "atoi", "atoi_simd", "avro-schema", @@ -2894,6 +2773,7 @@ dependencies = [ "simdutf8", "streaming-iterator", "strength_reduce", + "strum_macros", "tokio", "tokio-util", "version_check", @@ -2914,7 +2794,7 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.43.1" +version = "0.44.1" dependencies = [ "bytemuck", "either", @@ -2929,10 +2809,9 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", - "arrow-array", "bincode", "bitflags", "bytemuck", @@ -2958,6 +2837,7 @@ dependencies = [ "regex", "serde", "serde_json", + "strum_macros", "thiserror", "version_check", "xxhash-rust", @@ -2965,7 +2845,7 @@ dependencies = [ [[package]] name = "polars-doc-examples" -version = "0.43.1" +version = "0.44.1" dependencies = [ "aws-config", "aws-sdk-s3", @@ -2979,7 +2859,7 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.43.1" +version = "0.44.1" dependencies = [ "avro-schema", "object_store", @@ -2991,10 +2871,11 @@ dependencies = [ [[package]] name = "polars-expr" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bitflags", + "hashbrown 0.15.0", "num-traits", "once_cell", "polars-arrow", @@ -3004,14 +2885,16 @@ dependencies = [ "polars-json", "polars-ops", "polars-plan", + "polars-row", "polars-time", "polars-utils", + "rand", "rayon", ] [[package]] name = "polars-ffi" -version = "0.43.1" +version = "0.44.1" dependencies = [ "polars-arrow", "polars-core", @@ -3019,7 +2902,7 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "async-trait", @@ -3050,6 +2933,7 @@ dependencies = [ "polars-schema", "polars-time", "polars-utils", + "pyo3", "rayon", "regex", "reqwest", @@ -3067,7 +2951,7 @@ dependencies = [ [[package]] name = "polars-json" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "chrono", @@ -3087,7 +2971,7 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bitflags", @@ -3115,7 +2999,7 @@ dependencies = [ [[package]] name = "polars-mem-engine" -version = "0.43.1" +version = "0.44.1" dependencies = [ "futures", "memmap2", @@ -3136,7 +3020,7 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "aho-corasick", @@ -3163,15 +3047,17 @@ dependencies = [ "rand_distr", "rayon", "regex", + "regex-syntax 0.8.5", "serde", "serde_json", + "strum_macros", "unicode-reverse", "version_check", ] [[package]] name = "polars-parquet" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "async-stream", @@ -3186,10 +3072,10 @@ dependencies = [ "lz4", "lz4_flex", "num-traits", - "parquet-format-safe", "polars-arrow", "polars-compute", "polars-error", + "polars-parquet-format", "polars-utils", "rand", "serde", @@ -3200,9 +3086,19 @@ dependencies = [ "zstd", ] +[[package]] +name = "polars-parquet-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" +dependencies = [ + "async-trait", + "futures", +] + [[package]] name = "polars-pipe" -version = "0.43.1" +version = "0.44.1" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -3227,7 +3123,7 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bitflags", @@ -3241,6 +3137,7 @@ dependencies = [ "hashbrown 0.15.0", "libloading", "memmap2", + "num-traits", "once_cell", "percent-encoding", "polars-arrow", @@ -3264,7 +3161,7 @@ dependencies = [ [[package]] name = "polars-python" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "arboard", @@ -3281,8 +3178,10 @@ dependencies = [ "polars", "polars-core", "polars-error", + "polars-expr", "polars-io", "polars-lazy", + "polars-mem-engine", "polars-ops", "polars-parquet", "polars-plan", @@ -3298,7 +3197,7 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.43.1" +version = "0.44.1" dependencies = [ "bytemuck", "polars-arrow", @@ -3308,7 +3207,7 @@ dependencies = [ [[package]] name = "polars-schema" -version = "0.43.1" +version = "0.44.1" dependencies = [ "indexmap", "polars-error", @@ -3319,7 +3218,7 @@ dependencies = [ [[package]] name = "polars-sql" -version = "0.43.1" +version = "0.44.1" dependencies = [ "hex", "once_cell", @@ -3339,7 +3238,7 @@ dependencies = [ [[package]] name = "polars-stream" -version = "0.43.1" +version = "0.44.1" dependencies = [ "atomic-waker", "crossbeam-deque", @@ -3366,7 +3265,7 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.43.1" +version = "0.44.1" dependencies = [ "atoi", "bytemuck", @@ -3381,11 +3280,12 @@ dependencies = [ "polars-utils", "regex", "serde", + "strum_macros", ] [[package]] name = "polars-utils" -version = "0.43.1" +version = "0.44.1" dependencies = [ "ahash", "bytemuck", @@ -3398,6 +3298,7 @@ dependencies = [ "num-traits", "once_cell", "polars-error", + "pyo3", "rand", "raw-cpuid", "rayon", @@ -3430,9 +3331,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -3487,7 +3388,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "1.9.0" +version = "1.12.0" dependencies = [ "jemallocator", "libc", @@ -3499,9 +3400,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "chrono", @@ -3509,7 +3410,7 @@ dependencies = [ "inventory", "libc", "memoffset", - "once_cell", + "parking_lot", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -3519,9 +3420,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -3529,9 +3430,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -3539,27 +3440,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "pyo3-macros-backend" -version = "0.22.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ - "heck 0.5.0", + "heck 0.4.1", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -3600,7 +3501,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.13", + "rustls 0.23.15", "socket2", "thiserror", "tokio", @@ -3617,7 +3518,7 @@ dependencies = [ "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.13", + "rustls 0.23.15", "slab", "thiserror", "tinyvec", @@ -3757,7 +3658,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -3786,14 +3687,14 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -3845,7 +3746,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3856,7 +3757,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.13", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3937,9 +3838,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" dependencies = [ "bitflags", "errno", @@ -3962,9 +3863,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.13" +version = "0.23.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" dependencies = [ "once_cell", "ring", @@ -4019,9 +3920,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -4046,9 +3947,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -4113,9 +4014,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ "windows-sys 0.59.0", ] @@ -4181,9 +4082,9 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] @@ -4199,20 +4100,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "indexmap", "itoa", @@ -4282,9 +4183,9 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.14.0" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f0b376aada35f30a0012f5790e50aed62f91804a0682669aefdbe81c7fcb91" +checksum = "b1df0290e9bfe79ddd5ff8798ca887cd107b75353d2957efe9777296e17f26b5" dependencies = [ "ahash", "getrandom", @@ -4452,7 +4353,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4474,9 +4375,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -4532,22 +4433,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4580,15 +4481,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinytemplate" version = "1.2.1" @@ -4616,9 +4508,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", @@ -4639,7 +4531,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4658,7 +4550,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.13", + "rustls 0.23.15", "rustls-pki-types", "tokio", ] @@ -4702,7 +4594,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4747,7 +4639,7 @@ checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -4835,9 +4727,9 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -4845,9 +4737,9 @@ dependencies = [ [[package]] name = "value-trait" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcaa56177466248ba59d693a048c0959ddb67f1151b963f904306312548cf392" +checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" dependencies = [ "float-cmp", "halfbrown", @@ -4894,9 +4786,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -4905,24 +4797,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4932,9 +4824,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4942,28 +4834,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -4974,9 +4866,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -5052,7 +4944,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -5063,7 +4955,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -5291,7 +5183,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 884e4f109e62..11a8665e5b68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,18 +3,14 @@ resolver = "2" members = [ "crates/*", "docs/source/src/rust", - # "examples/*", "py-polars", ] default-members = [ "crates/*", ] -# exclude = [ -# "examples/datasets", -# ] [workspace.package] -version = "0.43.1" +version = "0.44.1" authors = ["Ritchie Vink "] edition = "2021" homepage = "https://www.pola.rs/" @@ -25,10 +21,6 @@ repository = "https://github.com/pola-rs/polars" ahash = ">=0.8.5" aho-corasick = "1.1" arboard = { version = "3.4.0", default-features = false } -arrow-array = { version = ">=41", default-features = false } -arrow-buffer = { version = ">=41", default-features = false } -arrow-data = { version = ">=41", default-features = false } -arrow-schema = { version = ">=41", default-features = false } atoi = "2" atoi_simd = "0.15.5" atomic-waker = "1" @@ -64,19 +56,19 @@ memmap = { package = "memmap2", version = "0.7" } multiversion = "0.7" ndarray = { version = "0.15", default-features = false } num-traits = "0.2" -numpy = "0.22" object_store = { version = "0.10", default-features = false } once_cell = "1" parking_lot = "0.12" percent-encoding = "2.3" pin-project-lite = "0.2" -pyo3 = "0.22" +pyo3 = "0.21" rand = "0.8" rand_distr = "0.4" raw-cpuid = "11" rayon = "1.9" recursive = "0.1" regex = "1.9" +regex-syntax = "0.8.5" reqwest = { version = "0.12", default-features = false } ryu = "1.0.13" serde = { version = "1.0.188", features = ["derive", "rc"] } @@ -99,27 +91,27 @@ version_check = "0.9.4" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } zstd = "0.13" -polars = { version = "0.43.1", path = "crates/polars", default-features = false } -polars-compute = { version = "0.43.1", path = "crates/polars-compute", default-features = false } -polars-core = { version = "0.43.1", path = "crates/polars-core", default-features = false } -polars-error = { version = "0.43.1", path = "crates/polars-error", default-features = false } -polars-expr = { version = "0.43.1", path = "crates/polars-expr", default-features = false } -polars-ffi = { version = "0.43.1", path = "crates/polars-ffi", default-features = false } -polars-io = { version = "0.43.1", path = "crates/polars-io", default-features = false } -polars-json = { version = "0.43.1", path = "crates/polars-json", default-features = false } -polars-lazy = { version = "0.43.1", path = "crates/polars-lazy", default-features = false } -polars-mem-engine = { version = "0.43.1", path = "crates/polars-mem-engine", default-features = false } -polars-ops = { version = "0.43.1", path = "crates/polars-ops", default-features = false } -polars-parquet = { version = "0.43.1", path = "crates/polars-parquet", default-features = false } -polars-pipe = { version = "0.43.1", path = "crates/polars-pipe", default-features = false } -polars-plan = { version = "0.43.1", path = "crates/polars-plan", default-features = false } -polars-python = { version = "0.43.1", path = "crates/polars-python", default-features = false } -polars-row = { version = "0.43.1", path = "crates/polars-row", default-features = false } -polars-schema = { version = "0.43.1", path = "crates/polars-schema", default-features = false } -polars-sql = { version = "0.43.1", path = "crates/polars-sql", default-features = false } -polars-stream = { version = "0.43.1", path = "crates/polars-stream", default-features = false } -polars-time = { version = "0.43.1", path = "crates/polars-time", default-features = false } -polars-utils = { version = "0.43.1", path = "crates/polars-utils", default-features = false } +polars = { version = "0.44.1", path = "crates/polars", default-features = false } +polars-compute = { version = "0.44.1", path = "crates/polars-compute", default-features = false } +polars-core = { version = "0.44.1", path = "crates/polars-core", default-features = false } +polars-error = { version = "0.44.1", path = "crates/polars-error", default-features = false } +polars-expr = { version = "0.44.1", path = "crates/polars-expr", default-features = false } +polars-ffi = { version = "0.44.1", path = "crates/polars-ffi", default-features = false } +polars-io = { version = "0.44.1", path = "crates/polars-io", default-features = false } +polars-json = { version = "0.44.1", path = "crates/polars-json", default-features = false } +polars-lazy = { version = "0.44.1", path = "crates/polars-lazy", default-features = false } +polars-mem-engine = { version = "0.44.1", path = "crates/polars-mem-engine", default-features = false } +polars-ops = { version = "0.44.1", path = "crates/polars-ops", default-features = false } +polars-parquet = { version = "0.44.1", path = "crates/polars-parquet", default-features = false } +polars-pipe = { version = "0.44.1", path = "crates/polars-pipe", default-features = false } +polars-plan = { version = "0.44.1", path = "crates/polars-plan", default-features = false } +polars-python = { version = "0.44.1", path = "crates/polars-python", default-features = false } +polars-row = { version = "0.44.1", path = "crates/polars-row", default-features = false } +polars-schema = { version = "0.44.1", path = "crates/polars-schema", default-features = false } +polars-sql = { version = "0.44.1", path = "crates/polars-sql", default-features = false } +polars-stream = { version = "0.44.1", path = "crates/polars-stream", default-features = false } +polars-time = { version = "0.44.1", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.44.1", path = "crates/polars-utils", default-features = false } [workspace.dependencies.arrow-format] package = "polars-arrow-format" @@ -127,7 +119,7 @@ version = "0.1.0" [workspace.dependencies.arrow] package = "polars-arrow" -version = "0.43.1" +version = "0.44.1" path = "crates/polars-arrow" default-features = false features = [ @@ -144,17 +136,20 @@ features = [ # packed_simd_2 = { git = "https://github.com/rust-lang/packed_simd", rev = "e57c7ba11386147e6d2cbad7c88f376aab4bdc86" } # simd-json = { git = "https://github.com/ritchie46/simd-json", branch = "alignment" } -[profile.opt-dev] -inherits = "dev" -opt-level = 1 +[profile.release] +lto = "thin" +debug = "line-tables-only" + +[profile.nodebug-release] +inherits = "release" +debug = false [profile.debug-release] inherits = "release" debug = true -incremental = true -codegen-units = 16 -lto = "thin" -[profile.release] +[profile.dist-release] +inherits = "release" codegen-units = 1 +debug = false lto = "fat" diff --git a/Makefile b/Makefile index 5dd746aa5b7a..c793d10279f4 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,56 @@ else VENV_BIN=$(VENV)/bin endif +# Detect CPU architecture. +ifeq ($(OS),Windows_NT) + ifeq ($(PROCESSOR_ARCHITECTURE),AMD64) + ARCH := amd64 + else ifeq ($(PROCESSOR_ARCHITECTURE),x86) + ARCH := x86 + else ifeq ($(PROCESSOR_ARCHITECTURE),ARM64) + ARCH := arm64 + else + ARCH := unknown + endif +else + UNAME_P := $(shell uname -p) + ifeq ($(UNAME_P),x86_64) + ARCH := amd64 + else ifneq ($(filter %86,$(UNAME_P)),) + ARCH := x86 + else ifneq ($(filter arm%,$(UNAME_P)),) + ARCH := arm64 + else + ARCH := unknown + endif +endif + +# Ensure boolean arguments are normalized to 1/0 to prevent surprises. +ifdef LTS_CPU + ifeq ($(LTS_CPU),0) + else ifeq ($(LTS_CPU),1) + else +$(error LTS_CPU must be 0 or 1 (or undefined, default to 0)) + endif +endif + +# Define RUSTFLAGS and CFLAGS appropriate for the architecture. +# Keep synchronized with .github/workflows/release-python.yml. +ifeq ($(ARCH),amd64) + ifeq ($(LTS_CPU),1) + FEAT_RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b + FEAT_CFLAGS=-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 + else + FEAT_RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt,+pclmulqdq,+movbe -Z tune-cpu=skylake + FEAT_CFLAGS=-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 -mavx -mavx2 -mfma -mbmi -mbmi2 -mlzcnt -mpclmul -mmovbe -mtune=skylake + endif +endif + +override RUSTFLAGS+=$(FEAT_RUSTFLAGS) +override CFLAGS+=$(FEAT_CFLAGS) +export RUSTFLAGS +export CFLAGS + # Define command to filter pip warnings when running maturin FILTER_PIP_WARNINGS=| grep -v "don't match your environment"; test $${PIPESTATUS[0]} -eq 0 @@ -35,55 +85,31 @@ requirements-all: .venv ## Install/refresh all Python requirements (including t .PHONY: build build: .venv ## Compile and install Python Polars for development @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-debug-opt -build-debug-opt: .venv ## Compile and install Python Polars with minimal optimizations turned on +.PHONY: build-release +build-release: .venv ## Compile and install Python Polars binary with optimizations, with minimal debug symbols @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile opt-dev \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-debug-opt-subset -build-debug-opt-subset: .venv ## Compile and install Python Polars with minimal optimizations turned on and no default features +.PHONY: build-nodebug-release +build-nodebug-release: .venv ## Same as build-release, but without any debug symbols at all (a bit faster to build) @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --no-default-features --profile opt-dev \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile nodebug-release $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-opt -build-opt: .venv ## Compile and install Python Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on +.PHONY: build-debug-release +build-debug-release: .venv ## Same as build-release, but with full debug symbols turned on (a bit slower to build) @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release $(ARGS) \ $(FILTER_PIP_WARNINGS) -.PHONY: build-release -build-release: .venv ## Compile and install a faster Python Polars binary with full optimizations +.PHONY: build-dist-release +build-dist-release: .venv ## Compile and install Python Polars binary with super slow extra optimization turned on, for distribution @unset CONDA_PREFIX \ - && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-native -build-native: .venv ## Same as build, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-debug-opt-native -build-debug-opt-native: .venv ## Same as build-debug-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile opt-dev \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-opt-native -build-opt-native: .venv ## Same as build-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release \ - $(FILTER_PIP_WARNINGS) - -.PHONY: build-release-native -build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on - @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ - $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile dist-release $(ARGS) \ $(FILTER_PIP_WARNINGS) .PHONY: check @@ -121,3 +147,6 @@ clean: ## Clean up caches, build artifacts, and the venv help: ## Display this help screen @echo -e "\033[1mAvailable commands:\033[0m" @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort + @echo + @echo The build commands support LTS_CPU=1 for building for older CPUs, and ARGS which is passed through to maturin. + @echo 'For example to build without default features use: make build ARGS="--no-default-features".' diff --git a/README.md b/README.md index 2f7c5b290ad8..43ac43596813 100644 --- a/README.md +++ b/README.md @@ -233,13 +233,14 @@ This can be done by going through the following steps in sequence: 1. Install the latest [Rust compiler](https://www.rust-lang.org/tools/install) 2. Install [maturin](https://maturin.rs/): `pip install maturin` 3. `cd py-polars` and choose one of the following: - - `make build-release`, fastest binary, very long compile times - - `make build-opt`, fast binary with debug symbols, long compile times - - `make build-debug-opt`, medium-speed binary with debug assertions and symbols, medium compile times - `make build`, slow binary with debug assertions and symbols, fast compile times + - `make build-release`, fast binary without debug assertions, minimal debug symbols, long compile times + - `make build-nodebug-release`, same as build-release but without any debug symbols, slightly faster to compile + - `make build-debug-release`, same as build-release but with full debug symbols, slightly slower to compile + - `make build-dist-release`, fastest binary, extreme compile times - Append `-native` (e.g. `make build-release-native`) to enable further optimizations specific to - your CPU. This produces a non-portable binary/wheel however. +By default the binary is compiled with optimizations turned on for a modern CPU. Specify `LTS_CPU=1` +with the command if your CPU is older and does not support e.g. AVX2. Note that the Rust crate implementing the Python bindings is called `py-polars` to distinguish from the wrapped Rust crate `polars` itself. However, both the Python package and the Python module are named `polars`, so you diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 379a9b090d18..5ce6d5deb9e7 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -14,7 +14,7 @@ description = "Minimal implementation of the Arrow specification forked from arr [dependencies] atoi = { workspace = true, optional = true } -bytemuck = { workspace = true } +bytemuck = { workspace = true, features = ["must_cast"] } chrono = { workspace = true } # for timezone support chrono-tz = { workspace = true, optional = true } @@ -67,12 +67,11 @@ multiversion = { workspace = true, optional = true } # Faster hashing ahash = { workspace = true } -# Support conversion to/from arrow-rs -arrow-array = { workspace = true, optional = true } -arrow-buffer = { workspace = true, optional = true } -arrow-data = { workspace = true, optional = true } -arrow-schema = { workspace = true, optional = true } -tokio = { workspace = true, optional = true } +# For async arrow flight conversion +async-stream = { version = "0.3", optional = true } +tokio = { workspace = true, optional = true, features = ["io-util"] } + +strum_macros = { workspace = true } [dev-dependencies] criterion = "0.5" @@ -100,7 +99,6 @@ getrandom = { version = "0.2", features = ["js"] } [features] default = [] full = [ - "arrow_rs", "io_ipc", "io_flight", "io_ipc_compression", @@ -113,10 +111,9 @@ full = [ # parses timezones used in timestamp conversions "chrono-tz", ] -arrow_rs = ["arrow-buffer", "arrow-schema", "arrow-data", "arrow-array"] io_ipc = ["arrow-format", "polars-error/arrow-format"] io_ipc_compression = ["lz4", "zstd", "io_ipc"] -io_flight = ["io_ipc", "arrow-format/flight-data"] +io_flight = ["io_ipc", "arrow-format/flight-data", "async-stream", "futures", "tokio"] io_avro = ["avro-schema", "polars-error/avro-schema"] io_avro_compression = [ @@ -157,7 +154,7 @@ timezones = [ ] dtype-array = [] dtype-decimal = ["atoi", "itoap"] -bigidx = [] +bigidx = ["polars-utils/bigidx"] nightly = [] performant = [] strings = [] diff --git a/crates/polars-arrow/src/array/binary/data.rs b/crates/polars-arrow/src/array/binary/data.rs deleted file mode 100644 index 2c08d94eb1b0..000000000000 --- a/crates/polars-arrow/src/array/binary/data.rs +++ /dev/null @@ -1,43 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, BinaryArray}; -use crate::bitmap::Bitmap; -use crate::offset::{Offset, OffsetsBuffer}; - -impl Arrow2Arrow for BinaryArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.offsets().len_proxy()) - .buffers(vec![ - self.offsets.clone().into_inner().into(), - self.values.clone().into(), - ]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let buffers = data.buffers(); - - // SAFETY: ArrayData is valid - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype, - offsets, - values: buffers[1].clone().into(), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/binary/mod.rs b/crates/polars-arrow/src/array/binary/mod.rs index 9cd06adaaabf..b590a4554597 100644 --- a/crates/polars-arrow/src/array/binary/mod.rs +++ b/crates/polars-arrow/src/array/binary/mod.rs @@ -21,9 +21,6 @@ mod mutable; pub use mutable::*; use polars_error::{polars_bail, PolarsResult}; -#[cfg(feature = "arrow_rs")] -mod data; - /// A [`BinaryArray`] is Arrow's semantically equivalent of an immutable `Vec>>`. /// It implements [`Array`]. /// diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 67334a53aa17..15a744f804e9 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -11,7 +11,7 @@ use polars_utils::total_ord::{TotalEq, TotalOrd}; use crate::buffer::Buffer; use crate::datatypes::PrimitiveType; -use crate::types::NativeType; +use crate::types::{Bytes16Alignment4, NativeType}; // We use this instead of u128 because we want alignment of <= 8 bytes. /// A reference to a set of bytes. @@ -346,7 +346,9 @@ impl MinMax for View { impl NativeType for View { const PRIMITIVE: PrimitiveType = PrimitiveType::UInt128; + type Bytes = [u8; 16]; + type AlignedBytes = Bytes16Alignment4; #[inline] fn to_le_bytes(&self) -> Self::Bytes { diff --git a/crates/polars-arrow/src/array/boolean/data.rs b/crates/polars-arrow/src/array/boolean/data.rs deleted file mode 100644 index 6c497896775c..000000000000 --- a/crates/polars-arrow/src/array/boolean/data.rs +++ /dev/null @@ -1,36 +0,0 @@ -use arrow_buffer::{BooleanBuffer, NullBuffer}; -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, BooleanArray}; -use crate::bitmap::Bitmap; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for BooleanArray { - fn to_data(&self) -> ArrayData { - let buffer = NullBuffer::from(self.values.clone()); - - let builder = ArrayDataBuilder::new(arrow_schema::DataType::Boolean) - .len(buffer.len()) - .offset(buffer.offset()) - .buffers(vec![buffer.into_inner().into_inner()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - assert_eq!(data.data_type(), &arrow_schema::DataType::Boolean); - - let buffers = data.buffers(); - let buffer = BooleanBuffer::new(buffers[0].clone(), data.offset(), data.len()); - // Use NullBuffer to compute set count - let values = Bitmap::from_null_buffer(NullBuffer::new(buffer)); - - Self { - dtype: ArrowDataType::Boolean, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/boolean/mod.rs b/crates/polars-arrow/src/array/boolean/mod.rs index 5cd9870fdbf4..c1a17c0f27f3 100644 --- a/crates/polars-arrow/src/array/boolean/mod.rs +++ b/crates/polars-arrow/src/array/boolean/mod.rs @@ -7,8 +7,6 @@ use crate::bitmap::{Bitmap, MutableBitmap}; use crate::datatypes::{ArrowDataType, PhysicalType}; use crate::trusted_len::TrustedLen; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod from; diff --git a/crates/polars-arrow/src/array/dictionary/data.rs b/crates/polars-arrow/src/array/dictionary/data.rs deleted file mode 100644 index a5eda5a0fd73..000000000000 --- a/crates/polars-arrow/src/array/dictionary/data.rs +++ /dev/null @@ -1,49 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{ - from_data, to_data, Arrow2Arrow, DictionaryArray, DictionaryKey, PrimitiveArray, -}; -use crate::datatypes::{ArrowDataType, PhysicalType}; - -impl Arrow2Arrow for DictionaryArray { - fn to_data(&self) -> ArrayData { - let keys = self.keys.to_data(); - let builder = keys - .into_builder() - .data_type(self.dtype.clone().into()) - .child_data(vec![to_data(self.values.as_ref())]); - - // SAFETY: Dictionary is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let key = match data.data_type() { - arrow_schema::DataType::Dictionary(k, _) => k.as_ref(), - d => panic!("unsupported dictionary type {d}"), - }; - - let dtype = ArrowDataType::from(data.data_type().clone()); - assert_eq!( - dtype.to_physical_type(), - PhysicalType::Dictionary(K::KEY_TYPE) - ); - - let key_builder = ArrayDataBuilder::new(key.clone()) - .buffers(vec![data.buffers()[0].clone()]) - .offset(data.offset()) - .len(data.len()) - .nulls(data.nulls().cloned()); - - // SAFETY: Dictionary is valid - let key_data = unsafe { key_builder.build_unchecked() }; - let keys = PrimitiveArray::from_data(&key_data); - let values = from_data(&data.child_data()[0]); - - Self { - dtype, - keys, - values, - } - } -} diff --git a/crates/polars-arrow/src/array/dictionary/iterator.rs b/crates/polars-arrow/src/array/dictionary/iterator.rs index 68e95ca86fed..af6ef539572d 100644 --- a/crates/polars-arrow/src/array/dictionary/iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/iterator.rs @@ -21,7 +21,7 @@ impl<'a, K: DictionaryKey> DictionaryValuesIter<'a, K> { } } -impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> { +impl Iterator for DictionaryValuesIter<'_, K> { type Item = Box; #[inline] @@ -40,9 +40,9 @@ impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> { } } -unsafe impl<'a, K: DictionaryKey> TrustedLen for DictionaryValuesIter<'a, K> {} +unsafe impl TrustedLen for DictionaryValuesIter<'_, K> {} -impl<'a, K: DictionaryKey> DoubleEndedIterator for DictionaryValuesIter<'a, K> { +impl DoubleEndedIterator for DictionaryValuesIter<'_, K> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs index d53970dacd98..f23c409c48a9 100644 --- a/crates/polars-arrow/src/array/dictionary/mod.rs +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -8,8 +8,6 @@ use crate::scalar::{new_scalar, Scalar}; use crate::trusted_len::TrustedLen; use crate::types::NativeType; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs index 5257bde2cae0..d7e7637bf28d 100644 --- a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs @@ -117,11 +117,9 @@ impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped< } } -unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryValuesIterTyped<'a, K, V> {} +unsafe impl TrustedLen for DictionaryValuesIterTyped<'_, K, V> {} -impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator - for DictionaryValuesIterTyped<'a, K, V> -{ +impl DoubleEndedIterator for DictionaryValuesIterTyped<'_, K, V> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { @@ -181,9 +179,9 @@ impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryIterTyped<'a, K, } } -unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryIterTyped<'a, K, V> {} +unsafe impl TrustedLen for DictionaryIterTyped<'_, K, V> {} -impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator for DictionaryIterTyped<'a, K, V> { +impl DoubleEndedIterator for DictionaryIterTyped<'_, K, V> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/fixed_size_binary/data.rs b/crates/polars-arrow/src/array/fixed_size_binary/data.rs deleted file mode 100644 index f04be9883f64..000000000000 --- a/crates/polars-arrow/src/array/fixed_size_binary/data.rs +++ /dev/null @@ -1,37 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, FixedSizeBinaryArray}; -use crate::bitmap::Bitmap; -use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for FixedSizeBinaryArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.values.clone().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype: ArrowDataType = data.data_type().clone().into(); - let size = match dtype { - ArrowDataType::FixedSizeBinary(size) => size, - _ => unreachable!("must be FixedSizeBinary"), - }; - - let mut values: Buffer = data.buffers()[0].clone().into(); - values.slice(data.offset() * size, data.len() * size); - - Self { - size, - dtype, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs index ec3f96626c14..f8f5a1760d45 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs @@ -3,8 +3,6 @@ use crate::bitmap::Bitmap; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/fixed_size_list/data.rs b/crates/polars-arrow/src/array/fixed_size_list/data.rs deleted file mode 100644 index c1f353db691a..000000000000 --- a/crates/polars-arrow/src/array/fixed_size_list/data.rs +++ /dev/null @@ -1,38 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, FixedSizeListArray}; -use crate::bitmap::Bitmap; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for FixedSizeListArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(vec![to_data(self.values.as_ref())]); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype: ArrowDataType = data.data_type().clone().into(); - let length = data.len() - data.offset(); - let size = match dtype { - ArrowDataType::FixedSizeList(_, size) => size, - _ => unreachable!("must be FixedSizeList type"), - }; - - let mut values = from_data(&data.child_data()[0]); - values.slice(data.offset() * size, data.len() * size); - - Self { - size, - length, - dtype, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 4f1622819813..32267cc5a4b7 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -1,9 +1,7 @@ -use super::{new_empty_array, new_null_array, Array, Splitable}; +use super::{new_empty_array, new_null_array, Array, ArrayRef, Splitable}; use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; @@ -11,8 +9,11 @@ mod iterator; mod mutable; pub use mutable::*; use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_utils::format_tuple; use polars_utils::pl_str::PlSmallStr; +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + /// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. /// Cloning and slicing this struct is `O(1)`. #[derive(Clone)] @@ -122,6 +123,108 @@ impl FixedSizeListArray { let values = new_null_array(field.dtype().clone(), length * size); Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) } + + pub fn from_shape( + leaf_array: ArrayRef, + dimensions: &[ReshapeDimension], + ) -> PolarsResult { + polars_ensure!( + !dimensions.is_empty(), + InvalidOperation: "at least one dimension must be specified" + ); + let size = leaf_array.len(); + + let mut total_dim_size = 1; + let mut num_infers = 0; + for &dim in dimensions { + match dim { + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, + } + } + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + if size == 0 { + polars_ensure!( + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(prev_array); + } + + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // We pop the outer dimension as that is the height of the series. + for dim in dimensions[1..].iter().rev() { + // Infer dimension if needed + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = FixedSizeListArray::new( + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, + prev_array, + None, + ) + .boxed(); + } + Ok(prev_array) + } + + pub fn get_dims(&self) -> Vec { + let mut dims = vec![ + Dimension::new(self.length as _), + Dimension::new(self.size as _), + ]; + + let mut prev_array = &self.values; + + while let Some(a) = prev_array.as_any().downcast_ref::() { + dims.push(Dimension::new(a.size as _)); + prev_array = &a.values; + } + dims + } } // must use @@ -146,6 +249,7 @@ impl FixedSizeListArray { /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); self.validity = self .validity .take() diff --git a/crates/polars-arrow/src/array/growable/list.rs b/crates/polars-arrow/src/array/growable/list.rs index 90e4f15020a6..095f39522da4 100644 --- a/crates/polars-arrow/src/array/growable/list.rs +++ b/crates/polars-arrow/src/array/growable/list.rs @@ -14,7 +14,7 @@ unsafe fn extend_offset_values( start: usize, len: usize, ) { - let array = growable.arrays[index]; + let array = growable.arrays.get_unchecked_release(index); let offsets = array.offsets(); growable diff --git a/crates/polars-arrow/src/array/growable/null.rs b/crates/polars-arrow/src/array/growable/null.rs index c0b92e132819..e663fc31b8b4 100644 --- a/crates/polars-arrow/src/array/growable/null.rs +++ b/crates/polars-arrow/src/array/growable/null.rs @@ -23,7 +23,7 @@ impl GrowableNull { } } -impl<'a> Growable<'a> for GrowableNull { +impl Growable<'_> for GrowableNull { unsafe fn extend(&mut self, _: usize, _: usize, len: usize) { self.length += len; } diff --git a/crates/polars-arrow/src/array/iterator.rs b/crates/polars-arrow/src/array/iterator.rs index 46b585ef2a36..5009442d5718 100644 --- a/crates/polars-arrow/src/array/iterator.rs +++ b/crates/polars-arrow/src/array/iterator.rs @@ -117,3 +117,12 @@ impl<'a, A: ArrayAccessor<'a> + ?Sized> Iterator for NonNullValuesIter<'a, A> { } unsafe impl<'a, A: ArrayAccessor<'a> + ?Sized> TrustedLen for NonNullValuesIter<'a, A> {} + +impl Clone for NonNullValuesIter<'_, A> { + fn clone(&self) -> Self { + Self { + accessor: self.accessor, + idxs: self.idxs.clone(), + } + } +} diff --git a/crates/polars-arrow/src/array/list/data.rs b/crates/polars-arrow/src/array/list/data.rs deleted file mode 100644 index 0d28583df125..000000000000 --- a/crates/polars-arrow/src/array/list/data.rs +++ /dev/null @@ -1,38 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, ListArray}; -use crate::bitmap::Bitmap; -use crate::offset::{Offset, OffsetsBuffer}; - -impl Arrow2Arrow for ListArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.offsets.clone().into_inner().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(vec![to_data(self.values.as_ref())]); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype, - offsets, - values: from_data(&data.child_data()[0]), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/list/mod.rs b/crates/polars-arrow/src/array/list/mod.rs index 3c2bb6b41f98..87f7b709f14b 100644 --- a/crates/polars-arrow/src/array/list/mod.rs +++ b/crates/polars-arrow/src/array/list/mod.rs @@ -4,8 +4,6 @@ use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; use crate::offset::{Offset, Offsets, OffsetsBuffer}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; @@ -130,6 +128,49 @@ impl ListArray { impl_sliced!(); impl_mut_validity!(); impl_into_array!(); + + pub fn trim_to_normalized_offsets_recursive(&self) -> Self { + let offsets = self.offsets(); + let values = self.values(); + + let first_idx = *offsets.first(); + let len = offsets.range().to_usize(); + + if first_idx.to_usize() == 0 && values.len() == len { + return self.clone(); + } + + let offsets = if first_idx.to_usize() == 0 { + offsets.clone() + } else { + let v = offsets.iter().map(|x| *x - first_idx).collect::>(); + unsafe { OffsetsBuffer::::new_unchecked(v.into()) } + }; + + let values = values.sliced(first_idx.to_usize(), len); + + let values = match values.dtype() { + ArrowDataType::List(_) => { + let inner: &ListArray = values.as_ref().as_any().downcast_ref().unwrap(); + Box::new(inner.trim_to_normalized_offsets_recursive()) as Box + }, + ArrowDataType::LargeList(_) => { + let inner: &ListArray = values.as_ref().as_any().downcast_ref().unwrap(); + Box::new(inner.trim_to_normalized_offsets_recursive()) as Box + }, + _ => values, + }; + + assert_eq!(offsets.first().to_usize(), 0); + assert_eq!(values.len(), offsets.range().to_usize()); + + Self::new( + self.dtype().clone(), + offsets, + values, + self.validity().cloned(), + ) + } } // Accessors diff --git a/crates/polars-arrow/src/array/map/data.rs b/crates/polars-arrow/src/array/map/data.rs deleted file mode 100644 index b5530886d817..000000000000 --- a/crates/polars-arrow/src/array/map/data.rs +++ /dev/null @@ -1,38 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, MapArray}; -use crate::bitmap::Bitmap; -use crate::offset::OffsetsBuffer; - -impl Arrow2Arrow for MapArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.offsets.clone().into_inner().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(vec![to_data(self.field.as_ref())]); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype: data.data_type().clone().into(), - offsets, - field: from_data(&data.child_data()[0]), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/map/iterator.rs b/crates/polars-arrow/src/array/map/iterator.rs index 558405ddc8de..79fc630cc520 100644 --- a/crates/polars-arrow/src/array/map/iterator.rs +++ b/crates/polars-arrow/src/array/map/iterator.rs @@ -22,7 +22,7 @@ impl<'a> MapValuesIter<'a> { } } -impl<'a> Iterator for MapValuesIter<'a> { +impl Iterator for MapValuesIter<'_> { type Item = Box; #[inline] @@ -43,9 +43,9 @@ impl<'a> Iterator for MapValuesIter<'a> { } } -unsafe impl<'a> TrustedLen for MapValuesIter<'a> {} +unsafe impl TrustedLen for MapValuesIter<'_> {} -impl<'a> DoubleEndedIterator for MapValuesIter<'a> { +impl DoubleEndedIterator for MapValuesIter<'_> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/map/mod.rs b/crates/polars-arrow/src/array/map/mod.rs index 5497c1d7342b..1018c21c830a 100644 --- a/crates/polars-arrow/src/array/map/mod.rs +++ b/crates/polars-arrow/src/array/map/mod.rs @@ -4,8 +4,6 @@ use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; use crate::offset::OffsetsBuffer; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs index 08702e8021d3..a2acd7164f6a 100644 --- a/crates/polars-arrow/src/array/mod.rs +++ b/crates/polars-arrow/src/array/mod.rs @@ -189,7 +189,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { new } - /// Clones this [`Array`] with a new new assigned bitmap. + /// Clones this [`Array`] with a new assigned bitmap. /// # Panic /// This function panics iff `validity.len() != self.len()`. fn with_validity(&self, validity: Option) -> Box; @@ -409,115 +409,6 @@ pub fn new_null_array(dtype: ArrowDataType, length: usize) -> Box { } } -/// Trait providing bi-directional conversion between polars_arrow [`Array`] and arrow-rs [`ArrayData`] -/// -/// [`ArrayData`]: arrow_data::ArrayData -#[cfg(feature = "arrow_rs")] -pub trait Arrow2Arrow: Array { - /// Convert this [`Array`] into [`ArrayData`] - fn to_data(&self) -> arrow_data::ArrayData; - - /// Create this [`Array`] from [`ArrayData`] - fn from_data(data: &arrow_data::ArrayData) -> Self; -} - -#[cfg(feature = "arrow_rs")] -macro_rules! to_data_dyn { - ($array:expr, $ty:ty) => {{ - let f = |x: &$ty| x.to_data(); - general_dyn!($array, $ty, f) - }}; -} - -#[cfg(feature = "arrow_rs")] -impl From> for arrow_array::ArrayRef { - fn from(value: Box) -> Self { - value.as_ref().into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&dyn Array> for arrow_array::ArrayRef { - fn from(value: &dyn Array) -> Self { - arrow_array::make_array(to_data(value)) - } -} - -#[cfg(feature = "arrow_rs")] -impl From for Box { - fn from(value: arrow_array::ArrayRef) -> Self { - value.as_ref().into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&dyn arrow_array::Array> for Box { - fn from(value: &dyn arrow_array::Array) -> Self { - from_data(&value.to_data()) - } -} - -/// Convert an polars_arrow [`Array`] to [`arrow_data::ArrayData`] -#[cfg(feature = "arrow_rs")] -pub fn to_data(array: &dyn Array) -> arrow_data::ArrayData { - use crate::datatypes::PhysicalType::*; - match array.dtype().to_physical_type() { - Null => to_data_dyn!(array, NullArray), - Boolean => to_data_dyn!(array, BooleanArray), - Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - to_data_dyn!(array, PrimitiveArray<$T>) - }), - Binary => to_data_dyn!(array, BinaryArray), - LargeBinary => to_data_dyn!(array, BinaryArray), - FixedSizeBinary => to_data_dyn!(array, FixedSizeBinaryArray), - Utf8 => to_data_dyn!(array, Utf8Array::), - LargeUtf8 => to_data_dyn!(array, Utf8Array::), - List => to_data_dyn!(array, ListArray::), - LargeList => to_data_dyn!(array, ListArray::), - FixedSizeList => to_data_dyn!(array, FixedSizeListArray), - Struct => to_data_dyn!(array, StructArray), - Union => to_data_dyn!(array, UnionArray), - Dictionary(key_type) => { - match_integer_type!(key_type, |$T| { - to_data_dyn!(array, DictionaryArray::<$T>) - }) - }, - Map => to_data_dyn!(array, MapArray), - BinaryView | Utf8View => todo!(), - } -} - -/// Convert an [`arrow_data::ArrayData`] to polars_arrow [`Array`] -#[cfg(feature = "arrow_rs")] -pub fn from_data(data: &arrow_data::ArrayData) -> Box { - use crate::datatypes::PhysicalType::*; - let dtype: ArrowDataType = data.data_type().clone().into(); - match dtype.to_physical_type() { - Null => Box::new(NullArray::from_data(data)), - Boolean => Box::new(BooleanArray::from_data(data)), - Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - Box::new(PrimitiveArray::<$T>::from_data(data)) - }), - Binary => Box::new(BinaryArray::::from_data(data)), - LargeBinary => Box::new(BinaryArray::::from_data(data)), - FixedSizeBinary => Box::new(FixedSizeBinaryArray::from_data(data)), - Utf8 => Box::new(Utf8Array::::from_data(data)), - LargeUtf8 => Box::new(Utf8Array::::from_data(data)), - List => Box::new(ListArray::::from_data(data)), - LargeList => Box::new(ListArray::::from_data(data)), - FixedSizeList => Box::new(FixedSizeListArray::from_data(data)), - Struct => Box::new(StructArray::from_data(data)), - Union => Box::new(UnionArray::from_data(data)), - Dictionary(key_type) => { - match_integer_type!(key_type, |$T| { - Box::new(DictionaryArray::<$T>::from_data(data)) - }) - }, - Map => Box::new(MapArray::from_data(data)), - BinaryView | Utf8View => todo!(), - } -} - macro_rules! clone_dyn { ($array:expr, $ty:ty) => {{ let f = |x: &$ty| Box::new(x.clone()); diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 4960b263667c..e6e840d86860 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -213,24 +213,3 @@ impl FromFfi for NullArray { Self::try_new(dtype, array.array().len()) } } - -#[cfg(feature = "arrow_rs")] -mod arrow { - use arrow_data::{ArrayData, ArrayDataBuilder}; - - use super::*; - impl NullArray { - /// Convert this array into [`arrow_data::ArrayData`] - pub fn to_data(&self) -> ArrayData { - let builder = ArrayDataBuilder::new(arrow_schema::DataType::Null).len(self.len()); - - // SAFETY: safe by construction - unsafe { builder.build_unchecked() } - } - - /// Create this array from [`ArrayData`] - pub fn from_data(data: &ArrayData) -> Self { - Self::new(ArrowDataType::Null, data.len()) - } - } -} diff --git a/crates/polars-arrow/src/array/primitive/data.rs b/crates/polars-arrow/src/array/primitive/data.rs deleted file mode 100644 index 56a94107cb89..000000000000 --- a/crates/polars-arrow/src/array/primitive/data.rs +++ /dev/null @@ -1,33 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, PrimitiveArray}; -use crate::bitmap::Bitmap; -use crate::buffer::Buffer; -use crate::types::NativeType; - -impl Arrow2Arrow for PrimitiveArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .buffers(vec![self.values.clone().into()]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - - let mut values: Buffer = data.buffers()[0].clone().into(); - values.slice(data.offset(), data.len()); - - Self { - dtype, - values, - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/primitive/mod.rs b/crates/polars-arrow/src/array/primitive/mod.rs index 831a10c372cc..8accc161faf2 100644 --- a/crates/polars-arrow/src/array/primitive/mod.rs +++ b/crates/polars-arrow/src/array/primitive/mod.rs @@ -11,8 +11,6 @@ use crate::datatypes::*; use crate::trusted_len::TrustedLen; use crate::types::{days_ms, f16, i256, months_days_ns, NativeType}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod from_natural; @@ -459,8 +457,8 @@ impl PrimitiveArray { // SAFETY: this is fine, we checked size and alignment, and NativeType // is always Pod. - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - assert_eq!(std::mem::align_of::(), std::mem::align_of::()); + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); let new_values = unsafe { std::mem::transmute::, Buffer>(values) }; PrimitiveArray::new(U::PRIMITIVE.into(), new_values, validity) } diff --git a/crates/polars-arrow/src/array/static_array_collect.rs b/crates/polars-arrow/src/array/static_array_collect.rs index 296d93502abe..9ff5ceb39361 100644 --- a/crates/polars-arrow/src/array/static_array_collect.rs +++ b/crates/polars-arrow/src/array/static_array_collect.rs @@ -417,10 +417,10 @@ impl IntoBytes for T { } } impl TrivialIntoBytes for Vec {} -impl<'a> TrivialIntoBytes for Cow<'a, [u8]> {} -impl<'a> TrivialIntoBytes for &'a [u8] {} +impl TrivialIntoBytes for Cow<'_, [u8]> {} +impl TrivialIntoBytes for &[u8] {} impl TrivialIntoBytes for String {} -impl<'a> TrivialIntoBytes for &'a str {} +impl TrivialIntoBytes for &str {} impl<'a> IntoBytes for Cow<'a, str> { type AsRefT = Cow<'a, [u8]>; fn into_bytes(self) -> Cow<'a, [u8]> { @@ -590,8 +590,8 @@ unsafe fn into_utf8array(arr: BinaryArray) -> Utf8Array { trait StrIntoBytes: IntoBytes {} impl StrIntoBytes for String {} -impl<'a> StrIntoBytes for &'a str {} -impl<'a> StrIntoBytes for Cow<'a, str> {} +impl StrIntoBytes for &str {} +impl StrIntoBytes for Cow<'_, str> {} impl ArrayFromIter for Utf8ViewArray { #[inline] diff --git a/crates/polars-arrow/src/array/struct_/data.rs b/crates/polars-arrow/src/array/struct_/data.rs deleted file mode 100644 index a65b491bfe77..000000000000 --- a/crates/polars-arrow/src/array/struct_/data.rs +++ /dev/null @@ -1,29 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, StructArray}; -use crate::bitmap::Bitmap; - -impl Arrow2Arrow for StructArray { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype.clone().into(); - - let builder = ArrayDataBuilder::new(dtype) - .len(self.len()) - .nulls(self.validity.as_ref().map(|b| b.clone().into())) - .child_data(self.values.iter().map(|x| to_data(x.as_ref())).collect()); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - - Self { - dtype, - length: data.len(), - values: data.child_data().iter().map(from_data).collect(), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/struct_/iterator.rs b/crates/polars-arrow/src/array/struct_/iterator.rs index 4e89af3a6a7f..38a49f274cde 100644 --- a/crates/polars-arrow/src/array/struct_/iterator.rs +++ b/crates/polars-arrow/src/array/struct_/iterator.rs @@ -20,7 +20,7 @@ impl<'a> StructValueIter<'a> { } } -impl<'a> Iterator for StructValueIter<'a> { +impl Iterator for StructValueIter<'_> { type Item = Vec>; #[inline] @@ -48,9 +48,9 @@ impl<'a> Iterator for StructValueIter<'a> { } } -unsafe impl<'a> TrustedLen for StructValueIter<'a> {} +unsafe impl TrustedLen for StructValueIter<'_> {} -impl<'a> DoubleEndedIterator for StructValueIter<'a> { +impl DoubleEndedIterator for StructValueIter<'_> { #[inline] fn next_back(&mut self) -> Option { if self.index == self.end { diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs index 11d0f2de200f..eeaac519bb0d 100644 --- a/crates/polars-arrow/src/array/struct_/mod.rs +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -2,8 +2,6 @@ use super::{new_empty_array, new_null_array, Array, Splitable}; use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/union/data.rs b/crates/polars-arrow/src/array/union/data.rs deleted file mode 100644 index 869fdcfc248d..000000000000 --- a/crates/polars-arrow/src/array/union/data.rs +++ /dev/null @@ -1,70 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{from_data, to_data, Arrow2Arrow, UnionArray}; -use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; - -impl Arrow2Arrow for UnionArray { - fn to_data(&self) -> ArrayData { - let dtype = arrow_schema::DataType::from(self.dtype.clone()); - let len = self.len(); - - let builder = match self.offsets.clone() { - Some(offsets) => ArrayDataBuilder::new(dtype) - .len(len) - .buffers(vec![self.types.clone().into(), offsets.into()]) - .child_data(self.fields.iter().map(|x| to_data(x.as_ref())).collect()), - None => ArrayDataBuilder::new(dtype) - .len(len) - .buffers(vec![self.types.clone().into()]) - .child_data( - self.fields - .iter() - .map(|x| to_data(x.as_ref()).slice(self.offset, len)) - .collect(), - ), - }; - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype: ArrowDataType = data.data_type().clone().into(); - - let fields = data.child_data().iter().map(from_data).collect(); - let buffers = data.buffers(); - let mut types: Buffer = buffers[0].clone().into(); - types.slice(data.offset(), data.len()); - let offsets = match buffers.len() == 2 { - true => { - let mut offsets: Buffer = buffers[1].clone().into(); - offsets.slice(data.offset(), data.len()); - Some(offsets) - }, - false => None, - }; - - // Map from type id to array index - let map = match &dtype { - ArrowDataType::Union(_, Some(ids), _) => { - let mut map = [0; 127]; - for (pos, &id) in ids.iter().enumerate() { - map[id as usize] = pos; - } - Some(map) - }, - ArrowDataType::Union(_, None, _) => None, - _ => unreachable!("must be Union type"), - }; - - Self { - types, - map, - fields, - offsets, - dtype, - offset: data.offset(), - } - } -} diff --git a/crates/polars-arrow/src/array/union/iterator.rs b/crates/polars-arrow/src/array/union/iterator.rs index bdcf5825af6c..e93223e46c43 100644 --- a/crates/polars-arrow/src/array/union/iterator.rs +++ b/crates/polars-arrow/src/array/union/iterator.rs @@ -15,7 +15,7 @@ impl<'a> UnionIter<'a> { } } -impl<'a> Iterator for UnionIter<'a> { +impl Iterator for UnionIter<'_> { type Item = Box; #[inline] @@ -54,6 +54,6 @@ impl<'a> UnionArray { } } -impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {} +impl std::iter::ExactSizeIterator for UnionIter<'_> {} -unsafe impl<'a> TrustedLen for UnionIter<'a> {} +unsafe impl TrustedLen for UnionIter<'_> {} diff --git a/crates/polars-arrow/src/array/union/mod.rs b/crates/polars-arrow/src/array/union/mod.rs index e42d268f5c06..f8007a485ed5 100644 --- a/crates/polars-arrow/src/array/union/mod.rs +++ b/crates/polars-arrow/src/array/union/mod.rs @@ -6,8 +6,6 @@ use crate::buffer::Buffer; use crate::datatypes::{ArrowDataType, Field, UnionMode}; use crate::scalar::{new_scalar, Scalar}; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod iterator; diff --git a/crates/polars-arrow/src/array/utf8/data.rs b/crates/polars-arrow/src/array/utf8/data.rs deleted file mode 100644 index 37f73a089aa6..000000000000 --- a/crates/polars-arrow/src/array/utf8/data.rs +++ /dev/null @@ -1,42 +0,0 @@ -use arrow_data::{ArrayData, ArrayDataBuilder}; - -use crate::array::{Arrow2Arrow, Utf8Array}; -use crate::bitmap::Bitmap; -use crate::offset::{Offset, OffsetsBuffer}; - -impl Arrow2Arrow for Utf8Array { - fn to_data(&self) -> ArrayData { - let dtype = self.dtype().clone().into(); - let builder = ArrayDataBuilder::new(dtype) - .len(self.offsets().len_proxy()) - .buffers(vec![ - self.offsets.clone().into_inner().into(), - self.values.clone().into(), - ]) - .nulls(self.validity.as_ref().map(|b| b.clone().into())); - - // SAFETY: Array is valid - unsafe { builder.build_unchecked() } - } - - fn from_data(data: &ArrayData) -> Self { - let dtype = data.data_type().clone().into(); - if data.is_empty() { - // Handle empty offsets - return Self::new_empty(dtype); - } - - let buffers = data.buffers(); - - // SAFETY: ArrayData is valid - let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; - offsets.slice(data.offset(), data.len() + 1); - - Self { - dtype, - offsets, - values: buffers[1].clone().into(), - validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), - } - } -} diff --git a/crates/polars-arrow/src/array/utf8/mod.rs b/crates/polars-arrow/src/array/utf8/mod.rs index ebec52b78d28..fffa36ba2f8f 100644 --- a/crates/polars-arrow/src/array/utf8/mod.rs +++ b/crates/polars-arrow/src/array/utf8/mod.rs @@ -11,8 +11,6 @@ use crate::datatypes::ArrowDataType; use crate::offset::{Offset, Offsets, OffsetsBuffer}; use crate::trusted_len::TrustedLen; -#[cfg(feature = "arrow_rs")] -mod data; mod ffi; pub(super) mod fmt; mod from; diff --git a/crates/polars-arrow/src/bitmap/aligned.rs b/crates/polars-arrow/src/bitmap/aligned.rs index b5eab49432ca..ad1baf06631e 100644 --- a/crates/polars-arrow/src/bitmap/aligned.rs +++ b/crates/polars-arrow/src/bitmap/aligned.rs @@ -55,7 +55,7 @@ impl<'a, T: BitChunk> AlignedBitmapSlice<'a, T> { /// The length (in bits) of the portion of the bitmap found in bulk. #[inline(always)] pub fn bulk_bitlen(&self) -> usize { - 8 * std::mem::size_of::() * self.bulk.len() + 8 * size_of::() * self.bulk.len() } /// The length (in bits) of the portion of the bitmap found in suffix. @@ -77,7 +77,7 @@ impl<'a, T: BitChunk> AlignedBitmapSlice<'a, T> { offset %= 8; // Fast-path: fits entirely in one chunk. - let chunk_len = std::mem::size_of::(); + let chunk_len = size_of::(); let chunk_len_bits = 8 * chunk_len; if offset + len <= chunk_len_bits { let mut prefix = load_chunk_le::(bytes) >> offset; diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs index a3edb658be4e..ec9814e4fad3 100644 --- a/crates/polars-arrow/src/bitmap/bitmap_ops.rs +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -12,7 +12,7 @@ pub(crate) fn push_bitchunk(buffer: &mut Vec, value: T) { /// Creates a [`Vec`] from a [`TrustedLen`] of [`BitChunk`]. pub fn chunk_iter_to_vec>(iter: I) -> Vec { - let cap = iter.size_hint().0 * std::mem::size_of::(); + let cap = iter.size_hint().0 * size_of::(); let mut buffer = Vec::with_capacity(cap); for v in iter { push_bitchunk(&mut buffer, v) @@ -24,7 +24,7 @@ fn chunk_iter_to_vec_and_remainder>( iter: I, remainder: T, ) -> Vec { - let cap = (iter.size_hint().0 + 1) * std::mem::size_of::(); + let cap = (iter.size_hint().0 + 1) * size_of::(); let mut buffer = Vec::with_capacity(cap); for v in iter { push_bitchunk(&mut buffer, v) @@ -338,7 +338,7 @@ impl PartialEq for Bitmap { } } -impl<'a, 'b> BitOr<&'b Bitmap> for &'a Bitmap { +impl<'b> BitOr<&'b Bitmap> for &Bitmap { type Output = Bitmap; fn bitor(self, rhs: &'b Bitmap) -> Bitmap { @@ -346,7 +346,7 @@ impl<'a, 'b> BitOr<&'b Bitmap> for &'a Bitmap { } } -impl<'a, 'b> BitAnd<&'b Bitmap> for &'a Bitmap { +impl<'b> BitAnd<&'b Bitmap> for &Bitmap { type Output = Bitmap; fn bitand(self, rhs: &'b Bitmap) -> Bitmap { @@ -354,7 +354,7 @@ impl<'a, 'b> BitAnd<&'b Bitmap> for &'a Bitmap { } } -impl<'a, 'b> BitXor<&'b Bitmap> for &'a Bitmap { +impl<'b> BitXor<&'b Bitmap> for &Bitmap { type Output = Bitmap; fn bitxor(self, rhs: &'b Bitmap) -> Bitmap { diff --git a/crates/polars-arrow/src/bitmap/builder.rs b/crates/polars-arrow/src/bitmap/builder.rs new file mode 100644 index 000000000000..c507df97c5ba --- /dev/null +++ b/crates/polars-arrow/src/bitmap/builder.rs @@ -0,0 +1,100 @@ +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::storage::SharedStorage; + +/// Used to build bitmaps bool-by-bool in sequential order. +#[derive(Default, Clone)] +pub struct BitmapBuilder { + buf: u64, + len: usize, + cap: usize, + set_bits: usize, + bytes: Vec, +} + +impl BitmapBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn capacity(&self) -> usize { + self.cap + } + + pub fn with_capacity(bits: usize) -> Self { + let bytes = Vec::with_capacity(bits.div_ceil(64) * 8); + let words_available = bytes.capacity() / 8; + Self { + buf: 0, + len: 0, + cap: words_available * 64, + set_bits: 0, + bytes, + } + } + + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + if self.len + additional > self.cap { + self.reserve_slow(additional) + } + } + + #[cold] + #[inline(never)] + fn reserve_slow(&mut self, additional: usize) { + let bytes_needed = (self.len + additional).div_ceil(64) * 8; + self.bytes.reserve(bytes_needed - self.bytes.capacity()); + let words_available = self.bytes.capacity() / 8; + self.cap = words_available * 64; + } + + #[inline(always)] + pub fn push(&mut self, x: bool) { + self.reserve(1); + unsafe { self.push_unchecked(x) } + } + + /// # Safety + /// self.len() < self.capacity() must hold. + #[inline(always)] + pub unsafe fn push_unchecked(&mut self, x: bool) { + debug_assert!(self.len < self.cap); + self.buf |= (x as u64) << (self.len % 64); + self.len += 1; + if self.len % 64 == 0 { + let p = self.bytes.as_mut_ptr().add(self.bytes.len()).cast::(); + p.write_unaligned(self.buf.to_le()); + self.bytes.set_len(self.bytes.len() + 8); + self.set_bits += self.buf.count_ones() as usize; + self.buf = 0; + } + } + + /// # Safety + /// May only be called once at the end. + unsafe fn finish(&mut self) { + if self.len % 64 != 0 { + self.bytes.extend_from_slice(&self.buf.to_le_bytes()); + self.set_bits += self.buf.count_ones() as usize; + } + } + + pub fn into_mut(mut self) -> MutableBitmap { + unsafe { + self.finish(); + MutableBitmap::from_vec(self.bytes, self.len) + } + } + + pub fn freeze(mut self) -> Bitmap { + unsafe { + self.finish(); + let storage = SharedStorage::from_vec(self.bytes); + Bitmap::from_inner_unchecked(storage, 0, self.len, Some(self.len - self.set_bits)) + } + } +} diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index 2ba89e68568a..a896651467d2 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -595,22 +595,6 @@ impl Bitmap { ) -> std::result::Result { Ok(MutableBitmap::try_from_trusted_len_iter_unchecked(iterator)?.into()) } - - /// Create a new [`Bitmap`] from an arrow [`NullBuffer`] - /// - /// [`NullBuffer`]: arrow_buffer::buffer::NullBuffer - #[cfg(feature = "arrow_rs")] - pub fn from_null_buffer(value: arrow_buffer::buffer::NullBuffer) -> Self { - let offset = value.offset(); - let length = value.len(); - let unset_bits = value.null_count(); - Self { - storage: SharedStorage::from_arrow_buffer(value.buffer().clone()), - offset, - length, - unset_bit_count_cache: AtomicU64::new(unset_bits as u64), - } - } } impl<'a> IntoIterator for &'a Bitmap { @@ -631,17 +615,6 @@ impl IntoIterator for Bitmap { } } -#[cfg(feature = "arrow_rs")] -impl From for arrow_buffer::buffer::NullBuffer { - fn from(value: Bitmap) -> Self { - let null_count = value.unset_bits(); - let buffer = value.storage.into_arrow_buffer(); - let buffer = arrow_buffer::buffer::BooleanBuffer::new(buffer, value.offset, value.length); - // SAFETY: null count is accurate - unsafe { arrow_buffer::buffer::NullBuffer::new_unchecked(buffer, null_count) } - } -} - impl Splitable for Bitmap { #[inline(always)] fn check_bound(&self, offset: usize) -> bool { diff --git a/crates/polars-arrow/src/bitmap/iterator.rs b/crates/polars-arrow/src/bitmap/iterator.rs index 63c61afd83f7..84e0a2d7a985 100644 --- a/crates/polars-arrow/src/bitmap/iterator.rs +++ b/crates/polars-arrow/src/bitmap/iterator.rs @@ -25,6 +25,7 @@ fn calc_iters_remaining(length: usize, min_length_for_iter: usize, consume: usiz 1 + obvious_iters // Thus always exactly 1 more iter. } +#[derive(Clone)] pub struct TrueIdxIter<'a> { mask: BitMask<'a>, first_unknown: usize, @@ -57,7 +58,7 @@ impl<'a> TrueIdxIter<'a> { } } -impl<'a> Iterator for TrueIdxIter<'a> { +impl Iterator for TrueIdxIter<'_> { type Item = usize; #[inline] @@ -92,7 +93,7 @@ impl<'a> Iterator for TrueIdxIter<'a> { } } -unsafe impl<'a> TrustedLen for TrueIdxIter<'a> {} +unsafe impl TrustedLen for TrueIdxIter<'_> {} pub struct FastU32BitmapIter<'a> { bytes: &'a [u8], @@ -142,7 +143,7 @@ impl<'a> FastU32BitmapIter<'a> { } } -impl<'a> Iterator for FastU32BitmapIter<'a> { +impl Iterator for FastU32BitmapIter<'_> { type Item = u32; #[inline] @@ -170,7 +171,7 @@ impl<'a> Iterator for FastU32BitmapIter<'a> { } } -unsafe impl<'a> TrustedLen for FastU32BitmapIter<'a> {} +unsafe impl TrustedLen for FastU32BitmapIter<'_> {} pub struct FastU56BitmapIter<'a> { bytes: &'a [u8], @@ -221,7 +222,7 @@ impl<'a> FastU56BitmapIter<'a> { } } -impl<'a> Iterator for FastU56BitmapIter<'a> { +impl Iterator for FastU56BitmapIter<'_> { type Item = u64; #[inline] @@ -251,7 +252,7 @@ impl<'a> Iterator for FastU56BitmapIter<'a> { } } -unsafe impl<'a> TrustedLen for FastU56BitmapIter<'a> {} +unsafe impl TrustedLen for FastU56BitmapIter<'_> {} pub struct FastU64BitmapIter<'a> { bytes: &'a [u8], @@ -316,7 +317,7 @@ impl<'a> FastU64BitmapIter<'a> { } } -impl<'a> Iterator for FastU64BitmapIter<'a> { +impl Iterator for FastU64BitmapIter<'_> { type Item = u64; #[inline] @@ -348,7 +349,7 @@ impl<'a> Iterator for FastU64BitmapIter<'a> { } } -unsafe impl<'a> TrustedLen for FastU64BitmapIter<'a> {} +unsafe impl TrustedLen for FastU64BitmapIter<'_> {} /// This crates' equivalent of [`std::vec::IntoIter`] for [`Bitmap`]. #[derive(Debug, Clone)] diff --git a/crates/polars-arrow/src/bitmap/mod.rs b/crates/polars-arrow/src/bitmap/mod.rs index e7ed5fa363e8..6d518bf596b4 100644 --- a/crates/polars-arrow/src/bitmap/mod.rs +++ b/crates/polars-arrow/src/bitmap/mod.rs @@ -19,3 +19,6 @@ pub use assign_ops::*; pub mod utils; pub mod bitmask; + +mod builder; +pub use builder::*; diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs index 7bc12e22898e..50f4e023ac1f 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -17,7 +17,7 @@ impl<'a, T: BitChunk> BitChunksExact<'a, T> { #[inline] pub fn new(bitmap: &'a [u8], length: usize) -> Self { assert!(length <= bitmap.len() * 8); - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let bitmap = &bitmap[..length.saturating_add(7) / 8]; diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs index 81e08df0059e..680d3bf96fa4 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs @@ -23,7 +23,7 @@ where // expected = [n5, n6, n7, c0, c1, c2, c3, c4] // 1. unset most significants of `next` up to `offset` - let inverse_offset = std::mem::size_of::() * 8 - offset; + let inverse_offset = size_of::() * 8 - offset; next <<= inverse_offset; // next = [n5, n6, n7, 0 , 0 , 0 , 0 , 0 ] diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs index 8a1668a37d1f..7c387d65b10d 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -44,7 +44,7 @@ fn copy_with_merge(dst: &mut T::Bytes, bytes: &[u8], bit_offset: us bytes .windows(2) .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref())) - .take(std::mem::size_of::()) + .take(size_of::()) .enumerate() .for_each(|(i, w)| { let val = merge_reversed(w[0], w[1], bit_offset); @@ -59,7 +59,7 @@ impl<'a, T: BitChunk> BitChunks<'a, T> { let slice = &slice[offset / 8..]; let bit_offset = offset % 8; - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let bytes_len = len / 8; let bytes_upper_len = (len + bit_offset + 7) / 8; @@ -120,7 +120,7 @@ impl<'a, T: BitChunk> BitChunks<'a, T> { // all remaining bytes self.remainder_bytes .iter() - .take(std::mem::size_of::()) + .take(size_of::()) .enumerate() .for_each(|(i, val)| remainder[i] = *val); @@ -137,7 +137,7 @@ impl<'a, T: BitChunk> BitChunks<'a, T> { /// Returns the remainder bits in [`BitChunks::remainder`]. pub fn remainder_len(&self) -> usize { - self.len - (std::mem::size_of::() * ((self.len / 8) / std::mem::size_of::()) * 8) + self.len - (size_of::() * ((self.len / 8) / size_of::()) * 8) } } diff --git a/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs b/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs index 7a5a91a12805..04a9f8b661a7 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs @@ -4,7 +4,7 @@ use super::BitChunk; /// /// # Safety /// The slices returned by this iterator are guaranteed to have length equal to -/// `std::mem::size_of::()`. +/// `size_of::()`. #[derive(Debug)] pub struct BitChunksExactMut<'a, T: BitChunk> { chunks: std::slice::ChunksExactMut<'a, u8>, @@ -18,7 +18,7 @@ impl<'a, T: BitChunk> BitChunksExactMut<'a, T> { #[inline] pub fn new(bitmap: &'a mut [u8], length: usize) -> Self { assert!(length <= bitmap.len() * 8); - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let bitmap = &mut bitmap[..length.saturating_add(7) / 8]; diff --git a/crates/polars-arrow/src/bitmap/utils/iterator.rs b/crates/polars-arrow/src/bitmap/utils/iterator.rs index c386b2d02db4..243372599687 100644 --- a/crates/polars-arrow/src/bitmap/utils/iterator.rs +++ b/crates/polars-arrow/src/bitmap/utils/iterator.rs @@ -176,7 +176,7 @@ impl<'a> BitmapIter<'a> { let num_words = n / 64; if num_words > 0 { - assert!(self.bytes.len() >= num_words * std::mem::size_of::()); + assert!(self.bytes.len() >= num_words * size_of::()); bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize); @@ -189,7 +189,7 @@ impl<'a> BitmapIter<'a> { return; } - assert!(self.bytes.len() >= std::mem::size_of::()); + assert!(self.bytes.len() >= size_of::()); self.word_len = usize::min(self.rest_len, 64); self.rest_len -= self.word_len; @@ -205,7 +205,7 @@ impl<'a> BitmapIter<'a> { } } -impl<'a> Iterator for BitmapIter<'a> { +impl Iterator for BitmapIter<'_> { type Item = bool; #[inline] @@ -238,7 +238,7 @@ impl<'a> Iterator for BitmapIter<'a> { } } -impl<'a> DoubleEndedIterator for BitmapIter<'a> { +impl DoubleEndedIterator for BitmapIter<'_> { #[inline] fn next_back(&mut self) -> Option { if self.rest_len > 0 { diff --git a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs index f3083ad0b141..9f43a3dfe89a 100644 --- a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs +++ b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs @@ -74,7 +74,7 @@ impl<'a> SlicesIterator<'a> { } } -impl<'a> Iterator for SlicesIterator<'a> { +impl Iterator for SlicesIterator<'_> { type Item = (usize, usize); #[inline] diff --git a/crates/polars-arrow/src/buffer/immutable.rs b/crates/polars-arrow/src/buffer/immutable.rs index 1dfe805ffc57..1c6e5b5aa4ff 100644 --- a/crates/polars-arrow/src/buffer/immutable.rs +++ b/crates/polars-arrow/src/buffer/immutable.rs @@ -288,24 +288,6 @@ impl IntoIterator for Buffer { } } -#[cfg(feature = "arrow_rs")] -impl From for Buffer { - fn from(value: arrow_buffer::Buffer) -> Self { - Self::from_storage(SharedStorage::from_arrow_buffer(value)) - } -} - -#[cfg(feature = "arrow_rs")] -impl From> for arrow_buffer::Buffer { - fn from(value: Buffer) -> Self { - let offset = value.offset(); - value.storage.into_arrow_buffer().slice_with_length( - offset * std::mem::size_of::(), - value.length * std::mem::size_of::(), - ) - } -} - unsafe impl<'a, T: 'a> ArrayAccessor<'a> for Buffer { type Item = &'a T; diff --git a/crates/polars-arrow/src/compute/aggregate/memory.rs b/crates/polars-arrow/src/compute/aggregate/memory.rs index bd4ba7ab6384..1cf7512bbbef 100644 --- a/crates/polars-arrow/src/compute/aggregate/memory.rs +++ b/crates/polars-arrow/src/compute/aggregate/memory.rs @@ -18,7 +18,7 @@ macro_rules! dyn_binary { let values_end = offsets[offsets.len() - 1] as usize; values_end - values_start - + offsets.len() * std::mem::size_of::<$o>() + + offsets.len() * size_of::<$o>() + validity_size(array.validity()) }}; } @@ -50,7 +50,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { }, Primitive(PrimitiveType::DaysMs) => { let array = array.as_any().downcast_ref::().unwrap(); - array.values().len() * std::mem::size_of::() * 2 + validity_size(array.validity()) + array.values().len() * size_of::() * 2 + validity_size(array.validity()) }, Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let array = array @@ -58,7 +58,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { .downcast_ref::>() .unwrap(); - array.values().len() * std::mem::size_of::<$T>() + validity_size(array.validity()) + array.values().len() * size_of::<$T>() + validity_size(array.validity()) }), Binary => dyn_binary!(array, BinaryArray, i32), FixedSizeBinary => { @@ -74,7 +74,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { List => { let array = array.as_any().downcast_ref::>().unwrap(); estimated_bytes_size(array.values().as_ref()) - + array.offsets().len_proxy() * std::mem::size_of::() + + array.offsets().len_proxy() * size_of::() + validity_size(array.validity()) }, FixedSizeList => { @@ -84,7 +84,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { LargeList => { let array = array.as_any().downcast_ref::>().unwrap(); estimated_bytes_size(array.values().as_ref()) - + array.offsets().len_proxy() * std::mem::size_of::() + + array.offsets().len_proxy() * size_of::() + validity_size(array.validity()) }, Struct => { @@ -99,11 +99,11 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { }, Union => { let array = array.as_any().downcast_ref::().unwrap(); - let types = array.types().len() * std::mem::size_of::(); + let types = array.types().len() * size_of::(); let offsets = array .offsets() .as_ref() - .map(|x| x.len() * std::mem::size_of::()) + .map(|x| x.len() * size_of::()) .unwrap_or_default(); let fields = array .fields() @@ -124,7 +124,7 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { BinaryView => binview_size::<[u8]>(array.as_any().downcast_ref().unwrap()), Map => { let array = array.as_any().downcast_ref::().unwrap(); - let offsets = array.offsets().len_proxy() * std::mem::size_of::(); + let offsets = array.offsets().len_proxy() * size_of::(); offsets + estimated_bytes_size(array.field().as_ref()) + validity_size(array.validity()) }, } diff --git a/crates/polars-arrow/src/compute/aggregate/mod.rs b/crates/polars-arrow/src/compute/aggregate/mod.rs index 9528f833a67e..481194c1551c 100644 --- a/crates/polars-arrow/src/compute/aggregate/mod.rs +++ b/crates/polars-arrow/src/compute/aggregate/mod.rs @@ -1,4 +1,4 @@ -//! Contains different aggregation functions +/// ! Contains different aggregation functions #[cfg(feature = "compute_aggregate")] mod sum; #[cfg(feature = "compute_aggregate")] diff --git a/crates/polars-arrow/src/compute/aggregate/sum.rs b/crates/polars-arrow/src/compute/aggregate/sum.rs index 9fbed5f8b1b6..e2098d969e03 100644 --- a/crates/polars-arrow/src/compute/aggregate/sum.rs +++ b/crates/polars-arrow/src/compute/aggregate/sum.rs @@ -6,7 +6,7 @@ use polars_error::PolarsResult; use crate::array::{Array, PrimitiveArray}; use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; use crate::bitmap::Bitmap; -use crate::datatypes::{ArrowDataType, PhysicalType, PrimitiveType}; +use crate::datatypes::PhysicalType; use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; @@ -102,19 +102,6 @@ where } } -/// Whether [`sum`] supports `dtype` -pub fn can_sum(dtype: &ArrowDataType) -> bool { - if let PhysicalType::Primitive(primitive) = dtype.to_physical_type() { - use PrimitiveType::*; - matches!( - primitive, - Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 - ) - } else { - false - } -} - /// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical /// and logical types as `array`. /// # Error diff --git a/crates/polars-arrow/src/compute/arity.rs b/crates/polars-arrow/src/compute/arity.rs index c5b397f6faac..2670cfb4d031 100644 --- a/crates/polars-arrow/src/compute/arity.rs +++ b/crates/polars-arrow/src/compute/arity.rs @@ -1,10 +1,7 @@ //! Defines kernels suitable to perform operations to primitive arrays. -use polars_error::PolarsResult; - use super::utils::{check_same_len, combine_validities_and}; use crate::array::PrimitiveArray; -use crate::bitmap::{Bitmap, MutableBitmap}; use crate::datatypes::ArrowDataType; use crate::types::NativeType; @@ -29,104 +26,6 @@ where PrimitiveArray::::new(dtype, values.into(), array.validity().cloned()) } -/// Version of unary that checks for errors in the closure used to create the -/// buffer -pub fn try_unary( - array: &PrimitiveArray, - op: F, - dtype: ArrowDataType, -) -> PolarsResult> -where - I: NativeType, - O: NativeType, - F: Fn(I) -> PolarsResult, -{ - let values = array - .values() - .iter() - .map(|v| op(*v)) - .collect::>>()? - .into(); - - Ok(PrimitiveArray::::new( - dtype, - values, - array.validity().cloned(), - )) -} - -/// Version of unary that returns an array and bitmap. Used when working with -/// overflowing operations -pub fn unary_with_bitmap( - array: &PrimitiveArray, - op: F, - dtype: ArrowDataType, -) -> (PrimitiveArray, Bitmap) -where - I: NativeType, - O: NativeType, - F: Fn(I) -> (O, bool), -{ - let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); - - let values = array - .values() - .iter() - .map(|v| { - let (res, over) = op(*v); - mut_bitmap.push(over); - res - }) - .collect::>() - .into(); - - ( - PrimitiveArray::::new(dtype, values, array.validity().cloned()), - mut_bitmap.into(), - ) -} - -/// Version of unary that creates a mutable bitmap that is used to keep track -/// of checked operations. The resulting bitmap is compared with the array -/// bitmap to create the final validity array. -pub fn unary_checked( - array: &PrimitiveArray, - op: F, - dtype: ArrowDataType, -) -> PrimitiveArray -where - I: NativeType, - O: NativeType, - F: Fn(I) -> Option, -{ - let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); - - let values = array - .values() - .iter() - .map(|v| match op(*v) { - Some(val) => { - mut_bitmap.push(true); - val - }, - None => { - mut_bitmap.push(false); - O::default() - }, - }) - .collect::>() - .into(); - - // The validity has to be checked against the bitmap created during the - // creation of the values with the iterator. If an error was found during - // the iteration, then the validity is changed to None to mark the value - // as Null - let bitmap: Bitmap = mut_bitmap.into(); - let validity = combine_validities_and(array.validity(), Some(&bitmap)); - - PrimitiveArray::::new(dtype, values, validity) -} - /// Applies a binary operations to two primitive arrays. /// /// This is the fastest way to perform an operation on two primitive array when the benefits of a @@ -169,115 +68,3 @@ where PrimitiveArray::::new(dtype, values, validity) } - -/// Version of binary that checks for errors in the closure used to create the -/// buffer -pub fn try_binary( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - dtype: ArrowDataType, - op: F, -) -> PolarsResult> -where - T: NativeType, - D: NativeType, - F: Fn(T, D) -> PolarsResult, -{ - check_same_len(lhs, rhs)?; - - let validity = combine_validities_and(lhs.validity(), rhs.validity()); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| op(*l, *r)) - .collect::>>()? - .into(); - - Ok(PrimitiveArray::::new(dtype, values, validity)) -} - -/// Version of binary that returns an array and bitmap. Used when working with -/// overflowing operations -pub fn binary_with_bitmap( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - dtype: ArrowDataType, - op: F, -) -> (PrimitiveArray, Bitmap) -where - T: NativeType, - D: NativeType, - F: Fn(T, D) -> (T, bool), -{ - check_same_len(lhs, rhs).unwrap(); - - let validity = combine_validities_and(lhs.validity(), rhs.validity()); - - let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| { - let (res, over) = op(*l, *r); - mut_bitmap.push(over); - res - }) - .collect::>() - .into(); - - ( - PrimitiveArray::::new(dtype, values, validity), - mut_bitmap.into(), - ) -} - -/// Version of binary that creates a mutable bitmap that is used to keep track -/// of checked operations. The resulting bitmap is compared with the array -/// bitmap to create the final validity array. -pub fn binary_checked( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - dtype: ArrowDataType, - op: F, -) -> PrimitiveArray -where - T: NativeType, - D: NativeType, - F: Fn(T, D) -> Option, -{ - check_same_len(lhs, rhs).unwrap(); - - let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| match op(*l, *r) { - Some(val) => { - mut_bitmap.push(true); - val - }, - None => { - mut_bitmap.push(false); - T::default() - }, - }) - .collect::>() - .into(); - - let bitmap: Bitmap = mut_bitmap.into(); - let validity = combine_validities_and(lhs.validity(), rhs.validity()); - - // The validity has to be checked against the bitmap created during the - // creation of the values with the iterator. If an error was found during - // the iteration, then the validity is changed to None to mark the value - // as Null - let validity = combine_validities_and(validity.as_ref(), Some(&bitmap)); - - PrimitiveArray::::new(dtype, values, validity) -} diff --git a/crates/polars-arrow/src/compute/cast/binary_to.rs b/crates/polars-arrow/src/compute/cast/binary_to.rs index e14a03040522..5d2bd3e0b6d9 100644 --- a/crates/polars-arrow/src/compute/cast/binary_to.rs +++ b/crates/polars-arrow/src/compute/cast/binary_to.rs @@ -92,19 +92,6 @@ pub fn binary_to_utf8( ) } -/// Conversion to utf8 -/// # Errors -/// This function errors if the values are not valid utf8 -pub fn binary_to_large_utf8( - from: &BinaryArray, - to_dtype: ArrowDataType, -) -> PolarsResult> { - let values = from.values().clone(); - let offsets = from.offsets().into(); - - Utf8Array::::try_new(to_dtype, offsets, values, from.validity().cloned()) -} - /// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. pub(super) fn binary_to_primitive( from: &BinaryArray, @@ -212,7 +199,7 @@ pub fn fixed_size_binary_to_binview(from: &FixedSizeBinaryArray) -> BinaryViewAr // This is NOT equal to MAX_BYTES_PER_BUFFER because of integer division let split_point = num_elements_per_buffer * size; - // This is zero-copy for the buffer since split just increases the the data since + // This is zero-copy for the buffer since split just increases the data since let mut buffer = from.values().clone(); let mut buffers = Vec::with_capacity(num_buffers); for _ in 0..num_buffers - 1 { diff --git a/crates/polars-arrow/src/compute/cast/dictionary_to.rs b/crates/polars-arrow/src/compute/cast/dictionary_to.rs index d67a116ca0de..3f6211ad544e 100644 --- a/crates/polars-arrow/src/compute/cast/dictionary_to.rs +++ b/crates/polars-arrow/src/compute/cast/dictionary_to.rs @@ -1,6 +1,6 @@ use polars_error::{polars_bail, PolarsResult}; -use super::{primitive_as_primitive, primitive_to_primitive, CastOptionsImpl}; +use super::{primitive_to_primitive, CastOptionsImpl}; use crate::array::{Array, DictionaryArray, DictionaryKey}; use crate::compute::cast::cast; use crate::datatypes::ArrowDataType; @@ -23,101 +23,6 @@ macro_rules! key_cast { }}; } -/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] by keeping the -/// keys and casting the values to `values_type`. -/// # Errors -/// This function errors if the values are not castable to `values_type` -pub fn dictionary_to_dictionary_values( - from: &DictionaryArray, - values_type: &ArrowDataType, -) -> PolarsResult> { - let keys = from.keys(); - let values = from.values(); - let length = values.len(); - - let values = cast(values.as_ref(), values_type, CastOptionsImpl::default())?; - - assert_eq!(values.len(), length); // this is guaranteed by `cast` - unsafe { - DictionaryArray::try_new_unchecked(from.dtype().clone(), keys.clone(), values.clone()) - } -} - -/// Similar to dictionary_to_dictionary_values, but overflowing cast is wrapped -pub fn wrapping_dictionary_to_dictionary_values( - from: &DictionaryArray, - values_type: &ArrowDataType, -) -> PolarsResult> { - let keys = from.keys(); - let values = from.values(); - let length = values.len(); - - let values = cast( - values.as_ref(), - values_type, - CastOptionsImpl { - wrapped: true, - partial: false, - }, - )?; - assert_eq!(values.len(), length); // this is guaranteed by `cast` - unsafe { - DictionaryArray::try_new_unchecked(from.dtype().clone(), keys.clone(), values.clone()) - } -} - -/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] backed by a -/// different physical type of the keys, while keeping the values equal. -/// # Errors -/// Errors if any of the old keys' values is larger than the maximum value -/// supported by the new physical type. -pub fn dictionary_to_dictionary_keys( - from: &DictionaryArray, -) -> PolarsResult> -where - K1: DictionaryKey + num_traits::NumCast, - K2: DictionaryKey + num_traits::NumCast, -{ - let keys = from.keys(); - let values = from.values(); - let is_ordered = from.is_ordered(); - - let casted_keys = primitive_to_primitive::(keys, &K2::PRIMITIVE.into()); - - if casted_keys.null_count() > keys.null_count() { - polars_bail!(ComputeError: "overflow") - } else { - let dtype = - ArrowDataType::Dictionary(K2::KEY_TYPE, Box::new(values.dtype().clone()), is_ordered); - // SAFETY: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` - unsafe { DictionaryArray::try_new_unchecked(dtype, casted_keys, values.clone()) } - } -} - -/// Similar to dictionary_to_dictionary_keys, but overflowing cast is wrapped -pub fn wrapping_dictionary_to_dictionary_keys( - from: &DictionaryArray, -) -> PolarsResult> -where - K1: DictionaryKey + num_traits::AsPrimitive, - K2: DictionaryKey, -{ - let keys = from.keys(); - let values = from.values(); - let is_ordered = from.is_ordered(); - - let casted_keys = primitive_as_primitive::(keys, &K2::PRIMITIVE.into()); - - if casted_keys.null_count() > keys.null_count() { - polars_bail!(ComputeError: "overflow") - } else { - let dtype = - ArrowDataType::Dictionary(K2::KEY_TYPE, Box::new(values.dtype().clone()), is_ordered); - // some of the values may not fit in `usize` and thus this needs to be checked - DictionaryArray::try_new(dtype, casted_keys, values.clone()) - } -} - pub(super) fn dictionary_cast_dyn( array: &dyn Array, to_type: &ArrowDataType, diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 2a544fc7209f..f34d9ebba2a5 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -15,7 +15,7 @@ use binview_to::binview_to_primitive_dyn; pub use binview_to::utf8view_to_utf8; pub use boolean_to::*; pub use decimal_to::*; -pub use dictionary_to::*; +use dictionary_to::*; use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult}; use polars_utils::IdxSize; pub use primitive_to::*; @@ -338,16 +338,6 @@ pub fn cast( array.as_any().downcast_ref().unwrap(), ) .boxed()), - UInt8 => binview_to_primitive_dyn::(array, to_type, options), - UInt16 => binview_to_primitive_dyn::(array, to_type, options), - UInt32 => binview_to_primitive_dyn::(array, to_type, options), - UInt64 => binview_to_primitive_dyn::(array, to_type, options), - Int8 => binview_to_primitive_dyn::(array, to_type, options), - Int16 => binview_to_primitive_dyn::(array, to_type, options), - Int32 => binview_to_primitive_dyn::(array, to_type, options), - Int64 => binview_to_primitive_dyn::(array, to_type, options), - Float32 => binview_to_primitive_dyn::(array, to_type, options), - Float64 => binview_to_primitive_dyn::(array, to_type, options), LargeList(inner) if matches!(inner.dtype, ArrowDataType::UInt8) => { let bin_array = view_to_binary::(array.as_any().downcast_ref().unwrap()); Ok(binary_to_list(&bin_array, to_type.clone()).boxed()) @@ -404,17 +394,16 @@ pub fn cast( match to_type { BinaryView => Ok(arr.to_binview().boxed()), LargeUtf8 => Ok(binview_to::utf8view_to_utf8::(arr).boxed()), - UInt8 - | UInt16 - | UInt32 - | UInt64 - | Int8 - | Int16 - | Int32 - | Int64 - | Float32 - | Float64 - | Decimal(_, _) => cast(&arr.to_binview(), to_type, options), + UInt8 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt16 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int8 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int16 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Float32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Float64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), Timestamp(time_unit, None) => { utf8view_to_naive_timestamp_dyn(array, time_unit.to_owned()) }, diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index d017b0a8e212..b0d84d81a3dd 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -8,10 +8,10 @@ use super::CastOptionsImpl; use crate::array::*; use crate::bitmap::Bitmap; use crate::compute::arity::unary; -use crate::datatypes::{ArrowDataType, IntervalUnit, TimeUnit}; +use crate::datatypes::{ArrowDataType, TimeUnit}; use crate::offset::{Offset, Offsets}; use crate::temporal_conversions::*; -use crate::types::{days_ms, f16, months_days_ns, NativeType}; +use crate::types::{f16, NativeType}; pub trait SerPrimitive { fn write(f: &mut Vec, val: Self) -> usize @@ -525,169 +525,6 @@ pub fn timestamp_to_timestamp( } } -fn timestamp_to_utf8_impl( - from: &PrimitiveArray, - time_unit: TimeUnit, - timezone: T, -) -> Utf8Array -where - T::Offset: std::fmt::Display, -{ - match time_unit { - TimeUnit::Nanosecond => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_ns_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Microsecond => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_us_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Millisecond => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_ms_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Second => { - let iter = from.iter().map(|x| { - x.map(|x| { - let datetime = timestamp_s_to_datetime(*x); - let offset = timezone.offset_from_utc_datetime(&datetime); - chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() - }) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - } -} - -#[cfg(feature = "chrono-tz")] -#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] -fn chrono_tz_timestamp_to_utf8( - from: &PrimitiveArray, - time_unit: TimeUnit, - timezone_str: &str, -) -> PolarsResult> { - let timezone = parse_offset_tz(timezone_str)?; - Ok(timestamp_to_utf8_impl::( - from, time_unit, timezone, - )) -} - -#[cfg(not(feature = "chrono-tz"))] -fn chrono_tz_timestamp_to_utf8( - _: &PrimitiveArray, - _: TimeUnit, - timezone_str: &str, -) -> PolarsResult> { - panic!( - "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", - timezone_str - ) -} - -/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. -pub fn timestamp_to_utf8( - from: &PrimitiveArray, - time_unit: TimeUnit, - timezone_str: &str, -) -> PolarsResult> { - let timezone = parse_offset(timezone_str); - - if let Ok(timezone) = timezone { - Ok(timestamp_to_utf8_impl::( - from, time_unit, timezone, - )) - } else { - chrono_tz_timestamp_to_utf8(from, time_unit, timezone_str) - } -} - -/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. -pub fn naive_timestamp_to_utf8( - from: &PrimitiveArray, - time_unit: TimeUnit, -) -> Utf8Array { - match time_unit { - TimeUnit::Nanosecond => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_ns_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Microsecond => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_us_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Millisecond => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_ms_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - TimeUnit::Second => { - let iter = from.iter().map(|x| { - x.copied() - .map(timestamp_s_to_datetime) - .map(|x| x.to_string()) - }); - Utf8Array::from_trusted_len_iter(iter) - }, - } -} - -#[inline] -fn days_ms_to_months_days_ns_scalar(from: days_ms) -> months_days_ns { - months_days_ns::new(0, from.days(), from.milliseconds() as i64 * 1000) -} - -/// Casts [`days_ms`]s to [`months_days_ns`]. This operation is infalible and lossless. -pub fn days_ms_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { - unary( - from, - days_ms_to_months_days_ns_scalar, - ArrowDataType::Interval(IntervalUnit::MonthDayNano), - ) -} - -#[inline] -fn months_to_months_days_ns_scalar(from: i32) -> months_days_ns { - months_days_ns::new(from, 0, 0) -} - -/// Casts months represented as [`i32`]s to [`months_days_ns`]. This operation is infalible and lossless. -pub fn months_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { - unary( - from, - months_to_months_days_ns_scalar, - ArrowDataType::Interval(IntervalUnit::MonthDayNano), - ) -} - /// Casts f16 into f32 pub fn f16_to_f32(from: &PrimitiveArray) -> PrimitiveArray { unary(from, |x| x.to_f32(), ArrowDataType::Float32) diff --git a/crates/polars-arrow/src/compute/take/fixed_size_list.rs b/crates/polars-arrow/src/compute/take/fixed_size_list.rs index 9eccd4bc043b..2a52a1ae3fd1 100644 --- a/crates/polars-arrow/src/compute/take/fixed_size_list.rs +++ b/crates/polars-arrow/src/compute/take/fixed_size_list.rs @@ -15,22 +15,34 @@ // specific language governing permissions and limitations // under the License. +use std::mem::ManuallyDrop; + +use polars_utils::itertools::Itertools; +use polars_utils::IdxSize; + use super::Index; use crate::array::growable::{Growable, GrowableFixedSizeList}; -use crate::array::{FixedSizeListArray, PrimitiveArray}; +use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray}; +use crate::bitmap::MutableBitmap; +use crate::compute::take::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked}; +use crate::compute::utils::combine_validities_and; +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::legacy::prelude::FromData; +use crate::with_match_primitive_type; -/// `take` implementation for FixedSizeListArrays -pub(super) unsafe fn take_unchecked( +pub(super) unsafe fn take_unchecked_slow( values: &FixedSizeListArray, indices: &PrimitiveArray, ) -> FixedSizeListArray { + let take_len = std::cmp::min(values.len(), 1); let mut capacity = 0; let arrays = indices .values() .iter() .map(|index| { let index = index.to_usize(); - let slice = values.clone().sliced(index, 1); + let slice = values.clone().sliced_unchecked(index, take_len); capacity += slice.len(); slice }) @@ -61,3 +73,261 @@ pub(super) unsafe fn take_unchecked( growable.into() } } + +fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) { + if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype { + get_stride_and_leaf_type(inner.dtype(), *size_inner * size) + } else { + (size, dtype) + } +} + +fn get_leaves(array: &FixedSizeListArray) -> &dyn Array { + if let Some(array) = array.values().as_any().downcast_ref::() { + get_leaves(array) + } else { + &**array.values() + } +} + +fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) { + match array.dtype().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let arr = array.as_any().downcast_ref::>().unwrap(); + let values = arr.values(); + (bytemuck::cast_slice(values), size_of::<$T>()) + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn from_buffer(mut buf: ManuallyDrop>, dtype: &ArrowDataType) -> ArrayRef { + match dtype.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let ptr = buf.as_mut_ptr(); + let len_units = buf.len(); + let cap_units = buf.capacity(); + + let buf = Vec::from_raw_parts( + ptr as *mut $T, + len_units / size_of::<$T>(), + cap_units / size_of::<$T>(), + ); + + PrimitiveArray::from_data_default(buf.into(), None).boxed() + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn aligned_vec(dt: &ArrowDataType, n_bytes: usize) -> Vec { + match dt.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let n_units = (n_bytes / size_of::<$T>()) + 1; + + let mut aligned: Vec<$T> = Vec::with_capacity(n_units); + + let ptr = aligned.as_mut_ptr(); + let len_units = aligned.len(); + let cap_units = aligned.capacity(); + + std::mem::forget(aligned); + + Vec::from_raw_parts( + ptr as *mut u8, + len_units * size_of::<$T>(), + cap_units * size_of::<$T>(), + ) + + }), + _ => { + unimplemented!() + }, + } +} + +fn arr_no_validities_recursive(arr: &dyn Array) -> bool { + arr.validity().is_none() + && arr + .as_any() + .downcast_ref::() + .map_or(true, |x| arr_no_validities_recursive(x.values().as_ref())) +} + +/// `take` implementation for FixedSizeListArrays +pub(super) unsafe fn take_unchecked( + values: &FixedSizeListArray, + indices: &PrimitiveArray, +) -> ArrayRef { + let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1); + if leaf_type.to_physical_type().is_primitive() + && arr_no_validities_recursive(values.values().as_ref()) + { + let leaves = get_leaves(values); + + let (leaves_buf, leave_size) = get_buffer_and_size(leaves); + let bytes_per_element = leave_size * stride; + + let n_idx = indices.len(); + let total_bytes = bytes_per_element * n_idx; + + let mut buf = ManuallyDrop::new(aligned_vec(leaves.dtype(), total_bytes)); + let dst = buf.spare_capacity_mut(); + + let mut count = 0; + let outer_validity = if indices.null_count() == 0 { + for i in indices.values().iter() { + let i = i.to_usize(); + + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); + count += 1; + } + None + } else { + let mut new_validity = MutableBitmap::with_capacity(indices.len()); + new_validity.extend_constant(indices.len(), true); + for i in indices.iter() { + if let Some(i) = i { + let i = i.to_usize(); + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); + } else { + new_validity.set_unchecked(count, false); + std::ptr::write_bytes( + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + 0, + bytes_per_element, + ); + } + + count += 1; + } + Some(new_validity.freeze()) + }; + + assert_eq!(count * bytes_per_element, total_bytes); + buf.set_len(total_bytes); + + let outer_validity = combine_validities_and( + outer_validity.as_ref(), + values + .validity() + .map(|x| { + if indices.has_nulls() { + take_bitmap_nulls_unchecked(x, indices) + } else { + take_bitmap_unchecked(x, indices.as_slice().unwrap()) + } + }) + .as_ref(), + ); + + let leaves = from_buffer(buf, leaves.dtype()); + let mut shape = values.get_dims(); + shape[0] = Dimension::new(indices.len() as _); + let shape = shape + .into_iter() + .map(ReshapeDimension::Specified) + .collect_vec(); + + FixedSizeListArray::from_shape(leaves.clone(), &shape) + .unwrap() + .with_validity(outer_validity) + } else { + take_unchecked_slow(values, indices).boxed() + } +} + +#[cfg(test)] +mod tests { + use crate::array::StaticArray; + use crate::datatypes::ArrowDataType; + + /// Test gather for FixedSizeListArray with outer validity but no inner validities. + #[test] + fn test_arr_gather_nulls_outer_validity_19482() { + use polars_utils::IdxSize; + + use super::take_unchecked; + use crate::array::{FixedSizeListArray, Int64Array, PrimitiveArray}; + use crate::bitmap::Bitmap; + use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + + unsafe { + let dyn_arr = FixedSizeListArray::from_shape( + Box::new(Int64Array::from_slice([1, 2, 3, 4])), + &[ + ReshapeDimension::Specified(Dimension::new(2)), + ReshapeDimension::Specified(Dimension::new(2)), + ], + ) + .unwrap() + .with_validity(Some(Bitmap::from_iter([true, false]))); // FixedSizeListArray[[1, 2], None] + + let arr = dyn_arr + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + [arr.validity().is_some(), arr.values().validity().is_some()], + [true, false] + ); + + assert_eq!( + take_unchecked(arr, &PrimitiveArray::::from_slice([0, 1])), + dyn_arr + ) + } + } + + #[test] + fn test_arr_gather_nulls_inner_validity() { + use polars_utils::IdxSize; + + use super::take_unchecked; + use crate::array::{FixedSizeListArray, Int64Array, PrimitiveArray}; + use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + + unsafe { + let dyn_arr = FixedSizeListArray::from_shape( + Box::new(Int64Array::full_null(4, ArrowDataType::Int64)), + &[ + ReshapeDimension::Specified(Dimension::new(2)), + ReshapeDimension::Specified(Dimension::new(2)), + ], + ) + .unwrap(); // FixedSizeListArray[[None, None], [None, None]] + + let arr = dyn_arr + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + [arr.validity().is_some(), arr.values().validity().is_some()], + [false, true] + ); + + assert_eq!( + take_unchecked(arr, &PrimitiveArray::::from_slice([0, 1])), + dyn_arr + ) + } + } +} diff --git a/crates/polars-arrow/src/compute/take/mod.rs b/crates/polars-arrow/src/compute/take/mod.rs index aed14823af1e..bdd782a1d609 100644 --- a/crates/polars-arrow/src/compute/take/mod.rs +++ b/crates/polars-arrow/src/compute/take/mod.rs @@ -68,7 +68,7 @@ pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box { let array = values.as_any().downcast_ref().unwrap(); - Box::new(fixed_size_list::take_unchecked(array, indices)) + fixed_size_list::take_unchecked(array, indices) }, BinaryView => { take_binview_unchecked(values.as_any().downcast_ref().unwrap(), indices).boxed() diff --git a/crates/polars-arrow/src/compute/temporal.rs b/crates/polars-arrow/src/compute/temporal.rs index 309493fbbbdb..98bc920caeb3 100644 --- a/crates/polars-arrow/src/compute/temporal.rs +++ b/crates/polars-arrow/src/compute/temporal.rs @@ -75,8 +75,6 @@ macro_rules! date_like { } /// Extracts the years of a temporal array as [`PrimitiveArray`]. -/// -/// Use [`can_year`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn year(array: &dyn Array) -> PolarsResult> { date_like!(year, array, ArrowDataType::Int32) } @@ -84,7 +82,6 @@ pub fn year(array: &dyn Array) -> PolarsResult> { /// Extracts the months of a temporal array as [`PrimitiveArray`]. /// /// Value ranges from 1 to 12. -/// Use [`can_month`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn month(array: &dyn Array) -> PolarsResult> { date_like!(month, array, ArrowDataType::Int8) } @@ -92,7 +89,6 @@ pub fn month(array: &dyn Array) -> PolarsResult> { /// Extracts the days of a temporal array as [`PrimitiveArray`]. /// /// Value ranges from 1 to 32 (Last day depends on month). -/// Use [`can_day`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn day(array: &dyn Array) -> PolarsResult> { date_like!(day, array, ArrowDataType::Int8) } @@ -100,7 +96,6 @@ pub fn day(array: &dyn Array) -> PolarsResult> { /// Extracts weekday of a temporal array as [`PrimitiveArray`]. /// /// Monday is 1, Tuesday is 2, ..., Sunday is 7. -/// Use [`can_weekday`] to check if this operation is supported for the target [`ArrowDataType`] pub fn weekday(array: &dyn Array) -> PolarsResult> { date_like!(i8_weekday, array, ArrowDataType::Int8) } @@ -108,7 +103,6 @@ pub fn weekday(array: &dyn Array) -> PolarsResult> { /// Extracts ISO week of a temporal array as [`PrimitiveArray`]. /// /// Value ranges from 1 to 53 (Last week depends on the year). -/// Use [`can_iso_week`] to check if this operation is supported for the target [`ArrowDataType`] pub fn iso_week(array: &dyn Array) -> PolarsResult> { date_like!(i8_iso_week, array, ArrowDataType::Int8) } @@ -345,86 +339,3 @@ where }, } } - -/// Checks if an array of type `datatype` can perform year operation -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::temporal::can_year; -/// use polars_arrow::datatypes::{ArrowDataType}; -/// -/// assert_eq!(can_year(&ArrowDataType::Date32), true); -/// assert_eq!(can_year(&ArrowDataType::Int8), false); -/// ``` -pub fn can_year(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `datatype` can perform month operation -pub fn can_month(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `datatype` can perform day operation -pub fn can_day(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `dtype` can perform weekday operation -pub fn can_weekday(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -/// Checks if an array of type `dtype` can perform ISO week operation -pub fn can_iso_week(dtype: &ArrowDataType) -> bool { - can_date(dtype) -} - -fn can_date(dtype: &ArrowDataType) -> bool { - matches!( - dtype, - ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, _) - ) -} - -/// Checks if an array of type `datatype` can perform hour operation -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::temporal::can_hour; -/// use polars_arrow::datatypes::{ArrowDataType, TimeUnit}; -/// -/// assert_eq!(can_hour(&ArrowDataType::Time32(TimeUnit::Second)), true); -/// assert_eq!(can_hour(&ArrowDataType::Int8), false); -/// ``` -pub fn can_hour(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -/// Checks if an array of type `datatype` can perform minute operation -pub fn can_minute(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -/// Checks if an array of type `datatype` can perform second operation -pub fn can_second(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -/// Checks if an array of type `datatype` can perform nanosecond operation -pub fn can_nanosecond(dtype: &ArrowDataType) -> bool { - can_time(dtype) -} - -fn can_time(dtype: &ArrowDataType) -> bool { - matches!( - dtype, - ArrowDataType::Time32(TimeUnit::Second) - | ArrowDataType::Time32(TimeUnit::Millisecond) - | ArrowDataType::Time64(TimeUnit::Microsecond) - | ArrowDataType::Time64(TimeUnit::Nanosecond) - | ArrowDataType::Date32 - | ArrowDataType::Date64 - | ArrowDataType::Timestamp(_, _) - ) -} diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs index 0b8e1ecd69f4..2402283aa4ef 100644 --- a/crates/polars-arrow/src/compute/utils.rs +++ b/crates/polars-arrow/src/compute/utils.rs @@ -38,6 +38,7 @@ pub fn combine_validities_or(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> _ => None, } } + pub fn combine_validities_and_not( opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>, diff --git a/crates/polars-arrow/src/datatypes/field.rs b/crates/polars-arrow/src/datatypes/field.rs index 8bf18af82f46..b1a5baf5c0ee 100644 --- a/crates/polars-arrow/src/datatypes/field.rs +++ b/crates/polars-arrow/src/datatypes/field.rs @@ -60,60 +60,3 @@ impl Field { &self.dtype } } - -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::Field { - fn from(value: Field) -> Self { - Self::new( - value.name.to_string(), - value.dtype.into(), - value.is_nullable, - ) - .with_metadata( - value - .metadata - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(), - ) - } -} - -#[cfg(feature = "arrow_rs")] -impl From for Field { - fn from(value: arrow_schema::Field) -> Self { - (&value).into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&arrow_schema::Field> for Field { - fn from(value: &arrow_schema::Field) -> Self { - let dtype = value.data_type().clone().into(); - let metadata = value - .metadata() - .iter() - .map(|(k, v)| (PlSmallStr::from_str(k), PlSmallStr::from_str(v))) - .collect(); - Self::new( - PlSmallStr::from_str(value.name().as_str()), - dtype, - value.is_nullable(), - ) - .with_metadata(metadata) - } -} - -#[cfg(feature = "arrow_rs")] -impl From for Field { - fn from(value: arrow_schema::FieldRef) -> Self { - value.as_ref().into() - } -} - -#[cfg(feature = "arrow_rs")] -impl From<&arrow_schema::FieldRef> for Field { - fn from(value: &arrow_schema::FieldRef) -> Self { - value.as_ref().into() - } -} diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 8f2226c709e6..0c0b7024bc71 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -2,6 +2,7 @@ mod field; mod physical_type; +pub mod reshape; mod schema; use std::collections::BTreeMap; @@ -175,143 +176,6 @@ pub enum ArrowDataType { Unknown, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::DataType { - fn from(value: ArrowDataType) -> Self { - use arrow_schema::{Field as ArrowField, UnionFields}; - - match value { - ArrowDataType::Null => Self::Null, - ArrowDataType::Boolean => Self::Boolean, - ArrowDataType::Int8 => Self::Int8, - ArrowDataType::Int16 => Self::Int16, - ArrowDataType::Int32 => Self::Int32, - ArrowDataType::Int64 => Self::Int64, - ArrowDataType::UInt8 => Self::UInt8, - ArrowDataType::UInt16 => Self::UInt16, - ArrowDataType::UInt32 => Self::UInt32, - ArrowDataType::UInt64 => Self::UInt64, - ArrowDataType::Float16 => Self::Float16, - ArrowDataType::Float32 => Self::Float32, - ArrowDataType::Float64 => Self::Float64, - ArrowDataType::Timestamp(unit, tz) => { - Self::Timestamp(unit.into(), tz.map(|x| Arc::::from(x.as_str()))) - }, - ArrowDataType::Date32 => Self::Date32, - ArrowDataType::Date64 => Self::Date64, - ArrowDataType::Time32(unit) => Self::Time32(unit.into()), - ArrowDataType::Time64(unit) => Self::Time64(unit.into()), - ArrowDataType::Duration(unit) => Self::Duration(unit.into()), - ArrowDataType::Interval(unit) => Self::Interval(unit.into()), - ArrowDataType::Binary => Self::Binary, - ArrowDataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), - ArrowDataType::LargeBinary => Self::LargeBinary, - ArrowDataType::Utf8 => Self::Utf8, - ArrowDataType::LargeUtf8 => Self::LargeUtf8, - ArrowDataType::List(f) => Self::List(Arc::new((*f).into())), - ArrowDataType::FixedSizeList(f, size) => { - Self::FixedSizeList(Arc::new((*f).into()), size as _) - }, - ArrowDataType::LargeList(f) => Self::LargeList(Arc::new((*f).into())), - ArrowDataType::Struct(f) => Self::Struct(f.into_iter().map(ArrowField::from).collect()), - ArrowDataType::Union(fields, Some(ids), mode) => { - let ids = ids.into_iter().map(|x| x as _); - let fields = fields.into_iter().map(ArrowField::from); - Self::Union(UnionFields::new(ids, fields), mode.into()) - }, - ArrowDataType::Union(fields, None, mode) => { - let ids = 0..fields.len() as i8; - let fields = fields.into_iter().map(ArrowField::from); - Self::Union(UnionFields::new(ids, fields), mode.into()) - }, - ArrowDataType::Map(f, ordered) => Self::Map(Arc::new((*f).into()), ordered), - ArrowDataType::Dictionary(key, value, _) => Self::Dictionary( - Box::new(ArrowDataType::from(key).into()), - Box::new((*value).into()), - ), - ArrowDataType::Decimal(precision, scale) => { - Self::Decimal128(precision as _, scale as _) - }, - ArrowDataType::Decimal256(precision, scale) => { - Self::Decimal256(precision as _, scale as _) - }, - ArrowDataType::Extension(_, d, _) => (*d).into(), - ArrowDataType::BinaryView | ArrowDataType::Utf8View => { - panic!("view datatypes not supported by arrow-rs") - }, - ArrowDataType::Unknown => unimplemented!(), - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for ArrowDataType { - fn from(value: arrow_schema::DataType) -> Self { - use arrow_schema::DataType; - match value { - DataType::Null => Self::Null, - DataType::Boolean => Self::Boolean, - DataType::Int8 => Self::Int8, - DataType::Int16 => Self::Int16, - DataType::Int32 => Self::Int32, - DataType::Int64 => Self::Int64, - DataType::UInt8 => Self::UInt8, - DataType::UInt16 => Self::UInt16, - DataType::UInt32 => Self::UInt32, - DataType::UInt64 => Self::UInt64, - DataType::Float16 => Self::Float16, - DataType::Float32 => Self::Float32, - DataType::Float64 => Self::Float64, - DataType::Timestamp(unit, tz) => { - Self::Timestamp(unit.into(), tz.map(|x| PlSmallStr::from_str(x.as_ref()))) - }, - DataType::Date32 => Self::Date32, - DataType::Date64 => Self::Date64, - DataType::Time32(unit) => Self::Time32(unit.into()), - DataType::Time64(unit) => Self::Time64(unit.into()), - DataType::Duration(unit) => Self::Duration(unit.into()), - DataType::Interval(unit) => Self::Interval(unit.into()), - DataType::Binary => Self::Binary, - DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), - DataType::LargeBinary => Self::LargeBinary, - DataType::Utf8 => Self::Utf8, - DataType::LargeUtf8 => Self::LargeUtf8, - DataType::List(f) => Self::List(Box::new(f.into())), - DataType::FixedSizeList(f, size) => Self::FixedSizeList(Box::new(f.into()), size as _), - DataType::LargeList(f) => Self::LargeList(Box::new(f.into())), - DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()), - DataType::Union(fields, mode) => { - let ids = fields.iter().map(|(x, _)| x as _).collect(); - let fields = fields.iter().map(|(_, f)| f.into()).collect(); - Self::Union(fields, Some(ids), mode.into()) - }, - DataType::Map(f, ordered) => Self::Map(Box::new(f.into()), ordered), - DataType::Dictionary(key, value) => { - let key = match *key { - DataType::Int8 => IntegerType::Int8, - DataType::Int16 => IntegerType::Int16, - DataType::Int32 => IntegerType::Int32, - DataType::Int64 => IntegerType::Int64, - DataType::UInt8 => IntegerType::UInt8, - DataType::UInt16 => IntegerType::UInt16, - DataType::UInt32 => IntegerType::UInt32, - DataType::UInt64 => IntegerType::UInt64, - d => panic!("illegal dictionary key type: {d}"), - }; - Self::Dictionary(key, Box::new((*value).into()), false) - }, - DataType::Decimal128(precision, scale) => Self::Decimal(precision as _, scale as _), - DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), - DataType::RunEndEncoded(_, _) => { - panic!("Run-end encoding not supported by polars_arrow") - }, - // This ensures that it doesn't fail to compile when new variants are added to Arrow - #[allow(unreachable_patterns)] - dtype => unimplemented!("unsupported datatype: {dtype}"), - } - } -} - /// Mode of [`ArrowDataType::Union`] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -322,26 +186,6 @@ pub enum UnionMode { Sparse, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::UnionMode { - fn from(value: UnionMode) -> Self { - match value { - UnionMode::Dense => Self::Dense, - UnionMode::Sparse => Self::Sparse, - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for UnionMode { - fn from(value: arrow_schema::UnionMode) -> Self { - match value { - arrow_schema::UnionMode::Dense => Self::Dense, - arrow_schema::UnionMode::Sparse => Self::Sparse, - } - } -} - impl UnionMode { /// Constructs a [`UnionMode::Sparse`] if the input bool is true, /// or otherwise constructs a [`UnionMode::Dense`] @@ -378,30 +222,6 @@ pub enum TimeUnit { Nanosecond, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::TimeUnit { - fn from(value: TimeUnit) -> Self { - match value { - TimeUnit::Nanosecond => Self::Nanosecond, - TimeUnit::Millisecond => Self::Millisecond, - TimeUnit::Microsecond => Self::Microsecond, - TimeUnit::Second => Self::Second, - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for TimeUnit { - fn from(value: arrow_schema::TimeUnit) -> Self { - match value { - arrow_schema::TimeUnit::Nanosecond => Self::Nanosecond, - arrow_schema::TimeUnit::Millisecond => Self::Millisecond, - arrow_schema::TimeUnit::Microsecond => Self::Microsecond, - arrow_schema::TimeUnit::Second => Self::Second, - } - } -} - /// Interval units defined in Arrow #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -415,28 +235,6 @@ pub enum IntervalUnit { MonthDayNano, } -#[cfg(feature = "arrow_rs")] -impl From for arrow_schema::IntervalUnit { - fn from(value: IntervalUnit) -> Self { - match value { - IntervalUnit::YearMonth => Self::YearMonth, - IntervalUnit::DayTime => Self::DayTime, - IntervalUnit::MonthDayNano => Self::MonthDayNano, - } - } -} - -#[cfg(feature = "arrow_rs")] -impl From for IntervalUnit { - fn from(value: arrow_schema::IntervalUnit) -> Self { - match value { - arrow_schema::IntervalUnit::YearMonth => Self::YearMonth, - arrow_schema::IntervalUnit::DayTime => Self::DayTime, - arrow_schema::IntervalUnit::MonthDayNano => Self::MonthDayNano, - } - } -} - impl ArrowDataType { /// the [`PhysicalType`] of this [`ArrowDataType`]. pub fn to_physical_type(&self) -> PhysicalType { @@ -568,6 +366,25 @@ impl ArrowDataType { matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) } + pub fn is_numeric(&self) -> bool { + use ArrowDataType as D; + matches!( + self, + D::Int8 + | D::Int16 + | D::Int32 + | D::Int64 + | D::UInt8 + | D::UInt16 + | D::UInt32 + | D::UInt64 + | D::Float32 + | D::Float64 + | D::Decimal(_, _) + | D::Decimal256(_, _) + ) + } + pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType { ArrowDataType::FixedSizeList( Box::new(Field::new( diff --git a/crates/polars-arrow/src/datatypes/physical_type.rs b/crates/polars-arrow/src/datatypes/physical_type.rs index 174c0401ca3f..732a129055a6 100644 --- a/crates/polars-arrow/src/datatypes/physical_type.rs +++ b/crates/polars-arrow/src/datatypes/physical_type.rs @@ -57,6 +57,10 @@ impl PhysicalType { false } } + + pub fn is_primitive(&self) -> bool { + matches!(self, Self::Primitive(_)) + } } /// the set of valid indices types of a dictionary-encoded Array. diff --git a/crates/polars-core/src/datatypes/reshape.rs b/crates/polars-arrow/src/datatypes/reshape.rs similarity index 100% rename from crates/polars-core/src/datatypes/reshape.rs rename to crates/polars-arrow/src/datatypes/reshape.rs diff --git a/crates/polars-arrow/src/ffi/array.rs b/crates/polars-arrow/src/ffi/array.rs index 7f7e9c409782..5e179be6cca7 100644 --- a/crates/polars-arrow/src/ffi/array.rs +++ b/crates/polars-arrow/src/ffi/array.rs @@ -104,6 +104,7 @@ impl ArrowArray { ArrowDataType::BinaryView | ArrowDataType::Utf8View ); + #[allow(unused_mut)] let (offset, mut buffers, children, dictionary) = offset_buffers_children_dictionary(array.as_ref()); @@ -224,11 +225,7 @@ unsafe fn get_buffer_ptr( ); } - if array - .buffers - .align_offset(std::mem::align_of::<*mut *const u8>()) - != 0 - { + if array.buffers.align_offset(align_of::<*mut *const u8>()) != 0 { polars_bail!( ComputeError: "an ArrowArray of type {dtype:?} must have buffer {index} aligned to type {}", @@ -293,7 +290,7 @@ unsafe fn create_buffer( // We have to check alignment. // This is the zero-copy path. - if ptr.align_offset(std::mem::align_of::()) == 0 { + if ptr.align_offset(align_of::()) == 0 { let storage = SharedStorage::from_internal_arrow_array(ptr, len, owner); Ok(Buffer::from_storage(storage).sliced(offset, len - offset)) } @@ -623,7 +620,7 @@ pub struct ArrowArrayChild<'a> { parent: InternalArrowArray, } -impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { +impl ArrowArrayRef for ArrowArrayChild<'_> { /// the dtype as declared in the schema fn dtype(&self) -> &ArrowDataType { &self.dtype diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs index 7c626f2b74ee..f2f8af90c167 100644 --- a/crates/polars-arrow/src/io/avro/read/deserialize.rs +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -195,9 +195,8 @@ fn deserialize_value<'a>( array.push(Some(value)) }, PrimitiveType::Float32 => { - let value = - f32::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); - block = &block[std::mem::size_of::()..]; + let value = f32::from_le_bytes(block[..size_of::()].try_into().unwrap()); + block = &block[size_of::()..]; let array = array .as_mut_any() .downcast_mut::>() @@ -205,9 +204,8 @@ fn deserialize_value<'a>( array.push(Some(value)) }, PrimitiveType::Float64 => { - let value = - f64::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); - block = &block[std::mem::size_of::()..]; + let value = f64::from_le_bytes(block[..size_of::()].try_into().unwrap()); + block = &block[size_of::()..]; let array = array .as_mut_any() .downcast_mut::>() @@ -404,10 +402,10 @@ fn skip_item<'a>( let _ = util::zigzag_i64(&mut block)?; }, PrimitiveType::Float32 => { - block = &block[std::mem::size_of::()..]; + block = &block[size_of::()..]; }, PrimitiveType::Float64 => { - block = &block[std::mem::size_of::()..]; + block = &block[size_of::()..]; }, PrimitiveType::MonthDayNano => { block = &block[12..]; diff --git a/crates/polars-arrow/src/io/flight/mod.rs b/crates/polars-arrow/src/io/flight/mod.rs deleted file mode 100644 index c02a4889f7bb..000000000000 --- a/crates/polars-arrow/src/io/flight/mod.rs +++ /dev/null @@ -1,241 +0,0 @@ -//! Serialization and deserialization to Arrow's flight protocol - -use arrow_format::flight::data::{FlightData, SchemaResult}; -use arrow_format::ipc; -use arrow_format::ipc::planus::ReadAsRoot; -use polars_error::{polars_bail, polars_err, PolarsResult}; - -use super::ipc::read::Dictionaries; -pub use super::ipc::write::default_ipc_fields; -use super::ipc::{IpcField, IpcSchema}; -use crate::array::Array; -use crate::datatypes::*; -pub use crate::io::ipc::write::common::WriteOptions; -use crate::io::ipc::write::common::{encode_chunk, DictionaryTracker, EncodedData}; -use crate::io::ipc::{read, write}; -use crate::record_batch::RecordBatchT; - -/// Serializes [`RecordBatchT`] to a vector of [`FlightData`] representing the serialized dictionaries -/// and a [`FlightData`] representing the batch. -/// # Errors -/// This function errors iff `fields` is not consistent with `columns` -pub fn serialize_batch( - chunk: &RecordBatchT>, - fields: &[IpcField], - options: &WriteOptions, -) -> PolarsResult<(Vec, FlightData)> { - if fields.len() != chunk.arrays().len() { - polars_bail!(oos = "The argument `fields` must be consistent with the columns' schema. Use e.g. &polars_arrow::io::flight::default_ipc_fields(&schema.fields)"); - } - - let mut dictionary_tracker = DictionaryTracker { - dictionaries: Default::default(), - cannot_replace: false, - }; - - let (encoded_dictionaries, encoded_batch) = - encode_chunk(chunk, fields, &mut dictionary_tracker, options) - .expect("DictionaryTracker configured above to not error on replacement"); - - let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); - let flight_batch = encoded_batch.into(); - - Ok((flight_dictionaries, flight_batch)) -} - -impl From for FlightData { - fn from(data: EncodedData) -> Self { - FlightData { - data_header: data.ipc_message, - data_body: data.arrow_data, - ..Default::default() - } - } -} - -/// Serializes a [`ArrowSchema`] to [`SchemaResult`]. -pub fn serialize_schema_to_result( - schema: &ArrowSchema, - ipc_fields: Option<&[IpcField]>, -) -> SchemaResult { - SchemaResult { - schema: _serialize_schema(schema, ipc_fields), - } -} - -/// Serializes a [`ArrowSchema`] to [`FlightData`]. -pub fn serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> FlightData { - FlightData { - data_header: _serialize_schema(schema, ipc_fields), - ..Default::default() - } -} - -/// Convert a [`ArrowSchema`] to bytes in the format expected in [`arrow_format::flight::data::FlightInfo`]. -pub fn serialize_schema_to_info( - schema: &ArrowSchema, - ipc_fields: Option<&[IpcField]>, -) -> PolarsResult> { - let encoded_data = if let Some(ipc_fields) = ipc_fields { - schema_as_encoded_data(schema, ipc_fields) - } else { - let ipc_fields = default_ipc_fields(schema.iter_values()); - schema_as_encoded_data(schema, &ipc_fields) - }; - - let mut schema = vec![]; - write::common_sync::write_message(&mut schema, &encoded_data)?; - Ok(schema) -} - -fn _serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> Vec { - if let Some(ipc_fields) = ipc_fields { - write::schema_to_bytes(schema, ipc_fields) - } else { - let ipc_fields = default_ipc_fields(schema.iter_values()); - write::schema_to_bytes(schema, &ipc_fields) - } -} - -fn schema_as_encoded_data(schema: &ArrowSchema, ipc_fields: &[IpcField]) -> EncodedData { - EncodedData { - ipc_message: write::schema_to_bytes(schema, ipc_fields), - arrow_data: vec![], - } -} - -/// Deserialize an IPC message into [`ArrowSchema`], [`IpcSchema`]. -/// Use to deserialize [`FlightData::data_header`] and [`SchemaResult::schema`]. -pub fn deserialize_schemas(bytes: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchema)> { - read::deserialize_schema(bytes) -} - -/// Deserializes [`FlightData`] representing a record batch message to [`RecordBatchT`]. -pub fn deserialize_batch( - data: &FlightData, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - dictionaries: &read::Dictionaries, -) -> PolarsResult>> { - // check that the data_header is a record batch message - let message = arrow_format::ipc::MessageRef::read_as_root(&data.data_header) - .map_err(|_err| polars_err!(oos = "Unable to get root as message: {err:?}"))?; - - let length = data.data_body.len(); - let mut reader = std::io::Cursor::new(&data.data_body); - - match message.header()?.ok_or_else(|| { - polars_err!(oos = "Unable to convert flight data header to a record batch".to_string()) - })? { - ipc::MessageHeaderRef::RecordBatch(batch) => read::read_record_batch( - batch, - fields, - ipc_schema, - None, - None, - dictionaries, - message.version()?, - &mut reader, - 0, - length as u64, - &mut Default::default(), - ), - _ => polars_bail!(oos = "flight currently only supports reading RecordBatch messages"), - } -} - -/// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`. -pub fn deserialize_dictionary( - data: &FlightData, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - dictionaries: &mut read::Dictionaries, -) -> PolarsResult<()> { - let message = ipc::MessageRef::read_as_root(&data.data_header)?; - - let chunk = if let ipc::MessageHeaderRef::DictionaryBatch(chunk) = message - .header()? - .ok_or_else(|| polars_err!(oos = "Header is required"))? - { - chunk - } else { - return Ok(()); - }; - - let length = data.data_body.len(); - let mut reader = std::io::Cursor::new(&data.data_body); - read::read_dictionary( - chunk, - fields, - ipc_schema, - dictionaries, - &mut reader, - 0, - length as u64, - &mut Default::default(), - )?; - - Ok(()) -} - -/// Deserializes [`FlightData`] into either a [`RecordBatchT`] (when the message is a record batch) -/// or by upserting into `dictionaries` (when the message is a dictionary) -pub fn deserialize_message( - data: &FlightData, - fields: &ArrowSchema, - ipc_schema: &IpcSchema, - dictionaries: &mut Dictionaries, -) -> PolarsResult>>> { - let FlightData { - data_header, - data_body, - .. - } = data; - - let message = arrow_format::ipc::MessageRef::read_as_root(data_header)?; - let header = message - .header()? - .ok_or_else(|| polars_err!(oos = "IPC Message must contain a header"))?; - - match header { - ipc::MessageHeaderRef::RecordBatch(batch) => { - let length = data_body.len(); - let mut reader = std::io::Cursor::new(data_body); - - let chunk = read::read_record_batch( - batch, - fields, - ipc_schema, - None, - None, - dictionaries, - arrow_format::ipc::MetadataVersion::V5, - &mut reader, - 0, - length as u64, - &mut Default::default(), - )?; - - Ok(chunk.into()) - }, - ipc::MessageHeaderRef::DictionaryBatch(dict_batch) => { - let length = data_body.len(); - let mut reader = std::io::Cursor::new(data_body); - - read::read_dictionary( - dict_batch, - fields, - ipc_schema, - dictionaries, - &mut reader, - 0, - length as u64, - &mut Default::default(), - )?; - Ok(None) - }, - t => polars_bail!(ComputeError: - "Reading types other than record batches not yet supported, unable to read {t:?}" - ), - } -} diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index fbb9155149bd..6b893c0e8ce3 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -42,7 +42,7 @@ impl<'a, A, I: Iterator> ProjectionIter<'a, A, I> { } } -impl<'a, A, I: Iterator> Iterator for ProjectionIter<'a, A, I> { +impl> Iterator for ProjectionIter<'_, A, I> { type Item = ProjectionResult; fn next(&mut self) -> Option { diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index 6c831064d5a1..a83e1b758d80 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -9,7 +9,7 @@ use polars_utils::aliases::{InitHashMaps, PlHashMap}; use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER}; use super::common::*; use super::schema::fb_to_schema; -use super::{Dictionaries, OutOfSpecKind}; +use super::{Dictionaries, OutOfSpecKind, SendableIterator}; use crate::array::Array; use crate::datatypes::ArrowSchemaRef; use crate::io::ipc::IpcSchema; @@ -129,14 +129,7 @@ pub fn read_file_dictionaries( Ok(dictionaries) } -/// Reads the footer's length and magic number in footer -fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> { - // read footer length and magic number in footer - let end = reader.seek(SeekFrom::End(-10))? + 10; - - let mut footer: [u8; 10] = [0; 10]; - - reader.read_exact(&mut footer)?; +pub(super) fn decode_footer_len(footer: [u8; 10], end: u64) -> PolarsResult<(u64, usize)> { let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); if footer[4..] != ARROW_MAGIC_V2 { @@ -152,6 +145,17 @@ fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> Ok((end, footer_len)) } +/// Reads the footer's length and magic number in footer +fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10))? + 10; + + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer)?; + decode_footer_len(footer, end) +} + fn read_footer(reader: &mut R, footer_len: usize) -> PolarsResult> { // read footer reader.seek(SeekFrom::End(-10 - footer_len as i64))?; @@ -187,29 +191,61 @@ fn deserialize_footer_blocks( Ok((footer, blocks)) } -pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { - let (footer, blocks) = deserialize_footer_blocks(footer_data)?; +pub(super) fn deserialize_footer_ref(footer_data: &[u8]) -> PolarsResult { + arrow_format::ipc::FooterRef::read_as_root(footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err))) +} - let ipc_schema = footer +pub(super) fn deserialize_schema_ref_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult { + footer .schema() .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferSchema(err)))? - .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingSchema))?; - let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingSchema)) +} + +/// Get the IPC blocks from the footer containing record batches +pub(super) fn iter_recordbatch_blocks_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult> + '_> { + let blocks = footer + .record_batches() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingRecordBatches))?; + + Ok(blocks.iter().map(|block| { + block + .try_into() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err))) + })) +} +pub(super) fn iter_dictionary_blocks_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult> + '_>> +{ let dictionaries = footer .dictionaries() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))? - .map(|dictionaries| { - dictionaries - .into_iter() - .map(|block| { - block.try_into().map_err(|err| { - polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) - }) - }) - .collect::>>() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))?; + + Ok(dictionaries.map(|dicts| { + dicts.into_iter().map(|block| { + block.try_into().map_err(|err| { + polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) }) + })) +} + +pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { + let footer = deserialize_footer_ref(footer_data)?; + let blocks = iter_recordbatch_blocks_from_footer(footer)?.collect::>>()?; + let dictionaries = iter_dictionary_blocks_from_footer(footer)? + .map(|dicts| dicts.collect::>>()) .transpose()?; + let ipc_schema = deserialize_schema_ref_from_footer(footer)?; + let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; Ok(FileMetadata { schema: Arc::new(schema), diff --git a/crates/polars-arrow/src/io/ipc/read/flight.rs b/crates/polars-arrow/src/io/ipc/read/flight.rs new file mode 100644 index 000000000000..8c35fa1cfd60 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/flight.rs @@ -0,0 +1,457 @@ +use std::io::SeekFrom; +use std::pin::Pin; +use std::sync::Arc; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef}; +use futures::{Stream, StreamExt}; +use polars_error::{polars_bail, polars_err, PolarsResult}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; + +use crate::datatypes::ArrowSchema; +use crate::io::ipc::read::common::read_record_batch; +use crate::io::ipc::read::file::{ + decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer, + iter_recordbatch_blocks_from_footer, +}; +use crate::io::ipc::read::schema::deserialize_stream_metadata; +use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata}; +use crate::io::ipc::write::common::EncodedData; +use crate::mmap::{mmap_dictionary_from_batch, mmap_record}; +use crate::record_batch::RecordBatch; + +async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>( + reader: &mut R, + block: &arrow_format::ipc::Block, + scratch: &'a mut Vec, +) -> PolarsResult> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + reader.seek(SeekFrom::Start(offset)).await?; + read_ipc_message(reader, scratch).await +} + +/// Read an encapsulated IPC Message from the reader +async fn read_ipc_message<'a, R: AsyncRead + Unpin>( + reader: &mut R, + scratch: &'a mut Vec, +) -> PolarsResult> { + let mut message_size: [u8; 4] = [0; 4]; + + reader.read_exact(&mut message_size).await?; + if message_size == crate::io::ipc::CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + }; + let message_length = i32::from_le_bytes(message_size); + + let message_length: usize = message_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + scratch.clear(); + scratch.try_reserve(message_length)?; + reader + .take(message_length as u64) + .read_to_end(scratch) + .await?; + + arrow_format::ipc::MessageRef::read_as_root(scratch) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err))) +} + +async fn read_footer_len( + reader: &mut R, +) -> PolarsResult<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10)).await? + 10; + + let mut footer: [u8; 10] = [0; 10]; + reader.read_exact(&mut footer).await?; + + decode_footer_len(footer, end) +} + +async fn read_footer( + reader: &mut R, + footer_len: usize, +) -> PolarsResult> { + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + + reader + .take(footer_len as u64) + .read_to_end(&mut serialized_footer) + .await?; + Ok(serialized_footer) +} + +fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData { + // Turn the IPC schema into an encapsulated message + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + // Assumed the conversion is infallible. + header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new( + schema.try_into().unwrap(), + ))), + body_length: 0, + custom_metadata: None, // todo: allow writing custom metadata + }; + let mut builder = arrow_format::ipc::planus::Builder::new(); + let header = builder.finish(&message, None).to_vec(); + + // Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use + // `data_header` and `data_body`. + EncodedData { + ipc_message: header, + arrow_data: vec![], + } +} + +async fn block_to_raw_message<'a, R>( + reader: &mut R, + block: &arrow_format::ipc::Block, + encoded_data: &mut EncodedData, +) -> PolarsResult<()> +where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, +{ + debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty()); + let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?; + + let block_length: u64 = message + .body_length() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + reader + .take(block_length) + .read_to_end(&mut encoded_data.arrow_data) + .await?; + + Ok(()) +} + +pub async fn into_flight_stream( + reader: &mut R, +) -> PolarsResult> + '_> { + Ok(async_stream::try_stream! { + let (_end, len) = read_footer_len(reader).await?; + let footer_data = read_footer(reader, len).await?; + let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + let data_blocks = iter_recordbatch_blocks_from_footer(footer)?; + let dict_blocks = iter_dictionary_blocks_from_footer(footer)?; + + let schema_ref = deserialize_schema_ref_from_footer(footer)?; + let schema = schema_to_raw_message(schema_ref); + + yield schema; + + if let Some(dict_blocks_iter) = dict_blocks { + for d in dict_blocks_iter { + let mut ed: EncodedData = Default::default(); + block_to_raw_message(reader, &d?, &mut ed).await?; + yield ed + } + }; + + for d in data_blocks { + let mut ed: EncodedData = Default::default(); + block_to_raw_message(reader, &d?, &mut ed).await?; + yield ed + } + }) +} + +pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> { + footer: Option<*const FooterRef<'static>>, + footer_data: Vec, + dict_blocks: Option>>>, + data_blocks: Option>>>, + reader: &'a mut R, +} + +impl Drop for FlightStreamProducer<'_, R> { + fn drop(&mut self) { + if let Some(p) = self.footer { + unsafe { + let _ = Box::from_raw(p as *mut FooterRef<'static>); + } + } + } +} + +unsafe impl Send for FlightStreamProducer<'_, R> {} + +impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { + pub async fn new(reader: &'a mut R) -> PolarsResult>> { + let (_end, len) = read_footer_len(reader).await?; + let footer_data = read_footer(reader, len).await?; + + Ok(Box::pin(Self { + footer: None, + footer_data, + dict_blocks: None, + data_blocks: None, + reader, + })) + } + + pub fn init(self: &mut Pin>) -> PolarsResult<()> { + let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + + let footer = Box::new(footer); + + #[allow(clippy::unnecessary_cast)] + let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>; + + self.footer = Some(ptr); + let footer = &unsafe { **self.footer.as_ref().unwrap() }; + + self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?) + as Box>); + self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)? + .map(|i| Box::new(i) as Box>); + + Ok(()) + } + + pub fn get_schema(self: &Pin>) -> PolarsResult { + let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") }; + + let schema_ref = deserialize_schema_ref_from_footer(*footer)?; + let schema = schema_to_raw_message(schema_ref); + + Ok(schema) + } + + pub async fn next_dict( + self: &mut Pin>, + encoded_data: &mut EncodedData, + ) -> PolarsResult> { + assert!(self.data_blocks.is_some(), "init must be called first"); + encoded_data.ipc_message.clear(); + encoded_data.arrow_data.clear(); + + if let Some(iter) = &mut self.dict_blocks { + let Some(value) = iter.next() else { + return Ok(None); + }; + let block = value?; + + block_to_raw_message(&mut self.reader, &block, encoded_data).await?; + Ok(Some(())) + } else { + Ok(None) + } + } + + pub async fn next_data( + self: &mut Pin>, + encoded_data: &mut EncodedData, + ) -> PolarsResult> { + encoded_data.ipc_message.clear(); + encoded_data.arrow_data.clear(); + + let iter = self + .data_blocks + .as_mut() + .expect("init must be called first"); + let Some(value) = iter.next() else { + return Ok(None); + }; + let block = value?; + + block_to_raw_message(&mut self.reader, &block, encoded_data).await?; + Ok(Some(())) + } +} + +pub struct FlightConsumer { + dictionaries: Dictionaries, + md: StreamMetadata, + scratch: Vec, +} + +impl FlightConsumer { + pub fn new(first: EncodedData) -> PolarsResult { + let md = deserialize_stream_metadata(&first.ipc_message)?; + Ok(Self { + dictionaries: Default::default(), + md, + scratch: vec![], + }) + } + + pub fn schema(&self) -> &ArrowSchema { + &self.md.schema + } + + pub fn consume(&mut self, msg: EncodedData) -> PolarsResult> { + // Parse the header + let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + + // Either append to the dictionaries and return None or return Some(ArrowChunk) + match header { + MessageHeaderRef::Schema(_) => { + polars_bail!(ComputeError: "Unexpected schema message while parsing Stream"); + }, + // Add to dictionary state and continue iteration + MessageHeaderRef::DictionaryBatch(batch) => unsafe { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + mmap_dictionary_from_batch( + &self.md.schema, + &self.md.ipc_schema.fields, + &arrow_data, + batch, + &mut self.dictionaries, + 0, + ) + .map(|_| None) + }, + // Return Batch + MessageHeaderRef::RecordBatch(batch) => { + if batch.compression()?.is_some() { + let data_size = msg.arrow_data.len() as u64; + let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice()); + read_record_batch( + batch, + &self.md.schema, + &self.md.ipc_schema, + None, + None, + &self.dictionaries, + self.md.version, + &mut reader, + 0, + data_size, + &mut self.scratch, + ) + .map(Some) + } else { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + unsafe { + mmap_record( + &self.md.schema, + &self.md.ipc_schema.fields, + arrow_data.clone(), + batch, + 0, + &self.dictionaries, + ) + .map(Some) + } + } + }, + _ => unimplemented!(), + } + } +} + +pub struct FlightstreamConsumer> + Unpin> { + inner: FlightConsumer, + stream: S, +} + +impl> + Unpin> FlightstreamConsumer { + pub async fn new(mut stream: S) -> PolarsResult { + let Some(first) = stream.next().await else { + polars_bail!(ComputeError: "expected the schema") + }; + let first = first?; + + Ok(FlightstreamConsumer { + inner: FlightConsumer::new(first)?, + stream, + }) + } + + pub async fn next_batch(&mut self) -> PolarsResult> { + while let Some(msg) = self.stream.next().await { + let msg = msg?; + let option_recordbatch = self.inner.consume(msg)?; + if option_recordbatch.is_some() { + return Ok(option_recordbatch); + } + } + Ok(None) + } +} + +#[cfg(test)] +mod test { + use std::path::{Path, PathBuf}; + + use tokio::fs::File; + + use super::*; + use crate::record_batch::RecordBatch; + + fn get_file_path() -> PathBuf { + let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc") + } + + fn read_file(path: &Path) -> RecordBatch { + let mut file = std::fs::File::open(path).unwrap(); + let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap(); + let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None); + ipc_reader.next().unwrap().unwrap() + } + + #[tokio::test] + async fn test_file_flight_simple() { + let path = &get_file_path(); + let mut file = tokio::fs::File::open(path).await.unwrap(); + let stream = into_flight_stream(&mut file).await.unwrap(); + + let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap(); + let b = c.next_batch().await.unwrap().unwrap(); + + assert_eq!(b, read_file(path)); + } + + #[tokio::test] + async fn test_file_flight_amortized() { + let path = &get_file_path(); + let mut file = File::open(path).await.unwrap(); + let mut p = FlightStreamProducer::new(&mut file).await.unwrap(); + p.init().unwrap(); + + let mut batches = vec![]; + + let schema = p.get_schema().unwrap(); + batches.push(schema); + + let mut ed = EncodedData::default(); + if p.next_dict(&mut ed).await.unwrap().is_some() { + batches.push(ed); + } + + let mut ed = EncodedData::default(); + p.next_data(&mut ed).await.unwrap(); + batches.push(ed); + + let mut c = + FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok)))) + .await + .unwrap(); + let b = c.next_batch().await.unwrap().unwrap(); + + assert_eq!(b, read_file(path)); + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/mod.rs b/crates/polars-arrow/src/io/ipc/read/mod.rs index 68e806ca8946..88411f9b905f 100644 --- a/crates/polars-arrow/src/io/ipc/read/mod.rs +++ b/crates/polars-arrow/src/io/ipc/read/mod.rs @@ -11,14 +11,14 @@ mod common; mod deserialize; mod error; pub(crate) mod file; +#[cfg(feature = "io_flight")] +mod flight; mod read_basic; mod reader; mod schema; mod stream; pub(crate) use common::first_dict_field; -#[cfg(feature = "io_flight")] -pub(crate) use common::{read_dictionary, read_record_batch}; pub use error::OutOfSpecKind; pub use file::{ deserialize_footer, get_row_count, read_batch, read_file_dictionaries, read_file_metadata, @@ -36,3 +36,10 @@ pub(crate) type Node<'a> = arrow_format::ipc::FieldNodeRef<'a>; pub(crate) type IpcBuffer<'a> = arrow_format::ipc::BufferRef<'a>; pub(crate) type Compression<'a> = arrow_format::ipc::BodyCompressionRef<'a>; pub(crate) type Version = arrow_format::ipc::MetadataVersion; + +#[cfg(feature = "io_flight")] +pub use flight::*; + +pub trait SendableIterator: Send + Iterator {} + +impl SendableIterator for T {} diff --git a/crates/polars-arrow/src/io/ipc/read/read_basic.rs b/crates/polars-arrow/src/io/ipc/read/read_basic.rs index 09005ea4222e..9721137adf5c 100644 --- a/crates/polars-arrow/src/io/ipc/read/read_basic.rs +++ b/crates/polars-arrow/src/io/ipc/read/read_basic.rs @@ -17,10 +17,10 @@ fn read_swapped( is_little_endian: bool, ) -> PolarsResult<()> { // slow case where we must reverse bits - let mut slice = vec![0u8; length * std::mem::size_of::()]; + let mut slice = vec![0u8; length * size_of::()]; reader.read_exact(&mut slice)?; - let chunks = slice.chunks_exact(std::mem::size_of::()); + let chunks = slice.chunks_exact(size_of::()); if !is_little_endian { // machine is little endian, file is big endian buffer @@ -67,7 +67,7 @@ fn read_uncompressed_buffer( length: usize, is_little_endian: bool, ) -> PolarsResult> { - let required_number_of_bytes = length.saturating_mul(std::mem::size_of::()); + let required_number_of_bytes = length.saturating_mul(size_of::()); if required_number_of_bytes > buffer_length { polars_bail!( oos = OutOfSpecKind::InvalidBuffer { diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs index a49c7fdcd790..6fa1b5f0d8c4 100644 --- a/crates/polars-arrow/src/io/ipc/write/common.rs +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -482,7 +482,7 @@ pub struct Record<'a> { fields: Option>, } -impl<'a> Record<'a> { +impl Record<'_> { /// Get the IPC fields for this record. pub fn fields(&self) -> Option<&[IpcField]> { self.fields.as_deref() diff --git a/crates/polars-arrow/src/io/ipc/write/mod.rs b/crates/polars-arrow/src/io/ipc/write/mod.rs index 99f6fcc3f355..2291448d3012 100644 --- a/crates/polars-arrow/src/io/ipc/write/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/mod.rs @@ -5,7 +5,7 @@ mod serialize; mod stream; pub(crate) mod writer; -pub use common::{Compression, Record, WriteOptions}; +pub use common::{Compression, EncodedData, Record, WriteOptions}; pub use schema::schema_to_bytes; pub use serialize::write; use serialize::write_dictionary; diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs index f13098477d4d..43af5add63a6 100644 --- a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs @@ -283,7 +283,7 @@ fn _write_buffer_from_iter>( is_little_endian: bool, ) { let len = buffer.size_hint().0; - arrow_data.reserve(len * std::mem::size_of::()); + arrow_data.reserve(len * size_of::()); if is_little_endian { buffer .map(|x| T::to_le_bytes(&x)) @@ -303,7 +303,7 @@ fn _write_compressed_buffer_from_iter>( compression: Compression, ) { let len = buffer.size_hint().0; - let mut swapped = Vec::with_capacity(len * std::mem::size_of::()); + let mut swapped = Vec::with_capacity(len * size_of::()); if is_little_endian { buffer .map(|x| T::to_le_bytes(&x)) diff --git a/crates/polars-arrow/src/io/mod.rs b/crates/polars-arrow/src/io/mod.rs index a2b178304df0..1ae39a8ef766 100644 --- a/crates/polars-arrow/src/io/mod.rs +++ b/crates/polars-arrow/src/io/mod.rs @@ -1,15 +1,7 @@ -#![forbid(unsafe_code)] -//! Contains modules to interface with other formats such as [`csv`], -//! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`]. - #[cfg(feature = "io_ipc")] #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] pub mod ipc; -#[cfg(feature = "io_flight")] -#[cfg_attr(docsrs, doc(cfg(feature = "io_flight")))] -pub mod flight; - #[cfg(feature = "io_avro")] #[cfg_attr(docsrs, doc(cfg(feature = "io_avro")))] pub mod avro; diff --git a/crates/polars-arrow/src/legacy/kernels/atan2.rs b/crates/polars-arrow/src/legacy/kernels/atan2.rs deleted file mode 100644 index 40d3d527b24a..000000000000 --- a/crates/polars-arrow/src/legacy/kernels/atan2.rs +++ /dev/null @@ -1,12 +0,0 @@ -use num_traits::Float; - -use crate::array::PrimitiveArray; -use crate::compute::arity::binary; -use crate::types::NativeType; - -pub fn atan2(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> PrimitiveArray -where - T: Float + NativeType, -{ - binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| a.atan2(b)) -} diff --git a/crates/polars-arrow/src/legacy/kernels/float.rs b/crates/polars-arrow/src/legacy/kernels/float.rs deleted file mode 100644 index 22413fd0b4c5..000000000000 --- a/crates/polars-arrow/src/legacy/kernels/float.rs +++ /dev/null @@ -1,54 +0,0 @@ -use num_traits::Float; - -use crate::array::{ArrayRef, BooleanArray, PrimitiveArray}; -use crate::bitmap::Bitmap; -use crate::legacy::array::default_arrays::FromData; -use crate::types::NativeType; - -pub fn is_nan(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| v.is_nan())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} - -pub fn is_not_nan(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| !v.is_nan())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} - -pub fn is_finite(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| v.is_finite())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} - -pub fn is_infinite(arr: &PrimitiveArray) -> ArrayRef -where - T: NativeType + Float, -{ - let values = Bitmap::from_trusted_len_iter(arr.values().iter().map(|v| v.is_infinite())); - - Box::new(BooleanArray::from_data_default( - values, - arr.validity().cloned(), - )) -} diff --git a/crates/polars-arrow/src/legacy/kernels/mod.rs b/crates/polars-arrow/src/legacy/kernels/mod.rs index cdb7cb9c2057..89b31684beed 100644 --- a/crates/polars-arrow/src/legacy/kernels/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/mod.rs @@ -2,15 +2,12 @@ use std::iter::Enumerate; use crate::array::BooleanArray; use crate::bitmap::utils::BitChunks; -pub mod atan2; pub mod concatenate; pub mod ewm; #[cfg(feature = "compute_take")] pub mod fixed_size_list; -pub mod float; #[cfg(feature = "compute_take")] pub mod list; -pub mod pow; pub mod rolling; pub mod set; pub mod sort_partition; @@ -62,8 +59,7 @@ impl<'a> MaskedSlicesIterator<'a> { pub(crate) fn new(mask: &'a BooleanArray) -> Self { let chunks = mask.values().chunks::(); - let chunk_bits = 8 * std::mem::size_of::(); - let chunk_len = mask.len() / chunk_bits; + let chunk_len = mask.len() / 64; let remainder_len = chunks.remainder_len(); let remainder_mask = chunks.remainder(); @@ -141,7 +137,7 @@ impl<'a> MaskedSlicesIterator<'a> { } } -impl<'a> Iterator for MaskedSlicesIterator<'a> { +impl Iterator for MaskedSlicesIterator<'_> { type Item = (usize, usize); fn next(&mut self) -> Option { @@ -213,7 +209,7 @@ impl<'a> BinaryMaskedSliceIterator<'a> { } } -impl<'a> Iterator for BinaryMaskedSliceIterator<'a> { +impl Iterator for BinaryMaskedSliceIterator<'_> { type Item = (usize, usize, bool); fn next(&mut self) -> Option { diff --git a/crates/polars-arrow/src/legacy/kernels/pow.rs b/crates/polars-arrow/src/legacy/kernels/pow.rs deleted file mode 100644 index a790e6193129..000000000000 --- a/crates/polars-arrow/src/legacy/kernels/pow.rs +++ /dev/null @@ -1,13 +0,0 @@ -use num_traits::pow::Pow; - -use crate::array::PrimitiveArray; -use crate::compute::arity::binary; -use crate::types::NativeType; - -pub fn pow(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> PrimitiveArray -where - T: Pow + NativeType, - F: NativeType, -{ - binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| Pow::pow(a, b)) -} diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs index 8f8004c570e8..51f3c95d2a56 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs @@ -93,5 +93,5 @@ pub struct RollingVarParams { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingQuantileParams { pub prob: f64, - pub interpol: QuantileInterpolOptions, + pub method: QuantileMethod, } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 1b9695358dcb..7abe2455e61f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -11,6 +11,7 @@ use num_traits::{Float, Num, NumCast}; pub use quantile::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; pub use sum::*; pub use variance::*; @@ -69,17 +70,22 @@ where Ok(Box::new(arr)) } -#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum QuantileInterpolOptions { +#[strum(serialize_all = "snake_case")] +pub enum QuantileMethod { #[default] Nearest, Lower, Higher, Midpoint, Linear, + Equiprobable, } +#[deprecated(note = "use QuantileMethod instead")] +pub type QuantileInterpolOptions = QuantileMethod; + pub(super) fn rolling_apply_weights( values: &[T], window_size: usize, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index ab3919b9aaaa..bf0ad01e79c3 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -2,13 +2,13 @@ use num_traits::ToPrimitive; use polars_error::polars_ensure; use polars_utils::slice::GetSaferUnchecked; -use super::QuantileInterpolOptions::*; +use super::QuantileMethod::*; use super::*; pub struct QuantileWindow<'a, T: NativeType> { sorted: SortedBuf<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -34,7 +34,7 @@ impl< Self { sorted: SortedBuf::new(slice, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -42,7 +42,7 @@ impl< let vals = self.sorted.update(start, end); let length = vals.len(); - let idx = match self.interpol { + let idx = match self.method { Linear => { // Maybe add a fast path for median case? They could branch depending on odd/even. let length_f = length as f64; @@ -92,6 +92,7 @@ impl< let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; std::cmp::min(idx, length - 1) }, + Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize, }; // SAFETY: @@ -134,7 +135,7 @@ where unreachable!("expected Quantile params"); }; let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>( - params.interpol, + params.method, min_periods, window_size, values, @@ -170,7 +171,7 @@ where Ok(rolling_apply_weighted_quantile( values, params.prob, - params.interpol, + params.method, window_size, min_periods, offset_fn, @@ -182,7 +183,7 @@ where } #[inline] -fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, interp: QuantileInterpolOptions) -> T +fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T where T: Debug + NativeType + Mul + Sub + NumCast + ToPrimitive + Zero, { @@ -201,7 +202,7 @@ where (s_old, v_old, vk) = (s, vk, v); s += w; } - match (h == s_old, interp) { + match (h == s_old, method) { (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter (_, Lower) => v_old, (_, Higher) => vk, @@ -212,6 +213,14 @@ where vk } }, + (_, Equiprobable) => { + let threshold = (wsum * p).ceil() - 1.0; + if s > threshold { + vk + } else { + v_old + } + }, (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(), // This is seemingly the canonical way to do it. (_, Linear) => { @@ -224,7 +233,7 @@ where fn rolling_apply_weighted_quantile( values: &[T], p: f64, - interpolation: QuantileInterpolOptions, + method: QuantileMethod, window_size: usize, min_periods: usize, det_offsets_fn: Fo, @@ -252,7 +261,7 @@ where .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w)); } buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0)); - compute_wq(&buf, p, wsum, interpolation) + compute_wq(&buf, p, wsum, method) }) .collect_trusted::>(); @@ -273,7 +282,7 @@ mod test { let values = &[1.0, 2.0, 3.0, 4.0]; let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: Linear, + method: Linear, })); let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap(); let out = out.as_any().downcast_ref::>().unwrap(); @@ -305,18 +314,19 @@ mod test { fn test_rolling_quantile_limits() { let values = &[1.0f64, 2.0, 3.0, 4.0]; - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -328,7 +338,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs index 52039fe77572..7f0c65d42bd7 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs @@ -25,7 +25,7 @@ pub struct SortedMinMax<'a, T: NativeType> { null_count: usize, } -impl<'a, T: NativeType> SortedMinMax<'a, T> { +impl SortedMinMax<'_, T> { fn count_nulls(&self, start: usize, end: usize) -> usize { let (bytes, offset, _) = self.validity.as_slice(); count_zeros(bytes, offset + start, end - start) diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 259316513fe5..3d5dd664bd34 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -6,7 +6,7 @@ use crate::array::MutablePrimitiveArray; pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> { sorted: SortedBufNulls<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -39,7 +39,7 @@ impl< Self { sorted: SortedBufNulls::new(slice, validity, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -53,21 +53,22 @@ impl< let values = &values[null_count..]; let length = values.len(); - let mut idx = match self.interpol { - QuantileInterpolOptions::Nearest => ((length as f64) * self.prob) as usize, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => { + let mut idx = match self.method { + QuantileMethod::Nearest => ((length as f64) * self.prob) as usize, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { ((length as f64 - 1.0) * self.prob).floor() as usize }, - QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Equiprobable => { + ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize + }, }; idx = std::cmp::min(idx, length - 1); // we can unwrap because we sliced of the nulls - match self.interpol { - QuantileInterpolOptions::Midpoint => { + match self.method { + QuantileMethod::Midpoint => { let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; Some( (values.get_unchecked_release(idx).unwrap() @@ -75,7 +76,7 @@ impl< / T::from::(2.0f64).unwrap(), ) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let float_idx = (length as f64 - 1.0) * self.prob; let top_idx = f64::ceil(float_idx) as usize; @@ -136,7 +137,7 @@ where }; let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>( - params.interpol, + params.method, min_periods, window_size, arr.clone(), @@ -171,7 +172,7 @@ mod test { ); let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: QuantileInterpolOptions::Linear, + method: QuantileMethod::Linear, })); let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone()); @@ -210,18 +211,19 @@ mod test { Some(Bitmap::from(&[true, false, false, true, true])), ); - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -233,7 +235,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs index d1392ef7cb50..599bdb241b07 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs @@ -9,7 +9,7 @@ pub struct SumWindow<'a, T> { pub(super) null_count: usize, } -impl<'a, T: NativeType + IsFloat + Add + Sub> SumWindow<'a, T> { +impl + Sub> SumWindow<'_, T> { // compute sum from the entire window unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { let mut sum = None; diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs index ee97d4cb15a3..8252c8931c4f 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs @@ -9,8 +9,8 @@ pub(super) struct SumSquaredWindow<'a, T> { null_count: usize, } -impl<'a, T: NativeType + IsFloat + Add + Sub + Mul> - SumSquaredWindow<'a, T> +impl + Sub + Mul> + SumSquaredWindow<'_, T> { // compute sum from the entire window unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs index 40a464e6f5bc..c616310ee568 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs @@ -11,7 +11,7 @@ use polars_utils::slice::{GetSaferUnchecked, SliceAble}; use polars_utils::sort::arg_sort_ascending; use polars_utils::total_ord::TotalOrd; -use crate::legacy::prelude::QuantileInterpolOptions; +use crate::legacy::prelude::QuantileMethod; use crate::pushable::Pushable; use crate::types::NativeType; @@ -32,7 +32,7 @@ struct Block<'a, A> { nulls_in_window: usize, } -impl<'a, A> Debug for Block<'a, A> +impl Debug for Block<'_, A> where A: Indexable, A::Item: Debug + Copy, @@ -443,7 +443,7 @@ where } } -impl<'a, A> LenGet for BlockUnion<'a, A> +impl LenGet for BlockUnion<'_, A> where A: Indexable + Bounded + NullCount + Clone, ::Item: TotalOrd + Copy + Debug, @@ -573,7 +573,7 @@ struct QuantileUpdate { inner: M, quantile: f64, min_periods: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl QuantileUpdate @@ -581,12 +581,12 @@ where M: LenGet, ::Item: Default + IsNull + Copy + FinishLinear + Debug, { - fn new(interpol: QuantileInterpolOptions, min_periods: usize, quantile: f64, inner: M) -> Self { + fn new(method: QuantileMethod, min_periods: usize, quantile: f64, inner: M) -> Self { Self { min_periods, quantile, inner, - interpol, + method, } } @@ -602,8 +602,8 @@ where let valid_length_f = valid_length as f64; - use QuantileInterpolOptions::*; - match self.interpol { + use QuantileMethod::*; + match self.method { Linear => { let float_idx_top = (valid_length_f - 1.0) * self.quantile; let idx = float_idx_top.floor() as usize; @@ -623,6 +623,10 @@ where let idx = std::cmp::min(idx, valid_length - 1); self.inner.get(idx + null_count) }, + Equiprobable => { + let idx = ((valid_length_f * self.quantile).ceil() - 1.0).max(0.0) as usize; + self.inner.get(idx + null_count) + }, Midpoint => { let idx = (valid_length_f * self.quantile) as usize; let idx = std::cmp::min(idx, valid_length - 1); @@ -651,7 +655,7 @@ where } pub(super) fn rolling_quantile::Item>>( - interpol: QuantileInterpolOptions, + method: QuantileMethod, min_periods: usize, k: usize, values: A, @@ -709,7 +713,7 @@ where // SAFETY: bounded by capacity unsafe { block_left.undelete(i) }; - let mut mu = QuantileUpdate::new(interpol, min_periods, quantile, &mut block_left); + let mut mu = QuantileUpdate::new(method, min_periods, quantile, &mut block_left); out.push(mu.quantile()); } for i in 1..n_blocks + 1 { @@ -747,7 +751,7 @@ where let mut union = BlockUnion::new(&mut *ptr_left, &mut *ptr_right); union.set_state(j); let q: ::Item = - QuantileUpdate::new(interpol, min_periods, quantile, union).quantile(); + QuantileUpdate::new(method, min_periods, quantile, union).quantile(); out.push(q); } } @@ -1062,22 +1066,22 @@ mod test { 2.0, 8.0, 5.0, 9.0, 1.0, 2.0, 4.0, 2.0, 4.0, 8.1, -1.0, 2.9, 1.2, 23.0, ] .as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 8.0, 5.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 2.9, 1.2, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 5, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 5, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 5.0, 4.0, 2.0, 2.0, 4.0, 4.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 7, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 7, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 3.5, 4.0, 4.0, 4.0, 4.0, 2.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 4, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 4, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 6.5, 3.5, 3.0, 2.0, 3.0, 4.0, 3.0, 3.45, 2.05, 2.05, ]; @@ -1087,7 +1091,7 @@ mod test { #[test] fn test_median_2() { let values = [10, 10, 15, 13, 9, 5, 3, 13, 19, 15, 19].as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [10, 10, 10, 13, 13, 9, 5, 5, 13, 15, 19]; assert_eq!(out, expected); } diff --git a/crates/polars-arrow/src/legacy/kernels/time.rs b/crates/polars-arrow/src/legacy/kernels/time.rs index 08bc285c7ffe..73caabb7587b 100644 --- a/crates/polars-arrow/src/legacy/kernels/time.rs +++ b/crates/polars-arrow/src/legacy/kernels/time.rs @@ -9,6 +9,7 @@ use polars_error::PolarsResult; use polars_error::{polars_bail, PolarsError}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; pub enum Ambiguous { Earliest, @@ -32,8 +33,9 @@ impl FromStr for Ambiguous { } } -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum NonExistent { Null, Raise, diff --git a/crates/polars-arrow/src/legacy/prelude.rs b/crates/polars-arrow/src/legacy/prelude.rs index 88b2dd48bbea..6afeb0c6c9be 100644 --- a/crates/polars-arrow/src/legacy/prelude.rs +++ b/crates/polars-arrow/src/legacy/prelude.rs @@ -2,7 +2,7 @@ use crate::array::{BinaryArray, ListArray, Utf8Array}; pub use crate::legacy::array::default_arrays::*; pub use crate::legacy::array::*; pub use crate::legacy::index::*; -pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; +pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod; pub use crate::legacy::kernels::rolling::{ RollingFnParams, RollingQuantileParams, RollingVarParams, }; @@ -11,3 +11,6 @@ pub use crate::legacy::kernels::{Ambiguous, NonExistent}; pub type LargeStringArray = Utf8Array; pub type LargeBinaryArray = BinaryArray; pub type LargeListArray = ListArray; + +#[allow(deprecated)] +pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; diff --git a/crates/polars-arrow/src/mmap/array.rs b/crates/polars-arrow/src/mmap/array.rs index 8822824858c6..060ab745fcc0 100644 --- a/crates/polars-arrow/src/mmap/array.rs +++ b/crates/polars-arrow/src/mmap/array.rs @@ -35,7 +35,7 @@ fn check_bytes_len_and_is_aligned( bytes: &[u8], expected_len: usize, ) -> PolarsResult { - if bytes.len() < std::mem::size_of::() * expected_len { + if bytes.len() < size_of::() * expected_len { polars_bail!(ComputeError: "buffer's length is too small in mmap") }; @@ -281,7 +281,7 @@ fn mmap_primitive>( let bytes = get_bytes(data_ref, block_offset, buffers)?; let is_aligned = check_bytes_len_and_is_aligned::

(bytes, num_rows)?; - let out = if is_aligned || std::mem::size_of::() <= 8 { + let out = if is_aligned || size_of::() <= 8 { assert!( is_aligned, "primitive type with size <= 8 bytes should have been aligned" diff --git a/crates/polars-arrow/src/mmap/mod.rs b/crates/polars-arrow/src/mmap/mod.rs index 37baba59e9df..6ad0ca776d7d 100644 --- a/crates/polars-arrow/src/mmap/mod.rs +++ b/crates/polars-arrow/src/mmap/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; mod array; use arrow_format::ipc::planus::ReadAsRoot; -use arrow_format::ipc::{Block, MessageRef, RecordBatchRef}; +use arrow_format::ipc::{Block, DictionaryBatchRef, MessageRef, RecordBatchRef}; use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; use polars_utils::pl_str::PlSmallStr; @@ -71,7 +71,7 @@ fn get_buffers_nodes(batch: RecordBatchRef) -> PolarsResult<(VecDeque Ok((buffers, field_nodes)) } -unsafe fn mmap_record>( +pub(crate) unsafe fn mmap_record>( fields: &ArrowSchema, ipc_fields: &[IpcField], data: Arc, @@ -146,19 +146,29 @@ pub unsafe fn mmap_unchecked>( } unsafe fn mmap_dictionary>( - metadata: &FileMetadata, + schema: &ArrowSchema, + ipc_fields: &[IpcField], data: Arc, block: Block, dictionaries: &mut Dictionaries, ) -> PolarsResult<()> { let (message, offset) = read_message(data.as_ref().as_ref(), block)?; let batch = get_dictionary_batch(&message)?; + mmap_dictionary_from_batch(schema, ipc_fields, &data, batch, dictionaries, offset) +} +pub(crate) unsafe fn mmap_dictionary_from_batch>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + data: &Arc, + batch: DictionaryBatchRef, + dictionaries: &mut Dictionaries, + offset: usize, +) -> PolarsResult<()> { let id = batch .id() .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferId(err)))?; - let (first_field, first_ipc_field) = - first_dict_field(id, &metadata.schema, &metadata.ipc_schema.fields)?; + let (first_field, first_ipc_field) = first_dict_field(id, schema, ipc_fields)?; let batch = batch .data() @@ -199,7 +209,21 @@ pub unsafe fn mmap_dictionaries_unchecked>( metadata: &FileMetadata, data: Arc, ) -> PolarsResult { - let blocks = if let Some(blocks) = &metadata.dictionaries { + mmap_dictionaries_unchecked2( + metadata.schema.as_ref(), + &metadata.ipc_schema.fields, + metadata.dictionaries.as_ref(), + data, + ) +} + +pub(crate) unsafe fn mmap_dictionaries_unchecked2>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + dictionaries: Option<&Vec>, + data: Arc, +) -> PolarsResult { + let blocks = if let Some(blocks) = &dictionaries { blocks } else { return Ok(Default::default()); @@ -207,9 +231,8 @@ pub unsafe fn mmap_dictionaries_unchecked>( let mut dictionaries = Default::default(); - blocks - .iter() - .cloned() - .try_for_each(|block| mmap_dictionary(metadata, data.clone(), block, &mut dictionaries))?; + blocks.iter().cloned().try_for_each(|block| { + mmap_dictionary(schema, ipc_fields, data.clone(), block, &mut dictionaries) + })?; Ok(dictionaries) } diff --git a/crates/polars-arrow/src/storage.rs b/crates/polars-arrow/src/storage.rs index 7cab3235b25a..ddde815b5b10 100644 --- a/crates/polars-arrow/src/storage.rs +++ b/crates/polars-arrow/src/storage.rs @@ -1,37 +1,93 @@ use std::marker::PhantomData; use std::mem::ManuallyDrop; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::ptr::NonNull; use std::sync::atomic::{AtomicU64, Ordering}; +use bytemuck::Pod; + +// Allows us to transmute between types while also keeping the original +// stats and drop method of the Vec around. +struct VecVTable { + size: usize, + align: usize, + drop_buffer: unsafe fn(*mut (), usize), +} + +impl VecVTable { + const fn new() -> Self { + unsafe fn drop_buffer(ptr: *mut (), cap: usize) { + unsafe { drop(Vec::from_raw_parts(ptr.cast::(), 0, cap)) } + } + + Self { + size: size_of::(), + align: align_of::(), + drop_buffer: drop_buffer::, + } + } + + fn new_static() -> &'static Self { + const { &Self::new::() } + } +} + use crate::ffi::InternalArrowArray; enum BackingStorage { Vec { - capacity: usize, + original_capacity: usize, // Elements, not bytes. + vtable: &'static VecVTable, }, InternalArrowArray(InternalArrowArray), - #[cfg(feature = "arrow_rs")] - ArrowBuffer(arrow_buffer::Buffer), } struct SharedStorageInner { ref_count: AtomicU64, ptr: *mut T, - length: usize, + length_in_bytes: usize, backing: Option, // https://github.com/rust-lang/rfcs/blob/master/text/0769-sound-generic-drop.md#phantom-data phantom: PhantomData, } +impl SharedStorageInner { + pub fn from_vec(mut v: Vec) -> Self { + let length_in_bytes = v.len() * size_of::(); + let original_capacity = v.capacity(); + let ptr = v.as_mut_ptr(); + core::mem::forget(v); + Self { + ref_count: AtomicU64::new(1), + ptr, + length_in_bytes, + backing: Some(BackingStorage::Vec { + original_capacity, + vtable: VecVTable::new_static::(), + }), + phantom: PhantomData, + } + } +} + impl Drop for SharedStorageInner { fn drop(&mut self) { match self.backing.take() { Some(BackingStorage::InternalArrowArray(a)) => drop(a), - #[cfg(feature = "arrow_rs")] - Some(BackingStorage::ArrowBuffer(b)) => drop(b), - Some(BackingStorage::Vec { capacity }) => unsafe { - drop(Vec::from_raw_parts(self.ptr, self.length, capacity)) + Some(BackingStorage::Vec { + original_capacity, + vtable, + }) => unsafe { + // Drop the elements in our slice. + if std::mem::needs_drop::() { + core::ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut( + self.ptr, + self.length_in_bytes / size_of::(), + )); + } + + // Free the buffer. + (vtable.drop_buffer)(self.ptr.cast(), original_capacity); }, None => {}, } @@ -48,12 +104,13 @@ unsafe impl Sync for SharedStorage {} impl SharedStorage { pub fn from_static(slice: &'static [T]) -> Self { - let length = slice.len(); + #[expect(clippy::manual_slice_size_calculation)] + let length_in_bytes = slice.len() * size_of::(); let ptr = slice.as_ptr().cast_mut(); let inner = SharedStorageInner { ref_count: AtomicU64::new(2), // Never used, but 2 so it won't pass exclusivity tests. ptr, - length, + length_in_bytes, backing: None, phantom: PhantomData, }; @@ -63,20 +120,9 @@ impl SharedStorage { } } - pub fn from_vec(mut v: Vec) -> Self { - let length = v.len(); - let capacity = v.capacity(); - let ptr = v.as_mut_ptr(); - core::mem::forget(v); - let inner = SharedStorageInner { - ref_count: AtomicU64::new(1), - ptr, - length, - backing: Some(BackingStorage::Vec { capacity }), - phantom: PhantomData, - }; + pub fn from_vec(v: Vec) -> Self { Self { - inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(), + inner: NonNull::new(Box::into_raw(Box::new(SharedStorageInner::from_vec(v)))).unwrap(), phantom: PhantomData, } } @@ -85,7 +131,7 @@ impl SharedStorage { let inner = SharedStorageInner { ref_count: AtomicU64::new(1), ptr: ptr.cast_mut(), - length: len, + length_in_bytes: len * size_of::(), backing: Some(BackingStorage::InternalArrowArray(arr)), phantom: PhantomData, }; @@ -96,39 +142,40 @@ impl SharedStorage { } } -#[cfg(feature = "arrow_rs")] -impl SharedStorage { - pub fn from_arrow_buffer(buffer: arrow_buffer::Buffer) -> Self { - let ptr = buffer.as_ptr(); - let align_offset = ptr.align_offset(std::mem::align_of::()); - assert_eq!(align_offset, 0, "arrow_buffer::Buffer misaligned"); - let length = buffer.len() / std::mem::size_of::(); +pub struct SharedStorageAsVecMut<'a, T> { + ss: &'a mut SharedStorage, + vec: ManuallyDrop>, +} - let inner = SharedStorageInner { - ref_count: AtomicU64::new(1), - ptr: ptr as *mut T, - length, - backing: Some(BackingStorage::ArrowBuffer(buffer)), - phantom: PhantomData, - }; - Self { - inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(), - phantom: PhantomData, - } +impl Deref for SharedStorageAsVecMut<'_, T> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.vec + } +} + +impl DerefMut for SharedStorageAsVecMut<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.vec } +} - pub fn into_arrow_buffer(self) -> arrow_buffer::Buffer { - let ptr = NonNull::new(self.as_ptr() as *mut u8).unwrap(); - let len = self.len() * std::mem::size_of::(); - let arc = std::sync::Arc::new(self); - unsafe { arrow_buffer::Buffer::from_custom_allocation(ptr, len, arc) } +impl Drop for SharedStorageAsVecMut<'_, T> { + fn drop(&mut self) { + unsafe { + // Restore the SharedStorage. + let vec = ManuallyDrop::take(&mut self.vec); + let inner = self.ss.inner.as_ptr(); + inner.write(SharedStorageInner::from_vec(vec)); + } } } impl SharedStorage { #[inline(always)] pub fn len(&self) -> usize { - self.inner().length + self.inner().length_in_bytes / size_of::() } #[inline(always)] @@ -156,21 +203,55 @@ impl SharedStorage { pub fn try_as_mut_slice(&mut self) -> Option<&mut [T]> { self.is_exclusive().then(|| { let inner = self.inner(); - unsafe { core::slice::from_raw_parts_mut(inner.ptr, inner.length) } + let len = inner.length_in_bytes / size_of::(); + unsafe { core::slice::from_raw_parts_mut(inner.ptr, len) } }) } - pub fn try_into_vec(mut self) -> Result, Self> { - let Some(BackingStorage::Vec { capacity }) = self.inner().backing else { - return Err(self); + /// Try to take the vec backing this SharedStorage, leaving this as an empty slice. + pub fn try_take_vec(&mut self) -> Option> { + // We may only go back to a Vec if we originally came from a Vec + // where the desired size/align matches the original. + let Some(BackingStorage::Vec { + original_capacity, + vtable, + }) = self.inner().backing + else { + return None; }; - if self.is_exclusive() { - let slf = ManuallyDrop::new(self); - let inner = slf.inner(); - Ok(unsafe { Vec::from_raw_parts(inner.ptr, inner.length, capacity) }) - } else { - Err(self) + + if vtable.size != size_of::() || vtable.align != align_of::() { + return None; + } + + // If there are other references we can't get an exclusive reference. + if !self.is_exclusive() { + return None; } + + let ret; + unsafe { + let inner = &mut *self.inner.as_ptr(); + let len = inner.length_in_bytes / size_of::(); + ret = Vec::from_raw_parts(inner.ptr, len, original_capacity); + inner.length_in_bytes = 0; + inner.backing = None; + } + Some(ret) + } + + /// Attempts to call the given function with this SharedStorage as a + /// reference to a mutable Vec. If this SharedStorage can't be converted to + /// a Vec the function is not called and instead returned as an error. + pub fn try_as_mut_vec(&mut self) -> Option> { + Some(SharedStorageAsVecMut { + vec: ManuallyDrop::new(self.try_take_vec()?), + ss: self, + }) + } + + pub fn try_into_vec(mut self) -> Result, Self> { + self.try_take_vec().ok_or(self) } #[inline(always)] @@ -186,6 +267,39 @@ impl SharedStorage { } } +impl SharedStorage { + fn try_transmute(self) -> Result, Self> { + let inner = self.inner(); + + // The length of the array in bytes must be a multiple of the target size. + // We can skip this check if the size of U divides the size of T. + if size_of::() % size_of::() != 0 && inner.length_in_bytes % size_of::() != 0 { + return Err(self); + } + + // The pointer must be properly aligned for U. + // We can skip this check if the alignment of U divides the alignment of T. + if align_of::() % align_of::() != 0 && !inner.ptr.cast::().is_aligned() { + return Err(self); + } + + Ok(SharedStorage { + inner: self.inner.cast(), + phantom: PhantomData, + }) + } +} + +impl SharedStorage { + /// Create a [`SharedStorage`][SharedStorage] from a [`Vec`] of [`Pod`]. + pub fn bytes_from_pod_vec(v: Vec) -> Self { + // This can't fail, bytes is compatible with everything. + SharedStorage::from_vec(v) + .try_transmute::() + .unwrap_or_else(|_| unreachable!()) + } +} + impl Deref for SharedStorage { type Target = [T]; @@ -193,7 +307,8 @@ impl Deref for SharedStorage { fn deref(&self) -> &Self::Target { unsafe { let inner = self.inner(); - core::slice::from_raw_parts(inner.ptr, inner.length) + let len = inner.length_in_bytes / size_of::(); + core::slice::from_raw_parts(inner.ptr, len) } } } diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs index b5672f6dd626..1cf8dc10846a 100644 --- a/crates/polars-arrow/src/temporal_conversions.rs +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -3,10 +3,14 @@ use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta}; use polars_error::{polars_err, PolarsResult}; +#[cfg(feature = "compute_cast")] use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "compute_cast")] use crate::array::{PrimitiveArray, Utf8ViewArray}; -use crate::datatypes::{ArrowDataType, TimeUnit}; +#[cfg(feature = "compute_cast")] +use crate::datatypes::ArrowDataType; +use crate::datatypes::TimeUnit; /// Number of seconds in a day pub const SECONDS_IN_DAY: i64 = 86_400; @@ -316,6 +320,7 @@ pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> .ok() } +#[cfg(feature = "compute_cast")] fn utf8view_to_timestamp_impl( array: &Utf8ViewArray, fmt: &str, @@ -365,6 +370,7 @@ fn chrono_tz_utf_to_timestamp( } #[cfg(not(feature = "chrono-tz"))] +#[cfg(feature = "compute_cast")] fn chrono_tz_utf_to_timestamp( _: &Utf8ViewArray, _: &str, @@ -387,6 +393,7 @@ fn chrono_tz_utf_to_timestamp( /// # Error /// /// This function errors iff `timezone` is not parsable to an offset. +#[cfg(feature = "compute_cast")] pub(crate) fn utf8view_to_timestamp( array: &Utf8ViewArray, fmt: &str, @@ -408,6 +415,7 @@ pub(crate) fn utf8view_to_timestamp( /// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. /// Timezones are ignored. /// Null elements remain null; non-parsable elements are set to null. +#[cfg(feature = "compute_cast")] pub(crate) fn utf8view_to_naive_timestamp( array: &Utf8ViewArray, fmt: &str, diff --git a/crates/polars-arrow/src/types/aligned_bytes.rs b/crates/polars-arrow/src/types/aligned_bytes.rs new file mode 100644 index 000000000000..2c9bf9aed977 --- /dev/null +++ b/crates/polars-arrow/src/types/aligned_bytes.rs @@ -0,0 +1,112 @@ +use bytemuck::{Pod, Zeroable}; + +use super::{days_ms, f16, i256, months_days_ns}; +use crate::array::View; + +/// Define that a type has the same byte alignment and size as `B`. +/// +/// # Safety +/// +/// This is safe to implement if both types have the same alignment and size. +pub unsafe trait AlignedBytesCast: Pod {} + +/// A representation of a type as raw bytes with the same alignment as the original type. +pub trait AlignedBytes: Pod + Zeroable + Copy + Default + Eq { + const ALIGNMENT: usize; + const SIZE: usize; + + type Unaligned: AsRef<[u8]> + + AsMut<[u8]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> + + std::fmt::Debug + + Default + + IntoIterator + + Pod; + + fn to_unaligned(&self) -> Self::Unaligned; + fn from_unaligned(unaligned: Self::Unaligned) -> Self; + + /// Safely cast a mutable reference to a [`Vec`] of `T` to a mutable reference of `Self`. + fn cast_vec_ref_mut>(vec: &mut Vec) -> &mut Vec { + if cfg!(debug_assertions) { + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + } + + // SAFETY: SameBytes guarantees that T: + // 1. has the same size + // 2. has the same alignment + // 3. is Pod (therefore has no life-time issues) + unsafe { std::mem::transmute(vec) } + } +} + +macro_rules! impl_aligned_bytes { + ( + $(($name:ident, $size:literal, $alignment:literal, [$($eq_type:ty),*]),)+ + ) => { + $( + /// Bytes with a size and alignment. + /// + /// This is used to reduce the monomorphizations for routines that solely rely on the size + /// and alignment of types. + #[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Pod, Zeroable)] + #[repr(C, align($alignment))] + pub struct $name([u8; $size]); + + impl AlignedBytes for $name { + const ALIGNMENT: usize = $alignment; + const SIZE: usize = $size; + + type Unaligned = [u8; $size]; + + #[inline(always)] + fn to_unaligned(&self) -> Self::Unaligned { + self.0 + } + #[inline(always)] + fn from_unaligned(unaligned: Self::Unaligned) -> Self { + Self(unaligned) + } + } + + impl AsRef<[u8; $size]> for $name { + #[inline(always)] + fn as_ref(&self) -> &[u8; $size] { + &self.0 + } + } + + $( + impl From<$eq_type> for $name { + #[inline(always)] + fn from(value: $eq_type) -> Self { + bytemuck::must_cast(value) + } + } + impl From<$name> for $eq_type { + #[inline(always)] + fn from(value: $name) -> Self { + bytemuck::must_cast(value) + } + } + unsafe impl AlignedBytesCast<$name> for $eq_type {} + )* + )+ + } +} + +impl_aligned_bytes! { + (Bytes1Alignment1, 1, 1, [u8, i8]), + (Bytes2Alignment2, 2, 2, [u16, i16, f16]), + (Bytes4Alignment4, 4, 4, [u32, i32, f32]), + (Bytes8Alignment8, 8, 8, [u64, i64, f64]), + (Bytes8Alignment4, 8, 4, [days_ms]), + (Bytes12Alignment4, 12, 4, [[u32; 3]]), + (Bytes16Alignment4, 16, 4, [View]), + (Bytes16Alignment8, 16, 8, [months_days_ns]), + (Bytes16Alignment16, 16, 16, [u128, i128]), + (Bytes32Alignment16, 32, 16, [i256]), +} diff --git a/crates/polars-arrow/src/types/bit_chunk.rs b/crates/polars-arrow/src/types/bit_chunk.rs index be4445a5d77a..60892689060f 100644 --- a/crates/polars-arrow/src/types/bit_chunk.rs +++ b/crates/polars-arrow/src/types/bit_chunk.rs @@ -72,7 +72,7 @@ impl BitChunkIter { /// Creates a new [`BitChunkIter`] with `len` bits. #[inline] pub fn new(value: T, len: usize) -> Self { - assert!(len <= std::mem::size_of::() * 8); + assert!(len <= size_of::() * 8); Self { value, remaining: len, diff --git a/crates/polars-arrow/src/types/mod.rs b/crates/polars-arrow/src/types/mod.rs index 49b4d315408e..c6f653a32311 100644 --- a/crates/polars-arrow/src/types/mod.rs +++ b/crates/polars-arrow/src/types/mod.rs @@ -20,6 +20,8 @@ //! Finally, this module contains traits used to compile code based on [`NativeType`] optimized //! for SIMD, at [`mod@simd`]. +mod aligned_bytes; +pub use aligned_bytes::*; mod bit_chunk; pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes}; mod index; diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs index 6f869df32602..9cf9a6b56c46 100644 --- a/crates/polars-arrow/src/types/native.rs +++ b/crates/polars-arrow/src/types/native.rs @@ -1,11 +1,13 @@ +use std::hash::{Hash, Hasher}; use std::ops::Neg; use std::panic::RefUnwindSafe; use bytemuck::{Pod, Zeroable}; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; -use polars_utils::total_ord::{TotalEq, TotalOrd}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrd, TotalOrdWrap}; +use super::aligned_bytes::*; use super::PrimitiveType; /// Sealed trait implemented by all physical types that can be allocated, @@ -26,6 +28,7 @@ pub trait NativeType: + TotalOrd + IsNull + MinMax + + AlignedBytesCast { /// The corresponding variant of [`PrimitiveType`]. const PRIMITIVE: PrimitiveType; @@ -41,6 +44,11 @@ pub trait NativeType: + Default + IntoIterator; + /// Type denoting its representation as aligned bytes. + /// + /// This is `[u8; N]` where `N = size_of::` and has alignment `align_of::`. + type AlignedBytes: AlignedBytes + From + Into; + /// To bytes in little endian fn to_le_bytes(&self) -> Self::Bytes; @@ -55,11 +63,13 @@ pub trait NativeType: } macro_rules! native_type { - ($type:ty, $primitive_type:expr) => { + ($type:ty, $aligned:ty, $primitive_type:expr) => { impl NativeType for $type { const PRIMITIVE: PrimitiveType = $primitive_type; type Bytes = [u8; std::mem::size_of::()]; + type AlignedBytes = $aligned; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { Self::to_le_bytes(*self) @@ -83,18 +93,18 @@ macro_rules! native_type { }; } -native_type!(u8, PrimitiveType::UInt8); -native_type!(u16, PrimitiveType::UInt16); -native_type!(u32, PrimitiveType::UInt32); -native_type!(u64, PrimitiveType::UInt64); -native_type!(i8, PrimitiveType::Int8); -native_type!(i16, PrimitiveType::Int16); -native_type!(i32, PrimitiveType::Int32); -native_type!(i64, PrimitiveType::Int64); -native_type!(f32, PrimitiveType::Float32); -native_type!(f64, PrimitiveType::Float64); -native_type!(i128, PrimitiveType::Int128); -native_type!(u128, PrimitiveType::UInt128); +native_type!(u8, Bytes1Alignment1, PrimitiveType::UInt8); +native_type!(u16, Bytes2Alignment2, PrimitiveType::UInt16); +native_type!(u32, Bytes4Alignment4, PrimitiveType::UInt32); +native_type!(u64, Bytes8Alignment8, PrimitiveType::UInt64); +native_type!(i8, Bytes1Alignment1, PrimitiveType::Int8); +native_type!(i16, Bytes2Alignment2, PrimitiveType::Int16); +native_type!(i32, Bytes4Alignment4, PrimitiveType::Int32); +native_type!(i64, Bytes8Alignment8, PrimitiveType::Int64); +native_type!(f32, Bytes4Alignment4, PrimitiveType::Float32); +native_type!(f64, Bytes8Alignment8, PrimitiveType::Float64); +native_type!(i128, Bytes16Alignment16, PrimitiveType::Int128); +native_type!(u128, Bytes16Alignment16, PrimitiveType::UInt128); /// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type. #[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, Pod)] @@ -150,7 +160,10 @@ impl MinMax for days_ms { impl NativeType for days_ms { const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs; + type Bytes = [u8; 8]; + type AlignedBytes = Bytes8Alignment4; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { let days = self.0.to_le_bytes(); @@ -288,7 +301,10 @@ impl MinMax for months_days_ns { impl NativeType for months_days_ns { const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano; + type Bytes = [u8; 16]; + type AlignedBytes = Bytes16Alignment8; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { let months = self.months().to_le_bytes(); @@ -434,6 +450,44 @@ impl PartialEq for f16 { } } +/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +#[inline] +pub fn canonical_f16(x: f16) -> f16 { + // zero out the sign bit if the f16 is zero. + let convert_zero = f16(x.0 & (0x7FFF | (u16::from(x.0 & 0x7FFF == 0) << 15))); + if convert_zero.is_nan() { + f16::from_bits(0x7c00) // Canonical quiet NaN. + } else { + convert_zero + } +} + +impl TotalHash for f16 { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f16(*self).to_bits().hash(state) + } +} + +impl ToTotalOrd for f16 { + type TotalOrdItem = TotalOrdWrap; + type SourceItem = f16; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + impl IsNull for f16 { const HAS_NULLS: bool = false; type Inner = f16; @@ -619,7 +673,10 @@ impl MinMax for f16 { impl NativeType for f16 { const PRIMITIVE: PrimitiveType = PrimitiveType::Float16; + type Bytes = [u8; 2]; + type AlignedBytes = Bytes2Alignment2; + #[inline] fn to_le_bytes(&self) -> Self::Bytes { self.0.to_le_bytes() @@ -719,6 +776,7 @@ impl NativeType for i256 { const PRIMITIVE: PrimitiveType = PrimitiveType::Int256; type Bytes = [u8; 32]; + type AlignedBytes = Bytes32Alignment16; #[inline] fn to_le_bytes(&self) -> Self::Bytes { diff --git a/crates/polars-arrow/src/util/macros.rs b/crates/polars-arrow/src/util/macros.rs index b09a9d5d5473..fb5bd61ebba0 100644 --- a/crates/polars-arrow/src/util/macros.rs +++ b/crates/polars-arrow/src/util/macros.rs @@ -13,6 +13,7 @@ macro_rules! with_match_primitive_type {( UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, + Int128 => __with_ty__! { i128 }, Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, _ => panic!("operator does not support primitive `{:?}`", diff --git a/crates/polars-compute/src/arity.rs b/crates/polars-compute/src/arity.rs index 33c8b0eb0584..7f99752e7afe 100644 --- a/crates/polars-compute/src/arity.rs +++ b/crates/polars-compute/src/arity.rs @@ -52,9 +52,7 @@ where let len = arr.len(); // Reuse memory if possible. - if std::mem::size_of::() == std::mem::size_of::() - && std::mem::align_of::() == std::mem::align_of::() - { + if size_of::() == size_of::() && align_of::() == align_of::() { if let Some(values) = arr.get_mut_values() { let ptr = values.as_mut_ptr(); // SAFETY: checked same size & alignment I/O, NativeType is always Pod. @@ -93,9 +91,7 @@ where let validity = combine_validities_and(lhs.validity(), rhs.validity()); // Reuse memory if possible. - if std::mem::size_of::() == std::mem::size_of::() - && std::mem::align_of::() == std::mem::align_of::() - { + if size_of::() == size_of::() && align_of::() == align_of::() { if let Some(lv) = lhs.get_mut_values() { let lp = lv.as_mut_ptr(); let rp = rhs.values().as_ptr(); @@ -106,9 +102,7 @@ where return lhs.transmute::().with_validity(validity); } } - if std::mem::size_of::() == std::mem::size_of::() - && std::mem::align_of::() == std::mem::align_of::() - { + if size_of::() == size_of::() && align_of::() == align_of::() { if let Some(rv) = rhs.get_mut_values() { let lp = lhs.values().as_ptr(); let rp = rv.as_mut_ptr(); diff --git a/crates/polars-compute/src/cardinality.rs b/crates/polars-compute/src/cardinality.rs new file mode 100644 index 000000000000..d28efa9d051e --- /dev/null +++ b/crates/polars-compute/src/cardinality.rs @@ -0,0 +1,159 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeBinaryArray, PrimitiveArray, + Utf8Array, Utf8ViewArray, +}; +use arrow::datatypes::PhysicalType; +use arrow::types::Offset; +use arrow::with_match_primitive_type_full; +use polars_utils::total_ord::ToTotalOrd; + +use crate::hyperloglogplus::HyperLogLog; + +/// Get an estimate for the *cardinality* of the array (i.e. the number of unique values) +/// +/// This is not currently implemented for nested types. +pub fn estimate_cardinality(array: &dyn Array) -> usize { + if array.is_empty() { + return 0; + } + + if array.null_count() == array.len() { + return 1; + } + + // Estimate the cardinality with HyperLogLog + use PhysicalType as PT; + match array.dtype().to_physical_type() { + PT::Null => 1, + + PT::Boolean => { + let mut cardinality = 0; + + let array = array.as_any().downcast_ref::().unwrap(); + + cardinality += usize::from(array.has_nulls()); + + if let Some(unset_bits) = array.values().lazy_unset_bits() { + cardinality += 1 + usize::from(unset_bits != array.len()); + } else { + cardinality += 2; + } + + cardinality + }, + + PT::Primitive(primitive_type) => with_match_primitive_type_full!(primitive_type, |$T| { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.copied().unwrap_or_default(); + hll.add(&v.to_total_ord()); + } + } else { + for v in array.values_iter() { + hll.add(&v.to_total_ord()); + } + } + + hll.count() + }), + PT::FixedSizeBinary => { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() + }, + PT::Binary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::LargeBinary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::Utf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::LargeUtf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::BinaryView => { + binary_view_array_estimate(array.as_any().downcast_ref::().unwrap()) + }, + PT::Utf8View => binary_view_array_estimate( + &array + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(), + ), + PT::List => unimplemented!(), + PT::FixedSizeList => unimplemented!(), + PT::LargeList => unimplemented!(), + PT::Struct => unimplemented!(), + PT::Union => unimplemented!(), + PT::Map => unimplemented!(), + PT::Dictionary(_) => unimplemented!(), + } +} + +fn binary_offset_array_estimate(array: &BinaryArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} + +fn binary_view_array_estimate(array: &BinaryViewArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} diff --git a/crates/polars-compute/src/comparisons/simd.rs b/crates/polars-compute/src/comparisons/simd.rs index f855ed4ad1c0..95ea9707744a 100644 --- a/crates/polars-compute/src/comparisons/simd.rs +++ b/crates/polars-compute/src/comparisons/simd.rs @@ -17,7 +17,7 @@ where T: NativeType, F: FnMut(&[T; N], &[T; N]) -> M, { - assert_eq!(N, std::mem::size_of::() * 8); + assert_eq!(N, size_of::() * 8); assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -29,7 +29,7 @@ where let rhs_rest = rhs_chunks.remainder(); let num_masks = n.div_ceil(N); - let mut v: Vec = Vec::with_capacity(num_masks * std::mem::size_of::()); + let mut v: Vec = Vec::with_capacity(num_masks * size_of::()); let mut p = v.as_mut_ptr() as *mut M; for (l, r) in lhs_chunks.zip(rhs_chunks) { unsafe { @@ -53,7 +53,7 @@ where } unsafe { - v.set_len(num_masks * std::mem::size_of::()); + v.set_len(num_masks * size_of::()); } Bitmap::from_u8_vec(v, n) @@ -64,7 +64,7 @@ where T: NativeType, F: FnMut(&[T; N]) -> M, { - assert_eq!(N, std::mem::size_of::() * 8); + assert_eq!(N, size_of::() * 8); let n = arg.len(); let arg_buf = arg.values().as_slice(); @@ -72,7 +72,7 @@ where let arg_rest = arg_chunks.remainder(); let num_masks = n.div_ceil(N); - let mut v: Vec = Vec::with_capacity(num_masks * std::mem::size_of::()); + let mut v: Vec = Vec::with_capacity(num_masks * size_of::()); let mut p = v.as_mut_ptr() as *mut M; for a in arg_chunks { unsafe { @@ -91,7 +91,7 @@ where } unsafe { - v.set_len(num_masks * std::mem::size_of::()); + v.set_len(num_masks * size_of::()); } Bitmap::from_u8_vec(v, n) diff --git a/crates/polars-compute/src/filter/avx512.rs b/crates/polars-compute/src/filter/avx512.rs index 11466b137be8..237aed8ed483 100644 --- a/crates/polars-compute/src/filter/avx512.rs +++ b/crates/polars-compute/src/filter/avx512.rs @@ -5,7 +5,7 @@ use core::arch::x86_64::*; // structured functions. macro_rules! simd_filter { ($values: ident, $mask_bytes: ident, $out: ident, |$subchunk: ident, $m: ident: $MaskT: ty| $body:block) => {{ - const MASK_BITS: usize = std::mem::size_of::<$MaskT>() * 8; + const MASK_BITS: usize = <$MaskT>::BITS as usize; // Do a 64-element loop for sparse fast path. let chunks = $values.chunks_exact(64); diff --git a/crates/polars-compute/src/filter/primitive.rs b/crates/polars-compute/src/filter/primitive.rs index 9cc542b60978..4671fa4ff592 100644 --- a/crates/polars-compute/src/filter/primitive.rs +++ b/crates/polars-compute/src/filter/primitive.rs @@ -19,7 +19,7 @@ fn nop_filter<'a, T: Pod>( } pub fn filter_values(values: &[T], mask: &Bitmap) -> Vec { - match (std::mem::size_of::(), std::mem::align_of::()) { + match (size_of::(), align_of::()) { (1, 1) => cast_vec(filter_values_u8(cast_slice(values), mask)), (2, 2) => cast_vec(filter_values_u16(cast_slice(values), mask)), (4, 4) => cast_vec(filter_values_u32(cast_slice(values), mask)), diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index da56c65983db..a0957daeafcc 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -1,4 +1,6 @@ #![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "simd", feature(core_intrinsics))] // For fadd_algebraic. +#![cfg_attr(feature = "simd", allow(internal_features))] #![cfg_attr(feature = "simd", feature(avx512_target_feature))] #![cfg_attr( all(feature = "simd", target_arch = "x86_64"), @@ -10,6 +12,8 @@ use arrow::types::NativeType; pub mod arithmetic; pub mod arity; pub mod bitwise; +#[cfg(feature = "approx_unique")] +pub mod cardinality; pub mod comparisons; pub mod filter; pub mod float_sum; @@ -19,6 +23,7 @@ pub mod if_then_else; pub mod min_max; pub mod size; pub mod unique; +pub mod var_cov; // Trait to enable the scalar blanket implementation. pub trait NotSimdPrimitive: NativeType {} diff --git a/crates/polars-compute/src/var_cov.rs b/crates/polars-compute/src/var_cov.rs new file mode 100644 index 000000000000..d6c0267faec6 --- /dev/null +++ b/crates/polars-compute/src/var_cov.rs @@ -0,0 +1,327 @@ +// Some formulae: +// mean_x = sum(weight[i] * x[i]) / sum(weight) +// dp_xy = weighted sum of deviation products of variables x, y, written in +// the paper as simply XY. +// dp_xy = sum(weight[i] * (x[i] - mean_x) * (y[i] - mean_y)) +// +// cov(x, y) = dp_xy / sum(weight) +// var(x) = cov(x, x) +// +// Algorithms from: +// Numerically stable parallel computation of (co-)variance. +// Schubert, E., & Gertz, M. (2018). +// +// Key equations from the paper: +// (17) for mean update, (23) for dp update (and also Table 1). + +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType; +use num_traits::AsPrimitive; + +const CHUNK_SIZE: usize = 128; + +#[inline(always)] +fn alg_add(a: f64, b: f64) -> f64 { + #[cfg(feature = "simd")] + { + std::intrinsics::fadd_algebraic(a, b) + } + #[cfg(not(feature = "simd"))] + { + a + b + } +} + +fn alg_sum(it: impl IntoIterator) -> f64 { + it.into_iter().fold(0.0, alg_add) +} + +#[derive(Default, Clone)] +pub struct VarState { + weight: f64, + mean: f64, + dp: f64, +} + +#[derive(Default, Clone)] +pub struct CovState { + weight: f64, + mean_x: f64, + mean_y: f64, + dp_xy: f64, +} + +#[derive(Default, Clone)] +pub struct PearsonState { + weight: f64, + mean_x: f64, + mean_y: f64, + dp_xx: f64, + dp_xy: f64, + dp_yy: f64, +} + +impl VarState { + fn new(x: &[f64]) -> Self { + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let mean = alg_sum(x.iter().copied()) / weight; + Self { + weight, + mean, + dp: alg_sum(x.iter().map(|&xi| (xi - mean) * (xi - mean))), + } + } + + pub fn add_one(&mut self, x: f64) { + // Just a specialized version of + // self.combine(&Self { weight: 1.0, mean: x, dp: 0.0 }) + let new_weight = self.weight + 1.0; + let delta_mean = self.mean - x; + let new_mean = self.mean - delta_mean / new_weight; + self.dp += (new_mean - x) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean = self.mean - other.mean; + let new_mean = self.mean - delta_mean * other_weight_frac; + self.dp += other.dp + other.weight * (new_mean - other.mean) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + } + + pub fn finalize(&self, ddof: u8) -> Option { + if self.weight <= ddof as f64 { + None + } else { + Some(self.dp / (self.weight - ddof as f64)) + } + } +} + +impl CovState { + fn new(x: &[f64], y: &[f64]) -> Self { + assert!(x.len() == y.len()); + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let inv_weight = 1.0 / weight; + let mean_x = alg_sum(x.iter().copied()) * inv_weight; + let mean_y = alg_sum(y.iter().copied()) * inv_weight; + Self { + weight, + mean_x, + mean_y, + dp_xy: alg_sum( + x.iter() + .zip(y) + .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y)), + ), + } + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean_x = self.mean_x - other.mean_x; + let delta_mean_y = self.mean_y - other.mean_y; + let new_mean_x = self.mean_x - delta_mean_x * other_weight_frac; + let new_mean_y = self.mean_y - delta_mean_y * other_weight_frac; + self.dp_xy += other.dp_xy + other.weight * (new_mean_x - other.mean_x) * delta_mean_y; + self.weight = new_weight; + self.mean_x = new_mean_x; + self.mean_y = new_mean_y; + } + + pub fn finalize(&self, ddof: u8) -> Option { + if self.weight <= ddof as f64 { + None + } else { + Some(self.dp_xy / (self.weight - ddof as f64)) + } + } +} + +impl PearsonState { + fn new(x: &[f64], y: &[f64]) -> Self { + assert!(x.len() == y.len()); + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let inv_weight = 1.0 / weight; + let mean_x = alg_sum(x.iter().copied()) * inv_weight; + let mean_y = alg_sum(y.iter().copied()) * inv_weight; + let mut dp_xx = 0.0; + let mut dp_xy = 0.0; + let mut dp_yy = 0.0; + for (xi, yi) in x.iter().zip(y.iter()) { + dp_xx = alg_add(dp_xx, (xi - mean_x) * (xi - mean_x)); + dp_xy = alg_add(dp_xy, (xi - mean_x) * (yi - mean_y)); + dp_yy = alg_add(dp_yy, (yi - mean_y) * (yi - mean_y)); + } + Self { + weight, + mean_x, + mean_y, + dp_xx, + dp_xy, + dp_yy, + } + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean_x = self.mean_x - other.mean_x; + let delta_mean_y = self.mean_y - other.mean_y; + let new_mean_x = self.mean_x - delta_mean_x * other_weight_frac; + let new_mean_y = self.mean_y - delta_mean_y * other_weight_frac; + self.dp_xx += other.dp_xx + other.weight * (new_mean_x - other.mean_x) * delta_mean_x; + self.dp_xy += other.dp_xy + other.weight * (new_mean_x - other.mean_x) * delta_mean_y; + self.dp_yy += other.dp_yy + other.weight * (new_mean_y - other.mean_y) * delta_mean_y; + self.weight = new_weight; + self.mean_x = new_mean_x; + self.mean_y = new_mean_y; + } + + pub fn finalize(&self, _ddof: u8) -> f64 { + // The division by sample_weight - ddof on both sides cancels out. + let denom = (self.dp_xx * self.dp_yy).sqrt(); + if denom == 0.0 { + f64::NAN + } else { + self.dp_xy / denom + } + } +} + +fn chunk_as_float(it: I, mut f: F) +where + T: NativeType + AsPrimitive, + I: IntoIterator, + F: FnMut(&[f64]), +{ + let mut chunk = [0.0; CHUNK_SIZE]; + let mut i = 0; + for val in it { + if i >= CHUNK_SIZE { + f(&chunk); + i = 0; + } + chunk[i] = val.as_(); + i += 1; + } + if i > 0 { + f(&chunk[..i]); + } +} + +fn chunk_as_float_binary(it: I, mut f: F) +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, + I: IntoIterator, + F: FnMut(&[f64], &[f64]), +{ + let mut left_chunk = [0.0; CHUNK_SIZE]; + let mut right_chunk = [0.0; CHUNK_SIZE]; + let mut i = 0; + for (l, r) in it { + if i >= CHUNK_SIZE { + f(&left_chunk, &right_chunk); + i = 0; + } + left_chunk[i] = l.as_(); + right_chunk[i] = r.as_(); + i += 1; + } + if i > 0 { + f(&left_chunk[..i], &right_chunk[..i]); + } +} + +pub fn var(arr: &PrimitiveArray) -> VarState +where + T: NativeType + AsPrimitive, +{ + let mut out = VarState::default(); + if arr.has_nulls() { + chunk_as_float(arr.non_null_values_iter(), |chunk| { + out.combine(&VarState::new(chunk)) + }); + } else { + chunk_as_float(arr.values().iter().copied(), |chunk| { + out.combine(&VarState::new(chunk)) + }); + } + out +} + +pub fn cov(x: &PrimitiveArray, y: &PrimitiveArray) -> CovState +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, +{ + assert!(x.len() == y.len()); + let mut out = CovState::default(); + if x.has_nulls() || y.has_nulls() { + chunk_as_float_binary( + x.iter() + .zip(y.iter()) + .filter_map(|(l, r)| l.copied().zip(r.copied())), + |l, r| out.combine(&CovState::new(l, r)), + ); + } else { + chunk_as_float_binary( + x.values().iter().copied().zip(y.values().iter().copied()), + |l, r| out.combine(&CovState::new(l, r)), + ); + } + out +} + +pub fn pearson_corr(x: &PrimitiveArray, y: &PrimitiveArray) -> PearsonState +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, +{ + assert!(x.len() == y.len()); + let mut out = PearsonState::default(); + if x.has_nulls() || y.has_nulls() { + chunk_as_float_binary( + x.iter() + .zip(y.iter()) + .filter_map(|(l, r)| l.copied().zip(r.copied())), + |l, r| out.combine(&PearsonState::new(l, r)), + ); + } else { + chunk_as_float_binary( + x.values().iter().copied().zip(y.values().iter().copied()), + |l, r| out.combine(&PearsonState::new(l, r)), + ); + } + out +} diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index c9d68f0ee173..bb5cdc85cdac 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -17,7 +17,6 @@ polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true, optional = true } bitflags = { workspace = true } bytemuck = { workspace = true } chrono = { workspace = true, optional = true } @@ -37,6 +36,7 @@ regex = { workspace = true, optional = true } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } thiserror = { workspace = true } xxhash-rust = { workspace = true } @@ -100,7 +100,7 @@ partition_by = ["algorithm_group_by"] describe = [] timezones = ["temporal", "chrono", "chrono-tz", "arrow/chrono-tz", "arrow/timezones"] dynamic_group_by = ["dtype-datetime", "dtype-date"] -arrow_rs = ["arrow-array", "arrow/arrow_rs"] +list_arithmetic = [] # opt-in datatypes for Series dtype-date = ["temporal"] @@ -148,6 +148,7 @@ docs-selection = [ "describe", "partition_by", "algorithm_group_by", + "list_arithmetic", ] [package.metadata.docs.rs] diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 98ae4962b1ae..fc8f993400d1 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -206,10 +206,9 @@ where // - remain signed // - unsigned -> signed // this may still fail with overflow? - let dtype = self.dtype(); - let to_signed = dtype.is_signed_integer(); - let unsigned2unsigned = dtype.is_unsigned_integer() && dtype.is_unsigned_integer(); + let unsigned2unsigned = + self.dtype().is_unsigned_integer() && dtype.is_unsigned_integer(); let allowed = to_signed || unsigned2unsigned; if (allowed) diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index ed8e18a27e43..e9b9efd2cae0 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -772,7 +772,7 @@ fn struct_helper( b: &StructChunked, op: F, reduce: R, - value: bool, + op_is_ne: bool, is_missing: bool, ) -> BooleanChunked where @@ -783,7 +783,7 @@ where let len_b = b.len(); let broadcasts = len_a == 1 || len_b == 1; if (a.len() != b.len() && !broadcasts) || a.struct_fields().len() != b.struct_fields().len() { - BooleanChunked::full(PlSmallStr::EMPTY, value, a.len()) + BooleanChunked::full(PlSmallStr::EMPTY, op_is_ne, a.len()) } else { let (a, b) = align_chunks_binary(a, b); @@ -792,8 +792,34 @@ where .iter() .zip(b.fields_as_series().iter()) .map(|(l, r)| op(l, r)) - .reduce(reduce) - .unwrap_or_else(|| BooleanChunked::full(PlSmallStr::EMPTY, !value, a.len())); + .reduce(&reduce) + .unwrap_or_else(|| BooleanChunked::full(PlSmallStr::EMPTY, !op_is_ne, a.len())); + + if is_missing && (a.has_nulls() || b.has_nulls()) { + // Do some allocations so that we can use the Series dispatch, it otherwise + // gets complicated dealing with combinations of ==, != and broadcasting. + let default = || { + BooleanChunked::with_chunk(PlSmallStr::EMPTY, BooleanArray::from_slice([true])) + .into_series() + }; + let validity_to_series = |x| unsafe { + BooleanChunked::with_chunk( + PlSmallStr::EMPTY, + BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None), + ) + .into_series() + }; + + out = reduce( + out, + op( + &a.rechunk_validity() + .map_or_else(default, validity_to_series), + &b.rechunk_validity() + .map_or_else(default, validity_to_series), + ), + ) + } if !is_missing && (a.null_count() > 0 || b.null_count() > 0) { let mut a = a.into_owned(); diff --git a/crates/polars-core/src/chunked_array/float.rs b/crates/polars-core/src/chunked_array/float.rs index 8376629cc403..5d9bae240062 100644 --- a/crates/polars-core/src/chunked_array/float.rs +++ b/crates/polars-core/src/chunked_array/float.rs @@ -1,4 +1,3 @@ -use arrow::legacy::kernels::float::*; use arrow::legacy::kernels::set::set_at_nulls; use num_traits::Float; use polars_utils::total_ord::{canonical_f32, canonical_f64}; @@ -12,16 +11,16 @@ where T::Native: Float, { pub fn is_nan(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_nan::) + unary_elementwise_values(self, |x| x.is_nan()) } pub fn is_not_nan(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_not_nan::) + unary_elementwise_values(self, |x| !x.is_nan()) } pub fn is_finite(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_finite) + unary_elementwise_values(self, |x| x.is_finite()) } pub fn is_infinite(&self) -> BooleanChunked { - self.apply_kernel_cast(&is_infinite) + unary_elementwise_values(self, |x| x.is_infinite()) } #[must_use] diff --git a/crates/polars-core/src/chunked_array/from_iterator.rs b/crates/polars-core/src/chunked_array/from_iterator.rs index 5d784fb51fd9..ba9e8d1e6ccc 100644 --- a/crates/polars-core/src/chunked_array/from_iterator.rs +++ b/crates/polars-core/src/chunked_array/from_iterator.rs @@ -81,12 +81,12 @@ impl PolarsAsRef for &str {} // &["foo", "bar"] impl PolarsAsRef for &&str {} -impl<'a> PolarsAsRef for Cow<'a, str> {} +impl PolarsAsRef for Cow<'_, str> {} impl PolarsAsRef<[u8]> for Vec {} impl PolarsAsRef<[u8]> for &[u8] {} // TODO: remove! impl PolarsAsRef<[u8]> for &&[u8] {} -impl<'a> PolarsAsRef<[u8]> for Cow<'a, [u8]> {} +impl PolarsAsRef<[u8]> for Cow<'_, [u8]> {} impl FromIterator for StringChunked where diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index 728ffc5a8cff..a87d888968f3 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -25,7 +25,7 @@ pub trait PolarsIterator: ExactSizeIterator + DoubleEndedIterator + Send + Sync + TrustedLen { } -unsafe impl<'a, I> TrustedLen for Box + 'a> {} +unsafe impl TrustedLen for Box + '_> {} /// Implement [`PolarsIterator`] for every iterator that implements the needed traits. impl PolarsIterator for T where @@ -79,7 +79,7 @@ impl<'a> BoolIterNoNull<'a> { } } -impl<'a> Iterator for BoolIterNoNull<'a> { +impl Iterator for BoolIterNoNull<'_> { type Item = bool; fn next(&mut self) -> Option { @@ -100,7 +100,7 @@ impl<'a> Iterator for BoolIterNoNull<'a> { } } -impl<'a> DoubleEndedIterator for BoolIterNoNull<'a> { +impl DoubleEndedIterator for BoolIterNoNull<'_> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -112,7 +112,7 @@ impl<'a> DoubleEndedIterator for BoolIterNoNull<'a> { } /// all arrays have known size. -impl<'a> ExactSizeIterator for BoolIterNoNull<'a> {} +impl ExactSizeIterator for BoolIterNoNull<'_> {} impl BooleanChunked { #[allow(clippy::wrong_self_convention)] @@ -339,7 +339,7 @@ impl<'a> FixedSizeListIterNoNull<'a> { } #[cfg(feature = "dtype-array")] -impl<'a> Iterator for FixedSizeListIterNoNull<'a> { +impl Iterator for FixedSizeListIterNoNull<'_> { type Item = Series; fn next(&mut self) -> Option { @@ -367,7 +367,7 @@ impl<'a> Iterator for FixedSizeListIterNoNull<'a> { } #[cfg(feature = "dtype-array")] -impl<'a> DoubleEndedIterator for FixedSizeListIterNoNull<'a> { +impl DoubleEndedIterator for FixedSizeListIterNoNull<'_> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -388,7 +388,7 @@ impl<'a> DoubleEndedIterator for FixedSizeListIterNoNull<'a> { /// all arrays have known size. #[cfg(feature = "dtype-array")] -impl<'a> ExactSizeIterator for FixedSizeListIterNoNull<'a> {} +impl ExactSizeIterator for FixedSizeListIterNoNull<'_> {} #[cfg(feature = "dtype-array")] impl ArrayChunked { diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index b575dcdf5a65..2c48da805171 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -18,7 +18,7 @@ pub struct AmortizedListIter<'a, I: Iterator>> { inner_dtype: DataType, } -impl<'a, I: Iterator>> AmortizedListIter<'a, I> { +impl>> AmortizedListIter<'_, I> { pub(crate) unsafe fn new( len: usize, series_container: Series, @@ -37,7 +37,7 @@ impl<'a, I: Iterator>> AmortizedListIter<'a, I> { } } -impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a, I> { +impl>> Iterator for AmortizedListIter<'_, I> { type Item = Option; fn next(&mut self) -> Option { @@ -106,8 +106,8 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a // # Safety // we correctly implemented size_hint -unsafe impl<'a, I: Iterator>> TrustedLen for AmortizedListIter<'a, I> {} -impl<'a, I: Iterator>> ExactSizeIterator for AmortizedListIter<'a, I> {} +unsafe impl>> TrustedLen for AmortizedListIter<'_, I> {} +impl>> ExactSizeIterator for AmortizedListIter<'_, I> {} impl ListChunked { /// This is an iterator over a [`ListChunked`] that saves allocations. @@ -152,7 +152,7 @@ impl ListChunked { let (s, ptr) = unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) }; - // SAFETY: ptr belongs the the Series.. + // SAFETY: ptr belongs the Series.. unsafe { AmortizedListIter::new( self.len(), diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index 8b730966b1bc..7aff61172f04 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -46,25 +46,6 @@ impl ListChunked { } } - /// Returns an iterator over the offsets of this chunked array. - /// - /// The offsets are returned as though the array consisted of a single chunk. - pub fn iter_offsets(&self) -> impl Iterator + '_ { - let mut offsets = self.downcast_iter().map(|arr| arr.offsets().iter()); - let first_iter = offsets.next().unwrap(); - - // The first offset doesn't have to be 0, it can be sliced to `n` in the array. - // So we must correct for this. - let correction = first_iter.clone().next().unwrap(); - - OffsetsIterator { - current_offsets_iter: first_iter, - current_adjusted_offset: 0, - offset_adjustment: -correction, - offsets_iters: offsets, - } - } - /// Ignore the list indices and apply `func` to the inner type as [`Series`]. pub fn apply_to_inner( &self, @@ -110,33 +91,14 @@ impl ListChunked { ) }) } -} - -pub struct OffsetsIterator<'a, N> -where - N: Iterator>, -{ - offsets_iters: N, - current_offsets_iter: std::slice::Iter<'a, i64>, - current_adjusted_offset: i64, - offset_adjustment: i64, -} -impl<'a, N> Iterator for OffsetsIterator<'a, N> -where - N: Iterator>, -{ - type Item = i64; - - fn next(&mut self) -> Option { - if let Some(offset) = self.current_offsets_iter.next() { - self.current_adjusted_offset = offset + self.offset_adjustment; - Some(self.current_adjusted_offset) - } else { - self.current_offsets_iter = self.offsets_iters.next()?; - let first = self.current_offsets_iter.next().unwrap(); - self.offset_adjustment = self.current_adjusted_offset - first; - self.next() - } + pub fn rechunk_and_trim_to_normalized_offsets(&self) -> Self { + Self::with_chunk( + self.name().clone(), + self.rechunk() + .downcast_get(0) + .unwrap() + .trim_to_normalized_offsets_recursive(), + ) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index a59ff68e40d9..8ccd455e4bd0 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -203,6 +203,24 @@ impl CategoricalChunked { } } + /// Create a [`CategoricalChunked`] from a physical array and dtype. + /// + /// # Safety + /// It's not checked that the indices are in-bounds or that the dtype is + /// correct. + pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self { + debug_assert!(matches!( + dtype, + DataType::Enum { .. } | DataType::Categorical { .. } + )); + let mut logical = Logical::::new_logical::(idx); + logical.2 = Some(dtype); + Self { + physical: logical, + bit_settings: Default::default(), + } + } + /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`. /// /// # Safety @@ -425,7 +443,7 @@ pub struct CatIter<'a> { iter: Box> + 'a>, } -unsafe impl<'a> TrustedLen for CatIter<'a> {} +unsafe impl TrustedLen for CatIter<'_> {} impl<'a> Iterator for CatIter<'a> { type Item = Option<&'a str>; @@ -445,7 +463,7 @@ impl<'a> Iterator for CatIter<'a> { } } -impl<'a> ExactSizeIterator for CatIter<'a> {} +impl ExactSizeIterator for CatIter<'_> {} #[cfg(test)] mod test { diff --git a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs index e143a59a7f7b..5279099f1cd7 100644 --- a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs @@ -85,16 +85,12 @@ impl EnumChunkedBuilder { let length = arr.len() as IdxSize; let ca = unsafe { UInt32Chunked::new_with_dims( - Arc::new(Field::new( - self.name, - DataType::Enum(Some(self.rev.clone()), self.ordering), - )), + Arc::new(Field::new(self.name, DataType::UInt32)), vec![Box::new(arr)], length, null_count, ) }; - // SAFETY: keys and values are in bounds unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering) diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 45f63847e97f..9383634cdcac 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -80,7 +80,7 @@ pub(crate) fn get_object_type() -> DataType { Box::new(ObjectChunkedBuilder::::new(name, capacity)) as Box }); - let object_size = std::mem::size_of::(); + let object_size = size_of::(); let physical_dtype = ArrowDataType::FixedSizeBinary(object_size); let registry = ObjectRegistry::new(object_builder, physical_dtype); diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 89ccd65a7c1a..f9167b200211 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -29,7 +29,7 @@ pub fn set_polars_allow_extension(toggle: bool) { /// `n_t_vals` must represent the correct number of `T` values in that allocation unsafe fn create_drop(mut ptr: *const u8, n_t_vals: usize) -> Box { Box::new(move || { - let t_size = std::mem::size_of::() as isize; + let t_size = size_of::() as isize; for _ in 0..n_t_vals { let _ = std::ptr::read_unaligned(ptr as *const T); ptr = ptr.offset(t_size) @@ -55,7 +55,7 @@ impl Drop for ExtensionSentinel { // https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8d // not entirely sure if padding bytes in T are initialized or not. unsafe fn any_as_u8_slice(p: &T) -> &[u8] { - std::slice::from_raw_parts((p as *const T) as *const u8, std::mem::size_of::()) + std::slice::from_raw_parts((p as *const T) as *const u8, size_of::()) } /// Create an extension Array that can be sent to arrow and (once wrapped in `[PolarsExtension]` will @@ -67,8 +67,8 @@ pub(crate) fn create_extension> + TrustedLen, T: Si if !(POLARS_ALLOW_EXTENSION.load(Ordering::Relaxed) || std::env::var(env).is_ok()) { panic!("creating extension types not allowed - try setting the environment variable {env}") } - let t_size = std::mem::size_of::(); - let t_alignment = std::mem::align_of::(); + let t_size = size_of::(); + let t_alignment = align_of::(); let n_t_vals = iter.size_hint().1.unwrap(); let mut buf = Vec::with_capacity(n_t_vals * t_size); diff --git a/crates/polars-core/src/chunked_array/object/iterator.rs b/crates/polars-core/src/chunked_array/object/iterator.rs index 7a5c6e00b590..7abb9c46f4ee 100644 --- a/crates/polars-core/src/chunked_array/object/iterator.rs +++ b/crates/polars-core/src/chunked_array/object/iterator.rs @@ -54,7 +54,7 @@ impl<'a, T: PolarsObject> std::iter::Iterator for ObjectIter<'a, T> { } } -impl<'a, T: PolarsObject> std::iter::DoubleEndedIterator for ObjectIter<'a, T> { +impl std::iter::DoubleEndedIterator for ObjectIter<'_, T> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -75,7 +75,7 @@ impl<'a, T: PolarsObject> std::iter::DoubleEndedIterator for ObjectIter<'a, T> { } /// all arrays have known size. -impl<'a, T: PolarsObject> std::iter::ExactSizeIterator for ObjectIter<'a, T> {} +impl std::iter::ExactSizeIterator for ObjectIter<'_, T> {} impl<'a, T: PolarsObject> IntoIterator for &'a ObjectArray { type Item = Option<&'a T>; diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index a7e3d2f9952d..8f4711976856 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -164,7 +164,7 @@ where } fn dtype(&self) -> &ArrowDataType { - &ArrowDataType::FixedSizeBinary(std::mem::size_of::()) + &ArrowDataType::FixedSizeBinary(size_of::()) } fn slice(&mut self, offset: usize, length: usize) { @@ -275,7 +275,7 @@ impl StaticArray for ObjectArray { impl ParameterFreeDtypeStaticArray for ObjectArray { fn get_dtype() -> ArrowDataType { - ArrowDataType::FixedSizeBinary(std::mem::size_of::()) + ArrowDataType::FixedSizeBinary(size_of::()) } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 071073460ff3..0a059eb54274 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -11,6 +11,7 @@ use num_traits::{Float, One, ToPrimitive, Zero}; use polars_compute::float_sum; use polars_compute::min_max::MinMaxKernel; use polars_utils::min_max::MinMax; +use polars_utils::sync::SyncPtr; pub use quantile::*; pub use var::*; @@ -369,12 +370,8 @@ where ::Simd: Add::Simd> + compute::aggregate::Sum, { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -385,12 +382,8 @@ where } impl QuantileAggSeries for Float32Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float32, v.into())) } @@ -401,12 +394,8 @@ impl QuantileAggSeries for Float32Chunked { } impl QuantileAggSeries for Float64Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -553,12 +542,54 @@ impl CategoricalChunked { #[cfg(feature = "dtype-categorical")] impl ChunkAggSeries for CategoricalChunked { fn min_reduce(&self) -> Scalar { - let av: AnyValue = self.min_categorical().into(); - Scalar::new(DataType::String, av.into_static()) + match self.dtype() { + DataType::Enum(r, _) => match self.physical().min() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(_, _) => { + let av: AnyValue = self.min_categorical().into(); + Scalar::new(DataType::String, av.into_static()) + }, + _ => unreachable!(), + } } fn max_reduce(&self) -> Scalar { - let av: AnyValue = self.max_categorical().into(); - Scalar::new(DataType::String, av.into_static()) + match self.dtype() { + DataType::Enum(r, _) => match self.physical().max() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(_, _) => { + let av: AnyValue = self.max_categorical().into(); + Scalar::new(DataType::String, av.into_static()) + }, + _ => unreachable!(), + } } } @@ -735,19 +766,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_f64.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i64.quantile(0.9, interpol).unwrap(), None); + for method in methods { + assert_eq!(test_f32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_f64.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i64.quantile(0.9, method).unwrap(), None); } } @@ -758,19 +790,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i64.quantile(0.5, interpol).unwrap(), Some(1.0)); + for method in methods { + assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0)); } } @@ -793,37 +826,38 @@ mod test { &[None, Some(1i64), Some(5i64), Some(1i64)], ); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.0, interpol).unwrap(), test_f32.min()); - assert_eq!(test_f32.quantile(1.0, interpol).unwrap(), test_f32.max()); + for method in methods { + assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min()); + assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max()); assert_eq!( - test_i32.quantile(0.0, interpol).unwrap().unwrap(), + test_i32.quantile(0.0, method).unwrap().unwrap(), test_i32.min().unwrap() as f64 ); assert_eq!( - test_i32.quantile(1.0, interpol).unwrap().unwrap(), + test_i32.quantile(1.0, method).unwrap().unwrap(), test_i32.max().unwrap() as f64 ); - assert_eq!(test_f64.quantile(0.0, interpol).unwrap(), test_f64.min()); - assert_eq!(test_f64.quantile(1.0, interpol).unwrap(), test_f64.max()); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), test_f64.median()); + assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min()); + assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max()); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median()); assert_eq!( - test_i64.quantile(0.0, interpol).unwrap().unwrap(), + test_i64.quantile(0.0, method).unwrap().unwrap(), test_i64.min().unwrap() as f64 ); assert_eq!( - test_i64.quantile(1.0, interpol).unwrap().unwrap(), + test_i64.quantile(1.0, method).unwrap().unwrap(), test_i64.max().unwrap() as f64 ); } @@ -837,72 +871,56 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(5.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(3.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(3.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(3.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert!( + (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001 + ); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); - assert!( - (ca.quantile(0.6, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - - 3.4) - .abs() - < 0.0000001 + assert_eq!( + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(3.0) ); let ca = UInt32Chunked::new( @@ -922,68 +940,54 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(2.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(6.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(5.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(6.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(7.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(6.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.6) + ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(6.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(5.0) ); } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index d6218e81d463..f7716c864559 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -4,11 +4,7 @@ pub trait QuantileAggSeries { /// Get the median of the [`ChunkedArray`] as a new [`Series`] of length 1. fn median_reduce(&self) -> Scalar; /// Get the quantile of the [`ChunkedArray`] as a new [`Series`] of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult; + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult; } /// helper @@ -16,18 +12,23 @@ fn quantile_idx( quantile: f64, length: usize, null_count: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> (usize, f64, usize) { - let float_idx = ((length - null_count) as f64 - 1.0) * quantile + null_count as f64; - let mut base_idx = match interpol { - QuantileInterpolOptions::Nearest => { + let nonnull_count = (length - null_count) as f64; + let float_idx = (nonnull_count - 1.0) * quantile + null_count as f64; + let mut base_idx = match method { + QuantileMethod::Nearest => { let idx = float_idx.round() as usize; - return (float_idx.round() as usize, 0.0, idx); + return (idx, 0.0, idx); + }, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { + float_idx as usize + }, + QuantileMethod::Higher => float_idx.ceil() as usize, + QuantileMethod::Equiprobable => { + let idx = ((nonnull_count * quantile).ceil() - 1.0).max(0.0) as usize + null_count; + return (idx, 0.0, idx); }, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => float_idx as usize, - QuantileInterpolOptions::Higher => float_idx.ceil() as usize, }; base_idx = base_idx.clamp(0, length - 1); @@ -57,7 +58,7 @@ fn midpoint_interpol(lower: T, upper: T) -> T { fn quantile_slice( vals: &mut [T], quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { polars_ensure!((0.0..=1.0).contains(&quantile), ComputeError: "quantile should be between 0.0 and 1.0", @@ -68,21 +69,21 @@ fn quantile_slice( if vals.len() == 1 { return Ok(vals[0].to_f64()); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, method); let (_lhs, lower, rhs) = vals.select_nth_unstable_by(idx, TotalOrd::tot_cmp); if idx == top_idx { Ok(lower.to_f64()) } else { - match interpol { - QuantileInterpolOptions::Midpoint => { + match method { + QuantileMethod::Midpoint => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(Some(midpoint_interpol( lower.to_f64().unwrap(), upper.to_f64().unwrap(), ))) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(linear_interpol( lower.to_f64().unwrap(), @@ -100,7 +101,7 @@ fn quantile_slice( fn generic_quantile( ca: ChunkedArray, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> where T: PolarsNumericType, @@ -117,12 +118,12 @@ where return Ok(None); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, method); let sorted = ca.sort(false); let lower = sorted.get(idx).map(|v| v.to_f64().unwrap()); - let opt = match interpol { - QuantileInterpolOptions::Midpoint => { + let opt = match method { + QuantileMethod::Midpoint => { if top_idx == idx { lower } else { @@ -130,7 +131,7 @@ where midpoint_interpol(lower.unwrap(), upper.unwrap()).to_f64() } }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { if top_idx == idx { lower } else { @@ -149,22 +150,18 @@ where T: PolarsIntegerType, T::Native: TotalOrd, { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -177,61 +174,52 @@ where pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } impl ChunkQuantile for Float32Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let out = if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) }; out.map(|v| v.map(|v| v as f32)) } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } impl ChunkQuantile for Float64Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -239,20 +227,19 @@ impl Float64Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } @@ -260,20 +247,19 @@ impl Float32Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol).map(|v| v.map(|v| v as f32)) + quantile_slice(slice, quantile, method).map(|v| v.map(|v| v as f32)) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs index 1ca04cc2b30e..ea332f0cc432 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs @@ -1,4 +1,4 @@ -use arity::unary_elementwise_values; +use polars_compute::var_cov::VarState; use super::*; @@ -15,20 +15,11 @@ where ChunkedArray: ChunkAgg, { fn var(&self, ddof: u8) -> Option { - let n_values = self.len() - self.null_count(); - if n_values <= ddof as usize { - return None; + let mut out = VarState::default(); + for arr in self.downcast_iter() { + out.combine(&polars_compute::var_cov::var(arr)) } - - let mean = self.mean()?; - let squared: Float64Chunked = unary_elementwise_values(self, |value| { - let tmp = value.to_f64().unwrap() - mean; - tmp * tmp - }); - - squared - .sum() - .map(|sum| sum / (n_values as f64 - ddof as f64)) + out.finalize(ddof) } fn std(&self, ddof: u8) -> Option { diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 383c76d63600..1cbf0da390e7 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -132,7 +132,7 @@ where impl ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, for<'a> T::Physical<'a>: TotalOrd, { /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs index 7b20d77e2444..7236fde9993c 100644 --- a/crates/polars-core/src/chunked_array/ops/bit_repr.rs +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -8,8 +8,8 @@ use crate::series::BitRepr; fn reinterpret_chunked_array( ca: &ChunkedArray, ) -> ChunkedArray { - assert!(std::mem::size_of::() == std::mem::size_of::()); - assert!(std::mem::align_of::() == std::mem::align_of::()); + assert!(size_of::() == size_of::()); + assert!(align_of::() == align_of::()); let chunks = ca.downcast_iter().map(|array| { let buf = array.values().clone(); @@ -29,8 +29,8 @@ fn reinterpret_chunked_array( fn reinterpret_list_chunked( ca: &ListChunked, ) -> ListChunked { - assert!(std::mem::size_of::() == std::mem::size_of::()); - assert!(std::mem::align_of::() == std::mem::align_of::()); + assert!(size_of::() == size_of::()); + assert!(align_of::() == align_of::()); let chunks = ca.downcast_iter().map(|array| { let inner_arr = array @@ -105,7 +105,7 @@ where T: PolarsNumericType, { fn to_bit_repr(&self) -> BitRepr { - let is_large = std::mem::size_of::() == 8; + let is_large = size_of::() == 8; if is_large { if matches!(self.dtype(), DataType::UInt64) { @@ -118,7 +118,7 @@ where BitRepr::Large(reinterpret_chunked_array(self)) } else { - BitRepr::Small(if std::mem::size_of::() == 4 { + BitRepr::Small(if size_of::() == 4 { if matches!(self.dtype(), DataType::UInt32) { let ca = self.clone(); // Convince the compiler we are this type. This preserves flags. diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index cb24305f75f6..fc162626bc27 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -143,7 +143,7 @@ unsafe fn gather_idx_array_unchecked( impl + ?Sized> ChunkTakeUnchecked for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &I) -> Self { @@ -178,7 +178,7 @@ pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> impl ChunkTakeUnchecked for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { @@ -312,3 +312,47 @@ impl IdxCa { f(&ca) } } + +#[cfg(feature = "dtype-array")] +impl ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +#[cfg(feature = "dtype-array")] +impl + ?Sized> ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} + +impl ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +impl + ?Sized> ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 33f43d530e45..c0daaa72bdf6 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -34,6 +34,7 @@ pub(crate) mod nulls; mod reverse; #[cfg(feature = "rolling_window")] pub(crate) mod rolling_window; +pub mod row_encode; pub mod search_sorted; mod set; mod shift; @@ -278,11 +279,7 @@ pub trait ChunkQuantile { } /// Aggregate a given quantile of the ChunkedArray. /// Returns `None` if the array is empty or only contains null values. - fn quantile( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult> { Ok(None) } } diff --git a/crates/polars-core/src/chunked_array/ops/row_encode.rs b/crates/polars-core/src/chunked_array/ops/row_encode.rs new file mode 100644 index 000000000000..5ac627327389 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/row_encode.rs @@ -0,0 +1,220 @@ +use arrow::compute::utils::combine_validities_and_many; +use polars_row::{convert_columns, EncodingField, RowsEncoded}; +use rayon::prelude::*; + +use crate::prelude::*; +use crate::utils::_split_offsets; +use crate::POOL; + +pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => s.rechunk(), + Binary | Boolean => s.clone(), + BinaryOffset => s.clone(), + String => s.str().unwrap().as_binary().into_series(), + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = s.struct_().unwrap(); + let new_fields = ca + .fields_as_series() + .iter() + .map(convert_series_for_row_encoding) + .collect::>>()?; + let mut out = + StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?; + out.zip_outer_validity(ca); + out.into_series() + }, + // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => s.clone(), + List(inner) if !inner.is_nested() => s.clone(), + Null => s.clone(), + _ => { + let phys = s.to_physical_repr().into_owned(); + polars_ensure!( + phys.dtype().is_numeric(), + InvalidOperation: "cannot sort column of dtype `{}`", s.dtype() + ); + phys + }, + }; + Ok(out) +} + +pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { + let by = convert_series_for_row_encoding(by)?; + let by = by.rechunk(); + + let out = match by.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + let ca = by.categorical().unwrap(); + if ca.uses_lexical_ordering() { + by.to_arrow(0, CompatLevel::newest()) + } else { + ca.physical().chunks[0].clone() + } + }, + // Take physical + _ => by.chunks()[0].clone(), + }; + Ok(out) +} + +pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + Ok(rows.into_array()) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +// Almost the same but broadcast nulls to the row-encoded array. +pub fn encode_rows_vertical_par_unordered_broadcast_nulls( + by: &[Series], +) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + + let validities = sliced + .iter() + .flat_map(|s| { + let s = s.rechunk(); + #[allow(clippy::unnecessary_to_owned)] + s.chunks() + .to_vec() + .into_iter() + .map(|arr| arr.validity().cloned()) + }) + .collect::>(); + + let validity = combine_validities_and_many(&validities); + Ok(rows.into_array().with_validity_typed(validity)) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +pub fn encode_rows_unordered(by: &[Series]) -> PolarsResult { + let rows = _get_rows_encoded_unordered(by)?; + Ok(BinaryOffsetChunked::with_chunk( + PlSmallStr::EMPTY, + rows.into_array(), + )) +} + +pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + for by in by { + let arr = _get_rows_encoded_compat_array(by)?; + let field = EncodingField::new_unsorted(); + match arr.dtype() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + for arr in arr.values() { + cols.push(arr.clone() as ArrayRef); + fields.push(field) + } + }, + _ => { + cols.push(arr); + fields.push(field) + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + +pub fn _get_rows_encoded( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + debug_assert_eq!(by.len(), descending.len()); + debug_assert_eq!(by.len(), nulls_last.len()); + + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + + for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { + let by = by.as_materialized_series(); + let arr = _get_rows_encoded_compat_array(by)?; + let sort_field = EncodingField { + descending: *desc, + nulls_last: *null_last, + no_order: false, + }; + match arr.dtype() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + let arr = arr.propagate_nulls(); + for value_arr in arr.values() { + cols.push(value_arr.clone() as ArrayRef); + fields.push(sort_field); + } + }, + _ => { + cols.push(arr); + fields.push(sort_field); + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + +pub fn _get_rows_encoded_ca( + name: PlSmallStr, + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + _get_rows_encoded(by, descending, nulls_last) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} + +pub fn _get_rows_encoded_arr( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult> { + _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) +} + +pub fn _get_rows_encoded_ca_unordered( + name: PlSmallStr, + by: &[Series], +) -> PolarsResult { + _get_rows_encoded_unordered(by) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs index cad95d6b1d10..7f257f23f59e 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -1,6 +1,7 @@ use polars_utils::itertools::Itertools; use super::*; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; #[derive(Eq)] struct CompareRow<'a> { diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 5653039ff02e..2291cc2306e1 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,10 +1,8 @@ -use arrow::compute::utils::combine_validities_and_many; use compare_inner::NullOrderCmp; -use polars_row::{convert_columns, EncodingField, RowsEncoded}; use polars_utils::itertools::Itertools; use super::*; -use crate::utils::_split_offsets; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; pub(crate) fn args_validate( ca: &ChunkedArray, @@ -86,181 +84,6 @@ pub(crate) fn arg_sort_multiple_impl( Ok(ca.into_inner()) } -pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { - let by = convert_sort_column_multi_sort(by)?; - let by = by.rechunk(); - - let out = match by.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - let ca = by.categorical().unwrap(); - if ca.uses_lexical_ordering() { - by.to_arrow(0, CompatLevel::newest()) - } else { - ca.physical().chunks[0].clone() - } - }, - // Take physical - _ => by.chunks()[0].clone(), - }; - Ok(out) -} - -pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult { - let n_threads = POOL.current_num_threads(); - let len = by[0].len(); - let splits = _split_offsets(len, n_threads); - - let chunks = splits.into_par_iter().map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded_unordered(&sliced)?; - Ok(rows.into_array()) - }); - let chunks = POOL.install(|| chunks.collect::>>()); - - Ok(BinaryOffsetChunked::from_chunk_iter( - PlSmallStr::EMPTY, - chunks?, - )) -} - -// Almost the same but broadcast nulls to the row-encoded array. -pub fn encode_rows_vertical_par_unordered_broadcast_nulls( - by: &[Series], -) -> PolarsResult { - let n_threads = POOL.current_num_threads(); - let len = by[0].len(); - let splits = _split_offsets(len, n_threads); - - let chunks = splits.into_par_iter().map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded_unordered(&sliced)?; - - let validities = sliced - .iter() - .flat_map(|s| { - let s = s.rechunk(); - #[allow(clippy::unnecessary_to_owned)] - s.chunks() - .to_vec() - .into_iter() - .map(|arr| arr.validity().cloned()) - }) - .collect::>(); - - let validity = combine_validities_and_many(&validities); - Ok(rows.into_array().with_validity_typed(validity)) - }); - let chunks = POOL.install(|| chunks.collect::>>()); - - Ok(BinaryOffsetChunked::from_chunk_iter( - PlSmallStr::EMPTY, - chunks?, - )) -} - -pub(crate) fn encode_rows_unordered(by: &[Series]) -> PolarsResult { - let rows = _get_rows_encoded_unordered(by)?; - Ok(BinaryOffsetChunked::with_chunk( - PlSmallStr::EMPTY, - rows.into_array(), - )) -} - -pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { - let mut cols = Vec::with_capacity(by.len()); - let mut fields = Vec::with_capacity(by.len()); - for by in by { - let arr = _get_rows_encoded_compat_array(by)?; - let field = EncodingField::new_unsorted(); - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for arr in arr.values() { - cols.push(arr.clone() as ArrayRef); - fields.push(field) - } - }, - _ => { - cols.push(arr); - fields.push(field) - }, - } - } - Ok(convert_columns(&cols, &fields)) -} - -pub fn _get_rows_encoded( - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult { - debug_assert_eq!(by.len(), descending.len()); - debug_assert_eq!(by.len(), nulls_last.len()); - - let mut cols = Vec::with_capacity(by.len()); - let mut fields = Vec::with_capacity(by.len()); - - for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { - let by = by.as_materialized_series(); - let arr = _get_rows_encoded_compat_array(by)?; - let sort_field = EncodingField { - descending: *desc, - nulls_last: *null_last, - no_order: false, - }; - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - let arr = arr.propagate_nulls(); - for value_arr in arr.values() { - cols.push(value_arr.clone() as ArrayRef); - fields.push(sort_field); - } - }, - _ => { - cols.push(arr); - fields.push(sort_field); - }, - } - } - Ok(convert_columns(&cols, &fields)) -} - -pub fn _get_rows_encoded_ca( - name: PlSmallStr, - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult { - _get_rows_encoded(by, descending, nulls_last) - .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) -} - -pub fn _get_rows_encoded_arr( - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult> { - _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) -} - -pub fn _get_rows_encoded_ca_unordered( - name: PlSmallStr, - by: &[Series], -) -> PolarsResult { - _get_rows_encoded_unordered(by) - .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) -} - pub(crate) fn argsort_multiple_row_fmt( by: &[Column], mut descending: Vec, diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index d68583f1dc5c..727f2ace15a8 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -18,6 +18,9 @@ use compare_inner::NonNull; use rayon::prelude::*; pub use slice::*; +use crate::chunked_array::ops::row_encode::{ + _get_rows_encoded_ca, convert_series_for_row_encoding, +}; use crate::prelude::compare_inner::TotalOrdInner; use crate::prelude::sort::arg_sort_multiple::*; use crate::prelude::*; @@ -708,44 +711,6 @@ impl ChunkSort for BooleanChunked { } } -pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult { - use DataType::*; - let out = match s.dtype() { - #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => s.rechunk(), - Binary | Boolean => s.clone(), - BinaryOffset => s.clone(), - String => s.str().unwrap().as_binary().into_series(), - #[cfg(feature = "dtype-struct")] - Struct(_) => { - let ca = s.struct_().unwrap(); - let new_fields = ca - .fields_as_series() - .iter() - .map(convert_sort_column_multi_sort) - .collect::>>()?; - let mut out = - StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?; - out.zip_outer_validity(ca); - out.into_series() - }, - // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here - #[cfg(feature = "dtype-decimal")] - Decimal(_, _) => s.clone(), - List(inner) if !inner.is_nested() => s.clone(), - Null => s.clone(), - _ => { - let phys = s.to_physical_repr().into_owned(); - polars_ensure!( - phys.dtype().is_numeric(), - InvalidOperation: "cannot sort column of dtype `{}`", s.dtype() - ); - phys - }, - }; - Ok(out) -} - pub fn _broadcast_bools(n_cols: usize, values: &mut Vec) { if n_cols > values.len() && values.len() == 1 { while n_cols != values.len() { @@ -763,7 +728,7 @@ pub(crate) fn prepare_arg_sort( let mut columns = columns .iter() .map(Column::as_materialized_series) - .map(convert_sort_column_multi_sort) + .map(convert_series_for_row_encoding) .map(|s| s.map(Column::from)) .collect::>>()?; diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs index 05ac32424d5e..625da8881117 100644 --- a/crates/polars-core/src/chunked_array/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -10,8 +10,8 @@ use polars_utils::aliases::PlHashMap; use polars_utils::itertools::Itertools; use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::chunked_array::ChunkedArray; -use crate::prelude::sort::arg_sort_multiple::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::prelude::*; use crate::series::Series; use crate::utils::Container; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index edf76969e976..a1a87e3001bd 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -854,13 +854,13 @@ impl AnyValue<'_> { } } -impl<'a> Hash for AnyValue<'a> { +impl Hash for AnyValue<'_> { fn hash(&self, state: &mut H) { self.hash_impl(state, false) } } -impl<'a> Eq for AnyValue<'a> {} +impl Eq for AnyValue<'_> {} impl<'a, T> From> for AnyValue<'a> where diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index e18dd9026a4d..30d96649762f 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; #[cfg(feature = "dtype-array")] use polars_utils::format_tuple; +use polars_utils::itertools::Itertools; use super::*; #[cfg(feature = "object")] @@ -115,9 +116,10 @@ impl PartialEq for DataType { use DataType::*; { match (self, other) { - // Don't include rev maps in comparisons #[cfg(feature = "dtype-categorical")] - (Categorical(_, _), Categorical(_, _)) => true, + // Don't include rev maps in comparisons + // TODO: include ordering in comparison + (Categorical(_, _ordering_l), Categorical(_, _ordering_r)) => true, #[cfg(feature = "dtype-categorical")] // None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)` (Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true, @@ -183,6 +185,8 @@ impl DataType { pub fn is_known(&self) -> bool { match self { DataType::List(inner) => inner.is_known(), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) => inner.is_known(), #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => fields.iter().all(|fld| fld.dtype.is_known()), DataType::Unknown(_) => false, @@ -190,6 +194,35 @@ impl DataType { } } + /// Materialize this datatype if it is unknown. All other datatypes + /// are left unchanged. + pub fn materialize_unknown(&self) -> PolarsResult { + match self { + DataType::Unknown(u) => u + .materialize() + .ok_or_else(|| polars_err!(SchemaMismatch: "failed to materialize unknown type")), + DataType::List(inner) => Ok(DataType::List(Box::new(inner.materialize_unknown()?))), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, size) => Ok(DataType::Array( + Box::new(inner.materialize_unknown()?), + *size, + )), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => Ok(DataType::Struct( + fields + .iter() + .map(|f| { + PolarsResult::Ok(Field::new( + f.name().clone(), + f.dtype().materialize_unknown()?, + )) + }) + .try_collect_vec()?, + )), + _ => Ok(self.clone()), + } + } + #[cfg(feature = "dtype-array")] /// Get the full shape of a multidimensional array. pub fn get_shape(&self) -> Option> { @@ -665,6 +698,10 @@ impl DataType { pub fn matches_schema_type(&self, schema_type: &DataType) -> PolarsResult { match (self, schema_type) { (DataType::List(l), DataType::List(r)) => l.matches_schema_type(r), + #[cfg(feature = "dtype-array")] + (DataType::Array(l, sl), DataType::Array(r, sr)) => { + Ok(l.matches_schema_type(r)? && sl == sr) + }, #[cfg(feature = "dtype-struct")] (DataType::Struct(l), DataType::Struct(r)) => { let mut must_cast = false; diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 712466482ce2..8d84d47be978 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -13,7 +13,6 @@ mod any_value; mod dtype; mod field; mod into_scalar; -mod reshape; #[cfg(feature = "object")] mod static_array_collect; mod time_unit; @@ -26,6 +25,7 @@ use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign}; pub use aliases::*; pub use any_value::*; pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray}; +pub use arrow::datatypes::reshape::*; #[cfg(feature = "dtype-categorical")] use arrow::datatypes::IntegerType; pub use arrow::datatypes::{ArrowDataType, TimeUnit as ArrowTimeUnit}; @@ -42,7 +42,6 @@ use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; -pub use reshape::*; #[cfg(feature = "serde")] use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] @@ -300,7 +299,7 @@ unsafe impl PolarsDataType for ObjectType { type OwnedPhysical = T; type ZeroablePhysical<'a> = Option<&'a T>; type Array = ObjectArray; - type IsNested = TrueT; + type IsNested = FalseT; type HasViews = FalseT; type IsStruct = FalseT; type IsObject = TrueT; diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index c930b9e94da7..88fbfae96701 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -125,6 +125,26 @@ fn get_str_len_limit() -> usize { fn get_list_len_limit() -> usize { parse_env_var_limit(FMT_TABLE_CELL_LIST_LEN, DEFAULT_LIST_LEN_LIMIT) } +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn get_ellipsis() -> &'static str { + match std::env::var(FMT_TABLE_FORMATTING).as_deref().unwrap_or("") { + preset if preset.starts_with("ASCII") => "...", + _ => "…", + } +} + +fn estimate_string_width(s: &str) -> usize { + // get a slightly more accurate estimate of a string's screen + // width, accounting (very roughly) for multibyte characters + let n_chars = s.chars().count(); + let n_bytes = s.len(); + if n_bytes == n_chars { + n_chars + } else { + let adjust = n_bytes as f64 / n_chars as f64; + std::cmp::min(n_chars * 2, (n_chars as f64 * adjust).ceil() as usize) + } +} macro_rules! format_array { ($f:ident, $a:expr, $dtype:expr, $name:expr, $array_type:expr) => {{ @@ -424,7 +444,7 @@ impl Debug for DataFrame { } } #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] -fn make_str_val(v: &str, truncate: usize) -> String { +fn make_str_val(v: &str, truncate: usize, ellipsis: &String) -> String { let v_trunc = &v[..v .char_indices() .take(truncate) @@ -434,14 +454,19 @@ fn make_str_val(v: &str, truncate: usize) -> String { if v == v_trunc { v.to_string() } else { - format!("{v_trunc}…") + format!("{v_trunc}{ellipsis}") } } #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] -fn field_to_str(f: &Field, str_truncate: usize) -> (String, usize) { - let name = make_str_val(f.name(), str_truncate); - let name_length = name.len(); +fn field_to_str( + f: &Field, + str_truncate: usize, + ellipsis: &String, + padding: usize, +) -> (String, usize) { + let name = make_str_val(f.name(), str_truncate, ellipsis); + let name_length = estimate_string_width(name.as_str()); let mut column_name = name; if env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) { column_name = "".to_string(); @@ -473,11 +498,11 @@ fn field_to_str(f: &Field, str_truncate: usize) -> (String, usize) { format!("{column_name}{separator}{column_dtype}") }; let mut s_len = std::cmp::max(name_length, dtype_length); - let separator_length = separator.trim().len(); + let separator_length = estimate_string_width(separator.trim()); if s_len < separator_length { s_len = separator_length; } - (s, s_len + 2) + (s, s_len + padding) } #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] @@ -487,27 +512,29 @@ fn prepare_row( n_last: usize, str_truncate: usize, max_elem_lengths: &mut [usize], + ellipsis: &String, + padding: usize, ) -> Vec { let reduce_columns = n_first + n_last < row.len(); let n_elems = n_first + n_last + reduce_columns as usize; let mut row_strings = Vec::with_capacity(n_elems); for (idx, v) in row[0..n_first].iter().enumerate() { - let elem_str = make_str_val(v, str_truncate); - let elem_len = elem_str.len() + 2; + let elem_str = make_str_val(v, str_truncate, ellipsis); + let elem_len = estimate_string_width(elem_str.as_str()) + padding; if max_elem_lengths[idx] < elem_len { max_elem_lengths[idx] = elem_len; }; row_strings.push(elem_str); } if reduce_columns { - row_strings.push("…".to_string()); - max_elem_lengths[n_first] = 3; + row_strings.push(ellipsis.to_string()); + max_elem_lengths[n_first] = ellipsis.chars().count() + padding; } let elem_offset = n_first + reduce_columns as usize; for (idx, v) in row[row.len() - n_last..].iter().enumerate() { - let elem_str = make_str_val(v, str_truncate); - let elem_len = elem_str.len() + 2; + let elem_str = make_str_val(v, str_truncate, ellipsis); + let elem_len = estimate_string_width(elem_str.as_str()) + padding; let elem_idx = elem_offset + idx; if max_elem_lengths[elem_idx] < elem_len { max_elem_lengths[elem_idx] = elem_len; @@ -542,16 +569,36 @@ impl Display for DataFrame { "The column lengths in the DataFrame are not equal." ); + let table_style = std::env::var(FMT_TABLE_FORMATTING).unwrap_or("DEFAULT".to_string()); + let is_utf8 = !table_style.starts_with("ASCII"); + let preset = match table_style.as_str() { + "ASCII_FULL" => ASCII_FULL, + "ASCII_FULL_CONDENSED" => ASCII_FULL_CONDENSED, + "ASCII_NO_BORDERS" => ASCII_NO_BORDERS, + "ASCII_BORDERS_ONLY" => ASCII_BORDERS_ONLY, + "ASCII_BORDERS_ONLY_CONDENSED" => ASCII_BORDERS_ONLY_CONDENSED, + "ASCII_HORIZONTAL_ONLY" => ASCII_HORIZONTAL_ONLY, + "ASCII_MARKDOWN" | "MARKDOWN" => ASCII_MARKDOWN, + "UTF8_FULL" => UTF8_FULL, + "UTF8_FULL_CONDENSED" => UTF8_FULL_CONDENSED, + "UTF8_NO_BORDERS" => UTF8_NO_BORDERS, + "UTF8_BORDERS_ONLY" => UTF8_BORDERS_ONLY, + "UTF8_HORIZONTAL_ONLY" => UTF8_HORIZONTAL_ONLY, + "NOTHING" => NOTHING, + _ => UTF8_FULL_CONDENSED, + }; + let ellipsis = get_ellipsis().to_string(); + let ellipsis_len = ellipsis.chars().count(); let max_n_cols = get_col_limit(); let max_n_rows = get_row_limit(); let str_truncate = get_str_len_limit(); + let padding = 2; // eg: one char either side of the value let (n_first, n_last) = if self.width() > max_n_cols { ((max_n_cols + 1) / 2, max_n_cols / 2) } else { (self.width(), 0) }; - let reduce_columns = n_first + n_last < self.width(); let n_tbl_cols = n_first + n_last + reduce_columns as usize; let mut names = Vec::with_capacity(n_tbl_cols); @@ -559,39 +606,19 @@ impl Display for DataFrame { let fields = self.fields(); for field in fields[0..n_first].iter() { - let (s, l) = field_to_str(field, str_truncate); + let (s, l) = field_to_str(field, str_truncate, &ellipsis, padding); names.push(s); name_lengths.push(l); } if reduce_columns { - names.push("…".into()); - name_lengths.push(3); + names.push(ellipsis.clone()); + name_lengths.push(ellipsis_len); } for field in fields[self.width() - n_last..].iter() { - let (s, l) = field_to_str(field, str_truncate); + let (s, l) = field_to_str(field, str_truncate, &ellipsis, padding); names.push(s); name_lengths.push(l); } - let (preset, is_utf8) = match std::env::var(FMT_TABLE_FORMATTING) - .as_deref() - .unwrap_or("DEFAULT") - { - "ASCII_FULL" => (ASCII_FULL, false), - "ASCII_FULL_CONDENSED" => (ASCII_FULL_CONDENSED, false), - "ASCII_NO_BORDERS" => (ASCII_NO_BORDERS, false), - "ASCII_BORDERS_ONLY" => (ASCII_BORDERS_ONLY, false), - "ASCII_BORDERS_ONLY_CONDENSED" => (ASCII_BORDERS_ONLY_CONDENSED, false), - "ASCII_HORIZONTAL_ONLY" => (ASCII_HORIZONTAL_ONLY, false), - "ASCII_MARKDOWN" => (ASCII_MARKDOWN, false), - "UTF8_FULL" => (UTF8_FULL, true), - "UTF8_FULL_CONDENSED" => (UTF8_FULL_CONDENSED, true), - "UTF8_NO_BORDERS" => (UTF8_NO_BORDERS, true), - "UTF8_BORDERS_ONLY" => (UTF8_BORDERS_ONLY, true), - "UTF8_HORIZONTAL_ONLY" => (UTF8_HORIZONTAL_ONLY, true), - "NOTHING" => (NOTHING, false), - "DEFAULT" => (UTF8_FULL_CONDENSED, true), - _ => (UTF8_FULL_CONDENSED, true), - }; let mut table = Table::new(); table @@ -601,7 +628,6 @@ impl Display for DataFrame { if is_utf8 && env_is_true(FMT_TABLE_ROUNDED_CORNERS) { table.apply_modifier(UTF8_ROUND_CORNERS); } - let mut constraints = Vec::with_capacity(n_tbl_cols); let mut max_elem_lengths: Vec = vec![0; n_tbl_cols]; @@ -610,7 +636,6 @@ impl Display for DataFrame { // Truncate the table if we have more rows than the // configured maximum number of rows let mut rows = Vec::with_capacity(std::cmp::max(max_n_rows, 2)); - let half = max_n_rows / 2; let rest = max_n_rows % 2; @@ -621,13 +646,20 @@ impl Display for DataFrame { .map(|c| c.str_value(i).unwrap()) .collect(); - let row_strings = - prepare_row(row, n_first, n_last, str_truncate, &mut max_elem_lengths); - + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); rows.push(row_strings); } - let dots = rows[0].iter().map(|_| "…".to_string()).collect(); + let dots = vec![ellipsis.clone(); rows[0].len()]; rows.push(dots); + for i in (height - half)..height { let row = self .get_columns() @@ -635,8 +667,15 @@ impl Display for DataFrame { .map(|c| c.str_value(i).unwrap()) .collect(); - let row_strings = - prepare_row(row, n_first, n_last, str_truncate, &mut max_elem_lengths); + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); rows.push(row_strings); } table.add_rows(rows); @@ -654,6 +693,8 @@ impl Display for DataFrame { n_last, str_truncate, &mut max_elem_lengths, + &ellipsis, + padding, ); table.add_row(row_strings); } else { @@ -662,10 +703,9 @@ impl Display for DataFrame { } } } else if height > 0 { - let dots: Vec = self.columns.iter().map(|_| "…".to_string()).collect(); + let dots: Vec = vec![ellipsis.clone(); self.columns.len()]; table.add_row(dots); } - let tbl_fallback_width = 100; let tbl_width = std::env::var("POLARS_TABLE_WIDTH") .map(|s| { @@ -683,10 +723,10 @@ impl Display for DataFrame { lower: Width::Fixed(l as u16), upper: Width::Fixed(u as u16), }; - let min_col_width = 5; + let min_col_width = std::cmp::max(5, 3 + padding); for (idx, elem_len) in max_elem_lengths.iter().enumerate() { let mx = std::cmp::min( - str_truncate + 3, // (3 = 2 space chars of padding + ellipsis char) + str_truncate + ellipsis_len + padding, std::cmp::max(name_lengths[idx], *elem_len), ); if mx <= min_col_width { @@ -1011,7 +1051,7 @@ fn format_blob(f: &mut Formatter<'_>, bytes: &[u8]) -> fmt::Result { } } if bytes.len() > width { - write!(f, "\"...")?; + write!(f, "\"…")?; } else { write!(f, "\"")?; } @@ -1138,9 +1178,7 @@ impl Series { if self.is_empty() { return "[]".to_owned(); } - let max_items = get_list_len_limit(); - match max_items { 0 => "[…]".to_owned(), _ if max_items >= self.len() => { diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index d83b91b78cff..e66c1ad12875 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -560,12 +560,12 @@ impl Column { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { // @scalar-opt unsafe { self.as_materialized_series() - .agg_quantile(groups, quantile, interpol) + .agg_quantile(groups, quantile, method) } .into() } diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index fe71148cd49b..aaf24a470969 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -236,7 +236,7 @@ impl Series { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { @@ -247,13 +247,12 @@ impl Series { use DataType::*; match s.dtype() { - Float32 => s.f32().unwrap().agg_quantile(groups, quantile, interpol), - Float64 => s.f64().unwrap().agg_quantile(groups, quantile, interpol), + Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method), + Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method), dt if dt.is_numeric() || dt.is_temporal() => { let ca = s.to_physical_repr(); let physical_type = ca.dtype(); - let s = - apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol); + let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, method); if dt.is_logical() { // back to physical and then // back to logical type diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 092d660fb4d2..19b8d5c2d061 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -13,7 +13,7 @@ use arrow::legacy::kernels::rolling::no_nulls::{ }; use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls; use arrow::legacy::kernels::take_agg::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; use arrow::legacy::trusted_len::TrustedLenPush; use arrow::types::NativeType; use num_traits::pow::Pow; @@ -295,8 +295,7 @@ impl_take_extremum!(float: f64); /// This trait will ensure the specific dispatch works without complicating /// the trait bounds. trait QuantileDispatcher { - fn _quantile(self, quantile: f64, interpol: QuantileInterpolOptions) - -> PolarsResult>; + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult>; fn _median(self) -> Option; } @@ -307,12 +306,8 @@ where T::Native: Ord, ChunkedArray: IntoSeries, { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -320,24 +315,16 @@ where } impl QuantileDispatcher for Float32Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() } } impl QuantileDispatcher for Float64Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -348,7 +335,7 @@ unsafe fn agg_quantile_generic( ca: &ChunkedArray, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series where T: PolarsNumericType, @@ -371,7 +358,7 @@ where } let take = { ca.take_unchecked(idx) }; // checked with invalid quantile check - take._quantile(quantile, interpol).unwrap_unchecked() + take._quantile(quantile, method).unwrap_unchecked() }) }, GroupsProxy::Slice { groups, .. } => { @@ -390,7 +377,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ), Some(validity) => { @@ -400,7 +387,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ) }, @@ -418,7 +405,7 @@ where let arr_group = _slice_from_offsets(ca, first, len); // unwrap checked with invalid quantile check arr_group - ._quantile(quantile, interpol) + ._quantile(quantile, method) .unwrap_unchecked() .map(|flt| NumCast::from(flt).unwrap_unchecked()) }, @@ -450,7 +437,7 @@ where }) }, GroupsProxy::Slice { .. } => { - agg_quantile_generic::(ca, groups, 0.5, QuantileInterpolOptions::Linear) + agg_quantile_generic::(ca, groups, 0.5, QuantileMethod::Linear) }, } } @@ -977,9 +964,9 @@ impl Float32Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float32Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float32Type>(self, groups) @@ -990,9 +977,9 @@ impl Float64Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) @@ -1184,9 +1171,9 @@ where &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index bdaa439a1232..12b4b27de7e2 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -3,8 +3,8 @@ use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca_unordered; use crate::config::verbose; -use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered; use crate::series::BitRepr; use crate::utils::flatten::flatten_par; diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 9c56f7c49122..9dee1e1f411a 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -21,7 +21,7 @@ mod proxy; pub use into_groups::*; pub use proxy::*; -use crate::prelude::sort::arg_sort_multiple::{ +use crate::chunked_array::ops::row_encode::{ encode_rows_unordered, encode_rows_vertical_par_unordered, }; @@ -594,18 +594,14 @@ impl<'df> GroupBy<'df> { /// /// ```rust /// # use polars_core::prelude::*; - /// # use arrow::legacy::prelude::QuantileInterpolOptions; + /// # use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> PolarsResult { - /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileInterpolOptions::default()) + /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default()) /// } /// ``` #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] - pub fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { polars_ensure!( (0.0..=1.0).contains(&quantile), ComputeError: "`quantile` should be within 0.0 and 1.0" @@ -614,9 +610,9 @@ impl<'df> GroupBy<'df> { for agg_col in agg_cols { let new_name = fmt_group_by_column( agg_col.name().as_str(), - GroupByMethod::Quantile(quantile, interpol), + GroupByMethod::Quantile(quantile, method), ); - let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, interpol) }; + let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) }; agg.rename(new_name); cols.push(agg); } @@ -868,7 +864,7 @@ pub enum GroupByMethod { Sum, Groups, NUnique, - Quantile(f64, QuantileInterpolOptions), + Quantile(f64, QuantileMethod), Count { include_nulls: bool, }, diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 86d952828c85..ef228981f7e5 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::mem::MaybeUninit; use num_traits::{FromPrimitive, ToPrimitive}; use polars_utils::idx_vec::IdxVec; @@ -17,162 +18,136 @@ where T: PolarsIntegerType, T::Native: ToPrimitive + FromPrimitive + Debug, { - // Use the indexes as perfect groups - pub fn group_tuples_perfect( + /// Use the indexes as perfect groups. + /// + /// # Safety + /// This ChunkedArray must contain each value in [0..num_groups) at least + /// once, and nothing outside this range. + pub unsafe fn group_tuples_perfect( &self, - max: usize, + num_groups: usize, mut multithreaded: bool, group_capacity: usize, ) -> GroupsProxy { multithreaded &= POOL.current_num_threads() > 1; + // The latest index will be used for the null sentinel. let len = if self.null_count() > 0 { - // we add one to store the null sentinel group - max + 2 + // We add one to store the null sentinel group. + num_groups + 1 } else { - max + 1 + num_groups }; - - // the latest index will be used for the null sentinel let null_idx = len.saturating_sub(1); - let n_threads = POOL.current_num_threads(); + let n_threads = POOL.current_num_threads(); let chunk_size = len / n_threads; let (groups, first) = if multithreaded && chunk_size > 1 { - let mut groups: Vec = unsafe { aligned_vec(len) }; + let mut groups: Vec = Vec::new(); groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); - let mut first: Vec = unsafe { aligned_vec(len) }; - - // ensure we keep aligned to cache lines - let chunk_size = (chunk_size * std::mem::size_of::()).next_multiple_of(64); - let chunk_size = chunk_size / std::mem::size_of::(); - - let mut cache_line_offsets = Vec::with_capacity(n_threads + 1); - cache_line_offsets.push(0); - let mut current_offset = chunk_size; - - while current_offset <= len { - cache_line_offsets.push(current_offset); - current_offset += chunk_size; + let mut first: Vec = Vec::with_capacity(len); + + // Round up offsets to nearest cache line for groups to reduce false sharing. + let groups_start = groups.as_ptr(); + let mut per_thread_offsets = Vec::with_capacity(n_threads + 1); + per_thread_offsets.push(0); + for t in 0..n_threads { + let ideal_offset = (t + 1) * chunk_size; + let cache_aligned_offset = + ideal_offset + groups_start.wrapping_add(ideal_offset).align_offset(128); + if t == n_threads - 1 { + per_thread_offsets.push(len); + } else { + per_thread_offsets.push(std::cmp::min(cache_aligned_offset, len)); + } } - cache_line_offsets.push(current_offset); let groups_ptr = unsafe { SyncPtr::new(groups.as_mut_ptr()) }; let first_ptr = unsafe { SyncPtr::new(first.as_mut_ptr()) }; - - // The number of threads is dependent on the number of categoricals/ unique values - // as every at least writes to a single cache line - // lower bound per thread: - // 32bit: 16 - // 64bit: 8 POOL.install(|| { - (0..cache_line_offsets.len() - 1) - .into_par_iter() - .for_each(|thread_no| { - let mut row_nr = 0 as IdxSize; - let start = cache_line_offsets[thread_no]; - let start = T::Native::from_usize(start).unwrap(); - let end = cache_line_offsets[thread_no + 1]; - let end = T::Native::from_usize(end).unwrap(); - - // SAFETY: we don't alias - let groups = - unsafe { std::slice::from_raw_parts_mut(groups_ptr.get(), len) }; - let first = unsafe { std::slice::from_raw_parts_mut(first_ptr.get(), len) }; - - for arr in self.downcast_iter() { - if arr.null_count() == 0 { - for &cat in arr.values().as_slice() { + (0..n_threads).into_par_iter().for_each(|thread_no| { + // We use raw pointers because the slices would overlap. + // However, each thread has its own range it is responsible for. + let groups = groups_ptr.get(); + let first = first_ptr.get(); + let start = per_thread_offsets[thread_no]; + let start = T::Native::from_usize(start).unwrap(); + let end = per_thread_offsets[thread_no + 1]; + let end = T::Native::from_usize(end).unwrap(); + + if start == end && thread_no != n_threads - 1 { + return; + }; + + let push_to_group = |cat, row_nr| unsafe { + debug_assert!(cat < len); + let buf = &mut *groups.add(cat); + buf.push(row_nr); + if buf.len() == 1 { + *first.add(cat) = row_nr; + } + }; + + let mut row_nr = 0 as IdxSize; + for arr in self.downcast_iter() { + if arr.null_count() == 0 { + for &cat in arr.values().as_slice() { + if cat >= start && cat < end { + push_to_group(cat.to_usize().unwrap(), row_nr); + } + + row_nr += 1; + } + } else { + for opt_cat in arr.iter() { + if let Some(&cat) = opt_cat { if cat >= start && cat < end { - let cat = cat.to_usize().unwrap(); - let buf = unsafe { groups.get_unchecked_release_mut(cat) }; - buf.push(row_nr); - - unsafe { - if buf.len() == 1 { - // SAFETY: we just pushed - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(cat) = *first_value - } - } + push_to_group(cat.to_usize().unwrap(), row_nr); } - row_nr += 1; + } else if thread_no == n_threads - 1 { + // Last thread handles null values. + push_to_group(null_idx, row_nr); } - } else { - for opt_cat in arr.iter() { - if let Some(&cat) = opt_cat { - // cannot factor out due to bchk - if cat >= start && cat < end { - let cat = cat.to_usize().unwrap(); - let buf = - unsafe { groups.get_unchecked_release_mut(cat) }; - buf.push(row_nr); - - unsafe { - if buf.len() == 1 { - // SAFETY: we just pushed - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(cat) = - *first_value - } - } - } - } - // last thread handles null values - else if thread_no == cache_line_offsets.len() - 2 { - let buf = - unsafe { groups.get_unchecked_release_mut(null_idx) }; - buf.push(row_nr); - unsafe { - if buf.len() == 1 { - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(null_idx) = - *first_value - } - } - } - row_nr += 1; - } + row_nr += 1; } } - }); + } + }); }); unsafe { - groups.set_len(len); first.set_len(len); } (groups, first) } else { let mut groups = Vec::with_capacity(len); - let mut first = vec![IdxSize::MAX; len]; + let mut first = Vec::with_capacity(len); + let first_out = first.spare_capacity_mut(); groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); + let mut push_to_group = |cat, row_nr| unsafe { + let buf: &mut IdxVec = groups.get_unchecked_release_mut(cat); + buf.push(row_nr); + if buf.len() == 1 { + *first_out.get_unchecked_release_mut(cat) = MaybeUninit::new(row_nr); + } + }; + let mut row_nr = 0 as IdxSize; for arr in self.downcast_iter() { for opt_cat in arr.iter() { if let Some(cat) = opt_cat { - let group_id = cat.to_usize().unwrap(); - let buf = unsafe { groups.get_unchecked_release_mut(group_id) }; - buf.push(row_nr); - - unsafe { - if buf.len() == 1 { - *first.get_unchecked_release_mut(group_id) = row_nr; - } - } + push_to_group(cat.to_usize().unwrap(), row_nr); } else { - let buf = unsafe { groups.get_unchecked_release_mut(null_idx) }; - buf.push(row_nr); - unsafe { - let first_value = buf.get_unchecked(0); - *first.get_unchecked_release_mut(null_idx) = *first_value - } + push_to_group(null_idx, row_nr); } row_nr += 1; } } + unsafe { + first.set_len(len); + } (groups, first) }; @@ -201,7 +176,7 @@ impl CategoricalChunked { } // on relative small tables this isn't much faster than the default strategy // but on huge tables, this can be > 2x faster - cats.group_tuples_perfect(cached.len() - 1, multithreaded, 0) + unsafe { cats.group_tuples_perfect(cached.len(), multithreaded, 0) } } else { self.physical().group_tuples(multithreaded, sorted).unwrap() } @@ -220,26 +195,3 @@ impl CategoricalChunked { out } } - -#[repr(C, align(64))] -struct AlignTo64([u8; 64]); - -/// There are no guarantees that the [`Vec`] will remain aligned if you reallocate the data. -/// This means that you cannot reallocate so you will need to know how big to allocate up front. -unsafe fn aligned_vec(n: usize) -> Vec { - assert!(std::mem::align_of::() <= 64); - let n_units = (n * std::mem::size_of::() / std::mem::size_of::()) + 1; - - let mut aligned: Vec = Vec::with_capacity(n_units); - - let ptr = aligned.as_mut_ptr(); - let cap_units = aligned.capacity(); - - std::mem::forget(aligned); - - Vec::from_raw_parts( - ptr as *mut T, - 0, - cap_units * std::mem::size_of::() / std::mem::size_of::(), - ) -} diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index d1c04162b7b9..63b1a8022108 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -546,7 +546,7 @@ pub enum GroupsIndicator<'a> { Slice([IdxSize; 2]), } -impl<'a> GroupsIndicator<'a> { +impl GroupsIndicator<'_> { pub fn len(&self) -> usize { match self { GroupsIndicator::Idx(g) => g.1.len(), diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index b82d43365010..0d4230ff3f91 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -33,6 +33,7 @@ use arrow::record_batch::RecordBatch; use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] @@ -49,8 +50,9 @@ pub enum NullStrategy { Propagate, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum UniqueKeepStrategy { /// Keep the first unique row. First, @@ -307,16 +309,6 @@ impl DataFrame { /// Converts a sequence of columns into a DataFrame, broadcasting length-1 /// columns to match the other columns. pub fn new_with_broadcast(columns: Vec) -> PolarsResult { - ensure_names_unique(&columns, |s| s.name().as_str())?; - unsafe { Self::new_with_broadcast_no_checks(columns) } - } - - /// Converts a sequence of columns into a DataFrame, broadcasting length-1 - /// columns to match the other columns. - /// - /// # Safety - /// Does not check that the column names are unique (which they must be). - pub unsafe fn new_with_broadcast_no_checks(mut columns: Vec) -> PolarsResult { // The length of the longest non-unit length column determines the // broadcast length. If all columns are unit-length the broadcast length // is one. @@ -326,17 +318,42 @@ impl DataFrame { .filter(|l| *l != 1) .max() .unwrap_or(1); + Self::new_with_broadcast_len(columns, broadcast_len) + } + + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to broadcast_len. + pub fn new_with_broadcast_len( + columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { + ensure_names_unique(&columns, |s| s.name().as_str())?; + unsafe { Self::new_with_broadcast_no_namecheck(columns, broadcast_len) } + } + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + /// + /// # Safety + /// Does not check that the column names are unique (which they must be). + pub unsafe fn new_with_broadcast_no_namecheck( + mut columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { for col in &mut columns { // Length not equal to the broadcast len, needs broadcast or is an error. let len = col.len(); if len != broadcast_len { if len != 1 { let name = col.name().to_owned(); - let longest_column = columns.iter().max_by_key(|c| c.len()).unwrap().name(); + let extra_info = + if let Some(c) = columns.iter().find(|c| c.len() == broadcast_len) { + format!(" (matching column '{}')", c.name()) + } else { + String::new() + }; polars_bail!( - ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", - name, len, longest_column, broadcast_len + ShapeMismatch: "could not create a new DataFrame: series {name:?} has length {len} while trying to broadcast to length {broadcast_len}{extra_info}", ); } *col = col.new_from_index(0, broadcast_len); @@ -2834,7 +2851,8 @@ impl DataFrame { dtype.is_numeric() || matches!(dtype, DataType::Boolean) }) .cloned() - .collect(); + .collect::>(); + polars_ensure!(!columns.is_empty(), InvalidOperation: "'horizontal_mean' expected at least 1 numerical column"); let numeric_df = unsafe { DataFrame::_new_no_checks_impl(self.height(), columns) }; let sum = || numeric_df.sum_horizontal(null_strategy); @@ -3292,7 +3310,7 @@ pub struct RecordBatchIter<'a> { parallel: bool, } -impl<'a> Iterator for RecordBatchIter<'a> { +impl Iterator for RecordBatchIter<'_> { type Item = RecordBatch; fn next(&mut self) -> Option { diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs index 2311194303db..ad8831ebda54 100644 --- a/crates/polars-core/src/frame/row/mod.rs +++ b/crates/polars-core/src/frame/row/mod.rs @@ -206,7 +206,7 @@ pub fn rows_to_schema_first_non_null( .iter_values() .enumerate() .filter_map(|(i, dtype)| { - // double check struct and list types types + // double check struct and list types // nested null values can be wrongly inferred by front ends match dtype { DataType::Null | DataType::List(_) => Some(i), diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 7dfb07c64d58..e00e45f1ede8 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -1,4 +1,5 @@ use arrow::bitmap::utils::get_bit_unchecked; +use polars_utils::hashing::folded_multiply; use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use rayon::prelude::*; use xxhash_rust::xxh3::xxh3_64_with_seed; @@ -30,11 +31,6 @@ pub trait VecHash { } } -pub(crate) const fn folded_multiply(s: u64, by: u64) -> u64 { - let result = (s as u128).wrapping_mul(by as u128); - ((result & 0xffff_ffff_ffff_ffff) as u64) ^ ((result >> 64) as u64) -} - pub(crate) fn get_null_hash_value(random_state: &PlRandomState) -> u64 { // we just start with a large prime number and hash that twice // to get a constant hash value for null/None diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index d100cf91172f..38fe0377f9c4 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -20,6 +20,8 @@ pub trait SchemaExt { fn iter_fields(&self) -> impl ExactSizeIterator + '_; fn to_supertype(&mut self, other: &Schema) -> PolarsResult; + + fn materialize_unknown_dtypes(&self) -> PolarsResult; } impl SchemaExt for Schema { @@ -88,6 +90,13 @@ impl SchemaExt for Schema { } Ok(changed) } + + /// Materialize all unknown dtypes in this schema. + fn materialize_unknown_dtypes(&self) -> PolarsResult { + self.iter() + .map(|(name, dtype)| Ok((name.clone(), dtype.materialize_unknown()?))) + .collect() + } } pub trait SchemaNamesAndDtypes { diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index f59492f71cb9..b96b78687c13 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -1,16 +1,9 @@ //! Allow arithmetic operations for ListChunked. +//! use polars_error::{feature_gated, PolarsResult}; -use arrow::bitmap::Bitmap; -use arrow::compute::utils::combine_validities_and; -use arrow::offset::OffsetsBuffer; -use either::Either; -use num_traits::Zero; -use polars_compute::arithmetic::pl_num::PlNumArithmetic; -use polars_compute::arithmetic::ArithmeticKernel; -use polars_compute::comparisons::TotalEqKernel; -use polars_utils::float::IsFloat; +use polars_error::{feature_gated, PolarsResult}; -use super::*; +use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series}; impl NumOpsDispatchInner for ListType { fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { @@ -45,136 +38,163 @@ pub enum NumericListOp { } impl NumericListOp { - fn name(&self) -> &'static str { - match self { - Self::Add => "add", - Self::Sub => "sub", - Self::Mul => "mul", - Self::Div => "div", - Self::Rem => "rem", - Self::FloorDiv => "floor_div", - } - } - - pub fn try_get_leaf_supertype( - &self, - prim_dtype_lhs: &DataType, - prim_dtype_rhs: &DataType, - ) -> PolarsResult { - let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?; - - Ok(if matches!(self, Self::Div) { - if dtype.is_float() { - dtype - } else { - DataType::Float64 - } - } else { - dtype - }) - } - + #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))] pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult { - // Ideally we only need to rechunk the leaf array, but getting the - // list offsets of a ListChunked triggers a rechunk anyway, so we just - // do it here. - let lhs = lhs.rechunk(); - let rhs = rhs.rechunk(); - - let binary_op_exec = match BinaryListNumericOpHelper::try_new( - self.clone(), - lhs.name().clone(), - lhs.dtype(), - rhs.dtype(), - lhs.len(), - rhs.len(), - { - let (a, b) = lhs.list_offsets_and_validities_recursive(); - (a, b, lhs.clone()) - }, - { - let (a, b) = rhs.list_offsets_and_validities_recursive(); - (a, b, rhs.clone()) - }, - lhs.rechunk_validity(), - rhs.rechunk_validity(), - )? { - Either::Left(v) => v, - Either::Right(ca) => return Ok(ca.into_series()), - }; - - Ok(binary_op_exec.finish()?.into_series()) + feature_gated!("list_arithmetic", { + use either::Either; + + // `trim_to_normalized_offsets` ensures we don't perform excessive + // memory allocation / compute on memory regions that have been + // sliced out. + let lhs = lhs.list_rechunk_and_trim_to_normalized_offsets(); + let rhs = rhs.list_rechunk_and_trim_to_normalized_offsets(); + + let binary_op_exec = match BinaryListNumericOpHelper::try_new( + self.clone(), + lhs.name().clone(), + lhs.dtype(), + rhs.dtype(), + lhs.len(), + rhs.len(), + { + let (a, b) = lhs.list_offsets_and_validities_recursive(); + debug_assert!(a.iter().all(|x| *x.first() as usize == 0)); + (a, b, lhs.clone()) + }, + { + let (a, b) = rhs.list_offsets_and_validities_recursive(); + debug_assert!(a.iter().all(|x| *x.first() as usize == 0)); + (a, b, rhs.clone()) + }, + lhs.rechunk_validity(), + rhs.rechunk_validity(), + )? { + Either::Left(v) => v, + Either::Right(ca) => return Ok(ca.into_series()), + }; + + Ok(binary_op_exec.finish()?.into_series()) + }) } +} - /// For operations that perform divisions on integers, sets the validity to NULL on rows where - /// the denominator is 0. - fn prepare_numeric_op_side_validities( - &self, - lhs: &mut PrimitiveArray, - rhs: &mut PrimitiveArray, - swapped: bool, - ) where - PrimitiveArray: polars_compute::comparisons::TotalEqKernel, - T::Native: Zero + IsFloat, - { - if !T::Native::is_float() { +#[cfg(feature = "list_arithmetic")] +use inner::BinaryListNumericOpHelper; + +#[cfg(feature = "list_arithmetic")] +mod inner { + use arrow::bitmap::Bitmap; + use arrow::compute::utils::combine_validities_and; + use arrow::offset::OffsetsBuffer; + use either::Either; + use num_traits::Zero; + use polars_compute::arithmetic::pl_num::PlNumArithmetic; + use polars_compute::arithmetic::ArithmeticKernel; + use polars_compute::comparisons::TotalEqKernel; + use polars_utils::float::IsFloat; + + use super::super::*; + + impl NumericListOp { + fn name(&self) -> &'static str { match self { - Self::Div | Self::Rem | Self::FloorDiv => { - let target = if swapped { lhs } else { rhs }; - let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero()); - let validity = combine_validities_and(target.validity(), Some(&ne_0)); - target.set_validity(validity); - }, - _ => {}, + Self::Add => "add", + Self::Sub => "sub", + Self::Mul => "mul", + Self::Div => "div", + Self::Rem => "rem", + Self::FloorDiv => "floor_div", } } - } - /// For list<->primitive where the primitive is broadcasted, we can dispatch to - /// `ArithmeticKernel`, which can have optimized codepaths for when one side is - /// a scalar. - fn apply_array_to_scalar( - &self, - arr_lhs: PrimitiveArray, - r: T::Native, - swapped: bool, - ) -> PrimitiveArray { - match self { - Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r), - Self::Sub => { - if swapped { - ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs) - } else { - ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r) - } - }, - Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r), - Self::Div => { - if swapped { - ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs) - } else { - ArithmeticKernel::legacy_div_scalar(arr_lhs, r) - } - }, - Self::Rem => { - if swapped { - ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs) + fn try_get_leaf_supertype( + &self, + prim_dtype_lhs: &DataType, + prim_dtype_rhs: &DataType, + ) -> PolarsResult { + let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + Ok(if matches!(self, Self::Div) { + if dtype.is_float() { + dtype } else { - ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r) + DataType::Float64 } - }, - Self::FloorDiv => { - if swapped { - ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs) - } else { - ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r) + } else { + dtype + }) + } + + /// For operations that perform divisions on integers, sets the validity to NULL on rows where + /// the denominator is 0. + fn prepare_numeric_op_side_validities( + &self, + lhs: &mut PrimitiveArray, + rhs: &mut PrimitiveArray, + swapped: bool, + ) where + PrimitiveArray: + polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + if !T::Native::is_float() { + match self { + Self::Div | Self::Rem | Self::FloorDiv => { + let target = if swapped { lhs } else { rhs }; + let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero()); + let validity = combine_validities_and(target.validity(), Some(&ne_0)); + target.set_validity(validity); + }, + _ => {}, } - }, + } + } + + /// For list<->primitive where the primitive is broadcasted, we can dispatch to + /// `ArithmeticKernel`, which can have optimized codepaths for when one side is + /// a scalar. + fn apply_array_to_scalar( + &self, + arr_lhs: PrimitiveArray, + r: T::Native, + swapped: bool, + ) -> PrimitiveArray { + match self { + Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r), + Self::Sub => { + if swapped { + ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r) + } + }, + Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r), + Self::Div => { + if swapped { + ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::legacy_div_scalar(arr_lhs, r) + } + }, + Self::Rem => { + if swapped { + ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r) + } + }, + Self::FloorDiv => { + if swapped { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r) + } + }, + } } } -} -macro_rules! with_match_numeric_list_op { + macro_rules! with_match_numeric_list_op { ($op:expr, $swapped:expr, | $_:tt $OP:tt | $($body:tt)* ) => ({ macro_rules! __with_func__ {( $_ $OP:tt ) => ( $($body)* )} @@ -213,814 +233,826 @@ macro_rules! with_match_numeric_list_op { }) } -#[derive(Debug)] -enum BinaryOpApplyType { - ListToList, - ListToPrimitive, - PrimitiveToList, -} - -#[derive(Debug)] -enum Broadcast { - Left, - Right, - #[allow(clippy::enum_variant_names)] - NoBroadcast, -} + #[derive(Debug)] + enum BinaryOpApplyType { + ListToList, + ListToPrimitive, + PrimitiveToList, + } -/// Utility to perform a binary operation between the primitive values of -/// 2 columns, where at least one of the columns is a `ListChunked` type. -struct BinaryListNumericOpHelper { - op: NumericListOp, - output_name: PlSmallStr, - op_apply_type: BinaryOpApplyType, - broadcast: Broadcast, - output_dtype: DataType, - output_primitive_dtype: DataType, - output_len: usize, - /// Outer validity of the result, we always materialize this to reduce the - /// amount of code paths we need. - outer_validity: Bitmap, - // The series are stored as they are used for list broadcasting. - data_lhs: (Vec>, Vec>, Series), - data_rhs: (Vec>, Vec>, Series), - list_to_prim_lhs: Option<(Box, usize)>, - swapped: bool, -} + #[derive(Debug)] + enum Broadcast { + Left, + Right, + #[allow(clippy::enum_variant_names)] + NoBroadcast, + } -/// This lets us separate some logic into `new()` to reduce the amount of -/// monomorphized code. -impl BinaryListNumericOpHelper { - /// Checks that: - /// * Dtypes are compatible: - /// * list<->primitive | primitive<->list - /// * list<->list both contain primitives (e.g. List) - /// * Primitive dtypes match - /// * Lengths are compatible: - /// * 1<->n | n<->1 - /// * n<->n - /// * Both sides have at least 1 non-NULL outer row. - /// - /// Does not check: - /// * Whether the offsets are aligned for list<->list, this will be checked during execution. - /// - /// This returns an `Either` which may contain the final result to simplify - /// the implementation. - #[allow(clippy::too_many_arguments)] - fn try_new( + /// Utility to perform a binary operation between the primitive values of + /// 2 columns, where at least one of the columns is a `ListChunked` type. + pub(super) struct BinaryListNumericOpHelper { op: NumericListOp, output_name: PlSmallStr, - dtype_lhs: &DataType, - dtype_rhs: &DataType, - len_lhs: usize, - len_rhs: usize, + op_apply_type: BinaryOpApplyType, + broadcast: Broadcast, + output_dtype: DataType, + output_primitive_dtype: DataType, + output_len: usize, + /// Outer validity of the result, we always materialize this to reduce the + /// amount of code paths we need. + outer_validity: Bitmap, + // The series are stored as they are used for list broadcasting. data_lhs: (Vec>, Vec>, Series), data_rhs: (Vec>, Vec>, Series), - validity_lhs: Option, - validity_rhs: Option, - ) -> PolarsResult> { - let prim_dtype_lhs = dtype_lhs.leaf_dtype(); - let prim_dtype_rhs = dtype_rhs.leaf_dtype(); - - let output_primitive_dtype = op.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?; - - let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) { - (l @ DataType::List(a), r @ DataType::List(b)) => { - // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user - // directly adds 2 series together it bypasses the DSL. - // This is currently duplicated code and should be replaced one day with an assert after Series ops get - // checked properly. - if ![a, b] - .into_iter() - .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) + list_to_prim_lhs: Option<(Box, usize)>, + swapped: bool, + } + + /// This lets us separate some logic into `new()` to reduce the amount of + /// monomorphized code. + impl BinaryListNumericOpHelper { + /// Checks that: + /// * Dtypes are compatible: + /// * list<->primitive | primitive<->list + /// * list<->list both contain primitives (e.g. List) + /// * Primitive dtypes match + /// * Lengths are compatible: + /// * 1<->n | n<->1 + /// * n<->n + /// * Both sides have at least 1 non-NULL outer row. + /// + /// Does not check: + /// * Whether the offsets are aligned for list<->list, this will be checked during execution. + /// + /// This returns an `Either` which may contain the final result to simplify + /// the implementation. + #[allow(clippy::too_many_arguments)] + pub(super) fn try_new( + op: NumericListOp, + output_name: PlSmallStr, + dtype_lhs: &DataType, + dtype_rhs: &DataType, + len_lhs: usize, + len_rhs: usize, + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + validity_lhs: Option, + validity_rhs: Option, + ) -> PolarsResult> { + let prim_dtype_lhs = dtype_lhs.leaf_dtype(); + let prim_dtype_rhs = dtype_rhs.leaf_dtype(); + + let output_primitive_dtype = + op.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) { + (l @ DataType::List(a), r @ DataType::List(b)) => { + // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user + // directly adds 2 series together it bypasses the DSL. + // This is currently duplicated code and should be replaced one day with an assert after Series ops get + // checked properly. + if ![a, b] + .into_iter() + .all(|x| x.is_numeric() || x.is_bool() || x.is_null()) + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op.name(), l, r, + ); + } + (BinaryOpApplyType::ListToList, l) + }, + (list_dtype @ DataType::List(_), x) + if x.is_numeric() || x.is_bool() || x.is_null() => { - polars_bail!( - InvalidOperation: - "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", - op.name(), l, r, - ); - } - (BinaryOpApplyType::ListToList, l) - }, - (list_dtype @ DataType::List(_), x) if x.is_numeric() || x.is_bool() || x.is_null() => { - (BinaryOpApplyType::ListToPrimitive, list_dtype) - }, - (x, list_dtype @ DataType::List(_)) if x.is_numeric() || x.is_bool() || x.is_null() => { - (BinaryOpApplyType::PrimitiveToList, list_dtype) - }, - (l, r) => polars_bail!( - InvalidOperation: - "{} operation not supported for dtypes: {} != {}", - op.name(), l, r, - ), - }; - - let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone()); - - let (broadcast, output_len) = match (len_lhs, len_rhs) { - (l, r) if l == r => (Broadcast::NoBroadcast, l), - (1, v) => (Broadcast::Left, v), - (v, 1) => (Broadcast::Right, v), - (l, r) => polars_bail!( - ShapeMismatch: - "cannot {} two columns of differing lengths: {} != {}", - op.name(), l, r - ), - }; - - let DataType::List(output_inner_dtype) = &output_dtype else { - unreachable!() - }; - - // # NULL semantics - // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]] - // * Essentially as if the NULL primitive was added to every primitive in the row of the list column. - // * NULL (List[Int64]) + 1 (Int64) => NULL - // * NULL (List[Int64]) + [1] (List[Int64]) => NULL - - if output_len == 0 - || (len_lhs == 1 - && matches!( - &op_apply_type, - BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive - ) - && validity_lhs.as_ref().map_or(false, |x| { - !x.get_bit(0) // is not valid - })) - || (len_rhs == 1 - && matches!( - &op_apply_type, - BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList + (BinaryOpApplyType::ListToPrimitive, list_dtype) + }, + (x, list_dtype @ DataType::List(_)) + if x.is_numeric() || x.is_bool() || x.is_null() => + { + (BinaryOpApplyType::PrimitiveToList, list_dtype) + }, + (l, r) => polars_bail!( + InvalidOperation: + "{} operation not supported for dtypes: {} != {}", + op.name(), l, r, + ), + }; + + let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone()); + + let (broadcast, output_len) = match (len_lhs, len_rhs) { + (l, r) if l == r => (Broadcast::NoBroadcast, l), + (1, v) => (Broadcast::Left, v), + (v, 1) => (Broadcast::Right, v), + (l, r) => polars_bail!( + ShapeMismatch: + "cannot {} two columns of differing lengths: {} != {}", + op.name(), l, r + ), + }; + + let DataType::List(output_inner_dtype) = &output_dtype else { + unreachable!() + }; + + // # NULL semantics + // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]] + // * Essentially as if the NULL primitive was added to every primitive in the row of the list column. + // * NULL (List[Int64]) + 1 (Int64) => NULL + // * NULL (List[Int64]) + [1] (List[Int64]) => NULL + + if output_len == 0 + || (len_lhs == 1 + && matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive + ) + && validity_lhs.as_ref().map_or(false, |x| { + !x.get_bit(0) // is not valid + })) + || (len_rhs == 1 + && matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList + ) + && validity_rhs.as_ref().map_or(false, |x| { + !x.get_bit(0) // is not valid + })) + { + return Ok(Either::Right(ListChunked::full_null_with_dtype( + output_name.clone(), + output_len, + output_inner_dtype.as_ref(), + ))); + } + + // At this point: + // * All unit length list columns have a valid outer value. + + // The outer validity is just the validity of any non-broadcasting lists. + let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) { + // Both lists with same length, we combine the validity. + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => { + combine_validities_and(l.as_ref(), r.as_ref()) + }, + // Match all other combinations that have non-broadcasting lists. + ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive, + Broadcast::NoBroadcast | Broadcast::Right, + v, + _, ) - && validity_rhs.as_ref().map_or(false, |x| { - !x.get_bit(0) // is not valid - })) - { - return Ok(Either::Right(ListChunked::full_null_with_dtype( - output_name.clone(), + | ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList, + Broadcast::NoBroadcast | Broadcast::Left, + _, + v, + ) => v, + _ => None, + } + .unwrap_or_else(|| Bitmap::new_with_value(true, output_len)); + + Ok(Either::Left(Self { + op, + output_name, + op_apply_type, + broadcast, + output_dtype: output_dtype.clone(), + output_primitive_dtype, output_len, - output_inner_dtype.as_ref(), - ))); + outer_validity, + data_lhs, + data_rhs, + list_to_prim_lhs: None, + swapped: false, + })) } - // At this point: - // * All unit length list columns have a valid outer value. - - // The outer validity is just the validity of any non-broadcasting lists. - let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) { - // Both lists with same length, we combine the validity. - (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => { - combine_validities_and(l.as_ref(), r.as_ref()) - }, - // Match all other combinations that have non-broadcasting lists. - ( - BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive, - Broadcast::NoBroadcast | Broadcast::Right, - v, - _, - ) - | ( - BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList, - Broadcast::NoBroadcast | Broadcast::Left, - _, - v, - ) => v, - _ => None, - } - .unwrap_or_else(|| Bitmap::new_with_value(true, output_len)); - - Ok(Either::Left(Self { - op, - output_name, - op_apply_type, - broadcast, - output_dtype: output_dtype.clone(), - output_primitive_dtype, - output_len, - outer_validity, - data_lhs, - data_rhs, - list_to_prim_lhs: None, - swapped: false, - })) - } + pub(super) fn finish(mut self) -> PolarsResult { + // We have physical codepaths for a subset of the possible combinations of broadcasting and + // column types. The remaining combinations are handled by dispatching to the physical + // codepaths after operand swapping and/or materialized broadcasting. + // + // # Physical impl table + // Legend + // * | N | // impl "N" + // * | [N] | // dispatches to impl "N" + // + // | L | N | R | // Broadcast (L)eft, (N)oBroadcast, (R)ight + // ListToList | [1] | 0 | 1 | + // ListToPrimitive | [2] | 2 | 3 | // list broadcasting just materializes and dispatches to NoBroadcast + // PrimitiveToList | [3] | [2] | [2] | + + self.swapped = true; + + match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToList, Broadcast::Right) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => { + // We materialize the list columns with `new_from_index`, as otherwise we'd have to + // implement logic that broadcasts the offsets and validities across multiple levels + // of nesting. But we will re-use the materialized memory to store the result. + + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_rhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::NoBroadcast; + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToList, Broadcast::Left) => { + self.broadcast = Broadcast::Right; - fn finish(mut self) -> PolarsResult { - // We have physical codepaths for a subset of the possible combinations of broadcasting and - // column types. The remaining combinations are handled by dispatching to the physical - // codepaths after operand swapping and/or materialized broadcasting. - // - // # Physical impl table - // Legend - // * | N | // impl "N" - // * | [N] | // dispatches to impl "N" - // - // | L | N | R | // Broadcast (L)eft, (N)oBroadcast, (R)ight - // ListToList | [1] | 0 | 1 | - // ListToPrimitive | [2] | 2 | 3 | // list broadcasting just materializes and dispatches to NoBroadcast - // PrimitiveToList | [3] | [2] | [2] | - - self.swapped = true; - - match (&self.op_apply_type, &self.broadcast) { - (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) - | (BinaryOpApplyType::ListToList, Broadcast::Right) - | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) - | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { - self.swapped = false; - self._finish_impl_dispatch() - }, - (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => { - // We materialize the list columns with `new_from_index`, as otherwise we'd have to - // implement logic that broadcasts the offsets and validities across multiple levels - // of nesting. But we will re-use the materialized memory to store the result. - - self.list_to_prim_lhs - .replace(Self::materialize_broadcasted_list( - &mut self.data_rhs, - self.output_len, - &self.output_primitive_dtype, - )); - - self.op_apply_type = BinaryOpApplyType::ListToPrimitive; - self.broadcast = Broadcast::NoBroadcast; - core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); - - self._finish_impl_dispatch() - }, - (BinaryOpApplyType::ListToList, Broadcast::Left) => { - self.broadcast = Broadcast::Right; + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); - core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => { + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_lhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.broadcast = Broadcast::NoBroadcast; + + // This does not swap! We are just dispatching to `NoBroadcast` + // after materializing the broadcasted list array. + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::Right; - self._finish_impl_dispatch() - }, - (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => { - self.list_to_prim_lhs - .replace(Self::materialize_broadcasted_list( - &mut self.data_lhs, - self.output_len, - &self.output_primitive_dtype, - )); - - self.broadcast = Broadcast::NoBroadcast; - - // This does not swap! We are just dispatching to `NoBroadcast` - // after materializing the broadcasted list array. - self.swapped = false; - self._finish_impl_dispatch() - }, - (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => { - self.op_apply_type = BinaryOpApplyType::ListToPrimitive; - self.broadcast = Broadcast::Right; + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); - core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; - self._finish_impl_dispatch() - }, - (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { - self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); - core::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + } + } - self._finish_impl_dispatch() - }, + fn _finish_impl_dispatch(&mut self) -> PolarsResult { + let output_dtype = self.output_dtype.clone(); + let output_len = self.output_len; + + let prim_lhs = self + .data_lhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + let prim_rhs = self + .data_rhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + + debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype()); + let prim_dtype = prim_lhs.dtype(); + debug_assert_eq!(prim_dtype, &self.output_primitive_dtype); + + // Safety: Leaf dtypes have been checked to be numeric by `try_new()` + let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| { + self._finish_impl::<$T>(prim_lhs, prim_rhs) + })?; + + debug_assert_eq!(out.dtype(), &output_dtype); + assert_eq!(out.len(), output_len); + + Ok(out) } - } - fn _finish_impl_dispatch(&mut self) -> PolarsResult { - let output_dtype = self.output_dtype.clone(); - let output_len = self.output_len; - - let prim_lhs = self - .data_lhs - .2 - .get_leaf_array() - .cast(&self.output_primitive_dtype)? - .rechunk(); - let prim_rhs = self - .data_rhs - .2 - .get_leaf_array() - .cast(&self.output_primitive_dtype)? - .rechunk(); - - debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype()); - let prim_dtype = prim_lhs.dtype(); - debug_assert_eq!(prim_dtype, &self.output_primitive_dtype); - - // Safety: Leaf dtypes have been checked to be numeric by `try_new()` - let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| { - self._finish_impl::<$T>(prim_lhs, prim_rhs) - })?; - - debug_assert_eq!(out.dtype(), &output_dtype); - assert_eq!(out.len(), output_len); - - Ok(out) - } + /// Internal use only - contains physical impls. + fn _finish_impl( + &mut self, + prim_s_lhs: Series, + prim_s_rhs: Series, + ) -> PolarsResult + where + T::Native: PlNumArithmetic, + PrimitiveArray: + polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + #[inline(never)] + fn check_mismatch_pos( + mismatch_pos: usize, + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + ) -> PolarsResult<()> { + if mismatch_pos < offsets_lhs.len_proxy() { + // RHS could be broadcasted + let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 { + 0 + } else { + mismatch_pos + }); + polars_bail!( + ShapeMismatch: + "list lengths differed at index {}: {} != {}", + mismatch_pos, + offsets_lhs.length_at(mismatch_pos), len_r + ) + } + Ok(()) + } - /// Internal use only - contains physical impls. - fn _finish_impl( - &mut self, - prim_s_lhs: Series, - prim_s_rhs: Series, - ) -> PolarsResult - where - T::Native: PlNumArithmetic, - PrimitiveArray: polars_compute::comparisons::TotalEqKernel, - T::Native: Zero + IsFloat, - { - #[inline(never)] - fn check_mismatch_pos( - mismatch_pos: usize, - offsets_lhs: &OffsetsBuffer, - offsets_rhs: &OffsetsBuffer, - ) -> PolarsResult<()> { - if mismatch_pos < offsets_lhs.len_proxy() { - // RHS could be broadcasted - let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 { - 0 - } else { - mismatch_pos - }); - polars_bail!( - ShapeMismatch: - "list lengths differed at index {}: {} != {}", - mismatch_pos, - offsets_lhs.length_at(mismatch_pos), len_r - ) + let mut arr_lhs = { + let ca: &ChunkedArray = prim_s_lhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + let mut arr_rhs = { + let ca: &ChunkedArray = prim_s_rhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + match (&self.op_apply_type, &self.broadcast) { + // We skip for this because it dispatches to `ArithmeticKernel`, which handles the + // validities for us. + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {}, + _ if self.list_to_prim_lhs.is_none() => { + self.op.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {}, + _ => unreachable!(), } - Ok(()) - } - let mut arr_lhs = { - let ca: &ChunkedArray = prim_s_lhs.as_ref().as_ref(); - assert_eq!(ca.chunks().len(), 1); - ca.downcast_get(0).unwrap().clone() - }; - - let mut arr_rhs = { - let ca: &ChunkedArray = prim_s_rhs.as_ref().as_ref(); - assert_eq!(ca.chunks().len(), 1); - ca.downcast_get(0).unwrap().clone() - }; - - match (&self.op_apply_type, &self.broadcast) { - // We skip for this because it dispatches to `ArithmeticKernel`, which handles the - // validities for us. - (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {}, - _ if self.list_to_prim_lhs.is_none() => self - .op - .prepare_numeric_op_side_validities::(&mut arr_lhs, &mut arr_rhs, self.swapped), - (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {}, - _ => unreachable!(), - } + // + // General notes + // * Lists can be: + // * Sliced, in which case the primitive/leaf array needs to be indexed starting from an + // offset instead of 0. + // * Masked, in which case the masked rows are permitted to have non-matching widths. + // - // - // General notes - // * Lists can be: - // * Sliced, in which case the primitive/leaf array needs to be indexed starting from an - // offset instead of 0. - // * Masked, in which case the masked rows are permitted to have non-matching widths. - // - - let out = match (&self.op_apply_type, &self.broadcast) { - (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => { - let offsets_lhs = &self.data_lhs.0[0]; - let offsets_rhs = &self.data_rhs.0[0]; - - assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy()); - - // Output primitive (and optional validity) are aligned to the LHS input. - let n_values = arr_lhs.len(); - let mut out_vec: Vec = Vec::with_capacity(n_values); - let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); - - // Counter that stops being incremented at the first row position with mismatching - // list lengths. - let mut mismatch_pos = 0; - - with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { - for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs - .offset_and_length_iter() - .zip(offsets_rhs.offset_and_length_iter()) - .enumerate() - { - if - (mismatch_pos == i) - & ( - (lhs_len == rhs_len) - | unsafe { !self.outer_validity.get_bit_unchecked(i) } - ) - { - mismatch_pos += 1; - } + let out = match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; - // Both sides are lists, we restrict the index to the min length to avoid - // OOB memory access. - let len: usize = lhs_len.min(rhs_len); + assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy()); - for i in 0..len { - let l_idx = i + lhs_start; - let r_idx = i + rhs_start; + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); - let l = unsafe { arr_lhs.value_unchecked(l_idx) }; - let r = unsafe { arr_rhs.value_unchecked(r_idx) }; - let v = $OP(l, r); + // Counter that stops being incremented at the first row position with mismatching + // list lengths. + let mut mismatch_pos = 0; - unsafe { out_ptr.add(l_idx).write(v) }; - } - } - }); - - check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; - - unsafe { out_vec.set_len(n_values) }; - - /// Reduce monomorphization - #[inline(never)] - fn combine_validities_list_to_list_no_broadcast( - offsets_lhs: &OffsetsBuffer, - offsets_rhs: &OffsetsBuffer, - validity_lhs: Option<&Bitmap>, - validity_rhs: Option<&Bitmap>, - len_lhs: usize, - ) -> Option { - match (validity_lhs, validity_rhs) { - (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), - (Some(v), None) => return Some(v.clone()), - (None, Some(v)) => { - Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) - }, - (None, None) => None, - } - .map(|(mut validity_out, validity_rhs)| { - for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs .offset_and_length_iter() .zip(offsets_rhs.offset_and_length_iter()) + .enumerate() { + if + (mismatch_pos == i) + & ( + (lhs_len == rhs_len) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + ) + { + mismatch_pos += 1; + } + + // Both sides are lists, we restrict the index to the min length to avoid + // OOB memory access. let len: usize = lhs_len.min(rhs_len); for i in 0..len { let l_idx = i + lhs_start; let r_idx = i + rhs_start; - let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; - let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; - let is_valid = l_valid & r_valid; + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); - // Size and alignment of validity vec are based on LHS. - unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + unsafe { out_ptr.add(l_idx).write(v) }; } } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + /// Reduce monomorphization + #[inline(never)] + fn combine_validities_list_to_list_no_broadcast( + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + { + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } - validity_out.freeze() - }) - } - - let leaf_validity = combine_validities_list_to_list_no_broadcast( - offsets_lhs, - offsets_rhs, - arr_lhs.validity(), - arr_rhs.validity(), - arr_lhs.len(), - ); - - let arr = - PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); - - let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); - assert_eq!(offsets.len(), 1); - - self.finish_offsets_and_validities(Box::new(arr), offsets, validities) - }, - (BinaryOpApplyType::ListToList, Broadcast::Right) => { - let offsets_lhs = &self.data_lhs.0[0]; - let offsets_rhs = &self.data_rhs.0[0]; + validity_out.freeze() + }) + } - // Output primitive (and optional validity) are aligned to the LHS input. - let n_values = arr_lhs.len(); - let mut out_vec: Vec = Vec::with_capacity(n_values); - let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + let leaf_validity = combine_validities_list_to_list_no_broadcast( + offsets_lhs, + offsets_rhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); - assert_eq!(offsets_rhs.len_proxy(), 1); - let rhs_start = *offsets_rhs.first() as usize; - let width = offsets_rhs.range() as usize; + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); - let mut mismatch_pos = 0; + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); - with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { - for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() { - if ((lhs_len == width) & (mismatch_pos == i)) - | unsafe { !self.outer_validity.get_bit_unchecked(i) } - { - mismatch_pos += 1; - } + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToList, Broadcast::Right) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; - let len: usize = lhs_len.min(width); + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); - for i in 0..len { - let l_idx = i + lhs_start; - let r_idx = i + rhs_start; + assert_eq!(offsets_rhs.len_proxy(), 1); + let rhs_start = *offsets_rhs.first() as usize; + let width = offsets_rhs.range() as usize; - let l = unsafe { arr_lhs.value_unchecked(l_idx) }; - let r = unsafe { arr_rhs.value_unchecked(r_idx) }; - let v = $OP(l, r); + let mut mismatch_pos = 0; - unsafe { - out_ptr.add(l_idx).write(v); + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() { + if ((lhs_len == width) & (mismatch_pos == i)) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + { + mismatch_pos += 1; } - } - } - }); - - check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; - - unsafe { out_vec.set_len(n_values) }; - - #[inline(never)] - fn combine_validities_list_to_list_broadcast_right( - offsets_lhs: &OffsetsBuffer, - validity_lhs: Option<&Bitmap>, - validity_rhs: Option<&Bitmap>, - len_lhs: usize, - width: usize, - rhs_start: usize, - ) -> Option { - match (validity_lhs, validity_rhs) { - (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), - (Some(v), None) => return Some(v.clone()), - (None, Some(v)) => { - Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) - }, - (None, None) => None, - } - .map(|(mut validity_out, validity_rhs)| { - for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() { + let len: usize = lhs_len.min(width); for i in 0..len { let l_idx = i + lhs_start; let r_idx = i + rhs_start; - let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; - let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; - let is_valid = l_valid & r_valid; + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); - // Size and alignment of validity vec are based on LHS. - unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + unsafe { + out_ptr.add(l_idx).write(v); + } } } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + #[inline(never)] + fn combine_validities_list_to_list_broadcast_right( + offsets_lhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + width: usize, + rhs_start: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() { + let len: usize = lhs_len.min(width); - validity_out.freeze() - }) - } - - let leaf_validity = combine_validities_list_to_list_broadcast_right( - offsets_lhs, - arr_lhs.validity(), - arr_rhs.validity(), - arr_lhs.len(), - width, - rhs_start, - ); + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; - let arr = - PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; - let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); - assert_eq!(offsets.len(), 1); + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } - self.finish_offsets_and_validities(Box::new(arr), offsets, validities) - }, - (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) - if self.list_to_prim_lhs.is_none() => - { - let offsets_lhs = self.data_lhs.0.as_slice(); + validity_out.freeze() + }) + } - // Notes - // * Primitive indexing starts from 0 - // * Output is aligned to LHS array + let leaf_validity = combine_validities_list_to_list_broadcast_right( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + width, + rhs_start, + ); - let n_values = arr_lhs.len(); - let mut out_vec = Vec::::with_capacity(n_values); - let out_ptr = out_vec.as_mut_ptr(); + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); - with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { - for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() - { - let r = unsafe { arr_rhs.value_unchecked(i) }; - for l_idx in l_range { - unsafe { - let l = arr_lhs.value_unchecked(l_idx); - let v = $OP(l, r); - out_ptr.add(l_idx).write(v); - } - } - } - }); + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); - unsafe { out_vec.set_len(n_values) } + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + if self.list_to_prim_lhs.is_none() => + { + let offsets_lhs = self.data_lhs.0.as_slice(); - let leaf_validity = combine_validities_list_to_primitive_no_broadcast( - offsets_lhs, - arr_lhs.validity(), - arr_rhs.validity(), - arr_lhs.len(), - ); + // Notes + // * Primitive indexing starts from 0 + // * Output is aligned to LHS array - let arr = - PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + let n_values = arr_lhs.len(); + let mut out_vec = Vec::::with_capacity(n_values); + let out_ptr = out_vec.as_mut_ptr(); - let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); - self.finish_offsets_and_validities(Box::new(arr), offsets, validities) - }, - // If we are dispatched here, it means that the LHS array is a unique allocation created - // after a unit-length list column was broadcasted, so this codepath mutably stores the - // results back into the LHS array to save memory. - (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { - let offsets_lhs = self.data_lhs.0.as_slice(); - - let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap(); - let arr = arr - .as_any_mut() - .downcast_mut::>() - .unwrap(); - let mut arr_lhs = core::mem::take(arr); - - self.op.prepare_numeric_op_side_validities::( - &mut arr_lhs, - &mut arr_rhs, - self.swapped, - ); - - let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap(); - assert_eq!(arr_lhs_mut_slice.len(), n_values); - - with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { - for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() - { - let r = unsafe { arr_rhs.value_unchecked(i) }; - for l_idx in l_range { - unsafe { - let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx); - *l = $OP(*l, r); + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs.value_unchecked(l_idx); + let v = $OP(l, r); + out_ptr.add(l_idx).write(v); + } } } - } - }); + }); - let leaf_validity = combine_validities_list_to_primitive_no_broadcast( - offsets_lhs, - arr_lhs.validity(), - arr_rhs.validity(), - arr_lhs.len(), - ); + unsafe { out_vec.set_len(n_values) } - let arr = arr_lhs.with_validity(leaf_validity); + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); - let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); - self.finish_offsets_and_validities(Box::new(arr), offsets, validities) - }, - (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { - assert_eq!(arr_rhs.len(), 1); + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); - let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else { - // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL. let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); - return self.finish_offsets_and_validities( - Box::new( - arr_lhs - .clone() - .with_validity(Some(Bitmap::new_with_value(false, arr_lhs.len()))), - ), - offsets, - validities, + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + // If we are dispatched here, it means that the LHS array is a unique allocation created + // after a unit-length list column was broadcasted, so this codepath mutably stores the + // results back into the LHS array to save memory. + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { + let offsets_lhs = self.data_lhs.0.as_slice(); + + let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap(); + let arr = arr + .as_any_mut() + .downcast_mut::>() + .unwrap(); + let mut arr_lhs = core::mem::take(arr); + + self.op.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, ); - }; - let arr = self.op.apply_array_to_scalar::(arr_lhs, r, self.swapped); - let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap(); + assert_eq!(arr_lhs_mut_slice.len(), n_values); - self.finish_offsets_and_validities(Box::new(arr), offsets, validities) - }, - v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) - | v @ (BinaryOpApplyType::ListToList, Broadcast::Left) - | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) - | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) - | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { - if cfg!(debug_assertions) { - panic!("operation was not re-written: {:?}", v) - } else { - unreachable!() - } - }, - }?; + with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx); + *l = $OP(*l, r); + } + } + } + }); - Ok(out) - } + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); - /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every - /// level. - fn finish_offsets_and_validities( - &mut self, - leaf_array: Box, - offsets: Vec>, - validities: Vec>, - ) -> PolarsResult { - assert!(!offsets.is_empty()); - assert_eq!(offsets.len(), validities.len()); - let mut results = leaf_array; - - let mut iter = offsets.into_iter().zip(validities).rev(); - - while iter.len() > 1 { - let (offsets, validity) = iter.next().unwrap(); - let dtype = LargeListArray::default_datatype(results.dtype().clone()); - results = Box::new(LargeListArray::new(dtype, offsets, results, validity)); + let arr = arr_lhs.with_validity(leaf_validity); + + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + assert_eq!(arr_rhs.len(), 1); + + let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else { + // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL. + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + return self.finish_offsets_and_validities( + Box::new( + arr_lhs.clone().with_validity(Some(Bitmap::new_with_value( + false, + arr_lhs.len(), + ))), + ), + offsets, + validities, + ); + }; + + let arr = self.op.apply_array_to_scalar::(arr_lhs, r, self.swapped); + let (offsets, validities, _) = core::mem::take(&mut self.data_lhs); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) + | v @ (BinaryOpApplyType::ListToList, Broadcast::Left) + | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + if cfg!(debug_assertions) { + panic!("operation was not re-written: {:?}", v) + } else { + unreachable!() + } + }, + }?; + + Ok(out) } - // The combined outer validity is pre-computed during `try_new()` - let (offsets, _) = iter.next().unwrap(); - let validity = core::mem::take(&mut self.outer_validity); - let dtype = LargeListArray::default_datatype(results.dtype().clone()); - let results = LargeListArray::new(dtype, offsets, results, Some(validity)); + /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every + /// level. + fn finish_offsets_and_validities( + &mut self, + leaf_array: Box, + offsets: Vec>, + validities: Vec>, + ) -> PolarsResult { + assert!(!offsets.is_empty()); + assert_eq!(offsets.len(), validities.len()); + let mut results = leaf_array; + + let mut iter = offsets.into_iter().zip(validities).rev(); + + while iter.len() > 1 { + let (offsets, validity) = iter.next().unwrap(); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + results = Box::new(LargeListArray::new(dtype, offsets, results, validity)); + } - Ok(ListChunked::with_chunk( - core::mem::take(&mut self.output_name), - results, - )) - } + // The combined outer validity is pre-computed during `try_new()` + let (offsets, _) = iter.next().unwrap(); + let validity = core::mem::take(&mut self.outer_validity); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + let results = LargeListArray::new(dtype, offsets, results, Some(validity)); - fn materialize_broadcasted_list( - side_data: &mut (Vec>, Vec>, Series), - output_len: usize, - output_primitive_dtype: &DataType, - ) -> (Box, usize) { - let s = &side_data.2; - assert_eq!(s.len(), 1); - - let expected_n_values = { - let offsets = s.list_offsets_and_validities_recursive().0; - output_len * OffsetsBuffer::::leaf_full_start_end(&offsets).len() - }; - - let ca = s.list().unwrap(); - // Remember to cast the leaf primitives to the supertype. - let ca = ca - .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone())) - .unwrap(); - assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data. - let ca = ca.new_from_index(0, output_len).rechunk(); - - let s = ca.into_series(); - - *side_data = { - let (a, b) = s.list_offsets_and_validities_recursive(); - // `Series::default()`: This field in the tuple is no longer used. - (a, b, Series::default()) - }; - - let n_values = OffsetsBuffer::::leaf_full_start_end(&side_data.0).len(); - assert_eq!(n_values, expected_n_values); - - let mut s = s.get_leaf_array(); - let v = unsafe { s.chunks_mut() }; - - assert_eq!(v.len(), 1); - (v.swap_remove(0), n_values) - } -} + Ok(ListChunked::with_chunk( + core::mem::take(&mut self.output_name), + results, + )) + } -/// Used in 2 places, so it's outside here. -#[inline(never)] -fn combine_validities_list_to_primitive_no_broadcast( - offsets_lhs: &[OffsetsBuffer], - validity_lhs: Option<&Bitmap>, - validity_rhs: Option<&Bitmap>, - len_lhs: usize, -) -> Option { - match (validity_lhs, validity_rhs) { - (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), - (Some(v), None) => return Some(v.clone()), - // Materialize a full-true validity to re-use the codepath, as we still - // need to spread the bits from the RHS to the correct positions. - (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)), - (None, None) => None, + fn materialize_broadcasted_list( + side_data: &mut (Vec>, Vec>, Series), + output_len: usize, + output_primitive_dtype: &DataType, + ) -> (Box, usize) { + let s = &side_data.2; + assert_eq!(s.len(), 1); + + let expected_n_values = { + let offsets = s.list_offsets_and_validities_recursive().0; + output_len * OffsetsBuffer::::leaf_full_start_end(&offsets).len() + }; + + let ca = s.list().unwrap(); + // Remember to cast the leaf primitives to the supertype. + let ca = ca + .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone())) + .unwrap(); + assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data. + let ca = ca.new_from_index(0, output_len).rechunk(); + + let s = ca.into_series(); + + *side_data = { + let (a, b) = s.list_offsets_and_validities_recursive(); + // `Series::default()`: This field in the tuple is no longer used. + (a, b, Series::default()) + }; + + let n_values = OffsetsBuffer::::leaf_full_start_end(&side_data.0).len(); + assert_eq!(n_values, expected_n_values); + + let mut s = s.get_leaf_array(); + let v = unsafe { s.chunks_mut() }; + + assert_eq!(v.len(), 1); + (v.swap_remove(0), n_values) + } } - .map(|(mut validity_out, validity_rhs)| { - for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() { - let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) }; - for l_idx in l_range { - let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; - let is_valid = l_valid & r_valid; - - // Size and alignment of validity vec are based on LHS. - unsafe { validity_out.set_unchecked(l_idx, is_valid) }; - } + + /// Used in 2 places, so it's outside here. + #[inline(never)] + fn combine_validities_list_to_primitive_no_broadcast( + offsets_lhs: &[OffsetsBuffer], + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + // Materialize a full-true validity to re-use the codepath, as we still + // need to spread the bits from the RHS to the correct positions. + (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)), + (None, None) => None, } + .map(|(mut validity_out, validity_rhs)| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() { + let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) }; + for l_idx in l_range { + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } - validity_out.freeze() - }) + validity_out.freeze() + }) + } } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index b91df29a0a38..ace52993b8a1 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -358,11 +358,7 @@ impl SeriesTrait for SeriesWrap { Ok(Scalar::new(self.dtype().clone(), av)) } - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { Ok(Scalar::new(self.dtype().clone(), AnyValue::Null)) } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 30125ccc15b6..612505057eca 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -404,13 +404,9 @@ impl SeriesTrait for SeriesWrap { Ok(self.apply_scale(self.0.std_reduce(ddof))) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { self.0 - .quantile_reduce(quantile, interpol) + .quantile_reduce(quantile, method) .map(|v| self.apply_scale(v)) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 13b121aee0ca..803ca813aa1c 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -501,12 +501,8 @@ impl SeriesTrait for SeriesWrap { v.as_duration(self.0.time_unit()), )) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.0.quantile_reduce(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.0.quantile_reduce(quantile, method)?; let to = self.dtype().to_physical(); let v = v.value().cast(&to); Ok(Scalar::new( diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 24be56671d69..846e326d35b2 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -365,9 +365,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] fn and_reduce(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 9d8357a905bc..b2cb97e39b69 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -468,9 +468,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 805f06d86bac..d40e53d1a01e 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -232,7 +232,14 @@ impl SeriesTrait for SeriesWrap { } fn reverse(&self) -> Series { - self.0._apply_fields(|s| s.reverse()).unwrap().into_series() + let validity = self + .rechunk_validity() + .map(|x| x.into_iter().rev().collect::()); + self.0 + ._apply_fields(|s| s.reverse()) + .unwrap() + .with_outer_validity(validity) + .into_series() } fn shift(&self, periods: i64) -> Series { diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 35f5b1258b19..fadd2b4f570f 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -23,6 +23,7 @@ use arrow::offset::Offsets; pub use from::*; pub use iterator::{SeriesIter, SeriesPhysIter}; use num_traits::NumCast; +use polars_utils::itertools::Itertools; pub use series_trait::{IsSorted, *}; use crate::chunked_array::cast::CastOptions; @@ -298,11 +299,6 @@ impl Series { Self::try_from((name, array)) } - #[cfg(feature = "arrow_rs")] - pub fn from_arrow_rs(name: PlSmallStr, array: &dyn arrow_array::Array) -> PolarsResult { - Self::from_arrow(name, array.into()) - } - /// Shrink the capacity of this array to fit its length. pub fn shrink_to_fit(&mut self) { self._get_inner_mut().shrink_to_fit() @@ -595,15 +591,17 @@ impl Series { lhs.zip_with_same_type(mask, rhs.as_ref()) } - /// Cast a datelike Series to their physical representation. - /// Primitives remain unchanged + /// Converts a Series to their physical representation, if they have one, + /// otherwise the series is left unchanged. /// /// * Date -> Int32 - /// * Datetime-> Int64 + /// * Datetime -> Int64 + /// * Duration -> Int64 /// * Time -> Int64 /// * Categorical -> UInt32 /// * List(inner) -> List(physical of inner) - /// + /// * Array(inner) -> Array(physical of inner) + /// * Struct -> Struct with physical repr of each struct column pub fn to_physical_repr(&self) -> Cow { use DataType::*; match self.dtype() { @@ -623,6 +621,11 @@ impl Series { Cow::Owned(ca.physical().clone().into_series()) }, List(inner) => Cow::Owned(self.cast(&List(Box::new(inner.to_physical()))).unwrap()), + #[cfg(feature = "dtype-array")] + Array(inner, size) => Cow::Owned( + self.cast(&Array(Box::new(inner.to_physical()), *size)) + .unwrap(), + ), #[cfg(feature = "dtype-struct")] Struct(_) => { let arr = self.struct_().unwrap(); @@ -644,6 +647,75 @@ impl Series { } } + /// Attempts to convert a Series to dtype, only allowing conversions from + /// physical to logical dtypes--the inverse of to_physical_repr(). + /// + /// # Safety + /// When converting from UInt32 to Categorical it is not checked that the + /// values are in-bound for the categorical mapping. + pub unsafe fn to_logical_repr_unchecked(&self, dtype: &DataType) -> PolarsResult { + use DataType::*; + + let err = || { + Err( + polars_err!(ComputeError: "can't cast from {} to {} in to_logical_repr_unchecked", self.dtype(), dtype), + ) + }; + + match dtype { + dt if self.dtype() == dt => Ok(self.clone()), + #[cfg(feature = "dtype-date")] + Date => Ok(self.i32()?.clone().into_date().into_series()), + #[cfg(feature = "dtype-datetime")] + Datetime(u, z) => Ok(self + .i64()? + .clone() + .into_datetime(*u, z.clone()) + .into_series()), + #[cfg(feature = "dtype-duration")] + Duration(u) => Ok(self.i64()?.clone().into_duration(*u).into_series()), + #[cfg(feature = "dtype-time")] + Time => Ok(self.i64()?.clone().into_time().into_series()), + #[cfg(feature = "dtype-categorical")] + Categorical { .. } | Enum { .. } => { + Ok(CategoricalChunked::from_cats_and_dtype_unchecked( + self.u32()?.clone(), + dtype.clone(), + ) + .into_series()) + }, + List(inner) => { + if let List(self_inner) = self.dtype() { + if inner.to_physical() == **self_inner { + return self.cast(dtype); + } + } + err() + }, + #[cfg(feature = "dtype-struct")] + Struct(target_fields) => { + let ca = self.struct_().unwrap(); + if ca.struct_fields().len() != target_fields.len() { + return err(); + } + let fields = ca + .fields_as_series() + .iter() + .zip(target_fields) + .map(|(s, tf)| s.to_logical_repr_unchecked(tf.dtype())) + .try_collect_vec()?; + let mut result = + StructChunked::from_series(self.name().clone(), ca.len(), fields.iter())?; + if ca.null_count() > 0 { + result.zip_outer_validity(ca); + } + Ok(result.into_series()) + }, + + _ => err(), + } + } + /// Take by index if ChunkedArray contains a single chunk. /// /// # Safety @@ -873,8 +945,7 @@ impl Series { DataType::Categorical(Some(rv), _) | DataType::Enum(Some(rv), _) => match &**rv { RevMapping::Local(arr, _) => size += estimated_bytes_size(arr), RevMapping::Global(map, arr, _) => { - size += - map.capacity() * std::mem::size_of::() * 2 + estimated_bytes_size(arr); + size += map.capacity() * size_of::() * 2 + estimated_bytes_size(arr); }, }, _ => {}, @@ -932,7 +1003,7 @@ fn equal_outer_type(dtype: &DataType) -> bool { } } -impl<'a, T> AsRef> for dyn SeriesTrait + 'a +impl AsRef> for dyn SeriesTrait + '_ where T: 'static + PolarsDataType, { @@ -949,7 +1020,7 @@ where } } -impl<'a, T> AsMut> for dyn SeriesTrait + 'a +impl AsMut> for dyn SeriesTrait + '_ where T: 'static + PolarsDataType, { diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index 642faafbfdf8..85998aa54de3 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -12,16 +12,11 @@ use crate::datatypes::{DataType, ListChunked}; use crate::prelude::{IntoSeries, Series, *}; fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series { - let mut ca = match s.dtype() { - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { - ListChunked::with_chunk(name, array_to_unit_list(s.array_ref(0).clone())) - }, - _ => ListChunked::from_chunk_iter( - name, - s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())), - ), - }; + let mut ca = ListChunked::from_chunk_iter( + name, + s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())), + ); + ca.set_inner_dtype(s.dtype().clone()); ca.set_fast_explode(); ca.into_series() @@ -76,6 +71,16 @@ impl Series { (offsets, validities) } + /// For ListArrays, recursively normalizes the offsets to begin from 0, and + /// slices excess length from the values array. + pub fn list_rechunk_and_trim_to_normalized_offsets(&self) -> Self { + if let Some(ca) = self.try_list() { + ca.rechunk_and_trim_to_normalized_offsets().into_series() + } else { + self.rechunk() + } + } + /// Convert the values of this Series to a ListChunked with a length of 1, /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`. pub fn implode(&self) -> PolarsResult { @@ -111,7 +116,7 @@ impl Series { InvalidOperation: "at least one dimension must be specified" ); - let leaf_array = self.get_leaf_array(); + let leaf_array = self.get_leaf_array().rechunk(); let size = leaf_array.len(); let mut total_dim_size = 1; diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 1ee69300fa92..0352343baa82 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -519,11 +519,7 @@ pub trait SeriesTrait: polars_bail!(opq = std, self._dtype()); } /// Get the quantile of the ChunkedArray as a new Series of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { polars_bail!(opq = quantile, self._dtype()); } /// Get the bitwise AND of the Series as a new Series of length 1, @@ -617,7 +613,7 @@ pub trait SeriesTrait: } } -impl<'a> (dyn SeriesTrait + 'a) { +impl (dyn SeriesTrait + '_) { pub fn unpack(&self) -> PolarsResult<&ChunkedArray> where N: 'static + PolarsDataType, diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 169c30fd0498..05171fe35cfe 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -1159,7 +1159,7 @@ pub fn coalesce_nulls_columns(a: &Column, b: &Column) -> (Column, Column) { } pub fn operation_exceeded_idxsize_msg(operation: &str) -> String { - if core::mem::size_of::() == core::mem::size_of::() { + if size_of::() == size_of::() { format!( "{} exceeded the maximum supported limit of {} rows. Consider installing 'polars-u64-idx'.", operation, diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 027e85886793..18b8c9ddd00a 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -498,3 +498,56 @@ fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> { }, } } + +pub fn merge_dtypes_many + Clone, D: AsRef>( + into_iter: I, +) -> PolarsResult { + let mut iter = into_iter.clone().into_iter(); + + let mut st = iter + .next() + .ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype")) + .map(|d| d.as_ref().clone())?; + + for d in iter { + st = try_get_supertype(d.as_ref(), &st)?; + } + + match st { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(_), ordering) => { + // This merges the global rev maps with linear complexity. + // If we do a binary reduce, it would be quadratic. + let mut iter = into_iter.into_iter(); + let first_dt = iter.next().unwrap(); + let first_dt = first_dt.as_ref(); + let DataType::Categorical(Some(rm), _) = first_dt else { + unreachable!() + }; + + let mut merger = GlobalRevMapMerger::new(rm.clone()); + + for d in iter { + if let DataType::Categorical(Some(rm), _) = d.as_ref() { + merger.merge_map(rm)? + } + } + let rev_map = merger.finish(); + + Ok(DataType::Categorical(Some(rev_map), ordering)) + }, + // This would be quadratic if we do this with the binary `merge_dtypes`. + DataType::List(inner) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) if fields.iter().any(|f| f.dtype().contains_categoricals()) => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + _ => Ok(st), + } +} diff --git a/crates/polars-expr/Cargo.toml b/crates/polars-expr/Cargo.toml index d53585b17d43..29aa34652146 100644 --- a/crates/polars-expr/Cargo.toml +++ b/crates/polars-expr/Cargo.toml @@ -12,6 +12,7 @@ description = "Physical expression implementation of the Polars project." ahash = { workspace = true } arrow = { workspace = true } bitflags = { workspace = true } +hashbrown = { workspace = true } num-traits = { workspace = true } once_cell = { workspace = true } polars-compute = { workspace = true } @@ -20,8 +21,10 @@ polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } polars-ops = { workspace = true, features = ["chunked_ids"] } polars-plan = { workspace = true } +polars-row = { workspace = true } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } +rand = { workspace = true } rayon = { workspace = true } [features] diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index af5383e83c83..f1cfa5251899 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -715,19 +715,19 @@ impl PartitionedAggregation for AggregationExpr { pub struct AggQuantileExpr { pub(crate) input: Arc, pub(crate) quantile: Arc, - pub(crate) interpol: QuantileInterpolOptions, + pub(crate) method: QuantileMethod, } impl AggQuantileExpr { pub fn new( input: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { Self { input, quantile, - interpol, + method, } } @@ -750,7 +750,7 @@ impl PhysicalExpr for AggQuantileExpr { let input = self.input.evaluate(df, state)?; let quantile = self.get_quantile(df, state)?; input - .quantile_reduce(quantile, self.interpol) + .quantile_reduce(quantile, self.method) .map(|sc| sc.into_series(input.name().clone())) } #[allow(clippy::ptr_arg)] @@ -771,7 +771,7 @@ impl PhysicalExpr for AggQuantileExpr { let mut agg = unsafe { ac.flat_naive() .into_owned() - .agg_quantile(ac.groups(), quantile, self.interpol) + .agg_quantile(ac.groups(), quantile, self.method) }; agg.rename(keep_name); Ok(AggregationContext::from_agg_state( diff --git a/crates/polars-expr/src/expressions/alias.rs b/crates/polars-expr/src/expressions/alias.rs index 6b38d8dc8270..8d321263a3f5 100644 --- a/crates/polars-expr/src/expressions/alias.rs +++ b/crates/polars-expr/src/expressions/alias.rs @@ -59,6 +59,10 @@ impl PhysicalExpr for AliasExpr { )) } + fn is_literal(&self) -> bool { + self.physical_expr.is_literal() + } + fn is_scalar(&self) -> bool { self.physical_expr.is_scalar() } diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 8c71c90cd152..53579b763033 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -525,14 +525,6 @@ fn apply_multiple_elementwise<'a>( impl StatsEvaluator for ApplyExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { let read = self.should_read_impl(stats)?; - if ExecutionState::new().verbose() { - if read { - eprintln!("parquet file must be read, statistics not sufficient for predicate.") - } else { - eprintln!("parquet file can be skipped, the statistics were sufficient to apply the predicate.") - } - } - Ok(read) } } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index c1cb286e7104..23f50af45273 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -351,7 +351,7 @@ mod stats { use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), - Operator::NotEq => apply_operator_stats_eq(min_max, literal), + Operator::NotEq => apply_operator_stats_neq(min_max, literal), Operator::Gt => { // Literal is bigger than max value, selection needs all rows. C::gt(literal, min_max).map(|ca| ca.any()).unwrap_or(false) @@ -454,10 +454,6 @@ mod stats { impl StatsEvaluator for BinaryExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { - if std::env::var("POLARS_NO_PARQUET_STATISTICS").is_ok() { - return Ok(true); - } - use Operator::*; match ( self.left.as_stats_evaluator(), diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs index e7d6e523a62d..c9b1c26b6a94 100644 --- a/crates/polars-expr/src/expressions/gather.rs +++ b/crates/polars-expr/src/expressions/gather.rs @@ -127,7 +127,7 @@ impl GatherExpr { let idx: IdxCa = match groups.as_ref() { GroupsProxy::Idx(groups) => { if groups.all().iter().zip(idx).any(|(g, idx)| match idx { - None => true, + None => false, Some(idx) => idx >= g.len() as IdxSize, }) { self.oob_err()?; @@ -148,7 +148,7 @@ impl GatherExpr { }, GroupsProxy::Slice { groups, .. } => { if groups.iter().zip(idx).any(|(g, idx)| match idx { - None => true, + None => false, Some(idx) => idx >= g[1], }) { self.oob_err()?; diff --git a/crates/polars-expr/src/expressions/group_iter.rs b/crates/polars-expr/src/expressions/group_iter.rs index 6b1d54d0ac13..b42851e49d2a 100644 --- a/crates/polars-expr/src/expressions/group_iter.rs +++ b/crates/polars-expr/src/expressions/group_iter.rs @@ -4,7 +4,7 @@ use polars_core::series::amortized_iter::AmortSeries; use super::*; -impl<'a> AggregationContext<'a> { +impl AggregationContext<'_> { pub(super) fn iter_groups( &mut self, keep_names: bool, diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index c776e4b951dd..37600c71f06a 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -230,7 +230,7 @@ impl PhysicalExpr for TernaryExpr { // * `zip_with` can be called directly with the series // * mix of unit literals and AggregatedList // * `zip_with` can be called with the flat values after the offsets - // have been been checked for alignment + // have been checked for alignment let ac_target = non_literal_acs.first().unwrap(); let agg_state_out = match ac_target.agg_state() { diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index d03467d01da9..f843c0e83d95 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -754,7 +754,7 @@ where unsafe { values.set_len(len) } ChunkedArray::new_vec(ca.name().clone(), values).into_series() } else { - // We don't use a mutable bitmap as bits will have have race conditions! + // We don't use a mutable bitmap as bits will have race conditions! // A single byte might alias if we write from single threads. let mut validity: Vec = vec![false; len]; let validity_ptr = validity.as_mut_ptr(); diff --git a/crates/polars-expr/src/groups/mod.rs b/crates/polars-expr/src/groups/mod.rs new file mode 100644 index 000000000000..43091244c661 --- /dev/null +++ b/crates/polars-expr/src/groups/mod.rs @@ -0,0 +1,67 @@ +use std::any::Any; +use std::path::Path; + +use polars_core::prelude::*; +use polars_utils::aliases::PlRandomState; +use polars_utils::IdxSize; + +mod row_encoded; + +/// A Grouper maps keys to groups, such that duplicate keys map to the same group. +pub trait Grouper: Any + Send { + /// Creates a new empty Grouper similar to this one. + fn new_empty(&self) -> Box; + + /// Returns the number of groups in this Grouper. + fn num_groups(&self) -> IdxSize; + + /// Inserts the given keys into this Grouper, mutating groups_idxs such + /// that group_idxs[i] is the group index of keys[..][i]. + fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec); + + /// Adds the given Grouper into this one, mutating groups_idxs such that + /// the ith group of other now has group index group_idxs[i] in self. + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec); + + /// Partitions this Grouper into the given number of partitions. + /// + /// Updates partition_idxs such that the ith group of self moves to partition + /// partition_idxs[i]. + /// + /// It is guaranteed that two equal keys in two independent partition_into + /// calls map to the same partition index if the seed and the number of + /// partitions is equal. + fn partition( + &self, + seed: u64, + num_partitions: usize, + partition_idxs: &mut Vec, + ) -> Vec>; + + /// Returns the keys in this Grouper in group order, that is the key for + /// group i is returned in row i. + fn get_keys_in_group_order(&self) -> DataFrame; + + /// Returns the keys in this Grouper, mutating group_idxs such that the ith + /// key returned corresponds to group group_idxs[i]. + fn get_keys_groups(&self, group_idxs: &mut Vec) -> DataFrame; + + /// Stores this Grouper at the given path. + fn store_ooc(&self, _path: &Path) { + unimplemented!(); + } + + /// Loads this Grouper from the given path. + fn load_ooc(&mut self, _path: &Path) { + unimplemented!(); + } + + fn as_any(&self) -> &dyn Any; +} + +pub fn new_hash_grouper(key_schema: Arc, random_state: PlRandomState) -> Box { + Box::new(row_encoded::RowEncodedHashGrouper::new( + key_schema, + random_state, + )) +} diff --git a/crates/polars-expr/src/groups/row_encoded.rs b/crates/polars-expr/src/groups/row_encoded.rs new file mode 100644 index 000000000000..1a2fd5209436 --- /dev/null +++ b/crates/polars-expr/src/groups/row_encoded.rs @@ -0,0 +1,251 @@ +use std::mem::MaybeUninit; + +use hashbrown::hash_table::{Entry, HashTable}; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_unordered; +use polars_row::EncodingField; +use polars_utils::aliases::PlRandomState; +use polars_utils::hashing::{folded_multiply, hash_to_partition}; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; +use rand::Rng; + +use super::*; + +struct Group { + key_hash: u64, + key_offset: usize, + key_length: u32, + group_idx: IdxSize, +} + +impl Group { + unsafe fn key<'k>(&self, key_data: &'k [u8]) -> &'k [u8] { + key_data.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) + } +} + +#[derive(Default)] +pub struct RowEncodedHashGrouper { + key_schema: Arc, + table: HashTable, + key_data: Vec, + + // Used for computing canonical hashes. + random_state: PlRandomState, + + // Internal random seed used to keep hash iteration order decorrelated. + // We simply store a random odd number and multiply the canonical hash by it. + seed: u64, +} + +impl RowEncodedHashGrouper { + pub fn new(key_schema: Arc, random_state: PlRandomState) -> Self { + Self { + key_schema, + random_state, + seed: rand::random::() | 1, + ..Default::default() + } + } + + fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize { + let num_groups = self.table.len(); + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |g| unsafe { hash == g.key_hash && key == g.key(&self.key_data) }, + |g| g.key_hash.wrapping_mul(self.seed), + ); + + match entry { + Entry::Occupied(e) => e.get().group_idx, + Entry::Vacant(e) => { + let group_idx: IdxSize = num_groups.try_into().unwrap(); + let group = Group { + key_hash: hash, + key_offset: self.key_data.len(), + key_length: key.len().try_into().unwrap(), + group_idx, + }; + self.key_data.extend(key); + e.insert(group); + group_idx + }, + } + } + + /// Insert a key, without checking that it is unique. + fn insert_key_unique(&mut self, hash: u64, key: &[u8]) -> IdxSize { + let group_idx = self.table.len().try_into().unwrap(); + let group = Group { + key_hash: hash, + key_offset: self.key_data.len(), + key_length: key.len().try_into().unwrap(), + group_idx, + }; + self.key_data.extend(key); + self.table + .insert_unique(hash.wrapping_mul(self.seed), group, |g| { + g.key_hash.wrapping_mul(self.seed) + }); + group_idx + } + + fn finalize_keys(&self, mut key_rows: Vec<&[u8]>) -> DataFrame { + let key_dtypes = self + .key_schema + .iter() + .map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest())) + .collect::>(); + let fields = vec![EncodingField::new_unsorted(); key_dtypes.len()]; + let key_columns = + unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &key_dtypes) }; + + let cols = self + .key_schema + .iter() + .zip(key_columns) + .map(|((name, dt), col)| { + let s = Series::try_from((name.clone(), col)).unwrap(); + unsafe { s.to_logical_repr_unchecked(dt) } + .unwrap() + .into_column() + }) + .collect(); + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +impl Grouper for RowEncodedHashGrouper { + fn new_empty(&self) -> Box { + Box::new(Self::new( + self.key_schema.clone(), + self.random_state.clone(), + )) + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec) { + let series = keys + .get_columns() + .iter() + .map(|c| c.as_materialized_series().clone()) + .collect_vec(); + let keys_encoded = _get_rows_encoded_unordered(&series[..]) + .unwrap() + .into_array(); + assert!(keys_encoded.len() == keys[0].len()); + + group_idxs.clear(); + group_idxs.reserve(keys_encoded.len()); + for key in keys_encoded.values_iter() { + let hash = self.random_state.hash_one(key); + unsafe { + group_idxs.push_unchecked(self.insert_key(hash, key)); + } + } + } + + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec) { + let other = other.as_any().downcast_ref::().unwrap(); + + // TODO: cardinality estimation. + self.table + .reserve(other.table.len(), |g| g.key_hash.wrapping_mul(self.seed)); + + unsafe { + group_idxs.clear(); + group_idxs.reserve(other.table.len()); + let idx_out = group_idxs.spare_capacity_mut(); + for group in other.table.iter() { + let group_key = group.key(&other.key_data); + let new_idx = self.insert_key(group.key_hash, group_key); + *idx_out.get_unchecked_mut(group.group_idx as usize) = MaybeUninit::new(new_idx); + } + group_idxs.set_len(other.table.len()); + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.table.len()); + unsafe { + let out = key_rows.spare_capacity_mut(); + for group in &self.table { + *out.get_unchecked_mut(group.group_idx as usize) = + MaybeUninit::new(group.key(&self.key_data)); + } + key_rows.set_len(self.table.len()); + } + self.finalize_keys(key_rows) + } + + fn get_keys_groups(&self, group_idxs: &mut Vec) -> DataFrame { + group_idxs.clear(); + group_idxs.reserve(self.table.len()); + self.finalize_keys( + self.table + .iter() + .map(|group| unsafe { + group_idxs.push(group.group_idx); + group.key(&self.key_data) + }) + .collect(), + ) + } + + fn partition( + &self, + seed: u64, + num_partitions: usize, + partition_idxs: &mut Vec, + ) -> Vec> { + assert!(num_partitions > 0); + + // Two-pass algorithm to prevent reallocations. + let mut partition_size = vec![(0, 0); num_partitions]; // (keys, bytes) + unsafe { + for group in self.table.iter() { + let ph = folded_multiply(group.key_hash, seed | 1); + let p_idx = hash_to_partition(ph, num_partitions); + let (p_keys, p_bytes) = partition_size.get_unchecked_mut(p_idx as usize); + *p_keys += 1; + *p_bytes += group.key_length as usize; + } + } + + let mut rng = rand::thread_rng(); + let mut partitions = partition_size + .into_iter() + .map(|(keys, bytes)| Self { + key_schema: self.key_schema.clone(), + table: HashTable::with_capacity(keys), + key_data: Vec::with_capacity(bytes), + random_state: self.random_state.clone(), + seed: rng.gen::() | 1, + }) + .collect_vec(); + + unsafe { + partition_idxs.clear(); + partition_idxs.reserve(self.table.len()); + let partition_idxs_out = partition_idxs.spare_capacity_mut(); + for group in self.table.iter() { + let ph = folded_multiply(group.key_hash, seed | 1); + let p_idx = hash_to_partition(ph, num_partitions); + let p = partitions.get_unchecked_mut(p_idx); + p.insert_key_unique(group.key_hash, group.key(&self.key_data)); + *partition_idxs_out.get_unchecked_mut(group.group_idx as usize) = + MaybeUninit::new(p_idx as IdxSize); + } + partition_idxs.set_len(self.table.len()); + } + + partitions.into_iter().map(|p| Box::new(p) as _).collect() + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 9981e47f1451..2778f4621dc2 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -1,4 +1,5 @@ mod expressions; +pub mod groups; pub mod planner; pub mod prelude; pub mod reduce; diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 620f8bf87089..c4006de0c8ec 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -402,7 +402,9 @@ fn create_physical_expr_inner( }, _ => { if let IRAggExpr::Quantile { - quantile, interpol, .. + quantile, + method: interpol, + .. } = agg { let quantile = diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs index 3573192ae16f..55a4b325bda1 100644 --- a/crates/polars-expr/src/reduce/convert.rs +++ b/crates/polars-expr/src/reduce/convert.rs @@ -7,6 +7,7 @@ use crate::reduce::len::LenReduce; use crate::reduce::mean::new_mean_reduction; use crate::reduce::min_max::{new_max_reduction, new_min_reduction}; use crate::reduce::sum::new_sum_reduction; +use crate::reduce::var_std::new_var_std_reduction; /// Converts a node into a reduction + its associated selector expression. pub fn into_reduction( @@ -17,7 +18,8 @@ pub fn into_reduction( let get_dt = |node| { expr_arena .get(node) - .to_dtype(schema, Context::Default, expr_arena) + .to_dtype(schema, Context::Default, expr_arena)? + .materialize_unknown() }; let out = match expr_arena.get(node) { AExpr::Agg(agg) => match agg { @@ -31,6 +33,12 @@ pub fn into_reduction( propagate_nans, input, } => (new_max_reduction(get_dt(*input)?, *propagate_nans), *input), + IRAggExpr::Var(input, ddof) => { + (new_var_std_reduction(get_dt(*input)?, false, *ddof), *input) + }, + IRAggExpr::Std(input, ddof) => { + (new_var_std_reduction(get_dt(*input)?, true, *ddof), *input) + }, _ => todo!(), }, AExpr::Len => { diff --git a/crates/polars-expr/src/reduce/len.rs b/crates/polars-expr/src/reduce/len.rs index db8aee647824..57641b1a02b6 100644 --- a/crates/polars-expr/src/reduce/len.rs +++ b/crates/polars-expr/src/reduce/len.rs @@ -1,6 +1,7 @@ use polars_core::error::constants::LENGTH_LIMIT_MSG; use super::*; +use crate::reduce::partition::partition_vec; #[derive(Default)] pub struct LenReduce { @@ -42,7 +43,7 @@ impl GroupedReduction for LenReduce { group_idxs: &[IdxSize], ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - assert!(self.groups.len() == other.groups.len()); + assert!(other.groups.len() == group_idxs.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, v) in group_idxs.iter().zip(other.groups.iter()) { @@ -61,6 +62,17 @@ impl GroupedReduction for LenReduce { Ok(ca.into_series()) } + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition_vec(self.groups, partition_sizes, partition_idxs) + .into_iter() + .map(|groups| Box::new(Self { groups }) as _) + .collect() + } + fn as_any(&self) -> &dyn Any { self } diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs index 0caa2ccabcb8..4a8ec962f237 100644 --- a/crates/polars-expr/src/reduce/mean.rs +++ b/crates/polars-expr/src/reduce/mean.rs @@ -9,14 +9,14 @@ pub fn new_mean_reduction(dtype: DataType) -> Box { use DataType::*; use VecGroupedReduction as VGR; match dtype { - Boolean => Box::new(VGR::::new(dtype)), + Boolean => Box::new(VGR::new(dtype, BoolMeanReducer)), _ if dtype.is_numeric() || dtype.is_temporal() => { with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { - Box::new(VGR::>::new(dtype)) + Box::new(VGR::new(dtype, NumMeanReducer::<$T>(PhantomData))) }) }, #[cfg(feature = "dtype-decimal")] - Decimal(_, _) => Box::new(VGR::>::new(dtype)), + Decimal(_, _) => Box::new(VGR::new(dtype, NumMeanReducer::(PhantomData))), _ => unimplemented!(), } } @@ -67,6 +67,11 @@ fn finish_output(values: Vec<(f64, usize)>, dtype: &DataType) -> Series { } struct NumMeanReducer(PhantomData); +impl Clone for NumMeanReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} impl Reducer for NumMeanReducer where @@ -77,37 +82,43 @@ where type Value = (f64, usize); #[inline(always)] - fn init() -> Self::Value { + fn init(&self) -> Self::Value { (0.0, 0) } - fn cast_series(s: &Series) -> Cow<'_, Series> { + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { s.to_physical_repr() } #[inline(always)] - fn combine(a: &mut Self::Value, b: &Self::Value) { + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { a.0 += b.0; a.1 += b.1; } #[inline(always)] - fn reduce_one(a: &mut Self::Value, b: Option) { + fn reduce_one(&self, a: &mut Self::Value, b: Option) { a.0 += b.unwrap_or(T::Native::zero()).as_(); a.1 += b.is_some() as usize; } - fn reduce_ca(v: &mut Self::Value, ca: &ChunkedArray) { + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { v.0 += ChunkAgg::_sum_as_f64(ca); v.1 += ca.len() - ca.null_count(); } - fn finish(v: Vec, m: Option, dtype: &DataType) -> PolarsResult { + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { assert!(m.is_none()); Ok(finish_output(v, dtype)) } } +#[derive(Clone)] struct BoolMeanReducer; impl Reducer for BoolMeanReducer { @@ -115,33 +126,38 @@ impl Reducer for BoolMeanReducer { type Value = (usize, usize); #[inline(always)] - fn init() -> Self::Value { + fn init(&self) -> Self::Value { (0, 0) } #[inline(always)] - fn combine(a: &mut Self::Value, b: &Self::Value) { + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { a.0 += b.0; a.1 += b.1; } #[inline(always)] - fn reduce_one(a: &mut Self::Value, b: Option) { + fn reduce_one(&self, a: &mut Self::Value, b: Option) { a.0 += b.unwrap_or(false) as usize; a.1 += b.is_some() as usize; } - fn reduce_ca(v: &mut Self::Value, ca: &ChunkedArray) { + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { v.0 += ca.sum().unwrap_or(0) as usize; v.1 += ca.len() - ca.null_count(); } - fn finish(v: Vec, m: Option, dtype: &DataType) -> PolarsResult { + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { assert!(m.is_none()); assert!(dtype == &DataType::Boolean); let ca: Float64Chunked = v .into_iter() - .map(|(s, c)| s as f64 / c as f64) + .map(|(s, c)| (c != 0).then(|| s as f64 / c as f64)) .collect_ca(PlSmallStr::EMPTY); Ok(ca.into_series()) } diff --git a/crates/polars-expr/src/reduce/min_max.rs b/crates/polars-expr/src/reduce/min_max.rs index f1ec0cbcc5d2..de25d3efc927 100644 --- a/crates/polars-expr/src/reduce/min_max.rs +++ b/crates/polars-expr/src/reduce/min_max.rs @@ -11,6 +11,7 @@ use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; use super::*; +use crate::reduce::partition::partition_mask; pub fn new_min_reduction(dtype: DataType, propagate_nans: bool) -> Box { use DataType::*; @@ -18,19 +19,23 @@ pub fn new_min_reduction(dtype: DataType, propagate_nans: bool) -> Box Box::new(BoolMinGroupedReduction::default()), #[cfg(feature = "propagate_nans")] - Float32 if propagate_nans => Box::new(VMGR::>::new(dtype)), + Float32 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, #[cfg(feature = "propagate_nans")] - Float64 if propagate_nans => Box::new(VMGR::>::new(dtype)), - Float32 => Box::new(VMGR::>::new(dtype)), - Float64 => Box::new(VMGR::>::new(dtype)), - String | Binary => Box::new(VecGroupedReduction::::new(dtype)), + Float64 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + Float32 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + Float64 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + String | Binary => Box::new(VecGroupedReduction::new(dtype, BinaryMinReducer)), _ if dtype.is_integer() || dtype.is_temporal() => { with_match_physical_integer_polars_type!(dtype.to_physical(), |$T| { - Box::new(VMGR::>::new(dtype)) + Box::new(VMGR::new(dtype, NumReducer::>::new())) }) }, #[cfg(feature = "dtype-decimal")] - Decimal(_, _) => Box::new(VMGR::>::new(dtype)), + Decimal(_, _) => Box::new(VMGR::new(dtype, NumReducer::>::new())), _ => unimplemented!(), } } @@ -41,34 +46,38 @@ pub fn new_max_reduction(dtype: DataType, propagate_nans: bool) -> Box Box::new(BoolMaxGroupedReduction::default()), #[cfg(feature = "propagate_nans")] - Float32 if propagate_nans => Box::new(VMGR::>::new(dtype)), + Float32 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, #[cfg(feature = "propagate_nans")] - Float64 if propagate_nans => Box::new(VMGR::>::new(dtype)), - Float32 => Box::new(VMGR::>::new(dtype)), - Float64 => Box::new(VMGR::>::new(dtype)), - String | Binary => Box::new(VecGroupedReduction::::new(dtype)), + Float64 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + Float32 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + Float64 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + String | Binary => Box::new(VecGroupedReduction::new(dtype, BinaryMaxReducer)), _ if dtype.is_integer() || dtype.is_temporal() => { with_match_physical_integer_polars_type!(dtype.to_physical(), |$T| { - Box::new(VMGR::>::new(dtype)) + Box::new(VMGR::new(dtype, NumReducer::>::new())) }) }, #[cfg(feature = "dtype-decimal")] - Decimal(_, _) => Box::new(VMGR::>::new(dtype)), + Decimal(_, _) => Box::new(VMGR::new(dtype, NumReducer::>::new())), _ => unimplemented!(), } } // These two variants ignore nans. -struct MinReducer(PhantomData); -struct MaxReducer(PhantomData); +struct Min(PhantomData); +struct Max(PhantomData); // These two variants propagate nans. #[cfg(feature = "propagate_nans")] -struct NanMinReducer(PhantomData); +struct NanMin(PhantomData); #[cfg(feature = "propagate_nans")] -struct NanMaxReducer(PhantomData); +struct NanMax(PhantomData); -impl NumericReducer for MinReducer +impl NumericReduction for Min where T: PolarsNumericType, ChunkedArray: ChunkAgg, @@ -95,7 +104,7 @@ where } } -impl NumericReducer for MaxReducer +impl NumericReduction for Max where T: PolarsNumericType, ChunkedArray: ChunkAgg, @@ -123,7 +132,7 @@ where } #[cfg(feature = "propagate_nans")] -impl NumericReducer for NanMinReducer { +impl NumericReduction for NanMin { type Dtype = T; #[inline(always)] @@ -143,7 +152,7 @@ impl NumericReducer for NanMinReducer { } #[cfg(feature = "propagate_nans")] -impl NumericReducer for NanMaxReducer { +impl NumericReduction for NanMax { type Dtype = T; #[inline(always)] @@ -162,27 +171,29 @@ impl NumericReducer for NanMaxReducer { } } +#[derive(Clone)] struct BinaryMinReducer; +#[derive(Clone)] struct BinaryMaxReducer; impl Reducer for BinaryMinReducer { type Dtype = BinaryType; type Value = Option>; // TODO: evaluate SmallVec. - fn init() -> Self::Value { + fn init(&self) -> Self::Value { None } #[inline(always)] - fn cast_series(s: &Series) -> Cow<'_, Series> { + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { Cow::Owned(s.cast(&DataType::Binary).unwrap()) } - fn combine(a: &mut Self::Value, b: &Self::Value) { - Self::reduce_one(a, b.as_deref()) + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + self.reduce_one(a, b.as_deref()) } - fn reduce_one(a: &mut Self::Value, b: Option<&[u8]>) { + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>) { match (a, b) { (_, None) => {}, (l @ None, Some(r)) => *l = Some(r.to_owned()), @@ -195,11 +206,16 @@ impl Reducer for BinaryMinReducer { } } - fn reduce_ca(v: &mut Self::Value, ca: &BinaryChunked) { - Self::reduce_one(v, ca.min_binary()) + fn reduce_ca(&self, v: &mut Self::Value, ca: &BinaryChunked) { + self.reduce_one(v, ca.min_binary()) } - fn finish(v: Vec, m: Option, dtype: &DataType) -> PolarsResult { + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { assert!(m.is_none()); // This should only be used with VecGroupedReduction. let ca: BinaryChunked = v.into_iter().collect_ca(PlSmallStr::EMPTY); ca.into_series().cast(dtype) @@ -211,22 +227,22 @@ impl Reducer for BinaryMaxReducer { type Value = Option>; // TODO: evaluate SmallVec. #[inline(always)] - fn init() -> Self::Value { + fn init(&self) -> Self::Value { None } #[inline(always)] - fn cast_series(s: &Series) -> Cow<'_, Series> { + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { Cow::Owned(s.cast(&DataType::Binary).unwrap()) } #[inline(always)] - fn combine(a: &mut Self::Value, b: &Self::Value) { - Self::reduce_one(a, b.as_deref()) + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + self.reduce_one(a, b.as_deref()) } #[inline(always)] - fn reduce_one(a: &mut Self::Value, b: Option<&[u8]>) { + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>) { match (a, b) { (_, None) => {}, (l @ None, Some(r)) => *l = Some(r.to_owned()), @@ -240,12 +256,17 @@ impl Reducer for BinaryMaxReducer { } #[inline(always)] - fn reduce_ca(v: &mut Self::Value, ca: &BinaryChunked) { - Self::reduce_one(v, ca.max_binary()) + fn reduce_ca(&self, v: &mut Self::Value, ca: &BinaryChunked) { + self.reduce_one(v, ca.max_binary()) } #[inline(always)] - fn finish(v: Vec, m: Option, dtype: &DataType) -> PolarsResult { + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { assert!(m.is_none()); // This should only be used with VecGroupedReduction. let ca: BinaryChunked = v.into_iter().collect_ca(PlSmallStr::EMPTY); ca.into_series().cast(dtype) @@ -324,6 +345,25 @@ impl GroupedReduction for BoolMinGroupedReduction { Ok(()) } + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + let p_values = partition_mask(&self.values.freeze(), partition_sizes, partition_idxs); + let p_mask = partition_mask(&self.mask.freeze(), partition_sizes, partition_idxs); + p_values + .into_iter() + .zip(p_mask) + .map(|(values, mask)| { + Box::new(Self { + values: values.into_mut(), + mask: mask.into_mut(), + }) as _ + }) + .collect() + } + fn finalize(&mut self) -> PolarsResult { let v = core::mem::take(&mut self.values); let m = core::mem::take(&mut self.mask); @@ -401,8 +441,7 @@ impl GroupedReduction for BoolMaxGroupedReduction { group_idxs: &[IdxSize], ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - assert!(self.values.len() == other.values.len()); - assert!(self.mask.len() == other.mask.len()); + assert!(other.values.len() == group_idxs.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, (v, o)) in group_idxs @@ -431,6 +470,25 @@ impl GroupedReduction for BoolMaxGroupedReduction { }) } + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + let p_values = partition_mask(&self.values.freeze(), partition_sizes, partition_idxs); + let p_mask = partition_mask(&self.mask.freeze(), partition_sizes, partition_idxs); + p_values + .into_iter() + .zip(p_mask) + .map(|(values, mask)| { + Box::new(Self { + values: values.into_mut(), + mask: mask.into_mut(), + }) as _ + }) + .collect() + } + fn as_any(&self) -> &dyn Any { self } diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs index 8fc0620f27fe..bfe4cb56417b 100644 --- a/crates/polars-expr/src/reduce/mod.rs +++ b/crates/polars-expr/src/reduce/mod.rs @@ -2,15 +2,15 @@ mod convert; mod len; mod mean; mod min_max; -// #[cfg(feature = "propagate_nans")] -// mod nan_min_max; +mod partition; mod sum; +mod var_std; use std::any::Any; use std::borrow::Cow; use std::marker::PhantomData; -use arrow::array::PrimitiveArray; +use arrow::array::{Array, PrimitiveArray, StaticArray}; use arrow::bitmap::{Bitmap, MutableBitmap}; pub use convert::into_reduction; use polars_core::prelude::*; @@ -50,6 +50,22 @@ pub trait GroupedReduction: Any + Send { group_idxs: &[IdxSize], ) -> PolarsResult<()>; + /// Partitions this GroupedReduction into several partitions. + /// + /// The ith group of this GroupedReduction should becomes the group_idxs[i] + /// group in partition partition_idxs[i]. + /// + /// # Safety + /// partitions_idxs[i] < partition_sizes.len() for all i. + /// group_idxs[i] < partition_sizes[partition_idxs[i]] for all i. + /// Each partition p has an associated set of group_idxs, this set contains + /// 0..partition_size[p] exactly once. + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec>; + /// Returns the finalized value per group as a Series. /// /// After this operation the number of groups is reset to 0. @@ -61,21 +77,30 @@ pub trait GroupedReduction: Any + Send { // Helper traits used in the VecGroupedReduction and VecMaskGroupedReduction to // reduce code duplication. -pub trait Reducer: Send + Sync + 'static { +pub trait Reducer: Send + Sync + Clone + 'static { type Dtype: PolarsDataType; type Value: Clone + Send + Sync + 'static; - fn init() -> Self::Value; + fn init(&self) -> Self::Value; #[inline(always)] - fn cast_series(s: &Series) -> Cow<'_, Series> { + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { Cow::Borrowed(s) } - fn combine(a: &mut Self::Value, b: &Self::Value); - fn reduce_one(a: &mut Self::Value, b: Option<::Physical<'_>>); - fn reduce_ca(v: &mut Self::Value, ca: &ChunkedArray); - fn finish(v: Vec, m: Option, dtype: &DataType) -> PolarsResult; + fn combine(&self, a: &mut Self::Value, b: &Self::Value); + fn reduce_one( + &self, + a: &mut Self::Value, + b: Option<::Physical<'_>>, + ); + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray); + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult; } -pub trait NumericReducer: Send + Sync + 'static { +pub trait NumericReduction: Send + Sync + 'static { type Dtype: PolarsNumericType; fn init() -> ::Native; fn combine( @@ -87,40 +112,61 @@ pub trait NumericReducer: Send + Sync + 'static { ) -> Option<::Native>; } -impl Reducer for T { - type Dtype = ::Dtype; - type Value = <::Dtype as PolarsNumericType>::Native; +struct NumReducer(PhantomData); +impl NumReducer { + fn new() -> Self { + Self(PhantomData) + } +} +impl Clone for NumReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Reducer for NumReducer { + type Dtype = ::Dtype; + type Value = <::Dtype as PolarsNumericType>::Native; #[inline(always)] - fn init() -> Self::Value { - ::init() + fn init(&self) -> Self::Value { + ::init() } #[inline(always)] - fn cast_series(s: &Series) -> Cow<'_, Series> { + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { s.to_physical_repr() } #[inline(always)] - fn combine(a: &mut Self::Value, b: &Self::Value) { - *a = ::combine(*a, *b); + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + *a = ::combine(*a, *b); } #[inline(always)] - fn reduce_one(a: &mut Self::Value, b: Option<::Physical<'_>>) { + fn reduce_one( + &self, + a: &mut Self::Value, + b: Option<::Physical<'_>>, + ) { if let Some(b) = b { - *a = ::combine(*a, b); + *a = ::combine(*a, b); } } #[inline(always)] - fn reduce_ca(v: &mut Self::Value, ca: &ChunkedArray) { - if let Some(r) = ::reduce_ca(ca) { - *v = ::combine(*v, r); + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + if let Some(r) = ::reduce_ca(ca) { + *v = ::combine(*v, r); } } - fn finish(v: Vec, m: Option, dtype: &DataType) -> PolarsResult { + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { let arr = Box::new(PrimitiveArray::::from_vec(v).with_validity(m)); Ok(unsafe { Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![arr], dtype) }) } @@ -129,15 +175,15 @@ impl Reducer for T { pub struct VecGroupedReduction { values: Vec, in_dtype: DataType, - reducer: PhantomData, + reducer: R, } impl VecGroupedReduction { - fn new(in_dtype: DataType) -> Self { + fn new(in_dtype: DataType, reducer: R) -> Self { Self { values: Vec::new(), in_dtype, - reducer: PhantomData, + reducer, } } } @@ -150,21 +196,20 @@ where Box::new(Self { values: Vec::new(), in_dtype: self.in_dtype.clone(), - reducer: PhantomData, + reducer: self.reducer.clone(), }) } fn resize(&mut self, num_groups: IdxSize) { - self.values.resize(num_groups as usize, R::init()); + self.values.resize(num_groups as usize, self.reducer.init()); } fn update_group(&mut self, values: &Series, group_idx: IdxSize) -> PolarsResult<()> { - // TODO: we should really implement a sum-as-other-type operation instead - // of doing this materialized cast. assert!(values.dtype() == &self.in_dtype); - let values = R::cast_series(values); + let values = self.reducer.cast_series(values); let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); - R::reduce_ca(&mut self.values[group_idx as usize], ca); + self.reducer + .reduce_ca(&mut self.values[group_idx as usize], ca); Ok(()) } @@ -173,17 +218,27 @@ where values: &Series, group_idxs: &[IdxSize], ) -> PolarsResult<()> { - // TODO: we should really implement a sum-as-other-type operation instead - // of doing this materialized cast. assert!(values.dtype() == &self.in_dtype); assert!(values.len() == group_idxs.len()); - let values = R::cast_series(values); + let values = self.reducer.cast_series(values); let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. - for (g, ov) in group_idxs.iter().zip(ca.iter()) { - let grp = self.values.get_unchecked_mut(*g as usize); - R::reduce_one(grp, ov); + if values.has_nulls() { + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, ov); + } + } else { + let mut offset = 0; + for arr in ca.downcast_iter() { + let subgroup = &group_idxs[offset..offset + arr.len()]; + for (g, v) in subgroup.iter().zip(arr.values_iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, Some(v)); + } + offset += arr.len(); + } } } Ok(()) @@ -196,20 +251,37 @@ where ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); assert!(self.in_dtype == other.in_dtype); - assert!(self.values.len() == other.values.len()); + assert!(group_idxs.len() == other.values.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, v) in group_idxs.iter().zip(other.values.iter()) { let grp = self.values.get_unchecked_mut(*g as usize); - R::combine(grp, v); + self.reducer.combine(grp, v); } } Ok(()) } + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition::partition_vec(self.values, partition_sizes, partition_idxs) + .into_iter() + .map(|values| { + Box::new(Self { + values, + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) as _ + }) + .collect() + } + fn finalize(&mut self) -> PolarsResult { let v = core::mem::take(&mut self.values); - R::finish(v, None, &self.in_dtype) + self.reducer.finish(v, None, &self.in_dtype) } fn as_any(&self) -> &dyn Any { @@ -221,16 +293,16 @@ pub struct VecMaskGroupedReduction { values: Vec, mask: MutableBitmap, in_dtype: DataType, - reducer: PhantomData, + reducer: R, } impl VecMaskGroupedReduction { - fn new(in_dtype: DataType) -> Self { + fn new(in_dtype: DataType, reducer: R) -> Self { Self { values: Vec::new(), mask: MutableBitmap::new(), in_dtype, - reducer: PhantomData, + reducer, } } } @@ -244,12 +316,12 @@ where values: Vec::new(), mask: MutableBitmap::new(), in_dtype: self.in_dtype.clone(), - reducer: PhantomData, + reducer: self.reducer.clone(), }) } fn resize(&mut self, num_groups: IdxSize) { - self.values.resize(num_groups as usize, R::init()); + self.values.resize(num_groups as usize, self.reducer.init()); self.mask.resize(num_groups as usize, false); } @@ -259,7 +331,8 @@ where assert!(values.dtype() == &self.in_dtype); let values = values.to_physical_repr(); let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); - R::reduce_ca(&mut self.values[group_idx as usize], ca); + self.reducer + .reduce_ca(&mut self.values[group_idx as usize], ca); if ca.len() != ca.null_count() { self.mask.set(group_idx as usize, true); } @@ -282,7 +355,7 @@ where for (g, ov) in group_idxs.iter().zip(ca.iter()) { if let Some(v) = ov { let grp = self.values.get_unchecked_mut(*g as usize); - R::reduce_one(grp, Some(v)); + self.reducer.reduce_one(grp, Some(v)); self.mask.set_unchecked(*g as usize, true); } } @@ -297,8 +370,7 @@ where ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); assert!(self.in_dtype == other.in_dtype); - assert!(self.values.len() == other.values.len()); - assert!(self.mask.len() == other.mask.len()); + assert!(group_idxs.len() == other.values.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, (v, o)) in group_idxs @@ -307,7 +379,7 @@ where { if o { let grp = self.values.get_unchecked_mut(*g as usize); - R::combine(grp, v); + self.reducer.combine(grp, v); self.mask.set_unchecked(*g as usize, true); } } @@ -315,10 +387,33 @@ where Ok(()) } + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition::partition_vec_mask( + self.values, + &self.mask.freeze(), + partition_sizes, + partition_idxs, + ) + .into_iter() + .map(|(values, mask)| { + Box::new(Self { + values, + mask: mask.into_mut(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) as _ + }) + .collect() + } + fn finalize(&mut self) -> PolarsResult { let v = core::mem::take(&mut self.values); let m = core::mem::take(&mut self.mask); - R::finish(v, Some(m.freeze()), &self.in_dtype) + self.reducer.finish(v, Some(m.freeze()), &self.in_dtype) } fn as_any(&self) -> &dyn Any { diff --git a/crates/polars-expr/src/reduce/partition.rs b/crates/polars-expr/src/reduce/partition.rs new file mode 100644 index 000000000000..0152035879bd --- /dev/null +++ b/crates/polars-expr/src/reduce/partition.rs @@ -0,0 +1,105 @@ +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; +use polars_utils::IdxSize; + +/// Partitions this Vec into multiple Vecs. +/// +/// # Safety +/// partitions_idxs[i] < partition_sizes.len() for all i. +/// idx_in_partition[i] < partition_sizes[partition_idxs[i]] for all i. +/// Each partition p has an associated set of idx_in_partition, this set +/// contains 0..partition_size[p] exactly once. +pub unsafe fn partition_vec( + v: Vec, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], +) -> Vec> { + assert!(partition_idxs.len() == v.len()); + + let mut partitions = partition_sizes + .iter() + .map(|sz| Vec::::with_capacity(*sz as usize)) + .collect_vec(); + + unsafe { + // Scatter into each partition. + for (i, val) in v.into_iter().enumerate() { + let p_idx = *partition_idxs.get_unchecked(i) as usize; + debug_assert!(p_idx < partitions.len()); + let p = partitions.get_unchecked_mut(p_idx); + p.push_unchecked(val); + } + + for (p, sz) in partitions.iter_mut().zip(partition_sizes) { + p.set_len(*sz as usize); + } + } + + partitions +} + +/// # Safety +/// Same as partition_vec. +pub unsafe fn partition_mask( + m: &Bitmap, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], +) -> Vec { + assert!(partition_idxs.len() == m.len()); + + let mut partitions = partition_sizes + .iter() + .map(|sz| BitmapBuilder::with_capacity(*sz as usize)) + .collect_vec(); + + unsafe { + // Scatter into each partition. + for i in 0..m.len() { + let p_idx = *partition_idxs.get_unchecked(i) as usize; + let p = partitions.get_unchecked_mut(p_idx); + p.push_unchecked(m.get_bit_unchecked(i)); + } + } + + partitions +} + +/// A fused loop of partition_vec and partition_mask. +/// # Safety +/// Same as partition_vec. +pub unsafe fn partition_vec_mask( + v: Vec, + m: &Bitmap, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], +) -> Vec<(Vec, BitmapBuilder)> { + assert!(partition_idxs.len() == v.len()); + assert!(m.len() == v.len()); + + let mut partitions = partition_sizes + .iter() + .map(|sz| { + ( + Vec::::with_capacity(*sz as usize), + BitmapBuilder::with_capacity(*sz as usize), + ) + }) + .collect_vec(); + + unsafe { + // Scatter into each partition. + for (i, val) in v.into_iter().enumerate() { + let p_idx = *partition_idxs.get_unchecked(i) as usize; + let (pv, pm) = partitions.get_unchecked_mut(p_idx); + pv.push_unchecked(val); + pm.push_unchecked(m.get_bit_unchecked(i)); + } + + for (p, sz) in partitions.iter_mut().zip(partition_sizes) { + p.0.set_len(*sz as usize); + } + } + + partitions +} diff --git a/crates/polars-expr/src/reduce/sum.rs b/crates/polars-expr/src/reduce/sum.rs index 2b5d9d79c13f..466d5ffb9f9d 100644 --- a/crates/polars-expr/src/reduce/sum.rs +++ b/crates/polars-expr/src/reduce/sum.rs @@ -116,7 +116,7 @@ where ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); assert!(self.in_dtype == other.in_dtype); - assert!(self.sums.len() == other.sums.len()); + assert!(other.sums.len() == group_idxs.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, v) in group_idxs.iter().zip(other.sums.iter()) { @@ -126,6 +126,22 @@ where Ok(()) } + unsafe fn partition( + self: Box, + partition_sizes: &[IdxSize], + partition_idxs: &[IdxSize], + ) -> Vec> { + partition::partition_vec(self.sums, partition_sizes, partition_idxs) + .into_iter() + .map(|sums| { + Box::new(Self { + sums, + in_dtype: self.in_dtype.clone(), + }) as _ + }) + .collect() + } + fn finalize(&mut self) -> PolarsResult { let v = core::mem::take(&mut self.sums); let arr = Box::new(PrimitiveArray::::from_vec(v)); diff --git a/crates/polars-expr/src/reduce/var_std.rs b/crates/polars-expr/src/reduce/var_std.rs new file mode 100644 index 000000000000..7993c06e6a4e --- /dev/null +++ b/crates/polars-expr/src/reduce/var_std.rs @@ -0,0 +1,168 @@ +use std::marker::PhantomData; + +use num_traits::AsPrimitive; +use polars_compute::var_cov::VarState; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub fn new_var_std_reduction(dtype: DataType, is_std: bool, ddof: u8) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VGR::new(dtype, BoolVarStdReducer { is_std, ddof })), + _ if dtype.is_numeric() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, VarStdReducer::<$T> { + is_std, + ddof, + needs_cast: false, + _phantom: PhantomData, + })) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VGR::new( + dtype, + VarStdReducer:: { + is_std, + ddof, + needs_cast: true, + _phantom: PhantomData, + }, + )), + Duration(..) => todo!(), + _ => unimplemented!(), + } +} + +struct VarStdReducer { + is_std: bool, + ddof: u8, + needs_cast: bool, + _phantom: PhantomData, +} + +impl Clone for VarStdReducer { + fn clone(&self) -> Self { + Self { + is_std: self.is_std, + ddof: self.ddof, + needs_cast: self.needs_cast, + _phantom: PhantomData, + } + } +} + +impl Reducer for VarStdReducer { + type Dtype = T; + type Value = VarState; + + fn init(&self) -> Self::Value { + VarState::default() + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + if self.needs_cast { + Cow::Owned(s.cast(&DataType::Float64).unwrap()) + } else { + Cow::Borrowed(s) + } + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.combine(b) + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option) { + if let Some(x) = b { + a.add_one(x.as_()); + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + for arr in ca.downcast_iter() { + v.combine(&polars_compute::var_cov::var(arr)) + } + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let ca: Float64Chunked = v + .into_iter() + .map(|s| { + let var = s.finalize(self.ddof); + if self.is_std { + var.map(f64::sqrt) + } else { + var + } + }) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} + +#[derive(Clone)] +struct BoolVarStdReducer { + is_std: bool, + ddof: u8, +} + +impl Reducer for BoolVarStdReducer { + type Dtype = BooleanType; + type Value = (usize, usize); + + fn init(&self) -> Self::Value { + (0, 0) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option) { + a.0 += b.unwrap_or(false) as usize; + a.1 += b.is_some() as usize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray) { + v.0 += ca.sum().unwrap_or(0) as usize; + v.1 += ca.len() - ca.null_count(); + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let ca: Float64Chunked = v + .into_iter() + .map(|v| { + if v.1 <= self.ddof as usize { + return None; + } + + let sum = v.0 as f64; // Both the sum and sum-of-squares, letting us simplify. + let n = v.1; + let var = sum * (1.0 - sum / n as f64) / ((n - self.ddof as usize) as f64); + if self.is_std { + Some(var.sqrt()) + } else { + Some(var) + } + }) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} diff --git a/crates/polars-ffi/src/version_0.rs b/crates/polars-ffi/src/version_0.rs index 3cffd4425045..504f6cc126d1 100644 --- a/crates/polars-ffi/src/version_0.rs +++ b/crates/polars-ffi/src/version_0.rs @@ -132,7 +132,7 @@ impl CallerContext { self.bitflags |= 1 << k } - /// Parallelism is done by polars' main engine, the plugin should not run run its own parallelism. + /// Parallelism is done by polars' main engine, the plugin should not run its own parallelism. /// If this is `false`, the plugin could use parallelism without (much) contention with polars /// parallelism strategies. pub fn parallel(&self) -> bool { diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index c3ed9b93bd5c..2e0a51acc9e4 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -37,6 +37,7 @@ num-traits = { workspace = true } object_store = { workspace = true, optional = true } once_cell = { workspace = true } percent-encoding = { workspace = true } +pyo3 = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true } reqwest = { workspace = true, optional = true } @@ -129,7 +130,7 @@ gcp = ["object_store/gcp", "cloud"] http = ["object_store/http", "cloud"] temporal = ["dtype-datetime", "dtype-date", "dtype-time"] simd = [] -python = ["polars-error/python"] +python = ["pyo3", "polars-error/python", "polars-utils/python"] [package.metadata.docs.rs] all-features = true diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs new file mode 100644 index 000000000000..e6de837488c1 --- /dev/null +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -0,0 +1,738 @@ +use std::fmt::Debug; +use std::future::Future; +use std::hash::Hash; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +#[cfg(feature = "aws")] +pub use object_store::aws::AwsCredential; +#[cfg(feature = "azure")] +pub use object_store::azure::AzureCredential; +#[cfg(feature = "gcp")] +pub use object_store::gcp::GcpCredential; +use polars_core::config; +use polars_error::{polars_bail, PolarsResult}; +#[cfg(feature = "python")] +use polars_utils::python_function::PythonFunction; +#[cfg(feature = "python")] +use python_impl::PythonCredentialProvider; + +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum PlCredentialProvider { + /// Prefer using [`PlCredentialProvider::from_func`] instead of constructing this directly + Function(CredentialProviderFunction), + #[cfg(feature = "python")] + Python(python_impl::PythonCredentialProvider), +} + +impl PlCredentialProvider { + /// Accepts a function that returns (credential, expiry time as seconds since UNIX_EPOCH) + /// + /// This functionality is unstable. + pub fn from_func( + // Internal notes + // * This function is exposed as the Rust API for `PlCredentialProvider` + func: impl Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync + + 'static, + ) -> Self { + Self::Function(CredentialProviderFunction(Arc::new(func))) + } + + #[cfg(feature = "python")] + pub fn from_python_func(func: PythonFunction) -> Self { + Self::Python(python_impl::PythonCredentialProvider(Arc::new(func))) + } + + #[cfg(feature = "python")] + pub fn from_python_func_object(func: pyo3::PyObject) -> Self { + Self::Python(python_impl::PythonCredentialProvider(Arc::new( + PythonFunction(func), + ))) + } + + pub(super) fn func_addr(&self) -> usize { + match self { + Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize, + #[cfg(feature = "python")] + Self::Python(PythonCredentialProvider(v)) => Arc::as_ptr(v) as *const () as usize, + } + } +} + +pub enum ObjectStoreCredential { + #[cfg(feature = "aws")] + Aws(Arc), + #[cfg(feature = "azure")] + Azure(Arc), + #[cfg(feature = "gcp")] + Gcp(Arc), + /// For testing purposes + None, +} + +impl ObjectStoreCredential { + fn variant_name(&self) -> &'static str { + match self { + #[cfg(feature = "aws")] + Self::Aws(_) => "Aws", + #[cfg(feature = "azure")] + Self::Azure(_) => "Azure", + #[cfg(feature = "gcp")] + Self::Gcp(_) => "Gcp", + Self::None => "None", + } + } + + fn panic_type_mismatch(&self, expected: &str) { + panic!( + "impl error: credential type mismatch: expected {}, got {} instead", + expected, + self.variant_name() + ) + } + + #[cfg(feature = "aws")] + fn unwrap_aws(self) -> Arc { + let Self::Aws(v) = self else { + self.panic_type_mismatch("aws"); + unreachable!() + }; + v + } + + #[cfg(feature = "azure")] + fn unwrap_azure(self) -> Arc { + let Self::Azure(v) = self else { + self.panic_type_mismatch("azure"); + unreachable!() + }; + v + } + + #[cfg(feature = "gcp")] + fn unwrap_gcp(self) -> Arc { + let Self::Gcp(v) = self else { + self.panic_type_mismatch("gcp"); + unreachable!() + }; + v + } +} + +pub trait IntoCredentialProvider: Sized { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + unimplemented!() + } +} + +impl IntoCredentialProvider for PlCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + match self { + Self::Function(v) => v.into_aws_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_aws_provider(), + } + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + match self { + Self::Function(v) => v.into_azure_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_azure_provider(), + } + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + match self { + Self::Function(v) => v.into_gcp_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_gcp_provider(), + } + } +} + +type CredentialProviderFunctionImpl = Arc< + dyn Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync, +>; + +/// Wrapper that implements [`IntoCredentialProvider`], [`Debug`], [`PartialEq`], [`Hash`] etc. +#[derive(Clone)] +pub struct CredentialProviderFunction(CredentialProviderFunctionImpl); + +macro_rules! build_to_object_store_err { + ($s:expr) => {{ + fn to_object_store_err( + e: impl std::error::Error + Send + Sync + 'static, + ) -> object_store::Error { + object_store::Error::Generic { + store: $s, + source: Box::new(e), + } + } + + to_object_store_err + }}; +} + +impl IntoCredentialProvider for CredentialProviderFunction { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::aws::AwsCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_aws(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-aws")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + })), + )) + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::azure::AzureCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_azure(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-azure")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))), + )) + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::gcp::GcpCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_gcp(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-gcp")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(GcpCredential { + bearer: String::new(), + })), + )) + } +} + +impl Debug for CredentialProviderFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "credential provider function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for CredentialProviderFunction {} + +impl PartialEq for CredentialProviderFunction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Hash for CredentialProviderFunction { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for PlCredentialProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cfg(feature = "python")] + { + Ok(Self::Python(PythonCredentialProvider::deserialize( + deserializer, + )?)) + } + #[cfg(not(feature = "python"))] + { + use serde::de::Error; + Err(D::Error::custom("cannot deserialize PlCredentialProvider")) + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PlCredentialProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + + #[cfg(feature = "python")] + if let PlCredentialProvider::Python(v) = self { + return v.serialize(serializer); + } + + Err(S::Error::custom(format!("cannot serialize {:?}", self))) + } +} + +/// Avoids calling the credential provider function if we have not yet passed the expiry time. +#[derive(Debug)] +struct FetchedCredentialsCache(tokio::sync::Mutex<(C, u64)>); + +impl FetchedCredentialsCache { + fn new(init_creds: C) -> Self { + Self(tokio::sync::Mutex::new((init_creds, 0))) + } + + async fn get_maybe_update( + &self, + // Taking an `impl Future` here allows us to potentially avoid a `Box::pin` allocation from + // a `Fn() -> Pin>` by having it wrapped in an `async { f() }` block. We + // will not poll that block if the credentials have not yet expired. + update_func: impl Future>, + ) -> PolarsResult { + let verbose = config::verbose(); + + fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String { + if last_fetched_expiry == u64::MAX { + "expiry = (never expires)".into() + } else { + format!( + "expiry = {} (in {} seconds)", + last_fetched_expiry, + last_fetched_expiry.saturating_sub(now) + ) + } + } + + let mut inner = self.0.lock().await; + let (last_fetched_credentials, last_fetched_expiry) = &mut *inner; + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Ensure the credential is valid for at least this many seconds to + // accommodate for latency. + const REQUEST_TIME_BUFFER: u64 = 7; + + if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER { + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Call update_func: current_time = {}\ + , last_fetched_expiry = {}", + current_time, *last_fetched_expiry + ) + } + let (credentials, expiry) = update_func.await?; + + *last_fetched_credentials = credentials; + *last_fetched_expiry = expiry; + + if expiry < current_time && expiry != 0 { + polars_bail!( + ComputeError: + "credential expiry time {} is older than system time {} \ + by {} seconds", + expiry, + current_time, + current_time - expiry + ) + } + + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Finish update_func: new {}", + expiry_msg( + *last_fetched_expiry, + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + ) + ) + } + } else if verbose { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + eprintln!( + "[FetchedCredentialsCache]: Using cached credentials: \ + current_time = {}, {}", + now, + expiry_msg(*last_fetched_expiry, now) + ) + } + + Ok(last_fetched_credentials.clone()) + } +} + +#[cfg(feature = "python")] +mod python_impl { + use std::hash::Hash; + use std::sync::Arc; + + use polars_error::PolarsError; + use polars_utils::python_function::PythonFunction; + use pyo3::exceptions::PyValueError; + use pyo3::pybacked::PyBackedStr; + use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods}; + use pyo3::Python; + + use super::IntoCredentialProvider; + + #[derive(Clone, Debug)] + pub struct PythonCredentialProvider(pub(super) Arc); + + impl From for PythonCredentialProvider { + fn from(value: PythonFunction) -> Self { + Self(Arc::new(value)) + } + } + + impl IntoCredentialProvider for PythonCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = object_store::aws::AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::>()?; + + match k.as_ref() { + "aws_access_key_id" => { + credentials.key_id = v.ok_or_else(|| { + PyValueError::new_err("aws_access_key_id was None") + })?; + }, + "aws_secret_access_key" => { + credentials.secret_key = v.ok_or_else(|| { + PyValueError::new_err("aws_secret_access_key was None") + })? + }, + "aws_session_token" => credentials.token = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for aws: {}, \ + valid configuration keys are: \ + {}, {}, {}", + v, + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token" + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.key_id.is_empty() { + return Err(PolarsError::ComputeError( + "aws_access_key_id was empty or not given".into(), + )); + } + + if credentials.secret_key.is_empty() { + return Err(PolarsError::ComputeError( + "aws_secret_access_key was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry)) + }) + })) + .into_aws_provider() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = + object_store::azure::AzureCredential::BearerToken(String::new()); + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + // We only support bearer for now + match k.as_ref() { + "bearer_token" => { + credentials = + object_store::azure::AzureCredential::BearerToken(v) + }, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for azure: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + let object_store::azure::AzureCredential::BearerToken(bearer) = &credentials + else { + unreachable!() + }; + + if bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry)) + }) + })) + .into_azure_provider() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = object_store::gcp::GcpCredential { + bearer: String::new(), + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + match k.as_ref() { + "bearer_token" => credentials.bearer = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for gcp: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry)) + }) + })) + .into_gcp_provider() + } + } + + impl Eq for PythonCredentialProvider {} + + impl PartialEq for PythonCredentialProvider { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } + } + + impl Hash for PythonCredentialProvider { + fn hash(&self, state: &mut H) { + // # Safety + // * Inner is an `Arc` + // * Visibility is limited to super + // * No code in `mod python_impl` or `super` mutates the Arc inner. + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } + } + + #[cfg(feature = "serde")] + mod _serde_impl { + use polars_utils::python_function::PySerializeWrap; + + use super::PythonCredentialProvider; + + impl serde::Serialize for PythonCredentialProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + PySerializeWrap(self.0.as_ref()).serialize(serializer) + } + } + + impl<'a> serde::Deserialize<'a> for PythonCredentialProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + PySerializeWrap::::deserialize(deserializer) + .map(|x| x.0.into()) + } + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "serde")] + #[allow(clippy::redundant_pattern_matching)] + #[test] + fn test_serde() { + use super::*; + + assert!(matches!( + serde_json::to_string(&Some(PlCredentialProvider::from_func(|| { + Box::pin(core::future::ready(PolarsResult::Ok(( + ObjectStoreCredential::None, + 0, + )))) + }))), + Err(_) + )); + + assert!(matches!( + serde_json::to_string(&Option::::None), + Ok(String { .. }) + )); + + assert!(matches!( + serde_json::from_str::>( + serde_json::to_string(&Option::::None) + .unwrap() + .as_str() + ), + Ok(None) + )); + } +} diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs index b41f7d45cf21..7ae2d99444a7 100644 --- a/crates/polars-io/src/cloud/mod.rs +++ b/crates/polars-io/src/cloud/mod.rs @@ -19,3 +19,6 @@ pub use object_store_setup::*; pub use options::*; #[cfg(feature = "cloud")] pub use polars_object_store::*; + +#[cfg(feature = "cloud")] +pub mod credential_provider; diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index b6464b109535..0bb33b333c18 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -3,12 +3,14 @@ use std::sync::Arc; use object_store::local::LocalFileSystem; use object_store::ObjectStore; use once_cell::sync::Lazy; +use polars_core::config; use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; use polars_utils::aliases::PlHashMap; use tokio::sync::RwLock; use url::Url; use super::{parse_url, CloudLocation, CloudOptions, CloudType}; +use crate::cloud::CloudConfig; /// Object stores must be cached. Every object-store will do DNS lookups and /// get rate limited when querying the DNS (can take up to 5s). @@ -28,10 +30,40 @@ fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { } /// Get the key of a url for object store registration. -/// The credential info will be removed fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> String { + #[derive(Clone, Debug, PartialEq, Hash, Eq)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + struct S { + max_retries: usize, + #[cfg(feature = "file_cache")] + file_cache_ttl: u64, + config: Option, + #[cfg(feature = "cloud")] + credential_provider: usize, + } + // We include credentials as they can expire, so users will send new credentials for the same url. - let creds = serde_json::to_string(&options).unwrap_or_else(|_| "".into()); + let creds = serde_json::to_string(&options.map( + |CloudOptions { + // Destructure to ensure this breaks if anything changes. + max_retries, + #[cfg(feature = "file_cache")] + file_cache_ttl, + config, + #[cfg(feature = "cloud")] + credential_provider, + }| { + S { + max_retries: *max_retries, + #[cfg(feature = "file_cache")] + file_cache_ttl: *file_cache_ttl, + config: config.clone(), + #[cfg(feature = "cloud")] + credential_provider: credential_provider.as_ref().map_or(0, |x| x.func_addr()), + } + }, + )) + .unwrap(); format!( "{}://{}<\\creds\\>{}", url.scheme(), @@ -58,6 +90,8 @@ pub async fn build_object_store( let parsed = parse_url(url).map_err(to_compute_err)?; let cloud_location = CloudLocation::from_url(&parsed, glob)?; + // FIXME: `credential_provider` is currently serializing the entire Python function here + // into a string with pickle for this cache key because we are using `serde_json::to_string` let key = url_and_creds_to_key(&parsed, options); let mut allow_cache = true; @@ -124,6 +158,12 @@ pub async fn build_object_store( let mut cache = OBJECT_STORE_CACHE.write().await; // Clear the cache if we surpass a certain amount of buckets. if cache.len() > 8 { + if config::verbose() { + eprintln!( + "build_object_store: clearing store cache (cache.len(): {})", + cache.len() + ); + } cache.clear() } cache.insert(key, store.clone()); diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index efaab673f634..9549b837f06d 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -36,6 +36,8 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "cloud")] use url::Url; +#[cfg(feature = "cloud")] +use super::credential_provider::PlCredentialProvider; #[cfg(feature = "file_cache")] use crate::file_cache::get_env_file_cache_ttl; #[cfg(feature = "aws")] @@ -75,6 +77,8 @@ pub struct CloudOptions { #[cfg(feature = "file_cache")] pub file_cache_ttl: u64, pub(crate) config: Option, + #[cfg(feature = "cloud")] + pub(crate) credential_provider: Option, } impl Default for CloudOptions { @@ -84,6 +88,8 @@ impl Default for CloudOptions { #[cfg(feature = "file_cache")] file_cache_ttl: get_env_file_cache_ttl(), config: None, + #[cfg(feature = "cloud")] + credential_provider: Default::default(), } } } @@ -248,6 +254,15 @@ impl CloudOptions { self } + #[cfg(feature = "cloud")] + pub fn with_credential_provider( + mut self, + credential_provider: Option, + ) -> Self { + self.credential_provider = credential_provider; + self + } + /// Set the configuration for AWS connections. This is the preferred API from rust. #[cfg(feature = "aws")] pub fn with_aws)>>( @@ -263,6 +278,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for AWS. #[cfg(feature = "aws")] pub async fn build_aws(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = AmazonS3Builder::from_env().with_url(url); if let Some(options) = &self.config { let CloudConfig::Aws(options) = options else { @@ -346,11 +363,17 @@ impl CloudOptions { }; }; - builder + let builder = builder .with_client_options(get_client_options()) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_aws_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } /// Set the configuration for Azure connections. This is the preferred API from rust. @@ -368,6 +391,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for Azure. #[cfg(feature = "azure")] pub fn build_azure(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = MicrosoftAzureBuilder::from_env(); if let Some(options) = &self.config { let CloudConfig::Azure(options) = options else { @@ -378,12 +403,18 @@ impl CloudOptions { } } - builder + let builder = builder .with_client_options(get_client_options()) .with_url(url) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_azure_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } /// Set the configuration for GCP connections. This is the preferred API from rust. @@ -401,6 +432,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for GCP. #[cfg(feature = "gcp")] pub fn build_gcp(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = GoogleCloudStorageBuilder::from_env(); if let Some(options) = &self.config { let CloudConfig::Gcp(options) = options else { @@ -411,12 +444,18 @@ impl CloudOptions { } } - builder + let builder = builder .with_client_options(get_client_options()) .with_url(url) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_gcp_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } #[cfg(feature = "http")] diff --git a/crates/polars-io/src/csv/read/parser.rs b/crates/polars-io/src/csv/read/parser.rs index 2b11013be951..9272dc6d65b8 100644 --- a/crates/polars-io/src/csv/read/parser.rs +++ b/crates/polars-io/src/csv/read/parser.rs @@ -835,7 +835,7 @@ pub(super) fn parse_lines( \n\ You might want to try:\n\ - increasing `infer_schema_length` (e.g. `infer_schema_length=10000`),\n\ - - specifying correct dtype with the `dtypes` argument\n\ + - specifying correct dtype with the `schema_overrides` argument\n\ - setting `ignore_errors` to `True`,\n\ - adding `{}` to the `null_values` list.\n\n\ Original error: ```{}```", diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs index a8c304b8f65c..52f29ee0a128 100644 --- a/crates/polars-io/src/csv/read/read_impl.rs +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -126,7 +126,7 @@ pub(crate) struct CoreReader<'a> { truncate_ragged_lines: bool, } -impl<'a> fmt::Debug for CoreReader<'a> { +impl fmt::Debug for CoreReader<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Reader") .field("schema", &self.schema) @@ -191,7 +191,7 @@ impl<'a> CoreReader<'a> { if let Some(b) = decompress(&reader_bytes, total_n_rows, separator, quote_char, eol_char) { - reader_bytes = ReaderBytes::Owned(b); + reader_bytes = ReaderBytes::Owned(b.into()); } } @@ -467,16 +467,10 @@ impl<'a> CoreReader<'a> { continue; } - let b = unsafe { - bytes.get_unchecked_release(total_offset..total_offset + position) - }; - // The parsers will not create a null row if we end on a new line. - if b.last() == Some(&self.eol_char) { - chunk_size *= 2; - continue; - } + let end = total_offset + position + 1; + let b = unsafe { bytes.get_unchecked_release(total_offset..end) }; - total_offset += position + 1; + total_offset = end; (b, count) }; let check_utf8 = matches!(self.encoding, CsvEncoding::Utf8) diff --git a/crates/polars-io/src/csv/read/read_impl/batched.rs b/crates/polars-io/src/csv/read/read_impl/batched.rs index 3bf6e2dd4e32..90e0b4e4e37c 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched.rs @@ -66,7 +66,7 @@ struct ChunkOffsetIter<'a> { eol_char: u8, } -impl<'a> Iterator for ChunkOffsetIter<'a> { +impl Iterator for ChunkOffsetIter<'_> { type Item = (usize, usize); fn next(&mut self) -> Option { @@ -209,7 +209,7 @@ pub struct BatchedCsvReader<'a> { decimal_comma: bool, } -impl<'a> BatchedCsvReader<'a> { +impl BatchedCsvReader<'_> { pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { if n == 0 || self.remaining == 0 { return Ok(None); diff --git a/crates/polars-io/src/csv/read/schema_inference.rs b/crates/polars-io/src/csv/read/schema_inference.rs index 3684df1a2bac..a01ec5ddef3f 100644 --- a/crates/polars-io/src/csv/read/schema_inference.rs +++ b/crates/polars-io/src/csv/read/schema_inference.rs @@ -296,7 +296,7 @@ fn infer_file_schema_inner( buf.push(eol_char); return infer_file_schema_inner( - &ReaderBytes::Owned(buf), + &ReaderBytes::Owned(buf.into()), separator, max_read_rows, has_header, @@ -481,7 +481,7 @@ fn infer_file_schema_inner( rb.extend_from_slice(reader_bytes); rb.push(eol_char); return infer_file_schema_inner( - &ReaderBytes::Owned(rb), + &ReaderBytes::Owned(rb.into()), separator, max_read_rows, has_header, diff --git a/crates/polars-io/src/csv/write/write_impl/serializer.rs b/crates/polars-io/src/csv/write/write_impl/serializer.rs index e9a3e055278a..6a4f964d88b3 100644 --- a/crates/polars-io/src/csv/write/write_impl/serializer.rs +++ b/crates/polars-io/src/csv/write/write_impl/serializer.rs @@ -686,17 +686,11 @@ pub(super) fn serializer_for<'a>( }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, scale) => { - let array = array.as_any().downcast_ref().unwrap(); - match options.quote_style { - QuoteStyle::Never => Box::new(decimal_serializer(array, scale.unwrap_or(0))) - as Box, - _ => Box::new(quote_serializer(decimal_serializer( - array, - scale.unwrap_or(0), - ))), - } + quote_if_always!(decimal_serializer, scale.unwrap_or(0)) + }, + _ => { + polars_bail!(ComputeError: "datatype {dtype} cannot be written to CSV\n\nConsider using JSON or a binary format.") }, - _ => polars_bail!(ComputeError: "datatype {dtype} cannot be written to csv"), }; Ok(serializer) } diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index ab6805d18967..100d37b2c941 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -1,6 +1,6 @@ //! # (De)serializing Arrows IPC format. //! -//! Arrow IPC is a [binary format format](https://arrow.apache.org/docs/python/ipc.html). +//! Arrow IPC is a [binary format](https://arrow.apache.org/docs/python/ipc.html). //! It is the recommended way to serialize and deserialize Polars DataFrames as this is most true //! to the data schema. //! diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index 6b16579ac93d..6393c639cf35 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -1,6 +1,6 @@ //! # (De)serializing Arrows Streaming IPC format. //! -//! Arrow Streaming IPC is a [binary format format](https://arrow.apache.org/docs/python/ipc.html). +//! Arrow Streaming IPC is a [binary format](https://arrow.apache.org/docs/python/ipc.html). //! It used for sending an arbitrary length sequence of record batches. //! The format must be processed from start to end, and does not support random access. //! It is different than IPC, if you can't deserialize a file with `IpcReader::new`, it's probably an IPC Stream File. diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index 208b1ba8befa..df2ff8bb9c28 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -236,7 +236,7 @@ pub fn remove_bom(bytes: &[u8]) -> PolarsResult<&[u8]> { Ok(bytes) } } -impl<'a, R> SerReader for JsonReader<'a, R> +impl SerReader for JsonReader<'_, R> where R: MmapBytesReader, { @@ -287,6 +287,8 @@ where } } + let allow_extra_fields_in_struct = self.schema.is_some(); + // struct type let dtype = if let Some(mut schema) = self.schema { if let Some(overwrite) = self.schema_overwrite { @@ -338,7 +340,11 @@ where dtype }; - let arr = polars_json::json::deserialize(&json_value, dtype)?; + let arr = polars_json::json::deserialize( + &json_value, + dtype, + allow_extra_fields_in_struct, + )?; let arr = arr.as_any().downcast_ref::().ok_or_else( || polars_err!(ComputeError: "can only deserialize json objects"), )?; diff --git a/crates/polars-io/src/mmap.rs b/crates/polars-io/src/mmap.rs index df91f32942f9..2373257469e7 100644 --- a/crates/polars-io/src/mmap.rs +++ b/crates/polars-io/src/mmap.rs @@ -1,9 +1,8 @@ use std::fs::File; use std::io::{BufReader, Cursor, Read, Seek}; -use std::sync::Arc; use polars_core::config::verbose; -use polars_utils::mmap::{MMapSemaphore, MemSlice}; +use polars_utils::mmap::MemSlice; /// Trait used to get a hold to file handler or to the underlying bytes /// without performing a Read. @@ -67,8 +66,7 @@ impl MmapBytesReader for &mut T { // Handle various forms of input bytes pub enum ReaderBytes<'a> { Borrowed(&'a [u8]), - Owned(Vec), - Mapped(MMapSemaphore, &'a File), + Owned(MemSlice), } impl std::ops::Deref for ReaderBytes<'_> { @@ -77,19 +75,21 @@ impl std::ops::Deref for ReaderBytes<'_> { match self { Self::Borrowed(ref_bytes) => ref_bytes, Self::Owned(vec) => vec, - Self::Mapped(mmap, _) => mmap.as_ref(), } } } -/// Require 'static to force the caller to do any transmute as it's usually much -/// clearer to see there whether it's sound. +/// There are some places that perform manual lifetime management after transmuting `ReaderBytes` +/// to have a `'static` inner lifetime. The advantage to doing this is that it lets you construct a +/// `MemSlice` from the `ReaderBytes` in a zero-copy manner regardless of the underlying enum +/// variant. impl ReaderBytes<'static> { - pub fn into_mem_slice(self) -> MemSlice { + /// Construct a `MemSlice` in a zero-copy manner from the underlying bytes, with the assumption + /// that the underlying bytes have a `'static` lifetime. + pub fn to_memslice(&self) -> MemSlice { match self { ReaderBytes::Borrowed(v) => MemSlice::from_static(v), - ReaderBytes::Owned(v) => MemSlice::from_vec(v), - ReaderBytes::Mapped(v, _) => MemSlice::from_mmap(Arc::new(v)), + ReaderBytes::Owned(v) => v.clone(), } } } @@ -104,16 +104,14 @@ impl<'a, T: 'a + MmapBytesReader> From<&'a mut T> for ReaderBytes<'a> { }, None => { if let Some(f) = m.to_file() { - let f = unsafe { std::mem::transmute::<&File, &'a File>(f) }; - let mmap = MMapSemaphore::new_from_file(f).unwrap(); - ReaderBytes::Mapped(mmap, f) + ReaderBytes::Owned(MemSlice::from_file(f).unwrap()) } else { if verbose() { eprintln!("could not memory map file; read to buffer.") } let mut buf = vec![]; m.read_to_end(&mut buf).expect("could not read"); - ReaderBytes::Owned(buf) + ReaderBytes::Owned(MemSlice::from_vec(buf)) } }, } diff --git a/crates/polars-io/src/ndjson/buffer.rs b/crates/polars-io/src/ndjson/buffer.rs index 2bb2a028f1ca..1c9938979af5 100644 --- a/crates/polars-io/src/ndjson/buffer.rs +++ b/crates/polars-io/src/ndjson/buffer.rs @@ -12,9 +12,9 @@ use simd_json::{BorrowedValue as Value, KnownKey, StaticNode}; #[derive(Debug, Clone, PartialEq)] pub(crate) struct BufferKey<'a>(pub(crate) KnownKey<'a>); -impl<'a> Eq for BufferKey<'a> {} +impl Eq for BufferKey<'_> {} -impl<'a> Hash for BufferKey<'a> { +impl Hash for BufferKey<'_> { fn hash(&self, state: &mut H) { self.0.key().hash(state) } diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs index a72b4ccf7038..2fabd7dcd589 100644 --- a/crates/polars-io/src/ndjson/core.rs +++ b/crates/polars-io/src/ndjson/core.rs @@ -133,7 +133,7 @@ where } } -impl<'a> JsonLineReader<'a, File> { +impl JsonLineReader<'_, File> { /// This is the recommended way to create a json reader as this allows for fastest parsing. pub fn from_path>(path: P) -> PolarsResult { let path = crate::resolve_homedir(&path.into()); @@ -141,7 +141,7 @@ impl<'a> JsonLineReader<'a, File> { Ok(Self::new(f).with_path(Some(path))) } } -impl<'a, R> SerReader for JsonLineReader<'a, R> +impl SerReader for JsonLineReader<'_, R> where R: MmapBytesReader, { diff --git a/crates/polars-io/src/parquet/read/mod.rs b/crates/polars-io/src/parquet/read/mod.rs index 1fec749af5ce..cc0020cc7857 100644 --- a/crates/polars-io/src/parquet/read/mod.rs +++ b/crates/polars-io/src/parquet/read/mod.rs @@ -33,6 +33,7 @@ or set 'streaming'", pub use options::{ParallelStrategy, ParquetOptions}; use polars_error::{ErrString, PolarsError}; +pub use read_impl::{create_sorting_map, try_set_sorted_flag}; #[cfg(feature = "cloud")] pub use reader::ParquetAsyncReader; pub use reader::{BatchedParquetReader, ParquetReader}; diff --git a/crates/polars-io/src/parquet/read/predicates.rs b/crates/polars-io/src/parquet/read/predicates.rs index eb8f7747f078..a3269341c1a3 100644 --- a/crates/polars-io/src/parquet/read/predicates.rs +++ b/crates/polars-io/src/parquet/read/predicates.rs @@ -1,3 +1,4 @@ +use polars_core::config; use polars_core::prelude::*; use polars_parquet::read::statistics::{deserialize, Statistics}; use polars_parquet::read::RowGroupMetadata; @@ -50,18 +51,38 @@ pub fn read_this_row_group( md: &RowGroupMetadata, schema: &ArrowSchema, ) -> PolarsResult { + if std::env::var("POLARS_NO_PARQUET_STATISTICS").is_ok() { + return Ok(true); + } + + let mut should_read = true; + if let Some(pred) = predicate { if let Some(pred) = pred.as_stats_evaluator() { if let Some(stats) = collect_statistics(md, schema)? { - let should_read = pred.should_read(&stats); + let pred_result = pred.should_read(&stats); + // a parquet file may not have statistics of all columns - if matches!(should_read, Ok(false)) { - return Ok(false); - } else if !matches!(should_read, Err(PolarsError::ColumnNotFound(_))) { - let _ = should_read?; + match pred_result { + Err(PolarsError::ColumnNotFound(errstr)) => { + return Err(PolarsError::ColumnNotFound(errstr)) + }, + Ok(false) => should_read = false, + _ => {}, } } } + + if config::verbose() { + if should_read { + eprintln!( + "parquet row group must be read, statistics not sufficient for predicate." + ); + } else { + eprintln!("parquet row group can be skipped, the statistics were sufficient to apply the predicate."); + } + } } - Ok(true) + + Ok(should_read) } diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 0389a73b5081..de22b639bf8b 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -7,14 +7,14 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowSchemaRef; use polars_core::chunked_array::builder::NullChunkedBuilder; use polars_core::prelude::*; +use polars_core::series::IsSorted; use polars_core::utils::{accumulate_dataframes_vertical, split_df}; -use polars_core::POOL; +use polars_core::{config, POOL}; use polars_parquet::parquet::error::ParquetResult; use polars_parquet::parquet::statistics::Statistics; use polars_parquet::read::{ self, ColumnChunkMetadata, FileMetadata, Filter, PhysicalType, RowGroupMetadata, }; -use polars_utils::mmap::MemSlice; use rayon::prelude::*; #[cfg(feature = "cloud")] @@ -43,6 +43,9 @@ fn assert_dtypes(dtype: &ArrowDataType) { // These should all be casted to the BinaryView / Utf8View variants D::Utf8 | D::Binary | D::LargeUtf8 | D::LargeBinary => unreachable!(), + // These should be casted to Float32 + D::Float16 => unreachable!(), + // This should have been converted to a LargeList D::List(_) => unreachable!(), @@ -60,6 +63,57 @@ fn assert_dtypes(dtype: &ArrowDataType) { } } +fn should_copy_sortedness(dtype: &DataType) -> bool { + // @NOTE: For now, we are a bit conservative with this. + use DataType as D; + + matches!( + dtype, + D::Int8 | D::Int16 | D::Int32 | D::Int64 | D::UInt8 | D::UInt16 | D::UInt32 | D::UInt64 + ) +} + +pub fn try_set_sorted_flag( + series: &mut Series, + col_idx: usize, + sorting_map: &PlHashMap, +) { + if let Some(is_sorted) = sorting_map.get(&col_idx) { + if should_copy_sortedness(series.dtype()) { + if config::verbose() { + eprintln!( + "Parquet conserved SortingColumn for column chunk of '{}' to {is_sorted:?}", + series.name() + ); + } + + series.set_sorted_flag(*is_sorted); + } + } +} + +pub fn create_sorting_map(md: &RowGroupMetadata) -> PlHashMap { + let capacity = md.sorting_columns().map_or(0, |s| s.len()); + let mut sorting_map = PlHashMap::with_capacity(capacity); + + if let Some(sorting_columns) = md.sorting_columns() { + for sorting in sorting_columns { + let prev_value = sorting_map.insert( + sorting.column_idx as usize, + if sorting.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }, + ); + + debug_assert!(prev_value.is_none()); + } + } + + sorting_map +} + fn column_idx_to_series( column_i: usize, // The metadata belonging to this column @@ -68,6 +122,8 @@ fn column_idx_to_series( file_schema: &ArrowSchema, store: &mmap::ColumnStore, ) -> PolarsResult { + let did_filter = filter.is_some(); + let field = file_schema.get_at_index(column_i).unwrap().1; #[cfg(debug_assertions)] @@ -91,6 +147,11 @@ fn column_idx_to_series( _ => {}, } + // We cannot trust the statistics if we filtered the parquet already. + if did_filter { + return Ok(series); + } + // See if we can find some statistics for this series. If we cannot find anything just return // the series as is. let Ok(Some(stats)) = stats.map(|mut s| s.pop().flatten()) else { @@ -320,6 +381,8 @@ fn rg_to_dfs_prefiltered( } } + let sorting_map = create_sorting_map(md); + // Collect the data for the live columns let live_columns = (0..num_live_columns) .into_par_iter() @@ -338,8 +401,12 @@ fn rg_to_dfs_prefiltered( let part = iter.collect::>(); - column_idx_to_series(col_idx, part.as_slice(), None, schema, store) - .map(Column::from) + let mut series = + column_idx_to_series(col_idx, part.as_slice(), None, schema, store)?; + + try_set_sorted_flag(&mut series, col_idx, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; @@ -445,7 +512,7 @@ fn rg_to_dfs_prefiltered( array.filter(&mask_arr) }; - let array = if mask_setting.should_prefilter( + let mut series = if mask_setting.should_prefilter( prefilter_cost, &schema.get_at_index(col_idx).unwrap().1.dtype, ) { @@ -454,9 +521,11 @@ fn rg_to_dfs_prefiltered( post()? }; - debug_assert_eq!(array.len(), filter_mask.set_bits()); + debug_assert_eq!(series.len(), filter_mask.set_bits()); - Ok(array.into_column()) + try_set_sorted_flag(&mut series, col_idx, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; @@ -569,6 +638,8 @@ fn rg_to_dfs_optionally_par_over_columns( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let sorting_map = create_sorting_map(md); + let columns = if let ParallelStrategy::Columns = parallel { POOL.install(|| { projection @@ -586,14 +657,17 @@ fn rg_to_dfs_optionally_par_over_columns( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>() })? @@ -613,14 +687,17 @@ fn rg_to_dfs_optionally_par_over_columns( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()? }; @@ -705,6 +782,8 @@ fn rg_to_dfs_par_over_rg( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let sorting_map = create_sorting_map(md); + let columns = projection .iter() .map(|column_i| { @@ -720,14 +799,17 @@ fn rg_to_dfs_par_over_rg( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(slice.0, slice.0 + slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; @@ -825,10 +907,9 @@ pub fn read_parquet( } let reader = ReaderBytes::from(&mut reader); - let store = mmap::ColumnStore::Local( - unsafe { std::mem::transmute::, ReaderBytes<'static>>(reader) } - .into_mem_slice(), - ); + let store = mmap::ColumnStore::Local(unsafe { + std::mem::transmute::, ReaderBytes<'static>>(reader).to_memslice() + }); let dfs = rg_to_dfs( &store, @@ -876,9 +957,7 @@ impl FetchRowGroupsFromMmapReader { fn fetch_row_groups(&mut self, _row_groups: Range) -> PolarsResult { // @TODO: we can something smarter here with mmap - Ok(mmap::ColumnStore::Local(MemSlice::from_vec( - self.0.deref().to_vec(), - ))) + Ok(mmap::ColumnStore::Local(self.0.to_memslice())) } } diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 2a70ef2c5046..25d1f51b098b 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -89,9 +89,15 @@ impl ParquetReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema()?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema()?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema()?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -104,7 +110,7 @@ impl ParquetReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" @@ -328,9 +334,15 @@ impl ParquetAsyncReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema().await?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema().await?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema().await?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -343,7 +355,7 @@ impl ParquetAsyncReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" diff --git a/crates/polars-io/src/path_utils/mod.rs b/crates/polars-io/src/path_utils/mod.rs index 1795cda6ebd0..71c59fecb31d 100644 --- a/crates/polars-io/src/path_utils/mod.rs +++ b/crates/polars-io/src/path_utils/mod.rs @@ -99,7 +99,7 @@ struct HiveIdxTracker<'a> { check_directory_level: bool, } -impl<'a> HiveIdxTracker<'a> { +impl HiveIdxTracker<'_> { fn update(&mut self, i: usize, path_idx: usize) -> PolarsResult<()> { let check_directory_level = self.check_directory_level; let paths = self.paths; diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index 023d61fe525b..dda13bf1ea51 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -6,14 +6,14 @@ use once_cell::sync::Lazy; use polars_core::prelude::*; #[cfg(any(feature = "ipc_streaming", feature = "parquet"))] use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df_as_ref}; -use polars_utils::mmap::MMapSemaphore; +use polars_utils::mmap::{MMapSemaphore, MemSlice}; use regex::{Regex, RegexBuilder}; use crate::mmap::{MmapBytesReader, ReaderBytes}; -pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( - reader: &'a mut R, -) -> PolarsResult> { +pub fn get_reader_bytes( + reader: &mut R, +) -> PolarsResult> { // we have a file so we can mmap // only seekable files are mmap-able if let Some((file, offset)) = reader @@ -23,14 +23,8 @@ pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( { let mut options = memmap::MmapOptions::new(); options.offset(offset); - - // somehow bck thinks borrows alias - // this is sound as file was already bound to 'a - use std::fs::File; - - let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; let mmap = MMapSemaphore::new_from_file_with_options(file, options)?; - Ok(ReaderBytes::Mapped(mmap, file)) + Ok(ReaderBytes::Owned(MemSlice::from_mmap(Arc::new(mmap)))) } else { // we can get the bytes for free if reader.to_bytes().is_some() { @@ -40,7 +34,7 @@ pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( // we have to read to an owned buffer to get the bytes. let mut bytes = Vec::with_capacity(1024 * 128); reader.read_to_end(&mut bytes)?; - Ok(ReaderBytes::Owned(bytes)) + Ok(ReaderBytes::Owned(bytes.into())) } } } @@ -79,20 +73,43 @@ pub(crate) fn columns_to_projection( Ok(prj) } +#[cfg(debug_assertions)] +fn check_offsets(dfs: &[DataFrame]) { + dfs.windows(2).for_each(|s| { + let a = &s[0].get_columns()[0]; + let b = &s[1].get_columns()[0]; + + let prev = a.get(a.len() - 1).unwrap().extract::().unwrap(); + let next = b.get(0).unwrap().extract::().unwrap(); + assert_eq!(prev + 1, next); + }) +} + /// Because of threading every row starts from `0` or from `offset`. /// We must correct that so that they are monotonically increasing. #[cfg(any(feature = "csv", feature = "json"))] pub(crate) fn update_row_counts2(dfs: &mut [DataFrame], offset: IdxSize) { if !dfs.is_empty() { - let mut previous = dfs[0].height() as IdxSize + offset; - for df in &mut dfs[1..] { + let mut previous = offset; + for df in &mut *dfs { + if df.is_empty() { + continue; + } let n_read = df.height() as IdxSize; if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = &*s + previous; + if let Ok(v) = s.get(0) { + if v.extract::().unwrap() != previous as usize { + *s = &*s + previous; + } + } } previous += n_read; } } + #[cfg(debug_assertions)] + { + check_offsets(dfs) + } } /// Because of threading every row starts from `0` or from `offset`. @@ -101,15 +118,21 @@ pub(crate) fn update_row_counts2(dfs: &mut [DataFrame], offset: IdxSize) { pub(crate) fn update_row_counts3(dfs: &mut [DataFrame], heights: &[IdxSize], offset: IdxSize) { assert_eq!(dfs.len(), heights.len()); if !dfs.is_empty() { - let mut previous = heights[0] + offset; - for i in 1..dfs.len() { + let mut previous = offset; + for i in 0..dfs.len() { let df = &mut dfs[i]; - let n_read = heights[i]; + if df.is_empty() { + continue; + } if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = &*s + previous; + if let Ok(v) = s.get(0) { + if v.extract::().unwrap() != previous as usize { + *s = &*s + previous; + } + } } - + let n_read = heights[i]; previous += n_read; } } diff --git a/crates/polars-json/src/json/deserialize.rs b/crates/polars-json/src/json/deserialize.rs index 600bd6dbea53..5e6977eff3e2 100644 --- a/crates/polars-json/src/json/deserialize.rs +++ b/crates/polars-json/src/json/deserialize.rs @@ -17,169 +17,245 @@ const JSON_NULL_VALUE: BorrowedValue = BorrowedValue::Static(StaticNode::Null); fn deserialize_boolean_into<'a, A: Borrow>>( target: &mut MutableBooleanArray, rows: &[A], -) { - let iter = rows.iter().map(|row| match row.borrow() { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::Static(StaticNode::Bool(v)) => Some(v), - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); target.extend_trusted_len(iter); + check_err_idx(rows, err_idx, "boolean") } fn deserialize_primitive_into<'a, T: NativeType + NumCast, A: Borrow>>( target: &mut MutablePrimitiveArray, rows: &[A], -) { - let iter = rows.iter().map(|row| match row.borrow() { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::Static(StaticNode::I64(v)) => T::from(*v), BorrowedValue::Static(StaticNode::U64(v)) => T::from(*v), BorrowedValue::Static(StaticNode::F64(v)) => T::from(*v), BorrowedValue::Static(StaticNode::Bool(v)) => T::from(*v as u8), - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); target.extend_trusted_len(iter); + check_err_idx(rows, err_idx, "numeric") } -fn deserialize_binary<'a, A: Borrow>>(rows: &[A]) -> BinaryArray { - let iter = rows.iter().map(|row| match row.borrow() { +fn deserialize_binary<'a, A: Borrow>>( + rows: &[A], +) -> PolarsResult> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::String(v) => Some(v.as_bytes()), - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); - BinaryArray::from_trusted_len_iter(iter) + let out = BinaryArray::from_trusted_len_iter(iter); + check_err_idx(rows, err_idx, "binary")?; + Ok(out) } fn deserialize_utf8_into<'a, O: Offset, A: Borrow>>( target: &mut MutableUtf8Array, rows: &[A], -) { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); let mut scratch = String::new(); - for row in rows { + for (i, row) in rows.iter().enumerate() { match row.borrow() { BorrowedValue::String(v) => target.push(Some(v.as_ref())), BorrowedValue::Static(StaticNode::Bool(v)) => { target.push(Some(if *v { "true" } else { "false" })) }, - BorrowedValue::Static(node) if !matches!(node, StaticNode::Null) => { + BorrowedValue::Static(StaticNode::Null) => target.push_null(), + BorrowedValue::Static(node) => { write!(scratch, "{node}").unwrap(); target.push(Some(scratch.as_str())); scratch.clear(); }, - _ => target.push_null(), + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, } } + check_err_idx(rows, err_idx, "string") } fn deserialize_utf8view_into<'a, A: Borrow>>( target: &mut MutableBinaryViewArray, rows: &[A], -) { +) -> PolarsResult<()> { + let mut err_idx = rows.len(); let mut scratch = String::new(); - for row in rows { + for (i, row) in rows.iter().enumerate() { match row.borrow() { BorrowedValue::String(v) => target.push_value(v.as_ref()), BorrowedValue::Static(StaticNode::Bool(v)) => { target.push_value(if *v { "true" } else { "false" }) }, - BorrowedValue::Static(node) if !matches!(node, StaticNode::Null) => { + BorrowedValue::Static(StaticNode::Null) => target.push_null(), + BorrowedValue::Static(node) => { write!(scratch, "{node}").unwrap(); target.push_value(scratch.as_str()); scratch.clear(); }, - _ => target.push_null(), + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, } } + check_err_idx(rows, err_idx, "string") } fn deserialize_list<'a, A: Borrow>>( rows: &[A], dtype: ArrowDataType, -) -> ListArray { + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { + let mut err_idx = rows.len(); let child = ListArray::::get_child_type(&dtype); let mut validity = MutableBitmap::with_capacity(rows.len()); let mut offsets = Offsets::::with_capacity(rows.len()); let mut inner = vec![]; - rows.iter().for_each(|row| match row.borrow() { - BorrowedValue::Array(value) => { - inner.extend(value.iter()); - validity.push(true); - offsets - .try_push(value.len()) - .expect("List offset is too large :/"); - }, - BorrowedValue::Static(StaticNode::Null) => { - validity.push(false); - offsets.extend_constant(1) - }, - value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { - inner.push(value); - validity.push(true); - offsets.try_push(1).expect("List offset is too large :/"); - }, - _ => { - validity.push(false); - offsets.extend_constant(1); - }, - }); + rows.iter() + .enumerate() + .for_each(|(i, row)| match row.borrow() { + BorrowedValue::Array(value) => { + inner.extend(value.iter()); + validity.push(true); + offsets + .try_push(value.len()) + .expect("List offset is too large :/"); + }, + BorrowedValue::Static(StaticNode::Null) => { + validity.push(false); + offsets.extend_constant(1) + }, + value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { + inner.push(value); + validity.push(true); + offsets.try_push(1).expect("List offset is too large :/"); + }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, + }); + + check_err_idx(rows, err_idx, "list")?; - let values = _deserialize(&inner, child.clone()); + let values = _deserialize(&inner, child.clone(), allow_extra_fields_in_struct)?; - ListArray::::new(dtype, offsets.into(), values, validity.into()) + Ok(ListArray::::new( + dtype, + offsets.into(), + values, + validity.into(), + )) } fn deserialize_struct<'a, A: Borrow>>( rows: &[A], dtype: ArrowDataType, -) -> StructArray { + allow_extra_fields_in_struct: bool, +) -> PolarsResult { + let mut err_idx = rows.len(); let fields = StructArray::get_fields(&dtype); - let mut values = fields + let mut out_values = fields .iter() .map(|f| (f.name.as_str(), (f.dtype(), vec![]))) .collect::>(); let mut validity = MutableBitmap::with_capacity(rows.len()); + // Custom error tracker + let mut extra_field = None; - rows.iter().for_each(|row| { + rows.iter().enumerate().for_each(|(i, row)| { match row.borrow() { - BorrowedValue::Object(value) => { - values.iter_mut().for_each(|(s, (_, inner))| { - inner.push(value.get(*s).unwrap_or(&JSON_NULL_VALUE)) - }); + BorrowedValue::Object(values) => { + let mut n_matched = 0usize; + for (&key, &mut (_, ref mut inner)) in out_values.iter_mut() { + if let Some(v) = values.get(key) { + n_matched += 1; + inner.push(v) + } else { + inner.push(&JSON_NULL_VALUE) + } + } + validity.push(true); + + if n_matched < values.len() && extra_field.is_none() { + for k in values.keys() { + if !out_values.contains_key(k.as_ref()) { + extra_field = Some(k.as_ref()) + } + } + } }, - _ => { - values + BorrowedValue::Static(StaticNode::Null) => { + out_values .iter_mut() .for_each(|(_, (_, inner))| inner.push(&JSON_NULL_VALUE)); validity.push(false); }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, }; }); + if let Some(v) = extra_field { + if !allow_extra_fields_in_struct { + polars_bail!(ComputeError: "extra key in struct data: {}", v) + } + } + + check_err_idx(rows, err_idx, "struct")?; + // ensure we collect in the proper order let values = fields .iter() .map(|fld| { - let (dtype, vals) = values.get(fld.name.as_str()).unwrap(); - _deserialize(vals, (*dtype).clone()) + let (dtype, vals) = out_values.get(fld.name.as_str()).unwrap(); + _deserialize(vals, (*dtype).clone(), allow_extra_fields_in_struct) }) - .collect::>(); + .collect::>>()?; - StructArray::new(dtype.clone(), rows.len(), values, validity.into()) + Ok(StructArray::new( + dtype.clone(), + rows.len(), + values, + validity.into(), + )) } fn fill_array_from( - f: fn(&mut MutablePrimitiveArray, &[B]), + f: fn(&mut MutablePrimitiveArray, &[B]) -> PolarsResult<()>, dtype: ArrowDataType, rows: &[B], -) -> Box +) -> PolarsResult> where T: NativeType, A: From> + Array, { let mut array = MutablePrimitiveArray::::with_capacity(rows.len()).to(dtype); - f(&mut array, rows); - Box::new(A::from(array)) + f(&mut array, rows)?; + Ok(Box::new(A::from(array))) } /// A trait describing an array with a backing store that can be preallocated to @@ -236,22 +312,34 @@ impl Container for MutableUtf8Array { } } -fn fill_generic_array_from(f: fn(&mut M, &[B]), rows: &[B]) -> Box +fn fill_generic_array_from( + f: fn(&mut M, &[B]) -> PolarsResult<()>, + rows: &[B], +) -> PolarsResult> where M: Container, A: From + Array, { let mut array = M::with_capacity(rows.len()); - f(&mut array, rows); - Box::new(A::from(array)) + f(&mut array, rows)?; + Ok(Box::new(A::from(array))) } pub(crate) fn _deserialize<'a, A: Borrow>>( rows: &[A], dtype: ArrowDataType, -) -> Box { + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { match &dtype { - ArrowDataType::Null => Box::new(NullArray::new(dtype, rows.len())), + ArrowDataType::Null => { + if let Some(err_idx) = (0..rows.len()) + .find(|i| !matches!(rows[*i].borrow(), BorrowedValue::Static(StaticNode::Null))) + { + check_err_idx(rows, err_idx, "null")?; + } + + Ok(Box::new(NullArray::new(dtype, rows.len()))) + }, ArrowDataType::Boolean => { fill_generic_array_from::<_, _, BooleanArray>(deserialize_boolean_into, rows) }, @@ -277,7 +365,8 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) }, ArrowDataType::Timestamp(tu, tz) => { - let iter = rows.iter().map(|row| match row.borrow() { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { BorrowedValue::Static(StaticNode::I64(v)) => Some(*v), BorrowedValue::String(v) => match (tu, tz) { (_, None) => temporal_conversions::utf8_to_naive_timestamp_scalar(v, "%+", tu), @@ -286,9 +375,15 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( temporal_conversions::utf8_to_timestamp_scalar(v, "%+", &tz, tu) }, }, - _ => None, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, }); - Box::new(Int64Array::from_iter(iter).to(dtype)) + let out = Box::new(Int64Array::from_iter(iter).to(dtype)); + check_err_idx(rows, err_idx, "timestamp")?; + Ok(out) }, ArrowDataType::UInt8 => { fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) @@ -315,19 +410,51 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( ArrowDataType::Utf8View => { fill_generic_array_from::<_, _, Utf8ViewArray>(deserialize_utf8view_into, rows) }, - ArrowDataType::LargeList(_) => Box::new(deserialize_list(rows, dtype)), - ArrowDataType::LargeBinary => Box::new(deserialize_binary(rows)), - ArrowDataType::Struct(_) => Box::new(deserialize_struct(rows, dtype)), + ArrowDataType::LargeList(_) => Ok(Box::new(deserialize_list( + rows, + dtype, + allow_extra_fields_in_struct, + )?)), + ArrowDataType::LargeBinary => Ok(Box::new(deserialize_binary(rows)?)), + ArrowDataType::Struct(_) => Ok(Box::new(deserialize_struct( + rows, + dtype, + allow_extra_fields_in_struct, + )?)), _ => todo!(), } } -pub fn deserialize(json: &BorrowedValue, dtype: ArrowDataType) -> PolarsResult> { +pub fn deserialize( + json: &BorrowedValue, + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { match json { BorrowedValue::Array(rows) => match dtype { - ArrowDataType::LargeList(inner) => Ok(_deserialize(rows, inner.dtype)), + ArrowDataType::LargeList(inner) => { + _deserialize(rows, inner.dtype, allow_extra_fields_in_struct) + }, _ => todo!("read an Array from a non-Array data type"), }, - _ => Ok(_deserialize(&[json], dtype)), + _ => _deserialize(&[json], dtype, allow_extra_fields_in_struct), } } + +fn check_err_idx<'a>( + rows: &[impl Borrow>], + err_idx: usize, + type_name: &'static str, +) -> PolarsResult<()> { + if err_idx != rows.len() { + polars_bail!( + ComputeError: + r#"error deserializing value "{:?}" as {}. \ + Try increasing `infer_schema_length` or specifying a schema. + "#, + rows[err_idx].borrow(), type_name, + ) + } + + Ok(()) +} diff --git a/crates/polars-json/src/json/write/mod.rs b/crates/polars-json/src/json/write/mod.rs index a23b245b68b2..6796ef7436bb 100644 --- a/crates/polars-json/src/json/write/mod.rs +++ b/crates/polars-json/src/json/write/mod.rs @@ -101,7 +101,7 @@ impl<'a> RecordSerializer<'a> { } } -impl<'a> FallibleStreamingIterator for RecordSerializer<'a> { +impl FallibleStreamingIterator for RecordSerializer<'_> { type Item = [u8]; type Error = PolarsError; diff --git a/crates/polars-json/src/ndjson/deserialize.rs b/crates/polars-json/src/ndjson/deserialize.rs index 94a482b7b275..4441691cf034 100644 --- a/crates/polars-json/src/ndjson/deserialize.rs +++ b/crates/polars-json/src/ndjson/deserialize.rs @@ -18,19 +18,28 @@ pub fn deserialize_iter<'a>( dtype: ArrowDataType, buf_size: usize, count: usize, + allow_extra_fields_in_struct: bool, ) -> PolarsResult { let mut arr: Vec> = Vec::new(); let mut buf = Vec::with_capacity(std::cmp::min(buf_size + count + 2, u32::MAX as usize)); buf.push(b'['); - fn _deserializer(s: &mut [u8], dtype: ArrowDataType) -> PolarsResult> { + fn _deserializer( + s: &mut [u8], + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, + ) -> PolarsResult> { let out = simd_json::to_borrowed_value(s) .map_err(|e| PolarsError::ComputeError(format!("json parsing error: '{e}'").into()))?; - Ok(if let BorrowedValue::Array(rows) = out { - super::super::json::deserialize::_deserialize(&rows, dtype.clone()) + if let BorrowedValue::Array(rows) = out { + super::super::json::deserialize::_deserialize( + &rows, + dtype.clone(), + allow_extra_fields_in_struct, + ) } else { unreachable!() - }) + } } let mut row_iter = rows.peekable(); @@ -42,7 +51,11 @@ pub fn deserialize_iter<'a>( if buf.len() + next_row_length >= u32::MAX as usize { let _ = buf.pop(); buf.push(b']'); - arr.push(_deserializer(&mut buf, dtype.clone())?); + arr.push(_deserializer( + &mut buf, + dtype.clone(), + allow_extra_fields_in_struct, + )?); buf.clear(); buf.push(b'['); } @@ -53,9 +66,13 @@ pub fn deserialize_iter<'a>( buf.push(b']'); if arr.is_empty() { - _deserializer(&mut buf, dtype.clone()) + _deserializer(&mut buf, dtype.clone(), allow_extra_fields_in_struct) } else { - arr.push(_deserializer(&mut buf, dtype.clone())?); + arr.push(_deserializer( + &mut buf, + dtype.clone(), + allow_extra_fields_in_struct, + )?); concatenate_owned_unchecked(&arr) } } diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 2dfd642cde1f..78f8274fb079 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -71,7 +71,7 @@ temporal = [ ] # debugging purposes fmt = ["polars-core/fmt", "polars-plan/fmt"] -strings = ["polars-plan/strings"] +strings = ["polars-plan/strings", "polars-stream?/strings"] future = [] dtype-full = [ @@ -163,7 +163,7 @@ bitwise = [ "polars-plan/bitwise", "polars-expr/bitwise", "polars-core/bitwise", - "polars-stream/bitwise", + "polars-stream?/bitwise", "polars-ops/bitwise", ] approx_unique = ["polars-plan/approx_unique"] @@ -203,6 +203,7 @@ dynamic_group_by = [ "temporal", "polars-expr/dynamic_group_by", "polars-mem-engine/dynamic_group_by", + "polars-stream?/dynamic_group_by", ] ewma = ["polars-plan/ewma"] ewma_by = ["polars-plan/ewma_by"] @@ -257,7 +258,7 @@ replace = ["polars-plan/replace"] binary_encoding = ["polars-plan/binary_encoding"] string_encoding = ["polars-plan/string_encoding"] -bigidx = ["polars-plan/bigidx"] +bigidx = ["polars-plan/bigidx", "polars-utils/bigidx"] polars_cloud = ["polars-plan/polars_cloud"] panic_on_schema = ["polars-plan/panic_on_schema", "polars-expr/panic_on_schema"] diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index fd4334cad066..cf90c5232450 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -714,48 +714,12 @@ impl LazyFrame { pub fn collect(self) -> PolarsResult { #[cfg(feature = "new_streaming")] { - let auto_new_streaming = - std::env::var("POLARS_AUTO_NEW_STREAMING").as_deref() == Ok("1"); - if self.opt_state.contains(OptFlags::NEW_STREAMING) || auto_new_streaming { - // Try to run using the new streaming engine, falling back - // if it fails in a todo!() error if auto_new_streaming is set. - let mut new_stream_lazy = self.clone(); - new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING; - new_stream_lazy.opt_state &= !OptFlags::STREAMING; - let mut alp_plan = new_stream_lazy.to_alp_optimized()?; - let stream_lp_top = alp_plan.lp_arena.add(IR::Sink { - input: alp_plan.lp_top, - payload: SinkType::Memory, - }); - - let f = || { - polars_stream::run_query( - stream_lp_top, - alp_plan.lp_arena, - &mut alp_plan.expr_arena, - ) - }; - match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { - Ok(r) => return r, - Err(e) => { - // Fallback to normal engine if error is due to not being implemented - // and auto_new_streaming is set, otherwise propagate error. - if auto_new_streaming - && e.downcast_ref::<&str>() - .map(|s| s.starts_with("not yet implemented")) - .unwrap_or(false) - { - if polars_core::config::verbose() { - eprintln!("caught unimplemented error in new streaming engine, falling back to normal engine"); - } - } else { - std::panic::resume_unwind(e); - } - }, - } + let mut slf = self; + if let Some(df) = slf.try_new_streaming_if_requested(SinkType::Memory) { + return Ok(df?.unwrap()); } - let mut alp_plan = self.to_alp_optimized()?; + let mut alp_plan = slf.to_alp_optimized()?; let mut physical_plan = create_physical_plan( alp_plan.lp_top, &mut alp_plan.lp_arena, @@ -895,6 +859,54 @@ impl LazyFrame { ) } + #[cfg(feature = "new_streaming")] + pub fn try_new_streaming_if_requested( + &mut self, + payload: SinkType, + ) -> Option>> { + let auto_new_streaming = std::env::var("POLARS_AUTO_NEW_STREAMING").as_deref() == Ok("1"); + + if self.opt_state.contains(OptFlags::NEW_STREAMING) || auto_new_streaming { + // Try to run using the new streaming engine, falling back + // if it fails in a todo!() error if auto_new_streaming is set. + let mut new_stream_lazy = self.clone(); + new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING; + new_stream_lazy.opt_state &= !OptFlags::STREAMING; + let mut alp_plan = match new_stream_lazy.to_alp_optimized() { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }; + let stream_lp_top = alp_plan.lp_arena.add(IR::Sink { + input: alp_plan.lp_top, + payload, + }); + + let f = || { + polars_stream::run_query(stream_lp_top, alp_plan.lp_arena, &mut alp_plan.expr_arena) + }; + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { + Ok(v) => return Some(v), + Err(e) => { + // Fallback to normal engine if error is due to not being implemented + // and auto_new_streaming is set, otherwise propagate error. + if auto_new_streaming + && e.downcast_ref::<&str>() + .map(|s| s.starts_with("not yet implemented")) + .unwrap_or(false) + { + if polars_core::config::verbose() { + eprintln!("caught unimplemented error in new streaming engine, falling back to normal engine"); + } + } else { + std::panic::resume_unwind(e); + } + }, + } + } + + None + } + #[cfg(any( feature = "ipc", feature = "parquet", @@ -903,11 +915,21 @@ impl LazyFrame { feature = "json", ))] fn sink(mut self, payload: SinkType, msg_alternative: &str) -> Result<(), PolarsError> { - self.opt_state |= OptFlags::STREAMING; + #[cfg(feature = "new_streaming")] + { + if self + .try_new_streaming_if_requested(payload.clone()) + .is_some() + { + return Ok(()); + } + } + self.logical_plan = DslPlan::Sink { input: Arc::new(self.logical_plan), payload, }; + self.opt_state |= OptFlags::STREAMING; let (mut state, mut physical_plan, is_streaming) = self.prepare_collect(true)?; polars_ensure!( is_streaming, @@ -1004,7 +1026,7 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1012,7 +1034,7 @@ impl LazyFrame { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` @@ -1327,7 +1349,7 @@ impl LazyFrame { right_on: E, args: JoinArgs, ) -> LazyFrame { - // if any of the nodes reads from files we must activate this this plan as well. + // if any of the nodes reads from files we must activate this plan as well. if other.opt_state.contains(OptFlags::FILE_CACHING) { self.opt_state |= OptFlags::FILE_CACHING; } @@ -1495,10 +1517,10 @@ impl LazyFrame { } /// Aggregate all the columns as their quantile values. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { self.map_private(DslFunction::Stats(StatsFunction::Quantile { quantile, - interpol, + method, })) } @@ -1885,7 +1907,7 @@ impl LazyGroupBy { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1893,7 +1915,7 @@ impl LazyGroupBy { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index 3059384a1c8c..f3dff5710170 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -104,7 +104,7 @@ //! use polars_core::prelude::*; //! use polars_core::df; //! use polars_lazy::prelude::*; -//! use arrow::legacy::prelude::QuantileInterpolOptions; +//! use arrow::legacy::prelude::QuantileMethod; //! //! fn example() -> PolarsResult { //! let df = df!( @@ -118,7 +118,7 @@ //! .agg([ //! col("rain").min().alias("min_rain"), //! col("rain").sum().alias("sum_rain"), -//! col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), +//! col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), //! ]) //! .sort(["date"], Default::default()) //! .collect() diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index e3c2b9e58120..08673ca1f032 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -24,22 +24,25 @@ pub(crate) fn prepare_expression_for_context( // create a dummy lazyframe and run a very simple optimization run so that // type coercion and simplify expression optimizations run. let column = Series::full_null(name, 0, dtype); - let mut lf = column - .into_frame() + let df = column.into_frame(); + let input_schema = Arc::new(df.schema()); + let lf = df .lazy() .without_optimizations() .with_simplify_expr(true) .select([expr.clone()]); - let schema = lf.collect_schema()?; let optimized = lf.optimize(&mut lp_arena, &mut expr_arena)?; let lp = lp_arena.get(optimized); - let aexpr = lp.get_exprs().pop().unwrap(); + let aexpr = lp + .get_exprs() + .pop() + .ok_or_else(|| polars_err!(ComputeError: "expected expressions in the context"))?; create_physical_expr( &aexpr, ctxt, &expr_arena, - &schema, + &input_schema, &mut ExpressionConversionState::new(true, 0), ) } diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 7100c083bd47..6c84af4510b5 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -163,13 +163,19 @@ pub(crate) fn insert_streaming_nodes( execution_id += 1; match lp_arena.get(root) { Filter { input, predicate } - if is_streamable(predicate.node(), expr_arena, Context::Default) => + if is_streamable( + predicate.node(), + expr_arena, + IsStreamableContext::new(Default::default()), + ) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, - HStack { input, exprs, .. } if all_streamable(exprs, expr_arena, Context::Default) => { + HStack { input, exprs, .. } + if all_streamable(exprs, expr_arena, Default::default()) => + { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) @@ -194,7 +200,13 @@ pub(crate) fn insert_streaming_nodes( state.operators_sinks.push(PipelineNode::Sink(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, - Select { input, expr, .. } if all_streamable(expr, expr_arena, Context::Default) => { + Select { input, expr, .. } + if all_streamable( + expr, + expr_arena, + IsStreamableContext::new(Default::default()), + ) => + { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index 6ed8e1cc67c8..615b74e7738f 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -204,7 +204,7 @@ fn test_cse_joins_4954() -> PolarsResult<()> { let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = c.optimize(&mut lp_arena, &mut expr_arena).unwrap(); - // Ensure we get only one cache and the it is not above the join + // Ensure we get only one cache and it is not above the join // and ensure that every cache only has 1 hit. let cache_ids = (&lp_arena) .iter(lp) @@ -218,7 +218,7 @@ fn test_cse_joins_4954() -> PolarsResult<()> { .. } => { assert_eq!(*cache_hits, 1); - assert!(matches!(lp_arena.get(*input), IR::DataFrameScan { .. })); + assert!(matches!(lp_arena.get(*input), IR::SimpleProjection { .. })); Some(*id) }, diff --git a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs index 57bb99c0fad9..ad41378b3086 100644 --- a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs +++ b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs @@ -144,7 +144,7 @@ fn estimate_unique_count(keys: &[Column], mut sample_size: usize) -> PolarsResul if keys.len() == 1 { // we sample as that will work also with sorted data. - // not that sampling without replacement is very very expensive. don't do that. + // not that sampling without replacement is *very* expensive. don't do that. let s = keys[0].sample_n(sample_size, true, false, None).unwrap(); // fast multi-threaded way to get unique. let groups = s.as_materialized_series().group_tuples(true, false)?; diff --git a/crates/polars-mem-engine/src/executors/scan/ndjson.rs b/crates/polars-mem-engine/src/executors/scan/ndjson.rs index 58862bd71f9e..1f90e07a72c1 100644 --- a/crates/polars-mem-engine/src/executors/scan/ndjson.rs +++ b/crates/polars-mem-engine/src/executors/scan/ndjson.rs @@ -1,5 +1,6 @@ use polars_core::config; use polars_core::utils::accumulate_dataframes_vertical; +use polars_io::prelude::{JsonLineReader, SerReader}; use polars_io::utils::compression::maybe_decompress_bytes; use super::*; diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 40f874d8afe4..ace3c7406883 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -210,7 +210,7 @@ fn create_physical_plan_impl( }, SinkType::File { file_type, .. } => { polars_bail!(InvalidOperation: - "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_parquet()'" + "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_{file_type:?}()'" ) }, #[cfg(feature = "cloud")] @@ -239,7 +239,11 @@ fn create_physical_plan_impl( Ok(Box::new(executors::SliceExec { input, offset, len })) }, Filter { input, predicate } => { - let mut streamable = is_streamable(predicate.node(), expr_arena, Context::Default); + let mut streamable = is_streamable( + predicate.node(), + expr_arena, + IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), + ); let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); if streamable { // This can cause problems with string caches @@ -374,9 +378,6 @@ fn create_physical_plan_impl( POOL.current_num_threads() > expr.len(), state.expr_depth, ); - - let streamable = - options.should_broadcast && all_streamable(&expr, expr_arena, Context::Default); let phys_expr = create_physical_expressions_from_irs( &expr, Context::Default, @@ -384,6 +385,13 @@ fn create_physical_plan_impl( &input_schema, &mut state, )?; + + let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false)) + // If all columns are literal we would get a 1 row per thread. + && !phys_expr.iter().all(|p| { + p.is_literal() + }); + Ok(Box::new(executors::ProjectionExec { input, expr: phys_expr, @@ -627,8 +635,12 @@ fn create_physical_plan_impl( let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; - let streamable = - options.should_broadcast && all_streamable(&exprs, expr_arena, Context::Default); + let streamable = options.should_broadcast + && all_streamable( + &exprs, + expr_arena, + IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), + ); let mut state = ExpressionConversionState::new( POOL.current_num_threads() > exprs.len(), diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 027d846b485e..63e52cffc1e4 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -34,8 +34,10 @@ rand = { workspace = true, optional = true, features = ["small_rng", "std"] } rand_distr = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true } +regex-syntax = { workspace = true } serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } unicode-reverse = { workspace = true, optional = true } [dependencies.jsonpath_lib] diff --git a/crates/polars-ops/src/chunked_array/cov.rs b/crates/polars-ops/src/chunked_array/cov.rs index 5a9b952097b8..dbfa6b48f4fb 100644 --- a/crates/polars-ops/src/chunked_array/cov.rs +++ b/crates/polars-ops/src/chunked_array/cov.rs @@ -1,196 +1,34 @@ -use num_traits::{ToPrimitive, Zero}; -use polars_compute::float_sum::FloatSum; +use num_traits::AsPrimitive; +use polars_compute::var_cov::{CovState, PearsonState}; use polars_core::prelude::*; use polars_core::utils::align_chunks_binary; -const COV_BUF_SIZE: usize = 64; - -/// Calculates the sum of x[i] * y[i] from 0..k. -fn multiply_sum(x: &[f64; COV_BUF_SIZE], y: &[f64; COV_BUF_SIZE], k: usize) -> f64 { - assert!(k <= COV_BUF_SIZE); - let tmp: [f64; COV_BUF_SIZE] = std::array::from_fn(|i| x[i] * y[i]); - FloatSum::sum(&tmp[..k]) -} - /// Compute the covariance between two columns. pub fn cov(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option where T: PolarsNumericType, - T::Native: ToPrimitive, -{ - if a.len() != b.len() { - None - } else { - let (a, b) = align_chunks_binary(a, b); - - let out = if a.null_count() > 0 || b.null_count() > 0 { - let iters = a.downcast_iter().zip(b.downcast_iter()).map(|(a, b)| { - a.into_iter().zip(b).filter_map(|(a, b)| match (a, b) { - (Some(a), Some(b)) => Some((*a, *b)), - _ => None, - }) - }); - online_cov(iters, ddof) - } else { - let iters = a - .downcast_iter() - .zip(b.downcast_iter()) - .map(|(a, b)| a.values_iter().copied().zip(b.values_iter().copied())); - online_cov(iters, ddof) - }; - Some(out) - } -} - -/// # Arguments -/// `iter` - Iterator over `T` tuple where any `Option` would skip the tuple. -fn online_cov(iters: I, ddof: u8) -> f64 -where - I: Iterator, - J: IntoIterator + Clone, - T: ToPrimitive, + T::Native: AsPrimitive, + ChunkedArray: ChunkVar, { - // The algorithm is derived from - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version - // We simply set the weights to 1.0. This allows us to simplify the expressions - // a lot, and move out subtractions out of sums. - let mut mean_x = 0.0; - let mut mean_y = 0.0; - let mut cxy = 0.0; - let mut n = 0.0; - - let mut x_tmp = [0.0; COV_BUF_SIZE]; - let mut y_tmp = [0.0; COV_BUF_SIZE]; - - for iter in iters { - let mut iter = iter.clone().into_iter(); - - loop { - let mut k = 0; - for (x, y) in iter.by_ref().take(COV_BUF_SIZE) { - let x = x.to_f64().unwrap(); - let y = y.to_f64().unwrap(); - - x_tmp[k] = x; - y_tmp[k] = y; - k += 1; - } - if k == 0 { - break; - } - - // TODO: combine these all in one SIMD'ized pass. - let xsum: f64 = FloatSum::sum(&x_tmp[..k]); - let ysum: f64 = FloatSum::sum(&y_tmp[..k]); - let xysum = multiply_sum(&x_tmp, &y_tmp, k); - - let old_mean_x = mean_x; - let old_mean_y = mean_y; - n += k as f64; - mean_x += (xsum - k as f64 * old_mean_x) / n; - mean_y += (ysum - k as f64 * old_mean_y) / n; - - cxy += xysum - xsum * old_mean_y - ysum * mean_x + mean_x * old_mean_y * (k as f64); - } + let (a, b) = align_chunks_binary(a, b); + let mut out = CovState::default(); + for (a, b) in a.downcast_iter().zip(b.downcast_iter()) { + out.combine(&polars_compute::var_cov::cov(a, b)) } - - cxy / (n - ddof as f64) + out.finalize(ddof) } /// Compute the pearson correlation between two columns. pub fn pearson_corr(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option where T: PolarsNumericType, - T::Native: ToPrimitive, + T::Native: AsPrimitive, ChunkedArray: ChunkVar, { let (a, b) = align_chunks_binary(a, b); - - let out = if a.null_count() > 0 || b.null_count() > 0 { - let iters = a.downcast_iter().zip(b.downcast_iter()).map(|(a, b)| { - a.into_iter().zip(b).filter_map(|(a, b)| match (a, b) { - (Some(a), Some(b)) => Some((*a, *b)), - _ => None, - }) - }); - online_pearson_corr(iters, ddof) - } else { - let iters = a - .downcast_iter() - .zip(b.downcast_iter()) - .map(|(a, b)| a.values_iter().copied().zip(b.values_iter().copied())); - online_pearson_corr(iters, ddof) - }; - Some(out) -} - -/// # Arguments -/// `iter` - Iterator over `T` tuple where any `Option` would skip the tuple. -fn online_pearson_corr(iters: I, ddof: u8) -> f64 -where - I: Iterator, - J: IntoIterator + Clone, - T: ToPrimitive, -{ - // Algorithm is same as cov, we just maintain cov(X, X), cov(X, Y), and - // cov(Y, Y), noting that var(X) = cov(X, X). - // Then corr(X, Y) = cov(X, Y)/(std(X) * std(Y)). - let mut mean_x = 0.0; - let mut mean_y = 0.0; - let mut cxy = 0.0; - let mut cxx = 0.0; - let mut cyy = 0.0; - let mut n = 0.0; - - let mut x_tmp = [0.0; COV_BUF_SIZE]; - let mut y_tmp = [0.0; COV_BUF_SIZE]; - - for iter in iters { - let mut iter = iter.clone().into_iter(); - - loop { - let mut k = 0; - for (x, y) in iter.by_ref().take(COV_BUF_SIZE) { - let x = x.to_f64().unwrap(); - let y = y.to_f64().unwrap(); - - x_tmp[k] = x; - y_tmp[k] = y; - k += 1; - } - if k == 0 { - break; - } - - // TODO: combine these all in one SIMD'ized pass. - let xsum: f64 = FloatSum::sum(&x_tmp[..k]); - let ysum: f64 = FloatSum::sum(&y_tmp[..k]); - let xxsum = multiply_sum(&x_tmp, &x_tmp, k); - let xysum = multiply_sum(&x_tmp, &y_tmp, k); - let yysum = multiply_sum(&y_tmp, &y_tmp, k); - - let old_mean_x = mean_x; - let old_mean_y = mean_y; - n += k as f64; - mean_x += (xsum - k as f64 * old_mean_x) / n; - mean_y += (ysum - k as f64 * old_mean_y) / n; - - cxx += xxsum - xsum * old_mean_x - xsum * mean_x + mean_x * old_mean_x * (k as f64); - cxy += xysum - xsum * old_mean_y - ysum * mean_x + mean_x * old_mean_y * (k as f64); - cyy += yysum - ysum * old_mean_y - ysum * mean_y + mean_y * old_mean_y * (k as f64); - } - } - - let sample_n = n - ddof as f64; - let sample_cov = cxy / sample_n; - let sample_std_x = (cxx / sample_n).sqrt(); - let sample_std_y = (cyy / sample_n).sqrt(); - - let denom = sample_std_x * sample_std_y; - let result = sample_cov / denom; - if denom.is_zero() { - f64::NAN - } else { - result + let mut out = PearsonState::default(); + for (a, b) in a.downcast_iter().zip(b.downcast_iter()) { + out.combine(&polars_compute::var_cov::pearson_corr(a, b)) } + Some(out.finalize(ddof)) } diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index 6676de3983db..fad1bcebb9a1 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -5,82 +5,220 @@ use polars_utils::pl_str::PlSmallStr; use super::*; -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ListToStructArgs { + FixedWidth(Arc<[PlSmallStr]>), + InferWidth { + infer_field_strategy: ListToStructWidthStrategy, + get_index_name: Option, + /// If this is 0, it means unbounded. + max_fields: usize, + }, +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum ListToStructWidthStrategy { FirstNonNull, MaxWidth, } -fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize { - match n_fields { - ListToStructWidthStrategy::MaxWidth => { - let mut max = 0; - - ca.downcast_iter().for_each(|arr| { - let offsets = arr.offsets().as_slice(); - let mut last = offsets[0]; - for o in &offsets[1..] { - let len = (*o - last) as usize; - max = std::cmp::max(max, len); - last = *o; +impl ListToStructArgs { + pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult { + let DataType::List(inner_dtype) = input_dtype else { + polars_bail!( + InvalidOperation: + "attempted list to_struct on non-list dtype: {}", + input_dtype + ); + }; + let inner_dtype = inner_dtype.as_ref(); + + match self { + Self::FixedWidth(names) => Ok(DataType::Struct( + names + .iter() + .map(|x| Field::new(x.clone(), inner_dtype.clone())) + .collect::>(), + )), + Self::InferWidth { + get_index_name, + max_fields, + .. + } if *max_fields > 0 => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + Ok(DataType::Struct( + (0..*max_fields) + .map(|i| Field::new(get_index_name_func(i), inner_dtype.clone())) + .collect::>(), + )) + }, + Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)), + } + } + + fn det_n_fields(&self, ca: &ListChunked) -> usize { + match self { + Self::FixedWidth(v) => v.len(), + Self::InferWidth { + infer_field_strategy, + max_fields, + .. + } => { + let inferred = match infer_field_strategy { + ListToStructWidthStrategy::MaxWidth => { + let mut max = 0; + + ca.downcast_iter().for_each(|arr| { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + let len = (*o - last) as usize; + max = std::cmp::max(max, len); + last = *o; + } + }); + max + }, + ListToStructWidthStrategy::FirstNonNull => { + let mut len = 0; + for arr in ca.downcast_iter() { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + len = (*o - last) as usize; + if len > 0 { + break; + } + last = *o; + } + if len > 0 { + break; + } + } + len + }, + }; + + if *max_fields > 0 { + inferred.min(*max_fields) + } else { + inferred } - }); - max - }, - ListToStructWidthStrategy::FirstNonNull => { - let mut len = 0; - for arr in ca.downcast_iter() { - let offsets = arr.offsets().as_slice(); - let mut last = offsets[0]; - for o in &offsets[1..] { - len = (*o - last) as usize; - if len > 0 { - break; - } - last = *o; + }, + } + } + + fn set_output_names(&self, columns: &mut [Series]) { + match self { + Self::FixedWidth(v) => { + assert_eq!(columns.len(), v.len()); + + for (c, name) in columns.iter_mut().zip(v.iter()) { + c.rename(name.clone()); } - if len > 0 { - break; + }, + Self::InferWidth { get_index_name, .. } => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + + for (i, c) in columns.iter_mut().enumerate() { + c.rename(get_index_name_func(i)); } - } - len - }, + }, + } + } +} + +#[derive(Clone)] +pub struct NameGenerator(pub Arc PlSmallStr + Send + Sync>); + +impl NameGenerator { + pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self { + Self(Arc::new(func)) + } +} + +impl std::fmt::Debug for NameGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "list::to_struct::NameGenerator function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for NameGenerator {} + +impl PartialEq for NameGenerator { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) } } -pub type NameGenerator = Arc PlSmallStr + Send + Sync>; +impl std::hash::Hash for NameGenerator { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr { format_pl_smallstr!("field_{idx}") } pub trait ToStruct: AsList { - fn to_struct( - &self, - n_fields: ListToStructWidthStrategy, - name_generator: Option, - ) -> PolarsResult { + fn to_struct(&self, args: &ListToStructArgs) -> PolarsResult { let ca = self.as_list(); - let n_fields = det_n_fields(ca, n_fields); + let n_fields = args.det_n_fields(ca); - let name_generator = name_generator - .as_deref() - .unwrap_or(&_default_struct_name_gen); - - let fields = POOL.install(|| { + let mut fields = POOL.install(|| { (0..n_fields) .into_par_iter() - .map(|i| { - ca.lst_get(i as i64, true).map(|mut s| { - s.rename(name_generator(i)); - s - }) - }) + .map(|i| ca.lst_get(i as i64, true)) .collect::>>() })?; + args.set_output_names(&mut fields); + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) } } impl ToStruct for ListChunked {} + +#[cfg(feature = "serde")] +mod _serde_impl { + use super::*; + + impl serde::Serialize for NameGenerator { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + Err(S::Error::custom( + "cannot serialize name generator function for to_struct, \ + consider passing a list of field names instead.", + )) + } + } + + impl<'de> serde::Deserialize<'de> for NameGenerator { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + Err(D::Error::custom( + "invalid data: attempted to deserialize list::to_struct::NameGenerator", + )) + } + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/case.rs b/crates/polars-ops/src/chunked_array/strings/case.rs index 7bb348e28803..4004f122143f 100644 --- a/crates/polars-ops/src/chunked_array/strings/case.rs +++ b/crates/polars-ops/src/chunked_array/strings/case.rs @@ -5,7 +5,7 @@ fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec) { out.clear(); out.reserve(b.len()); - const USIZE_SIZE: usize = std::mem::size_of::(); + const USIZE_SIZE: usize = size_of::(); const MAGIC_UNROLL: usize = 2; const N: usize = USIZE_SIZE * MAGIC_UNROLL; const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]); diff --git a/crates/polars-ops/src/chunked_array/strings/escape_regex.rs b/crates/polars-ops/src/chunked_array/strings/escape_regex.rs new file mode 100644 index 000000000000..1edb9146e9f4 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/escape_regex.rs @@ -0,0 +1,21 @@ +use polars_core::prelude::{StringChunked, StringChunkedBuilder}; + +#[inline] +pub fn escape_regex_str(s: &str) -> String { + regex_syntax::escape(s) +} + +pub fn escape_regex(ca: &StringChunked) -> StringChunked { + let mut buffer = String::new(); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + for opt_s in ca.iter() { + if let Some(s) = opt_s { + buffer.clear(); + regex_syntax::escape_into(s, &mut buffer); + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } + builder.finish() +} diff --git a/crates/polars-ops/src/chunked_array/strings/json_path.rs b/crates/polars-ops/src/chunked_array/strings/json_path.rs index 02b3c076efd7..7ef813d63027 100644 --- a/crates/polars-ops/src/chunked_array/strings/json_path.rs +++ b/crates/polars-ops/src/chunked_array/strings/json_path.rs @@ -98,6 +98,8 @@ pub trait Utf8JsonPathImpl: AsString { infer_schema_len: Option, ) -> PolarsResult { let ca = self.as_string(); + // Ignore extra fields instead of erroring if the dtype was explicitly given. + let allow_extra_fields_in_struct = dtype.is_some(); let dtype = match dtype { Some(dt) => dt, None => ca.json_infer(infer_schema_len)?, @@ -110,6 +112,7 @@ pub trait Utf8JsonPathImpl: AsString { dtype.to_arrow(CompatLevel::newest()), buf_size, ca.len(), + allow_extra_fields_in_struct, ) .map_err(|e| polars_err!(ComputeError: "error deserializing JSON: {}", e))?; Series::try_from((PlSmallStr::EMPTY, array)) diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index b9149983307b..326349c36815 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -3,6 +3,8 @@ mod case; #[cfg(feature = "strings")] mod concat; #[cfg(feature = "strings")] +mod escape_regex; +#[cfg(feature = "strings")] mod extract; #[cfg(feature = "find_many")] mod find_many; @@ -20,12 +22,13 @@ mod split; mod strip; #[cfg(feature = "strings")] mod substring; - #[cfg(all(not(feature = "nightly"), feature = "strings"))] mod unicode_internals; #[cfg(feature = "strings")] pub use concat::*; +#[cfg(feature = "strings")] +pub use escape_regex::*; #[cfg(feature = "find_many")] pub use find_many::*; #[cfg(feature = "extract_jsonpath")] diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 812dfbfcba91..93574e5f3080 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -640,6 +640,12 @@ pub trait StringNameSpaceImpl: AsString { substring::tail(ca, n.i64()?) } + #[cfg(feature = "strings")] + /// Escapes all regular expression meta characters in the string. + fn str_escape_regex(&self) -> StringChunked { + let ca = self.as_string(); + escape_regex::escape_regex(ca) + } } impl StringNameSpaceImpl for StringChunked {} diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index b4d347170cbb..4c845a2ba541 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -18,6 +18,7 @@ pub type ChunkJoinIds = Vec; use polars_core::export::once_cell::sync::Lazy; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; #[derive(Clone, PartialEq, Eq, Debug, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -108,8 +109,9 @@ impl JoinArgs { } } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum JoinType { Inner, Left, diff --git a/crates/polars-ops/src/frame/join/dispatch_left_right.rs b/crates/polars-ops/src/frame/join/dispatch_left_right.rs index f5c91de88a74..b3193ce76628 100644 --- a/crates/polars-ops/src/frame/join/dispatch_left_right.rs +++ b/crates/polars-ops/src/frame/join/dispatch_left_right.rs @@ -90,14 +90,14 @@ fn materialize_left_join( if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - left._create_left_df_from_slice(left_idx, true, true) + left._create_left_df_from_slice(left_idx, true, args.slice.is_some(), true) }, ChunkJoinIds::Right(left_idx) => unsafe { let mut left_idx = &*left_idx; if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - left.create_left_df_chunked(left_idx, true) + left.create_left_df_chunked(left_idx, true, args.slice.is_some()) }, }; @@ -133,7 +133,8 @@ fn materialize_left_join( if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - let materialize_left = || unsafe { left._create_left_df_from_slice(&left_idx, true, true) }; + let materialize_left = + || unsafe { left._create_left_df_from_slice(&left_idx, true, args.slice.is_some(), true) }; let mut right_idx = &*right_idx; if let Some((offset, len)) = args.slice { diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs index 1420d7b66062..0bf0a86cd972 100644 --- a/crates/polars-ops/src/frame/join/general.rs +++ b/crates/polars-ops/src/frame/join/general.rs @@ -56,7 +56,7 @@ pub fn _coalesce_full_join( df_left: &DataFrame, ) -> DataFrame { // No need to allocate the schema because we already - // know for certain that the column name for left left is `name` + // know for certain that the column name for left is `name` // and for right is `name + suffix` let schema_left = if keys_left == keys_right { Schema::default() diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index fcda1be6802d..8d5533c41e5b 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -55,9 +55,18 @@ pub trait JoinDispatch: IntoDf { /// # Safety /// Join tuples must be in bounds #[cfg(feature = "chunked_ids")] - unsafe fn create_left_df_chunked(&self, chunk_ids: &[ChunkId], left_join: bool) -> DataFrame { + unsafe fn create_left_df_chunked( + &self, + chunk_ids: &[ChunkId], + left_join: bool, + was_sliced: bool, + ) -> DataFrame { let df_self = self.to_df(); - if left_join && chunk_ids.len() == df_self.height() { + + let left_join_no_duplicate_matches = + left_join && !was_sliced && chunk_ids.len() == df_self.height(); + + if left_join_no_duplicate_matches { df_self.clone() } else { // left join keys are in ascending order @@ -76,10 +85,15 @@ pub trait JoinDispatch: IntoDf { &self, join_tuples: &[IdxSize], left_join: bool, + was_sliced: bool, sorted_tuple_idx: bool, ) -> DataFrame { let df_self = self.to_df(); - if left_join && join_tuples.len() == df_self.height() { + + let left_join_no_duplicate_matches = + left_join && !was_sliced && join_tuples.len() == df_self.height(); + + if left_join_no_duplicate_matches { df_self.clone() } else { // left join tuples are always in ascending order diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 81f4fe54e7e4..fd59ef7f4a1c 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -34,11 +34,11 @@ use hashbrown::hash_map::{Entry, RawEntryMut}; pub use iejoin::{IEJoinOptions, InequalityOperator}; #[cfg(feature = "merge_sorted")] pub use merge_sorted::_merge_sorted_dfs; -use polars_core::hashing::_HASHMAP_INIT_SIZE; #[allow(unused_imports)] -use polars_core::prelude::sort::arg_sort_multiple::{ +use polars_core::chunked_array::ops::row_encode::{ encode_rows_vertical_par_unordered, encode_rows_vertical_par_unordered_broadcast_nulls, }; +use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; pub(super) use polars_core::series::IsSorted; use polars_core::utils::slice_offsets; @@ -506,7 +506,14 @@ trait DataFrameJoinOpsPrivate: IntoDf { let (df_left, df_right) = POOL.join( // SAFETY: join indices are known to be in bounds - || unsafe { left_df._create_left_df_from_slice(join_tuples_left, false, sorted) }, + || unsafe { + left_df._create_left_df_from_slice( + join_tuples_left, + false, + args.slice.is_some(), + sorted, + ) + }, || unsafe { if let Some(drop_names) = drop_names { other.drop_many(drop_names) diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index 4604920351eb..d72ff4488251 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -11,9 +11,6 @@ use polars_core::utils::accumulate_dataframes_horizontal; #[cfg(feature = "to_dummies")] use polars_core::POOL; -#[allow(unused_imports)] -use crate::prelude::*; - pub trait IntoDf { fn to_df(&self) -> &DataFrame; } @@ -94,6 +91,8 @@ pub trait DataFrameOps: IntoDf { separator: Option<&str>, drop_first: bool, ) -> PolarsResult { + use crate::series::ToDummies; + let df = self.to_df(); let set: PlHashSet<&str> = if let Some(columns) = columns { diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index f91d537fb9d4..b7058de0a05e 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -240,13 +240,13 @@ pub(super) fn compute_col_idx( let col_locations = match column_agg_physical.dtype() { T::Int32 | T::UInt32 => { let Some(BitRepr::Small(ca)) = column_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 32-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 32-bit representation to be available; this should never happen"); }; compute_col_idx_numeric(&ca) }, T::Int64 | T::UInt64 => { let Some(BitRepr::Large(ca)) = column_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 64-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 64-bit representation to be available; this should never happen"); }; compute_col_idx_numeric(&ca) }, @@ -413,13 +413,13 @@ pub(super) fn compute_row_idx( match index_agg_physical.dtype() { T::Int32 | T::UInt32 => { let Some(BitRepr::Small(ca)) = index_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 32-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 32-bit representation to be available; this should never happen"); }; compute_row_index(index, &ca, count, index_s.dtype()) }, T::Int64 | T::UInt64 => { let Some(BitRepr::Large(ca)) = index_agg_physical.bit_repr() else { - polars_bail!(ComputeError: "Expected 64-bit bit representation to be available. This should never happen"); + polars_bail!(ComputeError: "Expected 64-bit representation to be available; this should never happen"); }; compute_row_index(index, &ca, count, index_s.dtype()) }, diff --git a/crates/polars-ops/src/frame/pivot/unpivot.rs b/crates/polars-ops/src/frame/pivot/unpivot.rs index 60f5bff8eae9..49eeaeba4498 100644 --- a/crates/polars-ops/src/frame/pivot/unpivot.rs +++ b/crates/polars-ops/src/frame/pivot/unpivot.rs @@ -4,7 +4,7 @@ use polars_core::datatypes::{DataType, PlSmallStr}; use polars_core::frame::column::Column; use polars_core::frame::DataFrame; use polars_core::prelude::{IntoVec, Series, UnpivotArgsIR}; -use polars_core::utils::try_get_supertype; +use polars_core::utils::merge_dtypes_many; use polars_error::{polars_err, PolarsResult}; use polars_utils::aliases::PlHashSet; @@ -104,9 +104,9 @@ pub trait UnpivotDF: IntoDf { let len = self_.height(); - // if value vars is empty we take all columns that are not in id_vars. + // If value vars is empty we take all columns that are not in id_vars. if on.is_empty() { - // return empty frame if there are no columns available to use as value vars + // Return empty frame if there are no columns available to use as value vars. if index.len() == self_.width() { let variable_col = Column::new_empty(variable_name, &DataType::String); let value_col = Column::new_empty(value_name, &DataType::Null); @@ -133,15 +133,14 @@ pub trait UnpivotDF: IntoDf { .collect(); } - // values will all be placed in single column, so we must find their supertype + // Values will all be placed in single column, so we must find their supertype let schema = self_.schema(); - let mut iter = on + let dtypes = on .iter() - .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))); - let mut st = iter.next().unwrap()?.clone(); - for dt in iter { - st = try_get_supertype(&st, dt?)?; - } + .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))) + .collect::>>()?; + + let st = merge_dtypes_many(dtypes.iter())?; // The column name of the variable that is unpivoted let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); @@ -167,7 +166,7 @@ pub trait UnpivotDF: IntoDf { let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?; let col = &columns[pos]; let value_col = col.cast(&st).map_err( - |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}", col.dtype()), + |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}\n\nConsider casting to String.", col.dtype()), )?; values.extend_from_slice(value_col.as_materialized_series().chunks()) } diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index c78aa67efca8..b7b8d3e9f179 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -144,11 +144,7 @@ pub fn qcut( let s2 = s.sort(SortOptions::default())?; let ca = s2.f64()?; - let f = |&p| { - ca.quantile(p, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - }; + let f = |&p| ca.quantile(p, QuantileMethod::Linear).unwrap().unwrap(); let mut qbreaks: Vec<_> = probs.iter().map(f).collect(); qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); diff --git a/crates/polars-ops/src/series/ops/is_between.rs b/crates/polars-ops/src/series/ops/is_between.rs index 053493d552f6..96b1074b4d82 100644 --- a/crates/polars-ops/src/series/ops/is_between.rs +++ b/crates/polars-ops/src/series/ops/is_between.rs @@ -3,9 +3,11 @@ use std::ops::BitAnd; use polars_core::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum ClosedInterval { #[default] Both, diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index dfe3ba1a3ddf..eb2cf3a228c1 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -46,7 +46,8 @@ impl ToDummies for Series { }) .collect::>(); - Ok(unsafe { DataFrame::new_no_checks_height_from_first(sort_columns(columns)) }) + // SAFETY: `dummies_helper` functions preserve `self.len()` length + unsafe { DataFrame::new_no_length_checks(sort_columns(columns)) } } } diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index e49e24ddf7a8..47d467f7f7ba 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -1,7 +1,7 @@ use num_traits::Bounded; -use polars_core::prelude::arity::unary_elementwise_values; #[cfg(feature = "dtype-struct")] -use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_ca; +use polars_core::prelude::arity::unary_elementwise_values; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::with_match_physical_numeric_polars_type; diff --git a/crates/polars-parquet/Cargo.toml b/crates/polars-parquet/Cargo.toml index 26a57b22e713..8ae9108a1fa7 100644 --- a/crates/polars-parquet/Cargo.toml +++ b/crates/polars-parquet/Cargo.toml @@ -22,12 +22,12 @@ fallible-streaming-iterator = { workspace = true, optional = true } futures = { workspace = true, optional = true } hashbrown = { workspace = true } num-traits = { workspace = true } -polars-compute = { workspace = true } +polars-compute = { workspace = true, features = ["approx_unique"] } polars-error = { workspace = true } +polars-parquet-format = "0.1" polars-utils = { workspace = true, features = ["mmap"] } simdutf8 = { workspace = true } -parquet-format-safe = "0.2" streaming-decompression = "0.1" async-stream = { version = "0.3.3", optional = true } @@ -61,6 +61,6 @@ gzip_zlib_ng = ["flate2/zlib-ng"] lz4 = ["dep:lz4"] lz4_flex = ["dep:lz4_flex"] -async = ["async-stream", "futures", "parquet-format-safe/async"] +async = ["async-stream", "futures", "polars-parquet-format/async"] bloom_filter = ["xxhash-rust"] serde_types = ["serde"] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs index 6777f7e639c9..86e46756788a 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs @@ -362,7 +362,7 @@ impl DeltaGatherer for StatGatherer { } } -impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut DeltaCollector<'a, 'b> { +impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut DeltaCollector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } @@ -394,7 +394,7 @@ impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut Delta } } -impl<'a, 'b> DeltaCollector<'a, 'b> { +impl DeltaCollector<'_, '_> { pub fn flush(&mut self, target: &mut MutableBinaryViewArray<[u8]>) { if !self.pushed_lengths.is_empty() { let start_bytes_len = target.total_bytes_len(); @@ -428,7 +428,7 @@ impl<'a, 'b> DeltaCollector<'a, 'b> { } } -impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for DeltaBytesCollector<'a, 'b> { +impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for DeltaBytesCollector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } @@ -621,7 +621,7 @@ impl utils::Decoder for BinViewDecoder { max_length: &'b mut usize, } - impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'a, 'b> { + impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } @@ -709,7 +709,7 @@ impl utils::Decoder for BinViewDecoder { ) -> ParquetResult<()> { struct DictionaryTranslator<'a>(&'a [View]); - impl<'a> HybridRleGatherer for DictionaryTranslator<'a> { + impl HybridRleGatherer for DictionaryTranslator<'_> { type Target = MutableBinaryViewArray<[u8]>; fn target_reserve(&self, target: &mut Self::Target, n: usize) { @@ -803,7 +803,7 @@ impl utils::Decoder for BinViewDecoder { translator: DictionaryTranslator<'b>, } - impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'a, 'b> { + impl BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'_, '_> { fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs index af2e504d2646..51026f483bd7 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs @@ -161,7 +161,7 @@ impl HybridRleGatherer for BitmapGatherer { // @TODO: The slice impl here can speed some stuff up } struct BitmapCollector<'a, 'b>(&'b mut HybridRleDecoder<'a>); -impl<'a, 'b> BatchableCollector for BitmapCollector<'a, 'b> { +impl BatchableCollector for BitmapCollector<'_, '_> { fn reserve(target: &mut MutableBitmap, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs index de2bfe2e47f3..478c7cca0f2e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs @@ -184,7 +184,7 @@ pub(crate) struct DictArrayTranslator { dict_size: usize, } -impl<'a, 'b, K: DictionaryKey> BatchableCollector<(), Vec> for DictArrayCollector<'a, 'b> { +impl BatchableCollector<(), Vec> for DictArrayCollector<'_, '_> { fn reserve(target: &mut Vec, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs index 3825d528c8f5..5657a20dd151 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -154,7 +154,7 @@ impl Decoder for BinaryDecoder { size: usize, } - impl<'a, 'b> BatchableCollector<(), Vec> for FixedSizeBinaryCollector<'a, 'b> { + impl BatchableCollector<(), Vec> for FixedSizeBinaryCollector<'_, '_> { fn reserve(target: &mut Vec, n: usize) { target.reserve(n); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index a622153dfca8..d37a6d4bf3b1 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -215,8 +215,8 @@ pub struct BatchedNestedDecoder<'a, 'b, 'c, D: utils::NestedDecoder> { decoder: &'c mut D, } -impl<'a, 'b, 'c, D: utils::NestedDecoder> BatchableCollector<(), D::DecodedState> - for BatchedNestedDecoder<'a, 'b, 'c, D> +impl BatchableCollector<(), D::DecodedState> + for BatchedNestedDecoder<'_, '_, '_, D> { fn reserve(_target: &mut D::DecodedState, _n: usize) { unreachable!() diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs index 0a43141abd06..eb4815b6fbfe 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs @@ -55,7 +55,7 @@ where let values = split_buffer(page)?.values; Ok(Self::ByteStreamSplit(byte_stream_split::Decoder::try_new( values, - std::mem::size_of::

(), + size_of::

(), )?)) }, _ => Err(utils::not_implemented(page)), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs index ed8e0a541a68..ff9c0b014b08 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs @@ -58,7 +58,7 @@ where let values = split_buffer(page)?.values; Ok(Self::ByteStreamSplit(byte_stream_split::Decoder::try_new( values, - std::mem::size_of::

(), + size_of::

(), )?)) }, (Encoding::DeltaBinaryPacked, _) => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs index 1a9d50a66d31..5539595fda48 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs @@ -125,8 +125,8 @@ where pub(crate) _pd: std::marker::PhantomData, } -impl<'a, 'b, P, T, D: DecoderFunction> BatchableCollector<(), Vec> - for PlainDecoderFnCollector<'a, 'b, P, T, D> +impl> BatchableCollector<(), Vec> + for PlainDecoderFnCollector<'_, '_, P, T, D> where T: NativeType, P: ParquetNativeType, @@ -167,7 +167,7 @@ where D: DecoderFunction, { values - .chunks_exact(std::mem::size_of::

()) + .chunks_exact(size_of::

()) .map(decode) .map(|v| decoder.decode(v)) .collect::>() @@ -239,7 +239,7 @@ where } } -impl<'a, 'b, P, T, D> BatchableCollector<(), Vec> for DeltaCollector<'a, 'b, P, T, D> +impl BatchableCollector<(), Vec> for DeltaCollector<'_, '_, P, T, D> where T: NativeType, P: ParquetNativeType, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs index 56912934e100..f86523b74d9b 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs @@ -1,7 +1,7 @@ use arrow::array::{Array, DictionaryArray, DictionaryKey, FixedSizeBinaryArray, PrimitiveArray}; use arrow::datatypes::{ArrowDataType, IntervalUnit, TimeUnit}; use arrow::match_integer_type; -use arrow::types::{days_ms, i256}; +use arrow::types::{days_ms, i256, NativeType}; use ethnum::I256; use polars_error::{polars_bail, PolarsResult}; @@ -275,6 +275,34 @@ pub fn page_iter_to_array( primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), + + // Float16 + (PhysicalType::FixedLenByteArray(2), Float32) => { + // @NOTE: To reduce code bloat, we just use the FixedSizeBinary decoder. + + let mut fsb_array = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(2), + fixed_size_binary::BinaryDecoder { size: 2 }, + )?.collect_n(filter)?; + + + let validity = fsb_array.take_validity(); + let values = fsb_array.values().as_slice(); + assert_eq!(values.len() % 2, 0); + let values = values.chunks_exact(2); + let values = values.map(|v| { + // SAFETY: We know that `v` is always of size two. + let le_bytes: [u8; 2] = unsafe { v.try_into().unwrap_unchecked() }; + let v = arrow::types::f16::from_le_bytes(le_bytes); + v.to_f32() + }).collect(); + + let array = PrimitiveArray::::new(dtype, values, validity); + + Box::new(array) + }, + (PhysicalType::Float, Float32) => Box::new(PageDecoder::new( pages, dtype, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs index 330ad77a7c44..f66544ee3183 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs @@ -16,7 +16,7 @@ impl<'a, P: ParquetNativeType> ArrayChunks<'a, P> { /// /// This returns null if the `bytes` slice's length is not a multiple of the size of `P::Bytes`. pub(crate) fn new(bytes: &'a [u8]) -> Option { - if bytes.len() % std::mem::size_of::() != 0 { + if bytes.len() % size_of::() != 0 { return None; } @@ -47,4 +47,4 @@ impl<'a, P: ParquetNativeType> Iterator for ArrayChunks<'a, P> { } } -impl<'a, P: ParquetNativeType> ExactSizeIterator for ArrayChunks<'a, P> {} +impl ExactSizeIterator for ArrayChunks<'_, P> {} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs index dba00fc97930..7c6cf840bdce 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -433,7 +433,7 @@ where } } -impl<'a, 'b, 'c, O, T> BatchableCollector> for TranslatedHybridRle<'a, 'b, 'c, O, T> +impl BatchableCollector> for TranslatedHybridRle<'_, '_, '_, O, T> where O: Clone + Default, T: Translator, @@ -487,7 +487,7 @@ where } } -impl<'a, 'b, 'c, O, G> BatchableCollector> for GatheredHybridRle<'a, 'b, 'c, O, G> +impl BatchableCollector> for GatheredHybridRle<'_, '_, '_, O, G> where O: Clone, G: HybridRleGatherer>, @@ -516,8 +516,8 @@ where } } -impl<'a, 'b, 'c, T> BatchableCollector> - for TranslatedHybridRle<'a, 'b, 'c, View, T> +impl BatchableCollector> + for TranslatedHybridRle<'_, '_, '_, View, T> where T: Translator, { diff --git a/crates/polars-parquet/src/arrow/read/schema/metadata.rs b/crates/polars-parquet/src/arrow/read/schema/metadata.rs index 915936c81296..f73c401bd72b 100644 --- a/crates/polars-parquet/src/arrow/read/schema/metadata.rs +++ b/crates/polars-parquet/src/arrow/read/schema/metadata.rs @@ -35,6 +35,7 @@ fn convert_dtype(mut dtype: ArrowDataType) -> ArrowDataType { convert_field(field); } }, + Float16 => dtype = Float32, Binary | LargeBinary => dtype = BinaryView, Utf8 | LargeUtf8 => dtype = Utf8View, Dictionary(_, ref mut dtype, _) | Extension(_, ref mut dtype, _) => { diff --git a/crates/polars-parquet/src/arrow/read/schema/mod.rs b/crates/polars-parquet/src/arrow/read/schema/mod.rs index 347cd49faefd..ea27aa03d46d 100644 --- a/crates/polars-parquet/src/arrow/read/schema/mod.rs +++ b/crates/polars-parquet/src/arrow/read/schema/mod.rs @@ -40,7 +40,7 @@ impl Default for SchemaInferenceOptions { /// /// # Error /// This function errors iff the key `"ARROW:schema"` exists but is not correctly encoded, -/// indicating that that the file's arrow metadata was incorrectly written. +/// indicating that the file's arrow metadata was incorrectly written. pub fn infer_schema(file_metadata: &FileMetadata) -> PolarsResult { infer_schema_with_options(file_metadata, &None) } diff --git a/crates/polars-parquet/src/arrow/read/statistics/mod.rs b/crates/polars-parquet/src/arrow/read/statistics/mod.rs index d9210abcb5aa..22ba71caba17 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/mod.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/mod.rs @@ -3,7 +3,7 @@ use std::collections::VecDeque; use arrow::array::*; use arrow::datatypes::{ArrowDataType, Field, IntervalUnit, PhysicalType}; -use arrow::types::i256; +use arrow::types::{f16, i256, NativeType}; use arrow::with_match_primitive_type_full; use ethnum::I256; use polars_error::{polars_bail, PolarsResult}; @@ -28,6 +28,7 @@ mod struct_; mod utf8; use self::list::DynMutableListArray; +use super::PrimitiveLogicalType; /// Arrow-deserialized parquet Statistics of a file #[derive(Debug, PartialEq)] @@ -549,6 +550,37 @@ fn push( } } +pub fn cast_statistics( + statistics: ParquetStatistics, + primitive_type: &ParquetPrimitiveType, + output_type: &ArrowDataType, +) -> ParquetStatistics { + use {ArrowDataType as DT, PrimitiveLogicalType as PT}; + + match (primitive_type.logical_type, output_type) { + (Some(PT::Float16), DT::Float32) => { + let statistics = statistics.expect_fixedlen(); + + let primitive_type = primitive_type.clone(); + + ParquetStatistics::Float(PrimitiveStatistics:: { + primitive_type, + null_count: statistics.null_count, + distinct_count: statistics.distinct_count, + min_value: statistics + .min_value + .as_ref() + .map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()), + max_value: statistics + .max_value + .as_ref() + .map(|v| f16::from_le_bytes([v[0], v[1]]).to_f32()), + }) + }, + _ => statistics, + } +} + /// Deserializes the statistics in the column chunks from a single `row_group` /// into [`Statistics`] associated from `field`'s name. /// @@ -562,9 +594,13 @@ pub fn deserialize<'a>( let mut stats = field_md .map(|column| { + let primitive_type = &column.descriptor().descriptor.primitive_type; Ok(( - column.statistics().transpose()?, - column.descriptor().descriptor.primitive_type.clone(), + column + .statistics() + .transpose()? + .map(|stats| cast_statistics(stats, primitive_type, &field.dtype)), + primitive_type.clone(), )) }) .collect::, ParquetPrimitiveType)>>>()?; diff --git a/crates/polars-parquet/src/arrow/write/binary/basic.rs b/crates/polars-parquet/src/arrow/write/binary/basic.rs index 0032a5d1e29d..7af6ab5f353d 100644 --- a/crates/polars-parquet/src/arrow/write/binary/basic.rs +++ b/crates/polars-parquet/src/arrow/write/binary/basic.rs @@ -30,15 +30,15 @@ pub(crate) fn encode_plain( ) { if options.is_optional() && array.validity().is_some() { let len_before = buffer.len(); - let capacity = array.get_values_size() - + (array.len() - array.null_count()) * std::mem::size_of::(); + let capacity = + array.get_values_size() + (array.len() - array.null_count()) * size_of::(); buffer.reserve(capacity); encode_non_null_values(array.non_null_values_iter(), buffer); // Ensure we allocated properly. debug_assert_eq!(buffer.len() - len_before, capacity); } else { let len_before = buffer.len(); - let capacity = array.get_values_size() + array.len() * std::mem::size_of::(); + let capacity = array.get_values_size() + array.len() * size_of::(); buffer.reserve(capacity); encode_non_null_values(array.values_iter(), buffer); // Ensure we allocated properly. diff --git a/crates/polars-parquet/src/arrow/write/binview/basic.rs b/crates/polars-parquet/src/arrow/write/binview/basic.rs index 8251818e722f..018e1627d25f 100644 --- a/crates/polars-parquet/src/arrow/write/binview/basic.rs +++ b/crates/polars-parquet/src/arrow/write/binview/basic.rs @@ -16,8 +16,8 @@ pub(crate) fn encode_plain( buffer: &mut Vec, ) { if options.is_optional() && array.validity().is_some() { - let capacity = array.total_bytes_len() - + (array.len() - array.null_count()) * std::mem::size_of::(); + let capacity = + array.total_bytes_len() + (array.len() - array.null_count()) * size_of::(); let len_before = buffer.len(); buffer.reserve(capacity); @@ -26,7 +26,7 @@ pub(crate) fn encode_plain( // Append the non-null values. debug_assert_eq!(buffer.len() - len_before, capacity); } else { - let capacity = array.total_bytes_len() + array.len() * std::mem::size_of::(); + let capacity = array.total_bytes_len() + array.len() * size_of::(); let len_before = buffer.len(); buffer.reserve(capacity); diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index 4e0d57302314..fc97c268c0fd 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -3,16 +3,17 @@ use arrow::array::{ }; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::buffer::Buffer; -use arrow::datatypes::{ArrowDataType, IntegerType}; +use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType}; +use arrow::legacy::utils::CustomIterTools; +use arrow::trusted_len::TrustMyLength; use arrow::types::NativeType; use polars_compute::min_max::MinMaxKernel; use polars_error::{polars_bail, PolarsResult}; -use polars_utils::unwrap::UnwrapUncheckedRelease; use super::binary::{ build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, }; -use super::fixed_len_bytes::{ +use super::fixed_size_binary::{ build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain, }; use super::pages::PrimitiveNested; @@ -31,33 +32,51 @@ use crate::parquet::CowBuffer; use crate::write::DynIter; trait MinMaxThreshold { - const DELTA_THRESHOLD: Self; + const DELTA_THRESHOLD: usize; + const BITMASK_THRESHOLD: usize; + + fn from_start_and_offset(start: Self, offset: usize) -> Self; } macro_rules! minmaxthreshold_impls { - ($($t:ty => $threshold:literal,)+) => { + ($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => { $( - impl MinMaxThreshold for $t { - const DELTA_THRESHOLD: Self = $threshold; + impl MinMaxThreshold for $signed { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + ((offset as $unsigned) as $signed) + } + } + impl MinMaxThreshold for $unsigned { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + (offset as $unsigned) + } } )+ }; } minmaxthreshold_impls! { - i8 => 16, - i16 => 256, - i32 => 512, - i64 => 2048, - u8 => 16, - u16 => 256, - u32 => 512, - u64 => 2048, + i8, u8 => 16, u8::MAX as usize, + i16, u16 => 256, u16::MAX as usize, + i32, u32 => 512, u16::MAX as usize, + i64, u64 => 2048, u16::MAX as usize, +} + +enum DictionaryDecision { + NotWorth, + TryAgain, + Found(DictionaryArray), } fn min_max_integer_encode_as_dictionary_optional<'a, E, T>( array: &'a dyn Array, -) -> Option> +) -> DictionaryDecision where E: std::fmt::Debug, T: NativeType @@ -65,26 +84,82 @@ where + std::cmp::Ord + TryInto + std::ops::Sub - + num_traits::CheckedSub, + + num_traits::CheckedSub + + num_traits::cast::AsPrimitive, std::ops::RangeInclusive: Iterator, PrimitiveArray: MinMaxKernel = T>, { - use ArrowDataType as DT; - let (min, max): (T, T) = as MinMaxKernel>::min_max_ignore_nan_kernel( + let min_max = as MinMaxKernel>::min_max_ignore_nan_kernel( array.as_any().downcast_ref().unwrap(), - )?; + ); + + let Some((min, max)) = min_max else { + return DictionaryDecision::TryAgain; + }; debug_assert!(max >= min, "{max} >= {min}"); - if !max - .checked_sub(&min) - .is_some_and(|v| v <= T::DELTA_THRESHOLD) - { - return None; + let Some(diff) = max.checked_sub(&min) else { + return DictionaryDecision::TryAgain; + }; + + let diff = diff.as_(); + + if diff > T::BITMASK_THRESHOLD { + return DictionaryDecision::TryAgain; + } + + let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1); + + let array = array.as_any().downcast_ref::>().unwrap(); + + if array.has_nulls() { + for v in array.non_null_values_iter() { + let offset = (v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } + } else { + for v in array.values_iter() { + let offset = (*v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } } - // @TODO: This currently overestimates the values, it might be interesting to use the unique - // kernel here. - let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), (min..=max).collect(), None); + let cardinality = seen_mask.set_bits(); + + let mut is_worth_it = false; + + is_worth_it |= cardinality <= T::DELTA_THRESHOLD; + is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75; + + if !is_worth_it { + return DictionaryDecision::NotWorth; + } + + let seen_mask = seen_mask.freeze(); + + // SAFETY: We just did the calculation for this. + let indexes = seen_mask + .true_idx_iter() + .map(|idx| T::from_start_and_offset(min, idx)); + let indexes = unsafe { TrustMyLength::new(indexes, cardinality) }; + let indexes = indexes.collect_trusted::>(); + + let mut lookup = vec![0u16; diff + 1]; + + for (i, &idx) in indexes.iter().enumerate() { + lookup[(idx - min).as_()] = i as u16; + } + + use ArrowDataType as DT; + let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None); let values = Box::new(values); let keys: Buffer = array @@ -93,20 +168,19 @@ where .unwrap() .values() .iter() - .map(|v| unsafe { + .map(|v| { // @NOTE: // Since the values might contain nulls which have a undefined value. We just // clamp the values to between the min and max value. This way, they will still - // be valid dictionary keys. This is mostly to make the - // unwrap_unchecked_release not produce any unsafety. - (*v.clamp(&min, &max) - min) - .try_into() - .unwrap_unchecked_release() + // be valid dictionary keys. + let idx = *v.clamp(&min, &max) - min; + let value = unsafe { lookup.get_unchecked(idx.as_()) }; + (*value).into() }) .collect(); let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned()); - Some( + DictionaryDecision::Found( DictionaryArray::::try_new( ArrowDataType::Dictionary( IntegerType::UInt32, @@ -126,26 +200,15 @@ pub(crate) fn encode_as_dictionary_optional( type_: PrimitiveType, options: WriteOptions, ) -> Option>>> { - use ArrowDataType as DT; - let fast_dictionary = match array.dtype() { - DT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), - DT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), - DT::Int32 | DT::Date32 | DT::Time32(_) => { - min_max_integer_encode_as_dictionary_optional::<_, i32>(array) - }, - DT::Int64 | DT::Date64 | DT::Time64(_) | DT::Timestamp(_, _) | DT::Duration(_) => { - min_max_integer_encode_as_dictionary_optional::<_, i64>(array) - }, - DT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), - DT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), - DT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), - DT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), - _ => None, - }; + if array.is_empty() { + let array = DictionaryArray::::new_empty(ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(array.dtype().clone()), + false, // @TODO: This might be able to be set to true? + )); - if let Some(fast_dictionary) = fast_dictionary { return Some(array_to_pages( - &fast_dictionary, + &array, type_, nested, options, @@ -153,9 +216,44 @@ pub(crate) fn encode_as_dictionary_optional( )); } + use arrow::types::PrimitiveType as PT; + let fast_dictionary = match array.dtype().to_physical_type() { + PhysicalType::Primitive(pt) => match pt { + PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), + PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), + PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array), + PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array), + PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), + PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), + PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), + PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), + _ => DictionaryDecision::TryAgain, + }, + _ => DictionaryDecision::TryAgain, + }; + + match fast_dictionary { + DictionaryDecision::NotWorth => return None, + DictionaryDecision::Found(dictionary_array) => { + return Some(array_to_pages( + &dictionary_array, + type_, + nested, + options, + Encoding::RleDictionary, + )) + }, + DictionaryDecision::TryAgain => {}, + } + let dtype = Box::new(array.dtype().clone()); - let len_before = array.len(); + let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array); + + if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 { + return None; + } + // This does the group by. let array = arrow::compute::cast::cast( array, @@ -169,10 +267,6 @@ pub(crate) fn encode_as_dictionary_optional( .downcast_ref::>() .unwrap(); - if (array.values().len() as f64) / (len_before as f64) > 0.75 { - return None; - } - Some(array_to_pages( array, type_, diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs new file mode 100644 index 000000000000..27151ce51f70 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs @@ -0,0 +1,47 @@ +use arrow::array::{Array, FixedSizeBinaryArray}; +use polars_error::PolarsResult; + +use super::encode_plain; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; +use crate::read::schema::is_nullable; +use crate::write::{utils, EncodeNullability, Encoding, WriteOptions}; + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + statistics: Option, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, encode_options, &mut buffer); + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics.map(|x| x.serialize()), + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/fixed_len_bytes.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs similarity index 79% rename from crates/polars-parquet/src/arrow/write/fixed_len_bytes.rs rename to crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs index 9277b9c78a98..58f11adfa491 100644 --- a/crates/polars-parquet/src/arrow/write/fixed_len_bytes.rs +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs @@ -1,12 +1,13 @@ +mod basic; +mod nested; + use arrow::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; use arrow::types::i256; -use polars_error::PolarsResult; +pub use basic::array_to_page; +pub use nested::array_to_page as nested_array_to_page; use super::binary::ord_binary; -use super::{utils, EncodeNullability, StatisticsOptions, WriteOptions}; -use crate::arrow::read::schema::is_nullable; -use crate::parquet::encoding::Encoding; -use crate::parquet::page::DataPage; +use super::{EncodeNullability, StatisticsOptions}; use crate::parquet::schema::types::PrimitiveType; use crate::parquet::statistics::FixedLenStatistics; @@ -27,44 +28,6 @@ pub(crate) fn encode_plain( } } -pub fn array_to_page( - array: &FixedSizeBinaryArray, - options: WriteOptions, - type_: PrimitiveType, - statistics: Option, -) -> PolarsResult { - let is_optional = is_nullable(&type_.field_info); - let encode_options = EncodeNullability::new(is_optional); - - let validity = array.validity(); - - let mut buffer = vec![]; - utils::write_def_levels( - &mut buffer, - is_optional, - validity, - array.len(), - options.version, - )?; - - let definition_levels_byte_length = buffer.len(); - - encode_plain(array, encode_options, &mut buffer); - - utils::build_plain_page( - buffer, - array.len(), - array.len(), - array.null_count(), - 0, - definition_levels_byte_length, - statistics.map(|x| x.serialize()), - type_, - options, - Encoding::Plain, - ) -} - pub(super) fn build_statistics( array: &FixedSizeBinaryArray, primitive_type: PrimitiveType, diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs new file mode 100644 index 000000000000..81175cf5db18 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs @@ -0,0 +1,39 @@ +use arrow::array::{Array, FixedSizeBinaryArray}; +use polars_error::PolarsResult; + +use super::encode_plain; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; +use crate::read::schema::is_nullable; +use crate::write::{nested, utils, EncodeNullability, Encoding, Nested, WriteOptions}; + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], + statistics: Option, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, encode_options, &mut buffer); + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics.map(|x| x.serialize()), + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/mod.rs b/crates/polars-parquet/src/arrow/write/mod.rs index 02f0165d04c7..17a342ac9d67 100644 --- a/crates/polars-parquet/src/arrow/write/mod.rs +++ b/crates/polars-parquet/src/arrow/write/mod.rs @@ -17,7 +17,7 @@ mod binview; mod boolean; mod dictionary; mod file; -mod fixed_len_bytes; +mod fixed_size_binary; mod nested; mod pages; mod primitive; @@ -528,7 +528,7 @@ pub fn array_to_page_simple( array.validity().cloned(), ); let statistics = if options.has_statistics() { - Some(fixed_len_bytes::build_statistics( + Some(fixed_size_binary::build_statistics( &array, type_.clone(), &options.statistics, @@ -536,7 +536,7 @@ pub fn array_to_page_simple( } else { None }; - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) }, ArrowDataType::Interval(IntervalUnit::DayTime) => { let array = array @@ -555,7 +555,7 @@ pub fn array_to_page_simple( array.validity().cloned(), ); let statistics = if options.has_statistics() { - Some(fixed_len_bytes::build_statistics( + Some(fixed_size_binary::build_statistics( &array, type_.clone(), &options.statistics, @@ -563,12 +563,12 @@ pub fn array_to_page_simple( } else { None }; - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) }, ArrowDataType::FixedSizeBinary(_) => { let array = array.as_any().downcast_ref().unwrap(); let statistics = if options.has_statistics() { - Some(fixed_len_bytes::build_statistics( + Some(fixed_size_binary::build_statistics( array, type_.clone(), &options.statistics, @@ -577,7 +577,7 @@ pub fn array_to_page_simple( None }; - fixed_len_bytes::array_to_page(array, options, type_, statistics) + fixed_size_binary::array_to_page(array, options, type_, statistics) }, ArrowDataType::Decimal256(precision, _) => { let precision = *precision; @@ -620,7 +620,7 @@ pub fn array_to_page_simple( } else if precision <= 38 { let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + let stats = fixed_size_binary::build_statistics_decimal256_with_i128( array, type_.clone(), size, @@ -641,7 +641,7 @@ pub fn array_to_page_simple( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) } else { let size = 32; let array = array @@ -649,7 +649,7 @@ pub fn array_to_page_simple( .downcast_ref::>() .unwrap(); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256( + let stats = fixed_size_binary::build_statistics_decimal256( array, type_.clone(), size, @@ -670,7 +670,7 @@ pub fn array_to_page_simple( array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) } }, ArrowDataType::Decimal(precision, _) => { @@ -715,7 +715,7 @@ pub fn array_to_page_simple( let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal( + let stats = fixed_size_binary::build_statistics_decimal( array, type_.clone(), size, @@ -736,7 +736,7 @@ pub fn array_to_page_simple( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::array_to_page(&array, options, type_, statistics) } }, other => polars_bail!(nyi = "Writing parquet pages for data type {other:?}"), @@ -858,7 +858,7 @@ fn array_to_page_nested( let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal( + let stats = fixed_size_binary::build_statistics_decimal( array, type_.clone(), size, @@ -879,7 +879,7 @@ fn array_to_page_nested( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) } }, Decimal256(precision, _) => { @@ -919,7 +919,7 @@ fn array_to_page_nested( } else if precision <= 38 { let size = decimal_length_from_precision(precision); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + let stats = fixed_size_binary::build_statistics_decimal256_with_i128( array, type_.clone(), size, @@ -940,7 +940,7 @@ fn array_to_page_nested( values.into(), array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) } else { let size = 32; let array = array @@ -948,7 +948,7 @@ fn array_to_page_nested( .downcast_ref::>() .unwrap(); let statistics = if options.has_statistics() { - let stats = fixed_len_bytes::build_statistics_decimal256( + let stats = fixed_size_binary::build_statistics_decimal256( array, type_.clone(), size, @@ -969,7 +969,7 @@ fn array_to_page_nested( array.validity().cloned(), ); - fixed_len_bytes::array_to_page(&array, options, type_, statistics) + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) } }, other => polars_bail!(nyi = "Writing nested parquet pages for data type {other:?}"), diff --git a/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs b/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs index 546efd034a9c..961393bf4ed2 100644 --- a/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs +++ b/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs @@ -79,7 +79,7 @@ pub fn num_values(nested: &[Nested]) -> usize { BufferedDremelIter::new(nested).count() } -impl<'a> Level<'a> { +impl Level<'_> { /// Fetch the number of elements given on the next level at `offset` on this level fn next_level_length(&self, offset: usize, is_valid: bool) -> usize { match self.lengths { @@ -407,7 +407,7 @@ impl<'a> BufferedDremelIter<'a> { } } -impl<'a> Iterator for BufferedDremelIter<'a> { +impl Iterator for BufferedDremelIter<'_> { type Item = DremelValue; fn next(&mut self) -> Option { diff --git a/crates/polars-parquet/src/arrow/write/primitive/basic.rs b/crates/polars-parquet/src/arrow/write/primitive/basic.rs index d970e3659dcb..88bab6ca5ea6 100644 --- a/crates/polars-parquet/src/arrow/write/primitive/basic.rs +++ b/crates/polars-parquet/src/arrow/write/primitive/basic.rs @@ -38,7 +38,7 @@ where let mut iter = validity.iter(); let values = array.values().as_slice(); - buffer.reserve(std::mem::size_of::() * (array.len() - null_count)); + buffer.reserve(size_of::() * (array.len() - null_count)); let mut offset = 0; let mut remaining_valid = array.len() - null_count; @@ -61,7 +61,7 @@ where } } - buffer.reserve(std::mem::size_of::

() * array.len()); + buffer.reserve(size_of::

() * array.len()); buffer.extend( array .values() diff --git a/crates/polars-parquet/src/arrow/write/schema.rs b/crates/polars-parquet/src/arrow/write/schema.rs index 1403a7f4eeec..61c7f9fad218 100644 --- a/crates/polars-parquet/src/arrow/write/schema.rs +++ b/crates/polars-parquet/src/arrow/write/schema.rs @@ -290,7 +290,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { ArrowDataType::Struct(fields) => { if fields.is_empty() { polars_bail!(InvalidOperation: - "Parquet does not support writing empty structs".to_string(), + "Unable to write struct type with no child field to Parquet. Consider adding a dummy child field.".to_string(), ) } // recursively convert children to types/nodes diff --git a/crates/polars-parquet/src/parquet/bloom_filter/read.rs b/crates/polars-parquet/src/parquet/bloom_filter/read.rs index deda00b36272..fe4ee718cb7e 100644 --- a/crates/polars-parquet/src/parquet/bloom_filter/read.rs +++ b/crates/polars-parquet/src/parquet/bloom_filter/read.rs @@ -1,7 +1,7 @@ use std::io::{Read, Seek, SeekFrom}; -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; -use parquet_format_safe::{ +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::{ BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHeader, SplitBlockAlgorithm, Uncompressed, }; diff --git a/crates/polars-parquet/src/parquet/compression.rs b/crates/polars-parquet/src/parquet/compression.rs index 41bfb5f557bf..f8e90f65e3ee 100644 --- a/crates/polars-parquet/src/parquet/compression.rs +++ b/crates/polars-parquet/src/parquet/compression.rs @@ -245,7 +245,7 @@ fn try_decompress_hadoop(input_buf: &[u8], output_buf: &mut [u8]) -> ParquetResu // The Hadoop Lz4Codec source code can be found here: // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/Lz4Codec.cc - const SIZE_U32: usize = std::mem::size_of::(); + const SIZE_U32: usize = size_of::(); const PREFIX_LEN: usize = SIZE_U32 * 2; let mut input_len = input_buf.len(); let mut input = input_buf; diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs index 6e37507d137f..b5ea9b815dc1 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs @@ -13,7 +13,7 @@ pub struct Decoder<'a, T: Unpackable> { _pd: std::marker::PhantomData, } -impl<'a, T: Unpackable> Default for Decoder<'a, T> { +impl Default for Decoder<'_, T> { fn default() -> Self { Self { packed: [].chunks(1), @@ -56,7 +56,7 @@ impl<'a, T: Unpackable> Decoder<'a, T> { num_bits: usize, length: usize, ) -> ParquetResult { - let block_size = std::mem::size_of::() * num_bits; + let block_size = size_of::() * num_bits; if packed.len() * 8 < length * num_bits { return Err(ParquetError::oos(format!( @@ -78,7 +78,7 @@ impl<'a, T: Unpackable> Decoder<'a, T> { /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. pub fn try_new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { - let block_size = std::mem::size_of::() * num_bits; + let block_size = size_of::() * num_bits; if num_bits == 0 { return Err(ParquetError::oos("Bitpacking requires num_bits > 0")); @@ -114,7 +114,7 @@ pub struct ChunkedDecoder<'a, 'b, T: Unpackable> { pub(crate) decoder: &'b mut Decoder<'a, T>, } -impl<'a, 'b, T: Unpackable> Iterator for ChunkedDecoder<'a, 'b, T> { +impl Iterator for ChunkedDecoder<'_, '_, T> { type Item = T::Unpacked; #[inline] @@ -136,9 +136,9 @@ impl<'a, 'b, T: Unpackable> Iterator for ChunkedDecoder<'a, 'b, T> { } } -impl<'a, 'b, T: Unpackable> ExactSizeIterator for ChunkedDecoder<'a, 'b, T> {} +impl ExactSizeIterator for ChunkedDecoder<'_, '_, T> {} -impl<'a, 'b, T: Unpackable> ChunkedDecoder<'a, 'b, T> { +impl ChunkedDecoder<'_, '_, T> { /// Get and consume the remainder chunk if it exists pub fn remainder(&mut self) -> Option<(T::Unpacked, usize)> { let remainder_len = self.decoder.len() % T::Unpacked::LENGTH; @@ -181,7 +181,7 @@ impl<'a, T: Unpackable> Decoder<'a, T> { } pub fn take(&mut self) -> Self { - let block_size = std::mem::size_of::() * self.num_bits; + let block_size = size_of::() * self.num_bits; let packed = std::mem::replace(&mut self.packed, [].chunks(block_size)); let length = self.length; self.length = 0; @@ -262,7 +262,7 @@ mod tests { use super::super::tests::case1; use super::*; - impl<'a, T: Unpackable> Decoder<'a, T> { + impl Decoder<'_, T> { pub fn collect(self) -> Vec { let mut vec = Vec::new(); self.collect_into(&mut vec); diff --git a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs index 793fa6f111d7..1b383e9522f1 100644 --- a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs @@ -96,7 +96,7 @@ where converter: F, } -impl<'a, 'b, T, F> Iterator for DecoderIterator<'a, 'b, T, F> +impl Iterator for DecoderIterator<'_, '_, T, F> where F: Copy + Fn(&[u8]) -> T, { diff --git a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs index 1ef6a99a2128..555954ff8069 100644 --- a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs @@ -14,7 +14,7 @@ mod tests { let mut buffer = vec![]; encode(&data, &mut buffer); - let mut decoder = Decoder::try_new(&buffer, std::mem::size_of::())?; + let mut decoder = Decoder::try_new(&buffer, size_of::())?; let values = decoder .iter_converted(|bytes| f32::from_le_bytes(bytes.try_into().unwrap())) .collect::>(); @@ -30,7 +30,7 @@ mod tests { let mut buffer = vec![]; encode(&data, &mut buffer); - let mut decoder = Decoder::try_new(&buffer, std::mem::size_of::())?; + let mut decoder = Decoder::try_new(&buffer, size_of::())?; let values = decoder .iter_converted(|bytes| f64::from_le_bytes(bytes.try_into().unwrap())) .collect::>(); @@ -61,9 +61,9 @@ mod tests { } fn encode(data: &[T], buffer: &mut Vec) { - let element_size = std::mem::size_of::(); + let element_size = size_of::(); let num_elements = data.len(); - let total_length = std::mem::size_of_val(data); + let total_length = size_of_val(data); buffer.resize(total_length, 0); for (i, v) in data.iter().enumerate() { diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs index 03889e0aa5d3..deb95f1dd3a2 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs @@ -56,7 +56,7 @@ impl<'a> Decoder<'a> { mod tests { use super::*; - impl<'a> Iterator for Decoder<'a> { + impl Iterator for Decoder<'_> { type Item = ParquetResult>; fn next(&mut self) -> Option { diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs index f46f22f84adb..0d67dc935857 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs @@ -39,7 +39,7 @@ impl<'a> BitmapIter<'a> { } } -impl<'a> Iterator for BitmapIter<'a> { +impl Iterator for BitmapIter<'_> { type Item = bool; #[inline] diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs index 824638d253ad..95d53b2769e4 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/buffered.rs @@ -44,7 +44,7 @@ impl Iterator for BufferedRle { impl ExactSizeIterator for BufferedRle {} -impl<'a> Iterator for BufferedBitpacked<'a> { +impl Iterator for BufferedBitpacked<'_> { type Item = u32; fn next(&mut self) -> Option { @@ -74,9 +74,9 @@ impl<'a> Iterator for BufferedBitpacked<'a> { } } -impl<'a> ExactSizeIterator for BufferedBitpacked<'a> {} +impl ExactSizeIterator for BufferedBitpacked<'_> {} -impl<'a> Iterator for HybridRleBuffered<'a> { +impl Iterator for HybridRleBuffered<'_> { type Item = u32; fn next(&mut self) -> Option { @@ -94,9 +94,9 @@ impl<'a> Iterator for HybridRleBuffered<'a> { } } -impl<'a> ExactSizeIterator for HybridRleBuffered<'a> {} +impl ExactSizeIterator for HybridRleBuffered<'_> {} -impl<'a> BufferedBitpacked<'a> { +impl BufferedBitpacked<'_> { fn gather_limited_into>( &mut self, target: &mut G::Target, @@ -212,7 +212,7 @@ impl BufferedRle { } } -impl<'a> HybridRleBuffered<'a> { +impl HybridRleBuffered<'_> { pub fn gather_limited_into>( &mut self, target: &mut G::Target, diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs index 1548f6e50a02..c66ef5873439 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/gatherer.rs @@ -432,7 +432,7 @@ impl Translator for UnitTranslator { /// [`HybridRleDecoder`]: super::HybridRleDecoder pub struct DictionaryTranslator<'a, T>(pub &'a [T]); -impl<'a, T: Copy> Translator for DictionaryTranslator<'a, T> { +impl Translator for DictionaryTranslator<'_, T> { fn translate(&self, value: u32) -> ParquetResult { self.0 .get(value as usize) diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index c71b6455ddf4..f721274d00cd 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -133,7 +133,7 @@ impl<'a> HybridRleDecoder<'a> { if run_length == 0 { 0 } else { - let mut bytes = [0u8; std::mem::size_of::()]; + let mut bytes = [0u8; size_of::()]; pack.iter().zip(bytes.iter_mut()).for_each(|(src, dst)| { *dst = *src; }); @@ -380,7 +380,7 @@ impl<'a> HybridRleDecoder<'a> { if run_length <= n { run_length } else { - let mut bytes = [0u8; std::mem::size_of::()]; + let mut bytes = [0u8; size_of::()]; pack.iter().zip(bytes.iter_mut()).for_each(|(src, dst)| { *dst = *src; }); diff --git a/crates/polars-parquet/src/parquet/error.rs b/crates/polars-parquet/src/parquet/error.rs index fdd7f8fcadfd..3d00cf3c647f 100644 --- a/crates/polars-parquet/src/parquet/error.rs +++ b/crates/polars-parquet/src/parquet/error.rs @@ -94,8 +94,8 @@ impl From for ParquetError { } } -impl From for ParquetError { - fn from(e: parquet_format_safe::thrift::Error) -> ParquetError { +impl From for ParquetError { + fn from(e: polars_parquet_format::thrift::Error) -> ParquetError { ParquetError::OutOfSpec(format!("Invalid thrift: {}", e)) } } diff --git a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs index 30a606d6108a..dba897a8eeea 100644 --- a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::{ColumnChunk, ColumnMetaData, Encoding}; +use polars_parquet_format::{ColumnChunk, ColumnMetaData, Encoding}; use super::column_descriptor::ColumnDescriptor; use crate::parquet::compression::Compression; @@ -10,7 +10,7 @@ use crate::parquet::statistics::Statistics; mod serde_types { pub use std::io::Cursor; - pub use parquet_format_safe::thrift::protocol::{ + pub use polars_parquet_format::thrift::protocol::{ TCompactInputProtocol, TCompactOutputProtocol, }; pub use serde::de::Error as DeserializeError; diff --git a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs index 2705c2a7b70d..ed14a1e130d6 100644 --- a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::ColumnOrder as TColumnOrder; +use polars_parquet_format::ColumnOrder as TColumnOrder; use super::column_order::ColumnOrder; use super::schema_descriptor::SchemaDescriptor; @@ -8,7 +8,7 @@ use crate::parquet::metadata::get_sort_order; pub use crate::parquet::thrift_format::KeyValue; /// Metadata for a Parquet file. -// This is almost equal to [`parquet_format_safe::FileMetaData`] but contains the descriptors, +// This is almost equal to [`polars_parquet_format::FileMetaData`] but contains the descriptors, // which are crucial to deserialize pages. #[derive(Debug, Clone)] pub struct FileMetadata { @@ -65,7 +65,7 @@ impl FileMetadata { /// Deserializes [`crate::parquet::thrift_format::FileMetadata`] into this struct pub fn try_from_thrift( - metadata: parquet_format_safe::FileMetaData, + metadata: polars_parquet_format::FileMetaData, ) -> Result { let schema_descr = SchemaDescriptor::try_from_thrift(&metadata.schema)?; diff --git a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs index 9cca27553415..b02983a760ed 100644 --- a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use hashbrown::hash_map::RawEntryMut; -use parquet_format_safe::RowGroup; +use polars_parquet_format::{RowGroup, SortingColumn}; use polars_utils::aliases::{InitHashMaps, PlHashMap}; use polars_utils::idx_vec::UnitVec; use polars_utils::pl_str::PlSmallStr; @@ -41,6 +41,7 @@ pub struct RowGroupMetadata { num_rows: usize, total_byte_size: usize, full_byte_range: core::ops::Range, + sorting_columns: Option>, } impl RowGroupMetadata { @@ -59,6 +60,11 @@ impl RowGroupMetadata { .map(|x| x.iter().map(|&x| &self.columns[x])) } + /// Fetch all columns under this root name if it exists. + pub fn columns_idxs_under_root_iter<'a>(&'a self, root_name: &str) -> Option<&'a [usize]> { + self.column_lookup.get(root_name).map(|x| x.as_slice()) + } + /// Number of rows in this row group. pub fn num_rows(&self) -> usize { self.num_rows @@ -85,6 +91,10 @@ impl RowGroupMetadata { self.columns.iter().map(|x| x.byte_range()) } + pub fn sorting_columns(&self) -> Option<&[SortingColumn]> { + self.sorting_columns.as_deref() + } + /// Method to convert from Thrift. pub(crate) fn try_from_thrift( schema_descr: &SchemaDescriptor, @@ -106,6 +116,8 @@ impl RowGroupMetadata { 0..0 }; + let sorting_columns = rg.sorting_columns.clone(); + let columns = rg .columns .into_iter() @@ -131,6 +143,7 @@ impl RowGroupMetadata { num_rows, total_byte_size, full_byte_range, + sorting_columns, }) } } diff --git a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs index 7c29f983ee1d..c40fcdd1309b 100644 --- a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs +++ b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::SchemaElement; +use polars_parquet_format::SchemaElement; use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/metadata/sort.rs b/crates/polars-parquet/src/parquet/metadata/sort.rs index 93aac06605b6..d75d77134103 100644 --- a/crates/polars-parquet/src/parquet/metadata/sort.rs +++ b/crates/polars-parquet/src/parquet/metadata/sort.rs @@ -56,6 +56,7 @@ fn get_logical_sort_order(logical_type: &PrimitiveLogicalType) -> SortOrder { Timestamp { .. } => SortOrder::Signed, Unknown => SortOrder::Undefined, Uuid => SortOrder::Unsigned, + Float16 => SortOrder::Unsigned, } } diff --git a/crates/polars-parquet/src/parquet/mod.rs b/crates/polars-parquet/src/parquet/mod.rs index ea6b5b2c8357..1926e641fd04 100644 --- a/crates/polars-parquet/src/parquet/mod.rs +++ b/crates/polars-parquet/src/parquet/mod.rs @@ -15,7 +15,7 @@ pub mod write; use std::ops::Deref; -use parquet_format_safe as thrift_format; +use polars_parquet_format as thrift_format; use polars_utils::mmap::MemSlice; pub use streaming_decompression::{fallible_streaming_iterator, FallibleStreamingIterator}; diff --git a/crates/polars-parquet/src/parquet/parquet_bridge.rs b/crates/polars-parquet/src/parquet/parquet_bridge.rs index 21261f7ca011..523ea0e3e12e 100644 --- a/crates/polars-parquet/src/parquet/parquet_bridge.rs +++ b/crates/polars-parquet/src/parquet/parquet_bridge.rs @@ -497,6 +497,7 @@ pub enum PrimitiveLogicalType { Json, Bson, Uuid, + Float16, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -575,6 +576,7 @@ impl TryFrom for PrimitiveLogicalType { ParquetLogicalType::JSON(_) => PrimitiveLogicalType::Json, ParquetLogicalType::BSON(_) => PrimitiveLogicalType::Bson, ParquetLogicalType::UUID(_) => PrimitiveLogicalType::Uuid, + ParquetLogicalType::FLOAT16(_) => PrimitiveLogicalType::Float16, _ => return Err(ParquetError::oos("LogicalType value out of range")), }) } @@ -629,6 +631,7 @@ impl From for ParquetLogicalType { PrimitiveLogicalType::Json => ParquetLogicalType::JSON(Default::default()), PrimitiveLogicalType::Bson => ParquetLogicalType::BSON(Default::default()), PrimitiveLogicalType::Uuid => ParquetLogicalType::UUID(Default::default()), + PrimitiveLogicalType::Float16 => ParquetLogicalType::FLOAT16(Default::default()), } } } diff --git a/crates/polars-parquet/src/parquet/read/compression.rs b/crates/polars-parquet/src/parquet/read/compression.rs index a79989c39e26..1bd457474e06 100644 --- a/crates/polars-parquet/src/parquet/read/compression.rs +++ b/crates/polars-parquet/src/parquet/read/compression.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::DataPageHeaderV2; +use polars_parquet_format::DataPageHeaderV2; use super::PageReader; use crate::parquet::compression::{self, Compression}; diff --git a/crates/polars-parquet/src/parquet/read/metadata.rs b/crates/polars-parquet/src/parquet/read/metadata.rs index e14a2a60e997..a260fe71ff06 100644 --- a/crates/polars-parquet/src/parquet/read/metadata.rs +++ b/crates/polars-parquet/src/parquet/read/metadata.rs @@ -1,8 +1,8 @@ use std::cmp::min; use std::io::{Read, Seek, SeekFrom}; -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; -use parquet_format_safe::FileMetaData as TFileMetadata; +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::FileMetaData as TFileMetadata; use super::super::metadata::FileMetadata; use super::super::{DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, HEADER_SIZE, PARQUET_MAGIC}; diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs index cd23af0499d7..7dfa2e144d8d 100644 --- a/crates/polars-parquet/src/parquet/read/page/reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -1,7 +1,7 @@ use std::io::Seek; use std::sync::OnceLock; -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; use polars_utils::mmap::{MemReader, MemSlice}; use super::PageIterator; @@ -13,6 +13,7 @@ use crate::parquet::page::{ ParquetPageHeader, }; use crate::parquet::CowBuffer; +use crate::write::Encoding; /// This meta is a small part of [`ColumnChunkMetadata`]. #[derive(Debug, Clone, PartialEq, Eq)] @@ -96,7 +97,7 @@ impl PageReader { Self::new_with_page_meta(reader, column.into(), scratch, max_page_size) } - /// Create a a new [`PageReader`] with [`PageMetaData`]. + /// Create a new [`PageReader`] with [`PageMetaData`]. /// /// It assumes that the reader has been `sought` (`seek`) to the beginning of `column`. pub fn new_with_page_meta( @@ -251,7 +252,10 @@ pub(super) fn finish_page( })?; if do_verbose { - println!("DictPage ( )"); + eprintln!( + "Parquet DictPage ( num_values: {}, datatype: {:?} )", + dict_header.num_values, descriptor.primitive_type + ); } let is_sorted = dict_header.is_sorted.unwrap_or(false); @@ -275,9 +279,11 @@ pub(super) fn finish_page( })?; if do_verbose { - println!( - "DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", - header.num_values, descriptor.primitive_type, header.encoding + eprintln!( + "Parquet DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() ); } @@ -298,8 +304,10 @@ pub(super) fn finish_page( if do_verbose { println!( - "DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", - header.num_values, descriptor.primitive_type, header.encoding + "Parquet DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() ); } diff --git a/crates/polars-parquet/src/parquet/read/page/stream.rs b/crates/polars-parquet/src/parquet/read/page/stream.rs index fbd36b3ccfe1..7145689493fc 100644 --- a/crates/polars-parquet/src/parquet/read/page/stream.rs +++ b/crates/polars-parquet/src/parquet/read/page/stream.rs @@ -2,7 +2,7 @@ use std::io::SeekFrom; use async_stream::try_stream; use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream}; -use parquet_format_safe::thrift::protocol::TCompactInputStreamProtocol; +use polars_parquet_format::thrift::protocol::TCompactInputStreamProtocol; use polars_utils::mmap::MemSlice; use super::reader::{finish_page, PageMetaData}; diff --git a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs index d4f2c692e95d..ccf293e48b08 100644 --- a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs +++ b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs @@ -42,7 +42,7 @@ //! println!("{:?}", schema); //! ``` -use parquet_format_safe::Type; +use polars_parquet_format::Type; use polars_utils::pl_str::PlSmallStr; use types::PrimitiveLogicalType; @@ -303,7 +303,7 @@ fn parse_timeunit( }) } -impl<'a> Parser<'a> { +impl Parser<'_> { // Entry function to parse message type, uses internal tokenizer. fn parse_message_type(&mut self) -> ParquetResult { // Check that message type starts with "message". diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs index b0bbe20999bc..7a874fb59e46 100644 --- a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::SchemaElement; +use polars_parquet_format::SchemaElement; use polars_utils::pl_str::PlSmallStr; use super::super::types::ParquetType; diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs index 3aef1fe792fa..db372b733593 100644 --- a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::{ConvertedType, SchemaElement}; +use polars_parquet_format::{ConvertedType, SchemaElement}; use super::super::types::ParquetType; use crate::parquet::schema::types::PrimitiveType; diff --git a/crates/polars-parquet/src/parquet/schema/types/converted_type.rs b/crates/polars-parquet/src/parquet/schema/types/converted_type.rs index 8432167fcd3b..91b4b5ac78f9 100644 --- a/crates/polars-parquet/src/parquet/schema/types/converted_type.rs +++ b/crates/polars-parquet/src/parquet/schema/types/converted_type.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::ConvertedType; +use polars_parquet_format::ConvertedType; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/schema/types/physical_type.rs b/crates/polars-parquet/src/parquet/schema/types/physical_type.rs index 01595134c6b3..ed7242adac71 100644 --- a/crates/polars-parquet/src/parquet/schema/types/physical_type.rs +++ b/crates/polars-parquet/src/parquet/schema/types/physical_type.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Type; +use polars_parquet_format::Type; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/schema/types/spec.rs b/crates/polars-parquet/src/parquet/schema/types/spec.rs index f21cdbb9d611..f18f7ae9015b 100644 --- a/crates/polars-parquet/src/parquet/schema/types/spec.rs +++ b/crates/polars-parquet/src/parquet/schema/types/spec.rs @@ -170,6 +170,7 @@ pub fn check_logical_invariants( (String | Json | Bson, PhysicalType::ByteArray) => {}, // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#uuid (Uuid, PhysicalType::FixedLenByteArray(16)) => {}, + (Float16, PhysicalType::FixedLenByteArray(2)) => {}, (a, b) => { return Err(ParquetError::oos(format!( "Cannot annotate {:?} from {:?} fields", diff --git a/crates/polars-parquet/src/parquet/statistics/binary.rs b/crates/polars-parquet/src/parquet/statistics/binary.rs index e9506c375a71..7f1dabf21ec8 100644 --- a/crates/polars-parquet/src/parquet/statistics/binary.rs +++ b/crates/polars-parquet/src/parquet/statistics/binary.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::ParquetResult; use crate::parquet::schema::types::PrimitiveType; @@ -32,8 +32,10 @@ impl BinaryStatistics { distinct_count: self.distinct_count, max_value: self.max_value.clone(), min_value: self.min_value.clone(), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/statistics/boolean.rs b/crates/polars-parquet/src/parquet/statistics/boolean.rs index 607897bdddf0..55e478d5b957 100644 --- a/crates/polars-parquet/src/parquet/statistics/boolean.rs +++ b/crates/polars-parquet/src/parquet/statistics/boolean.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::{ParquetError, ParquetResult}; @@ -13,14 +13,14 @@ pub struct BooleanStatistics { impl BooleanStatistics { pub fn deserialize(v: &ParquetStatistics) -> ParquetResult { if let Some(ref v) = v.max_value { - if v.len() != std::mem::size_of::() { + if v.len() != size_of::() { return Err(ParquetError::oos( "The max_value of statistics MUST be plain encoded", )); } }; if let Some(ref v) = v.min_value { - if v.len() != std::mem::size_of::() { + if v.len() != size_of::() { return Err(ParquetError::oos( "The min_value of statistics MUST be plain encoded", )); @@ -49,8 +49,10 @@ impl BooleanStatistics { distinct_count: self.distinct_count, max_value: self.max_value.map(|x| vec![x as u8]), min_value: self.min_value.map(|x| vec![x as u8]), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs b/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs index 8de2aef0a508..87642246907d 100644 --- a/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs +++ b/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::schema::types::PrimitiveType; @@ -54,8 +54,10 @@ impl FixedLenStatistics { distinct_count: self.distinct_count, max_value: self.max_value.clone(), min_value: self.min_value.clone(), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/statistics/mod.rs b/crates/polars-parquet/src/parquet/statistics/mod.rs index 7501df2a3886..03335c27817b 100644 --- a/crates/polars-parquet/src/parquet/statistics/mod.rs +++ b/crates/polars-parquet/src/parquet/statistics/mod.rs @@ -41,6 +41,34 @@ impl Statistics { } } + pub fn clear_min(&mut self) { + use Statistics as S; + match self { + S::Binary(s) => _ = s.min_value.take(), + S::Boolean(s) => _ = s.min_value.take(), + S::FixedLen(s) => _ = s.min_value.take(), + S::Int32(s) => _ = s.min_value.take(), + S::Int64(s) => _ = s.min_value.take(), + S::Int96(s) => _ = s.min_value.take(), + S::Float(s) => _ = s.min_value.take(), + S::Double(s) => _ = s.min_value.take(), + }; + } + + pub fn clear_max(&mut self) { + use Statistics as S; + match self { + S::Binary(s) => _ = s.max_value.take(), + S::Boolean(s) => _ = s.max_value.take(), + S::FixedLen(s) => _ = s.max_value.take(), + S::Int32(s) => _ = s.max_value.take(), + S::Int64(s) => _ = s.max_value.take(), + S::Int96(s) => _ = s.max_value.take(), + S::Float(s) => _ = s.max_value.take(), + S::Double(s) => _ = s.max_value.take(), + }; + } + /// Deserializes a raw parquet statistics into [`Statistics`]. /// # Error /// This function errors if it is not possible to read the statistics to the @@ -51,7 +79,7 @@ impl Statistics { primitive_type: PrimitiveType, ) -> ParquetResult { use {PhysicalType as T, PrimitiveStatistics as PrimStat}; - Ok(match primitive_type.physical_type { + let mut stats: Self = match primitive_type.physical_type { T::ByteArray => BinaryStatistics::deserialize(statistics, primitive_type)?.into(), T::Boolean => BooleanStatistics::deserialize(statistics)?.into(), T::Int32 => PrimStat::::deserialize(statistics, primitive_type)?.into(), @@ -62,7 +90,16 @@ impl Statistics { T::FixedLenByteArray(size) => { FixedLenStatistics::deserialize(statistics, size, primitive_type)?.into() }, - }) + }; + + if statistics.is_min_value_exact.is_some_and(|v| !v) { + stats.clear_min(); + } + if statistics.is_max_value_exact.is_some_and(|v| !v) { + stats.clear_max(); + } + + Ok(stats) } } diff --git a/crates/polars-parquet/src/parquet/statistics/primitive.rs b/crates/polars-parquet/src/parquet/statistics/primitive.rs index ed5ae71515b0..7bd8d227faa2 100644 --- a/crates/polars-parquet/src/parquet/statistics/primitive.rs +++ b/crates/polars-parquet/src/parquet/statistics/primitive.rs @@ -1,4 +1,4 @@ -use parquet_format_safe::Statistics as ParquetStatistics; +use polars_parquet_format::Statistics as ParquetStatistics; use crate::parquet::error::{ParquetError, ParquetResult}; use crate::parquet::schema::types::PrimitiveType; @@ -20,7 +20,7 @@ impl PrimitiveStatistics { ) -> ParquetResult { if v.max_value .as_ref() - .is_some_and(|v| v.len() != std::mem::size_of::()) + .is_some_and(|v| v.len() != size_of::()) { return Err(ParquetError::oos( "The max_value of statistics MUST be plain encoded", @@ -28,7 +28,7 @@ impl PrimitiveStatistics { }; if v.min_value .as_ref() - .is_some_and(|v| v.len() != std::mem::size_of::()) + .is_some_and(|v| v.len() != size_of::()) { return Err(ParquetError::oos( "The min_value of statistics MUST be plain encoded", @@ -50,8 +50,10 @@ impl PrimitiveStatistics { distinct_count: self.distinct_count, max_value: self.max_value.map(|x| x.to_le_bytes().as_ref().to_vec()), min_value: self.min_value.map(|x| x.to_le_bytes().as_ref().to_vec()), - min: None, max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, } } } diff --git a/crates/polars-parquet/src/parquet/types.rs b/crates/polars-parquet/src/parquet/types.rs index 7591f6ba0bd7..1dd65d0ff622 100644 --- a/crates/polars-parquet/src/parquet/types.rs +++ b/crates/polars-parquet/src/parquet/types.rs @@ -22,7 +22,7 @@ pub trait NativeType: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { macro_rules! native { ($type:ty, $physical_type:expr) => { impl NativeType for $type { - type Bytes = [u8; std::mem::size_of::()]; + type Bytes = [u8; size_of::()]; #[inline] fn to_le_bytes(&self) -> Self::Bytes { Self::to_le_bytes(*self) @@ -51,7 +51,7 @@ native!(f64, PhysicalType::Double); impl NativeType for [u32; 3] { const TYPE: PhysicalType = PhysicalType::Int96; - type Bytes = [u8; std::mem::size_of::()]; + type Bytes = [u8; size_of::()]; #[inline] fn to_le_bytes(&self) -> Self::Bytes { let mut bytes = [0; 12]; @@ -137,7 +137,7 @@ pub fn ord_binary<'a>(a: &'a [u8], b: &'a [u8]) -> std::cmp::Ordering { #[inline] pub fn decode(chunk: &[u8]) -> T { - assert!(chunk.len() >= std::mem::size_of::<::Bytes>()); + assert!(chunk.len() >= size_of::<::Bytes>()); unsafe { decode_unchecked(chunk) } } diff --git a/crates/polars-parquet/src/parquet/write/column_chunk.rs b/crates/polars-parquet/src/parquet/write/column_chunk.rs index 6ae51a191dc5..8728f289eb65 100644 --- a/crates/polars-parquet/src/parquet/write/column_chunk.rs +++ b/crates/polars-parquet/src/parquet/write/column_chunk.rs @@ -2,10 +2,10 @@ use std::io::Write; #[cfg(feature = "async")] use futures::AsyncWrite; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; #[cfg(feature = "async")] -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; -use parquet_format_safe::{ColumnChunk, ColumnMetaData, Type}; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::{ColumnChunk, ColumnMetaData, Type}; use polars_utils::aliases::PlHashSet; #[cfg(feature = "async")] @@ -195,6 +195,8 @@ fn build_column_chunk( statistics, encoding_stats: None, bloom_filter_offset: None, + bloom_filter_length: None, + size_statistics: None, }; Ok(ColumnChunk { diff --git a/crates/polars-parquet/src/parquet/write/dyn_iter.rs b/crates/polars-parquet/src/parquet/write/dyn_iter.rs index f47710b56b22..a232c06375e8 100644 --- a/crates/polars-parquet/src/parquet/write/dyn_iter.rs +++ b/crates/polars-parquet/src/parquet/write/dyn_iter.rs @@ -7,7 +7,7 @@ pub struct DynIter<'a, V> { iter: Box + 'a + Send + Sync>, } -impl<'a, V> Iterator for DynIter<'a, V> { +impl Iterator for DynIter<'_, V> { type Item = V; fn next(&mut self) -> Option { self.iter.next() @@ -35,7 +35,7 @@ pub struct DynStreamingIterator<'a, V, E> { iter: Box + 'a + Send + Sync>, } -impl<'a, V, E> FallibleStreamingIterator for DynStreamingIterator<'a, V, E> { +impl FallibleStreamingIterator for DynStreamingIterator<'_, V, E> { type Item = V; type Error = E; diff --git a/crates/polars-parquet/src/parquet/write/file.rs b/crates/polars-parquet/src/parquet/write/file.rs index 8dd3212bb76a..d46f85dd3138 100644 --- a/crates/polars-parquet/src/parquet/write/file.rs +++ b/crates/polars-parquet/src/parquet/write/file.rs @@ -1,7 +1,7 @@ use std::io::Write; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; -use parquet_format_safe::RowGroup; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::RowGroup; use super::indexes::{write_column_index, write_offset_index}; use super::page::PageWriteSpec; @@ -39,7 +39,7 @@ pub(super) fn end_file( Ok(metadata_len as u64 + FOOTER_SIZE) } -fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec { +fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec { // We only include ColumnOrder for leaf nodes. // Currently only supported ColumnOrder is TypeDefinedOrder so we set this // for all leaf nodes. @@ -47,7 +47,9 @@ fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec ParquetResult ParquetResult>>()?; - Ok(OffsetIndex { page_locations }) + Ok(OffsetIndex { + page_locations, + unencoded_byte_array_data_bytes: None, + }) } diff --git a/crates/polars-parquet/src/parquet/write/indexes/write.rs b/crates/polars-parquet/src/parquet/write/indexes/write.rs index 7c82b1dcc9ae..73325654e518 100644 --- a/crates/polars-parquet/src/parquet/write/indexes/write.rs +++ b/crates/polars-parquet/src/parquet/write/indexes/write.rs @@ -2,9 +2,9 @@ use std::io::Write; #[cfg(feature = "async")] use futures::AsyncWrite; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; #[cfg(feature = "async")] -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; use super::serialize::{serialize_column_index, serialize_offset_index}; use crate::parquet::error::ParquetResult; diff --git a/crates/polars-parquet/src/parquet/write/page.rs b/crates/polars-parquet/src/parquet/write/page.rs index f9e527d5a9db..8fb65c3daf12 100644 --- a/crates/polars-parquet/src/parquet/write/page.rs +++ b/crates/polars-parquet/src/parquet/write/page.rs @@ -2,10 +2,10 @@ use std::io::Write; #[cfg(feature = "async")] use futures::{AsyncWrite, AsyncWriteExt}; -use parquet_format_safe::thrift::protocol::TCompactOutputProtocol; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; #[cfg(feature = "async")] -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; -use parquet_format_safe::{DictionaryPageHeader, Encoding, PageType}; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::{DictionaryPageHeader, Encoding, PageType}; use crate::parquet::compression::Compression; use crate::parquet::error::{ParquetError, ParquetResult}; diff --git a/crates/polars-parquet/src/parquet/write/row_group.rs b/crates/polars-parquet/src/parquet/write/row_group.rs index 43404dc32a89..dfca3d27f948 100644 --- a/crates/polars-parquet/src/parquet/write/row_group.rs +++ b/crates/polars-parquet/src/parquet/write/row_group.rs @@ -2,7 +2,7 @@ use std::io::Write; #[cfg(feature = "async")] use futures::AsyncWrite; -use parquet_format_safe::{ColumnChunk, RowGroup}; +use polars_parquet_format::{ColumnChunk, RowGroup}; use super::column_chunk::write_column_chunk; #[cfg(feature = "async")] diff --git a/crates/polars-parquet/src/parquet/write/stream.rs b/crates/polars-parquet/src/parquet/write/stream.rs index eca712db65dc..05c50e6e3a2c 100644 --- a/crates/polars-parquet/src/parquet/write/stream.rs +++ b/crates/polars-parquet/src/parquet/write/stream.rs @@ -1,8 +1,8 @@ use std::io::Write; use futures::{AsyncWrite, AsyncWriteExt}; -use parquet_format_safe::thrift::protocol::TCompactOutputStreamProtocol; -use parquet_format_safe::RowGroup; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::RowGroup; use super::row_group::write_row_group_async; use super::{RowGroupIterColumns, WriteOptions}; @@ -20,7 +20,7 @@ async fn start_file(writer: &mut W) -> ParquetResult async fn end_file( mut writer: &mut W, - metadata: parquet_format_safe::FileMetaData, + metadata: polars_parquet_format::FileMetaData, ) -> ParquetResult { // Write file metadata let mut protocol = TCompactOutputStreamProtocol::new(&mut writer); @@ -169,7 +169,7 @@ impl FileStreamer { } } - let metadata = parquet_format_safe::FileMetaData::new( + let metadata = polars_parquet_format::FileMetaData::new( self.options.version.into(), self.schema.clone().into_thrift(), num_rows, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs index e6711b7a339b..14ce4836096c 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs @@ -208,7 +208,7 @@ where let agg_fn = match logical_dtype.to_physical() { // Boolean is aggregated as the IDX type. DataType::Boolean => { - if std::mem::size_of::() == 4 { + if size_of::() == 4 { AggregateFunction::SumU32(SumAgg::::new()) } else { AggregateFunction::SumU64(SumAgg::::new()) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs index 77a939c64290..280cd236afa6 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs @@ -8,7 +8,7 @@ use crate::pipeline::{morsels_per_sink, FORCE_OOC}; pub(super) struct OocState { // OOC // Stores available memory in the system at the start of this sink. - // and stores the memory used by this this sink. + // and stores the memory used by this sink. mem_track: MemTracker, // sort in-memory or out-of-core pub(super) ooc: bool, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs index 1f79e20bcdab..f2c664087daf 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs @@ -13,7 +13,7 @@ use crate::pipeline::morsels_per_sink; pub(super) struct OocState { // OOC // Stores available memory in the system at the start of this sink. - // and stores the memory used by this this sink. + // and stores the memory used by this sink. _mem_track: MemTracker, // sort in-memory or out-of-core pub(super) ooc: bool, diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs index d6014c344978..77466578d3e9 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/cross.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -111,7 +111,7 @@ impl Operator for CrossJoinProbe { _context: &PExecutionContext, chunk: &DataChunk, ) -> PolarsResult { - // Expected output is size**2, so this needs to be a a small number. + // Expected output is size**2, so this needs to be a small number. // However, if one of the DataFrames is much smaller than 250, we want // to take rather more from the other DataFrame so we don't end up with // overly small chunks. diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 43589c9783a1..49d51cc2e2fb 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -20,7 +20,7 @@ pub struct SortSink { schema: SchemaRef, chunks: Vec, // Stores available memory in the system at the start of this sink. - // and stores the memory used by this this sink. + // and stores the memory used by this sink. mem_track: MemTracker, // sort in-memory or out-of-core ooc: bool, diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 9cd85697ebaa..7c0a35db38a1 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,8 +1,8 @@ use std::any::Any; use arrow::array::BinaryArray; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_compat_array; use polars_core::prelude::sort::_broadcast_bools; -use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_row::decode::decode_rows_from_binary; diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index 38484f9c7255..900be25256b4 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -216,18 +216,11 @@ impl Source for CsvSource { }; for data_chunk in &mut out { - // The batched reader creates the column containing all nulls because the schema it - // gets passed contains the column. - // + let n = data_chunk.data.height(); // SAFETY: Columns are only replaced with columns // 1. of the same name, and // 2. of the same length. - for s in unsafe { data_chunk.data.get_columns_mut() } { - if s.name() == ca.name() { - *s = ca.slice(0, s.len()).into_column(); - break; - } - } + unsafe { data_chunk.data.get_columns_mut() }.push(ca.slice(0, n).into_column()) } } diff --git a/crates/polars-pipe/src/operators/chunks.rs b/crates/polars-pipe/src/operators/chunks.rs index c1f63019a611..d237510b4636 100644 --- a/crates/polars-pipe/src/operators/chunks.rs +++ b/crates/polars-pipe/src/operators/chunks.rs @@ -39,7 +39,7 @@ pub(crate) fn chunks_to_df_unchecked(chunks: Vec) -> DataFrame { /// /// The benefit of having a series of `DataFrame` that are e.g. 4MB each that /// are then made contiguous is that you're not using a lot of memory (an extra -/// 4MB), but you're still doing better than if you had a series of of 2KB +/// 4MB), but you're still doing better than if you had a series of 2KB /// `DataFrame`s. /// /// Changing the `DataFrame` into contiguous chunks is the caller's diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 23df59fa59f7..b0b19aa26708 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -109,7 +109,7 @@ where FileScan::Csv { options, .. } => { let src = sources::CsvSource::new( sources, - file_info.schema, + file_info.reader_schema.clone().unwrap().unwrap_right(), options, file_options, verbose, diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 499bb115396a..9c5a3fe32913 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -34,6 +34,7 @@ either = { workspace = true } futures = { workspace = true, optional = true } hashbrown = { workspace = true } memmap = { workspace = true } +num-traits = { workspace = true } once_cell = { workspace = true } percent-encoding = { workspace = true } pyo3 = { workspace = true, optional = true } @@ -50,7 +51,7 @@ version_check = { workspace = true } [features] # debugging utility debugging = [] -python = ["dep:pyo3", "ciborium"] +python = ["dep:pyo3", "ciborium", "polars-utils/python"] serde = [ "dep:serde", "polars-core/serde-lazy", @@ -185,7 +186,7 @@ month_start = ["polars-time/month_start"] month_end = ["polars-time/month_end"] offset_by = ["polars-time/offset_by"] -bigidx = ["polars-core/bigidx"] +bigidx = ["polars-core/bigidx", "polars-utils/bigidx"] polars_cloud = ["serde", "ciborium"] ir_serde = ["serde", "polars-utils/ir_serde"] diff --git a/crates/polars-plan/src/client/check.rs b/crates/polars-plan/src/client/check.rs index f76f508643f0..97b1ed23da2e 100644 --- a/crates/polars-plan/src/client/check.rs +++ b/crates/polars-plan/src/client/check.rs @@ -6,6 +6,9 @@ use crate::plans::{DslPlan, FileScan, ScanSources}; /// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud. pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> { + if std::env::var("POLARS_SKIP_CLIENT_CHECK").as_deref() == Ok("1") { + return Ok(()); + } for plan_node in dsl.into_iter() { match plan_node { #[cfg(feature = "python")] diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 33eb20e86da6..32fa45528d3a 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -33,7 +33,7 @@ pub enum AggExpr { Quantile { expr: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Arc), AggGroups(Arc), diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 2c0acfeeb13b..483dafcc83f1 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -71,7 +71,7 @@ impl<'a> Deserialize<'a> for SpecialEq> { { let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(udf)) @@ -399,7 +399,7 @@ impl<'a> Deserialize<'a> for GetOutput { { let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(get_output)) diff --git a/crates/polars-plan/src/dsl/function_expr/bitwise.rs b/crates/polars-plan/src/dsl/function_expr/bitwise.rs index 2d4dd779cff0..1f0be9247993 100644 --- a/crates/polars-plan/src/dsl/function_expr/bitwise.rs +++ b/crates/polars-plan/src/dsl/function_expr/bitwise.rs @@ -2,6 +2,7 @@ use std::fmt; use std::sync::Arc; use polars_core::prelude::*; +use strum_macros::IntoStaticStr; use super::{ColumnsUdf, SpecialEq}; use crate::dsl::FieldsMapper; @@ -21,7 +22,8 @@ pub enum BitwiseFunction { } #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)] +#[strum(serialize_all = "snake_case")] pub enum BitwiseAggFunction { And, Or, diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 9225dfa4cae0..ddf8fb1fff20 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -4,7 +4,7 @@ use polars_ops::chunked_array::list::*; use super::*; use crate::{map, map_as_slice, wrap}; -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ListFunction { Concat, @@ -56,6 +56,8 @@ pub enum ListFunction { Join(bool), #[cfg(feature = "dtype-array")] ToArray(usize), + #[cfg(feature = "list_to_struct")] + ToStruct(ListToStructArgs), } impl ListFunction { @@ -103,6 +105,8 @@ impl ListFunction { #[cfg(feature = "dtype-array")] ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), NUnique => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), } } } @@ -174,6 +178,8 @@ impl Display for ListFunction { Join(_) => "join", #[cfg(feature = "dtype-array")] ToArray(_) => "to_array", + #[cfg(feature = "list_to_struct")] + ToStruct(_) => "to_struct", }; write!(f, "list.{name}") } @@ -235,6 +241,8 @@ impl From for SpecialEq> { #[cfg(feature = "dtype-array")] ToArray(width) => map!(to_array, width), NUnique => map!(n_unique), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => map!(to_struct, &args), } } } @@ -503,7 +511,7 @@ pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult let idx = &args[1]; let ca = ca.list()?; - if idx.len() == 1 && null_on_oob { + if idx.len() == 1 && idx.dtype().is_numeric() && null_on_oob { // fast path let idx = idx.get(0)?.try_extract::()?; let out = ca.lst_get(idx, null_on_oob).map(Column::from)?; @@ -650,6 +658,11 @@ pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult { s.cast(&array_dtype) } +#[cfg(feature = "list_to_struct")] +pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult { + Ok(s.list()?.to_struct(args)?.into_series().into()) +} + pub(super) fn n_unique(s: &Column) -> PolarsResult { Ok(s.list()?.lst_n_unique()?.into_column()) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 6cebaa301b85..17dc82e23e8c 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -954,6 +954,19 @@ impl From for SpecialEq> { Std(options) => map!(rolling::rolling_std, options.clone()), #[cfg(feature = "moment")] Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias), + #[cfg(feature = "cov")] + CorrCov { + rolling_options, + corr_cov_options, + is_corr, + } => { + map_as_slice!( + rolling::rolling_corr_cov, + rolling_options.clone(), + corr_cov_options, + is_corr + ) + }, } }, #[cfg(feature = "rolling_window_by")] diff --git a/crates/polars-plan/src/dsl/function_expr/pow.rs b/crates/polars-plan/src/dsl/function_expr/pow.rs index 44394e9ae10a..912aafc762d9 100644 --- a/crates/polars-plan/src/dsl/function_expr/pow.rs +++ b/crates/polars-plan/src/dsl/function_expr/pow.rs @@ -1,8 +1,8 @@ -use arrow::legacy::kernels::pow::pow as pow_kernel; use num::pow::Pow; +use num_traits::{One, Zero}; use polars_core::export::num; use polars_core::export::num::{Float, ToPrimitive}; -use polars_core::prelude::arity::unary_elementwise_values; +use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values}; use polars_core::with_match_physical_integer_type; use super::*; @@ -29,30 +29,27 @@ impl Display for PowFunction { fn pow_on_chunked_arrays( base: &ChunkedArray, exponent: &ChunkedArray, -) -> PolarsResult> +) -> ChunkedArray where T: PolarsNumericType, F: PolarsNumericType, T::Native: num::pow::Pow + ToPrimitive, - ChunkedArray: IntoColumn, { - if (base.len() == 1) && (exponent.len() != 1) { - let name = base.name(); - let base = base - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "base is null"))?; - - Ok(Some( - unary_elementwise_values(exponent, |exp| Pow::pow(base, exp)) - .into_column() - .with_name(name.clone()), - )) - } else { - Ok(Some( - polars_core::chunked_array::ops::arity::binary(base, exponent, pow_kernel) - .into_column(), - )) + if exponent.len() == 1 { + if let Some(e) = exponent.get(0) { + if e == F::Native::zero() { + return unary_elementwise_values(base, |_| T::Native::one()); + } + if e == F::Native::one() { + return base.clone(); + } + if e == F::Native::one() + F::Native::one() { + return base * base; + } + } } + + broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?))) } fn pow_on_floats( @@ -93,7 +90,7 @@ where }; Ok(Some(s)) } else { - pow_on_chunked_arrays(base, exponent) + Ok(Some(pow_on_chunked_arrays(base, exponent).into_column())) } } @@ -133,7 +130,7 @@ where }; Ok(Some(s)) } else { - pow_on_chunked_arrays(base, exponent) + Ok(Some(pow_on_chunked_arrays(base, exponent).into_column())) } } diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs index a61264ce7aca..e220a7107435 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -223,7 +223,7 @@ pub(super) fn datetime_ranges( out.cast(&to_type).map(Column::from) } -impl<'a> FieldsMapper<'a> { +impl FieldsMapper<'_> { pub(super) fn map_to_datetime_range_dtype( &self, time_unit: Option<&TimeUnit>, diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index c108c92b571a..44af4f1b1510 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -1,6 +1,12 @@ +#[cfg(feature = "cov")] +use std::ops::BitAnd; + +use polars_core::utils::Container; use polars_time::chunkedarray::*; use super::*; +#[cfg(feature = "cov")] +use crate::dsl::pow::pow; #[derive(Clone, PartialEq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -14,6 +20,13 @@ pub enum RollingFunction { Std(RollingOptionsFixedWindow), #[cfg(feature = "moment")] Skew(usize, bool), + #[cfg(feature = "cov")] + CorrCov { + rolling_options: RollingOptionsFixedWindow, + corr_cov_options: RollingCovOptions, + // Whether is Corr or Cov + is_corr: bool, + }, } impl Display for RollingFunction { @@ -30,6 +43,14 @@ impl Display for RollingFunction { Std(_) => "rolling_std", #[cfg(feature = "moment")] Skew(..) => "rolling_skew", + #[cfg(feature = "cov")] + CorrCov { is_corr, .. } => { + if *is_corr { + "rolling_corr" + } else { + "rolling_cov" + } + }, }; write!(f, "{name}") @@ -47,6 +68,10 @@ impl Hash for RollingFunction { window_size.hash(state); bias.hash(state) }, + #[cfg(feature = "cov")] + CorrCov { is_corr, .. } => { + is_corr.hash(state); + }, _ => {}, } } @@ -111,3 +136,100 @@ pub(super) fn rolling_skew(s: &Column, window_size: usize, bias: bool) -> Polars .rolling_skew(window_size, bias) .map(Column::from) } + +#[cfg(feature = "cov")] +fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series { + match dtype { + DataType::Float64 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f64) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + DataType::Float32 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f32) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + _ => unreachable!(), + } +} + +#[cfg(feature = "cov")] +pub(super) fn rolling_corr_cov( + s: &[Column], + rolling_options: RollingOptionsFixedWindow, + cov_options: RollingCovOptions, + is_corr: bool, +) -> PolarsResult { + let mut x = s[0].as_materialized_series().rechunk(); + let mut y = s[1].as_materialized_series().rechunk(); + + if !x.dtype().is_float() { + x = x.cast(&DataType::Float64)?; + } + if !y.dtype().is_float() { + y = y.cast(&DataType::Float64)?; + } + let dtype = x.dtype().clone(); + + let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?; + let rolling_options_count = RollingOptionsFixedWindow { + window_size: rolling_options.window_size, + min_periods: 0, + ..Default::default() + }; + + let count_x_y = if (x.null_count() + y.null_count()) > 0 { + // mask out nulls on both sides before compute mean/var + let valids = x.is_not_null().bitand(y.is_not_null()); + let valids_arr = valids.clone().downcast_into_array(); + let valids_bitmap = valids_arr.values(); + + unsafe { + let xarr = &mut x.chunks_mut()[0]; + *xarr = xarr.with_validity(Some(valids_bitmap.clone())); + let yarr = &mut y.chunks_mut()[0]; + *yarr = yarr.with_validity(Some(valids_bitmap.clone())); + x.compute_len(); + y.compute_len(); + } + valids + .cast(&dtype) + .unwrap() + .rolling_sum(rolling_options_count)? + } else { + det_count_x_y(rolling_options.window_size, x.len(), &dtype) + }; + + let mean_x = x.rolling_mean(rolling_options.clone())?; + let mean_y = y.rolling_mean(rolling_options.clone())?; + let ddof = Series::new( + PlSmallStr::EMPTY, + &[AnyValue::from(cov_options.ddof).cast(&dtype)], + ); + + let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap() + * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap()) + .unwrap(); + + if is_corr { + let var_x = x.rolling_var(rolling_options.clone())?; + let var_y = y.rolling_var(rolling_options.clone())?; + + let base = (var_x * var_y).unwrap(); + let sc = Scalar::new( + base.dtype().clone(), + AnyValue::Float64(0.5).cast(&dtype).into_static(), + ); + let denominator = pow(&mut [base.into_column(), sc.into_column("".into())]) + .unwrap() + .unwrap() + .take_materialized_series(); + + Ok((numerator / denominator)?.into_column()) + } else { + Ok(numerator.into_column()) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index bc09cca94215..606ab81207c4 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -66,6 +66,8 @@ impl FunctionExpr { match rolling_func { Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(), Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(), + #[cfg(feature = "cov")] + CorrCov {..} => mapper.map_to_float_dtype(), #[cfg(feature = "moment")] Skew(..) => mapper.map_to_float_dtype(), } diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index ba06dc00e67c..039c995557be 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -9,7 +9,7 @@ use polars_core::utils::handle_casting_failures; #[cfg(feature = "dtype-struct")] use polars_utils::format_pl_smallstr; #[cfg(feature = "regex")] -use regex::{escape, Regex}; +use regex::{escape, NoExpand, Regex}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -130,6 +130,8 @@ pub enum StringFunction { ascii_case_insensitive: bool, overlapping: bool, }, + #[cfg(feature = "regex")] + EscapeRegex, } impl StringFunction { @@ -197,6 +199,8 @@ impl StringFunction { ReplaceMany { .. } => mapper.with_same_dtype(), #[cfg(feature = "find_many")] ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))), + #[cfg(feature = "regex")] + EscapeRegex => mapper.with_same_dtype(), } } } @@ -285,6 +289,8 @@ impl Display for StringFunction { ReplaceMany { .. } => "replace_many", #[cfg(feature = "find_many")] ExtractMany { .. } => "extract_many", + #[cfg(feature = "regex")] + EscapeRegex => "escape_regex", }; write!(f, "str.{s}") } @@ -400,6 +406,8 @@ impl From for SpecialEq> { } => { map_as_slice!(extract_many, ascii_case_insensitive, overlapping) }, + #[cfg(feature = "regex")] + EscapeRegex => map!(escape_regex), } } } @@ -836,20 +844,26 @@ fn replace_n<'a>( "replacement value length ({}) does not match string column length ({})", len_val, ca.len(), ); - let literal = literal || is_literal_pat(&pat); + let lit = is_literal_pat(&pat); + let literal_pat = literal || lit; - if literal { + if literal_pat { pat = escape(&pat) } let reg = Regex::new(&pat)?; - let lit = pat.chars().all(|c| !c.is_ascii_punctuation()); let f = |s: &'a str, val: &'a str| { if lit && (s.len() <= 32) { Cow::Owned(s.replacen(&pat, val, 1)) } else { - reg.replace(s, val) + // According to the docs for replace + // when literal = True then capture groups are ignored. + if literal { + reg.replace(s, NoExpand(val)) + } else { + reg.replace(s, val) + } } }; Ok(iter_and_replace(ca, val, f)) @@ -888,15 +902,25 @@ fn replace_all<'a>( "replacement value length ({}) does not match string column length ({})", len_val, ca.len(), ); - let literal = literal || is_literal_pat(&pat); - if literal { + let literal_pat = literal || is_literal_pat(&pat); + + if literal_pat { pat = escape(&pat) } let reg = Regex::new(&pat)?; - let f = |s: &'a str, val: &'a str| reg.replace_all(s, val); + let f = |s: &'a str, val: &'a str| { + // According to the docs for replace_all + // when literal = True then capture groups are ignored. + if literal { + reg.replace_all(s, NoExpand(val)) + } else { + reg.replace_all(s, val) + } + }; + Ok(iter_and_replace(ca, val, f)) }, _ => polars_bail!( @@ -1023,3 +1047,9 @@ pub(super) fn json_path_match(s: &[Column]) -> PolarsResult { let pat = s[1].str()?; Ok(ca.json_path_match(pat)?.into_column()) } + +#[cfg(feature = "regex")] +pub(super) fn escape_regex(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_escape_regex().into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/trigonometry.rs b/crates/polars-plan/src/dsl/function_expr/trigonometry.rs index c0d83822aef9..5398f43f4323 100644 --- a/crates/polars-plan/src/dsl/function_expr/trigonometry.rs +++ b/crates/polars-plan/src/dsl/function_expr/trigonometry.rs @@ -1,5 +1,5 @@ -use arrow::legacy::kernels::atan2::atan2 as atan2_kernel; use num::Float; +use polars_core::chunked_array::ops::arity::broadcast_binary_elementwise; use polars_core::export::num; use super::*; @@ -117,23 +117,9 @@ where .unpack_series_matching_type(x.as_materialized_series()) .unwrap(); - if x.len() == 1 { - let x_value = x - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "arctan2 x value is null"))?; - - Ok(Some(y.apply_values(|v| v.atan2(x_value)).into_column())) - } else if y.len() == 1 { - let y_value = y - .get(0) - .ok_or_else(|| polars_err!(ComputeError: "arctan2 y value is null"))?; - - Ok(Some(x.apply_values(|v| y_value.atan2(v)).into_column())) - } else { - Ok(Some( - polars_core::prelude::arity::binary(y, x, atan2_kernel).into_column(), - )) - } + Ok(Some( + broadcast_binary_elementwise(y, x, |yv, xv| Some(yv?.atan2(xv?))).into_column(), + )) } fn apply_trigonometric_function_to_float( diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index dd7521ad20a9..97e14f5df2f8 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -70,8 +70,8 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E } } -#[cfg(feature = "rolling_window")] -pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { +#[cfg(all(feature = "rolling_window", feature = "cov"))] +fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804 let rolling_options = RollingOptionsFixedWindow { window_size: options.window_size as usize, @@ -79,59 +79,23 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; - let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) - .then(lit(1.0)) - .otherwise(lit(Null {})); - - let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let var_x = (x.clone() * non_null_mask.clone()).rolling_var(rolling_options.clone()); - let var_y = (y.clone() * non_null_mask.clone()).rolling_var(rolling_options); - - let rolling_options_count = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: 0, - ..Default::default() - }; - let ddof = options.ddof as f64; - let count_x_y = (x + y) - .is_not_null() - .cast(DataType::Float64) - .rolling_sum(rolling_options_count); - let numerator = (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof))); - let denominator = (var_x * var_y).pow(lit(0.5)); + Expr::Function { + input: vec![x, y], + function: FunctionExpr::RollingExpr(RollingFunction::CorrCov { + rolling_options, + corr_cov_options: options, + is_corr, + }), + options: Default::default(), + } +} - numerator / denominator +#[cfg(all(feature = "rolling_window", feature = "cov"))] +pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { + dispatch_corr_cov(x, y, options, true) } -#[cfg(feature = "rolling_window")] +#[cfg(all(feature = "rolling_window", feature = "cov"))] pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { - // see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700 - let rolling_options = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: options.min_periods as usize, - ..Default::default() - }; - - let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) - .then(lit(1.0)) - .otherwise(lit(Null {})); - - let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options); - let rolling_options_count = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: 0, - ..Default::default() - }; - let count_x_y = (x + y) - .is_not_null() - .cast(DataType::Float64) - .rolling_sum(rolling_options_count); - - let ddof = options.ddof as f64; - - (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof))) + dispatch_corr_cov(x, y, options, false) } diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index e1ef64ee02ec..4d0e4c105014 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -33,8 +33,8 @@ pub fn median(name: &str) -> Expr { } /// Find a specific quantile of all the values in the column named `name`. -pub fn quantile(name: &str, quantile: Expr, interpol: QuantileInterpolOptions) -> Expr { - col(name).quantile(quantile, interpol) +pub fn quantile(name: &str, quantile: Expr, method: QuantileMethod) -> Expr { + col(name).quantile(quantile, method) } /// Negates a boolean column. @@ -55,7 +55,7 @@ pub fn is_not_null(expr: Expr) -> Expr { /// Casts the column given by `Expr` to a different type. /// /// Follows the rules of Rust casting, with the exception that integers and floats can be cast to `DataType::Date` and -/// `DataType::DateTime(_, _)`. A column consisting entirely of of `Null` can be cast to any type, regardless of the +/// `DataType::DateTime(_, _)`. A column consisting entirely of `Null` can be cast to any type, regardless of the /// nominal type of the column. pub fn cast(expr: Expr, dtype: DataType) -> Expr { Expr::Cast { diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index 48508dba40d8..9a18abda55fd 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -427,7 +427,7 @@ pub fn duration(args: DurationArgs) -> Expr { function: FunctionExpr::TemporalExpr(TemporalFunction::Duration(args.time_unit)), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, + flags: FunctionFlags::default(), ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index fb0c7a83b463..3a1a37c9f393 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "list_to_struct")] -use std::sync::RwLock; - use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; @@ -281,50 +278,9 @@ impl ListNameSpace { /// an `upper_bound` of struct fields that will be set. /// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression /// will look in the current schema to determine which columns to select. - pub fn to_struct( - self, - n_fields: ListToStructWidthStrategy, - name_generator: Option, - upper_bound: usize, - ) -> Expr { - // heap allocate the output type and fill it later - let out_dtype = Arc::new(RwLock::new(None::)); - + pub fn to_struct(self, args: ListToStructArgs) -> Expr { self.0 - .map( - move |s| { - s.list()? - .to_struct(n_fields, name_generator.clone()) - .map(|s| Some(s.into_column())) - }, - // we don't yet know the fields - GetOutput::map_dtype(move |dt: &DataType| { - polars_ensure!(matches!(dt, DataType::List(_)), SchemaMismatch: "expected 'List' as input to 'list.to_struct' got {}", dt); - let out = out_dtype.read().unwrap(); - match out.as_ref() { - // dtype already set - Some(dt) => Ok(dt.clone()), - // dtype still unknown, set it - None => { - drop(out); - let mut lock = out_dtype.write().unwrap(); - - let inner = dt.inner_dtype().unwrap(); - let fields = (0..upper_bound) - .map(|i| { - let name = _default_struct_name_gen(i); - Field::new(name, inner.clone()) - }) - .collect(); - let dt = DataType::Struct(fields); - - *lock = Some(dt.clone()); - Ok(dt) - }, - } - }), - ) - .with_fmt("list.to_struct") + .map_private(FunctionExpr::ListExpr(ListFunction::ToStruct(args))) } #[cfg(feature = "is_in")] @@ -332,34 +288,24 @@ impl ListNameSpace { pub fn contains>(self, other: E) -> Expr { let other = other.into(); - self.0 - .map_many_private( - FunctionExpr::ListExpr(ListFunction::Contains), - &[other], - false, - None, - ) - .with_function_options(|mut options| { - options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; - options - }) + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Contains), + &[other], + false, + None, + ) } #[cfg(feature = "list_count")] /// Count how often the value produced by ``element`` occurs. pub fn count_matches>(self, element: E) -> Expr { let other = element.into(); - self.0 - .map_many_private( - FunctionExpr::ListExpr(ListFunction::CountMatches), - &[other], - false, - None, - ) - .with_function_options(|mut options| { - options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; - options - }) + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::CountMatches), + &[other], + false, + None, + ) } #[cfg(feature = "list_sets")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 5d814877a977..a88ff858e6ee 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -45,7 +45,7 @@ use std::sync::Arc; pub use arity::*; #[cfg(feature = "dtype-array")] pub use array::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; pub use expr::*; pub use function_expr::schema::FieldsMapper; pub use function_expr::*; @@ -227,11 +227,11 @@ impl Expr { } /// Compute the quantile per group. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { AggExpr::Quantile { expr: Arc::new(self), quantile: Arc::new(quantile), - interpol, + method, } .into() } @@ -1358,13 +1358,13 @@ impl Expr { pub fn rolling_quantile_by( self, by: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsDynamicWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) @@ -1385,7 +1385,7 @@ impl Expr { /// Apply a rolling median based on another column. #[cfg(feature = "rolling_window_by")] pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { - self.rolling_quantile_by(by, QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile_by(by, QuantileMethod::Linear, 0.5, options) } /// Apply a rolling minimum. @@ -1425,7 +1425,7 @@ impl Expr { /// See: [`RollingAgg::rolling_median`] #[cfg(feature = "rolling_window")] pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { - self.rolling_quantile(QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile(QuantileMethod::Linear, 0.5, options) } /// Apply a rolling quantile. @@ -1434,13 +1434,13 @@ impl Expr { #[cfg(feature = "rolling_window")] pub fn rolling_quantile( self, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsFixedWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling(options, RollingFunction::Quantile) diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index 73481796d3e0..259d66af95ae 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -5,6 +5,7 @@ use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::dsl::Selector; @@ -87,8 +88,9 @@ impl Default for WindowType { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum WindowMapping { /// Map the group values to the position #[default] diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 0f9543a24921..cd133ceb646e 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -1,7 +1,6 @@ use std::io::Cursor; use std::sync::Arc; -use once_cell::sync::Lazy; use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::column::Column; @@ -10,10 +9,6 @@ use polars_core::schema::Schema; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; -#[cfg(feature = "serde")] -use serde::ser::Error; -#[cfg(feature = "serde")] -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::expr_dyn_fn::*; use crate::constants::MAP_LIST_NAME; @@ -26,88 +21,10 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option< pub static mut CALL_DF_UDF_PYTHON: Option< fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, > = None; -#[cfg(feature = "serde")] -pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); -static PYTHON_VERSION_MINOR: Lazy = Lazy::new(get_python_minor_version); -#[derive(Debug)] -pub struct PythonFunction(pub PyObject); - -impl Clone for PythonFunction { - fn clone(&self) -> Self { - Python::with_gil(|py| Self(self.0.clone_ref(py))) - } -} - -impl From for PythonFunction { - fn from(value: PyObject) -> Self { - Self(value) - } -} - -impl Eq for PythonFunction {} - -impl PartialEq for PythonFunction { - fn eq(&self, other: &Self) -> bool { - Python::with_gil(|py| { - let eq = self.0.getattr(py, "__eq__").unwrap(); - eq.call1(py, (other.0.clone_ref(py),)) - .unwrap() - .extract::(py) - // equality can be not implemented, so default to false - .unwrap_or(false) - }) - } -} - -#[cfg(feature = "serde")] -impl Serialize for PythonFunction { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("unable to import 'cloudpickle' or 'pickle'") - .getattr("dumps") - .unwrap(); - - let python_function = self.0.clone_ref(py); - - let dumped = pickle - .call1((python_function,)) - .map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?; - let dumped = dumped.extract::().unwrap(); - - serializer.serialize_bytes(&dumped) - }) - } -} - -#[cfg(feature = "serde")] -impl<'a> Deserialize<'a> for PythonFunction { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'a>, - { - use serde::de::Error; - let bytes = Vec::::deserialize(deserializer)?; - - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "pickle") - .expect("unable to import 'pickle'") - .getattr("loads") - .unwrap(); - let arg = (PyBytes::new_bound(py, &bytes),); - let python_function = pickle - .call1(arg) - .map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?; - - Ok(Self(python_function.into())) - }) - } -} +pub use polars_utils::python_function::{ + PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK, +}; pub struct PythonUdfExpression { python_function: PyObject, @@ -134,23 +51,23 @@ impl PythonUdfExpression { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { // Handle byte mark - debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - let buf = &buf[MAGIC_BYTE_MARK.len()..]; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; // Handle pickle metadata let use_cloudpickle = buf[0]; if use_cloudpickle != 0 { - let ser_py_version = buf[1]; - let cur_py_version = *PYTHON_VERSION_MINOR; + let ser_py_version = &buf[1..3]; + let cur_py_version = *PYTHON3_VERSION; polars_ensure!( ser_py_version == cur_py_version, InvalidOperation: - "current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})", - cur_py_version, - ser_py_version + "current Python version {:?} does not match the Python version used to serialize the UDF {:?}", + (3, cur_py_version[0], cur_py_version[1]), + (3, ser_py_version[0], ser_py_version[1] ) ); } - let buf = &buf[2..]; + let buf = &buf[3..]; // Load UDF metadata let mut reader = Cursor::new(buf); @@ -181,7 +98,7 @@ fn from_pyerr(e: PyErr) -> PolarsError { PolarsError::ComputeError(format!("error raised in python: {e}").into()) } -impl DataFrameUdf for PythonFunction { +impl DataFrameUdf for polars_utils::python_function::PythonFunction { fn call_udf(&self, df: DataFrame) -> PolarsResult { let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() }; func(df, &self.0) @@ -215,7 +132,7 @@ impl ColumnsUdf for PythonUdfExpression { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { // Write byte marks - buf.extend_from_slice(MAGIC_BYTE_MARK); + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); Python::with_gil(|py| { // Try pickle to serialize the UDF, otherwise fall back to cloudpickle. @@ -224,8 +141,8 @@ impl ColumnsUdf for PythonUdfExpression { .getattr("dumps") .unwrap(); let pickle_result = pickle.call1((self.python_function.clone_ref(py),)); - let (dumped, use_cloudpickle, py_version) = match pickle_result { - Ok(dumped) => (dumped, false, 0), + let (dumped, use_cloudpickle) = match pickle_result { + Ok(dumped) => (dumped, false), Err(_) => { let cloudpickle = PyModule::import_bound(py, "cloudpickle") .map_err(from_pyerr)? @@ -234,12 +151,13 @@ impl ColumnsUdf for PythonUdfExpression { let dumped = cloudpickle .call1((self.python_function.clone_ref(py),)) .map_err(from_pyerr)?; - (dumped, true, *PYTHON_VERSION_MINOR) + (dumped, true) }, }; // Write pickle metadata - buf.extend_from_slice(&[use_cloudpickle as u8, py_version]); + buf.push(use_cloudpickle as u8); + buf.extend_from_slice(&*PYTHON3_VERSION); // Write UDF metadata ciborium::ser::into_writer( @@ -273,8 +191,8 @@ impl PythonGetOutput { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { // Skip header. - debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - let buf = &buf[MAGIC_BYTE_MARK.len()..]; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; let mut reader = Cursor::new(buf); let return_dtype: Option = @@ -302,7 +220,7 @@ impl FunctionOutputField for PythonGetOutput { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { - buf.extend_from_slice(MAGIC_BYTE_MARK); + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap(); Ok(()) } @@ -342,17 +260,3 @@ impl Expr { } } } - -/// Get the minor Python version from the `sys` module. -fn get_python_minor_version() -> u8 { - Python::with_gil(|py| { - PyModule::import_bound(py, "sys") - .unwrap() - .getattr("version_info") - .unwrap() - .getattr("minor") - .unwrap() - .extract() - .unwrap() - }) -} diff --git a/crates/polars-plan/src/dsl/selector.rs b/crates/polars-plan/src/dsl/selector.rs index 16e7d7b374e0..7877edb152df 100644 --- a/crates/polars-plan/src/dsl/selector.rs +++ b/crates/polars-plan/src/dsl/selector.rs @@ -11,7 +11,7 @@ pub enum Selector { Add(Box, Box), Sub(Box, Box), ExclusiveOr(Box, Box), - InterSect(Box, Box), + Intersect(Box, Box), Root(Box), } @@ -34,7 +34,7 @@ impl BitAnd for Selector { #[allow(clippy::suspicious_arithmetic_impl)] fn bitand(self, rhs: Self) -> Self::Output { - Selector::InterSect(Box::new(self), Box::new(rhs)) + Selector::Intersect(Box::new(self), Box::new(rhs)) } } diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index d392d403d1b6..2514d1a5f6a4 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -592,4 +592,14 @@ impl StringNameSpace { None, ) } + + #[cfg(feature = "regex")] + pub fn escape_regex(self) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::EscapeRegex), + &[], + false, + None, + ) + } } diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 565710c0dbaf..286ea86ac968 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -44,7 +44,7 @@ pub enum IRAggExpr { Quantile { expr: Node, quantile: Node, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Node), Count(Node, bool), @@ -62,7 +62,9 @@ impl Hash for IRAggExpr { Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => { propagate_nans.hash(state) }, - Self::Quantile { interpol, .. } => interpol.hash(state), + Self::Quantile { + method: interpol, .. + } => interpol.hash(state), Self::Std(_, v) | Self::Var(_, v) => v.hash(state), #[cfg(feature = "bitwise")] Self::Bitwise(_, f) => f.hash(state), @@ -92,7 +94,7 @@ impl IRAggExpr { propagate_nans: r, .. }, ) => l == r, - (Quantile { interpol: l, .. }, Quantile { interpol: r, .. }) => l == r, + (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r, (Std(_, l), Std(_, r)) => l == r, (Var(_, l), Var(_, r)) => l == r, #[cfg(feature = "bitwise")] diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index a290321f4cf8..8cb4b8cc2387 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -19,17 +19,17 @@ impl AExpr { pub fn to_dtype( &self, schema: &Schema, - ctxt: Context, + ctx: Context, arena: &Arena, ) -> PolarsResult { - self.to_field(schema, ctxt, arena).map(|f| f.dtype) + self.to_field(schema, ctx, arena).map(|f| f.dtype) } /// Get Field result of the expression. The schema is the input data. pub fn to_field( &self, schema: &Schema, - ctxt: Context, + ctx: Context, arena: &Arena, ) -> PolarsResult { // During aggregation a column that isn't aggregated gets an extra nesting level @@ -37,7 +37,7 @@ impl AExpr { // But not if we do an aggregation: // col(foo: i64).sum() -> i64 // The `nested` keeps track of the nesting we need to add. - let mut nested = matches!(ctxt, Context::Aggregation) as u8; + let mut nested = matches!(ctx, Context::Aggregation) as u8; let mut field = self.to_field_impl(schema, arena, &mut nested)?; if nested >= 1 { diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 7163e18de165..1697a5571d4e 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -85,7 +85,7 @@ impl AExpr { } } - pub(crate) fn replace_inputs(mut self, inputs: &[Node]) -> Self { + pub fn replace_inputs(mut self, inputs: &[Node]) -> Self { use AExpr::*; let input = match &mut self { Column(_) | Literal(_) | Len => return self, diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index aef7cd157334..6520cc476178 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -1,3 +1,5 @@ +use bitflags::bitflags; + use super::*; fn has_series_or_range(ae: &AExpr) -> bool { @@ -7,7 +9,46 @@ fn has_series_or_range(ae: &AExpr) -> bool { ) } -pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> bool { +bitflags! { + #[derive(Default, Copy, Clone)] + struct StreamableFlags: u8 { + const ALLOW_CAST_CATEGORICAL = 1; + } +} + +#[derive(Copy, Clone)] +pub struct IsStreamableContext { + flags: StreamableFlags, + context: Context, +} + +impl Default for IsStreamableContext { + fn default() -> Self { + Self { + flags: StreamableFlags::all(), + context: Default::default(), + } + } +} + +impl IsStreamableContext { + pub fn new(ctx: Context) -> Self { + Self { + flags: StreamableFlags::all(), + context: ctx, + } + } + + pub fn with_allow_cast_categorical(mut self, allow_cast_categorical: bool) -> Self { + self.flags.set( + StreamableFlags::ALLOW_CAST_CATEGORICAL, + allow_cast_categorical, + ); + self + } +} + +pub fn is_streamable(node: Node, expr_arena: &Arena, ctx: IsStreamableContext) -> bool { // check whether leaf column is Col or Lit let mut seen_column = false; let mut seen_lit_range = false; @@ -16,13 +57,14 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> function: FunctionExpr::SetSortedFlag(_), .. } => true, - AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => match context - { - Context::Default => matches!( - options.collect_groups, - ApplyOptions::ElementWise | ApplyOptions::ApplyList - ), - Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), + AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { + match ctx.context { + Context::Default => matches!( + options.collect_groups, + ApplyOptions::ElementWise | ApplyOptions::ApplyList + ), + Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), + } }, AExpr::Column(_) => { seen_column = true; @@ -41,6 +83,10 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> && !has_aexpr(*falsy, expr_arena, has_series_or_range) && !has_aexpr(*predicate, expr_arena, has_series_or_range) }, + #[cfg(feature = "dtype-categorical")] + AExpr::Cast { dtype, .. } if matches!(dtype, DataType::Categorical(_, _)) => { + ctx.flags.contains(StreamableFlags::ALLOW_CAST_CATEGORICAL) + }, AExpr::Alias(_, _) | AExpr::Cast { .. } => true, AExpr::Literal(lv) => match lv { LiteralValue::Series(_) | LiteralValue::Range { .. } => { @@ -64,8 +110,12 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> false } -pub fn all_streamable(exprs: &[ExprIR], expr_arena: &Arena, context: Context) -> bool { +pub fn all_streamable( + exprs: &[ExprIR], + expr_arena: &Arena, + ctx: IsStreamableContext, +) -> bool { exprs .iter() - .all(|e| is_streamable(e.node(), expr_arena, context)) + .all(|e| is_streamable(e.node(), expr_arena, ctx)) } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index c0178f5b383c..793ab63194d7 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -369,7 +369,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; - return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Context::Default) { + return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Default::default()) { // Split expression that are ANDed into multiple Filter nodes as the optimizer can then // push them down independently. Especially if they refer columns from different tables // this will be more performant. @@ -738,9 +738,9 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult |name| col(name.clone()).std(ddof), &input_schema, ), - StatsFunction::Quantile { quantile, interpol } => stats_helper( + StatsFunction::Quantile { quantile, method } => stats_helper( |dt| dt.is_numeric(), - |name| col(name.clone()).quantile(quantile.clone(), interpol), + |name| col(name.clone()).quantile(quantile.clone(), method), &input_schema, ), StatsFunction::Mean => stats_helper( diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index bec3fbe852cd..4709641662f9 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -1,5 +1,4 @@ //! this contains code used for rewriting projections, expanding wildcards, regex selection etc. -use std::ops::BitXor; use super::*; @@ -176,26 +175,28 @@ fn expand_columns( schema: &Schema, exclude: &PlHashSet, ) -> PolarsResult<()> { - let mut is_valid = true; + if !expr.into_iter().all(|e| match e { + // check for invalid expansions such as `col([a, b]) + col([c, d])` + Expr::Columns(ref members) => members.as_ref() == names, + _ => true, + }) { + polars_bail!(ComputeError: "expanding more than one `col` is not allowed"); + } for name in names { if !exclude.contains(name) { - let new_expr = expr.clone(); - let (new_expr, new_expr_valid) = replace_columns_with_column(new_expr, names, name); - is_valid &= new_expr_valid; - // we may have regex col in columns. - #[allow(clippy::collapsible_else_if)] + let new_expr = expr.clone().map_expr(|e| match e { + Expr::Columns(_) => Expr::Column((*name).clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }); + #[cfg(feature = "regex")] - { - replace_regex(&new_expr, result, schema, exclude)?; - } + replace_regex(&new_expr, result, schema, exclude)?; + #[cfg(not(feature = "regex"))] - { - let new_expr = rewrite_special_aliases(new_expr)?; - result.push(new_expr) - } + result.push(rewrite_special_aliases(new_expr)?); } } - polars_ensure!(is_valid, ComputeError: "expanding more than one `col` is not allowed"); Ok(()) } @@ -246,30 +247,6 @@ fn replace_dtype_or_index_with_column( }) } -/// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the -/// expression chain. -pub(super) fn replace_columns_with_column( - mut expr: Expr, - names: &[PlSmallStr], - column_name: &PlSmallStr, -) -> (Expr, bool) { - let mut is_valid = true; - expr = expr.map_expr(|e| match e { - Expr::Columns(members) => { - // `col([a, b]) + col([c, d])` - if members.as_ref() == names { - Expr::Column(column_name.clone()) - } else { - is_valid = false; - Expr::Columns(members) - } - }, - Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), - e => e, - }); - (expr, is_valid) -} - fn dtypes_match(d1: &DataType, d2: &DataType) -> bool { match (d1, d2) { // note: allow Datetime "*" wildcard for timezones... @@ -562,7 +539,7 @@ fn expand_function_inputs( }) } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] struct ExpansionFlags { multiple_columns: bool, has_nth: bool, @@ -819,42 +796,31 @@ fn replace_selector_inner( members.extend(scratch.drain(..)) }, Selector::Add(lhs, rhs) => { + let mut tmp_members: PlIndexSet = Default::default(); replace_selector_inner(*lhs, members, scratch, schema, keys)?; - let mut rhs_members: PlIndexSet = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - members.extend(rhs_members) + replace_selector_inner(*rhs, &mut tmp_members, scratch, schema, keys)?; + members.extend(tmp_members) }, Selector::ExclusiveOr(lhs, rhs) => { - let mut lhs_members = Default::default(); - replace_selector_inner(*lhs, &mut lhs_members, scratch, schema, keys)?; + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - let xor_members = lhs_members.bitxor(&rhs_members); - *members = xor_members; + *members = tmp_members.symmetric_difference(members).cloned().collect(); }, - Selector::InterSect(lhs, rhs) => { - replace_selector_inner(*lhs, members, scratch, schema, keys)?; + Selector::Intersect(lhs, rhs) => { + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - *members = members.intersection(&rhs_members).cloned().collect() + *members = tmp_members.intersection(members).cloned().collect(); }, Selector::Sub(lhs, rhs) => { - replace_selector_inner(*lhs, members, scratch, schema, keys)?; + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - let mut new_members = PlIndexSet::with_capacity(members.len()); - for e in members.drain(..) { - if !rhs_members.contains(&e) { - new_members.insert(e); - } - } - *members = new_members; + *members = tmp_members.difference(members).cloned().collect(); }, } Ok(()) diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index 6873ad3f6851..d3e0c17f8098 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -237,11 +237,11 @@ pub(super) fn to_aexpr_impl( AggExpr::Quantile { expr, quantile, - interpol, + method, } => IRAggExpr::Quantile { expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?, - interpol, + method, }, AggExpr::Sum(expr) => { IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index 5d2e4c373b30..160b70951962 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -129,14 +129,14 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { IRAggExpr::Quantile { expr, quantile, - interpol, + method, } => { let expr = node_to_expr(expr, expr_arena); let quantile = node_to_expr(quantile, expr_arena); AggExpr::Quantile { expr: Arc::new(expr), quantile: Arc::new(quantile), - interpol, + method, } .into() }, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 4fa47e7695c8..9d63d18b0a46 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -115,7 +115,7 @@ pub fn resolve_join( // Every expression must be elementwise so that we are // guaranteed the keys for a join are all the same length. let all_elementwise = - |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Context::Default); + |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Default::default()); polars_ensure!( all_elementwise(&left_on) && all_elementwise(&right_on), InvalidOperation: "All join key expressions must be elementwise." diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index e470bd3044bc..f1aa33a7e7dd 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -72,7 +72,7 @@ pub enum StatsFunction { }, Quantile { quantile: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Median, Mean, diff --git a/crates/polars-plan/src/plans/hive.rs b/crates/polars-plan/src/plans/hive.rs index a711aeb11848..d99054cb405c 100644 --- a/crates/polars-plan/src/plans/hive.rs +++ b/crates/polars-plan/src/plans/hive.rs @@ -231,21 +231,18 @@ pub fn hive_partitions_from_paths( } /// Determine the path separator for identifying Hive partitions. -#[cfg(target_os = "windows")] -fn separator(url: &Path) -> char { - if polars_io::path_utils::is_cloud_url(url) { - '/' +fn separator(url: &Path) -> &[char] { + if cfg!(target_family = "windows") { + if polars_io::path_utils::is_cloud_url(url) { + &['/'] + } else { + &['/', '\\'] + } } else { - '\\' + &['/'] } } -/// Determine the path separator for identifying Hive partitions. -#[cfg(not(target_os = "windows"))] -fn separator(_url: &Path) -> char { - '/' -} - /// Parse a Hive partition string (e.g. "column=1.5") into a name and value part. /// /// Returns `None` if the string is not a Hive partition string. diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs index 982b74c9d851..44f2f187ce4b 100644 --- a/crates/polars-plan/src/plans/ir/dot.rs +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -420,7 +420,7 @@ impl fmt::Display for OptionExprIRDisplay<'_> { /// Utility structure to write to a [`fmt::Formatter`] whilst escaping the output as a label name pub struct EscapeLabel<'a>(pub &'a mut dyn fmt::Write); -impl<'a> fmt::Write for EscapeLabel<'a> { +impl fmt::Write for EscapeLabel<'_> { fn write_str(&mut self, mut s: &str) -> fmt::Result { loop { let mut char_indices = s.char_indices(); diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 135c03115596..61b364f93b4a 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -413,13 +413,13 @@ impl<'a> ExprIRDisplay<'a> { } } -impl<'a> Display for IRDisplay<'a> { +impl Display for IRDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self._format(f, 0) } } -impl<'a, T: AsExpr> Display for ExprIRSliceDisplay<'a, T> { +impl Display for ExprIRSliceDisplay<'_, T> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { // Display items in slice delimited by a comma @@ -452,13 +452,13 @@ impl<'a, T: AsExpr> Display for ExprIRSliceDisplay<'a, T> { } } -impl<'a, T: AsExpr> fmt::Debug for ExprIRSliceDisplay<'a, T> { +impl fmt::Debug for ExprIRSliceDisplay<'_, T> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(self, f) } } -impl<'a> Display for ExprIRDisplay<'a> { +impl Display for ExprIRDisplay<'_> { #[recursive] fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let root = self.expr_arena.get(self.node); @@ -673,7 +673,7 @@ impl<'a> Display for ExprIRDisplay<'a> { } } -impl<'a> fmt::Debug for ExprIRDisplay<'a> { +impl fmt::Debug for ExprIRDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(self, f) } diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index a9eb45b6406f..14b658bea22f 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -272,6 +272,6 @@ mod test { #[ignore] #[test] fn test_alp_size() { - assert!(std::mem::size_of::() <= 152); + assert!(size_of::() <= 152); } } diff --git a/crates/polars-plan/src/plans/ir/scan_sources.rs b/crates/polars-plan/src/plans/ir/scan_sources.rs index 789a5c4f4811..bcb80bd6140d 100644 --- a/crates/polars-plan/src/plans/ir/scan_sources.rs +++ b/crates/polars-plan/src/plans/ir/scan_sources.rs @@ -330,4 +330,4 @@ impl<'a> Iterator for ScanSourceIter<'a> { } } -impl<'a> ExactSizeIterator for ScanSourceIter<'a> {} +impl ExactSizeIterator for ScanSourceIter<'_> {} diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index 51dd550ee9e4..74feffd60da0 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -266,7 +266,7 @@ impl Literal for String { } } -impl<'a> Literal for &'a str { +impl Literal for &str { fn lit(self) -> Expr { Expr::Literal(LiteralValue::String(PlSmallStr::from_str(self))) } @@ -278,7 +278,7 @@ impl Literal for Vec { } } -impl<'a> Literal for &'a [u8] { +impl Literal for &[u8] { fn lit(self) -> Expr { Expr::Literal(LiteralValue::Binary(self.to_vec())) } diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 03eb06387cc6..314ca8bb0cb2 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -49,11 +49,12 @@ pub use schema::*; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Default)] pub enum Context { /// Any operation that is done on groups Aggregation, /// Any operation that is done while projection/ selection of data + #[default] Default, } diff --git a/crates/polars-plan/src/plans/optimizer/cache_states.rs b/crates/polars-plan/src/plans/optimizer/cache_states.rs index da13d047d43f..f6968cc6f7d2 100644 --- a/crates/polars-plan/src/plans/optimizer/cache_states.rs +++ b/crates/polars-plan/src/plans/optimizer/cache_states.rs @@ -348,17 +348,23 @@ pub(super) fn set_cache_states( let lp = IRBuilder::new(new_child, expr_arena, lp_arena) .project_simple(projection) - .unwrap() + .expect("unique names") .build(); let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; - // Remove the projection added by the optimization. - let lp = - if let IR::Select { input, .. } | IR::SimpleProjection { input, .. } = lp { - lp_arena.take(input) + // Optimization can lead to a double projection. Only take the last. + let lp = if let IR::SimpleProjection { input, columns } = lp { + let input = if let IR::SimpleProjection { input: input2, .. } = + lp_arena.get(input) + { + *input2 } else { - lp + input }; + IR::SimpleProjection { input, columns } + } else { + lp + }; lp_arena.replace(child, lp); } } else { diff --git a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs index b3f52c6e30a9..4e109903fdce 100644 --- a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs +++ b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs @@ -141,7 +141,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) // @NOTE: Pruning of re-assigned columns // // We checked if this expression output is also assigned by the input and - // that that assignment is not used in the current WITH_COLUMNS. + // that this assignment is not used in the current WITH_COLUMNS. // Consequently, we are free to prune the input's assignment to the output. // // We immediately prune here to simplify the later code. diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index 608c7122f2ec..778efee6aa9b 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -123,7 +123,7 @@ struct MintermIter<'a> { expr_arena: &'a Arena, } -impl<'a> Iterator for MintermIter<'a> { +impl Iterator for MintermIter<'_> { type Item = Node; fn next(&mut self) -> Option { diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs index 6b7763760fa1..700a82720eb4 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs @@ -181,6 +181,11 @@ enum VisitRecord { fn skip_pre_visit(ae: &AExpr, is_groupby: bool) -> bool { match ae { AExpr::Window { .. } => true, + #[cfg(feature = "dtype-struct")] + AExpr::Function { + function: FunctionExpr::AsStruct, + .. + } => true, AExpr::Ternary { .. } => is_groupby, _ => false, } diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs index 075414597edf..f4522ee3a3ca 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs @@ -184,7 +184,7 @@ fn skip_children(lp: &IR) -> bool { } } -impl<'a> Visitor for LpIdentifierVisitor<'a> { +impl Visitor for LpIdentifierVisitor<'_> { type Node = IRNode; type Arena = IRNodeArena; @@ -265,7 +265,7 @@ impl<'a> CommonSubPlanRewriter<'a> { } } -impl<'a> RewritingVisitor for CommonSubPlanRewriter<'a> { +impl RewritingVisitor for CommonSubPlanRewriter<'_> { type Node = IRNode; type Arena = IRNodeArena; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index ed3f3e0376bd..ff5f2f89ff0d 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -672,7 +672,7 @@ impl<'a> PredicatePushDown<'a> { if let Some(predicate) = predicate { // For IO plugins we only accept streamable expressions as // we want to apply the predicates to the batches. - if !is_streamable(predicate.node(), expr_arena, Context::Default) + if !is_streamable(predicate.node(), expr_arena, Default::default()) && matches!(options.python_source, PythonScanSource::IOPlugin) { let lp = PythonScan { options }; diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs index 8096b5bde3d8..81328fe208e6 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs @@ -63,7 +63,7 @@ pub(super) fn process_hstack( acc_projections, &lp_arena.get(input).schema(lp_arena), expr_arena, - false, + true, // expands_schema ); proj_pd.pushdown_and_assign( diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs index 122d86369252..67c2cdedf227 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -344,6 +344,7 @@ impl ProjectionPushDown { projections_seen, lp_arena, expr_arena, + false, ), SimpleProjection { columns, input, .. } => { let exprs = names_to_expr_irs(columns.iter_names_cloned(), expr_arena); @@ -356,6 +357,7 @@ impl ProjectionPushDown { projections_seen, lp_arena, expr_arena, + true, ) }, DataFrameScan { @@ -509,6 +511,21 @@ impl ProjectionPushDown { file_options.row_index = None; } }; + + if let Some(col_name) = &file_options.include_file_paths { + if output_schema + .as_ref() + .map_or(false, |schema| !schema.contains(col_name)) + { + // Need to remove it from the input schema so + // that projection indices are correct. + let mut file_schema = Arc::unwrap_or_clone(file_info.schema); + file_schema.shift_remove(col_name); + file_info.schema = Arc::new(file_schema); + file_options.include_file_paths = None; + } + }; + let lp = Scan { sources, file_info, diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs index 6b1106a7ca19..854cc17d6fbe 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs @@ -54,6 +54,8 @@ pub(super) fn process_projection( projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, + // Whether is SimpleProjection. + simple: bool, ) -> PolarsResult { let mut local_projection = Vec::with_capacity(exprs.len()); @@ -130,7 +132,14 @@ pub(super) fn process_projection( )?; let builder = IRBuilder::new(input, expr_arena, lp_arena); - let lp = proj_pd.finish_node(local_projection, builder); + + let lp = if !local_projection.is_empty() && simple { + builder + .project_simple_nodes(local_projection.into_iter().map(|e| e.node()))? + .build() + } else { + proj_pd.finish_node(local_projection, builder) + }; Ok(lp) } diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs index 1df68a0adcfa..db123a5bd09d 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs @@ -152,7 +152,7 @@ impl OptimizationRule for SimplifyBooleanRule { AExpr::Literal(LiteralValue::Boolean(true)) ) && in_filter => { - // Only in filter as we we might change the name from "literal" + // Only in filter as we might change the name from "literal" // to whatever lhs columns is. return Ok(Some(expr_arena.get(*right).clone())); }, @@ -210,7 +210,7 @@ impl OptimizationRule for SimplifyBooleanRule { AExpr::Literal(LiteralValue::Boolean(false)) ) && in_filter => { - // Only in filter as we we might change the name from "literal" + // Only in filter as we might change the name from "literal" // to whatever lhs columns is. return Ok(Some(expr_arena.get(*right).clone())); }, diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index b656795f53d2..9c2f8497fac8 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -31,7 +31,7 @@ fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena) - // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, // `str.contains`, `str.contains_many` etc. - observe a column node is present // but the output height is not dependent on it. - let is_elementwise = is_streamable(expr_ir.node(), arena, Context::Default); + let is_elementwise = is_streamable(expr_ir.node(), arena, Default::default()); let (has_column, literals_all_scalar) = arena.iter(expr_ir.node()).fold( (false, true), |(has_column, lit_scalar), (_node, ae)| { diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index 71b287d03b85..62a64319ae2e 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -67,7 +67,7 @@ impl TreeWalker for Expr { Mean(x) => Mean(am(x, f)?), Implode(x) => Implode(am(x, f)?), Count(x, nulls) => Count(am(x, f)?, nulls), - Quantile { expr, quantile, interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, interpol }, + Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol }, Sum(x) => Sum(am(x, f)?), AggGroups(x) => AggGroups(am(x, f)?), Std(x, ddf) => Std(am(x, f)?, ddf), diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index 616ba491a19a..16af7a3071df 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -11,8 +11,10 @@ description = "Enable running Polars workloads in Python" [dependencies] polars-core = { workspace = true, features = ["python"] } polars-error = { workspace = true } +polars-expr = { workspace = true } polars-io = { workspace = true } polars-lazy = { workspace = true, features = ["python"] } +polars-mem-engine = { workspace = true } polars-ops = { workspace = true, features = ["bitwise"] } polars-parquet = { workspace = true, optional = true } polars-plan = { workspace = true } @@ -32,7 +34,9 @@ itoa = { workspace = true } libc = { workspace = true } ndarray = { workspace = true } num-traits = { workspace = true } -numpy = { workspace = true } +# TODO: Pin to released version once NumPy 2.0 support is merged +# https://github.com/PyO3/rust-numpy/issues/409 +numpy = { git = "https://github.com/stinodego/rust-numpy.git", rev = "9ba9962ae57ba26e35babdce6f179edf5fe5b9c8", default-features = false } once_cell = { workspace = true } pyo3 = { workspace = true, features = ["abi3-py39", "chrono", "multiple-pymethods"] } recursive = { workspace = true } @@ -75,6 +79,7 @@ features = [ "lazy", "list_eval", "list_to_struct", + "list_arithmetic", "array_to_struct", "log", "mode", @@ -230,7 +235,7 @@ optimizations = [ "streaming", ] -polars_cloud = ["polars/polars_cloud"] +polars_cloud = ["polars/polars_cloud", "polars/ir_serde"] # also includes simd nightly = ["polars/nightly"] @@ -250,7 +255,7 @@ all = [ "binary_encoding", "ffi_plugin", "polars_cloud", - # "new_streaming", + "new_streaming", ] # we cannot conditionally activate simd diff --git a/crates/polars-python/src/cloud.rs b/crates/polars-python/src/cloud.rs index dacca675c551..39410a6fa7a1 100644 --- a/crates/polars-python/src/cloud.rs +++ b/crates/polars-python/src/cloud.rs @@ -1,8 +1,17 @@ -use pyo3::prelude::*; -use pyo3::types::PyBytes; +use std::io::Cursor; + +use polars_core::error::{polars_err, to_compute_err, PolarsResult}; +use polars_expr::state::ExecutionState; +use polars_mem_engine::create_physical_plan; +use polars_plan::plans::{AExpr, IRPlan, IR}; +use polars_plan::prelude::{Arena, Node}; +use pyo3::intern; +use pyo3::prelude::{PyAnyMethods, PyModule, Python, *}; +use pyo3::types::{IntoPyDict, PyBytes}; use crate::error::PyPolarsErr; -use crate::PyLazyFrame; +use crate::lazyframe::visit::NodeTraverser; +use crate::{PyDataFrame, PyLazyFrame}; #[pyfunction] pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python) -> PyResult { @@ -11,3 +20,76 @@ pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python) -> PyResult { Ok(PyBytes::new_bound(py, &bytes).to_object(py)) } + +/// Take a serialized `IRPlan` and execute it on the GPU engine. +/// +/// This is done as a Python function because the `NodeTraverser` class created for this purpose +/// must exactly match the one expected by the `cudf_polars` package. +#[pyfunction] +pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec, py: Python) -> PyResult { + // Deserialize into IRPlan. + let reader = Cursor::new(ir_plan_ser); + let mut ir_plan = ciborium::from_reader::(reader) + .map_err(to_compute_err) + .map_err(PyPolarsErr::from)?; + + // Edit for use with GPU engine. + gpu_post_opt( + py, + ir_plan.lp_top, + &mut ir_plan.lp_arena, + &mut ir_plan.expr_arena, + ) + .map_err(PyPolarsErr::from)?; + + // Convert to physical plan. + let mut physical_plan = + create_physical_plan(ir_plan.lp_top, &mut ir_plan.lp_arena, &ir_plan.expr_arena) + .map_err(PyPolarsErr::from)?; + + // Execute the plan. + let mut state = ExecutionState::new(); + let df = physical_plan + .execute(&mut state) + .map_err(PyPolarsErr::from)?; + + Ok(df.into()) +} + +/// Prepare the IR for execution by the Polars GPU engine. +fn gpu_post_opt( + py: Python, + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult<()> { + // Get cuDF Python function. + let cudf = PyModule::import_bound(py, intern!(py, "cudf_polars")).unwrap(); + let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap(); + + // Define cuDF config. + let polars = PyModule::import_bound(py, intern!(py, "polars")).unwrap(); + let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap(); + let kwargs = [("raise_on_fail", true)].into_py_dict_bound(py); + let engine = engine.call((), Some(&kwargs)).unwrap(); + + // Define node traverser. + let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena)); + + // Get a copy of the arenas. + let arenas = nt.get_arenas(); + + // Pass the node visitor which allows the Python callback to replace parts of the query plan. + // Remove "cuda" or specify better once we have multiple post-opt callbacks. + let kwargs = [("config", engine)].into_py_dict_bound(py); + lambda + .call((nt,), Some(&kwargs)) + .map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?; + + // Unpack the arena's. + // At this point the `nt` is useless. + std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap()); + std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap()); + + Ok(()) +} diff --git a/crates/polars-python/src/conversion/any_value.rs b/crates/polars-python/src/conversion/any_value.rs index 712e0817e5b7..eb4835ada90f 100644 --- a/crates/polars-python/src/conversion/any_value.rs +++ b/crates/polars-python/src/conversion/any_value.rs @@ -115,7 +115,7 @@ pub(crate) fn any_value_into_py_object(av: AnyValue, py: Python) -> PyObject { let buf = unsafe { std::slice::from_raw_parts( buf.as_slice().as_ptr() as *const u8, - N * std::mem::size_of::(), + N * size_of::(), ) }; let digits = PyTuple::new_bound(py, buf.iter().take(n_digits)); diff --git a/crates/polars-python/src/conversion/chunked_array.rs b/crates/polars-python/src/conversion/chunked_array.rs index 3a69d61f7dd1..404fe68ce8ef 100644 --- a/crates/polars-python/src/conversion/chunked_array.rs +++ b/crates/polars-python/src/conversion/chunked_array.rs @@ -127,7 +127,7 @@ pub(crate) fn decimal_to_pyobject_iter<'a>( let buf = unsafe { std::slice::from_raw_parts( buf.as_slice().as_ptr() as *const u8, - N * std::mem::size_of::(), + N * size_of::(), ) }; let digits = PyTuple::new_bound(py, buf.iter().take(n_digits)); diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs index abde51745554..16abea471d7f 100644 --- a/crates/polars-python/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -59,8 +59,8 @@ unsafe impl Transparent for Option { } pub(crate) fn reinterpret_vec(input: Vec) -> Vec { - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - assert_eq!(std::mem::align_of::(), std::mem::align_of::()); + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); let len = input.len(); let cap = input.capacity(); let mut manual_drop_vec = std::mem::ManuallyDrop::new(input); @@ -604,20 +604,12 @@ impl IntoPy for Wrap<&Schema> { } } -#[derive(Debug)] +#[derive(Clone, Debug)] #[repr(transparent)] pub struct ObjectValue { pub inner: PyObject, } -impl Clone for ObjectValue { - fn clone(&self) -> Self { - Python::with_gil(|py| Self { - inner: self.inner.clone_ref(py), - }) - } -} - impl Hash for ObjectValue { fn hash(&self, state: &mut H) { let h = Python::with_gil(|py| self.inner.bind(py).hash().expect("should be hashable")); @@ -986,17 +978,18 @@ impl<'py> FromPyObject<'py> for Wrap { } } -impl<'py> FromPyObject<'py> for Wrap { +impl<'py> FromPyObject<'py> for Wrap { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let parsed = match &*ob.extract::()? { - "lower" => QuantileInterpolOptions::Lower, - "higher" => QuantileInterpolOptions::Higher, - "nearest" => QuantileInterpolOptions::Nearest, - "linear" => QuantileInterpolOptions::Linear, - "midpoint" => QuantileInterpolOptions::Midpoint, + "lower" => QuantileMethod::Lower, + "higher" => QuantileMethod::Higher, + "nearest" => QuantileMethod::Nearest, + "linear" => QuantileMethod::Linear, + "midpoint" => QuantileMethod::Midpoint, + "equiprobable" => QuantileMethod::Equiprobable, v => { return Err(PyValueError::new_err(format!( - "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint'}}, got {v}", + "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint', 'equiprobable'}}, got {v}", ))) } }; diff --git a/crates/polars-python/src/dataframe/general.rs b/crates/polars-python/src/dataframe/general.rs index be2652a1fb8b..ac4febced0f6 100644 --- a/crates/polars-python/src/dataframe/general.rs +++ b/crates/polars-python/src/dataframe/general.rs @@ -262,16 +262,16 @@ impl PyDataFrame { Ok(PyDataFrame::new(df)) } - pub fn gather(&self, indices: Wrap>) -> PyResult { + pub fn gather(&self, py: Python, indices: Wrap>) -> PyResult { let indices = indices.0; let indices = IdxCa::from_vec("".into(), indices); - let df = self.df.take(&indices).map_err(PyPolarsErr::from)?; + let df = Python::allow_threads(py, || self.df.take(&indices).map_err(PyPolarsErr::from))?; Ok(PyDataFrame::new(df)) } - pub fn gather_with_series(&self, indices: &PySeries) -> PyResult { + pub fn gather_with_series(&self, py: Python, indices: &PySeries) -> PyResult { let indices = indices.series.idx().map_err(PyPolarsErr::from)?; - let df = self.df.take(indices).map_err(PyPolarsErr::from)?; + let df = Python::allow_threads(py, || self.df.take(indices).map_err(PyPolarsErr::from))?; Ok(PyDataFrame::new(df)) } @@ -591,11 +591,11 @@ impl PyDataFrame { every: &str, stable: bool, ) -> PyResult { + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let out = if stable { - self.df - .upsample_stable(by, index_column, Duration::parse(every)) + self.df.upsample_stable(by, index_column, every) } else { - self.df.upsample(by, index_column, Duration::parse(every)) + self.df.upsample(by, index_column, every) }; let out = out.map_err(PyPolarsErr::from)?; Ok(out.into()) diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index d0b6d30c31e1..7125388e88cd 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -9,6 +9,7 @@ use pyo3::class::basic::CompareOp; use pyo3::prelude::*; use crate::conversion::{parse_fill_null_strategy, vec_extract_wrapped, Wrap}; +use crate::error::PyPolarsErr; use crate::map::lazy::map_single; use crate::PyExpr; @@ -149,7 +150,7 @@ impl PyExpr { fn implode(&self) -> Self { self.inner.clone().implode().into() } - fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { self.inner .clone() .quantile(quantile.inner, interpolation.0) @@ -614,15 +615,15 @@ impl PyExpr { period: &str, offset: &str, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingGroupOptions { index_column: index_column.into(), - period: Duration::parse(period), - offset: Duration::parse(offset), + period: Duration::try_parse(period).map_err(PyPolarsErr::from)?, + offset: Duration::try_parse(offset).map_err(PyPolarsErr::from)?, closed_window: closed.0, }; - self.inner.clone().rolling(options).into() + Ok(self.inner.clone().rolling(options).into()) } fn and_(&self, expr: Self) -> Self { @@ -812,12 +813,13 @@ impl PyExpr { }; self.inner.clone().ewm_mean(options).into() } - fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> Self { - let half_life = Duration::parse(half_life); - self.inner + fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> PyResult { + let half_life = Duration::try_parse(half_life).map_err(PyPolarsErr::from)?; + Ok(self + .inner .clone() .ewm_mean_by(times.inner, half_life) - .into() + .into()) } fn ewm_std( diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index 1bd087144634..af3be10449b1 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -4,6 +4,7 @@ use polars::prelude::*; use polars::series::ops::NullBehavior; use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; +use pyo3::types::PySequence; use crate::conversion::Wrap; use crate::PyExpr; @@ -214,20 +215,39 @@ impl PyExpr { upper_bound: usize, ) -> PyResult { let name_gen = name_gen.map(|lambda| { - Arc::new(move |idx: usize| { + NameGenerator::from_func(move |idx: usize| { Python::with_gil(|py| { let out = lambda.call1(py, (idx,)).unwrap(); let out: PlSmallStr = out.extract::>(py).unwrap().as_ref().into(); out }) - }) as NameGenerator + }) }); Ok(self .inner .clone() .list() - .to_struct(width_strat.0, name_gen, upper_bound) + .to_struct(ListToStructArgs::InferWidth { + infer_field_strategy: width_strat.0, + get_index_name: name_gen, + max_fields: upper_bound, + }) + .into()) + } + + #[pyo3(signature = (names))] + fn list_to_struct_fixed_width(&self, names: Bound<'_, PySequence>) -> PyResult { + Ok(self + .inner + .clone() + .list() + .to_struct(ListToStructArgs::FixedWidth( + names + .iter()? + .map(|x| Ok(x?.extract::>()?.0)) + .collect::>>()?, + )) .into()) } diff --git a/crates/polars-python/src/expr/rolling.rs b/crates/polars-python/src/expr/rolling.rs index 629f1eab391d..5ef511902613 100644 --- a/crates/polars-python/src/expr/rolling.rs +++ b/crates/polars-python/src/expr/rolling.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyFloat; use crate::conversion::Wrap; +use crate::error::PyPolarsErr; use crate::map::lazy::call_lambda_with_series; use crate::{PyExpr, PySeries}; @@ -34,14 +35,14 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_sum_by(by.inner, options).into() + Ok(self.inner.clone().rolling_sum_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -70,14 +71,14 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_min_by(by.inner, options).into() + Ok(self.inner.clone().rolling_min_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -105,14 +106,14 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_max_by(by.inner, options).into() + Ok(self.inner.clone().rolling_max_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -142,15 +143,15 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner.clone().rolling_mean_by(by.inner, options).into() + Ok(self.inner.clone().rolling_mean_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] @@ -182,15 +183,15 @@ impl PyExpr { min_periods: usize, closed: Wrap, ddof: u8, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })), }; - self.inner.clone().rolling_std_by(by.inner, options).into() + Ok(self.inner.clone().rolling_std_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] @@ -222,15 +223,15 @@ impl PyExpr { min_periods: usize, closed: Wrap, ddof: u8, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })), }; - self.inner.clone().rolling_var_by(by.inner, options).into() + Ok(self.inner.clone().rolling_var_by(by.inner, options).into()) } #[pyo3(signature = (window_size, weights, min_periods, center))] @@ -259,24 +260,25 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner + Ok(self + .inner .clone() .rolling_median_by(by.inner, options) - .into() + .into()) } #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center))] fn rolling_quantile( &self, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: usize, weights: Option>, min_periods: Option, @@ -302,22 +304,23 @@ impl PyExpr { &self, by: PyExpr, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: &str, min_periods: usize, closed: Wrap, - ) -> Self { + ) -> PyResult { let options = RollingOptionsDynamicWindow { - window_size: Duration::parse(window_size), + window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?, min_periods, closed_window: closed.0, fn_params: None, }; - self.inner + Ok(self + .inner .clone() .rolling_quantile_by(by.inner, interpolation.0, quantile, options) - .into() + .into()) } fn rolling_skew(&self, window_size: usize, bias: bool) -> Self { diff --git a/crates/polars-python/src/expr/string.rs b/crates/polars-python/src/expr/string.rs index 6f0836ad8d13..87521a2b7aa1 100644 --- a/crates/polars-python/src/expr/string.rs +++ b/crates/polars-python/src/expr/string.rs @@ -339,4 +339,9 @@ impl PyExpr { .extract_many(patterns.inner, ascii_case_insensitive, overlapping) .into() } + + #[cfg(feature = "regex")] + fn str_escape_regex(&self) -> Self { + self.inner.clone().str().escape_regex().into() + } } diff --git a/crates/polars-python/src/file.rs b/crates/polars-python/src/file.rs index 074e06115993..efbcbff3fc18 100644 --- a/crates/polars-python/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -17,22 +17,15 @@ use pyo3::types::{PyBytes, PyString, PyStringMethods}; use crate::error::PyPolarsErr; use crate::prelude::resolve_homedir; +#[derive(Clone)] pub struct PyFileLikeObject { inner: PyObject, } -impl Clone for PyFileLikeObject { - fn clone(&self) -> Self { - Python::with_gil(|py| Self { - inner: self.inner.clone_ref(py), - }) - } -} - /// Wraps a `PyObject`, and implements read, seek, and write for it. impl PyFileLikeObject { /// Creates an instance of a `PyFileLikeObject` from a `PyObject`. - /// To assert the object has the required methods methods, + /// To assert the object has the required methods, /// instantiate it with `PyFileLikeObject::require` pub fn new(object: PyObject) -> Self { PyFileLikeObject { inner: object } diff --git a/crates/polars-python/src/functions/lazy.rs b/crates/polars-python/src/functions/lazy.rs index e96a35c6e93a..24db48144508 100644 --- a/crates/polars-python/src/functions/lazy.rs +++ b/crates/polars-python/src/functions/lazy.rs @@ -252,7 +252,7 @@ pub fn cum_reduce(lambda: PyObject, exprs: Vec) -> PyExpr { } #[pyfunction] -#[pyo3(signature = (year, month, day, hour=None, minute=None, second=None, microsecond=None, time_unit=Wrap(TimeUnit::Microseconds), time_zone=None, ambiguous=None))] +#[pyo3(signature = (year, month, day, hour=None, minute=None, second=None, microsecond=None, time_unit=Wrap(TimeUnit::Microseconds), time_zone=None, ambiguous=PyExpr::from(dsl::lit(String::from("raise")))))] pub fn datetime( year: PyExpr, month: PyExpr, @@ -263,15 +263,13 @@ pub fn datetime( microsecond: Option, time_unit: Wrap, time_zone: Option>, - ambiguous: Option, + ambiguous: PyExpr, ) -> PyExpr { let year = year.inner; let month = month.inner; let day = day.inner; set_unwrapped_or_0!(hour, minute, second, microsecond); - let ambiguous = ambiguous - .map(|e| e.inner) - .unwrap_or(dsl::lit(String::from("raise"))); + let ambiguous = ambiguous.inner; let time_unit = time_unit.0; let time_zone = time_zone.map(|x| x.0); let args = DatetimeArgs { diff --git a/crates/polars-python/src/functions/mod.rs b/crates/polars-python/src/functions/mod.rs index 0bb5e55ea23c..ddf58c7acde6 100644 --- a/crates/polars-python/src/functions/mod.rs +++ b/crates/polars-python/src/functions/mod.rs @@ -8,6 +8,7 @@ mod misc; mod random; mod range; mod string_cache; +mod strings; mod whenthen; pub use aggregation::*; @@ -20,4 +21,5 @@ pub use misc::*; pub use random::*; pub use range::*; pub use string_cache::*; +pub use strings::*; pub use whenthen::*; diff --git a/crates/polars-python/src/functions/range.rs b/crates/polars-python/src/functions/range.rs index e6e421dac84c..b6eae4400dd8 100644 --- a/crates/polars-python/src/functions/range.rs +++ b/crates/polars-python/src/functions/range.rs @@ -71,12 +71,12 @@ pub fn date_range( end: PyExpr, interval: &str, closed: Wrap, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let interval = Duration::parse(interval); + let interval = Duration::try_parse(interval).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::date_range(start, end, interval, closed).into() + Ok(dsl::date_range(start, end, interval, closed).into()) } #[pyfunction] @@ -85,12 +85,12 @@ pub fn date_ranges( end: PyExpr, interval: &str, closed: Wrap, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let interval = Duration::parse(interval); + let interval = Duration::try_parse(interval).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::date_ranges(start, end, interval, closed).into() + Ok(dsl::date_ranges(start, end, interval, closed).into()) } #[pyfunction] @@ -102,14 +102,14 @@ pub fn datetime_range( closed: Wrap, time_unit: Option>, time_zone: Option>, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; let time_unit = time_unit.map(|x| x.0); let time_zone = time_zone.map(|x| x.0); - dsl::datetime_range(start, end, every, closed, time_unit, time_zone).into() + Ok(dsl::datetime_range(start, end, every, closed, time_unit, time_zone).into()) } #[pyfunction] @@ -121,30 +121,40 @@ pub fn datetime_ranges( closed: Wrap, time_unit: Option>, time_zone: Option>, -) -> PyExpr { +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; let time_unit = time_unit.map(|x| x.0); let time_zone = time_zone.map(|x| x.0); - dsl::datetime_ranges(start, end, every, closed, time_unit, time_zone).into() + Ok(dsl::datetime_ranges(start, end, every, closed, time_unit, time_zone).into()) } #[pyfunction] -pub fn time_range(start: PyExpr, end: PyExpr, every: &str, closed: Wrap) -> PyExpr { +pub fn time_range( + start: PyExpr, + end: PyExpr, + every: &str, + closed: Wrap, +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::time_range(start, end, every, closed).into() + Ok(dsl::time_range(start, end, every, closed).into()) } #[pyfunction] -pub fn time_ranges(start: PyExpr, end: PyExpr, every: &str, closed: Wrap) -> PyExpr { +pub fn time_ranges( + start: PyExpr, + end: PyExpr, + every: &str, + closed: Wrap, +) -> PyResult { let start = start.inner; let end = end.inner; - let every = Duration::parse(every); + let every = Duration::try_parse(every).map_err(PyPolarsErr::from)?; let closed = closed.0; - dsl::time_ranges(start, end, every, closed).into() + Ok(dsl::time_ranges(start, end, every, closed).into()) } diff --git a/crates/polars-python/src/functions/strings.rs b/crates/polars-python/src/functions/strings.rs new file mode 100644 index 000000000000..d75666ecf367 --- /dev/null +++ b/crates/polars-python/src/functions/strings.rs @@ -0,0 +1,7 @@ +use pyo3::prelude::*; + +#[pyfunction] +pub fn escape_regex(s: &str) -> PyResult { + let escaped_s = polars_ops::chunked_array::strings::escape_regex_str(s); + Ok(escaped_s) +} diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 905b04ee0b95..bbf2defd5d2f 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -40,7 +40,7 @@ impl PyLazyFrame { #[allow(clippy::too_many_arguments)] #[pyo3(signature = ( source, sources, infer_schema_length, schema, schema_overrides, batch_size, n_rows, low_memory, rechunk, - row_index, ignore_errors, include_file_paths, cloud_options, retries, file_cache_ttl + row_index, ignore_errors, include_file_paths, cloud_options, credential_provider, retries, file_cache_ttl ))] fn new_from_ndjson( source: Option, @@ -56,9 +56,11 @@ impl PyLazyFrame { ignore_errors: bool, include_file_paths: Option, cloud_options: Option>, + credential_provider: Option, retries: usize, file_cache_ttl: Option, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; let row_index = row_index.map(|(name, offset)| RowIndex { name: name.into(), offset, @@ -78,7 +80,11 @@ impl PyLazyFrame { let mut cloud_options = parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; - cloud_options = cloud_options.with_max_retries(retries); + cloud_options = cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ); if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; @@ -110,7 +116,7 @@ impl PyLazyFrame { low_memory, comment_prefix, quote_char, null_values, missing_utf8_is_empty_string, infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header, encoding, row_index, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, glob, schema, - cloud_options, retries, file_cache_ttl, include_file_paths + cloud_options, credential_provider, retries, file_cache_ttl, include_file_paths ) )] fn new_from_csv( @@ -142,10 +148,13 @@ impl PyLazyFrame { glob: bool, schema: Option>, cloud_options: Option>, + credential_provider: Option, retries: usize, file_cache_ttl: Option, include_file_paths: Option, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; + let null_values = null_values.map(|w| w.0); let quote_char = quote_char .map(|s| { @@ -197,7 +206,11 @@ impl PyLazyFrame { if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; } - cloud_options = cloud_options.with_max_retries(retries); + cloud_options = cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ); r = r.with_cloud_options(Some(cloud_options)); } @@ -256,9 +269,11 @@ impl PyLazyFrame { #[cfg(feature = "parquet")] #[staticmethod] - #[pyo3(signature = (source, sources, n_rows, cache, parallel, rechunk, row_index, - low_memory, cloud_options, use_statistics, hive_partitioning, schema, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns) - )] + #[pyo3(signature = ( + source, sources, n_rows, cache, parallel, rechunk, row_index, low_memory, cloud_options, + credential_provider, use_statistics, hive_partitioning, schema, hive_schema, + try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns, + ))] fn new_from_parquet( source: Option, sources: Wrap, @@ -269,6 +284,7 @@ impl PyLazyFrame { row_index: Option<(String, IdxSize)>, low_memory: bool, cloud_options: Option>, + credential_provider: Option, use_statistics: bool, hive_partitioning: Option, schema: Option>, @@ -279,6 +295,8 @@ impl PyLazyFrame { include_file_paths: Option, allow_missing_columns: bool, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; + let parallel = parallel.0; let hive_schema = hive_schema.map(|s| Arc::new(s.0)); @@ -321,7 +339,13 @@ impl PyLazyFrame { let first_path_url = first_path.to_string_lossy(); let cloud_options = parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; - args.cloud_options = Some(cloud_options.with_max_retries(retries)); + args.cloud_options = Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ); } let lf = LazyFrame::scan_parquet_sources(sources, args).map_err(PyPolarsErr::from)?; @@ -331,7 +355,11 @@ impl PyLazyFrame { #[cfg(feature = "ipc")] #[staticmethod] - #[pyo3(signature = (source, sources, n_rows, cache, rechunk, row_index, cloud_options, hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, include_file_paths))] + #[pyo3(signature = ( + source, sources, n_rows, cache, rechunk, row_index, cloud_options,credential_provider, + hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, + include_file_paths + ))] fn new_from_ipc( source: Option, sources: Wrap, @@ -340,6 +368,7 @@ impl PyLazyFrame { rechunk: bool, row_index: Option<(String, IdxSize)>, cloud_options: Option>, + credential_provider: Option, hive_partitioning: Option, hive_schema: Option>, try_parse_hive_dates: bool, @@ -347,6 +376,7 @@ impl PyLazyFrame { file_cache_ttl: Option, include_file_paths: Option, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; let row_index = row_index.map(|(name, offset)| RowIndex { name: name.into(), offset, @@ -385,7 +415,13 @@ impl PyLazyFrame { if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; } - args.cloud_options = Some(cloud_options.with_max_retries(retries)); + args.cloud_options = Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ); } let lf = LazyFrame::scan_ipc_sources(sources, args).map_err(PyPolarsErr::from)?; @@ -825,7 +861,7 @@ impl PyLazyFrame { offset: &str, closed: Wrap, by: Vec, - ) -> PyLazyGroupBy { + ) -> PyResult { let closed_window = closed.0; let ldf = self.ldf.clone(); let by = by @@ -837,13 +873,13 @@ impl PyLazyFrame { by, RollingGroupOptions { index_column: "".into(), - period: Duration::parse(period), - offset: Duration::parse(offset), + period: Duration::try_parse(period).map_err(PyPolarsErr::from)?, + offset: Duration::try_parse(offset).map_err(PyPolarsErr::from)?, closed_window, }, ); - PyLazyGroupBy { lgb: Some(lazy_gb) } + Ok(PyLazyGroupBy { lgb: Some(lazy_gb) }) } fn group_by_dynamic( @@ -857,7 +893,7 @@ impl PyLazyFrame { closed: Wrap, group_by: Vec, start_by: Wrap, - ) -> PyLazyGroupBy { + ) -> PyResult { let closed_window = closed.0; let group_by = group_by .into_iter() @@ -868,9 +904,9 @@ impl PyLazyFrame { index_column.inner, group_by, DynamicGroupOptions { - every: Duration::parse(every), - period: Duration::parse(period), - offset: Duration::parse(offset), + every: Duration::try_parse(every).map_err(PyPolarsErr::from)?, + period: Duration::try_parse(period).map_err(PyPolarsErr::from)?, + offset: Duration::try_parse(offset).map_err(PyPolarsErr::from)?, label: label.0, include_boundaries, closed_window, @@ -879,7 +915,7 @@ impl PyLazyFrame { }, ); - PyLazyGroupBy { lgb: Some(lazy_gb) } + Ok(PyLazyGroupBy { lgb: Some(lazy_gb) }) } fn with_context(&self, contexts: Vec) -> Self { @@ -1070,7 +1106,7 @@ impl PyLazyFrame { out.into() } - fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { let ldf = self.ldf.clone(); let out = ldf.quantile(quantile.inner, interpolation.0); out.into() diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 9e1bde005bd0..32fbeb07f683 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -56,7 +56,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (2, 3); + const VERSION: Version = (3, 1); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 67c25d755084..06a98e3fe970 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -2,9 +2,7 @@ use polars::datatypes::TimeUnit; #[cfg(feature = "iejoin")] use polars::prelude::InequalityOperator; use polars::series::ops::NullBehavior; -use polars_core::prelude::{NonExistent, QuantileInterpolOptions}; use polars_core::series::IsSorted; -use polars_ops::prelude::ClosedInterval; use polars_ops::series::InterpolationMethod; #[cfg(feature = "search_sorted")] use polars_ops::series::SearchSortedSide; @@ -16,6 +14,7 @@ use polars_plan::prelude::{ WindowMapping, WindowType, }; use polars_time::prelude::RollingGroupOptions; +use polars_time::{Duration, DynamicGroupOptions}; use pyo3::exceptions::PyNotImplementedError; use pyo3::prelude::*; @@ -44,20 +43,8 @@ pub struct Literal { dtype: PyObject, } -impl IntoPy for Wrap { - fn into_py(self, py: Python<'_>) -> PyObject { - match self.0 { - ClosedInterval::Both => "both", - ClosedInterval::Left => "left", - ClosedInterval::Right => "right", - ClosedInterval::None => "none", - } - .into_py(py) - } -} - -#[pyclass(name = "Operator", eq)] -#[derive(Copy, Clone, PartialEq)] +#[pyclass(name = "Operator")] +#[derive(Copy, Clone)] pub enum PyOperator { Eq, EqValidity, @@ -129,8 +116,8 @@ impl IntoPy for Wrap { } } -#[pyclass(name = "StringFunction", eq)] -#[derive(Copy, Clone, PartialEq)] +#[pyclass(name = "StringFunction")] +#[derive(Copy, Clone)] pub enum PyStringFunction { ConcatHorizontal, ConcatVertical, @@ -174,6 +161,7 @@ pub enum PyStringFunction { ZFill, ContainsMany, ReplaceMany, + EscapeRegex, } #[pymethods] @@ -183,8 +171,8 @@ impl PyStringFunction { } } -#[pyclass(name = "BooleanFunction", eq)] -#[derive(Copy, Clone, PartialEq)] +#[pyclass(name = "BooleanFunction")] +#[derive(Copy, Clone)] pub enum PyBooleanFunction { Any, All, @@ -212,8 +200,8 @@ impl PyBooleanFunction { } } -#[pyclass(name = "TemporalFunction", eq)] -#[derive(Copy, Clone, PartialEq)] +#[pyclass(name = "TemporalFunction")] +#[derive(Copy, Clone)] pub enum PyTemporalFunction { Millennium, Century, @@ -403,15 +391,25 @@ pub struct PyWindowMapping { impl PyWindowMapping { #[getter] fn kind(&self, py: Python<'_>) -> PyResult { - let result = match self.inner { - WindowMapping::GroupsToRows => "groups_to_rows".to_object(py), - WindowMapping::Explode => "explode".to_object(py), - WindowMapping::Join => "join".to_object(py), - }; + let result: &str = self.inner.into(); Ok(result.into_py(py)) } } +impl IntoPy for Wrap { + fn into_py(self, py: Python<'_>) -> PyObject { + ( + self.0.months(), + self.0.weeks(), + self.0.days(), + self.0.nanoseconds(), + self.0.parsed_int, + self.0.negative(), + ) + .into_py(py) + } +} + #[pyclass(name = "RollingGroupOptions")] pub struct PyRollingGroupOptions { inner: RollingGroupOptions, @@ -426,41 +424,68 @@ impl PyRollingGroupOptions { #[getter] fn period(&self, py: Python<'_>) -> PyResult { - let result = vec![ - self.inner.period.months().to_object(py), - self.inner.period.weeks().to_object(py), - self.inner.period.days().to_object(py), - self.inner.period.nanoseconds().to_object(py), - self.inner.period.parsed_int.to_object(py), - self.inner.period.negative().to_object(py), - ] - .into_py(py); - Ok(result) + Ok(Wrap(self.inner.period).into_py(py)) } #[getter] fn offset(&self, py: Python<'_>) -> PyResult { - let result = vec![ - self.inner.offset.months().to_object(py), - self.inner.offset.weeks().to_object(py), - self.inner.offset.days().to_object(py), - self.inner.offset.nanoseconds().to_object(py), - self.inner.offset.parsed_int.to_object(py), - self.inner.offset.negative().to_object(py), - ] - .into_py(py); - Ok(result) + Ok(Wrap(self.inner.offset).into_py(py)) } #[getter] fn closed_window(&self, py: Python<'_>) -> PyResult { - let result = match self.inner.closed_window { - polars::time::ClosedWindow::Left => "left".to_object(py), - polars::time::ClosedWindow::Right => "right".to_object(py), - polars::time::ClosedWindow::Both => "both".to_object(py), - polars::time::ClosedWindow::None => "none".to_object(py), - }; - Ok(result.into_py(py)) + let result: &str = self.inner.closed_window.into(); + Ok(result.to_object(py)) + } +} + +#[pyclass(name = "DynamicGroupOptions")] +pub struct PyDynamicGroupOptions { + inner: DynamicGroupOptions, +} + +#[pymethods] +impl PyDynamicGroupOptions { + #[getter] + fn index_column(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.index_column.to_object(py)) + } + + #[getter] + fn every(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.every).into_py(py)) + } + + #[getter] + fn period(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.period).into_py(py)) + } + + #[getter] + fn offset(&self, py: Python<'_>) -> PyResult { + Ok(Wrap(self.inner.offset).into_py(py)) + } + + #[getter] + fn label(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.label.into(); + Ok(result.to_object(py)) + } + + #[getter] + fn include_boundaries(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.include_boundaries.into_py(py)) + } + + #[getter] + fn closed_window(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.closed_window.into(); + Ok(result.to_object(py)) + } + #[getter] + fn start_by(&self, py: Python<'_>) -> PyResult { + let result: &str = self.inner.start_by.into(); + Ok(result.to_object(py)) } } @@ -485,6 +510,14 @@ impl PyGroupbyOptions { .map_or_else(|| py.None(), |f| f.to_object(py))) } + #[getter] + fn dynamic(&self, py: Python<'_>) -> PyResult { + Ok(self.inner.dynamic.as_ref().map_or_else( + || py.None(), + |f| PyDynamicGroupOptions { inner: f.clone() }.into_py(py), + )) + } + #[getter] fn rolling(&self, py: Python<'_>) -> PyResult { Ok(self.inner.rolling.as_ref().map_or_else( @@ -700,18 +733,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Quantile { expr, quantile, - interpol, + method: interpol, } => Agg { name: "quantile".to_object(py), arguments: vec![expr.0, quantile.0], - options: match interpol { - QuantileInterpolOptions::Nearest => "nearest", - QuantileInterpolOptions::Lower => "lower", - QuantileInterpolOptions::Higher => "higher", - QuantileInterpolOptions::Midpoint => "midpoint", - QuantileInterpolOptions::Linear => "linear", - } - .to_object(py), + options: Into::<&str>::into(interpol).to_object(py), }, IRAggExpr::Sum(n) => Agg { name: "sum".to_object(py), @@ -741,12 +767,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Bitwise(n, f) => Agg { name: "bitwise".to_object(py), arguments: vec![n.0], - options: match f { - polars::prelude::BitwiseAggFunction::And => "and", - polars::prelude::BitwiseAggFunction::Or => "or", - polars::prelude::BitwiseAggFunction::Xor => "xor", - } - .to_object(py), + options: Into::<&str>::into(f).to_object(py), }, } .into_py(py), @@ -952,6 +973,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::ExtractMany { .. } => { return Err(PyNotImplementedError::new_err("extract_many")) }, + StringFunction::EscapeRegex => { + (PyStringFunction::EscapeRegex.into_py(py),).to_object(py) + }, }, FunctionExpr::StructExpr(_) => { return Err(PyNotImplementedError::new_err("struct expr")) @@ -1030,10 +1054,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { time_zone .as_ref() .map_or_else(|| py.None(), |s| s.to_object(py)), - match non_existent { - NonExistent::Null => "nullify", - NonExistent::Raise => "raise", - }, + Into::<&str>::into(non_existent), ) .into_py(py), TemporalFunction::Combine(time_unit) => { @@ -1073,7 +1094,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { BooleanFunction::IsUnique => (PyBooleanFunction::IsUnique,).into_py(py), BooleanFunction::IsDuplicated => (PyBooleanFunction::IsDuplicated,).into_py(py), BooleanFunction::IsBetween { closed } => { - (PyBooleanFunction::IsBetween, Wrap(*closed)).into_py(py) + (PyBooleanFunction::IsBetween, Into::<&str>::into(closed)).into_py(py) }, #[cfg(feature = "is_in")] BooleanFunction::IsIn => (PyBooleanFunction::IsIn,).into_py(py), @@ -1166,6 +1187,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { RollingFunction::Skew(_, _) => { return Err(PyNotImplementedError::new_err("rolling skew")) }, + RollingFunction::CorrCov { .. } => { + return Err(PyNotImplementedError::new_err("rolling cor_cov")) + }, }, FunctionExpr::RollingExprBy(rolling) => match rolling { RollingFunctionBy::MinBy(_) => { diff --git a/crates/polars-python/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs index ba6b7ffefa95..71f47921f626 100644 --- a/crates/polars-python/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -1,4 +1,4 @@ -use polars_core::prelude::{IdxSize, UniqueKeepStrategy}; +use polars_core::prelude::IdxSize; use polars_ops::prelude::JoinType; use polars_plan::plans::IR; use polars_plan::prelude::{ @@ -455,7 +455,6 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { ))) })?, maintain_order: *maintain_order, - // TODO: dynamic options options: PyGroupbyOptions::new(options.as_ref().clone()).into_py(py), } .into_py(py), @@ -473,23 +472,16 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { right_on: right_on.iter().map(|e| e.into()).collect(), options: { let how = &options.args.how; - + let name = Into::<&str>::into(how).to_object(py); ( match how { - JoinType::Left => "left".to_object(py), - JoinType::Right => "right".to_object(py), - JoinType::Inner => "inner".to_object(py), - JoinType::Full => "full".to_object(py), #[cfg(feature = "asof_join")] JoinType::AsOf(_) => { return Err(PyNotImplementedError::new_err("asof join")) }, - JoinType::Cross => "cross".to_object(py), - JoinType::Semi => "leftsemi".to_object(py), - JoinType::Anti => "leftanti".to_object(py), #[cfg(feature = "iejoin")] JoinType::IEJoin(ie_options) => ( - "inequality".to_object(py), + name, crate::Wrap(ie_options.operator1).into_py(py), ie_options .operator2 @@ -497,6 +489,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { .map_or_else(|| py.None(), |op| crate::Wrap(*op).into_py(py)), ) .into_py(py), + _ => name, }, options.args.join_nulls, options.args.slice, @@ -530,12 +523,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { IR::Distinct { input, options } => Distinct { input: input.0, options: ( - match options.keep_strategy { - UniqueKeepStrategy::First => "first", - UniqueKeepStrategy::Last => "last", - UniqueKeepStrategy::None => "none", - UniqueKeepStrategy::Any => "any", - }, + Into::<&str>::into(options.keep_strategy), options.subset.as_ref().map_or_else( || py.None(), |f| { diff --git a/crates/polars-python/src/on_startup.rs b/crates/polars-python/src/on_startup.rs index 9b6f17d46f72..8c0b4275b4ba 100644 --- a/crates/polars-python/src/on_startup.rs +++ b/crates/polars-python/src/on_startup.rs @@ -88,7 +88,7 @@ pub fn register_startup_deps() { Box::new(object) as Box }); - let object_size = std::mem::size_of::(); + let object_size = size_of::(); let physical_dtype = ArrowDataType::FixedSizeBinary(object_size); registry::register_object_builder(object_builder, object_converter, physical_dtype); // register SERIES UDF diff --git a/crates/polars-python/src/series/aggregation.rs b/crates/polars-python/src/series/aggregation.rs index dbcbad59ddac..5aa8ee16639e 100644 --- a/crates/polars-python/src/series/aggregation.rs +++ b/crates/polars-python/src/series/aggregation.rs @@ -105,11 +105,7 @@ impl PySeries { .into_py(py)) } - fn quantile( - &self, - quantile: f64, - interpolation: Wrap, - ) -> PyResult { + fn quantile(&self, quantile: f64, interpolation: Wrap) -> PyResult { let bind = self.series.quantile_reduce(quantile, interpolation.0); let sc = bind.map_err(PyPolarsErr::from)?; diff --git a/crates/polars-python/src/series/numpy_ufunc.rs b/crates/polars-python/src/series/numpy_ufunc.rs index 10d765c3fc25..5df438686447 100644 --- a/crates/polars-python/src/series/numpy_ufunc.rs +++ b/crates/polars-python/src/series/numpy_ufunc.rs @@ -1,4 +1,4 @@ -use std::{mem, ptr}; +use std::ptr; use ndarray::IntoDimension; use numpy::npyffi::types::npy_intp; @@ -30,7 +30,7 @@ unsafe fn aligned_array( let buffer_ptr = buf.as_mut_ptr(); let mut dims = [len].into_dimension(); - let strides = [mem::size_of::() as npy_intp]; + let strides = [size_of::() as npy_intp]; let ptr = PY_ARRAY_API.PyArray_NewFromDescr( py, diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index c60e614350cd..04c320e33ec3 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -66,6 +66,9 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataTyp .collect(); StructArray::new(dtype.clone(), rows.len(), values, None).to_boxed() }, + ArrowDataType::List { .. } | ArrowDataType::LargeList { .. } => { + todo!("list decoding is not yet supported in polars' row encoding") + }, dt => { with_match_arrow_primitive_type!(dt, |$T| { decode_primitive::<$T>(rows, field).to_boxed() diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 363a43b6a9e8..fa9699c72f1b 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -230,7 +230,7 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE match encoder { Encoder::List { .. } => { let iter = encoder.list_iter(); - crate::variable::encode_iter(iter, out, &EncodingField::new_unsorted()) + crate::variable::encode_iter(iter, out, field) }, Encoder::Leaf(array) => { match array.dtype() { @@ -260,6 +260,7 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE .map(|opt_s| opt_s.map(|s| s.as_bytes())); crate::variable::encode_iter(iter, out, field) }, + ArrowDataType::Null => {}, // No output needed. dt => { with_match_arrow_primitive_type!(dt, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); @@ -286,6 +287,7 @@ pub fn encoded_size(dtype: &ArrowDataType) -> usize { Float32 => f32::ENCODED_LEN, Float64 => f64::ENCODED_LEN, Boolean => bool::ENCODED_LEN, + Null => 0, dt => unimplemented!("{dt:?}"), } } @@ -371,20 +373,13 @@ fn allocate_rows_buf( for opt_val in iter { unsafe { lengths.push_unchecked( - row_size_fixed - + crate::variable::encoded_len( - opt_val, - &EncodingField::new_unsorted(), - ), + row_size_fixed + crate::variable::encoded_len(opt_val, &field), ); } } } else { for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len( - opt_val, - &EncodingField::new_unsorted(), - ) + *row_length += crate::variable::encoded_len(opt_val, &field) } } processed_count += 1; @@ -637,7 +632,7 @@ mod test { let out = out.into_array(); assert_eq!( out.values().iter().map(|v| *v as usize).sum::(), - 82411 + 42774 ); } } diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index 7932420d1577..89f904e54296 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -25,7 +25,7 @@ impl FromSlice for [u8; N] { pub trait FixedLengthEncoding: Copy + Debug { // 1 is validity 0 or 1 // bit repr of encoding - const ENCODED_LEN: usize = 1 + std::mem::size_of::(); + const ENCODED_LEN: usize = 1 + size_of::(); type Encoded: Sized + Copy + AsRef<[u8]> + AsMut<[u8]>; diff --git a/crates/polars-row/src/lib.rs b/crates/polars-row/src/lib.rs index 823e5c6e4566..ddfbae9ea52b 100644 --- a/crates/polars-row/src/lib.rs +++ b/crates/polars-row/src/lib.rs @@ -120,6 +120,10 @@ //! This approach is loosely inspired by [COBS] encoding, and chosen over more traditional //! [byte stuffing] as it is more amenable to vectorisation, in particular AVX-256. //! +//! For the unordered row encoding we use a simpler scheme, we prepend the length +//! encoded as 4 bytes followed by the raw data, with nulls being marked with a +//! length of u32::MAX. +//! //! ## Dictionary Encoding //! //! [`RowsEncoded`] needs to support converting dictionary encoded arrays with unsorted, and diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index d48f6f51c205..1aa50e8b0e43 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -40,8 +40,8 @@ pub struct RowsEncoded { fn checks(offsets: &[usize]) { assert_eq!( - std::mem::size_of::(), - std::mem::size_of::(), + size_of::(), + size_of::(), "only supported on 64bit arch" ); assert!( diff --git a/crates/polars-row/src/variable.rs b/crates/polars-row/src/variable.rs index 5032e41085a8..2ccdf0f686f3 100644 --- a/crates/polars-row/src/variable.rs +++ b/crates/polars-row/src/variable.rs @@ -13,6 +13,7 @@ use std::mem::MaybeUninit; use arrow::array::{BinaryArray, BinaryViewArray, MutableBinaryViewArray}; +use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; use arrow::offset::Offsets; use polars_utils::slice::{GetSaferUnchecked, Slice2Uninit}; @@ -46,30 +47,12 @@ fn padded_length(a: usize) -> usize { 1 + ceil(a, BLOCK_SIZE) * (BLOCK_SIZE + 1) } -#[inline] -fn padded_length_opt(a: Option) -> usize { - if let Some(a) = a { - padded_length(a) - } else { - 1 - } -} - -#[inline] -fn length_opt(a: Option) -> usize { - if let Some(a) = a { - 1 + a - } else { - 1 - } -} - #[inline] pub fn encoded_len(a: Option<&[u8]>, field: &EncodingField) -> usize { if field.no_order { - length_opt(a.map(|v| v.len())) + 4 + a.map(|v| v.len()).unwrap_or(0) } else { - padded_length_opt(a.map(|v| v.len())) + a.map(|v| padded_length(v.len())).unwrap_or(1) } } @@ -78,30 +61,19 @@ unsafe fn encode_one_no_order( val: Option<&[MaybeUninit]>, field: &EncodingField, ) -> usize { + debug_assert!(field.no_order); match val { - Some([]) => { - let byte = if field.descending { - !EMPTY_SENTINEL - } else { - EMPTY_SENTINEL - }; - *out.get_unchecked_release_mut(0) = MaybeUninit::new(byte); - 1 - }, Some(val) => { - let end_offset = 1 + val.len(); - - // Write `2_u8` to demarcate as non-empty, non-null string - *out.get_unchecked_release_mut(0) = MaybeUninit::new(NON_EMPTY_SENTINEL); - std::ptr::copy_nonoverlapping(val.as_ptr(), out.as_mut_ptr().add(1), val.len()); - - end_offset + assert!(val.len() < u32::MAX as usize); + let encoded_len = (val.len() as u32).to_le_bytes().map(MaybeUninit::new); + std::ptr::copy_nonoverlapping(encoded_len.as_ptr(), out.as_mut_ptr(), 4); + std::ptr::copy_nonoverlapping(val.as_ptr(), out.as_mut_ptr().add(4), val.len()); + 4 + val.len() }, None => { - *out.get_unchecked_release_mut(0) = MaybeUninit::new(get_null_sentinel(field)); - // // write remainder as zeros - // out.get_unchecked_release_mut(1..).fill(MaybeUninit::new(0)); - 1 + let sentinel = u32::MAX.to_le_bytes().map(MaybeUninit::new); + std::ptr::copy_nonoverlapping(sentinel.as_ptr(), out.as_mut_ptr(), 4); + 4 }, } } @@ -258,7 +230,63 @@ unsafe fn decoded_len( } } +unsafe fn decoded_len_unordered(row: &[u8]) -> Option { + let len = u32::from_le_bytes(row.get_unchecked(0..4).try_into().unwrap()); + Some(len).filter(|l| *l < u32::MAX) +} + +unsafe fn decode_binary_unordered(rows: &mut [&[u8]]) -> BinaryArray { + let mut has_nulls = false; + let mut total_len = 0; + for row in rows.iter() { + if let Some(len) = decoded_len_unordered(row) { + total_len += len as usize; + } else { + has_nulls = true; + } + } + + let validity = has_nulls.then(|| { + Bitmap::from_trusted_len_iter_unchecked( + rows.iter().map(|row| decoded_len_unordered(row).is_none()), + ) + }); + + let mut values = Vec::with_capacity(total_len); + let mut offsets = Vec::with_capacity(rows.len() + 1); + offsets.push(0); + for row in rows.iter_mut() { + let len = decoded_len_unordered(row).unwrap_or(0) as usize; + values.extend_from_slice(row.get_unchecked(4..4 + len)); + *row = row.get_unchecked(4 + len..); + offsets.push(values.len() as i64); + } + BinaryArray::new( + ArrowDataType::LargeBinary, + Offsets::new_unchecked(offsets).into(), + values.into(), + validity, + ) +} + +unsafe fn decode_binview_unordered(rows: &mut [&[u8]]) -> BinaryViewArray { + let mut mutable = MutableBinaryViewArray::with_capacity(rows.len()); + for row in rows.iter_mut() { + if let Some(len) = decoded_len_unordered(row) { + mutable.push_value(row.get_unchecked(4..4 + len as usize)); + *row = row.get_unchecked(4 + len as usize..); + } else { + mutable.push_null(); + } + } + mutable.freeze() +} + pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &EncodingField) -> BinaryArray { + if field.no_order { + return decode_binary_unordered(rows); + } + let (non_empty_sentinel, continuation_token) = if field.descending { (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION_TOKEN) } else { @@ -330,6 +358,10 @@ pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &EncodingField) -> } pub(super) unsafe fn decode_binview(rows: &mut [&[u8]], field: &EncodingField) -> BinaryViewArray { + if field.no_order { + return decode_binview_unordered(rows); + } + let (non_empty_sentinel, continuation_token) = if field.descending { (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION_TOKEN) } else { diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 9db54d1c3333..b5b875403b6d 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -24,7 +24,6 @@ rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlparser = { workspace = true } -# sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs.git", rev = "ae3b5844c839072c235965fe0d1bddc473dced87" } [dev-dependencies] # to display dataframes in case of test failures @@ -34,6 +33,7 @@ polars-core = { workspace = true, features = ["fmt"] } default = [] nightly = [] binary_encoding = ["polars-lazy/binary_encoding"] +bitwise = ["polars-lazy/bitwise"] csv = ["polars-lazy/csv"] diagonal_concat = ["polars-lazy/diagonal_concat"] dtype-decimal = ["polars-lazy/dtype-decimal"] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 342a5e0883d2..1a060545439c 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -709,6 +709,11 @@ impl SQLContext { }; lf = if group_by_keys.is_empty() { + // The 'having' clause is only valid inside 'group by' + if select_stmt.having.is_some() { + polars_bail!(SQLSyntax: "HAVING clause not valid outside of GROUP BY; found:\n{:?}", select_stmt.having); + }; + // Final/selected cols, accounting for 'SELECT *' modifiers let mut retained_cols = Vec::with_capacity(projections.len()); let have_order_by = query.order_by.is_some(); diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 48011a323764..ff1a926f27e3 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -3,7 +3,7 @@ use std::ops::Sub; use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::export::regex; use polars_core::prelude::{ - polars_bail, polars_err, DataType, PolarsResult, QuantileInterpolOptions, Schema, TimeUnit, + polars_bail, polars_err, DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, }; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] @@ -30,11 +30,40 @@ pub(crate) struct SQLFunctionVisitor<'a> { /// SQL functions that are supported by Polars pub(crate) enum PolarsSQLFunctions { + // ---- + // Bitwise functions + // ---- + /// SQL 'bit_and' function. + /// Returns the bitwise AND of the input expressions. + /// ```sql + /// SELECT BIT_AND(column_1, column_2) FROM df; + /// ``` + BitAnd, + /// SQL 'bit_count' function. + /// Returns the number of set bits in the input expression. + /// ```sql + /// SELECT BIT_COUNT(column_1) FROM df; + /// ``` + #[cfg(feature = "bitwise")] + BitCount, + /// SQL 'bit_or' function. + /// Returns the bitwise OR of the input expressions. + /// ```sql + /// SELECT BIT_OR(column_1, column_2) FROM df; + /// ``` + BitOr, + /// SQL 'bit_xor' function. + /// Returns the bitwise XOR of the input expressions. + /// ```sql + /// SELECT BIT_XOR(column_1, column_2) FROM df; + /// ``` + BitXor, + // ---- // Math functions // ---- /// SQL 'abs' function - /// Returns the absolute value of the input column. + /// Returns the absolute value of the input expression. /// ```sql /// SELECT ABS(column_1) FROM df; /// ``` @@ -142,97 +171,97 @@ pub(crate) enum PolarsSQLFunctions { // Trig functions // ---- /// SQL 'cos' function - /// Compute the cosine sine of the input column (in radians). + /// Compute the cosine sine of the input expression (in radians). /// ```sql /// SELECT COS(column_1) FROM df; /// ``` Cos, /// SQL 'cot' function - /// Compute the cotangent of the input column (in radians). + /// Compute the cotangent of the input expression (in radians). /// ```sql /// SELECT COT(column_1) FROM df; /// ``` Cot, /// SQL 'sin' function - /// Compute the sine of the input column (in radians). + /// Compute the sine of the input expression (in radians). /// ```sql /// SELECT SIN(column_1) FROM df; /// ``` Sin, /// SQL 'tan' function - /// Compute the tangent of the input column (in radians). + /// Compute the tangent of the input expression (in radians). /// ```sql /// SELECT TAN(column_1) FROM df; /// ``` Tan, /// SQL 'cosd' function - /// Compute the cosine sine of the input column (in degrees). + /// Compute the cosine sine of the input expression (in degrees). /// ```sql /// SELECT COSD(column_1) FROM df; /// ``` CosD, /// SQL 'cotd' function - /// Compute cotangent of the input column (in degrees). + /// Compute cotangent of the input expression (in degrees). /// ```sql /// SELECT COTD(column_1) FROM df; /// ``` CotD, /// SQL 'sind' function - /// Compute the sine of the input column (in degrees). + /// Compute the sine of the input expression (in degrees). /// ```sql /// SELECT SIND(column_1) FROM df; /// ``` SinD, /// SQL 'tand' function - /// Compute the tangent of the input column (in degrees). + /// Compute the tangent of the input expression (in degrees). /// ```sql /// SELECT TAND(column_1) FROM df; /// ``` TanD, /// SQL 'acos' function - /// Compute inverse cosinus of the input column (in radians). + /// Compute inverse cosinus of the input expression (in radians). /// ```sql /// SELECT ACOS(column_1) FROM df; /// ``` Acos, /// SQL 'asin' function - /// Compute inverse sine of the input column (in radians). + /// Compute inverse sine of the input expression (in radians). /// ```sql /// SELECT ASIN(column_1) FROM df; /// ``` Asin, /// SQL 'atan' function - /// Compute inverse tangent of the input column (in radians). + /// Compute inverse tangent of the input expression (in radians). /// ```sql /// SELECT ATAN(column_1) FROM df; /// ``` Atan, /// SQL 'atan2' function - /// Compute the inverse tangent of column_2/column_1 (in radians). + /// Compute the inverse tangent of column_1/column_2 (in radians). /// ```sql /// SELECT ATAN2(column_1, column_2) FROM df; /// ``` Atan2, /// SQL 'acosd' function - /// Compute inverse cosinus of the input column (in degrees). + /// Compute inverse cosinus of the input expression (in degrees). /// ```sql /// SELECT ACOSD(column_1) FROM df; /// ``` AcosD, /// SQL 'asind' function - /// Compute inverse sine of the input column (in degrees). + /// Compute inverse sine of the input expression (in degrees). /// ```sql /// SELECT ASIND(column_1) FROM df; /// ``` AsinD, /// SQL 'atand' function - /// Compute inverse tangent of the input column (in degrees). + /// Compute inverse tangent of the input expression (in degrees). /// ```sql /// SELECT ATAND(column_1) FROM df; /// ``` AtanD, /// SQL 'atan2d' function - /// Compute the inverse tangent of column_2/column_1 (in degrees). + /// Compute the inverse tangent of column_1/column_2 (in degrees). /// ```sql /// SELECT ATAN2D(column_1) FROM df; /// ``` @@ -513,6 +542,13 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT QUANTILE_CONT(column_1) FROM df; /// ``` QuantileCont, + /// SQL 'quantile_disc' function + /// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, + /// and returns the value associated with the subinterval where the quantile value falls. + /// ```sql + /// SELECT QUANTILE_DISC(column_1) FROM df; + /// ``` + QuantileDisc, /// SQL 'min' function /// Returns the smallest (minimum) of all the elements in the grouping. /// ```sql @@ -649,7 +685,11 @@ impl PolarsSQLFunctions { "atan2d", "atand", "avg", + "bit_and", + "bit_count", "bit_length", + "bit_or", + "bit_xor", "cbrt", "ceil", "ceiling", @@ -688,6 +728,7 @@ impl PolarsSQLFunctions { "ltrim", "max", "median", + "quantile_disc", "min", "mod", "nullif", @@ -696,6 +737,7 @@ impl PolarsSQLFunctions { "pow", "power", "quantile_cont", + "quantile_disc", "radians", "regexp_like", "replace", @@ -732,6 +774,15 @@ impl PolarsSQLFunctions { fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult { let function_name = function.name.0[0].value.to_lowercase(); Ok(match function_name.as_str() { + // ---- + // Bitwise functions + // ---- + "bit_and" | "bitand" => Self::BitAnd, + #[cfg(feature = "bitwise")] + "bit_count" | "bitcount" => Self::BitCount, + "bit_or" | "bitor" => Self::BitOr, + "bit_xor" | "bitxor" | "xor" => Self::BitXor, + // ---- // Math functions // ---- @@ -829,6 +880,7 @@ impl PolarsSQLFunctions { "max" => Self::Max, "median" => Self::Median, "quantile_cont" => Self::QuantileCont, + "quantile_disc" => Self::QuantileDisc, "min" => Self::Min, "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev, "sum" => Self::Sum, @@ -884,6 +936,15 @@ impl SQLFunctionVisitor<'_> { } match function_name { + // ---- + // Bitwise functions + // ---- + BitAnd => self.visit_binary::(Expr::and), + #[cfg(feature = "bitwise")] + BitCount => self.visit_unary(Expr::bitwise_count_ones), + BitOr => self.visit_binary::(Expr::or), + BitXor => self.visit_binary::(Expr::xor), + // ---- // Math functions // ---- @@ -1275,11 +1336,37 @@ impl SQLFunctionVisitor<'_> { }, _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1]) }; - Ok(e.quantile(value, QuantileInterpolOptions::Linear)) + Ok(e.quantile(value, QuantileMethod::Linear)) }), _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()), } }, + QuantileDisc => { + let args = extract_args(function)?; + match args.len() { + 2 => self.try_visit_binary(|e, q| { + let value = match q { + Expr::Literal(LiteralValue::Float(f)) => { + if (0.0..=1.0).contains(&f) { + Expr::from(f) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + Expr::Literal(LiteralValue::Int(n)) => { + if (0..=1).contains(&n) { + Expr::from(n as f64) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1]) + }; + Ok(e.quantile(value, QuantileMethod::Equiprobable)) + }), + _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()), + } + }, Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min), StdDev => self.visit_unary(|e| e.std(1)), Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum), diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index f9caa288cb82..5eb2bdd843b4 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -374,10 +374,9 @@ impl SQLExprVisitor<'_> { }, // identify "CAST(expr AS type) string" and/or "expr::type string" expressions (Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => { - if let Expr::Column(name) = &**expr { - (Some(name.clone()), Some(s), Some(dtype)) - } else { - (None, Some(s), Some(dtype)) + match &**expr { + Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)), + _ => (None, Some(s), Some(dtype)), } }, _ => (None, None, None), @@ -385,23 +384,25 @@ impl SQLExprVisitor<'_> { if expr_dtype.is_none() && self.active_schema.is_none() { right.clone() } else { - let left_dtype = expr_dtype - .unwrap_or_else(|| self.active_schema.as_ref().unwrap().get(&name).unwrap()); - + let left_dtype = expr_dtype.or_else(|| { + self.active_schema + .as_ref() + .and_then(|schema| schema.get(&name)) + }); match left_dtype { - DataType::Time if is_iso_time(s) => { + Some(DataType::Time) if is_iso_time(s) => { right.clone().str().to_time(StrptimeOptions { strict: true, ..Default::default() }) }, - DataType::Date if is_iso_date(s) => { + Some(DataType::Date) if is_iso_date(s) => { right.clone().str().to_date(StrptimeOptions { strict: true, ..Default::default() }) }, - DataType::Datetime(tu, tz) if is_iso_datetime(s) || is_iso_date(s) => { + Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => { if s.len() == 10 { // handle upcast from ISO date string (10 chars) to datetime lit(format!("{}T00:00:00", s)) @@ -469,48 +470,53 @@ impl SQLExprVisitor<'_> { rhs = self.convert_temporal_strings(&lhs, &rhs); Ok(match op { - SQLBinaryOperator::And => lhs.and(rhs), - SQLBinaryOperator::Divide => lhs / rhs, - SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), - SQLBinaryOperator::Eq => lhs.eq(rhs), - SQLBinaryOperator::Gt => lhs.gt(rhs), - SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), - SQLBinaryOperator::Lt => lhs.lt(rhs), - SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), - SQLBinaryOperator::Minus => lhs - rhs, - SQLBinaryOperator::Modulo => lhs % rhs, - SQLBinaryOperator::Multiply => lhs * rhs, - SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), - SQLBinaryOperator::Or => lhs.or(rhs), - SQLBinaryOperator::Plus => lhs + rhs, - SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), - SQLBinaryOperator::StringConcat => { + // ---- + // Bitwise operators + // ---- + SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), // "x & y" + SQLBinaryOperator::BitwiseOr => lhs.or(rhs), // "x | y" + SQLBinaryOperator::Xor => lhs.xor(rhs), // "x XOR y" + + // ---- + // General operators + // ---- + SQLBinaryOperator::And => lhs.and(rhs), // "x AND y" + SQLBinaryOperator::Divide => lhs / rhs, // "x / y" + SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), // "x // y" + SQLBinaryOperator::Eq => lhs.eq(rhs), // "x = y" + SQLBinaryOperator::Gt => lhs.gt(rhs), // "x > y" + SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), // "x >= y" + SQLBinaryOperator::Lt => lhs.lt(rhs), // "x < y" + SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), // "x <= y" + SQLBinaryOperator::Minus => lhs - rhs, // "x - y" + SQLBinaryOperator::Modulo => lhs % rhs, // "x % y" + SQLBinaryOperator::Multiply => lhs * rhs, // "x * y" + SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), // "x != y" + SQLBinaryOperator::Or => lhs.or(rhs), // "x OR y" + SQLBinaryOperator::Plus => lhs + rhs, // "x + y" + SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), // "x <=> y" + SQLBinaryOperator::StringConcat => { // "x || y" lhs.cast(DataType::String) + rhs.cast(DataType::String) }, - SQLBinaryOperator::Xor => lhs.xor(rhs), - SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), + SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), // "x ^@ y" // ---- // Regular expression operators // ---- - // "a ~ b" - SQLBinaryOperator::PGRegexMatch => match rhs { + SQLBinaryOperator::PGRegexMatch => match rhs { // "x ~ y" Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true), _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs), }, - // "a !~ b" - SQLBinaryOperator::PGRegexNotMatch => match rhs { + SQLBinaryOperator::PGRegexNotMatch => match rhs { // "x !~ y" Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(), _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs), }, - // "a ~* b" - SQLBinaryOperator::PGRegexIMatch => match rhs { + SQLBinaryOperator::PGRegexIMatch => match rhs { // "x ~* y" Expr::Literal(LiteralValue::String(pat)) => { lhs.str().contains(lit(format!("(?i){}", pat)), true) }, _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs), }, - // "a !~* b" - SQLBinaryOperator::PGRegexNotIMatch => match rhs { + SQLBinaryOperator::PGRegexNotIMatch => match rhs { // "x !~* y" Expr::Literal(LiteralValue::String(pat)) => { lhs.str().contains(lit(format!("(?i){}", pat)), true).not() }, @@ -521,10 +527,10 @@ impl SQLExprVisitor<'_> { // ---- // LIKE/ILIKE operators // ---- - SQLBinaryOperator::PGLikeMatch - | SQLBinaryOperator::PGNotLikeMatch - | SQLBinaryOperator::PGILikeMatch - | SQLBinaryOperator::PGNotILikeMatch => { + SQLBinaryOperator::PGLikeMatch // "x ~~ y" + | SQLBinaryOperator::PGNotLikeMatch // "x !~~ y" + | SQLBinaryOperator::PGILikeMatch // "x ~~* y" + | SQLBinaryOperator::PGNotILikeMatch => { // "x !~~* y" let expr = if matches!( op, SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch @@ -548,7 +554,7 @@ impl SQLExprVisitor<'_> { // ---- // JSON/Struct field access operators // ---- - SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { + SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { // "x -> y", "x ->> y" Expr::Literal(LiteralValue::String(path)) => { let mut expr = self.struct_field_access_expr(&lhs, &path, false)?; if let SQLBinaryOperator::LongArrow = op { @@ -567,7 +573,7 @@ impl SQLExprVisitor<'_> { polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right) }, }, - SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { + SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { // "x #> y", "x #>> y" if let Expr::Literal(LiteralValue::String(path)) = rhs { let mut expr = self.struct_field_access_expr(&lhs, &path, true)?; if let SQLBinaryOperator::HashLongArrow = op { diff --git a/crates/polars-sql/tests/functions_aggregate.rs b/crates/polars-sql/tests/functions_aggregate.rs index 621ca18bd355..092a340f5f18 100644 --- a/crates/polars-sql/tests/functions_aggregate.rs +++ b/crates/polars-sql/tests/functions_aggregate.rs @@ -5,9 +5,7 @@ use polars_sql::*; fn create_df() -> LazyFrame { df! { - "Year" => [2018, 2018, 2019, 2019, 2020, 2020], - "Country" => ["US", "UK", "US", "UK", "US", "UK"], - "Sales" => [1000, 2000, 3000, 4000, 5000, 6000] + "Data" => [1000, 2000, 3000, 4000, 5000, 6000] } .unwrap() .lazy() @@ -41,9 +39,9 @@ fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { #[test] fn test_median() { - let expr = col("Sales").median(); + let expr = col("Data").median(); - let sql_expr = "MEDIAN(Sales)"; + let sql_expr = "MEDIAN(Data)"; let (expected, actual) = create_expected(expr, sql_expr); assert!(expected.equals(&actual)) @@ -52,9 +50,9 @@ fn test_median() { #[test] fn test_quantile_cont() { for &q in &[0.25, 0.5, 0.75] { - let expr = col("Sales").quantile(lit(q), QuantileInterpolOptions::Linear); + let expr = col("Data").quantile(lit(q), QuantileMethod::Linear); - let sql_expr = format!("QUANTILE_CONT(Sales, {})", q); + let sql_expr = format!("QUANTILE_CONT(Data, {})", q); let (expected, actual) = create_expected(expr, &sql_expr); assert!( @@ -63,3 +61,61 @@ fn test_quantile_cont() { ) } } + +#[test] +fn test_quantile_disc() { + for &q in &[0.25, 0.5, 0.75] { + let expr = col("Data").quantile(lit(q), QuantileMethod::Equiprobable); + + let sql_expr = format!("QUANTILE_DISC(Data, {})", q); + let (expected, actual) = create_expected(expr, &sql_expr); + + assert!(expected.equals(&actual)) + } +} + +#[test] +fn test_quantile_out_of_range() { + for &q in &["-1", "2", "-0.01", "1.01"] { + for &func in &["QUANTILE_CONT", "QUANTILE_DISC"] { + let query = format!("SELECT {func}(Data, {q})"); + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + let actual = ctx.execute(&query); + assert!(actual.is_err()) + } + } +} + +#[test] +fn test_quantile_disc_conformance() { + let expected = df![ + "q" => [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], + "Data" => [1000, 1000, 2000, 2000, 3000, 3000, 4000, 5000, 5000, 6000, 6000], + ] + .unwrap(); + + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + + let mut actual: Option = None; + for &q in &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] { + let res = ctx + .execute(&format!( + "SELECT {q}::float as q, QUANTILE_DISC(Data, {q}) as Data FROM df" + )) + .unwrap() + .collect() + .unwrap(); + actual = if let Some(df) = actual { + Some(df.vstack(&res).unwrap()) + } else { + Some(res) + }; + } + + assert!( + expected.equals(actual.as_ref().unwrap()), + "expected {expected:?}, got {actual:?}" + ) +} diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index 78cdbc9115d0..fc130a035140 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -37,4 +37,11 @@ version_check = { workspace = true } [features] nightly = [] -bitwise = ["polars-core/bitwise", "polars-plan/bitwise"] +bitwise = ["polars-core/bitwise", "polars-plan/bitwise", "polars-expr/bitwise"] +merge_sorted = ["polars-plan/merge_sorted"] +dynamic_group_by = [] +strings = [] + +# We need to specify default features here to match workspace defaults. +# Otherwise we get warnings with cargo check/clippy. +default = ["bitwise"] diff --git a/crates/polars-stream/src/async_executor/mod.rs b/crates/polars-stream/src/async_executor/mod.rs index dec560845b09..23789e5a20df 100644 --- a/crates/polars-stream/src/async_executor/mod.rs +++ b/crates/polars-stream/src/async_executor/mod.rs @@ -1,12 +1,16 @@ +#![allow(clippy::disallowed_types)] + mod park_group; mod task; use std::cell::{Cell, UnsafeCell}; +use std::collections::HashMap; use std::future::Future; use std::marker::PhantomData; -use std::panic::AssertUnwindSafe; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::{Arc, OnceLock, Weak}; +use std::panic::{AssertUnwindSafe, Location}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, LazyLock, OnceLock, Weak}; +use std::time::Duration; use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue}; use crossbeam_utils::CachePadded; @@ -30,6 +34,27 @@ thread_local!( static TLS_THREAD_ID: Cell = const { Cell::new(usize::MAX) }; ); +static NS_SPENT_BLOCKED: LazyLock, u64>>> = + LazyLock::new(Mutex::default); + +static TRACK_WAIT_STATISTICS: AtomicBool = AtomicBool::new(false); + +pub fn track_task_wait_statistics(should_track: bool) { + TRACK_WAIT_STATISTICS.store(should_track, Ordering::Relaxed); +} + +pub fn get_task_wait_statistics() -> Vec<(&'static Location<'static>, Duration)> { + NS_SPENT_BLOCKED + .lock() + .iter() + .map(|(l, ns)| (*l, Duration::from_nanos(*ns))) + .collect() +} + +pub fn clear_task_wait_statistics() { + NS_SPENT_BLOCKED.lock().clear() +} + slotmap::new_key_type! { struct TaskKey; } @@ -48,6 +73,8 @@ struct ScopedTaskMetadata { } struct TaskMetadata { + spawn_location: &'static Location<'static>, + ns_spent_blocked: AtomicU64, priority: TaskPriority, freshly_spawned: AtomicBool, scoped: Option, @@ -55,6 +82,10 @@ struct TaskMetadata { impl Drop for TaskMetadata { fn drop(&mut self) { + *NS_SPENT_BLOCKED + .lock() + .entry(self.spawn_location) + .or_default() += self.ns_spent_blocked.load(Ordering::Relaxed); if let Some(scoped) = &self.scoped { if let Some(completed_tasks) = scoped.completed_tasks.upgrade() { completed_tasks.lock().push(scoped.task_key); @@ -182,6 +213,7 @@ impl Executor { let mut rng = SmallRng::from_rng(&mut rand::thread_rng()).unwrap(); let mut worker = self.park_group.new_worker(); + let mut last_block_start = None; loop { let ttl = &self.thread_task_lists[thread]; @@ -206,11 +238,23 @@ impl Executor { if let Some(task) = self.try_steal_task(thread, &mut rng) { return Some(task); } + + if last_block_start.is_none() && TRACK_WAIT_STATISTICS.load(Ordering::Relaxed) { + last_block_start = Some(std::time::Instant::now()); + } park.park(); None })(); if let Some(task) = task { + if let Some(t) = last_block_start.take() { + if TRACK_WAIT_STATISTICS.load(Ordering::Relaxed) { + let ns: u64 = t.elapsed().as_nanos().try_into().unwrap(); + task.metadata() + .ns_spent_blocked + .fetch_add(ns, Ordering::Relaxed); + } + } worker.recruit_next(); task.run(); } @@ -264,7 +308,7 @@ pub struct TaskScope<'scope, 'env: 'scope> { env: PhantomData<&'env mut &'env ()>, } -impl<'scope, 'env> TaskScope<'scope, 'env> { +impl<'scope> TaskScope<'scope, '_> { // Not Drop because that extends lifetimes. fn destroy(&self) { // Make sure all tasks are cancelled. @@ -280,6 +324,7 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { } } + #[track_caller] pub fn spawn_task( &self, priority: TaskPriority, @@ -288,6 +333,7 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { where ::Output: Send + 'static, { + let spawn_location = Location::caller(); self.clear_completed_tasks(); let mut runnable = None; @@ -301,6 +347,8 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { fut, on_wake, TaskMetadata { + spawn_location, + ns_spent_blocked: AtomicU64::new(0), priority, freshly_spawned: AtomicBool::new(true), scoped: Some(ScopedTaskMetadata { @@ -345,16 +393,20 @@ where } } +#[track_caller] pub fn spawn(priority: TaskPriority, fut: F) -> JoinHandle where ::Output: Send + 'static, { + let spawn_location = Location::caller(); let executor = Executor::global(); let on_wake = move |task| executor.schedule_task(task); let (runnable, join_handle) = task::spawn( fut, on_wake, TaskMetadata { + spawn_location, + ns_spent_blocked: AtomicU64::new(0), priority, freshly_spawned: AtomicBool::new(true), scoped: None, diff --git a/crates/polars-stream/src/async_executor/park_group.rs b/crates/polars-stream/src/async_executor/park_group.rs index d9da30ce7f3e..d72a474da1e4 100644 --- a/crates/polars-stream/src/async_executor/park_group.rs +++ b/crates/polars-stream/src/async_executor/park_group.rs @@ -149,7 +149,7 @@ pub struct ParkAttempt<'a> { worker: &'a mut ParkGroupWorker, } -impl<'a> ParkAttempt<'a> { +impl ParkAttempt<'_> { /// Actually park this worker. /// /// If there were calls to unpark between calling prepare_park() and park(), diff --git a/crates/polars-stream/src/async_executor/task.rs b/crates/polars-stream/src/async_executor/task.rs index 9991377eb718..1383da2edde8 100644 --- a/crates/polars-stream/src/async_executor/task.rs +++ b/crates/polars-stream/src/async_executor/task.rs @@ -118,9 +118,9 @@ where } } -impl<'a, F, S, M> Wake for Task +impl Wake for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Fn(Runnable) + Send + Sync + Copy + 'static, M: Send + Sync + 'static, @@ -143,9 +143,9 @@ pub trait DynTask: Send + Sync { fn schedule(self: Arc); } -impl<'a, F, S, M> DynTask for Task +impl DynTask for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Fn(Runnable) + Send + Sync + Copy + 'static, M: Send + Sync + 'static, @@ -202,9 +202,9 @@ trait Joinable: Send + Sync { fn poll_join(&self, ctx: &mut Context<'_>) -> Poll; } -impl<'a, F, S, M> Joinable for Task +impl Joinable for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Fn(Runnable) + Send + Sync + Copy + 'static, M: Send + Sync + 'static, @@ -233,9 +233,9 @@ trait Cancellable: Send + Sync { fn cancel(&self); } -impl<'a, F, S, M> Cancellable for Task +impl Cancellable for Task where - F: Future + Send + 'a, + F: Future + Send, F::Output: Send + 'static, S: Send + Sync + 'static, M: Send + Sync + 'static, diff --git a/crates/polars-stream/src/async_primitives/connector.rs b/crates/polars-stream/src/async_primitives/connector.rs index 94999fff4e7a..8b53193b95f1 100644 --- a/crates/polars-stream/src/async_primitives/connector.rs +++ b/crates/polars-stream/src/async_primitives/connector.rs @@ -217,7 +217,7 @@ pin_project! { } } -unsafe impl<'a, T: Send> Send for SendFuture<'a, T> {} +unsafe impl Send for SendFuture<'_, T> {} impl Sender { /// Returns a future that when awaited will send the value to the [`Receiver`]. @@ -255,7 +255,7 @@ pin_project! { } } -unsafe impl<'a, T: Send> Send for RecvFuture<'a, T> {} +unsafe impl Send for RecvFuture<'_, T> {} impl Receiver { /// Returns a future that when awaited will return `Ok(value)` once the diff --git a/crates/polars-stream/src/async_primitives/task_parker.rs b/crates/polars-stream/src/async_primitives/task_parker.rs index 9e48b79e468b..d6cde679980b 100644 --- a/crates/polars-stream/src/async_primitives/task_parker.rs +++ b/crates/polars-stream/src/async_primitives/task_parker.rs @@ -43,7 +43,7 @@ pub struct TaskParkFuture<'a> { parker: &'a TaskParker, } -impl<'a> Future for TaskParkFuture<'a> { +impl Future for TaskParkFuture<'_> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/crates/polars-stream/src/async_primitives/wait_group.rs b/crates/polars-stream/src/async_primitives/wait_group.rs index 716363528505..e08f556d3b95 100644 --- a/crates/polars-stream/src/async_primitives/wait_group.rs +++ b/crates/polars-stream/src/async_primitives/wait_group.rs @@ -62,7 +62,7 @@ impl Future for WaitGroupFuture<'_> { } } -impl<'a> Drop for WaitGroupFuture<'a> { +impl Drop for WaitGroupFuture<'_> { fn drop(&mut self) { self.inner.is_waiting.store(false, Ordering::Relaxed); } diff --git a/crates/polars-stream/src/execute.rs b/crates/polars-stream/src/execute.rs index b199c0044e92..2d68cae2c90e 100644 --- a/crates/polars-stream/src/execute.rs +++ b/crates/polars-stream/src/execute.rs @@ -220,12 +220,21 @@ fn run_subgraph( } // Wait until all tasks are done. - polars_io::pl_async::get_runtime().block_on(async move { + // Only now do we turn on/off wait statistics tracking to reduce noise + // from task startup. + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { + async_executor::track_task_wait_statistics(true); + } + let ret = polars_io::pl_async::get_runtime().block_on(async move { for handle in join_handles { handle.await?; } PolarsResult::Ok(()) - }) + }); + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { + async_executor::track_task_wait_statistics(false); + } + ret })?; Ok(()) diff --git a/crates/polars-stream/src/nodes/filter.rs b/crates/polars-stream/src/nodes/filter.rs index 9f0b0301ef91..f89a53adbf23 100644 --- a/crates/polars-stream/src/nodes/filter.rs +++ b/crates/polars-stream/src/nodes/filter.rs @@ -27,14 +27,14 @@ impl ComputeNode for FilterNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/group_by.rs b/crates/polars-stream/src/nodes/group_by.rs new file mode 100644 index 000000000000..c534924d1433 --- /dev/null +++ b/crates/polars-stream/src/nodes/group_by.rs @@ -0,0 +1,290 @@ +use std::mem::ManuallyDrop; +use std::sync::Arc; + +use polars_core::prelude::IntoColumn; +use polars_core::schema::Schema; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_expr::groups::Grouper; +use polars_expr::reduce::GroupedReduction; +use polars_utils::itertools::Itertools; +use polars_utils::sync::SyncPtr; +use rayon::prelude::*; + +use super::compute_node_prelude::*; +use crate::async_primitives::connector::Receiver; +use crate::expression::StreamExpr; +use crate::nodes::in_memory_source::InMemorySourceNode; + +struct LocalGroupBySinkState { + grouper: Box, + grouped_reductions: Vec>, +} + +struct GroupBySinkState { + key_selectors: Vec, + grouped_reduction_selectors: Vec, + grouper: Box, + grouped_reductions: Vec>, + local: Vec, +} + +impl GroupBySinkState { + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + receivers: Vec>, + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(receivers.len() >= self.local.len()); + self.local + .resize_with(receivers.len(), || LocalGroupBySinkState { + grouper: self.grouper.new_empty(), + grouped_reductions: self + .grouped_reductions + .iter() + .map(|r| r.new_empty()) + .collect(), + }); + for (mut recv, local) in receivers.into_iter().zip(&mut self.local) { + let key_selectors = &self.key_selectors; + let grouped_reduction_selectors = &self.grouped_reduction_selectors; + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + let mut group_idxs = Vec::new(); + while let Ok(morsel) = recv.recv().await { + // Compute group indices from key. + let df = morsel.into_df(); + let mut key_columns = Vec::new(); + for selector in key_selectors { + let s = selector.evaluate(&df, state).await?; + key_columns.push(s.into_column()); + } + let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; + local.grouper.insert_keys(&keys, &mut group_idxs); + + // Update reductions. + for (selector, reduction) in grouped_reduction_selectors + .iter() + .zip(&mut local.grouped_reductions) + { + unsafe { + // SAFETY: we resize the reduction to the number of groups beforehand. + reduction.resize(local.grouper.num_groups()); + reduction.update_groups( + &selector.evaluate(&df, state).await?, + &group_idxs, + )?; + } + } + } + Ok(()) + })); + } + } + + fn combine_locals( + output_schema: &Schema, + mut locals: Vec, + ) -> PolarsResult { + let mut group_idxs = Vec::new(); + let mut combined = locals.pop().unwrap(); + for local in locals { + combined.grouper.combine(&*local.grouper, &mut group_idxs); + for (l, r) in combined + .grouped_reductions + .iter_mut() + .zip(&local.grouped_reductions) + { + unsafe { + l.resize(combined.grouper.num_groups()); + l.combine(&**r, &group_idxs)?; + } + } + } + let mut out = combined.grouper.get_keys_in_group_order(); + let out_names = output_schema.iter_names().skip(out.width()); + for (mut r, name) in combined.grouped_reductions.into_iter().zip(out_names) { + unsafe { + out.with_column_unchecked(r.finalize()?.with_name(name.clone()).into_column()); + } + } + Ok(out) + } + + fn into_source_parallel(self, output_schema: &Schema) -> PolarsResult { + let num_partitions = self.local.len(); + let seed = 0xdeadbeef; + let partitioned_locals: Vec<_> = self + .local + .into_par_iter() + .with_max_len(1) + .map(|local| { + let mut partition_idxs = Vec::new(); + let p_groupers = local + .grouper + .partition(seed, num_partitions, &mut partition_idxs); + let partition_sizes = p_groupers.iter().map(|g| g.num_groups()).collect_vec(); + let grouped_reductions_p = local + .grouped_reductions + .into_iter() + .map(|r| unsafe { r.partition(&partition_sizes, &partition_idxs) }) + .collect_vec(); + (p_groupers, grouped_reductions_p) + }) + .collect(); + + let frames = unsafe { + let mut partitioned_locals = ManuallyDrop::new(partitioned_locals); + let partitioned_locals_ptr = SyncPtr::new(partitioned_locals.as_mut_ptr()); + (0..num_partitions) + .into_par_iter() + .with_max_len(1) + .map(|p| { + let locals_in_p = (0..num_partitions) + .map(|l| { + let partitioned_local = &*partitioned_locals_ptr.get().add(l); + let (p_groupers, grouped_reductions_p) = partitioned_local; + LocalGroupBySinkState { + grouper: p_groupers.as_ptr().add(p).read(), + grouped_reductions: grouped_reductions_p + .iter() + .map(|r| r.as_ptr().add(p).read()) + .collect(), + } + }) + .collect(); + Self::combine_locals(output_schema, locals_in_p) + }) + .collect::>>() + }; + + let df = accumulate_dataframes_vertical_unchecked(frames?); + let mut source_node = InMemorySourceNode::new(Arc::new(df)); + source_node.initialize(num_partitions); + Ok(source_node) + } + + fn into_source(self, output_schema: &Schema) -> PolarsResult { + if std::env::var("POLARS_PARALLEL_GROUPBY_FINALIZE").as_deref() == Ok("1") { + self.into_source_parallel(output_schema) + } else { + let num_pipelines = self.local.len(); + let df = Self::combine_locals(output_schema, self.local); + let mut source_node = InMemorySourceNode::new(Arc::new(df?)); + source_node.initialize(num_pipelines); + Ok(source_node) + } + } +} + +enum GroupByState { + Sink(GroupBySinkState), + Source(InMemorySourceNode), + Done, +} + +pub struct GroupByNode { + state: GroupByState, + output_schema: Arc, +} + +impl GroupByNode { + pub fn new( + key_selectors: Vec, + grouped_reduction_selectors: Vec, + grouped_reductions: Vec>, + grouper: Box, + output_schema: Arc, + ) -> Self { + Self { + state: GroupByState::Sink(GroupBySinkState { + key_selectors, + grouped_reduction_selectors, + grouped_reductions, + grouper, + local: Vec::new(), + }), + output_schema, + } + } +} + +impl ComputeNode for GroupByNode { + fn name(&self) -> &str { + "group_by" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(recv.len() == 1 && send.len() == 1); + + // State transitions. + match &mut self.state { + // If the output doesn't want any more data, transition to being done. + _ if send[0] == PortState::Done => { + self.state = GroupByState::Done; + }, + // Input is done, transition to being a source. + GroupByState::Sink(_) if matches!(recv[0], PortState::Done) => { + let GroupByState::Sink(sink) = + core::mem::replace(&mut self.state, GroupByState::Done) + else { + unreachable!() + }; + self.state = GroupByState::Source(sink.into_source(&self.output_schema)?); + }, + // Defer to source node implementation. + GroupByState::Source(src) => { + src.update_state(&mut [], send)?; + if send[0] == PortState::Done { + self.state = GroupByState::Done; + } + }, + // Nothing to change. + GroupByState::Done | GroupByState::Sink(_) => {}, + } + + // Communicate our state. + match &self.state { + GroupByState::Sink { .. } => { + send[0] = PortState::Blocked; + recv[0] = PortState::Ready; + }, + GroupByState::Source(..) => { + recv[0] = PortState::Done; + send[0] = PortState::Ready; + }, + GroupByState::Done => { + recv[0] = PortState::Done; + send[0] = PortState::Done; + }, + } + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(send_ports.len() == 1 && recv_ports.len() == 1); + match &mut self.state { + GroupByState::Sink(sink) => { + assert!(send_ports[0].is_none()); + sink.spawn( + scope, + recv_ports[0].take().unwrap().parallel(), + state, + join_handles, + ) + }, + GroupByState::Source(source) => { + assert!(recv_ports[0].is_none()); + source.spawn(scope, &mut [], send_ports, state, join_handles); + }, + GroupByState::Done => unreachable!(), + } + } +} diff --git a/crates/polars-stream/src/nodes/in_memory_map.rs b/crates/polars-stream/src/nodes/in_memory_map.rs index 3a8bff496a18..27af6be9aa87 100644 --- a/crates/polars-stream/src/nodes/in_memory_map.rs +++ b/crates/polars-stream/src/nodes/in_memory_map.rs @@ -86,16 +86,16 @@ impl ComputeNode for InMemoryMapNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { match self { Self::Sink { sink_node, .. } => { - sink_node.spawn(scope, recv, &mut [], state, join_handles) + sink_node.spawn(scope, recv_ports, &mut [], state, join_handles) }, - Self::Source(source) => source.spawn(scope, &mut [], send, state, join_handles), + Self::Source(source) => source.spawn(scope, &mut [], send_ports, state, join_handles), Self::Done => unreachable!(), } } diff --git a/crates/polars-stream/src/nodes/in_memory_sink.rs b/crates/polars-stream/src/nodes/in_memory_sink.rs index afd6ccfd95cc..58d2f9e8ffe6 100644 --- a/crates/polars-stream/src/nodes/in_memory_sink.rs +++ b/crates/polars-stream/src/nodes/in_memory_sink.rs @@ -45,13 +45,13 @@ impl ComputeNode for InMemorySinkNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.is_empty()); - let receivers = recv[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.is_empty()); + let receivers = recv_ports[0].take().unwrap().parallel(); for mut recv in receivers { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/in_memory_source.rs b/crates/polars-stream/src/nodes/in_memory_source.rs index 5ab6b0f75d50..c8dfec9d0032 100644 --- a/crates/polars-stream/src/nodes/in_memory_source.rs +++ b/crates/polars-stream/src/nodes/in_memory_source.rs @@ -60,13 +60,13 @@ impl ComputeNode for InMemorySourceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.is_empty() && send.len() == 1); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.is_empty() && send_ports.len() == 1); + let senders = send_ports[0].take().unwrap().parallel(); let source = self.source.as_ref().unwrap(); // TODO: can this just be serial, using the work distributor? diff --git a/crates/polars-stream/src/nodes/input_independent_select.rs b/crates/polars-stream/src/nodes/input_independent_select.rs index f1a9113d05d4..9df4c1ab5281 100644 --- a/crates/polars-stream/src/nodes/input_independent_select.rs +++ b/crates/polars-stream/src/nodes/input_independent_select.rs @@ -36,13 +36,13 @@ impl ComputeNode for InputIndependentSelectNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.is_empty() && send.len() == 1); - let mut sender = send[0].take().unwrap().serial(); + assert!(recv_ports.is_empty() && send_ports.len() == 1); + let mut sender = send_ports[0].take().unwrap().serial(); join_handles.push(scope.spawn_task(TaskPriority::Low, async move { let empty_df = DataFrame::empty(); diff --git a/crates/polars-stream/src/nodes/io_sinks/ipc.rs b/crates/polars-stream/src/nodes/io_sinks/ipc.rs new file mode 100644 index 000000000000..5587221d894c --- /dev/null +++ b/crates/polars-stream/src/nodes/io_sinks/ipc.rs @@ -0,0 +1,81 @@ +use std::fs::{File, OpenOptions}; +use std::path::Path; + +use polars_core::schema::SchemaRef; +use polars_error::PolarsResult; +use polars_expr::state::ExecutionState; +use polars_io::ipc::{BatchedWriter, IpcWriter, IpcWriterOptions}; +use polars_io::SerWriter; + +use crate::nodes::{ComputeNode, JoinHandle, PortState, TaskPriority, TaskScope}; +use crate::pipe::{RecvPort, SendPort}; + +pub struct IpcSinkNode { + is_finished: bool, + writer: BatchedWriter, +} + +impl IpcSinkNode { + pub fn new( + input_schema: SchemaRef, + path: &Path, + write_options: &IpcWriterOptions, + ) -> PolarsResult { + let file = OpenOptions::new().write(true).open(path)?; + let writer = IpcWriter::new(file) + .with_compression(write_options.compression) + .batched(&input_schema)?; + + Ok(Self { + is_finished: false, + writer, + }) + } +} + +impl ComputeNode for IpcSinkNode { + fn name(&self) -> &str { + "ipc_sink" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(send.is_empty()); + assert!(recv.len() == 1); + + if recv[0] == PortState::Done && !self.is_finished { + // @NOTE: This function can be called afterwards multiple times. So make sure to only + // finish the writer once. + self.is_finished = true; + self.writer.finish()?; + } + + // We are always ready to receive, unless the sender is done, then we're + // also done. + if recv[0] != PortState::Done { + recv[0] = PortState::Ready; + } + + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + _state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(send_ports.is_empty()); + assert!(recv_ports.len() == 1); + let mut receiver = recv_ports[0].take().unwrap().serial(); + + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + while let Ok(morsel) = receiver.recv().await { + self.writer.write_batch(&morsel.into_df())?; + } + + Ok(()) + })); + } +} diff --git a/crates/polars-stream/src/nodes/io_sinks/mod.rs b/crates/polars-stream/src/nodes/io_sinks/mod.rs new file mode 100644 index 000000000000..ce14ad3b0f7a --- /dev/null +++ b/crates/polars-stream/src/nodes/io_sinks/mod.rs @@ -0,0 +1 @@ +pub mod ipc; diff --git a/crates/polars-stream/src/nodes/map.rs b/crates/polars-stream/src/nodes/map.rs index 007dfa921672..c1994d1e4a9a 100644 --- a/crates/polars-stream/src/nodes/map.rs +++ b/crates/polars-stream/src/nodes/map.rs @@ -29,14 +29,14 @@ impl ComputeNode for MapNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index 82ad0f8293e9..559e4717c4e9 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -1,8 +1,10 @@ pub mod filter; +pub mod group_by; pub mod in_memory_map; pub mod in_memory_sink; pub mod in_memory_source; pub mod input_independent_select; +pub mod io_sinks; pub mod map; pub mod multiplexer; pub mod ordered_union; @@ -61,8 +63,8 @@ pub trait ComputeNode: Send { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ); diff --git a/crates/polars-stream/src/nodes/multiplexer.rs b/crates/polars-stream/src/nodes/multiplexer.rs index 65f2e752d28d..d4e4ac62cf01 100644 --- a/crates/polars-stream/src/nodes/multiplexer.rs +++ b/crates/polars-stream/src/nodes/multiplexer.rs @@ -92,13 +92,13 @@ impl ComputeNode for MultiplexerNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && !send.is_empty()); - assert!(self.buffers.len() == send.len()); + assert!(recv_ports.len() == 1 && !send_ports.is_empty()); + assert!(self.buffers.len() == send_ports.len()); enum Listener<'a> { Active(UnboundedSender), @@ -114,7 +114,7 @@ impl ComputeNode for MultiplexerNode { .enumerate() .map(|(port_idx, buffer)| { if let BufferedStream::Open(buf) = buffer { - if send[port_idx].is_some() { + if send_ports[port_idx].is_some() { // TODO: replace with a bounded channel and store data // out-of-core beyond a certain size. let (rx, tx) = unbounded_channel(); @@ -129,7 +129,7 @@ impl ComputeNode for MultiplexerNode { .unzip(); // TODO: parallel multiplexing. - if let Some(mut receiver) = recv[0].take().map(|r| r.serial()) { + if let Some(mut receiver) = recv_ports[0].take().map(|r| r.serial()) { let buffered_source_token = buffered_source_token.clone(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { loop { @@ -176,7 +176,7 @@ impl ComputeNode for MultiplexerNode { })); } - for (send_port, opt_buf_recv) in send.iter_mut().zip(buf_receivers) { + for (send_port, opt_buf_recv) in send_ports.iter_mut().zip(buf_receivers) { if let Some((buf, mut rx)) = opt_buf_recv { let mut sender = send_port.take().unwrap().serial(); diff --git a/crates/polars-stream/src/nodes/ordered_union.rs b/crates/polars-stream/src/nodes/ordered_union.rs index 3c72d9cc6e15..cb65175292e2 100644 --- a/crates/polars-stream/src/nodes/ordered_union.rs +++ b/crates/polars-stream/src/nodes/ordered_union.rs @@ -52,15 +52,15 @@ impl ComputeNode for OrderedUnionNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - let ready_count = recv.iter().filter(|r| r.is_some()).count(); - assert!(ready_count == 1 && send.len() == 1); - let receivers = recv[self.cur_input_idx].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + let ready_count = recv_ports.iter().filter(|r| r.is_some()).count(); + assert!(ready_count == 1 && send_ports.len() == 1); + let receivers = recv_ports[self.cur_input_idx].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); let mut inner_handles = Vec::new(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { diff --git a/crates/polars-stream/src/nodes/parquet_source/init.rs b/crates/polars-stream/src/nodes/parquet_source/init.rs index 3187bbe797e4..a722186ff497 100644 --- a/crates/polars-stream/src/nodes/parquet_source/init.rs +++ b/crates/polars-stream/src/nodes/parquet_source/init.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::future::Future; use std::sync::Arc; @@ -14,7 +15,6 @@ use super::{AsyncTaskData, ParquetSourceNode}; use crate::async_executor; use crate::async_primitives::connector::connector; use crate::async_primitives::wait_group::{WaitGroup, WaitToken}; -use crate::morsel::get_ideal_morsel_size; use crate::nodes::{MorselSeq, TaskPriority}; use crate::utils::task_handles_ext; @@ -118,6 +118,8 @@ impl ParquetSourceNode { let row_group_decoder = self.init_row_group_decoder(); let row_group_decoder = Arc::new(row_group_decoder); + let ideal_morsel_size = self.config.ideal_morsel_size; + // Distributes morsels across pipelines. This does not perform any CPU or I/O bound work - // it is purely a dispatch loop. let raw_morsel_distributor_task_handle = io_runtime.spawn(async move { @@ -191,25 +193,31 @@ impl ParquetSourceNode { ); let morsel_seq_ref = &mut MorselSeq::default(); - let mut dfs = vec![].into_iter(); + let mut dfs = VecDeque::with_capacity(1); 'main: loop { let Some(mut indexed_wait_group) = wait_groups.next().await else { break; }; - if dfs.len() == 0 { + while dfs.is_empty() { let Some(v) = df_stream.next().await else { - break; + break 'main; }; - let v = v?; - assert!(!v.is_empty()); + let df = v?; + + if df.is_empty() { + continue; + } - dfs = v.into_iter(); + let (iter, n) = split_to_morsels(&df, ideal_morsel_size); + + dfs.reserve(n); + dfs.extend(iter); } - let mut df = dfs.next().unwrap(); + let mut df = dfs.pop_front().unwrap(); let morsel_seq = *morsel_seq_ref; *morsel_seq_ref = morsel_seq.successor(); @@ -270,7 +278,6 @@ impl ParquetSourceNode { let projected_arrow_schema = self.projected_arrow_schema.clone().unwrap(); let row_index = self.file_options.row_index.clone(); let physical_predicate = self.physical_predicate.clone(); - let ideal_morsel_size = get_ideal_morsel_size(); let min_values_per_thread = self.config.min_values_per_thread; let mut use_prefiltered = physical_predicate.is_some() @@ -348,7 +355,6 @@ impl ParquetSourceNode { predicate_arrow_field_indices, non_predicate_arrow_field_indices, predicate_arrow_field_mask, - ideal_morsel_size, min_values_per_thread, } } @@ -402,6 +408,28 @@ fn filtered_range(exclude: &[usize], len: usize) -> Vec { .collect() } +/// Note: The 2nd return is an upper bound on the number of morsels rather than an exact count. +fn split_to_morsels( + df: &DataFrame, + ideal_morsel_size: usize, +) -> (impl Iterator + '_, usize) { + let n_morsels = if df.height() > 3 * ideal_morsel_size / 2 { + // num_rows > (1.5 * ideal_morsel_size) + (df.height() / ideal_morsel_size).max(2) + } else { + 1 + }; + + let rows_per_morsel = 1 + df.height() / n_morsels; + + ( + (0..i64::try_from(df.height()).unwrap()) + .step_by(rows_per_morsel) + .map(move |offset| df.slice(offset, rows_per_morsel)), + n_morsels, + ) +} + mod tests { #[test] diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index 746c517ce744..e3377036b908 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -141,7 +141,7 @@ impl ParquetSourceNode { } if allow_missing_columns { - ensure_matching_dtypes_if_found(&first_schema, &schema)?; + ensure_matching_dtypes_if_found(projected_arrow_schema.as_ref(), &schema)?; } else { ensure_schema_has_projected_fields( &schema, diff --git a/crates/polars-stream/src/nodes/parquet_source/mod.rs b/crates/polars-stream/src/nodes/parquet_source/mod.rs index 44fc4e1f1239..a5efa4cb3b89 100644 --- a/crates/polars-stream/src/nodes/parquet_source/mod.rs +++ b/crates/polars-stream/src/nodes/parquet_source/mod.rs @@ -18,7 +18,7 @@ use polars_plan::prelude::FileScanOptions; use super::compute_node_prelude::*; use super::{MorselSeq, TaskPriority}; use crate::async_primitives::wait_group::WaitToken; -use crate::morsel::SourceToken; +use crate::morsel::{get_ideal_morsel_size, SourceToken}; use crate::utils::task_handles_ext; mod init; @@ -70,6 +70,7 @@ struct Config { /// Minimum number of values for a parallel spawned task to process to amortize /// parallelism overhead. min_values_per_thread: usize, + ideal_morsel_size: usize, } #[allow(clippy::too_many_arguments)] @@ -110,6 +111,7 @@ impl ParquetSourceNode { metadata_decode_ahead_size: 0, row_group_prefetch_size: 0, min_values_per_thread: 0, + ideal_morsel_size: 0, }, verbose, physical_predicate: None, @@ -142,6 +144,7 @@ impl ComputeNode for ParquetSourceNode { let min_values_per_thread = std::env::var("POLARS_MIN_VALUES_PER_THREAD") .map(|x| x.parse::().expect("integer").max(1)) .unwrap_or(16_777_216); + let ideal_morsel_size = get_ideal_morsel_size(); Config { num_pipelines, @@ -149,6 +152,7 @@ impl ComputeNode for ParquetSourceNode { metadata_decode_ahead_size, row_group_prefetch_size, min_values_per_thread, + ideal_morsel_size, } }; @@ -198,18 +202,18 @@ impl ComputeNode for ParquetSourceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { use std::sync::atomic::Ordering; - assert!(recv.is_empty()); - assert_eq!(send.len(), 1); + assert!(recv_ports.is_empty()); + assert_eq!(send_ports.len(), 1); assert!(!self.is_finished.load(Ordering::Relaxed)); - let morsel_senders = send[0].take().unwrap().parallel(); + let morsel_senders = send_ports[0].take().unwrap().parallel(); let mut async_task_data_guard = self.async_task_data.try_lock().unwrap(); let (raw_morsel_receivers, _) = async_task_data_guard.as_mut().unwrap(); diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs index dfa4b11e3b02..52d3003de7ea 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs @@ -2,11 +2,12 @@ use std::future::Future; use std::sync::Arc; use polars_core::prelude::{ArrowSchema, InitHashMaps, PlHashMap}; +use polars_core::series::IsSorted; use polars_core::utils::operation_exceeded_idxsize_msg; use polars_error::{polars_err, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; -use polars_io::prelude::FileMetadata; use polars_io::prelude::_internal::read_this_row_group; +use polars_io::prelude::{create_sorting_map, FileMetadata}; use polars_io::utils::byte_source::{ByteSource, DynByteSource}; use polars_io::utils::slice::SplitSlicePosition; use polars_parquet::read::RowGroupMetadata; @@ -27,6 +28,7 @@ pub(super) struct RowGroupData { pub(super) slice: Option<(usize, usize)>, pub(super) file_max_row_group_height: usize, pub(super) row_group_metadata: RowGroupMetadata, + pub(super) sorting_map: PlHashMap, pub(super) shared_file_state: Arc>, } @@ -86,6 +88,7 @@ impl RowGroupDataFetcher { let current_row_group_idx = self.current_row_group_idx; let num_rows = row_group_metadata.num_rows(); + let sorting_map = create_sorting_map(&row_group_metadata); self.current_row_offset = current_row_offset.saturating_add(num_rows); self.current_row_group_idx += 1; @@ -246,6 +249,7 @@ impl RowGroupDataFetcher { slice, file_max_row_group_height: current_max_row_group_height, row_group_metadata, + sorting_map, shared_file_state: current_shared_file_state.clone(), }) }); diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs index 119345295686..d31f1e51f71e 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs @@ -11,6 +11,7 @@ use polars_error::{polars_bail, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; use polars_io::prelude::_internal::calc_prefilter_cost; pub use polars_io::prelude::_internal::PrefilterMaskSetting; +use polars_io::prelude::try_set_sorted_flag; use polars_io::RowIndex; use polars_plan::plans::hive::HivePartitions; use polars_plan::plans::ScanSources; @@ -37,7 +38,6 @@ pub(super) struct RowGroupDecoder { pub(super) non_predicate_arrow_field_indices: Vec, /// The nth bit is set to `true` if the field at that index is used in the predicate. pub(super) predicate_arrow_field_mask: Vec, - pub(super) ideal_morsel_size: usize, pub(super) min_values_per_thread: usize, } @@ -45,7 +45,7 @@ impl RowGroupDecoder { pub(super) async fn row_group_data_to_df( &self, row_group_data: RowGroupData, - ) -> PolarsResult> { + ) -> PolarsResult { if self.use_prefiltered.is_some() { self.row_group_data_to_df_prefiltered(row_group_data).await } else { @@ -56,7 +56,7 @@ impl RowGroupDecoder { async fn row_group_data_to_df_impl( &self, row_group_data: RowGroupData, - ) -> PolarsResult> { + ) -> PolarsResult { let row_group_data = Arc::new(row_group_data); let out_width = self.row_index.is_some() as usize @@ -130,7 +130,7 @@ impl RowGroupDecoder { assert_eq!(df.width(), out_width); // `out_width` should have been calculated correctly - Ok(self.split_to_morsels(df)) + Ok(df) } async fn shared_file_state_init_func(&self, row_group_data: &RowGroupData) -> SharedFileState { @@ -307,26 +307,6 @@ impl RowGroupDecoder { Ok(()) } - - fn split_to_morsels(&self, df: DataFrame) -> Vec { - let n_morsels = if df.height() > 3 * self.ideal_morsel_size / 2 { - // num_rows > (1.5 * ideal_morsel_size) - (df.height() / self.ideal_morsel_size).max(2) - } else { - 1 - } as u64; - - if n_morsels == 1 { - return vec![df]; - } - - let rows_per_morsel = 1 + df.height() / n_morsels as usize; - - (0..i64::try_from(df.height()).unwrap()) - .step_by(rows_per_morsel) - .map(|offset| df.slice(offset, rows_per_morsel)) - .collect::>() - } } fn decode_column( @@ -367,11 +347,20 @@ fn decode_column( assert_eq!(array.len(), expected_num_rows); - let series = Series::try_from((arrow_field, array))?; + let mut series = Series::try_from((arrow_field, array))?; + + if let Some(col_idxs) = row_group_data + .row_group_metadata + .columns_idxs_under_root_iter(&arrow_field.name) + { + if col_idxs.len() == 1 { + try_set_sorted_flag(&mut series, col_idxs[0], &row_group_data.sorting_map); + } + } // TODO: Also load in the metadata. - Ok(series.into()) + Ok(series.into_column()) } /// # Safety @@ -468,7 +457,7 @@ impl RowGroupDecoder { async fn row_group_data_to_df_prefiltered( &self, row_group_data: RowGroupData, - ) -> PolarsResult> { + ) -> PolarsResult { debug_assert!(row_group_data.slice.is_none()); // Invariant of the optimizer. assert!(self.predicate_arrow_field_indices.len() <= self.projected_arrow_schema.len()); @@ -604,7 +593,7 @@ impl RowGroupDecoder { assert_eq!(dead_rem.len(), 0); let df = unsafe { DataFrame::new_no_checks(expected_num_rows, out_columns) }; - Ok(self.split_to_morsels(df)) + Ok(df) } } @@ -652,17 +641,26 @@ fn decode_column_prefiltered( deserialize_filter, )?; - let column = Series::try_from((arrow_field, array))?.into_column(); + let mut series = Series::try_from((arrow_field, array))?; + + if let Some(col_idxs) = row_group_data + .row_group_metadata + .columns_idxs_under_root_iter(&arrow_field.name) + { + if col_idxs.len() == 1 { + try_set_sorted_flag(&mut series, col_idxs[0], &row_group_data.sorting_map); + } + } - let column = if !prefilter { - column.filter(mask)? + let series = if !prefilter { + series.filter(mask)? } else { - column + series }; - assert_eq!(column.len(), expected_num_rows); + assert_eq!(series.len(), expected_num_rows); - Ok(column) + Ok(series.into_column()) } mod tests { diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs index ded581a4cf38..565854e97b81 100644 --- a/crates/polars-stream/src/nodes/reduce.rs +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -162,24 +162,24 @@ impl ComputeNode for ReduceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(send.len() == 1 && recv.len() == 1); + assert!(send_ports.len() == 1 && recv_ports.len() == 1); match &mut self.state { ReduceState::Sink { selectors, reductions, } => { - assert!(send[0].is_none()); - let recv_port = recv[0].take().unwrap(); + assert!(send_ports[0].is_none()); + let recv_port = recv_ports[0].take().unwrap(); Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles) }, ReduceState::Source(df) => { - assert!(recv[0].is_none()); - let send_port = send[0].take().unwrap(); + assert!(recv_ports[0].is_none()); + let send_port = send_ports[0].take().unwrap(); Self::spawn_source(df, scope, send_port, join_handles) }, ReduceState::Done => unreachable!(), diff --git a/crates/polars-stream/src/nodes/select.rs b/crates/polars-stream/src/nodes/select.rs index 3b060e78e654..bf12904ff12c 100644 --- a/crates/polars-stream/src/nodes/select.rs +++ b/crates/polars-stream/src/nodes/select.rs @@ -36,14 +36,14 @@ impl ComputeNode for SelectNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/simple_projection.rs b/crates/polars-stream/src/nodes/simple_projection.rs index 95f002df2889..00cd8ed55ad0 100644 --- a/crates/polars-stream/src/nodes/simple_projection.rs +++ b/crates/polars-stream/src/nodes/simple_projection.rs @@ -33,14 +33,14 @@ impl ComputeNode for SimpleProjectionNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let receivers = recv[0].take().unwrap().parallel(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let receivers = recv_ports[0].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); for (mut recv, mut send) in receivers.into_iter().zip(senders) { let slf = &*self; diff --git a/crates/polars-stream/src/nodes/streaming_slice.rs b/crates/polars-stream/src/nodes/streaming_slice.rs index 950b39331588..5d9f5a003340 100644 --- a/crates/polars-stream/src/nodes/streaming_slice.rs +++ b/crates/polars-stream/src/nodes/streaming_slice.rs @@ -43,14 +43,14 @@ impl ComputeNode for StreamingSliceNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let mut recv = recv[0].take().unwrap().serial(); - let mut send = send[0].take().unwrap().serial(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let mut recv = recv_ports[0].take().unwrap().serial(); + let mut send = send_ports[0].take().unwrap().serial(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { let stop_offset = self.start_offset + self.length; diff --git a/crates/polars-stream/src/nodes/with_row_index.rs b/crates/polars-stream/src/nodes/with_row_index.rs index 942d23219fec..fe075120963d 100644 --- a/crates/polars-stream/src/nodes/with_row_index.rs +++ b/crates/polars-stream/src/nodes/with_row_index.rs @@ -35,14 +35,14 @@ impl ComputeNode for WithRowIndexNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(recv.len() == 1 && send.len() == 1); - let mut receiver = recv[0].take().unwrap().serial(); - let senders = send[0].take().unwrap().parallel(); + assert!(recv_ports.len() == 1 && send_ports.len() == 1); + let mut receiver = recv_ports[0].take().unwrap().serial(); + let senders = send_ports[0].take().unwrap().parallel(); let (mut distributor, distr_receivers) = distributor_channel(senders.len(), DEFAULT_DISTRIBUTOR_BUFFER_SIZE); diff --git a/crates/polars-stream/src/nodes/zip.rs b/crates/polars-stream/src/nodes/zip.rs index cd72a3567442..614c7b506128 100644 --- a/crates/polars-stream/src/nodes/zip.rs +++ b/crates/polars-stream/src/nodes/zip.rs @@ -205,20 +205,20 @@ impl ComputeNode for ZipNode { fn spawn<'env, 's>( &'env mut self, scope: &'s TaskScope<'s, 'env>, - recv: &mut [Option>], - send: &mut [Option>], + recv_ports: &mut [Option>], + send_ports: &mut [Option>], _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { - assert!(send.len() == 1); - assert!(!recv.is_empty()); - let mut sender = send[0].take().unwrap().serial(); + assert!(send_ports.len() == 1); + assert!(!recv_ports.is_empty()); + let mut sender = send_ports[0].take().unwrap().serial(); - let mut receivers = recv + let mut receivers = recv_ports .iter_mut() - .map(|r| { + .map(|recv_port| { // Add buffering to each receiver to reduce contention between input heads. - let mut serial_recv = r.take()?.serial(); + let mut serial_recv = recv_port.take()?.serial(); let (buf_send, buf_recv) = tokio::sync::mpsc::channel(DEFAULT_ZIP_HEAD_BUFFER_SIZE); join_handles.push(scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = serial_recv.recv().await { diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 21dd8e9dd634..ed0f08a0d48f 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -2,6 +2,7 @@ use std::fmt::Write; use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{AExpr, EscapeLabel, FileScan, ScanSourcesDisplay}; +use polars_plan::prelude::FileType; use polars_utils::arena::Arena; use polars_utils::itertools::Itertools; use slotmap::{Key, SecondaryMap, SlotMap}; @@ -95,6 +96,14 @@ fn visualize_plan_rec( from_ref(input), ), PhysNodeKind::InMemorySink { input } => ("in-memory-sink".to_string(), from_ref(input)), + PhysNodeKind::FileSink { + input, file_type, .. + } => match file_type { + FileType::Parquet(_) => ("parquet-sink".to_string(), from_ref(input)), + FileType::Ipc(_) => ("ipc-sink".to_string(), from_ref(input)), + FileType::Csv(_) => ("csv-sink".to_string(), from_ref(input)), + FileType::Json(_) => ("json-sink".to_string(), from_ref(input)), + }, PhysNodeKind::InMemoryMap { input, map: _ } => { ("in-memory-map".to_string(), from_ref(input)) }, @@ -183,6 +192,17 @@ fn visualize_plan_rec( (out, &[][..]) }, + PhysNodeKind::GroupBy { input, key, aggs } => { + let label = "group-by"; + ( + format!( + "{label}\\nkey:\\n{}\\naggs:\\n{}", + fmt_exprs(key, expr_arena), + fmt_exprs(aggs, expr_arena) + ), + from_ref(input), + ) + }, }; out.push(format!( diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 6b61a98979f2..3af80df16f9f 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -98,6 +98,7 @@ pub(crate) fn is_elementwise( match function { // Non-strict strptime must be done in-memory to ensure the format // is consistent across the entire dataframe. + #[cfg(feature = "strings")] FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)) => opts.strict, _ => { options.is_elementwise() @@ -573,7 +574,9 @@ fn lower_exprs_with_ctx( .. } | IRAggExpr::Sum(ref mut inner) - | IRAggExpr::Mean(ref mut inner) => { + | IRAggExpr::Mean(ref mut inner) + | IRAggExpr::Var(ref mut inner, _ /* ddof */) + | IRAggExpr::Std(ref mut inner, _ /* ddof */) => { let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[*inner], ctx)?; *inner = trans_exprs[0]; @@ -596,8 +599,6 @@ fn lower_exprs_with_ctx( | IRAggExpr::Implode(_) | IRAggExpr::Quantile { .. } | IRAggExpr::Count(_, _) - | IRAggExpr::Std(_, _) - | IRAggExpr::Var(_, _) | IRAggExpr::AggGroups(_) => { let out_name = unique_column_name(); fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone()))); @@ -664,7 +665,7 @@ fn lower_exprs_with_ctx( /// Computes the schema that selecting the given expressions on the input schema /// would result in. -fn compute_output_schema( +pub fn compute_output_schema( input_schema: &Schema, exprs: &[ExprIR], expr_arena: &Arena, diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index f2b532ca1ca3..485bbf03a7fe 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -2,16 +2,37 @@ use std::sync::Arc; use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; use polars_core::schema::Schema; -use polars_error::PolarsResult; +use polars_error::{polars_ensure, PolarsResult}; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; -use polars_plan::plans::{AExpr, FunctionIR, IR}; -use polars_plan::prelude::SinkType; +use polars_plan::plans::{AExpr, FunctionIR, IRAggExpr, IR}; +use polars_plan::prelude::{FileType, SinkType}; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use slotmap::SlotMap; use super::{PhysNode, PhysNodeKey, PhysNodeKind}; -use crate::physical_plan::lower_expr::{is_elementwise, ExprCache}; +use crate::physical_plan::lower_expr::{build_select_node, is_elementwise, lower_exprs, ExprCache}; + +fn build_slice_node( + input: PhysNodeKey, + offset: i64, + length: usize, + phys_sm: &mut SlotMap, +) -> PhysNodeKey { + if offset >= 0 { + let offset = offset as usize; + phys_sm.insert(PhysNode::new( + phys_sm[input].output_schema.clone(), + PhysNodeKind::StreamingSlice { + input, + offset, + length, + }, + )) + } else { + todo!() + } +} #[recursive::recursive] pub fn lower_ir( @@ -22,19 +43,26 @@ pub fn lower_ir( schema_cache: &mut PlHashMap>, expr_cache: &mut ExprCache, ) -> PolarsResult { - let ir_node = ir_arena.get(node); - let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); - let node_kind = match ir_node { - IR::SimpleProjection { input, columns } => { - let columns = columns.iter_names_cloned().collect::>(); - let phys_input = lower_ir( - *input, + // Helper macro to simplify recursive calls. + macro_rules! lower_ir { + ($input:expr) => { + lower_ir( + $input, ir_arena, expr_arena, phys_sm, schema_cache, expr_cache, - )?; + ) + }; + } + + let ir_node = ir_arena.get(node); + let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); + let node_kind = match ir_node { + IR::SimpleProjection { input, columns } => { + let columns = columns.iter_names_cloned().collect::>(); + let phys_input = lower_ir!(*input)?; PhysNodeKind::SimpleProjection { input: phys_input, columns, @@ -43,17 +71,8 @@ pub fn lower_ir( IR::Select { input, expr, .. } => { let selectors = expr.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; - return super::lower_expr::build_select_node( - phys_input, &selectors, expr_arena, phys_sm, expr_cache, - ); + let phys_input = lower_ir!(*input)?; + return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::HStack { input, exprs, .. } @@ -63,14 +82,7 @@ pub fn lower_ir( { // FIXME: constant literal columns should be broadcasted with hstack. let selectors = exprs.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; PhysNodeKind::Select { input: phys_input, selectors, @@ -84,14 +96,7 @@ pub fn lower_ir( // // FIXME: constant literal columns should be broadcasted with hstack. let exprs = exprs.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; let input_schema = &phys_sm[phys_input].output_schema; let mut selectors = PlIndexMap::with_capacity(input_schema.len() + exprs.len()); for name in input_schema.iter_names() { @@ -106,43 +111,19 @@ pub fn lower_ir( selectors.insert(expr.output_name().clone(), expr); } let selectors = selectors.into_values().collect_vec(); - return super::lower_expr::build_select_node( - phys_input, &selectors, expr_arena, phys_sm, expr_cache, - ); + return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::Slice { input, offset, len } => { - if *offset >= 0 { - let offset = *offset as usize; - let length = *len as usize; - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; - PhysNodeKind::StreamingSlice { - input: phys_input, - offset, - length, - } - } else { - todo!() - } + let offset = *offset; + let len = *len as usize; + let phys_input = lower_ir!(*input)?; + return Ok(build_slice_node(phys_input, offset, len, phys_sm)); }, IR::Filter { input, predicate } => { let predicate = predicate.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; let cols_and_predicate = output_schema .iter_names() .cloned() @@ -154,7 +135,7 @@ pub fn lower_ir( }) .chain([predicate]) .collect_vec(); - let (trans_input, mut trans_cols_and_predicate) = super::lower_expr::lower_exprs( + let (trans_input, mut trans_cols_and_predicate) = lower_exprs( phys_input, &cols_and_predicate, expr_arena, @@ -170,7 +151,7 @@ pub fn lower_ir( let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); trans_cols_and_predicate.pop(); // Remove predicate. - return super::lower_expr::build_select_node( + return build_select_node( post_filter, &trans_cols_and_predicate, expr_arena, @@ -221,38 +202,40 @@ pub fn lower_ir( node_kind }, - IR::Sink { input, payload } => { - if *payload == SinkType::Memory { - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + IR::Sink { input, payload } => match payload { + SinkType::Memory => { + let phys_input = lower_ir!(*input)?; PhysNodeKind::InMemorySink { input: phys_input } - } else { - todo!() - } + }, + SinkType::File { path, file_type } => { + let path = path.clone(); + let file_type = file_type.clone(); + + match file_type { + FileType::Ipc(_) => { + let phys_input = lower_ir!(*input)?; + PhysNodeKind::FileSink { + path, + file_type, + input: phys_input, + } + }, + _ => todo!(), + } + }, + SinkType::Cloud { .. } => todo!(), }, IR::MapFunction { input, function } => { // MergeSorted uses a rechunk hack incompatible with the // streaming engine. + #[cfg(feature = "merge_sorted")] if let FunctionIR::MergeSorted { .. } = function { todo!() } let function = function.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; match function { FunctionIR::RowIndex { @@ -292,14 +275,7 @@ pub fn lower_ir( by_column: by_column.clone(), slice: *slice, sort_options: sort_options.clone(), - input: lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?, + input: lower_ir!(*input)?, }, IR::Union { inputs, options } => { @@ -310,16 +286,7 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| { - lower_ir( - input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - ) - }) + .map(|input| lower_ir!(input)) .collect::>()?; PhysNodeKind::OrderedUnion { inputs } }, @@ -332,16 +299,7 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| { - lower_ir( - input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - ) - }) + .map(|input| lower_ir!(input)) .collect::>()?; PhysNodeKind::Zip { inputs, @@ -377,7 +335,84 @@ pub fn lower_ir( IR::PythonScan { .. } => todo!(), IR::Reduce { .. } => todo!(), IR::Cache { .. } => todo!(), - IR::GroupBy { .. } => todo!(), + IR::GroupBy { + input, + keys, + aggs, + schema: _, + apply, + maintain_order, + options, + } => { + if apply.is_some() || *maintain_order { + todo!() + } + + #[cfg(feature = "dynamic_group_by")] + if options.dynamic.is_some() || options.rolling.is_some() { + todo!() + } + + let key = keys.clone(); + let mut aggs = aggs.clone(); + let options = options.clone(); + + polars_ensure!(!keys.is_empty(), ComputeError: "at least one key is required in a group_by operation"); + + // TODO: allow all aggregates. + let mut input_exprs = key.clone(); + for agg in &aggs { + match expr_arena.get(agg.node()) { + AExpr::Agg(expr) => match expr { + IRAggExpr::Min { input, .. } + | IRAggExpr::Max { input, .. } + | IRAggExpr::Mean(input) + | IRAggExpr::Sum(input) + | IRAggExpr::Var(input, ..) + | IRAggExpr::Std(input, ..) => { + if is_elementwise(*input, expr_arena, expr_cache) { + input_exprs.push(ExprIR::from_node(*input, expr_arena)); + } else { + todo!() + } + }, + _ => todo!(), + }, + AExpr::Len => input_exprs.push(key[0].clone()), // Hack, use the first key column for the length. + _ => todo!(), + } + } + + let phys_input = lower_ir!(*input)?; + let (trans_input, trans_exprs) = + lower_exprs(phys_input, &input_exprs, expr_arena, phys_sm, expr_cache)?; + let trans_key = trans_exprs[..key.len()].to_vec(); + let trans_aggs = aggs + .iter_mut() + .zip(trans_exprs.iter().skip(key.len())) + .map(|(agg, trans_expr)| { + let old_expr = expr_arena.get(agg.node()).clone(); + let new_expr = old_expr.replace_inputs(&[trans_expr.node()]); + ExprIR::new(expr_arena.add(new_expr), agg.output_name_inner().clone()) + }) + .collect(); + + let mut node = phys_sm.insert(PhysNode::new( + output_schema, + PhysNodeKind::GroupBy { + input: trans_input, + key: trans_key, + aggs: trans_aggs, + }, + )); + + // TODO: actually limit number of groups instead of computing full + // result and then slicing. + if let Some((offset, len)) = options.slice { + node = build_slice_node(node, offset, len, phys_sm); + } + return Ok(node); + }, IR::Join { .. } => todo!(), IR::Distinct { .. } => todo!(), IR::ExtContext { .. } => todo!(), diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index eddbc87bda99..3b4643100249 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::sync::Arc; use polars_core::frame::DataFrame; @@ -14,7 +15,7 @@ mod lower_ir; mod to_graph; pub use fmt::visualize_plan; -use polars_plan::prelude::FileScanOptions; +use polars_plan::prelude::{FileScanOptions, FileType}; use polars_utils::arena::{Arena, Node}; use polars_utils::pl_str::PlSmallStr; use slotmap::{Key, SecondaryMap, SlotMap}; @@ -93,6 +94,12 @@ pub enum PhysNodeKind { input: PhysNodeKey, }, + FileSink { + path: Arc, + file_type: FileType, + input: PhysNodeKey, + }, + InMemoryMap { input: PhysNodeKey, map: Arc, @@ -136,6 +143,12 @@ pub enum PhysNodeKind { scan_type: FileScan, file_options: FileScanOptions, }, + + GroupBy { + input: PhysNodeKey, + key: Vec, + aggs: Vec, + }, } #[recursive::recursive] @@ -176,10 +189,12 @@ fn insert_multiplexers( | PhysNodeKind::Filter { input, .. } | PhysNodeKind::SimpleProjection { input, .. } | PhysNodeKind::InMemorySink { input } + | PhysNodeKind::FileSink { input, .. } | PhysNodeKind::InMemoryMap { input, .. } | PhysNodeKind::Map { input, .. } | PhysNodeKind::Sort { input, .. } - | PhysNodeKind::Multiplexer { input } => { + | PhysNodeKind::Multiplexer { input } + | PhysNodeKind::GroupBy { input, .. } => { insert_multiplexers(*input, phys_sm, referenced); }, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 34692aa10b9a..d9253e48dfa5 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use parking_lot::Mutex; -use polars_core::schema::Schema; +use polars_core::schema::{Schema, SchemaExt}; use polars_error::PolarsResult; +use polars_expr::groups::new_hash_grouper; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; @@ -10,7 +11,7 @@ use polars_mem_engine::create_physical_plan; use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{AExpr, ArenaExprIter, Context, IR}; -use polars_plan::prelude::FunctionFlags; +use polars_plan::prelude::{FileType, FunctionFlags}; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use recursive::recursive; @@ -20,6 +21,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; +use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; fn has_potential_recurring_entrance(node: Node, arena: &Arena) -> bool { @@ -202,6 +204,23 @@ fn to_graph_rec<'a>( ) }, + FileSink { + path, + file_type, + input, + } => { + let input_schema = ctx.phys_sm[*input].output_schema.clone(); + let input_key = to_graph_rec(*input, ctx)?; + + match file_type { + FileType::Ipc(ipc_writer_options) => ctx.graph.add_node( + nodes::io_sinks::ipc::IpcSinkNode::new(input_schema, path, ipc_writer_options)?, + [input_key], + ), + _ => todo!(), + } + }, + InMemoryMap { input, map } => { let input_schema = ctx.phys_sm[*input].output_schema.clone(); let input_key = to_graph_rec(*input, ctx)?; @@ -349,6 +368,46 @@ fn to_graph_rec<'a>( } } }, + + GroupBy { input, key, aggs } => { + let input_key = to_graph_rec(*input, ctx)?; + + let input_schema = &ctx.phys_sm[*input].output_schema; + let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)? + .materialize_unknown_dtypes()?; + let random_state = Default::default(); + let grouper = new_hash_grouper(Arc::new(key_schema), random_state); + + let key_selectors = key + .iter() + .map(|e| create_stream_expr(e, ctx, input_schema)) + .try_collect_vec()?; + + let mut grouped_reductions = Vec::new(); + let mut grouped_reduction_selectors = Vec::new(); + for agg in aggs { + let (reduction, input_node) = + into_reduction(agg.node(), ctx.expr_arena, input_schema)?; + let selector = create_stream_expr( + &ExprIR::from_node(input_node, ctx.expr_arena), + ctx, + input_schema, + )?; + grouped_reductions.push(reduction); + grouped_reduction_selectors.push(selector); + } + + ctx.graph.add_node( + nodes::group_by::GroupByNode::new( + key_selectors, + grouped_reduction_selectors, + grouped_reductions, + grouper, + node.output_schema.clone(), + ), + [input_key], + ) + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key); diff --git a/crates/polars-stream/src/pipe.rs b/crates/polars-stream/src/pipe.rs index 019d8779d18a..21b6a5672618 100644 --- a/crates/polars-stream/src/pipe.rs +++ b/crates/polars-stream/src/pipe.rs @@ -20,7 +20,7 @@ pub enum PhysicalPipe { pub struct SendPort<'a>(&'a mut PhysicalPipe); pub struct RecvPort<'a>(&'a mut PhysicalPipe); -impl<'a> RecvPort<'a> { +impl RecvPort<'_> { pub fn serial(self) -> Receiver { let PhysicalPipe::Uninit(num_pipelines) = self.0 else { unreachable!() @@ -41,7 +41,7 @@ impl<'a> RecvPort<'a> { } } -impl<'a> SendPort<'a> { +impl SendPort<'_> { #[allow(unused)] pub fn is_receiver_serial(&self) -> bool { matches!(self.0, PhysicalPipe::SerialReceiver(..)) diff --git a/crates/polars-stream/src/skeleton.rs b/crates/polars-stream/src/skeleton.rs index 20ca189de9e0..9516be3b902a 100644 --- a/crates/polars-stream/src/skeleton.rs +++ b/crates/polars-stream/src/skeleton.rs @@ -1,22 +1,24 @@ #![allow(unused)] // TODO: remove me +use std::cmp::Reverse; + use polars_core::prelude::*; use polars_core::POOL; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; -use polars_plan::plans::{Context, IRPlan, IR}; +use polars_plan::plans::{Context, IRPlan, IsStreamableContext, IR}; use polars_plan::prelude::expr_ir::ExprIR; use polars_plan::prelude::AExpr; use polars_utils::arena::{Arena, Node}; use slotmap::{SecondaryMap, SlotMap}; fn is_streamable(node: Node, arena: &Arena) -> bool { - polars_plan::plans::is_streamable(node, arena, Context::Default) + polars_plan::plans::is_streamable(node, arena, IsStreamableContext::new(Context::Default)) } pub fn run_query( node: Node, mut ir_arena: Arena, expr_arena: &mut Arena, -) -> PolarsResult { +) -> PolarsResult> { if let Ok(visual_path) = std::env::var("POLARS_VISUALIZE_IR") { let plan = IRPlan { lp_top: node, @@ -35,6 +37,16 @@ pub fn run_query( } let (mut graph, phys_to_graph) = crate::physical_plan::physical_plan_to_graph(root, &phys_sm, expr_arena)?; + crate::async_executor::clear_task_wait_statistics(); let mut results = crate::execute::execute_graph(&mut graph)?; - Ok(results.remove(phys_to_graph[root]).unwrap()) + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { + let mut stats = crate::async_executor::get_task_wait_statistics(); + stats.sort_by_key(|(_l, w)| Reverse(*w)); + eprintln!("Time spent waiting for async tasks:"); + for (loc, wait_time) in stats { + eprintln!("{}:{} - {:?}", loc.file(), loc.line(), wait_time); + } + } + + Ok(results.remove(phys_to_graph[root])) } diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index d75d634d213d..6d878b0ba27f 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -23,6 +23,7 @@ now = { version = "0.1" } once_cell = { workspace = true } regex = { workspace = true } serde = { workspace = true, optional = true } +strum_macros = { workspace = true } [dev-dependencies] polars-ops = { workspace = true, features = ["abs"] } diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 8a8d2312d580..3ff08ee4d308 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -789,12 +789,12 @@ mod test { let quantile = unsafe { a.as_materialized_series() - .agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) + .agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]); assert_eq!(quantile, expected); - let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) }; + let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]); assert_eq!(quantile, expected); diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index 4f300f733100..56ce3e4bdcd5 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -153,7 +153,7 @@ impl Duration { /// # Panics /// If the given str is invalid for any reason. pub fn parse(duration: &str) -> Self { - Self::_parse(duration, false) + Self::try_parse(duration).unwrap() } #[doc(hidden)] @@ -161,23 +161,31 @@ impl Duration { /// units (such as 'year', 'minutes', etc.) and whitespace, as /// well as being case-insensitive. pub fn parse_interval(interval: &str) -> Self { + Self::try_parse_interval(interval).unwrap() + } + + pub fn try_parse(duration: &str) -> PolarsResult { + Self::_parse(duration, false) + } + + pub fn try_parse_interval(interval: &str) -> PolarsResult { Self::_parse(&interval.to_ascii_lowercase(), true) } - fn _parse(s: &str, as_interval: bool) -> Self { + fn _parse(s: &str, as_interval: bool) -> PolarsResult { let s = if as_interval { s.trim_start() } else { s }; let parse_type = if as_interval { "interval" } else { "duration" }; let num_minus_signs = s.matches('-').count(); if num_minus_signs > 1 { - panic!("{} string can only have a single minus sign", parse_type) + polars_bail!(InvalidOperation: "{} string can only have a single minus sign", parse_type); } if num_minus_signs > 0 { if as_interval { // TODO: intervals need to support per-element minus signs - panic!("minus signs are not currently supported in interval strings") + polars_bail!(InvalidOperation: "minus signs are not currently supported in interval strings"); } else if !s.starts_with('-') { - panic!("only a single minus sign is allowed, at the front of the string") + polars_bail!(InvalidOperation: "only a single minus sign is allowed, at the front of the string"); } } let mut months = 0; @@ -211,12 +219,12 @@ impl Duration { while let Some((i, mut ch)) = iter.next() { if !ch.is_ascii_digit() { - let n = s[start..i].parse::().unwrap_or_else(|_| { - panic!( + let Ok(n) = s[start..i].parse::() else { + polars_bail!(InvalidOperation: "expected leading integer in the {} string, found {}", parse_type, ch - ) - }); + ); + }; loop { match ch { @@ -233,10 +241,10 @@ impl Duration { } } if unit.is_empty() { - panic!( + polars_bail!(InvalidOperation: "expected a unit to follow integer in the {} string '{}'", parse_type, s - ) + ); } match &*unit { // matches that are allowed for both duration/interval @@ -270,24 +278,25 @@ impl Duration { "year" | "years" => months += n * 12, _ => { let valid_units = "'year', 'month', 'quarter', 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond', 'nanosecond'"; - panic!("unit: '{unit}' not supported; available units include: {} (and their plurals)", valid_units) + polars_bail!(InvalidOperation: "unit: '{unit}' not supported; available units include: {} (and their plurals)", valid_units); }, }, _ => { - panic!("unit: '{unit}' not supported; available units are: 'y', 'mo', 'q', 'w', 'd', 'h', 'm', 's', 'ms', 'us', 'ns'") + polars_bail!(InvalidOperation: "unit: '{unit}' not supported; available units are: 'y', 'mo', 'q', 'w', 'd', 'h', 'm', 's', 'ms', 'us', 'ns'"); }, } unit.clear(); } } - Duration { + + Ok(Duration { nsecs: nsecs.abs(), days: days.abs(), weeks: weeks.abs(), months: months.abs(), negative, parsed_int, - } + }) } fn to_positive(v: i64) -> (bool, i64) { diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 9ba3a2d3dbc2..0a40b9af6fbc 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -8,11 +8,13 @@ use polars_core::POOL; use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; use crate::prelude::*; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum ClosedWindow { Left, Right, @@ -20,16 +22,18 @@ pub enum ClosedWindow { None, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum Label { Left, Right, DataPoint, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] pub enum StartBy { WindowBound, DataPoint, diff --git a/crates/polars-time/src/windows/window.rs b/crates/polars-time/src/windows/window.rs index 90afe791e4d2..c7a29b846c58 100644 --- a/crates/polars-time/src/windows/window.rs +++ b/crates/polars-time/src/windows/window.rs @@ -316,7 +316,7 @@ impl<'a> BoundsIter<'a> { } } -impl<'a> Iterator for BoundsIter<'a> { +impl Iterator for BoundsIter<'_> { type Item = Bounds; fn next(&mut self) -> Option { diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index 442d319b7753..ef968918d4e8 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -21,6 +21,7 @@ libc = { workspace = true } memmap = { workspace = true, optional = true } num-traits = { workspace = true } once_cell = { workspace = true } +pyo3 = { workspace = true, optional = true } raw-cpuid = { workspace = true } rayon = { workspace = true } serde = { workspace = true, optional = true } @@ -39,3 +40,4 @@ bigidx = [] nightly = [] ir_serde = ["serde"] serde = ["dep:serde", "serde/derive"] +python = ["pyo3"] diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index c8b5823d695e..270a3fb9b835 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -7,11 +7,11 @@ use crate::error::*; use crate::slice::GetSaferUnchecked; unsafe fn index_of_unchecked(slice: &[T], item: &T) -> usize { - (item as *const _ as usize - slice.as_ptr() as usize) / std::mem::size_of::() + (item as *const _ as usize - slice.as_ptr() as usize) / size_of::() } fn index_of(slice: &[T], item: &T) -> Option { - debug_assert!(std::mem::size_of::() > 0); + debug_assert!(size_of::() > 0); let ptr = item as *const T; unsafe { if slice.as_ptr() < ptr && slice.as_ptr().add(slice.len()) > ptr { diff --git a/crates/polars-utils/src/hashing.rs b/crates/polars-utils/src/hashing.rs index 12e59bf52f26..63f4c661a2c3 100644 --- a/crates/polars-utils/src/hashing.rs +++ b/crates/polars-utils/src/hashing.rs @@ -2,6 +2,11 @@ use std::hash::{Hash, Hasher}; use crate::nulls::IsNull; +pub const fn folded_multiply(a: u64, b: u64) -> u64 { + let full = (a as u128).wrapping_mul(b as u128); + (full as u64) ^ ((full >> 64) as u64) +} + /// Contains a byte slice and a precomputed hash for that string. /// During rehashes, we will rehash the hash instead of the string, that makes /// rehashing cheap and allows cache coherent small hash tables. @@ -33,13 +38,13 @@ impl<'a> IsNull for BytesHash<'a> { } } -impl<'a> Hash for BytesHash<'a> { +impl Hash for BytesHash<'_> { fn hash(&self, state: &mut H) { state.write_u64(self.hash) } } -impl<'a> PartialEq for BytesHash<'a> { +impl PartialEq for BytesHash<'_> { #[inline] fn eq(&self, other: &Self) -> bool { (self.hash == other.hash) && (self.payload == other.payload) @@ -94,7 +99,7 @@ impl DirtyHash for i128 { } } -impl<'a> DirtyHash for BytesHash<'a> { +impl DirtyHash for BytesHash<'_> { fn dirty_hash(&self) -> u64 { self.hash } diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs index 8bfdfafa2fd4..b3fbbed403f5 100644 --- a/crates/polars-utils/src/idx_vec.rs +++ b/crates/polars-utils/src/idx_vec.rs @@ -45,10 +45,7 @@ impl UnitVec { #[inline] pub fn new() -> Self { // This is optimized away, all const. - assert!( - std::mem::size_of::() <= std::mem::size_of::<*mut T>() - && std::mem::align_of::() <= std::mem::align_of::<*mut T>() - ); + assert!(size_of::() <= size_of::<*mut T>() && align_of::() <= align_of::<*mut T>()); Self { len: 0, capacity: NonZeroUsize::new(1).unwrap(), diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index eacd517d1254..5c302067e146 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -51,3 +51,6 @@ pub mod partitioned; pub use index::{IdxSize, NullableIdxSize}; pub use io::*; + +#[cfg(feature = "python")] +pub mod python_function; diff --git a/crates/polars-utils/src/mmap.rs b/crates/polars-utils/src/mmap.rs index 29651d5eb56a..0ac1a643d93d 100644 --- a/crates/polars-utils/src/mmap.rs +++ b/crates/polars-utils/src/mmap.rs @@ -61,6 +61,12 @@ mod private { } } + impl From> for MemSlice { + fn from(value: Vec) -> Self { + Self::from_vec(value) + } + } + impl MemSlice { pub const EMPTY: Self = Self::from_static(&[]); diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs new file mode 100644 index 000000000000..178f89aa2ca1 --- /dev/null +++ b/crates/polars-utils/src/python_function.rs @@ -0,0 +1,235 @@ +use polars_error::{polars_bail, PolarsError, PolarsResult}; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedBytes; +use pyo3::types::PyBytes; +#[cfg(feature = "serde")] +pub use serde_wrap::{ + PySerializeWrap, TrySerializeToBytes, PYTHON3_VERSION, + SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK, +}; + +use crate::flatten; + +#[derive(Clone, Debug)] +pub struct PythonFunction(pub PyObject); + +impl From for PythonFunction { + fn from(value: PyObject) -> Self { + Self(value) + } +} + +impl Eq for PythonFunction {} + +impl PartialEq for PythonFunction { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + let eq = self.0.getattr(py, "__eq__").unwrap(); + eq.call1(py, (other.0.clone_ref(py),)) + .unwrap() + .extract::(py) + // equality can be not implemented, so default to false + .unwrap_or(false) + }) + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PythonFunction { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + serializer.serialize_bytes( + self.try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))? + .as_slice(), + ) + } +} + +#[cfg(feature = "serde")] +impl<'a> serde::Deserialize<'a> for PythonFunction { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + let bytes = Vec::::deserialize(deserializer)?; + Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string())) + } +} + +#[cfg(feature = "serde")] +impl TrySerializeToBytes for PythonFunction { + fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult> { + serialize_pyobject_with_cloudpickle_fallback(&self.0) + } + + fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult { + deserialize_pyobject_bytes_maybe_cloudpickle(bytes) + } +} + +pub fn serialize_pyobject_with_cloudpickle_fallback(py_object: &PyObject) -> PolarsResult> { + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("dumps") + .unwrap(); + + let dumped = pickle.call1((py_object.clone_ref(py),)); + + let (dumped, used_cloudpickle) = if let Ok(v) = dumped { + (v, false) + } else { + let cloudpickle = PyModule::import_bound(py, "cloudpickle") + .map_err(from_pyerr)? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle + .call1((py_object.clone_ref(py),)) + .map_err(from_pyerr)?; + (dumped, true) + }; + + let py_bytes = dumped.extract::().map_err(from_pyerr)?; + + Ok(flatten( + &[&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()], + None, + )) + }) +} + +pub fn deserialize_pyobject_bytes_maybe_cloudpickle From>( + bytes: &[u8], +) -> PolarsResult { + // TODO: Actually deserialize with cloudpickle if it's set. + let [_used_cloudpickle @ 0 | _used_cloudpickle @ 1, b'C', rem @ ..] = bytes else { + polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes") + }; + + let bytes = rem; + + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("loads") + .unwrap(); + let arg = (PyBytes::new_bound(py, bytes),); + let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?; + Ok(PyObject::from(pyany_bound).into()) + }) +} + +#[cfg(feature = "serde")] +mod serde_wrap { + use once_cell::sync::Lazy; + use polars_error::PolarsResult; + + use crate::flatten; + + pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes(); + /// [minor, micro] + pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version); + + /// Serializes a Python object without additional system metadata. This is intended to be used + /// together with `PySerializeWrap`, which attaches e.g. Python version metadata. + pub trait TrySerializeToBytes: Sized { + fn try_serialize_to_bytes(&self) -> PolarsResult>; + fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult; + } + + /// Serialization wrapper for T: TrySerializeToBytes that attaches Python + /// version metadata. + pub struct PySerializeWrap(pub T); + + impl serde::Serialize for PySerializeWrap<&T> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + let dumped = self + .0 + .try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))?; + + serializer.serialize_bytes( + flatten( + &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()], + None, + ) + .as_slice(), + ) + } + } + + impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + let bytes = Vec::::deserialize(deserializer)?; + + let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject version", + )); + }; + + if magic != SERDE_MAGIC_BYTE_MARK { + return Err(D::Error::custom( + "serialized pyobject did not begin with magic byte mark", + )); + } + + let bytes = rem; + + let [a, b, rem @ ..] = bytes else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject metadata", + )); + }; + + let py3_version = [*a, *b]; + + if py3_version != *PYTHON3_VERSION { + return Err(D::Error::custom(format!( + "python version that pyobject was serialized with {:?} \ + differs from system python version {:?}", + (3, py3_version[0], py3_version[1]), + (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]), + ))); + } + + let bytes = rem; + + T::try_deserialize_bytes(bytes) + .map(Self) + .map_err(|e| D::Error::custom(e.to_string())) + } + } +} + +/// Get the [minor, micro] Python3 version from the `sys` module. +fn get_python3_version() -> [u8; 2] { + Python::with_gil(|py| { + let version_info = PyModule::import_bound(py, "sys") + .unwrap() + .getattr("version_info") + .unwrap(); + + [ + version_info.getattr("minor").unwrap().extract().unwrap(), + version_info.getattr("micro").unwrap().extract().unwrap(), + ] + }) +} + +fn from_pyerr(e: PyErr) -> PolarsError { + PolarsError::ComputeError(format!("error raised in python: {e}").into()) +} diff --git a/crates/polars-utils/src/sort.rs b/crates/polars-utils/src/sort.rs index 780dc39d1b9c..0c83f2becddd 100644 --- a/crates/polars-utils/src/sort.rs +++ b/crates/polars-utils/src/sort.rs @@ -90,8 +90,8 @@ where Idx: FromPrimitive + Copy, { // Needed to be able to write back to back in the same buffer. - debug_assert_eq!(std::mem::align_of::(), std::mem::align_of::<(T, Idx)>()); - let size = std::mem::size_of::<(T, Idx)>(); + debug_assert_eq!(align_of::(), align_of::<(T, Idx)>()); + let size = size_of::<(T, Idx)>(); let upper_bound = size * n + size; scratch.reserve(upper_bound); let scratch_slice = unsafe { diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs index cfaa05f0141d..982dc707e3de 100644 --- a/crates/polars-utils/src/total_ord.rs +++ b/crates/polars-utils/src/total_ord.rs @@ -453,7 +453,7 @@ impl TotalOrd for (T, U) { } } -impl<'a> TotalHash for BytesHash<'a> { +impl TotalHash for BytesHash<'_> { #[inline(always)] fn tot_hash(&self, state: &mut H) where @@ -463,7 +463,7 @@ impl<'a> TotalHash for BytesHash<'a> { } } -impl<'a> TotalEq for BytesHash<'a> { +impl TotalEq for BytesHash<'_> { #[inline(always)] fn tot_eq(&self, other: &Self) -> bool { self == other diff --git a/crates/polars-utils/src/vec.rs b/crates/polars-utils/src/vec.rs index 108e7d573d1c..9060a348230c 100644 --- a/crates/polars-utils/src/vec.rs +++ b/crates/polars-utils/src/vec.rs @@ -20,7 +20,7 @@ impl IntoRawParts for Vec { } } -/// Fill current allocation if if > 0 +/// Fill current allocation if > 0 /// otherwise realloc pub trait ResizeFaster { fn fill_or_alloc(&mut self, new_len: usize, value: T); diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index a710011aa045..685ed71d8306 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -26,8 +26,7 @@ polars-utils = { workspace = true } [dev-dependencies] ahash = { workspace = true } apache-avro = { version = "0.17", features = ["snappy"] } -arrow = { workspace = true, features = ["arrow_rs"] } -arrow-buffer = { workspace = true } +arrow = { workspace = true } avro-schema = { workspace = true, features = ["async"] } either = { workspace = true } ethnum = "1" @@ -132,7 +131,13 @@ array_any_all = ["polars-lazy?/array_any_all", "dtype-array"] asof_join = ["polars-lazy?/asof_join", "polars-ops/asof_join"] iejoin = ["polars-lazy?/iejoin"] binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"] -bitwise = ["polars-core/bitwise", "polars-plan?/bitwise", "polars-ops/bitwise", "polars-lazy?/bitwise"] +bitwise = [ + "polars-core/bitwise", + "polars-plan?/bitwise", + "polars-ops/bitwise", + "polars-lazy?/bitwise", + "polars-sql?/bitwise", +] business = ["polars-lazy?/business", "polars-ops/business"] checked_arithmetic = ["polars-core/checked_arithmetic"] chunked_ids = ["polars-ops?/chunked_ids"] @@ -184,6 +189,7 @@ list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"] list_sample = ["polars-lazy?/list_sample"] list_sets = ["polars-lazy?/list_sets"] list_to_struct = ["polars-ops/list_to_struct", "polars-lazy?/list_to_struct"] +list_arithmetic = ["polars-core/list_arithmetic"] array_to_struct = ["polars-ops/array_to_struct", "polars-lazy?/array_to_struct"] log = ["polars-ops/log", "polars-lazy?/log"] merge_sorted = ["polars-lazy?/merge_sorted"] @@ -229,7 +235,7 @@ true_div = ["polars-lazy?/true_div"] unique_counts = ["polars-ops/unique_counts", "polars-lazy?/unique_counts"] zip_with = ["polars-core/zip_with"] -bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] +bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx", "polars-utils/bigidx"] polars_cloud = ["polars-lazy?/polars_cloud"] ir_serde = ["polars-plan/ir_serde"] diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 5ecc28c94c34..dba9bb39d46d 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -373,6 +373,7 @@ //! * `ASCII_BORDERS_ONLY_CONDENSED` //! * `ASCII_HORIZONTAL_ONLY` //! * `ASCII_MARKDOWN` +//! * `MARKDOWN` //! * `UTF8_FULL` //! * `UTF8_FULL_CONDENSED` //! * `UTF8_NO_BORDERS` diff --git a/crates/polars/tests/it/arrow/bitmap/immutable.rs b/crates/polars/tests/it/arrow/bitmap/immutable.rs index 4f2b3f3748b0..336322534de2 100644 --- a/crates/polars/tests/it/arrow/bitmap/immutable.rs +++ b/crates/polars/tests/it/arrow/bitmap/immutable.rs @@ -76,28 +76,3 @@ fn debug() { "Bitmap { len: 7, offset: 2, bytes: [0b111110__, 0b_______1] }" ); } - -#[test] -fn from_arrow() { - use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; - let buffer = arrow_buffer::Buffer::from_iter(vec![true, true, true, false, false, false, true]); - let bools = BooleanBuffer::new(buffer, 0, 7); - let nulls = NullBuffer::new(bools); - assert_eq!(nulls.null_count(), 3); - - let bitmap = Bitmap::from_null_buffer(nulls.clone()); - assert_eq!(nulls.null_count(), bitmap.unset_bits()); - assert_eq!(nulls.len(), bitmap.len()); - let back = NullBuffer::from(bitmap); - assert_eq!(nulls, back); - - let nulls = nulls.slice(1, 3); - assert_eq!(nulls.null_count(), 1); - assert_eq!(nulls.len(), 3); - - let bitmap = Bitmap::from_null_buffer(nulls.clone()); - assert_eq!(nulls.null_count(), bitmap.unset_bits()); - assert_eq!(nulls.len(), bitmap.len()); - let back = NullBuffer::from(bitmap); - assert_eq!(nulls, back); -} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs index fa138cd528c9..08cfbf31c62e 100644 --- a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs +++ b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs @@ -2,7 +2,7 @@ use arrow::bitmap::utils::fmt; struct A<'a>(&'a [u8], usize, usize); -impl<'a> std::fmt::Debug for A<'a> { +impl std::fmt::Debug for A<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt(self.0, self.1, self.2, f) } diff --git a/crates/polars/tests/it/arrow/buffer/immutable.rs b/crates/polars/tests/it/arrow/buffer/immutable.rs index 9065b52fba35..cc8742ba73ae 100644 --- a/crates/polars/tests/it/arrow/buffer/immutable.rs +++ b/crates/polars/tests/it/arrow/buffer/immutable.rs @@ -43,73 +43,3 @@ fn from_vec() { assert_eq!(buffer.len(), 3); assert_eq!(buffer.as_slice(), &[0, 1, 2]); } - -#[test] -fn from_arrow() { - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 3); - assert_eq!(b.as_slice(), &[1, 2, 3]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); - - let buffer = buffer.slice(4); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 2); - assert_eq!(b.as_slice(), &[2, 3]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); - - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i64, 2_i64]); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 4); - assert_eq!(b.as_slice(), &[1, 0, 2, 0]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); - - let buffer = buffer.slice(4); - let b = Buffer::::from(buffer.clone()); - assert_eq!(b.len(), 3); - assert_eq!(b.as_slice(), &[0, 2, 0]); - let back = arrow_buffer::Buffer::from(b); - assert_eq!(back, buffer); -} - -#[test] -fn from_arrow_vec() { - // Zero-copy vec conversion in arrow-rs - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); - let back: Vec = buffer.into_vec().unwrap(); - - // Zero-copy vec conversion in arrow2 - let buffer = Buffer::::from(back); - let back: Vec = buffer.into_mut().unwrap_right(); - - let buffer = arrow_buffer::Buffer::from_vec(back); - let buffer = Buffer::::from(buffer); - - // But not possible after conversion between buffer representations - let _ = buffer.into_mut().unwrap_left(); - - let buffer = Buffer::::from(vec![1_i32]); - let buffer = arrow_buffer::Buffer::from(buffer); - - // But not possible after conversion between buffer representations - let _ = buffer.into_vec::().unwrap_err(); -} - -#[test] -#[should_panic(expected = "arrow_buffer::Buffer misaligned")] -fn from_arrow_misaligned() { - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]).slice(1); - let _ = Buffer::::from(buffer); -} - -#[test] -fn from_arrow_sliced() { - let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); - let b = Buffer::::from(buffer); - let sliced = b.sliced(1, 2); - let back = arrow_buffer::Buffer::from(sliced); - assert_eq!(back.typed_data::(), &[2, 3]); -} diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs index 075e5179e1ca..942fd38cc4a7 100644 --- a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -5,7 +5,7 @@ use arrow::datatypes::{ArrowDataType, Field}; #[test] fn primitive() { let a = Int32Array::from_slice([1, 2, 3, 4, 5]); - assert_eq!(5 * std::mem::size_of::(), estimated_bytes_size(&a)); + assert_eq!(5 * size_of::(), estimated_bytes_size(&a)); } #[test] @@ -17,7 +17,7 @@ fn boolean() { #[test] fn utf8() { let a = Utf8Array::::from_slice(["aaa"]); - assert_eq!(3 + 2 * std::mem::size_of::(), estimated_bytes_size(&a)); + assert_eq!(3 + 2 * size_of::(), estimated_bytes_size(&a)); } #[test] @@ -28,5 +28,5 @@ fn fixed_size_list() { ); let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); let a = FixedSizeListArray::new(dtype, 2, values, None); - assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); + assert_eq!(6 * size_of::(), estimated_bytes_size(&a)); } diff --git a/crates/polars/tests/it/io/avro/read_async.rs b/crates/polars/tests/it/io/avro/read_async.rs index 049910c0ce28..d50fd7595c58 100644 --- a/crates/polars/tests/it/io/avro/read_async.rs +++ b/crates/polars/tests/it/io/avro/read_async.rs @@ -26,25 +26,16 @@ async fn test(codec: Codec) -> PolarsResult<()> { Ok(()) } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn read_without_codec() -> PolarsResult<()> { test(Codec::Null).await } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn read_deflate() -> PolarsResult<()> { test(Codec::Deflate).await } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn read_snappy() -> PolarsResult<()> { test(Codec::Snappy).await diff --git a/crates/polars/tests/it/io/avro/write_async.rs b/crates/polars/tests/it/io/avro/write_async.rs index 1be109e2733a..77cb212f89db 100644 --- a/crates/polars/tests/it/io/avro/write_async.rs +++ b/crates/polars/tests/it/io/avro/write_async.rs @@ -42,9 +42,6 @@ async fn roundtrip(compression: Option) -> PolarsResult<()> { Ok(()) } -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#[allow(clippy::needless_return)] #[tokio::test] async fn no_compression() -> PolarsResult<()> { roundtrip(None).await diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs index 9b2185656462..188adf8efeba 100644 --- a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs +++ b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs @@ -30,7 +30,7 @@ pub fn read( num_values: usize, _is_sorted: bool, ) -> ParquetResult> { - let size_of = std::mem::size_of::(); + let size_of = size_of::(); let typed_size = num_values.wrapping_mul(size_of); diff --git a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs index 36fdb254420a..430df46d1239 100644 --- a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs +++ b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs @@ -10,7 +10,7 @@ use super::dictionary::PrimitivePageDict; use super::{hybrid_rle_iter, Array}; fn read_buffer(values: &[u8]) -> impl Iterator + '_ { - let chunks = values.chunks_exact(std::mem::size_of::()); + let chunks = values.chunks_exact(size_of::()); chunks.map(|chunk| { // unwrap is infalible due to the chunk size. let chunk: T::Bytes = match chunk.try_into() { @@ -179,7 +179,7 @@ pub struct DecoderIter<'a, T: Unpackable> { pub(crate) unpacked_end: usize, } -impl<'a, T: Unpackable> Iterator for DecoderIter<'a, T> { +impl Iterator for DecoderIter<'_, T> { type Item = T; fn next(&mut self) -> Option { @@ -203,7 +203,7 @@ impl<'a, T: Unpackable> Iterator for DecoderIter<'a, T> { } } -impl<'a, T: Unpackable> ExactSizeIterator for DecoderIter<'a, T> {} +impl ExactSizeIterator for DecoderIter<'_, T> {} impl<'a, T: Unpackable> DecoderIter<'a, T> { pub fn new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { diff --git a/crates/polars/tests/it/io/parquet/read/utils.rs b/crates/polars/tests/it/io/parquet/read/utils.rs index 19feaee29534..14b664e0d962 100644 --- a/crates/polars/tests/it/io/parquet/read/utils.rs +++ b/crates/polars/tests/it/io/parquet/read/utils.rs @@ -197,13 +197,11 @@ pub fn native_cast(page: &DataPage) -> ParquetResult> { def: _, values, } = split_buffer(page)?; - if values.len() % std::mem::size_of::() != 0 { + if values.len() % size_of::() != 0 { panic!("A primitive page data's len must be a multiple of the type"); } - Ok(values - .chunks_exact(std::mem::size_of::()) - .map(decode::)) + Ok(values.chunks_exact(size_of::()).map(decode::)) } /// The deserialization state of a `DataPage` of `Primitive` parquet primitive type diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs index ad043e698e2e..10c386037d17 100644 --- a/crates/polars/tests/it/lazy/aggregation.rs +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -26,7 +26,7 @@ fn test_lazy_agg() { col("rain").min().alias("min"), col("rain").sum().alias("sum"), col("rain") - .quantile(lit(0.5), QuantileInterpolOptions::default()) + .quantile(lit(0.5), QuantileMethod::default()) .alias("median_rain"), ]) .sort(["date"], Default::default()); diff --git a/docs/source/development/contributing/code-style.md b/docs/source/development/contributing/code-style.md index 00ad8a8f726e..c22b9f3ed7ac 100644 --- a/docs/source/development/contributing/code-style.md +++ b/docs/source/development/contributing/code-style.md @@ -30,7 +30,7 @@ use polars::export::arrow::array::*; use polars::export::arrow::compute::arity::binary; use polars::export::arrow::types::NativeType; use polars::prelude::*; -use polars_core::utils::{align_chunks_binary, combine_validities_or}; +use polars_core::utils::{align_chunks_binary, combine_validities_and}; use polars_core::with_match_physical_numeric_polars_type; // Prefer to do the compute closest to the arrow arrays. @@ -45,7 +45,7 @@ where let validity_1 = arr_1.validity(); let validity_2 = arr_2.validity(); - let validity = combine_validities_or(validity_1, validity_2); + let validity = combine_validities_and(validity_1, validity_2); // process the numerical data as if there were no validities let values_1: &[T] = arr_1.values().as_slice(); diff --git a/docs/source/src/python/user-guide/io/cloud-storage.py b/docs/source/src/python/user-guide/io/cloud-storage.py index 73cf597ec84e..12b02df28e61 100644 --- a/docs/source/src/python/user-guide/io/cloud-storage.py +++ b/docs/source/src/python/user-guide/io/cloud-storage.py @@ -7,7 +7,16 @@ df = pl.read_parquet(source) # --8<-- [end:read_parquet] -# --8<-- [start:scan_parquet] +# --8<-- [start:scan_parquet_query] +import polars as pl + +source = "s3://bucket/*.parquet" + +df = pl.scan_parquet(source).filter(pl.col("id") < 100).select("id","value").collect() +# --8<-- [end:scan_parquet_query] + + +# --8<-- [start:scan_parquet_storage_options_aws] import polars as pl source = "s3://bucket/*.parquet" @@ -17,17 +26,42 @@ "aws_secret_access_key": "", "aws_region": "us-east-1", } -df = pl.scan_parquet(source, storage_options=storage_options) -# --8<-- [end:scan_parquet] +df = pl.scan_parquet(source, storage_options=storage_options).collect() +# --8<-- [end:scan_parquet_storage_options_aws] + +# --8<-- [start:credential_provider_class] +lf = pl.scan_parquet( + "s3://.../...", + credential_provider=pl.CredentialProviderAWS( + profile_name="..." + assume_role={ + "RoleArn": f"...", + "RoleSessionName": "...", + } + ), +) -# --8<-- [start:scan_parquet_query] -import polars as pl +df = lf.collect() +# --8<-- [end:credential_provider_class] -source = "s3://bucket/*.parquet" +# --8<-- [start:credential_provider_custom_func] +def get_credentials() -> pl.CredentialProviderFunctionReturn: + expiry = None + return { + "aws_access_key_id": "...", + "aws_secret_access_key": "...", + "aws_session_token": "...", + }, expiry -df = pl.scan_parquet(source).filter(pl.col("id") < 100).select("id","value").collect() -# --8<-- [end:scan_parquet_query] + +lf = pl.scan_parquet( + "s3://.../...", + credential_provider=get_credentials, +) + +df = lf.collect() +# --8<-- [end:credential_provider_custom_func] # --8<-- [start:scan_pyarrow_dataset] import polars as pl diff --git a/docs/source/src/rust/user-guide/io/cloud-storage.rs b/docs/source/src/rust/user-guide/io/cloud-storage.rs index 5c297739eeee..2df882a39c00 100644 --- a/docs/source/src/rust/user-guide/io/cloud-storage.rs +++ b/docs/source/src/rust/user-guide/io/cloud-storage.rs @@ -1,7 +1,3 @@ -// Issue with clippy interacting with tokio. See: -// https://github.com/rust-lang/rust-clippy/issues/13458 -#![allow(clippy::needless_return)] - // --8<-- [start:read_parquet] use aws_config::BehaviorVersion; use polars::prelude::*; @@ -31,12 +27,18 @@ async fn main() { } // --8<-- [end:read_parquet] -// --8<-- [start:scan_parquet] -// --8<-- [end:scan_parquet] - // --8<-- [start:scan_parquet_query] // --8<-- [end:scan_parquet_query] +// --8<-- [start:scan_parquet_storage_options_aws] +// --8<-- [end:scan_parquet_storage_options_aws] + +// --8<-- [start:credential_provider_class] +// --8<-- [end:credential_provider_class] + +// --8<-- [start:credential_provider_custom_func] +// --8<-- [end:credential_provider_custom_func] + // --8<-- [start:scan_pyarrow_dataset] // --8<-- [end:scan_pyarrow_dataset] diff --git a/docs/source/user-guide/ecosystem.md b/docs/source/user-guide/ecosystem.md index 21f1dbc2ba60..9a8f96c4f72f 100644 --- a/docs/source/user-guide/ecosystem.md +++ b/docs/source/user-guide/ecosystem.md @@ -20,25 +20,7 @@ On this page you can find a non-exhaustive list of libraries and tools that supp ### Data visualisation -#### hvPlot - -[hvPlot](https://hvplot.holoviz.org/) is available as the default plotting backend for Polars making it simple to create interactive and static visualisations. You can use hvPlot by using the feature flag `plot` during installing. - -```python -pip install 'polars[plot]' -``` - -#### Matplotlib - -[Matplotlib](https://matplotlib.org/) is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible. - -#### Plotly - -[Plotly](https://plotly.com/python/) is an interactive, open-source, and browser-based graphing library for Python. Built on top of plotly.js, it ships with over 30 chart types, including scientific charts, 3D graphs, statistical charts, SVG maps, financial charts, and more. - -#### [Seaborn](https://seaborn.pydata.org/) - -Seaborn is a Python data visualization library based on Matplotlib. It provides a high-level interface for drawing attractive and informative statistical graphics. +See the [dedicated visualization section](misc/visualization.md). ### IO @@ -71,3 +53,7 @@ With [Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html #### Mage [Mage](https://www.mage.ai) is an open-source data pipeline tool for transforming and integrating data. Learn about integration between Polars and Mage at [docs.mage.ai](https://docs.mage.ai/integrations/polars). + +#### marimo + +[marimo](https://marimo.io) is a reactive notebook for Python and SQL that models notebooks as dataflow graphs. It offers built-in support for Polars, allowing seamless integration of Polars dataframes in an interactive, reactive environment - such as displaying rich Polars tables, no-code transformations of Polars dataframes, or selecting points on a Polars-backed reactive chart. diff --git a/docs/source/user-guide/expressions/plugins.md b/docs/source/user-guide/expressions/plugins.md index e679b09ee180..9ef5633cfcd0 100644 --- a/docs/source/user-guide/expressions/plugins.md +++ b/docs/source/user-guide/expressions/plugins.md @@ -37,7 +37,7 @@ crate-type = ["cdylib"] [dependencies] polars = { version = "*" } -pyo3 = { version = "*", features = ["extension-module", "abi-py38"] } +pyo3 = { version = "*", features = ["extension-module", "abi3-py38"] } pyo3-polars = { version = "*", features = ["derive"] } serde = { version = "*", features = ["derive"] } ``` diff --git a/docs/source/user-guide/io/cloud-storage.md b/docs/source/user-guide/io/cloud-storage.md index ba686a5a0f11..f3b5d7a8fb09 100644 --- a/docs/source/user-guide/io/cloud-storage.md +++ b/docs/source/user-guide/io/cloud-storage.md @@ -18,23 +18,39 @@ To read from cloud storage, additional dependencies may be needed depending on t ## Reading from cloud storage -Polars can read a CSV, IPC or Parquet file in eager mode from cloud storage. +Polars supports reading Parquet, CSV, IPC and NDJSON files from cloud storage: {{code_block('user-guide/io/cloud-storage','read_parquet',['read_parquet','read_csv','read_ipc'])}} -This eager query downloads the file to a buffer in memory and creates a `DataFrame` from there. Polars uses `fsspec` to manage this download internally for all cloud storage providers. - ## Scanning from cloud storage with query optimisation -Polars can scan a Parquet file in lazy mode from cloud storage. We may need to provide further details beyond the source url such as authentication details or storage region. Polars looks for these as environment variables but we can also do this manually by passing a `dict` as the `storage_options` argument. +Using `pl.scan_*` functions to read from cloud storage can benefit from [predicate and projection pushdowns](../lazy/optimizations.md), where the query optimizer will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. -{{code_block('user-guide/io/cloud-storage','scan_parquet',['scan_parquet'])}} +{{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} -This query creates a `LazyFrame` without downloading the file. In the `LazyFrame` we have access to file metadata such as the schema. Polars uses the `object_store.rs` library internally to manage the interface with the cloud storage providers and so no extra dependencies are required in Python to scan a cloud Parquet file. +## Cloud authentication -If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimizer will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. +Polars is able to automatically load default credential configurations for some cloud providers. For +cases when this does not happen, it is possible to manually configure the credentials for Polars to +use for authentication. This can be done in a few ways: -{{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} +### Using `storage_options`: + +- Credentials can be passed as configuration keys in a dict with the `storage_options` parameter: + +{{code_block('user-guide/io/cloud-storage','scan_parquet_storage_options_aws',['scan_parquet'])}} + +### Using one of the available `CredentialProvider*` utility classes + +- There may be a utility class `pl.CredentialProvider*` that provides the required authentication functionality. For example, `pl.CredentialProviderAWS` supports selecting AWS profiles, as well as assuming an IAM role: + +{{code_block('user-guide/io/cloud-storage','credential_provider_class',['scan_parquet'])}} + +### Using a custom `credential_provider` function + +- Some environments may require custom authentication logic (e.g. AWS IAM role-chaining). For these cases a Python function can be provided for Polars to use to retrieve credentials: + +{{code_block('user-guide/io/cloud-storage','credential_provider_custom_func',['scan_parquet'])}} ## Scanning with PyArrow diff --git a/docs/source/user-guide/migration/pandas.md b/docs/source/user-guide/migration/pandas.md index 5fa435278949..3d1f0996bdad 100644 --- a/docs/source/user-guide/migration/pandas.md +++ b/docs/source/user-guide/migration/pandas.md @@ -50,8 +50,7 @@ eager evaluation. The lazy evaluation mode is powerful because Polars carries ou automatic query optimization when it examines the query plan and looks for ways to accelerate the query or reduce memory usage. -`Dask` also supports lazy evaluation when it generates a query plan. However, `Dask` -does not carry out query optimization on the query plan. +`Dask` also supports lazy evaluation when it generates a query plan. ## Key syntax differences diff --git a/examples/python_rust_compiled_function/Cargo.toml b/examples/python_rust_compiled_function/Cargo.toml deleted file mode 100644 index 94982fe498ef..000000000000 --- a/examples/python_rust_compiled_function/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "python_rust_compiled_function" -version = "0.1.0" -edition = "2021" - -[lib] -name = "my_polars_functions" -crate-type = ["cdylib"] - -[dependencies] -arrow = { workspace = true } -polars = { path = "../../crates/polars" } - -pyo3 = { workspace = true, features = ["extension-module"] } - -[build-dependencies] -pyo3-build-config = "0.21" diff --git a/examples/python_rust_compiled_function/README.md b/examples/python_rust_compiled_function/README.md deleted file mode 100644 index e958fc2fc24f..000000000000 --- a/examples/python_rust_compiled_function/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Compile Custom Rust functions and use in Python Polars - -## Compile a development binary in your current environment - -```sh -pip install -U maturin && maturin develop -``` - -## Run - -```sh -python example.py -``` - -## Compile a **release** build - -```sh -maturin develop --release -``` diff --git a/examples/python_rust_compiled_function/build.rs b/examples/python_rust_compiled_function/build.rs deleted file mode 100644 index dace4a9ba9f8..000000000000 --- a/examples/python_rust_compiled_function/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - pyo3_build_config::add_extension_module_link_args(); -} diff --git a/examples/python_rust_compiled_function/example.py b/examples/python_rust_compiled_function/example.py deleted file mode 100644 index 3b26d63975b1..000000000000 --- a/examples/python_rust_compiled_function/example.py +++ /dev/null @@ -1,19 +0,0 @@ -import polars as pl -from my_polars_functions import hamming_distance - -a = pl.Series("a", ["foo", "bar"]) -b = pl.Series("b", ["fooy", "ham"]) - -dist = hamming_distance(a, b) -expected = pl.Series("", [None, 2], dtype=pl.UInt32) - -# run on 2 Series -print("hamming distance: ", hamming_distance(a, b)) -assert dist.series_equal(expected, null_equal=True) - -# or use in polars expressions -print( - pl.DataFrame([a, b]).select( - pl.map(["a", "b"], lambda series: hamming_distance(series[0], series[1])) - ) -) diff --git a/examples/python_rust_compiled_function/pyproject.toml b/examples/python_rust_compiled_function/pyproject.toml deleted file mode 100644 index c0e2db718795..000000000000 --- a/examples/python_rust_compiled_function/pyproject.toml +++ /dev/null @@ -1,7 +0,0 @@ -[build-system] -requires = ["maturin~=1.2.1"] -build-backend = "maturin" - -[project] -name = "my_polars_functions" -version = "0.1.0" diff --git a/examples/python_rust_compiled_function/src/ffi.rs b/examples/python_rust_compiled_function/src/ffi.rs deleted file mode 100644 index 3597d1f83a03..000000000000 --- a/examples/python_rust_compiled_function/src/ffi.rs +++ /dev/null @@ -1,84 +0,0 @@ -use arrow::ffi; -use polars::prelude::*; -use pyo3::exceptions::PyValueError; -use pyo3::ffi::Py_uintptr_t; -use pyo3::prelude::*; -use pyo3::{PyAny, PyObject, PyResult}; - -/// Take an arrow array from python and convert it to a rust arrow array. -/// This operation does not copy data. -fn array_to_rust(arrow_array: &Bound) -> PyResult { - // prepare a pointer to receive the Array struct - let array = Box::new(ffi::ArrowArray::empty()); - let schema = Box::new(ffi::ArrowSchema::empty()); - - let array_ptr = &*array as *const ffi::ArrowArray; - let schema_ptr = &*schema as *const ffi::ArrowSchema; - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - arrow_array.call_method1( - "_export_to_c", - (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), - )?; - - unsafe { - let field = ffi::import_field_from_c(schema.as_ref()).unwrap(); - let array = ffi::import_array_from_c(*array, field.dtype).unwrap(); - Ok(array) - } -} - -/// Arrow array to Python. -pub(crate) fn to_py_array(py: Python, pyarrow: &Bound, array: ArrayRef) -> PyResult { - let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( - "", - array.dtype().clone(), - true, - ))); - let array = Box::new(ffi::export_array_to_c(array)); - - let schema_ptr: *const ffi::ArrowSchema = &*schema; - let array_ptr: *const ffi::ArrowArray = &*array; - - let array = pyarrow.getattr("Array")?.call_method1( - "_import_from_c", - (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), - )?; - - Ok(array.to_object(py)) -} - -pub fn py_series_to_rust_series(series: &Bound) -> PyResult { - // rechunk series so that they have a single arrow array - let series = series.call_method0("rechunk")?; - - let name = series.getattr("name")?.extract::()?; - - // retrieve pyarrow array - let array = series.call_method0("to_arrow")?; - - // retrieve rust arrow array - let array = array_to_rust(&array)?; - - Series::try_from((name.as_str(), array)).map_err(|e| PyValueError::new_err(format!("{}", e))) -} - -pub fn rust_series_to_py_series(series: &Series) -> PyResult { - // ensure we have a single chunk - let series = series.rechunk(); - let array = series.to_arrow(0, false); - - Python::with_gil(|py| { - // import pyarrow - let pyarrow = py.import_bound("pyarrow")?; - - // pyarrow array - let pyarrow_array = to_py_array(py, &pyarrow, array)?; - - // import polars - let polars = py.import_bound("polars")?; - let out = polars.call_method1("from_arrow", (pyarrow_array,))?; - Ok(out.to_object(py)) - }) -} diff --git a/examples/python_rust_compiled_function/src/lib.rs b/examples/python_rust_compiled_function/src/lib.rs deleted file mode 100644 index f8c2caec2123..000000000000 --- a/examples/python_rust_compiled_function/src/lib.rs +++ /dev/null @@ -1,50 +0,0 @@ -mod ffi; - -use polars::prelude::*; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -#[pyfunction] -fn hamming_distance(series_a: &Bound, series_b: &Bound) -> PyResult { - let series_a = ffi::py_series_to_rust_series(series_a)?; - let series_b = ffi::py_series_to_rust_series(series_b)?; - - let out = hamming_distance_impl(&series_a, &series_b) - .map_err(|e| PyValueError::new_err(format!("Something went wrong: {:?}", e)))?; - ffi::rust_series_to_py_series(&out.into_series()) -} - -/// This function iterates over 2 `StringChunked` arrays and computes the hamming distance between the values . -fn hamming_distance_impl(a: &Series, b: &Series) -> PolarsResult { - Ok(a.str()? - .into_iter() - .zip(b.str()?) - .map(|(lhs, rhs)| hamming_distance_strs(lhs, rhs)) - .collect()) -} - -/// Compute the hamming distance between 2 string values. -fn hamming_distance_strs(a: Option<&str>, b: Option<&str>) -> Option { - match (a, b) { - (None, _) => None, - (_, None) => None, - (Some(a), Some(b)) => { - if a.len() != b.len() { - None - } else { - Some( - a.chars() - .zip(b.chars()) - .map(|(a_char, b_char)| (a_char != b_char) as u32) - .sum::(), - ) - } - }, - } -} - -#[pymodule] -fn my_polars_functions(_py: Python, m: &Bound) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(hamming_distance)).unwrap(); - Ok(()) -} diff --git a/examples/read_csv/Cargo.toml b/examples/read_csv/Cargo.toml deleted file mode 100644 index 15cc4a573733..000000000000 --- a/examples/read_csv/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "read_csv" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "csv", "ipc"] } - -[features] -write_output = ["polars/ipc", "polars/parquet"] -default = ["write_output"] - -[workspace] diff --git a/examples/read_csv/src/main.rs b/examples/read_csv/src/main.rs deleted file mode 100644 index fdf2dc73be2e..000000000000 --- a/examples/read_csv/src/main.rs +++ /dev/null @@ -1,32 +0,0 @@ -use polars::io::mmap::MmapBytesReader; -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let file = std::fs::File::open("/home/ritchie46/Downloads/pdsh/tables_scale_100/lineitem.tbl") - .unwrap(); - let file = Box::new(file) as Box; - let _df = CsvReadOptions::default() - .map_parse_options(|x| x.with_separator(b'|')) - .with_has_header(false) - .with_chunk_size(10) - .into_reader_with_file_handle(file); - - // write_other_formats(&mut df)?; - Ok(()) -} - -fn _write_other_formats(df: &mut DataFrame) -> PolarsResult<()> { - let parquet_out = "../datasets/foods1.parquet"; - if std::fs::metadata(parquet_out).is_err() { - let f = std::fs::File::create(parquet_out).unwrap(); - ParquetWriter::new(f) - .with_statistics(StatisticsOptions::full()) - .finish(df)?; - } - let ipc_out = "../datasets/foods1.ipc"; - if std::fs::metadata(ipc_out).is_err() { - let f = std::fs::File::create(ipc_out).unwrap(); - IpcWriter::new(f).finish(df)? - } - Ok(()) -} diff --git a/examples/read_json/Cargo.toml b/examples/read_json/Cargo.toml deleted file mode 100644 index d9e66858c2c0..000000000000 --- a/examples/read_json/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "read_json" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["json"] } diff --git a/examples/read_json/src/main.rs b/examples/read_json/src/main.rs deleted file mode 100644 index f16c2abfdc69..000000000000 --- a/examples/read_json/src/main.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::io::Cursor; - -use polars::prelude::*; - -fn main() { - let data = r#"[ - {"date": "1996-12-16T00:00:00.000", "open": 16.86, "close": 16.86, "high": 16.86, "low": 16.86, "volume": 62442.0, "turnover": 105277000.0}, - {"date": "1996-12-17T00:00:00.000", "open": 15.17, "close": 15.17, "high": 16.79, "low": 15.17, "volume": 463675.0, "turnover": 718902016.0}, - {"date": "1996-12-18T00:00:00.000", "open": 15.28, "close": 16.69, "high": 16.69, "low": 15.18, "volume": 445380.0, "turnover": 719400000.0}, - {"date": "1996-12-19T00:00:00.000", "open": 17.01, "close": 16.4, "high": 17.9, "low": 15.99, "volume": 572946.0, "turnover": 970124992.0} - ]"#; - - let res = JsonReader::new(Cursor::new(data)).finish(); - println!("{:?}", res); - assert!(res.is_ok()); - let df = res.unwrap(); - println!("{:?}", df); -} diff --git a/examples/read_parquet/Cargo.toml b/examples/read_parquet/Cargo.toml deleted file mode 100644 index 0e88fb7b62ca..000000000000 --- a/examples/read_parquet/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "read_parquet" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "parquet"] } diff --git a/examples/read_parquet/src/main.rs b/examples/read_parquet/src/main.rs deleted file mode 100644 index 5023a95c2368..000000000000 --- a/examples/read_parquet/src/main.rs +++ /dev/null @@ -1,15 +0,0 @@ -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let df = LazyFrame::scan_parquet("../datasets/foods1.parquet", ScanArgsParquet::default())? - .select([ - // select all columns - all(), - // and do some aggregations - cols(["fats_g", "sugars_g"]).sum().name().suffix("_summed"), - ]) - .collect()?; - - println!("{}", df); - Ok(()) -} diff --git a/examples/read_parquet_cloud/Cargo.toml b/examples/read_parquet_cloud/Cargo.toml deleted file mode 100644 index bbb43403bd95..000000000000 --- a/examples/read_parquet_cloud/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "read_parquet_cloud" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet"] } - -aws-creds = "0.36.0" diff --git a/examples/read_parquet_cloud/src/main.rs b/examples/read_parquet_cloud/src/main.rs deleted file mode 100644 index 367575bbdd30..000000000000 --- a/examples/read_parquet_cloud/src/main.rs +++ /dev/null @@ -1,30 +0,0 @@ -use awscreds::Credentials; -use cloud::AmazonS3ConfigKey as Key; -use polars::prelude::*; - -// Login to your aws account and then copy the ../datasets/foods1.parquet file to your own bucket. -// Adjust the link below. -const TEST_S3: &str = "s3://lov2test/polars/datasets/*.parquet"; - -fn main() -> PolarsResult<()> { - let cred = Credentials::default().unwrap(); - - // Propagate the credentials and other cloud options. - let mut args = ScanArgsParquet::default(); - let cloud_options = cloud::CloudOptions::default().with_aws([ - (Key::AccessKeyId, &cred.access_key.unwrap()), - (Key::SecretAccessKey, &cred.secret_key.unwrap()), - (Key::Region, &"us-west-2".into()), - ]); - args.cloud_options = Some(cloud_options); - let df = LazyFrame::scan_parquet(TEST_S3, args)? - .with_streaming(true) - .select([ - // select all columns - all(), - ]) - .collect()?; - - println!("{}", df); - Ok(()) -} diff --git a/examples/string_filter/Cargo.toml b/examples/string_filter/Cargo.toml deleted file mode 100644 index a7738dd2db7f..000000000000 --- a/examples/string_filter/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "string_filter" -version = "0.1.0" -edition = "2021" - -[dependencies] -polars = { path = "../../crates/polars", features = ["strings", "lazy"] } diff --git a/examples/string_filter/src/main.rs b/examples/string_filter/src/main.rs deleted file mode 100644 index 5364f9b5b4d2..000000000000 --- a/examples/string_filter/src/main.rs +++ /dev/null @@ -1,13 +0,0 @@ -use polars::lazy::prelude::*; -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let a = Series::new("a", [1, 2, 3, 4]); - let b = Series::new("b", ["one", "two", "three", "four"]); - let df = DataFrame::new(vec![a, b])? - .lazy() - .filter(col("b").str().starts_with(lit("t"))) - .collect()?; - println!("{df:?}"); - Ok(()) -} diff --git a/examples/write_ipc_cloud/Cargo.toml b/examples/write_ipc_cloud/Cargo.toml deleted file mode 100644 index 764f67ed4504..000000000000 --- a/examples/write_ipc_cloud/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "write_ipc_cloud" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -aws-creds = "0.36.0" -polars = { path = "../../crates/polars", features = ["lazy", "aws", "ipc", "cloud_write", "streaming"] } - -[workspace] diff --git a/examples/write_ipc_cloud/src/main.rs b/examples/write_ipc_cloud/src/main.rs deleted file mode 100644 index 3da5c95fc362..000000000000 --- a/examples/write_ipc_cloud/src/main.rs +++ /dev/null @@ -1,30 +0,0 @@ -use cloud::AmazonS3ConfigKey as Key; -use polars::prelude::*; - -const TEST_S3_LOCATION: &str = "s3://test-bucket/test-writes/polars_write_example_cloud.ipc"; - -fn main() -> PolarsResult<()> { - let cloud_options = cloud::CloudOptions::default().with_aws([ - (Key::AccessKeyId, "test".to_string()), - (Key::SecretAccessKey, "test".to_string()), - (Key::Endpoint, "http://localhost:4566".to_string()), - (Key::Region, "us-east-1".to_string()), - ]); - let cloud_options = Some(cloud_options); - - let df = df!( - "foo" => &[1, 2, 3], - "bar" => &[None, Some("bak"), Some("baz")], - ) - .unwrap(); - - df.lazy() - .sink_ipc_cloud( - TEST_S3_LOCATION.to_string(), - cloud_options, - Default::default(), - ) - .unwrap(); - - Ok(()) -} diff --git a/examples/write_parquet_cloud/Cargo.toml b/examples/write_parquet_cloud/Cargo.toml deleted file mode 100644 index 1e79b04739b2..000000000000 --- a/examples/write_parquet_cloud/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "write_parquet_cloud" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -aws-creds = "0.36.0" -polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet", "cloud_write", "streaming"] } - -[workspace] diff --git a/examples/write_parquet_cloud/src/main.rs b/examples/write_parquet_cloud/src/main.rs deleted file mode 100644 index c928ceb4c765..000000000000 --- a/examples/write_parquet_cloud/src/main.rs +++ /dev/null @@ -1,62 +0,0 @@ -use awscreds::Credentials; -use cloud::AmazonS3ConfigKey as Key; -use polars::prelude::*; - -// Adjust the link below. -const TEST_S3_LOCATION: &str = "s3://polarstesting/polars_write_example_cloud.parquet"; - -fn main() -> PolarsResult<()> { - sink_file(); - sink_cloud_local(); - sink_aws(); - - Ok(()) -} - -fn sink_file() { - let df = example_dataframe(); - - // Writing to a local file: - let path = "/tmp/polars_write_example.parquet".into(); - df.lazy().sink_parquet(path, Default::default()).unwrap(); -} - -fn sink_cloud_local() { - let df = example_dataframe(); - - // Writing to a location that might be in the cloud: - let uri = "file:///tmp/polars_write_example_cloud.parquet".to_string(); - df.lazy() - .sink_parquet_cloud(uri, None, Default::default()) - .unwrap(); -} - -fn sink_aws() { - let cred = Credentials::default().unwrap(); - - // Propagate the credentials and other cloud options. - let cloud_options = cloud::CloudOptions::default().with_aws([ - (Key::AccessKeyId, &cred.access_key.unwrap()), - (Key::SecretAccessKey, &cred.secret_key.unwrap()), - (Key::Region, &"eu-central-1".into()), - ]); - let cloud_options = Some(cloud_options); - - let df = example_dataframe(); - - df.lazy() - .sink_parquet_cloud( - TEST_S3_LOCATION.to_string(), - cloud_options, - Default::default(), - ) - .unwrap(); -} - -fn example_dataframe() -> DataFrame { - df!( - "foo" => &[1, 2, 3], - "bar" => &[None, Some("bak"), Some("baz")], - ) - .unwrap() -} diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 63863959b720..fc3e520e5ecc 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "1.9.0" +version = "1.12.0" edition = "2021" [lib] @@ -104,6 +104,7 @@ all = [ "parquet", "ipc", "polars-python/all", + "performant", ] default = ["all", "nightly"] diff --git a/py-polars/Makefile b/py-polars/Makefile index 3c98adab08cb..3d9e5d7ffddc 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -23,39 +23,23 @@ requirements-all: .venv ## Install/refresh all Python requirements (including t @$(MAKE) -s -C .. $@ .PHONY: build -build: .venv ## Compile and install Polars for development - @$(MAKE) -s -C .. $@ - -.PHONY: build-debug-opt -build-debug-opt: .venv ## Compile and install Polars with minimal optimizations turned on - @$(MAKE) -s -C .. $@ - -.PHONY: build-debug-opt-subset -build-debug-opt-subset: .venv ## Compile and install Polars with minimal optimizations turned on and no default features - @$(MAKE) -s -C .. $@ - -.PHONY: build-opt -build-opt: .venv ## Compile and install Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on +build: .venv ## Compile and install Python Polars for development @$(MAKE) -s -C .. $@ .PHONY: build-release -build-release: .venv ## Compile and install a faster Polars binary with full optimizations - @$(MAKE) -s -C .. $@ - -.PHONY: build-native -build-native: .venv ## Same as build, except with native CPU optimizations turned on +build-release: .venv ## Compile and install Python Polars binary with optimizations, with minimal debug symbols @$(MAKE) -s -C .. $@ -.PHONY: build-debug-opt-native -build-debug-opt-native: .venv ## Same as build-debug-opt, except with native CPU optimizations turned on +.PHONY: build-nodebug-release +build-nodebug-release: .venv ## Same as build-release, but without any debug symbols at all (a bit faster to build) @$(MAKE) -s -C .. $@ -.PHONY: build-opt-native -build-opt-native: .venv ## Same as build-opt, except with native CPU optimizations turned on +.PHONY: build-debug-release +build-debug-release: .venv ## Same as build-release, but with full debug symbols turned on (a bit slower to build) @$(MAKE) -s -C .. $@ -.PHONY: build-release-native -build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on +.PHONY: build-dist-release +build-dist-release: .venv ## Compile and install Python Polars binary with super slow extra optimization turned on, for distribution @$(MAKE) -s -C .. $@ .PHONY: lint diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index a0cde717f0da..7c1358b480f6 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -16,6 +16,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.decode Expr.str.encode Expr.str.ends_with + Expr.str.escape_regex Expr.str.explode Expr.str.extract Expr.str.extract_all diff --git a/py-polars/docs/source/reference/functions.rst b/py-polars/docs/source/reference/functions.rst index c672aaa77eac..33ee296844db 100644 --- a/py-polars/docs/source/reference/functions.rst +++ b/py-polars/docs/source/reference/functions.rst @@ -25,6 +25,7 @@ Miscellaneous align_frames concat + escape_regex Parallelization ~~~~~~~~~~~~~~~ diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index 1f088958a3c0..bdfe93ee9ebe 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -117,3 +117,14 @@ Connect to pyarrow datasets. :toctree: api/ scan_pyarrow_dataset + +Cloud Credentials +~~~~~~~~~~~~~~~~~ +Configuration for cloud credential provisioning. + +.. autosummary:: + :toctree: api/ + + CredentialProvider + CredentialProviderAWS + CredentialProviderGCP diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst index 9e44b0d9b105..85dcf4b1b2d6 100644 --- a/py-polars/docs/source/reference/series/string.rst +++ b/py-polars/docs/source/reference/series/string.rst @@ -16,6 +16,7 @@ The following methods are available under the `Series.str` attribute. Series.str.decode Series.str.encode Series.str.ends_with + Series.str.escape_regex Series.str.explode Series.str.extract Series.str.extract_all diff --git a/py-polars/docs/source/reference/sql/functions/aggregate.rst b/py-polars/docs/source/reference/sql/functions/aggregate.rst index 0244af235175..07081c72454c 100644 --- a/py-polars/docs/source/reference/sql/functions/aggregate.rst +++ b/py-polars/docs/source/reference/sql/functions/aggregate.rst @@ -21,6 +21,11 @@ Aggregate - Returns the median element from the grouping. * - :ref:`MIN ` - Returns the smallest (minimum) of all the elements in the grouping. + * - :ref:`QUANTILE_CONT ` + - Returns the continuous quantile element from the grouping (interpolated value between two closest values). + * - :ref:`QUANTILE_DISC ` + - Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, and returns the + value associated with the subinterval where the quantile value falls. * - :ref:`STDDEV ` - Returns the standard deviation of all the elements in the grouping. * - :ref:`SUM ` @@ -198,6 +203,64 @@ Returns the smallest (minimum) of all the elements in the grouping. # │ 10 │ # └─────────┘ + +.. _quantile_cont: + +QUANTILE_CONT +------------- +Returns the continuous quantile element from the grouping (interpolated value between two closest values). + +**Example:** + +.. code-block:: python + + df = pl.DataFrame({"foo": [5, 20, 10, 30, 70, 40, 10, 90]}) + df.sql(""" + SELECT + QUANTILE_CONT(foo, 0.25) AS foo_q25, + QUANTILE_CONT(foo, 0.50) AS foo_q50, + QUANTILE_CONT(foo, 0.75) AS foo_q75, + FROM self + """) + # shape: (1, 3) + # ┌─────────┬─────────┬─────────┐ + # │ foo_q25 ┆ foo_q50 ┆ foo_q75 │ + # │ --- ┆ --- ┆ --- │ + # │ f64 ┆ f64 ┆ f64 │ + # ╞═════════╪═════════╪═════════╡ + # │ 10.0 ┆ 25.0 ┆ 47.5 │ + # └─────────┴─────────┴─────────┘ + + +.. _quantile_disc: + +QUANTILE_DISC +------------- +Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, and +returns the value associated with the subinterval where the quantile value falls. + +**Example:** + +.. code-block:: python + + df = pl.DataFrame({"foo": [5, 20, 10, 30, 70, 40, 10, 90]}) + df.sql(""" + SELECT + QUANTILE_DISC(foo, 0.25) AS foo_q25, + QUANTILE_DISC(foo, 0.50) AS foo_q50, + QUANTILE_DISC(foo, 0.75) AS foo_q75, + FROM self + """) + # shape: (1, 3) + # ┌─────────┬─────────┬─────────┐ + # │ foo_q25 ┆ foo_q50 ┆ foo_q75 │ + # │ --- ┆ --- ┆ --- │ + # │ f64 ┆ f64 ┆ f64 │ + # ╞═════════╪═════════╪═════════╡ + # │ 10.0 ┆ 20.0 ┆ 40.0 │ + # └─────────┴─────────┴─────────┘ + + .. _stddev: STDDEV diff --git a/py-polars/docs/source/reference/sql/functions/bitwise.rst b/py-polars/docs/source/reference/sql/functions/bitwise.rst new file mode 100644 index 000000000000..bd66c3810df1 --- /dev/null +++ b/py-polars/docs/source/reference/sql/functions/bitwise.rst @@ -0,0 +1,151 @@ +Temporal +======== + +.. list-table:: + :header-rows: 1 + :widths: 20 60 + + * - Function + - Description + + * - :ref:`BIT_AND ` + - Returns the bitwise AND of the given values. + * - :ref:`BIT_COUNT ` + - Returns the number of bits set to 1 in the binary representation of the given value. + * - :ref:`BIT_OR ` + - Returns the bitwise OR of the given values. + * - :ref:`BIT_XOR ` + - Returns the bitwise XOR of the given values. + + +.. _bit_and: + +BIT_AND +------- +Returns the bitwise AND of the given values. +Also available as the `&` binary operator. + +.. code-block:: python + + df = pl.DataFrame( + { + "i": [3, 10, 4, 8], + "j": [4, 7, 9, 10], + } + ) + df.sql(""" + SELECT + i, + j, + i & j AS i_bitand_op_j, + BIT_AND(i, j) AS i_bitand_j + FROM self + """) + # shape: (4, 4) + # ┌─────┬─────┬───────────────┬────────────┐ + # │ i ┆ j ┆ i_bitand_op_j ┆ i_bitand_j │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞═════╪═════╪═══════════════╪════════════╡ + # │ 3 ┆ 4 ┆ 0 ┆ 0 │ + # │ 10 ┆ 7 ┆ 2 ┆ 2 │ + # │ 4 ┆ 9 ┆ 0 ┆ 0 │ + # │ 8 ┆ 10 ┆ 8 ┆ 8 │ + # └─────┴─────┴───────────────┴────────────┘ + +.. _bit_count: + +BIT_COUNT +--------- +Returns the number of bits set to 1 in the binary representation of the given value. + +.. code-block:: python + + df = pl.DataFrame({"i": [16, 10, 55, 127]}) + df.sql(""" + SELECT + i, + BIT_COUNT(i) AS i_bitcount + FROM self + """) + # shape: (4, 2) + # ┌─────┬────────────┐ + # │ i ┆ i_bitcount │ + # │ --- ┆ --- │ + # │ i64 ┆ u32 │ + # ╞═════╪════════════╡ + # │ 16 ┆ 1 │ + # │ 10 ┆ 2 │ + # │ 55 ┆ 5 │ + # │ 127 ┆ 7 │ + # └─────┴────────────┘ + +.. _bit_or: + +BIT_OR +------ +Returns the bitwise OR of the given values. +Also available as the `|` binary operator. + +.. code-block:: python + + df = pl.DataFrame( + { + "i": [3, 10, 4, 8], + "j": [4, 7, 9, 10], + } + ) + df.sql(""" + SELECT + i, + j, + i | j AS i_bitor_op_j, + BIT_OR(i, j) AS i_bitor_j + FROM self + """) + # shape: (4, 4) + # ┌─────┬─────┬──────────────┬───────────┐ + # │ i ┆ j ┆ i_bitor_op_j ┆ i_bitor_j │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞═════╪═════╪══════════════╪═══════════╡ + # │ 3 ┆ 4 ┆ 7 ┆ 7 │ + # │ 10 ┆ 7 ┆ 15 ┆ 15 │ + # │ 4 ┆ 9 ┆ 13 ┆ 13 │ + # │ 8 ┆ 10 ┆ 10 ┆ 10 │ + # └─────┴─────┴──────────────┴───────────┘ + +.. _bit_xor: + +BIT_XOR +------- +Returns the bitwise XOR of the given values. +Also available as the `XOR` binary operator. + +.. code-block:: python + + df = pl.DataFrame( + { + "i": [3, 10, 4, 8], + "j": [4, 7, 9, 10], + } + ) + df.sql(""" + SELECT + i, + j, + i XOR j AS i_bitxor_op_j, + BIT_XOR(i, j) AS i_bitxor_j + FROM self + """) + # shape: (4, 4) + # ┌─────┬─────┬───────────────┬────────────┐ + # │ i ┆ j ┆ i_bitxor_op_j ┆ i_bitxor_j │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞═════╪═════╪═══════════════╪════════════╡ + # │ 3 ┆ 4 ┆ 7 ┆ 7 │ + # │ 10 ┆ 7 ┆ 13 ┆ 13 │ + # │ 4 ┆ 9 ┆ 13 ┆ 13 │ + # │ 8 ┆ 10 ┆ 2 ┆ 2 │ + # └─────┴─────┴───────────────┴────────────┘ diff --git a/py-polars/docs/source/reference/sql/functions/index.rst b/py-polars/docs/source/reference/sql/functions/index.rst index cf6247e41a10..3473a0741a91 100644 --- a/py-polars/docs/source/reference/sql/functions/index.rst +++ b/py-polars/docs/source/reference/sql/functions/index.rst @@ -32,6 +32,18 @@ SQL Functions array + .. grid-item-card:: + + **Bitwise** + ^^^^^^^^^^^ + + .. toctree:: + :maxdepth: 2 + + bitwise + +.. grid:: + .. grid-item-card:: **Conditional** @@ -52,8 +64,6 @@ SQL Functions math -.. grid:: - .. grid-item-card:: **String** @@ -64,6 +74,8 @@ SQL Functions string +.. grid:: + .. grid-item-card:: **Temporal** diff --git a/py-polars/docs/source/reference/sql/functions/trigonometry.rst b/py-polars/docs/source/reference/sql/functions/trigonometry.rst index a2ee47ad3d06..dcd7ba4fe18d 100644 --- a/py-polars/docs/source/reference/sql/functions/trigonometry.rst +++ b/py-polars/docs/source/reference/sql/functions/trigonometry.rst @@ -18,9 +18,9 @@ Trigonometry * - :ref:`ATAND ` - Compute inverse tangent of the input column (in degrees). * - :ref:`ATAN2 ` - - Compute the inverse tangent of column_2/column_1 (in radians). + - Compute the inverse tangent of column_1/column_2 (in radians). * - :ref:`ATAN2D ` - - Compute the inverse tangent of column_2/column_1 (in degrees). + - Compute the inverse tangent of column_1/column_2 (in degrees). * - :ref:`COT ` - Compute the cotangent of the input column (in radians). * - :ref:`COTD ` @@ -187,7 +187,7 @@ Compute inverse tangent of the input column (in degrees). ATAN2 ----- -Compute the inverse tangent of column_2/column_1 (in radians). +Compute the inverse tangent of column_1/column_2 (in radians). **Example:** @@ -216,7 +216,7 @@ Compute the inverse tangent of column_2/column_1 (in radians). ATAN2D ------ -Compute the inverse tangent of column_2/column_1 (in degrees). +Compute the inverse tangent of column_1/column_2 (in degrees). **Example:** @@ -224,22 +224,22 @@ Compute the inverse tangent of column_2/column_1 (in degrees). df = pl.DataFrame( { - "a": [0, 90, 180, 360], - "b": [360, 270, 180, 90], + "a": [-1.0, 0.0, 1.0, 1.0], + "b": [1.0, 1.0, 0.0, -1.0], } ) df.sql("SELECT a, b, ATAN2D(a, b) AS atan2d_ab FROM self") # shape: (4, 3) - # ┌─────┬─────┬───────────┐ - # │ a ┆ b ┆ atan2d_ab │ - # │ --- ┆ --- ┆ --- │ - # │ i64 ┆ i64 ┆ f64 │ - # ╞═════╪═════╪═══════════╡ - # │ 0 ┆ 360 ┆ 0.0 │ - # │ 90 ┆ 270 ┆ 18.434949 │ - # │ 180 ┆ 180 ┆ 45.0 │ - # │ 360 ┆ 90 ┆ 75.963757 │ - # └─────┴─────┴───────────┘ + # ┌──────┬──────┬───────────┐ + # │ a ┆ b ┆ atan2d_ab │ + # │ --- ┆ --- ┆ --- │ + # │ f64 ┆ f64 ┆ f64 │ + # ╞══════╪══════╪═══════════╡ + # │ -1 ┆ 1.0 ┆ 135.0 │ + # │ 0.0 ┆ 1.0 ┆ 90.0 │ + # │ 1.0 ┆ 0.0 ┆ 0.0 │ + # │ 1.0 ┆ -1.0 ┆ -45.0 │ + # └──────┴──────┴───────────┘ .. _cot: diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 10f0ee54228b..83ea52acc822 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -104,6 +104,7 @@ datetime_ranges, duration, element, + escape_regex, exclude, field, first, @@ -176,6 +177,13 @@ scan_parquet, scan_pyarrow_dataset, ) +from polars.io.cloud import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) from polars.lazyframe import GPUEngine, LazyFrame from polars.meta import ( build_info, @@ -266,6 +274,12 @@ "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", + # polars.io.cloud + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", # polars.stringcache "StringCache", "disable_string_cache", @@ -290,6 +304,7 @@ "time_range", "time_ranges", "zeros", + "escape_regex", # polars.functions.aggregation "all", "all_horizontal", diff --git a/py-polars/polars/_reexport.py b/py-polars/polars/_reexport.py index 408fead781de..10818f473166 100644 --- a/py-polars/polars/_reexport.py +++ b/py-polars/polars/_reexport.py @@ -3,12 +3,14 @@ from polars.dataframe import DataFrame from polars.expr import Expr, When from polars.lazyframe import LazyFrame +from polars.schema import Schema from polars.series import Series __all__ = [ "DataFrame", "Expr", "LazyFrame", + "Schema", "Series", "When", ] diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 1670b08aeb2f..67a06a2c689e 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -1,7 +1,9 @@ from __future__ import annotations from collections.abc import Collection, Iterable, Mapping, Sequence +from pathlib import Path from typing import ( + IO, TYPE_CHECKING, Any, Literal, @@ -158,9 +160,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: RollingInterpolationMethod: TypeAlias = Literal[ "nearest", "higher", "lower", "midpoint", "linear" ] # QuantileInterpolOptions -ToStructStrategy: TypeAlias = Literal[ - "first_non_null", "max_width" -] # ListToStructWidthStrategy +ListToStructWidthStrategy: TypeAlias = Literal["first_non_null", "max_width"] # The following have no equivalent on the Rust side ConcatMethod = Literal[ @@ -294,3 +294,14 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: # LazyFrame engine selection EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"] + +ScanSource: TypeAlias = Union[ + str, + Path, + IO[bytes], + bytes, + list[str], + list[Path], + list[IO[bytes]], + list[bytes], +] diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index d6774cfd126f..dc5060c4e4c3 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -35,6 +35,7 @@ "ASCII_BORDERS_ONLY_CONDENSED", "ASCII_HORIZONTAL_ONLY", "ASCII_MARKDOWN", + "MARKDOWN", "UTF8_FULL", "UTF8_FULL_CONDENSED", "UTF8_NO_BORDERS", @@ -195,7 +196,7 @@ def __init__( >>> df = pl.DataFrame({"abc": [1.0, 2.5, 5.0], "xyz": [True, False, True]}) >>> with pl.Config( ... # these options will be set for scope duration - ... tbl_formatting="ASCII_MARKDOWN", + ... tbl_formatting="MARKDOWN", ... tbl_hide_dataframe_shape=True, ... tbl_rows=10, ... ): @@ -1037,7 +1038,8 @@ def set_tbl_formatting( * "ASCII_BORDERS_ONLY": ASCII, borders only. * "ASCII_BORDERS_ONLY_CONDENSED": ASCII, borders only, dense row spacing. * "ASCII_HORIZONTAL_ONLY": ASCII, horizontal lines only. - * "ASCII_MARKDOWN": ASCII, Markdown compatible. + * "ASCII_MARKDOWN": Markdown format (ascii ellipses for truncated values). + * "MARKDOWN": Markdown format (utf8 ellipses for truncated values). * "UTF8_FULL": UTF8, with all borders and lines, including row dividers. * "UTF8_FULL_CONDENSED": Same as UTF8_FULL, but with dense row spacing. * "UTF8_NO_BORDERS": UTF8, no borders. @@ -1060,7 +1062,7 @@ def set_tbl_formatting( ... {"abc": [-2.5, 5.0], "mno": ["hello", "world"], "xyz": [True, False]} ... ) >>> with pl.Config( - ... tbl_formatting="ASCII_MARKDOWN", + ... tbl_formatting="MARKDOWN", ... tbl_hide_column_data_types=True, ... tbl_hide_dataframe_shape=True, ... ): diff --git a/py-polars/polars/convert/general.py b/py-polars/polars/convert/general.py index cee4b925e9e9..f80526e5a2c9 100644 --- a/py-polars/polars/convert/general.py +++ b/py-polars/polars/convert/general.py @@ -4,7 +4,7 @@ import itertools import re from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Literal, overload import polars._reexport as pl from polars import functions as F @@ -98,7 +98,7 @@ def from_dict( def from_dicts( - data: Sequence[dict[str, Any]], + data: Iterable[dict[str, Any]], schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, @@ -487,15 +487,26 @@ def from_pandas( @overload def from_pandas( - data: pd.Series[Any] | pd.Index[Any], + data: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, *, schema_overrides: SchemaDict | None = ..., rechunk: bool = ..., nan_to_null: bool = ..., - include_index: bool = ..., + include_index: Literal[False] = ..., ) -> Series: ... +@overload +def from_pandas( + data: pd.Series[Any], + *, + schema_overrides: SchemaDict | None = ..., + rechunk: bool = ..., + nan_to_null: bool = ..., + include_index: Literal[True] = ..., +) -> DataFrame: ... + + def from_pandas( data: pd.DataFrame | pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, *, @@ -525,8 +536,8 @@ def from_pandas( Load any non-default pandas indexes as columns. .. note:: - If the input is a pandas ``Series`` or ``DataFrame`` and has a nameless - index which just enumerates the rows, then it will not be included in the + If the input is a pandas ``DataFrame`` and has a nameless index + which just enumerates the rows, then it will not be included in the result, regardless of this parameter. If you want to be sure to include it, please call ``.reset_index()`` prior to calling this function. @@ -566,6 +577,9 @@ def from_pandas( 3 ] """ + if include_index and isinstance(data, pd.Series): + data = data.reset_index() + if isinstance(data, (pd.Series, pd.Index, pd.DatetimeIndex)): return wrap_s(pandas_to_pyseries("", data, nan_to_null=nan_to_null)) elif isinstance(data, pd.DataFrame): @@ -724,6 +738,7 @@ def _from_dataframe_repr(m: re.Match[str]) -> DataFrame: if schema and data and (n_extend_cols := (len(schema) - len(data))) > 0: empty_data = [None] * len(data[0]) data.extend((pl.Series(empty_data, dtype=String)) for _ in range(n_extend_cols)) + for dtype in set(schema.values()): if dtype in (List, Struct, Object): msg = ( diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index ae32d3b454e4..721a02c7fec6 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -912,7 +912,7 @@ def schema(self) -> Schema: >>> df.schema Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return Schema(zip(self.columns, self.dtypes)) + return Schema(zip(self.columns, self.dtypes), check_dtypes=False) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None @@ -1980,7 +1980,7 @@ def to_jax( Create the Array on a specific GPU device: - >>> gpu_device = jax.devices("gpu")[1]) # doctest: +SKIP + >>> gpu_device = jax.devices("gpu")[1] # doctest: +SKIP >>> a = df.to_jax(device=gpu_device) # doctest: +SKIP >>> a.device() # doctest: +SKIP GpuDevice(id=1, process_index=0) @@ -3893,11 +3893,14 @@ def write_database( Select the engine to use for writing frame data; only necessary when supplying a URI string (defaults to 'sqlalchemy' if unset) engine_options - Additional options to pass to the engine's associated insert method: - - * "sqlalchemy" - currently inserts using Pandas' `to_sql` method, though - this will eventually be phased out in favor of a native solution. - * "adbc" - inserts using the ADBC cursor's `adbc_ingest` method. + Additional options to pass to the insert method associated with the engine + specified by the option `engine`. + + * Setting `engine` to "sqlalchemy" currently inserts using Pandas' `to_sql` + method (though this will eventually be phased out in favor of a native + solution). + * Setting `engine` to "adbc" inserts using the ADBC cursor's `adbc_ingest` + method. Examples -------- diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index d46d8c111581..f773cc28b6ca 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -65,10 +65,14 @@ def is_polars_dtype( - dtype: Any, *, include_unknown: bool = False + dtype: Any, + *, + include_unknown: bool = False, + require_instantiated: bool = False, ) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" - is_dtype = isinstance(dtype, (DataType, DataTypeClass)) + check_classes = DataType if require_instantiated else (DataType, DataTypeClass) + is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type] if not include_unknown: return is_dtype and dtype != Unknown diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index 9aaed1352d09..4628ee2a9c15 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -6,7 +6,7 @@ import polars._reexport as pl from polars import functions as F from polars._utils.convert import parse_as_duration_string -from polars._utils.deprecation import deprecate_function +from polars._utils.deprecation import deprecate_function, deprecate_nonkeyword_arguments from polars._utils.parse import parse_into_expression from polars._utils.unstable import unstable from polars._utils.wrap import wrap_expr @@ -35,6 +35,7 @@ class ExprDateTimeNameSpace: def __init__(self, expr: Expr) -> None: self._pyexpr = expr._pyexpr + @deprecate_nonkeyword_arguments(allowed_args=["self", "n"], version="1.12.0") def add_business_days( self, n: int | IntoExpr, @@ -97,7 +98,9 @@ def add_business_days( You can pass a custom weekend - for example, if you only take Sunday off: >>> week_mask = (True, True, True, True, True, True, False) - >>> df.with_columns(result=pl.col("start").dt.add_business_days(5, week_mask)) + >>> df.with_columns( + ... result=pl.col("start").dt.add_business_days(5, week_mask=week_mask) + ... ) shape: (2, 2) ┌────────────┬────────────┐ │ start ┆ result │ diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 98f638b0846a..4613aabd4beb 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -407,9 +407,16 @@ def to_physical(self) -> Expr: - :func:`polars.datatypes.Duration` -> :func:`polars.datatypes.Int64` - :func:`polars.datatypes.Categorical` -> :func:`polars.datatypes.UInt32` - `List(inner)` -> `List(physical of inner)` + - `Array(inner)` -> `Struct(physical of inner)` + - `Struct(fields)` -> `Array(physical of fields)` Other data types will be left unchanged. + Warning + ------- + The physical representations are an implementation detail + and not guaranteed to be stable. + Examples -------- Replicating the pandas @@ -2815,7 +2822,7 @@ def fill_nan(self, value: int | float | Expr | None) -> Expr: def forward_fill(self, limit: int | None = None) -> Expr: """ - Fill missing values with the latest seen values. + Fill missing values with the last non-null value. Parameters ---------- @@ -2851,7 +2858,7 @@ def forward_fill(self, limit: int | None = None) -> Expr: def backward_fill(self, limit: int | None = None) -> Expr: """ - Fill missing values with the next to be seen values. + Fill missing values with the next non-null value. Parameters ---------- @@ -4195,7 +4202,6 @@ def filter( Filter expressions can also take constraints as keyword arguments. - >>> import polars.selectors as cs >>> df = pl.DataFrame( ... { ... "key": ["a", "a", "a", "a", "b", "b", "b", "b", "b"], diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 48b4d1da9c49..4d239460d6b5 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -16,8 +16,8 @@ from polars._typing import ( IntoExpr, IntoExprColumn, + ListToStructWidthStrategy, NullBehavior, - ToStructStrategy, ) @@ -1092,7 +1092,7 @@ def to_array(self, width: int) -> Expr: def to_struct( self, - n_field_strategy: ToStructStrategy = "first_non_null", + n_field_strategy: ListToStructWidthStrategy = "first_non_null", fields: Sequence[str] | Callable[[int], str] | None = None, upper_bound: int = 0, ) -> Expr: @@ -1180,9 +1180,8 @@ def to_struct( [{'n': {'one': 0, 'two': 1}}, {'n': {'one': 2, 'two': 3}}] """ if isinstance(fields, Sequence): - field_names = list(fields) - pyexpr = self._pyexpr.list_to_struct(n_field_strategy, None, upper_bound) - return wrap_expr(pyexpr).struct.rename_fields(field_names) + pyexpr = self._pyexpr.list_to_struct_fixed_width(fields) + return wrap_expr(pyexpr) else: pyexpr = self._pyexpr.list_to_struct(n_field_strategy, fields, upper_bound) return wrap_expr(pyexpr) diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 7582758d5921..e94f995ee700 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -2781,6 +2781,28 @@ def concat( delimiter = "-" return self.join(delimiter, ignore_nulls=ignore_nulls) + def escape_regex(self) -> Expr: + r""" + Returns string values with all regular expression meta characters escaped. + + Examples + -------- + >>> df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + >>> df.with_columns(pl.col("text").str.escape_regex().alias("escaped")) + shape: (4, 2) + ┌──────────┬──────────────┐ + │ text ┆ escaped │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞══════════╪══════════════╡ + │ abc ┆ abc │ + │ def ┆ def │ + │ null ┆ null │ + │ abc(\w+) ┆ abc\(\\w\+\) │ + └──────────┴──────────────┘ + """ + return wrap_expr(self._pyexpr.str_escape_regex()) + def _validate_format_argument(format: str | None) -> None: if format is not None and ".%f" in format: diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index fedd0ac2bff0..32fbe4578059 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -26,6 +26,7 @@ from polars.functions.business import business_day_count from polars.functions.col import col from polars.functions.eager import align_frames, concat +from polars.functions.escape_regex import escape_regex from polars.functions.lazy import ( approx_n_unique, arctan2, @@ -170,4 +171,6 @@ # polars.functions.whenthen "when", "sql_expr", + # polars.functions.escape_regex + "escape_regex", ] diff --git a/py-polars/polars/functions/escape_regex.py b/py-polars/polars/functions/escape_regex.py new file mode 100644 index 000000000000..1c038347e8af --- /dev/null +++ b/py-polars/polars/functions/escape_regex.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import contextlib + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr +import polars._reexport as pl + + +def escape_regex(s: str) -> str: + r""" + Escapes string regex meta characters. + + Parameters + ---------- + s + The string that all of its meta characters will be escaped. + + """ + if isinstance(s, pl.Expr): + msg = "escape_regex function is unsupported for `Expr`, you may want use `Expr.str.escape_regex` instead" + raise TypeError(msg) + elif not isinstance(s, str): + msg = f"escape_regex function supports only `str` type, got `{type(s)}`" + raise TypeError(msg) + + return plr.escape_regex(s) diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index 68d4b604d6a6..527f1da01240 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -7,7 +7,11 @@ from pathlib import Path from typing import IO, TYPE_CHECKING, Any, overload -from polars._utils.various import is_int_sequence, is_str_sequence, normalize_filepath +from polars._utils.various import ( + is_int_sequence, + is_str_sequence, + normalize_filepath, +) from polars.dependencies import _FSSPEC_AVAILABLE, fsspec from polars.exceptions import NoDataError diff --git a/py-polars/polars/io/cloud/__init__.py b/py-polars/polars/io/cloud/__init__.py new file mode 100644 index 000000000000..f5ef9c5fd0bf --- /dev/null +++ b/py-polars/polars/io/cloud/__init__.py @@ -0,0 +1,15 @@ +from polars.io.cloud.credential_provider import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) + +__all__ = [ + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", +] diff --git a/py-polars/polars/io/cloud/_utils.py b/py-polars/polars/io/cloud/_utils.py new file mode 100644 index 000000000000..7279838aa005 --- /dev/null +++ b/py-polars/polars/io/cloud/_utils.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from pathlib import Path +from typing import IO + +from polars._utils.various import is_path_or_str_sequence + + +def _first_scan_path( + source: str + | Path + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | list[bytes], +) -> str | Path | None: + if isinstance(source, (str, Path)): + return source + elif is_path_or_str_sequence(source) and source: + return source[0] + + return None + + +def _get_path_scheme(path: str | Path) -> str | None: + splitted = str(path).split("://", maxsplit=1) + + return None if not splitted else splitted[0] + + +def _is_aws_cloud(scheme: str) -> bool: + return any(scheme == x for x in ["s3", "s3a"]) + + +def _is_gcp_cloud(scheme: str) -> bool: + return any(scheme == x for x in ["gs", "gcp", "gcs"]) diff --git a/py-polars/polars/io/cloud/credential_provider.py b/py-polars/polars/io/cloud/credential_provider.py new file mode 100644 index 000000000000..26e8ebf6826e --- /dev/null +++ b/py-polars/polars/io/cloud/credential_provider.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import abc +import importlib.util +import os +import sys +import zoneinfo +from typing import IO, TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict, Union + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + from pathlib import Path + +from polars._utils.unstable import issue_unstable_warning + +# These typedefs are here to avoid circular import issues, as +# `CredentialProviderFunction` specifies "CredentialProvider" +CredentialProviderFunctionReturn: TypeAlias = tuple[ + dict[str, Optional[str]], Optional[int] +] + +CredentialProviderFunction: TypeAlias = Union[ + Callable[[], CredentialProviderFunctionReturn], "CredentialProvider" +] + + +class AWSAssumeRoleKWArgs(TypedDict): + """Parameters for [STS.Client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role).""" + + RoleArn: str + RoleSessionName: str + PolicyArns: list[dict[str, str]] + Policy: str + DurationSeconds: int + Tags: list[dict[str, str]] + TransitiveTagKeys: list[str] + ExternalId: str + SerialNumber: str + TokenCode: str + SourceIdentity: str + ProvidedContexts: list[dict[str, str]] + + +class CredentialProvider(abc.ABC): + """ + Base class for credential providers. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + @abc.abstractmethod + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetches the credentials.""" + + +class CredentialProviderAWS(CredentialProvider): + """ + AWS Credential Provider. + + Using this requires the `boto3` Python package to be installed. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__( + self, + *, + profile_name: str | None = None, + assume_role: AWSAssumeRoleKWArgs | None = None, + ) -> None: + """ + Initialize a credential provider for AWS. + + Parameters + ---------- + profile_name : str + Profile name to use from credentials file. + assume_role : AWSAssumeRoleKWArgs | None + Configure a role to assume. These are passed as kwarg parameters to + [STS.client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role) + """ + msg = "`CredentialProviderAWS` functionality is considered unstable" + issue_unstable_warning(msg) + + self._check_module_availability() + self.profile_name = profile_name + self.assume_role = assume_role + + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetch the credentials for the configured profile name.""" + import boto3 + + session = boto3.Session(profile_name=self.profile_name) + + if self.assume_role is not None: + return self._finish_assume_role(session) + + creds = session.get_credentials() + + if creds is None: + msg = "unexpected None value returned from boto3.Session.get_credentials()" + raise ValueError(msg) + + return { + "aws_access_key_id": creds.access_key, + "aws_secret_access_key": creds.secret_key, + "aws_session_token": creds.token, + }, None + + def _finish_assume_role(self, session: Any) -> CredentialProviderFunctionReturn: + client = session.client("sts") + + sts_response = client.assume_role(**self.assume_role) + creds = sts_response["Credentials"] + + expiry = creds["Expiration"] + + if expiry.tzinfo is None: + msg = "expiration time in STS response did not contain timezone information" + raise ValueError(msg) + + return { + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds["SecretAccessKey"], + "aws_session_token": creds["SessionToken"], + }, int(expiry.timestamp()) + + @classmethod + def _check_module_availability(cls) -> None: + if importlib.util.find_spec("boto3") is None: + msg = "boto3 must be installed to use `CredentialProviderAWS`" + raise ImportError(msg) + + +class CredentialProviderGCP(CredentialProvider): + """ + GCP Credential Provider. + + Using this requires the `google-auth` Python package to be installed. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__(self) -> None: + """Initialize a credential provider for Google Cloud (GCP).""" + msg = "`CredentialProviderAWS` functionality is considered unstable" + issue_unstable_warning(msg) + + self._check_module_availability() + + import google.auth + import google.auth.credentials + + # CI runs with both `mypy` and `mypy --allow-untyped-calls` depending on + # Python version. If we add a `type: ignore[no-untyped-call]`, then the + # check that runs with `--allow-untyped-calls` will complain about an + # unused "type: ignore" comment. And if we don't add the ignore, then + # he check that runs `mypy` will complain. + # + # So we just bypass it with a __dict__[] (because ruff complains about + # getattr) :| + creds, _ = google.auth.__dict__["default"]() + self.creds = creds + + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetch the credentials for the configured profile name.""" + import google.auth.transport.requests + + self.creds.refresh(google.auth.transport.requests.__dict__["Request"]()) + + return {"bearer_token": self.creds.token}, ( + int( + ( + expiry.replace(tzinfo=zoneinfo.ZoneInfo("UTC")) + if expiry.tzinfo is None + else expiry + ).timestamp() + ) + if (expiry := self.creds.expiry) is not None + else None + ) + + @classmethod + def _check_module_availability(cls) -> None: + if importlib.util.find_spec("google.auth") is None: + msg = "google-auth must be installed to use `CredentialProviderGCP`" + raise ImportError(msg) + + +def _maybe_init_credential_provider( + credential_provider: CredentialProviderFunction | Literal["auto"] | None, + source: str + | Path + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | list[bytes], + storage_options: dict[str, Any] | None, + caller_name: str, +) -> CredentialProviderFunction | CredentialProvider | None: + from polars.io.cloud._utils import ( + _first_scan_path, + _get_path_scheme, + _is_aws_cloud, + _is_gcp_cloud, + ) + + if credential_provider is not None: + msg = f"The `credential_provider` parameter of `{caller_name}` is considered unstable." + issue_unstable_warning(msg) + + if credential_provider != "auto": + return credential_provider + + if storage_options is not None: + return None + + verbose = os.getenv("POLARS_VERBOSE") == "1" + + if (path := _first_scan_path(source)) is None: + return None + + if (scheme := _get_path_scheme(path)) is None: + return None + + provider = None + + try: + provider = ( + CredentialProviderAWS() + if _is_aws_cloud(scheme) + else CredentialProviderGCP() + if _is_gcp_cloud(scheme) + else None + ) + except ImportError as e: + if verbose: + msg = f"Unable to auto-select credential provider: {e}" + print(msg, file=sys.stderr) + + if provider is not None and verbose: + msg = f"Auto-selected credential provider: {type(provider).__name__}" + print(msg, file=sys.stderr) + + return provider diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 7ded05836f90..daebcf452c1e 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from io import BytesIO, StringIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable +from typing import IO, TYPE_CHECKING, Any, Callable, Literal import polars._reexport as pl import polars.functions as F @@ -24,6 +24,7 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _maybe_init_credential_provider from polars.io.csv._utils import _check_arg_is_1byte, _update_columns from polars.io.csv.batched_reader import BatchedCsvReader @@ -35,6 +36,7 @@ from polars import DataFrame, LazyFrame from polars._typing import CsvEncoding, PolarsDataType, SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31") @@ -1034,6 +1036,7 @@ def scan_csv( decimal_comma: bool = False, glob: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -1154,6 +1157,14 @@ def scan_csv( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. file_cache_ttl @@ -1259,6 +1270,10 @@ def with_column_names(cols: list[str]) -> list[str]: if not infer_schema: infer_schema_length = 0 + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_csv" + ) + return _scan_csv_impl( source, has_header=has_header, @@ -1289,6 +1304,7 @@ def with_column_names(cols: list[str]) -> list[str]: glob=glob, retries=retries, storage_options=storage_options, + credential_provider=credential_provider, file_cache_ttl=file_cache_ttl, include_file_paths=include_file_paths, ) @@ -1332,6 +1348,7 @@ def _scan_csv_impl( decimal_comma: bool = False, glob: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -1384,6 +1401,7 @@ def _scan_csv_impl( glob=glob, schema=schema, cloud_options=storage_options, + credential_provider=credential_provider, retries=retries, file_cache_ttl=file_cache_ttl, include_file_paths=include_file_paths, diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index 278e3e8e0738..cb6b6b92ff1b 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -338,7 +338,7 @@ def _inject_type_overrides( @staticmethod def _is_alchemy_async(conn: Any) -> bool: - """Check if the cursor/connection/session object is async.""" + """Check if the given connection is SQLALchemy async.""" try: from sqlalchemy.ext.asyncio import ( AsyncConnection, @@ -352,7 +352,7 @@ def _is_alchemy_async(conn: Any) -> bool: @staticmethod def _is_alchemy_engine(conn: Any) -> bool: - """Check if the cursor/connection/session object is async.""" + """Check if the given connection is a SQLAlchemy Engine.""" from sqlalchemy.engine import Engine if isinstance(conn, Engine): @@ -364,9 +364,14 @@ def _is_alchemy_engine(conn: Any) -> bool: except ImportError: return False + @staticmethod + def _is_alchemy_object(conn: Any) -> bool: + """Check if the given connection is a SQLAlchemy object (of any kind).""" + return type(conn).__module__.split(".", 1)[0] == "sqlalchemy" + @staticmethod def _is_alchemy_session(conn: Any) -> bool: - """Check if the cursor/connection/session object is async.""" + """Check if the given connection is a SQLAlchemy Session object.""" from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, sessionmaker @@ -392,7 +397,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor: return conn.engine.raw_connection().cursor() elif conn.engine.driver == "duckdb_engine": self.driver_name = "duckdb" - return conn.engine.raw_connection().driver_connection + return conn elif self._is_alchemy_engine(conn): # note: if we create it, we can close it self.can_close_cursor = True @@ -482,7 +487,7 @@ def execute( options = options or {} - if self.driver_name == "sqlalchemy": + if self._is_alchemy_object(self.cursor): cursor_execute, options, query = self._sqlalchemy_setup(query, options) else: cursor_execute = self.cursor.execute @@ -505,8 +510,11 @@ def execute( ) result = cursor_execute(query, *positional_options) - # note: some cursors execute in-place + # note: some cursors execute in-place, some access results via a property result = self.cursor if result is None else result + if self.driver_name == "duckdb": + result = result.cursor + self.result = result return self diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index ac5dcaaac3e7..21e436dc0557 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -25,10 +25,12 @@ except ImportError: Selectable: TypeAlias = Any # type: ignore[no-redef] + from sqlalchemy.sql.elements import TextClause + @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: Literal[False] = ..., @@ -41,7 +43,7 @@ def read_database( @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: Literal[True], @@ -54,7 +56,7 @@ def read_database( @overload def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: bool, @@ -66,7 +68,7 @@ def read_database( def read_database( - query: str | Selectable, + query: str | TextClause | Selectable, connection: ConnectionOrCursor | str, *, iter_batches: bool = False, @@ -263,6 +265,51 @@ def read_database( ) +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["adbc"], + schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: list[str] | str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["connectorx"] | None = None, + schema_overrides: SchemaDict | None = None, + execute_options: None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: DbReadEngine | None = None, + schema_overrides: None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + def read_database_uri( query: list[str] | str, uri: str, diff --git a/py-polars/polars/io/delta.py b/py-polars/polars/io/delta.py index 9b0219fa8dff..b2b27c74a20f 100644 --- a/py-polars/polars/io/delta.py +++ b/py-polars/polars/io/delta.py @@ -1,10 +1,12 @@ from __future__ import annotations +import warnings from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse +from polars import DataFrame from polars.convert import from_arrow from polars.datatypes import Null, Time from polars.datatypes.convert import unpack_dtypes @@ -12,11 +14,13 @@ from polars.io.pyarrow_dataset import scan_pyarrow_dataset if TYPE_CHECKING: - from polars import DataFrame, DataType, LazyFrame + from deltalake import DeltaTable + + from polars import DataType, LazyFrame def read_delta( - source: str, + source: str | DeltaTable, *, version: int | str | datetime | None = None, columns: list[str] | None = None, @@ -31,7 +35,7 @@ def read_delta( Parameters ---------- source - Path or URI to the root of the Delta lake table. + DeltaTable or a Path or URI to the root of the Delta lake table. Note: For Local filesystem, absolute and relative paths are supported but for the supported object storages - GCS, Azure and S3 full URI must be provided. @@ -138,22 +142,23 @@ def read_delta( if pyarrow_options is None: pyarrow_options = {} - resolved_uri = _resolve_delta_lake_uri(source) - dl_tbl = _get_delta_lake_table( - table_path=resolved_uri, + table_path=source, version=version, storage_options=storage_options, delta_table_options=delta_table_options, ) - return from_arrow( - dl_tbl.to_pyarrow_table(columns=columns, **pyarrow_options), rechunk=rechunk - ) # type: ignore[return-value] + return cast( + DataFrame, + from_arrow( + dl_tbl.to_pyarrow_table(columns=columns, **pyarrow_options), rechunk=rechunk + ), + ) def scan_delta( - source: str, + source: str | DeltaTable, *, version: int | str | datetime | None = None, storage_options: dict[str, Any] | None = None, @@ -166,7 +171,7 @@ def scan_delta( Parameters ---------- source - Path or URI to the root of the Delta lake table. + DeltaTable or a Path or URI to the root of the Delta lake table. Note: For Local filesystem, absolute and relative paths are supported but for the supported object storages - GCS, Azure and S3 full URI must be provided. @@ -274,9 +279,8 @@ def scan_delta( if pyarrow_options is None: pyarrow_options = {} - resolved_uri = _resolve_delta_lake_uri(source) dl_tbl = _get_delta_lake_table( - table_path=resolved_uri, + table_path=source, version=version, storage_options=storage_options, delta_table_options=delta_table_options, @@ -299,7 +303,7 @@ def _resolve_delta_lake_uri(table_uri: str, *, strict: bool = True) -> str: def _get_delta_lake_table( - table_path: str, + table_path: str | DeltaTable, version: int | str | datetime | None = None, storage_options: dict[str, Any] | None = None, delta_table_options: dict[str, Any] | None = None, @@ -314,12 +318,27 @@ def _get_delta_lake_table( """ _check_if_delta_available() + if isinstance(table_path, deltalake.DeltaTable): + if any( + [ + version is not None, + storage_options is not None, + delta_table_options is not None, + ] + ): + warnings.warn( + """When supplying a DeltaTable directly, `version`, `storage_options`, and `delta_table_options` are ignored. + To silence this warning, don't supply those parameters.""", + RuntimeWarning, + stacklevel=1, + ) + return table_path if delta_table_options is None: delta_table_options = {} - + resolved_uri = _resolve_delta_lake_uri(table_path) if not isinstance(version, (str, datetime)): dl_tbl = deltalake.DeltaTable( - table_path, + resolved_uri, version=version, storage_options=storage_options, **delta_table_options, diff --git a/py-polars/polars/io/iceberg.py b/py-polars/polars/io/iceberg.py index d816703f56ac..5d47eb2f389e 100644 --- a/py-polars/polars/io/iceberg.py +++ b/py-polars/polars/io/iceberg.py @@ -42,6 +42,7 @@ def scan_iceberg( source: str | Table, *, + snapshot_id: int | None = None, storage_options: dict[str, Any] | None = None, ) -> LazyFrame: """ @@ -54,6 +55,8 @@ def scan_iceberg( Note: For Local filesystem, absolute and relative paths are supported but for the supported object storages - GCS, Azure and S3 full URI must be provided. + snapshot_id + The snapshot ID to scan from. storage_options Extra options for the storage backends supported by `pyiceberg`. For cloud storages, this may include configurations for authentication etc. @@ -126,6 +129,12 @@ def scan_iceberg( >>> pl.scan_iceberg( ... table_path, storage_options=storage_options ... ).collect() # doctest: +SKIP + + Creates a scan for an Iceberg table using a specific snapshot ID. + + >>> table_path = "/path/to/iceberg-table/metadata.json" + >>> snapshot_id = 7051579356916758811 + >>> pl.scan_iceberg(table_path, snapshot_id=snapshot_id).collect() # doctest: +SKIP """ from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.table import StaticTable @@ -135,7 +144,13 @@ def scan_iceberg( metadata_location=source, properties=storage_options or {} ) - func = partial(_scan_pyarrow_dataset_impl, source) + if snapshot_id is not None: + snapshot = source.snapshot_by_id(snapshot_id) + if snapshot is None: + msg = f"Snapshot ID not found: {snapshot_id}" + raise ValueError(msg) + + func = partial(_scan_pyarrow_dataset_impl, source, snapshot_id=snapshot_id) arrow_schema = schema_to_pyarrow(source.schema()) return pl.LazyFrame._scan_python_function(arrow_schema, func, pyarrow=True) @@ -145,6 +160,7 @@ def _scan_pyarrow_dataset_impl( with_columns: list[str] | None = None, predicate: str = "", n_rows: int | None = None, + snapshot_id: int | None = None, **kwargs: Any, ) -> DataFrame | Series: """ @@ -160,6 +176,8 @@ def _scan_pyarrow_dataset_impl( pyarrow expression that can be evaluated with eval n_rows: Materialize only n rows from the arrow dataset. + snapshot_id: + The snapshot ID to scan from. batch_size The maximum row count for scanned pyarrow record batches. kwargs: @@ -171,7 +189,7 @@ def _scan_pyarrow_dataset_impl( """ from polars import from_arrow - scan = tbl.scan(limit=n_rows) + scan = tbl.scan(limit=n_rows, snapshot_id=snapshot_id) if with_columns is not None: scan = scan.select(*with_columns) diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 6d64a560d094..1348134fc0d6 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -3,7 +3,7 @@ import contextlib import os from pathlib import Path -from typing import IO, TYPE_CHECKING, Any +from typing import IO, TYPE_CHECKING, Any, Literal import polars._reexport as pl import polars.functions as F @@ -22,6 +22,7 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _maybe_init_credential_provider with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PyLazyFrame @@ -32,6 +33,7 @@ from polars import DataFrame, DataType, LazyFrame from polars._typing import SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @@ -362,6 +364,7 @@ def scan_ipc( row_index_name: str | None = None, row_index_offset: int = 0, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, memory_map: bool = True, retries: int = 2, file_cache_ttl: int | None = None, @@ -407,6 +410,15 @@ def scan_ipc( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + memory_map Try to memory map the file. This can greatly improve performance on repeated queries as the OS may cache pages. @@ -451,6 +463,16 @@ def scan_ipc( # Memory Mapping is now a no-op _ = memory_map + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_parquet" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + pylf = PyLazyFrame.new_from_ipc( source, sources, @@ -459,6 +481,7 @@ def scan_ipc( rechunk, parse_row_index_args(row_index_name, row_index_offset), cloud_options=storage_options, + credential_provider=credential_provider, retries=retries, file_cache_ttl=file_cache_ttl, hive_partitioning=hive_partitioning, diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 7a5fb2c0d1e6..983b8cddcfe1 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -4,13 +4,14 @@ from collections.abc import Sequence from io import BytesIO, StringIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any +from typing import IO, TYPE_CHECKING, Any, Literal from polars._utils.deprecation import deprecate_renamed_parameter from polars._utils.various import is_path_or_str_sequence, normalize_filepath from polars._utils.wrap import wrap_df, wrap_ldf from polars.datatypes import N_INFER_DEFAULT from polars.io._utils import parse_row_index_args +from polars.io.cloud.credential_provider import _maybe_init_credential_provider with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PyLazyFrame @@ -20,6 +21,7 @@ from polars import DataFrame, LazyFrame from polars._typing import SchemaDefinition + from polars.io.cloud import CredentialProviderFunction def read_ndjson( @@ -36,6 +38,7 @@ def read_ndjson( row_index_offset: int = 0, ignore_errors: bool = False, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -96,6 +99,14 @@ def read_ndjson( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. file_cache_ttl @@ -147,6 +158,10 @@ def read_ndjson( return df + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "read_ndjson" + ) + return scan_ndjson( source, schema=schema, @@ -162,6 +177,7 @@ def read_ndjson( include_file_paths=include_file_paths, retries=retries, storage_options=storage_options, + credential_provider=credential_provider, file_cache_ttl=file_cache_ttl, ).collect() @@ -190,6 +206,7 @@ def scan_ndjson( row_index_offset: int = 0, ignore_errors: bool = False, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, @@ -249,6 +266,14 @@ def scan_ndjson( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. file_cache_ttl @@ -276,6 +301,10 @@ def scan_ndjson( msg = "'infer_schema_length' should be positive" raise ValueError(msg) + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_ndjson" + ) + if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] else: @@ -297,6 +326,7 @@ def scan_ndjson( include_file_paths=include_file_paths, retries=retries, cloud_options=storage_options, + credential_provider=credential_provider, file_cache_ttl=file_cache_ttl, ) return wrap_ldf(pylf) diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index c27ee99ca94b..16ff7f614349 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -21,27 +21,24 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _maybe_init_credential_provider with contextlib.suppress(ImportError): from polars.polars import PyLazyFrame from polars.polars import read_parquet_schema as _read_parquet_schema if TYPE_CHECKING: + from typing import Literal + from polars import DataFrame, DataType, LazyFrame - from polars._typing import ParallelStrategy, SchemaDict + from polars._typing import ParallelStrategy, ScanSource, SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_parquet( - source: str - | Path - | IO[bytes] - | bytes - | list[str] - | list[Path] - | list[IO[bytes]] - | list[bytes], + source: ScanSource, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -57,6 +54,7 @@ def read_parquet( rechunk: bool = False, low_memory: bool = False, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, use_pyarrow: bool = False, pyarrow_options: dict[str, Any] | None = None, @@ -139,6 +137,14 @@ def read_parquet( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. use_pyarrow @@ -203,16 +209,9 @@ def read_parquet( rechunk=rechunk, ) - # Read file and bytes inputs using `read_parquet` - if isinstance(source, bytes): - source = io.BytesIO(source) - elif isinstance(source, list) and len(source) > 0 and isinstance(source[0], bytes): - assert all(isinstance(s, bytes) for s in source) - source = [io.BytesIO(s) for s in source] # type: ignore[arg-type, assignment] - # For other inputs, defer to `scan_parquet` lf = scan_parquet( - source, # type: ignore[arg-type] + source, n_rows=n_rows, row_index_name=row_index_name, row_index_offset=row_index_offset, @@ -226,6 +225,7 @@ def read_parquet( low_memory=low_memory, cache=False, storage_options=storage_options, + credential_provider=credential_provider, retries=retries, glob=glob, include_file_paths=include_file_paths, @@ -322,7 +322,7 @@ def read_parquet_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, Dat @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_parquet( - source: str | Path | IO[bytes] | list[str] | list[Path] | list[IO[bytes]], + source: ScanSource, *, n_rows: int | None = None, row_index_name: str | None = None, @@ -338,6 +338,7 @@ def scan_parquet( low_memory: bool = False, cache: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, include_file_paths: str | None = None, allow_missing_columns: bool = False, @@ -426,6 +427,14 @@ def scan_parquet( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. include_file_paths @@ -474,6 +483,10 @@ def scan_parquet( normalize_filepath(source, check_not_directory=False) for source in source ] + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "scan_parquet" + ) + return _scan_parquet_impl( source, # type: ignore[arg-type] n_rows=n_rows, @@ -483,6 +496,7 @@ def scan_parquet( row_index_name=row_index_name, row_index_offset=row_index_offset, storage_options=storage_options, + credential_provider=credential_provider, low_memory=low_memory, use_statistics=use_statistics, hive_partitioning=hive_partitioning, @@ -506,6 +520,7 @@ def _scan_parquet_impl( row_index_name: str | None = None, row_index_offset: int = 0, storage_options: dict[str, object] | None = None, + credential_provider: CredentialProviderFunction | None = None, low_memory: bool = False, use_statistics: bool = True, hive_partitioning: bool | None = None, @@ -539,6 +554,7 @@ def _scan_parquet_impl( parse_row_index_args(row_index_name, row_index_offset), low_memory, cloud_options=storage_options, + credential_provider=credential_provider, use_statistics=use_statistics, hive_partitioning=hive_partitioning, schema=schema, diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 8b3df004f4b3..1d4bb5fe90b6 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -565,7 +565,7 @@ def _read_spreadsheet( infer_schema_length=infer_schema_length, ) engine_options = (engine_options or {}).copy() - schema_overrides = dict(schema_overrides or {}) + schema_overrides = pl.Schema(schema_overrides or {}) # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 93eb038d94f9..307c02db4c07 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1380,7 +1380,12 @@ def sort( └──────┴─────┴─────┘ """ # Fast path for sorting by a single existing column - if isinstance(by, str) and not more_by: + if ( + isinstance(by, str) + and not more_by + and isinstance(descending, bool) + and isinstance(nulls_last, bool) + ): return self._from_pyldf( self._ldf.sort( by, descending, nulls_last, maintain_order, multithreaded @@ -2269,7 +2274,7 @@ def collect_schema(self) -> Schema: >>> schema.len() 3 """ - return Schema(self._ldf.collect_schema()) + return Schema(self._ldf.collect_schema(), check_dtypes=False) @unstable() def sink_parquet( diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 72eb8b86d25e..81ade5a6b206 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -1,23 +1,54 @@ from __future__ import annotations +import sys from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union -from polars.datatypes import DataType +from polars._typing import PythonDataType +from polars.datatypes import DataType, DataTypeClass, is_polars_dtype from polars.datatypes._parse import parse_into_dtype -BaseSchema = OrderedDict[str, DataType] - if TYPE_CHECKING: from collections.abc import Iterable - from polars._typing import PythonDataType + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + +if sys.version_info >= (3, 10): + + def _required_init_args(tp: DataTypeClass) -> bool: + # note: this check is ~20% faster than the check for a + # custom "__init__", below, but is not available on py39 + return bool(tp.__annotations__) +else: + + def _required_init_args(tp: DataTypeClass) -> bool: + # indicates override of the default __init__ + # (eg: this type requires specific args) + return "__init__" in tp.__dict__ + + +BaseSchema = OrderedDict[str, DataType] +SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType] __all__ = ["Schema"] +def _check_dtype(tp: DataType | DataTypeClass) -> DataType: + if not isinstance(tp, DataType): + # note: if nested/decimal, or has signature params, this implies required args + if tp.is_nested() or tp.is_decimal() or _required_init_args(tp): + msg = f"dtypes must be fully-specified, got: {tp!r}" + raise TypeError(msg) + tp = tp() + return tp # type: ignore[return-value] + + class Schema(BaseSchema): """ Ordered mapping of column names to their data type. @@ -54,18 +85,42 @@ class Schema(BaseSchema): def __init__( self, schema: ( - Mapping[str, DataType | PythonDataType] - | Iterable[tuple[str, DataType | PythonDataType]] + Mapping[str, SchemaInitDataType] + | Iterable[tuple[str, SchemaInitDataType]] | None ) = None, + *, + check_dtypes: bool = True, ) -> None: input = ( schema.items() if schema and isinstance(schema, Mapping) else (schema or {}) ) - super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc] - - def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None: - super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment] + for name, tp in input: # type: ignore[misc] + if not check_dtypes: + super().__setitem__(name, tp) # type: ignore[assignment] + elif is_polars_dtype(tp): + super().__setitem__(name, _check_dtype(tp)) + else: + self[name] = tp + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Mapping): + return False + if len(self) != len(other): + return False + for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()): + if nm1 != nm2 or not tp1.is_(tp2): + return False + return True + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __setitem__( + self, name: str, dtype: DataType | DataTypeClass | PythonDataType + ) -> None: + dtype = _check_dtype(parse_into_dtype(dtype)) + super().__setitem__(name, dtype) def names(self) -> list[str]: """Get the column names of the schema.""" @@ -81,7 +136,7 @@ def len(self) -> int: def to_python(self) -> dict[str, type]: """ - Return Schema as a dictionary of column names and their Python types. + Return a dictionary of column names and Python types. Examples -------- diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 2631f222612a..4cb11506b3f6 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -508,7 +508,7 @@ def as_expr(self) -> Expr: def _re_string(string: str | Collection[str], *, escape: bool = True) -> str: """Return escaped regex, potentially representing multiple string fragments.""" if isinstance(string, str): - rx = f"{re_escape(string)}" if escape else string + rx = re_escape(string) if escape else string else: strings: list[str] = [] for st in string: diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index dcf65ff15312..3b0e905b84fc 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -97,7 +97,7 @@ def add_business_days( You can pass a custom weekend - for example, if you only take Sunday off: >>> week_mask = (True, True, True, True, True, True, False) - >>> s.dt.add_business_days(5, week_mask) + >>> s.dt.add_business_days(5, week_mask=week_mask) shape: (2,) Series: 'start' [date] [ diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index cf70f5225f56..0c4b08982606 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -14,8 +14,8 @@ from polars._typing import ( IntoExpr, IntoExprColumn, + ListToStructWidthStrategy, NullBehavior, - ToStructStrategy, ) from polars.polars import PySeries @@ -855,7 +855,7 @@ def to_array(self, width: int) -> Series: def to_struct( self, - n_field_strategy: ToStructStrategy = "first_non_null", + n_field_strategy: ListToStructWidthStrategy = "first_non_null", fields: Callable[[int], str] | Sequence[str] | None = None, ) -> Series: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index ea37a64aa778..8e27b3470b16 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4047,8 +4047,15 @@ def to_physical(self) -> Series: - :func:`polars.datatypes.Duration` -> :func:`polars.datatypes.Int64` - :func:`polars.datatypes.Categorical` -> :func:`polars.datatypes.UInt32` - `List(inner)` -> `List(physical of inner)` + - `Array(inner)` -> `Array(physical of inner)` + - `Struct(fields)` -> `Struct(physical of fields)` - Other data types will be left unchanged. + Warning + ------- + The physical representations are an implementation detail + and not guaranteed to be stable. + Examples -------- Replicating the pandas diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index af9ce66850c4..97f3e373e98f 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -2077,3 +2077,25 @@ def concat( null ] """ + + def escape_regex(self) -> Series: + r""" + Returns string values with all regular expression meta characters escaped. + + Returns + ------- + Series + Series of data type :class:`String`. + + Examples + -------- + >>> pl.Series(["abc", "def", None, "abc(\\w+)"]).str.escape_regex() + shape: (4,) + Series: '' [str] + [ + "abc" + "def" + null + "abc\(\\w\+\)" + ] + """ diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index e8137a23be32..a04d254808d3 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -107,7 +107,7 @@ def schema(self) -> Schema: return Schema({}) schema = self._s.dtype().to_schema() - return Schema(schema) + return Schema(schema, check_dtypes=False) def unnest(self) -> DataFrame: """ diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 9d0a2df292cb..e89a8e19c0b6 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -74,3 +74,4 @@ flask-cors # Stub files pandas-stubs boto3-stubs +google-auth-stubs diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 1c645738102a..859609828d19 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -275,6 +275,10 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::set_random_seed)) .unwrap(); + // Functions - escape_regex + m.add_wrapped(wrap_pyfunction!(functions::escape_regex)) + .unwrap(); + // Exceptions - Errors m.add( "PolarsError", @@ -377,6 +381,9 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { #[cfg(feature = "polars_cloud")] m.add_wrapped(wrap_pyfunction!(cloud::prepare_cloud_plan)) .unwrap(); + #[cfg(feature = "polars_cloud")] + m.add_wrapped(wrap_pyfunction!(cloud::_execute_ir_plan_with_gpu)) + .unwrap(); // Build info m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index d340433ddf10..3ab507f31fa8 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -150,7 +150,7 @@ def test_init_dict() -> None: data={"dt": dates, "dtm": datetimes}, schema=coldefs, ) - assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime} + assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")} assert df.rows() == list(zip(py_dates, py_datetimes)) # Overriding dict column names/types @@ -251,7 +251,7 @@ class TradeNT(NamedTuple): ) assert df.schema == { "ts": pl.Datetime("ms"), - "tk": pl.Categorical, + "tk": pl.Categorical(ordering="physical"), "pc": pl.Decimal(scale=1), "sz": pl.UInt16, } @@ -284,7 +284,6 @@ class PageView(BaseModel): models = adapter.validate_json(data_json) result = pl.DataFrame(models) - expected = pl.DataFrame( { "user_id": ["x"], diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 015b64ce6303..d8910cda4fb2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -303,10 +303,9 @@ def test_dataframe_membership_operator() -> None: def test_sort() -> None: df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3]}) - assert_frame_equal(df.sort("a"), pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]})) - assert_frame_equal( - df.sort(["a", "b"]), pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]}) - ) + expected = pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]}) + assert_frame_equal(df.sort("a"), expected) + assert_frame_equal(df.sort(["a", "b"]), expected) def test_sort_multi_output_exprs_01() -> None: @@ -761,7 +760,7 @@ def test_to_dummies() -> None: "i": [1, 2, 3], "category": ["dog", "cat", "cat"], }, - schema={"i": pl.Int32, "category": pl.Categorical}, + schema={"i": pl.Int32, "category": pl.Categorical("lexical")}, ) expected = pl.DataFrame( { diff --git a/py-polars/tests/unit/dataframe/test_getitem.py b/py-polars/tests/unit/dataframe/test_getitem.py index 5d112ad67528..ab618a944b80 100644 --- a/py-polars/tests/unit/dataframe/test_getitem.py +++ b/py-polars/tests/unit/dataframe/test_getitem.py @@ -479,3 +479,9 @@ def test_df_getitem_5343() -> None: assert df[4, 5] == 1024 assert_frame_equal(df[4, [2]], pl.DataFrame({"foo2": [16]})) assert_frame_equal(df[4, [5]], pl.DataFrame({"foo5": [1024]})) + + +def test_no_deadlock_19358() -> None: + s = pl.Series(["text"] * 100 + [1] * 100, dtype=pl.Object) + result = s.to_frame()[[0, -1]] + assert result[""].to_list() == ["text", 1] diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index c5888abee67e..f8750f9e2a95 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -864,3 +864,15 @@ def test_nested_categorical_concat( with pytest.raises(pl.exceptions.StringCacheMismatchError): pl.concat([a, b]) + + +def test_perfect_group_by_19452() -> None: + n = 40 + df2 = pl.DataFrame( + { + "a": pl.int_range(n, eager=True).cast(pl.String).cast(pl.Categorical), + "b": pl.int_range(n, eager=True), + } + ) + + assert df2.with_columns(a=(pl.col("b")).over(pl.col("a")))["a"].is_sorted() diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 9ad384cd1bab..9bd4a49ddd1d 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -536,3 +536,21 @@ def test_integer_cast_to_enum_15738(dt: pl.DataType) -> None: assert s.to_list() == ["a", "b", "c"] expected_s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) assert_series_equal(s, expected_s) + + +def test_enum_19269() -> None: + en = pl.Enum(["X", "Z", "Y"]) + df = pl.DataFrame( + {"test": pl.Series(["X", "Y", "Z"], dtype=en), "group": [1, 2, 2]} + ) + out = ( + df.group_by("group", maintain_order=True) + .agg(pl.col("test").mode()) + .select( + a=pl.col("test").list.max(), + b=pl.col("test").list.min(), + ) + ) + + assert out.to_dict(as_series=False) == {"a": ["X", "Y"], "b": ["X", "Z"]} + assert out.dtypes == [en, en] diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 8c5502d698fd..f7774fd70191 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -49,7 +49,7 @@ def test_dtype() -> None: "u": pl.List(pl.UInt64), "tm": pl.List(pl.Time), "dt": pl.List(pl.Date), - "dtm": pl.List(pl.Datetime), + "dtm": pl.List(pl.Datetime("us")), } assert all(tp.is_nested() for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None: assert df.to_dict(as_series=False) == expected df = pl.DataFrame(schema=[("col", pl.List)]) - assert df.schema == {"col": pl.List} + assert df.schema == {"col": pl.List(pl.Null)} assert df.rows() == [] diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 149367621aa0..706ae904775b 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -3,7 +3,7 @@ import io from dataclasses import dataclass from datetime import datetime, time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import pandas as pd import pyarrow as pa @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame: # struct column df = build_struct_df([{"struct_col": {"inner": 1}}]) assert df.columns == ["struct_col"] - assert df.schema == {"struct_col": pl.Struct} + assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})} assert df["struct_col"].struct.field("inner").to_list() == [1] # struct in struct @@ -1149,3 +1149,40 @@ def test_list_to_struct_19208() -> None: ).to_dict(as_series=False) == { "nested": [{"field_0": {"a": 1}}, {"field_0": None}, {"field_0": {"a": 3}}] } + + +def test_struct_reverse_outer_validity_19445() -> None: + assert_series_equal( + pl.Series([{"a": 1}, None]).reverse(), + pl.Series([None, {"a": 1}]), + ) + + +@pytest.mark.parametrize("maybe_swap", [lambda a, b: (a, b), lambda a, b: (b, a)]) +def test_struct_eq_missing_outer_validity_19156( + maybe_swap: Callable[[pl.Series, pl.Series], tuple[pl.Series, pl.Series]], +) -> None: + # Ensure that lit({'x': NULL}).eq_missing(lit(NULL)) => False + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([None, {"a": None, "b": None}]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([False, False])) + assert_series_equal(l.ne_missing(r), pl.Series([True, True])) + + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([None]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([False, True])) + assert_series_equal(l.ne_missing(r), pl.Series([True, False])) + + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([{"a": None, "b": None}]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([True, False])) + assert_series_equal(l.ne_missing(r), pl.Series([False, True])) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a925c6f18781..042a0fca786b 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -632,16 +632,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.072", "2016-05-25 13:30:00.075", ] - ticker = [ - "GOOG", - "MSFT", - "MSFT", - "MSFT", - "GOOG", - "AAPL", - "GOOG", - "MSFT", - ] + ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"] quotes = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -656,13 +647,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.048", "2016-05-25 13:30:00.048", ] - ticker = [ - "MSFT", - "MSFT", - "GOOG", - "GOOG", - "AAPL", - ] + ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"] trades = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -678,11 +663,11 @@ def test_asof_join() -> None: out = trades.join_asof(quotes, on="dates", strategy="backward") assert out.schema == { - "bid": pl.Float64, - "bid_right": pl.Float64, "dates": pl.Datetime("ms"), "ticker": pl.String, + "bid": pl.Float64, "ticker_right": pl.String, + "bid_right": pl.Float64, } assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"] assert (out["dates"].cast(int)).to_list() == [ diff --git a/py-polars/tests/unit/functions/as_datatype/test_duration.py b/py-polars/tests/unit/functions/as_datatype/test_duration.py index 8bef49e1f074..a3fc5dc658ef 100644 --- a/py-polars/tests/unit/functions/as_datatype/test_duration.py +++ b/py-polars/tests/unit/functions/as_datatype/test_duration.py @@ -181,3 +181,12 @@ def test_duration_time_unit_ms() -> None: result = pl.duration(milliseconds=4) expected = pl.duration(milliseconds=4, time_unit="us") assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in pl.duration + # https://github.com/pola-rs/polars/issues/19007 + df = df = pl.DataFrame({"a": [1], "b": [2]}) + assert df.select(pl.duration(hours=pl.all()).name.keep()).to_dict( + as_series=False + ) == {"a": [timedelta(seconds=3600)], "b": [timedelta(seconds=7200)]} diff --git a/py-polars/tests/unit/functions/range/test_date_range.py b/py-polars/tests/unit/functions/range/test_date_range.py index a881d30c1e41..a88287bedf41 100644 --- a/py-polars/tests/unit/functions/range/test_date_range.py +++ b/py-polars/tests/unit/functions/range/test_date_range.py @@ -7,7 +7,7 @@ import pytest import polars as pl -from polars.exceptions import ComputeError, PanicException +from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -21,7 +21,7 @@ def test_date_range() -> None: def test_date_range_invalid_time_unit() -> None: - with pytest.raises(PanicException, match="'x' not supported"): + with pytest.raises(InvalidOperationError, match="'x' not supported"): pl.date_range( start=date(2021, 12, 16), end=date(2021, 12, 18), @@ -312,6 +312,7 @@ def test_date_ranges_datetime_input() -> None: assert_series_equal(result, expected) +@pytest.mark.may_fail_auto_streaming def test_date_range_with_subclass_18470_18447() -> None: class MyAmazingDate(date): pass diff --git a/py-polars/tests/unit/functions/range/test_datetime_range.py b/py-polars/tests/unit/functions/range/test_datetime_range.py index 8dbc9da15c29..ce99cd27b802 100644 --- a/py-polars/tests/unit/functions/range/test_datetime_range.py +++ b/py-polars/tests/unit/functions/range/test_datetime_range.py @@ -9,7 +9,7 @@ import polars as pl from polars.datatypes import DTYPE_TEMPORAL_UNITS -from polars.exceptions import ComputeError, PanicException, SchemaError +from polars.exceptions import ComputeError, InvalidOperationError, SchemaError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: @@ -96,7 +96,7 @@ def test_datetime_range_precision( def test_datetime_range_invalid_time_unit() -> None: - with pytest.raises(PanicException, match="'x' not supported"): + with pytest.raises(InvalidOperationError, match="'x' not supported"): pl.datetime_range( start=datetime(2021, 12, 16), end=datetime(2021, 12, 16, 3), diff --git a/py-polars/tests/unit/functions/test_functions.py b/py-polars/tests/unit/functions/test_functions.py index de7e49574393..05bd11976fd8 100644 --- a/py-polars/tests/unit/functions/test_functions.py +++ b/py-polars/tests/unit/functions/test_functions.py @@ -538,3 +538,22 @@ def test_head_tail(fruits_cars: pl.DataFrame) -> None: res_expr = fruits_cars.select(pl.tail("A", 2)) expected = pl.Series("A", [4, 5]) assert_series_equal(res_expr.to_series(), expected) + + +def test_escape_regex() -> None: + result = pl.escape_regex("abc(\\w+)") + expected = "abc\\(\\\\w\\+\\)" + assert result == expected + + df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + with pytest.raises( + TypeError, + match="escape_regex function is unsupported for `Expr`, you may want use `Expr.str.escape_regex` instead", + ): + df.with_columns(escaped=pl.escape_regex(pl.col("text"))) # type: ignore[arg-type] + + with pytest.raises( + TypeError, + match="escape_regex function supports only `str` type, got ``", + ): + pl.escape_regex(3) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index c22d8abafd21..50ef4f2ac0ad 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -60,7 +60,7 @@ def test_from_pandas() -> None: "floats_nulls": pl.Float64, "strings": pl.String, "strings_nulls": pl.String, - "strings-cat": pl.Categorical, + "strings-cat": pl.Categorical(ordering="physical"), } assert out.rows() == [ (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), @@ -190,6 +190,18 @@ def test_from_pandas_include_indexes() -> None: assert df.to_dict(as_series=False) == data +def test_from_pandas_series_include_indexes() -> None: + # no default index + pd_series = pd.Series({"a": 1, "b": 2}, name="number").rename_axis(["letter"]) + df = pl.from_pandas(pd_series, include_index=True) + assert df.to_dict(as_series=False) == {"letter": ["a", "b"], "number": [1, 2]} + + # default index + pd_series = pd.Series(range(2)) + df = pl.from_pandas(pd_series, include_index=True) + assert df.to_dict(as_series=False) == {"index": [0, 1], "0": [0, 1]} + + def test_duplicate_cols_diff_types() -> None: df = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=["0", 0, "1", 1]) with pytest.raises( diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index eff356391b4b..b69a10671ca7 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None: assert df.shape == (2, 2) assert df["a"][0] == {"b": 1, "c": 2} assert df["a"][1] == {"b": 3, "c": 4} - assert df.schema == {"a": pl.Struct, "d": pl.Int64} + assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64} def test_from_dicts() -> None: @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None: assert frame.schema == { "a": pl.Int64, "b": pl.Float64, - "c": pl.Categorical, + "c": pl.Categorical(ordering="physical"), "d": pl.Boolean, "e": pl.String, "f": pl.Date, diff --git a/py-polars/tests/unit/io/cloud/test_cloud.py b/py-polars/tests/unit/io/cloud/test_cloud.py index f943ab5e2c26..e9b56b6b9f15 100644 --- a/py-polars/tests/unit/io/cloud/test_cloud.py +++ b/py-polars/tests/unit/io/cloud/test_cloud.py @@ -1,3 +1,7 @@ +import io +import sys +from typing import Any + import pytest import polars as pl @@ -23,3 +27,89 @@ def test_scan_nonexistent_cloud_path_17444(format: str) -> None: # Upon collection, it should fail with pytest.raises(ComputeError): result.collect() + + +@pytest.mark.parametrize( + "io_func", + [ + *[pl.scan_parquet, pl.read_parquet], + pl.scan_csv, + *[pl.scan_ndjson, pl.read_ndjson], + pl.scan_ipc, + ], +) +def test_scan_credential_provider( + io_func: Any, monkeypatch: pytest.MonkeyPatch +) -> None: + err_magic = "err_magic_3" + + def raises(*_: None, **__: None) -> None: + raise AssertionError(err_magic) + + monkeypatch.setattr(pl.CredentialProviderAWS, "__init__", raises) + + with pytest.raises(AssertionError, match=err_magic): + io_func("s3://bucket/path", credential_provider="auto") + + # We can't test these with the `read_` functions as they end up executing + # the query + if io_func.__name__.startswith("scan_"): + # Passing `None` should disable the automatic instantiation of + # `CredentialProviderAWS` + io_func("s3://bucket/path", credential_provider=None) + # Passing `storage_options` should disable the automatic instantiation of + # `CredentialProviderAWS` + io_func("s3://bucket/path", credential_provider="auto", storage_options={}) + + err_magic = "err_magic_7" + + def raises_2() -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + # Note to reader: It is converted to a ComputeError as it is being called + # from Rust. + with pytest.raises(ComputeError, match=err_magic): + io_func("s3://bucket/path", credential_provider=raises_2).collect() + + +def test_scan_credential_provider_serialization() -> None: + err_magic = "err_magic_3" + + class ErrCredentialProvider(pl.CredentialProvider): + def __call__(self) -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=ErrCredentialProvider() + ) + + serialized = lf.serialize() + + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + + with pytest.raises(ComputeError, match=err_magic): + lf.collect() + + +def test_scan_credential_provider_serialization_pyversion() -> None: + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=pl.CredentialProviderAWS() + ) + + serialized = lf.serialize() + serialized = bytearray(serialized) + + # We can't monkeypatch sys.python_version so we just mutate the output + # instead. + + v = b"PLPYFN" + i = serialized.index(v) + len(v) + a, b = serialized[i:][:2] + serialized_pyver = (a, b) + assert serialized_pyver == (sys.version_info.minor, sys.version_info.micro) + # Note: These are loaded as u8's + serialized[i] = 255 + serialized[i + 1] = 254 + + with pytest.raises(ComputeError, match=r"python version.*(3, 255, 254).*differs.*"): + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index deb44a5a79f4..69e7853172a1 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -12,7 +12,7 @@ import pyarrow as pa import pytest import sqlalchemy -from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast @@ -292,17 +292,18 @@ def test_read_database( tmp_sqlite_db: Path, ) -> None: if read_method == "read_database_uri": + connect_using = cast("DbReadEngine", connect_using) # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) df_empty = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data WHERE name LIKE '%polars%'", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: @@ -382,6 +383,39 @@ def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: assert_frame_equal(batches[0], expected) +def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None: + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + # establish sqlalchemy "textclause" and validate usage + textclause_query = text(""" + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < 0 + """) + + expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(textclause_query, connection=conn), + expected, + ) + + batches = list( + pl.read_database( + textclause_query, + connection=conn, + iter_batches=True, + batch_size=1, + ) + ) + assert len(batches) == 1 + assert_frame_equal(batches[0], expected) + + def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 1a635e68ceab..22df34dad668 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1833,9 +1833,9 @@ class TemporalFormats(TypedDict): ) assert df.write_csv(quote_style="necessary", **temporal_formats) == ( "float,string,int,bool,date,datetime,time,decimal\n" - '1.0,a,1,true,2077-07-05,,03:01:00,"1.0"\n' - '2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00,"2.0"\n' - ',"""hello",3,,2077-07-05,2077-07-05T03:01:00,,""\n' + "1.0,a,1,true,2077-07-05,,03:01:00,1.0\n" + '2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00,2.0\n' + ',"""hello",3,,2077-07-05,2077-07-05T03:01:00,,\n' ) assert df.write_csv(quote_style="never", **temporal_formats) == ( "float,string,int,bool,date,datetime,time,decimal\n" @@ -1847,9 +1847,9 @@ class TemporalFormats(TypedDict): quote_style="non_numeric", quote_char="8", **temporal_formats ) == ( "8float8,8string8,8int8,8bool8,8date8,8datetime8,8time8,8decimal8\n" - "1.0,8a8,1,8true8,82077-07-058,,803:01:008,81.08\n" - "2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008,82.08\n" - ',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,,88\n' + "1.0,8a8,1,8true8,82077-07-058,,803:01:008,1.0\n" + "2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008,2.0\n" + ',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,,\n' ) @@ -2299,3 +2299,11 @@ def test_read_csv_cast_unparsable_later( df.write_csv(f) f.seek(0) assert df.equals(pl.read_csv(f, schema={"x": dtype})) + + +def test_csv_double_new_line() -> None: + assert pl.read_csv(b"a,b,c\n\n", has_header=False).to_dict(as_series=False) == { + "column_1": ["a", None], + "column_2": ["b", None], + "column_3": ["c", None], + } diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py index 33c6b052ffdf..213988964de3 100644 --- a/py-polars/tests/unit/io/test_delta.py +++ b/py-polars/tests/unit/io/test_delta.py @@ -516,3 +516,11 @@ def test_read_parquet_respects_rechunk_16982( rechunk, expected_chunks = rechunk_and_expected_chunks result = pl.read_delta(str(tmp_path), rechunk=rechunk) assert result.n_chunks() == expected_chunks + + +def test_scan_delta_DT_input(delta_table_path: Path) -> None: + DT = DeltaTable(str(delta_table_path), version=0) + ldf = pl.scan_delta(DT) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index a01a2ef6e59d..9e9213ac9bd4 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -779,3 +779,45 @@ def test_hive_predicate_dates_14712( ) pl.scan_parquet(tmp_path).filter(pl.col("a") != datetime(2024, 1, 1)).collect() assert "hive partitioning: skipped 1 files" in capfd.readouterr().err + + +@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows paths") +@pytest.mark.write_disk +def test_hive_windows_splits_on_forward_slashes(tmp_path: Path) -> None: + # Note: This needs to be an absolute path. + tmp_path = tmp_path.resolve() + path = f"{tmp_path}/a=1/b=1/c=1/d=1/e=1" + Path(path).mkdir(exist_ok=True, parents=True) + + df = pl.DataFrame({"x": "x"}) + df.write_parquet(f"{path}/data.parquet") + + expect = pl.DataFrame( + [ + s.new_from_index(0, 5) + for s in pl.DataFrame( + { + "x": "x", + "a": 1, + "b": 1, + "c": 1, + "d": 1, + "e": 1, + } + ) + ] + ) + + assert_frame_equal( + pl.scan_parquet( + [ + f"{tmp_path}/a=1/b=1/c=1/d=1/e=1/data.parquet", + f"{tmp_path}\\a=1\\b=1\\c=1\\d=1\\e=1\\data.parquet", + f"{tmp_path}\\a=1/b=1/c=1/d=1/**/*", + f"{tmp_path}/a=1/b=1\\c=1/d=1/**/*", + f"{tmp_path}/a=1/b=1/c=1/d=1\\e=1/*", + ], + hive_partitioning=True, + ).collect(), + expect, + ) diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py index 5a5f6769e3a5..59ba92559459 100644 --- a/py-polars/tests/unit/io/test_iceberg.py +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -44,6 +44,19 @@ def test_scan_iceberg_plain(self, iceberg_path: str) -> None: "ts": pl.Datetime(time_unit="us", time_zone=None), } + def test_scan_iceberg_snapshot_id(self, iceberg_path: str) -> None: + df = pl.scan_iceberg(iceberg_path, snapshot_id=7051579356916758811) + assert len(df.collect()) == 3 + assert df.collect_schema() == { + "id": pl.Int32, + "str": pl.String, + "ts": pl.Datetime(time_unit="us", time_zone=None), + } + + def test_scan_iceberg_snapshot_id_not_found(self, iceberg_path: str) -> None: + with pytest.raises(ValueError, match="Snapshot ID not found"): + pl.scan_iceberg(iceberg_path, snapshot_id=1234567890) + def test_scan_iceberg_filter_on_partition(self, iceberg_path: str) -> None: ts1 = datetime(2023, 3, 1, 18, 15) ts2 = datetime(2023, 3, 1, 19, 25) diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index a4afd57e73e4..3b003ace6b1e 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -18,6 +18,7 @@ import pytest import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal @@ -68,9 +69,10 @@ def test_write_json_decimal() -> None: def test_json_infer_schema_length_11148() -> None: response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1 - result = pl.read_json(json.dumps(response).encode(), infer_schema_length=2) - with pytest.raises(AssertionError): - assert set(result.columns) == {"col1", "col2"} + with pytest.raises( + pl.exceptions.ComputeError, match="extra key in struct data: col2" + ): + pl.read_json(json.dumps(response).encode(), infer_schema_length=2) response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1 result = pl.read_json(json.dumps(response).encode(), infer_schema_length=3) @@ -433,7 +435,12 @@ def test_empty_list_json() -> None: def test_json_infer_3_dtypes() -> None: # would SO before df = pl.DataFrame({"a": ["{}", "1", "[1, 2]"]}) - out = df.select(pl.col("a").str.json_decode()) + + with pytest.raises(pl.exceptions.ComputeError): + df.select(pl.col("a").str.json_decode()) + + df = pl.DataFrame({"a": [None, "1", "[1, 2]"]}) + out = df.select(pl.col("a").str.json_decode(dtype=pl.List(pl.String))) assert out["a"].to_list() == [None, ["1"], ["1", "2"]] assert out.dtypes[0] == pl.List(pl.String) @@ -448,3 +455,57 @@ def test_zfs_json_roundtrip(size: int) -> None: f.seek(0) assert_frame_equal(a, pl.read_json(f)) + + +def test_read_json_raise_on_data_type_mismatch() -> None: + with pytest.raises(ComputeError): + pl.read_json( + b"""\ +[ + {"a": null}, + {"a": 1} +] +""", + infer_schema_length=1, + ) + + +def test_read_json_struct_schema() -> None: + with pytest.raises(ComputeError, match="extra key in struct data: b"): + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + infer_schema_length=1, + ) + + assert_frame_equal( + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + infer_schema_length=2, + ), + pl.DataFrame({"a": [1, 2], "b": [None, 2]}), + ) + + # If the schema was explicitly given, then we ignore extra fields. + # TODO: There should be a `columns=` parameter to this. + assert_frame_equal( + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + schema={"a": pl.Int64}, + ), + pl.DataFrame({"a": [1, 2]}), + ) diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 792ddc42b02a..49a842386ea2 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -235,12 +235,12 @@ def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) captured = capfd.readouterr().err assert ( - "parquet file must be read, statistics not sufficient for predicate." + "parquet row group must be read, statistics not sufficient for predicate." in captured ) assert ( - "parquet file can be skipped, the statistics were sufficient" - " to apply the predicate." in captured + "parquet row group can be skipped, the statistics were sufficient to apply the predicate." + in captured ) @@ -710,10 +710,18 @@ def test_parquet_schema_arg( schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef] - lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + for allow_missing_columns in [True, False]: + lf = pl.scan_parquet( + paths, + parallel=parallel, + schema=schema, + allow_missing_columns=allow_missing_columns, + ) - with pytest.raises(pl.exceptions.SchemaError, match="file contained extra columns"): - lf.collect(streaming=streaming) + with pytest.raises( + pl.exceptions.SchemaError, match="file contained extra columns" + ): + lf.collect(streaming=streaming) lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a") @@ -731,3 +739,29 @@ def test_parquet_schema_arg( match="data type mismatch for column b: expected: i8, found: i64", ): lf.collect(streaming=streaming) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("allow_missing_columns", [True, False]) +@pytest.mark.write_disk +def test_scan_parquet_ignores_dtype_mismatch_for_non_projected_columns_19249( + tmp_path: Path, + allow_missing_columns: bool, + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + paths = [tmp_path / "1", tmp_path / "2"] + + pl.DataFrame({"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt8}).write_parquet( + paths[0] + ) + pl.DataFrame( + {"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt64} + ).write_parquet(paths[1]) + + assert_frame_equal( + pl.scan_parquet(paths, allow_missing_columns=allow_missing_columns) + .select("a") + .collect(streaming=streaming), + pl.DataFrame({"a": [1, 1]}, schema={"a": pl.Int32}), + ) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 850bf61d978b..bf9d8ac4fad8 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1,9 +1,10 @@ from __future__ import annotations +import decimal import io from datetime import datetime, time, timezone from decimal import Decimal -from typing import IO, TYPE_CHECKING, Any, Literal, cast +from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast import fsspec import numpy as np @@ -1092,7 +1093,7 @@ def test_hybrid_rle() -> None: pl.Boolean, ], min_size=1, - max_size=5000, + max_size=500, ) ) @pytest.mark.slow @@ -1990,3 +1991,125 @@ def test_nested_nonnullable_19158() -> None: f.seek(0) assert_frame_equal(pl.read_parquet(f), pl.DataFrame(tbl)) + + +D = Decimal + + +@pytest.mark.parametrize("precision", range(1, 37, 2)) +@pytest.mark.parametrize( + "nesting", + [ + # Struct + lambda t: ([{"x": None}, None], pl.Struct({"x": t})), + lambda t: ([None, {"x": None}], pl.Struct({"x": t})), + lambda t: ([{"x": D("1.5")}, None], pl.Struct({"x": t})), + lambda t: ([{"x": D("1.5")}, {"x": D("4.8")}], pl.Struct({"x": t})), + # Array + lambda t: ([[None, None, D("8.2")], None], pl.Array(t, 3)), + lambda t: ([None, [None, D("8.9"), None]], pl.Array(t, 3)), + lambda t: ([[D("1.5"), D("3.7"), D("4.1")], None], pl.Array(t, 3)), + lambda t: ( + [[D("1.5"), D("3.7"), D("4.1")], [D("2.8"), D("5.2"), D("8.9")]], + pl.Array(t, 3), + ), + # List + lambda t: ([[None, D("8.2")], None], pl.List(t)), + lambda t: ([None, [D("8.9"), None]], pl.List(t)), + lambda t: ([[D("1.5"), D("4.1")], None], pl.List(t)), + lambda t: ([[D("1.5"), D("3.7"), D("4.1")], [D("2.8"), D("8.9")]], pl.List(t)), + ], +) +def test_decimal_precision_nested_roundtrip( + nesting: Callable[[pl.DataType], tuple[list[Any], pl.DataType]], + precision: int, +) -> None: + # Limit the context as to not disturb any other tests + with decimal.localcontext() as ctx: + ctx.prec = precision + + decimal_dtype = pl.Decimal(precision=precision) + values, dtype = nesting(decimal_dtype) + + df = pl.Series("a", values, dtype).to_frame() + + test_round_trip(df) + + +@pytest.mark.parametrize("parallel", ["prefiltered", "columns", "row_groups", "auto"]) +def test_conserve_sortedness( + monkeypatch: Any, capfd: Any, parallel: pl.ParallelStrategy +) -> None: + f = io.BytesIO() + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, None], + "b": [1.0, 2.0, 3.0, 4.0, 5.0, None], + "c": [None, 5, 4, 3, 2, 1], + "d": [None, 5.0, 4.0, 3.0, 2.0, 1.0], + "a_nosort": [1, 2, 3, 4, 5, None], + "f": range(6), + } + ) + + pq.write_table( + df.to_arrow(), + f, + sorting_columns=[ + pq.SortingColumn(0, False, False), + pq.SortingColumn(1, False, False), + pq.SortingColumn(2, True, True), + pq.SortingColumn(3, True, True), + ], + ) + + f.seek(0) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.scan_parquet(f, parallel=parallel).filter(pl.col.f > 1).collect() + + captured = capfd.readouterr().err + + # @NOTE: We don't conserve sortedness for anything except integers at the + # moment. + assert captured.count("Parquet conserved SortingColumn for column chunk of") == 2 + assert ( + "Parquet conserved SortingColumn for column chunk of 'a' to Ascending" + in captured + ) + assert ( + "Parquet conserved SortingColumn for column chunk of 'c' to Descending" + in captured + ) + + +def test_f16() -> None: + values = [float("nan"), 0.0, 0.5, 1.0, 1.5] + + table = pa.Table.from_pydict( + { + "x": pa.array(np.array(values, dtype=np.float16), type=pa.float16()), + } + ) + + df = pl.Series("x", values, pl.Float32).to_frame() + + f = io.BytesIO() + pq.write_table(table, f) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).filter(pl.col.x > 0.5).collect(), + df.filter(pl.col.x > 0.5), + ) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(1, 3).collect(), + df.slice(1, 3), + ) diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 799c4953cbf6..30af7b830ff8 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -801,3 +801,37 @@ def test_scan_double_collect_row_index_invalidates_cached_ir_18892() -> None: schema={"index": pl.UInt32, "a": pl.Int64}, ), ) + + +def test_scan_include_file_paths_respects_projection_pushdown() -> None: + q = pl.scan_csv(b"a,b,c\na1,b1,c1", include_file_paths="path_name").select( + ["a", "b"] + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": "a1", "b": "b1"})) + + +def test_streaming_scan_csv_include_file_paths_18257(io_files_path: Path) -> None: + lf = pl.scan_csv( + io_files_path / "foods1.csv", + include_file_paths="path", + ).select("category", "path") + + assert lf.collect(streaming=True).columns == ["category", "path"] + + +def test_streaming_scan_csv_with_row_index_19172(io_files_path: Path) -> None: + lf = ( + pl.scan_csv(io_files_path / "foods1.csv", infer_schema=False) + .with_row_index() + .select("calories", "index") + .head(1) + ) + + assert_frame_equal( + lf.collect(streaming=True), + pl.DataFrame( + {"calories": "45", "index": 0}, + schema={"calories": pl.String, "index": pl.UInt32}, + ), + ) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index a764f3c70755..b7b03a0bd02e 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -232,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None: xls = BytesIO() df.write_excel(xls, position="C5") - schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean} + schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean()} df_compare = df.with_columns( pl.col(nm).cast(tp) for nm, tp in schema_overrides.items() ) @@ -322,13 +322,12 @@ def test_read_mixed_dtype_columns( ) -> None: spreadsheet_path = request.getfixturevalue(source) schema_overrides = { - "Employee ID": pl.Utf8, - "Employee Name": pl.Utf8, - "Date": pl.Date, - "Details": pl.Categorical, - "Asset ID": pl.Utf8, + "Employee ID": pl.Utf8(), + "Employee Name": pl.Utf8(), + "Date": pl.Date(), + "Details": pl.Categorical("lexical"), + "Asset ID": pl.Utf8(), } - df = read_spreadsheet( spreadsheet_path, sheet_id=0, diff --git a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py index 64ad0e533d8d..0e7af02d8290 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py @@ -101,6 +101,7 @@ def func( BROADCAST_SERIES_COMBINATIONS, ) @pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow def test_list_arithmetic_values( list_side: str, broadcast_series: Callable[ @@ -380,6 +381,7 @@ def test_list_add_supertype( "broadcast_series", BROADCAST_SERIES_COMBINATIONS, ) +@pytest.mark.slow def test_list_numeric_op_validity_combination( broadcast_series: Callable[ [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] @@ -451,6 +453,7 @@ def test_list_add_alignment() -> None: @pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow def test_list_add_empty_lists( exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], ) -> None: @@ -516,6 +519,7 @@ def test_list_add_height_mismatch( ], ) @pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow def test_list_date_to_numeric_arithmetic_raises_error( op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series] ) -> None: diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 966fee3ea5ac..1264a5ed8773 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -7,7 +7,11 @@ import pytest import polars as pl -from polars.exceptions import ComputeError, OutOfBoundsError, SchemaError +from polars.exceptions import ( + ComputeError, + OutOfBoundsError, + SchemaError, +) from polars.testing import assert_frame_equal, assert_series_equal @@ -246,6 +250,16 @@ def test_list_contains_invalid_datatype() -> None: df.select(pl.col("a").list.contains(2)) +def test_list_contains_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in list.contains + # https://github.com/pola-rs/polars/issues/18968 + df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]}) + assert df.select(pl.all().list.contains(3)).to_dict(as_series=False) == { + "a": [False], + "b": [True], + } + + def test_list_concat() -> None: df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]}) @@ -643,6 +657,26 @@ def test_list_to_struct() -> None: {"n": {"one": 0, "two": 1, "three": None}}, ] + q = df.lazy().select( + pl.col("n").list.to_struct(fields=["a", "b"]).struct.field("a") + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": [0, 0]})) + + # Check that: + # * Specifying an upper bound calls the field name getter function to + # retrieve the lazy schema + # * The upper bound is respected during execution + q = df.lazy().select( + pl.col("n").list.to_struct(fields=str, upper_bound=2).struct.unnest() + ) + assert q.collect_schema() == {"0": pl.Int64, "1": pl.Int64} + assert_frame_equal(q.collect(), pl.DataFrame({"0": [0, 0], "1": [1, 1]})) + + assert df.lazy().select(pl.col("n").list.to_struct()).collect_schema() == { + "n": pl.Unknown + } + def test_select_from_list_to_struct_11143() -> None: ldf = pl.LazyFrame({"some_col": [[1.0, 2.0], [1.5, 3.0]]}) @@ -686,6 +720,16 @@ def test_list_count_matches_boolean_nulls_9141() -> None: assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1] +def test_list_count_matches_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in list.count_match + # https://github.com/pola-rs/polars/issues/18968 + df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]}) + assert df.select(pl.all().list.count_matches(3)).to_dict(as_series=False) == { + "a": [0], + "b": [1], + } + + def test_list_gather_oob_10079() -> None: df = pl.DataFrame( { @@ -885,3 +929,29 @@ def test_list_get_with_null() -> None: def test_list_sum_bool_schema() -> None: q = pl.LazyFrame({"x": [[True, True, False]]}) assert q.select(pl.col("x").list.sum()).collect_schema()["x"] == pl.UInt32 + + +def test_list_concat_struct_19279() -> None: + df = pl.select( + pl.struct(s=pl.lit("abcd").str.split("").explode(), i=pl.int_range(0, 4)) + ) + df = pl.concat([df[:2], df[-2:]]) + assert df.select(pl.concat_list("s")).to_dict(as_series=False) == { + "s": [ + [{"s": "a", "i": 0}], + [{"s": "b", "i": 1}], + [{"s": "c", "i": 2}], + [{"s": "d", "i": 3}], + ] + } + + +def test_list_eval_element_schema_19345() -> None: + assert_frame_equal( + ( + pl.LazyFrame({"a": [[{"a": 1}]]}) + .select(pl.col("a").list.eval(pl.element().struct.field("a"))) + .collect() + ), + pl.DataFrame({"a": [[1]]}), + ) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py index 842b0fd141a5..fcca5c5987b1 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_string.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -1006,6 +1006,66 @@ def test_replace_all() -> None: ) +def test_replace_all_literal_no_caputures() -> None: + # When using literal = True, capture groups should be disabled + + # Single row code path in Rust + df = pl.DataFrame({"text": ["I found yesterday."], "amt": ["$1"]}) + df = df.with_columns( + pl.col("text") + .str.replace_all("", pl.col("amt"), literal=True) + .alias("text2") + ) + assert df.get_column("text2")[0] == "I found $1 yesterday." + + # Multi-row code path in Rust + df2 = pl.DataFrame( + { + "text": ["I found yesterday.", "I lost yesterday."], + "amt": ["$1", "$2"], + } + ) + df2 = df2.with_columns( + pl.col("text") + .str.replace_all("", pl.col("amt"), literal=True) + .alias("text2") + ) + assert df2.get_column("text2")[0] == "I found $1 yesterday." + assert df2.get_column("text2")[1] == "I lost $2 yesterday." + + +def test_replace_literal_no_caputures() -> None: + # When using literal = True, capture groups should be disabled + + # Single row code path in Rust + df = pl.DataFrame({"text": ["I found yesterday."], "amt": ["$1"]}) + df = df.with_columns( + pl.col("text").str.replace("", pl.col("amt"), literal=True).alias("text2") + ) + assert df.get_column("text2")[0] == "I found $1 yesterday." + + # Multi-row code path in Rust + # A string shorter than 32 chars, + # and one longer than 32 chars to test both sub-paths + df2 = pl.DataFrame( + { + "text": [ + "I found yesterday.", + "I lost yesterday and this string is longer than 32 characters.", + ], + "amt": ["$1", "$2"], + } + ) + df2 = df2.with_columns( + pl.col("text").str.replace("", pl.col("amt"), literal=True).alias("text2") + ) + assert df2.get_column("text2")[0] == "I found $1 yesterday." + assert ( + df2.get_column("text2")[1] + == "I lost $2 yesterday and this string is longer than 32 characters." + ) + + def test_replace_expressions() -> None: df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"], "value": ["A", "B"]}) out = df.select([pl.col("foo").str.replace(pl.col("foo").first(), pl.col("value"))]) @@ -1727,3 +1787,55 @@ def test_extract_many() -> None: assert df.select(pl.col("values").str.extract_many("patterns")).to_dict( as_series=False ) == {"values": [["disco"], ["rhap", "ody"]]} + + +def test_json_decode_raise_on_data_type_mismatch_13061() -> None: + assert_series_equal( + pl.Series(["null", "null"]).str.json_decode(infer_schema_length=1), + pl.Series([None, None]), + ) + + with pytest.raises(ComputeError): + pl.Series(["null", "1"]).str.json_decode(infer_schema_length=1) + + assert_series_equal( + pl.Series(["null", "1"]).str.json_decode(infer_schema_length=2), + pl.Series([None, 1]), + ) + + +def test_json_decode_struct_schema() -> None: + with pytest.raises(ComputeError, match="extra key in struct data: b"): + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + infer_schema_length=1 + ) + + assert_series_equal( + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + infer_schema_length=2 + ), + pl.Series([{"a": 1, "b": None}, {"a": 2, "b": 2}]), + ) + + # If the schema was explicitly given, then we ignore extra fields. + # TODO: There should be a `columns=` parameter to this. + assert_series_equal( + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + dtype=pl.Struct({"a": pl.Int64}) + ), + pl.Series([{"a": 1}, {"a": 2}]), + ) + + +def test_escape_regex() -> None: + df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + result_df = df.with_columns(pl.col("text").str.escape_regex().alias("escaped")) + expected_df = pl.DataFrame( + { + "text": ["abc", "def", None, "abc(\\w+)"], + "escaped": ["abc", "def", None, "abc\\(\\\\w\\+\\)"], + } + ) + + assert_frame_equal(result_df, expected_df) + assert_series_equal(result_df["escaped"], expected_df["escaped"]) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index e01912237a19..dca9eeb3e767 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -461,12 +461,6 @@ def test_cast_temporal( "expected_value", ), [ - (str(2**7 - 1).encode(), pl.Binary, pl.Int8, 2**7 - 1), - (str(2**15 - 1).encode(), pl.Binary, pl.Int16, 2**15 - 1), - (str(2**31 - 1).encode(), pl.Binary, pl.Int32, 2**31 - 1), - (str(2**63 - 1).encode(), pl.Binary, pl.Int64, 2**63 - 1), - (b"1.0", pl.Binary, pl.Float32, 1.0), - (b"1.0", pl.Binary, pl.Float64, 1.0), (str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1), (str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1), (str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1), @@ -478,13 +472,9 @@ def test_cast_temporal( (str(2**15), pl.String, pl.Int16, None), (str(2**31), pl.String, pl.Int32, None), (str(2**63), pl.String, pl.Int64, None), - (str(2**7).encode(), pl.Binary, pl.Int8, None), - (str(2**15).encode(), pl.Binary, pl.Int16, None), - (str(2**31).encode(), pl.Binary, pl.Int32, None), - (str(2**63).encode(), pl.Binary, pl.Int64, None), ], ) -def test_cast_string_and_binary( +def test_cast_string( value: int, from_dtype: PolarsDataType, to_dtype: PolarsDataType, @@ -522,12 +512,6 @@ def test_cast_string_and_binary( "expected_value", ), [ - (str(2**7 - 1).encode(), pl.Binary, pl.Int8, True, 2**7 - 1), - (str(2**15 - 1).encode(), pl.Binary, pl.Int16, True, 2**15 - 1), - (str(2**31 - 1).encode(), pl.Binary, pl.Int32, True, 2**31 - 1), - (str(2**63 - 1).encode(), pl.Binary, pl.Int64, True, 2**63 - 1), - (b"1.0", pl.Binary, pl.Float32, True, 1.0), - (b"1.0", pl.Binary, pl.Float64, True, 1.0), (str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1), (str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1), (str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1), @@ -539,13 +523,9 @@ def test_cast_string_and_binary( (str(2**15), pl.String, pl.Int16, False, None), (str(2**31), pl.String, pl.Int32, False, None), (str(2**63), pl.String, pl.Int64, False, None), - (str(2**7).encode(), pl.Binary, pl.Int8, False, None), - (str(2**15).encode(), pl.Binary, pl.Int16, False, None), - (str(2**31).encode(), pl.Binary, pl.Int32, False, None), - (str(2**63).encode(), pl.Binary, pl.Int64, False, None), ], ) -def test_strict_cast_string_and_binary( +def test_strict_cast_string( value: int, from_dtype: PolarsDataType, to_dtype: PolarsDataType, @@ -692,3 +672,9 @@ def test_cast_consistency() -> None: assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns( b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String) ).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]} + + +def test_cast_int_to_string_unsets_sorted_flag_19424() -> None: + s = pl.Series([1, 2]).set_sorted() + assert s.flags["SORTED_ASC"] + assert not s.cast(pl.String).flags["SORTED_ASC"] diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py index d56e5d80858d..4c74f60a7ff9 100644 --- a/py-polars/tests/unit/operations/test_gather.py +++ b/py-polars/tests/unit/operations/test_gather.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import polars as pl @@ -156,3 +157,47 @@ def test_gather_str_col_18099() -> None: "foo": [1, 1, 2], "idx": [0, 0, 1], } + + +def test_gather_list_19243() -> None: + df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]}) + assert df.with_columns(pl.lit([0]).alias("c")).with_columns( + gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True) + ).to_dict(as_series=False) == { + "a": [[0.1, 0.2, 0.3]], + "c": [[0]], + "gather": [[0.1]], + } + + +def test_gather_array_list_null_19302() -> None: + data = pl.DataFrame( + {"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))} + ) + assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == { + "data": [None] + } + + +def test_gather_array() -> None: + a = np.arange(16).reshape(-1, 2, 2) + s = pl.Series(a) + + for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]: + assert (s.gather(idx).to_numpy() == a[idx]).all() + + v = s[[0, 1, None, 3]] # type: ignore[list-item] + assert v[2] is None + + +def test_gather_array_outer_validity_19482() -> None: + s = ( + pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1)) + .to_frame() + .select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first())) + .to_series() + ) + + expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1)) + assert_series_equal(s, expect) + assert_series_equal(s.gather([0, 1]), expect) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 5ed57b374149..cff43b43274c 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None: ) assert df.rows() == [] assert df.shape == (0, 2) - assert df.schema == {"key": pl.Categorical, "val": pl.Float64} + assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} def test_group_by_custom_agg_empty_list() -> None: diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 68220cf81551..87ad4dfc5a9b 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -1082,3 +1082,20 @@ def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None: msg = "cross join should not pass join keys" with pytest.raises(ValueError, match=msg): df1.join(df2, how="cross", **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize("set_sorted", [True, False]) +def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None: + left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]}) + right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]}) + + if set_sorted: + # The data isn't actually sorted on purpose to ensure we default to a + # hash join unless we set the sorted flag here, in case there is new + # code in the future that automatically identifies sortedness during + # Series construction from Python. + left = left.set_sorted("k") + right = right.set_sorted("k") + + q = left.join(right, on="k", how="left").head(5) + assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]})) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 57dbec1a13ee..b19d008556e1 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -413,6 +413,7 @@ def test_sort_by_in_over_5499() -> None: } +@pytest.mark.may_fail_auto_streaming def test_merge_sorted() -> None: df_a = ( pl.datetime_range( diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py index ada642c294ae..434c2fdc3af9 100644 --- a/py-polars/tests/unit/operations/test_unpivot.py +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -2,6 +2,7 @@ import polars as pl import polars.selectors as cs +from polars import StringCache from polars.testing import assert_frame_equal @@ -94,3 +95,21 @@ def test_unpivot_empty_18170() -> None: assert pl.DataFrame().unpivot().schema == pl.Schema( {"variable": pl.String(), "value": pl.Null()} ) + + +@StringCache() +def test_unpivot_categorical_global() -> None: + df = pl.DataFrame( + { + "index": [0, 1], + "1": pl.Series(["a", "b"], dtype=pl.Categorical), + "2": pl.Series(["b", "c"], dtype=pl.Categorical), + } + ) + out = df.unpivot(["1", "2"], index="index") + assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="physical")] + assert out.to_dict(as_series=False) == { + "index": [0, 1, 0, 1], + "variable": ["1", "1", "2", "2"], + "value": ["a", "b", "b", "c"], + } diff --git a/py-polars/tests/unit/operations/unique/test_unique.py b/py-polars/tests/unit/operations/unique/test_unique.py index 479a52ca2f9a..b50edd981e5f 100644 --- a/py-polars/tests/unit/operations/unique/test_unique.py +++ b/py-polars/tests/unit/operations/unique/test_unique.py @@ -154,3 +154,11 @@ def test_unique_with_null() -> None: {"a": [1, 2, 3, 4], "b": ["a", "b", "c", "c"], "c": [None, None, None, None]} ) assert_frame_equal(df.unique(maintain_order=True), expected_df) + + +def test_categorical_unique_19409() -> None: + df = pl.DataFrame({"x": [str(n % 50) for n in range(127)]}).cast(pl.Categorical) + uniq = df.unique() + assert uniq.height == 50 + assert uniq.null_count().item() == 0 + assert set(uniq["x"]) == set(df["x"]) diff --git a/py-polars/tests/unit/series/test_scatter.py b/py-polars/tests/unit/series/test_scatter.py index c3c0b38d6805..95e4aa5b31e3 100644 --- a/py-polars/tests/unit/series/test_scatter.py +++ b/py-polars/tests/unit/series/test_scatter.py @@ -43,7 +43,7 @@ def test_scatter() -> None: assert s.to_list() == ["a", "x", "x"] assert s.scatter([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"] - # set multiple values values + # set multiple values s = pl.Series(["z", "z", "z"]) assert s.scatter([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"] s = pl.Series([True, False, True]) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index f1a858f62add..0ed478b9aa83 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -23,6 +23,7 @@ Unknown, ) from polars.exceptions import ( + DuplicateError, InvalidOperationError, PolarsInefficientMapWarning, ShapeError, @@ -1356,6 +1357,13 @@ def test_to_dummies_drop_first() -> None: assert_frame_equal(result, expected) +def test_to_dummies_null_clash_19096() -> None: + with pytest.raises( + DuplicateError, match="column with name '_null' has more than one occurrence" + ): + pl.Series([None, "null"]).to_dummies() + + def test_chunk_lengths() -> None: s = pl.Series("a", [1, 2, 2, 3]) # this is a Series with one chunk, of length 4 diff --git a/py-polars/tests/unit/sql/test_bitwise.py b/py-polars/tests/unit/sql/test_bitwise.py new file mode 100644 index 000000000000..7bba8cfde5c5 --- /dev/null +++ b/py-polars/tests/unit/sql/test_bitwise.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [20, 32, 50, 88, 128], + "y": [-128, 0, 10, -1, None], + } + ) + + +def test_bitwise_and(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x & y AS x_bitand_op_y, + BITAND(y, x) AS y_bitand_x, + BIT_AND(x, y) AS x_bitand_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitand_op_y": [0, 0, 2, 88, None], + "y_bitand_x": [0, 0, 2, 88, None], + "x_bitand_y": [0, 0, 2, 88, None], + } + + +def test_bitwise_count(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + BITCOUNT(x) AS x_bits_set, + BIT_COUNT(y) AS y_bits_set, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bits_set": [2, 1, 3, 3, 1], + "y_bits_set": [57, 0, 2, 64, None], + } + + +def test_bitwise_or(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x | y AS x_bitor_op_y, + BITOR(y, x) AS y_bitor_x, + BIT_OR(x, y) AS x_bitor_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitor_op_y": [-108, 32, 58, -1, None], + "y_bitor_x": [-108, 32, 58, -1, None], + "x_bitor_y": [-108, 32, 58, -1, None], + } + + +def test_bitwise_xor(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x XOR y AS x_bitxor_op_y, + BITXOR(y, x) AS y_bitxor_x, + BIT_XOR(x, y) AS x_bitxor_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitxor_op_y": [-108, 32, 56, -89, None], + "y_bitxor_x": [-108, 32, 56, -89, None], + "x_bitxor_y": [-108, 32, 56, -89, None], + } diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 08e4b236c833..71fa1572831c 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -238,3 +238,9 @@ def test_group_by_errors() -> None: match=r"'a' should participate in the GROUP BY clause or an aggregate function", ): df.sql("SELECT a, SUM(b) FROM self GROUP BY b") + + with pytest.raises( + SQLSyntaxError, + match=r"HAVING clause not valid outside of GROUP BY", + ): + df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1") diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 95ba8461bebe..f7d0615e13c6 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -7,7 +7,7 @@ import pytest import polars as pl -from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal if TYPE_CHECKING: @@ -362,3 +362,26 @@ def test_global_variable_inference_17398() -> None: eager=True, ) assert_frame_equal(res, users) + + +@pytest.mark.parametrize( + "query", + [ + "SELECT invalid_column FROM self", + "SELECT key, invalid_column FROM self", + "SELECT invalid_column * 2 FROM self", + "SELECT * FROM self ORDER BY invalid_column", + "SELECT * FROM self WHERE invalid_column = 200", + "SELECT * FROM self WHERE invalid_column = '200'", + "SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column", + ], +) +def test_invalid_cols(query: str) -> None: + df = pl.DataFrame( + { + "key": ["xx", "xx", "yy"], + "n": ["100", "200", "300"], + } + ) + with pytest.raises(ColumnNotFoundError, match="invalid_column"): + df.sql(query) diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index 80730273672e..225c0b97553c 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -19,6 +19,7 @@ pytestmark = pytest.mark.xdist_group("streaming") +@pytest.mark.may_fail_auto_streaming def test_streaming_categoricals_5921() -> None: with pl.StringCache(): out_lazy = ( @@ -74,6 +75,7 @@ def test_streaming_streamable_functions(monkeypatch: Any, capfd: Any) -> None: @pytest.mark.slow +@pytest.mark.may_fail_auto_streaming def test_cross_join_stack() -> None: a = pl.Series(np.arange(100_000)).to_frame().lazy() t0 = time.time() diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py index 6c74f0693caa..02397afcf9d4 100644 --- a/py-polars/tests/unit/test_config.py +++ b/py-polars/tests/unit/test_config.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from textwrap import dedent from typing import TYPE_CHECKING, Any import pytest @@ -497,6 +498,16 @@ def test_shape_format_for_big_numbers() -> None: "╰─────────┴───╯" ) + pl.Config.set_tbl_formatting("ASCII_FULL_CONDENSED") + assert ( + str(df) == "shape: (1, 1_000)\n" + "+---------+-----+\n" + "| 0 (i64) | ... |\n" + "+===============+\n" + "| 1 | ... |\n" + "+---------+-----+" + ) + def test_numeric_right_alignment() -> None: pl.Config.set_tbl_cell_numeric_alignment("RIGHT") @@ -771,6 +782,79 @@ def test_set_fmt_str_lengths_invalid_length() -> None: cfg.set_fmt_str_lengths(-2) +def test_truncated_rows_cols_values_ascii() -> None: + df = pl.DataFrame({f"c{n}": list(range(-n, 100 - n)) for n in range(10)}) + + pl.Config.set_tbl_formatting("UTF8_BORDERS_ONLY", rounded_corners=True) + assert ( + str(df) == "shape: (100, 10)\n" + "╭───────────────────────────────────────────────────╮\n" + "│ c0 c1 c2 c3 … c6 c7 c8 c9 │\n" + "│ --- --- --- --- --- --- --- --- │\n" + "│ i64 i64 i64 i64 i64 i64 i64 i64 │\n" + "╞═══════════════════════════════════════════════════╡\n" + "│ 0 -1 -2 -3 … -6 -7 -8 -9 │\n" + "│ 1 0 -1 -2 … -5 -6 -7 -8 │\n" + "│ 2 1 0 -1 … -4 -5 -6 -7 │\n" + "│ 3 2 1 0 … -3 -4 -5 -6 │\n" + "│ 4 3 2 1 … -2 -3 -4 -5 │\n" + "│ … … … … … … … … … │\n" + "│ 95 94 93 92 … 89 88 87 86 │\n" + "│ 96 95 94 93 … 90 89 88 87 │\n" + "│ 97 96 95 94 … 91 90 89 88 │\n" + "│ 98 97 96 95 … 92 91 90 89 │\n" + "│ 99 98 97 96 … 93 92 91 90 │\n" + "╰───────────────────────────────────────────────────╯" + ) + with pl.Config(tbl_formatting="ASCII_FULL_CONDENSED"): + assert ( + str(df) == "shape: (100, 10)\n" + "+-----+-----+-----+-----+-----+-----+-----+-----+-----+\n" + "| c0 | c1 | c2 | c3 | ... | c6 | c7 | c8 | c9 |\n" + "| --- | --- | --- | --- | | --- | --- | --- | --- |\n" + "| i64 | i64 | i64 | i64 | | i64 | i64 | i64 | i64 |\n" + "+=====================================================+\n" + "| 0 | -1 | -2 | -3 | ... | -6 | -7 | -8 | -9 |\n" + "| 1 | 0 | -1 | -2 | ... | -5 | -6 | -7 | -8 |\n" + "| 2 | 1 | 0 | -1 | ... | -4 | -5 | -6 | -7 |\n" + "| 3 | 2 | 1 | 0 | ... | -3 | -4 | -5 | -6 |\n" + "| 4 | 3 | 2 | 1 | ... | -2 | -3 | -4 | -5 |\n" + "| ... | ... | ... | ... | ... | ... | ... | ... | ... |\n" + "| 95 | 94 | 93 | 92 | ... | 89 | 88 | 87 | 86 |\n" + "| 96 | 95 | 94 | 93 | ... | 90 | 89 | 88 | 87 |\n" + "| 97 | 96 | 95 | 94 | ... | 91 | 90 | 89 | 88 |\n" + "| 98 | 97 | 96 | 95 | ... | 92 | 91 | 90 | 89 |\n" + "| 99 | 98 | 97 | 96 | ... | 93 | 92 | 91 | 90 |\n" + "+-----+-----+-----+-----+-----+-----+-----+-----+-----+" + ) + + with pl.Config(tbl_formatting="MARKDOWN"): + df = pl.DataFrame({"b": [b"0tigohij1prisdfj1gs2io3fbjg0pfihodjgsnfbbmfgnd8j"]}) + assert ( + str(df) + == dedent(""" + shape: (1, 1) + | b | + | --- | + | binary | + |---------------------------------| + | b"0tigohij1prisdfj1gs2io3fbjg0… |""").lstrip() + ) + + with pl.Config(tbl_formatting="ASCII_MARKDOWN"): + df = pl.DataFrame({"b": [b"0tigohij1prisdfj1gs2io3fbjg0pfihodjgsnfbbmfgnd8j"]}) + assert ( + str(df) + == dedent(""" + shape: (1, 1) + | b | + | --- | + | binary | + |-----------------------------------| + | b"0tigohij1prisdfj1gs2io3fbjg0... |""").lstrip() + ) + + def test_warn_unstable(recwarn: pytest.WarningsRecorder) -> None: issue_unstable_warning() assert len(recwarn) == 0 diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index d6fdb8976c5b..5ac102efe5cb 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -788,3 +788,34 @@ def test_eager_cse_during_struct_expansion_18411() -> None: df.select(pl.col("foo").replace(classes, counts)) == df.select(pl.col("foo").replace(classes, counts)) )["foo"].all() + + +def test_cse_skip_as_struct_19253() -> None: + df = pl.LazyFrame({"x": [1, 2], "y": [4, 5]}) + + assert ( + df.with_columns( + q1=pl.struct(pl.col.x - pl.col.y.mean()), + q2=pl.struct(pl.col.x - pl.col.y.mean().over("y")), + ).collect() + ).to_dict(as_series=False) == { + "x": [1, 2], + "y": [4, 5], + "q1": [{"x": -3.5}, {"x": -2.5}], + "q2": [{"x": -3.0}, {"x": -3.0}], + } + + +def test_cse_union_19227() -> None: + lf = pl.LazyFrame({"A": [1], "B": [2]}) + lf_1 = lf.select(C="A", B="B") + lf_2 = lf.select(C="A", A="B") + + direct = lf_2.join(lf, on=["A"]).select("C", "A", "B") + + indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B") + + out = pl.concat([direct, indirect]) + assert out.collect().schema == pl.Schema( + [("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)] + ) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index c730ee8d30a7..0ca504d20e20 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -16,7 +16,6 @@ ComputeError, InvalidOperationError, OutOfBoundsError, - PanicException, SchemaError, SchemaFieldNotFoundError, ShapeError, @@ -116,7 +115,7 @@ def test_string_numeric_comp_err() -> None: def test_panic_error() -> None: with pytest.raises( - PanicException, + InvalidOperationError, match="unit: 'k' not supported", ): pl.datetime_range( @@ -696,7 +695,7 @@ def test_no_panic_pandas_nat() -> None: def test_list_to_struct_invalid_type() -> None: - with pytest.raises(pl.exceptions.SchemaError): + with pytest.raises(pl.exceptions.InvalidOperationError): pl.DataFrame({"a": 1}).select(pl.col("a").list.to_struct()) @@ -708,3 +707,15 @@ def test_raise_invalid_agg() -> None: .group_by("index") .agg(pl.col("foo").filter(pl.col("i_do_not_exist"))) ).collect() + + +def test_err_mean_horizontal_lists() -> None: + df = pl.DataFrame( + { + "experiment_id": [1, 2], + "sensor1": [[1, 2, 3], [7, 8, 9]], + "sensor2": [[4, 5, 6], [10, 11, 12]], + } + ) + with pytest.raises(pl.exceptions.InvalidOperationError): + df.with_columns(pl.mean_horizontal("sensor1", "sensor2").alias("avg_sensor")) diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 48b3077423e0..e075fdb6acc2 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -582,3 +582,33 @@ def test_projections_collapse_17781() -> None: else: lf = lf.join(lfj, on="index", how="left") assert "SELECT " not in lf.explain() # type: ignore[union-attr] + + +def test_with_columns_projection_pushdown() -> None: + # # Summary + # `process_hstack` in projection PD incorrectly took a fast-path meant for + # LP nodes that don't add new columns to the schema, which stops projection + # PD if it sees that the schema lengths on the upper node matches. + # + # To trigger this, we drop the same number of columns before and after + # the with_columns, and in the with_columns we also add the same number of + # columns. + lf = ( + pl.scan_csv( + b"""\ +a,b,c,d,e +1,1,1,1,1 +""", + include_file_paths="path", + ) + .drop("a", "b") + .with_columns(pl.lit(1).alias(x) for x in ["x", "y"]) + .drop("c", "d") + ) + + plan = lf.explain().strip() + + assert plan.startswith("WITH_COLUMNS:") + # [dyn int: 1.alias("x"), dyn int: 1.alias("y")] + # Csv SCAN [20 in-mem bytes] + assert plan.endswith("PROJECT 1/6 COLUMNS") diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 7b6a1583ef11..9c5848382f42 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,6 +1,8 @@ import pickle from datetime import datetime +import pytest + import polars as pl @@ -13,14 +15,40 @@ def test_schema() -> None: assert s.names() == ["foo", "bar"] assert s.dtypes() == [pl.Int8(), pl.String()] + with pytest.raises( + TypeError, + match="dtypes must be fully-specified, got: List", + ): + pl.Schema({"foo": pl.String, "bar": pl.List}) + + +def test_schema_equality() -> None: + s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) + s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) + + assert s1 == s1 + assert s2 == s2 + assert s3 == s3 + assert s1 != s2 + assert s1 != s3 + assert s2 != s3 + + s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")}) + s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")}) + s6 = {"foo": pl.Datetime, "bar": pl.Duration} + + assert s4 != s5 + assert s4 != s6 -def test_schema_parse_nonpolars_dtypes() -> None: + +def test_schema_parse_python_dtypes() -> None: cardinal_directions = pl.Enum(["north", "south", "east", "west"]) - s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] s["ham"] = datetime - assert s["foo"] == pl.List + assert s["foo"] == pl.List(pl.Int32) assert s["bar"] == pl.Int64 assert s["baz"] == cardinal_directions assert s["ham"] == pl.Datetime("us") @@ -33,19 +61,6 @@ def test_schema_parse_nonpolars_dtypes() -> None: assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] -def test_schema_equality() -> None: - s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) - s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) - s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) - - assert s1 == s1 - assert s2 == s2 - assert s3 == s3 - assert s1 != s2 - assert s1 != s3 - assert s2 != s3 - - def test_schema_picklable() -> None: s = pl.Schema( { @@ -88,7 +103,6 @@ def test_schema_in_map_elements_returns_scalar() -> None: "amounts": [100.0, -110.0] * 2, } ) - q = ldf.group_by("portfolio").agg( pl.col("amounts") .map_elements( @@ -112,7 +126,6 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None: .rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i") .agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2")) ) - assert q.collect_schema() == pl.Schema( [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] ) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index bf44ff87bac5..dd2c415c9a13 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -515,7 +515,7 @@ def test_selector_temporal(df: pl.DataFrame) -> None: assert df.select(cs.temporal()).schema == { "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), } all_columns = set(df.columns) @@ -611,7 +611,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: "eee": pl.Boolean, "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), "qqR": pl.String, } @@ -629,7 +629,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) @@ -639,7 +639,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: ).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 179cd9e58d86..90221c3e2edb 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-09-29" +channel = "nightly-2024-10-28"