diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..a4557c17fe9f --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: cargo + directory: "/" + schedule: + interval: weekly + day: sunday + time: "7:00" + open-pull-requests-limit: 10 + target-branch: master + labels: [auto-dependencies] \ No newline at end of file diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml deleted file mode 100644 index 6e54d12968de..000000000000 --- a/.github/workflows/python_build.yml +++ /dev/null @@ -1,131 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Python Release Build -on: - push: - tags: - - "*-rc*" - -defaults: - run: - working-directory: ./python - -jobs: - generate-license: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - name: Generate license file - run: python ../dev/create_license.py - - uses: actions/upload-artifact@v2 - with: - name: python-wheel-license - path: python/LICENSE.txt - - build-python-mac-win: - needs: [generate-license] - name: Mac/Win - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - python-version: ["3.10"] - os: [macos-latest, windows-latest] - steps: - - uses: actions/checkout@v2 - - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly-2021-10-23 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install maturin==0.11.5 - - - run: rm LICENSE.txt - - name: Download LICENSE.txt - uses: actions/download-artifact@v2 - with: - name: python-wheel-license - path: python - - - name: Build Python package - run: maturin build --release --no-sdist --strip - - - name: List Windows wheels - if: matrix.os == 'windows-latest' - run: dir target\wheels\ - - - name: List Mac wheels - if: matrix.os != 'windows-latest' - run: find target/wheels/ - - - name: Archive wheels - uses: actions/upload-artifact@v2 - with: - name: dist - path: python/target/wheels/* - - build-manylinux: - needs: [generate-license] - name: Manylinux - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - run: rm LICENSE.txt - - name: Download LICENSE.txt - uses: actions/download-artifact@v2 - with: - name: python-wheel-license - path: python - - run: cat LICENSE.txt - - name: Build wheels - run: | - export RUSTFLAGS='-C target-cpu=skylake' - docker run --rm -v $(pwd)/..:/io \ - --workdir /io/python \ - konstin2/maturin:v0.11.2 \ - build --release --manylinux 2010 - - name: Archive wheels - uses: actions/upload-artifact@v2 - with: - name: dist - path: python/target/wheels/* - - # NOTE: PyPI publish needs to be done manually for now after release passed the vote - # release: - # name: Publish in PyPI - # needs: [build-manylinux, build-python-mac-win] - # runs-on: ubuntu-latest - # steps: - # - uses: actions/download-artifact@v2 - # - name: Publish to PyPI - # uses: pypa/gh-action-pypi-publish@master - # with: - # user: __token__ - # password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml deleted file mode 100644 index 01a36af870af..000000000000 --- a/.github/workflows/python_test.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Python test -on: [push, pull_request] - -jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Setup Rust toolchain - run: | - rustup toolchain install nightly-2021-10-23 - rustup default nightly-2021-10-23 - rustup component add rustfmt - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /home/runner/.cargo - key: cargo-maturin-cache- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /home/runner/target - key: target-maturin-cache- - - uses: actions/setup-python@v2 - with: - python-version: "3.10" - - name: Create Virtualenv - run: | - python -m venv venv - source venv/bin/activate - pip install -r python/requirements.txt - - name: Run Linters - run: | - source venv/bin/activate - flake8 python --ignore=E501 - black --line-length 79 --diff --check python - - name: Run tests - run: | - source venv/bin/activate - cd python - maturin develop - RUST_BACKTRACE=1 pytest -v . - env: - CARGO_HOME: "/home/runner/.cargo" - CARGO_TARGET_DIR: "/home/runner/target" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2768355dc669..5e841f87ffe5 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -116,7 +116,8 @@ jobs: cargo test --no-default-features cargo run --example csv_sql cargo run --example parquet_sql - # cargo run --example avro_sql --features=datafusion/avro + #nopass + cargo run --example avro_sql --features=datafusion/avro env: CARGO_HOME: "/github/home/.cargo" CARGO_TARGET_DIR: "/github/home/target" @@ -127,6 +128,7 @@ jobs: export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data cd ballista/rust # snmalloc requires cmake so build without default features + #nopass cargo test --no-default-features --features sled env: CARGO_HOME: "/github/home/.cargo" diff --git a/Cargo.toml b/Cargo.toml index 66f7f932c7b5..757d671fbe0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,6 @@ lto = true codegen-units = 1 [patch.crates-io] -#arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", rev = "f2c7503bc171a4c75c0af9905823c8795bd17f9b" } -arrow2 = { git = "https://github.com/blaze-init/arrow2.git", branch = "shuffle_ipc" } -parquet2 = { git = "https://github.com/blaze-init/parquet2.git", branch = "meta_new" } +arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", rev = "ef7937dfe56033c2cc491482c67587b52cd91554" } +#arrow2 = { git = "https://github.com/blaze-init/arrow2.git", branch = "shuffle_ipc" } +#parquet2 = { git = "https://github.com/blaze-init/parquet2.git", branch = "meta_new" } diff --git a/README.md b/README.md index 1e2ffdc05be4..6bef96637712 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ DataFusion is designed to be extensible at all points. To that end, you can prov ## Rust Version Compatbility -This crate is tested with the latest stable version of Rust. We do not currrently test against other, older versions of the Rust compiler. +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. # Supported SQL @@ -264,9 +264,9 @@ This library currently supports many SQL constructs, including - `SELECT ... FROM ...` together with any expression - `ALIAS` to name an expression - `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` -- most mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. +- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. - `WHERE` to filter -- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG` +- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `VAR`, `STDDEV` (sample and population) - `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` ## Supported Functions @@ -366,7 +366,7 @@ Please see [Roadmap](docs/source/specification/roadmap.md) for information of wh There is no formal document describing DataFusion's architecture yet, but the following presentations offer a good overview of its different components and how they interact together. - (March 2021): The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) -- (Feburary 2021): How DataFusion is used within the Ballista Project is described in \*Ballista: Distributed Compute with Rust and Apache Arrow: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) +- (February 2021): How DataFusion is used within the Ballista Project is described in \*Ballista: Distributed Compute with Rust and Apache Arrow: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) # Developer's guide diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index a88adb2c8983..3415d13a3487 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -30,7 +30,7 @@ build = "build.rs" simd = ["datafusion/simd"] [dependencies] -ahash = "0.7" +ahash = { version = "0.7", default-features = false } async-trait = "0.1.36" futures = "0.3" hashbrown = "0.11" @@ -41,7 +41,7 @@ sqlparser = "0.13" tokio = "1.0" tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } -chrono = "0.4" +chrono = { version = "0.4", default-features = false } arrow-format = { version = "0.3", features = ["flight-data", "flight-service"] } arrow = { package = "arrow2", version="0.8", features = ["io_ipc", "io_flight"] } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index e59ec21fb5bd..5a755cc9a2ac 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -169,6 +169,10 @@ enum AggregateFunction { COUNT = 4; APPROX_DISTINCT = 5; ARRAY_AGG = 6; + VARIANCE=7; + VARIANCE_POP=8; + STDDEV=9; + STDDEV_POP=10; } message AggregateExprNode { diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs index 8fdae4376bc9..eaacda8badf2 100644 --- a/ballista/rust/core/src/client.rs +++ b/ballista/rust/core/src/client.rs @@ -17,6 +17,8 @@ //! Client API for sending requests to executors. +use arrow::io::flight::deserialize_schemas; +use arrow::io::ipc::IpcSchema; use std::sync::{Arc, Mutex}; use std::{collections::HashMap, pin::Pin}; use std::{ @@ -121,10 +123,12 @@ impl BallistaClient { { Some(flight_data) => { // convert FlightData to a stream - let schema = Arc::new(Schema::try_from(&flight_data)?); + let (schema, ipc_schema) = + deserialize_schemas(flight_data.data_body.as_slice()).unwrap(); + let schema = Arc::new(schema); // all the remaining stream messages should be dictionary and record batches - Ok(Box::pin(FlightDataStream::new(stream, schema))) + Ok(Box::pin(FlightDataStream::new(stream, schema, ipc_schema))) } None => Err(ballista_error( "Did not receive schema batch from flight server", @@ -136,13 +140,19 @@ impl BallistaClient { struct FlightDataStream { stream: Mutex>, schema: SchemaRef, + ipc_schema: IpcSchema, } impl FlightDataStream { - pub fn new(stream: Streaming, schema: SchemaRef) -> Self { + pub fn new( + stream: Streaming, + schema: SchemaRef, + ipc_schema: IpcSchema, + ) -> Self { Self { stream: Mutex::new(stream), schema, + ipc_schema, } } } @@ -161,10 +171,11 @@ impl Stream for FlightDataStream { .map_err(|e| ArrowError::from_external_error(Box::new(e))) .and_then(|flight_data_chunk| { let hm = HashMap::new(); + arrow::io::flight::deserialize_batch( &flight_data_chunk, self.schema.clone(), - true, + &self.ipc_schema, &hm, ) }); diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 49dbb1b4c480..991a9330e2df 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -458,12 +458,17 @@ impl ShuffleWriter { num_rows: 0, num_bytes: 0, path: path.to_owned(), - writer: FileWriter::try_new(buffer_writer, schema, WriteOptions::default())?, + writer: FileWriter::try_new( + buffer_writer, + schema, + None, + WriteOptions::default(), + )?, }) } fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; + self.writer.write(batch, None)?; self.num_batches += 1; self.num_rows += batch.num_rows() as u64; let num_bytes: usize = batch diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 7f8291b1a7f0..f429e175664f 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -494,10 +494,10 @@ fn typechecked_scalar_value_conversion( ScalarValue::Date32(Some(*v)) } (Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => { - ScalarValue::TimestampMicrosecond(Some(*v)) + ScalarValue::TimestampMicrosecond(Some(*v), None) } (Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => { - ScalarValue::TimestampNanosecond(Some(*v)) + ScalarValue::TimestampNanosecond(Some(*v), None) } (Value::Utf8Value(v), PrimitiveScalarType::Utf8) => { ScalarValue::Utf8(Some(v.to_owned())) @@ -530,10 +530,10 @@ fn typechecked_scalar_value_conversion( PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None), PrimitiveScalarType::Date32 => ScalarValue::Date32(None), PrimitiveScalarType::TimeMicrosecond => { - ScalarValue::TimestampMicrosecond(None) + ScalarValue::TimestampMicrosecond(None, None) } PrimitiveScalarType::TimeNanosecond => { - ScalarValue::TimestampNanosecond(None) + ScalarValue::TimestampNanosecond(None, None) } PrimitiveScalarType::Null => { return Err(proto_error( @@ -593,10 +593,10 @@ impl TryInto for &protobuf::scalar_value::Value ScalarValue::Date32(Some(*v)) } protobuf::scalar_value::Value::TimeMicrosecondValue(v) => { - ScalarValue::TimestampMicrosecond(Some(*v)) + ScalarValue::TimestampMicrosecond(Some(*v), None) } protobuf::scalar_value::Value::TimeNanosecondValue(v) => { - ScalarValue::TimestampNanosecond(Some(*v)) + ScalarValue::TimestampNanosecond(Some(*v), None) } protobuf::scalar_value::Value::ListValue(v) => v.try_into()?, protobuf::scalar_value::Value::NullListValue(v) => { @@ -758,10 +758,10 @@ impl TryInto for protobuf::PrimitiveScalarType protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None), protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None), protobuf::PrimitiveScalarType::TimeMicrosecond => { - ScalarValue::TimestampMicrosecond(None) + ScalarValue::TimestampMicrosecond(None, None) } protobuf::PrimitiveScalarType::TimeNanosecond => { - ScalarValue::TimestampNanosecond(None) + ScalarValue::TimestampNanosecond(None, None) } }) } @@ -811,10 +811,10 @@ impl TryInto for &protobuf::ScalarValue { ScalarValue::Date32(Some(*v)) } protobuf::scalar_value::Value::TimeMicrosecondValue(v) => { - ScalarValue::TimestampMicrosecond(Some(*v)) + ScalarValue::TimestampMicrosecond(Some(*v), None) } protobuf::scalar_value::Value::TimeNanosecondValue(v) => { - ScalarValue::TimestampNanosecond(Some(*v)) + ScalarValue::TimestampNanosecond(Some(*v), None) } protobuf::scalar_value::Value::ListValue(scalar_list) => { let protobuf::ScalarListValue { diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 75d856499529..50ab4c7b7c91 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -217,8 +217,8 @@ mod roundtrip_tests { ScalarValue::LargeUtf8(None), ScalarValue::List(None, Box::new(DataType::Boolean)), ScalarValue::Date32(None), - ScalarValue::TimestampMicrosecond(None), - ScalarValue::TimestampNanosecond(None), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampNanosecond(None, None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), ScalarValue::Float32(Some(1.0)), @@ -257,11 +257,11 @@ mod roundtrip_tests { ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(i32::MAX)), - ScalarValue::TimestampNanosecond(Some(0)), - ScalarValue::TimestampNanosecond(Some(i64::MAX)), - ScalarValue::TimestampMicrosecond(Some(0)), - ScalarValue::TimestampMicrosecond(Some(i64::MAX)), - ScalarValue::TimestampMicrosecond(None), + ScalarValue::TimestampNanosecond(Some(0), None), + ScalarValue::TimestampNanosecond(Some(i64::MAX), None), + ScalarValue::TimestampMicrosecond(Some(0), None), + ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), + ScalarValue::TimestampMicrosecond(None, None), ScalarValue::List( Some(Box::new(vec![ ScalarValue::Float32(Some(-213.1)), @@ -604,8 +604,8 @@ mod roundtrip_tests { ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), ScalarValue::Date32(None), - ScalarValue::TimestampMicrosecond(None), - ScalarValue::TimestampNanosecond(None), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampNanosecond(None, None), //ScalarValue::List(None, DataType::Boolean) ]; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index dd19cd7c0c4a..573cf86e607d 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -186,115 +186,7 @@ impl TryInto for &protobuf::ArrowType { "Protobuf deserialization error: ArrowType missing required field 'data_type'", ) })?; - Ok(match pb_arrow_type { - protobuf::arrow_type::ArrowTypeEnum::None(_) => DataType::Null, - protobuf::arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, - protobuf::arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, - protobuf::arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, - protobuf::arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, - protobuf::arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, - protobuf::arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, - protobuf::arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, - protobuf::arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, - protobuf::arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, - protobuf::arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, - protobuf::arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, - protobuf::arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, - protobuf::arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, - protobuf::arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, - protobuf::arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, - protobuf::arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { - DataType::FixedSizeBinary(*size as usize) - } - protobuf::arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, - protobuf::arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, - protobuf::arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, - protobuf::arrow_type::ArrowTypeEnum::Duration(time_unit_i32) => { - DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit_i32)?) - } - protobuf::arrow_type::ArrowTypeEnum::Timestamp(timestamp) => { - DataType::Timestamp( - protobuf::TimeUnit::from_i32_to_arrow(timestamp.time_unit)?, - match timestamp.timezone.is_empty() { - true => None, - false => Some(timestamp.timezone.to_owned()), - }, - ) - } - protobuf::arrow_type::ArrowTypeEnum::Time32(time_unit_i32) => { - DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit_i32)?) - } - protobuf::arrow_type::ArrowTypeEnum::Time64(time_unit_i32) => { - DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit_i32)?) - } - protobuf::arrow_type::ArrowTypeEnum::Interval(interval_unit_i32) => { - DataType::Interval(protobuf::IntervalUnit::from_i32_to_arrow( - *interval_unit_i32, - )?) - } - protobuf::arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { - whole, - fractional, - }) => DataType::Decimal(*whole as usize, *fractional as usize), - protobuf::arrow_type::ArrowTypeEnum::List(boxed_list) => { - let field_ref = boxed_list - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message was missing required field 'field_type'"))? - .as_ref(); - DataType::List(Box::new(field_ref.try_into()?)) - } - protobuf::arrow_type::ArrowTypeEnum::LargeList(boxed_list) => { - let field_ref = boxed_list - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message was missing required field 'field_type'"))? - .as_ref(); - DataType::LargeList(Box::new(field_ref.try_into()?)) - } - protobuf::arrow_type::ArrowTypeEnum::FixedSizeList(boxed_list) => { - let fsl_ref = boxed_list.as_ref(); - let pb_fieldtype = fsl_ref - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: FixedSizeList message was missing required field 'field_type'"))?; - DataType::FixedSizeList( - Box::new(pb_fieldtype.as_ref().try_into()?), - fsl_ref.list_size as usize, - ) - } - protobuf::arrow_type::ArrowTypeEnum::Struct(struct_type) => { - let fields = struct_type - .sub_field_types - .iter() - .map(|field| field.try_into()) - .collect::, _>>()?; - DataType::Struct(fields) - } - protobuf::arrow_type::ArrowTypeEnum::Union(union) => { - let union_types = union - .union_types - .iter() - .map(|field| field.try_into()) - .collect::, _>>()?; - DataType::Union(union_types, None, UnionMode::Dense) - } - protobuf::arrow_type::ArrowTypeEnum::Dictionary(boxed_dict) => { - let dict_ref = boxed_dict.as_ref(); - let pb_key = dict_ref - .key - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message was missing required field 'key'"))?; - let pb_value = dict_ref - .value - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message was missing required field 'value'"))?; - DataType::Dictionary( - pb_key.try_into()?, - Box::new(pb_value.as_ref().try_into()?), - ) - } - }) + pb_arrow_type.try_into() } } @@ -404,7 +296,7 @@ impl From<&DataType> for protobuf::arrow_type::ArrowTypeEnum { .map(|field| field.into()) .collect::>(), }), - DataType::Dictionary(key_type, value_type) => { + DataType::Dictionary(key_type, value_type, _) => { ArrowTypeEnum::Dictionary(Box::new(protobuf::Dictionary { key: Some(key_type.into()), value: Some(Box::new(value_type.as_ref().into())), @@ -551,7 +443,7 @@ impl TryFrom<&DataType> for protobuf::scalar_type::Datatype { | DataType::LargeList(_) | DataType::Struct(_) | DataType::Union(_, _, _) - | DataType::Dictionary(_, _) + | DataType::Dictionary(_, _, _) | DataType::Decimal(_, _) => { return Err(proto_error(format!( "Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.", @@ -710,12 +602,12 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { datafusion::scalar::ScalarValue::Date32(val) => { create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s)) } - datafusion::scalar::ScalarValue::TimestampMicrosecond(val) => { + datafusion::scalar::ScalarValue::TimestampMicrosecond(val, _) => { create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| { Value::TimeMicrosecondValue(*s) }) } - datafusion::scalar::ScalarValue::TimestampNanosecond(val) => { + datafusion::scalar::ScalarValue::TimestampNanosecond(val, _) => { create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| { Value::TimeNanosecondValue(*s) }) @@ -1192,6 +1084,14 @@ impl TryInto for &Expr { AggregateFunction::Sum => protobuf::AggregateFunction::Sum, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } }; let arg = &args[0]; @@ -1422,6 +1322,10 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, + AggregateFunction::Variance => Self::Variance, + AggregateFunction::VariancePop => Self::VariancePop, + AggregateFunction::Stddev => Self::Stddev, + AggregateFunction::StddevPop => Self::StddevPop, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 5ed2e27d1752..9ff2a6cedb17 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -98,6 +98,7 @@ pub(crate) fn from_proto_binary_op(op: &str) -> Result "Minus" => Ok(Operator::Minus), "Multiply" => Ok(Operator::Multiply), "Divide" => Ok(Operator::Divide), + "Modulo" => Ok(Operator::Modulo), "Like" => Ok(Operator::Like), "NotLike" => Ok(Operator::NotLike), other => Err(proto_error(format!( @@ -119,6 +120,10 @@ impl From for AggregateFunction { AggregateFunction::ApproxDistinct } protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, + protobuf::AggregateFunction::Variance => AggregateFunction::Variance, + protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop, + protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, + protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, } } } @@ -267,7 +272,7 @@ impl TryInto .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; let key_datatype: IntegerType = pb_key_datatype.try_into()?; let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?; - DataType::Dictionary(key_datatype, Box::new(value_datatype)) + DataType::Dictionary(key_datatype, Box::new(value_datatype), false) } }) } diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 15857678bf01..20820ee2bf23 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -87,6 +87,7 @@ pub async fn write_stream_to_disk( let mut writer = FileWriter::try_new( &mut file, stream.schema().as_ref(), + None, WriteOptions::default(), )?; @@ -103,7 +104,7 @@ pub async fn write_stream_to_disk( num_bytes += batch_size_bytes; let timer = disk_write_metric.timer(); - writer.write(&batch)?; + writer.write(&batch, None)?; timer.done(); } let timer = disk_write_metric.timer(); diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs index 398ebca2b8e6..d073d60f7209 100644 --- a/ballista/rust/executor/src/executor.rs +++ b/ballista/rust/executor/src/executor.rs @@ -78,9 +78,7 @@ impl Executor { job_id, stage_id, part, - DisplayableExecutionPlan::with_metrics(&exec) - .indent() - .to_string() + DisplayableExecutionPlan::with_metrics(&exec).indent() ); Ok(partitions) diff --git a/ballista/rust/executor/src/flight_service.rs b/ballista/rust/executor/src/flight_service.rs index 6199a44e509f..79666332a7f4 100644 --- a/ballista/rust/executor/src/flight_service.rs +++ b/ballista/rust/executor/src/flight_service.rs @@ -179,7 +179,7 @@ fn create_flight_iter( options: &WriteOptions, ) -> Box>> { let (flight_dictionaries, flight_batch) = - arrow::io::flight::serialize_batch(batch, options); + arrow::io::flight::serialize_batch(batch, &[], options); Box::new( flight_dictionaries .into_iter() @@ -202,7 +202,7 @@ async fn stream_flight_data(path: String, tx: FlightDataSender) -> Result<(), St let options = WriteOptions::default(); let schema_flight_data = - arrow::io::flight::serialize_schema(reader.schema().as_ref()); + arrow::io::flight::serialize_schema(reader.schema().as_ref(), &[]); send_response(&tx, Ok(schema_flight_data)).await?; let mut row_count = 0; diff --git a/ballista/rust/executor/src/standalone.rs b/ballista/rust/executor/src/standalone.rs index a9aedbf1687d..89f98082e9f7 100644 --- a/ballista/rust/executor/src/standalone.rs +++ b/ballista/rust/executor/src/standalone.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow_format::flight::service::flight_service_server::FlightServiceServer; use ballista_core::{ error::Result, + serde::protobuf::executor_registration::OptionalHost, serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration}, BALLISTA_VERSION, }; @@ -59,7 +60,7 @@ pub async fn new_standalone_executor( ); let executor_meta = ExecutorRegistration { id: Uuid::new_v4().to_string(), // assign this executor a unique ID - optional_host: None, + optional_host: Some(OptionalHost::Host("localhost".to_string())), port: addr.port() as u32, }; tokio::spawn(execution_loop::poll_loop( diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 3291a62abe64..efc7eb607e59 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -293,7 +293,7 @@ mod test { .plan_query_stages(&job_uuid.to_string(), plan) .await?; for stage in &stages { - println!("{}", displayable(stage.as_ref()).indent().to_string()); + println!("{}", displayable(stage.as_ref()).indent()); } /* Expected result: @@ -407,7 +407,7 @@ order by .plan_query_stages(&job_uuid.to_string(), plan) .await?; for stage in &stages { - println!("{}", displayable(stage.as_ref()).indent().to_string()); + println!("{}", displayable(stage.as_ref()).indent()); } /* Expected result: diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 5cbcca561be1..db863d68f335 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -41,6 +41,7 @@ futures = "0.3" env_logger = "0.9" mimalloc = { version = "0.1", optional = true, default-features = false } snmalloc-rs = {version = "0.2", optional = true, features= ["cache-friendly"] } +rand = "0.8.4" [dev-dependencies] ballista-core = { path = "../ballista/rust/core" } diff --git a/benchmarks/README.md b/benchmarks/README.md index a63761b6c2b3..e6c17430d6e2 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -178,5 +178,20 @@ Query 'fare_amt_by_passenger' iteration 1 took 7599 ms Query 'fare_amt_by_passenger' iteration 2 took 7969 ms ``` +## Running the Ballista Loadtest + +```bash + cargo run --bin tpch -- loadtest ballista-load + --query-list 1,3,5,6,7,10,12,13 + --requests 200 + --concurrency 10 + --data-path /**** + --format parquet + --host localhost + --port 50050 + --sql-path /*** + --debug +``` + [1]: http://www.tpc.org/tpch/ [2]: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index a077d83b3771..9d3302055121 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -17,6 +17,9 @@ //! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. +use futures::future::join_all; +use rand::prelude::*; +use std::ops::Div; use std::{ fs, iter::Iterator, @@ -46,16 +49,17 @@ use datafusion::{ }; use arrow::io::parquet::write::{Compression, Version, WriteOptions}; +use arrow::io::print::print; use ballista::prelude::{ BallistaConfig, BallistaContext, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, }; use structopt::StructOpt; -#[cfg(feature = "snmalloc")] +#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; -#[cfg(feature = "mimalloc")] +#[cfg(all(feature = "mimalloc", not(feature = "snmalloc")))] #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -135,6 +139,48 @@ struct DataFusionBenchmarkOpt { mem_table: bool, } +#[derive(Debug, StructOpt, Clone)] +struct BallistaLoadtestOpt { + #[structopt(short = "q", long)] + query_list: String, + + /// Activate debug mode to see query results + #[structopt(short, long)] + debug: bool, + + /// Number of requests + #[structopt(short = "r", long = "requests", default_value = "100")] + requests: usize, + + /// Number of connections + #[structopt(short = "c", long = "concurrency", default_value = "5")] + concurrency: usize, + + /// Number of partitions to process in parallel + #[structopt(short = "n", long = "partitions", default_value = "2")] + partitions: usize, + + /// Path to data files + #[structopt(parse(from_os_str), required = true, short = "p", long = "data-path")] + path: PathBuf, + + /// Path to sql files + #[structopt(parse(from_os_str), required = true, long = "sql-path")] + sql_path: PathBuf, + + /// File format: `csv` or `parquet` + #[structopt(short = "f", long = "format", default_value = "parquet")] + file_format: String, + + /// Ballista executor host + #[structopt(long = "host")] + host: Option, + + /// Ballista executor port + #[structopt(long = "port")] + port: Option, +} + #[derive(Debug, StructOpt)] struct ConvertOpt { /// Path to csv files @@ -171,11 +217,19 @@ enum BenchmarkSubCommandOpt { DataFusionBenchmark(DataFusionBenchmarkOpt), } +#[derive(Debug, StructOpt)] +#[structopt(about = "loadtest command")] +enum LoadtestOpt { + #[structopt(name = "ballista-load")] + BallistaLoadtest(BallistaLoadtestOpt), +} + #[derive(Debug, StructOpt)] #[structopt(name = "TPC-H", about = "TPC-H Benchmarks.")] enum TpchOpt { Benchmark(BenchmarkSubCommandOpt), Convert(ConvertOpt), + Loadtest(LoadtestOpt), } const TABLES: &[&str] = &[ @@ -185,6 +239,7 @@ const TABLES: &[&str] = &[ #[tokio::main] async fn main() -> Result<()> { use BenchmarkSubCommandOpt::*; + use LoadtestOpt::*; env_logger::init(); match TpchOpt::from_args() { @@ -195,6 +250,9 @@ async fn main() -> Result<()> { benchmark_datafusion(opt).await.map(|_| ()) } TpchOpt::Convert(opt) => convert_tbl(opt).await, + TpchOpt::Loadtest(BallistaLoadtest(opt)) => { + loadtest_ballista(opt).await.map(|_| ()) + } } } @@ -266,6 +324,151 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> { // register tables with Ballista context let path = opt.path.to_str().unwrap(); let file_format = opt.file_format.as_str(); + + register_tables(path, file_format, &ctx).await; + + let mut millis = vec![]; + + // run benchmark + let sql = get_query_sql(opt.query)?; + println!("Running benchmark with query {}:\n {}", opt.query, sql); + for i in 0..opt.iterations { + let start = Instant::now(); + let df = ctx + .sql(&sql) + .await + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); + let batches = df + .collect() + .await + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + millis.push(elapsed as f64); + println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed); + if opt.debug { + print(&batches); + } + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {} avg time: {:.2} ms", opt.query, avg); + + Ok(()) +} + +async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { + println!( + "Running loadtest_ballista with the following options: {:?}", + opt + ); + + let config = BallistaConfig::builder() + .set( + BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, + &format!("{}", opt.partitions), + ) + .build() + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + let concurrency = opt.concurrency; + let request_amount = opt.requests; + let mut clients = vec![]; + + for _num in 0..concurrency { + clients.push(BallistaContext::remote( + opt.host.clone().unwrap().as_str(), + opt.port.unwrap(), + &config, + )); + } + + // register tables with Ballista context + let path = opt.path.to_str().unwrap(); + let file_format = opt.file_format.as_str(); + let sql_path = opt.sql_path.to_str().unwrap().to_string(); + + for ctx in &clients { + register_tables(path, file_format, ctx).await; + } + + let request_per_thread = request_amount.div(concurrency); + // run benchmark + let query_list: Vec = opt + .query_list + .split(',') + .map(|s| s.parse().unwrap()) + .collect(); + println!("query list: {:?} ", &query_list); + + let total = Instant::now(); + let mut futures = vec![]; + + for (client_id, client) in clients.into_iter().enumerate() { + let query_list_clone = query_list.clone(); + let sql_path_clone = sql_path.clone(); + let handle = tokio::spawn(async move { + for i in 0..request_per_thread { + let query_id = query_list_clone + .get( + (0..query_list_clone.len()) + .choose(&mut rand::thread_rng()) + .unwrap(), + ) + .unwrap(); + let sql = + get_query_sql_by_path(query_id.to_owned(), sql_path_clone.clone()) + .unwrap(); + println!( + "Client {} Round {} Query {} started", + &client_id, &i, query_id + ); + let start = Instant::now(); + let df = client + .sql(&sql) + .await + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); + let batches = df + .collect() + .await + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); + let elapsed = start.elapsed().as_secs_f64() * 1000.0; + println!( + "Client {} Round {} Query {} took {:.1} ms ", + &client_id, &i, query_id, elapsed + ); + if opt.debug { + print(&batches); + } + } + }); + futures.push(handle); + } + join_all(futures).await; + let elapsed = total.elapsed().as_secs_f64() * 1000.0; + println!("###############################"); + println!("load test took {:.1} ms", elapsed); + Ok(()) +} + +fn get_query_sql_by_path(query: usize, mut sql_path: String) -> Result { + if sql_path.ends_with('/') { + sql_path.pop(); + } + if query > 0 && query < 23 { + let filename = format!("{}/q{}.sql", sql_path, query); + Ok(fs::read_to_string(&filename).expect("failed to read query")) + } else { + Err(DataFusionError::Plan( + "invalid query. Expected value between 1 and 22".to_owned(), + )) + } +} + +async fn register_tables(path: &str, file_format: &str, ctx: &BallistaContext) { for table in TABLES { match file_format { // dbgen creates .tbl ('|' delimited) files without header @@ -279,7 +482,8 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> { .file_extension(".tbl"); ctx.register_csv(table, &path, options) .await - .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?; + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); } "csv" => { let path = format!("{}/{}", path, table); @@ -287,47 +491,21 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> { let options = CsvReadOptions::new().schema(&schema).has_header(true); ctx.register_csv(table, &path, options) .await - .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?; + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); } "parquet" => { let path = format!("{}/{}", path, table); ctx.register_parquet(table, &path) .await - .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?; + .map_err(|e| DataFusionError::Plan(format!("{:?}", e))) + .unwrap(); } other => { unimplemented!("Invalid file format '{}'", other); } } } - - let mut millis = vec![]; - - // run benchmark - let sql = get_query_sql(opt.query)?; - println!("Running benchmark with query {}:\n {}", opt.query, sql); - for i in 0..opt.iterations { - let start = Instant::now(); - let df = ctx - .sql(&sql) - .await - .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?; - let batches = df - .collect() - .await - .map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?; - let elapsed = start.elapsed().as_secs_f64() * 1000.0; - millis.push(elapsed as f64); - println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed); - if opt.debug { - print::print(&batches); - } - } - - let avg = millis.iter().sum::() / millis.len() as f64; - println!("Query {} avg time: {:.2} ms", opt.query, avg); - - Ok(()) } fn get_query_sql(query: usize) -> Result { @@ -362,16 +540,14 @@ async fn execute_query( if debug { println!( "=== Physical plan ===\n{}\n", - displayable(physical_plan.as_ref()).indent().to_string() + displayable(physical_plan.as_ref()).indent() ); } let result = collect(physical_plan.clone()).await?; if debug { println!( "=== Physical plan with metrics ===\n{}\n", - DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) - .indent() - .to_string() + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()).indent() ); print::print(&result); } diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 5beca25e4fbf..0b7fd8ff6212 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -16,11 +16,8 @@ // under the License. //! Print format variants -use datafusion::arrow::io::{ - csv::write, - json::{JsonArray, JsonFormat, LineDelimited, Writer}, - print, -}; +use arrow::io::json::write::{JsonArray, JsonFormat, LineDelimited}; +use datafusion::arrow::io::{csv::write, print}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; use std::fmt; @@ -74,11 +71,23 @@ impl fmt::Display for PrintFormat { } fn print_batches_to_json(batches: &[RecordBatch]) -> Result { + if batches.is_empty() { + return Ok("{}".to_string()); + } let mut bytes = vec![]; - { - let mut writer = Writer::<_, J>::new(&mut bytes); - writer.write_batches(batches)?; - writer.finish()?; + let schema = batches[0].schema(); + let names = schema + .fields + .iter() + .map(|f| f.name.clone()) + .collect::>(); + for batch in batches { + arrow::io::json::write::serialize( + &names, + batch.columns(), + J::default(), + &mut bytes, + ); } let formatted = String::from_utf8(bytes) .map_err(|e| DataFusionError::Execution(e.to_string()))?; diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs index be1d46259b6e..2489f3f42f81 100644 --- a/datafusion-examples/examples/avro_sql.rs +++ b/datafusion-examples/examples/avro_sql.rs @@ -27,7 +27,7 @@ async fn main() -> Result<()> { // create local execution context let mut ctx = ExecutionContext::new(); - let testdata = datafusion::arrow::util::test_util::arrow_test_data(); + let testdata = datafusion::test_util::arrow_test_data(); // register avro file with the execution context let avro_file = &format!("{}/avro/alltypes_plain.avro", testdata); diff --git a/datafusion-examples/examples/flight_client.rs b/datafusion-examples/examples/flight_client.rs index c26a8855c0c0..469f3ebef0c8 100644 --- a/datafusion-examples/examples/flight_client.rs +++ b/datafusion-examples/examples/flight_client.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; -use datafusion::arrow::datatypes::Schema; - +use arrow::io::flight::deserialize_schemas; use arrow_format::flight::data::{flight_descriptor, FlightDescriptor, Ticket}; use arrow_format::flight::service::flight_service_client::FlightServiceClient; use datafusion::arrow::io::print; @@ -43,7 +41,8 @@ async fn main() -> Result<(), Box> { }); let schema_result = client.get_schema(request).await?.into_inner(); - let schema = Schema::try_from(&schema_result)?; + let (schema, _) = deserialize_schemas(schema_result.schema.as_slice()).unwrap(); + let schema = Arc::new(schema); println!("Schema: {:?}", schema); // Call do_get to execute a SQL query and receive results @@ -56,7 +55,9 @@ async fn main() -> Result<(), Box> { // the schema should be the first message returned, else client should error let flight_data = stream.message().await?.unwrap(); // convert FlightData to a stream - let schema = Arc::new(Schema::try_from(&flight_data)?); + let (schema, ipc_schema) = + deserialize_schemas(flight_data.data_body.as_slice()).unwrap(); + let schema = Arc::new(schema); println!("Schema: {:?}", schema); // all the remaining stream messages should be dictionary and record batches @@ -66,7 +67,7 @@ async fn main() -> Result<(), Box> { let record_batch = arrow::io::flight::deserialize_batch( &flight_data, schema.clone(), - true, + &ipc_schema, &dictionaries_by_field, )?; results.push(record_batch); diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs index f2580969c9d3..9a7b8a6bed21 100644 --- a/datafusion-examples/examples/flight_server.rs +++ b/datafusion-examples/examples/flight_server.rs @@ -77,7 +77,7 @@ impl FlightService for FlightServiceImpl { .unwrap(); let schema_result = - arrow::io::flight::serialize_schema_to_result(schema.as_ref()); + arrow::io::flight::serialize_schema_to_result(schema.as_ref(), &[]); Ok(Response::new(schema_result)) } @@ -116,7 +116,7 @@ impl FlightService for FlightServiceImpl { // add an initial FlightData message that sends schema let options = WriteOptions::default(); let schema_flight_data = - arrow::io::flight::serialize_schema(&df.schema().clone().into()); + arrow::io::flight::serialize_schema(&df.schema().clone().into(), &[]); let mut flights: Vec> = vec![Ok(schema_flight_data)]; @@ -125,7 +125,7 @@ impl FlightService for FlightServiceImpl { .iter() .flat_map(|batch| { let (flight_dictionaries, flight_batch) = - arrow::io::flight::serialize_batch(batch, &options); + arrow::io::flight::serialize_batch(batch, &[], &options); flight_dictionaries .into_iter() .chain(std::iter::once(flight_batch)) diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs new file mode 100644 index 000000000000..50edc03df85a --- /dev/null +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::ListingOptions; +use datafusion::error::Result; +use datafusion::prelude::*; +use std::sync::Arc; + +/// This example demonstrates executing a simple query against an Arrow data source (a directory +/// with multiple Parquet files) and fetching results +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let mut ctx = ExecutionContext::new(); + + let testdata = datafusion::test_util::parquet_test_data(); + + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(true); + let listing_options = ListingOptions { + file_extension: ".parquet".to_owned(), + format: Arc::new(file_format), + table_partition_cols: vec![], + collect_stat: true, + target_partitions: 1, + }; + + // Register a listing table - this will use all files in the directory as data sources + // for the query + ctx.register_listing_table( + "my_table", + &format!("file://{}", testdata), + listing_options, + None, + ) + .await + .unwrap(); + + // execute the query + let df = ctx + .sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col", + ) + .await?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 48ecb49ac2f3..8137d6d65ff2 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -39,7 +39,9 @@ path = "src/lib.rs" [features] default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] -simd = ["arrow/simd"] +# FIXME: https://github.com/jorgecarleitao/arrow2/issues/580 +#simd = ["arrow/simd"] +simd = [] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] @@ -48,16 +50,16 @@ pyarrow = ["pyo3"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format -avro = ["avro-rs", "num-traits"] +avro = ["arrow/io_avro", "arrow/io_avro_async", "arrow/io_avro_compression", "num-traits", "avro-schema"] [dependencies] -ahash = "0.7" +ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.11", features = ["raw"] } parquet = { package = "parquet2", version = "0.8", default_features = false, features = ["stream"] } sqlparser = "0.13" paste = "^1.0" num_cpus = "1.13.0" -chrono = "0.4" +chrono = { version = "0.4", default-features = false, features = ["clock"] } async-trait = "0.1.41" futures = "0.3" pin-project-lite= "^0.2.7" @@ -74,9 +76,9 @@ regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0" } smallvec = { version = "1.6", features = ["union"] } rand = "0.8" -avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.14", optional = true } +avro-schema = { version = "0.2", optional = true } [dependencies.arrow] package = "arrow2" diff --git a/datafusion/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/src/avro_to_arrow/arrow_array_reader.rs index 9d5552954f53..1a8424ab8448 100644 --- a/datafusion/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/src/avro_to_arrow/arrow_array_reader.rs @@ -17,965 +17,55 @@ //! Avro to Arrow array readers -use crate::arrow::array::{ - make_array, Array, ArrayBuilder, ArrayData, ArrayDataBuilder, ArrayRef, - BooleanBuilder, LargeStringArray, ListBuilder, NullArray, OffsetSizeTrait, - PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, - StringDictionaryBuilder, -}; -use crate::arrow::buffer::{Buffer, MutableBuffer}; -use crate::arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type, - Date64Type, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, -}; -use crate::arrow::error::ArrowError; use crate::arrow::record_batch::RecordBatch; -use crate::arrow::util::bit_util; -use crate::error::{DataFusionError, Result}; -use arrow::array::{BinaryArray, GenericListArray}; +use crate::error::Result; +use crate::physical_plan::coalesce_batches::concat_batches; use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; -use avro_rs::{ - schema::{Schema as AvroSchema, SchemaKind}, - types::Value, - AvroResult, Error as AvroError, Reader as AvroReader, -}; -use num_traits::NumCast; -use std::collections::HashMap; +use arrow::io::avro::read::Reader as AvroReader; +use arrow::io::avro::{read, Compression}; use std::io::Read; -use std::sync::Arc; -type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; - -pub struct AvroArrowArrayReader<'a, R: Read> { - reader: AvroReader<'a, R>, +pub struct AvroBatchReader { + reader: AvroReader, schema: SchemaRef, - projection: Option>, - schema_lookup: HashMap, } -impl<'a, R: Read> AvroArrowArrayReader<'a, R> { +impl<'a, R: Read> AvroBatchReader { pub fn try_new( reader: R, schema: SchemaRef, - projection: Option>, + avro_schemas: Vec, + codec: Option, + file_marker: [u8; 16], ) -> Result { - let reader = AvroReader::new(reader)?; - let writer_schema = reader.writer_schema().clone(); - let schema_lookup = Self::schema_lookup(writer_schema)?; - Ok(Self { - reader, - schema, - projection, - schema_lookup, - }) - } - - pub fn schema_lookup(schema: AvroSchema) -> Result> { - match schema { - AvroSchema::Record { - lookup: ref schema_lookup, - .. - } => Ok(schema_lookup.clone()), - _ => Err(DataFusionError::ArrowError(SchemaError( - "expected avro schema to be a record".to_string(), - ))), - } + let reader = AvroReader::new( + read::Decompressor::new( + read::BlockStreamIterator::new(reader, file_marker), + codec, + ), + avro_schemas, + schema.clone(), + ); + Ok(Self { reader, schema }) } /// Read the next batch of records #[allow(clippy::should_implement_trait)] pub fn next_batch(&mut self, batch_size: usize) -> ArrowResult> { - let rows = self - .reader - .by_ref() - .take(batch_size) - .map(|value| match value { - Ok(Value::Record(v)) => Ok(v), - Err(e) => Err(ArrowError::ParseError(format!( - "Failed to parse avro value: {:?}", - e - ))), - other => { - return Err(ArrowError::ParseError(format!( - "Row needs to be of type object, got: {:?}", - other - ))) - } - }) - .collect::>>>()?; - if rows.is_empty() { - // reached end of file - return Ok(None); - } - let rows = rows.iter().collect::>>(); - let projection = self.projection.clone().unwrap_or_else(Vec::new); - let arrays = - self.build_struct_array(rows.as_slice(), self.schema.fields(), &projection); - let projected_fields: Vec = if projection.is_empty() { - self.schema.fields().to_vec() - } else { - projection - .iter() - .map(|name| self.schema.column_with_name(name)) - .flatten() - .map(|(_, field)| field.clone()) - .collect() - }; - let projected_schema = Arc::new(Schema::new(projected_fields)); - arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr).map(Some)) - } - - fn build_boolean_array( - &self, - rows: RecordSlice, - col_name: &str, - ) -> ArrowResult { - let mut builder = BooleanBuilder::new(rows.len()); - for row in rows { - if let Some(value) = self.field_lookup(col_name, row) { - if let Some(boolean) = resolve_boolean(&value) { - builder.append_value(boolean)? - } else { - builder.append_null()?; - } - } else { - builder.append_null()?; - } - } - Ok(Arc::new(builder.finish())) - } - - #[allow(clippy::unnecessary_wraps)] - fn build_primitive_array( - &self, - rows: RecordSlice, - col_name: &str, - ) -> ArrowResult - where - T: ArrowNumericType, - T::Native: num_traits::cast::NumCast, - { - Ok(Arc::new( - rows.iter() - .map(|row| { - self.field_lookup(col_name, row) - .and_then(|value| resolve_item::(&value)) - }) - .collect::>(), - )) - } - - #[inline(always)] - #[allow(clippy::unnecessary_wraps)] - fn build_string_dictionary_builder( - &self, - row_len: usize, - ) -> ArrowResult> - where - T: ArrowPrimitiveType + ArrowDictionaryKeyType, - { - let key_builder = PrimitiveBuilder::::new(row_len); - let values_builder = StringBuilder::new(row_len * 5); - Ok(StringDictionaryBuilder::new(key_builder, values_builder)) - } - - fn build_wrapped_list_array( - &self, - rows: RecordSlice, - col_name: &str, - key_type: &DataType, - ) -> ArrowResult { - match *key_type { - DataType::Int8 => { - let dtype = DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::Int16 => { - let dtype = DataType::Dictionary( - Box::new(DataType::Int16), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::Int32 => { - let dtype = DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::Int64 => { - let dtype = DataType::Dictionary( - Box::new(DataType::Int64), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::UInt8 => { - let dtype = DataType::Dictionary( - Box::new(DataType::UInt8), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::UInt16 => { - let dtype = DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::UInt32 => { - let dtype = DataType::Dictionary( - Box::new(DataType::UInt32), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - DataType::UInt64 => { - let dtype = DataType::Dictionary( - Box::new(DataType::UInt64), - Box::new(DataType::Utf8), - ); - self.list_array_string_array_builder::(&dtype, col_name, rows) - } - ref e => Err(SchemaError(format!( - "Data type is currently not supported for dictionaries in list : {:?}", - e - ))), - } - } - - #[inline(always)] - fn list_array_string_array_builder( - &self, - data_type: &DataType, - col_name: &str, - rows: RecordSlice, - ) -> ArrowResult - where - D: ArrowPrimitiveType + ArrowDictionaryKeyType, - { - let mut builder: Box = match data_type { - DataType::Utf8 => { - let values_builder = StringBuilder::new(rows.len() * 5); - Box::new(ListBuilder::new(values_builder)) - } - DataType::Dictionary(_, _) => { - let values_builder = - self.build_string_dictionary_builder::(rows.len() * 5)?; - Box::new(ListBuilder::new(values_builder)) - } - e => { - return Err(SchemaError(format!( - "Nested list data builder type is not supported: {:?}", - e - ))) - } - }; - - for row in rows { - if let Some(value) = self.field_lookup(col_name, row) { - // value can be an array or a scalar - let vals: Vec> = if let Value::String(v) = value { - vec![Some(v.to_string())] - } else if let Value::Array(n) = value { - n.iter() - .map(|v| resolve_string(&v)) - .collect::>>()? - .into_iter() - .map(Some) - .collect::>>() - } else if let Value::Null = value { - vec![None] - } else if !matches!(value, Value::Record(_)) { - vec![Some(resolve_string(&value)?)] - } else { - return Err(SchemaError( - "Only scalars are currently supported in Avro arrays".to_string(), - )); - }; - - // TODO: ARROW-10335: APIs of dictionary arrays and others are different. Unify - // them. - match data_type { - DataType::Utf8 => { - let builder = builder - .as_any_mut() - .downcast_mut::>() - .ok_or_else(||ArrowError::SchemaError( - "Cast failed for ListBuilder during nested data parsing".to_string(), - ))?; - for val in vals { - if let Some(v) = val { - builder.values().append_value(&v)? - } else { - builder.values().append_null()? - }; - } - - // Append to the list - builder.append(true)?; - } - DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::SchemaError( - "Cast failed for ListBuilder during nested data parsing".to_string(), - ))?; - for val in vals { - if let Some(v) = val { - let _ = builder.values().append(&v)?; - } else { - builder.values().append_null()? - }; - } - - // Append to the list - builder.append(true)?; - } - e => { - return Err(SchemaError(format!( - "Nested list data builder type is not supported: {:?}", - e - ))) - } - } - } - } - - Ok(builder.finish() as ArrayRef) - } - - #[inline(always)] - fn build_dictionary_array( - &self, - rows: RecordSlice, - col_name: &str, - ) -> ArrowResult - where - T::Native: num_traits::cast::NumCast, - T: ArrowPrimitiveType + ArrowDictionaryKeyType, - { - let mut builder: StringDictionaryBuilder = - self.build_string_dictionary_builder(rows.len())?; - for row in rows { - if let Some(value) = self.field_lookup(col_name, row) { - if let Ok(str_v) = resolve_string(&value) { - builder.append(str_v).map(drop)? + if let Some(Ok(batch)) = self.reader.next() { + let mut batch = batch; + 'batch: while batch.num_rows() < batch_size { + if let Some(Ok(next_batch)) = self.reader.next() { + let num_rows = batch.num_rows() + next_batch.num_rows(); + batch = concat_batches(&self.schema, &[batch, next_batch], num_rows)? } else { - builder.append_null()? - } - } else { - builder.append_null()? - } - } - Ok(Arc::new(builder.finish()) as ArrayRef) - } - - #[inline(always)] - fn build_string_dictionary_array( - &self, - rows: RecordSlice, - col_name: &str, - key_type: &DataType, - value_type: &DataType, - ) -> ArrowResult { - if let DataType::Utf8 = *value_type { - match *key_type { - DataType::Int8 => self.build_dictionary_array::(rows, col_name), - DataType::Int16 => { - self.build_dictionary_array::(rows, col_name) - } - DataType::Int32 => { - self.build_dictionary_array::(rows, col_name) - } - DataType::Int64 => { - self.build_dictionary_array::(rows, col_name) - } - DataType::UInt8 => { - self.build_dictionary_array::(rows, col_name) - } - DataType::UInt16 => { - self.build_dictionary_array::(rows, col_name) + break 'batch; } - DataType::UInt32 => { - self.build_dictionary_array::(rows, col_name) - } - DataType::UInt64 => { - self.build_dictionary_array::(rows, col_name) - } - _ => Err(ArrowError::SchemaError( - "unsupported dictionary key type".to_string(), - )), } + Ok(Some(batch)) } else { - Err(ArrowError::SchemaError( - "dictionary types other than UTF-8 not yet supported".to_string(), - )) - } - } - - /// Build a nested GenericListArray from a list of unnested `Value`s - fn build_nested_list_array( - &self, - rows: &[&Value], - list_field: &Field, - ) -> ArrowResult { - // build list offsets - let mut cur_offset = OffsetSize::zero(); - let list_len = rows.len(); - let num_list_bytes = bit_util::ceil(list_len, 8); - let mut offsets = Vec::with_capacity(list_len + 1); - let mut list_nulls = MutableBuffer::from_len_zeroed(num_list_bytes); - let list_nulls = list_nulls.as_slice_mut(); - offsets.push(cur_offset); - rows.iter().enumerate().for_each(|(i, v)| { - // TODO: unboxing Union(Array(Union(...))) should probably be done earlier - let v = maybe_resolve_union(v); - if let Value::Array(a) = v { - cur_offset += OffsetSize::from_usize(a.len()).unwrap(); - bit_util::set_bit(list_nulls, i); - } else if let Value::Null = v { - // value is null, not incremented - } else { - cur_offset += OffsetSize::one(); - } - offsets.push(cur_offset); - }); - let valid_len = cur_offset.to_usize().unwrap(); - let array_data = match list_field.data_type() { - DataType::Null => NullArray::new(valid_len).data().clone(), - DataType::Boolean => { - let num_bytes = bit_util::ceil(valid_len, 8); - let mut bool_values = MutableBuffer::from_len_zeroed(num_bytes); - let mut bool_nulls = - MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let mut curr_index = 0; - rows.iter().for_each(|v| { - if let Value::Array(vs) = v { - vs.iter().for_each(|value| { - if let Value::Boolean(child) = value { - // if valid boolean, append value - if *child { - bit_util::set_bit( - bool_values.as_slice_mut(), - curr_index, - ); - } - } else { - // null slot - bit_util::unset_bit( - bool_nulls.as_slice_mut(), - curr_index, - ); - } - curr_index += 1; - }); - } - }); - ArrayData::builder(list_field.data_type().clone()) - .len(valid_len) - .add_buffer(bool_values.into()) - .null_bit_buffer(bool_nulls.into()) - .build() - .unwrap() - } - DataType::Int8 => self.read_primitive_list_values::(rows), - DataType::Int16 => self.read_primitive_list_values::(rows), - DataType::Int32 => self.read_primitive_list_values::(rows), - DataType::Int64 => self.read_primitive_list_values::(rows), - DataType::UInt8 => self.read_primitive_list_values::(rows), - DataType::UInt16 => self.read_primitive_list_values::(rows), - DataType::UInt32 => self.read_primitive_list_values::(rows), - DataType::UInt64 => self.read_primitive_list_values::(rows), - DataType::Float16 => { - return Err(ArrowError::SchemaError("Float16 not supported".to_string())) - } - DataType::Float32 => self.read_primitive_list_values::(rows), - DataType::Float64 => self.read_primitive_list_values::(rows), - DataType::Timestamp(_, _) - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) => { - return Err(ArrowError::SchemaError( - "Temporal types are not yet supported, see ARROW-4803".to_string(), - )) - } - DataType::Utf8 => flatten_string_values(rows) - .into_iter() - .collect::() - .data() - .clone(), - DataType::LargeUtf8 => flatten_string_values(rows) - .into_iter() - .collect::() - .data() - .clone(), - DataType::List(field) => { - let child = - self.build_nested_list_array::(&flatten_values(rows), field)?; - child.data().clone() - } - DataType::LargeList(field) => { - let child = - self.build_nested_list_array::(&flatten_values(rows), field)?; - child.data().clone() - } - DataType::Struct(fields) => { - // extract list values, with non-lists converted to Value::Null - let array_item_count = rows - .iter() - .map(|row| match row { - Value::Array(values) => values.len(), - _ => 1, - }) - .sum(); - let num_bytes = bit_util::ceil(array_item_count, 8); - let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes); - let mut struct_index = 0; - let rows: Vec> = rows - .iter() - .map(|row| { - if let Value::Array(values) = row { - values.iter().for_each(|_| { - bit_util::set_bit( - null_buffer.as_slice_mut(), - struct_index, - ); - struct_index += 1; - }); - values - .iter() - .map(|v| ("".to_string(), v.clone())) - .collect::>() - } else { - struct_index += 1; - vec![("null".to_string(), Value::Null)] - } - }) - .collect(); - let rows = rows.iter().collect::>>(); - let arrays = - self.build_struct_array(rows.as_slice(), fields.as_slice(), &[])?; - let data_type = DataType::Struct(fields.clone()); - let buf = null_buffer.into(); - ArrayDataBuilder::new(data_type) - .len(rows.len()) - .null_bit_buffer(buf) - .child_data(arrays.into_iter().map(|a| a.data().clone()).collect()) - .build() - .unwrap() - } - datatype => { - return Err(ArrowError::SchemaError(format!( - "Nested list of {:?} not supported", - datatype - ))); - } - }; - // build list - let list_data = ArrayData::builder(DataType::List(Box::new(list_field.clone()))) - .len(list_len) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_child_data(array_data) - .null_bit_buffer(list_nulls.into()) - .build() - .unwrap(); - Ok(Arc::new(GenericListArray::::from(list_data))) - } - - /// Builds the child values of a `StructArray`, falling short of constructing the StructArray. - /// The function does not construct the StructArray as some callers would want the child arrays. - /// - /// *Note*: The function is recursive, and will read nested structs. - /// - /// If `projection` is not empty, then all values are returned. The first level of projection - /// occurs at the `RecordBatch` level. No further projection currently occurs, but would be - /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. - fn build_struct_array( - &self, - rows: RecordSlice, - struct_fields: &[Field], - projection: &[String], - ) -> ArrowResult> { - let arrays: ArrowResult> = struct_fields - .iter() - .filter(|field| projection.is_empty() || projection.contains(field.name())) - .map(|field| { - match field.data_type() { - DataType::Null => { - Ok(Arc::new(NullArray::new(rows.len())) as ArrayRef) - } - DataType::Boolean => self.build_boolean_array(rows, field.name()), - DataType::Float64 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Float32 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Int64 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Int32 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Int16 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Int8 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::UInt64 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::UInt32 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::UInt16 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::UInt8 => { - self.build_primitive_array::(rows, field.name()) - } - // TODO: this is incomplete - DataType::Timestamp(unit, _) => match unit { - TimeUnit::Second => self - .build_primitive_array::( - rows, - field.name(), - ), - TimeUnit::Microsecond => self - .build_primitive_array::( - rows, - field.name(), - ), - TimeUnit::Millisecond => self - .build_primitive_array::( - rows, - field.name(), - ), - TimeUnit::Nanosecond => self - .build_primitive_array::( - rows, - field.name(), - ), - }, - DataType::Date64 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Date32 => { - self.build_primitive_array::(rows, field.name()) - } - DataType::Time64(unit) => match unit { - TimeUnit::Microsecond => self - .build_primitive_array::( - rows, - field.name(), - ), - TimeUnit::Nanosecond => self - .build_primitive_array::( - rows, - field.name(), - ), - t => Err(ArrowError::SchemaError(format!( - "TimeUnit {:?} not supported with Time64", - t - ))), - }, - DataType::Time32(unit) => match unit { - TimeUnit::Second => self - .build_primitive_array::( - rows, - field.name(), - ), - TimeUnit::Millisecond => self - .build_primitive_array::( - rows, - field.name(), - ), - t => Err(ArrowError::SchemaError(format!( - "TimeUnit {:?} not supported with Time32", - t - ))), - }, - DataType::Utf8 | DataType::LargeUtf8 => Ok(Arc::new( - rows.iter() - .map(|row| { - let maybe_value = self.field_lookup(field.name(), row); - maybe_value - .map(|value| resolve_string(&value)) - .transpose() - }) - .collect::>()?, - ) - as ArrayRef), - DataType::Binary | DataType::LargeBinary => Ok(Arc::new( - rows.iter() - .map(|row| { - let maybe_value = self.field_lookup(field.name(), row); - maybe_value.and_then(resolve_bytes) - }) - .collect::(), - ) - as ArrayRef), - DataType::List(ref list_field) => { - match list_field.data_type() { - DataType::Dictionary(ref key_ty, _) => { - self.build_wrapped_list_array(rows, field.name(), key_ty) - } - _ => { - // extract rows by name - let extracted_rows = rows - .iter() - .map(|row| { - self.field_lookup(field.name(), row) - .unwrap_or(&Value::Null) - }) - .collect::>(); - self.build_nested_list_array::( - extracted_rows.as_slice(), - list_field, - ) - } - } - } - DataType::Dictionary(ref key_ty, ref val_ty) => self - .build_string_dictionary_array( - rows, - field.name(), - key_ty, - val_ty, - ), - DataType::Struct(fields) => { - let len = rows.len(); - let num_bytes = bit_util::ceil(len, 8); - let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes); - let struct_rows = rows - .iter() - .enumerate() - .map(|(i, row)| (i, self.field_lookup(field.name(), row))) - .map(|(i, v)| { - if let Some(Value::Record(value)) = v { - bit_util::set_bit(null_buffer.as_slice_mut(), i); - value - } else { - panic!("expected struct got {:?}", v); - } - }) - .collect::>>(); - let arrays = - self.build_struct_array(struct_rows.as_slice(), fields, &[])?; - // construct a struct array's data in order to set null buffer - let data_type = DataType::Struct(fields.clone()); - let data = ArrayDataBuilder::new(data_type) - .len(len) - .null_bit_buffer(null_buffer.into()) - .child_data( - arrays.into_iter().map(|a| a.data().clone()).collect(), - ) - .build() - .unwrap(); - Ok(make_array(data)) - } - _ => Err(ArrowError::SchemaError(format!( - "type {:?} not supported", - field.data_type() - ))), - } - }) - .collect(); - arrays - } - - /// Read the primitive list's values into ArrayData - fn read_primitive_list_values(&self, rows: &[&Value]) -> ArrayData - where - T: ArrowPrimitiveType + ArrowNumericType, - T::Native: num_traits::cast::NumCast, - { - let values = rows - .iter() - .flat_map(|row| { - let row = maybe_resolve_union(row); - if let Value::Array(values) = row { - values - .iter() - .map(resolve_item::) - .collect::>>() - } else if let Some(f) = resolve_item::(row) { - vec![Some(f)] - } else { - vec![] - } - }) - .collect::>>(); - let array = values.iter().collect::>(); - array.data().clone() - } - - fn field_lookup<'b>( - &self, - name: &str, - row: &'b [(String, Value)], - ) -> Option<&'b Value> { - self.schema_lookup - .get(name) - .and_then(|i| row.get(*i)) - .map(|o| &o.1) - } -} - -/// Flattens a list of Avro values, by flattening lists, and treating all other values as -/// single-value lists. -/// This is used to read into nested lists (list of list, list of struct) and non-dictionary lists. -#[inline] -fn flatten_values<'a>(values: &[&'a Value]) -> Vec<&'a Value> { - values - .iter() - .flat_map(|row| { - let v = maybe_resolve_union(row); - if let Value::Array(values) = v { - values.iter().collect() - } else { - // we interpret a scalar as a single-value list to minimise data loss - vec![v] - } - }) - .collect() -} - -/// Flattens a list into string values, dropping Value::Null in the process. -/// This is useful for interpreting any Avro array as string, dropping nulls. -/// See `value_as_string`. -#[inline] -fn flatten_string_values(values: &[&Value]) -> Vec> { - values - .iter() - .flat_map(|row| { - if let Value::Array(values) = row { - values - .iter() - .map(|s| resolve_string(s).ok()) - .collect::>>() - } else if let Value::Null = row { - vec![] - } else { - vec![resolve_string(row).ok()] - } - }) - .collect::>>() -} - -/// Reads an Avro value as a string, regardless of its type. -/// This is useful if the expected datatype is a string, in which case we preserve -/// all the values regardless of they type. -fn resolve_string(v: &Value) -> ArrowResult { - let v = if let Value::Union(b) = v { b } else { v }; - match v { - Value::String(s) => Ok(s.clone()), - Value::Bytes(bytes) => { - String::from_utf8(bytes.to_vec()).map_err(AvroError::ConvertToUtf8) - } - other => Err(AvroError::GetString(other.into())), - } - .map_err(|e| SchemaError(format!("expected resolvable string : {}", e))) -} - -fn resolve_u8(v: &Value) -> AvroResult { - let int = match v { - Value::Int(n) => Ok(Value::Int(*n)), - Value::Long(n) => Ok(Value::Int(*n as i32)), - other => Err(AvroError::GetU8(other.into())), - }?; - if let Value::Int(n) = int { - if n >= 0 && n <= std::convert::From::from(u8::MAX) { - return Ok(n as u8); - } - } - - Err(AvroError::GetU8(int.into())) -} - -fn resolve_bytes(v: &Value) -> Option> { - let v = if let Value::Union(b) = v { b } else { v }; - match v { - Value::Bytes(_) => Ok(v.clone()), - Value::String(s) => Ok(Value::Bytes(s.clone().into_bytes())), - Value::Array(items) => Ok(Value::Bytes( - items - .iter() - .map(resolve_u8) - .collect::, _>>() - .ok()?, - )), - other => Err(AvroError::GetBytes(other.into())), - } - .ok() - .and_then(|v| match v { - Value::Bytes(s) => Some(s), - _ => None, - }) -} - -fn resolve_boolean(value: &Value) -> Option { - let v = if let Value::Union(b) = value { - b - } else { - value - }; - match v { - Value::Boolean(boolean) => Some(*boolean), - _ => None, - } -} - -trait Resolver: ArrowPrimitiveType { - fn resolve(value: &Value) -> Option; -} - -fn resolve_item(value: &Value) -> Option { - T::resolve(value) -} - -fn maybe_resolve_union(value: &Value) -> &Value { - if SchemaKind::from(value) == SchemaKind::Union { - // Pull out the Union, and attempt to resolve against it. - match value { - Value::Union(b) => b, - _ => unreachable!(), - } - } else { - value - } -} - -impl Resolver for N -where - N: ArrowNumericType, - N::Native: num_traits::cast::NumCast, -{ - fn resolve(value: &Value) -> Option { - let value = maybe_resolve_union(value); - match value { - Value::Int(i) | Value::TimeMillis(i) | Value::Date(i) => NumCast::from(*i), - Value::Long(l) - | Value::TimeMicros(l) - | Value::TimestampMillis(l) - | Value::TimestampMicros(l) => NumCast::from(*l), - Value::Float(f) => NumCast::from(*f), - Value::Double(f) => NumCast::from(*f), - Value::Duration(_d) => unimplemented!(), // shenanigans type - Value::Null => None, - _ => unreachable!(), + Ok(None) } } } @@ -985,7 +75,7 @@ mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; use crate::avro_to_arrow::{Reader, ReaderBuilder}; - use arrow::array::{Int32Array, Int64Array, ListArray, TimestampMicrosecondArray}; + use arrow::array::{Int32Array, Int64Array, ListArray}; use arrow::datatypes::DataType; use std::fs::File; @@ -1009,18 +99,18 @@ mod test { assert_eq!(8, batch.num_rows()); let schema = reader.schema(); - let batch_schema = batch.schema(); + let batch_schema = batch.schema().clone(); assert_eq!(schema, batch_schema); let timestamp_col = schema.column_with_name("timestamp_col").unwrap(); assert_eq!( - &DataType::Timestamp(TimeUnit::Microsecond, None), + &DataType::Timestamp(TimeUnit::Microsecond, Some("00:00".to_string())), timestamp_col.1.data_type() ); let timestamp_array = batch .column(timestamp_col.0) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); for i in 0..timestamp_array.len() { assert!(timestamp_array.is_valid(i)); @@ -1046,11 +136,11 @@ mod test { let a_array = batch .column(col_id_index) .as_any() - .downcast_ref::() + .downcast_ref::>() .unwrap(); assert_eq!( *a_array.data_type(), - DataType::List(Box::new(Field::new("bigint", DataType::Int64, true))) + DataType::List(Box::new(Field::new("item", DataType::Int64, true))) ); let array = a_array.value(0); assert_eq!(*array.data_type(), DataType::Int64); @@ -1088,7 +178,7 @@ mod test { assert_eq!(11, batch.num_columns()); sum_num_rows += batch.num_rows(); num_batches += 1; - let batch_schema = batch.schema(); + let batch_schema = batch.schema().clone(); assert_eq!(schema, batch_schema); let a_array = batch .column(col_id_index) @@ -1098,7 +188,7 @@ mod test { sum_id += (0..a_array.len()).map(|i| a_array.value(i)).sum::(); } assert_eq!(8, sum_num_rows); - assert_eq!(2, num_batches); + assert_eq!(1, num_batches); assert_eq!(28, sum_id); } } diff --git a/datafusion/src/avro_to_arrow/mod.rs b/datafusion/src/avro_to_arrow/mod.rs index f30fbdcc0cec..5071c55bfe91 100644 --- a/datafusion/src/avro_to_arrow/mod.rs +++ b/datafusion/src/avro_to_arrow/mod.rs @@ -21,8 +21,6 @@ mod arrow_array_reader; #[cfg(feature = "avro")] mod reader; -#[cfg(feature = "avro")] -mod schema; use crate::arrow::datatypes::Schema; use crate::error::Result; @@ -33,9 +31,8 @@ use std::io::Read; #[cfg(feature = "avro")] /// Read Avro schema given a reader pub fn read_avro_schema_from_reader(reader: &mut R) -> Result { - let avro_reader = avro_rs::Reader::new(reader)?; - let schema = avro_reader.writer_schema(); - schema::to_arrow_schema(schema) + let (_, schema, _, _) = arrow::io::avro::read::read_metadata(reader)?; + Ok(schema) } #[cfg(not(feature = "avro"))] diff --git a/datafusion/src/avro_to_arrow/reader.rs b/datafusion/src/avro_to_arrow/reader.rs index 8baad14746d3..76f3672fc3a1 100644 --- a/datafusion/src/avro_to_arrow/reader.rs +++ b/datafusion/src/avro_to_arrow/reader.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -use super::arrow_array_reader::AvroArrowArrayReader; +use super::arrow_array_reader::AvroBatchReader; use crate::arrow::datatypes::SchemaRef; use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use arrow::error::Result as ArrowResult; +use arrow::io::avro::{read, Compression}; use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; @@ -56,11 +57,9 @@ impl ReaderBuilder { /// # Example /// /// ``` - /// extern crate avro_rs; - /// /// use std::fs::File; /// - /// fn example() -> crate::datafusion::avro_to_arrow::Reader<'static, File> { + /// fn example() -> crate::datafusion::avro_to_arrow::Reader { /// let file = File::open("test/data/basic.avro").unwrap(); /// /// // create a builder, inferring the schema with the first 100 records @@ -101,30 +100,49 @@ impl ReaderBuilder { } /// Create a new `Reader` from the `ReaderBuilder` - pub fn build<'a, R>(self, source: R) -> Result> + pub fn build(self, source: R) -> Result> where R: Read + Seek, { let mut source = source; // check if schema should be inferred - let schema = match self.schema { - Some(schema) => schema, - None => Arc::new(super::read_avro_schema_from_reader(&mut source)?), - }; source.seek(SeekFrom::Start(0))?; - Reader::try_new(source, schema, self.batch_size, self.projection) + let (mut avro_schemas, mut schema, codec, file_marker) = + read::read_metadata(&mut source)?; + if let Some(proj) = self.projection { + let indices: Vec = schema + .fields + .iter() + .filter(|f| !proj.contains(&f.name)) + .enumerate() + .map(|(i, _)| i) + .collect(); + for i in indices { + avro_schemas.remove(i); + schema.fields.remove(i); + } + } + + Reader::try_new( + source, + Arc::new(schema), + self.batch_size, + avro_schemas, + codec, + file_marker, + ) } } /// Avro file record reader -pub struct Reader<'a, R: Read> { - array_reader: AvroArrowArrayReader<'a, R>, +pub struct Reader { + array_reader: AvroBatchReader, schema: SchemaRef, batch_size: usize, } -impl<'a, R: Read> Reader<'a, R> { +impl<'a, R: Read> Reader { /// Create a new Avro Reader from any value that implements the `Read` trait. /// /// If reading a `File`, you can customise the Reader, such as to enable schema @@ -133,13 +151,17 @@ impl<'a, R: Read> Reader<'a, R> { reader: R, schema: SchemaRef, batch_size: usize, - projection: Option>, + avro_schemas: Vec, + codec: Option, + file_marker: [u8; 16], ) -> Result { Ok(Self { - array_reader: AvroArrowArrayReader::try_new( + array_reader: AvroBatchReader::try_new( reader, schema.clone(), - projection, + avro_schemas, + codec, + file_marker, )?, schema, batch_size, @@ -160,7 +182,7 @@ impl<'a, R: Read> Reader<'a, R> { } } -impl<'a, R: Read> Iterator for Reader<'a, R> { +impl<'a, R: Read> Iterator for Reader { type Item = ArrowResult; fn next(&mut self) -> Option { @@ -200,7 +222,7 @@ mod tests { let schema = reader.schema(); let batch_schema = batch.schema(); - assert_eq!(schema, batch_schema); + assert_eq!(schema, batch_schema.clone()); let id = schema.column_with_name("id").unwrap(); assert_eq!(0, id.0); @@ -259,22 +281,22 @@ mod tests { let date_string_col = schema.column_with_name("date_string_col").unwrap(); assert_eq!(8, date_string_col.0); assert_eq!(&DataType::Binary, date_string_col.1.data_type()); - let col = get_col::(&batch, date_string_col).unwrap(); + let col = get_col::>(&batch, date_string_col).unwrap(); assert_eq!("01/01/09".as_bytes(), col.value(0)); assert_eq!("01/01/09".as_bytes(), col.value(1)); let string_col = schema.column_with_name("string_col").unwrap(); assert_eq!(9, string_col.0); assert_eq!(&DataType::Binary, string_col.1.data_type()); - let col = get_col::(&batch, string_col).unwrap(); + let col = get_col::>(&batch, string_col).unwrap(); assert_eq!("0".as_bytes(), col.value(0)); assert_eq!("1".as_bytes(), col.value(1)); let timestamp_col = schema.column_with_name("timestamp_col").unwrap(); assert_eq!(10, timestamp_col.0); assert_eq!( - &DataType::Timestamp(TimeUnit::Microsecond, None), + &DataType::Timestamp(TimeUnit::Microsecond, Some("00:00".to_string())), timestamp_col.1.data_type() ); - let col = get_col::(&batch, timestamp_col).unwrap(); + let col = get_col::(&batch, timestamp_col).unwrap(); assert_eq!(1230768000000000, col.value(0)); assert_eq!(1230768060000000, col.value(1)); } diff --git a/datafusion/src/avro_to_arrow/schema.rs b/datafusion/src/avro_to_arrow/schema.rs deleted file mode 100644 index c6eda8017012..000000000000 --- a/datafusion/src/avro_to_arrow/schema.rs +++ /dev/null @@ -1,465 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit}; -use crate::error::{DataFusionError, Result}; -use arrow::datatypes::Field; -use avro_rs::schema::Name; -use avro_rs::types::Value; -use avro_rs::Schema as AvroSchema; -use std::collections::BTreeMap; -use std::convert::TryFrom; - -/// Converts an avro schema to an arrow schema -pub fn to_arrow_schema(avro_schema: &avro_rs::Schema) -> Result { - let mut schema_fields = vec![]; - match avro_schema { - AvroSchema::Record { fields, .. } => { - for field in fields { - schema_fields.push(schema_to_field_with_props( - &field.schema, - Some(&field.name), - false, - Some(&external_props(&field.schema)), - )?) - } - } - schema => schema_fields.push(schema_to_field(schema, Some(""), false)?), - } - - let schema = Schema::new(schema_fields); - Ok(schema) -} - -fn schema_to_field( - schema: &avro_rs::Schema, - name: Option<&str>, - nullable: bool, -) -> Result { - schema_to_field_with_props(schema, name, nullable, None) -} - -fn schema_to_field_with_props( - schema: &AvroSchema, - name: Option<&str>, - nullable: bool, - props: Option<&BTreeMap>, -) -> Result { - let mut nullable = nullable; - let field_type: DataType = match schema { - AvroSchema::Null => DataType::Null, - AvroSchema::Boolean => DataType::Boolean, - AvroSchema::Int => DataType::Int32, - AvroSchema::Long => DataType::Int64, - AvroSchema::Float => DataType::Float32, - AvroSchema::Double => DataType::Float64, - AvroSchema::Bytes => DataType::Binary, - AvroSchema::String => DataType::Utf8, - AvroSchema::Array(item_schema) => DataType::List(Box::new( - schema_to_field_with_props(item_schema, None, false, None)?, - )), - AvroSchema::Map(value_schema) => { - let value_field = - schema_to_field_with_props(value_schema, Some("value"), false, None)?; - DataType::Dictionary( - Box::new(DataType::Utf8), - Box::new(value_field.data_type().clone()), - ) - } - AvroSchema::Union(us) => { - // If there are only two variants and one of them is null, set the other type as the field data type - let has_nullable = us.find_schema(&Value::Null).is_some(); - let sub_schemas = us.variants(); - if has_nullable && sub_schemas.len() == 2 { - nullable = true; - if let Some(schema) = sub_schemas - .iter() - .find(|&schema| !matches!(schema, AvroSchema::Null)) - { - schema_to_field_with_props(schema, None, has_nullable, None)? - .data_type() - .clone() - } else { - return Err(DataFusionError::AvroError( - avro_rs::Error::GetUnionDuplicate, - )); - } - } else { - let fields = sub_schemas - .iter() - .map(|s| schema_to_field_with_props(s, None, has_nullable, None)) - .collect::>>()?; - DataType::Union(fields) - } - } - AvroSchema::Record { name, fields, .. } => { - let fields: Result> = fields - .iter() - .map(|field| { - let mut props = BTreeMap::new(); - if let Some(doc) = &field.doc { - props.insert("avro::doc".to_string(), doc.clone()); - } - /*if let Some(aliases) = fields.aliases { - props.insert("aliases", aliases); - }*/ - schema_to_field_with_props( - &field.schema, - Some(&format!("{}.{}", name.fullname(None), field.name)), - false, - Some(&props), - ) - }) - .collect(); - DataType::Struct(fields?) - } - AvroSchema::Enum { symbols, name, .. } => { - return Ok(Field::new_dict( - &name.fullname(None), - index_type(symbols.len()), - false, - 0, - false, - )) - } - AvroSchema::Fixed { size, .. } => DataType::FixedSizeBinary(*size as i32), - AvroSchema::Decimal { - precision, scale, .. - } => DataType::Decimal(*precision, *scale), - AvroSchema::Uuid => DataType::FixedSizeBinary(16), - AvroSchema::Date => DataType::Date32, - AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond), - AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond), - AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None), - AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None), - AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond), - }; - - let data_type = field_type.clone(); - let name = name.unwrap_or_else(|| default_field_name(&data_type)); - - let mut field = Field::new(name, field_type, nullable); - field.set_metadata(props.cloned()); - Ok(field) -} - -fn default_field_name(dt: &DataType) -> &str { - match dt { - DataType::Null => "null", - DataType::Boolean => "bit", - DataType::Int8 => "tinyint", - DataType::Int16 => "smallint", - DataType::Int32 => "int", - DataType::Int64 => "bigint", - DataType::UInt8 => "uint1", - DataType::UInt16 => "uint2", - DataType::UInt32 => "uint4", - DataType::UInt64 => "uint8", - DataType::Float16 => "float2", - DataType::Float32 => "float4", - DataType::Float64 => "float8", - DataType::Date32 => "dateday", - DataType::Date64 => "datemilli", - DataType::Time32(tu) | DataType::Time64(tu) => match tu { - TimeUnit::Second => "timesec", - TimeUnit::Millisecond => "timemilli", - TimeUnit::Microsecond => "timemicro", - TimeUnit::Nanosecond => "timenano", - }, - DataType::Timestamp(tu, tz) => { - if tz.is_some() { - match tu { - TimeUnit::Second => "timestampsectz", - TimeUnit::Millisecond => "timestampmillitz", - TimeUnit::Microsecond => "timestampmicrotz", - TimeUnit::Nanosecond => "timestampnanotz", - } - } else { - match tu { - TimeUnit::Second => "timestampsec", - TimeUnit::Millisecond => "timestampmilli", - TimeUnit::Microsecond => "timestampmicro", - TimeUnit::Nanosecond => "timestampnano", - } - } - } - DataType::Duration(_) => "duration", - DataType::Interval(unit) => match unit { - IntervalUnit::YearMonth => "intervalyear", - IntervalUnit::DayTime => "intervalmonth", - }, - DataType::Binary => "varbinary", - DataType::FixedSizeBinary(_) => "fixedsizebinary", - DataType::LargeBinary => "largevarbinary", - DataType::Utf8 => "varchar", - DataType::LargeUtf8 => "largevarchar", - DataType::List(_) => "list", - DataType::FixedSizeList(_, _) => "fixed_size_list", - DataType::LargeList(_) => "largelist", - DataType::Struct(_) => "struct", - DataType::Union(_) => "union", - DataType::Dictionary(_, _) => "map", - DataType::Map(_, _) => unimplemented!("Map support not implemented"), - DataType::Decimal(_, _) => "decimal", - } -} - -fn index_type(len: usize) -> DataType { - if len <= usize::from(u8::MAX) { - DataType::Int8 - } else if len <= usize::from(u16::MAX) { - DataType::Int16 - } else if usize::try_from(u32::MAX).map(|i| len < i).unwrap_or(false) { - DataType::Int32 - } else { - DataType::Int64 - } -} - -fn external_props(schema: &AvroSchema) -> BTreeMap { - let mut props = BTreeMap::new(); - match &schema { - AvroSchema::Record { - doc: Some(ref doc), .. - } - | AvroSchema::Enum { - doc: Some(ref doc), .. - } => { - props.insert("avro::doc".to_string(), doc.clone()); - } - _ => {} - } - match &schema { - AvroSchema::Record { - name: - Name { - aliases: Some(aliases), - namespace, - .. - }, - .. - } - | AvroSchema::Enum { - name: - Name { - aliases: Some(aliases), - namespace, - .. - }, - .. - } - | AvroSchema::Fixed { - name: - Name { - aliases: Some(aliases), - namespace, - .. - }, - .. - } => { - let aliases: Vec = aliases - .iter() - .map(|alias| aliased(alias, namespace.as_deref(), None)) - .collect(); - props.insert( - "avro::aliases".to_string(), - format!("[{}]", aliases.join(",")), - ); - } - _ => {} - } - props -} - -#[allow(dead_code)] -fn get_metadata( - _schema: AvroSchema, - props: BTreeMap, -) -> BTreeMap { - let mut metadata: BTreeMap = Default::default(); - metadata.extend(props); - metadata -} - -/// Returns the fully qualified name for a field -pub fn aliased( - name: &str, - namespace: Option<&str>, - default_namespace: Option<&str>, -) -> String { - if name.contains('.') { - name.to_string() - } else { - let namespace = namespace.as_ref().copied().or(default_namespace); - - match namespace { - Some(ref namespace) => format!("{}.{}", namespace, name), - None => name.to_string(), - } - } -} - -#[cfg(test)] -mod test { - use super::{aliased, external_props, to_arrow_schema}; - use crate::arrow::datatypes::DataType::{Binary, Float32, Float64, Timestamp, Utf8}; - use crate::arrow::datatypes::TimeUnit::Microsecond; - use crate::arrow::datatypes::{Field, Schema}; - use arrow::datatypes::DataType::{Boolean, Int32, Int64}; - use avro_rs::schema::Name; - use avro_rs::Schema as AvroSchema; - - #[test] - fn test_alias() { - assert_eq!(aliased("foo.bar", None, None), "foo.bar"); - assert_eq!(aliased("bar", Some("foo"), None), "foo.bar"); - assert_eq!(aliased("bar", Some("foo"), Some("cat")), "foo.bar"); - assert_eq!(aliased("bar", None, Some("cat")), "cat.bar"); - } - - #[test] - fn test_external_props() { - let record_schema = AvroSchema::Record { - name: Name { - name: "record".to_string(), - namespace: None, - aliases: Some(vec!["fooalias".to_string(), "baralias".to_string()]), - }, - doc: Some("record documentation".to_string()), - fields: vec![], - lookup: Default::default(), - }; - let props = external_props(&record_schema); - assert_eq!( - props.get("avro::doc"), - Some(&"record documentation".to_string()) - ); - assert_eq!( - props.get("avro::aliases"), - Some(&"[fooalias,baralias]".to_string()) - ); - let enum_schema = AvroSchema::Enum { - name: Name { - name: "enum".to_string(), - namespace: None, - aliases: Some(vec!["fooenum".to_string(), "barenum".to_string()]), - }, - doc: Some("enum documentation".to_string()), - symbols: vec![], - }; - let props = external_props(&enum_schema); - assert_eq!( - props.get("avro::doc"), - Some(&"enum documentation".to_string()) - ); - assert_eq!( - props.get("avro::aliases"), - Some(&"[fooenum,barenum]".to_string()) - ); - let fixed_schema = AvroSchema::Fixed { - name: Name { - name: "fixed".to_string(), - namespace: None, - aliases: Some(vec!["foofixed".to_string(), "barfixed".to_string()]), - }, - size: 1, - }; - let props = external_props(&fixed_schema); - assert_eq!( - props.get("avro::aliases"), - Some(&"[foofixed,barfixed]".to_string()) - ); - } - - #[test] - fn test_invalid_avro_schema() {} - - #[test] - fn test_plain_types_schema() { - let schema = AvroSchema::parse_str( - r#" - { - "type" : "record", - "name" : "topLevelRecord", - "fields" : [ { - "name" : "id", - "type" : [ "int", "null" ] - }, { - "name" : "bool_col", - "type" : [ "boolean", "null" ] - }, { - "name" : "tinyint_col", - "type" : [ "int", "null" ] - }, { - "name" : "smallint_col", - "type" : [ "int", "null" ] - }, { - "name" : "int_col", - "type" : [ "int", "null" ] - }, { - "name" : "bigint_col", - "type" : [ "long", "null" ] - }, { - "name" : "float_col", - "type" : [ "float", "null" ] - }, { - "name" : "double_col", - "type" : [ "double", "null" ] - }, { - "name" : "date_string_col", - "type" : [ "bytes", "null" ] - }, { - "name" : "string_col", - "type" : [ "bytes", "null" ] - }, { - "name" : "timestamp_col", - "type" : [ { - "type" : "long", - "logicalType" : "timestamp-micros" - }, "null" ] - } ] - }"#, - ); - assert!(schema.is_ok(), "{:?}", schema); - let arrow_schema = to_arrow_schema(&schema.unwrap()); - assert!(arrow_schema.is_ok(), "{:?}", arrow_schema); - let expected = Schema::new(vec![ - Field::new("id", Int32, true), - Field::new("bool_col", Boolean, true), - Field::new("tinyint_col", Int32, true), - Field::new("smallint_col", Int32, true), - Field::new("int_col", Int32, true), - Field::new("bigint_col", Int64, true), - Field::new("float_col", Float32, true), - Field::new("double_col", Float64, true), - Field::new("date_string_col", Binary, true), - Field::new("string_col", Binary, true), - Field::new("timestamp_col", Timestamp(Microsecond, None), true), - ]); - assert_eq!(arrow_schema.unwrap(), expected); - } - - #[test] - fn test_non_record_schema() { - let arrow_schema = to_arrow_schema(&AvroSchema::String); - assert!(arrow_schema.is_ok(), "{:?}", arrow_schema); - assert_eq!( - arrow_schema.unwrap(), - Schema::new(vec![Field::new("", Utf8, false)]) - ); - } -} diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index 515584b16c03..190c893d3e4c 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -82,8 +82,7 @@ mod tests { use super::*; use arrow::array::{ - BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, - TimestampMicrosecondArray, + BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, UInt64Array, }; use futures::StreamExt; @@ -235,9 +234,9 @@ mod tests { let array = batches[0] .column(0) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); - let mut values: Vec = vec![]; + let mut values: Vec = vec![]; for i in 0..batches[0].num_rows() { values.push(array.value(i)); } @@ -316,7 +315,7 @@ mod tests { let array = batches[0] .column(0) .as_any() - .downcast_ref::() + .downcast_ref::>() .unwrap(); let mut values: Vec<&str> = vec![]; for i in 0..batches[0].num_rows() { diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index 1edbffc91da9..45c3d3af1195 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -57,17 +57,17 @@ impl FileFormat for JsonFormat { } async fn infer_schema(&self, mut readers: ObjectReaderStream) -> Result { - let mut schemas = Vec::new(); + let mut fields = Vec::new(); let records_to_read = self.schema_infer_max_rec; while let Some(obj_reader) = readers.next().await { let mut reader = std::io::BufReader::new(obj_reader?.sync_reader()?); // FIXME: return number of records read from infer_json_schema so we can enforce // records_to_read - let schema = json::infer_json_schema(&mut reader, records_to_read)?; - schemas.push(schema); + let schema = json::read::infer(&mut reader, records_to_read)?; + fields.extend(schema); } - let schema = Schema::try_merge(schemas)?; + let schema = Schema::new(fields); Ok(Arc::new(schema)) } @@ -158,7 +158,7 @@ mod tests { let projection = Some(vec![0]); let exec = get_exec(&projection, 1024, None).await?; - let batches = collect(exec).await.expect("Collect batches"); + let batches = collect(exec).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs index a47bfac8b622..b5676669df00 100644 --- a/datafusion/src/error.rs +++ b/datafusion/src/error.rs @@ -23,8 +23,6 @@ use std::io; use std::result; use arrow::error::ArrowError; -#[cfg(feature = "avro")] -use avro_rs::Error as AvroError; use parquet::error::ParquetError; use sqlparser::parser::ParserError; @@ -39,9 +37,6 @@ pub enum DataFusionError { ArrowError(ArrowError), /// Wraps an error from the Parquet crate ParquetError(ParquetError), - /// Wraps an error from the Avro crate - #[cfg(feature = "avro")] - AvroError(AvroError), /// Error associated to I/O operations and associated traits. IoError(io::Error), /// Error returned when SQL is syntactically incorrect. @@ -88,13 +83,6 @@ impl From for DataFusionError { } } -#[cfg(feature = "avro")] -impl From for DataFusionError { - fn from(e: AvroError) -> Self { - DataFusionError::AvroError(e) - } -} - impl From for DataFusionError { fn from(e: ParserError) -> Self { DataFusionError::SQL(e) @@ -108,10 +96,6 @@ impl Display for DataFusionError { DataFusionError::ParquetError(ref desc) => { write!(f, "Parquet error: {}", desc) } - #[cfg(feature = "avro")] - DataFusionError::AvroError(ref desc) => { - write!(f, "Avro error: {}", desc) - } DataFusionError::IoError(ref desc) => write!(f, "IO error: {}", desc), DataFusionError::SQL(ref desc) => { write!(f, "SQL error: {:?}", desc) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6f72380b7227..89ea4380e1c0 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1899,9 +1899,9 @@ mod tests { #[ignore] async fn aggregate_decimal_min() -> Result<()> { let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") .await .unwrap(); @@ -1912,6 +1912,10 @@ mod tests { "| -100.009 |", "+-----------------+", ]; + assert_eq!( + &DataType::Decimal(10, 3), + result[0].schema().field(0).data_type() + ); assert_batches_sorted_eq!(expected, &result); Ok(()) } @@ -1920,6 +1924,7 @@ mod tests { #[ignore] async fn aggregate_decimal_max() -> Result<()> { let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1933,6 +1938,58 @@ mod tests { "| 110.009 |", "+-----------------+", ]; + assert_eq!( + &DataType::Decimal(10, 3), + result[0].schema().field(0).data_type() + ); + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_sum() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| SUM(d_table.c1) |", + "+-----------------+", + "| 100.0 |", + "+-----------------+", + ]; + assert_eq!( + &DataType::Decimal(20, 3), + result[0].schema().field(0).data_type() + ); + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| AVG(d_table.c1) |", + "+-----------------+", + "| 5.0 |", + "+-----------------+", + ]; + assert_eq!( + &DataType::Decimal(14, 7), + result[0].schema().field(0).data_type() + ); assert_batches_sorted_eq!(expected, &result); Ok(()) } diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index 448e2cd0cbe3..301925227722 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -78,6 +78,8 @@ pub trait StructArrayExt { fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef>; /// Return the number of fields in this struct array fn num_columns(&self) -> usize; + /// Return the column at the position + fn column(&self, pos: usize) -> ArrayRef; } impl StructArrayExt for StructArray { @@ -95,4 +97,15 @@ impl StructArrayExt for StructArray { fn num_columns(&self) -> usize { self.fields().len() } + + fn column(&self, pos: usize) -> ArrayRef { + self.values()[pos].clone() + } +} + +/// Converts a list of field / array pairs to a struct array +pub fn struct_array_from(pairs: Vec<(Field, ArrayRef)>) -> StructArray { + let fields: Vec = pairs.iter().map(|v| v.0.clone()).collect(); + let values = pairs.iter().map(|v| v.1.clone()).collect(); + StructArray::from_data(DataType::Struct(fields), values, None) } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 14a619b0a6c4..9620236c3721 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -197,14 +197,12 @@ //! //! cargo run --example csv_sql //! -//! cargo run --example parquet_sql +//! PARQUET_TEST_DATA=./parquet-testing/data cargo run --example parquet_sql //! //! cargo run --example dataframe //! //! cargo run --example dataframe_in_memory //! -//! cargo run --example parquet_sql -//! //! cargo run --example simple_udaf //! //! cargo run --example simple_udf diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 90d2ae22241e..fc609390bcc0 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -46,8 +46,8 @@ use std::{ use super::dfschema::ToDFSchema; use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType}; use crate::logical_plan::{ - columnize_expr, normalize_col, normalize_cols, Column, CrossJoin, DFField, DFSchema, - DFSchemaRef, Limit, Partitioning, Repartition, Values, + columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column, + CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values, }; use crate::sql::utils::group_window_expr_by_sort_keys; @@ -521,6 +521,8 @@ impl LogicalPlanBuilder { &self, exprs: impl IntoIterator> + Clone, ) -> Result { + let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?; + let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema @@ -530,7 +532,7 @@ impl LogicalPlanBuilder { .into_iter() .try_for_each::<_, Result<()>>(|expr| { let mut columns: HashSet = HashSet::new(); - utils::expr_to_columns(&expr.into(), &mut columns)?; + utils::expr_to_columns(&expr, &mut columns)?; columns.into_iter().for_each(|c| { if schema.field_from_column(&c).is_err() { diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 31143c4f616d..e8698b8b4f34 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -536,9 +536,10 @@ mod tests { fn from_qualified_schema_into_arrow_schema() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let arrow_schema: Schema = schema.into(); - let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; - assert_eq!(expected, arrow_schema.to_string()); + let expected = + "[Field { name: \"c0\", data_type: Boolean, nullable: true, metadata: {} }, \ + Field { name: \"c1\", data_type: Boolean, nullable: true, metadata: {} }]"; + assert_eq!(expected, format!("{:?}", arrow_schema.fields)); Ok(()) } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index eabb865ea008..5a55f398cdab 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -24,7 +24,9 @@ use arrow::{compute::cast::can_cast_types, datatypes::DataType}; use crate::error::{DataFusionError, Result}; use crate::field_util::get_indexed_field; -use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan}; +use crate::logical_plan::{ + plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, +}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, @@ -1326,7 +1328,6 @@ fn normalize_col_with_schemas( } /// Recursively normalize all Column expressions in a list of expression trees -#[inline] pub fn normalize_cols( exprs: impl IntoIterator>, plan: &LogicalPlan, @@ -1337,6 +1338,80 @@ pub fn normalize_cols( .collect() } +/// Rewrite sort on aggregate expressions to sort on the column of aggregate output +/// For example, `max(x)` is written to `col("MAX(x)")` +pub fn rewrite_sort_cols_by_aggs( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| { + let expr = e.into(); + match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let sort = Expr::Sort { + expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), + asc, + nulls_first, + }; + Ok(sort) + } + expr => Ok(expr), + } + }) + .collect() +} + +fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, aggr_expr, .. + }) => { + struct Rewriter<'a> { + plan: &'a LogicalPlan, + input: &'a LogicalPlan, + aggr_expr: &'a Vec, + } + + impl<'a> ExprRewriter for Rewriter<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + let normalized_expr = normalize_col(expr.clone(), self.plan); + if normalized_expr.is_err() { + // The expr is not based on Aggregate plan output. Skip it. + return Ok(expr); + } + let normalized_expr = normalized_expr.unwrap(); + if let Some(found_agg) = + self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + { + let agg = normalize_col(found_agg.clone(), self.plan)?; + let col = Expr::Column( + agg.to_field(self.input.schema()) + .map(|f| f.qualified_column())?, + ); + Ok(col) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut Rewriter { + plan, + input, + aggr_expr, + }) + } + LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), + _ => Ok(expr), + } +} + /// Recursively 'unnormalize' (remove all qualifiers) from an /// expression tree. /// @@ -1498,9 +1573,10 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond(Some( - (self.clone()).into(), - ))) + Expr::Literal(ScalarValue::TimestampNanosecond( + Some((self.clone()).into()), + None, + )) } } }; @@ -1584,7 +1660,7 @@ pub fn approx_distinct(expr: Expr) -> Expr { /// Create an convenience function representing a unary scalar function macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { - #[doc = "this scalar function is not documented yet"] + #[doc = concat!("Unary scalar function definition for ", stringify!($FUNC) ) ] pub fn $FUNC(e: Expr) -> Expr { Expr::ScalarFunction { fun: functions::BuiltinScalarFunction::$ENUM, @@ -1594,14 +1670,25 @@ macro_rules! unary_scalar_expr { }; } -/// Create an convenience function representing a binary scalar function -macro_rules! binary_scalar_expr { +macro_rules! scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + #[doc = concat!("Scalar function definition for ", stringify!($FUNC) ) ] + pub fn $FUNC($($arg: Expr),*) -> Expr { + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::$ENUM, + args: vec![$($arg),*], + } + } + }; +} + +macro_rules! nary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { - #[doc = "this scalar function is not documented yet"] - pub fn $FUNC(arg1: Expr, arg2: Expr) -> Expr { + #[doc = concat!("Scalar function definition for ", stringify!($FUNC) ) ] + pub fn $FUNC(args: Vec) -> Expr { Expr::ScalarFunction { fun: functions::BuiltinScalarFunction::$ENUM, - args: vec![arg1, arg2], + args, } } }; @@ -1630,44 +1717,44 @@ unary_scalar_expr!(Log10, log10); unary_scalar_expr!(Ln, ln); // string functions -unary_scalar_expr!(Ascii, ascii); -unary_scalar_expr!(BitLength, bit_length); -unary_scalar_expr!(Btrim, btrim); -unary_scalar_expr!(CharacterLength, character_length); -unary_scalar_expr!(CharacterLength, length); -unary_scalar_expr!(Chr, chr); -unary_scalar_expr!(InitCap, initcap); -unary_scalar_expr!(Left, left); -unary_scalar_expr!(Lower, lower); -unary_scalar_expr!(Lpad, lpad); -unary_scalar_expr!(Ltrim, ltrim); -unary_scalar_expr!(MD5, md5); -unary_scalar_expr!(OctetLength, octet_length); -unary_scalar_expr!(RegexpMatch, regexp_match); -unary_scalar_expr!(RegexpReplace, regexp_replace); -unary_scalar_expr!(Replace, replace); -unary_scalar_expr!(Repeat, repeat); -unary_scalar_expr!(Reverse, reverse); -unary_scalar_expr!(Right, right); -unary_scalar_expr!(Rpad, rpad); -unary_scalar_expr!(Rtrim, rtrim); -unary_scalar_expr!(SHA224, sha224); -unary_scalar_expr!(SHA256, sha256); -unary_scalar_expr!(SHA384, sha384); -unary_scalar_expr!(SHA512, sha512); -unary_scalar_expr!(SplitPart, split_part); -unary_scalar_expr!(StartsWith, starts_with); -unary_scalar_expr!(Strpos, strpos); -unary_scalar_expr!(Substr, substr); -unary_scalar_expr!(ToHex, to_hex); -unary_scalar_expr!(Translate, translate); -unary_scalar_expr!(Trim, trim); -unary_scalar_expr!(Upper, upper); +scalar_expr!(Ascii, ascii, string); +scalar_expr!(BitLength, bit_length, string); +nary_scalar_expr!(Btrim, btrim); +scalar_expr!(CharacterLength, character_length, string); +scalar_expr!(CharacterLength, length, string); +scalar_expr!(Chr, chr, string); +scalar_expr!(Digest, digest, string, algorithm); +scalar_expr!(InitCap, initcap, string); +scalar_expr!(Left, left, string, count); +scalar_expr!(Lower, lower, string); +nary_scalar_expr!(Lpad, lpad); +scalar_expr!(Ltrim, ltrim, string); +scalar_expr!(MD5, md5, string); +scalar_expr!(OctetLength, octet_length, string); +nary_scalar_expr!(RegexpMatch, regexp_match); +nary_scalar_expr!(RegexpReplace, regexp_replace); +scalar_expr!(Replace, replace, string, from, to); +scalar_expr!(Repeat, repeat, string, count); +scalar_expr!(Reverse, reverse, string); +scalar_expr!(Right, right, string, count); +nary_scalar_expr!(Rpad, rpad); +scalar_expr!(Rtrim, rtrim, string); +scalar_expr!(SHA224, sha224, string); +scalar_expr!(SHA256, sha256, string); +scalar_expr!(SHA384, sha384, string); +scalar_expr!(SHA512, sha512, string); +scalar_expr!(SplitPart, split_part, expr, delimiter, index); +scalar_expr!(StartsWith, starts_with, string, characters); +scalar_expr!(Strpos, strpos, string, substring); +scalar_expr!(Substr, substr, string, position); +scalar_expr!(ToHex, to_hex, string); +scalar_expr!(Translate, translate, string, from, to); +scalar_expr!(Trim, trim, string); +scalar_expr!(Upper, upper, string); // date functions -binary_scalar_expr!(DatePart, date_part); -binary_scalar_expr!(DateTrunc, date_trunc); -binary_scalar_expr!(Digest, digest); +scalar_expr!(DatePart, date_part, part, date); +scalar_expr!(DateTrunc, date_trunc, part, date); /// returns an array of fixed size with each argument on it. pub fn array(args: Vec) -> Expr { @@ -2057,7 +2144,8 @@ mod tests { #[test] fn test_lit_timestamp_nano() { let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 - let expected = col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10)))); + let expected = + col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); assert_eq!(expr, expected); let i: i64 = 10; @@ -2237,6 +2325,44 @@ mod tests { }}; } + macro_rules! test_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = vec![$(stringify!($arg)),*]; + let result = $FUNC( + $( + col(stringify!($arg.to_string())) + ),* + ); + if let Expr::ScalarFunction { fun, args } = result { + let name = functions::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; + } + + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = vec![$(stringify!($arg)),*]; + let result = $FUNC( + vec![ + $( + col(stringify!($arg.to_string())) + ),* + ] + ); + if let Expr::ScalarFunction { fun, args } = result { + let name = functions::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; + } + #[test] fn digest_function_definitions() { if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) { @@ -2268,39 +2394,62 @@ mod tests { test_unary_scalar_expr!(Log2, log2); test_unary_scalar_expr!(Log10, log10); test_unary_scalar_expr!(Ln, ln); - test_unary_scalar_expr!(Ascii, ascii); - test_unary_scalar_expr!(BitLength, bit_length); - test_unary_scalar_expr!(Btrim, btrim); - test_unary_scalar_expr!(CharacterLength, character_length); - test_unary_scalar_expr!(CharacterLength, length); - test_unary_scalar_expr!(Chr, chr); - test_unary_scalar_expr!(InitCap, initcap); - test_unary_scalar_expr!(Left, left); - test_unary_scalar_expr!(Lower, lower); - test_unary_scalar_expr!(Lpad, lpad); - test_unary_scalar_expr!(Ltrim, ltrim); - test_unary_scalar_expr!(MD5, md5); - test_unary_scalar_expr!(OctetLength, octet_length); - test_unary_scalar_expr!(RegexpMatch, regexp_match); - test_unary_scalar_expr!(RegexpReplace, regexp_replace); - test_unary_scalar_expr!(Replace, replace); - test_unary_scalar_expr!(Repeat, repeat); - test_unary_scalar_expr!(Reverse, reverse); - test_unary_scalar_expr!(Right, right); - test_unary_scalar_expr!(Rpad, rpad); - test_unary_scalar_expr!(Rtrim, rtrim); - test_unary_scalar_expr!(SHA224, sha224); - test_unary_scalar_expr!(SHA256, sha256); - test_unary_scalar_expr!(SHA384, sha384); - test_unary_scalar_expr!(SHA512, sha512); - test_unary_scalar_expr!(SplitPart, split_part); - test_unary_scalar_expr!(StartsWith, starts_with); - test_unary_scalar_expr!(Strpos, strpos); - test_unary_scalar_expr!(Substr, substr); - test_unary_scalar_expr!(ToHex, to_hex); - test_unary_scalar_expr!(Translate, translate); - test_unary_scalar_expr!(Trim, trim); - test_unary_scalar_expr!(Upper, upper); + + test_scalar_expr!(Ascii, ascii, input); + test_scalar_expr!(BitLength, bit_length, string); + test_nary_scalar_expr!(Btrim, btrim, string); + test_nary_scalar_expr!(Btrim, btrim, string, characters); + test_scalar_expr!(CharacterLength, character_length, string); + test_scalar_expr!(CharacterLength, length, string); + test_scalar_expr!(Chr, chr, string); + test_scalar_expr!(Digest, digest, string, algorithm); + test_scalar_expr!(InitCap, initcap, string); + test_scalar_expr!(Left, left, string, count); + test_scalar_expr!(Lower, lower, string); + test_nary_scalar_expr!(Lpad, lpad, string, count); + test_nary_scalar_expr!(Lpad, lpad, string, count, characters); + test_scalar_expr!(Ltrim, ltrim, string); + test_scalar_expr!(MD5, md5, string); + test_scalar_expr!(OctetLength, octet_length, string); + test_nary_scalar_expr!(RegexpMatch, regexp_match, string, pattern); + test_nary_scalar_expr!(RegexpMatch, regexp_match, string, pattern, flags); + test_nary_scalar_expr!( + RegexpReplace, + regexp_replace, + string, + pattern, + replacement + ); + test_nary_scalar_expr!( + RegexpReplace, + regexp_replace, + string, + pattern, + replacement, + flags + ); + test_scalar_expr!(Replace, replace, string, from, to); + test_scalar_expr!(Repeat, repeat, string, count); + test_scalar_expr!(Reverse, reverse, string); + test_scalar_expr!(Right, right, string, count); + test_nary_scalar_expr!(Rpad, rpad, string, count); + test_nary_scalar_expr!(Rpad, rpad, string, count, characters); + test_scalar_expr!(Rtrim, rtrim, string); + test_scalar_expr!(SHA224, sha224, string); + test_scalar_expr!(SHA256, sha256, string); + test_scalar_expr!(SHA384, sha384, string); + test_scalar_expr!(SHA512, sha512, string); + test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); + test_scalar_expr!(StartsWith, starts_with, string, characters); + test_scalar_expr!(Strpos, strpos, string, substring); + test_scalar_expr!(Substr, substr, string, position); + test_scalar_expr!(ToHex, to_hex, string); + test_scalar_expr!(Translate, translate, string, from, to); + test_scalar_expr!(Trim, trim, string); + test_scalar_expr!(Upper, upper, string); + + test_scalar_expr!(DatePart, date_part, part, date); + test_scalar_expr!(DateTrunc, date_trunc, part, date); } #[test] diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index a20d57206749..56fec3cf1a0c 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -42,11 +42,11 @@ pub use expr::{ create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round, - rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, - starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, - unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter, - ExpressionVisitor, Literal, Recursion, RewriteRecursion, + regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, + signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, + translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, + Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index bf89c9391c28..fdfd3f3ca267 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -127,6 +127,14 @@ impl ops::Div for Expr { } } +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} + #[cfg(test)] mod tests { use crate::prelude::lit; @@ -149,5 +157,9 @@ mod tests { format!("{:?}", lit(1u32) / lit(2u32)), "UInt32(1) / UInt32(2)" ); + assert_eq!( + format!("{:?}", lit(1u32) % lit(2u32)), + "UInt32(1) % UInt32(2)" + ); } } diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 6d717df23912..2f448ea73c04 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -45,33 +45,18 @@ use crate::{error::Result, logical_plan::Operator}; /// pub struct SimplifyExpressions {} -fn expr_contains(expr: &Expr, needle: &Expr) -> bool { +/// returns true if `needle` is found in a chain of search_op +/// expressions. Such as: (A AND B) AND C +fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } => expr_contains(left, needle) || expr_contains(right, needle), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } => expr_contains(left, needle) || expr_contains(right, needle), + Expr::BinaryExpr { left, op, right } if *op == search_op => { + expr_contains(left, needle, search_op) + || expr_contains(right, needle, search_op) + } _ => expr == needle, } } -fn as_binary_expr(expr: &Expr) -> Option<&Expr> { - match expr { - Expr::BinaryExpr { .. } => Some(expr), - _ => None, - } -} - -fn operator_is_boolean(op: Operator) -> bool { - op == Operator::And || op == Operator::Or -} - fn is_one(s: &Expr) -> bool { match s { Expr::Literal(ScalarValue::Int8(Some(1))) @@ -95,6 +80,22 @@ fn is_true(expr: &Expr) -> bool { } } +/// returns true if expr is a +/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise +fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) +} + +/// Return a literal NULL value +fn lit_null() -> Expr { + Expr::Literal(ScalarValue::Boolean(None)) +} + +/// returns true if expr is a `Not(_)`, false otherwise +fn is_not(expr: &Expr) -> bool { + matches!(expr, Expr::Not(_)) +} + fn is_null(expr: &Expr) -> bool { match expr { Expr::Literal(v) => v.is_null(), @@ -109,160 +110,27 @@ fn is_false(expr: &Expr) -> bool { } } -fn simplify(expr: &Expr) -> Expr { - match expr { - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if is_true(left) || is_true(right) => lit(true), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if is_false(left) => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if is_false(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if left == right => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if is_false(left) || is_false(right) => lit(false), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if is_true(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if is_true(left) => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if left == right => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::Multiply, - right, - } if is_one(left) => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::Multiply, - right, - } if is_one(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::Divide, - right, - } if is_one(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::Divide, - right, - } if left == right && is_null(left) => *left.clone(), - Expr::BinaryExpr { - left, - op: Operator::Divide, - right, - } if left == right => lit(1), +/// returns true if `haystack` looks like (needle OP X) or (X OP needle) +fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { + match haystack { Expr::BinaryExpr { left, op, right } - if left == right && operator_is_boolean(*op) => + if op == &target_op + && (needle == left.as_ref() || needle == right.as_ref()) => { - simplify(left) + true } - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if expr_contains(left, right) => as_binary_expr(left) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&x.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&*right.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if expr_contains(right, left) => as_binary_expr(right) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&*right.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&*left.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if expr_contains(left, right) => as_binary_expr(left) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&*right.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&x.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if expr_contains(right, left) => as_binary_expr(right) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&*left.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&x.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { - left: Box::new(simplify(left)), - op: *op, - right: Box::new(simplify(right)), - }, - _ => expr.clone(), + _ => false, + } +} + +/// returns the contained boolean value in `expr` as +/// `Expr::Literal(ScalarValue::Boolean(v))`. +/// +/// panics if expr is not a literal boolean +fn as_bool_lit(expr: Expr) -> Option { + match expr { + Expr::Literal(ScalarValue::Boolean(v)) => v, + _ => panic!("Expected boolean literal, got {:?}", expr), } } @@ -281,11 +149,9 @@ impl OptimizerRule for SimplifyExpressions { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. - let mut simplifier = - super::simplify_expressions::Simplifier::new(plan.all_schemas()); + let mut simplifier = Simplifier::new(plan.all_schemas()); - let mut const_evaluator = - super::simplify_expressions::ConstEvaluator::new(execution_props); + let mut const_evaluator = ConstEvaluator::new(execution_props); let new_inputs = plan .inputs() @@ -301,9 +167,6 @@ impl OptimizerRule for SimplifyExpressions { // Constant folding should not change expression name. let name = &e.name(plan.schema()); - // TODO combine simplify into Simplifier - let e = simplify(&e); - // TODO iterate until no changes are made // during rewrite (evaluating constants can // enable new simplifications and @@ -316,7 +179,6 @@ impl OptimizerRule for SimplifyExpressions { let new_name = &new_e.name(plan.schema()); - // TODO simplify this logic if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) { if expr_name != new_expr_name { Ok(new_e.alias(expr_name)) @@ -497,7 +359,7 @@ impl ConstEvaluator { } /// Internal helper to evaluates an Expr - fn evaluate_to_scalar(&self, expr: Expr) -> Result { + pub(crate) fn evaluate_to_scalar(&self, expr: Expr) -> Result { if let Expr::Literal(s) = expr { return Ok(s); } @@ -554,212 +416,252 @@ impl<'a> Simplifier<'a> { false } - fn boolean_folding_for_or( - const_bool: &Option, - bool_expr: Box, - left_right_order: bool, - ) -> Expr { - // See if we can fold 'const_bool OR bool_expr' to a constant boolean - match const_bool { - // TRUE or expr (including NULL) = TRUE - Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))), - // FALSE or expr (including NULL) = expr - Some(false) => *bool_expr, - None => match *bool_expr { - // NULL or TRUE = TRUE - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(Some(true))) - } - // NULL or FALSE = NULL - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL or NULL = NULL - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL or expr can be either NULL or TRUE - // So let us not rewrite it - _ => { - let mut left = - Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool))); - let mut right = bool_expr; - if !left_right_order { - std::mem::swap(&mut left, &mut right); - } - - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } - } - }, - } - } - - fn boolean_folding_for_and( - const_bool: &Option, - bool_expr: Box, - left_right_order: bool, - ) -> Expr { - // See if we can fold 'const_bool AND bool_expr' to a constant boolean - match const_bool { - // TRUE and expr (including NULL) = expr - Some(true) => *bool_expr, - // FALSE and expr (including NULL) = FALSE - Some(false) => Expr::Literal(ScalarValue::Boolean(Some(false))), - None => match *bool_expr { - // NULL and TRUE = NULL - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL and FALSE = FALSE - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(Some(false))) - } - // NULL and NULL = NULL - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL and expr can either be NULL or FALSE - // So let us not rewrite it - _ => { - let mut left = - Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool))); - let mut right = bool_expr; - if !left_right_order { - std::mem::swap(&mut left, &mut right); - } - - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } - } - }, - } + /// Returns true if expr is nullable + fn nullable(&self, expr: &Expr) -> Result { + self.schemas + .iter() + .find_map(|schema| { + // expr may be from another input, so ignore errors + // by converting to None to keep trying + expr.nullable(schema.as_ref()).ok() + }) + .ok_or_else(|| { + // This means we weren't able to compute `Expr::nullable` with + // *any* input schemas, signalling a problem + DataFusionError::Internal(format!( + "Could not find find columns in '{}' during simplify", + expr + )) + }) } } impl<'a> ExprRewriter for Simplifier<'a> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { + use Expr::*; + use Operator::{And, Divide, Eq, Multiply, NotEq, Or}; + let new_expr = match expr { - Expr::BinaryExpr { left, op, right } => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => *right, - Some(false) => Expr::Not(right), - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => *left, - Some(false) => Expr::Not(left), - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - _ => Expr::BinaryExpr { - left, - op: Operator::Eq, - right, - }, - }, - Operator::NotEq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => Expr::Not(right), - Some(false) => *right, - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => Expr::Not(left), - Some(false) => *left, - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - _ => Expr::BinaryExpr { - left, - op: Operator::NotEq, - right, - }, - }, - Operator::Or => match (left.as_ref(), right.as_ref()) { - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - Self::boolean_folding_for_or(b, right, true) - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - Self::boolean_folding_for_or(b, left, false) - } - _ => Expr::BinaryExpr { - left, - op: Operator::Or, - right, - }, - }, - Operator::And => match (left.as_ref(), right.as_ref()) { - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - Self::boolean_folding_for_and(b, right, true) - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - Self::boolean_folding_for_and(b, left, false) - } - _ => Expr::BinaryExpr { - left, - op: Operator::And, - right, - }, - }, - _ => Expr::BinaryExpr { left, op, right }, - }, - // Not(Not(expr)) --> expr - Expr::Not(inner) => { - if let Expr::Not(negated_inner) = *inner { - *negated_inner - } else { - Expr::Not(inner) + // + // Rules for Eq + // + + // true = A --> A + // false = A --> !A + // null = A --> null + BinaryExpr { + left, + op: Eq, + right, + } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + match as_bool_lit(*left) { + Some(true) => *right, + Some(false) => Not(right), + None => lit_null(), } } + // A = true --> A + // A = false --> !A + // A = null --> null + BinaryExpr { + left, + op: Eq, + right, + } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + match as_bool_lit(*right) { + Some(true) => *left, + Some(false) => Not(left), + None => lit_null(), + } + } + + // + // Rules for NotEq + // + + // true != A --> !A + // false != A --> A + // null != A --> null + BinaryExpr { + left, + op: NotEq, + right, + } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + match as_bool_lit(*left) { + Some(true) => Not(right), + Some(false) => *right, + None => lit_null(), + } + } + // A != true --> !A + // A != false --> A + // A != null --> null, + BinaryExpr { + left, + op: NotEq, + right, + } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + match as_bool_lit(*right) { + Some(true) => Not(left), + Some(false) => *left, + None => lit_null(), + } + } + + // + // Rules for OR + // + + // true OR A --> true (even if A is null) + BinaryExpr { + left, + op: Or, + right: _, + } if is_true(&left) => *left, + // false OR A --> A + BinaryExpr { + left, + op: Or, + right, + } if is_false(&left) => *right, + // A OR true --> true (even if A is null) + BinaryExpr { + left: _, + op: Or, + right, + } if is_true(&right) => *right, + // A OR false --> A + BinaryExpr { + left, + op: Or, + right, + } if is_false(&right) => *left, + // (..A..) OR A --> (..A..) + BinaryExpr { + left, + op: Or, + right, + } if expr_contains(&left, &right, Or) => *left, + // A OR (..A..) --> (..A..) + BinaryExpr { + left, + op: Or, + right, + } if expr_contains(&right, &left, Or) => *right, + // A OR (A AND B) --> A (if B not null) + BinaryExpr { + left, + op: Or, + right, + } if !self.nullable(&right)? && is_op_with(And, &right, &left) => *left, + // (A AND B) OR A --> A (if B not null) + BinaryExpr { + left, + op: Or, + right, + } if !self.nullable(&left)? && is_op_with(And, &left, &right) => *right, + + // + // Rules for AND + // + + // true AND A --> A + BinaryExpr { + left, + op: And, + right, + } if is_true(&left) => *right, + // false AND A --> false (even if A is null) + BinaryExpr { + left, + op: And, + right: _, + } if is_false(&left) => *left, + // A AND true --> A + BinaryExpr { + left, + op: And, + right, + } if is_true(&right) => *left, + // A AND false --> false (even if A is null) + BinaryExpr { + left: _, + op: And, + right, + } if is_false(&right) => *right, + // (..A..) AND A --> (..A..) + BinaryExpr { + left, + op: And, + right, + } if expr_contains(&left, &right, And) => *left, + // A AND (..A..) --> (..A..) + BinaryExpr { + left, + op: And, + right, + } if expr_contains(&right, &left, And) => *right, + // A AND (A OR B) --> A (if B not null) + BinaryExpr { + left, + op: And, + right, + } if !self.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + // (A OR B) AND A --> A (if B not null) + BinaryExpr { + left, + op: And, + right, + } if !self.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + + // + // Rules for Multiply + // + BinaryExpr { + left, + op: Multiply, + right, + } if is_one(&right) => *left, + BinaryExpr { + left, + op: Multiply, + right, + } if is_one(&left) => *right, + + // + // Rules for Divide + // + + // A / 1 --> A + BinaryExpr { + left, + op: Divide, + right, + } if is_one(&right) => *left, + // A / null --> null + BinaryExpr { + left, + op: Divide, + right, + } if left == right && is_null(&left) => *left, + // A / A --> 1 (if a is not nullable) + BinaryExpr { + left, + op: Divide, + right, + } if !self.nullable(&left)? && left == right => lit(1), + + // + // Rules for Not + // + + // !(!A) --> A + Not(inner) if is_not(&inner) => match *inner { + Not(negated_inner) => *negated_inner, + _ => unreachable!(), + }, + expr => { // no additional rewrites possible expr @@ -791,8 +693,8 @@ mod tests { let expr_b = lit(true).or(col("c2")); let expected = lit(true); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -801,8 +703,8 @@ mod tests { let expr_b = col("c2").or(lit(false)); let expected = col("c2"); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -810,7 +712,7 @@ mod tests { let expr = col("c2").or(col("c2")); let expected = col("c2"); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -819,8 +721,8 @@ mod tests { let expr_b = col("c2").and(lit(false)); let expected = lit(false); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -828,7 +730,7 @@ mod tests { let expr = col("c2").and(col("c2")); let expected = col("c2"); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -837,8 +739,8 @@ mod tests { let expr_b = col("c2").and(lit(true)); let expected = col("c2"); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -847,8 +749,8 @@ mod tests { let expr_b = binary_expr(lit(1), Operator::Multiply, col("c2")); let expected = col("c2"); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -856,15 +758,24 @@ mod tests { let expr = binary_expr(col("c2"), Operator::Divide, lit(1)); let expected = col("c2"); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_divide_by_same() { let expr = binary_expr(col("c2"), Operator::Divide, col("c2")); + // if c2 is null, c2 / c2 = null, so can't simplify + let expected = expr.clone(); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_divide_by_same_non_null() { + let expr = binary_expr(col("c2_non_null"), Operator::Divide, col("c2_non_null")); let expected = lit(1); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -873,21 +784,21 @@ mod tests { let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5))); let expected = col("c2").gt(lit(5)); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_composed_and() { - // ((c > 5) AND (d < 6)) AND (c > 5) + // ((c > 5) AND (c1 < 6)) AND (c > 5) let expr = binary_expr( - binary_expr(col("c2").gt(lit(5)), Operator::And, col("d").lt(lit(6))), + binary_expr(col("c2").gt(lit(5)), Operator::And, col("c1").lt(lit(6))), Operator::And, col("c2").gt(lit(5)), ); let expected = - binary_expr(col("c2").gt(lit(5)), Operator::And, col("d").lt(lit(6))); + binary_expr(col("c2").gt(lit(5)), Operator::And, col("c1").lt(lit(6))); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -900,20 +811,91 @@ mod tests { ); let expected = expr.clone(); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_or_and() { - // (c > 5) OR ((d < 6) AND (c > 5) -- can remove - let expr = binary_expr( - col("c2").gt(lit(5)), + let l = col("c2").gt(lit(5)); + let r = binary_expr(col("c1").lt(lit(6)), Operator::And, col("c2").gt(lit(5))); + + // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) + let expr = binary_expr(l.clone(), Operator::Or, r.clone()); + + // no rewrites if c1 can be null + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) + let expr = binary_expr(l, Operator::Or, r); + + // no rewrites if c1 can be null + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_or_and_non_null() { + let l = col("c2_non_null").gt(lit(5)); + let r = binary_expr( + col("c1_non_null").lt(lit(6)), + Operator::And, + col("c2_non_null").gt(lit(5)), + ); + + // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5 + let expr = binary_expr(l.clone(), Operator::Or, r.clone()); + + // This is only true if `c1 < 6` is not nullable / can not be null. + let expected = col("c2_non_null").gt(lit(5)); + + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5 + let expr = binary_expr(l, Operator::Or, r); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_and_or() { + let l = col("c2").gt(lit(5)); + let r = binary_expr(col("c1").lt(lit(6)), Operator::Or, col("c2").gt(lit(5))); + + // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 + let expr = binary_expr(l.clone(), Operator::And, r.clone()); + + // no rewrites if c1 can be null + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 + let expr = binary_expr(l, Operator::And, r); + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_and_or_non_null() { + let l = col("c2_non_null").gt(lit(5)); + let r = binary_expr( + col("c1_non_null").lt(lit(6)), Operator::Or, - binary_expr(col("d").lt(lit(6)), Operator::And, col("c2").gt(lit(5))), + col("c2_non_null").gt(lit(5)), ); - let expected = col("c2").gt(lit(5)); - assert_eq!(simplify(&expr), expected); + // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 + let expr = binary_expr(l.clone(), Operator::And, r.clone()); + + // This is only true if `c1 < 6` is not nullable / can not be null. + let expected = col("c2_non_null").gt(lit(5)); + + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 + let expr = binary_expr(l, Operator::And, r); + + assert_eq!(simplify(expr), expected); } #[test] @@ -921,7 +903,7 @@ mod tests { let expr = binary_expr(lit_null(), Operator::And, lit(false)); let expr_eq = lit(false); - assert_eq!(simplify(&expr), expr_eq); + assert_eq!(simplify(expr), expr_eq); } #[test] @@ -930,16 +912,16 @@ mod tests { let expr_plus = binary_expr(null.clone(), Operator::Divide, null.clone()); let expr_eq = null; - assert_eq!(simplify(&expr_plus), expr_eq); + assert_eq!(simplify(expr_plus), expr_eq); } #[test] - fn test_simplify_do_not_simplify_arithmetic_expr() { + fn test_simplify_simplify_arithmetic_expr() { let expr_plus = binary_expr(lit(1), Operator::Plus, lit(1)); let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1)); - assert_eq!(simplify(&expr_plus), expr_plus); - assert_eq!(simplify(&expr_eq), expr_eq); + assert_eq!(simplify(expr_plus), lit(2)); + assert_eq!(simplify(expr_eq), lit(true)); } // ------------------------------ @@ -1182,11 +1164,17 @@ mod tests { // ----- Simplifier tests ------- // ------------------------------ - // TODO rename to simplify - fn do_simplify(expr: Expr) -> Expr { + fn simplify(expr: Expr) -> Expr { let schema = expr_test_schema(); let mut rewriter = Simplifier::new(vec![&schema]); - expr.rewrite(&mut rewriter).expect("expected to simplify") + + let execution_props = ExecutionProps::new(); + let mut const_evaluator = ConstEvaluator::new(&execution_props); + + expr.rewrite(&mut rewriter) + .expect("expected to simplify") + .rewrite(&mut const_evaluator) + .expect("expected to const evaluate") } fn expr_test_schema() -> DFSchemaRef { @@ -1194,6 +1182,8 @@ mod tests { DFSchema::new(vec![ DFField::new(None, "c1", DataType::Utf8, true), DFField::new(None, "c2", DataType::Boolean, true), + DFField::new(None, "c1_non_null", DataType::Utf8, false), + DFField::new(None, "c2_non_null", DataType::Boolean, false), ]) .unwrap(), ) @@ -1201,20 +1191,20 @@ mod tests { #[test] fn simplify_expr_not_not() { - assert_eq!(do_simplify(col("c2").not().not().not()), col("c2").not(),); + assert_eq!(simplify(col("c2").not().not().not()), col("c2").not(),); } #[test] fn simplify_expr_null_comparison() { // x = null is always null assert_eq!( - do_simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))), + simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null != null is always null assert_eq!( - do_simplify( + simplify( lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))) ), lit(ScalarValue::Boolean(None)), @@ -1222,13 +1212,13 @@ mod tests { // x != null is always null assert_eq!( - do_simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))), + simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - do_simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))), + simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))), lit(ScalarValue::Boolean(None)), ); } @@ -1239,16 +1229,16 @@ mod tests { assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); // true = ture -> true - assert_eq!(do_simplify(lit(true).eq(lit(true))), lit(true)); + assert_eq!(simplify(lit(true).eq(lit(true))), lit(true)); // true = false -> false - assert_eq!(do_simplify(lit(true).eq(lit(false))), lit(false),); + assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),); // c2 = true -> c2 - assert_eq!(do_simplify(col("c2").eq(lit(true))), col("c2")); + assert_eq!(simplify(col("c2").eq(lit(true))), col("c2")); // c2 = false => !c2 - assert_eq!(do_simplify(col("c2").eq(lit(false))), col("c2").not(),); + assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),); } #[test] @@ -1262,25 +1252,8 @@ mod tests { // Make sure c1 column to be used in tests is not boolean type assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); - // don't fold c1 = true - assert_eq!( - do_simplify(col("c1").eq(lit(true))), - col("c1").eq(lit(true)), - ); - - // don't fold c1 = false - assert_eq!( - do_simplify(col("c1").eq(lit(false))), - col("c1").eq(lit(false)), - ); - - // test constant operands - assert_eq!(do_simplify(lit(1).eq(lit(true))), lit(1).eq(lit(true)),); - - assert_eq!( - do_simplify(lit("a").eq(lit(false))), - lit("a").eq(lit(false)), - ); + // don't fold c1 = foo + assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),); } #[test] @@ -1290,15 +1263,15 @@ mod tests { assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); // c2 != true -> !c2 - assert_eq!(do_simplify(col("c2").not_eq(lit(true))), col("c2").not(),); + assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),); // c2 != false -> c2 - assert_eq!(do_simplify(col("c2").not_eq(lit(false))), col("c2"),); + assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),); // test constant - assert_eq!(do_simplify(lit(true).not_eq(lit(true))), lit(false),); + assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),); - assert_eq!(do_simplify(lit(true).not_eq(lit(false))), lit(true),); + assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),); } #[test] @@ -1311,44 +1284,25 @@ mod tests { assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); assert_eq!( - do_simplify(col("c1").not_eq(lit(true))), - col("c1").not_eq(lit(true)), - ); - - assert_eq!( - do_simplify(col("c1").not_eq(lit(false))), - col("c1").not_eq(lit(false)), - ); - - // test constants - assert_eq!( - do_simplify(lit(1).not_eq(lit(true))), - lit(1).not_eq(lit(true)), - ); - - assert_eq!( - do_simplify(lit("a").not_eq(lit(false))), - lit("a").not_eq(lit(false)), + simplify(col("c1").not_eq(lit("foo"))), + col("c1").not_eq(lit("foo")), ); } #[test] fn simplify_expr_case_when_then_else() { assert_eq!( - do_simplify(Expr::Case { + simplify(Expr::Case { expr: None, when_then_expr: vec![( Box::new(col("c2").not_eq(lit(false))), - Box::new(lit("ok").eq(lit(true))), + Box::new(lit("ok").eq(lit("not_ok"))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), Expr::Case { expr: None, - when_then_expr: vec![( - Box::new(col("c2")), - Box::new(lit("ok").eq(lit(true))) - )], + when_then_expr: vec![(Box::new(col("c2")), Box::new(lit(false)))], else_expr: Some(Box::new(col("c2"))), } ); @@ -1362,22 +1316,22 @@ mod tests { #[test] fn simplify_expr_bool_or() { // col || true is always true - assert_eq!(do_simplify(col("c2").or(lit(true))), lit(true),); + assert_eq!(simplify(col("c2").or(lit(true))), lit(true),); // col || false is always col - assert_eq!(do_simplify(col("c2").or(lit(false))), col("c2"),); + assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),); // true || null is always true - assert_eq!(do_simplify(lit(true).or(lit_null())), lit(true),); + assert_eq!(simplify(lit(true).or(lit_null())), lit(true),); // null || true is always true - assert_eq!(do_simplify(lit_null().or(lit(true))), lit(true),); + assert_eq!(simplify(lit_null().or(lit(true))), lit(true),); // false || null is always null - assert_eq!(do_simplify(lit(false).or(lit_null())), lit_null(),); + assert_eq!(simplify(lit(false).or(lit_null())), lit_null(),); // null || false is always null - assert_eq!(do_simplify(lit_null().or(lit(false))), lit_null(),); + assert_eq!(simplify(lit_null().or(lit(false))), lit_null(),); // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL) // it can be either NULL or TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)` @@ -1389,28 +1343,28 @@ mod tests { high: Box::new(lit(10)), }; let expr = expr.or(lit_null()); - let result = do_simplify(expr.clone()); + let result = simplify(expr.clone()); assert_eq!(expr, result); } #[test] fn simplify_expr_bool_and() { // col & true is always col - assert_eq!(do_simplify(col("c2").and(lit(true))), col("c2"),); + assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),); // col & false is always false - assert_eq!(do_simplify(col("c2").and(lit(false))), lit(false),); + assert_eq!(simplify(col("c2").and(lit(false))), lit(false),); // true && null is always null - assert_eq!(do_simplify(lit(true).and(lit_null())), lit_null(),); + assert_eq!(simplify(lit(true).and(lit_null())), lit_null(),); // null && true is always null - assert_eq!(do_simplify(lit_null().and(lit(true))), lit_null(),); + assert_eq!(simplify(lit_null().and(lit(true))), lit_null(),); // false && null is always false - assert_eq!(do_simplify(lit(false).and(lit_null())), lit(false),); + assert_eq!(simplify(lit(false).and(lit_null())), lit(false),); // null && false is always false - assert_eq!(do_simplify(lit_null().and(lit(false))), lit(false),); + assert_eq!(simplify(lit_null().and(lit(false))), lit(false),); // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL) // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10` @@ -1422,7 +1376,7 @@ mod tests { high: Box::new(lit(10)), }; let expr = expr.and(lit_null()); - let result = do_simplify(expr.clone()); + let result = simplify(expr.clone()); assert_eq!(expr, result); } @@ -1473,12 +1427,12 @@ mod tests { ); } - // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6) #[test] fn test_simplify_optimized_plan_with_composed_and() { let table_scan = test_table_scan(); + // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6) let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")]) + .project(vec![col("a"), col("b")]) .unwrap() .filter(and( and(col("a").gt(lit(5)), col("b").lt(lit(6))), @@ -1492,7 +1446,7 @@ mod tests { &plan, "\ Filter: #test.a > Int32(5) AND #test.b < Int32(6) AS test.a > Int32(5) AND test.b < Int32(6) AND test.a > Int32(5)\ - \n Projection: #test.a\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None", ); } @@ -1703,7 +1657,7 @@ mod tests { .build() .unwrap(); - let expected = "Projection: TimestampNanosecond(1599566400000000000) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ + let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ \n TableScan: test projection=None" .to_string(); let actual = get_optimized_plan_formatted(&plan, &Utc::now()); @@ -1779,7 +1733,7 @@ mod tests { // expect the same timestamp appears in both exprs let actual = get_optimized_plan_formatted(&plan, &time); let expected = format!( - "Projection: TimestampNanosecond({}) AS now(), TimestampNanosecond({}) AS t2\ + "Projection: TimestampNanosecond({}, Some(\"UTC\")) AS now(), TimestampNanosecond({}, Some(\"UTC\")) AS t2\ \n TableScan: test projection=None", time.timestamp_nanos(), time.timestamp_nanos() diff --git a/datafusion/src/optimizer/single_distinct_to_groupby.rs b/datafusion/src/optimizer/single_distinct_to_groupby.rs index 3232fa03ce80..9bddec997db6 100644 --- a/datafusion/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/src/optimizer/single_distinct_to_groupby.rs @@ -20,7 +20,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Projection}; -use crate::logical_plan::{columnize_expr, DFSchema, Expr, LogicalPlan}; +use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use hashbrown::HashSet; @@ -34,14 +34,16 @@ use std::sync::Arc; /// /// Into /// -/// SELECT F1(s),F2(s) +/// SELECT F1(alias1),F2(alias1) /// FROM ( -/// SELECT s, k ... GROUP BY s, k +/// SELECT s as alias1, k ... GROUP BY s, k /// ) /// GROUP BY k /// ``` pub struct SingleDistinctToGroupBy {} +const SINGLE_DISTINCT_ALIAS: &str = "alias1"; + impl SingleDistinctToGroupBy { #[allow(missing_docs)] pub fn new() -> Self { @@ -69,11 +71,12 @@ fn optimize(plan: &LogicalPlan) -> Result { if group_fields_set .insert(args[0].name(input.schema()).unwrap()) { - all_group_args.push(args[0].clone()); + all_group_args + .push(args[0].clone().alias(SINGLE_DISTINCT_ALIAS)); } Expr::AggregateFunction { fun: fun.clone(), - args: args.clone(), + args: vec![col(SINGLE_DISTINCT_ALIAS)], distinct: false, } } @@ -104,7 +107,6 @@ fn optimize(plan: &LogicalPlan) -> Result { ) .unwrap(), ); - let final_agg = LogicalPlan::Aggregate(Aggregate { input: Arc::new(grouped_agg.unwrap()), group_expr: group_expr.clone(), @@ -191,7 +193,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{col, count, count_distinct, max, LogicalPlanBuilder}; + use crate::logical_plan::{col, count, count_distinct, lit, max, LogicalPlanBuilder}; use crate::physical_plan::aggregates; use crate::test::*; @@ -229,9 +231,26 @@ mod tests { .build()?; // Should work - let expected = "Projection: #COUNT(test.b) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#test.b)]] [COUNT(test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.b]], aggr=[[]] [b:UInt32]\ + let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ + \n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn single_distinct_expr() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? + .build()?; + + let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ + \n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -247,9 +266,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(test.b) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#test.b)]] [a:UInt32, COUNT(test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.a, #test.b]], aggr=[[]] [a:UInt32, b:UInt32]\ + let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\ + \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -293,9 +312,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(test.b) AS COUNT(DISTINCT test.b), #MAX(test.b) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#test.b), MAX(#test.b)]] [a:UInt32, COUNT(test.b):UInt64;N, MAX(test.b):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a, #test.b]], aggr=[[]] [a:UInt32, b:UInt32]\ + let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 228d304dcb84..888de9aeb8bc 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -35,7 +35,9 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use expressions::{avg_return_type, sum_return_type}; +use expressions::{ + avg_return_type, stddev_return_type, sum_return_type, variance_return_type, +}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -64,6 +66,14 @@ pub enum AggregateFunction { ApproxDistinct, /// array_agg ArrayAgg, + /// Variance (Sample) + Variance, + /// Variance (Population) + VariancePop, + /// Standard Deviation (Sample) + Stddev, + /// Standard Deviation (Population) + StddevPop, } impl fmt::Display for AggregateFunction { @@ -84,6 +94,12 @@ impl FromStr for AggregateFunction { "sum" => AggregateFunction::Sum, "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, + "var" => AggregateFunction::Variance, + "var_samp" => AggregateFunction::Variance, + "var_pop" => AggregateFunction::VariancePop, + "stddev" => AggregateFunction::Stddev, + "stddev_samp" => AggregateFunction::Stddev, + "stddev_pop" => AggregateFunction::StddevPop, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -116,6 +132,10 @@ pub fn return_type( Ok(coerced_data_types[0].clone()) } AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), + AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]), + AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), + AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", @@ -212,6 +232,48 @@ pub fn create_aggregate_expr( "AVG(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Variance, true) => { + return Err(DataFusionError::NotImplemented( + "VAR(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::VariancePop, false) => { + Arc::new(expressions::VariancePop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )) + } + (AggregateFunction::VariancePop, true) => { + return Err(DataFusionError::NotImplemented( + "VAR_POP(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Stddev, true) => { + return Err(DataFusionError::NotImplemented( + "STDDEV(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::StddevPop, true) => { + return Err(DataFusionError::NotImplemented( + "STDDEV_POP(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -256,7 +318,12 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::Avg | AggregateFunction::Sum => { + AggregateFunction::Avg + | AggregateFunction::Sum + | AggregateFunction::Variance + | AggregateFunction::VariancePop + | AggregateFunction::Stddev + | AggregateFunction::StddevPop => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } } @@ -267,7 +334,7 @@ mod tests { use super::*; use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance, }; #[test] @@ -426,7 +493,7 @@ mod tests { | DataType::Int16 | DataType::Int32 | DataType::Int64 => DataType::Int64, - DataType::Float32 | DataType::Float64 => data_type.clone(), + DataType::Float32 | DataType::Float64 => DataType::Float64, _ => data_type.clone(), }; @@ -450,6 +517,158 @@ mod tests { Ok(()) } + #[test] + fn test_variance_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Variance]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_var_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::VariancePop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_stddev_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Stddev]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_stddev_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::StddevPop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; @@ -470,6 +689,29 @@ mod tests { Ok(()) } + #[test] + fn test_sum_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?; + assert_eq!(DataType::Int64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?; + assert_eq!(DataType::UInt64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?; + assert_eq!(DataType::Decimal(20, 5), observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?; + assert_eq!(DataType::Decimal(38, 5), observed); + + Ok(()) + } + #[test] fn test_sum_no_utf8() { let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]); @@ -504,6 +746,15 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?; assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(14, 10), observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?; + assert_eq!(DataType::Decimal(38, 10), observed); Ok(()) } @@ -512,4 +763,56 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); assert!(observed.is_err()); } + + #[test] + fn test_variance_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_variance_no_utf8() { + let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]); + assert!(observed.is_err()); + } + + #[test] + fn test_stddev_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_stddev_no_utf8() { + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]); + assert!(observed.is_err()); + } } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index d7b437528d5c..75672fd4fe99 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,7 +21,8 @@ use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, is_sum_support_arg_type, try_cast, + is_avg_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, + is_variance_support_arg_type, try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; @@ -86,6 +87,42 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Variance => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::VariancePop => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Stddev => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::StddevPop => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } @@ -95,7 +132,7 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // min and max support the dictionary data type // unpack the dictionary to get the value match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { + DataType::Dictionary(_, dict_value_type, _) => { // TODO add checker, if the value type is complex data type Ok(vec![dict_value_type.deref().clone()]) } @@ -193,8 +230,7 @@ mod tests { let input_types = vec![ vec![DataType::Int32], vec![DataType::Float32], - // support the decimal data type - // vec![DataType::Decimal(20, 3)], + vec![DataType::Decimal(20, 3)], ]; for fun in funs { for input_type in &input_types { diff --git a/datafusion/src/physical_plan/datetime_expressions.rs b/datafusion/src/physical_plan/datetime_expressions.rs index dbffba2ec91f..2879378c6331 100644 --- a/datafusion/src/physical_plan/datetime_expressions.rs +++ b/datafusion/src/physical_plan/datetime_expressions.rs @@ -181,6 +181,7 @@ pub fn make_now( move |_arg| { Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( now_ts, + Some("UTC".to_owned()), ))) } } @@ -240,8 +241,11 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { let f = |x: Option<&i64>| x.map(|x| date_trunc_single(granularity, *x)).transpose(); Ok(match array { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v)) => { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond((f)(v.as_ref())?)) + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + (f)(v.as_ref())?, + tz_opt.clone(), + )) } ColumnarValue::Array(array) => { let array = array.as_any().downcast_ref::().unwrap(); diff --git a/datafusion/src/physical_plan/distinct_expressions.rs b/datafusion/src/physical_plan/distinct_expressions.rs index f09481a94400..40f6d58dc051 100644 --- a/datafusion/src/physical_plan/distinct_expressions.rs +++ b/datafusion/src/physical_plan/distinct_expressions.rs @@ -76,7 +76,7 @@ impl DistinctCount { fn state_type(data_type: DataType) -> DataType { match data_type { // when aggregating dictionary values, use the underlying value type - DataType::Dictionary(_key_type, value_type) => *value_type, + DataType::Dictionary(_key_type, value_type, _) => *value_type, t => t, } } diff --git a/datafusion/src/physical_plan/expressions/approx_distinct.rs b/datafusion/src/physical_plan/expressions/approx_distinct.rs index 34eb55191aa5..0e4ba9c398ba 100644 --- a/datafusion/src/physical_plan/expressions/approx_distinct.rs +++ b/datafusion/src/physical_plan/expressions/approx_distinct.rs @@ -98,7 +98,7 @@ impl AggregateExpr for ApproxDistinct { DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), other => { return Err(DataFusionError::NotImplemented(format!( - "Support for 'approx_distinct' for data type {} is not implemented", + "Support for 'approx_distinct' for data type {:?} is not implemented", other ))) } diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 3174cda7f81b..3d60c77728ed 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -23,7 +23,9 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::ScalarValue; +use crate::scalar::{ + ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{array::*, datatypes::Field}; @@ -35,11 +37,19 @@ use super::{format_state_name, sum}; pub struct Avg { name: String, expr: Arc, + data_type: DataType, } /// function return type of an average pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { + DataType::Decimal(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); + let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); + Ok(DataType::Decimal(new_precision, new_scale)) + } DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -70,6 +80,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal(_, _) ) } @@ -80,14 +91,15 @@ impl Avg { name: impl Into, data_type: DataType, ) -> Self { - // Average is always Float64, but Avg::new() has a data_type - // parameter to keep a consistent signature with the other - // Aggregate expressions. - assert_eq!(data_type, DataType::Float64); - + // the result of avg just support FLOAT64 and Decimal data type. + assert!(matches!( + data_type, + DataType::Float64 | DataType::Decimal(_, _) + )); Self { name: name.into(), expr, + data_type, } } } @@ -99,7 +111,14 @@ impl AggregateExpr for Avg { } fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(AvgAccumulator::try_new( + // avg is f64 or decimal + &self.data_type, + )?)) } fn state_fields(&self) -> Result> { @@ -111,19 +130,12 @@ impl AggregateExpr for Avg { ), Field::new( &format_state_name(&self.name, "sum"), - DataType::Float64, + self.data_type.clone(), true, ), ]) } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 - &DataType::Float64, - )?)) - } - fn expressions(&self) -> Vec> { vec![self.expr.clone()] } @@ -202,6 +214,17 @@ impl Accumulator for AvgAccumulator { ScalarValue::Float64(e) => { Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) } + ScalarValue::Decimal128(value, precision, scale) => { + Ok(match value { + None => ScalarValue::Decimal128(None, precision, scale), + // TODO add the checker for overflow the precision + Some(v) => ScalarValue::Decimal128( + Some(v / self.count as i128), + precision, + scale, + ), + }) + } _ => Err(DataFusionError::Internal( "Sum should be f64 on average".to_string(), )), @@ -217,6 +240,73 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn test_avg_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = avg_return_type(&data_type)?; + assert_eq!(DataType::Decimal(14, 9), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = avg_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 14), result_type); + Ok(()) + } + + #[test] + fn avg_decimal() -> Result<()> { + // test agg + let mut decimal_builder = Int128Vec::with_capacity(6); + for i in 1..7 { + decimal_builder.push(Some(i as i128)); + } + let array = decimal_builder.as_arc(); + + generic_test_op!( + array, + DataType::Decimal(32, 32), + Avg, + ScalarValue::Decimal128(Some(35000), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn avg_decimal_with_nulls() -> Result<()> { + let mut decimal_builder = Int128Vec::with_capacity(5); + for i in 1..6 { + if i == 2 { + decimal_builder.push_null(); + } else { + decimal_builder.push(Some(i)); + } + } + let array: ArrayRef = decimal_builder.as_arc(); + generic_test_op!( + array, + DataType::Decimal(32, 32), + Avg, + ScalarValue::Decimal128(Some(32500), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn avg_decimal_all_nulls() -> Result<()> { + // test agg + let mut decimal_builder = Int128Vec::with_capacity(5); + for _i in 1..6 { + decimal_builder.push_null(); + } + let array: ArrayRef = decimal_builder.as_arc(); + generic_test_op!( + array, + DataType::Decimal(32, 32), + Avg, + ScalarValue::Decimal128(None, 14, 4), + DataType::Decimal(14, 4) + ) + } + #[test] fn avg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/cast.rs b/datafusion/src/physical_plan/expressions/cast.rs index 3ab058d6e1e0..789ab582a7a0 100644 --- a/datafusion/src/physical_plan/expressions/cast.rs +++ b/datafusion/src/physical_plan/expressions/cast.rs @@ -97,7 +97,7 @@ fn cast_with_error(array: &dyn Array, cast_type: &DataType) -> Result>>(); let invalid_values = take::take(array, &Int32Array::from(&invalid_indices))?; return Err(DataFusionError::Execution(format!( - "Could not cast {} to value of type {}", + "Could not cast {:?} to value of type {:?}", invalid_values, cast_type ))); } diff --git a/datafusion/src/physical_plan/expressions/coercion.rs b/datafusion/src/physical_plan/expressions/coercion.rs index 4a9da7387616..a04f11f263cd 100644 --- a/datafusion/src/physical_plan/expressions/coercion.rs +++ b/datafusion/src/physical_plan/expressions/coercion.rs @@ -63,13 +63,13 @@ fn dictionary_value_coercion( pub fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { match (lhs_type, rhs_type) { ( - DataType::Dictionary(_lhs_index_type, lhs_value_type), - DataType::Dictionary(_rhs_index_type, rhs_value_type), + DataType::Dictionary(_lhs_index_type, lhs_value_type, _), + DataType::Dictionary(_rhs_index_type, rhs_value_type, _), ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), - (DataType::Dictionary(_index_type, value_type), _) => { + (DataType::Dictionary(_index_type, value_type, _), _) => { dictionary_value_coercion(value_type, rhs_type) } - (_, DataType::Dictionary(_index_type, value_type)) => { + (_, DataType::Dictionary(_index_type, value_type, _)) => { dictionary_value_coercion(lhs_type, value_type) } _ => None, @@ -100,11 +100,48 @@ pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; + use arrow::datatypes::TimeUnit; match (lhs_type, rhs_type) { (Utf8, Date32) => Some(Date32), (Date32, Utf8) => Some(Date32), (Utf8, Date64) => Some(Date64), (Date64, Utf8) => Some(Date64), + (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { + let tz = match (lhs_tz, rhs_tz) { + // can't cast across timezones + (Some(lhs_tz), Some(rhs_tz)) => { + if lhs_tz != rhs_tz { + return None; + } else { + Some(lhs_tz.clone()) + } + } + (Some(lhs_tz), None) => Some(lhs_tz.clone()), + (None, Some(rhs_tz)) => Some(rhs_tz.clone()), + (None, None) => None, + }; + + let unit = match (lhs_unit, rhs_unit) { + (TimeUnit::Second, TimeUnit::Millisecond) => TimeUnit::Second, + (TimeUnit::Second, TimeUnit::Microsecond) => TimeUnit::Second, + (TimeUnit::Second, TimeUnit::Nanosecond) => TimeUnit::Second, + (TimeUnit::Millisecond, TimeUnit::Second) => TimeUnit::Second, + (TimeUnit::Millisecond, TimeUnit::Microsecond) => TimeUnit::Millisecond, + (TimeUnit::Millisecond, TimeUnit::Nanosecond) => TimeUnit::Millisecond, + (TimeUnit::Microsecond, TimeUnit::Second) => TimeUnit::Second, + (TimeUnit::Microsecond, TimeUnit::Millisecond) => TimeUnit::Millisecond, + (TimeUnit::Microsecond, TimeUnit::Nanosecond) => TimeUnit::Microsecond, + (TimeUnit::Nanosecond, TimeUnit::Second) => TimeUnit::Second, + (TimeUnit::Nanosecond, TimeUnit::Millisecond) => TimeUnit::Millisecond, + (TimeUnit::Nanosecond, TimeUnit::Microsecond) => TimeUnit::Microsecond, + (l, r) => { + assert_eq!(l, r); + *l + } + }; + + Some(Timestamp(unit, tz)) + } _ => None, } } @@ -176,18 +213,23 @@ mod tests { use arrow::datatypes::IntegerType; // TODO: In the future, this would ideally return Dictionary types and avoid unpacking - let lhs_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int32)); - let rhs_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int16)); + let lhs_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int32), false); + let rhs_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int16), false); assert_eq!( dictionary_coercion(&lhs_type, &rhs_type), Some(DataType::Int32) ); - let lhs_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8)); - let rhs_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int16)); + let lhs_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); + let rhs_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int16), false); assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); - let lhs_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8)); + let lhs_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); let rhs_type = DataType::Utf8; assert_eq!( dictionary_coercion(&lhs_type, &rhs_type), @@ -195,7 +237,8 @@ mod tests { ); let lhs_type = DataType::Utf8; - let rhs_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8)); + let rhs_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); assert_eq!( dictionary_coercion(&lhs_type, &rhs_type), Some(DataType::Utf8) diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index bbe80c76b3e1..ba16f50127cf 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -107,7 +107,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { Some(col) => Ok(ColumnarValue::Array(col.clone())) } } - (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))), + (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {:?} with {} index", dt, key))), }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( "field access is not yet implemented for scalar values".to_string(), @@ -227,7 +227,7 @@ mod tests { fn get_indexed_field_invalid_list_index() -> Result<()> { let schema = list_schema("l"); let expr = col("l", &schema).unwrap(); - get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") + get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, metadata: {} }) with 0 index") } fn build_struct( diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 7a1cbbd74f64..1d1ba506acba 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -39,7 +39,7 @@ use super::format_state_name; // The reason min/max aggregate produces unpacked output because there is only one // min/max value per group; there is no needs to keep them Dictionary encode fn min_max_aggregate_data_type(input_type: DataType) -> DataType { - if let DataType::Dictionary(_, value_type) = input_type { + if let DataType::Dictionary(_, value_type, _) = input_type { *value_type } else { input_type @@ -123,6 +123,12 @@ macro_rules! typed_min_max_batch { let value = $OP(array); ScalarValue::$SCALAR(value) }}; + + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{ + let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + let value = $OP(array); + ScalarValue::$SCALAR(value, $TZ.clone()) + }}; } // TODO implement this in arrow-rs with simd @@ -176,18 +182,30 @@ macro_rules! min_max_batch { DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), - DataType::Timestamp(TimeUnit::Second, _) => { - typed_min_max_batch!($VALUES, Int64Array, TimestampSecond, $OP) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - typed_min_max_batch!($VALUES, Int64Array, TimestampMillisecond, $OP) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - typed_min_max_batch!($VALUES, Int64Array, TimestampMicrosecond, $OP) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - typed_min_max_batch!($VALUES, Int64Array, TimestampNanosecond, $OP) + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_min_max_batch!($VALUES, Int64Array, TimestampSecond, $OP, tz_opt) } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( + $VALUES, + Int64Array, + TimestampMillisecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( + $VALUES, + Int64Array, + TimestampMicrosecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( + $VALUES, + Int64Array, + TimestampNanosecond, + $OP, + tz_opt + ), DataType::Date32 => typed_min_max_batch!($VALUES, Int32Array, Date32, $OP), DataType::Date64 => typed_min_max_batch!($VALUES, Int64Array, Date64, $OP), other => { @@ -269,6 +287,18 @@ macro_rules! typed_min_max { (Some(a), Some(b)) => Some((*a).$OP(*b)), }) }}; + + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $TZ:expr) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $TZ.clone(), + ) + }}; } // min/max of two scalar string values. @@ -333,26 +363,26 @@ macro_rules! min_max { (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) } - (ScalarValue::TimestampSecond(lhs), ScalarValue::TimestampSecond(rhs)) => { - typed_min_max!(lhs, rhs, TimestampSecond, $OP) + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) } ( - ScalarValue::TimestampMillisecond(lhs), - ScalarValue::TimestampMillisecond(rhs), + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), ) => { - typed_min_max!(lhs, rhs, TimestampMillisecond, $OP) + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) } ( - ScalarValue::TimestampMicrosecond(lhs), - ScalarValue::TimestampMicrosecond(rhs), + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), ) => { - typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP) + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) } ( - ScalarValue::TimestampNanosecond(lhs), - ScalarValue::TimestampNanosecond(rhs), + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), ) => { - typed_min_max!(lhs, rhs, TimestampNanosecond, $OP) + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) } ( ScalarValue::Date32(lhs), diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 6b49f5178a30..04127718f961 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -69,8 +69,11 @@ mod nth_value; mod nullif; mod rank; mod row_number; +mod stats; +mod stddev; mod sum; mod try_cast; +mod variance; /// Module with some convenient methods used in expression building pub mod helpers { @@ -101,9 +104,16 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub use stats::StatsType; +pub(crate) use stddev::{ + is_stddev_support_arg_type, stddev_return_type, Stddev, StddevPop, +}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; +pub(crate) use variance::{ + is_variance_support_arg_type, variance_return_type, Variance, VariancePop, +}; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { diff --git a/datafusion/src/physical_plan/expressions/rank.rs b/datafusion/src/physical_plan/expressions/rank.rs index 62adf460dd87..47b36ebfe676 100644 --- a/datafusion/src/physical_plan/expressions/rank.rs +++ b/datafusion/src/physical_plan/expressions/rank.rs @@ -38,6 +38,7 @@ pub struct Rank { } #[derive(Debug, Copy, Clone)] +#[allow(clippy::enum_variant_names)] pub(crate) enum RankType { Rank, DenseRank, diff --git a/datafusion/src/physical_plan/expressions/stats.rs b/datafusion/src/physical_plan/expressions/stats.rs new file mode 100644 index 000000000000..3f2d266622de --- /dev/null +++ b/datafusion/src/physical_plan/expressions/stats.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Enum used for differenciating population and sample for statistical functions +#[derive(Debug, Clone, Copy)] +pub enum StatsType { + /// Population + Population, + /// Sample + Sample, +} diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs new file mode 100644 index 000000000000..2c8538b28ef4 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -0,0 +1,424 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::{format_state_name, StatsType}; + +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +#[derive(Debug)] +pub struct Stddev { + name: String, + expr: Arc, +} + +/// STDDEV_POP population aggregate expression +#[derive(Debug)] +pub struct StddevPop { + name: String, + expr: Arc, +} + +/// function return type of standard deviation +pub(crate) fn stddev_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "STDDEV does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl StddevPop { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for StddevPop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} +/// An accumulator to compute the average +#[derive(Debug)] +pub struct StddevAccumulator { + variance: VarianceAccumulator, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + variance: VarianceAccumulator::try_new(s_type)?, + }) + } +} + +impl Accumulator for StddevAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + self.variance.get_mean(), + self.variance.get_m2(), + ]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + self.variance.update(values) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + self.variance.merge(states) + } + + fn evaluate(&self) -> Result { + let variance = self.variance.evaluate()?; + match variance { + ScalarValue::Float64(e) => { + if e == None { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) + } + } + _ => Err(DataFusionError::Internal( + "Variance should be f64".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn stddev_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + StddevPop, + ScalarValue::from(0.5_f64), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + StddevPop, + ScalarValue::from(0.7760297817881877), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_3() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + generic_test_op!( + a, + DataType::Float64, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_4() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Stddev, + ScalarValue::from(0.9504384952922168), + DataType::Float64 + ) + } + + #[test] + fn stddev_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn stddev_u32() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![ + 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, + ])); + generic_test_op!( + a, + DataType::UInt32, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn stddev_f32() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![ + 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, + ])); + generic_test_op!( + a, + DataType::Float32, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn test_stddev_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = stddev_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + assert!(stddev_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn test_stddev_1_input() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64])); + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn stddev_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + StddevPop, + ScalarValue::from(1.479019945774904), + DataType::Float64 + ) + } + + #[test] + fn stddev_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Int32Vec::from(vec![None, None]).as_arc(); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 59bd3e9dc769..08e0dfe10d8c 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::ScalarValue; +use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use arrow::compute; use arrow::{ array::*, @@ -31,6 +31,7 @@ use arrow::{ }; use super::format_state_name; +use crate::arrow::array::Array; /// SUM aggregate expression #[derive(Debug)] @@ -50,8 +51,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { Ok(DataType::UInt64) } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), + // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc, + // the result type of floating-point is FLOAT64 with the double precision. + DataType::Float64 | DataType::Float32 => Ok(DataType::Float64), + DataType::Decimal(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10); + Ok(DataType::Decimal(new_precision, *scale)) + } other => Err(DataFusionError::Plan(format!( "SUM does not support type \"{:?}\"", other @@ -72,6 +80,7 @@ pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal(_, _) ) } @@ -105,6 +114,10 @@ impl AggregateExpr for Sum { )) } + fn create_accumulator(&self) -> Result> { + Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![Field::new( &format_state_name(&self.name, "sum"), @@ -117,10 +130,6 @@ impl AggregateExpr for Sum { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) - } - fn name(&self) -> &str { &self.name } @@ -149,9 +158,34 @@ macro_rules! typed_sum_delta_batch { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 +fn sum_decimal_batch( + values: &ArrayRef, + precision: &usize, + scale: &usize, +) -> Result { + let array = values.as_any().downcast_ref::().unwrap(); + + if array.null_count() == array.len() { + return Ok(ScalarValue::Decimal128(None, *precision, *scale)); + } + + let mut result = 0_i128; + for i in 0..array.len() { + if array.is_valid(i) { + result += array.value(i); + } + } + Ok(ScalarValue::Decimal128(Some(result), *precision, *scale)) +} + // sums the array and returns a ScalarValue of its corresponding type. pub(super) fn sum_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { + DataType::Decimal(precision, scale) => { + sum_decimal_batch(values, precision, scale)? + } DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32), DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), @@ -166,7 +200,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive the type {:?}", e - ))) + ))); } }) } @@ -183,8 +217,62 @@ macro_rules! typed_sum { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 +fn sum_decimal( + lhs: &Option, + rhs: &Option, + precision: &usize, + scale: &usize, +) -> ScalarValue { + match (lhs, rhs) { + (None, None) => ScalarValue::Decimal128(None, *precision, *scale), + (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), + (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), + (Some(lhs_value), Some(rhs_value)) => { + ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale) + } + } +} + +fn sum_decimal_with_diff_scale( + lhs: &Option, + rhs: &Option, + precision: &usize, + lhs_scale: &usize, + rhs_scale: &usize, +) -> ScalarValue { + // the lhs_scale must be greater or equal rhs_scale. + match (lhs, rhs) { + (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale), + (None, Some(rhs_value)) => { + let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32); + ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) + } + (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale), + (Some(lhs_value), Some(rhs_value)) => { + let new_value = + rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value; + ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) + } + } +} + pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { Ok(match (lhs, rhs) { + (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { + let max_precision = p1.max(p2); + if s1.eq(s2) { + // s1 = s2 + sum_decimal(v1, v2, max_precision, s1) + } else if s1.gt(s2) { + // s1 > s2 + sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2) + } else { + // s1 < s2 + sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1) + } + } // float64 coerces everything to f64 (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { typed_sum!(lhs, rhs, Float64, f64) @@ -250,16 +338,14 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive a scalar {:?}", e - ))) + ))); } }) } impl Accumulator for SumAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.sum = sum(&self.sum, &sum_batch(values)?)?; - Ok(()) + fn state(&self) -> Result> { + Ok(vec![self.sum.clone()]) } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { @@ -268,6 +354,12 @@ impl Accumulator for SumAccumulator { Ok(()) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.sum = sum(&self.sum, &sum_batch(values)?)?; + Ok(()) + } + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { // sum(sum1, sum2) = sum1 + sum2 self.update(states) @@ -278,11 +370,9 @@ impl Accumulator for SumAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { - Ok(vec![self.sum.clone()]) - } - fn evaluate(&self) -> Result { + // TODO: add the checker for overflow + // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. Ok(self.sum.clone()) } } @@ -295,6 +385,139 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn test_sum_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal(20, 5), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 10), result_type); + Ok(()) + } + + #[test] + fn sum_decimal() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result); + // test sum decimal with diff scale + let left = ScalarValue::Decimal128(Some(123), 10, 3); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = sum(&left, &right)?; + assert_eq!( + ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3), + result + ); + // diff precision and scale for decimal data type + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 11, 3); + let result = sum(&left, &right); + assert_eq!( + ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3), + result.unwrap() + ); + + // test sum batch + let mut decimal_builder = Int128Vec::with_capacity(5); + for i in 1..6 { + decimal_builder.push(Some(i as i128)); + } + let array: ArrayRef = decimal_builder.as_arc(); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result); + + // test agg + let mut decimal_builder = Int128Vec::with_capacity(5); + for i in 1..6 { + decimal_builder.push(Some(i as i128)); + } + let array: ArrayRef = decimal_builder.as_arc(); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(Some(15), 20, 0), + DataType::Decimal(20, 0) + ) + } + + #[test] + fn sum_decimal_with_nulls() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(None, 10, 2); + let right = ScalarValue::Decimal128(Some(123), 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(Some(123), 10, 2), result); + + // test with batch + let mut decimal_builder = Int128Vec::with_capacity(5); + for i in 1..6 { + if i == 2 { + decimal_builder.push_null(); + } else { + decimal_builder.push(Some(i)); + } + } + let array: ArrayRef = decimal_builder.as_arc(); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); + + // test agg + let mut decimal_builder = Int128Vec::with_capacity(5); + for i in 1..6 { + if i == 2 { + decimal_builder.push_null(); + } else { + decimal_builder.push(Some(i)); + } + } + let array: ArrayRef = decimal_builder.as_arc(); + generic_test_op!( + array, + DataType::Decimal(35, 0), + Sum, + ScalarValue::Decimal128(Some(13), 38, 0), + DataType::Decimal(38, 0) + ) + } + + #[test] + fn sum_decimal_all_nulls() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(None, 10, 2); + let right = ScalarValue::Decimal128(None, 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 2), result); + + // test with batch + let mut decimal_builder = Int128Vec::with_capacity(5); + for _i in 1..6 { + decimal_builder.push_null(); + } + let array: ArrayRef = decimal_builder.as_arc(); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + + // test agg + let mut decimal_builder = Int128Vec::with_capacity(5); + for _i in 1..6 { + decimal_builder.push_null(); + } + let array: ArrayRef = decimal_builder.as_arc(); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(None, 20, 0), + DataType::Decimal(20, 0) + ) + } + #[test] fn sum_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs new file mode 100644 index 000000000000..1786c388e758 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -0,0 +1,528 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::{format_state_name, StatsType}; + +/// VAR and VAR_SAMP aggregate expression +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc, +} + +/// VAR_POP aggregate expression +#[derive(Debug)] +pub struct VariancePop { + name: String, + expr: Arc, +} + +/// function return type of variance +pub(crate) fn variance_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "VARIANCE does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_variance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl VariancePop { + /// Create a new VAR_POP aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for VariancePop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute variance +/// The algrithm used is an online implementation and numerically stable. It is based on this paper: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. + +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: ScalarValue, + mean: ScalarValue, + count: u64, + stats_type: StatsType, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + m2: ScalarValue::from(0 as f64), + mean: ScalarValue::from(0 as f64), + count: 0, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> ScalarValue { + self.mean.clone() + } + + pub fn get_m2(&self) -> ScalarValue { + self.m2.clone() + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + self.mean.clone(), + self.m2.clone(), + ]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let values = &values[0]; + let is_empty = values.is_null(); + + if !is_empty { + let new_count = self.count + 1; + let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; + let new_mean = ScalarValue::add( + &ScalarValue::div(&delta1, &ScalarValue::from(new_count as f64))?, + &self.mean, + )?; + let delta2 = ScalarValue::add(values, &new_mean.arithmetic_negate())?; + let tmp = ScalarValue::mul(&delta1, &delta2)?; + + let new_m2 = ScalarValue::add(&self.m2, &tmp)?; + self.count += 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + let count = &states[0]; + let mean = &states[1]; + let m2 = &states[2]; + let mut new_count: u64 = self.count; + + // counts are summed + if let ScalarValue::UInt64(Some(c)) = count { + if *c == 0_u64 { + return Ok(()); + } + + if self.count == 0 { + self.count = *c; + self.mean = mean.clone(); + self.m2 = m2.clone(); + return Ok(()); + } + new_count += c + } else { + unreachable!() + }; + + let new_mean = ScalarValue::div( + &ScalarValue::add(&self.mean, mean)?, + &ScalarValue::from(2_f64), + )?; + let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?; + let delta_sqrt = ScalarValue::mul(&delta, &delta)?; + let new_m2 = ScalarValue::add( + &ScalarValue::add( + &ScalarValue::mul( + &delta_sqrt, + &ScalarValue::div( + &ScalarValue::mul(&ScalarValue::from(self.count), count)?, + &ScalarValue::from(new_count as f64), + )?, + )?, + &self.m2, + )?, + m2, + )?; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + + Ok(()) + } + + fn evaluate(&self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + if count <= 1 { + return Err(DataFusionError::Internal( + "At least two values are needed to calculate variance".to_string(), + )); + } + + match self.m2 { + ScalarValue::Float64(e) => { + if self.count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f / count as f64))) + } + } + _ => Err(DataFusionError::Internal( + "M2 should be f64 for variance".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn variance_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + VariancePop, + ScalarValue::from(0.25_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + generic_test_op!( + a, + DataType::Float64, + VariancePop, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_3() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(2.5_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_4() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(0.9033333333333333_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + VariancePop, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_u32() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![ + 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, + ])); + generic_test_op!( + a, + DataType::UInt32, + VariancePop, + ScalarValue::from(2.0f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f32() -> Result<()> { + let a: ArrayRef = + Float32Vec::from_slice(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]).as_arc(); + generic_test_op!( + a, + DataType::Float32, + VariancePop, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn test_variance_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = variance_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + assert!(variance_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn test_variance_1_input() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64])); + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn variance_i32_with_nulls() -> Result<()> { + let a: ArrayRef = + Int32Vec::from(vec![Some(1), None, Some(3), Some(4), Some(5)]).as_arc(); + generic_test_op!( + a, + DataType::Int32, + VariancePop, + ScalarValue::from(2.1875f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Int32Vec::from(vec![None, None]).as_arc(); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index b50c0a082686..38be1142c4b7 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -18,14 +18,13 @@ //! Execution plan for reading line-delimited Avro files #[cfg(feature = "avro")] use crate::avro_to_arrow; +#[cfg(feature = "avro")] +use crate::datasource::object_store::ReadSeek; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use arrow::datatypes::SchemaRef; -#[cfg(feature = "avro")] -use arrow::error::ArrowError; - use async_trait::async_trait; use std::any::Any; use std::sync::Arc; @@ -106,19 +105,16 @@ impl ExecutionPlan for AvroExec { let file_schema = Arc::clone(&self.base_config.file_schema); // The avro reader cannot limit the number of records, so `remaining` is ignored. - let fun = move |file, _remaining: &Option| { - let reader_res = avro_to_arrow::Reader::try_new( - file, - Arc::clone(&file_schema), - batch_size, - proj.clone(), - ); - match reader_res { - Ok(r) => Box::new(r) as BatchIter, - Err(e) => Box::new( - vec![Err(ArrowError::ExternalError(Box::new(e)))].into_iter(), - ), + let fun = move |file: Box, + _remaining: &Option| { + let mut builder = avro_to_arrow::ReaderBuilder::new() + .with_batch_size(batch_size) + .with_schema(file_schema.clone()); + if let Some(proj) = proj.clone() { + builder = builder.with_projection(proj); } + let reader = builder.build(file).unwrap(); + Box::new(reader) as BatchIter }; Ok(Box::pin(FileStream::new( @@ -238,7 +234,7 @@ mod tests { projection: Some(vec![0, 1, file_schema.fields().len(), 2]), object_store: Arc::new(LocalFileSystem {}), file_groups: vec![vec![partitioned_file]], - file_schema: file_schema, + file_schema, statistics: Statistics::default(), batch_size: 1024, limit: None, diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index fff1877ecb46..ac517bc63df7 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -27,7 +27,7 @@ use arrow::error::Result as ArrowResult; use arrow::io::json; use arrow::record_batch::RecordBatch; use std::any::Any; -use std::io::Read; +use std::io::{BufRead, BufReader, Read}; use std::sync::Arc; use super::file_stream::{BatchIter, FileStream}; @@ -56,14 +56,37 @@ impl NdJsonExec { // TODO: implement iterator in upstream json::Reader type struct JsonBatchReader { - reader: json::Reader, + reader: R, + schema: SchemaRef, + batch_size: usize, + proj: Option>, } -impl Iterator for JsonBatchReader { +impl Iterator for JsonBatchReader { type Item = ArrowResult; fn next(&mut self) -> Option { - self.reader.next().transpose() + // json::read::read_rows iterates on the empty vec and reads at most n rows + let mut rows: Vec = Vec::with_capacity(self.batch_size); + let read = json::read::read_rows(&mut self.reader, rows.as_mut_slice()); + read.and_then(|records_read| { + if records_read > 0 { + let fields = if let Some(proj) = &self.proj { + self.schema + .fields + .iter() + .filter(|f| proj.contains(&f.name)) + .cloned() + .collect() + } else { + self.schema.fields.clone() + }; + json::read::deserialize(&rows, fields).map(Some) + } else { + Ok(None) + } + }) + .transpose() } } @@ -108,12 +131,10 @@ impl ExecutionPlan for NdJsonExec { // The json reader cannot limit the number of records, so `remaining` is ignored. let fun = move |file, _remaining: &Option| { Box::new(JsonBatchReader { - reader: json::Reader::new( - file, - Arc::clone(&file_schema), - batch_size, - proj.clone(), - ), + reader: BufReader::new(file), + schema: file_schema.clone(), + batch_size, + proj: proj.clone(), }) as BatchIter }; diff --git a/datafusion/src/physical_plan/file_format/mod.rs b/datafusion/src/physical_plan/file_format/mod.rs index f640e3df9145..0d372810985d 100644 --- a/datafusion/src/physical_plan/file_format/mod.rs +++ b/datafusion/src/physical_plan/file_format/mod.rs @@ -54,7 +54,7 @@ use super::{ColumnStatistics, Statistics}; lazy_static! { /// The datatype used for all partitioning columns for now pub static ref DEFAULT_PARTITION_COLUMN_DATATYPE: DataType = - DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::Utf8)); + DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::Utf8), false); } /// The base configurations to provide when creating a physical plan for diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 15c85d11bea2..e62ecb453a56 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -71,8 +71,8 @@ pub struct ParquetExec { projected_schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Optional predicate builder - predicate_builder: Option, + /// Optional predicate for pruning row groups + pruning_predicate: Option, } /// Stores metrics about the parquet execution for a particular parquet file @@ -95,12 +95,12 @@ impl ParquetExec { let predicate_creation_errors = MetricBuilder::new(&metrics).global_counter("num_predicate_creation_errors"); - let predicate_builder = predicate.and_then(|predicate_expr| { + let pruning_predicate = predicate.and_then(|predicate_expr| { match PruningPredicate::try_new( &predicate_expr, base_config.file_schema.clone(), ) { - Ok(predicate_builder) => Some(predicate_builder), + Ok(pruning_predicate) => Some(pruning_predicate), Err(e) => { debug!( "Could not create pruning predicate for {:?}: {}", @@ -119,7 +119,7 @@ impl ParquetExec { projected_schema, projected_statistics, metrics, - predicate_builder, + pruning_predicate, } } @@ -199,7 +199,7 @@ impl ExecutionPlan for ParquetExec { Some(proj) => proj, None => (0..self.base_config.file_schema.fields().len()).collect(), }; - let predicate_builder = self.predicate_builder.clone(); + let pruning_predicate = self.pruning_predicate.clone(); let batch_size = self.base_config.batch_size; let limit = self.base_config.limit; let object_store = Arc::clone(&self.base_config.object_store); @@ -215,7 +215,7 @@ impl ExecutionPlan for ParquetExec { partition, metrics, &projection, - &predicate_builder, + &pruning_predicate, batch_size, response_tx, limit, @@ -341,12 +341,7 @@ macro_rules! get_min_max_values { }; let data_type = field.data_type(); - let null_scalar: ScalarValue = if let Ok(v) = data_type.try_into() { - v - } else { - // DataFusion doesn't have support for ScalarValues of the column type - return None - }; + let null_scalar: ScalarValue = data_type.try_into().ok()?; let scalar_values : Vec = $self.row_group_metadata .iter() @@ -382,17 +377,17 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { } fn build_row_group_predicate( - predicate_builder: &PruningPredicate, + pruning_predicate: &PruningPredicate, metrics: ParquetFileMetrics, row_group_metadata: &[RowGroupMetaData], ) -> Box bool> { - let parquet_schema = predicate_builder.schema().as_ref(); + let parquet_schema = pruning_predicate.schema().as_ref(); let pruning_stats = RowGroupPruningStatistics { row_group_metadata, parquet_schema, }; - let predicate_values = predicate_builder.prune(&pruning_stats); + let predicate_values = pruning_predicate.prune(&pruning_stats); match predicate_values { Ok(values) => { @@ -418,7 +413,7 @@ fn read_partition( partition: Vec, metrics: ExecutionPlanMetricsSet, projection: &[usize], - predicate_builder: &Option, + pruning_predicate: &Option, _batch_size: usize, response_tx: Sender>, limit: Option, @@ -440,10 +435,9 @@ fn read_partition( None, None, )?; - if let Some(predicate_builder) = predicate_builder { - let _file_metadata = record_reader.metadata(); + if let Some(pruning_predicate) = pruning_predicate { record_reader.set_groups_filter(Arc::new(build_row_group_predicate( - predicate_builder, + pruning_predicate, file_metrics, &record_reader.metadata().row_groups, ))); @@ -478,6 +472,7 @@ mod tests { use futures::StreamExt; use parquet::metadata::ColumnChunkMetaData; use parquet::statistics::Statistics as ParquetStatistics; + use parquet_format_async_temp::RowGroup; #[tokio::test] async fn parquet_exec_with_projection() -> Result<()> { @@ -601,12 +596,12 @@ mod tests { } #[test] - fn row_group_predicate_builder_simple_expr() -> Result<()> { + fn row_group_pruning_predicate_simple_expr() -> Result<()> { use crate::logical_plan::{col, lit}; // int > 1 => c1_max > 1 let expr = col("c1").gt(lit(15)); let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let predicate_builder = + let pruning_predicate = PruningPredicate::try_new(&expr, Arc::new(schema.clone()))?; let schema_descr = to_parquet_schema(&schema)?; @@ -632,7 +627,7 @@ mod tests { ); let row_group_metadata = vec![rgm1, rgm2]; let row_group_predicate = build_row_group_predicate( - &predicate_builder, + &pruning_predicate, parquet_file_metrics(), &row_group_metadata, ); @@ -647,12 +642,12 @@ mod tests { } #[test] - fn row_group_predicate_builder_missing_stats() -> Result<()> { + fn row_group_pruning_predicate_missing_stats() -> Result<()> { use crate::logical_plan::{col, lit}; // int > 1 => c1_max > 1 let expr = col("c1").gt(lit(15)); let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let predicate_builder = + let pruning_predicate = PruningPredicate::try_new(&expr, Arc::new(schema.clone()))?; let schema_descr = to_parquet_schema(&schema)?; @@ -678,7 +673,7 @@ mod tests { ); let row_group_metadata = vec![rgm1, rgm2]; let row_group_predicate = build_row_group_predicate( - &predicate_builder, + &pruning_predicate, parquet_file_metrics(), &row_group_metadata, ); @@ -695,7 +690,7 @@ mod tests { } #[test] - fn row_group_predicate_builder_partial_expr() -> Result<()> { + fn row_group_pruning_predicate_partial_expr() -> Result<()> { use crate::logical_plan::{col, lit}; // test row group predicate with partially supported expression // int > 1 and int % 2 => c1_max > 1 and true @@ -704,7 +699,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ])); - let predicate_builder = PruningPredicate::try_new(&expr, schema.clone())?; + let pruning_predicate = PruningPredicate::try_new(&expr, schema.clone())?; let schema_descr = to_parquet_schema(&schema)?; let rgm1 = get_row_group_meta_data( @@ -747,7 +742,7 @@ mod tests { ); let row_group_metadata = vec![rgm1, rgm2]; let row_group_predicate = build_row_group_predicate( - &predicate_builder, + &pruning_predicate, parquet_file_metrics(), &row_group_metadata, ); @@ -763,9 +758,9 @@ mod tests { // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").modulus(lit(2))); - let predicate_builder = PruningPredicate::try_new(&expr, schema)?; + let pruning_predicate = PruningPredicate::try_new(&expr, schema)?; let row_group_predicate = build_row_group_predicate( - &predicate_builder, + &pruning_predicate, parquet_file_metrics(), &row_group_metadata, ); @@ -779,9 +774,8 @@ mod tests { Ok(()) } - #[ignore] - #[allow(dead_code)] - fn row_group_predicate_builder_null_expr() -> Result<()> { + #[test] + fn row_group_pruning_predicate_null_expr() -> Result<()> { use crate::logical_plan::{col, lit}; // test row group predicate with an unknown (Null) expr // @@ -793,7 +787,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); - let predicate_builder = PruningPredicate::try_new(&expr, schema.clone())?; + let pruning_predicate = PruningPredicate::try_new(&expr, schema.clone())?; let schema_descr = to_parquet_schema(&schema)?; let rgm1 = get_row_group_meta_data( @@ -834,7 +828,7 @@ mod tests { ); let row_group_metadata = vec![rgm1, rgm2]; let row_group_predicate = build_row_group_predicate( - &predicate_builder, + &pruning_predicate, parquet_file_metrics(), &row_group_metadata, ); @@ -858,6 +852,7 @@ mod tests { use parquet::schema::types::{physical_type_to_type, ParquetType}; use parquet_format_async_temp::{ColumnChunk, ColumnMetaData}; + let mut chunks = vec![]; let mut columns = vec![]; for (i, s) in column_statistics.into_iter().enumerate() { let column_descr = schema_descr.column(i); @@ -895,9 +890,15 @@ mod tests { crypto_metadata: None, encrypted_column_metadata: None, }; - let column = ColumnChunkMetaData::new(column_chunk, column_descr.clone()); + let column = ColumnChunkMetaData::try_from_thrift( + column_descr.clone(), + column_chunk.clone(), + ) + .unwrap(); columns.push(column); + chunks.push(column_chunk); } - RowGroupMetaData::new(columns, 1000, 2000) + let rg = RowGroup::new(chunks, 0, 0, None, None, None, None); + RowGroupMetaData::try_from_thrift(schema_descr, rg).unwrap() } } diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 1ca9231a0bbb..155f391d4c04 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -612,7 +612,10 @@ pub fn return_type( BuiltinScalarFunction::ToTimestampSeconds => { Ok(DataType::Timestamp(TimeUnit::Second, None)) } - BuiltinScalarFunction::Now => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), + BuiltinScalarFunction::Now => Ok(DataType::Timestamp( + TimeUnit::Nanosecond, + Some("UTC".to_owned()), + )), BuiltinScalarFunction::Translate => { utf8_to_str_type(&input_expr_types[0], "translate") } diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 932c76bf894f..90608db172d5 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -39,7 +39,6 @@ use crate::{ use arrow::{ array::*, - buffer::MutableBuffer, compute::{cast, concatenate, take}, datatypes::{DataType, Field, Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, @@ -424,7 +423,7 @@ fn group_aggregate_batch( } // Collect all indices + offsets based on keys in this vec - let mut batch_indices = MutableBuffer::::new(); + let mut batch_indices = Vec::::new(); let mut offsets = vec![0]; let mut offset_so_far = 0; for group_idx in groups_with_rows.iter() { diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 1b0f906bcf5e..07144d74a34d 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -29,10 +29,10 @@ use async_trait::async_trait; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::Mutex; +use arrow::array::*; use arrow::datatypes::*; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use arrow::{array::*, buffer::MutableBuffer}; use arrow::compute::take; @@ -56,6 +56,8 @@ use super::{ }; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; +use arrow::bitmap::MutableBitmap; +use arrow::buffer::Buffer; use log::debug; use std::fmt; @@ -389,9 +391,9 @@ impl ExecutionPlan for HashJoinExec { let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { - vec![false; num_rows] + MutableBitmap::from_iter((0..num_rows).map(|_| false)) } - JoinType::Inner | JoinType::Right => vec![], + JoinType::Inner | JoinType::Right => MutableBitmap::with_capacity(0), }; Ok(Box::pin(HashJoinStream::new( self.schema.clone(), @@ -490,8 +492,7 @@ struct HashJoinStream { /// Random state used for hashing initialization random_state: RandomState, /// Keeps track of the left side rows whether they are visited - visited_left_side: Vec, - // TODO: use a more memory efficient data structure, https://github.com/apache/arrow-datafusion/issues/240 + visited_left_side: MutableBitmap, /// There is nothing to process anymore and left side is processed in case of left join is_exhausted: bool, /// Metrics @@ -513,7 +514,7 @@ impl HashJoinStream { right: SendableRecordBatchStream, column_indices: Vec, random_state: RandomState, - visited_left_side: Vec, + visited_left_side: MutableBitmap, join_metrics: HashJoinMetrics, null_equals_null: bool, ) -> Self { @@ -665,8 +666,8 @@ fn build_join_indexes( match join_type { JoinType::Inner | JoinType::Semi | JoinType::Anti => { // Using a buffer builder to avoid slower normal builder - let mut left_indices = MutableBuffer::::new(); - let mut right_indices = MutableBuffer::::new(); + let mut left_indices = Vec::::new(); + let mut right_indices = Vec::::new(); // Visit all of the right rows for (row, hash_value) in hash_values.iter().enumerate() { @@ -708,8 +709,8 @@ fn build_join_indexes( )) } JoinType::Left => { - let mut left_indices = MutableBuffer::::new(); - let mut right_indices = MutableBuffer::::new(); + let mut left_indices = Vec::::new(); + let mut right_indices = Vec::::new(); // First visit all of the rows for (row, hash_value) in hash_values.iter().enumerate() { @@ -867,32 +868,26 @@ fn equal_rows( // Produces a batch for left-side rows that have/have not been matched during the whole join fn produce_from_matched( - visited_left_side: &[bool], + visited_left_side: &MutableBitmap, schema: &SchemaRef, column_indices: &[ColumnIndex], left_data: &JoinLeftData, unmatched: bool, ) -> ArrowResult { - // Find indices which didn't match any right row (are false) let indices = if unmatched { - visited_left_side - .iter() - .enumerate() - .filter(|&(_, &value)| !value) - .map(|(index, _)| index as u64) - .collect::>() + Buffer::from_iter( + (0..visited_left_side.len()) + .filter_map(|v| (!visited_left_side.get(v)).then(|| v as u64)), + ) } else { - // produce those that did match - visited_left_side - .iter() - .enumerate() - .filter(|&(_, &value)| value) - .map(|(index, _)| index as u64) - .collect::>() + Buffer::from_iter( + (0..visited_left_side.len()) + .filter_map(|v| (visited_left_side.get(v)).then(|| v as u64)), + ) }; // generate batches by taking values from the left side and generating columns filled with null on the right side - let indices = UInt64Array::from_data(DataType::UInt64, indices.into(), None); + let indices = UInt64Array::from_data(DataType::UInt64, indices, None); let num_rows = indices.len(); let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); @@ -949,7 +944,7 @@ impl Stream for HashJoinStream { | JoinType::Semi | JoinType::Anti => { left_side.iter().flatten().for_each(|x| { - self.visited_left_side[*x as usize] = true; + self.visited_left_side.set(*x as usize, true); }); } JoinType::Inner | JoinType::Right => {} diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index b334c5f2f7c0..4365c8af0a4c 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -17,498 +17,515 @@ //! Functionality used both on logical and physical plans -use crate::error::{DataFusionError, Result}; +use crate::error::Result; pub use ahash::{CallHasher, RandomState}; -use arrow::array::{ - Array, ArrayRef, BooleanArray, DictionaryArray, DictionaryKey, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, Utf8Array, -}; -use arrow::datatypes::{DataType, IntegerType, TimeUnit}; -use std::sync::Arc; - -// Combines two hashes into one hash -#[inline] -fn combine_hashes(l: u64, r: u64) -> u64 { - let hash = (17 * 37u64).wrapping_add(l); - hash.wrapping_mul(37).wrapping_add(r) -} +use arrow::array::ArrayRef; -macro_rules! hash_array { - ($array_type:ty, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - if array.null_count() == 0 { - if $multi_col { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = combine_hashes( - $ty::get_hash(&array.value(i), $random_state), - *hash, - ); +#[cfg(not(feature = "force_hash_collisions"))] +mod noforce_hash_collisions { + use super::{ArrayRef, CallHasher, RandomState, Result}; + use crate::error::DataFusionError; + use arrow::array::{Array, DictionaryArray, DictionaryKey}; + use arrow::array::{ + BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, + }; + use arrow::datatypes::{DataType, IntegerType, TimeUnit}; + use std::sync::Arc; + + type StringArray = Utf8Array; + type LargeStringArray = Utf8Array; + + macro_rules! hash_array_float { + ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + if $multi_col { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = combine_hashes( + $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ), + *hash, + ); + } + } else { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ) + } } } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = $ty::get_hash(&array.value(i), $random_state); + if $multi_col { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = combine_hashes( + $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ), + *hash, + ); + } + } + } else { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = $ty::get_hash( + &$ty::from_le_bytes(value.to_le_bytes()), + $random_state, + ); + } + } } } - } else { - if $multi_col { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { + }; + } + + macro_rules! hash_array { + ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + if $multi_col { + for (i, hash) in $hashes.iter_mut().enumerate() { *hash = combine_hashes( $ty::get_hash(&array.value(i), $random_state), *hash, ); } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $ty::get_hash(&array.value(i), $random_state); + } } } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = $ty::get_hash(&array.value(i), $random_state); + if $multi_col { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = combine_hashes( + $ty::get_hash(&array.value(i), $random_state), + *hash, + ); + } + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $ty::get_hash(&array.value(i), $random_state); + } } } } - } - }; -} + }; + } -macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); + macro_rules! hash_array_primitive { + ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); - if array.null_count() == 0 { - if $multi_col { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = combine_hashes($ty::get_hash(value, $random_state), *hash); - } - } else { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = $ty::get_hash(value, $random_state) - } - } - } else { - if $multi_col { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { + if array.null_count() == 0 { + if $multi_col { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { *hash = combine_hashes($ty::get_hash(value, $random_state), *hash); } - } - } else { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = $ty::get_hash(value, $random_state); + } else { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $ty::get_hash(value, $random_state) } } - } - } - }; -} - -macro_rules! hash_array_float { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); - - if array.null_count() == 0 { - if $multi_col { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = combine_hashes( - $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ), - *hash, - ); - } } else { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ) - } - } - } else { - if $multi_col { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = combine_hashes( - $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ), - *hash, - ); + if $multi_col { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = combine_hashes( + $ty::get_hash(value, $random_state), + *hash, + ); + } } - } - } else { - for (i, (hash, value)) in - $hashes.iter_mut().zip(values.iter()).enumerate() - { - if !array.is_null(i) { - *hash = $ty::get_hash( - &$ty::from_le_bytes(value.to_le_bytes()), - $random_state, - ); + } else { + for (i, (hash, value)) in + $hashes.iter_mut().zip(values.iter()).enumerate() + { + if !array.is_null(i) { + *hash = $ty::get_hash(value, $random_state); + } } } } - } - }; -} - -/// Hash the values in a dictionary array -fn create_hashes_dictionary( - array: &ArrayRef, - random_state: &RandomState, - hashes_buffer: &mut Vec, - multi_col: bool, -) -> Result<()> { - let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // Hash each dictionary value once, and then use that computed - // hash for each key value to avoid a potentially expensive - // redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(dict_array.values()); - let mut dict_hashes = vec![0; dict_values.len()]; - create_hashes(&[dict_values], random_state, &mut dict_hashes)?; - - // combine hash for each index in values - if multi_col { - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key - .to_usize() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, dict_array.data_type() - )) - })?; - *hash = combine_hashes(dict_hashes[idx], *hash) - } // no update for Null, consistent with other hashes - } - } else { - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key - .to_usize() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] - } // no update for Null, consistent with other hashes - } + }; } - Ok(()) -} -/// Test version of `create_hashes` that produces the same value for -/// all hashes (to test collisions) -/// -/// See comments on `hashes_buffer` for more details -#[cfg(feature = "force_hash_collisions")] -pub fn create_hashes<'a>( - _arrays: &[ArrayRef], - _random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for hash in hashes_buffer.iter_mut() { - *hash = 0 + // Combines two hashes into one hash + #[inline] + fn combine_hashes(l: u64, r: u64) -> u64 { + let hash = (17 * 37u64).wrapping_add(l); + hash.wrapping_mul(37).wrapping_add(r) } - return Ok(hashes_buffer); -} -/// Creates hash values for every row, based on the values in the -/// columns. -/// -/// The number of rows to hash is determined by `hashes_buffer.len()`. -/// `hashes_buffer` should be pre-sized appropriately -#[cfg(not(feature = "force_hash_collisions"))] -pub fn create_hashes<'a>( - arrays: &[ArrayRef], - random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - // combine hashes with `combine_hashes` if we have more than 1 column - let multi_col = arrays.len() > 1; - - for col in arrays { - match col.data_type() { - DataType::UInt8 => { - hash_array_primitive!( - UInt8Array, - col, - u8, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::UInt16 => { - hash_array_primitive!( - UInt16Array, - col, - u16, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::UInt32 => { - hash_array_primitive!( - UInt32Array, - col, - u32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::UInt64 => { - hash_array_primitive!( - UInt64Array, - col, - u64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int8 => { - hash_array_primitive!( - Int8Array, - col, - i8, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int16 => { - hash_array_primitive!( - Int16Array, - col, - i16, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int32 => { - hash_array_primitive!( - Int32Array, - col, - i32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Int64 => { - hash_array_primitive!( - Int64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); + /// Hash the values in a dictionary array + fn create_hashes_dictionary( + array: &ArrayRef, + random_state: &RandomState, + hashes_buffer: &mut Vec, + multi_col: bool, + ) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // Hash each dictionary value once, and then use that computed + // hash for each key value to avoid a potentially expensive + // redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + let mut dict_hashes = vec![0; dict_values.len()]; + create_hashes(&[dict_values], random_state, &mut dict_hashes)?; + + // combine hash for each index in values + if multi_col { + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key + .to_usize() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, dict_array.data_type() + )) + })?; + *hash = combine_hashes(dict_hashes[idx], *hash) + } // no update for Null, consistent with other hashes } - DataType::Float32 => { - hash_array_float!( - Float32Array, - col, - u32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Float64 => { - hash_array_float!( - Float64Array, - col, - u64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Timestamp(TimeUnit::Millisecond, None) => { - hash_array_primitive!( - Int64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Timestamp(TimeUnit::Microsecond, None) => { - hash_array_primitive!( - Int64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Timestamp(TimeUnit::Nanosecond, None) => { - hash_array_primitive!( - Int64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Date32 => { - hash_array_primitive!( - Int32Array, - col, - i32, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Date64 => { - hash_array_primitive!( - Int64Array, - col, - i64, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Boolean => { - hash_array!( - BooleanArray, - col, - u8, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::Utf8 => { - hash_array!( - Utf8Array::, - col, - str, - hashes_buffer, - random_state, - multi_col - ); - } - DataType::LargeUtf8 => { - hash_array!( - Utf8Array::, - col, - str, - hashes_buffer, - random_state, - multi_col - ); + } else { + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key + .to_usize() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes } - DataType::Dictionary(index_type, _) => match index_type { - IntegerType::Int8 => { - create_hashes_dictionary::( + } + Ok(()) + } + + /// Creates hash values for every row, based on the values in the + /// columns. + /// + /// The number of rows to hash is determined by `hashes_buffer.len()`. + /// `hashes_buffer` should be pre-sized appropriately + pub fn create_hashes<'a>( + arrays: &[ArrayRef], + random_state: &RandomState, + hashes_buffer: &'a mut Vec, + ) -> Result<&'a mut Vec> { + // combine hashes with `combine_hashes` if we have more than 1 column + let multi_col = arrays.len() > 1; + + for col in arrays { + match col.data_type() { + DataType::UInt8 => { + hash_array_primitive!( + UInt8Array, col, + u8, + hashes_buffer, random_state, + multi_col + ); + } + DataType::UInt16 => { + hash_array_primitive!( + UInt16Array, + col, + u16, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::Int16 => { - create_hashes_dictionary::( + DataType::UInt32 => { + hash_array_primitive!( + UInt32Array, col, + u32, + hashes_buffer, random_state, + multi_col + ); + } + DataType::UInt64 => { + hash_array_primitive!( + UInt64Array, + col, + u64, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::Int32 => { - create_hashes_dictionary::( + DataType::Int8 => { + hash_array_primitive!( + Int8Array, col, + i8, + hashes_buffer, random_state, + multi_col + ); + } + DataType::Int16 => { + hash_array_primitive!( + Int16Array, + col, + i16, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::Int64 => { - create_hashes_dictionary::( + DataType::Int32 => { + hash_array_primitive!( + Int32Array, col, + i32, + hashes_buffer, random_state, + multi_col + ); + } + DataType::Int64 => { + hash_array_primitive!( + Int64Array, + col, + i64, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::UInt8 => { - create_hashes_dictionary::( + DataType::Float32 => { + hash_array_float!( + Float32Array, col, + u32, + hashes_buffer, random_state, + multi_col + ); + } + DataType::Float64 => { + hash_array_float!( + Float64Array, + col, + u64, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::UInt16 => { - create_hashes_dictionary::( + DataType::Timestamp(TimeUnit::Millisecond, None) => { + hash_array_primitive!( + Int64Array, col, + i64, + hashes_buffer, random_state, + multi_col + ); + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + hash_array_primitive!( + Int64Array, + col, + i64, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::UInt32 => { - create_hashes_dictionary::( + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + hash_array_primitive!( + Int64Array, col, + i64, + hashes_buffer, random_state, + multi_col + ); + } + DataType::Date32 => { + hash_array_primitive!( + Int32Array, + col, + i32, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); } - IntegerType::UInt64 => { - create_hashes_dictionary::( + DataType::Date64 => { + hash_array_primitive!( + Int64Array, col, + i64, + hashes_buffer, random_state, + multi_col + ); + } + DataType::Boolean => { + hash_array!( + BooleanArray, + col, + u8, hashes_buffer, - multi_col, - )?; + random_state, + multi_col + ); + } + DataType::Utf8 => { + hash_array!( + StringArray, + col, + str, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::LargeUtf8 => { + hash_array!( + LargeStringArray, + col, + str, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Dictionary(index_type, _, _) => match index_type { + IntegerType::Int8 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::Int16 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::Int32 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::Int64 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::UInt8 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::UInt16 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::UInt32 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + IntegerType::UInt64 => { + create_hashes_dictionary::( + col, + random_state, + hashes_buffer, + multi_col, + )?; + } + }, + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {:?}", + col.data_type() + ))); } - }, - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {}", - col.data_type() - ))); } } + Ok(hashes_buffer) + } +} + +/// Test version of `create_hashes` that produces the same value for +/// all hashes (to test collisions) +/// +/// See comments on `hashes_buffer` for more details +#[cfg(feature = "force_hash_collisions")] +pub fn create_hashes<'a>( + _arrays: &[ArrayRef], + _random_state: &RandomState, + hashes_buffer: &'a mut Vec, +) -> Result<&'a mut Vec> { + for hash in hashes_buffer.iter_mut() { + *hash = 0 } Ok(hashes_buffer) } +#[cfg(not(feature = "force_hash_collisions"))] +pub use noforce_hash_collisions::create_hashes; + #[cfg(test)] mod tests { + use crate::error::Result; use std::sync::Arc; - use arrow::array::TryExtend; - use arrow::array::{MutableDictionaryArray, MutableUtf8Array}; + use arrow::array::{Float32Array, Float64Array}; + #[cfg(not(feature = "force_hash_collisions"))] + use arrow::array::{MutableDictionaryArray, MutableUtf8Array, TryExtend, Utf8Array}; use super::*; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 86490b786b06..817f4caa33dc 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -535,7 +535,7 @@ impl DefaultPhysicalPlanner { let contains_dict = groups .iter() .flat_map(|x| x.0.data_type(physical_input_schema.as_ref())) - .any(|x| matches!(x, DataType::Dictionary(_, _))); + .any(|x| matches!(x, DataType::Dictionary(_, _, _))); let can_repartition = !groups.is_empty() && ctx_state.config.target_partitions > 1 @@ -632,6 +632,7 @@ impl DefaultPhysicalPlanner { let physical_input = self.create_initial_plan(input, ctx_state).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); + let runtime_expr = self.create_physical_expr( predicate, input_dfschema, @@ -1624,7 +1625,7 @@ mod tests { Err(e) => assert!( e.to_string().contains(expected_error), "Error '{}' did not contain expected error '{}'", - e.to_string(), + e, expected_error ), } @@ -1671,7 +1672,7 @@ mod tests { Err(e) => assert!( e.to_string().contains(expected_error), "Error '{}' did not contain expected error '{}'", - e.to_string(), + e, expected_error ), } @@ -1730,7 +1731,7 @@ mod tests { Err(e) => assert!( e.to_string().contains(expected_error), "Error '{}' did not contain expected error '{}'", - e.to_string(), + e, expected_error ), } diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index 5aa0c040dd3d..7b78a442e6c6 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -30,7 +30,7 @@ use crate::physical_plan::{ ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::{Field, Metadata, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -64,13 +64,17 @@ impl ProjectionExec { let fields: Result> = expr .iter() - .map(|(e, name)| match input_schema.field_with_name(name) { - Ok(f) => Ok(f.clone()), - Err(_) => { - let dt = e.data_type(&input_schema)?; - let nullable = e.nullable(&input_schema)?; - Ok(Field::new(name, dt, nullable)) + .map(|(e, name)| { + let mut field = Field::new( + name, + e.data_type(&input_schema)?, + e.nullable(&input_schema)?, + ); + if let Some(metadata) = get_field_metadata(e, &input_schema) { + field = field.with_metadata(metadata); } + + Ok(field) }) .collect(); @@ -177,6 +181,24 @@ impl ExecutionPlan for ProjectionExec { } } +/// If e is a direct column reference, returns the field level +/// metadata for that field, if any. Otherwise returns None +fn get_field_metadata( + e: &Arc, + input_schema: &Schema, +) -> Option { + let name = if let Some(column) = e.as_any().downcast_ref::() { + column.name() + } else { + return None; + }; + + input_schema + .field_with_name(name) + .ok() + .map(|f| f.metadata().clone()) +} + fn stats_projection( stats: Statistics, exprs: impl Iterator>, @@ -298,7 +320,7 @@ mod tests { )?; let col_field = projection.schema.field(0); - let col_metadata = col_field.metadata().clone().unwrap().clone(); + let col_metadata = col_field.metadata().clone(); let data: &str = &col_metadata["testing"]; assert_eq!(data, "test"); diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 2137a8b0780a..5bd2f82f07ce 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -351,6 +351,9 @@ impl RepartitionExec { for (num_output_partition, partition_indices) in indices.into_iter().enumerate() { + if partition_indices.is_empty() { + continue; + } let timer = r_metrics.repart_time.timer(); let indices = UInt64Array::from_slice(&partition_indices); // Produce batches based on indices @@ -580,7 +583,10 @@ mod tests { ) .await?; - let total_rows: usize = output_partitions.iter().map(|x| x.len()).sum(); + let total_rows: usize = output_partitions + .iter() + .map(|x| x.iter().map(|x| x.num_rows()).sum::()) + .sum(); assert_eq!(8, output_partitions.len()); assert_eq!(total_rows, 8 * 50 * 3); @@ -955,4 +961,32 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn hash_repartition_avoid_empty_batch() -> Result<()> { + let batch = RecordBatch::try_from_iter(vec![( + "a", + Arc::new(StringArray::from_slice(vec!["foo"])) as ArrayRef, + )]) + .unwrap(); + let partitioning = Partitioning::Hash( + vec![Arc::new(crate::physical_plan::expressions::Column::new( + "a", 0, + ))], + 2, + ); + let schema = batch.schema().clone(); + let input = MockExec::new(vec![Ok(batch)], schema.clone()); + let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); + let output_stream0 = exec.execute(0).await.unwrap(); + let batch0 = crate::physical_plan::common::collect(output_stream0) + .await + .unwrap(); + let output_stream1 = exec.execute(1).await.unwrap(); + let batch1 = crate::physical_plan::common::collect(output_stream1) + .await + .unwrap(); + assert!(batch0.is_empty() || batch1.is_empty()); + Ok(()) + } } diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index bf521bb7c1fc..3700380fdb72 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -302,6 +302,8 @@ impl RecordBatchStream for SortStream { #[cfg(test)] mod tests { + use std::collections::{BTreeMap, HashMap}; + use super::*; use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -385,6 +387,54 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_metadata() -> Result<()> { + let field_metadata: BTreeMap = + vec![("foo".to_string(), "bar".to_string())] + .into_iter() + .collect(); + let schema_metadata: HashMap = + vec![("baz".to_string(), "barf".to_string())] + .into_iter() + .collect(); + + let mut field = Field::new("field_name", DataType::UInt64, true); + field = field.with_metadata(field_metadata.clone()); + let schema = Schema::new_from(vec![field], schema_metadata.clone()); + let schema = Arc::new(schema); + + let data: ArrayRef = + Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); + + let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); + let input = + Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + + let sort_exec = Arc::new(SortExec::try_new( + vec![PhysicalSortExpr { + expr: col("field_name", &schema)?, + options: SortOptions::default(), + }], + input, + )?); + + let result: Vec = collect(sort_exec).await?; + + let expected_data: ArrayRef = + Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); + let expected_batch = + RecordBatch::try_new(schema.clone(), vec![expected_data]).unwrap(); + + // Data is correct + assert_eq!(&vec![expected_batch], &result); + + // explicitlty ensure the metadata is present + assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata); + assert_eq!(result[0].schema().metadata(), &schema_metadata); + + Ok(()) + } + #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index 8e47ed60ea2b..abc75829ea17 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -32,8 +32,8 @@ pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, - lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_replace, - repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, sha512, - split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, Column, - JoinType, Partitioning, + lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, + regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, + sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, + Column, JoinType, Partitioning, }; diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs index da05d63d8c2c..d06e37f9e770 100644 --- a/datafusion/src/pyarrow.rs +++ b/datafusion/src/pyarrow.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::Array; +use arrow::error::ArrowError; use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; use pyo3::types::PyList; use pyo3::PyNativeType; +use std::sync::Arc; -use crate::arrow::array::ArrayData; -use crate::arrow::pyarrow::PyArrowConvert; use crate::error::DataFusionError; use crate::scalar::ScalarValue; @@ -31,8 +33,46 @@ impl From for PyErr { } } -impl PyArrowConvert for ScalarValue { - fn from_pyarrow(value: &PyAny) -> PyResult { +impl From for PyErr { + fn from(err: PyO3ArrowError) -> PyErr { + PyException::new_err(format!("{:?}", err)) + } +} + +#[derive(Debug)] +enum PyO3ArrowError { + ArrowError(ArrowError), +} + +fn to_rust_array(ob: PyObject, py: Python) -> PyResult> { + // prepare a pointer to receive the Array struct + let array = Box::new(arrow::ffi::Ffi_ArrowArray::empty()); + let schema = Box::new(arrow::ffi::Ffi_ArrowSchema::empty()); + + let array_ptr = &*array as *const arrow::ffi::Ffi_ArrowArray; + let schema_ptr = &*schema as *const arrow::ffi::Ffi_ArrowSchema; + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds + ob.call_method1( + py, + "_export_to_c", + (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), + )?; + + let field = unsafe { + arrow::ffi::import_field_from_c(schema.as_ref()) + .map_err(PyO3ArrowError::ArrowError)? + }; + let array = unsafe { + arrow::ffi::import_array_from_c(array, &field) + .map_err(PyO3ArrowError::ArrowError)? + }; + + Ok(array.into()) +} +impl<'source> FromPyObject<'source> for ScalarValue { + fn extract(value: &'source PyAny) -> PyResult { let py = value.py(); let typ = value.getattr("type")?; let val = value.call_method0("as_py")?; @@ -42,26 +82,16 @@ impl PyArrowConvert for ScalarValue { let args = PyList::new(py, &[val]); let array = factory.call1((args, typ))?; - // convert the pyarrow array to rust array using C data interface - let array = array.extract::()?; - let scalar = ScalarValue::try_from_array(&array.into(), 0)?; + // convert the pyarrow array to rust array using C data interface] + let array = to_rust_array(array.to_object(py), py)?; + let scalar = ScalarValue::try_from_array(&array, 0)?; Ok(scalar) } - - fn to_pyarrow(&self, _py: Python) -> PyResult { - Err(PyNotImplementedError::new_err("Not implemented")) - } -} - -impl<'source> FromPyObject<'source> for ScalarValue { - fn extract(value: &'source PyAny) -> PyResult { - Self::from_pyarrow(value) - } } impl<'a> IntoPy for ScalarValue { - fn into_py(self, py: Python) -> PyObject { - self.to_pyarrow(py).unwrap() + fn into_py(self, _py: Python) -> PyObject { + Err(PyNotImplementedError::new_err("Not implemented")).unwrap() } } diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 7bcd41bc6868..ea447a746cc7 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -27,7 +27,6 @@ use arrow::compute::concatenate; use arrow::datatypes::DataType::Decimal; use arrow::{ array::*, - buffer::MutableBuffer, datatypes::{DataType, Field, IntegerType, IntervalUnit, TimeUnit}, scalar::{PrimitiveScalar, Scalar}, types::{days_ms, NativeType}, @@ -44,6 +43,11 @@ type LargeBinaryArray = BinaryArray; type MutableStringArray = MutableUtf8Array; type MutableLargeStringArray = MutableUtf8Array; +// TODO may need to be moved to arrow-rs +/// The max precision and scale for decimal128 +pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; +pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; + /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] @@ -88,13 +92,13 @@ pub enum ScalarValue { /// Date stored as a signed 64bit int Date64(Option), /// Timestamp Second - TimestampSecond(Option), + TimestampSecond(Option, Option), /// Timestamp Milliseconds - TimestampMillisecond(Option), + TimestampMillisecond(Option, Option), /// Timestamp Microseconds - TimestampMicrosecond(Option), + TimestampMicrosecond(Option, Option), /// Timestamp Nanoseconds - TimestampNanosecond(Option), + TimestampNanosecond(Option, Option), /// Interval with YearMonth unit IntervalYearMonth(Option), /// Interval with DayTime unit @@ -161,14 +165,14 @@ impl PartialEq for ScalarValue { (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), (Date64(_), _) => false, - (TimestampSecond(v1), TimestampSecond(v2)) => v1.eq(v2), - (TimestampSecond(_), _) => false, - (TimestampMillisecond(v1), TimestampMillisecond(v2)) => v1.eq(v2), - (TimestampMillisecond(_), _) => false, - (TimestampMicrosecond(v1), TimestampMicrosecond(v2)) => v1.eq(v2), - (TimestampMicrosecond(_), _) => false, - (TimestampNanosecond(v1), TimestampNanosecond(v2)) => v1.eq(v2), - (TimestampNanosecond(_), _) => false, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.eq(v2), + (TimestampSecond(_, _), _) => false, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.eq(v2), + (TimestampMillisecond(_, _), _) => false, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.eq(v2), + (TimestampMicrosecond(_, _), _) => false, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), + (TimestampNanosecond(_, _), _) => false, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), (IntervalYearMonth(_), _) => false, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), @@ -247,15 +251,21 @@ impl PartialOrd for ScalarValue { (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), (Date64(_), _) => None, - (TimestampSecond(v1), TimestampSecond(v2)) => v1.partial_cmp(v2), - (TimestampSecond(_), _) => None, - (TimestampMillisecond(v1), TimestampMillisecond(v2)) => v1.partial_cmp(v2), - (TimestampMillisecond(_), _) => None, - (TimestampMicrosecond(v1), TimestampMicrosecond(v2)) => v1.partial_cmp(v2), - (TimestampMicrosecond(_), _) => None, - (TimestampNanosecond(v1), TimestampNanosecond(v2)) => v1.partial_cmp(v2), - (TimestampNanosecond(_), _) => None, - (_, IntervalYearMonth(_)) => None, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), + (TimestampSecond(_, _), _) => None, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMillisecond(_, _), _) => None, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMicrosecond(_, _), _) => None, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampNanosecond(_, _), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), (IntervalYearMonth(_), _) => None, (_, IntervalDayTime(_)) => None, (IntervalDayTime(_), _) => None, @@ -311,10 +321,10 @@ impl std::hash::Hash for ScalarValue { } Date32(v) => v.hash(state), Date64(v) => v.hash(state), - TimestampSecond(v) => v.hash(state), - TimestampMillisecond(v) => v.hash(state), - TimestampMicrosecond(v) => v.hash(state), - TimestampNanosecond(v) => v.hash(state), + TimestampSecond(v, _) => v.hash(state), + TimestampMillisecond(v, _) => v.hash(state), + TimestampMicrosecond(v, _) => v.hash(state), + TimestampNanosecond(v, _) => v.hash(state), IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), Struct(v, t) => { @@ -350,6 +360,19 @@ fn get_dict_value( Ok((dict_array.values(), Some(values_index))) } +macro_rules! typed_cast_tz { + ($array:expr, $index:expr, $SCALAR:ident, $TZ:expr) => {{ + let array = $array.as_any().downcast_ref::().unwrap(); + ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }, + $TZ.clone(), + ) + }}; +} + macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); @@ -381,8 +404,8 @@ macro_rules! build_list { } macro_rules! build_timestamp_list { - ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{ - let child_dt = DataType::Timestamp($TIME_UNIT, $TIME_ZONE); + ($TIME_UNIT:expr, $VALUES:expr, $SIZE:expr, $TZ:expr) => {{ + let child_dt = DataType::Timestamp($TIME_UNIT, $TZ.clone()); match $VALUES { // the return on the macro is necessary, to short-circuit and return ArrayRef None => { @@ -404,16 +427,16 @@ macro_rules! build_timestamp_list { match $TIME_UNIT { TimeUnit::Second => { - build_values_list!(array, TimestampSecond, values, $SIZE) + build_values_list_tz!(array, TimestampSecond, values, $SIZE) } TimeUnit::Microsecond => { - build_values_list!(array, TimestampMillisecond, values, $SIZE) + build_values_list_tz!(array, TimestampMillisecond, values, $SIZE) } TimeUnit::Millisecond => { - build_values_list!(array, TimestampMicrosecond, values, $SIZE) + build_values_list_tz!(array, TimestampMicrosecond, values, $SIZE) } TimeUnit::Nanosecond => { - build_values_list!(array, TimestampNanosecond, values, $SIZE) + build_values_list_tz!(array, TimestampNanosecond, values, $SIZE) } } } @@ -445,13 +468,32 @@ macro_rules! dyn_to_array { ($self:expr, $value:expr, $size:expr, $ty:ty) => {{ Arc::new(PrimitiveArray::<$ty>::from_data( $self.get_datatype(), - MutableBuffer::<$ty>::from_trusted_len_iter(repeat(*$value).take($size)) - .into(), + Buffer::<$ty>::from_iter(repeat(*$value).take($size)), None, )) }}; } +macro_rules! build_values_list_tz { + ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + for _ in 0..$SIZE { + let mut vec = vec![]; + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(v, _) => { + vec.push(v.clone()); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + $MUTABLE_ARR.try_push(Some(vec)).unwrap(); + } + + let array: ListArray = $MUTABLE_ARR.into(); + Arc::new(array) + }}; +} + macro_rules! eq_array_primitive { ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); @@ -464,20 +506,315 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Return true if the value is numeric + pub fn is_numeric(&self) -> bool { + matches!( + self, + ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) + } + + /// Add two numeric ScalarValues + pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Addition only supports numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Addition does not support empty values".to_string(), + )); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal types + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { + Err(DataFusionError::Internal( + "Addition with Decimals are not supported for now".to_string() + )) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() + f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() + f2.unwrap()))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() as i64 + f2.unwrap() as i64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Int32(Some(f1.unwrap() as i32 + f2.unwrap() as i32))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Int16(Some(f1.unwrap() as i16 + f2.unwrap() as i16))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::UInt32(Some(f1.unwrap() as u32 + f2.unwrap() as u32))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 + f2.unwrap() as u16))) + }, + _ => Err(DataFusionError::Internal( + format!( + "Addition only support calculation with the same type or f64 as one of the numbers for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + + /// Multiply two numeric ScalarValues + pub fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Multiplication is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Multiplication does not support empty values".to_string(), + )); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal type + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) + | (_, ScalarValue::Decimal128(_, _, _)) => Err(DataFusionError::Internal( + "Multiplication with Decimals are not supported for now".to_string(), + )), + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) + } + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => Ok( + ScalarValue::Float64(Some(f1.unwrap() as f64 * f2.unwrap() as f64)), + ), + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() * f2.unwrap()))) + } + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => Ok(ScalarValue::Int64( + Some(f1.unwrap() as i64 * f2.unwrap() as i64), + )), + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => Ok(ScalarValue::Int32( + Some(f1.unwrap() as i32 * f2.unwrap() as i32), + )), + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => Ok(ScalarValue::Int16( + Some(f1.unwrap() as i16 * f2.unwrap() as i16), + )), + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => Ok( + ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), + ), + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => Ok( + ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), + ), + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => Ok( + ScalarValue::UInt32(Some(f1.unwrap() as u32 * f2.unwrap() as u32)), + ), + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => Ok(ScalarValue::UInt16( + Some(f1.unwrap() as u16 * f2.unwrap() as u16), + )), + _ => Err(DataFusionError::Internal(format!( + "Multiplication only support f64 for now, here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))), + } + } + + /// Division between two numeric ScalarValues + pub fn div(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Division is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Division does not support empty values".to_string(), + )); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal types + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { + Err(DataFusionError::Internal( + "Division with Decimals are not supported for now".to_string() + )) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() / f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap()))) + }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + _ => Err(DataFusionError::Internal( + format!( + "Division only support calculation with the same type or f64 as denominator for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + /// Create null scalar value for specific data type. pub fn new_null(dt: DataType) -> Self { match dt { DataType::Timestamp(TimeUnit::Second, _) => { - ScalarValue::TimestampSecond(None) + ScalarValue::TimestampSecond(None, None) } DataType::Timestamp(TimeUnit::Millisecond, _) => { - ScalarValue::TimestampMillisecond(None) + ScalarValue::TimestampMillisecond(None, None) } DataType::Timestamp(TimeUnit::Microsecond, _) => { - ScalarValue::TimestampMicrosecond(None) + ScalarValue::TimestampMicrosecond(None, None) } DataType::Timestamp(TimeUnit::Nanosecond, _) => { - ScalarValue::TimestampNanosecond(None) + ScalarValue::TimestampNanosecond(None, None) } _ => todo!("Create null scalar value for datatype: {:?}", dt), } @@ -490,8 +827,7 @@ impl ScalarValue { scale: usize, ) -> Result { // make sure the precision and scale is valid - // TODO const the max precision and min scale - if precision <= 38 && scale <= precision { + if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); } return Err(DataFusionError::Internal(format!( @@ -515,17 +851,17 @@ impl ScalarValue { ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal(*precision, *scale) } - ScalarValue::TimestampSecond(_) => { - DataType::Timestamp(TimeUnit::Second, None) + ScalarValue::TimestampSecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) } - ScalarValue::TimestampMillisecond(_) => { - DataType::Timestamp(TimeUnit::Millisecond, None) + ScalarValue::TimestampMillisecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) } - ScalarValue::TimestampMicrosecond(_) => { - DataType::Timestamp(TimeUnit::Microsecond, None) + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) } - ScalarValue::TimestampNanosecond(_) => { - DataType::Timestamp(TimeUnit::Nanosecond, None) + ScalarValue::TimestampNanosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) } ScalarValue::Float32(_) => DataType::Float32, ScalarValue::Float64(_) => DataType::Float64, @@ -590,9 +926,10 @@ impl ScalarValue { | ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::List(None, _) - | ScalarValue::TimestampMillisecond(None) - | ScalarValue::TimestampMicrosecond(None) - | ScalarValue::TimestampNanosecond(None) + | ScalarValue::TimestampSecond(None, _) + | ScalarValue::TimestampMillisecond(None, _) + | ScalarValue::TimestampMicrosecond(None, _) + | ScalarValue::TimestampNanosecond(None, _) | ScalarValue::Struct(None, _) | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. ) @@ -665,13 +1002,34 @@ impl ScalarValue { data_type, sv ))) } - }) - .collect::>>()?.to($DT) + }).collect::>>()?.to($DT) ) as Box } }}; } + macro_rules! build_array_primitive_tz { + ($SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + + Box::new(array) + } + }}; + } + /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for "string-like" types. macro_rules! build_array_string { @@ -775,17 +1133,17 @@ impl ScalarValue { LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), Date32 => build_array_primitive!(i32, Date32, Date32), Date64 => build_array_primitive!(i64, Date64, Date64), - Timestamp(TimeUnit::Second, None) => { - build_array_primitive!(i64, TimestampSecond, data_type) + Timestamp(TimeUnit::Second, _) => { + build_array_primitive_tz!(TimestampSecond) } - Timestamp(TimeUnit::Millisecond, None) => { - build_array_primitive!(i64, TimestampMillisecond, data_type) + Timestamp(TimeUnit::Millisecond, _) => { + build_array_primitive_tz!(TimestampMillisecond) } - Timestamp(TimeUnit::Microsecond, None) => { - build_array_primitive!(i64, TimestampMicrosecond, data_type) + Timestamp(TimeUnit::Microsecond, _) => { + build_array_primitive_tz!(TimestampMicrosecond) } - Timestamp(TimeUnit::Nanosecond, None) => { - build_array_primitive!(i64, TimestampNanosecond, data_type) + Timestamp(TimeUnit::Nanosecond, _) => { + build_array_primitive_tz!(TimestampNanosecond) } Interval(IntervalUnit::DayTime) => { build_array_primitive!(days_ms, IntervalDayTime, data_type) @@ -978,7 +1336,9 @@ impl ScalarValue { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef } ScalarValue::Float64(e) => match e { - Some(value) => dyn_to_array!(self, value, size, f64), + Some(value) => { + dyn_to_array!(self, value, size, f64) + } None => new_null_array(self.get_datatype(), size).into(), }, ScalarValue::Float32(e) => match e { @@ -999,12 +1359,7 @@ impl ScalarValue { Some(value) => dyn_to_array!(self, value, size, i32), None => new_null_array(self.get_datatype(), size).into(), }, - ScalarValue::Int64(e) - | ScalarValue::Date64(e) - | ScalarValue::TimestampSecond(e) - | ScalarValue::TimestampMillisecond(e) - | ScalarValue::TimestampMicrosecond(e) - | ScalarValue::TimestampNanosecond(e) => match e { + ScalarValue::Int64(e) | ScalarValue::Date64(e) => match e { Some(value) => dyn_to_array!(self, value, size, i64), None => new_null_array(self.get_datatype(), size).into(), }, @@ -1024,6 +1379,23 @@ impl ScalarValue { Some(value) => dyn_to_array!(self, value, size, u64), None => new_null_array(self.get_datatype(), size).into(), }, + ScalarValue::TimestampSecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampMillisecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + + ScalarValue::TimestampMicrosecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, + ScalarValue::TimestampNanosecond(e, _) => match e { + Some(value) => dyn_to_array!(self, value, size, i64), + None => new_null_array(self.get_datatype(), size).into(), + }, ScalarValue::Utf8(e) => match e { Some(value) => Arc::new(Utf8Array::::from_trusted_len_values_iter( repeat(&value).take(size), @@ -1067,7 +1439,7 @@ impl ScalarValue { DataType::Float32 => build_list!(Float32Vec, Float32, values, size), DataType::Float64 => build_list!(Float64Vec, Float64, values, size), DataType::Timestamp(unit, tz) => { - build_timestamp_list!(*unit, tz.clone(), values, size) + build_timestamp_list!(*unit, values, size, tz.clone()) } DataType::Utf8 => build_list!(MutableStringArray, Utf8, values, size), DataType::LargeUtf8 => { @@ -1169,19 +1541,19 @@ impl ScalarValue { DataType::Date64 => { typed_cast!(array, index, Int64Array, Date64) } - DataType::Timestamp(TimeUnit::Second, _) => { - typed_cast!(array, index, Int64Array, TimestampSecond) + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_cast_tz!(array, index, TimestampSecond, tz_opt) } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - typed_cast!(array, index, Int64Array, TimestampMillisecond) + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampMillisecond, tz_opt) } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - typed_cast!(array, index, Int64Array, TimestampMicrosecond) + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampMicrosecond, tz_opt) } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - typed_cast!(array, index, Int64Array, TimestampNanosecond) + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + typed_cast_tz!(array, index, TimestampNanosecond, tz_opt) } - DataType::Dictionary(index_type, _) => { + DataType::Dictionary(index_type, _, _) => { let (values, values_index) = match index_type { IntegerType::Int8 => get_dict_value::(array, index)?, IntegerType::Int16 => get_dict_value::(array, index)?, @@ -1266,7 +1638,7 @@ impl ScalarValue { /// comparisons where comparing a single row at a time is necessary. #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - if let DataType::Dictionary(key_type, _) = array.data_type() { + if let DataType::Dictionary(key_type, _, _) = array.data_type() { return self.eq_array_dictionary(array, index, key_type); } @@ -1314,16 +1686,16 @@ impl ScalarValue { ScalarValue::Date64(val) => { eq_array_primitive!(array, index, Int64Array, val) } - ScalarValue::TimestampSecond(val) => { + ScalarValue::TimestampSecond(val, _) => { eq_array_primitive!(array, index, Int64Array, val) } - ScalarValue::TimestampMillisecond(val) => { + ScalarValue::TimestampMillisecond(val, _) => { eq_array_primitive!(array, index, Int64Array, val) } - ScalarValue::TimestampMicrosecond(val) => { + ScalarValue::TimestampMicrosecond(val, _) => { eq_array_primitive!(array, index, Int64Array, val) } - ScalarValue::TimestampNanosecond(val) => { + ScalarValue::TimestampNanosecond(val, _) => { eq_array_primitive!(array, index, Int64Array, val) } ScalarValue::IntervalYearMonth(val) => { @@ -1471,10 +1843,10 @@ impl TryFrom for i64 { match value { ScalarValue::Int64(Some(inner_value)) | ScalarValue::Date64(Some(inner_value)) - | ScalarValue::TimestampNanosecond(Some(inner_value)) - | ScalarValue::TimestampMicrosecond(Some(inner_value)) - | ScalarValue::TimestampMillisecond(Some(inner_value)) - | ScalarValue::TimestampSecond(Some(inner_value)) => Ok(inner_value), + | ScalarValue::TimestampNanosecond(Some(inner_value), _) + | ScalarValue::TimestampMicrosecond(Some(inner_value), _) + | ScalarValue::TimestampMillisecond(Some(inner_value), _) + | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), _ => Err(DataFusionError::Internal(format!( "Cannot convert {:?} to {}", value, @@ -1541,25 +1913,27 @@ impl TryInto> for &ScalarValue { ScalarValue::Date64(i) => { Ok(Box::new(PrimitiveScalar::::new(DataType::Date64, *i))) } - ScalarValue::TimestampSecond(i) => Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Second, None), - *i, - ))), - ScalarValue::TimestampMillisecond(i) => { + ScalarValue::TimestampSecond(i, tz) => { + Ok(Box::new(PrimitiveScalar::::new( + DataType::Timestamp(TimeUnit::Second, tz.clone()), + *i, + ))) + } + ScalarValue::TimestampMillisecond(i, tz) => { Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), *i, ))) } - ScalarValue::TimestampMicrosecond(i) => { + ScalarValue::TimestampMicrosecond(i, tz) => { Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), *i, ))) } - ScalarValue::TimestampNanosecond(i) => { + ScalarValue::TimestampNanosecond(i, tz) => { Ok(Box::new(PrimitiveScalar::::new( - DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), *i, ))) } @@ -1583,21 +1957,21 @@ impl TryFrom> for ScalarValue { fn try_from(s: PrimitiveScalar) -> Result { match s.data_type() { - DataType::Timestamp(TimeUnit::Second, _) => { + DataType::Timestamp(TimeUnit::Second, tz) => { let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampSecond(Some(s.value()))) + Ok(ScalarValue::TimestampSecond(s.value(), tz.clone())) } - DataType::Timestamp(TimeUnit::Microsecond, _) => { + DataType::Timestamp(TimeUnit::Microsecond, tz) => { let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampMicrosecond(Some(s.value()))) + Ok(ScalarValue::TimestampMicrosecond(s.value(), tz.clone())) } - DataType::Timestamp(TimeUnit::Millisecond, _) => { + DataType::Timestamp(TimeUnit::Millisecond, tz) => { let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampMillisecond(Some(s.value()))) + Ok(ScalarValue::TimestampMillisecond(s.value(), tz.clone())) } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { let s = s.as_any().downcast_ref::>().unwrap(); - Ok(ScalarValue::TimestampNanosecond(Some(s.value()))) + Ok(ScalarValue::TimestampNanosecond(s.value(), tz.clone())) } _ => Err(DataFusionError::Internal( format!( @@ -1631,19 +2005,19 @@ impl TryFrom<&DataType> for ScalarValue { DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), DataType::Date32 => ScalarValue::Date32(None), DataType::Date64 => ScalarValue::Date64(None), - DataType::Timestamp(TimeUnit::Second, _) => { - ScalarValue::TimestampSecond(None) + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - ScalarValue::TimestampMillisecond(None) + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - ScalarValue::TimestampMicrosecond(None) + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - ScalarValue::TimestampNanosecond(None) + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) } - DataType::Dictionary(_index_type, value_type) => { + DataType::Dictionary(_index_type, value_type, _) => { value_type.as_ref().try_into()? } DataType::List(ref nested_type) => { @@ -1675,7 +2049,7 @@ impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(v, p, s) => { - write!(f, "{}", format!("{:?},{:?},{:?}", v, p, s))?; + write!(f, "{}", format_args!("{:?},{:?},{:?}", v, p, s))?; } ScalarValue::Boolean(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, @@ -1688,10 +2062,10 @@ impl fmt::Display for ScalarValue { ScalarValue::UInt16(e) => format_option!(f, e)?, ScalarValue::UInt32(e) => format_option!(f, e)?, ScalarValue::UInt64(e) => format_option!(f, e)?, - ScalarValue::TimestampSecond(e) => format_option!(f, e)?, - ScalarValue::TimestampMillisecond(e) => format_option!(f, e)?, - ScalarValue::TimestampMicrosecond(e) => format_option!(f, e)?, - ScalarValue::TimestampNanosecond(e) => format_option!(f, e)?, + ScalarValue::TimestampSecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, ScalarValue::Utf8(e) => format_option!(f, e)?, ScalarValue::LargeUtf8(e) => format_option!(f, e)?, ScalarValue::Binary(e) => match e { @@ -1763,15 +2137,17 @@ impl fmt::Debug for ScalarValue { ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), - ScalarValue::TimestampSecond(_) => write!(f, "TimestampSecond({})", self), - ScalarValue::TimestampMillisecond(_) => { - write!(f, "TimestampMillisecond({})", self) + ScalarValue::TimestampSecond(_, tz_opt) => { + write!(f, "TimestampSecond({}, {:?})", self, tz_opt) } - ScalarValue::TimestampMicrosecond(_) => { - write!(f, "TimestampMicrosecond({})", self) + ScalarValue::TimestampMillisecond(_, tz_opt) => { + write!(f, "TimestampMillisecond({}, {:?})", self, tz_opt) } - ScalarValue::TimestampNanosecond(_) => { - write!(f, "TimestampNanosecond({})", self) + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + write!(f, "TimestampMicrosecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + write!(f, "TimestampNanosecond({}, {:?})", self, tz_opt) } ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), @@ -1781,7 +2157,7 @@ impl fmt::Debug for ScalarValue { ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), - ScalarValue::List(_, dt) => write!(f, "List[{}]([{}])", dt, self), + ScalarValue::List(_, dt) => write!(f, "List[{:?}]([{}])", dt, self), ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), ScalarValue::IntervalDayTime(_) => { @@ -1812,6 +2188,7 @@ impl fmt::Debug for ScalarValue { #[cfg(test)] mod tests { use super::*; + use crate::field_util::struct_array_from; #[test] fn scalar_decimal_test() { @@ -1994,7 +2371,24 @@ mod tests { let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); - let expected: Box = Box::new($ARRAYTYPE::from($INPUT)); + let expected = $ARRAYTYPE::from($INPUT).as_box(); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they are the same + /// but for variants that carry a timezone field. + macro_rules! check_scalar_iter_tz { + ($SCALAR_T:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(*v, None)) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + + let expected: Box = Box::new(Int64Array::from($INPUT)); assert_eq!(&array, &expected); }}; @@ -2039,19 +2433,28 @@ mod tests { #[test] fn scalar_iter_to_array_boolean() { - check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); - check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); - check_scalar_iter!(Float64, Float64Array, vec![Some(1.9), None, Some(-2.1)]); + check_scalar_iter!( + Boolean, + MutableBooleanArray, + vec![Some(true), None, Some(false)] + ); + check_scalar_iter!(Float32, Float32Vec, vec![Some(1.9), None, Some(-2.1)]); + check_scalar_iter!(Float64, Float64Vec, vec![Some(1.9), None, Some(-2.1)]); + + check_scalar_iter!(Int8, Int8Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int16, Int16Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int32, Int32Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int64, Int64Vec, vec![Some(1), None, Some(3)]); - check_scalar_iter!(Int8, Int8Array, vec![Some(1), None, Some(3)]); - check_scalar_iter!(Int16, Int16Array, vec![Some(1), None, Some(3)]); - check_scalar_iter!(Int32, Int32Array, vec![Some(1), None, Some(3)]); - check_scalar_iter!(Int64, Int64Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt8, UInt8Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt16, UInt16Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt32, UInt32Vec, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt64, UInt64Vec, vec![Some(1), None, Some(3)]); - check_scalar_iter!(UInt8, UInt8Array, vec![Some(1), None, Some(3)]); - check_scalar_iter!(UInt16, UInt16Array, vec![Some(1), None, Some(3)]); - check_scalar_iter!(UInt32, UInt32Array, vec![Some(1), None, Some(3)]); - check_scalar_iter!(UInt64, UInt64Array, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampSecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampMillisecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampMicrosecond, vec![Some(1), None, Some(3)]); + check_scalar_iter_tz!(TimestampNanosecond, vec![Some(1), None, Some(3)]); check_scalar_iter_string!( Utf8, @@ -2117,7 +2520,8 @@ mod tests { #[test] fn scalar_try_from_dict_datatype() { - let data_type = DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8)); + let data_type = + DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); let data_type = &data_type; assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap()) } @@ -2127,6 +2531,10 @@ mod tests { // Since ScalarValues are used in a non trivial number of places, // making it larger means significant more memory consumption // per distinct value. + #[cfg(target_arch = "aarch64")] + assert_eq!(std::mem::size_of::(), 64); + + #[cfg(target_arch = "amd64")] assert_eq!(std::mem::size_of::(), 48); } @@ -2175,6 +2583,17 @@ mod tests { scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), } }}; + + ($INPUT:expr, $ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ + let tz = $TZ; + TestCase { + array: Arc::new($INPUT.iter().collect::<$ARRAY_TY>()), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, tz.clone())) + .collect(), + } + }}; } macro_rules! make_date_test_case { @@ -2187,13 +2606,16 @@ mod tests { } macro_rules! make_ts_test_case { - ($INPUT:expr, $ARRAY_TY:ident, $ARROW_TU:ident, $SCALAR_TY:ident) => {{ + ($INPUT:expr, $ARROW_TU:ident, $SCALAR_TY:ident, $TZ:expr) => {{ TestCase { array: Arc::new( - $ARRAY_TY::from($INPUT) - .to(DataType::Timestamp(TimeUnit::$ARROW_TU, None)), + Int64Array::from($INPUT) + .to(DataType::Timestamp(TimeUnit::$ARROW_TU, $TZ)), ), - scalars: $INPUT.iter().map(|v| ScalarValue::$SCALAR_TY(*v)).collect(), + scalars: $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_TY(*v, $TZ)) + .collect(), } }}; } @@ -2256,7 +2678,7 @@ mod tests { } }}; } - + let utc_tz = Some("UTC".to_owned()); let cases = vec![ make_test_case!(bool_vals, BooleanArray, Boolean), make_test_case!(f32_vals, Float32Array, Float32), @@ -2275,10 +2697,29 @@ mod tests { make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), make_date_test_case!(&i32_vals, Int32Array, Date32), make_date_test_case!(&i64_vals, Int64Array, Date64), - make_ts_test_case!(&i64_vals, Int64Array, Second, TimestampSecond), - make_ts_test_case!(&i64_vals, Int64Array, Millisecond, TimestampMillisecond), - make_ts_test_case!(&i64_vals, Int64Array, Microsecond, TimestampMicrosecond), - make_ts_test_case!(&i64_vals, Int64Array, Nanosecond, TimestampNanosecond), + make_ts_test_case!(&i64_vals, Second, TimestampSecond, utc_tz.clone()), + make_ts_test_case!( + &i64_vals, + Millisecond, + TimestampMillisecond, + utc_tz.clone() + ), + make_ts_test_case!( + &i64_vals, + Microsecond, + TimestampMicrosecond, + utc_tz.clone() + ), + make_ts_test_case!( + &i64_vals, + Nanosecond, + TimestampNanosecond, + utc_tz.clone() + ), + make_ts_test_case!(&i64_vals, Second, TimestampSecond, None), + make_ts_test_case!(&i64_vals, Millisecond, TimestampMillisecond, None), + make_ts_test_case!(&i64_vals, Microsecond, TimestampMicrosecond, None), + make_ts_test_case!(&i64_vals, Nanosecond, TimestampNanosecond, None), make_temporal_test_case!(&i32_vals, Int32Array, YearMonth, IntervalYearMonth), make_temporal_test_case!(days_ms_vals, DaysMsArray, DayTime, IntervalDayTime), make_str_dict_test_case!(str_vals, i8, Utf8), @@ -2423,7 +2864,11 @@ mod tests { let field_e = Field::new("e", DataType::Int16, false); let field_f = Field::new("f", DataType::Int64, false); - let field_d = Field::new("D", DataType::Struct(vec![field_e, field_f]), false); + let field_d = Field::new( + "D", + DataType::Struct(vec![field_e.clone(), field_f.clone()]), + false, + ); let scalar = ScalarValue::Struct( Some(Box::new(vec![ @@ -2435,10 +2880,15 @@ mod tests { ("f", ScalarValue::from(3i64)), ]), ])), - Box::new(vec![field_a, field_b, field_c, field_d.clone()]), + Box::new(vec![ + field_a.clone(), + field_b.clone(), + field_c.clone(), + field_d.clone(), + ]), ); - let dt = scalar.get_datatype(); - let sub_dt = field_d.data_type; + let _dt = scalar.get_datatype(); + let _sub_dt = field_d.data_type.clone(); // Check Display assert_eq!( @@ -2456,25 +2906,30 @@ mod tests { // Convert to length-2 array let array = scalar.to_array_of_size(2); - - let expected = Arc::new(StructArray::from_data( - dt.clone(), - vec![ - Arc::new(Int32Array::from_slice([23, 23])) as ArrayRef, - Arc::new(BooleanArray::from_slice([false, false])) as ArrayRef, - Arc::new(StringArray::from_slice(["Hello", "Hello"])) as ArrayRef, + let expected_vals = vec![ + (field_a.clone(), Int32Vec::from_slice(vec![23, 23]).as_arc()), + ( + field_b.clone(), + Arc::new(BooleanArray::from_slice(&vec![false, false])) as ArrayRef, + ), + ( + field_c.clone(), + Arc::new(StringArray::from_slice(&vec!["Hello", "Hello"])) as ArrayRef, + ), + ( + field_d.clone(), Arc::new(StructArray::from_data( - sub_dt.clone(), + DataType::Struct(vec![field_e.clone(), field_f.clone()]), vec![ - Arc::new(Int16Array::from_slice([2, 2])) as ArrayRef, - Arc::new(Int64Array::from_slice([3, 3])) as ArrayRef, + Int16Vec::from_slice(vec![2, 2]).as_arc(), + Int64Vec::from_slice(vec![3, 3]).as_arc(), ], None, )) as ArrayRef, - ], - None, - )) as ArrayRef; + ), + ]; + let expected = Arc::new(struct_array_from(expected_vals)) as ArrayRef; assert_eq!(&array, &expected); // Construct from second element of ArrayRef @@ -2488,7 +2943,7 @@ mod tests { // Construct with convenience From> let constructed = ScalarValue::from(vec![ - ("A", ScalarValue::from(23)), + ("A", ScalarValue::from(23i32)), ("B", ScalarValue::from(false)), ("C", ScalarValue::from("Hello")), ( @@ -2504,7 +2959,7 @@ mod tests { // Build Array from Vec of structs let scalars = vec![ ScalarValue::from(vec![ - ("A", ScalarValue::from(23)), + ("A", ScalarValue::from(23i32)), ("B", ScalarValue::from(false)), ("C", ScalarValue::from("Hello")), ( @@ -2516,7 +2971,7 @@ mod tests { ), ]), ScalarValue::from(vec![ - ("A", ScalarValue::from(7)), + ("A", ScalarValue::from(7i32)), ("B", ScalarValue::from(true)), ("C", ScalarValue::from("World")), ( @@ -2528,7 +2983,7 @@ mod tests { ), ]), ScalarValue::from(vec![ - ("A", ScalarValue::from(-1000)), + ("A", ScalarValue::from(-1000i32)), ("B", ScalarValue::from(true)), ("C", ScalarValue::from("!!!!!")), ( @@ -2542,29 +2997,34 @@ mod tests { ]; let array: ArrayRef = ScalarValue::iter_to_array(scalars).unwrap().into(); - let expected = Arc::new(StructArray::from_data( - dt, - vec![ - Arc::new(Int32Array::from_slice(&[23, 7, -1000])) as ArrayRef, - Arc::new(BooleanArray::from_slice(&[false, true, true])) as ArrayRef, - Arc::new(StringArray::from_slice(&["Hello", "World", "!!!!!"])) + let expected = Arc::new(struct_array_from(vec![ + (field_a, Int32Vec::from_slice(vec![23, 7, -1000]).as_arc()), + ( + field_b, + Arc::new(BooleanArray::from_slice(&vec![false, true, true])) as ArrayRef, + ), + ( + field_c, + Arc::new(StringArray::from_slice(&vec!["Hello", "World", "!!!!!"])) as ArrayRef, + ), + ( + field_d, Arc::new(StructArray::from_data( - sub_dt, + DataType::Struct(vec![field_e, field_f]), vec![ - Arc::new(Int16Array::from_slice(&[2, 4, 6])) as ArrayRef, - Arc::new(Int64Array::from_slice(&[3, 5, 7])) as ArrayRef, + Int16Vec::from_slice(vec![2, 4, 6]).as_arc(), + Int64Vec::from_slice(vec![3, 5, 7]).as_arc(), ], None, )) as ArrayRef, - ], - None, - )) as ArrayRef; + ), + ])) as ArrayRef; assert_eq!(&array, &expected); } - /*#[test] + #[test] fn test_lists_in_struct() { let field_a = Field::new("A", DataType::Utf8, false); let field_primitive_list = Field::new( @@ -2617,25 +3077,23 @@ mod tests { ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); let array = array.as_any().downcast_ref::().unwrap(); - let int_data = vec![ - Some(vec![Some(1), Some(2), Some(3)]), - Some(vec![Some(4), Some(5)]), - Some(vec![Some(6)]), - ]; - let mut primitive_expected = - MutableListArray::>::new(); - primitive_expected.try_extend(int_data).unwrap(); - let primitive_expected: ListArray = expected.into(); - - let expected = StructArray::from_data( - s0.get_datatype(), - vec![ - Arc::new(StringArray::from_slice(&["First", "Second", "Third"])) + let mut list_array = + MutableListArray::::new_with_capacity(Int32Vec::new(), 5); + list_array + .try_extend(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6)]), + ]) + .unwrap(); + let expected = struct_array_from(vec![ + ( + field_a.clone(), + Arc::new(StringArray::from_slice(&vec!["First", "Second", "Third"])) as ArrayRef, - primitive_expected, - ], - None, - ); + ), + (field_primitive_list.clone(), list_array.as_arc()), + ]); assert_eq!(array, &expected); @@ -2656,137 +3114,37 @@ mod tests { let array = array.as_any().downcast_ref::>().unwrap(); // Construct expected array with array builders - let field_a_builder = StringBuilder::new(4); - let primitive_value_builder = Int32Array::builder(8); - let field_primitive_list_builder = ListBuilder::new(primitive_value_builder); - - let element_builder = StructBuilder::new( - vec![field_a, field_primitive_list], - vec![ - Box::new(field_a_builder), - Box::new(field_primitive_list_builder), - ], - ); - let mut list_builder = ListBuilder::new(element_builder); - - list_builder - .values() - .field_builder::(0) - .unwrap() - .append_value("First") - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(1) - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(2) - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(3) - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .append(true) + let field_a_builder = + Utf8Array::::from_slice(&vec!["First", "Second", "Third", "Second"]); + let primitive_value_builder = Int32Vec::with_capacity(5); + let mut field_primitive_list_builder = + MutableListArray::::new_with_capacity( + primitive_value_builder, + 0, + ); + field_primitive_list_builder + .try_push(Some(vec![1, 2, 3].into_iter().map(Option::Some))) .unwrap(); - list_builder.values().append(true).unwrap(); - - list_builder - .values() - .field_builder::(0) - .unwrap() - .append_value("Second") - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(4) + field_primitive_list_builder + .try_push(Some(vec![4, 5].into_iter().map(Option::Some))) .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(5) + field_primitive_list_builder + .try_push(Some(vec![6].into_iter().map(Option::Some))) .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .append(true) - .unwrap(); - list_builder.values().append(true).unwrap(); - list_builder.append(true).unwrap(); - - list_builder - .values() - .field_builder::(0) - .unwrap() - .append_value("Third") - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(6) + field_primitive_list_builder + .try_push(Some(vec![4, 5].into_iter().map(Option::Some))) .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .append(true) - .unwrap(); - list_builder.values().append(true).unwrap(); - list_builder.append(true).unwrap(); - - list_builder - .values() - .field_builder::(0) - .unwrap() - .append_value("Second") - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(4) - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .values() - .append_value(5) - .unwrap(); - list_builder - .values() - .field_builder::>>(1) - .unwrap() - .append(true) - .unwrap(); - list_builder.values().append(true).unwrap(); - list_builder.append(true).unwrap(); - - let expected = list_builder.finish(); - - assert_eq!(array, &expected); + let _element_builder = StructArray::from_data( + DataType::Struct(vec![field_a, field_primitive_list]), + vec![ + Arc::new(field_a_builder), + field_primitive_list_builder.as_arc(), + ], + None, + ); + //let expected = ListArray::(element_builder, 5); + eprintln!("array = {:?}", array); + //assert_eq!(array, &expected); } #[test] @@ -2851,37 +3209,301 @@ mod tests { ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - let array = array.as_any().downcast_ref::>().unwrap(); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); - let middle_builder = ListBuilder::new(inner_builder); - let mut outer_builder = ListBuilder::new(middle_builder); + let inner_builder = Int32Vec::with_capacity(8); + let middle_builder = + MutableListArray::::new_with_capacity(inner_builder, 0); + let mut outer_builder = + MutableListArray::>::new_with_capacity( + middle_builder, + 0, + ); + outer_builder + .try_push(Some(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])) + .unwrap(); + outer_builder + .try_push(Some(vec![ + Some(vec![Some(6)]), + Some(vec![Some(7), Some(8)]), + ])) + .unwrap(); + outer_builder + .try_push(Some(vec![Some(vec![Some(9)])])) + .unwrap(); + + let expected = outer_builder.as_box(); + + assert_eq!(&array, &expected); + } + + #[test] + fn scalar_timestamp_ns_utc_timezone() { + let scalar = ScalarValue::TimestampNanosecond( + Some(1599566400000000000), + Some("UTC".to_owned()), + ); + + assert_eq!( + scalar.get_datatype(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); - outer_builder.values().values().append_value(1).unwrap(); - outer_builder.values().values().append_value(2).unwrap(); - outer_builder.values().values().append_value(3).unwrap(); - outer_builder.values().append(true).unwrap(); + let array = scalar.to_array(); + assert_eq!(array.len(), 1); + assert_eq!( + array.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); - outer_builder.values().values().append_value(4).unwrap(); - outer_builder.values().values().append_value(5).unwrap(); - outer_builder.values().append(true).unwrap(); - outer_builder.append(true).unwrap(); + let newscalar = ScalarValue::try_from_array(&array, 0).unwrap(); + assert_eq!( + newscalar.get_datatype(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) + ); + } - outer_builder.values().values().append_value(6).unwrap(); - outer_builder.values().append(true).unwrap(); + macro_rules! test_scalar_op { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident, $RESULT:expr, $RESULT_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + assert_eq!( + ScalarValue::$OP(v1, v2).unwrap(), + ScalarValue::from($RESULT as $RESULT_TYPE) + ); + }}; + } - outer_builder.values().values().append_value(7).unwrap(); - outer_builder.values().values().append_value(8).unwrap(); - outer_builder.values().append(true).unwrap(); - outer_builder.append(true).unwrap(); + macro_rules! test_scalar_op_err { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + let actual = ScalarValue::$OP(v1, v2).is_err(); + assert!(actual); + }}; + } + + #[test] + fn scalar_addition() { + test_scalar_op!(add, 1, f64, 2, f64, 3, f64); + test_scalar_op!(add, 1, f32, 2, f32, 3, f64); + test_scalar_op!(add, 1, i64, 2, i64, 3, i64); + test_scalar_op!(add, 100, i64, -32, i64, 68, i64); + test_scalar_op!(add, -102, i64, 32, i64, -70, i64); + test_scalar_op!(add, 1, i32, 2, i32, 3, i64); + test_scalar_op!( + add, + std::i32::MAX, + i32, + std::i32::MAX, + i32, + std::i32::MAX as i64 * 2, + i64 + ); + test_scalar_op!(add, 1, i16, 2, i16, 3, i32); + test_scalar_op!( + add, + std::i16::MAX, + i16, + std::i16::MAX, + i16, + std::i16::MAX as i32 * 2, + i32 + ); + test_scalar_op!(add, 1, i8, 2, i8, 3, i16); + test_scalar_op!( + add, + std::i8::MAX, + i8, + std::i8::MAX, + i8, + std::i8::MAX as i16 * 2, + i16 + ); + test_scalar_op!(add, 1, u64, 2, u64, 3, u64); + test_scalar_op!(add, 1, u32, 2, u32, 3, u64); + test_scalar_op!( + add, + std::u32::MAX, + u32, + std::u32::MAX, + u32, + std::u32::MAX as u64 * 2, + u64 + ); + test_scalar_op!(add, 1, u16, 2, u16, 3, u32); + test_scalar_op!( + add, + std::u16::MAX, + u16, + std::u16::MAX, + u16, + std::u16::MAX as u32 * 2, + u32 + ); + test_scalar_op!(add, 1, u8, 2, u8, 3, u16); + test_scalar_op!( + add, + std::u8::MAX, + u8, + std::u8::MAX, + u8, + std::u8::MAX as u16 * 2, + u16 + ); + test_scalar_op_err!(add, 1, i32, 2, u16); + test_scalar_op_err!(add, 1, i32, 2, u16); - outer_builder.values().values().append_value(9).unwrap(); - outer_builder.values().append(true).unwrap(); - outer_builder.append(true).unwrap(); + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + assert!(ScalarValue::add(v1, v2).is_err()); - let expected = outer_builder.finish(); + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); - assert_eq!(array, &expected); - } */ + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::add(v1, v2).is_err()); + } + + #[test] + fn scalar_multiplication() { + test_scalar_op!(mul, 1, f64, 2, f64, 2, f64); + test_scalar_op!(mul, 1, f32, 2, f32, 2, f64); + test_scalar_op!(mul, 15, i64, 2, i64, 30, i64); + test_scalar_op!(mul, 100, i64, -32, i64, -3200, i64); + test_scalar_op!(mul, -1.1, f64, 2, f64, -2.2, f64); + test_scalar_op!(mul, 1, i32, 2, i32, 2, i64); + test_scalar_op!( + mul, + std::i32::MAX, + i32, + std::i32::MAX, + i32, + std::i32::MAX as i64 * std::i32::MAX as i64, + i64 + ); + test_scalar_op!(mul, 1, i16, 2, i16, 2, i32); + test_scalar_op!( + mul, + std::i16::MAX, + i16, + std::i16::MAX, + i16, + std::i16::MAX as i32 * std::i16::MAX as i32, + i32 + ); + test_scalar_op!(mul, 1, i8, 2, i8, 2, i16); + test_scalar_op!( + mul, + std::i8::MAX, + i8, + std::i8::MAX, + i8, + std::i8::MAX as i16 * std::i8::MAX as i16, + i16 + ); + test_scalar_op!(mul, 1, u64, 2, u64, 2, u64); + test_scalar_op!(mul, 1, u32, 2, u32, 2, u64); + test_scalar_op!( + mul, + std::u32::MAX, + u32, + std::u32::MAX, + u32, + std::u32::MAX as u64 * std::u32::MAX as u64, + u64 + ); + test_scalar_op!(mul, 1, u16, 2, u16, 2, u32); + test_scalar_op!( + mul, + std::u16::MAX, + u16, + std::u16::MAX, + u16, + std::u16::MAX as u32 * std::u16::MAX as u32, + u32 + ); + test_scalar_op!(mul, 1, u8, 2, u8, 2, u16); + test_scalar_op!( + mul, + std::u8::MAX, + u8, + std::u8::MAX, + u8, + std::u8::MAX as u16 * std::u8::MAX as u16, + u16 + ); + test_scalar_op_err!(mul, 1, i32, 2, u16); + test_scalar_op_err!(mul, 1, i32, 2, u16); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::mul(v1, v2).is_err()); + } + + #[test] + fn scalar_division() { + test_scalar_op!(div, 1, f64, 2, f64, 0.5, f64); + test_scalar_op!(div, 1, f32, 2, f32, 0.5, f64); + test_scalar_op!(div, 15, i64, 2, i64, 7.5, f64); + test_scalar_op!(div, 100, i64, -2, i64, -50, f64); + test_scalar_op!(div, 1, i32, 2, i32, 0.5, f64); + test_scalar_op!(div, 1, i16, 2, i16, 0.5, f64); + test_scalar_op!(div, 1, i8, 2, i8, 0.5, f64); + test_scalar_op!(div, 1, u64, 2, u64, 0.5, f64); + test_scalar_op!(div, 1, u32, 2, u32, 0.5, f64); + test_scalar_op!(div, 1, u16, 2, u16, 0.5, f64); + test_scalar_op!(div, 1, u8, 2, u8, 0.5, f64); + test_scalar_op_err!(div, 1, i32, 2, u16); + test_scalar_op_err!(div, 1, i32, 2, u16); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::div(v1, v2).is_err()); + } } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index d226ef1c2ce7..8a01287294ba 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1064,8 +1064,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let field = schema.field(field_index - 1); - let col_ident = SQLExpr::Identifier(Ident::new(field.qualified_name())); - self.sql_expr_to_logical_expr(&col_ident, schema)? + Expr::Column(field.qualified_column()) } e => self.sql_expr_to_logical_expr(e, schema)?, }; @@ -1325,9 +1324,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { - // create a column expression based on raw user input, this column will be - // normalized with qualifer later by the SQL planner. - Ok(col(&id.value)) + // Don't use `col()` here because it will try to + // interpret names with '.' as if they were + // compound indenfiers, but this is not a compound + // identifier. (e.g. it is "foo.bar" not foo.bar) + Ok(Expr::Column(Column { + relation: None, + name: id.value.clone(), + })) } } @@ -1343,22 +1347,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::CompoundIdentifier(ids) => { - let mut var_names = vec![]; - for id in ids { - var_names.push(id.value.clone()); - } + let mut var_names: Vec<_> = + ids.iter().map(|id| id.value.clone()).collect(); + if &var_names[0][0..1] == "@" { Ok(Expr::ScalarVariable(var_names)) - } else if var_names.len() == 2 { - // table.column identifier - let name = var_names.pop().unwrap(); - let relation = Some(var_names.pop().unwrap()); - Ok(Expr::Column(Column { relation, name })) } else { - Err(DataFusionError::NotImplemented(format!( - "Unsupported compound identifier '{:?}'", - var_names, - ))) + match (var_names.pop(), var_names.pop()) { + (Some(name), Some(relation)) if var_names.is_empty() => { + // table.column identifier + Ok(Expr::Column(Column { + relation: Some(relation), + name, + })) + } + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported compound identifier '{:?}'", + var_names, + ))), + } } } diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index bce50e5610d3..0ede5ad8559e 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType; use crate::logical_plan::{Expr, LogicalPlan}; -use crate::scalar::ScalarValue; +use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use crate::{ error::{DataFusionError, Result}, logical_plan::{Column, ExpressionVisitor, Recursion}, @@ -520,7 +520,7 @@ pub(crate) fn make_decimal_type( } (Some(p), Some(s)) => { // Arrow decimal is i128 meaning 38 maximum decimal digits - if p > 38 || s > p { + if (p as usize) > MAX_PRECISION_FOR_DECIMAL128 || s > p { return Err(DataFusionError::Internal(format!( "For decimal(precision, scale) precision must be less than or equal to 38 and scale can't be greater than precision. Got ({}, {})", p, s diff --git a/datafusion/src/test/variable.rs b/datafusion/src/test/variable.rs index 47d1370e8014..12597b832df6 100644 --- a/datafusion/src/test/variable.rs +++ b/datafusion/src/test/variable.rs @@ -34,7 +34,7 @@ impl SystemVar { impl VarProvider for SystemVar { /// get system variable value fn get_value(&self, var_names: Vec) -> Result { - let s = format!("{}-{}", "system-var".to_string(), var_names.concat()); + let s = format!("{}-{}", "system-var", var_names.concat()); Ok(ScalarValue::Utf8(Some(s))) } } @@ -52,7 +52,7 @@ impl UserDefinedVar { impl VarProvider for UserDefinedVar { /// Get user defined variable value fn get_value(&self, var_names: Vec) -> Result { - let s = format!("{}-{}", "user-defined-var".to_string(), var_names.concat()); + let s = format!("{}-{}", "user-defined-var", var_names.concat()); Ok(ScalarValue::Utf8(Some(s))) } } diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs index aad014372981..5d5494fa58eb 100644 --- a/datafusion/src/test_util.rs +++ b/datafusion/src/test_util.rs @@ -231,9 +231,9 @@ fn get_data_dir(udf_env: &str, submodule_data: &str) -> Result Arc { let mut f1 = Field::new("c1", DataType::Utf8, false); - f1.set_metadata(Some(BTreeMap::from_iter( + f1 = f1.with_metadata(BTreeMap::from_iter( vec![("testing".into(), "test".into())].into_iter(), - ))); + )); let schema = Schema::new(vec![ f1, Field::new("c2", DataType::UInt32, false), diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs new file mode 100644 index 000000000000..b9277f4f5969 --- /dev/null +++ b/datafusion/tests/dataframe_functions.rs @@ -0,0 +1,665 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::Utf8Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::{array::Int32Array, record_batch::RecordBatch}; + +use datafusion::dataframe::DataFrame; +use datafusion::datasource::MemTable; + +use datafusion::error::Result; + +// use datafusion::logical_plan::Expr; +use datafusion::prelude::*; + +use datafusion::execution::context::ExecutionContext; + +use datafusion::assert_batches_eq; + +fn create_test_table() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Utf8Array::::from_slice(vec![ + "abcDEF", + "abc123", + "CBAdef", + "123AbcDef", + ])), + Arc::new(Int32Array::from_slice(vec![1, 10, 10, 100])), + ], + )?; + + let mut ctx = ExecutionContext::new(); + + let table = MemTable::try_new(schema, vec![vec![batch]])?; + + ctx.register_table("test", Arc::new(table))?; + + ctx.table("test") +} + +/// Excutes an expression on the test dataframe as a select. +/// Compares formatted output of a record batch with an expected +/// vector of strings, using the assert_batch_eq! macro +macro_rules! assert_fn_batches { + ($EXPR:expr, $EXPECTED: expr) => { + assert_fn_batches!($EXPR, $EXPECTED, 10) + }; + ($EXPR:expr, $EXPECTED: expr, $LIMIT: expr) => { + let df = create_test_table()?; + let df = df.select(vec![$EXPR])?.limit($LIMIT)?; + let batches = df.collect().await?; + + assert_batches_eq!($EXPECTED, &batches); + }; +} + +#[tokio::test] +async fn test_fn_ascii() -> Result<()> { + let expr = ascii(col("a")); + + let expected = vec![ + "+---------------+", + "| ascii(test.a) |", + "+---------------+", + "| 97 |", + "+---------------+", + ]; + + assert_fn_batches!(expr, expected, 1); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_bit_length() -> Result<()> { + let expr = bit_length(col("a")); + + let expected = vec![ + "+-------------------+", + "| bitlength(test.a) |", + "+-------------------+", + "| 48 |", + "| 48 |", + "| 48 |", + "| 72 |", + "+-------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_btrim() -> Result<()> { + let expr = btrim(vec![lit(" a b c ")]); + + let expected = vec![ + "+-----------------------------------------+", + "| btrim(Utf8(\" a b c \")) |", + "+-----------------------------------------+", + "| a b c |", + "+-----------------------------------------+", + ]; + + assert_fn_batches!(expr, expected, 1); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_btrim_with_chars() -> Result<()> { + let expr = btrim(vec![col("a"), lit("ab")]); + + let expected = vec![ + "+--------------------------+", + "| btrim(test.a,Utf8(\"ab\")) |", + "+--------------------------+", + "| cDEF |", + "| c123 |", + "| CBAdef |", + "| 123AbcDef |", + "+--------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_character_length() -> Result<()> { + let expr = character_length(col("a")); + + let expected = vec![ + "+-------------------------+", + "| characterlength(test.a) |", + "+-------------------------+", + "| 6 |", + "| 6 |", + "| 6 |", + "| 9 |", + "+-------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_chr() -> Result<()> { + let expr = chr(lit(128175)); + + let expected = vec![ + "+--------------------+", + "| chr(Int32(128175)) |", + "+--------------------+", + "| 💯 |", + "+--------------------+", + ]; + + assert_fn_batches!(expr, expected, 1); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_initcap() -> Result<()> { + let expr = initcap(col("a")); + + let expected = vec![ + "+-----------------+", + "| initcap(test.a) |", + "+-----------------+", + "| Abcdef |", + "| Abc123 |", + "| Cbadef |", + "| 123abcdef |", + "+-----------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_left() -> Result<()> { + let expr = left(col("a"), lit(3)); + + let expected = vec![ + "+-----------------------+", + "| left(test.a,Int32(3)) |", + "+-----------------------+", + "| abc |", + "| abc |", + "| CBA |", + "| 123 |", + "+-----------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_lower() -> Result<()> { + let expr = lower(col("a")); + + let expected = vec![ + "+---------------+", + "| lower(test.a) |", + "+---------------+", + "| abcdef |", + "| abc123 |", + "| cbadef |", + "| 123abcdef |", + "+---------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_lpad() -> Result<()> { + let expr = lpad(vec![col("a"), lit(10)]); + + let expected = vec![ + "+------------------------+", + "| lpad(test.a,Int32(10)) |", + "+------------------------+", + "| abcDEF |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_lpad_with_string() -> Result<()> { + let expr = lpad(vec![col("a"), lit(10), lit("*")]); + + let expected = vec![ + "+----------------------------------+", + "| lpad(test.a,Int32(10),Utf8(\"*\")) |", + "+----------------------------------+", + "| ****abcDEF |", + "| ****abc123 |", + "| ****CBAdef |", + "| *123AbcDef |", + "+----------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_ltrim() -> Result<()> { + let expr = ltrim(lit(" a b c ")); + + let expected = vec![ + "+-----------------------------------------+", + "| ltrim(Utf8(\" a b c \")) |", + "+-----------------------------------------+", + "| a b c |", + "+-----------------------------------------+", + ]; + + assert_fn_batches!(expr, expected, 1); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_ltrim_with_columns() -> Result<()> { + let expr = ltrim(col("a")); + + let expected = vec![ + "+---------------+", + "| ltrim(test.a) |", + "+---------------+", + "| abcDEF |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+---------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_md5() -> Result<()> { + let expr = md5(col("a")); + + let expected = vec![ + "+----------------------------------+", + "| md5(test.a) |", + "+----------------------------------+", + "| ea2de8bd80f3a1f52c754214fc9b0ed1 |", + "| e99a18c428cb38d5f260853678922e03 |", + "| 11ed4a6e9985df40913eead67f022e27 |", + "| 8f5e60e523c9253e623ae38bb58c399a |", + "+----------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +// TODO: tobyhede - Issue #1429 +// https://github.com/apache/arrow-datafusion/issues/1429 +// g flag doesn't compile +#[tokio::test] +async fn test_fn_regexp_match() -> Result<()> { + let expr = regexp_match(vec![col("a"), lit("[a-z]")]); + // The below will fail + // let expr = regexp_match( vec![col("a"), lit("[a-z]"), lit("g")]); + + let expected = vec![ + "+-----------------------------------+", + "| regexpmatch(test.a,Utf8(\"[a-z]\")) |", + "+-----------------------------------+", + "| [] |", + "| [] |", + "| [] |", + "| [] |", + "+-----------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_regexp_replace() -> Result<()> { + let expr = regexp_replace(vec![col("a"), lit("[a-z]"), lit("x"), lit("g")]); + + let expected = vec![ + "+---------------------------------------------------------+", + "| regexpreplace(test.a,Utf8(\"[a-z]\"),Utf8(\"x\"),Utf8(\"g\")) |", + "+---------------------------------------------------------+", + "| xxxDEF |", + "| xxx123 |", + "| CBAxxx |", + "| 123AxxDxx |", + "+---------------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_replace() -> Result<()> { + let expr = replace(col("a"), lit("abc"), lit("x")); + + let expected = vec![ + "+---------------------------------------+", + "| replace(test.a,Utf8(\"abc\"),Utf8(\"x\")) |", + "+---------------------------------------+", + "| xDEF |", + "| x123 |", + "| CBAdef |", + "| 123AbcDef |", + "+---------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_repeat() -> Result<()> { + let expr = repeat(col("a"), lit(2)); + + let expected = vec![ + "+-------------------------+", + "| repeat(test.a,Int32(2)) |", + "+-------------------------+", + "| abcDEFabcDEF |", + "| abc123abc123 |", + "| CBAdefCBAdef |", + "| 123AbcDef123AbcDef |", + "+-------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_reverse() -> Result<()> { + let expr = reverse(col("a")); + + let expected = vec![ + "+-----------------+", + "| reverse(test.a) |", + "+-----------------+", + "| FEDcba |", + "| 321cba |", + "| fedABC |", + "| feDcbA321 |", + "+-----------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_right() -> Result<()> { + let expr = right(col("a"), lit(3)); + + let expected = vec![ + "+------------------------+", + "| right(test.a,Int32(3)) |", + "+------------------------+", + "| DEF |", + "| 123 |", + "| def |", + "| Def |", + "+------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_rpad() -> Result<()> { + let expr = rpad(vec![col("a"), lit(11)]); + + let expected = vec![ + "+------------------------+", + "| rpad(test.a,Int32(11)) |", + "+------------------------+", + "| abcDEF |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_rpad_with_characters() -> Result<()> { + let expr = rpad(vec![col("a"), lit(11), lit("x")]); + + let expected = vec![ + "+----------------------------------+", + "| rpad(test.a,Int32(11),Utf8(\"x\")) |", + "+----------------------------------+", + "| abcDEFxxxxx |", + "| abc123xxxxx |", + "| CBAdefxxxxx |", + "| 123AbcDefxx |", + "+----------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_sha224() -> Result<()> { + let expr = sha224(col("a")); + + let expected = vec![ + "+----------------------------------------------------------+", + "| sha224(test.a) |", + "+----------------------------------------------------------+", + "| 8b9ef961d2b19cfe7ee2a8452e3adeea98c7b22954b4073976bf80ee |", + "| 5c69bb695cc29b93d655e1a4bb5656cda624080d686f74477ea09349 |", + "| b3b3783b7470594e7ddb845eca0aec5270746dd6d0bc309bb948ceab |", + "| fc8a30d59386d78053328440c6670c3b583404a905cbe9bbd491a517 |", + "+----------------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_split_part() -> Result<()> { + let expr = split_part(col("a"), lit("b"), lit(1)); + + let expected = vec![ + "+--------------------------------------+", + "| splitpart(test.a,Utf8(\"b\"),Int32(1)) |", + "+--------------------------------------+", + "| a |", + "| a |", + "| CBAdef |", + "| 123A |", + "+--------------------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_starts_with() -> Result<()> { + let expr = starts_with(col("a"), lit("abc")); + + let expected = vec![ + "+--------------------------------+", + "| startswith(test.a,Utf8(\"abc\")) |", + "+--------------------------------+", + "| true |", + "| true |", + "| false |", + "| false |", + "+--------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_strpos() -> Result<()> { + let expr = strpos(col("a"), lit("f")); + + let expected = vec![ + "+--------------------------+", + "| strpos(test.a,Utf8(\"f\")) |", + "+--------------------------+", + "| 0 |", + "| 0 |", + "| 6 |", + "| 9 |", + "+--------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_substr() -> Result<()> { + let expr = substr(col("a"), lit(2)); + + let expected = vec![ + "+-------------------------+", + "| substr(test.a,Int32(2)) |", + "+-------------------------+", + "| bcDEF |", + "| bc123 |", + "| BAdef |", + "| 23AbcDef |", + "+-------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_to_hex() -> Result<()> { + let expr = to_hex(col("b")); + + let expected = vec![ + "+---------------+", + "| tohex(test.b) |", + "+---------------+", + "| 1 |", + "| a |", + "| a |", + "| 64 |", + "+---------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_translate() -> Result<()> { + let expr = translate(col("a"), lit("bc"), lit("xx")); + + let expected = vec![ + "+-----------------------------------------+", + "| translate(test.a,Utf8(\"bc\"),Utf8(\"xx\")) |", + "+-----------------------------------------+", + "| axxDEF |", + "| axx123 |", + "| CBAdef |", + "| 123AxxDef |", + "+-----------------------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_upper() -> Result<()> { + let expr = upper(col("a")); + + let expected = vec![ + "+---------------+", + "| upper(test.a) |", + "+---------------+", + "| ABCDEF |", + "| ABC123 |", + "| CBADEF |", + "| 123ABCDEF |", + "+---------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 57611b8cd336..ed21fae8ad2f 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -639,11 +639,12 @@ async fn make_test_file(scenario: Scenario) -> NamedTempFile { .iter() .zip(descritors.clone()) .map(|(array, type_)| { - let encoding = if let DataType::Dictionary(_, _) = array.data_type() { - Encoding::RleDictionary - } else { - Encoding::Plain - }; + let encoding = + if let DataType::Dictionary(_, _, _) = array.data_type() { + Encoding::RleDictionary + } else { + Encoding::Plain + }; array_to_pages(array.as_ref(), type_, options, encoding).map( move |pages| { let encoded_pages = DynIter::new(pages.map(|x| Ok(x?))); diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index bc1ff554abfa..8b137891791f 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1,6408 +1 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -//! This module contains end to end tests of running SQL queries using -//! DataFusion - -use std::sync::Arc; - -use chrono::{Duration, TimeZone}; - -use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; - -use datafusion::arrow::io::print; -use datafusion::assert_batches_eq; -use datafusion::assert_batches_sorted_eq; -use datafusion::assert_contains; -use datafusion::assert_not_contains; -use datafusion::logical_plan::plan::{Aggregate, Projection}; -use datafusion::logical_plan::LogicalPlan; -use datafusion::logical_plan::TableScan; -use datafusion::physical_plan::functions::Volatility; -use datafusion::physical_plan::metrics::MetricValue; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::ExecutionPlanVisitor; -use datafusion::prelude::*; -use datafusion::test_util; -use datafusion::{datasource::MemTable, physical_plan::collect}; -use datafusion::{ - error::{DataFusionError, Result}, - physical_plan::ColumnarValue, -}; -use datafusion::{execution::context::ExecutionContext, physical_plan::displayable}; - -#[tokio::test] -async fn nyc() -> Result<()> { - // schema for nyxtaxi csv files - let schema = Schema::new(vec![ - Field::new("VendorID", DataType::Utf8, true), - Field::new("tpep_pickup_datetime", DataType::Utf8, true), - Field::new("tpep_dropoff_datetime", DataType::Utf8, true), - Field::new("passenger_count", DataType::Utf8, true), - Field::new("trip_distance", DataType::Float64, true), - Field::new("RatecodeID", DataType::Utf8, true), - Field::new("store_and_fwd_flag", DataType::Utf8, true), - Field::new("PULocationID", DataType::Utf8, true), - Field::new("DOLocationID", DataType::Utf8, true), - Field::new("payment_type", DataType::Utf8, true), - Field::new("fare_amount", DataType::Float64, true), - Field::new("extra", DataType::Float64, true), - Field::new("mta_tax", DataType::Float64, true), - Field::new("tip_amount", DataType::Float64, true), - Field::new("tolls_amount", DataType::Float64, true), - Field::new("improvement_surcharge", DataType::Float64, true), - Field::new("total_amount", DataType::Float64, true), - ]); - - let mut ctx = ExecutionContext::new(); - ctx.register_csv( - "tripdata", - "file.csv", - CsvReadOptions::new().schema(&schema), - ) - .await?; - - let logical_plan = ctx.create_logical_plan( - "SELECT passenger_count, MIN(fare_amount), MAX(fare_amount) \ - FROM tripdata GROUP BY passenger_count", - )?; - - let optimized_plan = ctx.optimize(&logical_plan)?; - - match &optimized_plan { - LogicalPlan::Projection(Projection { input, .. }) => match input.as_ref() { - LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { - LogicalPlan::TableScan(TableScan { - ref projected_schema, - .. - }) => { - assert_eq!(2, projected_schema.fields().len()); - assert_eq!(projected_schema.field(0).name(), "passenger_count"); - assert_eq!(projected_schema.field(1).name(), "fare_amount"); - } - _ => unreachable!(), - }, - _ => unreachable!(), - }, - _ => unreachable!(false), - } - - Ok(()) -} - -#[tokio::test] -async fn parquet_query() { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-----------------------------------------+", - "| id | CAST(alltypes_plain.string_col AS Utf8) |", - "+----+-----------------------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+-----------------------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn parquet_single_nan_schema() { - let mut ctx = ExecutionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) - .await - .unwrap(); - let sql = "SELECT mycol FROM single_nan"; - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); - for batch in results { - assert_eq!(1, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -#[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] -async fn parquet_list_columns() { - let mut ctx = ExecutionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "list_columns", - &format!("{}/list_columns.parquet", testdata), - ) - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "int64_list", - DataType::List(Box::new(Field::new("item", DataType::Int64, true))), - true, - ), - Field::new( - "utf8_list", - DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), - true, - ), - ])); - - let sql = "SELECT int64_list, utf8_list FROM list_columns"; - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); - - // int64_list utf8_list - // 0 [1, 2, 3] [abc, efg, hij] - // 1 [None, 1] None - // 2 [4] [efg, None, hij, xyz] - - assert_eq!(1, results.len()); - let batch = &results[0]; - assert_eq!(3, batch.num_rows()); - assert_eq!(2, batch.num_columns()); - assert_eq!(schema.as_ref(), batch.schema().as_ref()); - - let int_list_array = batch - .column(0) - .as_any() - .downcast_ref::>() - .unwrap(); - let utf8_list_array = batch - .column(1) - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - int_list_array - .value(0) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3)]) - ); - - assert_eq!( - utf8_list_array - .value(0) - .as_any() - .downcast_ref::>() - .unwrap(), - &Utf8Array::::from(vec![Some("abc"), Some("efg"), Some("hij")]) - ); - - assert_eq!( - int_list_array - .value(1) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![None, Some(1),]) - ); - - assert!(utf8_list_array.is_null(1)); - - assert_eq!( - int_list_array - .value(2) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![Some(4),]) - ); - - let result = utf8_list_array.value(2); - let result = result.as_any().downcast_ref::>().unwrap(); - - assert_eq!(result.value(0), "efg"); - assert!(result.is_null(1)); - assert_eq!(result.value(2), "hij"); - assert_eq!(result.value(3), "xyz"); -} - -#[tokio::test] -async fn csv_select_nested() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT o1, o2, c3 - FROM ( - SELECT c1 AS o1, c2 + 1 AS o2, c3 - FROM ( - SELECT c1, c2, c3, c4 - FROM aggregate_test_100 - WHERE c1 = 'a' AND c2 >= 4 - ORDER BY c2 ASC, c3 ASC - ) AS a - ) AS b"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+----+------+", - "| o1 | o2 | c3 |", - "+----+----+------+", - "| a | 5 | -101 |", - "| a | 5 | -54 |", - "| a | 5 | -38 |", - "| a | 5 | 65 |", - "| a | 6 | -101 |", - "| a | 6 | -31 |", - "| a | 6 | 36 |", - "+----+----+------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_count_star() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+-----+------------------------------+", - "| COUNT(UInt8(1)) | c | COUNT(aggregate_test_100.c1) |", - "+-----------------+-----+------------------------------+", - "| 100 | 100 | 100 |", - "+-----------------+-----+------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_with_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+---------------------+", - "| c1 | c12 |", - "+----+---------------------+", - "| e | 0.39144436569161134 |", - "| d | 0.38870280983958583 |", - "+----+---------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_with_negative_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, c4 FROM aggregate_test_100 WHERE c3 < -55 AND -c4 > 30000"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+--------+", - "| c1 | c4 |", - "+----+--------+", - "| e | -31500 |", - "| c | -30187 |", - "+----+--------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_with_negated_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE NOT(c1 != 'a')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 21 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_with_is_not_null_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NOT NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 100 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_with_is_null_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 0 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_int_min_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-----------------------------+-----------------------------+", - "| c2 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", - "+----+-----------------------------+-----------------------------+", - "| 1 | 0.05636955101974106 | 0.9965400387585364 |", - "| 2 | 0.16301110515739792 | 0.991517828651004 |", - "| 3 | 0.047343434291126085 | 0.9293883502480845 |", - "| 4 | 0.02182578039211991 | 0.9237877978193884 |", - "| 5 | 0.01479305307777301 | 0.9723580396501548 |", - "+----+-----------------------------+-----------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_float32() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = - "SELECT COUNT(*) as cnt, c1 FROM aggregate_simple GROUP BY c1 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----+---------+", - "| cnt | c1 |", - "+-----+---------+", - "| 5 | 0.00005 |", - "| 4 | 0.00004 |", - "| 3 | 0.00003 |", - "| 2 | 0.00002 |", - "| 1 | 0.00001 |", - "+-----+---------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn select_values_list() -> Result<()> { - let mut ctx = ExecutionContext::new(); - { - let sql = "VALUES (1)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (-1)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| -1 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (2+1,2-1,2>1)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+---------+", - "| column1 | column2 | column3 |", - "+---------+---------+---------+", - "| 3 | 1 | true |", - "+---------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES ()"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES (1),(2)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (1),()"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES (1,'a'),(2,'b')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+", - "| column1 | column2 |", - "+---------+---------+", - "| 1 | a |", - "| 2 | b |", - "+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (1),(1,2)"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES (1),('2')"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES (1),(2.0)"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES (1,2), (1,'2')"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - } - { - let sql = "VALUES (1,'a'),(NULL,'b'),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+", - "| column1 | column2 |", - "+---------+---------+", - "| 1 | a |", - "| | b |", - "| 3 | c |", - "+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (NULL,'a'),(NULL,'b'),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+", - "| column1 | column2 |", - "+---------+---------+", - "| | a |", - "| | b |", - "| 3 | c |", - "+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (NULL,'a'),(NULL,'b'),(NULL,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+", - "| column1 | column2 |", - "+---------+---------+", - "| | a |", - "| | b |", - "| | c |", - "+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (1,'a'),(2,NULL),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+", - "| column1 | column2 |", - "+---------+---------+", - "| 1 | a |", - "| 2 | |", - "| 3 | c |", - "+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (1,NULL),(2,NULL),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+", - "| column1 | column2 |", - "+---------+---------+", - "| 1 | |", - "| 2 | |", - "| 3 | c |", - "+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", - "| column1 | column2 | column3 | column4 | column5 | column6 | column7 | column8 | column9 | column10 | column11 | column12 | column13 | column14 | column15 | column16 |", - "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", - "| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | | F | 3.5 |", - "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "SELECT * FROM (VALUES (1,'a'),(2,NULL)) AS t(c1, c2)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 1 | a |", - "| 2 | |", - "+----+----+", - ]; - assert_batches_eq!(expected, &actual); - } - { - let sql = "EXPLAIN VALUES (1, 'a', -1, 1.1),(NULL, 'b', -3, 0.5)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------+-----------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+-----------------------------------------------------------------------------------------------------------+", - "| logical_plan | Values: (Int64(1), Utf8(\"a\"), Int64(-1), Float64(1.1)), (Int64(NULL), Utf8(\"b\"), Int64(-3), Float64(0.5)) |", - "| physical_plan | ValuesExec |", - "| | |", - "+---------------+-----------------------------------------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn select_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = "SELECT c1 FROM aggregate_simple order by c1"; - let results = execute_to_batches(&mut ctx, sql).await; - - let sql_all = "SELECT ALL c1 FROM aggregate_simple order by c1"; - let results_all = execute_to_batches(&mut ctx, sql_all).await; - - let expected = vec![ - "+---------+", - "| c1 |", - "+---------+", - "| 0.00001 |", - "| 0.00002 |", - "| 0.00002 |", - "| 0.00003 |", - "| 0.00003 |", - "| 0.00003 |", - "| 0.00004 |", - "| 0.00004 |", - "| 0.00004 |", - "| 0.00004 |", - "| 0.00005 |", - "| 0.00005 |", - "| 0.00005 |", - "| 0.00005 |", - "| 0.00005 |", - "+---------+", - ]; - - assert_batches_eq!(expected, &results); - assert_batches_eq!(expected, &results_all); - - Ok(()) -} - -#[tokio::test] -async fn create_table_as() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; - ctx.sql(sql).await.unwrap(); - - let sql_all = "SELECT * FROM my_table order by c1 LIMIT 1"; - let results_all = execute_to_batches(&mut ctx, sql_all).await; - - let expected = vec![ - "+---------+----------------+------+", - "| c1 | c2 | c3 |", - "+---------+----------------+------+", - "| 0.00001 | 0.000000000001 | true |", - "+---------+----------------+------+", - ]; - - assert_batches_eq!(expected, &results_all); - - Ok(()) -} - -#[tokio::test] -async fn drop_table() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; - ctx.sql(sql).await.unwrap(); - - let sql = "DROP TABLE my_table"; - ctx.sql(sql).await.unwrap(); - - let result = ctx.table("my_table"); - assert!(result.is_err(), "drop table should deregister table."); - - let sql = "DROP TABLE IF EXISTS my_table"; - ctx.sql(sql).await.unwrap(); - - Ok(()) -} - -#[tokio::test] -async fn select_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = "SELECT DISTINCT * FROM aggregate_simple"; - let mut actual = execute(&mut ctx, sql).await; - actual.sort(); - - let mut dedup = actual.clone(); - dedup.dedup(); - - assert_eq!(actual, dedup); - - Ok(()) -} - -#[tokio::test] -async fn select_distinct_simple_1() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); - - let sql = "SELECT DISTINCT c1 FROM aggregate_simple order by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+---------+", - "| c1 |", - "+---------+", - "| 0.00001 |", - "| 0.00002 |", - "| 0.00003 |", - "| 0.00004 |", - "| 0.00005 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn select_distinct_simple_2() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); - - let sql = "SELECT DISTINCT c1, c2 FROM aggregate_simple order by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+---------+----------------+", - "| c1 | c2 |", - "+---------+----------------+", - "| 0.00001 | 0.000000000001 |", - "| 0.00002 | 0.000000000002 |", - "| 0.00003 | 0.000000000003 |", - "| 0.00004 | 0.000000000004 |", - "| 0.00005 | 0.000000000005 |", - "+---------+----------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn select_distinct_simple_3() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); - - let sql = "SELECT distinct c3 FROM aggregate_simple order by c3"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------+", - "| c3 |", - "+-------+", - "| false |", - "| true |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn select_distinct_simple_4() { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); - - let sql = "SELECT distinct c1+c2 as a FROM aggregate_simple"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------------------------+", - "| a |", - "+-------------------------+", - "| 0.000030000002242136256 |", - "| 0.000040000002989515004 |", - "| 0.000010000000747378751 |", - "| 0.00005000000373689376 |", - "| 0.000020000001494757502 |", - "+-------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); -} - -#[tokio::test] -async fn select_distinct_from() { - let mut ctx = ExecutionContext::new(); - - let sql = "select - 1 IS DISTINCT FROM CAST(NULL as INT) as a, - 1 IS DISTINCT FROM 1 as b, - 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, - 1 IS NOT DISTINCT FROM 1 as d, - NULL IS DISTINCT FROM NULL as e, - NULL IS NOT DISTINCT FROM NULL as f - "; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------+-------+-------+------+-------+------+", - "| a | b | c | d | e | f |", - "+------+-------+-------+------+-------+------+", - "| true | false | false | true | false | true |", - "+------+-------+-------+------+-------+------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn select_distinct_from_utf8() { - let mut ctx = ExecutionContext::new(); - - let sql = "select - 'x' IS DISTINCT FROM NULL as a, - 'x' IS DISTINCT FROM 'x' as b, - 'x' IS NOT DISTINCT FROM NULL as c, - 'x' IS NOT DISTINCT FROM 'x' as d - "; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------+-------+-------+------+", - "| a | b | c | d |", - "+------+-------+-------+------+", - "| true | false | false | true |", - "+------+-------+-------+------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn projection_same_fields() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec!["+---+", "| a |", "+---+", "| 2 |", "+---+"]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_float64() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = - "SELECT COUNT(*) as cnt, c2 FROM aggregate_simple GROUP BY c2 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----+----------------+", - "| cnt | c2 |", - "+-----+----------------+", - "| 5 | 0.000000000005 |", - "| 4 | 0.000000000004 |", - "| 3 | 0.000000000003 |", - "| 2 | 0.000000000002 |", - "| 1 | 0.000000000001 |", - "+-----+----------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_boolean() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; - - let sql = - "SELECT COUNT(*) as cnt, c3 FROM aggregate_simple GROUP BY c3 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----+-------+", - "| cnt | c3 |", - "+-----+-------+", - "| 9 | true |", - "| 6 | false |", - "+-----+-------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_two_columns() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, c2, MIN(c3) FROM aggregate_test_100 GROUP BY c1, c2"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+----+----------------------------+", - "| c1 | c2 | MIN(aggregate_test_100.c3) |", - "+----+----+----------------------------+", - "| a | 1 | -85 |", - "| a | 2 | -48 |", - "| a | 3 | -72 |", - "| a | 4 | -101 |", - "| a | 5 | -101 |", - "| b | 1 | 12 |", - "| b | 2 | -60 |", - "| b | 3 | -101 |", - "| b | 4 | -117 |", - "| b | 5 | -82 |", - "| c | 1 | -24 |", - "| c | 2 | -117 |", - "| c | 3 | -2 |", - "| c | 4 | -90 |", - "| c | 5 | -94 |", - "| d | 1 | -99 |", - "| d | 2 | 93 |", - "| d | 3 | -76 |", - "| d | 4 | 5 |", - "| d | 5 | -59 |", - "| e | 1 | 36 |", - "| e | 2 | -61 |", - "| e | 3 | -95 |", - "| e | 4 | -56 |", - "| e | 5 | -86 |", - "+----+----+----------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_and_having() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+------+", - "| c1 | m |", - "+----+------+", - "| a | -101 |", - "| c | -117 |", - "+----+------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_and_having_and_where() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, MIN(c3) AS m - FROM aggregate_test_100 - WHERE c1 IN ('a', 'b') - GROUP BY c1 - HAVING m < -100 AND MAX(c3) > 70"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+------+", - "| c1 | m |", - "+----+------+", - "| a | -101 |", - "+----+------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn all_where_empty() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT * - FROM aggregate_test_100 - WHERE 1=2"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["++", "++"]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_having_without_group_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, c2, c3 FROM aggregate_test_100 HAVING c2 >= 4 AND c3 > 90"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+----+-----+", - "| c1 | c2 | c3 |", - "+----+----+-----+", - "| c | 4 | 123 |", - "| c | 5 | 118 |", - "| d | 4 | 102 |", - "| e | 4 | 96 |", - "| e | 4 | 97 |", - "+----+----+-----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_boolean_eq_neq() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); - // verify the plumbing is all hooked up for eq and neq - let sql = "SELECT a, b, a = b as eq, b = true as eq_scalar, a != b as neq, a != true as neq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------+-------+-------+-----------+-------+------------+", - "| a | b | eq | eq_scalar | neq | neq_scalar |", - "+-------+-------+-------+-----------+-------+------------+", - "| true | true | true | true | false | false |", - "| true | | | | | false |", - "| true | false | false | false | true | false |", - "| | true | | true | | |", - "| | | | | | |", - "| | false | | false | | |", - "| false | true | false | true | true | true |", - "| false | | | | | true |", - "| false | false | true | false | false | true |", - "+-------+-------+-------+-----------+-------+------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -#[ignore] -async fn csv_query_boolean_lt_lt_eq() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); - // verify the plumbing is all hooked up for < and <= - let sql = "SELECT a, b, a < b as lt, b = true as lt_scalar, a <= b as lt_eq, a <= true as lt_eq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------+-------+-------+-----------+-------+--------------+", - "| a | b | lt | lt_scalar | lt_eq | lt_eq_scalar |", - "+-------+-------+-------+-----------+-------+--------------+", - "| true | true | false | true | true | true |", - "| true | | | | | true |", - "| true | false | false | false | false | true |", - "| | true | | true | | |", - "| | | | | | |", - "| | false | | false | | |", - "| false | true | true | true | true | true |", - "| false | | | | | true |", - "| false | false | false | false | true | true |", - "+-------+-------+-------+-----------+-------+--------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_boolean_gt_gt_eq() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); - // verify the plumbing is all hooked up for > and >= - let sql = "SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------+-------+-------+-----------+-------+--------------+", - "| a | b | gt | gt_scalar | gt_eq | gt_eq_scalar |", - "+-------+-------+-------+-----------+-------+--------------+", - "| true | true | false | true | true | true |", - "| true | | | | | true |", - "| true | false | true | false | true | true |", - "| | true | | true | | |", - "| | | | | | |", - "| | false | | false | | |", - "| false | true | false | true | false | false |", - "| false | | | | | false |", - "| false | false | false | false | true | false |", - "+-------+-------+-------+-----------+-------+--------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_boolean_distinct_from() { - let mut ctx = ExecutionContext::new(); - register_boolean(&mut ctx).await.unwrap(); - // verify the plumbing is all hooked up for is distinct from and is not distinct from - let sql = "SELECT a, b, \ - a is distinct from b as df, \ - b is distinct from true as df_scalar, \ - a is not distinct from b as ndf, \ - a is not distinct from true as ndf_scalar \ - FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------+-------+-------+-----------+-------+------------+", - "| a | b | df | df_scalar | ndf | ndf_scalar |", - "+-------+-------+-------+-----------+-------+------------+", - "| true | true | false | false | true | true |", - "| true | | true | true | false | true |", - "| true | false | true | true | false | true |", - "| | true | true | false | false | false |", - "| | | false | true | true | false |", - "| | false | true | true | false | false |", - "| false | true | true | false | false | false |", - "| false | | true | true | false | false |", - "| false | false | false | true | true | false |", - "+-------+-------+-------+-----------+-------+------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_avg_sqrt() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; - actual.sort(); - let expected = vec![vec!["0.6706002946036462"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - -/// test that casting happens on udfs. -/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and -/// physical plan have the same schema. -#[tokio::test] -async fn csv_query_custom_udf_with_cast() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0.6584408483418833"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - -/// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) -#[tokio::test] -async fn sqrt_f32_vs_f64() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; - // sqrt(f32)'s plan passes - let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0.658440933227539"]]; - - assert_eq!(actual, expected); - let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0.6584408483418833"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_error() -> Result<()> { - // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT sin(c1) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - Ok(()) -} - -// this query used to deadlock due to the call udf(udf()) -#[tokio::test] -async fn csv_query_sqrt_sqrt() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT sqrt(sqrt(c12)) FROM aggregate_test_100 LIMIT 1"; - let actual = execute(&mut ctx, sql).await; - // sqrt(sqrt(c12=0.9294097332465232)) = 0.9818650561397431 - let expected = vec![vec!["0.9818650561397431"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - -#[allow(clippy::unnecessary_wraps)] -fn create_ctx() -> Result { - let mut ctx = ExecutionContext::new(); - - // register a custom UDF - ctx.register_udf(create_udf( - "custom_sqrt", - vec![DataType::Float64], - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(custom_sqrt), - )); - - Ok(ctx) -} - -fn custom_sqrt(args: &[ColumnarValue]) -> Result { - let arg = &args[0]; - if let ColumnarValue::Array(v) = arg { - let input = v - .as_any() - .downcast_ref::() - .expect("cast failed"); - - let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); - Ok(ColumnarValue::Array(Arc::new(array))) - } else { - unimplemented!() - } -} - -#[tokio::test] -async fn csv_query_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT avg(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; - actual.sort(); - let expected = vec![vec!["0.5089725099127211"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-----------------------------+", - "| c1 | AVG(aggregate_test_100.c12) |", - "+----+-----------------------------+", - "| a | 0.48754517466109415 |", - "| b | 0.41040709263815384 |", - "| c | 0.6600456536439785 |", - "| d | 0.48855379387549824 |", - "| e | 0.48600669271341557 |", - "+----+-----------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------------------+----+", - "| AVG(aggregate_test_100.c12) | c1 |", - "+-----------------------------+----+", - "| 0.41040709263815384 | b |", - "| 0.48600669271341557 | e |", - "| 0.48754517466109415 | a |", - "| 0.48855379387549824 | d |", - "| 0.6600456536439785 | c |", - "+-----------------------------+----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_avg_multi_batch() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT avg(c12) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); - let batch = &results[0]; - let column = batch.column(0); - let array = column.as_any().downcast_ref::().unwrap(); - let actual = array.value(0); - let expected = 0.5089725; - // Due to float number's accuracy, different batch size will lead to different - // answers. - assert!((expected - actual).abs() < 0.01); - Ok(()) -} - -#[tokio::test] -async fn csv_query_nullif_divide_by_0() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; - let actual = &actual[80..90]; // We just want to compare rows 80-89 - let expected = vec![ - vec!["258"], - vec!["664"], - vec![""], - vec!["22"], - vec!["164"], - vec!["448"], - vec!["365"], - vec!["1640"], - vec!["671"], - vec!["203"], - ]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT count(c12) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------------+", - "| COUNT(aggregate_test_100.c12) |", - "+-------------------------------+", - "| 100 |", - "+-------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_approx_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------+--------------+", - "| count_c9 | count_c9_str |", - "+----------+--------------+", - "| 100 | 99 |", - "+----------+--------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_count_without_from() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT count(1 + 1)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------------------+", - "| COUNT(Int64(1) + Int64(1)) |", - "+----------------------------+", - "| 1 |", - "+----------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_array_agg() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = - "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------------------------------------------------------------------+", - "| ARRAYAGG(test.c13) |", - "+------------------------------------------------------------------+", - "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] |", - "+------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_array_agg_empty() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = - "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------------+", - "| ARRAYAGG(test.c13) |", - "+--------------------+", - "| [] |", - "+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_array_agg_one() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = - "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------------------------+", - "| ARRAYAGG(test.c13) |", - "+----------------------------------+", - "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] |", - "+----------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -/// for window functions without order by the first, last, and nth function call does not make sense -#[tokio::test] -async fn csv_query_window_with_empty_over() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "select \ - c9, \ - count(c5) over (), \ - max(c5) over (), \ - min(c5) over () \ - from aggregate_test_100 \ - order by c9 \ - limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------+------------------------------+----------------------------+----------------------------+", - "| c9 | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) |", - "+-----------+------------------------------+----------------------------+----------------------------+", - "| 28774375 | 100 | 2143473091 | -2141999138 |", - "| 63044568 | 100 | 2143473091 | -2141999138 |", - "| 141047417 | 100 | 2143473091 | -2141999138 |", - "| 141680161 | 100 | 2143473091 | -2141999138 |", - "| 145294611 | 100 | 2143473091 | -2141999138 |", - "+-----------+------------------------------+----------------------------+----------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -/// for window functions without order by the first, last, and nth function call does not make sense -#[tokio::test] -#[ignore] -async fn csv_query_window_with_partition_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "select \ - c9, \ - sum(cast(c4 as Int)) over (partition by c3), \ - avg(cast(c4 as Int)) over (partition by c3), \ - count(cast(c4 as Int)) over (partition by c3), \ - max(cast(c4 as Int)) over (partition by c3), \ - min(cast(c4 as Int)) over (partition by c3) \ - from aggregate_test_100 \ - order by c9 \ - limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", - "| c9 | SUM(CAST(aggregate_test_100.c4 AS Int32)) | AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) | MIN(CAST(aggregate_test_100.c4 AS Int32)) |", - "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", - "| 28774375 | -16110 | -16110 | 1 | -16110 | -16110 |", - "| 63044568 | 3917 | 3917 | 1 | 3917 | 3917 |", - "| 141047417 | -38455 | -19227.5 | 2 | -16974 | -21481 |", - "| 141680161 | -1114 | -1114 | 1 | -1114 | -1114 |", - "| 145294611 | 15673 | 15673 | 1 | 15673 | 15673 |", - "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_window_with_order_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "select \ - c9, \ - sum(c5) over (order by c9), \ - avg(c5) over (order by c9), \ - count(c5) over (order by c9), \ - max(c5) over (order by c9), \ - min(c5) over (order by c9), \ - first_value(c5) over (order by c9), \ - last_value(c5) over (order by c9), \ - nth_value(c5, 2) over (order by c9) \ - from aggregate_test_100 \ - order by c9 \ - limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", - "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", - "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", - "| 28774375 | 61035129 | 61035129 | 1 | 61035129 | 61035129 | 61035129 | 61035129 | |", - "| 63044568 | -47938237 | -23969118.5 | 2 | 61035129 | -108973366 | 61035129 | -108973366 | -108973366 |", - "| 141047417 | 575165281 | 191721760.33333334 | 3 | 623103518 | -108973366 | 61035129 | 623103518 | -108973366 |", - "| 141680161 | -1352462829 | -338115707.25 | 4 | 623103518 | -1927628110 | 61035129 | -1927628110 | -108973366 |", - "| 145294611 | -3251637940 | -650327588 | 5 | 623103518 | -1927628110 | 61035129 | -1899175111 | -108973366 |", - "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_window_with_partition_by_order_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "select \ - c9, \ - sum(c5) over (partition by c4 order by c9), \ - avg(c5) over (partition by c4 order by c9), \ - count(c5) over (partition by c4 order by c9), \ - max(c5) over (partition by c4 order by c9), \ - min(c5) over (partition by c4 order by c9), \ - first_value(c5) over (partition by c4 order by c9), \ - last_value(c5) over (partition by c4 order by c9), \ - nth_value(c5, 2) over (partition by c4 order by c9) \ - from aggregate_test_100 \ - order by c9 \ - limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", - "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", - "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", - "| 28774375 | 61035129 | 61035129 | 1 | 61035129 | 61035129 | 61035129 | 61035129 | |", - "| 63044568 | -108973366 | -108973366 | 1 | -108973366 | -108973366 | -108973366 | -108973366 | |", - "| 141047417 | 623103518 | 623103518 | 1 | 623103518 | 623103518 | 623103518 | 623103518 | |", - "| 141680161 | -1927628110 | -1927628110 | 1 | -1927628110 | -1927628110 | -1927628110 | -1927628110 | |", - "| 145294611 | -1899175111 | -1899175111 | 1 | -1899175111 | -1899175111 | -1899175111 | -1899175111 | |", - "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+" - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_int_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-------------------------------+", - "| c1 | COUNT(aggregate_test_100.c12) |", - "+----+-------------------------------+", - "| a | 21 |", - "| b | 19 |", - "| c | 21 |", - "| d | 18 |", - "| e | 21 |", - "+----+-------------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_with_aliased_aggregate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-------+", - "| c1 | count |", - "+----+-------+", - "| a | 21 |", - "| b | 19 |", - "| c | 21 |", - "| d | 18 |", - "| e | 21 |", - "+----+-------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_string_min_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-----------------------------+-----------------------------+", - "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", - "+----+-----------------------------+-----------------------------+", - "| a | 0.02182578039211991 | 0.9800193410444061 |", - "| b | 0.04893135681998029 | 0.9185813970744787 |", - "| c | 0.0494924465469434 | 0.991517828651004 |", - "| d | 0.061029375346466685 | 0.9748360509016578 |", - "| e | 0.01479305307777301 | 0.9965400387585364 |", - "+----+-----------------------------+-----------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_cast() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----------------------------------------+", - "| CAST(aggregate_test_100.c12 AS Float32) |", - "+-----------------------------------------+", - "| 0.39144436 |", - "| 0.3887028 |", - "+-----------------------------------------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_cast_literal() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = - "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+--------------------+---------------------------+", - "| c12 | CAST(Int64(1) AS Float32) |", - "+--------------------+---------------------------+", - "| 0.9294097332465232 | 1 |", - "| 0.3114712539863804 | 1 |", - "+--------------------+---------------------------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_cast_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); - let t1_data = RecordBatch::try_new( - t1_schema.clone(), - vec![Arc::new(Int64Array::from_values(vec![ - 1235865600000, - 1235865660000, - 1238544000000, - ]))], - )?; - let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table))?; - - let sql = "SELECT to_timestamp_millis(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+--------------------------+", - "| totimestampmillis(t1.ts) |", - "+--------------------------+", - "| 2009-03-01 00:00:00 |", - "| 2009-03-01 00:01:00 |", - "| 2009-04-01 00:00:00 |", - "+--------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_cast_timestamp_micros() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); - let t1_data = RecordBatch::try_new( - t1_schema.clone(), - vec![Arc::new(Int64Array::from_values(vec![ - 1235865600000000, - 1235865660000000, - 1238544000000000, - ]))], - )?; - let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table))?; - - let sql = "SELECT to_timestamp_micros(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+--------------------------+", - "| totimestampmicros(t1.ts) |", - "+--------------------------+", - "| 2009-03-01 00:00:00 |", - "| 2009-03-01 00:01:00 |", - "| 2009-04-01 00:00:00 |", - "+--------------------------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_cast_timestamp_seconds() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); - let t1_data = RecordBatch::try_new( - t1_schema.clone(), - vec![Arc::new(Int64Array::from_values(vec![ - 1235865600, 1235865660, 1238544000, - ]))], - )?; - let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table))?; - - let sql = "SELECT to_timestamp_seconds(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+---------------------------+", - "| totimestampseconds(t1.ts) |", - "+---------------------------+", - "| 2009-03-01 00:00:00 |", - "| 2009-03-01 00:01:00 |", - "| 2009-04-01 00:00:00 |", - "+---------------------------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_cast_timestamp_nanos_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_nano_table()?)?; - - // Original column is nanos, convert to millis and check timestamp - let sql = "SELECT to_timestamp_millis(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------------------------------+", - "| totimestampmillis(ts_data.ts) |", - "+-------------------------------+", - "| 2020-09-08 13:42:29.190 |", - "| 2020-09-08 12:42:29.190 |", - "| 2020-09-08 11:42:29.190 |", - "+-------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT to_timestamp_micros(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------------------------------+", - "| totimestampmicros(ts_data.ts) |", - "+-------------------------------+", - "| 2020-09-08 13:42:29.190855 |", - "| 2020-09-08 12:42:29.190855 |", - "| 2020-09-08 11:42:29.190855 |", - "+-------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT to_timestamp_seconds(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------------------------+", - "| totimestampseconds(ts_data.ts) |", - "+--------------------------------+", - "| 2020-09-08 13:42:29 |", - "| 2020-09-08 12:42:29 |", - "| 2020-09-08 11:42:29 |", - "+--------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn query_cast_timestamp_seconds_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_secs", make_timestamp_table(TimeUnit::Second)?)?; - - // Original column is seconds, convert to millis and check timestamp - let sql = "SELECT to_timestamp_millis(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------------+", - "| totimestampmillis(ts_secs.ts) |", - "+-------------------------------+", - "| 2020-09-08 13:42:29 |", - "| 2020-09-08 12:42:29 |", - "| 2020-09-08 11:42:29 |", - "+-------------------------------+", - ]; - - assert_batches_eq!(expected, &actual); - - // Original column is seconds, convert to micros and check timestamp - let sql = "SELECT to_timestamp_micros(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------------+", - "| totimestampmicros(ts_secs.ts) |", - "+-------------------------------+", - "| 2020-09-08 13:42:29 |", - "| 2020-09-08 12:42:29 |", - "| 2020-09-08 11:42:29 |", - "+-------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - // to nanos - let sql = "SELECT to_timestamp(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------+", - "| totimestamp(ts_secs.ts) |", - "+-------------------------+", - "| 2020-09-08 13:42:29 |", - "| 2020-09-08 12:42:29 |", - "| 2020-09-08 11:42:29 |", - "+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_cast_timestamp_micros_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_micros", make_timestamp_table(TimeUnit::Microsecond)?)?; - - // Original column is micros, convert to millis and check timestamp - let sql = "SELECT to_timestamp_millis(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------------------------+", - "| totimestampmillis(ts_micros.ts) |", - "+---------------------------------+", - "| 2020-09-08 13:42:29.190 |", - "| 2020-09-08 12:42:29.190 |", - "| 2020-09-08 11:42:29.190 |", - "+---------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - // Original column is micros, convert to seconds and check timestamp - let sql = "SELECT to_timestamp_seconds(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------------------------+", - "| totimestampseconds(ts_micros.ts) |", - "+----------------------------------+", - "| 2020-09-08 13:42:29 |", - "| 2020-09-08 12:42:29 |", - "| 2020-09-08 11:42:29 |", - "+----------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - // Original column is micros, convert to nanos and check timestamp - let sql = "SELECT to_timestamp(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------------------+", - "| totimestamp(ts_micros.ts) |", - "+----------------------------+", - "| 2020-09-08 13:42:29.190855 |", - "| 2020-09-08 12:42:29.190855 |", - "| 2020-09-08 11:42:29.190855 |", - "+----------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn union_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT 1 as x UNION ALL SELECT 2 as x"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "| 2 |", "+---+"]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_union_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = - "SELECT c1 FROM aggregate_test_100 UNION ALL SELECT c1 FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; - assert_eq!(actual.len(), 200); - Ok(()) -} - -#[tokio::test] -async fn csv_query_limit() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["+----+", "| c1 |", "+----+", "| c |", "| d |", "+----+"]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", - "| 4 |", "| 3 |", "| 3 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 3 |", - "| 2 |", "| 1 |", "| 1 |", "| 2 |", "| 1 |", "| 3 |", "| 2 |", "| 4 |", - "| 1 |", "| 5 |", "| 4 |", "| 2 |", "| 1 |", "| 4 |", "| 5 |", "| 2 |", - "| 3 |", "| 4 |", "| 2 |", "| 1 |", "| 5 |", "| 3 |", "| 1 |", "| 2 |", - "| 3 |", "| 3 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", "| 2 |", - "| 5 |", "| 2 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 2 |", "| 5 |", - "| 4 |", "| 2 |", "| 3 |", "| 4 |", "| 4 |", "| 4 |", "| 5 |", "| 4 |", - "| 2 |", "| 1 |", "| 2 |", "| 4 |", "| 2 |", "| 3 |", "| 5 |", "| 1 |", - "| 1 |", "| 4 |", "| 2 |", "| 1 |", "| 2 |", "| 1 |", "| 1 |", "| 5 |", - "| 4 |", "| 5 |", "| 2 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", - "| 4 |", "| 3 |", "| 2 |", "| 5 |", "| 3 |", "| 3 |", "| 2 |", "| 5 |", - "| 5 |", "| 4 |", "| 1 |", "| 3 |", "| 3 |", "| 4 |", "| 4 |", "+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", - "| 4 |", "| 3 |", "| 3 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 3 |", - "| 2 |", "| 1 |", "| 1 |", "| 2 |", "| 1 |", "| 3 |", "| 2 |", "| 4 |", - "| 1 |", "| 5 |", "| 4 |", "| 2 |", "| 1 |", "| 4 |", "| 5 |", "| 2 |", - "| 3 |", "| 4 |", "| 2 |", "| 1 |", "| 5 |", "| 3 |", "| 1 |", "| 2 |", - "| 3 |", "| 3 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", "| 2 |", - "| 5 |", "| 2 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 2 |", "| 5 |", - "| 4 |", "| 2 |", "| 3 |", "| 4 |", "| 4 |", "| 4 |", "| 5 |", "| 4 |", - "| 2 |", "| 1 |", "| 2 |", "| 4 |", "| 2 |", "| 3 |", "| 5 |", "| 1 |", - "| 1 |", "| 4 |", "| 2 |", "| 1 |", "| 2 |", "| 1 |", "| 1 |", "| 5 |", - "| 4 |", "| 5 |", "| 2 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", - "| 4 |", "| 3 |", "| 2 |", "| 5 |", "| 3 |", "| 3 |", "| 2 |", "| 5 |", - "| 5 |", "| 4 |", "| 1 |", "| 3 |", "| 3 |", "| 4 |", "| 4 |", "+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_limit_zero() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 0"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["++", "++"]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_create_external_table() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", - "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | Int64(10) | c11 | c12 | c13 |", - "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", - "| c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 10 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW |", - "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_external_table_count() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "SELECT COUNT(c12) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------------+", - "| COUNT(aggregate_test_100.c12) |", - "+-------------------------------+", - "| 100 |", - "+-------------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_external_table_sum() { - let mut ctx = ExecutionContext::new(); - // cast smallint and int to bigint to avoid overflow during calculation - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = - "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------------------------+-------------------------------------------+", - "| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |", - "+-------------------------------------------+-------------------------------------------+", - "| 13060 | 3017641 |", - "+-------------------------------------------+-------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_count_star() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "SELECT COUNT(*) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 100 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn csv_query_count_one() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "SELECT COUNT(1) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 100 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn case_when() -> Result<()> { - let mut ctx = create_case_context()?; - let sql = "SELECT \ - CASE WHEN c1 = 'a' THEN 1 \ - WHEN c1 = 'b' THEN 2 \ - END \ - FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------------------------------------------------------------------------------+", - "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) END |", - "+--------------------------------------------------------------------------------------+", - "| 1 |", - "| 2 |", - "| |", - "| |", - "+--------------------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn case_when_else() -> Result<()> { - let mut ctx = create_case_context()?; - let sql = "SELECT \ - CASE WHEN c1 = 'a' THEN 1 \ - WHEN c1 = 'b' THEN 2 \ - ELSE 999 END \ - FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------------------------------------------------------------------------------------------------------+", - "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", - "+------------------------------------------------------------------------------------------------------+", - "| 1 |", - "| 2 |", - "| 999 |", - "| 999 |", - "+------------------------------------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn case_when_with_base_expr() -> Result<()> { - let mut ctx = create_case_context()?; - let sql = "SELECT \ - CASE c1 WHEN 'a' THEN 1 \ - WHEN 'b' THEN 2 \ - END \ - FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------------------------------------------------------------------+", - "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) END |", - "+---------------------------------------------------------------------------+", - "| 1 |", - "| 2 |", - "| |", - "| |", - "+---------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn case_when_else_with_base_expr() -> Result<()> { - let mut ctx = create_case_context()?; - let sql = "SELECT \ - CASE c1 WHEN 'a' THEN 1 \ - WHEN 'b' THEN 2 \ - ELSE 999 END \ - FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------------------------------------------------------------------------+", - "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", - "+-------------------------------------------------------------------------------------------+", - "| 1 |", - "| 2 |", - "| 999 |", - "| 999 |", - "+-------------------------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -fn create_case_context() -> Result { - let mut ctx = ExecutionContext::new(); - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Utf8Array::::from(vec![ - Some("a"), - Some("b"), - Some("c"), - None, - ]))], - )?; - let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("t1", Arc::new(table))?; - Ok(ctx) -} - -#[tokio::test] -async fn equijoin() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - - let mut ctx = create_join_context_qualified()?; - let equivalent_sql = [ - "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a", - "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a", - ]; - let expected = vec![ - "+---+-----+", - "| a | b |", - "+---+-----+", - "| 1 | 100 |", - "| 2 | 200 |", - "| 4 | 400 |", - "+---+-----+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_multiple_condition_ordering() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t1_name <> t2_name ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t2_name <> t1_name ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_and_other_condition() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn equijoin_left_and_condition_from_right() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; - let res = ctx.create_logical_plan(sql); - assert!(res.is_ok()); - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn equijoin_right_and_condition_from_left() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 ORDER BY t2_name"; - let res = ctx.create_logical_plan(sql); - assert!(res.is_ok()); - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| | | w |", - "| 44 | d | x |", - "| 22 | b | y |", - "| | | z |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn equijoin_and_unsupported_condition() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= '44' ORDER BY t1_id"; - let res = ctx.create_logical_plan(sql); - - assert!(res.is_err()); - assert_eq!(format!("{}", res.unwrap_err()), "This feature is not implemented: Unsupported expressions in Left JOIN: [#t1_id >= Utf8(\"44\")]"); - - Ok(()) -} - -#[tokio::test] -async fn left_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn left_join_unbalanced() -> Result<()> { - // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in - let mut ctx = create_join_context_unbalanced("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "| 77 | e | |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn right_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id" - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "| | | w |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn full_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "| | | w |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn left_join_using() -> Result<()> { - let mut ctx = create_join_context("id", "id")?; - let sql = "SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+---------+---------+", - "| id | t1_name | t2_name |", - "+----+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "+----+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn equijoin_implicit_syntax() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_implicit_syntax_with_filter() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = "SELECT t1_id, t1_name, t2_name \ - FROM t1, t2 \ - WHERE t1_id > 0 \ - AND t1_id = t2_id \ - AND t2_id < 99 \ - ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn equijoin_implicit_syntax_reversed() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn cross_join() { - let mut ctx = create_join_context("t1_id", "t2_id").unwrap(); - - let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; - - assert_eq!(4 * 4, actual.len()); - - let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; - - assert_eq!(4 * 4, actual.len()); - - let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2"; - - let actual = execute(&mut ctx, sql).await; - assert_eq!(4 * 4, actual.len()); - - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 11 | a | y |", - "| 11 | a | x |", - "| 11 | a | w |", - "| 22 | b | z |", - "| 22 | b | y |", - "| 22 | b | x |", - "| 22 | b | w |", - "| 33 | c | z |", - "| 33 | c | y |", - "| 33 | c | x |", - "| 33 | c | w |", - "| 44 | d | z |", - "| 44 | d | y |", - "| 44 | d | x |", - "| 44 | d | w |", - "+-------+---------+---------+", - ]; - - assert_batches_eq!(expected, &actual); - - // Two partitions (from UNION) on the left - let sql = "SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2"; - let actual = execute(&mut ctx, sql).await; - - assert_eq!(4 * 4 * 2, actual.len()); - - // Two partitions (from UNION) on the right - let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2"; - let actual = execute(&mut ctx, sql).await; - - assert_eq!(4 * 4 * 2, actual.len()); -} - -#[tokio::test] -async fn cross_join_unbalanced() { - // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in - let mut ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); - - // the order of the values is not determinisitic, so we need to sort to check the values - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 11 | a | y |", - "| 11 | a | x |", - "| 11 | a | w |", - "| 22 | b | z |", - "| 22 | b | y |", - "| 22 | b | x |", - "| 22 | b | w |", - "| 33 | c | z |", - "| 33 | c | y |", - "| 33 | c | x |", - "| 33 | c | w |", - "| 44 | d | z |", - "| 44 | d | y |", - "| 44 | d | x |", - "| 44 | d | w |", - "| 77 | e | z |", - "| 77 | e | y |", - "| 77 | e | x |", - "| 77 | e | w |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn test_join_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - // register time table - let timestamp_schema = Arc::new(Schema::new(vec![Field::new( - "time", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - )])); - let timestamp_data = RecordBatch::try_new( - timestamp_schema.clone(), - vec![Arc::new( - Int64Array::from_slice(&[131964190213133, 131964190213134, 131964190213135]) - .to(DataType::Timestamp(TimeUnit::Nanosecond, None)), - )], - )?; - let timestamp_table = - MemTable::try_new(timestamp_schema, vec![vec![timestamp_data]])?; - ctx.register_table("timestamp", Arc::new(timestamp_table))?; - - let sql = "SELECT * \ - FROM timestamp as a \ - JOIN (SELECT * FROM timestamp) as b \ - ON a.time = b.time \ - ORDER BY a.time"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-------------------------------+-------------------------------+", - "| time | time |", - "+-------------------------------+-------------------------------+", - "| 1970-01-02 12:39:24.190213133 | 1970-01-02 12:39:24.190213133 |", - "| 1970-01-02 12:39:24.190213134 | 1970-01-02 12:39:24.190213134 |", - "| 1970-01-02 12:39:24.190213135 | 1970-01-02 12:39:24.190213135 |", - "+-------------------------------+-------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn test_join_float32() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - // register population table - let population_schema = Arc::new(Schema::new(vec![ - Field::new("city", DataType::Utf8, true), - Field::new("population", DataType::Float32, true), - ])); - let population_data = RecordBatch::try_new( - population_schema.clone(), - vec![ - Arc::new(Utf8Array::::from(vec![ - Some("a"), - Some("b"), - Some("c"), - ])), - Arc::new(Float32Array::from_slice(vec![838.698, 1778.934, 626.443])), - ], - )?; - let population_table = - MemTable::try_new(population_schema, vec![vec![population_data]])?; - ctx.register_table("population", Arc::new(population_table))?; - - let sql = "SELECT * \ - FROM population as a \ - JOIN (SELECT * FROM population) as b \ - ON a.population = b.population \ - ORDER BY a.population"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+------+------------+------+------------+", - "| city | population | city | population |", - "+------+------------+------+------------+", - "| c | 626.443 | c | 626.443 |", - "| a | 838.698 | a | 838.698 |", - "| b | 1778.934 | b | 1778.934 |", - "+------+------------+------+------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn test_join_float64() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - // register population table - let population_schema = Arc::new(Schema::new(vec![ - Field::new("city", DataType::Utf8, true), - Field::new("population", DataType::Float64, true), - ])); - let population_data = RecordBatch::try_new( - population_schema.clone(), - vec![ - Arc::new(Utf8Array::::from(vec![ - Some("a"), - Some("b"), - Some("c"), - ])), - Arc::new(Float64Array::from_slice(vec![838.698, 1778.934, 626.443])), - ], - )?; - let population_table = - MemTable::try_new(population_schema, vec![vec![population_data]])?; - ctx.register_table("population", Arc::new(population_table))?; - - let sql = "SELECT * \ - FROM population as a \ - JOIN (SELECT * FROM population) as b \ - ON a.population = b.population \ - ORDER BY a.population"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+------+------------+------+------------+", - "| city | population | city | population |", - "+------+------------+------+------------+", - "| c | 626.443 | c | 626.443 |", - "| a | 838.698 | a | 838.698 |", - "| b | 1778.934 | b | 1778.934 |", - "+------+------------+------+------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -fn create_join_context( - column_left: &str, - column_right: &str, -) -> Result { - let mut ctx = ExecutionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new(column_left, DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema.clone(), - vec![ - Arc::new(UInt32Array::from_slice(&[11, 22, 33, 44])), - Arc::new(Utf8Array::::from(&[ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - ])), - ], - )?; - let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table))?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new(column_right, DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema.clone(), - vec![ - Arc::new(UInt32Array::from_slice(&[11, 22, 44, 55])), - Arc::new(Utf8Array::::from(&[ - Some("z"), - Some("y"), - Some("x"), - Some("w"), - ])), - ], - )?; - let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table))?; - - Ok(ctx) -} - -fn create_join_context_qualified() -> Result { - let mut ctx = ExecutionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt32, true), - Field::new("b", DataType::UInt32, true), - Field::new("c", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema.clone(), - vec![ - Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4])), - Arc::new(UInt32Array::from_slice(&[10, 20, 30, 40])), - Arc::new(UInt32Array::from_slice(&[50, 60, 70, 80])), - ], - )?; - let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table))?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt32, true), - Field::new("b", DataType::UInt32, true), - Field::new("c", DataType::UInt32, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema.clone(), - vec![ - Arc::new(UInt32Array::from_slice(&[1, 2, 9, 4])), - Arc::new(UInt32Array::from_slice(&[100, 200, 300, 400])), - Arc::new(UInt32Array::from_slice(&[500, 600, 700, 800])), - ], - )?; - let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table))?; - - Ok(ctx) -} - -/// the table column_left has more rows than the table column_right -fn create_join_context_unbalanced( - column_left: &str, - column_right: &str, -) -> Result { - let mut ctx = ExecutionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new(column_left, DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema.clone(), - vec![ - Arc::new(UInt32Array::from_values(vec![11, 22, 33, 44, 77])), - Arc::new(Utf8Array::::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - Some("e"), - ])), - ], - )?; - let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(t1_table))?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new(column_right, DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema.clone(), - vec![ - Arc::new(UInt32Array::from_values(vec![11, 22, 44, 55])), - Arc::new(Utf8Array::::from(vec![ - Some("z"), - Some("y"), - Some("x"), - Some("w"), - ])), - ], - )?; - let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(t2_table))?; - - Ok(ctx) -} - -#[tokio::test] -async fn csv_explain() { - // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, - // then execute the physical plan and return the final explain results - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - - // Note can't use `assert_batches_eq` as the plan needs to be - // normalized for filenames and number of cores - let expected = vec![ - vec![ - "logical_plan", - "Projection: #aggregate_test_100.c1\ - \n Filter: #aggregate_test_100.c2 > Int64(10)\ - \n TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]" - ], - vec!["physical_plan", - "ProjectionExec: expr=[c1@0 as c1]\ - \n CoalesceBatchesExec: target_batch_size=4096\ - \n FilterExec: CAST(c2@1 AS Int64) > 10\ - \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ - \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None\ - \n" - ]]; - assert_eq!(expected, actual); - - // Also, expect same result with lowercase explain - let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - assert_eq!(expected, actual); -} - -#[tokio::test] -async fn csv_explain_analyze() { - // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "EXPLAIN ANALYZE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let formatted = print::write(&actual); - - // Only test basic plumbing and try to avoid having to change too - // many things. explain_analyze_baseline_metrics covers the values - // in greater depth - let needle = "CoalescePartitionsExec, metrics=[output_rows=5, elapsed_compute="; - assert_contains!(&formatted, needle); - - let verbose_needle = "Output Rows"; - assert_not_contains!(formatted, verbose_needle); -} - -#[tokio::test] -async fn csv_explain_analyze_verbose() { - // This test uses the execute function to run an actual plan under EXPLAIN VERBOSE ANALYZE - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = - "EXPLAIN ANALYZE VERBOSE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let formatted = print::write(&actual); - - let verbose_needle = "Output Rows"; - assert_contains!(formatted, verbose_needle); -} - -/// A macro to assert that some particular line contains two substrings -/// -/// Usage: `assert_metrics!(actual, operator_name, metrics)` -/// -macro_rules! assert_metrics { - ($ACTUAL: expr, $OPERATOR_NAME: expr, $METRICS: expr) => { - let found = $ACTUAL - .lines() - .any(|line| line.contains($OPERATOR_NAME) && line.contains($METRICS)); - assert!( - found, - "Can not find a line with both '{}' and '{}' in\n\n{}", - $OPERATOR_NAME, $METRICS, $ACTUAL - ); - }; -} - -#[tokio::test] -async fn explain_analyze_baseline_metrics() { - // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE - // and then validate the presence of baseline metrics for supported operators - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); - register_aggregate_csv_by_sql(&mut ctx).await; - // a query with as many operators as we have metrics for - let sql = "EXPLAIN ANALYZE \ - SELECT count(*) as cnt FROM \ - (SELECT count(*), c1 \ - FROM aggregate_test_100 \ - WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434' \ - GROUP BY c1 \ - ORDER BY c1 ) AS a \ - UNION ALL \ - SELECT 1 as cnt \ - UNION ALL \ - SELECT lead(c1, 1) OVER () as cnt FROM (select 1 as c1) AS b \ - LIMIT 3"; - println!("running query: {}", sql); - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(physical_plan.clone()).await.unwrap(); - let formatted = print::write(&results); - println!("Query Output:\n\n{}", formatted); - - assert_metrics!( - &formatted, - "HashAggregateExec: mode=Partial, gby=[]", - "metrics=[output_rows=3, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "SortExec: [c1@0 ASC NULLS LAST]", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", - "metrics=[output_rows=99, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "GlobalLimitExec: limit=3, ", - "metrics=[output_rows=1, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "LocalLimitExec: limit=3", - "metrics=[output_rows=3, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "ProjectionExec: expr=[COUNT(UInt8(1))", - "metrics=[output_rows=1, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "CoalesceBatchesExec: target_batch_size=4096", - "metrics=[output_rows=5, elapsed_compute" - ); - assert_metrics!( - &formatted, - "CoalescePartitionsExec", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "UnionExec", - "metrics=[output_rows=3, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "WindowAggExec", - "metrics=[output_rows=1, elapsed_compute=" - ); - - fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { - use datafusion::physical_plan; - - plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - // CoalescePartitionsExec doesn't do any work so is not included - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - } - - // Validate that the recorded elapsed compute time was more than - // zero for all operators as well as the start/end timestamp are set - struct TimeValidator {} - impl ExecutionPlanVisitor for TimeValidator { - type Error = std::convert::Infallible; - - fn pre_visit( - &mut self, - plan: &dyn ExecutionPlan, - ) -> std::result::Result { - if !expected_to_have_metrics(plan) { - return Ok(true); - } - let metrics = plan.metrics().unwrap().aggregate_by_partition(); - - assert!(metrics.output_rows().unwrap() > 0); - assert!(metrics.elapsed_compute().unwrap() > 0); - - let mut saw_start = false; - let mut saw_end = false; - metrics.iter().for_each(|m| match m.value() { - MetricValue::StartTimestamp(ts) => { - saw_start = true; - assert!(ts.value().unwrap().timestamp_nanos() > 0); - } - MetricValue::EndTimestamp(ts) => { - saw_end = true; - assert!(ts.value().unwrap().timestamp_nanos() > 0); - } - _ => {} - }); - - assert!(saw_start); - assert!(saw_end); - - Ok(true) - } - } - - datafusion::physical_plan::accept(physical_plan.as_ref(), &mut TimeValidator {}) - .unwrap(); -} - -#[tokio::test] -async fn csv_explain_plans() { - // This test verify the look of each plan in its full cycle plan creation - - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; - - // Logical plan - // Create plan - let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let logical_schema = plan.schema(); - // - println!("SQL: {}", sql); - // - // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 projection=None [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=None", - ]; - let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=None\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=None\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; - let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - - // Optimized logical plan - // - let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); - let plan = ctx.optimize(&plan).expect(&msg); - let optimized_logical_schema = plan.schema(); - // Both schema has to be the same - assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); - // - // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]", - ]; - let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; - let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - - // Physical plan - // Create plan - let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); - let plan = ctx.create_physical_plan(&plan).await.expect(&msg); - // - // Execute plan - let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); - let actual = result_vec(&results); - // flatten to a single string - let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); - // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content - assert_contains!(&actual, "logical_plan"); - assert_contains!(&actual, "Projection: #aggregate_test_100.c1"); - assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)"); -} - -#[tokio::test] -async fn csv_explain_verbose() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; - - // flatten to a single string - let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); - - // Don't actually test the contents of the debuging output (as - // that may change and keeping this test updated will be a - // pain). Instead just check for a few key pieces. - assert_contains!(&actual, "logical_plan"); - assert_contains!(&actual, "physical_plan"); - assert_contains!(&actual, "#aggregate_test_100.c2 > Int64(10)"); - - // ensure the "same text as above" optimization is working - assert_contains!(actual, "SAME TEXT AS ABOVE"); -} - -#[tokio::test] -async fn csv_explain_verbose_plans() { - // This test verify the look of each plan in its full cycle plan creation - - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; - - // Logical plan - // Create plan - let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let logical_schema = plan.schema(); - // - println!("SQL: {}", sql); - - // - // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 projection=None [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=None", - ]; - let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=None\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=None\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; - let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - - // Optimized logical plan - // - let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); - let plan = ctx.optimize(&plan).expect(&msg); - let optimized_logical_schema = plan.schema(); - // Both schema has to be the same - assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); - // - // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]", - ]; - let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - // - // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; - let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - - // Physical plan - // Create plan - let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); - let plan = ctx.create_physical_plan(&plan).await.expect(&msg); - // - // Execute plan - let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); - let actual = result_vec(&results); - // flatten to a single string - let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); - // Since the plan contains path that are environmentally - // dependant(e.g. full path of the test file), only verify - // important content - assert_contains!(&actual, "logical_plan after projection_push_down"); - assert_contains!(&actual, "physical_plan"); - assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10"); - assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); -} - -#[tokio::test] -async fn explain_analyze_runs_optimizers() { - // repro for https://github.com/apache/arrow-datafusion/issues/917 - // where EXPLAIN ANALYZE was not correctly running optiimizer - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - - // This happens as an optimization pass where count(*) can be - // answered using statistics only. - let expected = "EmptyExec: produce_one_row=true"; - - let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; - let actual = print::write(&actual); - assert_contains!(actual, expected); - - // EXPLAIN ANALYZE should work the same - let sql = "EXPLAIN ANALYZE SELECT count(*) from alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; - let actual = print::write(&actual); - assert_contains!(actual, expected); -} - -#[tokio::test] -async fn tpch_explain_q10() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - register_tpch_csv(&mut ctx, "customer").await?; - register_tpch_csv(&mut ctx, "orders").await?; - register_tpch_csv(&mut ctx, "lineitem").await?; - register_tpch_csv(&mut ctx, "nation").await?; - - let sql = "select - c_custkey, - c_name, - sum(l_extendedprice * (1 - l_discount)) as revenue, - c_acctbal, - n_name, - c_address, - c_phone, - c_comment -from - customer, - orders, - lineitem, - nation -where - c_custkey = o_custkey - and l_orderkey = o_orderkey - and o_orderdate >= date '1993-10-01' - and o_orderdate < date '1994-01-01' - and l_returnflag = 'R' - and c_nationkey = n_nationkey -group by - c_custkey, - c_name, - c_acctbal, - c_phone, - n_name, - c_address, - c_comment -order by - revenue desc;"; - - let mut plan = ctx.create_logical_plan(sql); - plan = ctx.optimize(&plan.unwrap()); - - let expected = "\ - Sort: #revenue DESC NULLS FIRST\ - \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ - \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\ - \n Join: #customer.c_nationkey = #nation.n_nationkey\ - \n Join: #orders.o_orderkey = #lineitem.l_orderkey\ - \n Join: #customer.c_custkey = #orders.o_custkey\ - \n TableScan: customer projection=Some([0, 1, 2, 3, 4, 5, 7])\ - \n Filter: #orders.o_orderdate >= Date32(\"8674\") AND #orders.o_orderdate < Date32(\"8766\")\ - \n TableScan: orders projection=Some([0, 1, 4]), filters=[#orders.o_orderdate >= Date32(\"8674\"), #orders.o_orderdate < Date32(\"8766\")]\ - \n Filter: #lineitem.l_returnflag = Utf8(\"R\")\ - \n TableScan: lineitem projection=Some([0, 5, 6, 8]), filters=[#lineitem.l_returnflag = Utf8(\"R\")]\ - \n TableScan: nation projection=Some([0, 1])"; - assert_eq!(format!("{:?}", plan.unwrap()), expected); - - Ok(()) -} - -fn get_tpch_table_schema(table: &str) -> Schema { - match table { - "customer" => Schema::new(vec![ - Field::new("c_custkey", DataType::Int64, false), - Field::new("c_name", DataType::Utf8, false), - Field::new("c_address", DataType::Utf8, false), - Field::new("c_nationkey", DataType::Int64, false), - Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Float64, false), - Field::new("c_mktsegment", DataType::Utf8, false), - Field::new("c_comment", DataType::Utf8, false), - ]), - - "orders" => Schema::new(vec![ - Field::new("o_orderkey", DataType::Int64, false), - Field::new("o_custkey", DataType::Int64, false), - Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Float64, false), - Field::new("o_orderdate", DataType::Date32, false), - Field::new("o_orderpriority", DataType::Utf8, false), - Field::new("o_clerk", DataType::Utf8, false), - Field::new("o_shippriority", DataType::Int32, false), - Field::new("o_comment", DataType::Utf8, false), - ]), - - "lineitem" => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int64, false), - Field::new("l_partkey", DataType::Int64, false), - Field::new("l_suppkey", DataType::Int64, false), - Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Float64, false), - Field::new("l_extendedprice", DataType::Float64, false), - Field::new("l_discount", DataType::Float64, false), - Field::new("l_tax", DataType::Float64, false), - Field::new("l_returnflag", DataType::Utf8, false), - Field::new("l_linestatus", DataType::Utf8, false), - Field::new("l_shipdate", DataType::Date32, false), - Field::new("l_commitdate", DataType::Date32, false), - Field::new("l_receiptdate", DataType::Date32, false), - Field::new("l_shipinstruct", DataType::Utf8, false), - Field::new("l_shipmode", DataType::Utf8, false), - Field::new("l_comment", DataType::Utf8, false), - ]), - - "nation" => Schema::new(vec![ - Field::new("n_nationkey", DataType::Int64, false), - Field::new("n_name", DataType::Utf8, false), - Field::new("n_regionkey", DataType::Int64, false), - Field::new("n_comment", DataType::Utf8, false), - ]), - - _ => unimplemented!(), - } -} - -async fn register_tpch_csv(ctx: &mut ExecutionContext, table: &str) -> Result<()> { - let schema = get_tpch_table_schema(table); - - ctx.register_csv( - table, - format!("tests/tpch-csv/{}.csv", table).as_str(), - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - -async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { - let testdata = datafusion::test_util::arrow_test_data(); - - // TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once - // unsigned is supported. - let df = ctx - .sql(&format!( - " - CREATE EXTERNAL TABLE aggregate_test_100 ( - c1 VARCHAR NOT NULL, - c2 INT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT NOT NULL, - c5 INT NOT NULL, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 BIGINT NOT NULL, - c10 VARCHAR NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL - ) - STORED AS CSV - WITH HEADER ROW - LOCATION '{}/csv/aggregate_test_100.csv' - ", - testdata - )) - .await - .expect("Creating dataframe for CREATE EXTERNAL TABLE"); - - // Mimic the CLI and execute the resulting plan -- even though it - // is effectively a no-op (returns zero rows) - let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE"); - assert!( - results.is_empty(), - "Expected no rows from executing CREATE EXTERNAL TABLE" - ); -} - -/// Create table "t1" with two boolean columns "a" and "b" -async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> { - let a: BooleanArray = [ - Some(true), - Some(true), - Some(true), - None, - None, - None, - Some(false), - Some(false), - Some(false), - ] - .iter() - .collect(); - let b: BooleanArray = [ - Some(true), - None, - Some(false), - Some(true), - None, - Some(false), - Some(true), - None, - Some(false), - ] - .iter() - .collect(); - - let data = - RecordBatch::try_from_iter([("a", Arc::new(a) as _), ("b", Arc::new(b) as _)])?; - let table = MemTable::try_new(data.schema().clone(), vec![vec![data]])?; - ctx.register_table("t1", Arc::new(table))?; - Ok(()) -} - -async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); - let schema = test_util::aggr_test_schema(); - ctx.register_csv( - "aggregate_test_100", - &format!("{}/csv/aggregate_test_100.csv", testdata), - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - -async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) { - let df = ctx - .sql( - "CREATE EXTERNAL TABLE aggregate_simple ( - c1 DECIMAL(10,6) NOT NULL, - c2 DOUBLE NOT NULL, - c3 BOOLEAN NOT NULL - ) - STORED AS CSV - WITH HEADER ROW - LOCATION 'tests/aggregate_simple.csv'", - ) - .await - .expect("Creating dataframe for CREATE EXTERNAL TABLE with decimal data type"); - - let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE"); - assert!( - results.is_empty(), - "Expected no rows from executing CREATE EXTERNAL TABLE" - ); -} - -async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> { - // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Float32, false), - Field::new("c2", DataType::Float64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - ctx.register_csv( - "aggregate_simple", - "tests/aggregate_simple.csv", - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - -async fn register_alltypes_parquet(ctx: &mut ExecutionContext) { - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{}/alltypes_plain.parquet", testdata), - ) - .await - .unwrap(); -} - -#[cfg(feature = "avro")] -async fn register_alltypes_avro(ctx: &mut ExecutionContext) { - let testdata = datafusion::test_util::arrow_test_data(); - ctx.register_avro( - "alltypes_plain", - &format!("{}/avro/alltypes_plain.avro", testdata), - AvroReadOptions::default(), - ) - .await - .unwrap(); -} - -/// Execute query and return result set as 2-d table of Vecs -/// `result[row][column]` -async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec { - let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let logical_schema = plan.schema(); - - let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); - let plan = ctx.optimize(&plan).expect(&msg); - let optimized_logical_schema = plan.schema(); - - let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); - let plan = ctx.create_physical_plan(&plan).await.expect(&msg); - - let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); - - assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); - results -} - -/// Execute query and return result set as 2-d table of Vecs -/// `result[row][column]` -async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { - result_vec(&execute_to_batches(ctx, sql).await) -} - -/// Converts the results into a 2d array of strings, `result[row][column]` -/// Special cases nulls to NULL for testing -fn result_vec(results: &[RecordBatch]) -> Vec> { - let mut result = vec![]; - for batch in results { - let display_col = batch - .columns() - .iter() - .map(|x| get_display(x.as_ref())) - .collect::>(); - for row_index in 0..batch.num_rows() { - let row_vec = display_col - .iter() - .map(|display_col| display_col(row_index)) - .collect(); - result.push(row_vec); - } - } - result -} - -async fn generic_query_length(datatype: DataType) -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", datatype, false)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Utf8Array::::from_slice(vec![ - "", "a", "aa", "aaa", - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT length(c1) FROM test"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -#[cfg_attr(not(feature = "unicode_expressions"), ignore)] -async fn query_length() -> Result<()> { - generic_query_length::(DataType::Utf8).await -} - -#[tokio::test] -#[cfg_attr(not(feature = "unicode_expressions"), ignore)] -async fn query_large_length() -> Result<()> { - generic_query_length::(DataType::LargeUtf8).await -} - -#[tokio::test] -async fn query_not() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(BooleanArray::from(vec![ - Some(false), - None, - Some(true), - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT NOT c1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------+", - "| NOT test.c1 |", - "+-------------+", - "| true |", - "| |", - "| false |", - "+-------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_concat() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Utf8Array::::from_slice(vec!["", "a", "aa", "aaa"])), - Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), - ], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------------------------------------------+", - "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |", - "+----------------------------------------------------+", - "| -hi-0 |", - "| a-hi-1 |", - "| aa-hi- |", - "| aaa-hi-3 |", - "+----------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -// Revisit after implementing https://github.com/apache/arrow-rs/issues/925 -#[tokio::test] -async fn query_array() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Utf8Array::::from_slice(vec!["", "a", "aa", "aaa"])), - Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), - ], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![ - vec!["[, 0]"], - vec!["[a, 1]"], - vec!["[aa, ]"], - vec!["[aaa, 3]"], - ]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn csv_query_sum_cast() { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - // c8 = i32; c9 = i64 - let sql = "SELECT c8 + c9 FROM aggregate_test_100"; - // check that the physical and logical schemas are equal - execute(&mut ctx, sql).await; -} - -#[tokio::test] -async fn query_where_neg_num() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - - // Negative numbers do not parse correctly as of Arrow 2.0.0 - let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2 and c7 < 10"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-------+", - "| c7 | c8 |", - "+----+-------+", - "| 7 | 45465 |", - "| 5 | 40622 |", - "| 0 | 61069 |", - "| 2 | 20120 |", - "| 4 | 39363 |", - "+----+-------+", - ]; - assert_batches_eq!(expected, &actual); - - // Also check floating point neg numbers - let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2.9 and c7 < 10"; - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn like() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; - // check that the physical and logical schemas are equal - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------------------------------+", - "| COUNT(aggregate_test_100.c1) |", - "+------------------------------+", - "| 1 |", - "+------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -fn make_timestamp_table(time_unit: TimeUnit) -> Result> { - let schema = Arc::new(Schema::new(vec![ - Field::new("ts", DataType::Timestamp(time_unit, None), false), - Field::new("value", DataType::Int32, true), - ])); - - let divisor = match time_unit { - TimeUnit::Nanosecond => 1i64, - TimeUnit::Microsecond => 1000, - TimeUnit::Millisecond => 1_000_000, - TimeUnit::Second => 1_000_000_000, - }; - - let nanotimestamps = vec![ - 1599572549190855000, // 2020-09-08T13:42:29.190855+00:00 - 1599568949190855000, // 2020-09-08T12:42:29.190855+00:00 - 1599565349190855000, //2020-09-08T11:42:29.190855+00:00 - ]; - let values = nanotimestamps - .into_iter() - .map(|x| x / divisor) - .collect::>(); - - let array = Int64Array::from_values(values).to(DataType::Timestamp(time_unit, None)); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(array), - Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), - ], - )?; - let table = MemTable::try_new(schema, vec![vec![data]])?; - Ok(Arc::new(table)) -} - -fn make_timestamp_nano_table() -> Result> { - make_timestamp_table(TimeUnit::Nanosecond) -} - -#[tokio::test] -async fn to_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_nano_table()?)?; - - let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 2 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn to_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_table(TimeUnit::Millisecond)?)?; - - let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_millis('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 2 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn to_timestamp_micros() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_table(TimeUnit::Microsecond)?)?; - - let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_micros('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 2 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn to_timestamp_seconds() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_table(TimeUnit::Second)?)?; - - let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_seconds('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 2 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn count_distinct_timestamps() -> Result<()> { - let mut ctx = ExecutionContext::new(); - ctx.register_table("ts_data", make_timestamp_nano_table()?)?; - - let sql = "SELECT COUNT(DISTINCT(ts)) FROM ts_data"; - let actual = execute_to_batches(&mut ctx, sql).await; - - let expected = vec![ - "+----------------------------+", - "| COUNT(DISTINCT ts_data.ts) |", - "+----------------------------+", - "| 3 |", - "+----------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_is_null() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Float64Array::from(vec![ - Some(1.0), - None, - Some(f64::NAN), - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT c1 IS NULL FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----------------+", - "| test.c1 IS NULL |", - "+-----------------+", - "| false |", - "| true |", - "| false |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_is_not_null() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Float64Array::from(vec![ - Some(1.0), - None, - Some(f64::NAN), - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT c1 IS NOT NULL FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------------+", - "| test.c1 IS NOT NULL |", - "+---------------------+", - "| true |", - "| false |", - "| true |", - "+---------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_count_distinct() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![ - Some(0), - Some(1), - None, - Some(3), - Some(3), - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT COUNT(DISTINCT c1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------------------------+", - "| COUNT(DISTINCT test.c1) |", - "+-------------------------+", - "| 3 |", - "+-------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_group_on_null() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![ - Some(0), - Some(3), - None, - Some(1), - Some(3), - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; - - let actual = execute_to_batches(&mut ctx, sql).await; - - // Note that the results also - // include a row for NULL (c1=NULL, count = 1) - let expected = vec![ - "+-----------------+----+", - "| COUNT(UInt8(1)) | c1 |", - "+-----------------+----+", - "| 1 | |", - "| 1 | 0 |", - "| 1 | 1 |", - "| 2 | 3 |", - "+-----------------+----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_group_on_null_multi_col() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - Some(0), - Some(3), - None, - None, - Some(3), - Some(0), - None, - Some(3), - ])), - Arc::new(Utf8Array::::from(vec![ - None, - None, - Some("foo"), - None, - Some("bar"), - Some("foo"), - None, - Some("bar"), - Some("foo"), - ])), - ], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; - - let actual = execute_to_batches(&mut ctx, sql).await; - - // Note that the results also include values for null - // include a row for NULL (c1=NULL, count = 1) - let expected = vec![ - "+-----------------+----+-----+", - "| COUNT(UInt8(1)) | c1 | c2 |", - "+-----------------+----+-----+", - "| 1 | | |", - "| 2 | | bar |", - "| 3 | 0 | |", - "| 3 | 3 | foo |", - "+-----------------+----+-----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - - // Also run query with group columns reversed (results should be the same) - let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1"; - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_on_string_dictionary() -> Result<()> { - // Test to ensure DataFusion can operate on dictionary types - // Use StringDictionary (32 bit indexes = keys) - let original_data = vec![Some("one"), None, Some("three")]; - let mut array = MutableDictionaryArray::>::new(); - array.try_extend(original_data)?; - let array: DictionaryArray = array.into(); - - let batch = - RecordBatch::try_from_iter(vec![("d1", Arc::new(array) as ArrayRef)]).unwrap(); - - let table = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - - // Basic SELECT - let sql = "SELECT * FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| d1 |", - "+-------+", - "| one |", - "| |", - "| three |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - // basic filtering - let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| d1 |", - "+-------+", - "| one |", - "| three |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - // filtering with constant - let sql = "SELECT * FROM test WHERE d1 = 'three'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| d1 |", - "+-------+", - "| three |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - // Expression evaluation - let sql = "SELECT concat(d1, '-foo') FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------------------------------+", - "| concat(test.d1,Utf8(\"-foo\")) |", - "+------------------------------+", - "| one-foo |", - "| -foo |", - "| three-foo |", - "+------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - // aggregation - let sql = "SELECT COUNT(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------+", - "| COUNT(test.d1) |", - "+----------------+", - "| 2 |", - "+----------------+", - ]; - assert_batches_eq!(expected, &actual); - - // aggregation min - let sql = "SELECT MIN(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------+", - "| MIN(test.d1) |", - "+--------------+", - "| one |", - "+--------------+", - ]; - assert_batches_eq!(expected, &actual); - - // aggregation max - let sql = "SELECT MAX(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------+", - "| MAX(test.d1) |", - "+--------------+", - "| three |", - "+--------------+", - ]; - assert_batches_eq!(expected, &actual); - - // grouping - let sql = "SELECT d1, COUNT(*) FROM test group by d1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+-----------------+", - "| d1 | COUNT(UInt8(1)) |", - "+-------+-----------------+", - "| one | 1 |", - "| | 1 |", - "| three | 1 |", - "+-------+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - - // window functions - let sql = "SELECT d1, row_number() OVER (partition by d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+--------------+", - "| d1 | ROW_NUMBER() |", - "+-------+--------------+", - "| | 1 |", - "| one | 1 |", - "| three | 1 |", - "+-------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn query_without_from() -> Result<()> { - // Test for SELECT without FROM. - // Should evaluate expressions in project position. - let mut ctx = ExecutionContext::new(); - - let sql = "SELECT 1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------+", - "| Int64(1) |", - "+----------+", - "| 1 |", - "+----------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT 1+2, 3/4, cos(0)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------------+---------------------+---------------+", - "| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |", - "+---------------------+---------------------+---------------+", - "| 3 | 0 | 1 |", - "+---------------------+---------------------+---------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn query_cte() -> Result<()> { - // Test for SELECT without FROM. - // Should evaluate expressions in project position. - let mut ctx = ExecutionContext::new(); - - // simple with - let sql = "WITH t AS (SELECT 1) SELECT * FROM t"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------+", - "| Int64(1) |", - "+----------+", - "| 1 |", - "+----------+", - ]; - assert_batches_eq!(expected, &actual); - - // with + union - let sql = - "WITH t AS (SELECT 1 AS a), u AS (SELECT 2 AS a) SELECT * FROM t UNION ALL SELECT * FROM u"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "| 2 |", "+---+"]; - assert_batches_eq!(expected, &actual); - - // with + join - let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT 1 AS id2, 5 as x) SELECT x FROM t JOIN u ON (id1 = id2)"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["+---+", "| x |", "+---+", "| 5 |", "+---+"]; - assert_batches_eq!(expected, &actual); - - // backward reference - let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT * FROM t) SELECT * from u"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["+-----+", "| id1 |", "+-----+", "| 1 |", "+-----+"]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn query_cte_incorrect() -> Result<()> { - let ctx = ExecutionContext::new(); - - // self reference - let sql = "WITH t AS (SELECT * FROM t) SELECT * from u"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - assert_eq!( - format!("{}", plan.unwrap_err()), - "Error during planning: Table or CTE with name \'t\' not found" - ); - - // forward referencing - let sql = "WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - assert_eq!( - format!("{}", plan.unwrap_err()), - "Error during planning: Table or CTE with name \'u\' not found" - ); - - // wrapping should hide u - let sql = "WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u"; - let plan = ctx.create_logical_plan(sql); - assert!(plan.is_err()); - assert_eq!( - format!("{}", plan.unwrap_err()), - "Error during planning: Table or CTE with name \'u\' not found" - ); - - Ok(()) -} - -#[tokio::test] -async fn query_scalar_minus_array() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![ - Some(0), - Some(1), - None, - Some(3), - ]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT 4 - c1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+------------------------+", - "| Int64(4) Minus test.c1 |", - "+------------------------+", - "| 4 |", - "| 3 |", - "| |", - "| 1 |", - "+------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -fn assert_float_eq(expected: &[Vec], received: &[Vec]) -where - T: AsRef, -{ - expected - .iter() - .flatten() - .zip(received.iter().flatten()) - .for_each(|(l, r)| { - let (l, r) = ( - l.as_ref().parse::().unwrap(), - r.as_str().parse::().unwrap(), - ); - assert!((l - r).abs() <= 2.0 * f64::EPSILON); - }); -} - -#[tokio::test] -async fn csv_between_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 BETWEEN 0.995 AND 1.0"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c4 |", - "+-------+", - "| 10837 |", - "+-------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_between_expr_negated() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 NOT BETWEEN 0 AND 0.995"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c4 |", - "+-------+", - "| 10837 |", - "+-------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn csv_group_by_date() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let schema = Arc::new(Schema::new(vec![ - Field::new("date", DataType::Date32, false), - Field::new("cnt", DataType::Int32, false), - ])); - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new( - Int32Array::from([ - Some(100), - Some(100), - Some(100), - Some(101), - Some(101), - Some(101), - ]) - .to(DataType::Date32), - ), - Arc::new(Int32Array::from([ - Some(1), - Some(2), - Some(3), - Some(3), - Some(3), - Some(3), - ])), - ], - )?; - let table = MemTable::try_new(schema, vec![vec![data]])?; - - ctx.register_table("dates", Arc::new(table))?; - let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------+", - "| SUM(dates.cnt) |", - "+----------------+", - "| 6 |", - "| 9 |", - "+----------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn group_by_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - let data_type = DataType::Timestamp(TimeUnit::Millisecond, None); - let schema = Arc::new(Schema::new(vec![ - Field::new("timestamp", data_type.clone(), false), - Field::new("count", DataType::Int32, false), - ])); - let base_dt = chrono::Utc.ymd(2018, 7, 1).and_hms(6, 0, 0); // 2018-Jul-01 06:00 - let hour1 = Duration::hours(1); - let timestamps = vec![ - base_dt.timestamp_millis(), - (base_dt + hour1).timestamp_millis(), - base_dt.timestamp_millis(), - base_dt.timestamp_millis(), - (base_dt + hour1).timestamp_millis(), - (base_dt + hour1).timestamp_millis(), - ]; - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int64Array::from_slice(×tamps).to(data_type)), - Arc::new(Int32Array::from_slice(&[10, 20, 30, 40, 50, 60])), - ], - )?; - let t1_table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("t1", Arc::new(t1_table)).unwrap(); - - let sql = - "SELECT timestamp, SUM(count) FROM t1 GROUP BY timestamp ORDER BY timestamp ASC"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------------+---------------+", - "| timestamp | SUM(t1.count) |", - "+---------------------+---------------+", - "| 2018-07-01 06:00:00 | 80 |", - "| 2018-07-01 07:00:00 | 130 |", - "+---------------------+---------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -macro_rules! test_expression { - ($SQL:expr, $EXPECTED:expr) => { - let mut ctx = ExecutionContext::new(); - let sql = format!("SELECT {}", $SQL); - let actual = execute(&mut ctx, sql.as_str()).await; - assert_eq!(actual[0][0], $EXPECTED); - }; -} - -#[tokio::test] -async fn test_boolean_expressions() -> Result<()> { - test_expression!("true", "true"); - test_expression!("false", "false"); - test_expression!("false = false", "true"); - test_expression!("true = false", "false"); - Ok(()) -} - -#[tokio::test] -#[cfg_attr(not(feature = "crypto_expressions"), ignore)] -#[ignore] -/// arrow2 use ":#010b" instead of ":02x" to represent binaries. -/// use "" instead of "NULL" to represent nulls. -async fn test_crypto_expressions() -> Result<()> { - test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); - test_expression!("digest('tom','md5')", "34b7da764b21d298ef307d04d8152dc5"); - test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); - test_expression!("digest('','md5')", "d41d8cd98f00b204e9800998ecf8427e"); - test_expression!("md5(NULL)", "NULL"); - test_expression!("digest(NULL,'md5')", "NULL"); - test_expression!( - "sha224('tom')", - "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" - ); - test_expression!( - "digest('tom','sha224')", - "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" - ); - test_expression!( - "sha224('')", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" - ); - test_expression!( - "digest('','sha224')", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" - ); - test_expression!("sha224(NULL)", "NULL"); - test_expression!("digest(NULL,'sha224')", "NULL"); - test_expression!( - "sha256('tom')", - "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" - ); - test_expression!( - "digest('tom','sha256')", - "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" - ); - test_expression!( - "sha256('')", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - ); - test_expression!( - "digest('','sha256')", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - ); - test_expression!("sha256(NULL)", "NULL"); - test_expression!("digest(NULL,'sha256')", "NULL"); - test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); - test_expression!("digest('tom','sha384')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); - test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); - test_expression!("digest('','sha384')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); - test_expression!("sha384(NULL)", "NULL"); - test_expression!("digest(NULL,'sha384')", "NULL"); - test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); - test_expression!("digest('tom','sha512')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); - test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); - test_expression!("digest('','sha512')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); - test_expression!("sha512(NULL)", "NULL"); - test_expression!("digest(NULL,'sha512')", "NULL"); - test_expression!("digest(NULL,'blake2s')", "NULL"); - test_expression!("digest(NULL,'blake2b')", "NULL"); - test_expression!("digest('','blake2b')", "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce"); - test_expression!("digest('tom','blake2b')", "482499a18da10a18d8d35ab5eb4c635551ec5b8d3ff37c3e87a632caf6680fe31566417834b4732e26e0203d1cad4f5366cb7ab57d89694e4c1fda3e26af2c23"); - test_expression!( - "digest('','blake2s')", - "69217a3079908094e11121d042354a7c1f55b6482ca1a51e1b250dfd1ed0eef9" - ); - test_expression!( - "digest('tom','blake2s')", - "5fc3f2b3a07cade5023c3df566e4d697d3823ba1b72bfb3e84cf7e768b2e7529" - ); - test_expression!( - "digest('','blake3')", - "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262" - ); - Ok(()) -} - -#[tokio::test] -async fn test_interval_expressions() -> Result<()> { - test_expression!("interval '1'", "0d1000ms"); - test_expression!("interval '1 second'", "0d1000ms"); - test_expression!("interval '500 milliseconds'", "0d500ms"); - test_expression!("interval '5 second'", "0d5000ms"); - test_expression!("interval '0.5 minute'", "0d30000ms"); - test_expression!("interval '.5 minute'", "0d30000ms"); - test_expression!("interval '5 minute'", "0d300000ms"); - test_expression!("interval '5 minute 1 second'", "0d301000ms"); - test_expression!("interval '1 hour'", "0d3600000ms"); - test_expression!("interval '5 hour'", "0d18000000ms"); - test_expression!("interval '1 day'", "1d0ms"); - test_expression!("interval '1 day 1'", "1d1000ms"); - test_expression!("interval '0.5'", "0d500ms"); - test_expression!("interval '0.5 day 1'", "0d43201000ms"); - test_expression!("interval '0.49 day'", "0d42336000ms"); - // TODO: precision here. - // test_expression!( - // "interval '0.499 day'", - // "0d43113600ms" - // ); - // test_expression!( - // "interval '0.4999 day'", - // "0d43191360ms" - // ); - // test_expression!( - // "interval '0.49999 day'", - // "0d43199136ms" - // ); - // test_expression!( - // "interval '0.49999999999 day'", - // "0d43199999.999136ms" - // ); - test_expression!("interval '5 day'", "5d0ms"); - // Hour is ignored, this matches PostgreSQL - test_expression!("interval '5 day' hour", "5d0ms"); - test_expression!( - "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", - "5d14582100ms" - ); - test_expression!("interval '0.5 month'", "15d0ms"); - test_expression!("interval '0.5' month", "15d0ms"); - test_expression!("interval '1 month'", "1m"); - test_expression!("interval '1' MONTH", "1m"); - test_expression!("interval '5 month'", "5m"); - test_expression!("interval '13 month'", "13m"); - test_expression!("interval '0.5 year'", "6m"); - test_expression!("interval '1 year'", "12m"); - test_expression!("interval '2 year'", "24m"); - test_expression!("interval '2' year", "24m"); - Ok(()) -} - -#[tokio::test] -async fn test_string_expressions() -> Result<()> { - test_expression!("ascii('')", "0"); - test_expression!("ascii('x')", "120"); - test_expression!("ascii(NULL)", ""); - test_expression!("bit_length('')", "0"); - test_expression!("bit_length('chars')", "40"); - test_expression!("bit_length('josé')", "40"); - test_expression!("bit_length(NULL)", ""); - test_expression!("btrim(' xyxtrimyyx ', NULL)", ""); - test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); - test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); - test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); - test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); - test_expression!("btrim(NULL, 'xyz')", ""); - test_expression!("chr(CAST(120 AS int))", "x"); - test_expression!("chr(CAST(128175 AS int))", "💯"); - test_expression!("chr(CAST(NULL AS int))", ""); - test_expression!("concat('a','b','c')", "abc"); - test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); - test_expression!("concat(NULL)", ""); - test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); - test_expression!("concat_ws('|','a','b','c')", "a|b|c"); - test_expression!("concat_ws('|',NULL)", ""); - test_expression!("concat_ws(NULL,'a',NULL,'b','c')", ""); - test_expression!("initcap('')", ""); - test_expression!("initcap('hi THOMAS')", "Hi Thomas"); - test_expression!("initcap(NULL)", ""); - test_expression!("lower('')", ""); - test_expression!("lower('TOM')", "tom"); - test_expression!("lower(NULL)", ""); - test_expression!("ltrim(' zzzytest ', NULL)", ""); - test_expression!("ltrim(' zzzytest ')", "zzzytest "); - test_expression!("ltrim('zzzytest', 'xyz')", "test"); - test_expression!("ltrim(NULL, 'xyz')", ""); - test_expression!("octet_length('')", "0"); - test_expression!("octet_length('chars')", "5"); - test_expression!("octet_length('josé')", "5"); - test_expression!("octet_length(NULL)", ""); - test_expression!("repeat('Pg', 4)", "PgPgPgPg"); - test_expression!("repeat('Pg', CAST(NULL AS INT))", ""); - test_expression!("repeat(NULL, 4)", ""); - test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); - test_expression!("replace('abcdefabcdef', 'cd', NULL)", ""); - test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); - test_expression!("replace('abcdefabcdef', NULL, 'XX')", ""); - test_expression!("replace(NULL, 'cd', 'XX')", ""); - test_expression!("rtrim(' testxxzx ')", " testxxzx"); - test_expression!("rtrim(' zzzytest ', NULL)", ""); - test_expression!("rtrim('testxxzx', 'xyz')", "test"); - test_expression!("rtrim(NULL, 'xyz')", ""); - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); - test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); - test_expression!("split_part(NULL, '~@~', 20)", ""); - test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", ""); - test_expression!( - "split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT))", - "" - ); - test_expression!("starts_with('alphabet', 'alph')", "true"); - test_expression!("starts_with('alphabet', 'blph')", "false"); - test_expression!("starts_with(NULL, 'blph')", ""); - test_expression!("starts_with('alphabet', NULL)", ""); - test_expression!("to_hex(2147483647)", "7fffffff"); - test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); - test_expression!("to_hex(CAST(NULL AS int))", ""); - test_expression!("trim(' tom ')", "tom"); - test_expression!("trim(LEADING ' ' FROM ' tom ')", "tom "); - test_expression!("trim(TRAILING ' ' FROM ' tom ')", " tom"); - test_expression!("trim(BOTH ' ' FROM ' tom ')", "tom"); - test_expression!("trim(LEADING 'x' FROM 'xxxtomxxx')", "tomxxx"); - test_expression!("trim(TRAILING 'x' FROM 'xxxtomxxx')", "xxxtom"); - test_expression!("trim(BOTH 'x' FROM 'xxxtomxx')", "tom"); - test_expression!("trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdefxyx"); - test_expression!("trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx')", "xyxabcxyzdef"); - test_expression!("trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); - test_expression!("trim(' tom')", "tom"); - test_expression!("trim('')", ""); - test_expression!("trim('tom ')", "tom"); - test_expression!("upper('')", ""); - test_expression!("upper('tom')", "TOM"); - test_expression!("upper(NULL)", ""); - Ok(()) -} - -#[tokio::test] -#[cfg_attr(not(feature = "unicode_expressions"), ignore)] -async fn test_unicode_expressions() -> Result<()> { - test_expression!("char_length('')", "0"); - test_expression!("char_length('chars')", "5"); - test_expression!("char_length('josé')", "4"); - test_expression!("char_length(NULL)", ""); - test_expression!("character_length('')", "0"); - test_expression!("character_length('chars')", "5"); - test_expression!("character_length('josé')", "4"); - test_expression!("character_length(NULL)", ""); - test_expression!("left('abcde', -2)", "abc"); - test_expression!("left('abcde', -200)", ""); - test_expression!("left('abcde', 0)", ""); - test_expression!("left('abcde', 2)", "ab"); - test_expression!("left('abcde', 200)", "abcde"); - test_expression!("left('abcde', CAST(NULL AS INT))", ""); - test_expression!("left(NULL, 2)", ""); - test_expression!("left(NULL, CAST(NULL AS INT))", ""); - test_expression!("length('')", "0"); - test_expression!("length('chars')", "5"); - test_expression!("length('josé')", "4"); - test_expression!("length(NULL)", ""); - test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); - test_expression!("lpad('hi', 0)", ""); - test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); - test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); - test_expression!("lpad('hi', 5, NULL)", ""); - test_expression!("lpad('hi', 5)", " hi"); - test_expression!("lpad('hi', CAST(NULL AS INT), 'xy')", ""); - test_expression!("lpad('hi', CAST(NULL AS INT))", ""); - test_expression!("lpad('xyxhi', 3)", "xyx"); - test_expression!("lpad(NULL, 0)", ""); - test_expression!("lpad(NULL, 5, 'xy')", ""); - test_expression!("reverse('abcde')", "edcba"); - test_expression!("reverse('loẅks')", "skẅol"); - test_expression!("reverse(NULL)", ""); - test_expression!("right('abcde', -2)", "cde"); - test_expression!("right('abcde', -200)", ""); - test_expression!("right('abcde', 0)", ""); - test_expression!("right('abcde', 2)", "de"); - test_expression!("right('abcde', 200)", "abcde"); - test_expression!("right('abcde', CAST(NULL AS INT))", ""); - test_expression!("right(NULL, 2)", ""); - test_expression!("right(NULL, CAST(NULL AS INT))", ""); - test_expression!("rpad('hi', 5, 'xy')", "hixyx"); - test_expression!("rpad('hi', 0)", ""); - test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); - test_expression!("rpad('hi', 5, 'xy')", "hixyx"); - test_expression!("rpad('hi', 5, NULL)", ""); - test_expression!("rpad('hi', 5)", "hi "); - test_expression!("rpad('hi', CAST(NULL AS INT), 'xy')", ""); - test_expression!("rpad('hi', CAST(NULL AS INT))", ""); - test_expression!("rpad('xyxhi', 3)", "xyx"); - test_expression!("strpos('abc', 'c')", "3"); - test_expression!("strpos('josé', 'é')", "4"); - test_expression!("strpos('joséésoj', 'so')", "6"); - test_expression!("strpos('joséésoj', 'abc')", "0"); - test_expression!("strpos(NULL, 'abc')", ""); - test_expression!("strpos('joséésoj', NULL)", ""); - test_expression!("substr('alphabet', -3)", "alphabet"); - test_expression!("substr('alphabet', 0)", "alphabet"); - test_expression!("substr('alphabet', 1)", "alphabet"); - test_expression!("substr('alphabet', 2)", "lphabet"); - test_expression!("substr('alphabet', 3)", "phabet"); - test_expression!("substr('alphabet', 30)", ""); - test_expression!("substr('alphabet', CAST(NULL AS int))", ""); - test_expression!("substr('alphabet', 3, 2)", "ph"); - test_expression!("substr('alphabet', 3, 20)", "phabet"); - test_expression!("substr('alphabet', CAST(NULL AS int), 20)", ""); - test_expression!("substr('alphabet', 3, CAST(NULL AS int))", ""); - test_expression!("translate('12345', '143', 'ax')", "a2x5"); - test_expression!("translate(NULL, '143', 'ax')", ""); - test_expression!("translate('12345', NULL, 'ax')", ""); - test_expression!("translate('12345', '143', NULL)", ""); - Ok(()) -} - -#[tokio::test] -#[cfg_attr(not(feature = "regex_expressions"), ignore)] -async fn test_regex_expressions() -> Result<()> { - test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'gi')", "XXX"); - test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'i')", "XabcABC"); - test_expression!("regexp_replace('foobarbaz', 'b..', 'X', 'g')", "fooXX"); - test_expression!("regexp_replace('foobarbaz', 'b..', 'X')", "fooXbaz"); - test_expression!( - "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g')", - "fooXarYXazY" - ); - test_expression!("regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL)", ""); - test_expression!("regexp_replace('foobarbaz', 'b(..)', NULL, 'g')", ""); - test_expression!("regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g')", ""); - test_expression!("regexp_replace('Thomas', '.[mN]a.', 'M')", "ThM"); - test_expression!("regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g')", ""); - test_expression!("regexp_match('foobarbequebaz', '')", "[]"); - test_expression!( - "regexp_match('foobarbequebaz', '(bar)(beque)')", - "[bar, beque]" - ); - test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", ""); - test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]"); - test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]"); - test_expression!("regexp_match('aa', '.*-(\\d)')", ""); - test_expression!("regexp_match(NULL, '.*-(\\d)')", ""); - test_expression!("regexp_match('aaa-0', NULL)", ""); - Ok(()) -} - -#[tokio::test] -async fn test_extract_date_part() -> Result<()> { - test_expression!("date_part('hour', CAST('2020-01-01' AS DATE))", "0"); - test_expression!("EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE))", "0"); - test_expression!( - "EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "12" - ); - test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000"); - test_expression!( - "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", - "2020" - ); - Ok(()) -} - -#[tokio::test] -async fn test_in_list_scalar() -> Result<()> { - test_expression!("'a' IN ('a','b')", "true"); - test_expression!("'c' IN ('a','b')", "false"); - test_expression!("'c' NOT IN ('a','b')", "true"); - test_expression!("'a' NOT IN ('a','b')", "false"); - test_expression!("NULL IN ('a','b')", ""); - test_expression!("NULL NOT IN ('a','b')", ""); - test_expression!("'a' IN ('a','b',NULL)", "true"); - test_expression!("'c' IN ('a','b',NULL)", ""); - test_expression!("'a' NOT IN ('a','b',NULL)", "false"); - test_expression!("'c' NOT IN ('a','b',NULL)", ""); - test_expression!("0 IN (0,1,2)", "true"); - test_expression!("3 IN (0,1,2)", "false"); - test_expression!("3 NOT IN (0,1,2)", "true"); - test_expression!("0 NOT IN (0,1,2)", "false"); - test_expression!("NULL IN (0,1,2)", ""); - test_expression!("NULL NOT IN (0,1,2)", ""); - test_expression!("0 IN (0,1,2,NULL)", "true"); - test_expression!("3 IN (0,1,2,NULL)", ""); - test_expression!("0 NOT IN (0,1,2,NULL)", "false"); - test_expression!("3 NOT IN (0,1,2,NULL)", ""); - test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); - test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); - test_expression!("NULL IN (0.0,0.1,0.2)", ""); - test_expression!("NULL NOT IN (0.0,0.1,0.2)", ""); - test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); - test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", ""); - test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); - test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", ""); - test_expression!("'1' IN ('a','b',1)", "true"); - test_expression!("'2' IN ('a','b',1)", "false"); - test_expression!("'2' NOT IN ('a','b',1)", "true"); - test_expression!("'1' NOT IN ('a','b',1)", "false"); - test_expression!("NULL IN ('a','b',1)", ""); - test_expression!("NULL NOT IN ('a','b',1)", ""); - test_expression!("'1' IN ('a','b',NULL,1)", "true"); - test_expression!("'2' IN ('a','b',NULL,1)", ""); - test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); - test_expression!("'2' NOT IN ('a','b',NULL,1)", ""); - Ok(()) -} - -#[tokio::test] -async fn in_list_array() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; - let sql = "SELECT - c1 IN ('a', 'c') AS utf8_in_true - ,c1 IN ('x', 'y') AS utf8_in_false - ,c1 NOT IN ('x', 'y') AS utf8_not_in_true - ,c1 NOT IN ('a', 'c') AS utf8_not_in_false - ,NULL IN ('a', 'c') AS utf8_in_null - FROM aggregate_test_100 WHERE c12 < 0.05"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------+---------------+------------------+-------------------+--------------+", - "| utf8_in_true | utf8_in_false | utf8_not_in_true | utf8_not_in_false | utf8_in_null |", - "+--------------+---------------+------------------+-------------------+--------------+", - "| true | false | true | false | |", - "| true | false | true | false | |", - "| true | false | true | false | |", - "| false | false | true | true | |", - "| false | false | true | true | |", - "| false | false | true | true | |", - "| false | false | true | true | |", - "+--------------+---------------+------------------+-------------------+--------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -// TODO Tests to prove correct implementation of INNER JOIN's with qualified names. -// https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11432. -#[tokio::test] -#[ignore] -async fn inner_join_qualified_names() -> Result<()> { - // Setup the statements that test qualified names function correctly. - let equivalent_sql = [ - "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c - FROM t1 - INNER JOIN t2 ON t1.a = t2.a - ORDER BY t1.a", - "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c - FROM t1 - INNER JOIN t2 ON t2.a = t1.a - ORDER BY t1.a", - ]; - - let expected = vec![ - "+---+----+----+---+-----+-----+", - "| a | b | c | a | b | c |", - "+---+----+----+---+-----+-----+", - "| 1 | 10 | 50 | 1 | 100 | 500 |", - "| 2 | 20 | 60 | 2 | 200 | 600 |", - "| 4 | 40 | 80 | 4 | 400 | 800 |", - "+---+----+----+---+-----+-----+", - ]; - - for sql in equivalent_sql.iter() { - let mut ctx = create_join_context_qualified()?; - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn inner_join_nulls() { - let sql = "SELECT * FROM (SELECT null AS id1) t1 - INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; - - let expected = vec!["++", "++"]; - - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; - - // left and right shouldn't match anything - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn qualified_table_references() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - - for table_ref in &[ - "aggregate_test_100", - "public.aggregate_test_100", - "datafusion.public.aggregate_test_100", - ] { - let sql = format!("SELECT COUNT(*) FROM {}", table_ref); - let actual = execute_to_batches(&mut ctx, &sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 100 |", - "+-----------------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn invalid_qualified_table_references() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - - for table_ref in &[ - "nonexistentschema.aggregate_test_100", - "nonexistentcatalog.public.aggregate_test_100", - "way.too.many.namespaces.as.ident.prefixes.aggregate_test_100", - ] { - let sql = format!("SELECT COUNT(*) FROM {}", table_ref); - assert!(matches!(ctx.sql(&sql).await, Err(DataFusionError::Plan(_)))); - } - Ok(()) -} - -#[tokio::test] -async fn test_cast_expressions() -> Result<()> { - test_expression!("CAST('0' AS INT)", "0"); - test_expression!("CAST(NULL AS INT)", ""); - test_expression!("TRY_CAST('0' AS INT)", "0"); - test_expression!("TRY_CAST('x' AS INT)", ""); - Ok(()) -} - -#[tokio::test] -async fn test_current_timestamp_expressions() -> Result<()> { - let t1 = chrono::Utc::now().timestamp(); - let mut ctx = ExecutionContext::new(); - let actual = execute(&mut ctx, "SELECT NOW(), NOW() as t2").await; - let res1 = actual[0][0].as_str(); - let res2 = actual[0][1].as_str(); - let t3 = chrono::Utc::now().timestamp(); - let t2_naive = - chrono::NaiveDateTime::parse_from_str(res1, "%Y-%m-%d %H:%M:%S%.6f").unwrap(); - - let t2 = t2_naive.timestamp(); - assert!(t1 <= t2 && t2 <= t3); - assert_eq!(res2, res1); - - Ok(()) -} - -#[tokio::test] -async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { - let t1 = chrono::Utc::now().timestamp(); - let ctx = ExecutionContext::new(); - let sql = "SELECT NOW(), NOW() as t2"; - - let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - - let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); - let plan = ctx.create_physical_plan(&plan).await.expect(&msg); - - let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let res = collect(plan).await.expect(&msg); - let actual = result_vec(&res); - - let res1 = actual[0][0].as_str(); - let res2 = actual[0][1].as_str(); - let t3 = chrono::Utc::now().timestamp(); - let t2_naive = - chrono::NaiveDateTime::parse_from_str(res1, "%Y-%m-%d %H:%M:%S%.6f").unwrap(); - - let t2 = t2_naive.timestamp(); - assert!(t1 <= t2 && t2 <= t3); - assert_eq!(res2, res1); - - Ok(()) -} - -#[tokio::test] -async fn test_random_expression() -> Result<()> { - let mut ctx = create_ctx()?; - let sql = "SELECT random() r1"; - let actual = execute(&mut ctx, sql).await; - let r1 = actual[0][0].parse::().unwrap(); - assert!(0.0 <= r1); - assert!(r1 < 1.0); - Ok(()) -} - -#[tokio::test] -async fn test_cast_expressions_error() -> Result<()> { - // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT CAST(c1 AS INT) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let result = collect(plan).await; - - match result { - Ok(_) => panic!("expected cast error"), - Err(e) => { - assert_contains!( - e.to_string(), - "Execution error: Could not cast Utf8[c, d, b, a, b, b, e, a, d, a, d, a, e, d, b, c, e, d, d, e, e, d, a, e, c, a, c, a, a, b, e, c, e, b, a, c, d, c, c, c, b, d, d, a, e, b, b, c, a, d, b, c, d, d, b, d, e, b, a, b, c, b, c, e, e, d, e, c, d, e, e, a, a, e, a, b, e, c, e, c, a, c, b, a, a, c, a, c, c, c, b, a, a, b, d, e, e, d, b, e] to value of type Int32" - ); - } - } - - Ok(()) -} - -#[tokio::test] -async fn test_physical_plan_display_indent() { - // Hard code target_partitions as it appears in the RepartitionExec output - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); - register_aggregate_csv(&mut ctx).await.unwrap(); - let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ - FROM aggregate_test_100 \ - WHERE c12 < 10 \ - GROUP BY c1 \ - ORDER BY the_min DESC \ - LIMIT 10"; - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - - let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let expected = vec![ - "GlobalLimitExec: limit=10", - " SortExec: [the_min@2 DESC]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", - " HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", - " HashAggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12@1 < CAST(10 AS Float64)", - " RepartitionExec: partitioning=RoundRobinBatch(3)", - " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", - ]; - - let data_path = datafusion::test_util::arrow_test_data(); - let actual = format!("{}", displayable(physical_plan.as_ref()).indent()) - .trim() - .lines() - // normalize paths - .map(|s| s.replace(&data_path, "ARROW_TEST_DATA")) - .collect::>(); - - assert_eq!( - expected, actual, - "expected:\n{:#?}\nactual:\n\n{:#?}\n", - expected, actual - ); -} - -#[tokio::test] -async fn test_physical_plan_display_indent_multi_children() { - // Hard code target_partitions as it appears in the RepartitionExec output - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); - // ensure indenting works for nodes with multiple children - register_aggregate_csv(&mut ctx).await.unwrap(); - let sql = "SELECT c1 \ - FROM (select c1 from aggregate_test_100) AS a \ - JOIN\ - (select c1 as c2 from aggregate_test_100) AS b \ - ON c1=c2\ - "; - - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - - let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let expected = vec![ - "ProjectionExec: expr=[c1@0 as c1]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", - " ProjectionExec: expr=[c1@0 as c1]", - " ProjectionExec: expr=[c1@0 as c1]", - " RepartitionExec: partitioning=RoundRobinBatch(3)", - " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c2\", index: 0 }], 3)", - " ProjectionExec: expr=[c2@0 as c2]", - " ProjectionExec: expr=[c1@0 as c2]", - " RepartitionExec: partitioning=RoundRobinBatch(3)", - " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", - ]; - - let data_path = datafusion::test_util::arrow_test_data(); - let actual = format!("{}", displayable(physical_plan.as_ref()).indent()) - .trim() - .lines() - // normalize paths - .map(|s| s.replace(&data_path, "ARROW_TEST_DATA")) - .collect::>(); - - assert_eq!( - expected, actual, - "expected:\n{:#?}\nactual:\n\n{:#?}\n", - expected, actual - ); -} - -#[tokio::test] -async fn test_aggregation_with_bad_arguments() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; - let logical_plan = ctx.create_logical_plan(sql); - let err = logical_plan.unwrap_err(); - assert_eq!( - err.to_string(), - DataFusionError::Plan( - "The function Count expects 1 arguments, but 0 were provided".to_string() - ) - .to_string() - ); - Ok(()) -} - -// Normalizes parts of an explain plan that vary from run to run (such as path) -fn normalize_for_explain(s: &str) -> String { - // Convert things like /Users/alamb/Software/arrow/testing/data/csv/aggregate_test_100.csv - // to ARROW_TEST_DATA/csv/aggregate_test_100.csv - let data_path = datafusion::test_util::arrow_test_data(); - let s = s.replace(&data_path, "ARROW_TEST_DATA"); - - // convert things like partitioning=RoundRobinBatch(16) - // to partitioning=RoundRobinBatch(NUM_CORES) - let needle = format!("RoundRobinBatch({})", num_cpus::get()); - s.replace(&needle, "RoundRobinBatch(NUM_CORES)") -} - -/// Applies normalize_for_explain to every line -fn normalize_vec_for_explain(v: Vec>) -> Vec> { - v.into_iter() - .map(|l| { - l.into_iter() - .map(|s| normalize_for_explain(&s)) - .collect::>() - }) - .collect::>() -} - -#[tokio::test] -async fn test_partial_qualified_name() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = "SELECT t1.t1_id, t1_name FROM public.t1"; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 11 | a |", - "| 22 | b |", - "| 33 | c |", - "| 44 | d |", - "+-------+---------+", - ]; - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn like_on_strings() -> Result<()> { - let input = - Utf8Array::::from(vec![Some("foo"), Some("bar"), None, Some("fazzz")]); - - let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); - - let table = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - - let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c1 |", - "+-------+", - "| bar |", - "| fazzz |", - "+-------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn like_on_string_dictionaries() -> Result<()> { - let original_data = vec![Some("foo"), Some("bar"), None, Some("fazzz")]; - let mut input = MutableDictionaryArray::>::new(); - input.try_extend(original_data)?; - let input: DictionaryArray = input.into(); - - let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); - - let table = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - - let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c1 |", - "+-------+", - "| bar |", - "| fazzz |", - "+-------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_regexp_is_match() -> Result<()> { - let input = Utf8Array::::from(vec![ - Some("foo"), - Some("Barrr"), - Some("Bazzz"), - Some("ZZZZZ"), - ]); - - let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); - - let table = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - - let sql = "SELECT * FROM test WHERE c1 ~ 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c1 |", - "+-------+", - "| Bazzz |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT * FROM test WHERE c1 ~* 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c1 |", - "+-------+", - "| Bazzz |", - "| ZZZZZ |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT * FROM test WHERE c1 !~ 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c1 |", - "+-------+", - "| foo |", - "| Barrr |", - "| ZZZZZ |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT * FROM test WHERE c1 !~* 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-------+", - "| c1 |", - "+-------+", - "| foo |", - "| Barrr |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Result<()> { - let batch = RecordBatch::try_from_iter(vec![ - ("id", Arc::new(Int32Array::from_slice(&[1, 2, 3])) as _), - ( - "country", - Arc::new(Utf8Array::::from_slice(&[ - "Germany", "Sweden", "Japan", - ])) as _, - ), - ]) - .unwrap(); - let countries = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - - let batch = RecordBatch::try_from_iter(vec![ - ( - "id", - Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5, 6, 7])) as _, - ), - ( - "city", - Arc::new(Utf8Array::::from_slice(&[ - "Hamburg", - "Stockholm", - "Osaka", - "Berlin", - "Göteborg", - "Tokyo", - "Kyoto", - ])) as _, - ), - ( - "country_id", - Arc::new(Int32Array::from_slice(&[1, 2, 3, 1, 2, 3, 3])) as _, - ), - ]) - .unwrap(); - let cities = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("countries", Arc::new(countries))?; - ctx.register_table("cities", Arc::new(cities))?; - - // city.id is not in the on constraint, but the output result will contain both city.id and - // country.id - let sql = "SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+----+-----------+---------+", - "| id | id | city | country |", - "+----+----+-----------+---------+", - "| 1 | 1 | Hamburg | Germany |", - "| 2 | 2 | Stockholm | Sweden |", - "| 3 | 3 | Osaka | Japan |", - "| 4 | 1 | Berlin | Germany |", - "| 5 | 2 | Göteborg | Sweden |", - "| 6 | 3 | Tokyo | Japan |", - "| 7 | 3 | Kyoto | Japan |", - "+----+----+-----------+---------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[cfg(feature = "avro")] -#[tokio::test] -async fn avro_query() { - let mut ctx = ExecutionContext::new(); - register_alltypes_avro(&mut ctx).await; - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-----------------------------------------+", - "| id | CAST(alltypes_plain.string_col AS Utf8) |", - "+----+-----------------------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+-----------------------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[cfg(feature = "avro")] -#[tokio::test] -async fn avro_query_multiple_files() { - let tempdir = tempfile::tempdir().unwrap(); - let table_path = tempdir.path(); - let testdata = datafusion::test_util::arrow_test_data(); - let alltypes_plain_file = format!("{}/avro/alltypes_plain.avro", testdata); - std::fs::copy( - &alltypes_plain_file, - format!("{}/alltypes_plain1.avro", table_path.display()), - ) - .unwrap(); - std::fs::copy( - &alltypes_plain_file, - format!("{}/alltypes_plain2.avro", table_path.display()), - ) - .unwrap(); - - let mut ctx = ExecutionContext::new(); - ctx.register_avro( - "alltypes_plain", - table_path.display().to_string().as_str(), - AvroReadOptions::default(), - ) - .await - .unwrap(); - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+-----------------------------------------+", - "| id | CAST(alltypes_plain.string_col AS Utf8) |", - "+----+-----------------------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+-----------------------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[cfg(feature = "avro")] -#[tokio::test] -async fn avro_single_nan_schema() { - let mut ctx = ExecutionContext::new(); - let testdata = datafusion::test_util::arrow_test_data(); - ctx.register_avro( - "single_nan", - &format!("{}/avro/single_nan.avro", testdata), - AvroReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT mycol FROM single_nan"; - let plan = ctx.create_logical_plan(sql).unwrap(); - let plan = ctx.optimize(&plan).unwrap(); - let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); - for batch in results { - assert_eq!(1, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[cfg(feature = "avro")] -#[tokio::test] -async fn avro_explain() { - let mut ctx = ExecutionContext::new(); - register_alltypes_avro(&mut ctx).await; - - let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; - let actual = execute(&mut ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - let expected = vec![ - vec![ - "logical_plan", - "Projection: #COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: alltypes_plain projection=Some([0])", - ], - vec![ - "physical_plan", - "ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\ - \n HashAggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\ - \n CoalescePartitionsExec\ - \n HashAggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\ - \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ - \n AvroExec: files=[ARROW_TEST_DATA/avro/alltypes_plain.avro], batch_size=8192, limit=None\ - \n", - ], - ]; - assert_eq!(expected, actual); -} - -#[tokio::test] -async fn union_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT 1 as x UNION SELECT 1 as x"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "+---+"]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn union_all_with_aggregate() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = - "SELECT SUM(d) FROM (SELECT 1 as c, 2 as d UNION ALL SELECT 1 as c, 3 AS d) as a"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------+", - "| SUM(a.d) |", - "+----------+", - "| 5 |", - "+----------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn case_with_bool_type_result() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "select case when 'cpu' != 'cpu' then true else false end"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------------------------------------------------------------------------------+", - "| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |", - "+---------------------------------------------------------------------------------+", - "| false |", - "+---------------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn use_between_expression_in_select_query() -> Result<()> { - let mut ctx = ExecutionContext::new(); - - let sql = "SELECT 1 NOT BETWEEN 3 AND 5"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+--------------------------------------------+", - "| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |", - "+--------------------------------------------+", - "| true |", - "+--------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - let input = Int64Array::from_slice(&[1, 2, 3, 4]); - let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); - let table = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; - ctx.register_table("test", Arc::new(table))?; - - let sql = "SELECT abs(c1) BETWEEN 0 AND LoG(c1 * 100 ) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - // Expect field name to be correctly converted for expr, low and high. - let expected = vec![ - "+--------------------------------------------------------------------+", - "| abs(test.c1) BETWEEN Int64(0) AND log(test.c1 Multiply Int64(100)) |", - "+--------------------------------------------------------------------+", - "| true |", - "| true |", - "| false |", - "| false |", - "+--------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "EXPLAIN SELECT c1 BETWEEN 2 AND 3 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; - let formatted = print::write(&actual); - - // Only test that the projection exprs arecorrect, rather than entire output - let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]"; - assert_contains!(&formatted, needle); - let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)"; - assert_contains!(&formatted, needle); - - Ok(()) -} - -// --- End Test Porting --- - -#[tokio::test] -async fn query_get_indexed_field() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let schema = Arc::new(Schema::new(vec![Field::new( - "some_list", - DataType::List(Box::new(Field::new("item", DataType::Int64, true))), - false, - )])); - - let rows = vec![ - vec![Some(0), Some(1), Some(2)], - vec![Some(4), Some(5), Some(6)], - vec![Some(7), Some(8), Some(9)], - ]; - let mut array = - MutableListArray::>::with_capacity(rows.len()); - for int_vec in rows { - array.try_push(Some(int_vec))?; - } - - let data = RecordBatch::try_new(schema.clone(), vec![array.into_arc()])?; - let table = MemTable::try_new(schema, vec![vec![data]])?; - let table_a = Arc::new(table); - - ctx.register_table("ints", table_a)?; - - // Original column is micros, convert to millis and check timestamp - let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 7 |", "+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_nested_get_indexed_field() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); - // Nested schema of { "some_list": [[i64]] } - let schema = Arc::new(Schema::new(vec![Field::new( - "some_list", - DataType::List(Box::new(Field::new("item", nested_dt.clone(), true))), - false, - )])); - - let rows = vec![ - vec![vec![0, 1], vec![2, 3], vec![3, 4]], - vec![vec![5, 6], vec![7, 8], vec![9, 10]], - vec![vec![11, 12], vec![13, 14], vec![15, 16]], - ]; - let mut array = MutableListArray::< - i32, - MutableListArray>, - >::with_capacity(rows.len()); - for int_vec_vec in rows.into_iter() { - array.try_push(Some( - int_vec_vec - .into_iter() - .map(|v| Some(v.into_iter().map(Some))), - ))?; - } - - let data = RecordBatch::try_new(schema.clone(), vec![array.into_arc()])?; - let table = MemTable::try_new(schema, vec![vec![data]])?; - let table_a = Arc::new(table); - - ctx.register_table("ints", table_a)?; - - // Original column is micros, convert to millis and check timestamp - let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------+", - "| i0 |", - "+----------+", - "| [0, 1] |", - "| [5, 6] |", - "| [11, 12] |", - "+----------+", - ]; - assert_batches_eq!(expected, &actual); - let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn query_nested_get_indexed_field_on_struct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); - // Nested schema of { "some_struct": { "bar": [i64] } } - let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)]; - let dt = DataType::Struct(struct_fields.clone()); - let schema = Arc::new(Schema::new(vec![Field::new( - "some_struct", - dt.clone(), - false, - )])); - - let rows = vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]]; - let mut list_array = - MutableListArray::>::with_capacity(rows.len()); - for int_vec in rows.into_iter() { - list_array.try_push(Some(int_vec.into_iter().map(Some)))?; - } - let array = StructArray::from_data(dt, vec![list_array.into_arc()], None); - - let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)])?; - let table = MemTable::try_new(schema, vec![vec![data]])?; - let table_a = Arc::new(table); - - ctx.register_table("structs", table_a)?; - - // Original column is micros, convert to millis and check timestamp - let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------------+", - "| l0 |", - "+----------------+", - "| [0, 1, 2, 3] |", - "| [4, 5, 6, 7] |", - "| [8, 9, 10, 11] |", - "+----------------+", - ]; - assert_batches_eq!(expected, &actual); - let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn intersect_with_null_not_equal() { - let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 - INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2"; - - let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn intersect_with_null_equal() { - let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 - INTERSECT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2"; - - let expected = vec![ - "+-----+-----+", - "| id1 | id2 |", - "+-----+-----+", - "| | 1 |", - "+-----+-----+", - ]; - - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn test_intersect_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - // execute the query - let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 1 | 10.1 |", - "| 1 | 10.1 |", - "| 1 | 10.1 |", - "| 1 | 10.1 |", - "+---------+------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_intersect_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - // execute the query - let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 1 | 10.1 |", - "+---------+------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn except_with_null_not_equal() { - let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 - EXCEPT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2"; - - let expected = vec![ - "+-----+-----+", - "| id1 | id2 |", - "+-----+-----+", - "| | 1 |", - "+-----+-----+", - ]; - - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn except_with_null_equal() { - let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 - EXCEPT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2"; - - let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn test_expect_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - // execute the query - let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 1 | 10.1 |", - "| 1 | 10.1 |", - "| 1 | 10.1 |", - "| 1 | 10.1 |", - "+---------+------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_expect_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - // execute the query - let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 1 | 10.1 |", - "+---------+------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_sort_unprojected_col() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_alltypes_parquet(&mut ctx).await; - // execute the query - let sql = "SELECT id FROM alltypes_plain ORDER BY int_col, double_col"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----+", "| id |", "+----+", "| 4 |", "| 6 |", "| 2 |", "| 0 |", "| 5 |", - "| 7 |", "| 3 |", "| 1 |", "+----+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_nulls_first_asc() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----+--------+", - "| num | letter |", - "+-----+--------+", - "| 1 | one |", - "| 2 | two |", - "| | three |", - "+-----+--------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_nulls_first_desc() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----+--------+", - "| num | letter |", - "+-----+--------+", - "| | three |", - "| 2 | two |", - "| 1 | one |", - "+-----+--------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_specific_nulls_last_desc() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC NULLS LAST"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----+--------+", - "| num | letter |", - "+-----+--------+", - "| 2 | two |", - "| 1 | one |", - "| | three |", - "+-----+--------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_specific_nulls_first_asc() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num ASC NULLS FIRST"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+-----+--------+", - "| num | letter |", - "+-----+--------+", - "| | three |", - "| 1 | one |", - "| 2 | two |", - "+-----+--------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn test_select_wildcard_without_table() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT * "; - let actual = ctx.sql(sql).await; - match actual { - Ok(_) => panic!("expect err"), - Err(e) => { - assert_contains!( - e.to_string(), - "Error during planning: SELECT * with no tables specified is not valid" - ); - } - } - Ok(()) -} - -#[tokio::test] -#[ignore] -async fn csv_query_with_decimal_by_sql() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; - let sql = "SELECT c1 from aggregate_simple"; - let actual = execute_to_batches(&mut ctx, sql).await; - let expected = vec![ - "+----------+", - "| c1 |", - "+----------+", - "| 0.000010 |", - "| 0.000020 |", - "| 0.000020 |", - "| 0.000030 |", - "| 0.000030 |", - "| 0.000030 |", - "| 0.000040 |", - "| 0.000040 |", - "| 0.000040 |", - "| 0.000040 |", - "| 0.000050 |", - "| 0.000050 |", - "| 0.000050 |", - "| 0.000050 |", - "| 0.000050 |", - "+----------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs new file mode 100644 index 000000000000..edf530be8b7d --- /dev/null +++ b/datafusion/tests/sql/aggregates.rs @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_query_avg_multi_batch() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT avg(c12) FROM aggregate_test_100"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).await.unwrap(); + let results = collect(plan).await.unwrap(); + let batch = &results[0]; + let column = batch.column(0); + let array = column.as_any().downcast_ref::().unwrap(); + let actual = array.value(0); + let expected = 0.5089725; + // Due to float number's accuracy, different batch size will lead to different + // answers. + assert!((expected - actual).abs() < 0.01); + Ok(()) +} + +#[tokio::test] +async fn csv_query_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT avg(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.5089725099127211"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_pop(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8675"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_pop(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["26156334342021890000000000000000000000"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_pop(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.09234223721582163"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_4() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8863636363636365"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_5() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_samp(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8863636363636365"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_pop(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.3665650368716449"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_pop(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["5114326382039172000"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_pop(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.30387865541334363"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_4() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.3054095399405338"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_5() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_samp(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.3054095399405338"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_6() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.9504384952922168"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_external_table_count() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT COUNT(c12) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------------+", + "| COUNT(aggregate_test_100.c12) |", + "+-------------------------------+", + "| 100 |", + "+-------------------------------+", + ]; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_external_table_sum() { + let mut ctx = ExecutionContext::new(); + // cast smallint and int to bigint to avoid overflow during calculation + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = + "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------------------------+-------------------------------------------+", + "| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |", + "+-------------------------------------------+-------------------------------------------+", + "| 13060 | 3017641 |", + "+-------------------------------------------+-------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_count() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT count(c12) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------------+", + "| COUNT(aggregate_test_100.c12) |", + "+-------------------------------+", + "| 100 |", + "+-------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_count_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT count(distinct c2) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------------------------+", + "| COUNT(DISTINCT aggregate_test_100.c2) |", + "+---------------------------------------+", + "| 5 |", + "+---------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_count_distinct_expr() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT count(distinct c2 % 2) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------------------------------------+", + "| COUNT(DISTINCT aggregate_test_100.c2 % Int64(2)) |", + "+--------------------------------------------------+", + "| 2 |", + "+--------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_count_star() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT COUNT(*) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 100 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_count_one() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT COUNT(1) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 100 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_approx_count() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+--------------+", + "| count_c9 | count_c9_str |", + "+----------+--------------+", + "| 100 | 99 |", + "+----------+--------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_count_without_from() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT count(1 + 1)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+", + "| COUNT(Int64(1) + Int64(1)) |", + "+----------------------------+", + "| 1 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_array_agg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------------------------------------------------------------------+", + "| ARRAYAGG(test.c13) |", + "+------------------------------------------------------------------+", + "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] |", + "+------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_array_agg_empty() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------+", + "| ARRAYAGG(test.c13) |", + "+--------------------+", + "| [] |", + "+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_array_agg_one() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------------+", + "| ARRAYAGG(test.c13) |", + "+----------------------------------+", + "| [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] |", + "+----------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs new file mode 100644 index 000000000000..3983389dae34 --- /dev/null +++ b/datafusion/tests/sql/avro.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +async fn register_alltypes_avro(ctx: &mut ExecutionContext) { + let testdata = datafusion::test_util::arrow_test_data(); + ctx.register_avro( + "alltypes_plain", + &format!("{}/avro/alltypes_plain.avro", testdata), + AvroReadOptions::default(), + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn avro_query() { + let mut ctx = ExecutionContext::new(); + register_alltypes_avro(&mut ctx).await; + // NOTE that string_col is actually a binary column and does not have the UTF8 logical type + // so we need an explicit cast + let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-----------------------------------------+", + "| id | CAST(alltypes_plain.string_col AS Utf8) |", + "+----+-----------------------------------------+", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "+----+-----------------------------------------+", + ]; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn avro_query_multiple_files() { + let tempdir = tempfile::tempdir().unwrap(); + let table_path = tempdir.path(); + let testdata = datafusion::test_util::arrow_test_data(); + let alltypes_plain_file = format!("{}/avro/alltypes_plain.avro", testdata); + std::fs::copy( + &alltypes_plain_file, + format!("{}/alltypes_plain1.avro", table_path.display()), + ) + .unwrap(); + std::fs::copy( + &alltypes_plain_file, + format!("{}/alltypes_plain2.avro", table_path.display()), + ) + .unwrap(); + + let mut ctx = ExecutionContext::new(); + ctx.register_avro( + "alltypes_plain", + table_path.display().to_string().as_str(), + AvroReadOptions::default(), + ) + .await + .unwrap(); + // NOTE that string_col is actually a binary column and does not have the UTF8 logical type + // so we need an explicit cast + let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-----------------------------------------+", + "| id | CAST(alltypes_plain.string_col AS Utf8) |", + "+----+-----------------------------------------+", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "+----+-----------------------------------------+", + ]; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn avro_single_nan_schema() { + let mut ctx = ExecutionContext::new(); + let testdata = datafusion::test_util::arrow_test_data(); + ctx.register_avro( + "single_nan", + &format!("{}/avro/single_nan.avro", testdata), + AvroReadOptions::default(), + ) + .await + .unwrap(); + let sql = "SELECT mycol FROM single_nan"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).await.unwrap(); + let results = collect(plan).await.unwrap(); + for batch in results { + assert_eq!(1, batch.num_rows()); + assert_eq!(1, batch.num_columns()); + } +} + +#[tokio::test] +async fn avro_explain() { + let mut ctx = ExecutionContext::new(); + register_alltypes_avro(&mut ctx).await; + + let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; + let actual = execute(&mut ctx, sql).await; + let actual = normalize_vec_for_explain(actual); + let expected = vec![ + vec![ + "logical_plan", + "Projection: #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: alltypes_plain projection=Some([0])", + ], + vec![ + "physical_plan", + "ProjectionExec: expr=[COUNT(UInt8(1))@0 as COUNT(UInt8(1))]\ + \n HashAggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]\ + \n CoalescePartitionsExec\ + \n HashAggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\ + \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ + \n AvroExec: files=[ARROW_TEST_DATA/avro/alltypes_plain.avro], batch_size=8192, limit=None\ + \n", + ], + ]; + assert_eq!(expected, actual); +} diff --git a/datafusion/tests/sql/create_drop.rs b/datafusion/tests/sql/create_drop.rs new file mode 100644 index 000000000000..7dcca46710b7 --- /dev/null +++ b/datafusion/tests/sql/create_drop.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn create_table_as() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; + ctx.sql(sql).await.unwrap(); + + let sql_all = "SELECT * FROM my_table order by c1 LIMIT 1"; + let results_all = execute_to_batches(&mut ctx, sql_all).await; + + let expected = vec![ + "+---------+----------------+------+", + "| c1 | c2 | c3 |", + "+---------+----------------+------+", + "| 0.00001 | 0.000000000001 | true |", + "+---------+----------------+------+", + ]; + + assert_batches_eq!(expected, &results_all); + + Ok(()) +} + +#[tokio::test] +async fn drop_table() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; + ctx.sql(sql).await.unwrap(); + + let sql = "DROP TABLE my_table"; + ctx.sql(sql).await.unwrap(); + + let result = ctx.table("my_table"); + assert!(result.is_err(), "drop table should deregister table."); + + let sql = "DROP TABLE IF EXISTS my_table"; + ctx.sql(sql).await.unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn csv_query_create_external_table() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", + "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | Int64(10) | c11 | c12 | c13 |", + "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", + "| c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 10 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW |", + "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", + ]; + assert_batches_eq!(expected, &actual); +} diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs new file mode 100644 index 000000000000..9cd7bc96ff89 --- /dev/null +++ b/datafusion/tests/sql/errors.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_query_error() -> Result<()> { + // sin(utf8) should error + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT sin(c1) FROM aggregate_test_100"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + Ok(()) +} + +#[tokio::test] +async fn test_cast_expressions_error() -> Result<()> { + // sin(utf8) should error + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT CAST(c1 AS INT) FROM aggregate_test_100"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).await.unwrap(); + let result = collect(plan).await; + + match result { + Ok(_) => panic!("expected error"), + Err(e) => { + assert_contains!(e.to_string(), + "Cast error: Cannot cast string 'c' to value of arrow::datatypes::types::Int32Type type" + ); + } + } + + Ok(()) +} + +#[tokio::test] +async fn test_aggregation_with_bad_arguments() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; + let logical_plan = ctx.create_logical_plan(sql); + let err = logical_plan.unwrap_err(); + assert_eq!( + err.to_string(), + DataFusionError::Plan( + "The function Count expects 1 arguments, but 0 were provided".to_string() + ) + .to_string() + ); + Ok(()) +} + +#[tokio::test] +async fn query_cte_incorrect() -> Result<()> { + let ctx = ExecutionContext::new(); + + // self reference + let sql = "WITH t AS (SELECT * FROM t) SELECT * from u"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + assert_eq!( + format!("{}", plan.unwrap_err()), + "Error during planning: Table or CTE with name \'t\' not found" + ); + + // forward referencing + let sql = "WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + assert_eq!( + format!("{}", plan.unwrap_err()), + "Error during planning: Table or CTE with name \'u\' not found" + ); + + // wrapping should hide u + let sql = "WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + assert_eq!( + format!("{}", plan.unwrap_err()), + "Error during planning: Table or CTE with name \'u\' not found" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_select_wildcard_without_table() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT * "; + let actual = ctx.sql(sql).await; + match actual { + Ok(_) => panic!("expect err"), + Err(e) => { + assert_contains!( + e.to_string(), + "Error during planning: SELECT * with no tables specified is not valid" + ); + } + } + Ok(()) +} + +#[tokio::test] +async fn invalid_qualified_table_references() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + + for table_ref in &[ + "nonexistentschema.aggregate_test_100", + "nonexistentcatalog.public.aggregate_test_100", + "way.too.many.namespaces.as.ident.prefixes.aggregate_test_100", + ] { + let sql = format!("SELECT COUNT(*) FROM {}", table_ref); + assert!(matches!(ctx.sql(&sql).await, Err(DataFusionError::Plan(_)))); + } + Ok(()) +} diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs new file mode 100644 index 000000000000..47e729038c3b --- /dev/null +++ b/datafusion/tests/sql/explain_analyze.rs @@ -0,0 +1,787 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn explain_analyze_baseline_metrics() { + // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE + // and then validate the presence of baseline metrics for supported operators + let config = ExecutionConfig::new().with_target_partitions(3); + let mut ctx = ExecutionContext::with_config(config); + register_aggregate_csv_by_sql(&mut ctx).await; + // a query with as many operators as we have metrics for + let sql = "EXPLAIN ANALYZE \ + SELECT count(*) as cnt FROM \ + (SELECT count(*), c1 \ + FROM aggregate_test_100 \ + WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434' \ + GROUP BY c1 \ + ORDER BY c1 ) AS a \ + UNION ALL \ + SELECT 1 as cnt \ + UNION ALL \ + SELECT lead(c1, 1) OVER () as cnt FROM (select 1 as c1) AS b \ + LIMIT 3"; + println!("running query: {}", sql); + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); + let results = collect(physical_plan.clone()).await.unwrap(); + let formatted = arrow::util::pretty::pretty_format_batches(&results).unwrap(); + println!("Query Output:\n\n{}", formatted); + + assert_metrics!( + &formatted, + "HashAggregateExec: mode=Partial, gby=[]", + "metrics=[output_rows=3, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", + "metrics=[output_rows=5, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "SortExec: [c1@0 ASC NULLS LAST]", + "metrics=[output_rows=5, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", + "metrics=[output_rows=99, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "GlobalLimitExec: limit=3, ", + "metrics=[output_rows=1, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "LocalLimitExec: limit=3", + "metrics=[output_rows=3, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "ProjectionExec: expr=[COUNT(UInt8(1))", + "metrics=[output_rows=1, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "CoalesceBatchesExec: target_batch_size=4096", + "metrics=[output_rows=5, elapsed_compute" + ); + assert_metrics!( + &formatted, + "CoalescePartitionsExec", + "metrics=[output_rows=5, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "UnionExec", + "metrics=[output_rows=3, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "WindowAggExec", + "metrics=[output_rows=1, elapsed_compute=" + ); + + fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { + use datafusion::physical_plan; + + plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + // CoalescePartitionsExec doesn't do any work so is not included + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + } + + // Validate that the recorded elapsed compute time was more than + // zero for all operators as well as the start/end timestamp are set + struct TimeValidator {} + impl ExecutionPlanVisitor for TimeValidator { + type Error = std::convert::Infallible; + + fn pre_visit( + &mut self, + plan: &dyn ExecutionPlan, + ) -> std::result::Result { + if !expected_to_have_metrics(plan) { + return Ok(true); + } + let metrics = plan.metrics().unwrap().aggregate_by_partition(); + + assert!(metrics.output_rows().unwrap() > 0); + assert!(metrics.elapsed_compute().unwrap() > 0); + + let mut saw_start = false; + let mut saw_end = false; + metrics.iter().for_each(|m| match m.value() { + MetricValue::StartTimestamp(ts) => { + saw_start = true; + assert!(ts.value().unwrap().timestamp_nanos() > 0); + } + MetricValue::EndTimestamp(ts) => { + saw_end = true; + assert!(ts.value().unwrap().timestamp_nanos() > 0); + } + _ => {} + }); + + assert!(saw_start); + assert!(saw_end); + + Ok(true) + } + } + + datafusion::physical_plan::accept(physical_plan.as_ref(), &mut TimeValidator {}) + .unwrap(); +} + +#[tokio::test] +async fn csv_explain_plans() { + // This test verify the look of each plan in its full cycle plan creation + + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; + + // Logical plan + // Create plan + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let logical_schema = plan.schema(); + // + println!("SQL: {}", sql); + // + // Verify schema + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + " TableScan: aggregate_test_100 projection=None [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // Verify the text format of the plan + let expected = vec![ + "Explain", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 > Int64(10)", + " TableScan: aggregate_test_100 projection=None", + ]; + let formatted = plan.display_indent().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // verify the grahviz format of the plan + let expected = vec![ + "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "digraph {", + " subgraph cluster_1", + " {", + " graph[label=\"LogicalPlan\"]", + " 2[shape=box label=\"Explain\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", + " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=None\"]", + " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + " subgraph cluster_6", + " {", + " graph[label=\"Detailed LogicalPlan\"]", + " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=None\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + "}", + "// End DataFusion GraphViz Plan", + ]; + let formatted = plan.display_graphviz().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + // Optimized logical plan + // + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); + let optimized_logical_schema = plan.schema(); + // Both schema has to be the same + assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); + // + // Verify schema + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", + " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // Verify the text format of the plan + let expected = vec![ + "Explain", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 > Int64(10)", + " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]", + ]; + let formatted = plan.display_indent().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // verify the grahviz format of the plan + let expected = vec![ + "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "digraph {", + " subgraph cluster_1", + " {", + " graph[label=\"LogicalPlan\"]", + " 2[shape=box label=\"Explain\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", + " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\"]", + " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + " subgraph cluster_6", + " {", + " graph[label=\"Detailed LogicalPlan\"]", + " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + "}", + "// End DataFusion GraphViz Plan", + ]; + let formatted = plan.display_graphviz().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + // Physical plan + // Create plan + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).await.expect(&msg); + // + // Execute plan + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = collect(plan).await.expect(&msg); + let actual = result_vec(&results); + // flatten to a single string + let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); + // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content + assert_contains!(&actual, "logical_plan"); + assert_contains!(&actual, "Projection: #aggregate_test_100.c1"); + assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)"); +} + +#[tokio::test] +async fn csv_explain_verbose() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; + let actual = execute(&mut ctx, sql).await; + + // flatten to a single string + let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); + + // Don't actually test the contents of the debuging output (as + // that may change and keeping this test updated will be a + // pain). Instead just check for a few key pieces. + assert_contains!(&actual, "logical_plan"); + assert_contains!(&actual, "physical_plan"); + assert_contains!(&actual, "#aggregate_test_100.c2 > Int64(10)"); + + // ensure the "same text as above" optimization is working + assert_contains!(actual, "SAME TEXT AS ABOVE"); +} + +#[tokio::test] +async fn csv_explain_verbose_plans() { + // This test verify the look of each plan in its full cycle plan creation + + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; + + // Logical plan + // Create plan + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let logical_schema = plan.schema(); + // + println!("SQL: {}", sql); + + // + // Verify schema + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + " TableScan: aggregate_test_100 projection=None [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // Verify the text format of the plan + let expected = vec![ + "Explain", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 > Int64(10)", + " TableScan: aggregate_test_100 projection=None", + ]; + let formatted = plan.display_indent().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // verify the grahviz format of the plan + let expected = vec![ + "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "digraph {", + " subgraph cluster_1", + " {", + " graph[label=\"LogicalPlan\"]", + " 2[shape=box label=\"Explain\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", + " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=None\"]", + " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + " subgraph cluster_6", + " {", + " graph[label=\"Detailed LogicalPlan\"]", + " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=None\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + "}", + "// End DataFusion GraphViz Plan", + ]; + let formatted = plan.display_graphviz().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + // Optimized logical plan + // + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); + let optimized_logical_schema = plan.schema(); + // Both schema has to be the same + assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); + // + // Verify schema + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", + " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // Verify the text format of the plan + let expected = vec![ + "Explain", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 > Int64(10)", + " TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]", + ]; + let formatted = plan.display_indent().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + // + // verify the grahviz format of the plan + let expected = vec![ + "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "digraph {", + " subgraph cluster_1", + " {", + " graph[label=\"LogicalPlan\"]", + " 2[shape=box label=\"Explain\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", + " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\"]", + " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + " subgraph cluster_6", + " {", + " graph[label=\"Detailed LogicalPlan\"]", + " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", + " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", + " }", + "}", + "// End DataFusion GraphViz Plan", + ]; + let formatted = plan.display_graphviz().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + // Physical plan + // Create plan + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).await.expect(&msg); + // + // Execute plan + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = collect(plan).await.expect(&msg); + let actual = result_vec(&results); + // flatten to a single string + let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); + // Since the plan contains path that are environmentally + // dependant(e.g. full path of the test file), only verify + // important content + assert_contains!(&actual, "logical_plan after projection_push_down"); + assert_contains!(&actual, "physical_plan"); + assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10"); + assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); +} + +#[tokio::test] +async fn explain_analyze_runs_optimizers() { + // repro for https://github.com/apache/arrow-datafusion/issues/917 + // where EXPLAIN ANALYZE was not correctly running optiimizer + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + + // This happens as an optimization pass where count(*) can be + // answered using statistics only. + let expected = "EmptyExec: produce_one_row=true"; + + let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; + let actual = execute_to_batches(&mut ctx, sql).await; + let actual = arrow::util::pretty::pretty_format_batches(&actual).unwrap(); + assert_contains!(actual, expected); + + // EXPLAIN ANALYZE should work the same + let sql = "EXPLAIN ANALYZE SELECT count(*) from alltypes_plain"; + let actual = execute_to_batches(&mut ctx, sql).await; + let actual = arrow::util::pretty::pretty_format_batches(&actual).unwrap(); + assert_contains!(actual, expected); +} + +#[tokio::test] +async fn tpch_explain_q10() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + register_tpch_csv(&mut ctx, "customer").await?; + register_tpch_csv(&mut ctx, "orders").await?; + register_tpch_csv(&mut ctx, "lineitem").await?; + register_tpch_csv(&mut ctx, "nation").await?; + + let sql = "select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1994-01-01' + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc;"; + + let mut plan = ctx.create_logical_plan(sql); + plan = ctx.optimize(&plan.unwrap()); + + let expected = "\ + Sort: #revenue DESC NULLS FIRST\ + \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ + \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\ + \n Join: #customer.c_nationkey = #nation.n_nationkey\ + \n Join: #orders.o_orderkey = #lineitem.l_orderkey\ + \n Join: #customer.c_custkey = #orders.o_custkey\ + \n TableScan: customer projection=Some([0, 1, 2, 3, 4, 5, 7])\ + \n Filter: #orders.o_orderdate >= Date32(\"8674\") AND #orders.o_orderdate < Date32(\"8766\")\ + \n TableScan: orders projection=Some([0, 1, 4]), filters=[#orders.o_orderdate >= Date32(\"8674\"), #orders.o_orderdate < Date32(\"8766\")]\ + \n Filter: #lineitem.l_returnflag = Utf8(\"R\")\ + \n TableScan: lineitem projection=Some([0, 5, 6, 8]), filters=[#lineitem.l_returnflag = Utf8(\"R\")]\ + \n TableScan: nation projection=Some([0, 1])"; + assert_eq!(format!("{:?}", plan.unwrap()), expected); + + Ok(()) +} + +#[tokio::test] +async fn test_physical_plan_display_indent() { + // Hard code target_partitions as it appears in the RepartitionExec output + let config = ExecutionConfig::new().with_target_partitions(3); + let mut ctx = ExecutionContext::with_config(config); + register_aggregate_csv(&mut ctx).await.unwrap(); + let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ + FROM aggregate_test_100 \ + WHERE c12 < 10 \ + GROUP BY c1 \ + ORDER BY the_min DESC \ + LIMIT 10"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + + let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); + let expected = vec![ + "GlobalLimitExec: limit=10", + " SortExec: [the_min@2 DESC]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", + " HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", + " HashAggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", + " CoalesceBatchesExec: target_batch_size=4096", + " FilterExec: c12@1 < CAST(10 AS Float64)", + " RepartitionExec: partitioning=RoundRobinBatch(3)", + " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", + ]; + + let data_path = datafusion::test_util::arrow_test_data(); + let actual = format!("{}", displayable(physical_plan.as_ref()).indent()) + .trim() + .lines() + // normalize paths + .map(|s| s.replace(&data_path, "ARROW_TEST_DATA")) + .collect::>(); + + assert_eq!( + expected, actual, + "expected:\n{:#?}\nactual:\n\n{:#?}\n", + expected, actual + ); +} + +#[tokio::test] +async fn test_physical_plan_display_indent_multi_children() { + // Hard code target_partitions as it appears in the RepartitionExec output + let config = ExecutionConfig::new().with_target_partitions(3); + let mut ctx = ExecutionContext::with_config(config); + // ensure indenting works for nodes with multiple children + register_aggregate_csv(&mut ctx).await.unwrap(); + let sql = "SELECT c1 \ + FROM (select c1 from aggregate_test_100) AS a \ + JOIN\ + (select c1 as c2 from aggregate_test_100) AS b \ + ON c1=c2\ + "; + + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + + let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); + let expected = vec![ + "ProjectionExec: expr=[c1@0 as c1]", + " CoalesceBatchesExec: target_batch_size=4096", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 0 })]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", + " ProjectionExec: expr=[c1@0 as c1]", + " ProjectionExec: expr=[c1@0 as c1]", + " RepartitionExec: partitioning=RoundRobinBatch(3)", + " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c2\", index: 0 }], 3)", + " ProjectionExec: expr=[c2@0 as c2]", + " ProjectionExec: expr=[c1@0 as c2]", + " RepartitionExec: partitioning=RoundRobinBatch(3)", + " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", + ]; + + let data_path = datafusion::test_util::arrow_test_data(); + let actual = format!("{}", displayable(physical_plan.as_ref()).indent()) + .trim() + .lines() + // normalize paths + .map(|s| s.replace(&data_path, "ARROW_TEST_DATA")) + .collect::>(); + + assert_eq!( + expected, actual, + "expected:\n{:#?}\nactual:\n\n{:#?}\n", + expected, actual + ); +} + +#[tokio::test] +async fn csv_explain() { + // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, + // then execute the physical plan and return the final explain results + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; + let actual = execute(&mut ctx, sql).await; + let actual = normalize_vec_for_explain(actual); + + // Note can't use `assert_batches_eq` as the plan needs to be + // normalized for filenames and number of cores + let expected = vec![ + vec![ + "logical_plan", + "Projection: #aggregate_test_100.c1\ + \n Filter: #aggregate_test_100.c2 > Int64(10)\ + \n TableScan: aggregate_test_100 projection=Some([0, 1]), filters=[#aggregate_test_100.c2 > Int64(10)]" + ], + vec!["physical_plan", + "ProjectionExec: expr=[c1@0 as c1]\ + \n CoalesceBatchesExec: target_batch_size=4096\ + \n FilterExec: CAST(c2@1 AS Int64) > 10\ + \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ + \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None\ + \n" + ]]; + assert_eq!(expected, actual); + + // Also, expect same result with lowercase explain + let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; + let actual = execute(&mut ctx, sql).await; + let actual = normalize_vec_for_explain(actual); + assert_eq!(expected, actual); +} + +#[tokio::test] +async fn csv_explain_analyze() { + // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "EXPLAIN ANALYZE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let formatted = arrow::util::pretty::pretty_format_batches(&actual).unwrap(); + + // Only test basic plumbing and try to avoid having to change too + // many things. explain_analyze_baseline_metrics covers the values + // in greater depth + let needle = "CoalescePartitionsExec, metrics=[output_rows=5, elapsed_compute="; + assert_contains!(&formatted, needle); + + let verbose_needle = "Output Rows"; + assert_not_contains!(formatted, verbose_needle); +} + +#[tokio::test] +async fn csv_explain_analyze_verbose() { + // This test uses the execute function to run an actual plan under EXPLAIN VERBOSE ANALYZE + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = + "EXPLAIN ANALYZE VERBOSE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let formatted = arrow::util::pretty::pretty_format_batches(&actual).unwrap(); + + let verbose_needle = "Output Rows"; + assert_contains!(formatted, verbose_needle); +} diff --git a/datafusion/tests/sql/expr.rs b/datafusion/tests/sql/expr.rs new file mode 100644 index 000000000000..8c2f6b970165 --- /dev/null +++ b/datafusion/tests/sql/expr.rs @@ -0,0 +1,917 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn case_when() -> Result<()> { + let mut ctx = create_case_context()?; + let sql = "SELECT \ + CASE WHEN c1 = 'a' THEN 1 \ + WHEN c1 = 'b' THEN 2 \ + END \ + FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------------------------------------------------------------------------+", + "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) END |", + "+--------------------------------------------------------------------------------------+", + "| 1 |", + "| 2 |", + "| |", + "| |", + "+--------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn case_when_else() -> Result<()> { + let mut ctx = create_case_context()?; + let sql = "SELECT \ + CASE WHEN c1 = 'a' THEN 1 \ + WHEN c1 = 'b' THEN 2 \ + ELSE 999 END \ + FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------------------------------------------------------------------------------------------------------+", + "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", + "+------------------------------------------------------------------------------------------------------+", + "| 1 |", + "| 2 |", + "| 999 |", + "| 999 |", + "+------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn case_when_with_base_expr() -> Result<()> { + let mut ctx = create_case_context()?; + let sql = "SELECT \ + CASE c1 WHEN 'a' THEN 1 \ + WHEN 'b' THEN 2 \ + END \ + FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------------------------------------------------------------+", + "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) END |", + "+---------------------------------------------------------------------------+", + "| 1 |", + "| 2 |", + "| |", + "| |", + "+---------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn case_when_else_with_base_expr() -> Result<()> { + let mut ctx = create_case_context()?; + let sql = "SELECT \ + CASE c1 WHEN 'a' THEN 1 \ + WHEN 'b' THEN 2 \ + ELSE 999 END \ + FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------------------------------------------------------------------------+", + "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", + "+-------------------------------------------------------------------------------------------+", + "| 1 |", + "| 2 |", + "| 999 |", + "| 999 |", + "+-------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_not() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(BooleanArray::from(vec![ + Some(false), + None, + Some(true), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT NOT c1 FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------+", + "| NOT test.c1 |", + "+-------------+", + "| true |", + "| |", + "| false |", + "+-------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_sum_cast() { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + // c8 = i32; c9 = i64 + let sql = "SELECT c8 + c9 FROM aggregate_test_100"; + // check that the physical and logical schemas are equal + execute(&mut ctx, sql).await; +} + +#[tokio::test] +async fn query_is_null() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + None, + Some(f64::NAN), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT c1 IS NULL FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| test.c1 IS NULL |", + "+-----------------+", + "| false |", + "| true |", + "| false |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_is_not_null() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + None, + Some(f64::NAN), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT c1 IS NOT NULL FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------+", + "| test.c1 IS NOT NULL |", + "+---------------------+", + "| true |", + "| false |", + "| true |", + "+---------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_without_from() -> Result<()> { + // Test for SELECT without FROM. + // Should evaluate expressions in project position. + let mut ctx = ExecutionContext::new(); + + let sql = "SELECT 1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| Int64(1) |", + "+----------+", + "| 1 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT 1+2, 3/4, cos(0)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------+---------------------+---------------+", + "| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |", + "+---------------------+---------------------+---------------+", + "| 3 | 0 | 1 |", + "+---------------------+---------------------+---------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn query_scalar_minus_array() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT 4 - c1 FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------------------------+", + "| Int64(4) Minus test.c1 |", + "+------------------------+", + "| 4 |", + "| 3 |", + "| |", + "| 1 |", + "+------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_boolean_expressions() -> Result<()> { + test_expression!("true", "true"); + test_expression!("false", "false"); + test_expression!("false = false", "true"); + test_expression!("true = false", "false"); + Ok(()) +} + +#[tokio::test] +#[cfg_attr(not(feature = "crypto_expressions"), ignore)] +async fn test_crypto_expressions() -> Result<()> { + test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); + test_expression!("digest('tom','md5')", "34b7da764b21d298ef307d04d8152dc5"); + test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); + test_expression!("digest('','md5')", "d41d8cd98f00b204e9800998ecf8427e"); + test_expression!("md5(NULL)", "NULL"); + test_expression!("digest(NULL,'md5')", "NULL"); + test_expression!( + "sha224('tom')", + "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" + ); + test_expression!( + "digest('tom','sha224')", + "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" + ); + test_expression!( + "sha224('')", + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + ); + test_expression!( + "digest('','sha224')", + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + ); + test_expression!("sha224(NULL)", "NULL"); + test_expression!("digest(NULL,'sha224')", "NULL"); + test_expression!( + "sha256('tom')", + "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + ); + test_expression!( + "digest('tom','sha256')", + "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + ); + test_expression!( + "sha256('')", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + test_expression!( + "digest('','sha256')", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + test_expression!("sha256(NULL)", "NULL"); + test_expression!("digest(NULL,'sha256')", "NULL"); + test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); + test_expression!("digest('tom','sha384')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); + test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); + test_expression!("digest('','sha384')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); + test_expression!("sha384(NULL)", "NULL"); + test_expression!("digest(NULL,'sha384')", "NULL"); + test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); + test_expression!("digest('tom','sha512')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); + test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); + test_expression!("digest('','sha512')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); + test_expression!("sha512(NULL)", "NULL"); + test_expression!("digest(NULL,'sha512')", "NULL"); + test_expression!("digest(NULL,'blake2s')", "NULL"); + test_expression!("digest(NULL,'blake2b')", "NULL"); + test_expression!("digest('','blake2b')", "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce"); + test_expression!("digest('tom','blake2b')", "482499a18da10a18d8d35ab5eb4c635551ec5b8d3ff37c3e87a632caf6680fe31566417834b4732e26e0203d1cad4f5366cb7ab57d89694e4c1fda3e26af2c23"); + test_expression!( + "digest('','blake2s')", + "69217a3079908094e11121d042354a7c1f55b6482ca1a51e1b250dfd1ed0eef9" + ); + test_expression!( + "digest('tom','blake2s')", + "5fc3f2b3a07cade5023c3df566e4d697d3823ba1b72bfb3e84cf7e768b2e7529" + ); + test_expression!( + "digest('','blake3')", + "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262" + ); + Ok(()) +} + +#[tokio::test] +async fn test_interval_expressions() -> Result<()> { + test_expression!( + "interval '1'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '500 milliseconds'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '5 second'", + "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs" + ); + test_expression!( + "interval '0.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '5 minute'", + "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs" + ); + test_expression!( + "interval '5 minute 1 second'", + "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs" + ); + test_expression!( + "interval '1 hour'", + "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 hour'", + "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day'", + "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day 1'", + "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.5'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '0.5 day 1'", + "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.49 day'", + "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs" + ); + test_expression!( + "interval '0.499 day'", + "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs" + ); + test_expression!( + "interval '0.4999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs" + ); + test_expression!( + "interval '0.49999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs" + ); + test_expression!( + "interval '0.49999999999 day'", + "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day'", + "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs" + ); + // Hour is ignored, this matches PostgreSQL + test_expression!( + "interval '5 day' hour", + "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", + "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs" + ); + test_expression!( + "interval '0.5 month'", + "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '0.5' month", + "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 month'", + "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1' MONTH", + "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 month'", + "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '13 month'", + "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '0.5 year'", + "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 year'", + "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '2 year'", + "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '2' year", + "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); + Ok(()) +} + +#[tokio::test] +async fn test_string_expressions() -> Result<()> { + test_expression!("ascii('')", "0"); + test_expression!("ascii('x')", "120"); + test_expression!("ascii(NULL)", "NULL"); + test_expression!("bit_length('')", "0"); + test_expression!("bit_length('chars')", "40"); + test_expression!("bit_length('josé')", "40"); + test_expression!("bit_length(NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); + test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); + test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); + test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); + test_expression!("btrim(NULL, 'xyz')", "NULL"); + test_expression!("chr(CAST(120 AS int))", "x"); + test_expression!("chr(CAST(128175 AS int))", "💯"); + test_expression!("chr(CAST(NULL AS int))", "NULL"); + test_expression!("concat('a','b','c')", "abc"); + test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); + test_expression!("concat(NULL)", ""); + test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); + test_expression!("concat_ws('|','a','b','c')", "a|b|c"); + test_expression!("concat_ws('|',NULL)", ""); + test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("initcap('')", ""); + test_expression!("initcap('hi THOMAS')", "Hi Thomas"); + test_expression!("initcap(NULL)", "NULL"); + test_expression!("lower('')", ""); + test_expression!("lower('TOM')", "tom"); + test_expression!("lower(NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ')", "zzzytest "); + test_expression!("ltrim('zzzytest', 'xyz')", "test"); + test_expression!("ltrim(NULL, 'xyz')", "NULL"); + test_expression!("octet_length('')", "0"); + test_expression!("octet_length('chars')", "5"); + test_expression!("octet_length('josé')", "5"); + test_expression!("octet_length(NULL)", "NULL"); + test_expression!("repeat('Pg', 4)", "PgPgPgPg"); + test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); + test_expression!("repeat(NULL, 4)", "NULL"); + test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); + test_expression!("replace('abcdefabcdef', 'cd', NULL)", "NULL"); + test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); + test_expression!("replace('abcdefabcdef', NULL, 'XX')", "NULL"); + test_expression!("replace(NULL, 'cd', 'XX')", "NULL"); + test_expression!("rtrim(' testxxzx ')", " testxxzx"); + test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); + test_expression!("rtrim('testxxzx', 'xyz')", "test"); + test_expression!("rtrim(NULL, 'xyz')", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); + test_expression!("split_part(NULL, '~@~', 20)", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", "NULL"); + test_expression!( + "split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT))", + "NULL" + ); + test_expression!("starts_with('alphabet', 'alph')", "true"); + test_expression!("starts_with('alphabet', 'blph')", "false"); + test_expression!("starts_with(NULL, 'blph')", "NULL"); + test_expression!("starts_with('alphabet', NULL)", "NULL"); + test_expression!("to_hex(2147483647)", "7fffffff"); + test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); + test_expression!("to_hex(CAST(NULL AS int))", "NULL"); + test_expression!("trim(' tom ')", "tom"); + test_expression!("trim(LEADING ' ' FROM ' tom ')", "tom "); + test_expression!("trim(TRAILING ' ' FROM ' tom ')", " tom"); + test_expression!("trim(BOTH ' ' FROM ' tom ')", "tom"); + test_expression!("trim(LEADING 'x' FROM 'xxxtomxxx')", "tomxxx"); + test_expression!("trim(TRAILING 'x' FROM 'xxxtomxxx')", "xxxtom"); + test_expression!("trim(BOTH 'x' FROM 'xxxtomxx')", "tom"); + test_expression!("trim(LEADING 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdefxyx"); + test_expression!("trim(TRAILING 'xy' FROM 'xyxabcxyzdefxyx')", "xyxabcxyzdef"); + test_expression!("trim(BOTH 'xy' FROM 'xyxabcxyzdefxyx')", "abcxyzdef"); + test_expression!("trim(' tom')", "tom"); + test_expression!("trim('')", ""); + test_expression!("trim('tom ')", "tom"); + test_expression!("upper('')", ""); + test_expression!("upper('tom')", "TOM"); + test_expression!("upper(NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +#[cfg_attr(not(feature = "regex_expressions"), ignore)] +async fn test_regex_expressions() -> Result<()> { + test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'gi')", "XXX"); + test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'i')", "XabcABC"); + test_expression!("regexp_replace('foobarbaz', 'b..', 'X', 'g')", "fooXX"); + test_expression!("regexp_replace('foobarbaz', 'b..', 'X')", "fooXbaz"); + test_expression!( + "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g')", + "fooXarYXazY" + ); + test_expression!( + "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL)", + "NULL" + ); + test_expression!("regexp_replace('foobarbaz', 'b(..)', NULL, 'g')", "NULL"); + test_expression!("regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g')", "NULL"); + test_expression!("regexp_replace('Thomas', '.[mN]a.', 'M')", "ThM"); + test_expression!("regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g')", "NULL"); + test_expression!("regexp_match('foobarbequebaz', '')", "[]"); + test_expression!( + "regexp_match('foobarbequebaz', '(bar)(beque)')", + "[bar, beque]" + ); + test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL"); + test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]"); + test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]"); + test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL"); + test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL"); + test_expression!("regexp_match('aaa-0', NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_cast_expressions() -> Result<()> { + test_expression!("CAST('0' AS INT)", "0"); + test_expression!("CAST(NULL AS INT)", "NULL"); + test_expression!("TRY_CAST('0' AS INT)", "0"); + test_expression!("TRY_CAST('x' AS INT)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_random_expression() -> Result<()> { + let mut ctx = create_ctx()?; + let sql = "SELECT random() r1"; + let actual = execute(&mut ctx, sql).await; + let r1 = actual[0][0].parse::().unwrap(); + assert!(0.0 <= r1); + assert!(r1 < 1.0); + Ok(()) +} + +#[tokio::test] +async fn case_with_bool_type_result() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "select case when 'cpu' != 'cpu' then true else false end"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------------------------------------------------------------------+", + "| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |", + "+---------------------------------------------------------------------------------+", + "| false |", + "+---------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn in_list_array() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT + c1 IN ('a', 'c') AS utf8_in_true + ,c1 IN ('x', 'y') AS utf8_in_false + ,c1 NOT IN ('x', 'y') AS utf8_not_in_true + ,c1 NOT IN ('a', 'c') AS utf8_not_in_false + ,NULL IN ('a', 'c') AS utf8_in_null + FROM aggregate_test_100 WHERE c12 < 0.05"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------+---------------+------------------+-------------------+--------------+", + "| utf8_in_true | utf8_in_false | utf8_not_in_true | utf8_not_in_false | utf8_in_null |", + "+--------------+---------------+------------------+-------------------+--------------+", + "| true | false | true | false | |", + "| true | false | true | false | |", + "| true | false | true | false | |", + "| false | false | true | true | |", + "| false | false | true | true | |", + "| false | false | true | true | |", + "| false | false | true | true | |", + "+--------------+---------------+------------------+-------------------+--------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_extract_date_part() -> Result<()> { + test_expression!("date_part('hour', CAST('2020-01-01' AS DATE))", "0"); + test_expression!("EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE))", "0"); + test_expression!( + "EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "12" + ); + test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000"); + test_expression!( + "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "2020" + ); + Ok(()) +} + +#[tokio::test] +async fn test_in_list_scalar() -> Result<()> { + test_expression!("'a' IN ('a','b')", "true"); + test_expression!("'c' IN ('a','b')", "false"); + test_expression!("'c' NOT IN ('a','b')", "true"); + test_expression!("'a' NOT IN ('a','b')", "false"); + test_expression!("NULL IN ('a','b')", "NULL"); + test_expression!("NULL NOT IN ('a','b')", "NULL"); + test_expression!("'a' IN ('a','b',NULL)", "true"); + test_expression!("'c' IN ('a','b',NULL)", "NULL"); + test_expression!("'a' NOT IN ('a','b',NULL)", "false"); + test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); + test_expression!("0 IN (0,1,2)", "true"); + test_expression!("3 IN (0,1,2)", "false"); + test_expression!("3 NOT IN (0,1,2)", "true"); + test_expression!("0 NOT IN (0,1,2)", "false"); + test_expression!("NULL IN (0,1,2)", "NULL"); + test_expression!("NULL NOT IN (0,1,2)", "NULL"); + test_expression!("0 IN (0,1,2,NULL)", "true"); + test_expression!("3 IN (0,1,2,NULL)", "NULL"); + test_expression!("0 NOT IN (0,1,2,NULL)", "false"); + test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); + test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); + test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("'1' IN ('a','b',1)", "true"); + test_expression!("'2' IN ('a','b',1)", "false"); + test_expression!("'2' NOT IN ('a','b',1)", "true"); + test_expression!("'1' NOT IN ('a','b',1)", "false"); + test_expression!("NULL IN ('a','b',1)", "NULL"); + test_expression!("NULL NOT IN ('a','b',1)", "NULL"); + test_expression!("'1' IN ('a','b',NULL,1)", "true"); + test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); + test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); + test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn csv_query_boolean_eq_neq() { + let mut ctx = ExecutionContext::new(); + register_boolean(&mut ctx).await.unwrap(); + // verify the plumbing is all hooked up for eq and neq + let sql = "SELECT a, b, a = b as eq, b = true as eq_scalar, a != b as neq, a != true as neq_scalar FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------+-------+-------+-----------+-------+------------+", + "| a | b | eq | eq_scalar | neq | neq_scalar |", + "+-------+-------+-------+-----------+-------+------------+", + "| true | true | true | true | false | false |", + "| true | | | | | false |", + "| true | false | false | false | true | false |", + "| | true | | true | | |", + "| | | | | | |", + "| | false | | false | | |", + "| false | true | false | true | true | true |", + "| false | | | | | true |", + "| false | false | true | false | false | true |", + "+-------+-------+-------+-----------+-------+------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_boolean_lt_lt_eq() { + let mut ctx = ExecutionContext::new(); + register_boolean(&mut ctx).await.unwrap(); + // verify the plumbing is all hooked up for < and <= + let sql = "SELECT a, b, a < b as lt, b = true as lt_scalar, a <= b as lt_eq, a <= true as lt_eq_scalar FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------+-------+-------+-----------+-------+--------------+", + "| a | b | lt | lt_scalar | lt_eq | lt_eq_scalar |", + "+-------+-------+-------+-----------+-------+--------------+", + "| true | true | false | true | true | true |", + "| true | | | | | true |", + "| true | false | false | false | false | true |", + "| | true | | true | | |", + "| | | | | | |", + "| | false | | false | | |", + "| false | true | true | true | true | true |", + "| false | | | | | true |", + "| false | false | false | false | true | true |", + "+-------+-------+-------+-----------+-------+--------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_boolean_gt_gt_eq() { + let mut ctx = ExecutionContext::new(); + register_boolean(&mut ctx).await.unwrap(); + // verify the plumbing is all hooked up for > and >= + let sql = "SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------+-------+-------+-----------+-------+--------------+", + "| a | b | gt | gt_scalar | gt_eq | gt_eq_scalar |", + "+-------+-------+-------+-----------+-------+--------------+", + "| true | true | false | true | true | true |", + "| true | | | | | true |", + "| true | false | true | false | true | true |", + "| | true | | true | | |", + "| | | | | | |", + "| | false | | false | | |", + "| false | true | false | true | false | false |", + "| false | | | | | false |", + "| false | false | false | false | true | false |", + "+-------+-------+-------+-----------+-------+--------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_boolean_distinct_from() { + let mut ctx = ExecutionContext::new(); + register_boolean(&mut ctx).await.unwrap(); + // verify the plumbing is all hooked up for is distinct from and is not distinct from + let sql = "SELECT a, b, \ + a is distinct from b as df, \ + b is distinct from true as df_scalar, \ + a is not distinct from b as ndf, \ + a is not distinct from true as ndf_scalar \ + FROM t1"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------+-------+-------+-----------+-------+------------+", + "| a | b | df | df_scalar | ndf | ndf_scalar |", + "+-------+-------+-------+-----------+-------+------------+", + "| true | true | false | false | true | true |", + "| true | | true | true | false | true |", + "| true | false | true | true | false | true |", + "| | true | true | false | false | false |", + "| | | false | true | true | false |", + "| | false | true | true | false | false |", + "| false | true | true | false | false | false |", + "| false | | true | true | false | false |", + "| false | false | false | true | true | false |", + "+-------+-------+-------+-----------+-------+------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_nullif_divide_by_0() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let actual = &actual[80..90]; // We just want to compare rows 80-89 + let expected = vec![ + vec!["258"], + vec!["664"], + vec!["NULL"], + vec!["22"], + vec!["164"], + vec!["448"], + vec!["365"], + vec!["1640"], + vec!["671"], + vec!["203"], + ]; + assert_eq!(expected, actual); + Ok(()) +} +#[tokio::test] +async fn csv_count_star() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+-----+------------------------------+", + "| COUNT(UInt8(1)) | c | COUNT(aggregate_test_100.c1) |", + "+-----------------+-----+------------------------------+", + "| 100 | 100 | 100 |", + "+-----------------+-----+------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_avg_sqrt() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.6706002946036462"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +// this query used to deadlock due to the call udf(udf()) +#[tokio::test] +async fn csv_query_sqrt_sqrt() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT sqrt(sqrt(c12)) FROM aggregate_test_100 LIMIT 1"; + let actual = execute(&mut ctx, sql).await; + // sqrt(sqrt(c12=0.9294097332465232)) = 0.9818650561397431 + let expected = vec![vec!["0.9818650561397431"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs new file mode 100644 index 000000000000..224f8ba1c008 --- /dev/null +++ b/datafusion/tests/sql/functions.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +/// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) +#[tokio::test] +async fn sqrt_f32_vs_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx).await?; + // sqrt(f32)'s plan passes + let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.6584407806396484"]]; + + assert_eq!(actual, expected); + let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.6584408483418833"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_cast() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----------------------------------------+", + "| CAST(aggregate_test_100.c12 AS Float32) |", + "+-----------------------------------------+", + "| 0.39144436 |", + "| 0.3887028 |", + "+-----------------------------------------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_cast_literal() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+--------------------+---------------------------+", + "| c12 | CAST(Int64(1) AS Float32) |", + "+--------------------+---------------------------+", + "| 0.9294097332465232 | 1 |", + "| 0.3114712539863804 | 1 |", + "+--------------------+---------------------------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_concat() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Int32, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["", "a", "aa", "aaa"])), + Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------------------------------+", + "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |", + "+----------------------------------------------------+", + "| -hi-0 |", + "| a-hi-1 |", + "| aa-hi- |", + "| aaa-hi-3 |", + "+----------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +// Revisit after implementing https://github.com/apache/arrow-rs/issues/925 +#[tokio::test] +async fn query_array() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Int32, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["", "a", "aa", "aaa"])), + Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["[,0]"], + vec!["[a,1]"], + vec!["[aa,NULL]"], + vec!["[aaa,3]"], + ]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn query_count_distinct() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(3), + Some(3), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT COUNT(DISTINCT c1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+", + "| COUNT(DISTINCT test.c1) |", + "+-------------------------+", + "| 3 |", + "+-------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/group_by.rs b/datafusion/tests/sql/group_by.rs new file mode 100644 index 000000000000..38a0c2e44204 --- /dev/null +++ b/datafusion/tests/sql/group_by.rs @@ -0,0 +1,444 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_query_group_by_int_min_max() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+-----------------------------+", + "| c2 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", + "+----+-----------------------------+-----------------------------+", + "| 1 | 0.05636955101974106 | 0.9965400387585364 |", + "| 2 | 0.16301110515739792 | 0.991517828651004 |", + "| 3 | 0.047343434291126085 | 0.9293883502480845 |", + "| 4 | 0.02182578039211991 | 0.9237877978193884 |", + "| 5 | 0.01479305307777301 | 0.9723580396501548 |", + "+----+-----------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_float32() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = + "SELECT COUNT(*) as cnt, c1 FROM aggregate_simple GROUP BY c1 ORDER BY cnt DESC"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----+---------+", + "| cnt | c1 |", + "+-----+---------+", + "| 5 | 0.00005 |", + "| 4 | 0.00004 |", + "| 3 | 0.00003 |", + "| 2 | 0.00002 |", + "| 1 | 0.00001 |", + "+-----+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_float64() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = + "SELECT COUNT(*) as cnt, c2 FROM aggregate_simple GROUP BY c2 ORDER BY cnt DESC"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----+----------------+", + "| cnt | c2 |", + "+-----+----------------+", + "| 5 | 0.000000000005 |", + "| 4 | 0.000000000004 |", + "| 3 | 0.000000000003 |", + "| 2 | 0.000000000002 |", + "| 1 | 0.000000000001 |", + "+-----+----------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_boolean() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = + "SELECT COUNT(*) as cnt, c3 FROM aggregate_simple GROUP BY c3 ORDER BY cnt DESC"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----+-------+", + "| cnt | c3 |", + "+-----+-------+", + "| 9 | true |", + "| 6 | false |", + "+-----+-------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_two_columns() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, c2, MIN(c3) FROM aggregate_test_100 GROUP BY c1, c2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+----+----------------------------+", + "| c1 | c2 | MIN(aggregate_test_100.c3) |", + "+----+----+----------------------------+", + "| a | 1 | -85 |", + "| a | 2 | -48 |", + "| a | 3 | -72 |", + "| a | 4 | -101 |", + "| a | 5 | -101 |", + "| b | 1 | 12 |", + "| b | 2 | -60 |", + "| b | 3 | -101 |", + "| b | 4 | -117 |", + "| b | 5 | -82 |", + "| c | 1 | -24 |", + "| c | 2 | -117 |", + "| c | 3 | -2 |", + "| c | 4 | -90 |", + "| c | 5 | -94 |", + "| d | 1 | -99 |", + "| d | 2 | 93 |", + "| d | 3 | -76 |", + "| d | 4 | 5 |", + "| d | 5 | -59 |", + "| e | 1 | 36 |", + "| e | 2 | -61 |", + "| e | 3 | -95 |", + "| e | 4 | -56 |", + "| e | 5 | -86 |", + "+----+----+----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_and_having() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+------+", + "| c1 | m |", + "+----+------+", + "| a | -101 |", + "| c | -117 |", + "+----+------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_and_having_and_where() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, MIN(c3) AS m + FROM aggregate_test_100 + WHERE c1 IN ('a', 'b') + GROUP BY c1 + HAVING m < -100 AND MAX(c3) > 70"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+------+", + "| c1 | m |", + "+----+------+", + "| a | -101 |", + "+----+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_having_without_group_by() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, c2, c3 FROM aggregate_test_100 HAVING c2 >= 4 AND c3 > 90"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+----+-----+", + "| c1 | c2 | c3 |", + "+----+----+-----+", + "| c | 4 | 123 |", + "| c | 5 | 118 |", + "| d | 4 | 102 |", + "| e | 4 | 96 |", + "| e | 4 | 97 |", + "+----+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+", + "| c1 | AVG(aggregate_test_100.c12) |", + "+----+-----------------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+----+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_int_count() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-------------------------------+", + "| c1 | COUNT(aggregate_test_100.c12) |", + "+----+-------------------------------+", + "| a | 21 |", + "| b | 19 |", + "| c | 21 |", + "| d | 18 |", + "| e | 21 |", + "+----+-------------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_with_aliased_aggregate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-------+", + "| c1 | count |", + "+----+-------+", + "| a | 21 |", + "| b | 19 |", + "| c | 21 |", + "| d | 18 |", + "| e | 21 |", + "+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_string_min_max() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+-----------------------------+", + "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", + "+----+-----------------------------+-----------------------------+", + "| a | 0.02182578039211991 | 0.9800193410444061 |", + "| b | 0.04893135681998029 | 0.9185813970744787 |", + "| c | 0.0494924465469434 | 0.991517828651004 |", + "| d | 0.061029375346466685 | 0.9748360509016578 |", + "| e | 0.01479305307777301 | 0.9965400387585364 |", + "+----+-----------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_group_on_null() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + Some(3), + None, + Some(1), + Some(3), + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; + + let actual = execute_to_batches(&mut ctx, sql).await; + + // Note that the results also + // include a row for NULL (c1=NULL, count = 1) + let expected = vec![ + "+-----------------+----+", + "| COUNT(UInt8(1)) | c1 |", + "+-----------------+----+", + "| 1 | |", + "| 1 | 0 |", + "| 1 | 1 |", + "| 2 | 3 |", + "+-----------------+----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_group_on_null_multi_col() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![ + Some(0), + Some(0), + Some(3), + None, + None, + Some(3), + Some(0), + None, + Some(3), + ])), + Arc::new(StringArray::from(vec![ + None, + None, + Some("foo"), + None, + Some("bar"), + Some("foo"), + None, + Some("bar"), + Some("foo"), + ])), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; + + let actual = execute_to_batches(&mut ctx, sql).await; + + // Note that the results also include values for null + // include a row for NULL (c1=NULL, count = 1) + let expected = vec![ + "+-----------------+----+-----+", + "| COUNT(UInt8(1)) | c1 | c2 |", + "+-----------------+----+-----+", + "| 1 | | |", + "| 2 | | bar |", + "| 3 | 0 | |", + "| 3 | 3 | foo |", + "+-----------------+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + // Also run query with group columns reversed (results should be the same) + let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_group_by_date() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("date", DataType::Date32, false), + Field::new("cnt", DataType::Int32, false), + ])); + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Date32Array::from(vec![ + Some(100), + Some(100), + Some(100), + Some(101), + Some(101), + Some(101), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(3), + Some(3), + Some(3), + ])), + ], + )?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + + ctx.register_table("dates", Arc::new(table))?; + let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------+", + "| SUM(dates.cnt) |", + "+----------------+", + "| 6 |", + "| 9 |", + "+----------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/intersection.rs b/datafusion/tests/sql/intersection.rs new file mode 100644 index 000000000000..d28dd8079fa9 --- /dev/null +++ b/datafusion/tests/sql/intersection.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn intersect_with_null_not_equal() { + let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 + INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2"; + + let expected = vec!["++", "++"]; + let mut ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn intersect_with_null_equal() { + let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 + INTERSECT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2"; + + let expected = vec![ + "+-----+-----+", + "| id1 | id2 |", + "+-----+-----+", + "| | 1 |", + "+-----+-----+", + ]; + + let mut ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&mut ctx, sql).await; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn test_intersect_all() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + // execute the query + let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+------------+", + "| int_col | double_col |", + "+---------+------------+", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "+---------+------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_intersect_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + // execute the query + let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+------------+", + "| int_col | double_col |", + "+---------+------------+", + "| 1 | 10.1 |", + "+---------+------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs new file mode 100644 index 000000000000..1613463550f0 --- /dev/null +++ b/datafusion/tests/sql/joins.rs @@ -0,0 +1,687 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn equijoin() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 44 | d | x |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + + let mut ctx = create_join_context_qualified()?; + let equivalent_sql = [ + "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a", + "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a", + ]; + let expected = vec![ + "+---+-----+", + "| a | b |", + "+---+-----+", + "| 1 | 100 |", + "| 2 | 200 |", + "| 4 | 400 |", + "+---+-----+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn equijoin_multiple_condition_ordering() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t1_name <> t2_name ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t2_name <> t1_name ORDER BY t1_id", + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 44 | d | x |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn equijoin_and_other_condition() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "+-------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn equijoin_left_and_condition_from_right() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; + let res = ctx.create_logical_plan(sql); + assert!(res.is_ok()); + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 33 | c | |", + "| 44 | d | |", + "+-------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn equijoin_right_and_condition_from_left() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 ORDER BY t2_name"; + let res = ctx.create_logical_plan(sql); + assert!(res.is_ok()); + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| | | w |", + "| 44 | d | x |", + "| 22 | b | y |", + "| | | z |", + "+-------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn equijoin_and_unsupported_condition() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= '44' ORDER BY t1_id"; + let res = ctx.create_logical_plan(sql); + + assert!(res.is_err()); + assert_eq!(format!("{}", res.unwrap_err()), "This feature is not implemented: Unsupported expressions in Left JOIN: [#t1_id >= Utf8(\"44\")]"); + + Ok(()) +} + +#[tokio::test] +async fn left_join() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 33 | c | |", + "| 44 | d | x |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn left_join_unbalanced() -> Result<()> { + // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in + let mut ctx = create_join_context_unbalanced("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 33 | c | |", + "| 44 | d | x |", + "| 77 | e | |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn right_join() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id" + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 44 | d | x |", + "| | | w |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn full_join() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 33 | c | |", + "| 44 | d | x |", + "| | | w |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + + Ok(()) +} + +#[tokio::test] +async fn left_join_using() -> Result<()> { + let mut ctx = create_join_context("id", "id")?; + let sql = "SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+---------+---------+", + "| id | t1_name | t2_name |", + "+----+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 33 | c | |", + "| 44 | d | x |", + "+----+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn equijoin_implicit_syntax() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id", + ]; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 44 | d | x |", + "+-------+---------+---------+", + ]; + for sql in equivalent_sql.iter() { + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn equijoin_implicit_syntax_with_filter() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = "SELECT t1_id, t1_name, t2_name \ + FROM t1, t2 \ + WHERE t1_id > 0 \ + AND t1_id = t2_id \ + AND t2_id < 99 \ + ORDER BY t1_id"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 44 | d | x |", + "+-------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn equijoin_implicit_syntax_reversed() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 22 | b | y |", + "| 44 | d | x |", + "+-------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn cross_join() { + let mut ctx = create_join_context("t1_id", "t2_id").unwrap(); + + let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id"; + let actual = execute(&mut ctx, sql).await; + + assert_eq!(4 * 4, actual.len()); + + let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id"; + let actual = execute(&mut ctx, sql).await; + + assert_eq!(4 * 4, actual.len()); + + let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2"; + + let actual = execute(&mut ctx, sql).await; + assert_eq!(4 * 4, actual.len()); + + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 11 | a | y |", + "| 11 | a | x |", + "| 11 | a | w |", + "| 22 | b | z |", + "| 22 | b | y |", + "| 22 | b | x |", + "| 22 | b | w |", + "| 33 | c | z |", + "| 33 | c | y |", + "| 33 | c | x |", + "| 33 | c | w |", + "| 44 | d | z |", + "| 44 | d | y |", + "| 44 | d | x |", + "| 44 | d | w |", + "+-------+---------+---------+", + ]; + + assert_batches_eq!(expected, &actual); + + // Two partitions (from UNION) on the left + let sql = "SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2"; + let actual = execute(&mut ctx, sql).await; + + assert_eq!(4 * 4 * 2, actual.len()); + + // Two partitions (from UNION) on the right + let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2"; + let actual = execute(&mut ctx, sql).await; + + assert_eq!(4 * 4 * 2, actual.len()); +} + +#[tokio::test] +async fn cross_join_unbalanced() { + // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in + let mut ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); + + // the order of the values is not determinisitic, so we need to sort to check the values + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+---------+---------+", + "| t1_id | t1_name | t2_name |", + "+-------+---------+---------+", + "| 11 | a | z |", + "| 11 | a | y |", + "| 11 | a | x |", + "| 11 | a | w |", + "| 22 | b | z |", + "| 22 | b | y |", + "| 22 | b | x |", + "| 22 | b | w |", + "| 33 | c | z |", + "| 33 | c | y |", + "| 33 | c | x |", + "| 33 | c | w |", + "| 44 | d | z |", + "| 44 | d | y |", + "| 44 | d | x |", + "| 44 | d | w |", + "| 77 | e | z |", + "| 77 | e | y |", + "| 77 | e | x |", + "| 77 | e | w |", + "+-------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn test_join_timestamp() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + // register time table + let timestamp_schema = Arc::new(Schema::new(vec![Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )])); + let timestamp_data = RecordBatch::try_new( + timestamp_schema.clone(), + vec![Arc::new(TimestampNanosecondArray::from(vec![ + 131964190213133, + 131964190213134, + 131964190213135, + ]))], + )?; + let timestamp_table = + MemTable::try_new(timestamp_schema, vec![vec![timestamp_data]])?; + ctx.register_table("timestamp", Arc::new(timestamp_table))?; + + let sql = "SELECT * \ + FROM timestamp as a \ + JOIN (SELECT * FROM timestamp) as b \ + ON a.time = b.time \ + ORDER BY a.time"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------------------------------+-------------------------------+", + "| time | time |", + "+-------------------------------+-------------------------------+", + "| 1970-01-02 12:39:24.190213133 | 1970-01-02 12:39:24.190213133 |", + "| 1970-01-02 12:39:24.190213134 | 1970-01-02 12:39:24.190213134 |", + "| 1970-01-02 12:39:24.190213135 | 1970-01-02 12:39:24.190213135 |", + "+-------------------------------+-------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_join_float32() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + // register population table + let population_schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, true), + Field::new("population", DataType::Float32, true), + ])); + let population_data = RecordBatch::try_new( + population_schema.clone(), + vec![ + Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])), + Arc::new(Float32Array::from(vec![838.698, 1778.934, 626.443])), + ], + )?; + let population_table = + MemTable::try_new(population_schema, vec![vec![population_data]])?; + ctx.register_table("population", Arc::new(population_table))?; + + let sql = "SELECT * \ + FROM population as a \ + JOIN (SELECT * FROM population) as b \ + ON a.population = b.population \ + ORDER BY a.population"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+------+------------+------+------------+", + "| city | population | city | population |", + "+------+------------+------+------------+", + "| c | 626.443 | c | 626.443 |", + "| a | 838.698 | a | 838.698 |", + "| b | 1778.934 | b | 1778.934 |", + "+------+------------+------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_join_float64() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + // register population table + let population_schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, true), + Field::new("population", DataType::Float64, true), + ])); + let population_data = RecordBatch::try_new( + population_schema.clone(), + vec![ + Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])), + Arc::new(Float64Array::from(vec![838.698, 1778.934, 626.443])), + ], + )?; + let population_table = + MemTable::try_new(population_schema, vec![vec![population_data]])?; + ctx.register_table("population", Arc::new(population_table))?; + + let sql = "SELECT * \ + FROM population as a \ + JOIN (SELECT * FROM population) as b \ + ON a.population = b.population \ + ORDER BY a.population"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+------+------------+------+------------+", + "| city | population | city | population |", + "+------+------------+------+------------+", + "| c | 626.443 | c | 626.443 |", + "| a | 838.698 | a | 838.698 |", + "| b | 1778.934 | b | 1778.934 |", + "+------+------------+------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO Tests to prove correct implementation of INNER JOIN's with qualified names. +// https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11432. +#[tokio::test] +#[ignore] +async fn inner_join_qualified_names() -> Result<()> { + // Setup the statements that test qualified names function correctly. + let equivalent_sql = [ + "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c + FROM t1 + INNER JOIN t2 ON t1.a = t2.a + ORDER BY t1.a", + "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c + FROM t1 + INNER JOIN t2 ON t2.a = t1.a + ORDER BY t1.a", + ]; + + let expected = vec![ + "+---+----+----+---+-----+-----+", + "| a | b | c | a | b | c |", + "+---+----+----+---+-----+-----+", + "| 1 | 10 | 50 | 1 | 100 | 500 |", + "| 2 | 20 | 60 | 2 | 200 | 600 |", + "| 4 | 40 | 80 | 4 | 400 | 800 |", + "+---+----+----+---+-----+-----+", + ]; + + for sql in equivalent_sql.iter() { + let mut ctx = create_join_context_qualified()?; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn inner_join_nulls() { + let sql = "SELECT * FROM (SELECT null AS id1) t1 + INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; + + let expected = vec!["++", "++"]; + + let mut ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&mut ctx, sql).await; + + // left and right shouldn't match anything + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Result<()> { + let batch = RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int32Array::from(vec![1, 2, 3])) as _), + ( + "country", + Arc::new(StringArray::from(vec!["Germany", "Sweden", "Japan"])) as _, + ), + ]) + .unwrap(); + let countries = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + + let batch = RecordBatch::try_from_iter(vec![ + ( + "id", + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])) as _, + ), + ( + "city", + Arc::new(StringArray::from(vec![ + "Hamburg", + "Stockholm", + "Osaka", + "Berlin", + "Göteborg", + "Tokyo", + "Kyoto", + ])) as _, + ), + ( + "country_id", + Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3, 3])) as _, + ), + ]) + .unwrap(); + let cities = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("countries", Arc::new(countries))?; + ctx.register_table("cities", Arc::new(cities))?; + + // city.id is not in the on constraint, but the output result will contain both city.id and + // country.id + let sql = "SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+----+-----------+---------+", + "| id | id | city | country |", + "+----+----+-----------+---------+", + "| 1 | 1 | Hamburg | Germany |", + "| 2 | 2 | Stockholm | Sweden |", + "| 3 | 3 | Osaka | Japan |", + "| 4 | 1 | Berlin | Germany |", + "| 5 | 2 | Göteborg | Sweden |", + "| 6 | 3 | Tokyo | Japan |", + "| 7 | 3 | Kyoto | Japan |", + "+----+----+-----------+---------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/limit.rs b/datafusion/tests/sql/limit.rs new file mode 100644 index 000000000000..fd68e330bee1 --- /dev/null +++ b/datafusion/tests/sql/limit.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_query_limit() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["+----+", "| c1 |", "+----+", "| c |", "| d |", "+----+"]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; + let actual = execute_to_batches(&mut ctx, sql).await; + // println!("{}", pretty_format_batches(&a).unwrap()); + let expected = vec![ + "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", + "| 4 |", "| 3 |", "| 3 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 3 |", + "| 2 |", "| 1 |", "| 1 |", "| 2 |", "| 1 |", "| 3 |", "| 2 |", "| 4 |", + "| 1 |", "| 5 |", "| 4 |", "| 2 |", "| 1 |", "| 4 |", "| 5 |", "| 2 |", + "| 3 |", "| 4 |", "| 2 |", "| 1 |", "| 5 |", "| 3 |", "| 1 |", "| 2 |", + "| 3 |", "| 3 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", "| 2 |", + "| 5 |", "| 2 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 2 |", "| 5 |", + "| 4 |", "| 2 |", "| 3 |", "| 4 |", "| 4 |", "| 4 |", "| 5 |", "| 4 |", + "| 2 |", "| 1 |", "| 2 |", "| 4 |", "| 2 |", "| 3 |", "| 5 |", "| 1 |", + "| 1 |", "| 4 |", "| 2 |", "| 1 |", "| 2 |", "| 1 |", "| 1 |", "| 5 |", + "| 4 |", "| 5 |", "| 2 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", + "| 4 |", "| 3 |", "| 2 |", "| 5 |", "| 3 |", "| 3 |", "| 2 |", "| 5 |", + "| 5 |", "| 4 |", "| 1 |", "| 3 |", "| 3 |", "| 4 |", "| 4 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", + "| 4 |", "| 3 |", "| 3 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 3 |", + "| 2 |", "| 1 |", "| 1 |", "| 2 |", "| 1 |", "| 3 |", "| 2 |", "| 4 |", + "| 1 |", "| 5 |", "| 4 |", "| 2 |", "| 1 |", "| 4 |", "| 5 |", "| 2 |", + "| 3 |", "| 4 |", "| 2 |", "| 1 |", "| 5 |", "| 3 |", "| 1 |", "| 2 |", + "| 3 |", "| 3 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", "| 2 |", + "| 5 |", "| 2 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 2 |", "| 5 |", + "| 4 |", "| 2 |", "| 3 |", "| 4 |", "| 4 |", "| 4 |", "| 5 |", "| 4 |", + "| 2 |", "| 1 |", "| 2 |", "| 4 |", "| 2 |", "| 3 |", "| 5 |", "| 1 |", + "| 1 |", "| 4 |", "| 2 |", "| 1 |", "| 2 |", "| 1 |", "| 1 |", "| 5 |", + "| 4 |", "| 5 |", "| 2 |", "| 3 |", "| 2 |", "| 4 |", "| 1 |", "| 3 |", + "| 4 |", "| 3 |", "| 2 |", "| 5 |", "| 3 |", "| 3 |", "| 2 |", "| 5 |", + "| 5 |", "| 4 |", "| 1 |", "| 3 |", "| 3 |", "| 4 |", "| 4 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_limit_zero() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 0"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs new file mode 100644 index 000000000000..3cc129e73115 --- /dev/null +++ b/datafusion/tests/sql/mod.rs @@ -0,0 +1,726 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::TryFrom; +use std::sync::Arc; + +use arrow::{ + array::*, datatypes::*, record_batch::RecordBatch, + util::display::array_value_to_string, +}; +use chrono::prelude::*; +use chrono::Duration; + +use datafusion::assert_batches_eq; +use datafusion::assert_batches_sorted_eq; +use datafusion::assert_contains; +use datafusion::assert_not_contains; +use datafusion::logical_plan::plan::{Aggregate, Projection}; +use datafusion::logical_plan::LogicalPlan; +use datafusion::logical_plan::TableScan; +use datafusion::physical_plan::functions::Volatility; +use datafusion::physical_plan::metrics::MetricValue; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::ExecutionPlanVisitor; +use datafusion::prelude::*; +use datafusion::test_util; +use datafusion::{datasource::MemTable, physical_plan::collect}; +use datafusion::{ + error::{DataFusionError, Result}, + physical_plan::ColumnarValue, +}; +use datafusion::{execution::context::ExecutionContext, physical_plan::displayable}; + +/// A macro to assert that some particular line contains two substrings +/// +/// Usage: `assert_metrics!(actual, operator_name, metrics)` +/// +macro_rules! assert_metrics { + ($ACTUAL: expr, $OPERATOR_NAME: expr, $METRICS: expr) => { + let found = $ACTUAL + .lines() + .any(|line| line.contains($OPERATOR_NAME) && line.contains($METRICS)); + assert!( + found, + "Can not find a line with both '{}' and '{}' in\n\n{}", + $OPERATOR_NAME, $METRICS, $ACTUAL + ); + }; +} + +macro_rules! test_expression { + ($SQL:expr, $EXPECTED:expr) => { + let mut ctx = ExecutionContext::new(); + let sql = format!("SELECT {}", $SQL); + let actual = execute(&mut ctx, sql.as_str()).await; + assert_eq!(actual[0][0], $EXPECTED); + }; +} + +pub mod aggregates; +#[cfg(feature = "avro")] +pub mod avro; +pub mod create_drop; +pub mod errors; +pub mod explain_analyze; +pub mod expr; +pub mod functions; +pub mod group_by; +pub mod intersection; +pub mod joins; +pub mod limit; +pub mod order; +pub mod parquet; +pub mod predicates; +pub mod projection; +pub mod references; +pub mod select; +pub mod timestamp; +pub mod udf; +pub mod union; +pub mod window; + +#[cfg_attr(not(feature = "unicode_expressions"), ignore)] +pub mod unicode; + +fn assert_float_eq(expected: &[Vec], received: &[Vec]) +where + T: AsRef, +{ + expected + .iter() + .flatten() + .zip(received.iter().flatten()) + .for_each(|(l, r)| { + let (l, r) = ( + l.as_ref().parse::().unwrap(), + r.as_str().parse::().unwrap(), + ); + assert!((l - r).abs() <= 2.0 * f64::EPSILON); + }); +} + +#[allow(clippy::unnecessary_wraps)] +fn create_ctx() -> Result { + let mut ctx = ExecutionContext::new(); + + // register a custom UDF + ctx.register_udf(create_udf( + "custom_sqrt", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(custom_sqrt), + )); + + Ok(ctx) +} + +fn custom_sqrt(args: &[ColumnarValue]) -> Result { + let arg = &args[0]; + if let ColumnarValue::Array(v) = arg { + let input = v + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); + Ok(ColumnarValue::Array(Arc::new(array))) + } else { + unimplemented!() + } +} + +fn create_case_context() -> Result { + let mut ctx = ExecutionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + None, + ]))], + )?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + ctx.register_table("t1", Arc::new(table))?; + Ok(ctx) +} + +fn create_join_context( + column_left: &str, + column_right: &str, +) -> Result { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new(column_left, DataType::UInt32, true), + Field::new("t1_name", DataType::Utf8, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 33, 44])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + ])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new(column_right, DataType::UInt32, true), + Field::new("t2_name", DataType::Utf8, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), + Arc::new(StringArray::from(vec![ + Some("z"), + Some("y"), + Some("x"), + Some("w"), + ])), + ], + )?; + let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(t2_table))?; + + Ok(ctx) +} + +fn create_join_context_qualified() -> Result { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, true), + Field::new("b", DataType::UInt32, true), + Field::new("c", DataType::UInt32, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), + Arc::new(UInt32Array::from(vec![10, 20, 30, 40])), + Arc::new(UInt32Array::from(vec![50, 60, 70, 80])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, true), + Field::new("b", DataType::UInt32, true), + Field::new("c", DataType::UInt32, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 9, 4])), + Arc::new(UInt32Array::from(vec![100, 200, 300, 400])), + Arc::new(UInt32Array::from(vec![500, 600, 700, 800])), + ], + )?; + let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(t2_table))?; + + Ok(ctx) +} + +/// the table column_left has more rows than the table column_right +fn create_join_context_unbalanced( + column_left: &str, + column_right: &str, +) -> Result { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new(column_left, DataType::UInt32, true), + Field::new("t1_name", DataType::Utf8, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 33, 44, 77])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new(column_right, DataType::UInt32, true), + Field::new("t2_name", DataType::Utf8, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), + Arc::new(StringArray::from(vec![ + Some("z"), + Some("y"), + Some("x"), + Some("w"), + ])), + ], + )?; + let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(t2_table))?; + + Ok(ctx) +} + +fn get_tpch_table_schema(table: &str) -> Schema { + match table { + "customer" => Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Float64, false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]), + + "orders" => Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Float64, false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]), + + "lineitem" => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Float64, false), + Field::new("l_extendedprice", DataType::Float64, false), + Field::new("l_discount", DataType::Float64, false), + Field::new("l_tax", DataType::Float64, false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]), + + "nation" => Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]), + + _ => unimplemented!(), + } +} + +async fn register_tpch_csv(ctx: &mut ExecutionContext, table: &str) -> Result<()> { + let schema = get_tpch_table_schema(table); + + ctx.register_csv( + table, + format!("tests/tpch-csv/{}.csv", table).as_str(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + +async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { + let testdata = datafusion::test_util::arrow_test_data(); + + // TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once + // unsigned is supported. + let df = ctx + .sql(&format!( + " + CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 INT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INT NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL + ) + STORED AS CSV + WITH HEADER ROW + LOCATION '{}/csv/aggregate_test_100.csv' + ", + testdata + )) + .await + .expect("Creating dataframe for CREATE EXTERNAL TABLE"); + + // Mimic the CLI and execute the resulting plan -- even though it + // is effectively a no-op (returns zero rows) + let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE"); + assert!( + results.is_empty(), + "Expected no rows from executing CREATE EXTERNAL TABLE" + ); +} + +/// Create table "t1" with two boolean columns "a" and "b" +async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> { + let a: BooleanArray = [ + Some(true), + Some(true), + Some(true), + None, + None, + None, + Some(false), + Some(false), + Some(false), + ] + .iter() + .collect(); + let b: BooleanArray = [ + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + ] + .iter() + .collect(); + + let data = + RecordBatch::try_from_iter([("a", Arc::new(a) as _), ("b", Arc::new(b) as _)])?; + let table = MemTable::try_new(data.schema(), vec![vec![data]])?; + ctx.register_table("t1", Arc::new(table))?; + Ok(()) +} + +async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> { + // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Float32, false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + ctx.register_csv( + "aggregate_simple", + "tests/aggregate_simple.csv", + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + +async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { + let testdata = datafusion::test_util::arrow_test_data(); + let schema = test_util::aggr_test_schema(); + ctx.register_csv( + "aggregate_test_100", + &format!("{}/csv/aggregate_test_100.csv", testdata), + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + +/// Execute query and return result set as 2-d table of Vecs +/// `result[row][column]` +async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec { + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let logical_schema = plan.schema(); + + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); + let optimized_logical_schema = plan.schema(); + + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).await.expect(&msg); + + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = collect(plan).await.expect(&msg); + + assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); + results +} + +/// Execute query and return result set as 2-d table of Vecs +/// `result[row][column]` +async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { + result_vec(&execute_to_batches(ctx, sql).await) +} + +/// Specialised String representation +fn col_str(column: &ArrayRef, row_index: usize) -> String { + if column.is_null(row_index) { + return "NULL".to_string(); + } + + // Special case ListArray as there is no pretty print support for it yet + if let DataType::FixedSizeList(_, n) = column.data_type() { + let array = column + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index); + + let mut r = Vec::with_capacity(*n as usize); + for i in 0..*n { + r.push(col_str(&array, i as usize)); + } + return format!("[{}]", r.join(",")); + } + + array_value_to_string(column, row_index) + .ok() + .unwrap_or_else(|| "???".to_string()) +} + +/// Converts the results into a 2d array of strings, `result[row][column]` +/// Special cases nulls to NULL for testing +fn result_vec(results: &[RecordBatch]) -> Vec> { + let mut result = vec![]; + for batch in results { + for row_index in 0..batch.num_rows() { + let row_vec = batch + .columns() + .iter() + .map(|column| col_str(column, row_index)) + .collect(); + result.push(row_vec); + } + } + result +} + +async fn generic_query_length>>( + datatype: DataType, +) -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", datatype, false)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(T::from(vec!["", "a", "aa", "aaa"]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT length(c1) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; + assert_eq!(expected, actual); + Ok(()) +} + +async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) { + let df = ctx + .sql( + "CREATE EXTERNAL TABLE aggregate_simple ( + c1 DECIMAL(10,6) NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL + ) + STORED AS CSV + WITH HEADER ROW + LOCATION 'tests/aggregate_simple.csv'", + ) + .await + .expect("Creating dataframe for CREATE EXTERNAL TABLE with decimal data type"); + + let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE"); + assert!( + results.is_empty(), + "Expected no rows from executing CREATE EXTERNAL TABLE" + ); +} + +async fn register_alltypes_parquet(ctx: &mut ExecutionContext) { + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{}/alltypes_plain.parquet", testdata), + ) + .await + .unwrap(); +} + +fn make_timestamp_table() -> Result> +where + A: ArrowTimestampType, +{ + make_timestamp_tz_table::(None) +} + +fn make_timestamp_tz_table(tz: Option) -> Result> +where + A: ArrowTimestampType, +{ + let schema = Arc::new(Schema::new(vec![ + Field::new( + "ts", + DataType::Timestamp(A::get_time_unit(), tz.clone()), + false, + ), + Field::new("value", DataType::Int32, true), + ])); + + let divisor = match A::get_time_unit() { + TimeUnit::Nanosecond => 1, + TimeUnit::Microsecond => 1000, + TimeUnit::Millisecond => 1_000_000, + TimeUnit::Second => 1_000_000_000, + }; + + let timestamps = vec![ + 1599572549190855000i64 / divisor, // 2020-09-08T13:42:29.190855+00:00 + 1599568949190855000 / divisor, // 2020-09-08T12:42:29.190855+00:00 + 1599565349190855000 / divisor, //2020-09-08T11:42:29.190855+00:00 + ]; // 2020-09-08T11:42:29.190855+00:00 + + let array = PrimitiveArray::::from_vec(timestamps, tz); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(array), + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + ], + )?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + Ok(Arc::new(table)) +} + +fn make_timestamp_nano_table() -> Result> { + make_timestamp_table::() +} + +// Normalizes parts of an explain plan that vary from run to run (such as path) +fn normalize_for_explain(s: &str) -> String { + // Convert things like /Users/alamb/Software/arrow/testing/data/csv/aggregate_test_100.csv + // to ARROW_TEST_DATA/csv/aggregate_test_100.csv + let data_path = datafusion::test_util::arrow_test_data(); + let s = s.replace(&data_path, "ARROW_TEST_DATA"); + + // convert things like partitioning=RoundRobinBatch(16) + // to partitioning=RoundRobinBatch(NUM_CORES) + let needle = format!("RoundRobinBatch({})", num_cpus::get()); + s.replace(&needle, "RoundRobinBatch(NUM_CORES)") +} + +/// Applies normalize_for_explain to every line +fn normalize_vec_for_explain(v: Vec>) -> Vec> { + v.into_iter() + .map(|l| { + l.into_iter() + .map(|s| normalize_for_explain(&s)) + .collect::>() + }) + .collect::>() +} + +#[tokio::test] +async fn nyc() -> Result<()> { + // schema for nyxtaxi csv files + let schema = Schema::new(vec![ + Field::new("VendorID", DataType::Utf8, true), + Field::new("tpep_pickup_datetime", DataType::Utf8, true), + Field::new("tpep_dropoff_datetime", DataType::Utf8, true), + Field::new("passenger_count", DataType::Utf8, true), + Field::new("trip_distance", DataType::Float64, true), + Field::new("RatecodeID", DataType::Utf8, true), + Field::new("store_and_fwd_flag", DataType::Utf8, true), + Field::new("PULocationID", DataType::Utf8, true), + Field::new("DOLocationID", DataType::Utf8, true), + Field::new("payment_type", DataType::Utf8, true), + Field::new("fare_amount", DataType::Float64, true), + Field::new("extra", DataType::Float64, true), + Field::new("mta_tax", DataType::Float64, true), + Field::new("tip_amount", DataType::Float64, true), + Field::new("tolls_amount", DataType::Float64, true), + Field::new("improvement_surcharge", DataType::Float64, true), + Field::new("total_amount", DataType::Float64, true), + ]); + + let mut ctx = ExecutionContext::new(); + ctx.register_csv( + "tripdata", + "file.csv", + CsvReadOptions::new().schema(&schema), + ) + .await?; + + let logical_plan = ctx.create_logical_plan( + "SELECT passenger_count, MIN(fare_amount), MAX(fare_amount) \ + FROM tripdata GROUP BY passenger_count", + )?; + + let optimized_plan = ctx.optimize(&logical_plan)?; + + match &optimized_plan { + LogicalPlan::Projection(Projection { input, .. }) => match input.as_ref() { + LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { + LogicalPlan::TableScan(TableScan { + ref projected_schema, + .. + }) => { + assert_eq!(2, projected_schema.fields().len()); + assert_eq!(projected_schema.field(0).name(), "passenger_count"); + assert_eq!(projected_schema.field(1).name(), "fare_amount"); + } + _ => unreachable!(), + }, + _ => unreachable!(), + }, + _ => unreachable!(false), + } + + Ok(()) +} diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs new file mode 100644 index 000000000000..fa59d9d19661 --- /dev/null +++ b/datafusion/tests/sql/order.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn test_sort_unprojected_col() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + // execute the query + let sql = "SELECT id FROM alltypes_plain ORDER BY int_col, double_col"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+", "| id |", "+----+", "| 4 |", "| 6 |", "| 2 |", "| 0 |", "| 5 |", + "| 7 |", "| 3 |", "| 1 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_order_by_agg_expr() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------------------+", + "| MIN(aggregate_test_100.c12) |", + "+-----------------------------+", + "| 0.01479305307777301 |", + "+-----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 0.1"; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_nulls_first_asc() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----+--------+", + "| num | letter |", + "+-----+--------+", + "| 1 | one |", + "| 2 | two |", + "| | three |", + "+-----+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_nulls_first_desc() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----+--------+", + "| num | letter |", + "+-----+--------+", + "| | three |", + "| 2 | two |", + "| 1 | one |", + "+-----+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_specific_nulls_last_desc() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC NULLS LAST"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----+--------+", + "| num | letter |", + "+-----+--------+", + "| 2 | two |", + "| 1 | one |", + "| | three |", + "+-----+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_specific_nulls_first_asc() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num ASC NULLS FIRST"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----+--------+", + "| num | letter |", + "+-----+--------+", + "| | three |", + "| 1 | one |", + "| 2 | two |", + "+-----+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/parquet.rs b/datafusion/tests/sql/parquet.rs new file mode 100644 index 000000000000..b4f08d143963 --- /dev/null +++ b/datafusion/tests/sql/parquet.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn parquet_query() { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + // NOTE that string_col is actually a binary column and does not have the UTF8 logical type + // so we need an explicit cast + let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-----------------------------------------+", + "| id | CAST(alltypes_plain.string_col AS Utf8) |", + "+----+-----------------------------------------+", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "+----+-----------------------------------------+", + ]; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn parquet_single_nan_schema() { + let mut ctx = ExecutionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) + .await + .unwrap(); + let sql = "SELECT mycol FROM single_nan"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).await.unwrap(); + let results = collect(plan).await.unwrap(); + for batch in results { + assert_eq!(1, batch.num_rows()); + assert_eq!(1, batch.num_columns()); + } +} + +#[tokio::test] +#[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] +async fn parquet_list_columns() { + let mut ctx = ExecutionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "list_columns", + &format!("{}/list_columns.parquet", testdata), + ) + .await + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "int64_list", + DataType::List(Box::new(Field::new("item", DataType::Int64, true))), + true, + ), + Field::new( + "utf8_list", + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), + true, + ), + ])); + + let sql = "SELECT int64_list, utf8_list FROM list_columns"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan).await.unwrap(); + let results = collect(plan).await.unwrap(); + + // int64_list utf8_list + // 0 [1, 2, 3] [abc, efg, hij] + // 1 [None, 1] None + // 2 [4] [efg, None, hij, xyz] + + assert_eq!(1, results.len()); + let batch = &results[0]; + assert_eq!(3, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + assert_eq!(schema, batch.schema()); + + let int_list_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let utf8_list_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + int_list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) + ); + + assert_eq!( + utf8_list_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap(), + &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + ); + + assert_eq!( + int_list_array + .value(1) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![None, Some(1),]) + ); + + assert!(utf8_list_array.is_null(1)); + + assert_eq!( + int_list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(4),]) + ); + + let result = utf8_list_array.value(2); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), "efg"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "hij"); + assert_eq!(result.value(3), "xyz"); +} diff --git a/datafusion/tests/sql/predicates.rs b/datafusion/tests/sql/predicates.rs new file mode 100644 index 000000000000..f4e1f4f4deef --- /dev/null +++ b/datafusion/tests/sql/predicates.rs @@ -0,0 +1,371 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_query_with_predicate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+---------------------+", + "| c1 | c12 |", + "+----+---------------------+", + "| e | 0.39144436569161134 |", + "| d | 0.38870280983958583 |", + "+----+---------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_negative_predicate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c1, c4 FROM aggregate_test_100 WHERE c3 < -55 AND -c4 > 30000"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+--------+", + "| c1 | c4 |", + "+----+--------+", + "| e | -31500 |", + "| c | -30187 |", + "+----+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_negated_predicate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE NOT(c1 != 'a')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 21 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_is_not_null_predicate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NOT NULL"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 100 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_is_null_predicate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NULL"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 0 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_where_neg_num() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + + // Negative numbers do not parse correctly as of Arrow 2.0.0 + let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2 and c7 < 10"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+-------+", + "| c7 | c8 |", + "+----+-------+", + "| 7 | 45465 |", + "| 5 | 40622 |", + "| 0 | 61069 |", + "| 2 | 20120 |", + "| 4 | 39363 |", + "+----+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // Also check floating point neg numbers + let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2.9 and c7 < 10"; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn like() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; + // check that the physical and logical schemas are equal + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------------------------------+", + "| COUNT(aggregate_test_100.c1) |", + "+------------------------------+", + "| 1 |", + "+------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_between_expr() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 BETWEEN 0.995 AND 1.0"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c4 |", + "+-------+", + "| 10837 |", + "+-------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_between_expr_negated() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 NOT BETWEEN 0 AND 0.995"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c4 |", + "+-------+", + "| 10837 |", + "+-------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn like_on_strings() -> Result<()> { + let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")] + .into_iter() + .collect::(); + + let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); + + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + + let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| bar |", + "| fazzz |", + "+-------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn like_on_string_dictionaries() -> Result<()> { + let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")] + .into_iter() + .collect::>(); + + let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); + + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + + let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| bar |", + "| fazzz |", + "+-------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_regexp_is_match() -> Result<()> { + let input = vec![Some("foo"), Some("Barrr"), Some("Bazzz"), Some("ZZZZZ")] + .into_iter() + .collect::(); + + let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); + + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + + let sql = "SELECT * FROM test WHERE c1 ~ 'z'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| Bazzz |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 ~* 'z'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| Bazzz |", + "| ZZZZZ |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 !~ 'z'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| foo |", + "| Barrr |", + "| ZZZZZ |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT * FROM test WHERE c1 !~* 'z'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| c1 |", + "+-------+", + "| foo |", + "| Barrr |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn except_with_null_not_equal() { + let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 + EXCEPT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2"; + + let expected = vec![ + "+-----+-----+", + "| id1 | id2 |", + "+-----+-----+", + "| | 1 |", + "+-----+-----+", + ]; + + let mut ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&mut ctx, sql).await; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn except_with_null_equal() { + let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 + EXCEPT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2"; + + let expected = vec!["++", "++"]; + let mut ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&mut ctx, sql).await; + + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn test_expect_all() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + // execute the query + let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+------------+", + "| int_col | double_col |", + "+---------+------------+", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "+---------+------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_expect_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx).await; + // execute the query + let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+------------+", + "| int_col | double_col |", + "+---------+------------+", + "| 1 | 10.1 |", + "+---------+------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/projection.rs b/datafusion/tests/sql/projection.rs new file mode 100644 index 000000000000..57fa598bb754 --- /dev/null +++ b/datafusion/tests/sql/projection.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn projection_same_fields() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let sql = "select (1+1) as a from (select 1 as a) as b;"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec!["+---+", "| a |", "+---+", "| 2 |", "+---+"]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn projection_type_alias() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + // Query that aliases one column to the name of a different column + // that also has a different type (c1 == float32, c3 == boolean) + let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c3 |", + "+---------+", + "| 0.00001 |", + "| 0.00002 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn csv_query_group_by_avg_with_projection() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------------------+----+", + "| AVG(aggregate_test_100.c12) | c1 |", + "+-----------------------------+----+", + "| 0.41040709263815384 | b |", + "| 0.48600669271341534 | e |", + "| 0.48754517466109415 | a |", + "| 0.48855379387549824 | d |", + "| 0.6600456536439784 | c |", + "+-----------------------------+----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/references.rs b/datafusion/tests/sql/references.rs new file mode 100644 index 000000000000..779c6a336673 --- /dev/null +++ b/datafusion/tests/sql/references.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn qualified_table_references() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + + for table_ref in &[ + "aggregate_test_100", + "public.aggregate_test_100", + "datafusion.public.aggregate_test_100", + ] { + let sql = format!("SELECT COUNT(*) FROM {}", table_ref); + let actual = execute_to_batches(&mut ctx, &sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 100 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn qualified_table_references_and_fields() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let c1: StringArray = vec!["foofoo", "foobar", "foobaz"] + .into_iter() + .map(Some) + .collect(); + let c2: Int64Array = vec![1, 2, 3].into_iter().map(Some).collect(); + let c3: Int64Array = vec![10, 20, 30].into_iter().map(Some).collect(); + + let batch = RecordBatch::try_from_iter(vec![ + ("f.c1", Arc::new(c1) as ArrayRef), + // evil -- use the same name as the table + ("test.c2", Arc::new(c2) as ArrayRef), + // more evil still + ("....", Arc::new(c3) as ArrayRef), + ])?; + + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("test", Arc::new(table))?; + + // referring to the unquoted column is an error + let sql = r#"SELECT f1.c1 from test"#; + let error = ctx.create_logical_plan(sql).unwrap_err(); + assert_contains!( + error.to_string(), + "No field named 'f1.c1'. Valid fields are 'test.f.c1', 'test.test.c2'" + ); + + // however, enclosing it in double quotes is ok + let sql = r#"SELECT "f.c1" from test"#; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------+", + "| f.c1 |", + "+--------+", + "| foofoo |", + "| foobar |", + "| foobaz |", + "+--------+", + ]; + assert_batches_eq!(expected, &actual); + // Works fully qualified too + let sql = r#"SELECT test."f.c1" from test"#; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + + // check that duplicated table name and column name are ok + let sql = r#"SELECT "test.c2" as expr1, test."test.c2" as expr2 from test"#; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+-------+", + "| expr1 | expr2 |", + "+-------+-------+", + "| 1 | 1 |", + "| 2 | 2 |", + "| 3 | 3 |", + "+-------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // check that '....' is also an ok column name (in the sense that + // datafusion should run the query, not that someone should write + // this + let sql = r#"SELECT "....", "...." as c3 from test order by "....""#; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------+----+", + "| .... | c3 |", + "+------+----+", + "| 10 | 10 |", + "| 20 | 20 |", + "| 30 | 30 |", + "+------+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_partial_qualified_name() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = "SELECT t1.t1_id, t1_name FROM public.t1"; + let expected = vec![ + "+-------+---------+", + "| t1_id | t1_name |", + "+-------+---------+", + "| 11 | a |", + "| 22 | b |", + "| 33 | c |", + "| 44 | d |", + "+-------+---------+", + ]; + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs new file mode 100644 index 000000000000..8d0d12f18d1e --- /dev/null +++ b/datafusion/tests/sql/select.rs @@ -0,0 +1,856 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn all_where_empty() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT * + FROM aggregate_test_100 + WHERE 1=2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["++", "++"]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn select_values_list() -> Result<()> { + let mut ctx = ExecutionContext::new(); + { + let sql = "VALUES (1)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+", + "| column1 |", + "+---------+", + "| 1 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (-1)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+", + "| column1 |", + "+---------+", + "| -1 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (2+1,2-1,2>1)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+---------+", + "| column1 | column2 | column3 |", + "+---------+---------+---------+", + "| 3 | 1 | true |", + "+---------+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES ()"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES (1),(2)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+", + "| column1 |", + "+---------+", + "| 1 |", + "| 2 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (1),()"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES (1,'a'),(2,'b')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+", + "| column1 | column2 |", + "+---------+---------+", + "| 1 | a |", + "| 2 | b |", + "+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (1),(1,2)"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES (1),('2')"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES (1),(2.0)"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES (1,2), (1,'2')"; + let plan = ctx.create_logical_plan(sql); + assert!(plan.is_err()); + } + { + let sql = "VALUES (1,'a'),(NULL,'b'),(3,'c')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+", + "| column1 | column2 |", + "+---------+---------+", + "| 1 | a |", + "| | b |", + "| 3 | c |", + "+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (NULL,'a'),(NULL,'b'),(3,'c')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+", + "| column1 | column2 |", + "+---------+---------+", + "| | a |", + "| | b |", + "| 3 | c |", + "+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (NULL,'a'),(NULL,'b'),(NULL,'c')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+", + "| column1 | column2 |", + "+---------+---------+", + "| | a |", + "| | b |", + "| | c |", + "+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (1,'a'),(2,NULL),(3,'c')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+", + "| column1 | column2 |", + "+---------+---------+", + "| 1 | a |", + "| 2 | |", + "| 3 | c |", + "+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (1,NULL),(2,NULL),(3,'c')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+", + "| column1 | column2 |", + "+---------+---------+", + "| 1 | |", + "| 2 | |", + "| 3 | c |", + "+---------+---------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", + "| column1 | column2 | column3 | column4 | column5 | column6 | column7 | column8 | column9 | column10 | column11 | column12 | column13 | column14 | column15 | column16 |", + "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", + "| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | | F | 3.5 |", + "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "SELECT * FROM (VALUES (1,'a'),(2,NULL)) AS t(c1, c2)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | a |", + "| 2 | |", + "+----+----+", + ]; + assert_batches_eq!(expected, &actual); + } + { + let sql = "EXPLAIN VALUES (1, 'a', -1, 1.1),(NULL, 'b', -3, 0.5)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------+-----------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+-----------------------------------------------------------------------------------------------------------+", + "| logical_plan | Values: (Int64(1), Utf8(\"a\"), Int64(-1), Float64(1.1)), (Int64(NULL), Utf8(\"b\"), Int64(-3), Float64(0.5)) |", + "| physical_plan | ValuesExec |", + "| | |", + "+---------------+-----------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + Ok(()) +} + +#[tokio::test] +async fn select_all() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = "SELECT c1 FROM aggregate_simple order by c1"; + let results = execute_to_batches(&mut ctx, sql).await; + + let sql_all = "SELECT ALL c1 FROM aggregate_simple order by c1"; + let results_all = execute_to_batches(&mut ctx, sql_all).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00001 |", + "| 0.00002 |", + "| 0.00002 |", + "| 0.00003 |", + "| 0.00003 |", + "| 0.00003 |", + "| 0.00004 |", + "| 0.00004 |", + "| 0.00004 |", + "| 0.00004 |", + "| 0.00005 |", + "| 0.00005 |", + "| 0.00005 |", + "| 0.00005 |", + "| 0.00005 |", + "+---------+", + ]; + + assert_batches_eq!(expected, &results); + assert_batches_eq!(expected, &results_all); + + Ok(()) +} + +#[tokio::test] +async fn select_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = "SELECT DISTINCT * FROM aggregate_simple"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + + let mut dedup = actual.clone(); + dedup.dedup(); + + assert_eq!(actual, dedup); + + Ok(()) +} + +#[tokio::test] +async fn select_distinct_simple_1() { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await.unwrap(); + + let sql = "SELECT DISTINCT c1 FROM aggregate_simple order by c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00001 |", + "| 0.00002 |", + "| 0.00003 |", + "| 0.00004 |", + "| 0.00005 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn select_distinct_simple_2() { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await.unwrap(); + + let sql = "SELECT DISTINCT c1, c2 FROM aggregate_simple order by c1"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+---------+----------------+", + "| c1 | c2 |", + "+---------+----------------+", + "| 0.00001 | 0.000000000001 |", + "| 0.00002 | 0.000000000002 |", + "| 0.00003 | 0.000000000003 |", + "| 0.00004 | 0.000000000004 |", + "| 0.00005 | 0.000000000005 |", + "+---------+----------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn select_distinct_simple_3() { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await.unwrap(); + + let sql = "SELECT distinct c3 FROM aggregate_simple order by c3"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------+", + "| c3 |", + "+-------+", + "| false |", + "| true |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn select_distinct_simple_4() { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await.unwrap(); + + let sql = "SELECT distinct c1+c2 as a FROM aggregate_simple"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------------------------+", + "| a |", + "+-------------------------+", + "| 0.000030000002242136256 |", + "| 0.000040000002989515004 |", + "| 0.000010000000747378751 |", + "| 0.00005000000373689376 |", + "| 0.000020000001494757502 |", + "+-------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); +} + +#[tokio::test] +async fn select_distinct_from() { + let mut ctx = ExecutionContext::new(); + + let sql = "select + 1 IS DISTINCT FROM CAST(NULL as INT) as a, + 1 IS DISTINCT FROM 1 as b, + 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, + 1 IS NOT DISTINCT FROM 1 as d, + NULL IS DISTINCT FROM NULL as e, + NULL IS NOT DISTINCT FROM NULL as f + "; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------+-------+-------+------+-------+------+", + "| a | b | c | d | e | f |", + "+------+-------+-------+------+-------+------+", + "| true | false | false | true | false | true |", + "+------+-------+-------+------+-------+------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn select_distinct_from_utf8() { + let mut ctx = ExecutionContext::new(); + + let sql = "select + 'x' IS DISTINCT FROM NULL as a, + 'x' IS DISTINCT FROM 'x' as b, + 'x' IS NOT DISTINCT FROM NULL as c, + 'x' IS NOT DISTINCT FROM 'x' as d + "; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------+-------+-------+------+", + "| a | b | c | d |", + "+------+-------+-------+------+", + "| true | false | false | true |", + "+------+-------+-------+------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_with_decimal_by_sql() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; + let sql = "SELECT c1 from aggregate_simple"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| c1 |", + "+----------+", + "| 0.000010 |", + "| 0.000020 |", + "| 0.000020 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn use_between_expression_in_select_query() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let sql = "SELECT 1 NOT BETWEEN 3 AND 5"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------------------------------+", + "| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |", + "+--------------------------------------------+", + "| true |", + "+--------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let input = Int64Array::from(vec![1, 2, 3, 4]); + let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("test", Arc::new(table))?; + + let sql = "SELECT abs(c1) BETWEEN 0 AND LoG(c1 * 100 ) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + // Expect field name to be correctly converted for expr, low and high. + let expected = vec![ + "+--------------------------------------------------------------------+", + "| abs(test.c1) BETWEEN Int64(0) AND log(test.c1 Multiply Int64(100)) |", + "+--------------------------------------------------------------------+", + "| true |", + "| true |", + "| false |", + "| false |", + "+--------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "EXPLAIN SELECT c1 BETWEEN 2 AND 3 FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let formatted = arrow::util::pretty::pretty_format_batches(&actual).unwrap(); + + // Only test that the projection exprs arecorrect, rather than entire output + let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]"; + assert_contains!(&formatted, needle); + let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)"; + assert_contains!(&formatted, needle); + + Ok(()) +} + +#[tokio::test] +async fn query_get_indexed_field() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new( + "some_list", + DataType::List(Box::new(Field::new("item", DataType::Int64, true))), + false, + )])); + let builder = PrimitiveBuilder::::new(3); + let mut lb = ListBuilder::new(builder); + for int_vec in vec![vec![0, 1, 2], vec![4, 5, 6], vec![7, 8, 9]] { + let builder = lb.values(); + for int in int_vec { + builder.append_value(int).unwrap(); + } + lb.append(true).unwrap(); + } + + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("ints", table_a)?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 7 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_nested_get_indexed_field() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); + // Nested schema of { "some_list": [[i64]] } + let schema = Arc::new(Schema::new(vec![Field::new( + "some_list", + DataType::List(Box::new(Field::new("item", nested_dt.clone(), true))), + false, + )])); + + let builder = PrimitiveBuilder::::new(3); + let nested_lb = ListBuilder::new(builder); + let mut lb = ListBuilder::new(nested_lb); + for int_vec_vec in vec![ + vec![vec![0, 1], vec![2, 3], vec![3, 4]], + vec![vec![5, 6], vec![7, 8], vec![9, 10]], + vec![vec![11, 12], vec![13, 14], vec![15, 16]], + ] { + let nested_builder = lb.values(); + for int_vec in int_vec_vec { + let builder = nested_builder.values(); + for int in int_vec { + builder.append_value(int).unwrap(); + } + nested_builder.append(true).unwrap(); + } + lb.append(true).unwrap(); + } + + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("ints", table_a)?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| i0 |", + "+----------+", + "| [0, 1] |", + "| [5, 6] |", + "| [11, 12] |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_nested_get_indexed_field_on_struct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); + // Nested schema of { "some_struct": { "bar": [i64] } } + let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)]; + let schema = Arc::new(Schema::new(vec![Field::new( + "some_struct", + DataType::Struct(struct_fields.clone()), + false, + )])); + + let builder = PrimitiveBuilder::::new(3); + let nested_lb = ListBuilder::new(builder); + let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]); + for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]] { + let lb = sb.field_builder::>(0).unwrap(); + for int in int_vec { + lb.values().append_value(int).unwrap(); + } + lb.append(true).unwrap(); + } + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(sb.finish())])?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("structs", table_a)?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------+", + "| l0 |", + "+----------------+", + "| [0, 1, 2, 3] |", + "| [4, 5, 6, 7] |", + "| [8, 9, 10, 11] |", + "+----------------+", + ]; + assert_batches_eq!(expected, &actual); + let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_on_string_dictionary() -> Result<()> { + // Test to ensure DataFusion can operate on dictionary types + // Use StringDictionary (32 bit indexes = keys) + let array = vec![Some("one"), None, Some("three")] + .into_iter() + .collect::>(); + + let batch = + RecordBatch::try_from_iter(vec![("d1", Arc::new(array) as ArrayRef)]).unwrap(); + + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + + // Basic SELECT + let sql = "SELECT * FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| one |", + "| |", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // basic filtering + let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| one |", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // filtering with constant + let sql = "SELECT * FROM test WHERE d1 = 'three'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // Expression evaluation + let sql = "SELECT concat(d1, '-foo') FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+------------------------------+", + "| concat(test.d1,Utf8(\"-foo\")) |", + "+------------------------------+", + "| one-foo |", + "| -foo |", + "| three-foo |", + "+------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + // aggregation + let sql = "SELECT COUNT(d1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------+", + "| COUNT(test.d1) |", + "+----------------+", + "| 2 |", + "+----------------+", + ]; + assert_batches_eq!(expected, &actual); + + // aggregation min + let sql = "SELECT MIN(d1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------+", + "| MIN(test.d1) |", + "+--------------+", + "| one |", + "+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + // aggregation max + let sql = "SELECT MAX(d1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------+", + "| MAX(test.d1) |", + "+--------------+", + "| three |", + "+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + // grouping + let sql = "SELECT d1, COUNT(*) FROM test group by d1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+-----------------+", + "| d1 | COUNT(UInt8(1)) |", + "+-------+-----------------+", + "| one | 1 |", + "| | 1 |", + "| three | 1 |", + "+-------+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + // window functions + let sql = "SELECT d1, row_number() OVER (partition by d1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+--------------+", + "| d1 | ROW_NUMBER() |", + "+-------+--------------+", + "| | 1 |", + "| one | 1 |", + "| three | 1 |", + "+-------+--------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn query_cte() -> Result<()> { + // Test for SELECT without FROM. + // Should evaluate expressions in project position. + let mut ctx = ExecutionContext::new(); + + // simple with + let sql = "WITH t AS (SELECT 1) SELECT * FROM t"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| Int64(1) |", + "+----------+", + "| 1 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + + // with + union + let sql = + "WITH t AS (SELECT 1 AS a), u AS (SELECT 2 AS a) SELECT * FROM t UNION ALL SELECT * FROM u"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "| 2 |", "+---+"]; + assert_batches_eq!(expected, &actual); + + // with + join + let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT 1 AS id2, 5 as x) SELECT x FROM t JOIN u ON (id1 = id2)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["+---+", "| x |", "+---+", "| 5 |", "+---+"]; + assert_batches_eq!(expected, &actual); + + // backward reference + let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT * FROM t) SELECT * from u"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["+-----+", "| id1 |", "+-----+", "| 1 |", "+-----+"]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn csv_select_nested() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT o1, o2, c3 + FROM ( + SELECT c1 AS o1, c2 + 1 AS o2, c3 + FROM ( + SELECT c1, c2, c3, c4 + FROM aggregate_test_100 + WHERE c1 = 'a' AND c2 >= 4 + ORDER BY c2 ASC, c3 ASC + ) AS a + ) AS b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----+----+------+", + "| o1 | o2 | c3 |", + "+----+----+------+", + "| a | 5 | -101 |", + "| a | 5 | -54 |", + "| a | 5 | -38 |", + "| a | 5 | 65 |", + "| a | 6 | -101 |", + "| a | 6 | -31 |", + "| a | 6 | 36 |", + "+----+----+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/timestamp.rs b/datafusion/tests/sql/timestamp.rs new file mode 100644 index 000000000000..9c5d59e5a937 --- /dev/null +++ b/datafusion/tests/sql/timestamp.rs @@ -0,0 +1,814 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn query_cast_timestamp_millis() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![Arc::new(Int64Array::from(vec![ + 1235865600000, + 1235865660000, + 1238544000000, + ]))], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let sql = "SELECT to_timestamp_millis(ts) FROM t1 LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+--------------------------+", + "| totimestampmillis(t1.ts) |", + "+--------------------------+", + "| 2009-03-01 00:00:00 |", + "| 2009-03-01 00:01:00 |", + "| 2009-04-01 00:00:00 |", + "+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_cast_timestamp_micros() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![Arc::new(Int64Array::from(vec![ + 1235865600000000, + 1235865660000000, + 1238544000000000, + ]))], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let sql = "SELECT to_timestamp_micros(ts) FROM t1 LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+--------------------------+", + "| totimestampmicros(t1.ts) |", + "+--------------------------+", + "| 2009-03-01 00:00:00 |", + "| 2009-03-01 00:01:00 |", + "| 2009-04-01 00:00:00 |", + "+--------------------------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_cast_timestamp_seconds() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![Arc::new(Int64Array::from(vec![ + 1235865600, 1235865660, 1238544000, + ]))], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let sql = "SELECT to_timestamp_seconds(ts) FROM t1 LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+---------------------------+", + "| totimestampseconds(t1.ts) |", + "+---------------------------+", + "| 2009-03-01 00:00:00 |", + "| 2009-03-01 00:01:00 |", + "| 2009-04-01 00:00:00 |", + "+---------------------------+", + ]; + + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_cast_timestamp_nanos_to_others() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("ts_data", make_timestamp_nano_table()?)?; + + // Original column is nanos, convert to millis and check timestamp + let sql = "SELECT to_timestamp_millis(ts) FROM ts_data LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------------------------------+", + "| totimestampmillis(ts_data.ts) |", + "+-------------------------------+", + "| 2020-09-08 13:42:29.190 |", + "| 2020-09-08 12:42:29.190 |", + "| 2020-09-08 11:42:29.190 |", + "+-------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT to_timestamp_micros(ts) FROM ts_data LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------------------------------+", + "| totimestampmicros(ts_data.ts) |", + "+-------------------------------+", + "| 2020-09-08 13:42:29.190855 |", + "| 2020-09-08 12:42:29.190855 |", + "| 2020-09-08 11:42:29.190855 |", + "+-------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT to_timestamp_seconds(ts) FROM ts_data LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------------------+", + "| totimestampseconds(ts_data.ts) |", + "+--------------------------------+", + "| 2020-09-08 13:42:29 |", + "| 2020-09-08 12:42:29 |", + "| 2020-09-08 11:42:29 |", + "+--------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn query_cast_timestamp_seconds_to_others() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("ts_secs", make_timestamp_table::()?)?; + + // Original column is seconds, convert to millis and check timestamp + let sql = "SELECT to_timestamp_millis(ts) FROM ts_secs LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------------+", + "| totimestampmillis(ts_secs.ts) |", + "+-------------------------------+", + "| 2020-09-08 13:42:29 |", + "| 2020-09-08 12:42:29 |", + "| 2020-09-08 11:42:29 |", + "+-------------------------------+", + ]; + + assert_batches_eq!(expected, &actual); + + // Original column is seconds, convert to micros and check timestamp + let sql = "SELECT to_timestamp_micros(ts) FROM ts_secs LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------------+", + "| totimestampmicros(ts_secs.ts) |", + "+-------------------------------+", + "| 2020-09-08 13:42:29 |", + "| 2020-09-08 12:42:29 |", + "| 2020-09-08 11:42:29 |", + "+-------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + // to nanos + let sql = "SELECT to_timestamp(ts) FROM ts_secs LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+", + "| totimestamp(ts_secs.ts) |", + "+-------------------------+", + "| 2020-09-08 13:42:29 |", + "| 2020-09-08 12:42:29 |", + "| 2020-09-08 11:42:29 |", + "+-------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn query_cast_timestamp_micros_to_others() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table( + "ts_micros", + make_timestamp_table::()?, + )?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT to_timestamp_millis(ts) FROM ts_micros LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------------------+", + "| totimestampmillis(ts_micros.ts) |", + "+---------------------------------+", + "| 2020-09-08 13:42:29.190 |", + "| 2020-09-08 12:42:29.190 |", + "| 2020-09-08 11:42:29.190 |", + "+---------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + // Original column is micros, convert to seconds and check timestamp + let sql = "SELECT to_timestamp_seconds(ts) FROM ts_micros LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------------+", + "| totimestampseconds(ts_micros.ts) |", + "+----------------------------------+", + "| 2020-09-08 13:42:29 |", + "| 2020-09-08 12:42:29 |", + "| 2020-09-08 11:42:29 |", + "+----------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + // Original column is micros, convert to nanos and check timestamp + let sql = "SELECT to_timestamp(ts) FROM ts_micros LIMIT 3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+", + "| totimestamp(ts_micros.ts) |", + "+----------------------------+", + "| 2020-09-08 13:42:29.190855 |", + "| 2020-09-08 12:42:29.190855 |", + "| 2020-09-08 11:42:29.190855 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn to_timestamp() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("ts_data", make_timestamp_nano_table()?)?; + + let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 2 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn to_timestamp_millis() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table( + "ts_data", + make_timestamp_table::()?, + )?; + + let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_millis('2020-09-08T12:00:00+00:00')"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 2 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn to_timestamp_micros() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table( + "ts_data", + make_timestamp_table::()?, + )?; + + let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_micros('2020-09-08T12:00:00+00:00')"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 2 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn to_timestamp_seconds() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("ts_data", make_timestamp_table::()?)?; + + let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_seconds('2020-09-08T12:00:00+00:00')"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 2 |", + "+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn count_distinct_timestamps() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("ts_data", make_timestamp_nano_table()?)?; + + let sql = "SELECT COUNT(DISTINCT(ts)) FROM ts_data"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+----------------------------+", + "| COUNT(DISTINCT ts_data.ts) |", + "+----------------------------+", + "| 3 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn test_current_timestamp_expressions() -> Result<()> { + let t1 = chrono::Utc::now().timestamp(); + let mut ctx = ExecutionContext::new(); + let actual = execute(&mut ctx, "SELECT NOW(), NOW() as t2").await; + let res1 = actual[0][0].as_str(); + let res2 = actual[0][1].as_str(); + let t3 = chrono::Utc::now().timestamp(); + let t2_naive = + chrono::NaiveDateTime::parse_from_str(res1, "%Y-%m-%d %H:%M:%S%.6f").unwrap(); + + let t2 = t2_naive.timestamp(); + assert!(t1 <= t2 && t2 <= t3); + assert_eq!(res2, res1); + + Ok(()) +} + +#[tokio::test] +async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { + let t1 = chrono::Utc::now().timestamp(); + let ctx = ExecutionContext::new(); + let sql = "SELECT NOW(), NOW() as t2"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).await.expect(&msg); + + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let res = collect(plan).await.expect(&msg); + let actual = result_vec(&res); + + let res1 = actual[0][0].as_str(); + let res2 = actual[0][1].as_str(); + let t3 = chrono::Utc::now().timestamp(); + let t2_naive = + chrono::NaiveDateTime::parse_from_str(res1, "%Y-%m-%d %H:%M:%S%.6f").unwrap(); + + let t2 = t2_naive.timestamp(); + assert!(t1 <= t2 && t2 <= t3); + assert_eq!(res2, res1); + + Ok(()) +} + +#[tokio::test] +async fn timestamp_minmax() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_tz_table::(None)?; + let table_b = + make_timestamp_tz_table::(Some("UTC".to_owned()))?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT MIN(table_a.ts), MAX(table_b.ts) FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+----------------------------+", + "| MIN(table_a.ts) | MAX(table_b.ts) |", + "+-------------------------+----------------------------+", + "| 2020-09-08 11:42:29.190 | 2020-09-08 13:42:29.190855 |", + "+-------------------------+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn timestamp_coercion() -> Result<()> { + { + let mut ctx = ExecutionContext::new(); + let table_a = + make_timestamp_tz_table::(Some("UTC".to_owned()))?; + let table_b = + make_timestamp_tz_table::(Some("UTC".to_owned()))?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------+-------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+---------------------+-------------------------+--------------------------+", + "| 2020-09-08 13:42:29 | 2020-09-08 13:42:29.190 | true |", + "| 2020-09-08 13:42:29 | 2020-09-08 12:42:29.190 | false |", + "| 2020-09-08 13:42:29 | 2020-09-08 11:42:29.190 | false |", + "| 2020-09-08 12:42:29 | 2020-09-08 13:42:29.190 | false |", + "| 2020-09-08 12:42:29 | 2020-09-08 12:42:29.190 | true |", + "| 2020-09-08 12:42:29 | 2020-09-08 11:42:29.190 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 13:42:29.190 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 12:42:29.190 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 11:42:29.190 | true |", + "+---------------------+-------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------+----------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+---------------------+----------------------------+--------------------------+", + "| 2020-09-08 13:42:29 | 2020-09-08 13:42:29.190855 | true |", + "| 2020-09-08 13:42:29 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 13:42:29 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 12:42:29 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 12:42:29 | 2020-09-08 12:42:29.190855 | true |", + "| 2020-09-08 12:42:29 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 11:42:29.190855 | true |", + "+---------------------+----------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------+----------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+---------------------+----------------------------+--------------------------+", + "| 2020-09-08 13:42:29 | 2020-09-08 13:42:29.190855 | true |", + "| 2020-09-08 13:42:29 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 13:42:29 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 12:42:29 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 12:42:29 | 2020-09-08 12:42:29.190855 | true |", + "| 2020-09-08 12:42:29 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 11:42:29 | 2020-09-08 11:42:29.190855 | true |", + "+---------------------+----------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+---------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+-------------------------+---------------------+--------------------------+", + "| 2020-09-08 13:42:29.190 | 2020-09-08 13:42:29 | true |", + "| 2020-09-08 13:42:29.190 | 2020-09-08 12:42:29 | false |", + "| 2020-09-08 13:42:29.190 | 2020-09-08 11:42:29 | false |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 13:42:29 | false |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 12:42:29 | true |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 11:42:29 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 13:42:29 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 12:42:29 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 11:42:29 | true |", + "+-------------------------+---------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+----------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+-------------------------+----------------------------+--------------------------+", + "| 2020-09-08 13:42:29.190 | 2020-09-08 13:42:29.190855 | true |", + "| 2020-09-08 13:42:29.190 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 13:42:29.190 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 12:42:29.190855 | true |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 11:42:29.190855 | true |", + "+-------------------------+----------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+----------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+-------------------------+----------------------------+--------------------------+", + "| 2020-09-08 13:42:29.190 | 2020-09-08 13:42:29.190855 | true |", + "| 2020-09-08 13:42:29.190 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 13:42:29.190 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 12:42:29.190855 | true |", + "| 2020-09-08 12:42:29.190 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190 | 2020-09-08 11:42:29.190855 | true |", + "+-------------------------+----------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+---------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+----------------------------+---------------------+--------------------------+", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 13:42:29 | true |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 12:42:29 | false |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 11:42:29 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 13:42:29 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 12:42:29 | true |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 11:42:29 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 13:42:29 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 12:42:29 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 11:42:29 | true |", + "+----------------------------+---------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+-------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+----------------------------+-------------------------+--------------------------+", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 13:42:29.190 | true |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 12:42:29.190 | false |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 11:42:29.190 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 13:42:29.190 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 12:42:29.190 | true |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 11:42:29.190 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 13:42:29.190 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 12:42:29.190 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 11:42:29.190 | true |", + "+----------------------------+-------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+----------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+----------------------------+----------------------------+--------------------------+", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 13:42:29.190855 | true |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 12:42:29.190855 | true |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 11:42:29.190855 | true |", + "+----------------------------+----------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+---------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+----------------------------+---------------------+--------------------------+", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 13:42:29 | true |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 12:42:29 | false |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 11:42:29 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 13:42:29 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 12:42:29 | true |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 11:42:29 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 13:42:29 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 12:42:29 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 11:42:29 | true |", + "+----------------------------+---------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+-------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+----------------------------+-------------------------+--------------------------+", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 13:42:29.190 | true |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 12:42:29.190 | false |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 11:42:29.190 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 13:42:29.190 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 12:42:29.190 | true |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 11:42:29.190 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 13:42:29.190 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 12:42:29.190 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 11:42:29.190 | true |", + "+----------------------------+-------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + { + let mut ctx = ExecutionContext::new(); + let table_a = make_timestamp_table::()?; + let table_b = make_timestamp_table::()?; + ctx.register_table("table_a", table_a)?; + ctx.register_table("table_b", table_b)?; + + let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+----------------------------+--------------------------+", + "| ts | ts | table_a.ts Eq table_b.ts |", + "+----------------------------+----------------------------+--------------------------+", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 13:42:29.190855 | true |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 13:42:29.190855 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 12:42:29.190855 | true |", + "| 2020-09-08 12:42:29.190855 | 2020-09-08 11:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 13:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 12:42:29.190855 | false |", + "| 2020-09-08 11:42:29.190855 | 2020-09-08 11:42:29.190855 | true |", + "+----------------------------+----------------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &actual); + } + + Ok(()) +} + +#[tokio::test] +async fn group_by_timestamp_millis() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("count", DataType::Int32, false), + ])); + let base_dt = Utc.ymd(2018, 7, 1).and_hms(6, 0, 0); // 2018-Jul-01 06:00 + let hour1 = Duration::hours(1); + let timestamps = vec![ + base_dt.timestamp_millis(), + (base_dt + hour1).timestamp_millis(), + base_dt.timestamp_millis(), + base_dt.timestamp_millis(), + (base_dt + hour1).timestamp_millis(), + (base_dt + hour1).timestamp_millis(), + ]; + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(TimestampMillisecondArray::from(timestamps)), + Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50, 60])), + ], + )?; + let t1_table = MemTable::try_new(schema, vec![vec![data]])?; + ctx.register_table("t1", Arc::new(t1_table)).unwrap(); + + let sql = + "SELECT timestamp, SUM(count) FROM t1 GROUP BY timestamp ORDER BY timestamp ASC"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+---------------------+---------------+", + "| timestamp | SUM(t1.count) |", + "+---------------------+---------------+", + "| 2018-07-01 06:00:00 | 80 |", + "| 2018-07-01 07:00:00 | 130 |", + "+---------------------+---------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/udf.rs b/datafusion/tests/sql/udf.rs new file mode 100644 index 000000000000..db42574c1bd0 --- /dev/null +++ b/datafusion/tests/sql/udf.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +/// test that casting happens on udfs. +/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and +/// physical plan have the same schema. +#[tokio::test] +async fn csv_query_custom_udf_with_cast() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.6584408483418833"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/unicode.rs b/datafusion/tests/sql/unicode.rs new file mode 100644 index 000000000000..28a0c83d17d9 --- /dev/null +++ b/datafusion/tests/sql/unicode.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn query_length() -> Result<()> { + generic_query_length::(DataType::Utf8).await +} + +#[tokio::test] +async fn query_large_length() -> Result<()> { + generic_query_length::(DataType::LargeUtf8).await +} + +#[tokio::test] +async fn test_unicode_expressions() -> Result<()> { + test_expression!("char_length('')", "0"); + test_expression!("char_length('chars')", "5"); + test_expression!("char_length('josé')", "4"); + test_expression!("char_length(NULL)", "NULL"); + test_expression!("character_length('')", "0"); + test_expression!("character_length('chars')", "5"); + test_expression!("character_length('josé')", "4"); + test_expression!("character_length(NULL)", "NULL"); + test_expression!("left('abcde', -2)", "abc"); + test_expression!("left('abcde', -200)", ""); + test_expression!("left('abcde', 0)", ""); + test_expression!("left('abcde', 2)", "ab"); + test_expression!("left('abcde', 200)", "abcde"); + test_expression!("left('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("left(NULL, 2)", "NULL"); + test_expression!("left(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("length('')", "0"); + test_expression!("length('chars')", "5"); + test_expression!("length('josé')", "4"); + test_expression!("length(NULL)", "NULL"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 0)", ""); + test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 5, NULL)", "NULL"); + test_expression!("lpad('hi', 5)", " hi"); + test_expression!("lpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); + test_expression!("lpad('hi', CAST(NULL AS INT))", "NULL"); + test_expression!("lpad('xyxhi', 3)", "xyx"); + test_expression!("lpad(NULL, 0)", "NULL"); + test_expression!("lpad(NULL, 5, 'xy')", "NULL"); + test_expression!("reverse('abcde')", "edcba"); + test_expression!("reverse('loẅks')", "skẅol"); + test_expression!("reverse(NULL)", "NULL"); + test_expression!("right('abcde', -2)", "cde"); + test_expression!("right('abcde', -200)", ""); + test_expression!("right('abcde', 0)", ""); + test_expression!("right('abcde', 2)", "de"); + test_expression!("right('abcde', 200)", "abcde"); + test_expression!("right('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("right(NULL, 2)", "NULL"); + test_expression!("right(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 0)", ""); + test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 5, NULL)", "NULL"); + test_expression!("rpad('hi', 5)", "hi "); + test_expression!("rpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); + test_expression!("rpad('hi', CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('xyxhi', 3)", "xyx"); + test_expression!("strpos('abc', 'c')", "3"); + test_expression!("strpos('josé', 'é')", "4"); + test_expression!("strpos('joséésoj', 'so')", "6"); + test_expression!("strpos('joséésoj', 'abc')", "0"); + test_expression!("strpos(NULL, 'abc')", "NULL"); + test_expression!("strpos('joséésoj', NULL)", "NULL"); + test_expression!("substr('alphabet', -3)", "alphabet"); + test_expression!("substr('alphabet', 0)", "alphabet"); + test_expression!("substr('alphabet', 1)", "alphabet"); + test_expression!("substr('alphabet', 2)", "lphabet"); + test_expression!("substr('alphabet', 3)", "phabet"); + test_expression!("substr('alphabet', 30)", ""); + test_expression!("substr('alphabet', CAST(NULL AS int))", "NULL"); + test_expression!("substr('alphabet', 3, 2)", "ph"); + test_expression!("substr('alphabet', 3, 20)", "phabet"); + test_expression!("substr('alphabet', CAST(NULL AS int), 20)", "NULL"); + test_expression!("substr('alphabet', 3, CAST(NULL AS int))", "NULL"); + test_expression!("translate('12345', '143', 'ax')", "a2x5"); + test_expression!("translate(NULL, '143', 'ax')", "NULL"); + test_expression!("translate('12345', NULL, 'ax')", "NULL"); + test_expression!("translate('12345', '143', NULL)", "NULL"); + Ok(()) +} diff --git a/datafusion/tests/sql/union.rs b/datafusion/tests/sql/union.rs new file mode 100644 index 000000000000..a1f81d24f456 --- /dev/null +++ b/datafusion/tests/sql/union.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn union_all() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT 1 as x UNION ALL SELECT 2 as x"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "| 2 |", "+---+"]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_union_all() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = + "SELECT c1 FROM aggregate_test_100 UNION ALL SELECT c1 FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + assert_eq!(actual.len(), 200); + Ok(()) +} + +#[tokio::test] +async fn union_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT 1 as x UNION SELECT 1 as x"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "+---+"]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn union_all_with_aggregate() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = + "SELECT SUM(d) FROM (SELECT 1 as c, 2 as d UNION ALL SELECT 1 as c, 3 AS d) as a"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| SUM(a.d) |", + "+----------+", + "| 5 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/datafusion/tests/sql/window.rs b/datafusion/tests/sql/window.rs new file mode 100644 index 000000000000..321ab320f5be --- /dev/null +++ b/datafusion/tests/sql/window.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +/// for window functions without order by the first, last, and nth function call does not make sense +#[tokio::test] +async fn csv_query_window_with_empty_over() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select \ + c9, \ + count(c5) over (), \ + max(c5) over (), \ + min(c5) over () \ + from aggregate_test_100 \ + order by c9 \ + limit 5"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------+------------------------------+----------------------------+----------------------------+", + "| c9 | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) |", + "+-----------+------------------------------+----------------------------+----------------------------+", + "| 28774375 | 100 | 2143473091 | -2141999138 |", + "| 63044568 | 100 | 2143473091 | -2141999138 |", + "| 141047417 | 100 | 2143473091 | -2141999138 |", + "| 141680161 | 100 | 2143473091 | -2141999138 |", + "| 145294611 | 100 | 2143473091 | -2141999138 |", + "+-----------+------------------------------+----------------------------+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +/// for window functions without order by the first, last, and nth function call does not make sense +#[tokio::test] +async fn csv_query_window_with_partition_by() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select \ + c9, \ + sum(cast(c4 as Int)) over (partition by c3), \ + avg(cast(c4 as Int)) over (partition by c3), \ + count(cast(c4 as Int)) over (partition by c3), \ + max(cast(c4 as Int)) over (partition by c3), \ + min(cast(c4 as Int)) over (partition by c3) \ + from aggregate_test_100 \ + order by c9 \ + limit 5"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", + "| c9 | SUM(CAST(aggregate_test_100.c4 AS Int32)) | AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) | MIN(CAST(aggregate_test_100.c4 AS Int32)) |", + "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", + "| 28774375 | -16110 | -16110 | 1 | -16110 | -16110 |", + "| 63044568 | 3917 | 3917 | 1 | 3917 | 3917 |", + "| 141047417 | -38455 | -19227.5 | 2 | -16974 | -21481 |", + "| 141680161 | -1114 | -1114 | 1 | -1114 | -1114 |", + "| 145294611 | 15673 | 15673 | 1 | 15673 | 15673 |", + "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_window_with_order_by() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select \ + c9, \ + sum(c5) over (order by c9), \ + avg(c5) over (order by c9), \ + count(c5) over (order by c9), \ + max(c5) over (order by c9), \ + min(c5) over (order by c9), \ + first_value(c5) over (order by c9), \ + last_value(c5) over (order by c9), \ + nth_value(c5, 2) over (order by c9) \ + from aggregate_test_100 \ + order by c9 \ + limit 5"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", + "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", + "| 28774375 | 61035129 | 61035129 | 1 | 61035129 | 61035129 | 61035129 | 61035129 | |", + "| 63044568 | -47938237 | -23969118.5 | 2 | 61035129 | -108973366 | 61035129 | -108973366 | -108973366 |", + "| 141047417 | 575165281 | 191721760.33333334 | 3 | 623103518 | -108973366 | 61035129 | 623103518 | -108973366 |", + "| 141680161 | -1352462829 | -338115707.25 | 4 | 623103518 | -1927628110 | 61035129 | -1927628110 | -108973366 |", + "| 145294611 | -3251637940 | -650327588 | 5 | 623103518 | -1927628110 | 61035129 | -1899175111 | -108973366 |", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_window_with_partition_by_order_by() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select \ + c9, \ + sum(c5) over (partition by c4 order by c9), \ + avg(c5) over (partition by c4 order by c9), \ + count(c5) over (partition by c4 order by c9), \ + max(c5) over (partition by c4 order by c9), \ + min(c5) over (partition by c4 order by c9), \ + first_value(c5) over (partition by c4 order by c9), \ + last_value(c5) over (partition by c4 order by c9), \ + nth_value(c5, 2) over (partition by c4 order by c9) \ + from aggregate_test_100 \ + order by c9 \ + limit 5"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", + "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", + "| 28774375 | 61035129 | 61035129 | 1 | 61035129 | 61035129 | 61035129 | 61035129 | |", + "| 63044568 | -108973366 | -108973366 | 1 | -108973366 | -108973366 | -108973366 | -108973366 | |", + "| 141047417 | 623103518 | 623103518 | 1 | 623103518 | 623103518 | 623103518 | 623103518 | |", + "| 141680161 | -1927628110 | -1927628110 | 1 | -1927628110 | -1927628110 | -1927628110 | -1927628110 | |", + "| 145294611 | -1899175111 | -1899175111 | 1 | -1899175111 | -1899175111 | -1899175111 | -1899175111 | |", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+" + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} diff --git a/docs/source/community/communication.md b/docs/source/community/communication.md index 76aa0ea36fa7..b34b913c6f56 100644 --- a/docs/source/community/communication.md +++ b/docs/source/community/communication.md @@ -76,9 +76,8 @@ Our source code is hosted on [GitHub](https://github.com/apache/arrow-datafusion). For developers new to the project, we have curated a [good-first-issue](https://github.com/apache/arrow-datafusion/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) -list to help you get started. +list to help you get started. You can find datafusion's major designs in docs/source/specification. We use GitHub issues for maintaining a queue of development work and as the public record. We often use Google docs, Github issues and pull requests for -quick and small design discussions. For major design change proposals, please -make sure to send them to the dev list for more visibility. +quick and small design discussions. For major design change proposals, we encourage you to write a rfc. diff --git a/docs/source/specification/rfcs/template.md b/docs/source/specification/rfcs/template.md new file mode 100644 index 000000000000..98704fd46fe9 --- /dev/null +++ b/docs/source/specification/rfcs/template.md @@ -0,0 +1,58 @@ + + +Feature Name: + +Status: draft/in-progress/completed/ + +Start Date: YYYY-MM-DD + +Authors: + +RFC PR: # + +Datafusion Issue: # + +--- + +### Background + +--- + +### Goals + +--- + +### Non-Goals + +--- + +### Survey + +--- + +### General design + +--- + +### Detailed design + +--- + +### Others diff --git a/docs/source/specification/roadmap.md b/docs/source/specification/roadmap.md index 09f636f3bb7f..76b2896aa71c 100644 --- a/docs/source/specification/roadmap.md +++ b/docs/source/specification/roadmap.md @@ -49,16 +49,15 @@ to provide: ## Additional SQL Language Features +- Decimal Support [#122](https://github.com/apache/arrow-datafusion/issues/122) - Complete support list on [status](https://github.com/apache/arrow-datafusion/blob/master/README.md#status) - Timestamp Arithmetic [#194](https://github.com/apache/arrow-datafusion/issues/194) - SQL Parser extension point [#533](https://github.com/apache/arrow-datafusion/issues/533) - Support for nested structures (fields, lists, structs) [#119](https://github.com/apache/arrow-datafusion/issues/119) -- Remaining Set Operators (`INTERSECT` / `EXCEPT`) [#1082](https://github.com/apache/arrow-datafusion/issues/1082) - Run all queries from the TPCH benchmark (see [milestone](https://github.com/apache/arrow-datafusion/milestone/2) for more details) ## Query Optimizer -- Additional constant folding / partial evaluation [#1070](https://github.com/apache/arrow-datafusion/issues/1070) - More sophisticated cost based optimizer for join ordering - Implement advanced query optimization framework (Tokomak) #440 - Finer optimizations for group by and aggregate functions @@ -66,7 +65,6 @@ to provide: ## Datasources - Better support for reading data from remote filesystems (e.g. S3) without caching it locally [#907](https://github.com/apache/arrow-datafusion/issues/907) [#1060](https://github.com/apache/arrow-datafusion/issues/1060) -- Support for partitioned datasources [#1139](https://github.com/apache/arrow-datafusion/issues/1139) and make the integration of other table formats (Delta, Iceberg...) simpler - Improve performances of file format datasources (parallelize file listings, async Arrow readers, file chunk prefetching capability...) ## Runtime / Infrastructure diff --git a/python/.cargo/config b/python/.cargo/config deleted file mode 100644 index 0b24f30cf908..000000000000 --- a/python/.cargo/config +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[target.x86_64-apple-darwin] -rustflags = [ - "-C", "link-arg=-undefined", - "-C", "link-arg=dynamic_lookup", -] diff --git a/python/.dockerignore b/python/.dockerignore deleted file mode 100644 index 08c131c2e7d6..000000000000 --- a/python/.dockerignore +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -target -venv diff --git a/python/.gitignore b/python/.gitignore deleted file mode 100644 index 586db7c4a5b3..000000000000 --- a/python/.gitignore +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -/target -venv -.venv diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md deleted file mode 100644 index a07cb003c5cd..000000000000 --- a/python/CHANGELOG.md +++ /dev/null @@ -1,129 +0,0 @@ - - -# Changelog - -## [python-0.4.0](https://github.com/apache/arrow-datafusion/tree/python-0.4.0) (2021-11-13) - -[Full Changelog](https://github.com/apache/arrow-datafusion/compare/python-0.3.0...python-0.4.0) - -**Breaking changes:** - -- Add function volatility to Signature [\#1071](https://github.com/apache/arrow-datafusion/pull/1071) [[sql](https://github.com/apache/arrow-datafusion/labels/sql)] ([pjmore](https://github.com/pjmore)) -- Make TableProvider.scan\(\) and PhysicalPlanner::create\_physical\_plan\(\) async [\#1013](https://github.com/apache/arrow-datafusion/pull/1013) ([rdettai](https://github.com/rdettai)) -- Reorganize table providers by table format [\#1010](https://github.com/apache/arrow-datafusion/pull/1010) ([rdettai](https://github.com/rdettai)) - -**Implemented enhancements:** - -- Build abi3 wheels for python binding [\#921](https://github.com/apache/arrow-datafusion/issues/921) -- Release documentation for python binding [\#837](https://github.com/apache/arrow-datafusion/issues/837) -- use arrow 6.1.0 [\#1255](https://github.com/apache/arrow-datafusion/pull/1255) ([Jimexist](https://github.com/Jimexist)) -- python `lit` function to support bool and byte vec [\#1152](https://github.com/apache/arrow-datafusion/pull/1152) ([Jimexist](https://github.com/Jimexist)) -- add python binding for `approx_distinct` aggregate function [\#1134](https://github.com/apache/arrow-datafusion/pull/1134) ([Jimexist](https://github.com/Jimexist)) -- refactor datafusion python `lit` function to allow different types [\#1130](https://github.com/apache/arrow-datafusion/pull/1130) ([Jimexist](https://github.com/Jimexist)) -- \[python\] add digest python function [\#1127](https://github.com/apache/arrow-datafusion/pull/1127) ([Jimexist](https://github.com/Jimexist)) -- \[crypto\] add `blake3` algorithm to `digest` function [\#1086](https://github.com/apache/arrow-datafusion/pull/1086) ([Jimexist](https://github.com/Jimexist)) -- \[crypto\] add blake2b and blake2s functions [\#1081](https://github.com/apache/arrow-datafusion/pull/1081) ([Jimexist](https://github.com/Jimexist)) -- fix: fix joins on Float32/Float64 columns bug [\#1054](https://github.com/apache/arrow-datafusion/pull/1054) ([francis-du](https://github.com/francis-du)) -- Update DataFusion to arrow 6.0 [\#984](https://github.com/apache/arrow-datafusion/pull/984) ([alamb](https://github.com/alamb)) -- \[Python\] Add support to perform sql query on in-memory datasource. [\#981](https://github.com/apache/arrow-datafusion/pull/981) ([mmuru](https://github.com/mmuru)) -- \[Python\] - Support show function for DataFrame api of python library [\#942](https://github.com/apache/arrow-datafusion/pull/942) ([francis-du](https://github.com/francis-du)) -- Rework the python bindings using conversion traits from arrow-rs [\#873](https://github.com/apache/arrow-datafusion/pull/873) ([kszucs](https://github.com/kszucs)) - -**Fixed bugs:** - -- Error in `python test` check / maturn python build: `function or associated item not found in `proc_macro::Literal` [\#961](https://github.com/apache/arrow-datafusion/issues/961) -- Use UUID to create unique table names in python binding [\#1111](https://github.com/apache/arrow-datafusion/pull/1111) ([hippowdon](https://github.com/hippowdon)) -- python: fix generated table name in dataframe creation [\#1078](https://github.com/apache/arrow-datafusion/pull/1078) ([houqp](https://github.com/houqp)) -- fix: joins on Timestamp columns [\#1055](https://github.com/apache/arrow-datafusion/pull/1055) ([francis-du](https://github.com/francis-du)) -- register datafusion.functions as a python package [\#995](https://github.com/apache/arrow-datafusion/pull/995) ([houqp](https://github.com/houqp)) - -**Documentation updates:** - -- python: update docs to use new APIs [\#1287](https://github.com/apache/arrow-datafusion/pull/1287) ([houqp](https://github.com/houqp)) -- Fix typo on Python functions [\#1207](https://github.com/apache/arrow-datafusion/pull/1207) ([j-a-m-l](https://github.com/j-a-m-l)) -- fix deadlink in python/readme [\#1002](https://github.com/apache/arrow-datafusion/pull/1002) ([waynexia](https://github.com/waynexia)) - -**Performance improvements:** - -- optimize build profile for datafusion python binding, cli and ballista [\#1137](https://github.com/apache/arrow-datafusion/pull/1137) ([houqp](https://github.com/houqp)) - -**Closed issues:** - -- InList expr with NULL literals do not work [\#1190](https://github.com/apache/arrow-datafusion/issues/1190) -- update the homepage README to include values, `approx_distinct`, etc. [\#1171](https://github.com/apache/arrow-datafusion/issues/1171) -- \[Python\]: Inconsistencies with Python package name [\#1011](https://github.com/apache/arrow-datafusion/issues/1011) -- Wanting to contribute to project where to start? [\#983](https://github.com/apache/arrow-datafusion/issues/983) -- delete redundant code [\#973](https://github.com/apache/arrow-datafusion/issues/973) -- \[Python\]: register custom datasource [\#906](https://github.com/apache/arrow-datafusion/issues/906) -- How to build DataFusion python wheel [\#853](https://github.com/apache/arrow-datafusion/issues/853) -- Produce a design for a metrics framework [\#21](https://github.com/apache/arrow-datafusion/issues/21) - - -For older versions, see [apache/arrow/CHANGELOG.md](https://github.com/apache/arrow/blob/master/CHANGELOG.md) - -## [python-0.3.0](https://github.com/apache/arrow-datafusion/tree/python-0.3.0) (2021-08-10) - -[Full Changelog](https://github.com/apache/arrow-datafusion/compare/4.0.0...python-0.3.0) - -**Implemented enhancements:** - -- add more math functions and unit tests to `python` crate [\#748](https://github.com/apache/arrow-datafusion/pull/748) ([Jimexist](https://github.com/Jimexist)) -- Expose ExecutionContext.register\_csv to the python bindings [\#524](https://github.com/apache/arrow-datafusion/pull/524) ([kszucs](https://github.com/kszucs)) -- Implement missing join types for Python dataframe [\#503](https://github.com/apache/arrow-datafusion/pull/503) ([Dandandan](https://github.com/Dandandan)) -- Add missing functions to python [\#388](https://github.com/apache/arrow-datafusion/pull/388) ([jgoday](https://github.com/jgoday)) - -**Fixed bugs:** - -- fix maturin version in pyproject.toml [\#756](https://github.com/apache/arrow-datafusion/pull/756) ([Jimexist](https://github.com/Jimexist)) -- fix pyarrow type id mapping in `python` crate [\#742](https://github.com/apache/arrow-datafusion/pull/742) ([Jimexist](https://github.com/Jimexist)) - -**Closed issues:** - -- Confirm git tagging strategy for releases [\#770](https://github.com/apache/arrow-datafusion/issues/770) -- arrow::util::pretty::pretty\_format\_batches missing [\#769](https://github.com/apache/arrow-datafusion/issues/769) -- move the `assert_batches_eq!` macros to a non part of datafusion [\#745](https://github.com/apache/arrow-datafusion/issues/745) -- fix an issue where aliases are not respected in generating downstream schemas in window expr [\#592](https://github.com/apache/arrow-datafusion/issues/592) -- make the planner to print more succinct and useful information in window function explain clause [\#526](https://github.com/apache/arrow-datafusion/issues/526) -- move window frame module to be in `logical_plan` [\#517](https://github.com/apache/arrow-datafusion/issues/517) -- use a more rust idiomatic way of handling nth\_value [\#448](https://github.com/apache/arrow-datafusion/issues/448) -- create a test with more than one partition for window functions [\#435](https://github.com/apache/arrow-datafusion/issues/435) -- Implement hash-partitioned hash aggregate [\#27](https://github.com/apache/arrow-datafusion/issues/27) -- Consider using GitHub pages for DataFusion/Ballista documentation [\#18](https://github.com/apache/arrow-datafusion/issues/18) -- Update "repository" in Cargo.toml [\#16](https://github.com/apache/arrow-datafusion/issues/16) - -**Merged pull requests:** - -- fix python binding for `concat`, `concat_ws`, and `random` [\#768](https://github.com/apache/arrow-datafusion/pull/768) ([Jimexist](https://github.com/Jimexist)) -- fix 226, make `concat`, `concat_ws`, and `random` work with `Python` crate [\#761](https://github.com/apache/arrow-datafusion/pull/761) ([Jimexist](https://github.com/Jimexist)) -- fix python crate with the changes to logical plan builder [\#650](https://github.com/apache/arrow-datafusion/pull/650) ([Jimexist](https://github.com/Jimexist)) -- use nightly nightly-2021-05-10 [\#536](https://github.com/apache/arrow-datafusion/pull/536) ([Jimexist](https://github.com/Jimexist)) -- Define the unittests using pytest [\#493](https://github.com/apache/arrow-datafusion/pull/493) ([kszucs](https://github.com/kszucs)) -- use requirements.txt to formalize python deps [\#484](https://github.com/apache/arrow-datafusion/pull/484) ([Jimexist](https://github.com/Jimexist)) -- update cargo.toml in python crate and fix unit test due to hash joins [\#483](https://github.com/apache/arrow-datafusion/pull/483) ([Jimexist](https://github.com/Jimexist)) -- simplify python function definitions [\#477](https://github.com/apache/arrow-datafusion/pull/477) ([Jimexist](https://github.com/Jimexist)) -- Expose DataFrame::sort in the python bindings [\#469](https://github.com/apache/arrow-datafusion/pull/469) ([kszucs](https://github.com/kszucs)) -- Revert "Revert "Add datafusion-python \(\#69\)" \(\#257\)" [\#270](https://github.com/apache/arrow-datafusion/pull/270) ([andygrove](https://github.com/andygrove)) -- Revert "Add datafusion-python \(\#69\)" [\#257](https://github.com/apache/arrow-datafusion/pull/257) ([andygrove](https://github.com/andygrove)) -- update arrow-rs deps to latest master [\#216](https://github.com/apache/arrow-datafusion/pull/216) ([alamb](https://github.com/alamb)) -- Add datafusion-python [\#69](https://github.com/apache/arrow-datafusion/pull/69) ([jorgecarleitao](https://github.com/jorgecarleitao)) - - - -\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/python/Cargo.lock b/python/Cargo.lock deleted file mode 100644 index fa84a54ced7b..000000000000 --- a/python/Cargo.lock +++ /dev/null @@ -1,1456 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom 0.2.3", - "once_cell", - "version_check", -] - -[[package]] -name = "aho-corasick" -version = "0.7.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" -dependencies = [ - "memchr", -] - -[[package]] -name = "alloc-no-stdlib" -version = "2.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35ef4730490ad1c4eae5c4325b2a95f521d023e5c885853ff7aca0a6a1631db3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697ed7edc0f1711de49ce108c541623a0af97c6c60b2f6e2b65229847ac843c2" -dependencies = [ - "alloc-no-stdlib", -] - -[[package]] -name = "arrayref" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" - -[[package]] -name = "arrayvec" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4dc07131ffa69b8072d35f5007352af944213cde02545e2103680baed38fcd" - -[[package]] -name = "arrow" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "337e668497751234149fd607f5cb41a6ae7b286b6329589126fe67f0ac55d637" -dependencies = [ - "bitflags", - "chrono", - "comfy-table", - "csv", - "flatbuffers", - "hex", - "indexmap", - "lazy_static", - "lexical-core", - "multiversion", - "num", - "pyo3", - "rand 0.8.4", - "regex", - "serde", - "serde_derive", - "serde_json", -] - -[[package]] -name = "async-trait" -version = "0.1.51" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44318e776df68115a881de9a8fd1b9e53368d7a4a5ce4cc48517da3393233a5e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "autocfg" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" - -[[package]] -name = "base64" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "blake2" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" -dependencies = [ - "crypto-mac", - "digest", - "opaque-debug", -] - -[[package]] -name = "blake3" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2607a74355ce2e252d0c483b2d8a348e1bba36036e786ccc2dcd777213c86ffd" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq", - "digest", -] - -[[package]] -name = "block-buffer" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" -dependencies = [ - "generic-array", -] - -[[package]] -name = "brotli" -version = "3.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71cb90ade945043d3d53597b2fc359bb063db8ade2bcffe7997351d0756e9d50" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ad2d4653bf5ca36ae797b1f4bb4dbddb60ce49ca4aed8a2ce4829f60425b80" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - -[[package]] -name = "byteorder" -version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" - -[[package]] -name = "cc" -version = "1.0.71" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79c2681d6594606957bbb8631c4b90a7fcaaa72cdb714743a437b156d6a7eedd" -dependencies = [ - "jobserver", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "chrono" -version = "0.4.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" -dependencies = [ - "libc", - "num-integer", - "num-traits", - "time", - "winapi", -] - -[[package]] -name = "comfy-table" -version = "4.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11e95a3e867422fd8d04049041f5671f94d53c32a9dcd82e2be268714942f3f3" -dependencies = [ - "strum", - "strum_macros", - "unicode-width", -] - -[[package]] -name = "constant_time_eq" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" - -[[package]] -name = "cpufeatures" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" -dependencies = [ - "libc", -] - -[[package]] -name = "crc32fast" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81156fece84ab6a9f2afdb109ce3ae577e42b1228441eded99bd77f627953b1a" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "crypto-mac" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" -dependencies = [ - "generic-array", - "subtle", -] - -[[package]] -name = "csv" -version = "1.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" -dependencies = [ - "bstr", - "csv-core", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "csv-core" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" -dependencies = [ - "memchr", -] - -[[package]] -name = "datafusion" -version = "5.1.0" -dependencies = [ - "ahash", - "arrow", - "async-trait", - "blake2", - "blake3", - "chrono", - "futures", - "hashbrown", - "lazy_static", - "log", - "md-5", - "num_cpus", - "ordered-float 2.8.0", - "parquet", - "paste 1.0.5", - "pin-project-lite", - "pyo3", - "rand 0.8.4", - "regex", - "sha2", - "smallvec", - "sqlparser", - "tokio", - "tokio-stream", - "unicode-segmentation", -] - -[[package]] -name = "datafusion-python" -version = "0.3.0" -dependencies = [ - "datafusion", - "pyo3", - "rand 0.7.3", - "tokio", - "uuid", -] - -[[package]] -name = "digest" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" -dependencies = [ - "generic-array", -] - -[[package]] -name = "flatbuffers" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4c5738bcd7fad10315029c50026f83c9da5e4a21f8ed66826f43e0e2bde5f6" -dependencies = [ - "bitflags", - "smallvec", - "thiserror", -] - -[[package]] -name = "flate2" -version = "1.0.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f" -dependencies = [ - "cfg-if", - "crc32fast", - "libc", - "miniz_oxide", -] - -[[package]] -name = "futures" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12aa0eb539080d55c3f2d45a67c3b58b6b0773c1a3ca2dfec66d58c97fd66ca" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" - -[[package]] -name = "futures-executor" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "522de2a0fe3e380f1bc577ba0474108faf3f6b18321dbf60b3b9c39a75073377" - -[[package]] -name = "futures-macro" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb" -dependencies = [ - "autocfg", - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11" - -[[package]] -name = "futures-task" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" - -[[package]] -name = "futures-util" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" -dependencies = [ - "autocfg", - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "proc-macro-hack", - "proc-macro-nested", - "slab", -] - -[[package]] -name = "generic-array" -version = "0.14.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501466ecc8a30d1d3b7fc9229b122b2ce8ed6e9d9223f1138d4babb253e51817" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - -[[package]] -name = "getrandom" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.10.2+wasi-snapshot-preview1", -] - -[[package]] -name = "hashbrown" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" -dependencies = [ - "ahash", -] - -[[package]] -name = "heck" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" -dependencies = [ - "unicode-segmentation", -] - -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "indexmap" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" -dependencies = [ - "autocfg", - "hashbrown", -] - -[[package]] -name = "indoc" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" -dependencies = [ - "indoc-impl", - "proc-macro-hack", -] - -[[package]] -name = "indoc-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" -dependencies = [ - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", - "unindent", -] - -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "integer-encoding" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48dc51180a9b377fd75814d0cc02199c20f8e99433d6762f650d39cdbbd3b56f" - -[[package]] -name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - -[[package]] -name = "jobserver" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af25a77299a7f711a01975c35a6a424eb6862092cc2d6c72c4ed6cbc56dfc1fa" -dependencies = [ - "libc", -] - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "lexical-core" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a3926d8f156019890be4abe5fd3785e0cff1001e06f59c597641fd513a5a284" -dependencies = [ - "lexical-parse-float", - "lexical-parse-integer", - "lexical-util", - "lexical-write-float", - "lexical-write-integer", -] - -[[package]] -name = "lexical-parse-float" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4d066d004fa762d9da995ed21aa8845bb9f6e4265f540d716fb4b315197bf0e" -dependencies = [ - "lexical-parse-integer", - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-parse-integer" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2c92badda8cc0fc4f3d3cc1c30aaefafb830510c8781ce4e8669881f3ed53ac" -dependencies = [ - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-util" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff669ccaae16ee33af90dc51125755efed17f1309626ba5c12052512b11e291" -dependencies = [ - "static_assertions", -] - -[[package]] -name = "lexical-write-float" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b5186948c7b297abaaa51560f2581dae625e5ce7dfc2d8fdc56345adb6dc576" -dependencies = [ - "lexical-util", - "lexical-write-integer", - "static_assertions", -] - -[[package]] -name = "lexical-write-integer" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ece956492e0e40fd95ef8658a34d53a3b8c2015762fdcaaff2167b28de1f56ef" -dependencies = [ - "lexical-util", - "static_assertions", -] - -[[package]] -name = "libc" -version = "0.2.105" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013" - -[[package]] -name = "lock_api" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" -dependencies = [ - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "lz4" -version = "1.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aac20ed6991e01bf6a2e68cc73df2b389707403662a8ba89f68511fb340f724c" -dependencies = [ - "libc", - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dca79aa95d8b3226213ad454d328369853be3a1382d89532a854f4d69640acae" -dependencies = [ - "cc", - "libc", -] - -[[package]] -name = "md-5" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" -dependencies = [ - "block-buffer", - "digest", - "opaque-debug", -] - -[[package]] -name = "memchr" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" - -[[package]] -name = "miniz_oxide" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" -dependencies = [ - "adler", - "autocfg", -] - -[[package]] -name = "multiversion" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "025c962a3dd3cc5e0e520aa9c612201d127dcdf28616974961a649dca64f5373" -dependencies = [ - "multiversion-macros", -] - -[[package]] -name = "multiversion-macros" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8a3e2bde382ebf960c1f3e79689fa5941625fe9bf694a1cb64af3e85faff3af" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "num" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - -[[package]] -name = "num-bigint" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74e768dff5fb39a41b3bcd30bb25cf989706c90d028d1ad71971987aa309d535" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-complex" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" -dependencies = [ - "autocfg", - "num-traits", -] - -[[package]] -name = "num-iter" -version = "0.1.42" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" -dependencies = [ - "autocfg", - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "once_cell" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" - -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - -[[package]] -name = "ordered-float" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3305af35278dd29f46fcdd139e0b1fbfae2153f0e5928b39b035542dd31e37b7" -dependencies = [ - "num-traits", -] - -[[package]] -name = "ordered-float" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97c9d06878b3a851e8026ef94bf7fef9ba93062cd412601da4d9cf369b1cc62d" -dependencies = [ - "num-traits", -] - -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall", - "smallvec", - "winapi", -] - -[[package]] -name = "parquet" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d263b9b59ba260518de9e57bd65931c3f765fea0fabacfe84f40d6fde38e841a" -dependencies = [ - "arrow", - "base64", - "brotli", - "byteorder", - "chrono", - "flate2", - "lz4", - "num-bigint", - "parquet-format", - "rand 0.8.4", - "snap", - "thrift", - "zstd", -] - -[[package]] -name = "parquet-format" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5bc6b23543b5dedc8f6cce50758a35e5582e148e0cfa26bd0cacd569cda5b71" -dependencies = [ - "thrift", -] - -[[package]] -name = "paste" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" -dependencies = [ - "paste-impl", - "proc-macro-hack", -] - -[[package]] -name = "paste" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acbf547ad0c65e31259204bd90935776d1c693cec2f4ff7abb7a1bbbd40dfe58" - -[[package]] -name = "paste-impl" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" -dependencies = [ - "proc-macro-hack", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "ppv-lite86" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed0cfbc8191465bed66e1718596ee0b0b35d5ee1f41c5df2189d0fe8bde535ba" - -[[package]] -name = "proc-macro-hack" -version = "0.5.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" - -[[package]] -name = "proc-macro-nested" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" - -[[package]] -name = "proc-macro2" -version = "1.0.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc3358ebc67bc8b7fa0c007f945b0b18226f78437d61bec735a9eb96b61ee70" -dependencies = [ - "unicode-xid", -] - -[[package]] -name = "pyo3" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35100f9347670a566a67aa623369293703322bb9db77d99d7df7313b575ae0c8" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "parking_lot", - "paste 0.1.18", - "pyo3-build-config", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d12961738cacbd7f91b7c43bc25cfeeaa2698ad07a04b3be0aa88b950865738f" -dependencies = [ - "once_cell", -] - -[[package]] -name = "pyo3-macros" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0bc5215d704824dfddddc03f93cb572e1155c68b6761c37005e1c288808ea8" -dependencies = [ - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71623fc593224afaab918aa3afcaf86ed2f43d34f6afde7f3922608f253240df" -dependencies = [ - "proc-macro2", - "pyo3-build-config", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38bc8cc6a5f2e3655e0899c1b848643b2562f853f114bfec7be120678e3ace05" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc 0.2.0", -] - -[[package]] -name = "rand" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.3", - "rand_hc 0.3.1", -] - -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.3", -] - -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", -] - -[[package]] -name = "rand_core" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" -dependencies = [ - "getrandom 0.2.3", -] - -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - -[[package]] -name = "rand_hc" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7" -dependencies = [ - "rand_core 0.6.3", -] - -[[package]] -name = "redox_syscall" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff" -dependencies = [ - "bitflags", -] - -[[package]] -name = "regex" -version = "1.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" - -[[package]] -name = "regex-syntax" -version = "0.6.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" - -[[package]] -name = "ryu" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" - -[[package]] -name = "scopeguard" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" - -[[package]] -name = "serde" -version = "1.0.130" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f12d06de37cf59146fbdecab66aa99f9fe4f78722e3607577a5375d66bd0c913" - -[[package]] -name = "serde_derive" -version = "1.0.130" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7bc1a1ab1961464eae040d96713baa5a724a8152c1222492465b54322ec508b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.68" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f690853975602e1bfe1ccbf50504d67174e3bcf340f23b5ea9992e0587a52d8" -dependencies = [ - "indexmap", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sha2" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa" -dependencies = [ - "block-buffer", - "cfg-if", - "cpufeatures", - "digest", - "opaque-debug", -] - -[[package]] -name = "slab" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" - -[[package]] -name = "smallvec" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" - -[[package]] -name = "snap" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" - -[[package]] -name = "sqlparser" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760e624412a15d5838ae04fad01037beeff1047781431d74360cddd6b3c1c784" -dependencies = [ - "log", -] - -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - -[[package]] -name = "strum" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaf86bbcfd1fa9670b7a129f64fc0c9fcbbfe4f1bc4210e9e98fe71ffc12cde2" - -[[package]] -name = "strum_macros" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d06aaeeee809dbc59eb4556183dd927df67db1540de5be8d3ec0b6636358a5ec" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "subtle" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" - -[[package]] -name = "syn" -version = "1.0.80" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d010a1623fbd906d51d650a9916aaefc05ffa0e4053ff7fe601167f3e715d194" -dependencies = [ - "proc-macro2", - "quote", - "unicode-xid", -] - -[[package]] -name = "thiserror" -version = "1.0.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "threadpool" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" -dependencies = [ - "num_cpus", -] - -[[package]] -name = "thrift" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6d965454947cc7266d22716ebfd07b18d84ebaf35eec558586bbb2a8cb6b5b" -dependencies = [ - "byteorder", - "integer-encoding", - "log", - "ordered-float 1.1.1", - "threadpool", -] - -[[package]] -name = "time" -version = "0.1.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" -dependencies = [ - "libc", - "winapi", -] - -[[package]] -name = "tokio" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" -dependencies = [ - "autocfg", - "num_cpus", - "pin-project-lite", - "tokio-macros", -] - -[[package]] -name = "tokio-macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-stream" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "typenum" -version = "1.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b63708a265f51345575b27fe43f9500ad611579e764c79edbc2037b1121959ec" - -[[package]] -name = "unicode-segmentation" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b" - -[[package]] -name = "unicode-width" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" - -[[package]] -name = "unicode-xid" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" - -[[package]] -name = "unindent" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7" - -[[package]] -name = "uuid" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" -dependencies = [ - "getrandom 0.2.3", -] - -[[package]] -name = "version_check" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" - -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - -[[package]] -name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "zstd" -version = "0.9.0+zstd.1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07749a5dc2cb6b36661290245e350f15ec3bbb304e493db54a1d354480522ccd" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "4.1.1+zstd.1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91c90f2c593b003603e5e0493c837088df4469da25aafff8bce42ba48caf079" -dependencies = [ - "libc", - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "1.6.1+zstd.1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "615120c7a2431d16cf1cf979e7fc31ba7a5b5e5707b29c8a99e5dbf8a8392a33" -dependencies = [ - "cc", - "libc", -] diff --git a/python/Cargo.toml b/python/Cargo.toml deleted file mode 100644 index 974a6140644e..000000000000 --- a/python/Cargo.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "datafusion-python" -version = "0.4.0" -homepage = "https://github.com/apache/arrow" -repository = "https://github.com/apache/arrow" -authors = ["Apache Arrow "] -description = "Build and run queries against data" -readme = "README.md" -license = "Apache-2.0" -edition = "2021" -rust-version = "1.57" - -[dependencies] -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } -rand = "0.7" -pyo3 = { version = "0.14", features = ["extension-module", "abi3", "abi3-py36"] } -datafusion = { path = "../datafusion", version = "6.0.0", features = ["pyarrow"] } -uuid = { version = "0.8", features = ["v4"] } - -[lib] -name = "_internal" -crate-type = ["cdylib"] - -[package.metadata.maturin] -name = "datafusion._internal" - -[profile.release] -lto = true -codegen-units = 1 diff --git a/python/LICENSE.txt b/python/LICENSE.txt deleted file mode 100644 index d64569567334..000000000000 --- a/python/LICENSE.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/python/README.md b/python/README.md index 5979803dc31c..b3a2a3061ce9 100644 --- a/python/README.md +++ b/python/README.md @@ -17,161 +17,6 @@ under the License. --> -## DataFusion in Python +# DataFusion in Python -This is a Python library that binds to [Apache Arrow](https://arrow.apache.org/) in-memory query engine [DataFusion](https://github.com/apache/arrow-datafusion). - -Like pyspark, it allows you to build a plan through SQL or a DataFrame API against in-memory data, parquet or CSV files, run it in a multi-threaded environment, and obtain the result back in Python. - -It also allows you to use UDFs and UDAFs for complex operations. - -The major advantage of this library over other execution engines is that this library achieves zero-copy between Python and its execution engine: there is no cost in using UDFs, UDAFs, and collecting the results to Python apart from having to lock the GIL when running those operations. - -Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions about thread safety and lack of memory leaks. - -Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html). - -## How to use it - -Simple usage: - -```python -import datafusion -import pyarrow - -# an alias -f = datafusion.functions - -# create a context -ctx = datafusion.ExecutionContext() - -# create a RecordBatch and a new DataFrame from it -batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], - names=["a", "b"], -) -df = ctx.create_dataframe([[batch]]) - -# create a new statement -df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), -) - -# execute and collect the first (and only) batch -result = df.collect()[0] - -assert result.column(0) == pyarrow.array([5, 7, 9]) -assert result.column(1) == pyarrow.array([-3, -3, -3]) -``` - -### UDFs - -```python -def is_null(array: pyarrow.Array) -> pyarrow.Array: - return array.is_null() - -udf = f.udf(is_null, [pyarrow.int64()], pyarrow.bool_()) - -df = df.select(udf(f.col("a"))) -``` - -### UDAF - -```python -import pyarrow -import pyarrow.compute - - -class Accumulator: - """ - Interface of a user-defined accumulation. - """ - def __init__(self): - self._sum = pyarrow.scalar(0.0) - - def to_scalars(self) -> [pyarrow.Scalar]: - return [self._sum] - - def update(self, values: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py()) - - def merge(self, states: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py()) - - def evaluate(self) -> pyarrow.Scalar: - return self._sum - - -df = ... - -udaf = f.udaf(Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()]) - -df = df.aggregate( - [], - [udaf(f.col("a"))] -) -``` - -## How to install (from pip) - -```bash -pip install datafusion -# or -python -m pip install datafusion -``` - -## How to develop - -This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). - -Bootstrap: - -```bash -# fetch this repo -git clone git@github.com:apache/arrow-datafusion.git -# change to python directory -cd arrow-datafusion/python -# prepare development environment (used to build wheel / install in development) -python3 -m venv venv -# activate the venv -source venv/bin/activate -# update pip itself if necessary -python -m pip install -U pip -# if python -V gives python 3.7 -python -m pip install -r requirements-37.txt -# if python -V gives python 3.8/3.9/3.10 -python -m pip install -r requirements.txt -``` - -Whenever rust code changes (your changes or via `git pull`): - -```bash -# make sure you activate the venv using "source venv/bin/activate" first -maturin develop -python -m pytest -``` - -## How to update dependencies - -To change test dependencies, change the `requirements.in` and run - -```bash -# install pip-tools (this can be done only once), also consider running in venv -python -m pip install pip-tools - -# change requirements.in and then run -python -m piptools compile --generate-hashes -o requirements-37.txt -# or run this is you are on python 3.8/3.9/3.10 -python -m piptools compile --generate-hashes -o requirements.txt -``` - -To update dependencies, run with `-U` - -```bash -python -m piptools compile -U --generate-hashes -o requirements-310.txt -``` - -More details [here](https://github.com/jazzband/pip-tools) +This directory is now moved to its [dedicated repository](https://github.com/datafusion-contrib/datafusion-python). diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py deleted file mode 100644 index 0a25592f80ae..000000000000 --- a/python/datafusion/__init__.py +++ /dev/null @@ -1,111 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from abc import ABCMeta, abstractmethod -from typing import List - -import pyarrow as pa - -from ._internal import ( - AggregateUDF, - DataFrame, - ExecutionContext, - Expression, - ScalarUDF, -) - - -__all__ = [ - "DataFrame", - "ExecutionContext", - "Expression", - "AggregateUDF", - "ScalarUDF", - "column", - "literal", -] - - -class Accumulator(metaclass=ABCMeta): - @abstractmethod - def state(self) -> List[pa.Scalar]: - pass - - @abstractmethod - def update(self, values: pa.Array) -> None: - pass - - @abstractmethod - def merge(self, states: pa.Array) -> None: - pass - - @abstractmethod - def evaluate(self) -> pa.Scalar: - pass - - -def column(value): - return Expression.column(value) - - -col = column - - -def literal(value): - if not isinstance(value, pa.Scalar): - value = pa.scalar(value) - return Expression.literal(value) - - -lit = literal - - -def udf(func, input_types, return_type, volatility, name=None): - """ - Create a new User Defined Function - """ - if not callable(func): - raise TypeError("`func` argument must be callable") - if name is None: - name = func.__qualname__ - return ScalarUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) - - -def udaf(accum, input_type, return_type, state_type, volatility, name=None): - """ - Create a new User Defined Aggregate Function - """ - if not issubclass(accum, Accumulator): - raise TypeError( - "`accum` must implement the abstract base class Accumulator" - ) - if name is None: - name = accum.__qualname__ - return AggregateUDF( - name=name, - accumulator=accum, - input_type=input_type, - return_type=return_type, - state_type=state_type, - volatility=volatility, - ) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py deleted file mode 100644 index 782ecba22191..000000000000 --- a/python/datafusion/functions.py +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from ._internal import functions - - -def __getattr__(name): - return getattr(functions, name) diff --git a/python/datafusion/tests/__init__.py b/python/datafusion/tests/__init__.py deleted file mode 100644 index 13a83393a912..000000000000 --- a/python/datafusion/tests/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/datafusion/tests/generic.py b/python/datafusion/tests/generic.py deleted file mode 100644 index 1f984a40adaa..000000000000 --- a/python/datafusion/tests/generic.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime - -import numpy as np -import pyarrow as pa -import pyarrow.csv - -# used to write parquet files -import pyarrow.parquet as pq - - -def data(): - np.random.seed(1) - data = np.concatenate( - [ - np.random.normal(0, 0.01, size=50), - np.random.normal(50, 0.01, size=50), - ] - ) - return pa.array(data) - - -def data_with_nans(): - np.random.seed(0) - data = np.random.normal(0, 0.01, size=50) - mask = np.random.randint(0, 2, size=50) - data[mask == 0] = np.NaN - return data - - -def data_datetime(f): - data = [ - datetime.datetime.now(), - datetime.datetime.now() - datetime.timedelta(days=1), - datetime.datetime.now() + datetime.timedelta(days=1), - ] - return pa.array( - data, type=pa.timestamp(f), mask=np.array([False, True, False]) - ) - - -def data_date32(): - data = [ - datetime.date(2000, 1, 1), - datetime.date(1980, 1, 1), - datetime.date(2030, 1, 1), - ] - return pa.array( - data, type=pa.date32(), mask=np.array([False, True, False]) - ) - - -def data_timedelta(f): - data = [ - datetime.timedelta(days=100), - datetime.timedelta(days=1), - datetime.timedelta(seconds=1), - ] - return pa.array( - data, type=pa.duration(f), mask=np.array([False, True, False]) - ) - - -def data_binary_other(): - return np.array([1, 0, 0], dtype="u4") - - -def write_parquet(path, data): - table = pa.Table.from_arrays([data], names=["a"]) - pq.write_table(table, path) - return str(path) diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py deleted file mode 100644 index d539c44585a6..000000000000 --- a/python/datafusion/tests/test_aggregation.py +++ /dev/null @@ -1,48 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest - -from datafusion import ExecutionContext, column -from datafusion import functions as f - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 4, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - -def test_built_in_aggregation(df): - col_a = column("a") - col_b = column("b") - df = df.aggregate( - [], - [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)], - ) - result = df.collect()[0] - assert result.column(0) == pa.array([3]) - assert result.column(1) == pa.array([1]) - assert result.column(2) == pa.array([3], type=pa.uint64()) - assert result.column(3) == pa.array([2], type=pa.uint64()) diff --git a/python/datafusion/tests/test_catalog.py b/python/datafusion/tests/test_catalog.py deleted file mode 100644 index 2e64a810a718..000000000000 --- a/python/datafusion/tests/test_catalog.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest - -from datafusion import ExecutionContext - - -@pytest.fixture -def ctx(): - return ExecutionContext() - - -@pytest.fixture -def database(ctx, tmp_path): - path = tmp_path / "test.csv" - - table = pa.Table.from_arrays( - [ - [1, 2, 3, 4], - ["a", "b", "c", "d"], - [1.1, 2.2, 3.3, 4.4], - ], - names=["int", "str", "float"], - ) - pa.csv.write_csv(table, path) - - ctx.register_csv("csv", path) - ctx.register_csv("csv1", str(path)) - ctx.register_csv( - "csv2", - path, - has_header=True, - delimiter=",", - schema_infer_max_records=10, - ) - - -def test_basic(ctx, database): - with pytest.raises(KeyError): - ctx.catalog("non-existent") - - default = ctx.catalog() - assert default.names() == ["public"] - - for database in [default.database("public"), default.database()]: - assert database.names() == {"csv1", "csv", "csv2"} - - table = database.table("csv") - assert table.kind == "physical" - assert table.schema == pa.schema( - [ - pa.field("int", pa.int64(), nullable=False), - pa.field("str", pa.string(), nullable=False), - pa.field("float", pa.float64(), nullable=False), - ] - ) diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py deleted file mode 100644 index 60beea4a01be..000000000000 --- a/python/datafusion/tests/test_context.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest - -from datafusion import ExecutionContext - - -@pytest.fixture -def ctx(): - return ExecutionContext() - - -def test_register_record_batches(ctx): - # create a RecordBatch and register it as memtable - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - - ctx.register_record_batches("t", [[batch]]) - - assert ctx.tables() == {"t"} - - result = ctx.sql("SELECT a+b, a-b FROM t").collect() - - assert result[0].column(0) == pa.array([5, 7, 9]) - assert result[0].column(1) == pa.array([-3, -3, -3]) - - -def test_create_dataframe_registers_unique_table_name(ctx): - # create a RecordBatch and register it as memtable - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - - df = ctx.create_dataframe([[batch]]) - tables = list(ctx.tables()) - - assert df - assert len(tables) == 1 - assert len(tables[0]) == 33 - assert tables[0].startswith("c") - # ensure that the rest of the table name contains - # only hexadecimal numbers - for c in tables[0][1:]: - assert c in "0123456789abcdef" diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py deleted file mode 100644 index 9040b6f807f9..000000000000 --- a/python/datafusion/tests/test_dataframe.py +++ /dev/null @@ -1,155 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest - -from datafusion import functions as f -from datafusion import DataFrame, ExecutionContext, column, literal, udf - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - - return ctx.create_dataframe([[batch]]) - - -def test_select(df): - df = df.select( - column("a") + column("b"), - column("a") - column("b"), - ) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - assert result.column(0) == pa.array([5, 7, 9]) - assert result.column(1) == pa.array([-3, -3, -3]) - - -def test_filter(df): - df = df.select( - column("a") + column("b"), - column("a") - column("b"), - ).filter(column("a") > literal(2)) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - assert result.column(0) == pa.array([9]) - assert result.column(1) == pa.array([-3]) - - -def test_sort(df): - df = df.sort(column("b").sort(ascending=False)) - - table = pa.Table.from_batches(df.collect()) - expected = {"a": [3, 2, 1], "b": [6, 5, 4]} - - assert table.to_pydict() == expected - - -def test_limit(df): - df = df.limit(1) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - assert len(result.column(0)) == 1 - assert len(result.column(1)) == 1 - - -def test_udf(df): - # is_null is a pa function over arrays - is_null = udf( - lambda x: x.is_null(), - [pa.int64()], - pa.bool_(), - volatility="immutable", - ) - - df = df.select(is_null(column("a"))) - result = df.collect()[0].column(0) - - assert result == pa.array([False, False, False]) - - -def test_join(): - ctx = ExecutionContext() - - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - df = ctx.create_dataframe([[batch]]) - - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2]), pa.array([8, 10])], - names=["a", "c"], - ) - df1 = ctx.create_dataframe([[batch]]) - - df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") - df = df.sort(column("a").sort(ascending=True)) - table = pa.Table.from_batches(df.collect()) - - expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} - assert table.to_pydict() == expected - - -def test_window_lead(df): - df = df.select( - column("a"), - f.alias( - f.window( - "lead", [column("b")], order_by=[f.order_by(column("b"))] - ), - "a_next", - ), - ) - - table = pa.Table.from_batches(df.collect()) - - expected = {"a": [1, 2, 3], "a_next": [5, 6, None]} - assert table.to_pydict() == expected - - -def test_get_dataframe(tmp_path): - ctx = ExecutionContext() - - path = tmp_path / "test.csv" - table = pa.Table.from_arrays( - [ - [1, 2, 3, 4], - ["a", "b", "c", "d"], - [1.1, 2.2, 3.3, 4.4], - ], - names=["int", "str", "float"], - ) - pa.csv.write_csv(table, path) - - ctx.register_csv("csv", path) - - df = ctx.table("csv") - assert isinstance(df, DataFrame) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py deleted file mode 100644 index 84718eaf0ce6..000000000000 --- a/python/datafusion/tests/test_functions.py +++ /dev/null @@ -1,219 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np -import pyarrow as pa -import pytest - -from datafusion import ExecutionContext, column -from datafusion import functions as f -from datafusion import literal - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - -def test_literal(df): - df = df.select( - literal(1), - literal("1"), - literal("OK"), - literal(3.14), - literal(True), - literal(b"hello world"), - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array([1] * 3) - assert result.column(1) == pa.array(["1"] * 3) - assert result.column(2) == pa.array(["OK"] * 3) - assert result.column(3) == pa.array([3.14] * 3) - assert result.column(4) == pa.array([True] * 3) - assert result.column(5) == pa.array([b"hello world"] * 3) - - -def test_lit_arith(df): - """ - Test literals with arithmetic operations - """ - df = df.select( - literal(1) + column("b"), f.concat(column("a"), literal("!")) - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array([5, 6, 7]) - assert result.column(1) == pa.array(["Hello!", "World!", "!!"]) - - -def test_math_functions(): - ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([0.1, -0.7, 0.55])], names=["value"] - ) - df = ctx.create_dataframe([[batch]]) - - values = np.array([0.1, -0.7, 0.55]) - col_v = column("value") - df = df.select( - f.abs(col_v), - f.sin(col_v), - f.cos(col_v), - f.tan(col_v), - f.asin(col_v), - f.acos(col_v), - f.exp(col_v), - f.ln(col_v + literal(pa.scalar(1))), - f.log2(col_v + literal(pa.scalar(1))), - f.log10(col_v + literal(pa.scalar(1))), - f.random(), - ) - batches = df.collect() - assert len(batches) == 1 - result = batches[0] - - np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) - np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) - np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) - np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) - np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) - np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) - np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) - np.testing.assert_array_almost_equal( - result.column(7), np.log(values + 1.0) - ) - np.testing.assert_array_almost_equal( - result.column(8), np.log2(values + 1.0) - ) - np.testing.assert_array_almost_equal( - result.column(9), np.log10(values + 1.0) - ) - np.testing.assert_array_less(result.column(10), np.ones_like(values)) - - -def test_string_functions(df): - df = df.select(f.md5(column("a")), f.lower(column("a"))) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array( - [ - "8b1a9953c4611296a827abf8c47804d7", - "f5a7924e621e84c9280a9a27e1bcb7f6", - "9033e0e305f247c0c3c80d0c7848c8b3", - ] - ) - assert result.column(1) == pa.array(["hello", "world", "!"]) - - -def test_hash_functions(df): - exprs = [ - f.digest(column("a"), literal(m)) - for m in ("md5", "sha256", "sha512", "blake2s", "blake3") - ] - df = df.select(*exprs) - result = df.collect() - assert len(result) == 1 - result = result[0] - b = bytearray.fromhex - assert result.column(0) == pa.array( - [ - b("8B1A9953C4611296A827ABF8C47804D7"), - b("F5A7924E621E84C9280A9A27E1BCB7F6"), - b("9033E0E305F247C0C3C80D0C7848C8B3"), - ] - ) - assert result.column(1) == pa.array( - [ - b( - "185F8DB32271FE25F561A6FC938B2E26" - "4306EC304EDA518007D1764826381969" - ), - b( - "78AE647DC5544D227130A0682A51E30B" - "C7777FBB6D8A8F17007463A3ECD1D524" - ), - b( - "BB7208BC9B5D7C04F1236A82A0093A5E" - "33F40423D5BA8D4266F7092C3BA43B62" - ), - ] - ) - assert result.column(2) == pa.array( - [ - b( - "3615F80C9D293ED7402687F94B22D58E" - "529B8CC7916F8FAC7FDDF7FBD5AF4CF7" - "77D3D795A7A00A16BF7E7F3FB9561EE9" - "BAAE480DA9FE7A18769E71886B03F315" - ), - b( - "8EA77393A42AB8FA92500FB077A9509C" - "C32BC95E72712EFA116EDAF2EDFAE34F" - "BB682EFDD6C5DD13C117E08BD4AAEF71" - "291D8AACE2F890273081D0677C16DF0F" - ), - b( - "3831A6A6155E509DEE59A7F451EB3532" - "4D8F8F2DF6E3708894740F98FDEE2388" - "9F4DE5ADB0C5010DFB555CDA77C8AB5D" - "C902094C52DE3278F35A75EBC25F093A" - ), - ] - ) - assert result.column(3) == pa.array( - [ - b( - "F73A5FBF881F89B814871F46E26AD3FA" - "37CB2921C5E8561618639015B3CCBB71" - ), - b( - "B792A0383FB9E7A189EC150686579532" - "854E44B71AC394831DAED169BA85CCC5" - ), - b( - "27988A0E51812297C77A433F63523334" - "6AEE29A829DCF4F46E0F58F402C6CFCB" - ), - ] - ) - assert result.column(4) == pa.array( - [ - b( - "FBC2B0516EE8744D293B980779178A35" - "08850FDCFE965985782C39601B65794F" - ), - b( - "BF73D18575A736E4037D45F9E316085B" - "86C19BE6363DE6AA789E13DEAACC1C4E" - ), - b( - "C8D11B9F7237E4034ADBCD2005735F9B" - "C4C597C75AD89F4492BEC8F77D15F7EB" - ), - ] - ) diff --git a/python/datafusion/tests/test_imports.py b/python/datafusion/tests/test_imports.py deleted file mode 100644 index 423800248a5c..000000000000 --- a/python/datafusion/tests/test_imports.py +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest - -import datafusion -from datafusion import ( - AggregateUDF, - DataFrame, - ExecutionContext, - Expression, - ScalarUDF, - functions, -) - - -def test_import_datafusion(): - assert datafusion.__name__ == "datafusion" - - -def test_class_module_is_datafusion(): - for klass in [ - ExecutionContext, - Expression, - DataFrame, - ScalarUDF, - AggregateUDF, - ]: - assert klass.__module__ == "datafusion" - - -def test_import_from_functions_submodule(): - from datafusion.functions import abs, sin # noqa - - assert functions.abs is abs - assert functions.sin is sin - - msg = "cannot import name 'foobar' from 'datafusion.functions'" - with pytest.raises(ImportError, match=msg): - from datafusion.functions import foobar # noqa - - -def test_classes_are_inheritable(): - class MyExecContext(ExecutionContext): - pass - - class MyExpression(Expression): - pass - - class MyDataFrame(DataFrame): - pass diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py deleted file mode 100644 index 23f20079f0da..000000000000 --- a/python/datafusion/tests/test_sql.py +++ /dev/null @@ -1,250 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np -import pyarrow as pa -import pytest - -from datafusion import ExecutionContext, udf - -from . import generic as helpers - - -@pytest.fixture -def ctx(): - return ExecutionContext() - - -def test_no_table(ctx): - with pytest.raises(Exception, match="DataFusion error"): - ctx.sql("SELECT a FROM b").collect() - - -def test_register_csv(ctx, tmp_path): - path = tmp_path / "test.csv" - - table = pa.Table.from_arrays( - [ - [1, 2, 3, 4], - ["a", "b", "c", "d"], - [1.1, 2.2, 3.3, 4.4], - ], - names=["int", "str", "float"], - ) - pa.csv.write_csv(table, path) - - ctx.register_csv("csv", path) - ctx.register_csv("csv1", str(path)) - ctx.register_csv( - "csv2", - path, - has_header=True, - delimiter=",", - schema_infer_max_records=10, - ) - alternative_schema = pa.schema( - [ - ("some_int", pa.int16()), - ("some_bytes", pa.string()), - ("some_floats", pa.float32()), - ] - ) - ctx.register_csv("csv3", path, schema=alternative_schema) - - assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"} - - for table in ["csv", "csv1", "csv2"]: - result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect() - result = pa.Table.from_batches(result) - assert result.to_pydict() == {"cnt": [4]} - - result = ctx.sql("SELECT * FROM csv3").collect() - result = pa.Table.from_batches(result) - assert result.schema == alternative_schema - - with pytest.raises( - ValueError, match="Delimiter must be a single character" - ): - ctx.register_csv("csv4", path, delimiter="wrong") - - -def test_register_parquet(ctx, tmp_path): - path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) - ctx.register_parquet("t", path) - assert ctx.tables() == {"t"} - - result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() - result = pa.Table.from_batches(result) - assert result.to_pydict() == {"cnt": [100]} - - -def test_execute(ctx, tmp_path): - data = [1, 1, 2, 2, 3, 11, 12] - - # single column, "a" - path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) - ctx.register_parquet("t", path) - - assert ctx.tables() == {"t"} - - # count - result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() - - expected = pa.array([7], pa.uint64()) - expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] - assert result == expected - - # where - expected = pa.array([2], pa.uint64()) - expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] - result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a > 10").collect() - assert result == expected - - # group by - results = ctx.sql( - "SELECT CAST(a as int) AS a, COUNT(a) AS cnt FROM t GROUP BY a" - ).collect() - - # group by returns batches - result_keys = [] - result_values = [] - for result in results: - pydict = result.to_pydict() - result_keys.extend(pydict["a"]) - result_values.extend(pydict["cnt"]) - - result_keys, result_values = ( - list(t) for t in zip(*sorted(zip(result_keys, result_values))) - ) - - assert result_keys == [1, 2, 3, 11, 12] - assert result_values == [2, 2, 1, 1, 1] - - # order by - result = ctx.sql( - "SELECT a, CAST(a AS int) AS a_int FROM t ORDER BY a DESC LIMIT 2" - ).collect() - expected_a = pa.array([50.0219, 50.0152], pa.float64()) - expected_cast = pa.array([50, 50], pa.int32()) - expected = [ - pa.RecordBatch.from_arrays([expected_a, expected_cast], ["a", "a_int"]) - ] - np.testing.assert_equal(expected[0].column(1), expected[0].column(1)) - - -def test_cast(ctx, tmp_path): - """ - Verify that we can cast - """ - path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) - ctx.register_parquet("t", path) - - valid_types = [ - "smallint", - "int", - "bigint", - "float(32)", - "float(64)", - "float", - ] - - select = ", ".join( - [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] - ) - - # can execute, which implies that we can cast - ctx.sql(f"SELECT {select} FROM t").collect() - - -@pytest.mark.parametrize( - ("fn", "input_types", "output_type", "input_values", "expected_values"), - [ - ( - lambda x: x, - [pa.float64()], - pa.float64(), - [-1.2, None, 1.2], - [-1.2, None, 1.2], - ), - ( - lambda x: x.is_null(), - [pa.float64()], - pa.bool_(), - [-1.2, None, 1.2], - [False, True, False], - ), - ], -) -def test_udf( - ctx, tmp_path, fn, input_types, output_type, input_values, expected_values -): - # write to disk - path = helpers.write_parquet( - tmp_path / "a.parquet", pa.array(input_values) - ) - ctx.register_parquet("t", path) - - func = udf( - fn, input_types, output_type, name="func", volatility="immutable" - ) - ctx.register_udf(func) - - batches = ctx.sql("SELECT func(a) AS tt FROM t").collect() - result = batches[0].column(0) - - assert result == pa.array(expected_values) - - -_null_mask = np.array([False, True, False]) - - -@pytest.mark.parametrize( - "arr", - [ - pa.array(["a", "b", "c"], pa.utf8(), _null_mask), - pa.array(["a", "b", "c"], pa.large_utf8(), _null_mask), - pa.array([b"1", b"2", b"3"], pa.binary(), _null_mask), - pa.array([b"1111", b"2222", b"3333"], pa.large_binary(), _null_mask), - pa.array([False, True, True], None, _null_mask), - pa.array([0, 1, 2], None), - helpers.data_binary_other(), - helpers.data_date32(), - helpers.data_with_nans(), - # C data interface missing - pytest.param( - pa.array([b"1111", b"2222", b"3333"], pa.binary(4), _null_mask), - marks=pytest.mark.xfail, - ), - pytest.param(helpers.data_datetime("s"), marks=pytest.mark.xfail), - pytest.param(helpers.data_datetime("ms"), marks=pytest.mark.xfail), - pytest.param(helpers.data_datetime("us"), marks=pytest.mark.xfail), - pytest.param(helpers.data_datetime("ns"), marks=pytest.mark.xfail), - # Not writtable to parquet - pytest.param(helpers.data_timedelta("s"), marks=pytest.mark.xfail), - pytest.param(helpers.data_timedelta("ms"), marks=pytest.mark.xfail), - pytest.param(helpers.data_timedelta("us"), marks=pytest.mark.xfail), - pytest.param(helpers.data_timedelta("ns"), marks=pytest.mark.xfail), - ], -) -def test_simple_select(ctx, tmp_path, arr): - path = helpers.write_parquet(tmp_path / "a.parquet", arr) - ctx.register_parquet("t", path) - - batches = ctx.sql("SELECT a AS tt FROM t").collect() - result = batches[0].column(0) - - np.testing.assert_equal(result, arr) diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py deleted file mode 100644 index 2f286ba105dd..000000000000 --- a/python/datafusion/tests/test_udaf.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import List - -import pyarrow as pa -import pyarrow.compute as pc -import pytest - -from datafusion import Accumulator, ExecutionContext, column, udaf - - -class Summarize(Accumulator): - """ - Interface of a user-defined accumulation. - """ - - def __init__(self): - self._sum = pa.scalar(0.0) - - def state(self) -> List[pa.Scalar]: - return [self._sum] - - def update(self, values: pa.Array) -> None: - # Not nice since pyarrow scalars can't be summed yet. - # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - - def merge(self, states: pa.Array) -> None: - # Not nice since pyarrow scalars can't be summed yet. - # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states).as_py()) - - def evaluate(self) -> pa.Scalar: - return self._sum - - -class NotSubclassOfAccumulator: - pass - - -class MissingMethods(Accumulator): - def __init__(self): - self._sum = pa.scalar(0) - - def state(self) -> List[pa.Scalar]: - return [self._sum] - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 4, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - -def test_errors(df): - with pytest.raises(TypeError): - udaf( - NotSubclassOfAccumulator, - pa.float64(), - pa.float64(), - [pa.float64()], - volatility="immutable", - ) - - accum = udaf( - MissingMethods, - pa.int64(), - pa.int64(), - [pa.int64()], - volatility="immutable", - ) - df = df.aggregate([], [accum(column("a"))]) - - msg = ( - "Can't instantiate abstract class MissingMethods with abstract " - "methods evaluate, merge, update" - ) - with pytest.raises(Exception, match=msg): - df.collect() - - -def test_aggregate(df): - summarize = udaf( - Summarize, - pa.float64(), - pa.float64(), - [pa.float64()], - volatility="immutable", - ) - - df = df.aggregate([], [summarize(column("a"))]) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) - - -def test_group_by(df): - summarize = udaf( - Summarize, - pa.float64(), - pa.float64(), - [pa.float64()], - volatility="immutable", - ) - - df = df.aggregate([column("b")], [summarize(column("a"))]) - - batches = df.collect() - - arrays = [batch.column(1) for batch in batches] - joined = pa.concat_arrays(arrays) - assert joined == pa.array([1.0 + 2.0, 3.0]) diff --git a/python/pyproject.toml b/python/pyproject.toml deleted file mode 100644 index c6ee363497d7..000000000000 --- a/python/pyproject.toml +++ /dev/null @@ -1,55 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[build-system] -requires = ["maturin>=0.11,<0.12"] -build-backend = "maturin" - -[project] -name = "datafusion" -description = "Build and run queries against data" -readme = "README.md" -license = {file = "LICENSE.txt"} -requires-python = ">=3.6" -keywords = ["datafusion", "dataframe", "rust", "query-engine"] -classifier = [ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "License :: OSI Approved", - "Operating System :: MacOS", - "Operating System :: Microsoft :: Windows", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python", - "Programming Language :: Rust", -] -dependencies = [ - "pyarrow>=1", -] - -[project.urls] -documentation = "https://arrow.apache.org/datafusion/python" -repository = "https://github.com/apache/arrow-datafusion" - -[tool.isort] -profile = "black" diff --git a/python/requirements-37.txt b/python/requirements-37.txt deleted file mode 100644 index e64bebf3201f..000000000000 --- a/python/requirements-37.txt +++ /dev/null @@ -1,329 +0,0 @@ -# -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: -# -# pip-compile --generate-hashes -# -attrs==21.2.0 \ - --hash=sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1 \ - --hash=sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb - # via pytest -black==21.9b0 \ - --hash=sha256:380f1b5da05e5a1429225676655dddb96f5ae8c75bdf91e53d798871b902a115 \ - --hash=sha256:7de4cfc7eb6b710de325712d40125689101d21d25283eed7e9998722cf10eb91 - # via -r requirements.in -click==8.0.3 \ - --hash=sha256:353f466495adaeb40b6b5f592f9f91cb22372351c84caeb068132442a4518ef3 \ - --hash=sha256:410e932b050f5eed773c4cda94de75971c89cdb3155a72a0831139a79e5ecb5b - # via black -flake8==4.0.1 \ - --hash=sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d \ - --hash=sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d - # via -r requirements.in -importlib-metadata==4.2.0 \ - --hash=sha256:057e92c15bc8d9e8109738a48db0ccb31b4d9d5cfbee5a8670879a30be66304b \ - --hash=sha256:b7e52a1f8dec14a75ea73e0891f3060099ca1d8e6a462a4dff11c3e119ea1b31 - # via - # click - # flake8 - # pluggy - # pytest -iniconfig==1.1.1 \ - --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ - --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 - # via pytest -isort==5.9.3 \ - --hash=sha256:9c2ea1e62d871267b78307fe511c0838ba0da28698c5732d54e2790bf3ba9899 \ - --hash=sha256:e17d6e2b81095c9db0a03a8025a957f334d6ea30b26f9ec70805411e5c7c81f2 - # via -r requirements.in -maturin==0.11.5 \ - --hash=sha256:07074778b063a439fdfd5501bd1d1823a216ec5b657d3ecde78fd7f2c4782422 \ - --hash=sha256:1ce666c386ff9c3c2b5d7d3ca4b1f9f675c38d7540ffbda0d5d5bc7d6ddde49a \ - --hash=sha256:20f9c30701c9932ed8026ceaf896fc77ecc76cebd6a182668dbc10ed597f8789 \ - --hash=sha256:3354d030b88c938a33bf407a6c0f79ccdd2cce3e1e3e4a2d0c92dc2e063adc6e \ - --hash=sha256:4191b0b7362b3025096faf126ff15cb682fbff324ac4a6ca18d55bb16e2b759b \ - --hash=sha256:70381be1585cb9fa5c02b83af80ae661aaad959e8aa0fddcfe195b004054bd69 \ - --hash=sha256:7bf96e7586bfdb5b0fadc6d662534b8a41123b33dff084fa383a81ded0ce5334 \ - --hash=sha256:ab2b3ccf66f5e0f9c3904d215835337b1bd305e79e3bf53b65bbc80a5755e01b \ - --hash=sha256:b0ac45879a7d624b47d72b093ae3370270894c19779f42aad7568a92951c5d47 \ - --hash=sha256:c2ded8b4ef9210d627bb966bc67661b7db259535f6062afe1ce5605406b50f3f \ - --hash=sha256:d78f24561a5e02f7d119b348b26e5772ad5698a43ca49e8facb9ce77cf273714 - # via -r requirements.in -mccabe==0.6.1 \ - --hash=sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42 \ - --hash=sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f - # via flake8 -mypy==0.910 \ - --hash=sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9 \ - --hash=sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a \ - --hash=sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9 \ - --hash=sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e \ - --hash=sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2 \ - --hash=sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212 \ - --hash=sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b \ - --hash=sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885 \ - --hash=sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150 \ - --hash=sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703 \ - --hash=sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072 \ - --hash=sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457 \ - --hash=sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e \ - --hash=sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0 \ - --hash=sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb \ - --hash=sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97 \ - --hash=sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8 \ - --hash=sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811 \ - --hash=sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6 \ - --hash=sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de \ - --hash=sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504 \ - --hash=sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921 \ - --hash=sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d - # via -r requirements.in -mypy-extensions==0.4.3 \ - --hash=sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d \ - --hash=sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8 - # via - # black - # mypy -numpy==1.21.3 \ - --hash=sha256:043e83bfc274649c82a6f09836943e4a4aebe5e33656271c7dbf9621dd58b8ec \ - --hash=sha256:160ccc1bed3a8371bf0d760971f09bfe80a3e18646620e9ded0ad159d9749baa \ - --hash=sha256:188031f833bbb623637e66006cf75e933e00e7231f67e2b45cf8189612bb5dc3 \ - --hash=sha256:28f15209fb535dd4c504a7762d3bc440779b0e37d50ed810ced209e5cea60d96 \ - --hash=sha256:29fb3dcd0468b7715f8ce2c0c2d9bbbaf5ae686334951343a41bd8d155c6ea27 \ - --hash=sha256:2a6ee9620061b2a722749b391c0d80a0e2ae97290f1b32e28d5a362e21941ee4 \ - --hash=sha256:300321e3985c968e3ae7fbda187237b225f3ffe6528395a5b7a5407f73cf093e \ - --hash=sha256:32437f0b275c1d09d9c3add782516413e98cd7c09e6baf4715cbce781fc29912 \ - --hash=sha256:3c09418a14471c7ae69ba682e2428cae5b4420a766659605566c0fa6987f6b7e \ - --hash=sha256:49c6249260890e05b8111ebfc391ed58b3cb4b33e63197b2ec7f776e45330721 \ - --hash=sha256:4cc9b512e9fb590797474f58b7f6d1f1b654b3a94f4fa8558b48ca8b3cfc97cf \ - --hash=sha256:508b0b513fa1266875524ba8a9ecc27b02ad771fe1704a16314dc1a816a68737 \ - --hash=sha256:50cd26b0cf6664cb3b3dd161ba0a09c9c1343db064e7c69f9f8b551f5104d654 \ - --hash=sha256:5c4193f70f8069550a1788bd0cd3268ab7d3a2b70583dfe3b2e7f421e9aace06 \ - --hash=sha256:5dfe9d6a4c39b8b6edd7990091fea4f852888e41919d0e6722fe78dd421db0eb \ - --hash=sha256:63571bb7897a584ca3249c86dd01c10bcb5fe4296e3568b2e9c1a55356b6410e \ - --hash=sha256:75621882d2230ab77fb6a03d4cbccd2038511491076e7964ef87306623aa5272 \ - --hash=sha256:75eb7cadc8da49302f5b659d40ba4f6d94d5045fbd9569c9d058e77b0514c9e4 \ - --hash=sha256:88a5d6b268e9ad18f3533e184744acdaa2e913b13148160b1152300c949bbb5f \ - --hash=sha256:8a10968963640e75cc0193e1847616ab4c718e83b6938ae74dea44953950f6b7 \ - --hash=sha256:90bec6a86b348b4559b6482e2b684db4a9a7eed1fa054b86115a48d58fbbf62a \ - --hash=sha256:98339aa9911853f131de11010f6dd94c8cec254d3d1f7261528c3b3e3219f139 \ - --hash=sha256:a99a6b067e5190ac6d12005a4d85aa6227c5606fa93211f86b1dafb16233e57d \ - --hash=sha256:bffa2eee3b87376cc6b31eee36d05349571c236d1de1175b804b348dc0941e3f \ - --hash=sha256:c6c2d535a7beb1f8790aaa98fd089ceab2e3dd7ca48aca0af7dc60e6ef93ffe1 \ - --hash=sha256:cc14e7519fab2a4ed87d31f99c31a3796e4e1fe63a86ebdd1c5a1ea78ebd5896 \ - --hash=sha256:dd0482f3fc547f1b1b5d6a8b8e08f63fdc250c58ce688dedd8851e6e26cff0f3 \ - --hash=sha256:dde972a1e11bb7b702ed0e447953e7617723760f420decb97305e66fb4afc54f \ - --hash=sha256:e54af82d68ef8255535a6cdb353f55d6b8cf418a83e2be3569243787a4f4866f \ - --hash=sha256:e606e6316911471c8d9b4618e082635cfe98876007556e89ce03d52ff5e8fcf0 \ - --hash=sha256:f41b018f126aac18583956c54544db437f25c7ee4794bcb23eb38bef8e5e192a \ - --hash=sha256:f8f4625536926a155b80ad2bbff44f8cc59e9f2ad14cdda7acf4c135b4dc8ff2 \ - --hash=sha256:fe52dbe47d9deb69b05084abd4b0df7abb39a3c51957c09f635520abd49b29dd - # via - # -r requirements.in - # pandas - # pyarrow -packaging==21.0 \ - --hash=sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7 \ - --hash=sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14 - # via pytest -pandas==1.3.4 \ - --hash=sha256:003ba92db58b71a5f8add604a17a059f3068ef4e8c0c365b088468d0d64935fd \ - --hash=sha256:10e10a2527db79af6e830c3d5842a4d60383b162885270f8cffc15abca4ba4a9 \ - --hash=sha256:22808afb8f96e2269dcc5b846decacb2f526dd0b47baebc63d913bf847317c8f \ - --hash=sha256:2d1dc09c0013d8faa7474574d61b575f9af6257ab95c93dcf33a14fd8d2c1bab \ - --hash=sha256:35c77609acd2e4d517da41bae0c11c70d31c87aae8dd1aabd2670906c6d2c143 \ - --hash=sha256:372d72a3d8a5f2dbaf566a5fa5fa7f230842ac80f29a931fb4b071502cf86b9a \ - --hash=sha256:42493f8ae67918bf129869abea8204df899902287a7f5eaf596c8e54e0ac7ff4 \ - --hash=sha256:5298a733e5bfbb761181fd4672c36d0c627320eb999c59c65156c6a90c7e1b4f \ - --hash=sha256:5ba0aac1397e1d7b654fccf263a4798a9e84ef749866060d19e577e927d66e1b \ - --hash=sha256:a2aa18d3f0b7d538e21932f637fbfe8518d085238b429e4790a35e1e44a96ffc \ - --hash=sha256:a388960f979665b447f0847626e40f99af8cf191bce9dc571d716433130cb3a7 \ - --hash=sha256:a51528192755f7429c5bcc9e80832c517340317c861318fea9cea081b57c9afd \ - --hash=sha256:b528e126c13816a4374e56b7b18bfe91f7a7f6576d1aadba5dee6a87a7f479ae \ - --hash=sha256:c1aa4de4919358c5ef119f6377bc5964b3a7023c23e845d9db7d9016fa0c5b1c \ - --hash=sha256:c2646458e1dce44df9f71a01dc65f7e8fa4307f29e5c0f2f92c97f47a5bf22f5 \ - --hash=sha256:d47750cf07dee6b55d8423471be70d627314277976ff2edd1381f02d52dbadf9 \ - --hash=sha256:d99d2350adb7b6c3f7f8f0e5dfb7d34ff8dd4bc0a53e62c445b7e43e163fce63 \ - --hash=sha256:dd324f8ee05925ee85de0ea3f0d66e1362e8c80799eb4eb04927d32335a3e44a \ - --hash=sha256:eaca36a80acaacb8183930e2e5ad7f71539a66805d6204ea88736570b2876a7b \ - --hash=sha256:f567e972dce3bbc3a8076e0b675273b4a9e8576ac629149cf8286ee13c259ae5 \ - --hash=sha256:fe48e4925455c964db914b958f6e7032d285848b7538a5e1b19aeb26ffaea3ec - # via -r requirements.in -pathspec==0.9.0 \ - --hash=sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a \ - --hash=sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1 - # via black -platformdirs==2.4.0 \ - --hash=sha256:367a5e80b3d04d2428ffa76d33f124cf11e8fff2acdaa9b43d545f5c7d661ef2 \ - --hash=sha256:8868bbe3c3c80d42f20156f22e7131d2fb321f5bc86a2a345375c6481a67021d - # via black -pluggy==1.0.0 \ - --hash=sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159 \ - --hash=sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3 - # via pytest -py==1.10.0 \ - --hash=sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3 \ - --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a - # via pytest -pyarrow==6.0.0 \ - --hash=sha256:004185e0babc6f3c3fba6ba4f106e406a0113d0f82bb9ad9a8571a1978c45d04 \ - --hash=sha256:0204e80777ab8f4e9abd3a765a8ec07ed1e3c4630bacda50d2ce212ef0f3826f \ - --hash=sha256:072c1a0fca4509eefd7d018b78542fb7e5c63aaf5698f1c0a6e45628ae17ba44 \ - --hash=sha256:15dc0d673d3f865ca63c877bd7a2eced70b0a08969fb733a28247134b8a1f18b \ - --hash=sha256:1c38263ea438a1666b13372e7565450cfeec32dbcd1c2595749476a58465eaec \ - --hash=sha256:281ce5fa03621d786a9beb514abb09846db7f0221b50eabf543caa24037eaacd \ - --hash=sha256:2d2c681659396c745e4f1988d5dd41dcc3ad557bb8d4a8c2e44030edafc08a91 \ - --hash=sha256:376c4b5f248ae63df21fe15c194e9013753164be2d38f4b3fb8bde63ac5a1958 \ - --hash=sha256:465f87fa0be0b2928b2beeba22b5813a0203fb05d90fd8563eea48e08ecc030e \ - --hash=sha256:477c746ef42c039348a288584800e299456c80c5691401bb9b19aa9c02a427b7 \ - --hash=sha256:5144bd9db2920c7cb566c96462d62443cc239104f94771d110f74393f2fb42a2 \ - --hash=sha256:5408fa8d623e66a0445f3fb0e4027fd219bf99bfb57422d543d7b7876e2c5b55 \ - --hash=sha256:5be62679201c441356d3f2a739895dcc8d4d299f2a6eabcd2163bfb6a898abba \ - --hash=sha256:5c666bc6a1cebf01206e2dc1ab05f25f39f35d3a499e0ef5cd635225e07306ca \ - --hash=sha256:6163d82cca7541774b00503c295fe86a1722820eddb958b57f091bb6f5b0a6db \ - --hash=sha256:6a1d9a2f4ee812ed0bd4182cabef99ea914ac297274f0de086f2488093d284ef \ - --hash=sha256:7a683f71b848eb6310b4ec48c0def55dac839e9994c1ac874c9b2d3d5625def1 \ - --hash=sha256:82fe80309e01acf29e3943a1f6d3c98ec109fe1d356bc1ac37d639bcaadcf684 \ - --hash=sha256:8c23f8cdecd3d9e49f9b0f9a651ae5549d1d32fd4901fb1bdc2d327edfba844f \ - --hash=sha256:8d41dfb09ba9236cca6245f33088eb42f3c54023da281139241e0f9f3b4b754e \ - --hash=sha256:a19e58dfb04e451cd8b7bdec3ac8848373b95dfc53492c9a69789aa9074a3c1b \ - --hash=sha256:a50d2f77b86af38ceabf45617208b9105d20e7a5eebc584e7c8c0acededd82ce \ - --hash=sha256:a5bed4f948c032c40597302e9bdfa65f62295240306976ecbe43a54924c6f94f \ - --hash=sha256:ac941a147d14993987cc8b605b721735a34b3e54d167302501fb4db1ad7382c7 \ - --hash=sha256:b86d175262db1eb46afdceb36d459409eb6f8e532d3dec162f8bf572c7f57623 \ - --hash=sha256:bf3400780c4d3c9cb43b1e8a1aaf2e1b7199a0572d0a645529d2784e4d0d8497 \ - --hash=sha256:c7a6e7e0bf8779e9c3428ced85507541f3da9a0675e2f4781d4eb2c7042cbf81 \ - --hash=sha256:cc1d4a70efd583befe92d4ea6f74ed2e0aa31ccdde767cd5cae8e77c65a1c2d4 \ - --hash=sha256:d046dc78a9337baa6415be915c5a16222505233e238a1017f368243c89817eea \ - --hash=sha256:da7860688c33ca88ac05f1a487d32d96d9caa091412496c35f3d1d832145675a \ - --hash=sha256:ddf2e6e3b321adaaf716f2d5af8e92d205a9671e0cb7c0779710a567fd1dd580 \ - --hash=sha256:e81508239a71943759cee272ce625ae208092dd36ef2c6713fccee30bbcf52bb \ - --hash=sha256:ea64a48a85c631eb2a0ea13ccdec5143c85b5897836b16331ee4289d27a57247 \ - --hash=sha256:ed0be080cf595ea15ff1c9ff4097bbf1fcc4b50847d98c0a3c0412fbc6ede7e9 \ - --hash=sha256:fb701ec4a94b92102606d4e88f0b8eba34f09a5ad8e014eaa4af76f42b7f62ae \ - --hash=sha256:fbda7595f24a639bcef3419ecfac17216efacb09f7b0f1b4c4c97f900d65ca0e - # via -r requirements.in -pycodestyle==2.8.0 \ - --hash=sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20 \ - --hash=sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f - # via flake8 -pyflakes==2.4.0 \ - --hash=sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c \ - --hash=sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e - # via flake8 -pyparsing==3.0.3 \ - --hash=sha256:9e3511118010f112a4b4b435ae50e1eaa610cda191acb9e421d60cf5fde83455 \ - --hash=sha256:f8d3fe9fc404576c5164f0f0c4e382c96b85265e023c409c43d48f65da9d60d0 - # via packaging -pytest==6.2.5 \ - --hash=sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89 \ - --hash=sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134 - # via -r requirements.in -python-dateutil==2.8.2 \ - --hash=sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86 \ - --hash=sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 - # via pandas -pytz==2021.3 \ - --hash=sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c \ - --hash=sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326 - # via pandas -regex==2021.10.23 \ - --hash=sha256:0c186691a7995ef1db61205e00545bf161fb7b59cdb8c1201c89b333141c438a \ - --hash=sha256:0dcc0e71118be8c69252c207630faf13ca5e1b8583d57012aae191e7d6d28b84 \ - --hash=sha256:0f7552429dd39f70057ac5d0e897e5bfe211629652399a21671e53f2a9693a4e \ - --hash=sha256:129472cd06062fb13e7b4670a102951a3e655e9b91634432cfbdb7810af9d710 \ - --hash=sha256:13ec99df95003f56edcd307db44f06fbeb708c4ccdcf940478067dd62353181e \ - --hash=sha256:1f2b59c28afc53973d22e7bc18428721ee8ca6079becf1b36571c42627321c65 \ - --hash=sha256:2b20f544cbbeffe171911f6ce90388ad36fe3fad26b7c7a35d4762817e9ea69c \ - --hash=sha256:2fb698037c35109d3c2e30f2beb499e5ebae6e4bb8ff2e60c50b9a805a716f79 \ - --hash=sha256:34d870f9f27f2161709054d73646fc9aca49480617a65533fc2b4611c518e455 \ - --hash=sha256:391703a2abf8013d95bae39145d26b4e21531ab82e22f26cd3a181ee2644c234 \ - --hash=sha256:450dc27483548214314640c89a0f275dbc557968ed088da40bde7ef8fb52829e \ - --hash=sha256:45b65d6a275a478ac2cbd7fdbf7cc93c1982d613de4574b56fd6972ceadb8395 \ - --hash=sha256:5095a411c8479e715784a0c9236568ae72509450ee2226b649083730f3fadfc6 \ - --hash=sha256:530fc2bbb3dc1ebb17f70f7b234f90a1dd43b1b489ea38cea7be95fb21cdb5c7 \ - --hash=sha256:56f0c81c44638dfd0e2367df1a331b4ddf2e771366c4b9c5d9a473de75e3e1c7 \ - --hash=sha256:5e9c9e0ce92f27cef79e28e877c6b6988c48b16942258f3bc55d39b5f911df4f \ - --hash=sha256:6d7722136c6ed75caf84e1788df36397efdc5dbadab95e59c2bba82d4d808a4c \ - --hash=sha256:74d071dbe4b53c602edd87a7476ab23015a991374ddb228d941929ad7c8c922e \ - --hash=sha256:7b568809dca44cb75c8ebb260844ea98252c8c88396f9d203f5094e50a70355f \ - --hash=sha256:80bb5d2e92b2258188e7dcae5b188c7bf868eafdf800ea6edd0fbfc029984a88 \ - --hash=sha256:8d1cdcda6bd16268316d5db1038965acf948f2a6f43acc2e0b1641ceab443623 \ - --hash=sha256:9f665677e46c5a4d288ece12fdedf4f4204a422bb28ff05f0e6b08b7447796d1 \ - --hash=sha256:a30513828180264294953cecd942202dfda64e85195ae36c265daf4052af0464 \ - --hash=sha256:a7a986c45d1099a5de766a15de7bee3840b1e0e1a344430926af08e5297cf666 \ - --hash=sha256:a940ca7e7189d23da2bfbb38973832813eab6bd83f3bf89a977668c2f813deae \ - --hash=sha256:ab7c5684ff3538b67df3f93d66bd3369b749087871ae3786e70ef39e601345b0 \ - --hash=sha256:be04739a27be55631069b348dda0c81d8ea9822b5da10b8019b789e42d1fe452 \ - --hash=sha256:c0938ddd60cc04e8f1faf7a14a166ac939aac703745bfcd8e8f20322a7373019 \ - --hash=sha256:cb46b542133999580ffb691baf67410306833ee1e4f58ed06b6a7aaf4e046952 \ - --hash=sha256:d134757a37d8640f3c0abb41f5e68b7cf66c644f54ef1cb0573b7ea1c63e1509 \ - --hash=sha256:de557502c3bec8e634246588a94e82f1ee1b9dfcfdc453267c4fb652ff531570 \ - --hash=sha256:ded0c4a3eee56b57fcb2315e40812b173cafe79d2f992d50015f4387445737fa \ - --hash=sha256:e1dae12321b31059a1a72aaa0e6ba30156fe7e633355e445451e4021b8e122b6 \ - --hash=sha256:eb672217f7bd640411cfc69756ce721d00ae600814708d35c930930f18e8029f \ - --hash=sha256:ee684f139c91e69fe09b8e83d18b4d63bf87d9440c1eb2eeb52ee851883b1b29 \ - --hash=sha256:f3f9a91d3cc5e5b0ddf1043c0ae5fa4852f18a1c0050318baf5fc7930ecc1f9c - # via black -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -toml==0.10.2 \ - --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ - --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f - # via - # -r requirements.in - # maturin - # mypy - # pytest -tomli==1.2.2 \ - --hash=sha256:c6ce0015eb38820eaf32b5db832dbc26deb3dd427bd5f6556cf0acac2c214fee \ - --hash=sha256:f04066f68f5554911363063a30b108d2b5a5b1a010aa8b6132af78489fe3aade - # via black -typed-ast==1.4.3 \ - --hash=sha256:01ae5f73431d21eead5015997ab41afa53aa1fbe252f9da060be5dad2c730ace \ - --hash=sha256:067a74454df670dcaa4e59349a2e5c81e567d8d65458d480a5b3dfecec08c5ff \ - --hash=sha256:0fb71b8c643187d7492c1f8352f2c15b4c4af3f6338f21681d3681b3dc31a266 \ - --hash=sha256:1b3ead4a96c9101bef08f9f7d1217c096f31667617b58de957f690c92378b528 \ - --hash=sha256:2068531575a125b87a41802130fa7e29f26c09a2833fea68d9a40cf33902eba6 \ - --hash=sha256:209596a4ec71d990d71d5e0d312ac935d86930e6eecff6ccc7007fe54d703808 \ - --hash=sha256:2c726c276d09fc5c414693a2de063f521052d9ea7c240ce553316f70656c84d4 \ - --hash=sha256:398e44cd480f4d2b7ee8d98385ca104e35c81525dd98c519acff1b79bdaac363 \ - --hash=sha256:52b1eb8c83f178ab787f3a4283f68258525f8d70f778a2f6dd54d3b5e5fb4341 \ - --hash=sha256:5feca99c17af94057417d744607b82dd0a664fd5e4ca98061480fd8b14b18d04 \ - --hash=sha256:7538e495704e2ccda9b234b82423a4038f324f3a10c43bc088a1636180f11a41 \ - --hash=sha256:760ad187b1041a154f0e4d0f6aae3e40fdb51d6de16e5c99aedadd9246450e9e \ - --hash=sha256:777a26c84bea6cd934422ac2e3b78863a37017618b6e5c08f92ef69853e765d3 \ - --hash=sha256:95431a26309a21874005845c21118c83991c63ea800dd44843e42a916aec5899 \ - --hash=sha256:9ad2c92ec681e02baf81fdfa056fe0d818645efa9af1f1cd5fd6f1bd2bdfd805 \ - --hash=sha256:9c6d1a54552b5330bc657b7ef0eae25d00ba7ffe85d9ea8ae6540d2197a3788c \ - --hash=sha256:aee0c1256be6c07bd3e1263ff920c325b59849dc95392a05f258bb9b259cf39c \ - --hash=sha256:af3d4a73793725138d6b334d9d247ce7e5f084d96284ed23f22ee626a7b88e39 \ - --hash=sha256:b36b4f3920103a25e1d5d024d155c504080959582b928e91cb608a65c3a49e1a \ - --hash=sha256:b9574c6f03f685070d859e75c7f9eeca02d6933273b5e69572e5ff9d5e3931c3 \ - --hash=sha256:bff6ad71c81b3bba8fa35f0f1921fb24ff4476235a6e94a26ada2e54370e6da7 \ - --hash=sha256:c190f0899e9f9f8b6b7863debfb739abcb21a5c054f911ca3596d12b8a4c4c7f \ - --hash=sha256:c907f561b1e83e93fad565bac5ba9c22d96a54e7ea0267c708bffe863cbe4075 \ - --hash=sha256:cae53c389825d3b46fb37538441f75d6aecc4174f615d048321b716df2757fb0 \ - --hash=sha256:dd4a21253f42b8d2b48410cb31fe501d32f8b9fbeb1f55063ad102fe9c425e40 \ - --hash=sha256:dde816ca9dac1d9c01dd504ea5967821606f02e510438120091b84e852367428 \ - --hash=sha256:f2362f3cb0f3172c42938946dbc5b7843c2a28aec307c49100c8b38764eb6927 \ - --hash=sha256:f328adcfebed9f11301eaedfa48e15bdece9b519fb27e6a8c01aa52a17ec31b3 \ - --hash=sha256:f8afcf15cc511ada719a88e013cec87c11aff7b91f019295eb4530f96fe5ef2f \ - --hash=sha256:fb1bbeac803adea29cedd70781399c99138358c26d05fcbd23c13016b7f5ec65 - # via - # black - # mypy -typing-extensions==3.10.0.2 \ - --hash=sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e \ - --hash=sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7 \ - --hash=sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34 - # via - # black - # importlib-metadata - # mypy -zipp==3.6.0 \ - --hash=sha256:71c644c5369f4a6e07636f0aa966270449561fcea2e3d6747b8d23efaa9d7832 \ - --hash=sha256:9fe5ea21568a0a70e50f273397638d39b03353731e6cbbb3fd8502a33fec40bc - # via importlib-metadata diff --git a/python/requirements.in b/python/requirements.in deleted file mode 100644 index 7e54705fc8ab..000000000000 --- a/python/requirements.in +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -black -flake8 -isort -maturin -mypy -numpy -pandas -pyarrow -pytest -toml diff --git a/python/requirements.txt b/python/requirements.txt deleted file mode 100644 index 358578ecb923..000000000000 --- a/python/requirements.txt +++ /dev/null @@ -1,282 +0,0 @@ -# -# This file is autogenerated by pip-compile with python 3.10 -# To update, run: -# -# pip-compile --generate-hashes -# -attrs==21.2.0 \ - --hash=sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1 \ - --hash=sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb - # via pytest -black==21.9b0 \ - --hash=sha256:380f1b5da05e5a1429225676655dddb96f5ae8c75bdf91e53d798871b902a115 \ - --hash=sha256:7de4cfc7eb6b710de325712d40125689101d21d25283eed7e9998722cf10eb91 - # via -r requirements.in -click==8.0.3 \ - --hash=sha256:353f466495adaeb40b6b5f592f9f91cb22372351c84caeb068132442a4518ef3 \ - --hash=sha256:410e932b050f5eed773c4cda94de75971c89cdb3155a72a0831139a79e5ecb5b - # via black -flake8==4.0.1 \ - --hash=sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d \ - --hash=sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d - # via -r requirements.in -iniconfig==1.1.1 \ - --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ - --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 - # via pytest -isort==5.9.3 \ - --hash=sha256:9c2ea1e62d871267b78307fe511c0838ba0da28698c5732d54e2790bf3ba9899 \ - --hash=sha256:e17d6e2b81095c9db0a03a8025a957f334d6ea30b26f9ec70805411e5c7c81f2 - # via -r requirements.in -maturin==0.11.5 \ - --hash=sha256:07074778b063a439fdfd5501bd1d1823a216ec5b657d3ecde78fd7f2c4782422 \ - --hash=sha256:1ce666c386ff9c3c2b5d7d3ca4b1f9f675c38d7540ffbda0d5d5bc7d6ddde49a \ - --hash=sha256:20f9c30701c9932ed8026ceaf896fc77ecc76cebd6a182668dbc10ed597f8789 \ - --hash=sha256:3354d030b88c938a33bf407a6c0f79ccdd2cce3e1e3e4a2d0c92dc2e063adc6e \ - --hash=sha256:4191b0b7362b3025096faf126ff15cb682fbff324ac4a6ca18d55bb16e2b759b \ - --hash=sha256:70381be1585cb9fa5c02b83af80ae661aaad959e8aa0fddcfe195b004054bd69 \ - --hash=sha256:7bf96e7586bfdb5b0fadc6d662534b8a41123b33dff084fa383a81ded0ce5334 \ - --hash=sha256:ab2b3ccf66f5e0f9c3904d215835337b1bd305e79e3bf53b65bbc80a5755e01b \ - --hash=sha256:b0ac45879a7d624b47d72b093ae3370270894c19779f42aad7568a92951c5d47 \ - --hash=sha256:c2ded8b4ef9210d627bb966bc67661b7db259535f6062afe1ce5605406b50f3f \ - --hash=sha256:d78f24561a5e02f7d119b348b26e5772ad5698a43ca49e8facb9ce77cf273714 - # via -r requirements.in -mccabe==0.6.1 \ - --hash=sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42 \ - --hash=sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f - # via flake8 -mypy==0.910 \ - --hash=sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9 \ - --hash=sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a \ - --hash=sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9 \ - --hash=sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e \ - --hash=sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2 \ - --hash=sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212 \ - --hash=sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b \ - --hash=sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885 \ - --hash=sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150 \ - --hash=sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703 \ - --hash=sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072 \ - --hash=sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457 \ - --hash=sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e \ - --hash=sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0 \ - --hash=sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb \ - --hash=sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97 \ - --hash=sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8 \ - --hash=sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811 \ - --hash=sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6 \ - --hash=sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de \ - --hash=sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504 \ - --hash=sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921 \ - --hash=sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d - # via -r requirements.in -mypy-extensions==0.4.3 \ - --hash=sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d \ - --hash=sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8 - # via - # black - # mypy -numpy==1.21.3 \ - --hash=sha256:043e83bfc274649c82a6f09836943e4a4aebe5e33656271c7dbf9621dd58b8ec \ - --hash=sha256:160ccc1bed3a8371bf0d760971f09bfe80a3e18646620e9ded0ad159d9749baa \ - --hash=sha256:188031f833bbb623637e66006cf75e933e00e7231f67e2b45cf8189612bb5dc3 \ - --hash=sha256:28f15209fb535dd4c504a7762d3bc440779b0e37d50ed810ced209e5cea60d96 \ - --hash=sha256:29fb3dcd0468b7715f8ce2c0c2d9bbbaf5ae686334951343a41bd8d155c6ea27 \ - --hash=sha256:2a6ee9620061b2a722749b391c0d80a0e2ae97290f1b32e28d5a362e21941ee4 \ - --hash=sha256:300321e3985c968e3ae7fbda187237b225f3ffe6528395a5b7a5407f73cf093e \ - --hash=sha256:32437f0b275c1d09d9c3add782516413e98cd7c09e6baf4715cbce781fc29912 \ - --hash=sha256:3c09418a14471c7ae69ba682e2428cae5b4420a766659605566c0fa6987f6b7e \ - --hash=sha256:49c6249260890e05b8111ebfc391ed58b3cb4b33e63197b2ec7f776e45330721 \ - --hash=sha256:4cc9b512e9fb590797474f58b7f6d1f1b654b3a94f4fa8558b48ca8b3cfc97cf \ - --hash=sha256:508b0b513fa1266875524ba8a9ecc27b02ad771fe1704a16314dc1a816a68737 \ - --hash=sha256:50cd26b0cf6664cb3b3dd161ba0a09c9c1343db064e7c69f9f8b551f5104d654 \ - --hash=sha256:5c4193f70f8069550a1788bd0cd3268ab7d3a2b70583dfe3b2e7f421e9aace06 \ - --hash=sha256:5dfe9d6a4c39b8b6edd7990091fea4f852888e41919d0e6722fe78dd421db0eb \ - --hash=sha256:63571bb7897a584ca3249c86dd01c10bcb5fe4296e3568b2e9c1a55356b6410e \ - --hash=sha256:75621882d2230ab77fb6a03d4cbccd2038511491076e7964ef87306623aa5272 \ - --hash=sha256:75eb7cadc8da49302f5b659d40ba4f6d94d5045fbd9569c9d058e77b0514c9e4 \ - --hash=sha256:88a5d6b268e9ad18f3533e184744acdaa2e913b13148160b1152300c949bbb5f \ - --hash=sha256:8a10968963640e75cc0193e1847616ab4c718e83b6938ae74dea44953950f6b7 \ - --hash=sha256:90bec6a86b348b4559b6482e2b684db4a9a7eed1fa054b86115a48d58fbbf62a \ - --hash=sha256:98339aa9911853f131de11010f6dd94c8cec254d3d1f7261528c3b3e3219f139 \ - --hash=sha256:a99a6b067e5190ac6d12005a4d85aa6227c5606fa93211f86b1dafb16233e57d \ - --hash=sha256:bffa2eee3b87376cc6b31eee36d05349571c236d1de1175b804b348dc0941e3f \ - --hash=sha256:c6c2d535a7beb1f8790aaa98fd089ceab2e3dd7ca48aca0af7dc60e6ef93ffe1 \ - --hash=sha256:cc14e7519fab2a4ed87d31f99c31a3796e4e1fe63a86ebdd1c5a1ea78ebd5896 \ - --hash=sha256:dd0482f3fc547f1b1b5d6a8b8e08f63fdc250c58ce688dedd8851e6e26cff0f3 \ - --hash=sha256:dde972a1e11bb7b702ed0e447953e7617723760f420decb97305e66fb4afc54f \ - --hash=sha256:e54af82d68ef8255535a6cdb353f55d6b8cf418a83e2be3569243787a4f4866f \ - --hash=sha256:e606e6316911471c8d9b4618e082635cfe98876007556e89ce03d52ff5e8fcf0 \ - --hash=sha256:f41b018f126aac18583956c54544db437f25c7ee4794bcb23eb38bef8e5e192a \ - --hash=sha256:f8f4625536926a155b80ad2bbff44f8cc59e9f2ad14cdda7acf4c135b4dc8ff2 \ - --hash=sha256:fe52dbe47d9deb69b05084abd4b0df7abb39a3c51957c09f635520abd49b29dd - # via - # -r requirements.in - # pandas - # pyarrow -packaging==21.0 \ - --hash=sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7 \ - --hash=sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14 - # via pytest -pandas==1.3.4 \ - --hash=sha256:003ba92db58b71a5f8add604a17a059f3068ef4e8c0c365b088468d0d64935fd \ - --hash=sha256:10e10a2527db79af6e830c3d5842a4d60383b162885270f8cffc15abca4ba4a9 \ - --hash=sha256:22808afb8f96e2269dcc5b846decacb2f526dd0b47baebc63d913bf847317c8f \ - --hash=sha256:2d1dc09c0013d8faa7474574d61b575f9af6257ab95c93dcf33a14fd8d2c1bab \ - --hash=sha256:35c77609acd2e4d517da41bae0c11c70d31c87aae8dd1aabd2670906c6d2c143 \ - --hash=sha256:372d72a3d8a5f2dbaf566a5fa5fa7f230842ac80f29a931fb4b071502cf86b9a \ - --hash=sha256:42493f8ae67918bf129869abea8204df899902287a7f5eaf596c8e54e0ac7ff4 \ - --hash=sha256:5298a733e5bfbb761181fd4672c36d0c627320eb999c59c65156c6a90c7e1b4f \ - --hash=sha256:5ba0aac1397e1d7b654fccf263a4798a9e84ef749866060d19e577e927d66e1b \ - --hash=sha256:a2aa18d3f0b7d538e21932f637fbfe8518d085238b429e4790a35e1e44a96ffc \ - --hash=sha256:a388960f979665b447f0847626e40f99af8cf191bce9dc571d716433130cb3a7 \ - --hash=sha256:a51528192755f7429c5bcc9e80832c517340317c861318fea9cea081b57c9afd \ - --hash=sha256:b528e126c13816a4374e56b7b18bfe91f7a7f6576d1aadba5dee6a87a7f479ae \ - --hash=sha256:c1aa4de4919358c5ef119f6377bc5964b3a7023c23e845d9db7d9016fa0c5b1c \ - --hash=sha256:c2646458e1dce44df9f71a01dc65f7e8fa4307f29e5c0f2f92c97f47a5bf22f5 \ - --hash=sha256:d47750cf07dee6b55d8423471be70d627314277976ff2edd1381f02d52dbadf9 \ - --hash=sha256:d99d2350adb7b6c3f7f8f0e5dfb7d34ff8dd4bc0a53e62c445b7e43e163fce63 \ - --hash=sha256:dd324f8ee05925ee85de0ea3f0d66e1362e8c80799eb4eb04927d32335a3e44a \ - --hash=sha256:eaca36a80acaacb8183930e2e5ad7f71539a66805d6204ea88736570b2876a7b \ - --hash=sha256:f567e972dce3bbc3a8076e0b675273b4a9e8576ac629149cf8286ee13c259ae5 \ - --hash=sha256:fe48e4925455c964db914b958f6e7032d285848b7538a5e1b19aeb26ffaea3ec - # via -r requirements.in -pathspec==0.9.0 \ - --hash=sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a \ - --hash=sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1 - # via black -platformdirs==2.4.0 \ - --hash=sha256:367a5e80b3d04d2428ffa76d33f124cf11e8fff2acdaa9b43d545f5c7d661ef2 \ - --hash=sha256:8868bbe3c3c80d42f20156f22e7131d2fb321f5bc86a2a345375c6481a67021d - # via black -pluggy==1.0.0 \ - --hash=sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159 \ - --hash=sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3 - # via pytest -py==1.10.0 \ - --hash=sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3 \ - --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a - # via pytest -pyarrow==6.0.0 \ - --hash=sha256:004185e0babc6f3c3fba6ba4f106e406a0113d0f82bb9ad9a8571a1978c45d04 \ - --hash=sha256:0204e80777ab8f4e9abd3a765a8ec07ed1e3c4630bacda50d2ce212ef0f3826f \ - --hash=sha256:072c1a0fca4509eefd7d018b78542fb7e5c63aaf5698f1c0a6e45628ae17ba44 \ - --hash=sha256:15dc0d673d3f865ca63c877bd7a2eced70b0a08969fb733a28247134b8a1f18b \ - --hash=sha256:1c38263ea438a1666b13372e7565450cfeec32dbcd1c2595749476a58465eaec \ - --hash=sha256:281ce5fa03621d786a9beb514abb09846db7f0221b50eabf543caa24037eaacd \ - --hash=sha256:2d2c681659396c745e4f1988d5dd41dcc3ad557bb8d4a8c2e44030edafc08a91 \ - --hash=sha256:376c4b5f248ae63df21fe15c194e9013753164be2d38f4b3fb8bde63ac5a1958 \ - --hash=sha256:465f87fa0be0b2928b2beeba22b5813a0203fb05d90fd8563eea48e08ecc030e \ - --hash=sha256:477c746ef42c039348a288584800e299456c80c5691401bb9b19aa9c02a427b7 \ - --hash=sha256:5144bd9db2920c7cb566c96462d62443cc239104f94771d110f74393f2fb42a2 \ - --hash=sha256:5408fa8d623e66a0445f3fb0e4027fd219bf99bfb57422d543d7b7876e2c5b55 \ - --hash=sha256:5be62679201c441356d3f2a739895dcc8d4d299f2a6eabcd2163bfb6a898abba \ - --hash=sha256:5c666bc6a1cebf01206e2dc1ab05f25f39f35d3a499e0ef5cd635225e07306ca \ - --hash=sha256:6163d82cca7541774b00503c295fe86a1722820eddb958b57f091bb6f5b0a6db \ - --hash=sha256:6a1d9a2f4ee812ed0bd4182cabef99ea914ac297274f0de086f2488093d284ef \ - --hash=sha256:7a683f71b848eb6310b4ec48c0def55dac839e9994c1ac874c9b2d3d5625def1 \ - --hash=sha256:82fe80309e01acf29e3943a1f6d3c98ec109fe1d356bc1ac37d639bcaadcf684 \ - --hash=sha256:8c23f8cdecd3d9e49f9b0f9a651ae5549d1d32fd4901fb1bdc2d327edfba844f \ - --hash=sha256:8d41dfb09ba9236cca6245f33088eb42f3c54023da281139241e0f9f3b4b754e \ - --hash=sha256:a19e58dfb04e451cd8b7bdec3ac8848373b95dfc53492c9a69789aa9074a3c1b \ - --hash=sha256:a50d2f77b86af38ceabf45617208b9105d20e7a5eebc584e7c8c0acededd82ce \ - --hash=sha256:a5bed4f948c032c40597302e9bdfa65f62295240306976ecbe43a54924c6f94f \ - --hash=sha256:ac941a147d14993987cc8b605b721735a34b3e54d167302501fb4db1ad7382c7 \ - --hash=sha256:b86d175262db1eb46afdceb36d459409eb6f8e532d3dec162f8bf572c7f57623 \ - --hash=sha256:bf3400780c4d3c9cb43b1e8a1aaf2e1b7199a0572d0a645529d2784e4d0d8497 \ - --hash=sha256:c7a6e7e0bf8779e9c3428ced85507541f3da9a0675e2f4781d4eb2c7042cbf81 \ - --hash=sha256:cc1d4a70efd583befe92d4ea6f74ed2e0aa31ccdde767cd5cae8e77c65a1c2d4 \ - --hash=sha256:d046dc78a9337baa6415be915c5a16222505233e238a1017f368243c89817eea \ - --hash=sha256:da7860688c33ca88ac05f1a487d32d96d9caa091412496c35f3d1d832145675a \ - --hash=sha256:ddf2e6e3b321adaaf716f2d5af8e92d205a9671e0cb7c0779710a567fd1dd580 \ - --hash=sha256:e81508239a71943759cee272ce625ae208092dd36ef2c6713fccee30bbcf52bb \ - --hash=sha256:ea64a48a85c631eb2a0ea13ccdec5143c85b5897836b16331ee4289d27a57247 \ - --hash=sha256:ed0be080cf595ea15ff1c9ff4097bbf1fcc4b50847d98c0a3c0412fbc6ede7e9 \ - --hash=sha256:fb701ec4a94b92102606d4e88f0b8eba34f09a5ad8e014eaa4af76f42b7f62ae \ - --hash=sha256:fbda7595f24a639bcef3419ecfac17216efacb09f7b0f1b4c4c97f900d65ca0e - # via -r requirements.in -pycodestyle==2.8.0 \ - --hash=sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20 \ - --hash=sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f - # via flake8 -pyflakes==2.4.0 \ - --hash=sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c \ - --hash=sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e - # via flake8 -pyparsing==3.0.3 \ - --hash=sha256:9e3511118010f112a4b4b435ae50e1eaa610cda191acb9e421d60cf5fde83455 \ - --hash=sha256:f8d3fe9fc404576c5164f0f0c4e382c96b85265e023c409c43d48f65da9d60d0 - # via packaging -pytest==6.2.5 \ - --hash=sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89 \ - --hash=sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134 - # via -r requirements.in -python-dateutil==2.8.2 \ - --hash=sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86 \ - --hash=sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 - # via pandas -pytz==2021.3 \ - --hash=sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c \ - --hash=sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326 - # via pandas -regex==2021.10.23 \ - --hash=sha256:0c186691a7995ef1db61205e00545bf161fb7b59cdb8c1201c89b333141c438a \ - --hash=sha256:0dcc0e71118be8c69252c207630faf13ca5e1b8583d57012aae191e7d6d28b84 \ - --hash=sha256:0f7552429dd39f70057ac5d0e897e5bfe211629652399a21671e53f2a9693a4e \ - --hash=sha256:129472cd06062fb13e7b4670a102951a3e655e9b91634432cfbdb7810af9d710 \ - --hash=sha256:13ec99df95003f56edcd307db44f06fbeb708c4ccdcf940478067dd62353181e \ - --hash=sha256:1f2b59c28afc53973d22e7bc18428721ee8ca6079becf1b36571c42627321c65 \ - --hash=sha256:2b20f544cbbeffe171911f6ce90388ad36fe3fad26b7c7a35d4762817e9ea69c \ - --hash=sha256:2fb698037c35109d3c2e30f2beb499e5ebae6e4bb8ff2e60c50b9a805a716f79 \ - --hash=sha256:34d870f9f27f2161709054d73646fc9aca49480617a65533fc2b4611c518e455 \ - --hash=sha256:391703a2abf8013d95bae39145d26b4e21531ab82e22f26cd3a181ee2644c234 \ - --hash=sha256:450dc27483548214314640c89a0f275dbc557968ed088da40bde7ef8fb52829e \ - --hash=sha256:45b65d6a275a478ac2cbd7fdbf7cc93c1982d613de4574b56fd6972ceadb8395 \ - --hash=sha256:5095a411c8479e715784a0c9236568ae72509450ee2226b649083730f3fadfc6 \ - --hash=sha256:530fc2bbb3dc1ebb17f70f7b234f90a1dd43b1b489ea38cea7be95fb21cdb5c7 \ - --hash=sha256:56f0c81c44638dfd0e2367df1a331b4ddf2e771366c4b9c5d9a473de75e3e1c7 \ - --hash=sha256:5e9c9e0ce92f27cef79e28e877c6b6988c48b16942258f3bc55d39b5f911df4f \ - --hash=sha256:6d7722136c6ed75caf84e1788df36397efdc5dbadab95e59c2bba82d4d808a4c \ - --hash=sha256:74d071dbe4b53c602edd87a7476ab23015a991374ddb228d941929ad7c8c922e \ - --hash=sha256:7b568809dca44cb75c8ebb260844ea98252c8c88396f9d203f5094e50a70355f \ - --hash=sha256:80bb5d2e92b2258188e7dcae5b188c7bf868eafdf800ea6edd0fbfc029984a88 \ - --hash=sha256:8d1cdcda6bd16268316d5db1038965acf948f2a6f43acc2e0b1641ceab443623 \ - --hash=sha256:9f665677e46c5a4d288ece12fdedf4f4204a422bb28ff05f0e6b08b7447796d1 \ - --hash=sha256:a30513828180264294953cecd942202dfda64e85195ae36c265daf4052af0464 \ - --hash=sha256:a7a986c45d1099a5de766a15de7bee3840b1e0e1a344430926af08e5297cf666 \ - --hash=sha256:a940ca7e7189d23da2bfbb38973832813eab6bd83f3bf89a977668c2f813deae \ - --hash=sha256:ab7c5684ff3538b67df3f93d66bd3369b749087871ae3786e70ef39e601345b0 \ - --hash=sha256:be04739a27be55631069b348dda0c81d8ea9822b5da10b8019b789e42d1fe452 \ - --hash=sha256:c0938ddd60cc04e8f1faf7a14a166ac939aac703745bfcd8e8f20322a7373019 \ - --hash=sha256:cb46b542133999580ffb691baf67410306833ee1e4f58ed06b6a7aaf4e046952 \ - --hash=sha256:d134757a37d8640f3c0abb41f5e68b7cf66c644f54ef1cb0573b7ea1c63e1509 \ - --hash=sha256:de557502c3bec8e634246588a94e82f1ee1b9dfcfdc453267c4fb652ff531570 \ - --hash=sha256:ded0c4a3eee56b57fcb2315e40812b173cafe79d2f992d50015f4387445737fa \ - --hash=sha256:e1dae12321b31059a1a72aaa0e6ba30156fe7e633355e445451e4021b8e122b6 \ - --hash=sha256:eb672217f7bd640411cfc69756ce721d00ae600814708d35c930930f18e8029f \ - --hash=sha256:ee684f139c91e69fe09b8e83d18b4d63bf87d9440c1eb2eeb52ee851883b1b29 \ - --hash=sha256:f3f9a91d3cc5e5b0ddf1043c0ae5fa4852f18a1c0050318baf5fc7930ecc1f9c - # via black -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -toml==0.10.2 \ - --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ - --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f - # via - # -r requirements.in - # maturin - # mypy - # pytest -tomli==1.2.2 \ - --hash=sha256:c6ce0015eb38820eaf32b5db832dbc26deb3dd427bd5f6556cf0acac2c214fee \ - --hash=sha256:f04066f68f5554911363063a30b108d2b5a5b1a010aa8b6132af78489fe3aade - # via black -typing-extensions==3.10.0.2 \ - --hash=sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e \ - --hash=sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7 \ - --hash=sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34 - # via - # black - # mypy diff --git a/python/rust-toolchain b/python/rust-toolchain deleted file mode 100644 index 12b27c03a24a..000000000000 --- a/python/rust-toolchain +++ /dev/null @@ -1 +0,0 @@ -nightly-2021-10-23 diff --git a/python/src/catalog.rs b/python/src/catalog.rs deleted file mode 100644 index f93c795ec34c..000000000000 --- a/python/src/catalog.rs +++ /dev/null @@ -1,123 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashSet; -use std::sync::Arc; - -use pyo3::exceptions::PyKeyError; -use pyo3::prelude::*; - -use datafusion::{ - arrow::pyarrow::PyArrowConvert, - catalog::{catalog::CatalogProvider, schema::SchemaProvider}, - datasource::{TableProvider, TableType}, -}; - -#[pyclass(name = "Catalog", module = "datafusion", subclass)] -pub(crate) struct PyCatalog { - catalog: Arc, -} - -#[pyclass(name = "Database", module = "datafusion", subclass)] -pub(crate) struct PyDatabase { - database: Arc, -} - -#[pyclass(name = "Table", module = "datafusion", subclass)] -pub(crate) struct PyTable { - table: Arc, -} - -impl PyCatalog { - pub fn new(catalog: Arc) -> Self { - Self { catalog } - } -} - -impl PyDatabase { - pub fn new(database: Arc) -> Self { - Self { database } - } -} - -impl PyTable { - pub fn new(table: Arc) -> Self { - Self { table } - } -} - -#[pymethods] -impl PyCatalog { - fn names(&self) -> Vec { - self.catalog.schema_names() - } - - #[args(name = "\"public\"")] - fn database(&self, name: &str) -> PyResult { - match self.catalog.schema(name) { - Some(database) => Ok(PyDatabase::new(database)), - None => Err(PyKeyError::new_err(format!( - "Database with name {} doesn't exist.", - name - ))), - } - } -} - -#[pymethods] -impl PyDatabase { - fn names(&self) -> HashSet { - self.database.table_names().into_iter().collect() - } - - fn table(&self, name: &str) -> PyResult { - match self.database.table(name) { - Some(table) => Ok(PyTable::new(table)), - None => Err(PyKeyError::new_err(format!( - "Table with name {} doesn't exist.", - name - ))), - } - } - - // register_table - // deregister_table -} - -#[pymethods] -impl PyTable { - /// Get a reference to the schema for this table - #[getter] - fn schema(&self, py: Python) -> PyResult { - self.table.schema().to_pyarrow(py) - } - - /// Get the type of this table for metadata/catalog purposes. - #[getter] - fn kind(&self) -> &str { - match self.table.table_type() { - TableType::Base => "physical", - TableType::View => "view", - TableType::Temporary => "temporary", - } - } - - // fn scan - // fn statistics - // fn has_exact_statistics - // fn supports_filter_pushdown -} diff --git a/python/src/context.rs b/python/src/context.rs deleted file mode 100644 index 7f386bac398d..000000000000 --- a/python/src/context.rs +++ /dev/null @@ -1,173 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::path::PathBuf; -use std::{collections::HashSet, sync::Arc}; - -use uuid::Uuid; - -use pyo3::exceptions::{PyKeyError, PyValueError}; -use pyo3::prelude::*; - -use datafusion::arrow::datatypes::Schema; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext; -use datafusion::prelude::CsvReadOptions; - -use crate::catalog::PyCatalog; -use crate::dataframe::PyDataFrame; -use crate::errors::DataFusionError; -use crate::udf::PyScalarUDF; -use crate::utils::wait_for_future; - -/// `PyExecutionContext` is able to plan and execute DataFusion plans. -/// It has a powerful optimizer, a physical planner for local execution, and a -/// multi-threaded execution engine to perform the execution. -#[pyclass(name = "ExecutionContext", module = "datafusion", subclass, unsendable)] -pub(crate) struct PyExecutionContext { - ctx: ExecutionContext, -} - -#[pymethods] -impl PyExecutionContext { - // TODO(kszucs): should expose the configuration options as keyword arguments - #[new] - fn new() -> Self { - PyExecutionContext { - ctx: ExecutionContext::new(), - } - } - - /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str, py: Python) -> PyResult { - let result = self.ctx.sql(query); - let df = wait_for_future(py, result).map_err(DataFusionError::from)?; - Ok(PyDataFrame::new(df)) - } - - fn create_dataframe( - &mut self, - partitions: Vec>, - ) -> PyResult { - let table = MemTable::try_new(partitions[0][0].schema(), partitions) - .map_err(DataFusionError::from)?; - - // generate a random (unique) name for this table - // table name cannot start with numeric digit - let name = "c".to_owned() - + &Uuid::new_v4() - .to_simple() - .encode_lower(&mut Uuid::encode_buffer()); - - self.ctx - .register_table(&*name, Arc::new(table)) - .map_err(DataFusionError::from)?; - let table = self.ctx.table(&*name).map_err(DataFusionError::from)?; - - let df = PyDataFrame::new(table); - Ok(df) - } - - fn register_record_batches( - &mut self, - name: &str, - partitions: Vec>, - ) -> PyResult<()> { - let schema = partitions[0][0].schema(); - let table = MemTable::try_new(schema, partitions)?; - self.ctx - .register_table(name, Arc::new(table)) - .map_err(DataFusionError::from)?; - Ok(()) - } - - fn register_parquet(&mut self, name: &str, path: &str, py: Python) -> PyResult<()> { - let result = self.ctx.register_parquet(name, path); - wait_for_future(py, result).map_err(DataFusionError::from)?; - Ok(()) - } - - #[args( - schema = "None", - has_header = "true", - delimiter = "\",\"", - schema_infer_max_records = "1000", - file_extension = "\".csv\"" - )] - fn register_csv( - &mut self, - name: &str, - path: PathBuf, - schema: Option, - has_header: bool, - delimiter: &str, - schema_infer_max_records: usize, - file_extension: &str, - py: Python, - ) -> PyResult<()> { - let path = path - .to_str() - .ok_or(PyValueError::new_err("Unable to convert path to a string"))?; - let delimiter = delimiter.as_bytes(); - if delimiter.len() != 1 { - return Err(PyValueError::new_err( - "Delimiter must be a single character", - )); - } - - let mut options = CsvReadOptions::new() - .has_header(has_header) - .delimiter(delimiter[0]) - .schema_infer_max_records(schema_infer_max_records) - .file_extension(file_extension); - options.schema = schema.as_ref(); - - let result = self.ctx.register_csv(name, path, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; - - Ok(()) - } - - fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { - self.ctx.register_udf(udf.function); - Ok(()) - } - - #[args(name = "\"datafusion\"")] - fn catalog(&self, name: &str) -> PyResult { - match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::new(catalog)), - None => Err(PyKeyError::new_err(format!( - "Catalog with name {} doesn't exist.", - &name - ))), - } - } - - fn tables(&self) -> HashSet { - self.ctx.tables().unwrap() - } - - fn table(&self, name: &str) -> PyResult { - Ok(PyDataFrame::new(self.ctx.table(name)?)) - } - - fn empty_table(&self) -> PyResult { - Ok(PyDataFrame::new(self.ctx.read_empty()?)) - } -} diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index d66b50dd75b9..8b137891791f 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -1,130 +1 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use std::sync::Arc; - -use pyo3::prelude::*; - -use datafusion::arrow::datatypes::Schema; -use datafusion::arrow::io::print; -use datafusion::arrow::pyarrow::PyArrowConvert; -use datafusion::dataframe::DataFrame; -use datafusion::logical_plan::JoinType; - -use crate::utils::wait_for_future; -use crate::{errors::DataFusionError, expression::PyExpr}; - -/// A PyDataFrame is a representation of a logical plan and an API to compose statements. -/// Use it to build a plan and `.collect()` to execute the plan and collect the result. -/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass(name = "DataFrame", module = "datafusion", subclass)] -#[derive(Clone)] -pub(crate) struct PyDataFrame { - df: Arc, -} - -impl PyDataFrame { - /// creates a new PyDataFrame - pub fn new(df: Arc) -> Self { - Self { df } - } -} - -#[pymethods] -impl PyDataFrame { - /// Returns the schema from the logical plan - fn schema(&self) -> Schema { - self.df.schema().into() - } - - #[args(args = "*")] - fn select(&self, args: Vec) -> PyResult { - let expr = args.into_iter().map(|e| e.into()).collect(); - let df = self.df.select(expr)?; - Ok(Self::new(df)) - } - - fn filter(&self, predicate: PyExpr) -> PyResult { - let df = self.df.filter(predicate.into())?; - Ok(Self::new(df)) - } - - fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyResult { - let group_by = group_by.into_iter().map(|e| e.into()).collect(); - let aggs = aggs.into_iter().map(|e| e.into()).collect(); - let df = self.df.aggregate(group_by, aggs)?; - Ok(Self::new(df)) - } - - #[args(exprs = "*")] - fn sort(&self, exprs: Vec) -> PyResult { - let exprs = exprs.into_iter().map(|e| e.into()).collect(); - let df = self.df.sort(exprs)?; - Ok(Self::new(df)) - } - - fn limit(&self, count: usize) -> PyResult { - let df = self.df.limit(count)?; - Ok(Self::new(df)) - } - - /// Executes the plan, returning a list of `RecordBatch`es. - /// Unless some order is specified in the plan, there is no - /// guarantee of the order of the result. - fn collect(&self, py: Python) -> PyResult> { - let batches = wait_for_future(py, self.df.collect())?; - // cannot use PyResult> return type due to - // https://github.com/PyO3/pyo3/issues/1813 - batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() - } - - /// Print the result, 20 lines by default - #[args(num = "20")] - fn show(&self, py: Python, num: usize) -> PyResult<()> { - let df = self.df.limit(num)?; - let batches = wait_for_future(py, df.collect())?; - Ok(print::print(&batches)) - } - - fn join( - &self, - right: PyDataFrame, - join_keys: (Vec<&str>, Vec<&str>), - how: &str, - ) -> PyResult { - let join_type = match how { - "inner" => JoinType::Inner, - "left" => JoinType::Left, - "right" => JoinType::Right, - "full" => JoinType::Full, - "semi" => JoinType::Semi, - "anti" => JoinType::Anti, - how => { - return Err(DataFusionError::Common(format!( - "The join type {} does not exist or is not implemented", - how - )) - .into()) - } - }; - - let df = self - .df - .join(right.df, join_type, &join_keys.0, &join_keys.1)?; - Ok(Self::new(df)) - } -} diff --git a/python/src/errors.rs b/python/src/errors.rs deleted file mode 100644 index 655ed8441cb4..000000000000 --- a/python/src/errors.rs +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use core::fmt; - -use datafusion::arrow::error::ArrowError; -use datafusion::error::DataFusionError as InnerDataFusionError; -use pyo3::{exceptions::PyException, PyErr}; - -#[derive(Debug)] -pub enum DataFusionError { - ExecutionError(InnerDataFusionError), - ArrowError(ArrowError), - Common(String), -} - -impl fmt::Display for DataFusionError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {:?}", e), - DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}", e), - DataFusionError::Common(e) => write!(f, "{}", e), - } - } -} - -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) - } -} - -impl From for DataFusionError { - fn from(err: InnerDataFusionError) -> DataFusionError { - DataFusionError::ExecutionError(err) - } -} - -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} diff --git a/python/src/expression.rs b/python/src/expression.rs deleted file mode 100644 index 5e1cad246bf8..000000000000 --- a/python/src/expression.rs +++ /dev/null @@ -1,135 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; -use std::convert::{From, Into}; - -use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::{col, lit, Expr}; - -use datafusion::scalar::ScalarValue; - -/// An PyExpr that can be used on a DataFrame -#[pyclass(name = "Expression", module = "datafusion", subclass)] -#[derive(Debug, Clone)] -pub(crate) struct PyExpr { - pub(crate) expr: Expr, -} - -impl From for Expr { - fn from(expr: PyExpr) -> Expr { - expr.expr - } -} - -impl Into for Expr { - fn into(self) -> PyExpr { - PyExpr { expr: self } - } -} - -#[pyproto] -impl PyNumberProtocol for PyExpr { - fn __add__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr + rhs.expr).into()) - } - - fn __sub__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr - rhs.expr).into()) - } - - fn __truediv__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr / rhs.expr).into()) - } - - fn __mul__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok((lhs.expr * rhs.expr).into()) - } - - fn __mod__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok(lhs.expr.clone().modulus(rhs.expr).into()) - } - - fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok(lhs.expr.clone().and(rhs.expr).into()) - } - - fn __or__(lhs: PyExpr, rhs: PyExpr) -> PyResult { - Ok(lhs.expr.clone().or(rhs.expr).into()) - } - - fn __invert__(&self) -> PyResult { - Ok(self.expr.clone().not().into()) - } -} - -#[pyproto] -impl PyObjectProtocol for PyExpr { - fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr { - let expr = match op { - CompareOp::Lt => self.expr.clone().lt(other.expr), - CompareOp::Le => self.expr.clone().lt_eq(other.expr), - CompareOp::Eq => self.expr.clone().eq(other.expr), - CompareOp::Ne => self.expr.clone().not_eq(other.expr), - CompareOp::Gt => self.expr.clone().gt(other.expr), - CompareOp::Ge => self.expr.clone().gt_eq(other.expr), - }; - expr.into() - } - - fn __str__(&self) -> PyResult { - Ok(format!("{}", self.expr)) - } -} - -#[pymethods] -impl PyExpr { - #[staticmethod] - pub fn literal(value: ScalarValue) -> PyExpr { - lit(value).into() - } - - #[staticmethod] - pub fn column(value: &str) -> PyExpr { - col(value).into() - } - - /// assign a name to the PyExpr - pub fn alias(&self, name: &str) -> PyExpr { - self.expr.clone().alias(name).into() - } - - /// Create a sort PyExpr from an existing PyExpr. - #[args(ascending = true, nulls_first = true)] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyExpr { - self.expr.clone().sort(ascending, nulls_first).into() - } - - pub fn is_null(&self) -> PyExpr { - self.expr.clone().is_null().into() - } - - pub fn cast(&self, to: DataType) -> PyExpr { - // self.expr.cast_to() requires DFSchema to validate that the cast - // is supported, omit that for now - let expr = Expr::Cast { - expr: Box::new(self.expr.clone()), - data_type: to, - }; - expr.into() - } -} diff --git a/python/src/functions.rs b/python/src/functions.rs deleted file mode 100644 index c0b4e5989012..000000000000 --- a/python/src/functions.rs +++ /dev/null @@ -1,343 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::{prelude::*, wrap_pyfunction}; - -use datafusion::logical_plan; - -use datafusion::physical_plan::{ - aggregates::AggregateFunction, functions::BuiltinScalarFunction, -}; - -use crate::errors; -use crate::expression::PyExpr; - -#[pyfunction] -fn array(value: Vec) -> PyExpr { - PyExpr { - expr: logical_plan::array(value.into_iter().map(|x| x.expr).collect::>()), - } -} - -#[pyfunction] -fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { - logical_plan::in_list( - expr.expr, - value.into_iter().map(|x| x.expr).collect::>(), - negated, - ) - .into() -} - -/// Current date and time -#[pyfunction] -fn now() -> PyExpr { - PyExpr { - // here lit(0) is a stub for conform to arity - expr: logical_plan::now(logical_plan::lit(0)), - } -} - -/// Returns a random value in the range 0.0 <= x < 1.0 -#[pyfunction] -fn random() -> PyExpr { - PyExpr { - expr: logical_plan::random(), - } -} - -/// Computes a binary hash of the given data. type is the algorithm to use. -/// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. -#[pyfunction(value, method)] -fn digest(value: PyExpr, method: PyExpr) -> PyExpr { - PyExpr { - expr: logical_plan::digest(value.expr, method.expr), - } -} - -/// Concatenates the text representations of all the arguments. -/// NULL arguments are ignored. -#[pyfunction(args = "*")] -fn concat(args: Vec) -> PyResult { - let args = args.into_iter().map(|e| e.expr).collect::>(); - Ok(logical_plan::concat(&args).into()) -} - -/// Concatenates all but the first argument, with separators. -/// The first argument is used as the separator string, and should not be NULL. -/// Other NULL arguments are ignored. -#[pyfunction(sep, args = "*")] -fn concat_ws(sep: String, args: Vec) -> PyResult { - let args = args.into_iter().map(|e| e.expr).collect::>(); - Ok(logical_plan::concat_ws(sep, &args).into()) -} - -/// Creates a new Sort expression -#[pyfunction] -fn order_by( - expr: PyExpr, - asc: Option, - nulls_first: Option, -) -> PyResult { - Ok(PyExpr { - expr: datafusion::logical_plan::Expr::Sort { - expr: Box::new(expr.expr), - asc: asc.unwrap_or(true), - nulls_first: nulls_first.unwrap_or(true), - }, - }) -} - -/// Creates a new Alias expression -#[pyfunction] -fn alias(expr: PyExpr, name: &str) -> PyResult { - Ok(PyExpr { - expr: datafusion::logical_plan::Expr::Alias( - Box::new(expr.expr), - String::from(name), - ), - }) -} - -/// Creates a new Window function expression -#[pyfunction] -fn window( - name: &str, - args: Vec, - partition_by: Option>, - order_by: Option>, -) -> PyResult { - use std::str::FromStr; - let fun = datafusion::physical_plan::window_functions::WindowFunction::from_str(name) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - Ok(PyExpr { - expr: datafusion::logical_plan::Expr::WindowFunction { - fun, - args: args.into_iter().map(|x| x.expr).collect::>(), - partition_by: partition_by - .unwrap_or(vec![]) - .into_iter() - .map(|x| x.expr) - .collect::>(), - order_by: order_by - .unwrap_or(vec![]) - .into_iter() - .map(|x| x.expr) - .collect::>(), - window_frame: None, - }, - }) -} - -macro_rules! scalar_function { - ($NAME: ident, $FUNC: ident) => { - scalar_function!($NAME, $FUNC, stringify!($NAME)); - }; - ($NAME: ident, $FUNC: ident, $DOC: expr) => { - #[doc = $DOC] - #[pyfunction(args = "*")] - fn $NAME(args: Vec) -> PyExpr { - let expr = logical_plan::Expr::ScalarFunction { - fun: BuiltinScalarFunction::$FUNC, - args: args.into_iter().map(|e| e.into()).collect(), - }; - expr.into() - } - }; -} - -macro_rules! aggregate_function { - ($NAME: ident, $FUNC: ident) => { - aggregate_function!($NAME, $FUNC, stringify!($NAME)); - }; - ($NAME: ident, $FUNC: ident, $DOC: expr) => { - #[doc = $DOC] - #[pyfunction(args = "*", distinct = "false")] - fn $NAME(args: Vec, distinct: bool) -> PyExpr { - let expr = logical_plan::Expr::AggregateFunction { - fun: AggregateFunction::$FUNC, - args: args.into_iter().map(|e| e.into()).collect(), - distinct, - }; - expr.into() - } - }; -} - -scalar_function!(abs, Abs); -scalar_function!(acos, Acos); -scalar_function!(ascii, Ascii, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character."); -scalar_function!(asin, Asin); -scalar_function!(atan, Atan); -scalar_function!( - bit_length, - BitLength, - "Returns number of bits in the string (8 times the octet_length)." -); -scalar_function!(btrim, Btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string."); -scalar_function!(ceil, Ceil); -scalar_function!( - character_length, - CharacterLength, - "Returns number of characters in the string." -); -scalar_function!(chr, Chr, "Returns the character with the given code."); -scalar_function!(cos, Cos); -scalar_function!(exp, Exp); -scalar_function!(floor, Floor); -scalar_function!(initcap, InitCap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); -scalar_function!(left, Left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); -scalar_function!(ln, Ln); -scalar_function!(log10, Log10); -scalar_function!(log2, Log2); -scalar_function!(lower, Lower, "Converts the string to all lower case"); -scalar_function!(lpad, Lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); -scalar_function!(ltrim, Ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string."); -scalar_function!( - md5, - MD5, - "Computes the MD5 hash of the argument, with the result written in hexadecimal." -); -scalar_function!(octet_length, OctetLength, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); -scalar_function!(regexp_match, RegexpMatch); -scalar_function!( - regexp_replace, - RegexpReplace, - "Replaces substring(s) matching a POSIX regular expression" -); -scalar_function!( - repeat, - Repeat, - "Repeats string the specified number of times." -); -scalar_function!( - replace, - Replace, - "Replaces all occurrences in string of substring from with substring to." -); -scalar_function!( - reverse, - Reverse, - "Reverses the order of the characters in the string." -); -scalar_function!(right, Right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); -scalar_function!(round, Round); -scalar_function!(rpad, Rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); -scalar_function!(rtrim, Rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string."); -scalar_function!(sha224, SHA224); -scalar_function!(sha256, SHA256); -scalar_function!(sha384, SHA384); -scalar_function!(sha512, SHA512); -scalar_function!(signum, Signum); -scalar_function!(sin, Sin); -scalar_function!(split_part, SplitPart, "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."); -scalar_function!(sqrt, Sqrt); -scalar_function!( - starts_with, - StartsWith, - "Returns true if string starts with prefix." -); -scalar_function!(strpos, Strpos, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); -scalar_function!(substr, Substr); -scalar_function!(tan, Tan); -scalar_function!( - to_hex, - ToHex, - "Converts the number to its equivalent hexadecimal representation." -); -scalar_function!(to_timestamp, ToTimestamp); -scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); -scalar_function!(trim, Trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); -scalar_function!(trunc, Trunc); -scalar_function!(upper, Upper, "Converts the string to all upper case."); - -aggregate_function!(avg, Avg); -aggregate_function!(count, Count); -aggregate_function!(max, Max); -aggregate_function!(min, Min); -aggregate_function!(sum, Sum); -aggregate_function!(approx_distinct, ApproxDistinct); - -pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(abs))?; - m.add_wrapped(wrap_pyfunction!(acos))?; - m.add_wrapped(wrap_pyfunction!(approx_distinct))?; - m.add_wrapped(wrap_pyfunction!(alias))?; - m.add_wrapped(wrap_pyfunction!(array))?; - m.add_wrapped(wrap_pyfunction!(ascii))?; - m.add_wrapped(wrap_pyfunction!(asin))?; - m.add_wrapped(wrap_pyfunction!(atan))?; - m.add_wrapped(wrap_pyfunction!(avg))?; - m.add_wrapped(wrap_pyfunction!(bit_length))?; - m.add_wrapped(wrap_pyfunction!(btrim))?; - m.add_wrapped(wrap_pyfunction!(ceil))?; - m.add_wrapped(wrap_pyfunction!(character_length))?; - m.add_wrapped(wrap_pyfunction!(chr))?; - m.add_wrapped(wrap_pyfunction!(concat_ws))?; - m.add_wrapped(wrap_pyfunction!(concat))?; - m.add_wrapped(wrap_pyfunction!(cos))?; - m.add_wrapped(wrap_pyfunction!(count))?; - m.add_wrapped(wrap_pyfunction!(digest))?; - m.add_wrapped(wrap_pyfunction!(exp))?; - m.add_wrapped(wrap_pyfunction!(floor))?; - m.add_wrapped(wrap_pyfunction!(in_list))?; - m.add_wrapped(wrap_pyfunction!(initcap))?; - m.add_wrapped(wrap_pyfunction!(left))?; - m.add_wrapped(wrap_pyfunction!(ln))?; - m.add_wrapped(wrap_pyfunction!(log10))?; - m.add_wrapped(wrap_pyfunction!(log2))?; - m.add_wrapped(wrap_pyfunction!(lower))?; - m.add_wrapped(wrap_pyfunction!(lpad))?; - m.add_wrapped(wrap_pyfunction!(ltrim))?; - m.add_wrapped(wrap_pyfunction!(max))?; - m.add_wrapped(wrap_pyfunction!(md5))?; - m.add_wrapped(wrap_pyfunction!(min))?; - m.add_wrapped(wrap_pyfunction!(now))?; - m.add_wrapped(wrap_pyfunction!(octet_length))?; - m.add_wrapped(wrap_pyfunction!(order_by))?; - m.add_wrapped(wrap_pyfunction!(random))?; - m.add_wrapped(wrap_pyfunction!(regexp_match))?; - m.add_wrapped(wrap_pyfunction!(regexp_replace))?; - m.add_wrapped(wrap_pyfunction!(repeat))?; - m.add_wrapped(wrap_pyfunction!(replace))?; - m.add_wrapped(wrap_pyfunction!(reverse))?; - m.add_wrapped(wrap_pyfunction!(right))?; - m.add_wrapped(wrap_pyfunction!(round))?; - m.add_wrapped(wrap_pyfunction!(rpad))?; - m.add_wrapped(wrap_pyfunction!(rtrim))?; - m.add_wrapped(wrap_pyfunction!(sha224))?; - m.add_wrapped(wrap_pyfunction!(sha256))?; - m.add_wrapped(wrap_pyfunction!(sha384))?; - m.add_wrapped(wrap_pyfunction!(sha512))?; - m.add_wrapped(wrap_pyfunction!(signum))?; - m.add_wrapped(wrap_pyfunction!(sin))?; - m.add_wrapped(wrap_pyfunction!(split_part))?; - m.add_wrapped(wrap_pyfunction!(sqrt))?; - m.add_wrapped(wrap_pyfunction!(starts_with))?; - m.add_wrapped(wrap_pyfunction!(strpos))?; - m.add_wrapped(wrap_pyfunction!(substr))?; - m.add_wrapped(wrap_pyfunction!(sum))?; - m.add_wrapped(wrap_pyfunction!(tan))?; - m.add_wrapped(wrap_pyfunction!(to_hex))?; - m.add_wrapped(wrap_pyfunction!(to_timestamp))?; - m.add_wrapped(wrap_pyfunction!(translate))?; - m.add_wrapped(wrap_pyfunction!(trim))?; - m.add_wrapped(wrap_pyfunction!(trunc))?; - m.add_wrapped(wrap_pyfunction!(upper))?; - m.add_wrapped(wrap_pyfunction!(window))?; - Ok(()) -} diff --git a/python/src/lib.rs b/python/src/lib.rs deleted file mode 100644 index d40bae251c86..000000000000 --- a/python/src/lib.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -mod catalog; -mod context; -mod dataframe; -mod errors; -mod expression; -mod functions; -mod udaf; -mod udf; -mod utils; - -/// Low-level DataFusion internal package. -/// -/// The higher-level public API is defined in pure python files under the -/// datafusion directory. -#[pymodule] -fn _internal(py: Python, m: &PyModule) -> PyResult<()> { - // Register the python classes - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - // Register the functions as a submodule - let funcs = PyModule::new(py, "functions")?; - functions::init_module(funcs)?; - m.add_submodule(funcs)?; - - Ok(()) -} diff --git a/python/src/udaf.rs b/python/src/udaf.rs deleted file mode 100644 index 1de6e63205ed..000000000000 --- a/python/src/udaf.rs +++ /dev/null @@ -1,153 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::PyArrowConvert; -use datafusion::error::{DataFusionError, Result}; -use datafusion::logical_plan; -use datafusion::physical_plan::aggregates::AccumulatorFunctionImplementation; -use datafusion::physical_plan::udaf::AggregateUDF; -use datafusion::physical_plan::Accumulator; -use datafusion::scalar::ScalarValue; - -use crate::expression::PyExpr; -use crate::utils::parse_volatility; - -#[derive(Debug)] -struct RustAccumulator { - accum: PyObject, -} - -impl RustAccumulator { - fn new(accum: PyObject) -> Self { - Self { accum } - } -} - -impl Accumulator for RustAccumulator { - fn state(&self) -> Result> { - Python::with_gil(|py| self.accum.as_ref(py).call_method0("state")?.extract()) - .map_err(|e| DataFusionError::Execution(format!("{}", e))) - } - - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - - fn evaluate(&self) -> Result { - Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) - .map_err(|e| DataFusionError::Execution(format!("{}", e))) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - Python::with_gil(|py| { - // 1. cast args to Pyarrow array - let py_args = values - .iter() - .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) - .collect::>(); - let py_args = PyTuple::new(py, py_args); - - // 2. call function - self.accum - .as_ref(py) - .call_method1("update", py_args) - .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; - - Ok(()) - }) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - Python::with_gil(|py| { - let state = &states[0]; - - // 1. cast states to Pyarrow array - let state = state - .to_pyarrow(py) - .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; - - // 2. call merge - self.accum - .as_ref(py) - .call_method1("merge", (state,)) - .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; - - Ok(()) - }) - } -} - -pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFunctionImplementation { - Arc::new(move || -> Result> { - let accum = Python::with_gil(|py| { - accum - .call0(py) - .map_err(|e| DataFusionError::Execution(format!("{}", e))) - })?; - Ok(Box::new(RustAccumulator::new(accum))) - }) -} - -/// Represents a AggregateUDF -#[pyclass(name = "AggregateUDF", module = "datafusion", subclass)] -#[derive(Debug, Clone)] -pub struct PyAggregateUDF { - pub(crate) function: AggregateUDF, -} - -#[pymethods] -impl PyAggregateUDF { - #[new(name, accumulator, input_type, return_type, state_type, volatility)] - fn new( - name: &str, - accumulator: PyObject, - input_type: DataType, - return_type: DataType, - state_type: Vec, - volatility: &str, - ) -> PyResult { - let function = logical_plan::create_udaf( - &name, - input_type, - Arc::new(return_type), - parse_volatility(volatility)?, - to_rust_accumulator(accumulator), - Arc::new(state_type), - ); - Ok(Self { function }) - } - - /// creates a new PyExpr with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: Vec) -> PyResult { - let args = args.iter().map(|e| e.expr.clone()).collect(); - Ok(self.function.call(args).into()) - } -} diff --git a/python/src/udf.rs b/python/src/udf.rs deleted file mode 100644 index 379c449870b2..000000000000 --- a/python/src/udf.rs +++ /dev/null @@ -1,98 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::PyArrowConvert; -use datafusion::error::DataFusionError; -use datafusion::logical_plan; -use datafusion::physical_plan::functions::{ - make_scalar_function, ScalarFunctionImplementation, -}; -use datafusion::physical_plan::udf::ScalarUDF; - -use crate::expression::PyExpr; -use crate::utils::parse_volatility; - -/// Create a DataFusion's UDF implementation from a python function -/// that expects pyarrow arrays. This is more efficient as it performs -/// a zero-copy of the contents. -fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation { - make_scalar_function( - move |args: &[ArrayRef]| -> Result { - Python::with_gil(|py| { - // 1. cast args to Pyarrow arrays - let py_args = args - .iter() - .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) - .collect::>(); - let py_args = PyTuple::new(py, py_args); - - // 2. call function - let value = func.as_ref(py).call(py_args, None); - let value = match value { - Ok(n) => Ok(n), - Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), - }?; - - // 3. cast to arrow::array::Array - let array = ArrayRef::from_pyarrow(value).unwrap(); - Ok(array) - }) - }, - ) -} - -/// Represents a PyScalarUDF -#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] -#[derive(Debug, Clone)] -pub struct PyScalarUDF { - pub(crate) function: ScalarUDF, -} - -#[pymethods] -impl PyScalarUDF { - #[new(name, func, input_types, return_type, volatility)] - fn new( - name: &str, - func: PyObject, - input_types: Vec, - return_type: DataType, - volatility: &str, - ) -> PyResult { - let function = logical_plan::create_udf( - name, - input_types, - Arc::new(return_type), - parse_volatility(volatility)?, - to_rust_function(func), - ); - Ok(Self { function }) - } - - /// creates a new PyExpr with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: Vec) -> PyResult { - let args = args.iter().map(|e| e.expr.clone()).collect(); - Ok(self.function.call(args).into()) - } -} diff --git a/python/src/utils.rs b/python/src/utils.rs deleted file mode 100644 index c8e1c63b1d0f..000000000000 --- a/python/src/utils.rs +++ /dev/null @@ -1,50 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::future::Future; - -use pyo3::prelude::*; -use tokio::runtime::Runtime; - -use datafusion::physical_plan::functions::Volatility; - -use crate::errors::DataFusionError; - -/// Utility to collect rust futures with GIL released -pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output -where - F: Send, - F::Output: Send, -{ - let rt = Runtime::new().unwrap(); - py.allow_threads(|| rt.block_on(f)) -} - -pub(crate) fn parse_volatility(value: &str) -> Result { - Ok(match value { - "immutable" => Volatility::Immutable, - "stable" => Volatility::Stable, - "volatile" => Volatility::Volatile, - value => { - return Err(DataFusionError::Common(format!( - "Unsupportad volatility type: `{}`, supported \ - values are: immutable, stable and volatile.", - value - ))) - } - }) -}