diff --git a/3rd_party/apache-arrow-adbc/.asf.yaml b/3rd_party/apache-arrow-adbc/.asf.yaml deleted file mode 100644 index ddee959..0000000 --- a/3rd_party/apache-arrow-adbc/.asf.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -github: - description: "Database connectivity API standard and libraries for Apache Arrow" - homepage: https://arrow.apache.org/adbc/ - enabled_merge_buttons: - merge: false - rebase: false - squash: true - features: - issues: true - -notifications: - commits: commits@arrow.apache.org - issues_status: issues@arrow.apache.org - issues: github@arrow.apache.org - pullrequests: github@arrow.apache.org - jira_options: link label worklog - -publish: - whoami: asf-site - subdir: adbc diff --git a/3rd_party/apache-arrow-adbc/.clang-format b/3rd_party/apache-arrow-adbc/.clang-format deleted file mode 100644 index 9448dc8..0000000 --- a/3rd_party/apache-arrow-adbc/.clang-format +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. ---- -BasedOnStyle: Google -ColumnLimit: 90 -DerivePointerAlignment: false -IncludeBlocks: Preserve diff --git a/3rd_party/apache-arrow-adbc/.cmake-format b/3rd_party/apache-arrow-adbc/.cmake-format deleted file mode 100644 index 3e77733..0000000 --- a/3rd_party/apache-arrow-adbc/.cmake-format +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# cmake-format configuration file -# Use `archery lint --cmake-format --fix` to reformat all cmake files in the -# source tree - -# ----------------------------- -# Options affecting formatting. -# ----------------------------- -with section("format"): - # How wide to allow formatted cmake files - line_width = 90 - - # How many spaces to tab for indent - tab_size = 2 - - # If a positional argument group contains more than this many arguments, - # then force it to a vertical layout. - max_pargs_hwrap = 4 - - # If the statement spelling length (including space and parenthesis) is - # smaller than this amount, then force reject nested layouts. - # This value only comes into play when considering whether or not to nest - # arguments below their parent. If the number of characters in the parent - # is less than this value, we will not nest. - min_prefix_chars = 32 - - # If true, separate flow control names from their parentheses with a space - separate_ctrl_name_with_space = False - - # If true, separate function names from parentheses with a space - separate_fn_name_with_space = False - - # If a statement is wrapped to more than one line, than dangle the closing - # parenthesis on it's own line - dangle_parens = False - - # What style line endings to use in the output. - line_ending = 'unix' - - # Format command names consistently as 'lower' or 'upper' case - command_case = 'lower' - - # Format keywords consistently as 'lower' or 'upper' case - keyword_case = 'unchanged' - -# ------------------------------------------------ -# Options affecting comment reflow and formatting. -# ------------------------------------------------ -with section("markup"): - # enable comment markup parsing and reflow - enable_markup = False - - # If comment markup is enabled, don't reflow the first comment block in - # eachlistfile. Use this to preserve formatting of your - # copyright/licensestatements. - first_comment_is_literal = True - - # If comment markup is enabled, don't reflow any comment block which - # matchesthis (regex) pattern. Default is `None` (disabled). - literal_comment_pattern = None diff --git a/3rd_party/apache-arrow-adbc/.env b/3rd_party/apache-arrow-adbc/.env deleted file mode 100644 index 2b73edb..0000000 --- a/3rd_party/apache-arrow-adbc/.env +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# All of the following environment variables are required to set default values -# for the parameters in docker-compose.yml. - -# Default repository to pull and push images from -REPO=apache/arrow-dev - -# different architecture notations -ARCH=amd64 -ARCH_ALIAS=x86_64 -ARCH_SHORT=amd64 -ARCH_CONDA_FORGE=linux_64_ - -# Default versions for various dependencies -JDK=8 -MANYLINUX=2014 -MAVEN=3.5.4 -PYTHON=3.10 -GO=1.19.5 -ARROW_MAJOR_VERSION=12 - -# Used through docker-compose.yml and serves as the default version for the -# ci/scripts/install_vcpkg.sh script. -VCPKG="2871ddd918cecb9cb642bcb9c56897f397283192" diff --git a/3rd_party/apache-arrow-adbc/.flake8 b/3rd_party/apache-arrow-adbc/.flake8 deleted file mode 100644 index cda8c86..0000000 --- a/3rd_party/apache-arrow-adbc/.flake8 +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[flake8] -max-line-length = 88 -extend-ignore = E203 diff --git a/3rd_party/apache-arrow-adbc/.gitattributes b/3rd_party/apache-arrow-adbc/.gitattributes deleted file mode 100644 index 7f39f6a..0000000 --- a/3rd_party/apache-arrow-adbc/.gitattributes +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -c/vendor/* linguist-vendored -python/adbc_driver_flightsql/adbc_driver_flightsql/_static_version.py export-subst -python/adbc_driver_manager/adbc_driver_manager/_static_version.py export-subst -python/adbc_driver_postgresql/adbc_driver_postgresql/_static_version.py export-subst -python/adbc_driver_snowflake/adbc_driver_snowflake/_static_version.py export-subst -python/adbc_driver_sqlite/adbc_driver_sqlite/_static_version.py export-subst diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/dev.yml b/3rd_party/apache-arrow-adbc/.github/workflows/dev.yml deleted file mode 100644 index 3acb30f..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/dev.yml +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Dev - -on: - pull_request: {} - push: {} - -concurrency: - group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - pre-commit: - name: "pre-commit" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - uses: actions/setup-go@v3 - with: - go-version-file: 'go/adbc/go.mod' - check-latest: true - - uses: actions/setup-python@v4 - - name: install golangci-lint - run: | - go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.49.0 - - name: pre-commit (cache) - uses: actions/cache@v3 - with: - path: ~/.cache/pre-commit - key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} - - name: pre-commit (--all-files) - run: | - python -m pip install pre-commit - pre-commit run --show-diff-on-failure --color=always --all-files diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/dev_pr.yml b/3rd_party/apache-arrow-adbc/.github/workflows/dev_pr.yml deleted file mode 100644 index ad629df..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/dev_pr.yml +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Dev PR - -on: - pull_request_target: - types: - - opened - - edited - - synchronize - -permissions: - contents: read - pull-requests: write - -jobs: - process: - name: Process - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - persist-credentials: false - - - name: Check title for Conventional Commits format - if: | - github.event_name == 'pull_request_target' && - (github.event.action == 'opened' || - github.event.action == 'edited') - uses: actions/github-script@v6 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const script = require(`${process.env.GITHUB_WORKSPACE}/.github/workflows/dev_pr/title_check.js`); - script({github, context}); diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/dev_pr/title_check.js b/3rd_party/apache-arrow-adbc/.github/workflows/dev_pr/title_check.js deleted file mode 100644 index 2f9f1c2..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/dev_pr/title_check.js +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -const COMMIT_TYPES = [ - 'build', - 'chore', - 'ci', - 'docs', - 'feat', - 'fix', - 'perf', - 'refactor', - 'revert', - 'style', - 'test', -]; - -const COMMENT_BODY = ":warning: Please follow the [Conventional Commits format in CONTRIBUTING.md](https://github.com/apache/arrow-adbc/blob/main/CONTRIBUTING.md) for PR titles."; - -function matchesCommitFormat(title) { - const commitType = `(${COMMIT_TYPES.join('|')})`; - const scope = "(\\([a-zA-Z0-9_/\\-,]+\\))?"; - const delimiter = "!?:"; - const subject = " .+"; - const regexp = new RegExp(`^${commitType}${scope}${delimiter}${subject}$`); - return title.match(regexp) != null; -} - -async function commentCommitFormat(github, context, pullRequestNumber) { - const {data: comments} = await github.rest.issues.listComments({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: pullRequestNumber, - per_page: 100, - }); - - let found = false; - for (const comment of comments) { - if (comment.body.includes("Conventional Commits format in CONTRIBUTING.md")) { - found = true; - break; - } - } - - if (!found) { - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: pullRequestNumber, - body: COMMENT_BODY, - }); - } -} - -module.exports = async ({github, context}) => { - const pullRequestNumber = context.payload.number; - const title = context.payload.pull_request.title; - if (!matchesCommitFormat(title)) { - await commentCommitFormat(github, context, pullRequestNumber); - } -}; diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/integration.yml b/3rd_party/apache-arrow-adbc/.github/workflows/integration.yml deleted file mode 100644 index e0b1efe..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/integration.yml +++ /dev/null @@ -1,265 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Integration - -on: - pull_request: - branches: - - main - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "go/**" - - "python/**" - - ".github/workflows/integration.yml" - push: - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "go/**" - - "python/**" - - ".github/workflows/integration.yml" - -concurrency: - group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -permissions: - contents: read - -env: - # Increment this to reset cache manually - CACHE_NUMBER: "0" - -jobs: - flightsql: - name: "FlightSQL Integration Tests (Dremio and SQLite)" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache/restore@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_python.txt - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: go/adbc/go.sum - - name: Start SQLite server and Dremio - shell: bash -l {0} - run: | - docker-compose up -d golang-sqlite-flightsql dremio dremio-init - - - name: Build FlightSQL Driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_FLIGHTSQL: "1" - run: | - ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - - name: Test FlightSQL Driver against Dremio and SQLite - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_FLIGHTSQL: "1" - ADBC_DREMIO_FLIGHTSQL_URI: "grpc+tcp://localhost:32010" - ADBC_DREMIO_FLIGHTSQL_USER: "dremio" - ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" - ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" - run: | - ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - ./ci/scripts/cpp_test.sh "$(pwd)" "$(pwd)/build" - - name: Build Python Flight SQL driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_FLIGHTSQL: "1" - BUILD_DRIVER_MANAGER: "1" - run: | - ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" - - name: Test Python Flight SQL driver against Dremio - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_FLIGHTSQL: "1" - ADBC_DREMIO_FLIGHTSQL_URI: "grpc+tcp://localhost:32010" - ADBC_DREMIO_FLIGHTSQL_USER: "dremio" - ADBC_DREMIO_FLIGHTSQL_PASS: "dremio123" - run: | - ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" - - name: Stop SQLite server and Dremio - shell: bash -l {0} - run: | - docker-compose down - - postgresql: - name: "PostgreSQL Integration Tests" - runs-on: ubuntu-latest - services: - postgres: - image: postgres - env: - POSTGRES_DB: tempdb - POSTGRES_PASSWORD: password - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache/restore@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_python.txt - - name: Build PostgreSQL Driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - ADBC_USE_ASAN: "OFF" - ADBC_USE_UBSAN: "OFF" - run: | - ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - - name: Test PostgreSQL Driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - ADBC_POSTGRESQL_TEST_URI: "postgres://localhost:5432/postgres?user=postgres&password=password" - run: | - ./ci/scripts/cpp_test.sh "$(pwd)" "$(pwd)/build" - - - name: Build Python PostgreSQL Driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - run: | - ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" - - name: Test Python PostgreSQL Driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - ADBC_POSTGRESQL_TEST_URI: "postgres://localhost:5432/postgres?user=postgres&password=password" - run: | - ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" - - snowflake: - name: "Snowflake Integration Tests" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache/restore@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_python.txt - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: go/adbc/go.sum - - name: Build and Test Snowflake Driver - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_SNOWFLAKE: "1" - ADBC_SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} - run: | - ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - ./ci/scripts/cpp_test.sh "$(pwd)" "$(pwd)/build" - - name: Build and Test Snowflake Driver (Python) - shell: bash -l {0} - env: - BUILD_ALL: "0" - BUILD_DRIVER_SNOWFLAKE: "1" - ADBC_SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} - run: | - ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" - ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/java.yml b/3rd_party/apache-arrow-adbc/.github/workflows/java.yml deleted file mode 100644 index 057f960..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/java.yml +++ /dev/null @@ -1,115 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Java - -on: - pull_request: - branches: - - main - paths: - - "java/**" - - ".github/workflows/java.yml" - push: - paths: - - "java/**" - - ".github/workflows/java.yml" - -concurrency: - group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - java: - name: "Java ${{ matrix.java }}/Linux" - runs-on: ubuntu-latest - strategy: - matrix: - java: ['8', '11'] - services: - postgres: - image: postgres - env: - POSTGRES_DB: postgres - POSTGRES_PASSWORD: password - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - uses: actions/setup-java@v3 - with: - cache: "maven" - distribution: "temurin" - java-version: ${{ matrix.java }} - - name: Start SQLite server - shell: bash -l {0} - run: | - docker-compose up -d golang-sqlite-flightsql - - name: Build/Test - env: - ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080" - ADBC_JDBC_POSTGRESQL_URL: "localhost:5432/postgres" - ADBC_JDBC_POSTGRESQL_USER: "postgres" - ADBC_JDBC_POSTGRESQL_PASSWORD: "password" - run: | - cd java - mvn install - - java-errorprone: - name: "Java ${{ matrix.java }}/Linux with ErrorProne" - runs-on: ubuntu-latest - strategy: - matrix: - java: ['11'] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - uses: actions/setup-java@v3 - with: - cache: "maven" - distribution: "temurin" - java-version: ${{ matrix.java }} - - name: Build/Test - run: | - cd java - mkdir .mvn - cat < .mvn/jvm.config - --add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED - --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED - --add-opens jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED - --add-opens jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED - HERE - mvn -P errorprone install diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/native-unix.yml b/3rd_party/apache-arrow-adbc/.github/workflows/native-unix.yml deleted file mode 100644 index 57b9dca..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/native-unix.yml +++ /dev/null @@ -1,542 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Native Libraries (Unix) - -on: - pull_request: - branches: - - main - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "glib/**" - - "go/**" - - "python/**" - - "r/**" - - "ruby/**" - - ".github/workflows/native-unix.yml" - push: - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "glib/**" - - "go/**" - - "python/**" - - "r/**" - - "ruby/**" - - ".github/workflows/native-unix.yml" - -concurrency: - group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -permissions: - contents: read - -env: - # Increment this to reset cache manually - CACHE_NUMBER: "1" - -jobs: - # ------------------------------------------------------------ - # Common build (builds libraries used in GLib, Python, Ruby) - # ------------------------------------------------------------ - drivers-build-conda: - name: "Common Libraries (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest"] - env: - # Required for macOS - # https://conda-forge.org/docs/maintainer/knowledge_base.html#newer-c-features-with-old-sdk - CXXFLAGS: "-D_LIBCPP_DISABLE_AVAILABILITY" - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: go/adbc/go.sum - - - name: Build and Install (No ASan) - shell: bash -l {0} - run: | - # Python and others need something that don't use the ASAN runtime - rm -rf "$(pwd)/build" - export BUILD_ALL=1 - export ADBC_BUILD_TESTS=OFF - export ADBC_USE_ASAN=OFF - export ADBC_USE_UBSAN=OFF - export PATH=$RUNNER_TOOL_CACHE/go/1.18.6/x64/bin:$PATH - ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Go Build - shell: bash -l {0} - env: - CGO_ENABLED: "1" - run: | - export PATH=$RUNNER_TOOL_CACHE/go/1.18.6/x64/bin:$PATH - ./ci/scripts/go_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - - uses: actions/upload-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - retention-days: 3 - path: | - ~/local - - # ------------------------------------------------------------ - # C/C++ (builds and tests) - # ------------------------------------------------------------ - drivers-test-conda: - name: "C/C++ (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest"] - env: - # Required for macOS - # https://conda-forge.org/docs/maintainer/knowledge_base.html#newer-c-features-with-old-sdk - CXXFLAGS: "-D_LIBCPP_DISABLE_AVAILABILITY" - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt - - - name: Build SQLite3 Driver - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_SQLITE=1 ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - - name: Test SQLite3 Driver - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_SQLITE=1 ./ci/scripts/cpp_test.sh "$(pwd)" "$(pwd)/build" - - name: Build PostgreSQL Driver - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_POSTGRESQL=1 ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - - name: Build Driver Manager - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - - name: Test Driver Manager - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/cpp_test.sh "$(pwd)" "$(pwd)/build" - - # ------------------------------------------------------------ - # GLib/Ruby - # ------------------------------------------------------------ - glib-conda: - name: "GLib/Ruby (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - needs: - - drivers-build-conda - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest"] - env: - # Required for macOS - # https://conda-forge.org/docs/maintainer/knowledge_base.html#newer-c-features-with-old-sdk - CXXFLAGS: "-D_LIBCPP_DISABLE_AVAILABILITY" - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - 'arrow-c-glib>=10.0.1' \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_glib.txt - - - uses: actions/download-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - path: ~/local - - - name: Build GLib Driver Manager - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/glib_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Test GLib/Ruby Driver Manager - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/glib_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - # ------------------------------------------------------------ - # Go - # ------------------------------------------------------------ - go-no-cgo: - name: "Go (No CGO) (${{ matrix.os }})" - env: - CGO_ENABLED: "0" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest", "windows-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - uses: actions/setup-go@v3 - with: - go-version-file: 'go/adbc/go.mod' - check-latest: true - cache: true - cache-dependency-path: go/adbc/go.sum - - name: Install staticcheck - run: go install honnef.co/go/tools/cmd/staticcheck@v0.3.3 - - name: Go Build - run: | - ./ci/scripts/go_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Run Staticcheck - run: | - pushd go/adbc - staticcheck -f stylish ./... - popd - - name: Go Test - env: - SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} - run: | - ./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - go-conda: - name: "Go (CGO) (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - needs: - - drivers-build-conda - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest"] - env: - CGO_ENABLED: "1" - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: go/adbc/go.sum - - name: Install staticcheck - shell: bash -l {0} - if: ${{ !contains('macos-latest', matrix.os) }} - run: go install honnef.co/go/tools/cmd/staticcheck@v0.3.3 - - - uses: actions/download-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - path: ~/local - - - name: Go Build - shell: bash -l {0} - run: | - export PATH=$RUNNER_TOOL_CACHE/go/1.18.6/x64/bin:$PATH - ./ci/scripts/go_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Run Staticcheck - if: ${{ !contains('macos-latest', matrix.os) }} - shell: bash -l {0} - run: | - pushd go/adbc - staticcheck -f stylish ./... - popd - - name: Go Test - shell: bash -l {0} - env: - SNOWFLAKE_URI: ${{ secrets.SNOWFLAKE_URI }} - run: | - export PATH=$RUNNER_TOOL_CACHE/go/1.18.6/x64/bin:$PATH - ./ci/scripts/go_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - # ------------------------------------------------------------ - # Python/doctests - # ------------------------------------------------------------ - python-conda: - name: "Python ${{ matrix.python }} (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - needs: - - drivers-build-conda - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest"] - python: ["3.9", "3.11"] - env: - # Required for macOS - # https://conda-forge.org/docs/maintainer/knowledge_base.html#newer-c-features-with-old-sdk - CXXFLAGS: "-D_LIBCPP_DISABLE_AVAILABILITY" - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - python=${{ matrix.python }} \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_docs.txt \ - --file ci/conda_env_python.txt - - - uses: actions/download-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - path: ~/local - - - name: Build Python Driver Manager - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Test Python Driver Manager - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_MANAGER=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Build Python Driver Flight SQL - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Test Python Driver Flight SQL - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Build Python Driver PostgreSQL - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_POSTGRESQL=1 ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Test Python Driver PostgreSQL - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_POSTGRESQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Build Python Driver SQLite - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_SQLITE=1 ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Test Python Driver SQLite - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_SQLITE=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Build Python Driver Snowflake - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_SNOWFLAKE=1 ./ci/scripts/python_build.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - - name: Test Python Driver Snowflake - shell: bash -l {0} - run: | - env BUILD_ALL=0 BUILD_DRIVER_SNOWFLAKE=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - # Docs requires Python packages since it runs doctests - - name: Build Docs - shell: bash -l {0} - run: | - ./ci/scripts/docs_build.sh "$(pwd)" - - # ------------------------------------------------------------ - # R - # ------------------------------------------------------------ - r: - name: "R/${{ matrix.config.pkg }} (${{ matrix.config.os }})" - runs-on: ${{ matrix.config.os }} - strategy: - matrix: - config: - - {os: macOS-latest, r: 'release', pkg: 'adbcdrivermanager'} - - {os: windows-latest, r: 'release', pkg: 'adbcdrivermanager'} - - {os: ubuntu-latest, r: 'release', pkg: 'adbcdrivermanager'} - - {os: macOS-latest, r: 'release', pkg: 'adbcsqlite'} - - {os: windows-latest, r: 'release', pkg: 'adbcsqlite'} - - {os: ubuntu-latest, r: 'release', pkg: 'adbcsqlite'} - - {os: macOS-latest, r: 'release', pkg: 'adbcpostgresql'} - - {os: windows-latest, r: 'release', pkg: 'adbcpostgresql'} - - {os: ubuntu-latest, r: 'release', pkg: 'adbcpostgresql'} - - {os: macOS-latest, r: 'release', pkg: 'adbcsnowflake'} - - {os: windows-latest, r: 'release', pkg: 'adbcsnowflake'} - - {os: ubuntu-latest, r: 'release', pkg: 'adbcsnowflake'} - - env: - GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} - R_KEEP_PKG_SOURCE: yes - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - - uses: r-lib/actions/setup-pandoc@v2 - - uses: r-lib/actions/setup-r@v2 - with: - r-version: ${{ matrix.config.r }} - http-user-agent: ${{ matrix.config.http-user-agent }} - use-public-rspm: true - - - name: Set PKG_CONFIG_PATH on MacOS - if: matrix.config.pkg == 'adbcpostgresql' && runner.os == 'macOS' - run: | - PKG_CONFIG_PATH="${PKG_CONFIG_PATH}:$(brew --prefix libpq)/lib/pkgconfig:$(brew --prefix openssl)/lib/pkgconfig" - echo "PKG_CONFIG_PATH=${PKG_CONFIG_PATH}" >> $GITHUB_ENV - - - name: Prepare sources (driver manager) - if: matrix.config.pkg == 'adbcdrivermanager' - run: | - R -e 'install.packages("nanoarrow", repos = "https://cloud.r-project.org")' - R CMD INSTALL r/${{ matrix.config.pkg }} - shell: bash - - - name: Prepare sources - if: matrix.config.pkg != 'adbcdrivermanager' - run: | - R -e 'install.packages("nanoarrow", repos = "https://cloud.r-project.org")' - R CMD INSTALL r/adbcdrivermanager - R CMD INSTALL r/${{ matrix.config.pkg }} - shell: bash - - - uses: r-lib/actions/setup-r-dependencies@v2 - with: - extra-packages: any::rcmdcheck, local::../adbcdrivermanager - needs: check - working-directory: r/${{ matrix.config.pkg }} - - - name: Start postgres test database - if: matrix.config.pkg == 'adbcpostgresql' && runner.os == 'Linux' - run: | - cd r/adbcpostgresql - docker compose up --detach postgres_test - ADBC_POSTGRESQL_TEST_URI="postgresql://localhost:5432/postgres?user=postgres&password=password" - echo "ADBC_POSTGRESQL_TEST_URI=${ADBC_POSTGRESQL_TEST_URI}" >> $GITHUB_ENV - - - uses: r-lib/actions/check-r-package@v2 - env: - ADBC_SNOWFLAKE_TEST_URI: ${{ secrets.SNOWFLAKE_URI }} - with: - upload-snapshots: true - working-directory: r/${{ matrix.config.pkg }} - - - name: Stop postgres test database - if: matrix.config.pkg == 'adbcpostgresql' && runner.os == 'Linux' - run: | - cd r/adbcpostgresql - docker compose down diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/native-windows.yml b/3rd_party/apache-arrow-adbc/.github/workflows/native-windows.yml deleted file mode 100644 index 3537fc4..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/native-windows.yml +++ /dev/null @@ -1,341 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Native Libraries (Windows) - -on: - pull_request: - branches: - - main - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "glib/**" - - "go/**" - - "python/**" - - "ruby/**" - - ".github/workflows/native-windows.yml" - push: - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "glib/**" - - "go/**" - - "python/**" - - "ruby/**" - - ".github/workflows/native-windows.yml" - -concurrency: - group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - # ------------------------------------------------------------ - # Common build (builds libraries used in GLib, Python, Ruby) - # ------------------------------------------------------------ - drivers-build-conda: - name: "Common C/C++ Libraries (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: ["windows-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - env: - # Increment this to reset cache manually - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt - # Force bundled gtest - mamba uninstall gtest - - - name: Build and Install (No ASan) - shell: pwsh - env: - BUILD_ALL: "1" - # TODO(apache/arrow-adbc#634) - BUILD_DRIVER_FLIGHTSQL: "0" - BUILD_DRIVER_SNOWFLAKE: "0" - run: | - .\ci\scripts\cpp_build.ps1 $pwd ${{ github.workspace }}\build - - - uses: actions/upload-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - retention-days: 3 - path: | - ${{ github.workspace }}/build - - # ------------------------------------------------------------ - # C/C++ build (builds and tests) - # ------------------------------------------------------------ - drivers-test-conda: - name: "C/C++ (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: ["windows-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - env: - # Increment this to reset cache manually - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt - # Force bundled gtest - mamba uninstall gtest - - - name: Build Driver Manager - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_MANAGER: "1" - run: - .\ci\scripts\cpp_build.ps1 $pwd $pwd\build - - name: Build Driver PostgreSQL - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - run: - .\ci\scripts\cpp_build.ps1 $pwd $pwd\build - - name: Build Driver SQLite - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_SQLITE: "1" - run: - .\ci\scripts\cpp_build.ps1 $pwd $pwd\build - - name: Test Driver Manager - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_MANAGER: "1" - run: - .\ci\scripts\cpp_test.ps1 $pwd $pwd\build - - name: Test Driver SQLite - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_SQLITE: "1" - run: - .\ci\scripts\cpp_test.ps1 $pwd $pwd\build - - # ------------------------------------------------------------ - # Go build - # ------------------------------------------------------------ - go-conda: - name: "Go (CGO) (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - needs: - - drivers-build-conda - strategy: - matrix: - os: ["windows-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - env: - # Increment this to reset cache manually - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - --file ci/conda_env_cpp.txt - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: go/adbc/go.sum - - - uses: actions/download-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - path: ${{ github.workspace }}/build - - - name: Go Build - shell: pwsh - env: - CGO_ENABLED: "1" - run: | - $env:PATH="$($env:RUNNER_TOOL_CACHE)\go\1.18.6\x64\bin;" + $env:PATH - .\ci\scripts\go_build.ps1 $pwd $pwd\build - # TODO(apache/arrow#358): enable these tests on Windows - # - name: Go Test - # shell: pwsh - # env: - # CGO_ENABLED: "1" - # run: | - # $env:PATH="$($env:RUNNER_TOOL_CACHE)\go\1.18.6\x64\bin;" + $env:PATH - # .\ci\scripts\go_test.ps1 $pwd $pwd\build - - # ------------------------------------------------------------ - # Python build - # ------------------------------------------------------------ - python-conda: - name: "Python ${{ matrix.python }} (Conda/${{ matrix.os }})" - runs-on: ${{ matrix.os }} - needs: - - drivers-build-conda - strategy: - matrix: - os: ["windows-latest"] - python: ["3.9", "3.11"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Get Date - id: get-date - shell: bash - run: | - echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT - - name: Cache Conda - uses: actions/cache@v3 - env: - # Increment this to reset cache manually - CACHE_NUMBER: 0 - with: - path: ~/conda_pkgs_dir - key: conda-${{ runner.os }}-${{ steps.get-date.outputs.today }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('ci/**') }} - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge \ - python=${{ matrix.python }} \ - --file ci/conda_env_cpp.txt \ - --file ci/conda_env_python.txt - - - uses: actions/download-artifact@v3 - with: - name: driver-manager-${{ matrix.os }} - path: ${{ github.workspace }}/build - - - name: Build Python Driver Manager - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_MANAGER: "1" - run: - .\ci\scripts\python_build.ps1 $pwd $pwd\build - - name: Build Python Driver PostgreSQL - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - run: - .\ci\scripts\python_build.ps1 $pwd $pwd\build - - name: Build Python Driver SQLite - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_SQLITE: "1" - run: - .\ci\scripts\python_build.ps1 $pwd $pwd\build - - name: Test Python Driver Manager - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_MANAGER: "1" - run: - .\ci\scripts\python_test.ps1 $pwd $pwd\build - - name: Test Python Driver PostgreSQL - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_POSTGRESQL: "1" - run: - .\ci\scripts\python_test.ps1 $pwd $pwd\build - - name: Test Python Driver SQLite - shell: pwsh - env: - BUILD_ALL: "0" - BUILD_DRIVER_SQLITE: "1" - run: - .\ci\scripts\python_test.ps1 $pwd $pwd\build diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/nightly-verify.yml b/3rd_party/apache-arrow-adbc/.github/workflows/nightly-verify.yml deleted file mode 100644 index 9d9f1a2..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/nightly-verify.yml +++ /dev/null @@ -1,157 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Verification (Nightly) - -on: - schedule: - - cron: "0 0 * * *" - workflow_dispatch: {} - -permissions: - contents: read - -jobs: - source: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Prepare version - shell: bash - run: | - VERSION=$(grep 'set(ADBC_VERSION' c/cmake_modules/AdbcVersion.cmake | \ - grep -E -o '[0-9]+\.[0-9]+\.[0-9]+') - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - - name: Create archive - shell: bash - run: | - git config --global user.name 'github-actions[bot]' - git config --global user.email 'github-actions[bot]@users.noreply.github.com' - git tag -f \ - -a apache-arrow-adbc-${VERSION}-rc0 \ - -m "ADBC Libraries ${VERSION} RC 0" - ci/scripts/source_build.sh \ - apache-arrow-adbc-${VERSION} \ - apache-arrow-adbc-${VERSION}-rc0 - - - name: Create fake GPG key - shell: bash - run: | - gpg \ - --quick-gen-key \ - --batch \ - --passphrase '' \ - user@localhost - - gpg \ - --list-sigs >> KEYS - - gpg \ - --armor \ - --export >> KEYS - - - name: Create sum/signature - shell: bash - run: | - gpg \ - --armor \ - --detach-sign \ - --output apache-arrow-adbc-${VERSION}.tar.gz.asc \ - apache-arrow-adbc-${VERSION}.tar.gz - - shasum --algorithm 512 \ - apache-arrow-adbc-${VERSION}.tar.gz > apache-arrow-adbc-${VERSION}.tar.gz.sha512 - - - uses: actions/upload-artifact@v3 - with: - name: source - retention-days: 7 - path: | - KEYS - apache-arrow-adbc-${{ env.VERSION }}.tar.gz - apache-arrow-adbc-${{ env.VERSION }}.tar.gz.asc - apache-arrow-adbc-${{ env.VERSION }}.tar.gz.sha512 - - source-conda: - name: "Verify Source (Conda)/${{ matrix.os }}" - runs-on: ${{ matrix.os }} - needs: - - source - strategy: - fail-fast: false - matrix: - os: ["macos-latest", "ubuntu-latest", "windows-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - path: arrow-adbc - persist-credentials: false - - - name: Prepare version - shell: bash - run: | - VERSION=$(grep 'set(ADBC_VERSION' arrow-adbc/c/cmake_modules/AdbcVersion.cmake | \ - grep -E -o '[0-9]+\.[0-9]+\.[0-9]+') - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - - uses: actions/download-artifact@v3 - with: - name: source - path: ${{ github.workspace }}/apache-arrow-adbc-${{ env.VERSION }}-rc0/ - - - name: Setup directory structure - shell: bash - run: | - mv apache-arrow-adbc-${{ env.VERSION }}-rc0/KEYS . - - - uses: conda-incubator/setup-miniconda@v2 - # The Unix script will set up conda itself - if: matrix.os == 'windows-latest' - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - - - name: Verify - if: matrix.os != 'windows-latest' - env: - REPOSITORY: ${{ github.repository }} - TEST_DEFAULT: "0" - TEST_SOURCE: "1" - USE_CONDA: "1" - VERBOSE: "1" - VERIFICATION_MOCK_DIST_DIR: ${{ github.workspace }} - run: | - ./arrow-adbc/dev/release/verify-release-candidate.sh $VERSION 0 - - - name: Verify - if: matrix.os == 'windows-latest' - shell: pwsh - env: - REPOSITORY: ${{ github.repository }} - TEST_DEFAULT: "0" - TEST_SOURCE: "1" - USE_CONDA: "1" - VERBOSE: "1" - VERIFICATION_MOCK_DIST_DIR: ${{ github.workspace }}\apache-arrow-adbc-${{ env.VERSION }}-rc0 - run: | - .\arrow-adbc\dev\release\verify-release-candidate.ps1 $env:VERSION 0 diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/nightly-website.yml b/3rd_party/apache-arrow-adbc/.github/workflows/nightly-website.yml deleted file mode 100644 index 38083c8..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/nightly-website.yml +++ /dev/null @@ -1,88 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Automated - Website - -on: - push: - branches: - - main - tags: - - 'apache-arrow-adbc-*' - - '!apache-arrow-adbc-*-rc*' - workflow_dispatch: {} - -# Ensure concurrent builds don't stomp on each other -concurrency: - group: ${{ github.repository }}-${{ github.workflow }} - cancel-in-progress: false - -jobs: - build: - name: "Build Website" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Build - shell: bash - run: | - docker-compose run docs - - name: Archive docs - uses: actions/upload-artifact@v3 - with: - name: docs - retention-days: 2 - path: | - docs/build/html - - publish: - name: "Publish Website" - runs-on: ubuntu-latest - needs: [build] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - path: site - # NOTE: needed to push at the end - persist-credentials: true - ref: asf-site - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - path: scripts - persist-credentials: false - - name: Download docs - uses: actions/download-artifact@v3 - with: - name: docs - path: temp - - name: Build - shell: bash - run: | - pip install sphobjinv - ./scripts/ci/scripts/website_build.sh "$(pwd)/scripts" "$(pwd)/site" "$(pwd)/temp" - - name: Push changes to asf-site branch - run: | - cd site - git config --global user.name 'github-actions[bot]' - git config --global user.email 'github-actions[bot]@users.noreply.github.com' - git commit -m "publish documentation" --allow-empty - git push origin asf-site:asf-site diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/packaging.yml b/3rd_party/apache-arrow-adbc/.github/workflows/packaging.yml deleted file mode 100644 index bc5397a..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/packaging.yml +++ /dev/null @@ -1,978 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Packaging - -on: - pull_request: - branches: - - main - paths: - - "adbc.h" - - "c/**" - - "ci/**" - - "glib/**" - - "python/**" - - "ruby/**" - - ".github/workflows/packaging.yml" - - push: - # Automatically build on RC tags - branches-ignore: - - '**' - tags: - - 'apache-arrow-adbc-*-rc*' - schedule: - - cron: "0 0 * * *" - workflow_dispatch: - inputs: - upload_artifacts: - description: "Upload artifacts to Gemfury" - required: true - type: boolean - default: false - -concurrency: - group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} - cancel-in-progress: true - -jobs: - source: - name: Source - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - name: Update tags - shell: bash - run: | - git fetch --tags --force origin - - - name: Prepare version - shell: bash - run: | - if [ "${GITHUB_REF_TYPE}" = "tag" ]; then - VERSION=${GITHUB_REF_NAME#apache-arrow-adbc-} - VERSION=${VERSION%-rc*} - else - VERSION=$(grep 'set(ADBC_VERSION' c/cmake_modules/AdbcVersion.cmake | \ - grep -E -o '[0-9]+\.[0-9]+\.[0-9]+') - description=$(git describe \ - --always \ - --dirty \ - --long \ - --match "apache-arrow-adbc-[0-9]*.*" \ - --tags) - case "${description}" in - # apache-arrow-adbc-0.1.0-10-1234567-dirty - apache-arrow-adbc-*) - # 10-1234567-dirty - distance="${description#apache-arrow-adbc-*.*.*-}" - # 10-1234567 - distance="${distance%-dirty}" - # 10 - distance="${distance%-*}" - ;; - *) - distance=$(git log --format=oneline | wc -l) - ;; - esac - VERSION="${VERSION}.dev${distance}" - fi - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - - name: Create archive - shell: bash - run: | - ci/scripts/source_build.sh \ - apache-arrow-adbc-${VERSION} \ - $(git log -n 1 --format=%h) - - - uses: actions/upload-artifact@v3 - with: - name: source - retention-days: 7 - path: | - apache-arrow-adbc-${{ env.VERSION }}.tar.gz - - docs: - name: "Documentation" - runs-on: ubuntu-latest - needs: - - source - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Build and test - shell: bash - run: | - pushd adbc - docker-compose run \ - -e SETUPTOOLS_SCM_PRETEND_VERSION=${VERSION} \ - docs - popd - - - name: Compress docs - shell: bash - run: | - pushd adbc - tar --transform "s|docs/build/html|adbc-docs|" -czf ../docs.tgz docs/build/html - popd - - - name: Archive docs - uses: actions/upload-artifact@v3 - with: - name: docs - retention-days: 2 - path: | - docs.tgz - - java: - name: "Java 1.8" - runs-on: ubuntu-latest - needs: - - source - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Build and test - shell: bash - run: | - pushd adbc/ - docker-compose run java-dist - popd - cp -a adbc/dist/ ./ - - - name: Archive JARs - uses: actions/upload-artifact@v3 - with: - name: java - retention-days: 7 - path: | - dist/*.jar - dist/*.pom - - linux: - name: Linux ${{ matrix.target }} - runs-on: ubuntu-latest - needs: - - source - strategy: - fail-fast: false - matrix: - target: - - almalinux-8 - - almalinux-9 - - debian-bookworm - - debian-bullseye - - ubuntu-jammy - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - uses: actions/checkout@v3 - with: - repository: apache/arrow - path: arrow - - - name: Set environment variables - run: | - echo "ARROW_SOURCE=$(pwd)/arrow" >> $GITHUB_ENV - case ${{ matrix.target }} in - almalinux-*) - echo "TASK_NAMESPACE=yum" >> $GITHUB_ENV - echo "YUM_TARGETS=${{ matrix.target }}" >> $GITHUB_ENV - ;; - debian-*|ubuntu-*) - echo "TASK_NAMESPACE=apt" >> $GITHUB_ENV - echo "APT_TARGETS=${{ matrix.target }}" >> $GITHUB_ENV - ;; - esac - distribution=$(echo ${{ matrix.target }} | cut -d- -f1) - echo "DISTRIBUTION=${distribution}" >> $GITHUB_ENV - - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - - name: Extract source archive - run: | - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - mv apache-arrow-adbc-${VERSION}.tar.gz adbc/ci/linux-packages/ - - - name: Set up Ruby - uses: ruby/setup-ruby@v1 - with: - ruby-version: ruby - - - name: Cache ccache - uses: actions/cache@v3 - with: - path: adbc/ci/linux-packages/${{ env.TASK_NAMESPACE }}/build/${{ matrix.target }}/ccache - key: linux-${{ env.TASK_NAMESPACE }}-ccache-${{ matrix.target }}-{{ "${{ hashFiles('adbc.h', 'c/**', 'glib/**') }}" }} - restore-keys: linux-${{ env.TASK_NAMESPACE }}-ccache-${{ matrix.target }}- - - - name: Login to GitHub Container registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ github.token }} - - - name: Build - run: | - pushd adbc/ci/linux-packages - if [ "${GITHUB_REF_TYPE}" != "tag" ]; then - rake version:update - fi - rake docker:pull || : - rake --trace ${TASK_NAMESPACE}:build BUILD_DIR=build - popd - - - name: Prepare artifacts - run: | - cp -a \ - adbc/ci/linux-packages/${{ env.TASK_NAMESPACE }}/repositories/${DISTRIBUTION} \ - ./ - tar czf ${{ matrix.target }}.tar.gz ${DISTRIBUTION} - - - name: Upload artifacts - uses: actions/upload-artifact@v3 - with: - name: ${{ matrix.target }} - retention-days: 7 - path: | - ${{ matrix.target }}.tar.gz - - - name: Push Docker image - run: | - pushd adbc/ci/linux-packages - rake docker:push || : - popd - - - name: Set up test - run: | - sudo apt update - sudo apt install -y \ - apt-utils \ - createrepo-c \ - devscripts \ - gpg \ - rpm - gem install apt-dists-merge - (echo "Key-Type: RSA"; \ - echo "Key-Length: 4096"; \ - echo "Name-Real: Test"; \ - echo "Name-Email: test@example.com"; \ - echo "%no-protection") | \ - gpg --full-generate-key --batch - GPG_KEY_ID=$(gpg --list-keys --with-colon test@example.com | grep fpr | cut -d: -f10) - echo "GPG_KEY_ID=${GPG_KEY_ID}" >> ${GITHUB_ENV} - gpg --export --armor test@example.com > adbc/ci/linux-packages/KEYS - - - name: Test - run: | - pushd adbc/ci/linux-packages - rake --trace ${TASK_NAMESPACE}:test - popd - - python-conda-linux: - name: "Python ${{ matrix.arch }} Conda" - runs-on: ubuntu-latest - # No need for Conda packages during release - # if: "!startsWith(github.ref, 'refs/tags/')" - # TODO(apache/arrow-adbc#468): disable for now - if: false - needs: - - source - strategy: - matrix: - # TODO: "linux_aarch64_" - arch: ["linux_64_"] - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ github.event.inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Build Conda package - shell: bash - env: - ARCH_CONDA_FORGE: ${{ matrix.arch }} - run: | - pushd adbc - docker-compose run \ - -e HOST_USER_ID=$(id -u) \ - python-conda - popd - - - name: Archive Conda packages - uses: actions/upload-artifact@v3 - with: - name: python-${{ matrix.arch }}-conda - retention-days: 7 - path: | - adbc/build/conda/*/*.tar.bz2 - - - name: Test Conda packages - if: matrix.arch == 'linux_64_' - shell: bash - env: - ARCH_CONDA_FORGE: ${{ matrix.arch }} - run: | - pushd adbc - docker-compose run \ - python-conda-test - popd - - python-conda-macos: - name: "Python ${{ matrix.arch }} Conda" - runs-on: macos-latest - # No need for Conda packages during release - # TODO(apache/arrow-adbc#468): re-enable - if: false - needs: - - source - strategy: - matrix: - # TODO: "osx_arm64_" - arch: ["osx_64_"] - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ github.event.inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge boa conda-smithy conda-verify - conda config --remove channels defaults - conda config --add channels conda-forge - - - name: Build Conda package - shell: bash -l {0} - env: - ARCH_CONDA_FORGE: ${{ matrix.arch }} - run: | - ./adbc/ci/scripts/python_conda_build.sh $(pwd)/adbc ${ARCH_CONDA_FORGE}.yaml $(pwd)/adbc/build - - - name: Archive Conda packages - uses: actions/upload-artifact@v3 - with: - name: python-${{ matrix.arch }}-conda - retention-days: 7 - path: | - adbc/build/conda/*/*.tar.bz2 - - - name: Test Conda packages - shell: bash -l {0} - if: matrix.arch == 'osx_64_' - env: - ARCH_CONDA_FORGE: ${{ matrix.arch }} - run: | - ./adbc/ci/scripts/python_conda_test.sh $(pwd)/adbc $(pwd)/adbc/build - - python-manylinux: - name: "Python ${{ matrix.arch }} manylinux${{ matrix.manylinux_version }}" - runs-on: ubuntu-latest - needs: - - source - strategy: - matrix: - arch: ["amd64", "arm64v8"] - manylinux_version: ["2014"] - is_pr: - - ${{ startsWith(github.ref, 'refs/pull/') }} - exclude: - # Don't run arm64v8 build on PRs since the build is excessively slow - - arch: arm64v8 - is_pr: true - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ github.event.inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Build wheel - shell: bash - env: - ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} - run: | - pushd adbc - docker-compose run \ - -e SETUPTOOLS_SCM_PRETEND_VERSION=${VERSION} \ - python-wheel-manylinux - popd - - - name: Archive wheels - uses: actions/upload-artifact@v3 - with: - name: python-${{ matrix.arch }}-manylinux${{ matrix.manylinux_version }} - retention-days: 7 - path: | - adbc/python/adbc_driver_flightsql/repaired_wheels/*.whl - adbc/python/adbc_driver_manager/repaired_wheels/*.whl - adbc/python/adbc_driver_postgresql/repaired_wheels/*.whl - adbc/python/adbc_driver_sqlite/repaired_wheels/*.whl - adbc/python/adbc_driver_snowflake/repaired_wheels/*.whl - - - name: Test wheel - shell: bash - env: - ARCH: ${{ matrix.arch }} - MANYLINUX: ${{ matrix.manylinux_version }} - run: | - pushd adbc - env PYTHON=3.9 docker-compose run python-wheel-manylinux-test - env PYTHON=3.10 docker-compose run python-wheel-manylinux-test - env PYTHON=3.11 docker-compose run python-wheel-manylinux-test - - python-macos: - name: "Python ${{ matrix.arch }} macOS" - runs-on: macos-latest - needs: - - source - strategy: - matrix: - arch: ["amd64", "arm64v8"] - env: - MACOSX_DEPLOYMENT_TARGET: "10.15" - PYTHON: "/Library/Frameworks/Python.framework/Versions/3.10/bin/python3.10" - # Where to install vcpkg - VCPKG_ROOT: "${{ github.workspace }}/vcpkg" - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - shell: bash - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - echo "SETUPTOOLS_SCM_PRETEND_VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ github.event.inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Install Homebrew dependencies - shell: bash - run: brew install autoconf bash pkg-config ninja - - - name: Retrieve VCPKG version from .env - shell: bash - run: | - pushd adbc - vcpkg_version=$(cat ".env" | grep "VCPKG" | cut -d "=" -f2 | tr -d '"') - echo "VCPKG_VERSION=$vcpkg_version" | tee -a $GITHUB_ENV - popd - - - name: Install vcpkg - shell: bash - run: | - pushd adbc - ci/scripts/install_vcpkg.sh $VCPKG_ROOT $VCPKG_VERSION - popd - - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: adbc/go/adbc/go.sum - - - name: Install Python - shell: bash - run: | - pushd adbc - sudo ci/scripts/install_python.sh macos 3.9 - sudo ci/scripts/install_python.sh macos 3.10 - sudo ci/scripts/install_python.sh macos 3.11 - popd - - - name: Build wheel - shell: bash - env: - ARCH: ${{ matrix.arch }} - run: | - pushd adbc - $PYTHON -m venv build-env - source build-env/bin/activate - ./ci/scripts/python_wheel_unix_build.sh $ARCH $(pwd) $(pwd)/build - popd - - - name: Archive wheels - uses: actions/upload-artifact@v3 - with: - name: python-${{ matrix.arch }}-macos - retention-days: 7 - path: | - adbc/python/adbc_driver_flightsql/repaired_wheels/*.whl - adbc/python/adbc_driver_manager/repaired_wheels/*.whl - adbc/python/adbc_driver_postgresql/repaired_wheels/*.whl - adbc/python/adbc_driver_sqlite/repaired_wheels/*.whl - adbc/python/adbc_driver_snowflake/repaired_wheels/*.whl - - - name: Test wheel - if: matrix.arch == 'amd64' - shell: bash - run: | - pushd adbc - - /Library/Frameworks/Python.framework/Versions/3.9/bin/python3.9 -m venv test-env-39 - source test-env-39/bin/activate - export PYTHON_VERSION=3.9 - ./ci/scripts/python_wheel_unix_test.sh $(pwd) - deactivate - - /Library/Frameworks/Python.framework/Versions/3.10/bin/python3.10 -m venv test-env-310 - source test-env-310/bin/activate - export PYTHON_VERSION=3.10 - ./ci/scripts/python_wheel_unix_test.sh $(pwd) - deactivate - - /Library/Frameworks/Python.framework/Versions/3.11/bin/python3.11 -m venv test-env-311 - source test-env-311/bin/activate - export PYTHON_VERSION=3.11 - ./ci/scripts/python_wheel_unix_test.sh $(pwd) - deactivate - - popd - - python-windows: - name: "Python ${{ matrix.python_version }} Windows" - runs-on: windows-latest - needs: - - source - strategy: - matrix: - python_version: ["3.9", "3.10", "3.11"] - env: - PYTHON_VERSION: "${{ matrix.python_version }}" - # Where to install vcpkg - VCPKG_ROOT: "${{ github.workspace }}\\vcpkg" - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - shell: bash - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: pwsh - run: | - echo "upload_artifacts: ${{ inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Install Chocolatey Dependencies - shell: pwsh - run: | - choco install --no-progress -y cmake --installargs 'ADD_CMAKE_TO_PATH=System' - choco install --no-progress -y visualcpp-build-tools - - - name: Retrieve VCPKG version from .env - shell: pwsh - run: | - pushd adbc - Select-String -Path .env -Pattern 'VCPKG="(.+)"' | % {"VCPKG_VERSION=$($_.matches.groups[1])"} >> $env:GITHUB_ENV - popd - - - name: Install vcpkg - shell: pwsh - run: | - echo $env:VCPKG_VERSION - git clone --shallow-since=2022-06-01 https://github.com/microsoft/vcpkg $env:VCPKG_ROOT - pushd $env:VCPKG_ROOT - .\bootstrap-vcpkg.bat -disableMetrics - popd - - - uses: actions/setup-go@v3 - with: - go-version: 1.18.6 - check-latest: true - cache: true - cache-dependency-path: adbc/go/adbc/go.sum - - - name: Install Python ${{ matrix.python_version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python_version }} - - - name: Build wheel - shell: cmd - run: | - where python.exe - CALL "C:\Program Files (x86)\Microsoft Visual Studio\2017\BuildTools\VC\Auxiliary\Build\vcvars64.bat" - pushd adbc - set SETUPTOOLS_SCM_PRETEND_VERSION=%VERSION% - .\ci\scripts\python_wheel_windows_build.bat %cd% %cd%\build - popd - - - name: Archive wheels - uses: actions/upload-artifact@v3 - with: - name: python${{ matrix.python_version }}-windows - retention-days: 7 - path: | - adbc/python/adbc_driver_flightsql/repaired_wheels/*.whl - adbc/python/adbc_driver_manager/repaired_wheels/*.whl - adbc/python/adbc_driver_postgresql/repaired_wheels/*.whl - adbc/python/adbc_driver_sqlite/repaired_wheels/*.whl - adbc/python/adbc_driver_snowflake/repaired_wheels/*.whl - - - name: Test wheel - shell: cmd - run: | - pushd adbc - where python.exe - python -m venv venv - CALL venv\Scripts\activate.bat - .\ci\scripts\python_wheel_windows_test.bat %cd% - popd - - python-sdist: - name: "Python sdist" - runs-on: ubuntu-latest - needs: - - source - steps: - - uses: actions/download-artifact@v3 - with: - name: source - - - name: Extract source archive - run: | - source_archive=$(echo apache-arrow-adbc-*.tar.gz) - VERSION=${source_archive#apache-arrow-adbc-} - VERSION=${VERSION%.tar.gz} - echo "VERSION=${VERSION}" >> $GITHUB_ENV - - tar xf apache-arrow-adbc-${VERSION}.tar.gz - mv apache-arrow-adbc-${VERSION} adbc - - - name: Show inputs - shell: bash - run: | - echo "upload_artifacts: ${{ github.event.inputs.upload_artifacts }}" - echo "schedule: ${{ github.event.schedule }}" - echo "ref: ${{ github.ref }}" - - - name: Build sdist - shell: bash - run: | - pushd adbc - docker-compose run \ - -e SETUPTOOLS_SCM_PRETEND_VERSION=${VERSION} \ - python-sdist - popd - - - name: Archive sdist - uses: actions/upload-artifact@v3 - with: - name: python${{ matrix.python_version }}-manylinux${{ matrix.manylinux_version }} - retention-days: 7 - path: | - adbc/python/adbc_driver_flightsql/dist/*.tar.gz - adbc/python/adbc_driver_manager/dist/*.tar.gz - adbc/python/adbc_driver_postgresql/dist/*.tar.gz - adbc/python/adbc_driver_sqlite/dist/*.tar.gz - adbc/python/adbc_driver_snowflake/dist/*.tar.gz - - - name: Test sdist - shell: bash - run: | - pushd adbc - docker-compose run python-sdist-test - popd - - release: - name: "Create release" - runs-on: ubuntu-latest - if: startsWith(github.ref, 'refs/tags/') - needs: - - docs - - source - - java - - linux - - python-manylinux - - python-macos - - python-windows - - python-sdist - steps: - - name: Get All Artifacts - uses: actions/download-artifact@v3 - with: - path: release-artifacts - - name: Release - shell: bash - run: | - RELEASE_TAG=${GITHUB_REF#refs/*/} - - # Deduplicate wheels built in different jobs with same tag - mkdir -p upload-staging - find ./release-artifacts/ \ - '(' \ - -name docs.tgz -or \ - -name '*.jar' -or \ - -name '*.pom' -or \ - -name '*.whl' -or \ - -name 'adbc_*.tar.gz' -or \ - -name 'almalinux-*.tar.gz' -or \ - -name 'apache-arrow-adbc-*.tar.gz' -or \ - -name 'debian-*.tar.gz' -or \ - -name 'ubuntu-*.tar.gz' \ - ')' \ - -exec mv '{}' upload-staging \; - - UPLOAD=$(find upload-staging -type f | sort | uniq) - - echo "Uploading files:" - echo ${UPLOAD} - - gh release create "${RELEASE_TAG}" \ - --repo ${{ github.repository }} \ - --prerelease \ - --title "ADBC Libraries ${RELEASE_TAG}" \ - ${UPLOAD} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - upload-anaconda: - name: "Upload packages to Anaconda.org" - runs-on: ubuntu-latest - if: github.ref == 'refs/heads/main' && (github.event.schedule || inputs.upload_artifacts) - needs: - - python-conda-linux - # TODO(apache/arrow-adbc#468): re-enable - # - python-conda-macos - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: true - - name: Get All Artifacts - uses: actions/download-artifact@v3 - with: - path: conda-packages - - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-only-tar-bz2: false - use-mamba: true - - name: Install Dependencies - shell: bash -l {0} - run: | - mamba install -c conda-forge anaconda-client - - name: Clean - shell: bash -l {0} - continue-on-error: true - run: | - # Clean all existing packages, OK if we fail - ./ci/scripts/python_conda_clean.sh - env: - ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }} - - name: Upload - shell: bash -l {0} - run: | - ls -laR conda-packages - # Upload fresh packages - ./ci/scripts/python_conda_upload.sh conda-packages/python-*-conda/*/*.tar.bz2 - env: - ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }} - - upload-gemfury: - name: "Upload packages to Gemfury" - runs-on: ubuntu-latest - if: github.ref == 'refs/heads/main' && (github.event.schedule || inputs.upload_artifacts) - needs: - - java - - python-manylinux - - python-macos - - python-windows - - python-sdist - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: true - - name: Get All Artifacts - uses: actions/download-artifact@v3 - with: - path: nightly-artifacts - - name: Upload - shell: bash - run: | - # Deduplicate wheels built in different jobs with same tag - mkdir -p upload-staging - find ./nightly-artifacts/ \ - '(' \ - -name '*.jar' -or \ - -name '*.pom' -or \ - -name '*.whl' -or \ - -name 'adbc_*.tar.gz' \ - ')' \ - -exec mv '{}' upload-staging \; - - # Java - ./ci/scripts/java_jar_upload.sh upload-staging/*.pom - # Python - ./ci/scripts/python_wheel_upload.sh upload-staging/adbc_*.tar.gz upload-staging/*.whl - env: - GEMFURY_PUSH_TOKEN: ${{ secrets.GEMFURY_PUSH_TOKEN }} - - clean-gemfury: - name: "Clean old releases" - runs-on: ubuntu-latest - if: "!startsWith(github.ref, 'refs/tags/')" - needs: - - upload-gemfury - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: true - - - name: Clean old releases - shell: bash - run: | - gem install --user-install gemfury - ruby ./ci/scripts/gemfury_clean.rb - env: - GEMFURY_API_TOKEN: ${{ secrets.GEMFURY_API_TOKEN }} diff --git a/3rd_party/apache-arrow-adbc/.github/workflows/verify.yml b/3rd_party/apache-arrow-adbc/.github/workflows/verify.yml deleted file mode 100644 index fbde892..0000000 --- a/3rd_party/apache-arrow-adbc/.github/workflows/verify.yml +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Verification - -on: - workflow_dispatch: - inputs: - version: - description: "Version to verify" - required: false - type: string - default: "" - rc: - description: "RC to verify" - required: false - type: string - default: "" - pull_request: - branches: - - main - paths: - - '.github/workflows/verify.yml' - - 'dev/release/verify-release-candidate.sh' - - 'dev/release/verify-release-candidate.ps1' - -permissions: - contents: read - -jobs: - binary-unix: - name: "Verify Binaries/${{ matrix.os }}" - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: ["macos-latest", "ubuntu-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - name: Verify - shell: bash - env: - REPOSITORY: ${{ github.repository }} - TEST_DEFAULT: "0" - TEST_BINARIES: "1" - USE_CONDA: "1" - VERBOSE: "1" - run: | - ./dev/release/verify-release-candidate.sh ${{ inputs.version }} ${{ inputs.rc }} - - source-conda: - name: "Verify Source (Conda)/${{ matrix.os }}" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: ["macos-latest", "ubuntu-latest", "windows-latest"] - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - uses: conda-incubator/setup-miniconda@v2 - # The Unix script will set up conda itself - if: matrix.os == 'windows-latest' - with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - - name: Verify - if: matrix.os != 'windows-latest' - env: - REPOSITORY: ${{ github.repository }} - TEST_DEFAULT: "0" - TEST_SOURCE: "1" - USE_CONDA: "1" - VERBOSE: "1" - run: | - ./dev/release/verify-release-candidate.sh ${{ inputs.version }} ${{ inputs.rc }} - - name: Verify - if: matrix.os == 'windows-latest' - shell: pwsh - env: - REPOSITORY: ${{ github.repository }} - TEST_DEFAULT: "0" - TEST_SOURCE: "1" - USE_CONDA: "1" - VERBOSE: "1" - run: | - .\dev\release\verify-release-candidate.ps1 ${{ inputs.version }} ${{ inputs.rc }} diff --git a/3rd_party/apache-arrow-adbc/.gitignore b/3rd_party/apache-arrow-adbc/.gitignore deleted file mode 100644 index b1ede18..0000000 --- a/3rd_party/apache-arrow-adbc/.gitignore +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Release artifacts -apache-arrow-adbc-*.tar.gz -apache-arrow-adbc-*.tar.gz.asc -apache-arrow-adbc-*.tar.gz.sha256 -apache-arrow-adbc-*.tar.gz.sha512 -apache-rat-*.jar -dev/release/.env -filtered_rat.txt -packages/ -rat.txt - -# Compiled source -*.a -*.dll -*.o -*.py[ocd] -*.so -*.so.* -*.bundle -*.dylib -.build_cache_dir -dependency-reduced-pom.xml -MANIFEST -compile_commands.json -build.ninja -.clangd - -# Generated Visual Studio files -*.vcxproj -*.vcxproj.* -*.sln -*.iml - -# Linux perf sample data -perf.data -perf.data.old - -cpp/.idea/ -c/apidoc/html/ -c/apidoc/latex/ -c/apidoc/xml/ -docs/example.gz -docs/example1.dat -docs/example3.dat -python/.eggs/ -python/doc/ -# Egg metadata -*.egg-info - -.vscode -.idea/ -.pytest_cache/ -pkgs -docker_cache -.gdb_history -*.orig -.*.swp -.*.swo -CMakeUserPresets.json -build/ - -site/ - -# Python -dist/ -.hypothesis/ -repaired_wheels/ - -# R files -**/.Rproj.user -**/*.Rcheck/ -**/.Rhistory -.Rproj.user - -# macOS -cpp/Brewfile.lock.json -.DS_Store - -# docker volumes used for caching -.docker - -# generated native binaries created by java JNI build -java-dist/ -java-native-c/ -java-native-cpp/ -target/ - -*.log - -# Linux packages -/ci/linux-packages/*.tar.gz -/ci/linux-packages/KEYS -/ci/linux-packages/apt/build.sh -/ci/linux-packages/apt/build/ -/ci/linux-packages/apt/env.sh -/ci/linux-packages/apt/merged/ -/ci/linux-packages/apt/repositories/ -/ci/linux-packages/apt/tmp/ -/ci/linux-packages/yum/build.sh -/ci/linux-packages/yum/build/ -/ci/linux-packages/yum/env.sh -/ci/linux-packages/yum/merged/ -/ci/linux-packages/yum/repositories/ -/ci/linux-packages/yum/tmp/ diff --git a/3rd_party/apache-arrow-adbc/.isort.cfg b/3rd_party/apache-arrow-adbc/.isort.cfg deleted file mode 100644 index 1d6b380..0000000 --- a/3rd_party/apache-arrow-adbc/.isort.cfg +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[settings] -known_first_party = adbc_driver_flightsql, adbc_driver_manager, adbc_driver_postgresql, adbc_driver_sqlite, adbc_driver_snowflake -profile = black diff --git a/3rd_party/apache-arrow-adbc/.pre-commit-config.yaml b/3rd_party/apache-arrow-adbc/.pre-commit-config.yaml deleted file mode 100644 index aacf930..0000000 --- a/3rd_party/apache-arrow-adbc/.pre-commit-config.yaml +++ /dev/null @@ -1,88 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# To use this, install the python package `pre-commit` and -# run once `pre-commit install`. This will setup a git pre-commit-hook -# that is executed on each commit and will report the linting problems. -# To run all hooks on all files use `pre-commit run -a` - -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 - hooks: - - id: check-xml - - id: check-yaml - exclude: ci/conda/meta.yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/pocc/pre-commit-hooks - rev: v1.3.5 - hooks: - - id: clang-format - args: [-i] - types_or: [c, c++] - - repo: https://github.com/cheshirekow/cmake-format-precommit - rev: v0.6.13 - hooks: - - id: cmake-format - args: [--in-place] - - repo: https://github.com/cpplint/cpplint - rev: 1.6.0 - hooks: - - id: cpplint - args: - # From Arrow's config - - "--filter=-whitespace/comments,-readability/casting,-readability/todo,-readability/alt_tokens,-build/header_guard,-build/c++11,-build/include_order,-build/include_subdir" - - "--linelength=90" - - "--verbose=2" - - repo: https://github.com/golangci/golangci-lint - rev: v1.49.0 - hooks: - - id: golangci-lint - entry: bash -c 'cd go/adbc && golangci-lint run --fix --timeout 5m' - types_or: [go] - - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.3.0 - hooks: - - id: pretty-format-golang - - id: pretty-format-java - args: [--autofix] - types_or: [java] - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - types_or: [python] - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - types_or: [python] - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - types_or: [python] - - repo: local - hooks: - - id: apache-rat - name: Check for unapproved licenses - language: script - pass_filenames: false - entry: "./ci/scripts/run_rat_local.sh" - -exclude: "^c/vendor/.*" diff --git a/3rd_party/apache-arrow-adbc/CHANGELOG.md b/3rd_party/apache-arrow-adbc/CHANGELOG.md index 186f710..e67a721 100644 --- a/3rd_party/apache-arrow-adbc/CHANGELOG.md +++ b/3rd_party/apache-arrow-adbc/CHANGELOG.md @@ -238,24 +238,86 @@ - **go/adbc/driver/flightsql**: filter by schema in getObjectsTables (#726) -## ADBC Libraries 0.5.1 (2023-06-22) +## ADBC Libraries 0.6.0 (2023-08-23) ### Feat -- **r**: Add FlightSQL driver wrapper (#835) -- **python/adbc_driver_flightsql**: add cookie middleware option to DatabaseOptions (#830) -- **go/adbc/driver/flightsql**: Add cookie middleware option (#825) -- **c/driver/postgresql**: Implement GetObjects with table_types argument (#799) -- **c/driver/postgresql**: Binary ingest (#808) -- **c/driver/postgresql**: Support float type (#807) +- **python/adbc_driver_manager**: add fetch_record_batch (#989) +- **c/driver**: Date32 support (#948) +- **c/driver/postgresql**: Interval support (#908) +- **go/adbc/driver/flightsql**: add context to gRPC errors (#921) +- **c/driver/sqlite**: SQLite timestamp write support (#897) +- **c/driver/postgresql**: Handle NUMERIC type by converting to string (#883) +- **python/adbc_driver_postgresql**: add PostgreSQL options enum (#886) +- **c/driver/postgresql**: TimestampTz write (#868) +- **c/driver/postgresql**: Implement streaming/chunked output (#870) +- **c/driver/postgresql**: Timestamp write support (#861) +- **c/driver_manager,go/adbc,python**: trim down error messages (#866) +- **c/driver/postgresql**: Int8 support (#858) +- **c/driver/postgresql**: Better type error messages (#860) ### Fix -- **go/adbc/driver/snowflake**: fix potential deadlock and error handling (#828) -- **csharp**: submodule not pulling correctly (#824) -- **go/adbc/driver/snowflake**: initialize Params, add DLL build (#820) -- **dev/release**: add missing duckdb dependency (#810) +- **go/adbc/driver/flightsql**: Have GetTableSchema check for table name match instead of the first schema it receives (#980) +- **r**: Ensure that info_codes are coerced to integer (#986) +- **go/adbc/sqldriver**: fix handling of decimal types (#970) +- **c/driver/postgresql**: Fix segfault associated with uninitialized copy_reader_ (#964) +- **c/driver/sqlite**: add table types by default from arrow types (#955) +- **csharp**: include GetTableTypes and GetTableSchema call for .NET 4.7.2 (#950) +- **csharp**: include GetInfo and GetObjects call for .NET 4.7.2 (#945) +- **c/driver/sqlite**: Wrap bulk ingests in a single begin/commit txn (#910) +- **csharp**: fix C api to work under .NET 4.7.2 (#931) +- **python/adbc_driver_snowflake**: allow connecting without URI (#923) +- **go/adbc/pkg**: export Adbc* symbols on Windows (#916) +- **go/adbc/driver/snowflake**: handle non-arrow result sets (#909) +- **c/driver/sqlite**: fix escaping of sqlite TABLE CREATE columns (#906) +- **go/adbc/pkg**: follow CGO rules properly (#902) +- **go/adbc/driver/snowflake**: Fix integration tests by fixing timestamp handling (#889) +- **go/adbc/driver/snowflake**: fix failing integration tests (#888) +- **c/validation**: Fix ASAN-detected leak (#879) +- **go/adbc**: fix crash on map type (#854) +- **go/adbc/driver/snowflake**: handle result sets without Arrow data (#864) + +### Perf + +- **go/adbc/driver/snowflake**: Implement concurrency limit (#974) + +### Refactor + +- **c**: Vendor portable-snippets for overflow checks (#951) +- **c/driver/postgresql**: Use ArrowArrayViewGetIntervalUnsafe from nanoarrow (#957) +- **c/driver/postgresql**: Simplify current database querying (#880) + +## ADBC Libraries 0.7.0 (2023-09-20) + +### Feat + +- **r**: Add quoting/escaping generics (#1083) +- **r**: Implement temporary table option in R driver manager (#1084) +- **python/adbc_driver_flightsql**: add adbc.flight.sql.client_option.authority to DatabaseOptions (#1069) +- **go/adbc/driver/snowflake**: improve XDBC support (#1034) +- **go/adbc/driver/flightsql**: add adbc.flight.sql.client_option.authority (#1060) +- **c/driver**: support ingesting into temporary tables (#1057) +- **c/driver**: support target catalog/schema for ingestion (#1056) +- **go**: add basic driver logging (#1048) +- **c/driver/postgresql**: Support ingesting LARGE_STRING types (#1050) +- **c/driver/postgresql**: Duration support (#907) +- ADBC API revision 1.1.0 (#971) + +### Fix + +- **java/driver/flight-sql**: fix leak in InfoMetadataBuilder (#1070) +- **c/driver/postgresql**: Fix overflow in statement.cc (#1072) +- **r/adbcdrivermanager**: Ensure nullable arguments `adbc_connection_get_objects()` can be specified (#1032) +- **c/driver/sqlite**: Escape table name in sqlite GetTableSchema (#1036) +- **c/driver**: return NOT_FOUND for GetTableSchema (#1026) +- **c/driver_manager**: fix crash when error is null (#1029) +- **c/driver/postgresql**: suppress console spam (#1027) +- **c/driver/sqlite**: escape table names in INSERT, too (#1003) +- **go/adbc/driver/snowflake**: properly handle time fields (#1021) +- **r/adbcdrivermanager**: Make `adbc_xptr_is_valid()` return `FALSE` for external pointer to NULL (#1007) +- **go/adbc**: don't include NUL in error messages (#998) ### Refactor -- **csharp**: cleanup load of imported drivers (#818) +- **c/driver/postgresql**: hardcode overflow checks (#1051) diff --git a/3rd_party/apache-arrow-adbc/CONTRIBUTING.md b/3rd_party/apache-arrow-adbc/CONTRIBUTING.md index bcd2315..d8b2fe3 100644 --- a/3rd_party/apache-arrow-adbc/CONTRIBUTING.md +++ b/3rd_party/apache-arrow-adbc/CONTRIBUTING.md @@ -26,6 +26,34 @@ https://github.com/apache/arrow-adbc/issues ## Building +### Environment Setup + +Some dependencies are required to build and test the various ADBC packages. + +For C/C++, you will most likely want a [Conda][conda] installation, +with [Mambaforge][mambaforge] being the most convenient distribution. +If you have Mambaforge installed, you can set up a development +environment as follows: + +```shell +$ mamba create -n adbc --file ci/conda_env_cpp.txt +$ mamba activate adbc +``` + +(For other Conda distributions, you will likely need `create ... -c +conda-forge --file ...`). + +There are additional environment definitions for development on Python +and GLib/Ruby packages. + +Conda is not required; you may also use a package manager like Nix or +Homebrew, the system package manager, etc. so long as you configure +CMake or other build tool appropriately. However, we primarily +develop and support Conda users. + +[conda]: https://docs.conda.io/en/latest/ +[mambaforge]: https://mamba.readthedocs.io/en/latest/installation.html + ### C/C++ All libraries here contained within one CMake project. To build any @@ -35,34 +63,44 @@ replacing `_COMPONENT` with the name of the library/libraries. _Note:_ unlike the Arrow C++ build system, the CMake projects will **not** automatically download and build dependencies—you should configure CMake appropriately to find dependencies in system or -package manager locations. +package manager locations, if you are not using something like Conda. -For example, the driver manager and postgres driver may be built -together as follows: +You can use CMake presets to build and test: ```shell $ mkdir build $ cd build +$ cmake ../c --preset debug +# ctest reads presets from PWD +$ cd ../c +$ ctest --preset debug --test-dir ../build +``` + +You can also manually configure CMake. For example, the driver manager and +postgres driver may be built together as follows: + +```shell +$ mkdir build +$ cd build +$ export CMAKE_EXPORT_COMPILE_COMMANDS=ON $ cmake ../c -DADBC_DRIVER_POSTGRESQL=ON -DADBC_DRIVER_MANAGER=ON $ make -j ``` +[`export CMAKE_EXPORT_COMPILE_COMMANDS=ON`][cmake-compile-commands] is +not required, but is useful if you are using Visual Studio Code, +Emacs, or another editor that integrates with a C/C++ language server. + For information on what each library can do and their dependencies, see their individual READMEs. To specify where dependencies are to the build, use standard CMake -options such as [`CMAKE_PREFIX_PATH`][cmake-prefix-path]. A list of -dependencies for Conda (conda-forge) is included, and can be used as -follows: - -```shell -$ conda create -n adbc -c conda-forge --file ci/conda_env_cpp.txt -$ conda activate adbc -``` +options such as [`CMAKE_PREFIX_PATH`][cmake-prefix-path]. -Some of Arrow's build options are supported (under a different prefix): +Some build options are supported: -- `ADBC_BUILD_SHARED`, `ADBC_BUILD_STATIC`: build the shared/static libraries. +- `ADBC_BUILD_SHARED`, `ADBC_BUILD_STATIC`: toggle building the + shared/static libraries. - `ADBC_BUILD_TESTS`: build the unit tests (requires googletest/gmock). - `ADBC_INSTALL_NAME_RPATH`: set `install_name` to `@rpath` on MacOS. Usually it is more convenient to explicitly set this to `OFF` for @@ -85,6 +123,7 @@ test-time dependencies. For instance, the PostgreSQL and Flight SQL drivers require servers to test against. See their individual READMEs for details. +[cmake-compile-commands]: https://cmake.org/cmake/help/latest/variable/CMAKE_EXPORT_COMPILE_COMMANDS.html [cmake-prefix-path]: https://cmake.org/cmake/help/latest/variable/CMAKE_PREFIX_PATH.html [gtest]: https://github.com/google/googletest/ @@ -98,6 +137,9 @@ used as follows: ```shell $ conda create -n adbc -c conda-forge --file ci/conda_env_docs.txt $ conda activate adbc +# Mermaid must be installed separately +# While "global", it will end up in your Conda environment +$ npm install -g @mermaid-js/mermaid-cli ``` To build the HTML documentation: diff --git a/3rd_party/apache-arrow-adbc/LICENSE.txt b/3rd_party/apache-arrow-adbc/LICENSE.txt index 316c69e..7eb6402 100644 --- a/3rd_party/apache-arrow-adbc/LICENSE.txt +++ b/3rd_party/apache-arrow-adbc/LICENSE.txt @@ -213,6 +213,22 @@ License: http://www.apache.org/licenses/LICENSE-2.0 -------------------------------------------------------------------------------- +The files in c/vendor/portable-snippets/ contain code from + +https://github.com/nemequ/portable-snippets + +and have the following copyright notice: + +Each source file contains a preamble explaining the license situation +for that file, which takes priority over this file. With the +exception of some code pulled in from other repositories (such as +µnit, an MIT-licensed project which is used for testing), the code is +public domain, released using the CC0 1.0 Universal dedication (*). + +(*) https://creativecommons.org/publicdomain/zero/1.0/legalcode + +-------------------------------------------------------------------------------- + The files python/*/*/_version.py and python/*/*/_static_version.py contain code from diff --git a/3rd_party/apache-arrow-adbc/adbc.h b/3rd_party/apache-arrow-adbc/adbc.h index 154e881..1ec2f05 100644 --- a/3rd_party/apache-arrow-adbc/adbc.h +++ b/3rd_party/apache-arrow-adbc/adbc.h @@ -35,7 +35,7 @@ /// but not concurrent access. Specific implementations may permit /// multiple threads. /// -/// \version 1.0.0 +/// \version 1.1.0 #pragma once @@ -248,7 +248,24 @@ typedef uint8_t AdbcStatusCode; /// May indicate a database-side error only. #define ADBC_STATUS_UNAUTHORIZED 14 +/// \brief Inform the driver/driver manager that we are using the extended +/// AdbcError struct from ADBC 1.1.0. +/// +/// See the AdbcError documentation for usage. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA INT32_MIN + /// \brief A detailed error message for an operation. +/// +/// The caller must zero-initialize this struct (clarified in ADBC 1.1.0). +/// +/// The structure was extended in ADBC 1.1.0. Drivers and clients using ADBC +/// 1.0.0 will not have the private_data or private_driver fields. Drivers +/// should read/write these fields if and only if vendor_code is equal to +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. Clients are required to initialize +/// this struct to avoid the possibility of uninitialized values confusing the +/// driver. struct ADBC_EXPORT AdbcError { /// \brief The error message. char* message; @@ -266,8 +283,112 @@ struct ADBC_EXPORT AdbcError { /// Unlike other structures, this is an embedded callback to make it /// easier for the driver manager and driver to cooperate. void (*release)(struct AdbcError* error); + + /// \brief Opaque implementation-defined state. + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. If present, this field is NULLPTR + /// iff the error is unintialized/freed. + /// + /// \since ADBC API revision 1.1.0 + void* private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + /// + /// This field may not be used unless vendor_code is + /// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. + /// + /// \since ADBC API revision 1.1.0 + struct AdbcDriver* private_driver; }; +#ifdef __cplusplus +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + (AdbcError{nullptr, \ + ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, \ + {0, 0, 0, 0, 0}, \ + nullptr, \ + nullptr, \ + nullptr}) +#else +/// \brief A helper to initialize the full AdbcError structure. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_INIT \ + ((struct AdbcError){ \ + NULL, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA, {0, 0, 0, 0, 0}, NULL, NULL, NULL}) +#endif + +/// \brief The size of the AdbcError structure in ADBC 1.0.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is not +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_0_0_SIZE (offsetof(struct AdbcError, private_data)) +/// \brief The size of the AdbcError structure in ADBC 1.1.0. +/// +/// Drivers written for ADBC 1.1.0 and later should never touch more than this +/// portion of an AdbcDriver struct when vendor_code is +/// ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_ERROR_1_1_0_SIZE (sizeof(struct AdbcError)) + +/// \brief Extra key-value metadata for an error. +/// +/// The fields here are owned by the driver and should not be freed. The +/// fields here are invalidated when the release callback in AdbcError is +/// called. +/// +/// \since ADBC API revision 1.1.0 +struct ADBC_EXPORT AdbcErrorDetail { + /// \brief The metadata key. + const char* key; + /// \brief The binary metadata value. + const uint8_t* value; + /// \brief The length of the metadata value. + size_t value_length; +}; + +/// \brief Get the number of metadata values available in an error. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +int AdbcErrorGetDetailCount(const struct AdbcError* error); + +/// \brief Get a metadata value in an error by index. +/// +/// If index is invalid, returns an AdbcErrorDetail initialized with NULL/0 +/// fields. +/// +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index); + +/// \brief Get an ADBC error from an ArrowArrayStream created by a driver. +/// +/// This allows retrieving error details and other metadata that would +/// normally be suppressed by the Arrow C Stream Interface. +/// +/// The caller MUST NOT release the error; it is managed by the release +/// callback in the stream itself. +/// +/// \param[in] stream The stream to query. +/// \param[out] status The ADBC status code, or ADBC_STATUS_OK if there is no +/// error. Not written to if the stream does not contain an ADBC error or +/// if the pointer is NULL. +/// \return NULL if not supported. +/// \since ADBC API revision 1.1.0 +ADBC_EXPORT +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + /// @} /// \defgroup adbc-constants Constants @@ -279,6 +400,14 @@ struct ADBC_EXPORT AdbcError { /// point to an AdbcDriver. #define ADBC_VERSION_1_0_0 1000000 +/// \brief ADBC revision 1.1.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_VERSION_1_1_0 1001000 + /// \brief Canonical option value for enabling an option. /// /// For use as the value in SetOption calls. @@ -288,6 +417,34 @@ struct ADBC_EXPORT AdbcError { /// For use as the value in SetOption calls. #define ADBC_OPTION_VALUE_DISABLED "false" +/// \brief Canonical option name for URIs. +/// +/// Should be used as the expected option name to specify a URI for +/// any ADBC driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_URI "uri" +/// \brief Canonical option name for usernames. +/// +/// Should be used as the expected option name to specify a username +/// to a driver for authentication. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_USERNAME "username" +/// \brief Canonical option name for passwords. +/// +/// Should be used as the expected option name to specify a password +/// for authentication to a driver. +/// +/// The type is char*. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_OPTION_PASSWORD "password" + /// \brief The database vendor/product name (e.g. the server name). /// (type: utf8). /// @@ -315,6 +472,15 @@ struct ADBC_EXPORT AdbcError { /// /// \see AdbcConnectionGetInfo #define ADBC_INFO_DRIVER_ARROW_VERSION 102 +/// \brief The driver ADBC API version (type: int64). +/// +/// The value should be one of the ADBC_VERSION constants. +/// +/// \since ADBC API revision 1.1.0 +/// \see AdbcConnectionGetInfo +/// \see ADBC_VERSION_1_0_0 +/// \see ADBC_VERSION_1_1_0 +#define ADBC_INFO_DRIVER_ADBC_VERSION 103 /// \brief Return metadata on catalogs, schemas, tables, and columns. /// @@ -337,18 +503,133 @@ struct ADBC_EXPORT AdbcError { /// \see AdbcConnectionGetObjects #define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL +/// \defgroup adbc-table-statistics ADBC Statistic Types +/// Standard statistic names for AdbcConnectionGetStatistics. +/// @{ + +/// \brief The dictionary-encoded name of the average byte width statistic. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY 0 +/// \brief The average byte width statistic. The average size in bytes of a +/// row in the column. Value type is float64. +/// +/// For example, this is roughly the average length of a string for a string +/// column. +#define ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the distinct value count statistic. +#define ADBC_STATISTIC_DISTINCT_COUNT_KEY 1 +/// \brief The distinct value count (NDV) statistic. The number of distinct +/// values in the column. Value type is int64 (when not approximate) or +/// float64 (when approximate). +#define ADBC_STATISTIC_DISTINCT_COUNT_NAME "adbc.statistic.distinct_count" +/// \brief The dictionary-encoded name of the max byte width statistic. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_KEY 2 +/// \brief The max byte width statistic. The maximum size in bytes of a row +/// in the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +/// +/// For example, this is the maximum length of a string for a string column. +#define ADBC_STATISTIC_MAX_BYTE_WIDTH_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the max value statistic. +#define ADBC_STATISTIC_MAX_VALUE_KEY 3 +/// \brief The max value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MAX_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the min value statistic. +#define ADBC_STATISTIC_MIN_VALUE_KEY 4 +/// \brief The min value statistic. Value type is column-dependent. +#define ADBC_STATISTIC_MIN_VALUE_NAME "adbc.statistic.byte_width" +/// \brief The dictionary-encoded name of the null count statistic. +#define ADBC_STATISTIC_NULL_COUNT_KEY 5 +/// \brief The null count statistic. The number of values that are null in +/// the column. Value type is int64 (when not approximate) or float64 +/// (when approximate). +#define ADBC_STATISTIC_NULL_COUNT_NAME "adbc.statistic.null_count" +/// \brief The dictionary-encoded name of the row count statistic. +#define ADBC_STATISTIC_ROW_COUNT_KEY 6 +/// \brief The row count statistic. The number of rows in the column or +/// table. Value type is int64 (when not approximate) or float64 (when +/// approximate). +#define ADBC_STATISTIC_ROW_COUNT_NAME "adbc.statistic.row_count" +/// @} + /// \brief The name of the canonical option for whether autocommit is /// enabled. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" /// \brief The name of the canonical option for whether the current /// connection should be restricted to being read-only. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" +/// \brief The name of the canonical option for the current catalog. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_CATALOG "adbc.connection.catalog" + +/// \brief The name of the canonical option for the current schema. +/// +/// The type is char*. +/// +/// \see AdbcConnectionGetOption +/// \see AdbcConnectionSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA "adbc.connection.db_schema" + +/// \brief The name of the canonical option for making query execution +/// nonblocking. +/// +/// When enabled, AdbcStatementExecutePartitions will return +/// partitions as soon as they are available, instead of returning +/// them all at the end. When there are no more to return, it will +/// return an empty set of partitions. AdbcStatementExecuteQuery and +/// AdbcStatementExecuteSchema are not affected. +/// +/// The default is ADBC_OPTION_VALUE_DISABLED. +/// +/// The type is char*. +/// +/// \see AdbcStatementSetOption +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_INCREMENTAL "adbc.statement.exec.incremental" + +/// \brief The name of the option for getting the progress of a query. +/// +/// The value is not necessarily in any particular range or have any +/// particular units. (For example, it might be a percentage, bytes of data, +/// rows of data, number of workers, etc.) The max value can be retrieved via +/// ADBC_STATEMENT_OPTION_MAX_PROGRESS. This represents the progress of +/// execution, not of consumption (i.e., it is independent of how much of the +/// result set has been read by the client via ArrowArrayStream.get_next().) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_PROGRESS "adbc.statement.exec.progress" + +/// \brief The name of the option for getting the maximum progress of a query. +/// +/// This is the value of ADBC_STATEMENT_OPTION_PROGRESS for a completed query. +/// If not supported, or if the value is nonpositive, then the maximum is not +/// known. (For instance, the query may be fully streaming and the driver +/// does not know when the result set will end.) +/// +/// The type is double. +/// +/// \see AdbcStatementGetOptionDouble +/// \since ADBC API revision 1.1.0 +#define ADBC_STATEMENT_OPTION_MAX_PROGRESS "adbc.statement.exec.max_progress" + /// \brief The name of the canonical option for setting the isolation /// level of a transaction. /// @@ -357,6 +638,8 @@ struct ADBC_EXPORT AdbcError { /// isolation level is not supported by a driver, it should return an /// appropriate error. /// +/// The type is char*. +/// /// \see AdbcConnectionSetOption #define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \ "adbc.connection.transaction.isolation_level" @@ -449,8 +732,12 @@ struct ADBC_EXPORT AdbcError { /// exist. If the table exists but has a different schema, /// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be /// appended to the target table. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" /// \brief Whether to create (the default) or append. +/// +/// The type is char*. #define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" /// \brief Create the table and insert data; error if the table exists. #define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" @@ -458,6 +745,15 @@ struct ADBC_EXPORT AdbcError { /// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match /// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). #define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" +/// \brief Create the table and insert data; drop the original table +/// if it already exists. +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_REPLACE "adbc.ingest.mode.replace" +/// \brief Insert data; create the table if it does not exist, or +/// error if the table exists, but the schema does not match the +/// schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +/// \since ADBC API revision 1.1.0 +#define ADBC_INGEST_OPTION_MODE_CREATE_APPEND "adbc.ingest.mode.create_append" /// @} @@ -624,7 +920,7 @@ struct ADBC_EXPORT AdbcDriver { AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); - AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, const uint32_t*, size_t, struct ArrowArrayStream*, struct AdbcError*); AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, const char*, const char*, const char**, @@ -667,8 +963,108 @@ struct ADBC_EXPORT AdbcDriver { struct AdbcError*); AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, size_t, struct AdbcError*); + + /// \defgroup adbc-1.1.0 ADBC API Revision 1.1.0 + /// + /// Functions added in ADBC 1.1.0. For backwards compatibility, + /// these members must not be accessed unless the version passed to + /// the AdbcDriverInitFunc is greater than or equal to + /// ADBC_VERSION_1_1_0. + /// + /// For a 1.0.0 driver being loaded by a 1.1.0 driver manager: the + /// 1.1.0 manager will allocate the new, expanded AdbcDriver struct + /// and attempt to have the driver initialize it with + /// ADBC_VERSION_1_1_0. This must return an error, after which the + /// driver will try again with ADBC_VERSION_1_0_0. The driver must + /// not access the new fields, which will carry undefined values. + /// + /// For a 1.1.0 driver being loaded by a 1.0.0 driver manager: the + /// 1.0.0 manager will allocate the old AdbcDriver struct and + /// attempt to have the driver initialize it with + /// ADBC_VERSION_1_0_0. The driver must not access the new fields, + /// and should initialize the old fields. + /// + /// @{ + + int (*ErrorGetDetailCount)(const struct AdbcError* error); + struct AdbcErrorDetail (*ErrorGetDetail)(const struct AdbcError* error, int index); + const struct AdbcError* (*ErrorFromArrayStream)(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + + AdbcStatusCode (*DatabaseGetOption)(struct AdbcDatabase*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionBytes)(struct AdbcDatabase*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionDouble)(struct AdbcDatabase*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*DatabaseGetOptionInt)(struct AdbcDatabase*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionBytes)(struct AdbcDatabase*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionDouble)(struct AdbcDatabase*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*DatabaseSetOptionInt)(struct AdbcDatabase*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*ConnectionCancel)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOption)(struct AdbcConnection*, const char*, char*, + size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionBytes)(struct AdbcConnection*, const char*, + uint8_t*, size_t*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionDouble)(struct AdbcConnection*, const char*, + double*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetOptionInt)(struct AdbcConnection*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatistics)(struct AdbcConnection*, const char*, + const char*, const char*, char, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetStatisticNames)(struct AdbcConnection*, + struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionBytes)(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionDouble)(struct AdbcConnection*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*ConnectionSetOptionInt)(struct AdbcConnection*, const char*, int64_t, + struct AdbcError*); + + AdbcStatusCode (*StatementCancel)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementExecuteSchema)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOption)(struct AdbcStatement*, const char*, char*, size_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionBytes)(struct AdbcStatement*, const char*, uint8_t*, + size_t*, struct AdbcError*); + AdbcStatusCode (*StatementGetOptionDouble)(struct AdbcStatement*, const char*, double*, + struct AdbcError*); + AdbcStatusCode (*StatementGetOptionInt)(struct AdbcStatement*, const char*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionBytes)(struct AdbcStatement*, const char*, + const uint8_t*, size_t, struct AdbcError*); + AdbcStatusCode (*StatementSetOptionDouble)(struct AdbcStatement*, const char*, double, + struct AdbcError*); + AdbcStatusCode (*StatementSetOptionInt)(struct AdbcStatement*, const char*, int64_t, + struct AdbcError*); + + /// @} }; +/// \brief The size of the AdbcDriver structure in ADBC 1.0.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_0_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_0_0_SIZE (offsetof(struct AdbcDriver, ErrorGetDetailCount)) + +/// \brief The size of the AdbcDriver structure in ADBC 1.1.0. +/// Drivers written for ADBC 1.1.0 and later should never touch more +/// than this portion of an AdbcDriver struct when given +/// ADBC_VERSION_1_1_0. +/// +/// \since ADBC API revision 1.1.0 +#define ADBC_DRIVER_1_1_0_SIZE (sizeof(struct AdbcDriver)) + /// @} /// \addtogroup adbc-database @@ -684,16 +1080,189 @@ struct ADBC_EXPORT AdbcDriver { ADBC_EXPORT AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error); +/// \brief Get a string option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the database. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a double option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the double +/// representation of an integer option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error); + +/// \brief Get an integer option of the database. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the integer +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error); + /// \brief Set a char* option. /// /// Options may be set before AdbcDatabaseInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error); + +/// \brief Set a double option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error); + +/// \brief Set an integer option on a database. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] database The database. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error); + /// \brief Finish setting options and initialize the database. /// /// Some drivers may support setting options after initialization @@ -730,11 +1299,65 @@ AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, /// Options may be set before AdbcConnectionInit. Some drivers may /// support setting options after initialization as well. /// +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. /// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized ADBC_EXPORT AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a connection. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error); + +/// \brief Set a double option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error); + /// \brief Finish setting options and initialize the connection. /// /// Some drivers may support setting options after initialization @@ -752,6 +1375,30 @@ ADBC_EXPORT AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, struct AdbcError* error); +/// \brief Cancel the in-progress operation on a connection. +/// +/// This can be called during AdbcConnectionGetObjects (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] connection The connection to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no operation to cancel. +/// \return ADBC_STATUS_UNKNOWN if the operation could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error); + /// \defgroup adbc-connection-metadata Metadata /// Functions for retrieving metadata about the database. /// @@ -765,6 +1412,8 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// concurrent active statements and it must execute a SQL query /// internally in order to implement the metadata function). /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// Some functions accept "search pattern" arguments, which are /// strings that can contain the special character "%" to match zero /// or more characters, or "_" to match exactly one character. (See @@ -799,6 +1448,10 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// for ADBC usage. Drivers/vendors will ignore requests for /// unrecognized codes (the row will be omitted from the result). /// +/// Since ADBC 1.1.0: the range [500, 1_000) is reserved for "XDBC" +/// information, which is the same metadata provided by the same info +/// code range in the Arrow Flight SQL GetSqlInfo RPC. +/// /// \param[in] connection The connection to query. /// \param[in] info_codes A list of metadata codes to fetch, or NULL /// to fetch all. @@ -808,7 +1461,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, /// \param[out] error Error details, if an error occurs. ADBC_EXPORT AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); @@ -891,6 +1544,8 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, /// | fk_table | utf8 not null | /// | fk_column_name | utf8 not null | /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[in] depth The level of nesting to display. If 0, display /// all levels. If 1, display only catalogs (i.e. catalog_schemas @@ -922,6 +1577,212 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d struct ArrowArrayStream* out, struct AdbcError* error); +/// \brief Get a string option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the connection. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error); + +/// \brief Get a double option of the connection. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error); + +/// \brief Get statistics about the data distribution of table(s). +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list not null | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_statistics | list not null | +/// +/// STATISTICS_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|----------------------------------| -------- | +/// | table_name | utf8 not null | | +/// | column_name | utf8 | (1) | +/// | statistic_key | int16 not null | (2) | +/// | statistic_value | VALUE_SCHEMA not null | | +/// | statistic_is_approximate | bool not null | (3) | +/// +/// 1. If null, then the statistic applies to the entire table. +/// 2. A dictionary-encoded statistic name (although we do not use the Arrow +/// dictionary type). Values in [0, 1024) are reserved for ADBC. Other +/// values are for implementation-specific statistics. For the definitions +/// of predefined statistic types, see \ref adbc-table-statistics. To get +/// driver-specific statistic names, use AdbcConnectionGetStatisticNames. +/// 3. If true, then the value is approximate or best-effort. +/// +/// VALUE_SCHEMA is a dense union with members: +/// +/// | Field Name | Field Type | +/// |--------------------------|----------------------------------| +/// | int64 | int64 | +/// | uint64 | uint64 | +/// | float64 | float64 | +/// | binary | binary | +/// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr). May be a search +/// pattern (see section documentation). +/// \param[in] db_schema The database schema (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] table_name The table name (or nullptr). May be a +/// search pattern (see section documentation). +/// \param[in] approximate If zero, request exact values of +/// statistics, else allow for best-effort, approximate, or cached +/// values. The database may return approximate values regardless, +/// as indicated in the result. Requesting exact values may be +/// expensive or unsupported. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get the names of statistics specific to this driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|---------------- +/// statistic_name | utf8 not null +/// statistic_key | int16 not null +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error); + /// \brief Get the Arrow schema of a table. /// /// \param[in] connection The database connection. @@ -945,6 +1806,8 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, /// ---------------|-------------- /// table_type | utf8 not null /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The database connection. /// \param[out] out The result set. /// \param[out] error Error details, if an error occurs. @@ -973,6 +1836,8 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, /// /// A partition can be retrieved from AdbcPartitions. /// +/// This AdbcConnection must outlive the returned ArrowArrayStream. +/// /// \param[in] connection The connection to use. This does not have /// to be the same connection that the partition was created on. /// \param[in] serialized_partition The partition descriptor. @@ -1042,7 +1907,11 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, /// \brief Execute a statement and get the results. /// -/// This invalidates any prior result sets. +/// This invalidates any prior result sets. This AdbcStatement must +/// outlive the returned ArrowArrayStream. +/// +/// Since ADBC 1.1.0: releasing the returned ArrowArrayStream without +/// consuming it fully is equivalent to calling AdbcStatementCancel. /// /// \param[in] statement The statement to execute. /// \param[out] out The results. Pass NULL if the client does not @@ -1056,6 +1925,27 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, struct ArrowArrayStream* out, int64_t* rows_affected, struct AdbcError* error); +/// \brief Get the schema of the result set of a query without +/// executing it. +/// +/// This invalidates any prior result sets. +/// +/// Depending on the driver, this may require first executing +/// AdbcStatementPrepare. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The result schema. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support this. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + /// \brief Turn this statement into a prepared statement to be /// executed multiple times. /// @@ -1138,6 +2028,158 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, struct ArrowArrayStream* stream, struct AdbcError* error); +/// \brief Cancel execution of an in-progress query. +/// +/// This can be called during AdbcStatementExecuteQuery (or similar), +/// or while consuming an ArrowArrayStream returned from such. +/// Calling this function should make the other functions return +/// ADBC_STATUS_CANCELLED (from ADBC functions) or ECANCELED (from +/// methods of ArrowArrayStream). (It is not guaranteed to, for +/// instance, the result set may be buffered in memory already.) +/// +/// This must always be thread-safe (other operations are not). It is +/// not necessarily signal-safe. +/// +/// \since ADBC API revision 1.1.0 +/// +/// \param[in] statement The statement to cancel. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// +/// \return ADBC_STATUS_INVALID_STATE if there is no query to cancel. +/// \return ADBC_STATUS_UNKNOWN if the query could not be cancelled. +ADBC_EXPORT +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief Get a string option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call GetOption +/// concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value (including the null terminator) to buffer and set +/// length to the size of the actual value. If the buffer is too +/// small, no data will be written and length will be set to the +/// required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The length of value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error); + +/// \brief Get a bytestring option of the statement. +/// +/// This must always be thread-safe (other operations are not), though +/// given the semantics here, it is not recommended to call +/// GetOptionBytes concurrently with itself. +/// +/// length must be provided and must be the size of the buffer pointed +/// to by value. If there is sufficient space, the driver will copy +/// the option value to buffer and set length to the size of the +/// actual value. If the buffer is too small, no data will be written +/// and length will be set to the required length. +/// +/// In other words: +/// +/// - If output length <= input length, value will contain a value +/// with length bytes. +/// - If output length > input length, nothing has been written to +/// value. +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[in,out] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error); + +/// \brief Get an integer option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error); + +/// \brief Get a double option of the statement. +/// +/// This must always be thread-safe (other operations are not). +/// +/// For standard options, drivers must always support getting the +/// option value (if they support getting option values at all) via +/// the type specified in the option. (For example, an option set via +/// SetOptionDouble must be retrievable via GetOptionDouble.) Drivers +/// may also support getting a converted option value via other +/// getters if needed. (For example, getting the string +/// representation of a double option.) +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to get. +/// \param[out] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_FOUND if the option is not recognized. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error); + /// \brief Get the schema for bound parameters. /// /// This retrieves an Arrow schema describing the number, names, and @@ -1159,10 +2201,58 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct AdbcError* error); /// \brief Set a string option on a statement. +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized. ADBC_EXPORT AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, const char* value, struct AdbcError* error); +/// \brief Set a bytestring option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[in] length The option value length. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error); + +/// \brief Set an integer option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error); + +/// \brief Set a double option on a statement. +/// +/// \since ADBC API revision 1.1.0 +/// \param[in] statement The statement. +/// \param[in] key The option to set. +/// \param[in] value The option value. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error); + /// \addtogroup adbc-statement-partition /// @{ @@ -1198,7 +2288,15 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, /// driver. /// /// Although drivers may choose any name for this function, the -/// recommended name is "AdbcDriverInit". +/// recommended name is "AdbcDriverInit", or a name derived from the +/// name of the driver's shared library as follows: remove the 'lib' +/// prefix (on Unix systems) and all file extensions, then PascalCase +/// the driver name, append Init, and prepend Adbc (if not already +/// there). For example: +/// +/// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit +/// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit +/// - proprietary_driver.dll -> AdbcProprietaryDriverInit /// /// \param[in] version The ADBC revision to attempt to initialize (see /// ADBC_VERSION_1_0_0). diff --git a/3rd_party/apache-arrow-adbc/c/CMakePresets.json b/3rd_party/apache-arrow-adbc/c/CMakePresets.json new file mode 100644 index 0000000..fc4fdcb --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/CMakePresets.json @@ -0,0 +1,70 @@ +{ + "version": 3, + "cmakeMinimumRequired": { + "major": 3, + "minor": 21, + "patch": 0 + }, + "configurePresets": [ + { + "name": "debug", + "displayName": "debug, all drivers, with tests, with ASan/UBSan (not usable from Python)", + "generator": "Ninja", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "ADBC_BUILD_TESTS": "ON", + "ADBC_DRIVER_FLIGHTSQL": "ON", + "ADBC_DRIVER_MANAGER": "ON", + "ADBC_DRIVER_POSTGRESQL": "ON", + "ADBC_DRIVER_SNOWFLAKE": "ON", + "ADBC_DRIVER_SQLITE": "ON", + "ADBC_USE_ASAN": "ON", + "ADBC_USE_UBSAN": "ON" + } + }, + { + "name": "debug-python", + "displayName": "debug, all drivers, with tests, without ASan/UBSan (usable from Python)", + "generator": "Ninja", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "ADBC_BUILD_TESTS": "ON", + "ADBC_DRIVER_FLIGHTSQL": "ON", + "ADBC_DRIVER_MANAGER": "ON", + "ADBC_DRIVER_POSTGRESQL": "ON", + "ADBC_DRIVER_SNOWFLAKE": "ON", + "ADBC_DRIVER_SQLITE": "ON", + "ADBC_USE_ASAN": "OFF", + "ADBC_USE_UBSAN": "OFF" + } + } + ], + "testPresets": [ + { + "name": "debug", + "description": "run tests (except Snowflake)", + "displayName": "debug, all drivers (except Snowflake)", + "configurePreset": "debug", + "environment": { + "ADBC_DREMIO_FLIGHTSQL_PASS": "dremio123", + "ADBC_DREMIO_FLIGHTSQL_URI": "grpc+tcp://localhost:32010", + "ADBC_DREMIO_FLIGHTSQL_USER": "dremio", + "ADBC_POSTGRESQL_TEST_URI": "postgresql://localhost:5432/postgres?user=postgres&password=password", + "ADBC_SQLITE_FLIGHTSQL_URI": "grpc://localhost:8080" + }, + "execution": { + "jobs": 4 + }, + "filter": { + "exclude": { + "label": "driver-snowflake" + } + }, + "output": { + "outputOnFailure": true + } + } + ] +} diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcDefines.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcDefines.cmake index 25dfcba..5cc2eaa 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcDefines.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcDefines.cmake @@ -60,16 +60,35 @@ if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) endif() # Set common build options -macro(adbc_configure_target TARGET) - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - target_compile_options(${TARGET} - PRIVATE -Wall - -Werror - -Wextra - -Wpedantic - -Wno-unused-parameter - -Wunused-result) +if("${ADBC_BUILD_WARNING_LEVEL}" STREQUAL "") + string(TOLOWER "${CMAKE_BUILD_TYPE}" _lower_build_type) + if("${_lower_build_type}" STREQUAL "release") + set(ADBC_BUILD_WARNING_LEVEL "PRODUCTION") + else() + set(ADBC_BUILD_WARNING_LEVEL "CHECKIN") endif() +endif() + +if(MSVC) + set(ADBC_C_CXX_FLAGS_CHECKIN /Wall /WX) + set(ADBC_C_CXX_FLAGS_PRODUCTION /Wall) +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" + OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang" + OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(ADBC_C_CXX_FLAGS_CHECKIN + -Wall + -Wextra + -Wpedantic + -Werror + -Wno-unused-parameter) + set(ADBC_C_CXX_FLAGS_PRODUCTION -Wall) +else() + message(WARNING "Unknown compiler: ${CMAKE_CXX_COMPILER_ID}") +endif() + +macro(adbc_configure_target TARGET) + target_compile_options(${TARGET} + PRIVATE ${ADBC_C_CXX_FLAGS_${ADBC_BUILD_WARNING_LEVEL}}) endmacro() # Common testing setup diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake index 7275918..9d9c2af 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake @@ -21,7 +21,7 @@ # ------------------------------------------------------------ # Version definitions -set(ADBC_VERSION "0.5.1") +set(ADBC_VERSION "0.7.0") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ADBC_BASE_VERSION "${ADBC_VERSION}") string(REPLACE "." ";" _adbc_version_list "${ADBC_BASE_VERSION}") list(GET _adbc_version_list 0 ADBC_VERSION_MAJOR) diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/BuildUtils.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/BuildUtils.cmake index df2590a..de1a7b2 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/BuildUtils.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/BuildUtils.cmake @@ -166,7 +166,6 @@ function(ADD_ARROW_LIB LIB_NAME) add_library(${LIB_NAME}_objlib OBJECT ${ARG_SOURCES}) # Necessary to make static linking into other shared libraries work properly set_property(TARGET ${LIB_NAME}_objlib PROPERTY POSITION_INDEPENDENT_CODE 1) - set_property(TARGET ${LIB_NAME}_objlib PROPERTY CXX_STANDARD 17) set_property(TARGET ${LIB_NAME}_objlib PROPERTY CXX_STANDARD_REQUIRED ON) if(ARG_DEPENDENCIES) add_dependencies(${LIB_NAME}_objlib ${ARG_DEPENDENCIES}) @@ -194,6 +193,9 @@ function(ADD_ARROW_LIB LIB_NAME) target_link_libraries(${LIB_NAME}_objlib PRIVATE ${ARG_SHARED_LINK_LIBS} ${ARG_SHARED_PRIVATE_LINK_LIBS} ${ARG_STATIC_LINK_LIBS}) + adbc_configure_target(${LIB_NAME}_objlib) + # https://github.com/apache/arrow-adbc/issues/81 + target_compile_features(${LIB_NAME}_objlib PRIVATE cxx_std_11) else() # Prepare arguments for separate compilation of static and shared libs below # TODO: add PCH directives @@ -209,7 +211,7 @@ function(ADD_ARROW_LIB LIB_NAME) if(BUILD_SHARED) add_library(${LIB_NAME}_shared SHARED ${LIB_DEPS}) - set_property(TARGET ${LIB_NAME}_shared PROPERTY CXX_STANDARD 17) + target_compile_features(${LIB_NAME}_shared PRIVATE cxx_std_11) set_property(TARGET ${LIB_NAME}_shared PROPERTY CXX_STANDARD_REQUIRED ON) adbc_configure_target(${LIB_NAME}_shared) if(EXTRA_DEPS) @@ -255,6 +257,9 @@ function(ADD_ARROW_LIB LIB_NAME) VERSION "${ADBC_FULL_SO_VERSION}" SOVERSION "${ADBC_SO_VERSION}") + # https://github.com/apache/arrow-adbc/issues/81 + target_compile_features(${LIB_NAME}_shared PRIVATE cxx_std_11) + target_link_libraries(${LIB_NAME}_shared LINK_PUBLIC "$" @@ -304,7 +309,7 @@ function(ADD_ARROW_LIB LIB_NAME) if(BUILD_STATIC) add_library(${LIB_NAME}_static STATIC ${LIB_DEPS}) - set_property(TARGET ${LIB_NAME}_shared PROPERTY CXX_STANDARD 17) + target_compile_features(${LIB_NAME}_static PRIVATE cxx_std_11) set_property(TARGET ${LIB_NAME}_shared PROPERTY CXX_STANDARD_REQUIRED ON) adbc_configure_target(${LIB_NAME}_static) if(EXTRA_DEPS) @@ -342,6 +347,9 @@ function(ADD_ARROW_LIB LIB_NAME) PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${OUTPUT_PATH}" OUTPUT_NAME ${LIB_NAME_STATIC}) + # https://github.com/apache/arrow-adbc/issues/81 + target_compile_features(${LIB_NAME}_static PRIVATE cxx_std_11) + if(ARG_STATIC_INSTALL_INTERFACE_LIBS) target_link_libraries(${LIB_NAME}_static LINK_PUBLIC "$") @@ -584,6 +592,7 @@ function(ADD_TEST_CASE REL_TEST_NAME) set(TEST_PATH "${EXECUTABLE_OUTPUT_PATH}/${TEST_NAME}") add_executable(${TEST_NAME} ${SOURCES}) + adbc_configure_target(${TEST_NAME}) # With OSX and conda, we need to set the correct RPATH so that dependencies # are found. The installed libraries with conda have an RPATH that matches @@ -637,8 +646,6 @@ function(ADD_TEST_CASE REL_TEST_NAME) add_test(${TEST_NAME} ${TEST_PATH} ${ARG_TEST_ARGUMENTS}) endif() - adbc_configure_target(${TEST_NAME}) - # Add test as dependency of relevant targets add_dependencies(all-tests ${TEST_NAME}) foreach(TARGET ${ARG_LABELS}) diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/DefineOptions.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/DefineOptions.cmake index 42b8f4f..b6dd107 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/DefineOptions.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/DefineOptions.cmake @@ -86,6 +86,9 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") #---------------------------------------------------------------------- set_option_category("Compile and link") + define_option_string(ADBC_BUILD_WARNING_LEVEL + "CHECKIN to enable Werror, PRODUCTION otherwise" "") + define_option_string(ADBC_CXXFLAGS "Compiler flags to append when compiling ADBC C++ libraries" "") define_option_string(ADBC_GO_BUILD_TAGS diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/GoUtils.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/GoUtils.cmake index d485214..aac6f5a 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/GoUtils.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/GoUtils.cmake @@ -202,7 +202,10 @@ function(add_go_lib GO_MOD_DIR GO_LIBNAME) DESTINATION ${CMAKE_INSTALL_LIBDIR}) endif() - if(NOT WIN32) + if(WIN32) + # This symlink doesn't get installed + install(FILES "${LIBOUT_SHARED}.${ADBC_SO_VERSION}" TYPE BIN) + else() install(FILES "${LIBOUT_SHARED}" "${LIBOUT_SHARED}.${ADBC_SO_VERSION}" TYPE LIB) endif() endif() diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/san-config.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/san-config.cmake index a678d87..0c5e59b 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/san-config.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/san-config.cmake @@ -25,6 +25,7 @@ if(${ADBC_USE_ASAN}) OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.8")) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -DADDRESS_SANITIZER") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -DADDRESS_SANITIZER") else() message(SEND_ERROR "Cannot use ASAN without clang or gcc >= 4.8") endif() @@ -46,11 +47,17 @@ if(${ADBC_USE_UBSAN}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero -fno-sanitize-recover=all" ) + set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero -fno-sanitize-recover=all" + ) elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "5.1") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr -fno-sanitize-recover=all" ) + set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr -fno-sanitize-recover=all" + ) else() message(SEND_ERROR "Cannot use UBSAN without clang or gcc >= 5.1") endif() diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/driver/common/CMakeLists.txt index 33dd1c8..0da24bb 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/common/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/driver/common/CMakeLists.txt @@ -16,6 +16,7 @@ # under the License. add_library(adbc_driver_common STATIC utils.c) +adbc_configure_target(adbc_driver_common) set_target_properties(adbc_driver_common PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(adbc_driver_common PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/vendor") diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/options.h b/3rd_party/apache-arrow-adbc/c/driver/common/options.h new file mode 100644 index 0000000..f42bb09 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/common/options.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Common options that haven't yet been formally standardized. +/// https://github.com/apache/arrow-adbc/issues/1055 + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +/// \brief The catalog of the table for bulk insert. +/// +/// The type is char*. +#define ADBC_INGEST_OPTION_TARGET_CATALOG "adbc.ingest.target_catalog" + +/// \brief The schema of the table for bulk insert. +/// +/// The type is char*. +#define ADBC_INGEST_OPTION_TARGET_DB_SCHEMA "adbc.ingest.target_db_schema" + +/// \brief Use a temporary table for ingestion. +/// +/// The value should be ADBC_OPTION_VALUE_ENABLED or +/// ADBC_OPTION_VALUE_DISABLED (the default). +/// +/// This is not supported with ADBC_INGEST_OPTION_TARGET_CATALOG and +/// ADBC_INGEST_OPTION_TARGET_DB_SCHEMA. +/// +/// The type is char*. +#define ADBC_INGEST_OPTION_TEMPORARY "adbc.ingest.temporary" + +#ifdef __cplusplus +} +#endif diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/utils.c b/3rd_party/apache-arrow-adbc/c/driver/common/utils.c index dfac14f..ad82e79 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/common/utils.c +++ b/3rd_party/apache-arrow-adbc/c/driver/common/utils.c @@ -17,15 +17,80 @@ #include "utils.h" +#include #include -#include #include #include #include -#include +#include + +static size_t kErrorBufferSize = 1024; + +int AdbcStatusCodeToErrno(AdbcStatusCode code) { + switch (code) { + case ADBC_STATUS_OK: + return 0; + case ADBC_STATUS_UNKNOWN: + return EIO; + case ADBC_STATUS_NOT_IMPLEMENTED: + return ENOTSUP; + case ADBC_STATUS_NOT_FOUND: + return ENOENT; + case ADBC_STATUS_ALREADY_EXISTS: + return EEXIST; + case ADBC_STATUS_INVALID_ARGUMENT: + case ADBC_STATUS_INVALID_STATE: + return EINVAL; + case ADBC_STATUS_INVALID_DATA: + case ADBC_STATUS_INTEGRITY: + case ADBC_STATUS_INTERNAL: + case ADBC_STATUS_IO: + return EIO; + case ADBC_STATUS_CANCELLED: + return ECANCELED; + case ADBC_STATUS_TIMEOUT: + return ETIMEDOUT; + case ADBC_STATUS_UNAUTHENTICATED: + // FreeBSD/macOS have EAUTH, but not other platforms + case ADBC_STATUS_UNAUTHORIZED: + return EACCES; + default: + return EIO; + } +} -static size_t kErrorBufferSize = 256; +/// For ADBC 1.1.0, the structure held in private_data. +struct AdbcErrorDetails { + char* message; + + // The metadata keys (may be NULL). + char** keys; + // The metadata values (may be NULL). + uint8_t** values; + // The metadata value lengths (may be NULL). + size_t* lengths; + // The number of initialized metadata. + int count; + // The length of the keys/values/lengths arrays above. + int capacity; +}; + +static void ReleaseErrorWithDetails(struct AdbcError* error) { + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + free(details->message); + + for (int i = 0; i < details->count; i++) { + free(details->keys[i]); + free(details->values[i]); + } + + free(details->keys); + free(details->values); + free(details->lengths); + free(error->private_data); + *error = ADBC_ERROR_INIT; +} static void ReleaseError(struct AdbcError* error) { free(error->message); @@ -34,20 +99,132 @@ static void ReleaseError(struct AdbcError* error) { } void SetError(struct AdbcError* error, const char* format, ...) { + va_list args; + va_start(args, format); + SetErrorVariadic(error, format, args); + va_end(args); +} + +void SetErrorVariadic(struct AdbcError* error, const char* format, va_list args) { if (!error) return; if (error->release) { // TODO: combine the errors if possible error->release(error); } - error->message = malloc(kErrorBufferSize); - if (!error->message) return; - error->release = &ReleaseError; + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + error->private_data = malloc(sizeof(struct AdbcErrorDetails)); + if (!error->private_data) return; + + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + + details->message = malloc(kErrorBufferSize); + if (!details->message) { + free(details); + return; + } + details->keys = NULL; + details->values = NULL; + details->lengths = NULL; + details->count = 0; + details->capacity = 0; + + error->message = details->message; + error->release = &ReleaseErrorWithDetails; + } else { + error->message = malloc(kErrorBufferSize); + if (!error->message) return; + + error->release = &ReleaseError; + } - va_list args; - va_start(args, format); vsnprintf(error->message, kErrorBufferSize, format, args); - va_end(args); +} + +void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* detail, + size_t detail_length) { + if (error->release != ReleaseErrorWithDetails) return; + + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + if (details->count >= details->capacity) { + int new_capacity = (details->capacity == 0) ? 4 : (2 * details->capacity); + char** new_keys = calloc(new_capacity, sizeof(char*)); + if (!new_keys) { + return; + } + + uint8_t** new_values = calloc(new_capacity, sizeof(uint8_t*)); + if (!new_values) { + free(new_keys); + return; + } + + size_t* new_lengths = calloc(new_capacity, sizeof(size_t*)); + if (!new_lengths) { + free(new_keys); + free(new_values); + return; + } + + if (details->keys != NULL) { + memcpy(new_keys, details->keys, sizeof(char*) * details->count); + free(details->keys); + } + details->keys = new_keys; + + if (details->values != NULL) { + memcpy(new_values, details->values, sizeof(uint8_t*) * details->count); + free(details->values); + } + details->values = new_values; + + if (details->lengths != NULL) { + memcpy(new_lengths, details->lengths, sizeof(size_t) * details->count); + free(details->lengths); + } + details->lengths = new_lengths; + + details->capacity = new_capacity; + } + + char* key_data = strdup(key); + if (!key_data) return; + uint8_t* value_data = malloc(detail_length); + if (!value_data) { + free(key_data); + return; + } + memcpy(value_data, detail, detail_length); + + int index = details->count; + details->keys[index] = key_data; + details->values[index] = value_data; + details->lengths[index] = detail_length; + + details->count++; +} + +int CommonErrorGetDetailCount(const struct AdbcError* error) { + if (error->release != ReleaseErrorWithDetails) { + return 0; + } + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + return details->count; +} + +struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index) { + if (error->release != ReleaseErrorWithDetails) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + struct AdbcErrorDetails* details = (struct AdbcErrorDetails*)error->private_data; + if (index < 0 || index >= details->count) { + return (struct AdbcErrorDetail){NULL, NULL, 0}; + } + return (struct AdbcErrorDetail){ + .key = details->keys[index], + .value = details->values[index], + .value_length = details->lengths[index], + }; } struct SingleBatchArrayStream { @@ -244,6 +421,19 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, return ADBC_STATUS_OK; } +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error) { + CHECK_NA(INTERNAL, ArrowArrayAppendUInt(array->children[0], info_code), error); + // Append to type variant + CHECK_NA(INTERNAL, ArrowArrayAppendInt(array->children[1]->children[2], info_value), + error); + // Append type code/offset + CHECK_NA(INTERNAL, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2), + error); + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error) { ArrowSchemaInit(schema); diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/utils.h b/3rd_party/apache-arrow-adbc/c/driver/common/utils.h index f0b5fa3..e3d81cb 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/common/utils.h +++ b/3rd_party/apache-arrow-adbc/c/driver/common/utils.h @@ -17,31 +17,40 @@ #pragma once +#include #include #include #include #include "nanoarrow/nanoarrow.h" -#if defined(__GNUC__) -#define SET_ERROR_ATTRIBUTE __attribute__((format(printf, 2, 3))) -#else -#define SET_ERROR_ATTRIBUTE -#endif - #ifdef __cplusplus extern "C" { #endif -/// Set error details using a format string. -void SetError(struct AdbcError* error, const char* format, ...) SET_ERROR_ATTRIBUTE; +int AdbcStatusCodeToErrno(AdbcStatusCode code); + +// The printf checking attribute doesn't work properly on gcc 4.8 +// and results in spurious compiler warnings +#if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) +#define ADBC_CHECK_PRINTF_ATTRIBUTE __attribute__((format(printf, 2, 3))) +#else +#define ADBC_CHECK_PRINTF_ATTRIBUTE +#endif -#undef SET_ERROR_ATTRIBUTE +/// Set error message using a format string. +void SetError(struct AdbcError* error, const char* format, + ...) ADBC_CHECK_PRINTF_ATTRIBUTE; -/// Wrap a single batch as a stream. -AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* schema, - struct ArrowArrayStream* stream, - struct AdbcError* error); +/// Set error message using a format string. +void SetErrorVariadic(struct AdbcError* error, const char* format, va_list args); + +/// Add an error detail. +void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* detail, + size_t detail_length); + +int CommonErrorGetDetailCount(const struct AdbcError* error); +struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index); struct StringBuilder { char* buffer; @@ -51,15 +60,17 @@ struct StringBuilder { }; int StringBuilderInit(struct StringBuilder* builder, size_t initial_size); -#if defined(__GNUC__) -#define ADBC_STRING_BUILDER_FORMAT_CHECK __attribute__((format(printf, 2, 3))) -#else -#define ADBC_STRING_BUILDER_FORMAT_CHECK -#endif -int ADBC_STRING_BUILDER_FORMAT_CHECK StringBuilderAppend(struct StringBuilder* builder, - const char* fmt, ...); +int ADBC_CHECK_PRINTF_ATTRIBUTE StringBuilderAppend(struct StringBuilder* builder, + const char* fmt, ...); void StringBuilderReset(struct StringBuilder* builder); +#undef ADBC_CHECK_PRINTF_ATTRIBUTE + +/// Wrap a single batch as a stream. +AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* schema, + struct ArrowArrayStream* stream, + struct AdbcError* error); + /// Check an NanoArrow status code. #define CHECK_NA(CODE, EXPR, ERROR) \ do { \ @@ -119,6 +130,9 @@ AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, const char* info_value, struct AdbcError* error); +AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, + uint32_t info_code, int64_t info_value, + struct AdbcError* error); AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, struct AdbcError* error); diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/utils_test.cc b/3rd_party/apache-arrow-adbc/c/driver/common/utils_test.cc index 6fa7e25..d5c202b 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/common/utils_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/common/utils_test.cc @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include + +#include #include #include "utils.h" @@ -72,3 +78,92 @@ TEST(TestStringBuilder, TestMultipleAppends) { StringBuilderReset(&str); } + +TEST(ErrorDetails, Adbc100) { + struct AdbcError error; + std::memset(&error, 0, ADBC_ERROR_1_1_0_SIZE); + + SetError(&error, "My message"); + + ASSERT_EQ(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); + + { + std::string detail = "detail"; + AppendErrorDetail(&error, "key", reinterpret_cast(detail.data()), + detail.size()); + } + + ASSERT_EQ(0, CommonErrorGetDetailCount(&error)); + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, 0); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + error.release(&error); +} + +TEST(ErrorDetails, Adbc110) { + struct AdbcError error = ADBC_ERROR_INIT; + SetError(&error, "My message"); + + ASSERT_NE(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); + + { + std::string detail = "detail"; + AppendErrorDetail(&error, "key", reinterpret_cast(detail.data()), + detail.size()); + } + + ASSERT_EQ(1, CommonErrorGetDetailCount(&error)); + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, 0); + ASSERT_STREQ("key", detail.key); + ASSERT_EQ("detail", std::string_view(reinterpret_cast(detail.value), + detail.value_length)); + + detail = CommonErrorGetDetail(&error, -1); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + detail = CommonErrorGetDetail(&error, 2); + ASSERT_EQ(nullptr, detail.key); + ASSERT_EQ(nullptr, detail.value); + ASSERT_EQ(0, detail.value_length); + + error.release(&error); + ASSERT_EQ(nullptr, error.private_data); + ASSERT_EQ(nullptr, error.private_driver); +} + +TEST(ErrorDetails, RoundTripValues) { + struct AdbcError error = ADBC_ERROR_INIT; + SetError(&error, "My message"); + + struct Detail { + std::string key; + std::vector value; + }; + + std::vector details = { + {"x-key-1", {0, 1, 2, 3}}, {"x-key-2", {1, 1}}, {"x-key-3", {128, 129, 200, 0, 1}}, + {"x-key-4", {97, 98, 99}}, {"x-key-5", {42}}, + }; + + for (const auto& detail : details) { + AppendErrorDetail(&error, detail.key.c_str(), detail.value.data(), + detail.value.size()); + } + + ASSERT_EQ(details.size(), CommonErrorGetDetailCount(&error)); + for (int i = 0; i < static_cast(details.size()); i++) { + struct AdbcErrorDetail detail = CommonErrorGetDetail(&error, i); + ASSERT_EQ(details[i].key, detail.key); + ASSERT_EQ(details[i].value.size(), detail.value_length); + ASSERT_THAT(std::vector(detail.value, detail.value + detail.value_length), + ::testing::ElementsAreArray(details[i].value)); + } + + error.release(&error); +} diff --git a/3rd_party/apache-arrow-adbc/c/driver/flightsql/dremio_flightsql_test.cc b/3rd_party/apache-arrow-adbc/c/driver/flightsql/dremio_flightsql_test.cc index f4e0642..52c184f 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/flightsql/dremio_flightsql_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/flightsql/dremio_flightsql_test.cc @@ -42,11 +42,11 @@ class DremioFlightSqlQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return false; } bool supports_get_sql_info() const override { return false; } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return false; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return false; } }; @@ -87,6 +87,12 @@ class DremioFlightSqlStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } + void TestResultInvalidation() { GTEST_SKIP() << "Dremio generates a CANCELLED"; } + void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } + void TestSqlIngestColumnEscaping() { + GTEST_SKIP() << "Column escaping not implemented"; + } + protected: DremioFlightSqlQuirks quirks_; }; diff --git a/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc b/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc index 2bf0441..41a7ded 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc @@ -15,17 +15,28 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include + #include #include #include #include #include #include + #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" +using adbc_validation::IsOkErrno; using adbc_validation::IsOkStatus; +extern "C" { +AdbcStatusCode FlightSQLDriverInit(int, void*, struct AdbcError*); +} + #define CHECK_OK(EXPR) \ do { \ if (auto adbc_status = (EXPR); adbc_status != ADBC_STATUS_OK) { \ @@ -83,11 +94,31 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { } std::string BindParameter(int index) const override { return "?"; } + + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return false; } bool supports_get_sql_info() const override { return true; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC Flight SQL Driver - Go"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown or development build)"; + case ADBC_INFO_DRIVER_ADBC_VERSION: + return ADBC_VERSION_1_1_0; + case ADBC_INFO_VENDOR_NAME: + return "db_name"; + case ADBC_INFO_VENDOR_VERSION: + return "sqlite 3"; + case ADBC_INFO_VENDOR_ARROW_VERSION: + return "12.0.0"; + default: + return std::nullopt; + } + } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return false; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } }; @@ -103,6 +134,120 @@ class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::Data }; ADBCV_TEST_DATABASE(SqliteFlightSqlTest) +TEST_F(SqliteFlightSqlTest, TestGarbageInput) { + // Regression test for https://github.com/apache/arrow-adbc/issues/729 + + // 0xc000000000 is the base of the Go heap. Go's write barriers ask + // the GC to mark both the pointer being written, and the pointer + // being *overwritten*. So if Go overwrites a value in a C + // structure that looks like a Go pointer, the GC may get confused + // and error. + void* bad_pointer = reinterpret_cast(uintptr_t(0xc000000240)); + + // ADBC functions are expected not to blindly overwrite an + // already-allocated value/callers are expected to zero-initialize. + database.private_data = bad_pointer; + database.private_driver = reinterpret_cast(bad_pointer); + ASSERT_THAT(AdbcDatabaseNew(&database, &error), ::testing::Not(IsOkStatus(&error))); + + std::memset(&database, 0, sizeof(database)); + ASSERT_THAT(AdbcDatabaseNew(&database, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->SetupDatabase(&database, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcDatabaseInit(&database, &error), IsOkStatus(&error)); + + struct AdbcConnection connection; + connection.private_data = bad_pointer; + connection.private_driver = reinterpret_cast(bad_pointer); + ASSERT_THAT(AdbcConnectionNew(&connection, &error), ::testing::Not(IsOkStatus(&error))); + + std::memset(&connection, 0, sizeof(connection)); + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + struct AdbcStatement statement; + statement.private_data = bad_pointer; + statement.private_driver = reinterpret_cast(bad_pointer); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), + ::testing::Not(IsOkStatus(&error))); + + // This needs to happen in parallel since we need to trigger the + // write barrier buffer, which means we need to trigger a GC. The + // Go FFI bridge deterministically triggers GC on Release calls. + + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (std::chrono::steady_clock::now() < deadline) { + std::vector threads; + std::random_device rd; + for (int i = 0; i < 23; i++) { + auto seed = rd(); + threads.emplace_back([&, seed]() { + std::mt19937 gen(seed); + std::uniform_int_distribution dist(0xc000000000L, 0xc000002000L); + for (int i = 0; i < 23; i++) { + void* bad_pointer = reinterpret_cast(uintptr_t(dist(gen))); + + struct AdbcStatement statement; + std::memset(&statement, 0, sizeof(statement)); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 1", &error), + IsOkStatus(&error)); + // This is not expected to be zero-initialized + struct ArrowArrayStream stream; + stream.private_data = bad_pointer; + stream.release = + reinterpret_cast(bad_pointer); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &stream, nullptr, &error), + IsOkStatus(&error)); + + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + schema.name = reinterpret_cast(bad_pointer); + schema.format = reinterpret_cast(bad_pointer); + schema.private_data = bad_pointer; + ASSERT_THAT(stream.get_schema(&stream, &schema), IsOkErrno()); + + while (true) { + struct ArrowArray array; + array.private_data = bad_pointer; + ASSERT_THAT(stream.get_next(&stream, &array), IsOkErrno()); + if (array.release) { + array.release(&array); + } else { + break; + } + } + + schema.release(&schema); + stream.release(&stream); + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + } + }); + } + for (auto& thread : threads) { + thread.join(); + } + } + + ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error)); +} + +TEST_F(SqliteFlightSqlTest, AdbcDriverBackwardsCompatibility) { + // XXX: sketchy cast + auto* driver = static_cast(malloc(ADBC_DRIVER_1_0_0_SIZE)); + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + + ASSERT_THAT(::FlightSQLDriverInit(ADBC_VERSION_1_0_0, driver, &error), + IsOkStatus(&error)); + + ASSERT_THAT(::FlightSQLDriverInit(424242, driver, &error), + adbc_validation::IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + free(driver); +} + class SqliteFlightSqlConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -122,7 +267,151 @@ class SqliteFlightSqlStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } + void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } + void TestSqlIngestColumnEscaping() { + GTEST_SKIP() << "Column escaping not implemented"; + } + void TestSqlIngestInterval() { + GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; + } + protected: SqliteFlightSqlQuirks quirks_; }; ADBCV_TEST_STATEMENT(SqliteFlightSqlStatementTest) + +// Test what happens when using the ADBC 1.1.0 error structure +TEST_F(SqliteFlightSqlStatementTest, NonexistentTable) { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "SELECT * FROM tabledoesnotexist", &error), + IsOkStatus(&error)); + + for (auto vendor_code : {0, ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA}) { + error.vendor_code = vendor_code; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + ASSERT_EQ(0, AdbcErrorGetDetailCount(&error)); + error.release(&error); + } +} + +TEST_F(SqliteFlightSqlStatementTest, CancelError) { + // Ensure cancellation propagates properly through the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT 1", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementCancel(&statement.value, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_EQ(ECANCELED, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* adbc_error = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, adbc_error); + ASSERT_EQ(ADBC_STATUS_CANCELLED, status); +} + +TEST_F(SqliteFlightSqlStatementTest, RpcError) { + // Ensure errors that happen at the start of the stream propagate properly + // through the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + + int count = AdbcErrorGetDetailCount(&error); + ASSERT_NE(0, count); + for (int i = 0; i < count; i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(&error, i); + ASSERT_NE(nullptr, detail.key); + ASSERT_NE(nullptr, detail.value); + ASSERT_NE(0, detail.value_length); + EXPECT_STREQ("afsql-sqlite-query", detail.key); + EXPECT_EQ("SELECT", std::string_view(reinterpret_cast(detail.value), + detail.value_length)); + } +} + +TEST_F(SqliteFlightSqlStatementTest, StreamError) { + // Ensure errors that happen during the stream propagate properly through + // the Go FFI boundary + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + R"( +DROP TABLE IF EXISTS foo; +CREATE TABLE foo (a INT); +WITH RECURSIVE sequence(x) AS + (SELECT 1 UNION ALL SELECT x+1 FROM sequence LIMIT 1024) +INSERT INTO foo(a) +SELECT x FROM sequence; +INSERT INTO foo(a) VALUES ('foo');)", + &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM foo", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, + &reader.rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_NE(0, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* adbc_error = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, adbc_error); + ASSERT_EQ(ADBC_STATUS_UNKNOWN, status); + + int count = AdbcErrorGetDetailCount(adbc_error); + ASSERT_NE(0, count); + for (int i = 0; i < count; i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(adbc_error, i); + ASSERT_NE(nullptr, detail.key); + ASSERT_NE(nullptr, detail.value); + ASSERT_NE(0, detail.value_length); + EXPECT_STREQ("grpc-status-details-bin", detail.key); + } +} diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/driver/postgresql/CMakeLists.txt index d169794..6deae29 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/CMakeLists.txt @@ -29,8 +29,10 @@ endif() add_arrow_lib(adbc_driver_postgresql SOURCES connection.cc + error.cc database.cc postgresql.cc + result_helper.cc statement.cc OUTPUTS ADBC_LIBRARIES diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/README.md b/3rd_party/apache-arrow-adbc/c/driver/postgresql/README.md index cc5a3df..8ccffb6 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/README.md +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/README.md @@ -54,9 +54,9 @@ Alternatively use the `docker compose` provided by ADBC to manage the test database container. ```shell -$ docker compose up postgres_test +$ docker compose up postgres-test # When finished: -# docker compose down postgres_test +# docker compose down postgres-test ``` Then, to run the tests, set the environment variable specifying the diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc index 611cd51..a9f7405 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -32,145 +33,22 @@ #include "common/utils.h" #include "database.h" +#include "error.h" +#include "result_helper.h" +namespace adbcpq { namespace { static const uint32_t kSupportedInfoCodes[] = { - ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ARROW_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, }; static const std::unordered_map kPgTableTypes = { {"table", "r"}, {"view", "v"}, {"materialized_view", "m"}, {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}}; -struct PqRecord { - const char* data; - const int len; - const bool is_null; -}; - -// Used by PqResultHelper to provide index-based access to the records within each -// row of a pg_result -class PqResultRow { - public: - PqResultRow(pg_result* result, int row_num) : result_(result), row_num_(row_num) { - ncols_ = PQnfields(result); - } - - PqRecord operator[](const int& col_num) { - assert(col_num < ncols_); - const char* data = PQgetvalue(result_, row_num_, col_num); - const int len = PQgetlength(result_, row_num_, col_num); - const bool is_null = PQgetisnull(result_, row_num_, col_num); - - return PqRecord{data, len, is_null}; - } - - private: - pg_result* result_ = nullptr; - int row_num_; - int ncols_; -}; - -// Helper to manager the lifecycle of a PQResult. The query argument -// will be evaluated as part of the constructor, with the desctructor handling cleanup -// Caller must call Prepare then Execute, checking both for an OK AdbcStatusCode -// prior to iterating -class PqResultHelper { - public: - explicit PqResultHelper(PGconn* conn, std::string query, struct AdbcError* error) - : conn_(conn), query_(std::move(query)), error_(error) {} - - explicit PqResultHelper(PGconn* conn, std::string query, - std::vector param_values, struct AdbcError* error) - : conn_(conn), - query_(std::move(query)), - param_values_(param_values), - error_(error) {} - - AdbcStatusCode Prepare() { - // TODO: make stmtName a unique identifier? - PGresult* result = - PQprepare(conn_, /*stmtName=*/"", query_.c_str(), param_values_.size(), NULL); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error_, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn_), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - - PQclear(result); - return ADBC_STATUS_OK; - } - - AdbcStatusCode Execute() { - std::vector param_c_strs; - - for (auto index = 0; index < param_values_.size(); index++) { - param_c_strs.push_back(param_values_[index].c_str()); - } - - result_ = PQexecPrepared(conn_, "", param_values_.size(), param_c_strs.data(), NULL, - NULL, 0); - - if (PQresultStatus(result_) != PGRES_TUPLES_OK) { - SetError(error_, "[libpq] Failed to execute query: %s", PQerrorMessage(conn_)); - return ADBC_STATUS_IO; - } - - return ADBC_STATUS_OK; - } - - ~PqResultHelper() { - if (result_ != nullptr) { - PQclear(result_); - } - } - - int NumRows() { return PQntuples(result_); } - - int NumColumns() { return PQnfields(result_); } - - class iterator { - const PqResultHelper& outer_; - int curr_row_ = 0; - - public: - explicit iterator(const PqResultHelper& outer, int curr_row = 0) - : outer_(outer), curr_row_(curr_row) {} - iterator& operator++() { - curr_row_++; - return *this; - } - iterator operator++(int) { - iterator retval = *this; - ++(*this); - return retval; - } - bool operator==(iterator other) const { - return outer_.result_ == other.outer_.result_ && curr_row_ == other.curr_row_; - } - bool operator!=(iterator other) const { return !(*this == other); } - PqResultRow operator*() { return PqResultRow(outer_.result_, curr_row_); } - using iterator_category = std::forward_iterator_tag; - using difference_type = std::ptrdiff_t; - using value_type = std::vector; - using pointer = const std::vector*; - using reference = const std::vector&; - }; - - iterator begin() { return iterator(*this); } - iterator end() { return iterator(*this, NumRows()); } - - private: - pg_result* result_ = nullptr; - PGconn* conn_; - std::string query_; - std::vector param_values_; - struct AdbcError* error_; -}; - class PqGetObjectsHelper { public: PqGetObjectsHelper(PGconn* conn, int depth, const char* catalog, const char* db_schema, @@ -191,17 +69,6 @@ class PqGetObjectsHelper { } AdbcStatusCode GetObjects() { - PqResultHelper curr_db_helper = - PqResultHelper{conn_, std::string("SELECT current_database()"), error_}; - - RAISE_ADBC(curr_db_helper.Prepare()); - RAISE_ADBC(curr_db_helper.Execute()); - - assert(curr_db_helper.NumRows() == 1); - auto curr_iter = curr_db_helper.begin(); - PqResultRow db_row = *curr_iter; - current_db_ = std::string(db_row[0].data); - RAISE_ADBC(InitArrowArray()); catalog_name_col_ = array_->children[0]; @@ -252,8 +119,9 @@ class PqGetObjectsHelper { AdbcStatusCode AppendSchemas(std::string db_name) { // postgres only allows you to list schemas for the currently connected db - if (db_name == current_db_) { - struct StringBuilder query = {0}; + if (!strcmp(db_name.c_str(), PQdb(conn_))) { + struct StringBuilder query; + std::memset(&query, 0, sizeof(query)); if (StringBuilderInit(&query, /*initial_size*/ 256)) { return ADBC_STATUS_INTERNAL; } @@ -302,7 +170,8 @@ class PqGetObjectsHelper { } AdbcStatusCode AppendCatalogs() { - struct StringBuilder query = {0}; + struct StringBuilder query; + std::memset(&query, 0, sizeof(query)); if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL; if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) { @@ -341,7 +210,8 @@ class PqGetObjectsHelper { } AdbcStatusCode AppendTables(std::string schema_name) { - struct StringBuilder query = {0}; + struct StringBuilder query; + std::memset(&query, 0, sizeof(query)); if (StringBuilderInit(&query, /*initial_size*/ 512)) { return ADBC_STATUS_INTERNAL; } @@ -375,8 +245,8 @@ class PqGetObjectsHelper { const char** table_types = table_types_; while (*table_types != NULL) { auto table_type_str = std::string(*table_types); - if (auto search = kPgTableTypes.find(table_type_str); - search != kPgTableTypes.end()) { + auto search = kPgTableTypes.find(table_type_str); + if (search != kPgTableTypes.end()) { table_type_filter.push_back(search->second); } table_types++; @@ -440,7 +310,8 @@ class PqGetObjectsHelper { } AdbcStatusCode AppendColumns(std::string schema_name, std::string table_name) { - struct StringBuilder query = {0}; + struct StringBuilder query; + std::memset(&query, 0, sizeof(query)); if (StringBuilderInit(&query, /*initial_size*/ 512)) { return ADBC_STATUS_INTERNAL; } @@ -527,7 +398,8 @@ class PqGetObjectsHelper { } AdbcStatusCode AppendConstraints(std::string schema_name, std::string table_name) { - struct StringBuilder query = {0}; + struct StringBuilder query; + std::memset(&query, 0, sizeof(query)); if (StringBuilderInit(&query, /*initial_size*/ 4096)) { return ADBC_STATUS_INTERNAL; } @@ -658,10 +530,9 @@ class PqGetObjectsHelper { const char* constraint_ftable_name = row[4].data; auto constraint_fcolumn_names = PqTextArrayToVector(std::string(row[5].data)); for (const auto& constraint_fcolumn_name : constraint_fcolumn_names) { - CHECK_NA( - INTERNAL, - ArrowArrayAppendString(fk_catalog_col_, ArrowCharView(current_db_.c_str())), - error_); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(fk_catalog_col_, ArrowCharView(PQdb(conn_))), + error_); CHECK_NA(INTERNAL, ArrowArrayAppendString(fk_db_schema_col_, ArrowCharView(constraint_ftable_schema)), @@ -705,7 +576,6 @@ class PqGetObjectsHelper { struct ArrowArray* array_; struct AdbcError* error_; struct ArrowError na_error_; - std::string current_db_; struct ArrowArray* catalog_name_col_; struct ArrowArray* catalog_db_schemas_col_; struct ArrowArray* catalog_db_schemas_items_; @@ -733,9 +603,25 @@ class PqGetObjectsHelper { struct ArrowArray* fk_column_name_col_; }; +// A notice processor that does nothing with notices. In the future we can log +// these, but this suppresses the default of printing to stderr. +void SilentNoticeProcessor(void* /*arg*/, const char* /*message*/) {} + } // namespace -namespace adbcpq { +AdbcStatusCode PostgresConnection::Cancel(struct AdbcError* error) { + // > errbuf must be a char array of size errbufsize (the recommended size is + // > 256 bytes). + // https://www.postgresql.org/docs/current/libpq-cancel.html + char errbuf[256]; + // > The return value is 1 if the cancel request was successfully dispatched + // > and 0 if not. + if (PQcancel(cancel_, errbuf, sizeof(errbuf)) != 1) { + SetError(error, "[libpq] Failed to cancel operation: %s", errbuf); + return ADBC_STATUS_UNKNOWN; + } + return ADBC_STATUS_OK; +} AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { if (autocommit_) { @@ -745,19 +631,18 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { PGresult* result = PQexec(conn_, "COMMIT"); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "%s%s", "[libpq] Failed to commit: ", PQerrorMessage(conn_)); + AdbcStatusCode code = SetError(error, result, "%s%s", + "[libpq] Failed to commit: ", PQerrorMessage(conn_)); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; } -AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, - size_t info_codes_length, - struct ArrowSchema* schema, - struct ArrowArray* array, - struct AdbcError* error) { +AdbcStatusCode PostgresConnection::PostgresConnectionGetInfoImpl( + const uint32_t* info_codes, size_t info_codes_length, struct ArrowSchema* schema, + struct ArrowArray* array, struct AdbcError* error) { RAISE_ADBC(AdbcInitConnectionGetInfoSchema(info_codes, info_codes_length, schema, array, error)); @@ -767,10 +652,22 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, RAISE_ADBC( AdbcConnectionGetInfoAppendString(array, info_codes[i], "PostgreSQL", error)); break; - case ADBC_INFO_VENDOR_VERSION: - RAISE_ADBC(AdbcConnectionGetInfoAppendString( - array, info_codes[i], std::to_string(PQlibVersion()).c_str(), error)); + case ADBC_INFO_VENDOR_VERSION: { + const char* stmt = "SHOW server_version_num"; + auto result_helper = PqResultHelper{conn_, std::string(stmt), error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); + return ADBC_STATUS_INTERNAL; + } + const char* server_version_num = (*it)[0].data; + + RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], + server_version_num, error)); break; + } case ADBC_INFO_DRIVER_NAME: RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], "ADBC PostgreSQL Driver", error)); @@ -784,6 +681,10 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], NANOARROW_VERSION, error)); break; + case ADBC_INFO_DRIVER_ADBC_VERSION: + RAISE_ADBC(AdbcConnectionGetInfoAppendInt(array, info_codes[i], + ADBC_VERSION_1_1_0, error)); + break; default: // Ignore continue; @@ -799,21 +700,22 @@ AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, } AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { - // XXX: mistake in adbc.h (should have been const pointer) - const uint32_t* codes = info_codes; if (!info_codes) { - codes = kSupportedInfoCodes; + info_codes = kSupportedInfoCodes; info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); } - struct ArrowSchema schema = {0}; - struct ArrowArray array = {0}; + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); - AdbcStatusCode status = - PostgresConnectionGetInfoImpl(codes, info_codes_length, &schema, &array, error); + AdbcStatusCode status = PostgresConnectionGetInfoImpl(info_codes, info_codes_length, + &schema, &array, error); if (status != ADBC_STATUS_OK) { if (schema.release) schema.release(&schema); if (array.release) array.release(&array); @@ -827,8 +729,10 @@ AdbcStatusCode PostgresConnection::GetObjects( struct AdbcConnection* connection, int depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_types, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error) { - struct ArrowSchema schema = {0}; - struct ArrowArray array = {0}; + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); PqGetObjectsHelper helper = PqGetObjectsHelper(conn_, depth, catalog, db_schema, table_name, table_types, @@ -844,13 +748,407 @@ AdbcStatusCode PostgresConnection::GetObjects( return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + std::string output; + if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { + output = PQdb(conn_); + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA", {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + return ADBC_STATUS_INTERNAL; + } + output = (*it)[0].data; + } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { + output = autocommit_ ? ADBC_OPTION_VALUE_ENABLED : ADBC_OPTION_VALUE_DISABLED; + } else { + return ADBC_STATUS_NOT_FOUND; + } + + if (output.size() + 1 <= *length) { + std::memcpy(value, output.c_str(), output.size() + 1); + } + *length = output.size() + 1; + return ADBC_STATUS_OK; +} +AdbcStatusCode PostgresConnection::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresConnection::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_schema, + const char* table_name, + struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error) { + // Set up schema + auto uschema = nanoarrow::UniqueSchema(); + { + ArrowSchemaInit(uschema.get()); + CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), /*num_columns=*/2), error); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[0], "catalog_name"), error); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema->children[1], NANOARROW_TYPE_LIST), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema->children[1], "catalog_db_schemas"), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema->children[1]->children[0], 2), + error); + uschema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* db_schema_schema = uschema->children[1]->children[0]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_statistics"), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 5), + error); + db_schema_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* statistics_schema = db_schema_schema->children[1]->children[0]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[0], "table_name"), + error); + statistics_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[1], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(statistics_schema->children[1], "column_name"), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[2], NANOARROW_TYPE_INT16), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(statistics_schema->children[2], "statistic_key"), error); + statistics_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetTypeUnion(statistics_schema->children[3], + NANOARROW_TYPE_DENSE_UNION, 4), + error); + CHECK_NA(INTERNAL, + ArrowSchemaSetName(statistics_schema->children[3], "statistic_value"), + error); + statistics_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(statistics_schema->children[4], NANOARROW_TYPE_BOOL), + error); + CHECK_NA( + INTERNAL, + ArrowSchemaSetName(statistics_schema->children[4], "statistic_is_approximate"), + error); + statistics_schema->children[4]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* value_schema = statistics_schema->children[3]; + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[0], NANOARROW_TYPE_INT64), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[0], "int64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[1], NANOARROW_TYPE_UINT64), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[1], "uint64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[2], NANOARROW_TYPE_DOUBLE), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[2], "float64"), error); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(value_schema->children[3], NANOARROW_TYPE_BINARY), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(value_schema->children[3], "binary"), error); + } + + // Set up builders + struct ArrowError na_error = {0}; + CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), &na_error), + &na_error, error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); + + struct ArrowArray* catalog_name_col = array->children[0]; + struct ArrowArray* catalog_db_schemas_col = array->children[1]; + struct ArrowArray* catalog_db_schemas_items = catalog_db_schemas_col->children[0]; + struct ArrowArray* db_schema_name_col = catalog_db_schemas_items->children[0]; + struct ArrowArray* db_schema_statistics_col = catalog_db_schemas_items->children[1]; + struct ArrowArray* db_schema_statistics_items = db_schema_statistics_col->children[0]; + struct ArrowArray* statistics_table_name_col = db_schema_statistics_items->children[0]; + struct ArrowArray* statistics_column_name_col = db_schema_statistics_items->children[1]; + struct ArrowArray* statistics_key_col = db_schema_statistics_items->children[2]; + struct ArrowArray* statistics_value_col = db_schema_statistics_items->children[3]; + struct ArrowArray* statistics_is_approximate_col = + db_schema_statistics_items->children[4]; + // struct ArrowArray* value_int64_col = statistics_value_col->children[0]; + // struct ArrowArray* value_uint64_col = statistics_value_col->children[1]; + struct ArrowArray* value_float64_col = statistics_value_col->children[2]; + // struct ArrowArray* value_binary_col = statistics_value_col->children[3]; + + // Query (could probably be massively improved) + std::string query = R"( + WITH + class AS ( + SELECT nspname, relname, reltuples + FROM pg_namespace + INNER JOIN pg_class ON pg_class.relnamespace = pg_namespace.oid + ) + SELECT tablename, attname, null_frac, avg_width, n_distinct, reltuples + FROM pg_stats + INNER JOIN class ON pg_stats.schemaname = class.nspname AND pg_stats.tablename = class.relname + WHERE pg_stats.schemaname = $1 AND tablename LIKE $2 + ORDER BY tablename +)"; + + CHECK_NA(INTERNAL, ArrowArrayAppendString(catalog_name_col, ArrowCharView(PQdb(conn))), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendString(db_schema_name_col, ArrowCharView(db_schema)), + error); + + constexpr int8_t kStatsVariantFloat64 = 2; + + std::string prev_table; + + { + PqResultHelper result_helper{ + conn, query, {db_schema, table_name ? table_name : "%"}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + + for (PqResultRow row : result_helper) { + auto reltuples = row[5].ParseDouble(); + if (!reltuples.first) { + SetError(error, "[libpq] Invalid double value in reltuples: '%s'", row[5].data); + return ADBC_STATUS_INTERNAL; + } + + if (std::strcmp(prev_table.c_str(), row[0].data) != 0) { + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendNull(statistics_column_name_col, 1), error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_ROW_COUNT_KEY), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendDouble(value_float64_col, reltuples.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + prev_table = std::string(row[0].data, row[0].len); + } + + auto null_frac = row[2].ParseDouble(); + if (!null_frac.first) { + SetError(error, "[libpq] Invalid double value in null_frac: '%s'", row[2].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_NULL_COUNT_KEY), + error); + CHECK_NA( + INTERNAL, + ArrowArrayAppendDouble(value_float64_col, null_frac.second * reltuples.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + + auto average_byte_width = row[3].ParseDouble(); + if (!average_byte_width.first) { + SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[3].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA( + INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendDouble(value_float64_col, average_byte_width.second), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + + auto n_distinct = row[4].ParseDouble(); + if (!n_distinct.first) { + SetError(error, "[libpq] Invalid double value in avg_width: '%s'", row[4].data); + return ADBC_STATUS_INTERNAL; + } + + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_table_name_col, + ArrowStringView{row[0].data, row[0].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendString(statistics_column_name_col, + ArrowStringView{row[1].data, row[1].len}), + error); + CHECK_NA(INTERNAL, + ArrowArrayAppendInt(statistics_key_col, ADBC_STATISTIC_DISTINCT_COUNT_KEY), + error); + // > If greater than zero, the estimated number of distinct values in + // > the column. If less than zero, the negative of the number of + // > distinct values divided by the number of rows. + // https://www.postgresql.org/docs/current/view-pg-stats.html + CHECK_NA( + INTERNAL, + ArrowArrayAppendDouble(value_float64_col, + n_distinct.second > 0 + ? n_distinct.second + : (std::fabs(n_distinct.second) * reltuples.second)), + error); + CHECK_NA(INTERNAL, + ArrowArrayFinishUnionElement(statistics_value_col, kStatsVariantFloat64), + error); + CHECK_NA(INTERNAL, ArrowArrayAppendInt(statistics_is_approximate_col, 1), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_items), error); + } + } + + CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_statistics_col), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_col), error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); + + CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error, + error); + uschema.move(schema); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresConnection::GetStatistics(const char* catalog, + const char* db_schema, + const char* table_name, bool approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + // Simplify our jobs here + if (!approximate) { + SetError(error, "[libpq] Exact statistics are not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (!db_schema) { + SetError(error, "[libpq] Must request statistics for a single schema"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else if (catalog && std::strcmp(catalog, PQdb(conn_)) != 0) { + SetError(error, "[libpq] Can only request statistics for current catalog"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); + + AdbcStatusCode status = PostgresConnectionGetStatisticsImpl( + conn_, db_schema, table_name, &schema, &array, error); + if (status != ADBC_STATUS_OK) { + if (schema.release) schema.release(&schema); + if (array.release) array.release(&array); + return status; + } + + return BatchToArrayStream(&array, &schema, out, error); +} + +AdbcStatusCode PostgresConnectionGetStatisticNamesImpl(struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error) { + auto uschema = nanoarrow::UniqueSchema(); + ArrowSchemaInit(uschema.get()); + + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get(), NANOARROW_TYPE_STRUCT), error); + CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(uschema.get(), /*num_columns=*/2), + error); + + ArrowSchemaInit(uschema.get()->children[0]); + CHECK_NA(INTERNAL, + ArrowSchemaSetType(uschema.get()->children[0], NANOARROW_TYPE_STRING), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[0], "statistic_name"), + error); + uschema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + ArrowSchemaInit(uschema.get()->children[1]); + CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get()->children[1], NANOARROW_TYPE_INT16), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[1], "statistic_key"), + error); + uschema.get()->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + + CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); + CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error); + + uschema.move(schema); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresConnection::GetStatisticNames(struct ArrowArrayStream* out, + struct AdbcError* error) { + // We don't support any extended statistics, just return an empty stream + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); + + AdbcStatusCode status = PostgresConnectionGetStatisticNamesImpl(&schema, &array, error); + if (status != ADBC_STATUS_OK) { + if (schema.release) schema.release(&schema); + if (array.release) array.release(&array); + return status; + } + return BatchToArrayStream(&array, &schema, out, error); + + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { AdbcStatusCode final_status = ADBC_STATUS_OK; - struct StringBuilder query = {0}; + struct StringBuilder query; + std::memset(&query, 0, sizeof(query)); std::vector params; if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL; @@ -883,7 +1181,14 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, StringBuilderReset(&query); RAISE_ADBC(result_helper.Prepare()); - RAISE_ADBC(result_helper.Execute()); + auto result = result_helper.Execute(); + if (result != ADBC_STATUS_OK) { + auto error_code = std::string(error->sqlstate, 5); + if ((error_code == "42P01") || (error_code == "42602")) { + return ADBC_STATUS_NOT_FOUND; + } + return result; + } auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); @@ -950,8 +1255,10 @@ AdbcStatusCode PostgresConnectionGetTableTypesImpl(struct ArrowSchema* schema, AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connection, struct ArrowArrayStream* out, struct AdbcError* error) { - struct ArrowSchema schema = {0}; - struct ArrowArray array = {0}; + struct ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + struct ArrowArray array; + std::memset(&array, 0, sizeof(array)); AdbcStatusCode status = PostgresConnectionGetTableTypesImpl(&schema, &array, error); if (status != ADBC_STATUS_OK) { @@ -965,16 +1272,31 @@ AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connecti AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, struct AdbcError* error) { if (!database || !database->private_data) { - SetError(error, "%s", "[libpq] Must provide an initialized AdbcDatabase"); + SetError(error, "[libpq] Must provide an initialized AdbcDatabase"); return ADBC_STATUS_INVALID_ARGUMENT; } database_ = *reinterpret_cast*>(database->private_data); type_resolver_ = database_->type_resolver(); - return database_->Connect(&conn_, error); + + RAISE_ADBC(database_->Connect(&conn_, error)); + + cancel_ = PQgetCancel(conn_); + if (!cancel_) { + SetError(error, "[libpq] Could not initialize PGcancel"); + return ADBC_STATUS_UNKNOWN; + } + + std::ignore = PQsetNoticeProcessor(conn_, SilentNoticeProcessor, nullptr); + + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Release(struct AdbcError* error) { + if (cancel_) { + PQfreeCancel(cancel_); + cancel_ = nullptr; + } if (conn_) { return database_->Disconnect(&conn_, error); } @@ -1024,8 +1346,35 @@ AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value, autocommit_ = autocommit; } return ADBC_STATUS_OK; + } else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { + // PostgreSQL doesn't accept a parameter here + PqResultHelper result_helper{ + conn_, std::string("SET search_path TO ") + value, {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + return ADBC_STATUS_OK; } SetError(error, "%s%s", "[libpq] Unknown option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } + +AdbcStatusCode PostgresConnection::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresConnection::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h index 99770c2..5e45e90 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h @@ -29,10 +29,12 @@ namespace adbcpq { class PostgresDatabase; class PostgresConnection { public: - PostgresConnection() : database_(nullptr), conn_(nullptr), autocommit_(true) {} + PostgresConnection() + : database_(nullptr), conn_(nullptr), cancel_(nullptr), autocommit_(true) {} + AdbcStatusCode Cancel(struct AdbcError* error); AdbcStatusCode Commit(struct AdbcError* error); - AdbcStatusCode GetInfo(struct AdbcConnection* connection, uint32_t* info_codes, + AdbcStatusCode GetInfo(struct AdbcConnection* connection, const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error); AdbcStatusCode GetObjects(struct AdbcConnection* connection, int depth, @@ -40,6 +42,18 @@ class PostgresConnection { const char* table_name, const char** table_types, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); + AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, + const char* table_name, bool approximate, + struct ArrowArrayStream* out, struct AdbcError* error); + AdbcStatusCode GetStatisticNames(struct ArrowArrayStream* out, struct AdbcError* error); AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error); @@ -49,16 +63,27 @@ class PostgresConnection { AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode Rollback(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); PGconn* conn() const { return conn_; } const std::shared_ptr& type_resolver() const { return type_resolver_; } + bool autocommit() const { return autocommit_; } private: + AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, + size_t info_codes_length, + struct ArrowSchema* schema, + struct ArrowArray* array, + struct AdbcError* error); std::shared_ptr database_; std::shared_ptr type_resolver_; PGconn* conn_; + PGcancel* cancel_; bool autocommit_; }; } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc index 3976c4b..5de8628 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc @@ -36,6 +36,23 @@ PostgresDatabase::PostgresDatabase() : open_connections_(0) { } PostgresDatabase::~PostgresDatabase() = default; +AdbcStatusCode PostgresDatabase::GetOption(const char* option, char* value, + size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionBytes(const char* option, uint8_t* value, + size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} +AdbcStatusCode PostgresDatabase::GetOptionDouble(const char* option, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) { // Connect to validate the parameters. return RebuildTypeResolver(error); @@ -61,6 +78,24 @@ AdbcStatusCode PostgresDatabase::SetOption(const char* key, const char* value, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresDatabase::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresDatabase::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresDatabase::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode PostgresDatabase::Connect(PGconn** conn, struct AdbcError* error) { if (uri_.empty()) { SetError(error, "%s", @@ -90,10 +125,10 @@ AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* err // Helpers for building the type resolver from queries static inline int32_t InsertPgAttributeResult( - pg_result* result, const std::shared_ptr& resolver); + PGresult* result, const std::shared_ptr& resolver); static inline int32_t InsertPgTypeResult( - pg_result* result, const std::shared_ptr& resolver); + PGresult* result, const std::shared_ptr& resolver); AdbcStatusCode PostgresDatabase::RebuildTypeResolver(struct AdbcError* error) { PGconn* conn = nullptr; @@ -142,7 +177,7 @@ ORDER BY auto resolver = std::make_shared(); // Insert record type definitions (this includes table schemas) - pg_result* result = PQexec(conn, kColumnsQuery.c_str()); + PGresult* result = PQexec(conn, kColumnsQuery.c_str()); ExecStatusType pq_status = PQresultStatus(result); if (pq_status == PGRES_TUPLES_OK) { InsertPgAttributeResult(result, resolver); @@ -187,7 +222,7 @@ ORDER BY } static inline int32_t InsertPgAttributeResult( - pg_result* result, const std::shared_ptr& resolver) { + PGresult* result, const std::shared_ptr& resolver) { int num_rows = PQntuples(result); std::vector> columns; uint32_t current_type_oid = 0; @@ -219,7 +254,7 @@ static inline int32_t InsertPgAttributeResult( } static inline int32_t InsertPgTypeResult( - pg_result* result, const std::shared_ptr& resolver) { + PGresult* result, const std::shared_ptr& resolver) { int num_rows = PQntuples(result); PostgresTypeResolver::Item item; int32_t n_added = 0; diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h index f104647..6c3da58 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h @@ -36,7 +36,19 @@ class PostgresDatabase { AdbcStatusCode Init(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); + AdbcStatusCode GetOption(const char* option, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* option, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* option, double* value, + struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* option, int64_t* value, + struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); // Internal implementation @@ -54,3 +66,10 @@ class PostgresDatabase { std::shared_ptr type_resolver_; }; } // namespace adbcpq + +extern "C" { +/// For applications that want to use the driver struct directly, this gives +/// them access to the Init routine. +ADBC_EXPORT +AdbcStatusCode PostgresqlDriverInit(int, void*, struct AdbcError*); +} diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc new file mode 100644 index 0000000..ed93d17 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "error.h" + +#include +#include +#include +#include +#include + +#include + +#include "common/utils.h" + +namespace adbcpq { + +namespace { +struct DetailField { + int code; + std::string key; +}; + +static const std::vector kDetailFields = { + {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, + {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, + {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, + {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, + {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, + {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, + {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, + {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, + {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, + {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, + {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, + {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, + {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, + {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, +}; +} // namespace + +AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, + ...) { + va_list args; + va_start(args, format); + SetErrorVariadic(error, format, args); + va_end(args); + + AdbcStatusCode code = ADBC_STATUS_IO; + + const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); + if (sqlstate) { + // https://www.postgresql.org/docs/current/errcodes-appendix.html + // This can be extended in the future + if (std::strcmp(sqlstate, "57014") == 0) { + code = ADBC_STATUS_CANCELLED; + } else if (std::strcmp(sqlstate, "42P01") == 0 || + std::strcmp(sqlstate, "42602") == 0) { + code = ADBC_STATUS_NOT_FOUND; + } else if (std::strncmp(sqlstate, "42", 0) == 0) { + // Class 42 — Syntax Error or Access Rule Violation + code = ADBC_STATUS_INVALID_ARGUMENT; + } + + static_assert(sizeof(error->sqlstate) == 5, ""); + // N.B. strncpy generates warnings when used for this purpose + int i = 0; + for (; sqlstate[i] != '\0' && i < 5; i++) { + error->sqlstate[i] = sqlstate[i]; + } + for (; i < 5; i++) { + error->sqlstate[i] = '\0'; + } + } + + for (const auto& field : kDetailFields) { + const char* value = PQresultErrorField(result, field.code); + if (value) { + AppendErrorDetail(error, field.key.c_str(), reinterpret_cast(value), + std::strlen(value)); + } + } + return code; +} + +} // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h new file mode 100644 index 0000000..75c52b4 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Error handling utilities. + +#pragma once + +#include +#include + +namespace adbcpq { + +// The printf checking attribute doesn't work properly on gcc 4.8 +// and results in spurious compiler warnings +#if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) +#define ADBC_CHECK_PRINTF_ATTRIBUTE(x, y) __attribute__((format(printf, x, y))) +#else +#define ADBC_CHECK_PRINTF_ATTRIBUTE(x, y) +#endif + +/// \brief Set an error based on a PGresult, inferring the proper ADBC status +/// code from the PGresult. +AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, + ...) ADBC_CHECK_PRINTF_ATTRIBUTE(3, 4); + +#undef ADBC_CHECK_PRINTF_ATTRIBUTE + +} // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader.h index 78358a9..5a58970 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader.h @@ -17,7 +17,9 @@ #pragma once +#include #include +#include #include #include #include @@ -29,6 +31,11 @@ #include "postgres_type.h" #include "postgres_util.h" +// R 3.6 / Windows builds on a very old toolchain that does not define ENODATA +#if defined(_WIN32) && !defined(MSVC) && !defined(ENODATA) +#define ENODATA 120 +#endif + namespace adbcpq { // "PGCOPY\n\377\r\n\0" @@ -36,6 +43,30 @@ static int8_t kPgCopyBinarySignature[] = {0x50, 0x47, 0x43, 0x4F, 0x50, 0x59, 0x0A, static_cast(0xFF), 0x0D, 0x0A, 0x00}; +// The maximum value in seconds that can be converted into microseconds +// without overflow +constexpr int64_t kMaxSafeSecondsToMicros = 9223372036854L; + +// The minimum value in seconds that can be converted into microseconds +// without overflow +constexpr int64_t kMinSafeSecondsToMicros = -9223372036854L; + +// The maximum value in milliseconds that can be converted into microseconds +// without overflow +constexpr int64_t kMaxSafeMillisToMicros = 9223372036854775L; + +// The minimum value in milliseconds that can be converted into microseconds +// without overflow +constexpr int64_t kMinSafeMillisToMicros = -9223372036854775L; + +// The maximum value in microseconds that can be converted into nanoseconds +// without overflow +constexpr int64_t kMaxSafeMicrosToNanos = 9223372036854775L; + +// The minimum value in microseconds that can be converted into nanoseconds +// without overflow +constexpr int64_t kMinSafeMicrosToNanos = -9223372036854775L; + // Read a value from the buffer without checking the buffer size. Advances // the cursor of data and reduces its size by sizeof(T). template @@ -206,6 +237,212 @@ class PostgresCopyNetworkEndianFieldReader : public PostgresCopyFieldReader { } }; +// Reader for Intervals +class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader { + public: + ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array, + ArrowError* error) override { + if (field_size_bytes <= 0) { + return ArrowArrayAppendNull(array, 1); + } + + if (field_size_bytes != 16) { + ArrowErrorSet(error, "Expected field with %d bytes but found field with %d bytes", + 16, + static_cast(field_size_bytes)); // NOLINT(runtime/int) + return EINVAL; + } + + // postgres stores time as usec, arrow stores as ns + const int64_t time_usec = ReadUnsafe(data); + int64_t time; + + if (time_usec > kMaxSafeMicrosToNanos || time_usec < kMinSafeMicrosToNanos) { + ArrowErrorSet(error, + "[libpq] Interval with time value %" PRId64 + " usec would overflow when converting to nanoseconds", + time_usec); + return EINVAL; + } + + time = time_usec * 1000; + + const int32_t days = ReadUnsafe(data); + const int32_t months = ReadUnsafe(data); + + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &months, sizeof(int32_t))); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &days, sizeof(int32_t))); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &time, sizeof(int64_t))); + return AppendValid(array); + } +}; + +// // Converts COPY resulting from the Postgres NUMERIC type into a string. +// Rewritten based on the Postgres implementation of NUMERIC cast to string in +// src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source, +// DEC_DIGITS is always 4 and DBASE is always 10000). +// +// Briefly, the Postgres representation of "numeric" is an array of int16_t ("digits") +// from most significant to least significant. Each "digit" is a value between 0000 and +// 9999. There are weight + 1 digits before the decimal point and dscale digits after the +// decimal point. Both of those values can be zero or negative. A "sign" component +// encodes the positive or negativeness of the value and is also used to encode special +// values (inf, -inf, and nan). +class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader { + public: + ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array, + ArrowError* error) override { + // -1 for NULL + if (field_size_bytes < 0) { + return ArrowArrayAppendNull(array, 1); + } + + // Read the input + if (data->size_bytes < static_cast(4 * sizeof(int16_t))) { + ArrowErrorSet(error, + "Expected at least %d bytes of field data for numeric copy data but " + "only %d bytes of input remain", + static_cast(4 * sizeof(int16_t)), + static_cast(data->size_bytes)); // NOLINT(runtime/int) + return EINVAL; + } + + int16_t ndigits = ReadUnsafe(data); + int16_t weight = ReadUnsafe(data); + uint16_t sign = ReadUnsafe(data); + uint16_t dscale = ReadUnsafe(data); + + if (data->size_bytes < static_cast(ndigits * sizeof(int16_t))) { + ArrowErrorSet(error, + "Expected at least %d bytes of field data for numeric digits copy " + "data but only %d bytes of input remain", + static_cast(ndigits * sizeof(int16_t)), + static_cast(data->size_bytes)); // NOLINT(runtime/int) + return EINVAL; + } + + digits_.clear(); + for (int16_t i = 0; i < ndigits; i++) { + digits_.push_back(ReadUnsafe(data)); + } + + // Handle special values + std::string special_value; + switch (sign) { + case kNumericNAN: + special_value = std::string("nan"); + break; + case kNumericPinf: + special_value = std::string("inf"); + break; + case kNumericNinf: + special_value = std::string("-inf"); + break; + case kNumericPos: + case kNumericNeg: + special_value = std::string(""); + break; + default: + ArrowErrorSet(error, + "Unexpected value for sign read from Postgres numeric field: %d", + static_cast(sign)); + return EINVAL; + } + + if (!special_value.empty()) { + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_, special_value.data(), special_value.size())); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes)); + return AppendValid(array); + } + + // Calculate string space requirement + int64_t max_chars_required = std::max(1, (weight + 1) * kDecDigits); + max_chars_required += dscale + kDecDigits + 2; + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(data_, max_chars_required)); + char* out0 = reinterpret_cast(data_->data + data_->size_bytes); + char* out = out0; + + // Build output string in-place, starting with the negative sign + if (sign == kNumericNeg) { + *out++ = '-'; + } + + // ...then digits before the decimal point + int d; + int d1; + int16_t dig; + + if (weight < 0) { + d = weight + 1; + *out++ = '0'; + } else { + for (d = 0; d <= weight; d++) { + if (d < ndigits) { + dig = digits_[d]; + } else { + dig = 0; + } + + // To strip leading zeroes + int append = (d > 0); + + for (const auto pow10 : {1000, 100, 10, 1}) { + d1 = dig / pow10; + dig -= d1 * pow10; + append |= (d1 > 0); + if (append) { + *out++ = d1 + '0'; + } + } + } + } + + // ...then the decimal point + digits after it. This may write more digits + // than specified by dscale so we need to keep track of how many we want to + // keep here. + int64_t actual_chars_required = out - out0; + + if (dscale > 0) { + *out++ = '.'; + actual_chars_required += dscale + 1; + + for (int i = 0; i < dscale; i++, d++, i += kDecDigits) { + if (d >= 0 && d < ndigits) { + dig = digits_[d]; + } else { + dig = 0; + } + + for (const auto pow10 : {1000, 100, 10, 1}) { + d1 = dig / pow10; + dig -= d1 * pow10; + *out++ = d1 + '0'; + } + } + } + + // Update data buffer size and add offsets + data_->size_bytes += actual_chars_required; + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes)); + return AppendValid(array); + } + + private: + std::vector digits_; + + // Number of decimal digits per Postgres digit + static const int kDecDigits = 4; + // The "base" of the Postgres representation (i.e., each "digit" is 0 to 9999) + static const int kNBase = 10000; + // Valid values for the sign component + static const uint16_t kNumericPos = 0x0000; + static const uint16_t kNumericNeg = 0x4000; + static const uint16_t kNumericNAN = 0xC000; + static const uint16_t kNumericPinf = 0xD000; + static const uint16_t kNumericNinf = 0xF000; +}; + // Reader for Pg->Arrow conversions whose Arrow representation is simply the // bytes of the field representation. This can be used with binary and string // Arrow types and any Postgres type. @@ -569,6 +806,9 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type, case PostgresTypeId::kName: *out = new PostgresCopyBinaryFieldReader(); return NANOARROW_OK; + case PostgresTypeId::kNumeric: + *out = new PostgresCopyNumericFieldReader(); + return NANOARROW_OK; default: return ErrorCantConvert(error, pg_type, schema_view); } @@ -661,6 +901,16 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type, default: return ErrorCantConvert(error, pg_type, schema_view); } + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + switch (pg_type.type_id()) { + case PostgresTypeId::kInterval: { + *out = new PostgresCopyIntervalFieldReader(); + return NANOARROW_OK; + } + default: + return ErrorCantConvert(error, pg_type, schema_view); + } + default: return ErrorCantConvert(error, pg_type, schema_view); } @@ -668,15 +918,19 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type, class PostgresCopyStreamReader { public: - ArrowErrorCode Init(const PostgresType& pg_type) { + ArrowErrorCode Init(PostgresType pg_type) { if (pg_type.type_id() != PostgresTypeId::kRecord) { return EINVAL; } - root_reader_.Init(pg_type); + pg_type_ = std::move(pg_type); + root_reader_.Init(pg_type_); + array_size_approx_bytes_ = 0; return NANOARROW_OK; } + int64_t array_size_approx_bytes() const { return array_size_approx_bytes_; } + ArrowErrorCode SetOutputSchema(ArrowSchema* schema, ArrowError* error) { if (std::string(schema_->format) != "+s") { ArrowErrorSet( @@ -771,9 +1025,12 @@ class PostgresCopyStreamReader { ArrowArrayInitFromSchema(array_.get(), schema_.get(), error)); NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(array_.get())); NANOARROW_RETURN_NOT_OK(root_reader_.InitArray(array_.get())); + array_size_approx_bytes_ = 0; } + const uint8_t* start = data->data.as_uint8; NANOARROW_RETURN_NOT_OK(root_reader_.Read(data, -1, array_.get(), error)); + array_size_approx_bytes_ += (data->data.as_uint8 - start); return NANOARROW_OK; } @@ -791,10 +1048,14 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } + const PostgresType& pg_type() const { return pg_type_; } + private: + PostgresType pg_type_; PostgresCopyFieldTupleReader root_reader_; nanoarrow::UniqueSchema schema_; nanoarrow::UniqueArray array_; + int64_t array_size_approx_bytes_; }; } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader_test.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader_test.cc index 26a1ab3..44ad060 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_copy_reader_test.cc @@ -334,6 +334,88 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadDoublePrecision) { ASSERT_EQ(data_buffer[4], 0); } +// For full coverage, ensure that this contains NUMERIC examples that: +// - Have >= four zeroes to the left of the decimal point +// - Have >= four zeroes to the right of the decimal point +// - Include special values (nan, -inf, inf, NULL) +// - Have >= four trailing zeroes to the right of the decimal point +// - Have >= four leading zeroes before the first digit to the right of the decimal point +// - Is < 0 (negative) +// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (1000000), ('0.00001234'), +// ('1.0000'), (-123.456), (123.456), ('nan'), ('-inf'), ('inf'), (NULL)) AS drvd(col)) TO +// STDOUT WITH (FORMAT binary); +static uint8_t kTestPgCopyNumeric[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, + 0x01, 0xff, 0xfe, 0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, + 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x7b, 0x11, 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0xf0, 0x00, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0xd0, 0x00, 0x00, 0x20, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) { + ArrowBufferView data; + data.data.as_uint8 = kTestPgCopyNumeric; + data.size_bytes = sizeof(kTestPgCopyNumeric); + + auto col_type = PostgresType(PostgresTypeId::kNumeric); + PostgresType input_type(PostgresTypeId::kRecord); + input_type.AppendChild("col", col_type); + + PostgresCopyStreamTester tester; + ASSERT_EQ(tester.Init(input_type), NANOARROW_OK); + ASSERT_EQ(tester.ReadAll(&data), ENODATA); + ASSERT_EQ(data.data.as_uint8 - kTestPgCopyNumeric, sizeof(kTestPgCopyNumeric)); + ASSERT_EQ(data.size_bytes, 0); + + nanoarrow::UniqueArray array; + ASSERT_EQ(tester.GetArray(array.get()), NANOARROW_OK); + ASSERT_EQ(array->length, 9); + ASSERT_EQ(array->n_children, 1); + + nanoarrow::UniqueSchema schema; + tester.GetSchema(schema.get()); + + nanoarrow::UniqueArrayView array_view; + ASSERT_EQ(ArrowArrayViewInitFromSchema(array_view.get(), schema.get(), nullptr), + NANOARROW_OK); + ASSERT_EQ(array_view->children[0]->storage_type, NANOARROW_TYPE_STRING); + ASSERT_EQ(ArrowArrayViewSetArray(array_view.get(), array.get(), nullptr), NANOARROW_OK); + + auto validity = array_view->children[0]->buffer_views[0].data.as_uint8; + ASSERT_TRUE(ArrowBitGet(validity, 0)); + ASSERT_TRUE(ArrowBitGet(validity, 1)); + ASSERT_TRUE(ArrowBitGet(validity, 2)); + ASSERT_TRUE(ArrowBitGet(validity, 3)); + ASSERT_TRUE(ArrowBitGet(validity, 4)); + ASSERT_TRUE(ArrowBitGet(validity, 5)); + ASSERT_TRUE(ArrowBitGet(validity, 6)); + ASSERT_TRUE(ArrowBitGet(validity, 7)); + ASSERT_FALSE(ArrowBitGet(validity, 8)); + + struct ArrowStringView item; + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 0); + EXPECT_EQ(std::string(item.data, item.size_bytes), "1000000"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 1); + EXPECT_EQ(std::string(item.data, item.size_bytes), "0.00001234"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 2); + EXPECT_EQ(std::string(item.data, item.size_bytes), "1.0000"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 3); + EXPECT_EQ(std::string(item.data, item.size_bytes), "-123.456"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 4); + EXPECT_EQ(std::string(item.data, item.size_bytes), "123.456"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 5); + EXPECT_EQ(std::string(item.data, item.size_bytes), "nan"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 6); + EXPECT_EQ(std::string(item.data, item.size_bytes), "-inf"); + item = ArrowArrayViewGetStringUnsafe(array_view->children[0], 7); + EXPECT_EQ(std::string(item.data, item.size_bytes), "inf"); +} + // COPY (SELECT CAST("col" AS TEXT) AS "col" FROM ( VALUES ('abc'), ('1234'), // (NULL::text)) AS drvd("col")) TO STDOUT WITH (FORMAT binary); static uint8_t kTestPgCopyText[] = { diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h index e234e36..1dfcbe2 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h @@ -214,6 +214,11 @@ class PostgresType { NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_DOUBLE)); break; + // ---- Numeric/Decimal------------------- + case PostgresTypeId::kNumeric: + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); + break; + // ---- Binary/string -------------------- case PostgresTypeId::kChar: case PostgresTypeId::kBpchar: @@ -253,6 +258,11 @@ class PostgresType { NANOARROW_TIME_UNIT_MICRO, /*timezone=*/"UTC")); break; + case PostgresTypeId::kInterval: + NANOARROW_RETURN_NOT_OK( + ArrowSchemaSetType(schema, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO)); + break; + // ---- Nested -------------------- case PostgresTypeId::kRecord: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children())); diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc index 29fd04c..2e25c4b 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc @@ -34,7 +34,7 @@ using adbcpq::PostgresStatement; // --------------------------------------------------------------------- // ADBC interface implementation - as private functions so that these // don't get replaced by the dynamic linker. If we implemented these -// under the Adbc* names, then DriverInit, the linker may resolve +// under the Adbc* names, then in DriverInit, the linker may resolve // functions to the address of the functions provided by the driver // manager instead of our functions. // @@ -47,6 +47,30 @@ using adbcpq::PostgresStatement; // // So in the end some manual effort here was chosen. +// --------------------------------------------------------------------- +// AdbcError + +namespace { +const struct AdbcError* PostgresErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + // Currently only valid for TupleReader + return adbcpq::TupleReader::ErrorFromArrayStream(stream, status); +} +} // namespace + +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + return CommonErrorGetDetailCount(error); +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + return CommonErrorGetDetail(error, index); +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return PostgresErrorFromArrayStream(stream, status); +} + // --------------------------------------------------------------------- // AdbcDatabase @@ -83,14 +107,92 @@ AdbcStatusCode PostgresDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode PostgresDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionBytes(struct AdbcDatabase* database, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionDouble(struct AdbcDatabase* database, + const char* key, double* value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresDatabaseGetOptionInt(struct AdbcDatabase* database, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + AdbcStatusCode PostgresDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE; auto ptr = reinterpret_cast*>(database->private_data); return (*ptr)->SetOption(key, value, error); } + +AdbcStatusCode PostgresDatabaseSetOptionBytes(struct AdbcDatabase* database, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresDatabaseSetOptionDouble(struct AdbcDatabase* database, + const char* key, double value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresDatabaseSetOptionInt(struct AdbcDatabase* database, + const char* key, int64_t value, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>(database->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} } // namespace +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresDatabaseGetOption(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return PostgresDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return PostgresDatabaseGetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return PostgresDatabaseGetOptionDouble(database, key, value, error); +} + AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return PostgresDatabaseInit(database, error); } @@ -109,10 +211,34 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return PostgresDatabaseSetOption(database, key, value, error); } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return PostgresDatabaseSetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return PostgresDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return PostgresDatabaseSetOptionDouble(database, key, value, error); +} + // --------------------------------------------------------------------- // AdbcConnection namespace { +AdbcStatusCode PostgresConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->Cancel(error); +} + AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; @@ -122,7 +248,8 @@ AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection, } AdbcStatusCode PostgresConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; @@ -142,6 +269,63 @@ AdbcStatusCode PostgresConnectionGetObjects( table_types, column_name, stream, error); } +AdbcStatusCode PostgresConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetOptionInt(key, value, error); +} + +AdbcStatusCode PostgresConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetStatistics(catalog, db_schema, table_name, approximate == 1, out, + error); +} + +AdbcStatusCode PostgresConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->GetStatisticNames(out, error); +} + AdbcStatusCode PostgresConnectionGetTableSchema( struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { @@ -213,14 +397,47 @@ AdbcStatusCode PostgresConnectionSetOption(struct AdbcConnection* connection, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(connection->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + } // namespace + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return PostgresConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { return PostgresConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* stream, struct AdbcError* error) { return PostgresConnectionGetInfo(connection, info_codes, info_codes_length, stream, @@ -237,6 +454,45 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_types, column_name, stream, error); } +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return PostgresConnectionGetOptionDouble(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PostgresConnectionGetStatistics(connection, catalog, db_schema, table_name, + approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return PostgresConnectionGetStatisticNames(connection, out, error); +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -287,6 +543,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return PostgresConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return PostgresConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return PostgresConnectionSetOptionDouble(connection, key, value, error); +} + // --------------------------------------------------------------------- // AdbcStatement @@ -310,6 +584,14 @@ AdbcStatusCode PostgresStatementBindStream(struct AdbcStatement* statement, return (*ptr)->Bind(stream, error); } +AdbcStatusCode PostgresStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->Cancel(error); +} + AdbcStatusCode PostgresStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -329,16 +611,49 @@ AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement, return (*ptr)->ExecuteQuery(output, rows_affected, error); } -AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement, - uint8_t* partition_desc, - struct AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->ExecuteSchema(schema, error); } -AdbcStatusCode PostgresStatementGetPartitionDescSize(struct AdbcStatement* statement, - size_t* length, - struct AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresStatementGetOption(struct AdbcStatement* statement, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOption(key, value, length, error); +} + +AdbcStatusCode PostgresStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetOptionInt(key, value, error); } AdbcStatusCode PostgresStatementGetParameterSchema(struct AdbcStatement* statement, @@ -386,6 +701,33 @@ AdbcStatusCode PostgresStatementSetOption(struct AdbcStatement* statement, return (*ptr)->SetOption(key, value, error); } +AdbcStatusCode PostgresStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionBytes(key, value, length, error); +} + +AdbcStatusCode PostgresStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionDouble(key, value, error); +} + +AdbcStatusCode PostgresStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOptionInt(key, value, error); +} + AdbcStatusCode PostgresStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; @@ -407,6 +749,11 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return PostgresStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return PostgresStatementCancel(statement, error); +} + AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -423,16 +770,32 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return PostgresStatementExecuteQuery(statement, output, rows_affected, error); } -AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement, - uint8_t* partition_desc, - struct AdbcError* error) { - return PostgresStatementGetPartitionDesc(statement, partition_desc, error); +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + ArrowSchema* schema, struct AdbcError* error) { + return PostgresStatementExecuteSchema(statement, schema, error); } -AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement, - size_t* length, - struct AdbcError* error) { - return PostgresStatementGetPartitionDescSize(statement, length, error); +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return PostgresStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return PostgresStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return PostgresStatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return PostgresStatementGetOptionDouble(statement, key, value, error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -462,6 +825,23 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha return PostgresStatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return PostgresStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return PostgresStatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return PostgresStatementSetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { return PostgresStatementSetSqlQuery(statement, query, error); @@ -469,11 +849,53 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, extern "C" { ADBC_EXPORT -AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { - if (version != ADBC_VERSION_1_0_0) return ADBC_STATUS_NOT_IMPLEMENTED; +AdbcStatusCode PostgresqlDriverInit(int version, void* raw_driver, + struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0 && version != ADBC_VERSION_1_1_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + if (!raw_driver) return ADBC_STATUS_INVALID_ARGUMENT; auto* driver = reinterpret_cast(raw_driver); - std::memset(driver, 0, sizeof(*driver)); + if (version >= ADBC_VERSION_1_1_0) { + std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE); + + driver->ErrorGetDetailCount = CommonErrorGetDetailCount; + driver->ErrorGetDetail = CommonErrorGetDetail; + driver->ErrorFromArrayStream = PostgresErrorFromArrayStream; + + driver->DatabaseGetOption = PostgresDatabaseGetOption; + driver->DatabaseGetOptionBytes = PostgresDatabaseGetOptionBytes; + driver->DatabaseGetOptionDouble = PostgresDatabaseGetOptionDouble; + driver->DatabaseGetOptionInt = PostgresDatabaseGetOptionInt; + driver->DatabaseSetOptionBytes = PostgresDatabaseSetOptionBytes; + driver->DatabaseSetOptionDouble = PostgresDatabaseSetOptionDouble; + driver->DatabaseSetOptionInt = PostgresDatabaseSetOptionInt; + + driver->ConnectionCancel = PostgresConnectionCancel; + driver->ConnectionGetOption = PostgresConnectionGetOption; + driver->ConnectionGetOptionBytes = PostgresConnectionGetOptionBytes; + driver->ConnectionGetOptionDouble = PostgresConnectionGetOptionDouble; + driver->ConnectionGetOptionInt = PostgresConnectionGetOptionInt; + driver->ConnectionGetStatistics = PostgresConnectionGetStatistics; + driver->ConnectionGetStatisticNames = PostgresConnectionGetStatisticNames; + driver->ConnectionSetOptionBytes = PostgresConnectionSetOptionBytes; + driver->ConnectionSetOptionDouble = PostgresConnectionSetOptionDouble; + driver->ConnectionSetOptionInt = PostgresConnectionSetOptionInt; + + driver->StatementCancel = PostgresStatementCancel; + driver->StatementExecuteSchema = PostgresStatementExecuteSchema; + driver->StatementGetOption = PostgresStatementGetOption; + driver->StatementGetOptionBytes = PostgresStatementGetOptionBytes; + driver->StatementGetOptionDouble = PostgresStatementGetOptionDouble; + driver->StatementGetOptionInt = PostgresStatementGetOptionInt; + driver->StatementSetOptionBytes = PostgresStatementSetOptionBytes; + driver->StatementSetOptionDouble = PostgresStatementSetOptionDouble; + driver->StatementSetOptionInt = PostgresStatementSetOptionInt; + } else { + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + } + driver->DatabaseInit = PostgresDatabaseInit; driver->DatabaseNew = PostgresDatabaseNew; driver->DatabaseRelease = PostgresDatabaseRelease; @@ -501,6 +923,12 @@ AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* e driver->StatementRelease = PostgresStatementRelease; driver->StatementSetOption = PostgresStatementSetOption; driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery; + return ADBC_STATUS_OK; } + +ADBC_EXPORT +AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { + return PostgresqlDriverInit(version, raw_driver, error); +} } diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc index 429d59d..2afc4db 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include +#include #include #include #include @@ -25,11 +27,14 @@ #include #include #include -#include "common/utils.h" +#include "common/options.h" +#include "common/utils.h" +#include "database.h" #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" +using adbc_validation::Handle; using adbc_validation::IsOkStatus; using adbc_validation::IsStatus; @@ -47,44 +52,54 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { AdbcStatusCode DropTable(struct AdbcConnection* connection, const std::string& name, struct AdbcError* error) const override { - struct AdbcStatement statement; - std::memset(&statement, 0, sizeof(statement)); - AdbcStatusCode status = AdbcStatementNew(connection, &statement, error); - if (status != ADBC_STATUS_OK) return status; - - std::string query = "DROP TABLE IF EXISTS " + name; - status = AdbcStatementSetSqlQuery(&statement, query.c_str(), error); - if (status != ADBC_STATUS_OK) { - std::ignore = AdbcStatementRelease(&statement, error); - return status; - } - status = AdbcStatementExecuteQuery(&statement, nullptr, nullptr, error); - std::ignore = AdbcStatementRelease(&statement, error); - return status; + Handle statement; + RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error)); + + std::string query = "DROP TABLE IF EXISTS \"" + name + "\""; + RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error)); + RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + return AdbcStatementRelease(&statement.value, error); + } + + AdbcStatusCode DropTempTable(struct AdbcConnection* connection, const std::string& name, + struct AdbcError* error) const override { + Handle statement; + RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error)); + + std::string query = "DROP TABLE IF EXISTS pg_temp . \"" + name + "\""; + RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error)); + RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + return AdbcStatementRelease(&statement.value, error); } AdbcStatusCode DropView(struct AdbcConnection* connection, const std::string& name, struct AdbcError* error) const override { - struct AdbcStatement statement; - std::memset(&statement, 0, sizeof(statement)); - AdbcStatusCode status = AdbcStatementNew(connection, &statement, error); - if (status != ADBC_STATUS_OK) return status; - - std::string query = "DROP VIEW IF EXISTS " + name; - status = AdbcStatementSetSqlQuery(&statement, query.c_str(), error); - if (status != ADBC_STATUS_OK) { - std::ignore = AdbcStatementRelease(&statement, error); - return status; - } - status = AdbcStatementExecuteQuery(&statement, nullptr, nullptr, error); - std::ignore = AdbcStatementRelease(&statement, error); - return status; + Handle statement; + RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error)); + + std::string query = "DROP VIEW IF EXISTS \"" + name + "\""; + RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error)); + RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + return AdbcStatementRelease(&statement.value, error); } std::string BindParameter(int index) const override { return "$" + std::to_string(index + 1); } + ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override { + switch (ingest_type) { + case NANOARROW_TYPE_INT8: + return NANOARROW_TYPE_INT16; + case NANOARROW_TYPE_DURATION: + return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO; + case NANOARROW_TYPE_LARGE_STRING: + return NANOARROW_TYPE_STRING; + default: + return ingest_type; + } + } + std::optional PrimaryKeyTableDdl(std::string_view name) const override { std::string ddl = "CREATE TABLE "; ddl += name; @@ -94,6 +109,30 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { std::string catalog() const override { return "postgres"; } std::string db_schema() const override { return "public"; } + + bool supports_bulk_ingest_catalog() const override { return false; } + bool supports_bulk_ingest_db_schema() const override { return true; } + bool supports_bulk_ingest_temporary() const override { return true; } + bool supports_cancel() const override { return true; } + bool supports_execute_schema() const override { return true; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_ADBC_VERSION: + return ADBC_VERSION_1_1_0; + case ADBC_INFO_DRIVER_NAME: + return "ADBC PostgreSQL Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "PostgreSQL"; + default: + return std::nullopt; + } + } + bool supports_metadata_current_catalog() const override { return true; } + bool supports_metadata_current_db_schema() const override { return true; } + bool supports_statistics() const override { return true; } }; class PostgresDatabaseTest : public ::testing::Test, @@ -108,6 +147,20 @@ class PostgresDatabaseTest : public ::testing::Test, }; ADBCV_TEST_DATABASE(PostgresDatabaseTest) +TEST_F(PostgresDatabaseTest, AdbcDriverBackwardsCompatibility) { + // XXX: sketchy cast + auto* driver = static_cast(malloc(ADBC_DRIVER_1_0_0_SIZE)); + std::memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); + + ASSERT_THAT(::PostgresqlDriverInit(ADBC_VERSION_1_0_0, driver, &error), + IsOkStatus(&error)); + + ASSERT_THAT(::PostgresqlDriverInit(424242, driver, &error), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + free(driver); +} + class PostgresConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -125,10 +178,8 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { adbc_validation::StreamReader reader; std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, + ADBC_INFO_DRIVER_NAME, ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, }; ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), &reader.stream.value, &error), @@ -144,29 +195,30 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); const uint32_t code = reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + const uint32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; seen.push_back(code); - int str_child_index = 0; - struct ArrowArrayView* str_child = - reader.array_view->children[1]->children[str_child_index]; + struct ArrowArrayView* str_child = reader.array_view->children[1]->children[0]; + struct ArrowArrayView* int_child = reader.array_view->children[1]->children[2]; switch (code) { case ADBC_INFO_DRIVER_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 0); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("ADBC PostgreSQL Driver", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_DRIVER_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 1); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("(unknown)", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_NAME: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 2); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); EXPECT_EQ("PostgreSQL", std::string(val.data, val.size_bytes)); break; } case ADBC_INFO_VENDOR_VERSION: { - ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, 3); + ArrowStringView val = ArrowArrayViewGetStringUnsafe(str_child, offset); #ifdef __WIN32 const char* pater = "\\d\\d\\d\\d\\d\\d"; #else @@ -176,6 +228,10 @@ TEST_F(PostgresConnectionTest, GetInfoMetadata) { ::testing::MatchesRegex(pater)); break; } + case ADBC_INFO_DRIVER_ADBC_VERSION: { + EXPECT_EQ(ADBC_VERSION_1_1_0, ArrowArrayViewGetIntUnsafe(int_child, offset)); + break; + } default: // Ignored break; @@ -189,10 +245,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetCatalogs) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - adbc_validation::StreamReader reader; ASSERT_THAT( AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_CATALOGS, nullptr, nullptr, @@ -219,10 +271,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - adbc_validation::StreamReader reader; ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_DB_SCHEMAS, nullptr, nullptr, nullptr, nullptr, nullptr, @@ -246,10 +294,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "adbc_pkey_test", &error), IsOkStatus(&error)); @@ -320,10 +364,6 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test", &error), IsOkStatus(&error)); ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test_base", &error), @@ -389,11 +429,12 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) { << "expected 1 constraint on adbc_fkey_test table, found: " << table->n_table_constraints; + const std::string version = adbc_validation::GetDriverVendorVersion(&connection); + const std::string search_name = + version < "120000" ? "adbc_fkey_test_fid1_fkey" : "adbc_fkey_test_fid1_fid2_fkey"; struct AdbcGetObjectsConstraint* constraint = AdbcGetObjectsDataGetConstraintByName( - *get_objects_data, "postgres", "public", "adbc_fkey_test", - "adbc_fkey_test_fid1_fid2_fkey"); - ASSERT_NE(constraint, nullptr) - << "could not find adbc_fkey_test_fid1_fid2_fkey constraint"; + *get_objects_data, "postgres", "public", "adbc_fkey_test", search_name.c_str()); + ASSERT_NE(constraint, nullptr) << "could not find " << search_name << " constraint"; auto constraint_type = std::string(constraint->constraint_type.data, constraint->constraint_type.size_bytes); @@ -441,10 +482,6 @@ TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - if (!quirks()->supports_get_objects()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropView(&connection, "adbc_table_types_view_test", &error), IsOkStatus(&error)); ASSERT_THAT(quirks()->DropTable(&connection, "adbc_table_types_table_test", &error), @@ -506,38 +543,234 @@ TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) { ASSERT_NE(view, nullptr) << "did not find view adbc_table_types_view_test"; } -TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) { - if (!quirks()->supports_bulk_ingest()) { - GTEST_SKIP(); - } +TEST_F(PostgresConnectionTest, MetadataSetCurrentDbSchema) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); - ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), - IsOkStatus(&error)); - ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "bulk_ingest", &error), + + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement.value, "CREATE SCHEMA IF NOT EXISTS testschema", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, + "CREATE TABLE IF NOT EXISTS testschema.schematable (ints INT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), IsOkStatus(&error)); - adbc_validation::Handle schema; - ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, - /*db_schema=*/nullptr, - "0'::int; DROP TABLE bulk_ingest;--", - &schema.value, &error), - IsStatus(ADBC_STATUS_IO, &error)); + // Table does not exist in this schema + error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_NOT_FOUND, &error)); + // 42P01 = table not found + ASSERT_EQ("42P01", std::string_view(error.sqlstate, 5)); + ASSERT_NE(0, AdbcErrorGetDetailCount(&error)); + bool found = false; + for (int i = 0; i < AdbcErrorGetDetailCount(&error); i++) { + struct AdbcErrorDetail detail = AdbcErrorGetDetail(&error, i); + if (std::strcmp(detail.key, "PG_DIAG_MESSAGE_PRIMARY") == 0) { + found = true; + std::string_view message(reinterpret_cast(detail.value), + detail.value_length); + ASSERT_THAT(message, ::testing::HasSubstr("schematable")); + } + } + error.release(&error); + ASSERT_TRUE(found) << "Did not find expected error detail"; + + ASSERT_THAT( + AdbcConnectionSetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + "testschema", &error), + IsOkStatus(&error)); ASSERT_THAT( - AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, - /*db_schema=*/"0'::int; DROP TABLE bulk_ingest;--", - "DROP TABLE bulk_ingest;", &schema.value, &error), - IsStatus(ADBC_STATUS_IO, &error)); - - ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, - /*db_schema=*/nullptr, "bulk_ingest", - &schema.value, &error), + AdbcStatementSetSqlQuery(&statement.value, "SELECT * FROM schematable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +TEST_F(PostgresConnectionTest, MetadataGetStatistics) { + if (!quirks()->supports_statistics()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + // Create sample table + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, + "DROP TABLE IF EXISTS statstable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement.value, + "CREATE TABLE statstable (ints INT, strs TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement.value, + "INSERT INTO statstable VALUES (1, 'a'), (NULL, 'bcd'), (-5, NULL)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, "ANALYZE statstable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + adbc_validation::StreamReader reader; + ASSERT_THAT( + AdbcConnectionGetStatistics(&connection, nullptr, quirks()->db_schema().c_str(), + "statstable", 1, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + &reader.schema.value, { + {"catalog_name", NANOARROW_TYPE_STRING, true}, + {"catalog_db_schemas", NANOARROW_TYPE_LIST, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0], + { + {"db_schema_name", NANOARROW_TYPE_STRING, true}, + {"db_schema_statistics", NANOARROW_TYPE_LIST, false}, + })); + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( - &schema.value, {{"int64s", NANOARROW_TYPE_INT64, true}, - {"strings", NANOARROW_TYPE_STRING, true}})); + reader.schema->children[1]->children[0]->children[1]->children[0], + { + {"table_name", NANOARROW_TYPE_STRING, false}, + {"column_name", NANOARROW_TYPE_STRING, true}, + {"statistic_key", NANOARROW_TYPE_INT16, false}, + {"statistic_value", NANOARROW_TYPE_DENSE_UNION, false}, + {"statistic_is_approximate", NANOARROW_TYPE_BOOL, false}, + })); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( + reader.schema->children[1]->children[0]->children[1]->children[0]->children[3], + { + {"int64", NANOARROW_TYPE_INT64, true}, + {"uint64", NANOARROW_TYPE_UINT64, true}, + {"float64", NANOARROW_TYPE_DOUBLE, true}, + {"binary", NANOARROW_TYPE_BINARY, true}, + })); + + std::vector, int16_t, int64_t>> seen; + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + + for (int64_t catalog_index = 0; catalog_index < reader.array->length; + catalog_index++) { + struct ArrowStringView catalog_name = + ArrowArrayViewGetStringUnsafe(reader.array_view->children[0], catalog_index); + ASSERT_EQ(quirks()->catalog(), + std::string_view(catalog_name.data, + static_cast(catalog_name.size_bytes))); + + struct ArrowArrayView* catalog_db_schemas = reader.array_view->children[1]; + struct ArrowArrayView* schema_stats = catalog_db_schemas->children[0]->children[1]; + struct ArrowArrayView* stats = + catalog_db_schemas->children[0]->children[1]->children[0]; + for (int64_t schema_index = + ArrowArrayViewListChildOffset(catalog_db_schemas, catalog_index); + schema_index < + ArrowArrayViewListChildOffset(catalog_db_schemas, catalog_index + 1); + schema_index++) { + struct ArrowStringView schema_name = ArrowArrayViewGetStringUnsafe( + catalog_db_schemas->children[0]->children[0], schema_index); + ASSERT_EQ(quirks()->db_schema(), + std::string_view(schema_name.data, + static_cast(schema_name.size_bytes))); + + for (int64_t stat_index = + ArrowArrayViewListChildOffset(schema_stats, schema_index); + stat_index < ArrowArrayViewListChildOffset(schema_stats, schema_index + 1); + stat_index++) { + struct ArrowStringView table_name = + ArrowArrayViewGetStringUnsafe(stats->children[0], stat_index); + ASSERT_EQ("statstable", + std::string_view(table_name.data, + static_cast(table_name.size_bytes))); + std::optional column_name; + if (!ArrowArrayViewIsNull(stats->children[1], stat_index)) { + struct ArrowStringView value = + ArrowArrayViewGetStringUnsafe(stats->children[1], stat_index); + column_name = std::string(value.data, value.size_bytes); + } + ASSERT_TRUE(ArrowArrayViewGetIntUnsafe(stats->children[4], stat_index)); + + const int16_t stat_key = static_cast( + ArrowArrayViewGetIntUnsafe(stats->children[2], stat_index)); + const int32_t offset = + stats->children[3]->buffer_views[1].data.as_int32[stat_index]; + int64_t stat_value; + switch (stat_key) { + case ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY: + case ADBC_STATISTIC_DISTINCT_COUNT_KEY: + case ADBC_STATISTIC_NULL_COUNT_KEY: + case ADBC_STATISTIC_ROW_COUNT_KEY: + stat_value = static_cast( + std::round(100 * ArrowArrayViewGetDoubleUnsafe( + stats->children[3]->children[2], offset))); + break; + default: + continue; + } + seen.emplace_back(std::move(column_name), stat_key, stat_value); + } + } + } + } + + ASSERT_THAT(seen, + ::testing::UnorderedElementsAreArray( + std::vector, int16_t, int64_t>>{ + {"ints", ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY, 400}, + {"strs", ADBC_STATISTIC_AVERAGE_BYTE_WIDTH_KEY, 300}, + {"ints", ADBC_STATISTIC_NULL_COUNT_KEY, 100}, + {"strs", ADBC_STATISTIC_NULL_COUNT_KEY, 100}, + {"ints", ADBC_STATISTIC_DISTINCT_COUNT_KEY, 200}, + {"strs", ADBC_STATISTIC_DISTINCT_COUNT_KEY, 200}, + {std::nullopt, ADBC_STATISTIC_ROW_COUNT_KEY, 300}, + })); } ADBCV_TEST_CONNECTION(PostgresConnectionTest) @@ -549,7 +782,6 @@ class PostgresStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } - void TestSqlIngestInt8() { GTEST_SKIP() << "Not implemented"; } void TestSqlIngestUInt8() { GTEST_SKIP() << "Not implemented"; } void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; } void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; } @@ -568,10 +800,271 @@ class PostgresStatementTest : public ::testing::Test, } protected: + void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, + enum ArrowTimeUnit unit, + const char* timezone) override { + switch (type) { + case NANOARROW_TYPE_TIMESTAMP: { + std::vector> expected; + switch (unit) { + case (NANOARROW_TIME_UNIT_SECOND): + expected.insert(expected.end(), {std::nullopt, -42000000, 0, 42000000}); + break; + case (NANOARROW_TIME_UNIT_MILLI): + expected.insert(expected.end(), {std::nullopt, -42000, 0, 42000}); + break; + case (NANOARROW_TIME_UNIT_MICRO): + expected.insert(expected.end(), {std::nullopt, -42, 0, 42}); + break; + case (NANOARROW_TIME_UNIT_NANO): + expected.insert(expected.end(), {std::nullopt, 0, 0, 0}); + break; + } + ASSERT_NO_FATAL_FAILURE( + adbc_validation::CompareArray(values, expected)); + break; + } + case NANOARROW_TYPE_DURATION: { + struct ArrowInterval neg_interval; + struct ArrowInterval zero_interval; + struct ArrowInterval pos_interval; + + ArrowIntervalInit(&neg_interval, type); + ArrowIntervalInit(&zero_interval, type); + ArrowIntervalInit(&pos_interval, type); + + neg_interval.months = 0; + neg_interval.days = 0; + zero_interval.months = 0; + zero_interval.days = 0; + pos_interval.months = 0; + pos_interval.days = 0; + + switch (unit) { + case (NANOARROW_TIME_UNIT_SECOND): + neg_interval.ns = -42000000000; + zero_interval.ns = 0; + pos_interval.ns = 42000000000; + break; + case (NANOARROW_TIME_UNIT_MILLI): + neg_interval.ns = -42000000; + zero_interval.ns = 0; + pos_interval.ns = 42000000; + break; + case (NANOARROW_TIME_UNIT_MICRO): + neg_interval.ns = -42000; + zero_interval.ns = 0; + pos_interval.ns = 42000; + break; + case (NANOARROW_TIME_UNIT_NANO): + // lower than us precision is lost + neg_interval.ns = 0; + zero_interval.ns = 0; + pos_interval.ns = 0; + break; + } + const std::vector> expected = { + std::nullopt, &neg_interval, &zero_interval, &pos_interval}; + ASSERT_NO_FATAL_FAILURE( + adbc_validation::CompareArray(values, expected)); + break; + } + default: + FAIL() << "ValidateIngestedTemporalData not implemented for type " << type; + } + } + PostgresQuirks quirks_; }; ADBCV_TEST_STATEMENT(PostgresStatementTest) +TEST_F(PostgresStatementTest, SqlIngestTemporaryTable) { + ASSERT_THAT(quirks()->DropTempTable(&connection, "temptable", &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcConnectionSetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "CREATE TEMPORARY TABLE temptable (ints BIGINT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcConnectionCommit(&connection, &error), IsOkStatus(&error)); + + { + adbc_validation::Handle schema; + adbc_validation::Handle batch; + + ArrowSchemaInit(&schema.value); + ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_INT64), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "ints"), + adbc_validation::IsOkErrno()); + + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &batch.value, static_cast(nullptr), + {-1, 0, 1, std::nullopt})), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "temptable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, &error), + IsOkStatus(&error)); + // because temporary table + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_NOT_FOUND, &error)); + } + + ASSERT_THAT(AdbcConnectionRollback(&connection, &error), IsOkStatus(&error)); + + { + adbc_validation::Handle schema; + adbc_validation::Handle batch; + + ArrowSchemaInit(&schema.value); + ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_INT64), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "ints"), + adbc_validation::IsOkErrno()); + + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &batch.value, static_cast(nullptr), + {-1, 0, 1, std::nullopt})), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "temptable", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + } +} + +TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + { + adbc_validation::Handle schema; + adbc_validation::Handle batch; + + ArrowSchemaInit(&schema.value); + ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "$1"), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetTypeDateTime(schema->children[0], NANOARROW_TYPE_TIMESTAMP, + NANOARROW_TIME_UNIT_SECOND, nullptr), + adbc_validation::IsOkErrno()); + + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &batch.value, static_cast(nullptr), + {std::numeric_limits::max()})), + adbc_validation::IsOkErrno()); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT CAST($1 AS TIMESTAMP)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); + ASSERT_THAT(error.message, + ::testing::HasSubstr("Row #1 has value '9223372036854775807' which " + "exceeds PostgreSQL timestamp limits")); + } + + { + adbc_validation::Handle schema; + adbc_validation::Handle batch; + + ArrowSchemaInit(&schema.value); + ASSERT_THAT(ArrowSchemaSetTypeStruct(&schema.value, 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(schema->children[0], "$1"), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetTypeDateTime(schema->children[0], NANOARROW_TYPE_TIMESTAMP, + NANOARROW_TIME_UNIT_SECOND, nullptr), + adbc_validation::IsOkErrno()); + + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &batch.value, static_cast(nullptr), + {std::numeric_limits::min()})), + adbc_validation::IsOkErrno()); + + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT CAST($1 AS TIMESTAMP)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &batch.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); + ASSERT_THAT(error.message, + ::testing::HasSubstr("Row #1 has value '-9223372036854775808' which " + "exceeds PostgreSQL timestamp limits")); + } +} + +TEST_F(PostgresStatementTest, SqlReadIntervalOverflow) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT CAST('P0Y0M0DT2562048H0M0S' AS INTERVAL)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, -1); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_THAT(reader.MaybeNext(), + adbc_validation::IsErrno(EINVAL, &reader.stream.value, nullptr)); + ASSERT_THAT(reader.stream->get_last_error(&reader.stream.value), + ::testing::HasSubstr("Interval with time value 9223372800000000 usec " + "would overflow when converting to nanoseconds")); + ASSERT_EQ(reader.array->release, nullptr); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT CAST('P0Y0M0DT-2562048H0M0S' AS INTERVAL)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, -1); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_THAT(reader.MaybeNext(), + adbc_validation::IsErrno(EINVAL, &reader.stream.value, nullptr)); + ASSERT_THAT(reader.stream->get_last_error(&reader.stream.value), + ::testing::HasSubstr("Interval with time value -9223372800000000 usec " + "would overflow when converting to nanoseconds")); + ASSERT_EQ(reader.array->release, nullptr); + } +} + TEST_F(PostgresStatementTest, UpdateInExecuteQuery) { ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), IsOkStatus(&error)); @@ -630,6 +1123,111 @@ TEST_F(PostgresStatementTest, UpdateInExecuteQuery) { } } +TEST_F(PostgresStatementTest, BatchSizeHint) { + ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + // Setting the batch size hint to a negative or non-integer value should fail + ASSERT_EQ(AdbcStatementSetOption(&statement, "adbc.postgresql.batch_size_hint_bytes", + "-1", nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + ASSERT_EQ(AdbcStatementSetOption(&statement, "adbc.postgresql.batch_size_hint_bytes", + "not a valid number", nullptr), + ADBC_STATUS_INVALID_ARGUMENT); + + // For this test, use a batch size of 1 byte to force every row to be its own batch + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.batch_size_hint_bytes", + "1", &error), + IsOkStatus(&error)); + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT int64s from batch_size_hint_test ORDER BY int64s LIMIT 3", + &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(ArrowArrayViewGetIntUnsafe(reader.array_view->children[0], 0), -42); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(ArrowArrayViewGetIntUnsafe(reader.array_view->children[0], 0), 42); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_TRUE(ArrowArrayViewIsNull(reader.array_view->children[0], 0)); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + +// Test that an ADBC 1.0.0-sized error still works +TEST_F(PostgresStatementTest, AdbcErrorBackwardsCompatibility) { + // XXX: sketchy cast + auto* error = static_cast(malloc(ADBC_ERROR_1_0_0_SIZE)); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, error), IsOkStatus(error)); + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM thistabledoesnotexist", error), + IsOkStatus(error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, error), + IsStatus(ADBC_STATUS_NOT_FOUND, error)); + + ASSERT_EQ("42P01", std::string_view(error->sqlstate, 5)); + ASSERT_EQ(0, AdbcErrorGetDetailCount(error)); + + error->release(error); + free(error); +} + +TEST_F(PostgresStatementTest, Cancel) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + for (const char* query : { + "DROP TABLE IF EXISTS test_cancel", + "CREATE TABLE test_cancel (ints INT)", + R"(INSERT INTO test_cancel (ints) + SELECT g :: INT FROM GENERATE_SERIES(1, 65536) temp(g))", + }) { + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM test_cancel", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementCancel(&statement, &error), IsOkStatus(&error)); + + int retcode = 0; + while (true) { + retcode = reader.MaybeNext(); + if (retcode != 0 || !reader.array->release) break; + } + + ASSERT_EQ(ECANCELED, retcode); + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* detail = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(ADBC_STATUS_CANCELLED, status); + ASSERT_EQ("57014", std::string_view(detail->sqlstate, 5)); + ASSERT_NE(0, AdbcErrorGetDetailCount(detail)); +} + struct TypeTestCase { std::string name; std::string sql_type; @@ -683,6 +1281,13 @@ class PostgresTypeTest : public ::testing::TestWithParam { }; TEST_P(PostgresTypeTest, SelectValue) { + std::string value = GetParam().sql_literal; + if ((value == "'-inf'") || (value == "'inf'")) { + const std::string version = adbc_validation::GetDriverVendorVersion(&connection_); + if (version < "140000") { + GTEST_SKIP() << "-inf and inf not implemented until postgres 14"; + } + } // create table std::string query = "CREATE TABLE foo (col "; query += GetParam().sql_type; @@ -784,6 +1389,16 @@ static std::initializer_list kIntTypeCases = { {"BIGSERIAL", "BIGSERIAL", std::to_string(std::numeric_limits::max()), NANOARROW_TYPE_INT64, std::numeric_limits::max()}, }; +static std::initializer_list kNumericTypeCases = { + {"NUMERIC_TRAILING0", "NUMERIC", "1000000", NANOARROW_TYPE_STRING, "1000000"}, + {"NUMERIC_LEADING0", "NUMERIC", "0.00001234", NANOARROW_TYPE_STRING, "0.00001234"}, + {"NUMERIC_TRAILING02", "NUMERIC", "'1.0000'", NANOARROW_TYPE_STRING, "1.0000"}, + {"NUMERIC_NEGATIVE", "NUMERIC", "-123.456", NANOARROW_TYPE_STRING, "-123.456"}, + {"NUMERIC_POSITIVE", "NUMERIC", "123.456", NANOARROW_TYPE_STRING, "123.456"}, + {"NUMERIC_NAN", "NUMERIC", "'nan'", NANOARROW_TYPE_STRING, "nan"}, + {"NUMERIC_NINF", "NUMERIC", "'-inf'", NANOARROW_TYPE_STRING, "-inf"}, + {"NUMERIC_PINF", "NUMERIC", "'inf'", NANOARROW_TYPE_STRING, "inf"}, +}; static std::initializer_list kDateTypeCases = { {"DATE0", "DATE", "'1970-01-01'", NANOARROW_TYPE_DATE32, int64_t(0)}, {"DATE1", "DATE", "'2000-01-01'", NANOARROW_TYPE_DATE32, int64_t(10957)}, @@ -914,6 +1529,8 @@ INSTANTIATE_TEST_SUITE_P(FloatTypes, PostgresTypeTest, testing::ValuesIn(kFloatT TypeTestCase::FormatName); INSTANTIATE_TEST_SUITE_P(IntTypes, PostgresTypeTest, testing::ValuesIn(kIntTypeCases), TypeTestCase::FormatName); +INSTANTIATE_TEST_SUITE_P(NumericType, PostgresTypeTest, + testing::ValuesIn(kNumericTypeCases), TypeTestCase::FormatName); INSTANTIATE_TEST_SUITE_P(DateTypes, PostgresTypeTest, testing::ValuesIn(kDateTypeCases), TypeTestCase::FormatName); INSTANTIATE_TEST_SUITE_P(TimeTypes, PostgresTypeTest, testing::ValuesIn(kTimeTypeCases), diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc new file mode 100644 index 0000000..3a2a0d0 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "result_helper.h" + +#include "common/utils.h" +#include "error.h" + +namespace adbcpq { + +PqResultHelper::~PqResultHelper() { + if (result_ != nullptr) { + PQclear(result_); + } +} + +AdbcStatusCode PqResultHelper::Prepare() { + // TODO: make stmtName a unique identifier? + PGresult* result = + PQprepare(conn_, /*stmtName=*/"", query_.c_str(), param_values_.size(), NULL); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error_, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(conn_), query_.c_str()); + PQclear(result); + return code; + } + + PQclear(result); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PqResultHelper::Execute() { + std::vector param_c_strs; + + for (size_t index = 0; index < param_values_.size(); index++) { + param_c_strs.push_back(param_values_[index].c_str()); + } + + result_ = + PQexecPrepared(conn_, "", param_values_.size(), param_c_strs.data(), NULL, NULL, 0); + + ExecStatusType status = PQresultStatus(result_); + if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) { + AdbcStatusCode error = + SetError(error_, result_, "[libpq] Failed to execute query '%s': %s", + query_.c_str(), PQerrorMessage(conn_)); + return error; + } + + return ADBC_STATUS_OK; +} + +} // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h new file mode 100644 index 0000000..e9307dc --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace adbcpq { + +/// \brief A single column in a single row of a result set. +struct PqRecord { + const char* data; + const int len; + const bool is_null; + + // XXX: can't use optional due to R + std::pair ParseDouble() const { + char* end; + double result = std::strtod(data, &end); + if (errno != 0 || end == data) { + return std::make_pair(false, 0.0); + } + return std::make_pair(true, result); + } +}; + +// Used by PqResultHelper to provide index-based access to the records within each +// row of a PGresult +class PqResultRow { + public: + PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) { + ncols_ = PQnfields(result); + } + + PqRecord operator[](const int& col_num) { + assert(col_num < ncols_); + const char* data = PQgetvalue(result_, row_num_, col_num); + const int len = PQgetlength(result_, row_num_, col_num); + const bool is_null = PQgetisnull(result_, row_num_, col_num); + + return PqRecord{data, len, is_null}; + } + + private: + PGresult* result_ = nullptr; + int row_num_; + int ncols_; +}; + +// Helper to manager the lifecycle of a PQResult. The query argument +// will be evaluated as part of the constructor, with the desctructor handling cleanup +// Caller must call Prepare then Execute, checking both for an OK AdbcStatusCode +// prior to iterating +class PqResultHelper { + public: + explicit PqResultHelper(PGconn* conn, std::string query, struct AdbcError* error) + : conn_(conn), query_(std::move(query)), error_(error) {} + + explicit PqResultHelper(PGconn* conn, std::string query, + std::vector param_values, struct AdbcError* error) + : conn_(conn), + query_(std::move(query)), + param_values_(std::move(param_values)), + error_(error) {} + + ~PqResultHelper(); + + AdbcStatusCode Prepare(); + AdbcStatusCode Execute(); + + int NumRows() const { return PQntuples(result_); } + + int NumColumns() const { return PQnfields(result_); } + + class iterator { + const PqResultHelper& outer_; + int curr_row_ = 0; + + public: + explicit iterator(const PqResultHelper& outer, int curr_row = 0) + : outer_(outer), curr_row_(curr_row) {} + iterator& operator++() { + curr_row_++; + return *this; + } + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + bool operator==(iterator other) const { + return outer_.result_ == other.outer_.result_ && curr_row_ == other.curr_row_; + } + bool operator!=(iterator other) const { return !(*this == other); } + PqResultRow operator*() { return PqResultRow(outer_.result_, curr_row_); } + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = std::vector; + using pointer = const std::vector*; + using reference = const std::vector&; + }; + + iterator begin() { return iterator(*this); } + iterator end() { return iterator(*this, NumRows()); } + + private: + PGresult* result_ = nullptr; + PGconn* conn_; + std::string query_; + std::vector param_values_; + struct AdbcError* error_; +}; +} // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc index 3092046..3910378 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc @@ -18,6 +18,7 @@ #include "statement.h" #include +#include #include #include #include @@ -30,11 +31,14 @@ #include #include +#include "common/options.h" #include "common/utils.h" #include "connection.h" +#include "error.h" #include "postgres_copy_reader.h" #include "postgres_type.h" #include "postgres_util.h" +#include "result_helper.h" namespace adbcpq { @@ -145,6 +149,9 @@ struct BindStream { // XXX: this assumes fixed-length fields only - will need more // consideration to deal with variable-length fields + bool has_tz_field = false; + std::string tz_setting; + struct ArrowError na_error; explicit BindStream(struct ArrowArrayStream&& bind) { @@ -187,6 +194,7 @@ struct BindStream { for (size_t i = 0; i < bind_schema_fields.size(); i++) { PostgresTypeId type_id; switch (bind_schema_fields[i].type) { + case ArrowType::NANOARROW_TYPE_INT8: case ArrowType::NANOARROW_TYPE_INT16: type_id = PostgresTypeId::kInt2; param_lengths[i] = 2; @@ -208,6 +216,7 @@ struct BindStream { param_lengths[i] = 8; break; case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: type_id = PostgresTypeId::kText; param_lengths[i] = 0; break; @@ -215,6 +224,19 @@ struct BindStream { type_id = PostgresTypeId::kBytea; param_lengths[i] = 0; break; + case ArrowType::NANOARROW_TYPE_DATE32: + type_id = PostgresTypeId::kDate; + param_lengths[i] = 4; + break; + case ArrowType::NANOARROW_TYPE_TIMESTAMP: + type_id = PostgresTypeId::kTimestamp; + param_lengths[i] = 8; + break; + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + type_id = PostgresTypeId::kInterval; + param_lengths[i] = 16; + break; default: SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", static_cast(i + 1), " ('", bind_schema->children[i]->name, @@ -242,15 +264,61 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode Prepare(PGconn* conn, const std::string& query, - struct AdbcError* error) { + AdbcStatusCode Prepare(PGconn* conn, const std::string& query, struct AdbcError* error, + const bool autocommit) { + // tz-aware timestamps require special handling to set the timezone to UTC + // prior to sending over the binary protocol; must be reset after execute + for (int64_t col = 0; col < bind_schema->n_children; col++) { + if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && + (strcmp("", bind_schema_fields[col].timezone))) { + has_tz_field = true; + + if (autocommit) { + PGresult* begin_result = PQexec(conn, "BEGIN"); + if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, begin_result, + "[libpq] Failed to begin transaction for timezone data: %s", + PQerrorMessage(conn)); + PQclear(begin_result); + return code; + } + PQclear(begin_result); + } + + PGresult* get_tz_result = PQexec(conn, "SELECT current_setting('TIMEZONE')"); + if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { + AdbcStatusCode code = SetError(error, get_tz_result, + "[libpq] Could not query current timezone: %s", + PQerrorMessage(conn)); + PQclear(get_tz_result); + return code; + } + + tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); + PQclear(get_tz_result); + + PGresult* set_utc_result = PQexec(conn, "SET TIME ZONE 'UTC'"); + if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = SetError(error, set_utc_result, + "[libpq] Failed to set time zone to UTC: %s", + PQerrorMessage(conn)); + PQclear(set_utc_result); + return code; + } + PQclear(set_utc_result); + break; + } + } + PGresult* result = PQprepare(conn, /*stmtName=*/"", query.c_str(), /*nParams=*/bind_schema->n_children, param_types.data()); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn), query.c_str()); + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(conn), query.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -290,6 +358,13 @@ struct BindStream { param_values[col] = param_values_buffer.data() + param_values_offsets[col]; } switch (bind_schema_fields[col].type) { + case ArrowType::NANOARROW_TYPE_INT8: { + const int16_t val = + array_view->children[col]->buffer_views[1].data.as_int8[row]; + const uint16_t value = ToNetworkInt16(val); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; + } case ArrowType::NANOARROW_TYPE_INT16: { const uint16_t value = ToNetworkInt16( array_view->children[col]->buffer_views[1].data.as_int16[row]); @@ -321,6 +396,7 @@ struct BindStream { break; } case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: case ArrowType::NANOARROW_TYPE_BINARY: { const ArrowBufferView view = ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); @@ -329,12 +405,97 @@ struct BindStream { param_values[col] = const_cast(view.data.as_char); break; } + case ArrowType::NANOARROW_TYPE_DATE32: { + // 2000-01-01 + constexpr int32_t kPostgresDateEpoch = 10957; + const int32_t raw_value = + array_view->children[col]->buffer_views[1].data.as_int32[row]; + if (raw_value < INT32_MIN + kPostgresDateEpoch) { + SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, + "('", bind_schema->children[col]->name, "') Row #", row + 1, + "has value which exceeds postgres date limits"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_TIMESTAMP: { + int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; + + // 2000-01-01 00:00:00.000000 in microseconds + constexpr int64_t kPostgresTimestampEpoch = 946684800000000; + bool overflow_safe = true; + + auto unit = bind_schema_fields[col].time_unit; + + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + if ((overflow_safe = val <= kMaxSafeSecondsToMicros && + val >= kMinSafeSecondsToMicros)) { + val *= 1000000; + } + + break; + case NANOARROW_TIME_UNIT_MILLI: + if ((overflow_safe = val <= kMaxSafeMillisToMicros && + val >= kMinSafeMillisToMicros)) { + val *= 1000; + } + break; + case NANOARROW_TIME_UNIT_MICRO: + break; + case NANOARROW_TIME_UNIT_NANO: + val /= 1000; + break; + } + + if (!overflow_safe) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 + "' which exceeds PostgreSQL timestamp limits", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { + const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + } else if (bind_schema_fields[col].type == + ArrowType::NANOARROW_TYPE_DURATION) { + // postgres stores an interval as a 64 bit offset in microsecond + // resolution alongside a 32 bit day and 32 bit month + // for now we just send 0 for the day / month values + const uint64_t value = ToNetworkInt64(val); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); + } + break; + } + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + struct ArrowInterval interval; + ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); + ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); + + const uint32_t months = ToNetworkInt32(interval.months); + const uint32_t days = ToNetworkInt32(interval.days); + const uint64_t ms = ToNetworkInt64(interval.ns / 1000); + + std::memcpy(param_values[col], &ms, sizeof(uint64_t)); + std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); + std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), + &months, sizeof(uint32_t)); + break; + } default: - // TODO: data type to string - SetError(error, "%s%" PRId64 "%s%s%s%ud", "[libpq] Field #", col + 1, " ('", + SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", bind_schema->children[col]->name, "') has unsupported type for ingestion ", - bind_schema_fields[col].type); + ArrowTypeString(bind_schema_fields[col].type)); return ADBC_STATUS_NOT_IMPLEMENTED; } } @@ -344,16 +505,41 @@ struct BindStream { param_lengths.data(), param_formats.data(), /*resultFormat=*/0 /*text*/); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "%s%s", "[libpq] Failed to execute prepared statement: ", - PQerrorMessage(conn)); + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK) { + AdbcStatusCode code = SetError( + error, result, "[libpq] Failed to execute prepared statement: %s %s", + PQresStatus(pg_status), PQerrorMessage(conn)); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); } if (rows_affected) *rows_affected += array->length; + + if (has_tz_field) { + std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; + PGresult* reset_tz_result = PQexec(conn, reset_query.c_str()); + if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", + PQerrorMessage(conn)); + PQclear(reset_tz_result); + return code; + } + PQclear(reset_tz_result); + + PGresult* commit_result = PQexec(conn, "COMMIT"); + if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", + PQerrorMessage(conn)); + PQclear(commit_result); + return code; + } + PQclear(commit_result); + } } return ADBC_STATUS_OK; } @@ -361,118 +547,183 @@ struct BindStream { } // namespace int TupleReader::GetSchema(struct ArrowSchema* out) { + assert(copy_reader_ != nullptr); + int na_res = copy_reader_->GetSchema(out); if (out->release == nullptr) { - StringBuilderAppend(&error_builder_, - "[libpq] Result set was already consumed or freed"); - return EINVAL; + SetError(&error_, "[libpq] Result set was already consumed or freed"); + status_ = ADBC_STATUS_INVALID_STATE; + return AdbcStatusCodeToErrno(status_); } else if (na_res != NANOARROW_OK) { // e.g., Can't allocate memory - StringBuilderAppend(&error_builder_, "[libpq] Error copying schema"); + SetError(&error_, "[libpq] Error copying schema"); + status_ = ADBC_STATUS_INTERNAL; } return na_res; } -int TupleReader::GetNext(struct ArrowArray* out) { - if (!result_) { - out->release = nullptr; - return 0; +int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) { + // Fetch + parse the header + int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); + data_.size_bytes = get_copy_res; + data_.data.as_char = pgbuf_; + + if (get_copy_res == -2) { + SetError(&error_, "[libpq] Fetch header failed: %s", PQerrorMessage(conn_)); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); } - // Clear the result, since the data is actually read from the connection - PQclear(result_); - result_ = nullptr; + int na_res = copy_reader_->ReadHeader(&data_, error); + if (na_res != NANOARROW_OK) { + SetError(&error_, "[libpq] ReadHeader failed: %s", error->message); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); + } + + return NANOARROW_OK; +} - // Clear the error builder - error_builder_.size = 0; +int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { + // Parse the result (the header AND the first row are included in the first + // call to PQgetCopyData()) + int na_res = copy_reader_->ReadRecord(&data_, error); + if (na_res != NANOARROW_OK && na_res != ENODATA) { + SetError(&error_, "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, + error->message); + status_ = ADBC_STATUS_IO; + return na_res; + } - struct ArrowError error; - error.message[0] = '\0'; - struct ArrowBufferView data; - data.data.data = nullptr; - data.size_bytes = 0; + row_id_++; - // Fetch + parse the header + // Fetch + check + PQfreemem(pgbuf_); + pgbuf_ = nullptr; int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); + data_.size_bytes = get_copy_res; + data_.data.as_char = pgbuf_; + if (get_copy_res == -2) { - StringBuilderAppend(&error_builder_, "[libpq] Fetch header failed: %s", - PQerrorMessage(conn_)); - return EIO; + SetError(&error_, "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, + PQerrorMessage(conn_)); + status_ = ADBC_STATUS_IO; + return AdbcStatusCodeToErrno(status_); + } else if (get_copy_res == -1) { + // Returned when COPY has finished successfully + return ENODATA; + } else if ((copy_reader_->array_size_approx_bytes() + get_copy_res) >= + batch_size_hint_bytes_) { + // Appending the next row will result in an array larger than requested. + // Return EOVERFLOW to force GetNext() to build the current result and return. + return EOVERFLOW; + } else { + return NANOARROW_OK; } +} - data.size_bytes = get_copy_res; - data.data.as_char = pgbuf_; - int na_res = copy_reader_->ReadHeader(&data, &error); +int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) { + if (copy_reader_->array_size_approx_bytes() == 0) { + out->release = nullptr; + return NANOARROW_OK; + } + + int na_res = copy_reader_->GetArray(out, error); if (na_res != NANOARROW_OK) { - StringBuilderAppend(&error_builder_, "[libpq] ReadHeader failed: %s", error.message); + SetError(&error_, "[libpq] Failed to build result array: %s", error->message); + status_ = ADBC_STATUS_INTERNAL; return na_res; } - int64_t row_id = 0; - do { - // Parse the result (the header AND the first row are included in the first - // call to PQgetCopyData()) - na_res = copy_reader_->ReadRecord(&data, &error); - if (na_res != NANOARROW_OK && na_res != ENODATA) { - StringBuilderAppend(&error_builder_, "[libpq] ReadRecord failed at row %ld: %s", - static_cast(row_id), // NOLINT(runtime/int) - error.message); - return na_res; - } + return NANOARROW_OK; +} - row_id++; +int TupleReader::GetNext(struct ArrowArray* out) { + if (is_finished_) { + out->release = nullptr; + return 0; + } - // Fetch + check - PQfreemem(pgbuf_); - pgbuf_ = nullptr; - get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); - if (get_copy_res == -2) { - StringBuilderAppend(&error_builder_, "[libpq] Fetch row %ld failed: %s", - static_cast(row_id), // NOLINT(runtime/int) - PQerrorMessage(conn_)); - return EIO; - } else if (get_copy_res == -1) { - // Returned when COPY has finished - break; - } + struct ArrowError error; + error.message[0] = '\0'; - data.size_bytes = get_copy_res; - data.data.as_char = pgbuf_; - } while (true); + if (row_id_ == -1) { + NANOARROW_RETURN_NOT_OK(InitQueryAndFetchFirst(&error)); + row_id_++; + } - na_res = copy_reader_->GetArray(out, &error); - if (na_res != NANOARROW_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Failed to build result array: %s", - error.message); + int na_res; + do { + na_res = AppendRowAndFetchNext(&error); + if (na_res == EOVERFLOW) { + // The result would be too big to return if we appended the row. When EOVERFLOW is + // returned, the copy reader leaves the output in a valid state. The data is left in + // pg_buf_/data_ and will attempt to be appended on the next call to GetNext() + return BuildOutput(out, &error); + } + } while (na_res == NANOARROW_OK); + + if (na_res != ENODATA) { return na_res; } + is_finished_ = true; + + // Finish the result properly and return the last result. Note that BuildOutput() may + // set tmp.release = nullptr if there were zero rows in the copy reader (can + // occur in an overflow scenario). + struct ArrowArray tmp; + NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error)); + + PQclear(result_); // Check the server-side response result_ = PQgetResult(conn_); - const int pq_status = PQresultStatus(result_); + const ExecStatusType pq_status = PQresultStatus(result_); if (pq_status != PGRES_COMMAND_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", pq_status, - PQresultErrorMessage(result_)); - return EIO; + const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE); + SetError(&error_, result_, "[libpq] Query failed [%s]: %s", PQresStatus(pq_status), + PQresultErrorMessage(result_)); + + if (tmp.release != nullptr) { + tmp.release(&tmp); + } + + if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) { + status_ = ADBC_STATUS_CANCELLED; + } else { + status_ = ADBC_STATUS_IO; + } + return AdbcStatusCodeToErrno(status_); } - PQclear(result_); - result_ = nullptr; + ArrowArrayMove(&tmp, out); return NANOARROW_OK; } void TupleReader::Release() { - StringBuilderReset(&error_builder_); + if (error_.release) { + error_.release(&error_); + } + error_ = ADBC_ERROR_INIT; + status_ = ADBC_STATUS_OK; if (result_) { PQclear(result_); result_ = nullptr; } + if (pgbuf_) { PQfreemem(pgbuf_); pgbuf_ = nullptr; } + + if (copy_reader_) { + copy_reader_.reset(); + } + + is_finished_ = false; + row_id_ = -1; } void TupleReader::ExportTo(struct ArrowArrayStream* stream) { @@ -483,6 +734,19 @@ void TupleReader::ExportTo(struct ArrowArrayStream* stream) { stream->private_data = this; } +const struct AdbcError* TupleReader::ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != &ReleaseTrampoline) { + return nullptr; + } + + TupleReader* reader = static_cast(stream->private_data); + if (status) { + *status = reader->status_; + } + return &reader->error_; +} + int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out) { if (!self || !self->private_data) return EINVAL; @@ -564,18 +828,113 @@ AdbcStatusCode PostgresStatement::Bind(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) { + // Ultimately the same underlying PGconn + return connection_->Cancel(error); +} + AdbcStatusCode PostgresStatement::CreateBulkTable( - const struct ArrowSchema& source_schema, + const std::string& current_schema, const struct ArrowSchema& source_schema, const std::vector& source_schema_fields, - struct AdbcError* error) { - std::string create = "CREATE TABLE "; - create += ingest_.target; + std::string* escaped_table, struct AdbcError* error) { + PGconn* conn = connection_->conn(); + + if (!ingest_.db_schema.empty() && ingest_.temporary) { + SetError(error, "[libpq] Cannot set both %s and %s", + ADBC_INGEST_OPTION_TARGET_DB_SCHEMA, ADBC_INGEST_OPTION_TEMPORARY); + return ADBC_STATUS_INVALID_STATE; + } + + { + if (!ingest_.db_schema.empty()) { + char* escaped = + PQescapeIdentifier(conn, ingest_.db_schema.c_str(), ingest_.db_schema.size()); + if (escaped == nullptr) { + SetError(error, "[libpq] Failed to escape target schema %s for ingestion: %s", + ingest_.db_schema.c_str(), PQerrorMessage(conn)); + return ADBC_STATUS_INTERNAL; + } + *escaped_table += escaped; + *escaped_table += " . "; + PQfreemem(escaped); + } else if (ingest_.temporary) { + // OK to be redundant (CREATE TEMPORARY TABLE pg_temp.foo) + *escaped_table += "pg_temp . "; + } else { + // Explicitly specify the current schema to avoid any temporary tables + // shadowing this table + char* escaped = + PQescapeIdentifier(conn, current_schema.c_str(), current_schema.size()); + *escaped_table += escaped; + *escaped_table += " . "; + PQfreemem(escaped); + } + + if (!ingest_.target.empty()) { + char* escaped = + PQescapeIdentifier(conn, ingest_.target.c_str(), ingest_.target.size()); + if (escaped == nullptr) { + SetError(error, "[libpq] Failed to escape target table %s for ingestion: %s", + ingest_.target.c_str(), PQerrorMessage(conn)); + return ADBC_STATUS_INTERNAL; + } + *escaped_table += escaped; + PQfreemem(escaped); + } + } + + std::string create; + + if (ingest_.temporary) { + create = "CREATE TEMPORARY TABLE "; + } else { + create = "CREATE TABLE "; + } + + switch (ingest_.mode) { + case IngestMode::kCreate: + // Nothing to do + break; + case IngestMode::kAppend: + return ADBC_STATUS_OK; + case IngestMode::kReplace: { + std::string drop = "DROP TABLE IF EXISTS " + *escaped_table; + PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0, + /*paramTypes=*/nullptr, /*paramValues=*/nullptr, + /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, + /*resultFormat=*/1 /*(binary)*/); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to drop table: %s\nQuery was: %s", + PQerrorMessage(conn), drop.c_str()); + PQclear(result); + return code; + } + PQclear(result); + break; + } + case IngestMode::kCreateAppend: + create += "IF NOT EXISTS "; + break; + } + create += *escaped_table; create += " ("; for (size_t i = 0; i < source_schema_fields.size(); i++) { if (i > 0) create += ", "; - create += source_schema.children[i]->name; + + const char* unescaped = source_schema.children[i]->name; + char* escaped = PQescapeIdentifier(conn, unescaped, std::strlen(unescaped)); + if (escaped == nullptr) { + SetError(error, "[libpq] Failed to escape column %s for ingestion: %s", unescaped, + PQerrorMessage(conn)); + return ADBC_STATUS_INTERNAL; + } + create += escaped; + PQfreemem(escaped); + switch (source_schema_fields[i].type) { + case ArrowType::NANOARROW_TYPE_INT8: case ArrowType::NANOARROW_TYPE_INT16: create += " SMALLINT"; break; @@ -592,31 +951,47 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( create += " DOUBLE PRECISION"; break; case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: create += " TEXT"; break; case ArrowType::NANOARROW_TYPE_BINARY: create += " BYTEA"; break; + case ArrowType::NANOARROW_TYPE_DATE32: + create += " DATE"; + break; + case ArrowType::NANOARROW_TYPE_TIMESTAMP: + if (strcmp("", source_schema_fields[i].timezone)) { + create += " TIMESTAMPTZ"; + } else { + create += " TIMESTAMP"; + } + break; + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + create += " INTERVAL"; + break; default: - // TODO: data type to string - SetError(error, "%s%" PRIu64 "%s%s%s%ud", "[libpq] Field #", + SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", static_cast(i + 1), " ('", source_schema.children[i]->name, - "') has unsupported type for ingestion ", source_schema_fields[i].type); + "') has unsupported type for ingestion ", + ArrowTypeString(source_schema_fields[i].type)); return ADBC_STATUS_NOT_IMPLEMENTED; } } create += ")"; SetError(error, "%s%s", "[libpq] ", create.c_str()); - PGresult* result = PQexecParams(connection_->conn(), create.c_str(), /*nParams=*/0, + PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0, /*paramTypes=*/nullptr, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, /*resultFormat=*/1 /*(binary)*/); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to create table: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), create.c_str()); + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to create table: %s\nQuery was: %s", + PQerrorMessage(conn), create.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } PQclear(result); return ADBC_STATUS_OK; @@ -642,7 +1017,8 @@ AdbcStatusCode PostgresStatement::ExecutePreparedStatement( RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error)); RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); - RAISE_ADBC(bind_stream.Prepare(connection_->conn(), query_, error)); + RAISE_ADBC( + bind_stream.Prepare(connection_->conn(), query_, error, connection_->autocommit())); RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error)); return ADBC_STATUS_OK; } @@ -676,50 +1052,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // 1. Prepare the query to get the schema { - // TODO: we should pipeline here and assume this will succeed - PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), - /*nParams=*/0, nullptr); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "prepare query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - PQclear(result); - result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, - "[libpq] Failed to execute query: could not infer schema: failed to " - "describe prepared statement: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); - PQclear(result); - return ADBC_STATUS_IO; - } - - // Resolve the information from the PGresult into a PostgresType - PostgresType root_type; - AdbcStatusCode status = - ResolvePostgresType(*type_resolver_, result, &root_type, error); - PQclear(result); - if (status != ADBC_STATUS_OK) return status; - - // Initialize the copy reader and infer the output schema (i.e., error for - // unsupported types before issuing the COPY query) - reader_.copy_reader_.reset(new PostgresCopyStreamReader()); - reader_.copy_reader_->Init(root_type); - struct ArrowError na_error; - int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); - if (na_res != NANOARROW_OK) { - SetError(error, "[libpq] Failed to infer output schema: %s", na_error.message); - return na_res; - } + RAISE_ADBC(SetupReader(error)); // If the caller did not request a result set or if there are no // inferred output columns (e.g. a CREATE or UPDATE), then don't // use COPY (which would fail anyways) - if (!stream || root_type.n_children() == 0) { + if (!stream || reader_.copy_reader_->pg_type().n_children() == 0) { RAISE_ADBC(ExecuteUpdateQuery(rows_affected, error)); if (stream) { struct ArrowSchema schema; @@ -733,7 +1071,8 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // This resolves the reader specific to each PostgresType -> ArrowSchema // conversion. It is unlikely that this will fail given that we have just // inferred these conversions ourselves. - na_res = reader_.copy_reader_->InitFieldReaders(&na_error); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InitFieldReaders(&na_error); if (na_res != NANOARROW_OK) { SetError(error, "[libpq] Failed to initialize field readers: %s", na_error.message); return na_res; @@ -748,11 +1087,12 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, /*paramTypes=*/nullptr, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat); if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) { - SetError(error, - "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", - PQerrorMessage(connection_->conn()), copy_query.c_str()); + AdbcStatusCode code = SetError( + error, reader_.result_, + "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", + PQerrorMessage(connection_->conn()), copy_query.c_str()); ClearResult(); - return ADBC_STATUS_IO; + return code; } // Result is read from the connection, not the result, but we won't clear it here } @@ -762,6 +1102,23 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, + struct AdbcError* error) { + ClearResult(); + if (query_.empty()) { + SetError(error, "%s", "[libpq] Must SetSqlQuery before ExecuteQuery"); + return ADBC_STATUS_INVALID_STATE; + } else if (bind_.release) { + // TODO: if we have parameters, bind them (since they can affect the output schema) + SetError(error, "[libpq] ExecuteSchema with parameters is not implemented"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + RAISE_ADBC(SetupReader(error)); + CHECK_NA(INTERNAL, reader_.copy_reader_->GetSchema(schema), error); + return ADBC_STATUS_OK; +} + AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error) { if (!bind_.release) { @@ -769,22 +1126,34 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, return ADBC_STATUS_INVALID_STATE; } + // Need the current schema to avoid being shadowed by temp tables + // This is a little unfortunate; we need another DB roundtrip + std::string current_schema; + { + PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA", {}, error}; + RAISE_ADBC(result_helper.Prepare()); + RAISE_ADBC(result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + return ADBC_STATUS_INTERNAL; + } + current_schema = (*it)[0].data; + } + BindStream bind_stream(std::move(bind_)); std::memset(&bind_, 0, sizeof(bind_)); + std::string escaped_table; RAISE_ADBC(bind_stream.Begin( [&]() -> AdbcStatusCode { - if (!ingest_.append) { - // CREATE TABLE - return CreateBulkTable(bind_stream.bind_schema.value, - bind_stream.bind_schema_fields, error); - } - return ADBC_STATUS_OK; + return CreateBulkTable(current_schema, bind_stream.bind_schema.value, + bind_stream.bind_schema_fields, &escaped_table, error); }, error)); RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); std::string insert = "INSERT INTO "; - insert += ingest_.target; + insert += escaped_table; insert += " VALUES ("; for (size_t i = 0; i < bind_stream.bind_schema_fields.size(); i++) { if (i > 0) insert += ", "; @@ -793,7 +1162,8 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, } insert += ")"; - RAISE_ADBC(bind_stream.Prepare(connection_->conn(), insert, error)); + RAISE_ADBC( + bind_stream.Prepare(connection_->conn(), insert, error, connection_->autocommit())); RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error)); return ADBC_STATUS_OK; } @@ -805,17 +1175,79 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected, PQexecPrepared(connection_->conn(), /*stmtName=*/"", /*nParams=*/0, /*paramValues=*/nullptr, /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, /*resultFormat=*/kPgBinaryFormat); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - SetError(error, "[libpq] Failed to execute query: %s\nQuery was:%s", - PQerrorMessage(connection_->conn()), query_.c_str()); + ExecStatusType status = PQresultStatus(result); + if (status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to execute query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); PQclear(result); - return ADBC_STATUS_IO; + return code; } if (rows_affected) *rows_affected = PQntuples(reader_.result_); PQclear(result); return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t* length, + struct AdbcError* error) { + std::string result; + if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { + result = ingest_.target; + } else if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) == 0) { + result = ingest_.db_schema; + } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { + switch (ingest_.mode) { + case IngestMode::kCreate: + result = ADBC_INGEST_OPTION_MODE_CREATE; + break; + case IngestMode::kAppend: + result = ADBC_INGEST_OPTION_MODE_APPEND; + break; + case IngestMode::kReplace: + result = ADBC_INGEST_OPTION_MODE_REPLACE; + break; + case IngestMode::kCreateAppend: + result = ADBC_INGEST_OPTION_MODE_CREATE_APPEND; + break; + } + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + result = std::to_string(reader_.batch_size_hint_bytes_); + } else { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; + } + + if (result.size() + 1 <= *length) { + std::memcpy(value, result.data(), result.size() + 1); + } + *length = static_cast(result.size() + 1); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PostgresStatement::GetOptionBytes(const char* key, uint8_t* value, + size_t* length, + struct AdbcError* error) { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresStatement::GetOptionDouble(const char* key, double* value, + struct AdbcError* error) { + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode PostgresStatement::GetOptionInt(const char* key, int64_t* value, + struct AdbcError* error) { + std::string result; + if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + *value = reader_.batch_size_hint_bytes_; + return ADBC_STATUS_OK; + } + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode PostgresStatement::GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -844,6 +1276,7 @@ AdbcStatusCode PostgresStatement::Release(struct AdbcError* error) { AdbcStatusCode PostgresStatement::SetSqlQuery(const char* query, struct AdbcError* error) { ingest_.target.clear(); + ingest_.db_schema.clear(); query_ = query; prepared_ = false; return ADBC_STATUS_OK; @@ -854,22 +1287,127 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { query_.clear(); ingest_.target = value; + prepared_ = false; + } else if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) == 0) { + query_.clear(); + if (value == nullptr) { + ingest_.db_schema.clear(); + } else { + ingest_.db_schema = value; + } + prepared_ = false; } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { - ingest_.append = false; + ingest_.mode = IngestMode::kCreate; } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { - ingest_.append = true; + ingest_.mode = IngestMode::kAppend; + } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_REPLACE) == 0) { + ingest_.mode = IngestMode::kReplace; + } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE_APPEND) == 0) { + ingest_.mode = IngestMode::kCreateAppend; + } else { + SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } + prepared_ = false; + } else if (std::strcmp(key, ADBC_INGEST_OPTION_TEMPORARY) == 0) { + if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + ingest_.temporary = true; + } else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + ingest_.temporary = false; } else { - SetError(error, "%s%s%s%s", "[libpq] Invalid value ", value, " for option ", key); + SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } + ingest_.db_schema.clear(); + prepared_ = false; + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + int64_t int_value = std::atol(value); + if (int_value <= 0) { + SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); return ADBC_STATUS_INVALID_ARGUMENT; } + + this->reader_.batch_size_hint_bytes_ = int_value; } else { - SetError(error, "%s%s", "[libq] Unknown statement option ", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; } return ADBC_STATUS_OK; } +AdbcStatusCode PostgresStatement::SetOptionBytes(const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetOptionDouble(const char* key, double value, + struct AdbcError* error) { + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value, + struct AdbcError* error) { + if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { + if (value <= 0) { + SetError(error, "[libpq] Invalid value '%" PRIi64 "' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + this->reader_.batch_size_hint_bytes_ = value; + return ADBC_STATUS_OK; + } + SetError(error, "[libpq] Unknown statement option '%s'", key); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) { + // TODO: we should pipeline here and assume this will succeed + PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", query_.c_str(), + /*nParams=*/0, nullptr); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, + "[libpq] Failed to execute query: could not infer schema: failed to " + "prepare query: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return code; + } + PQclear(result); + result = PQdescribePrepared(connection_->conn(), /*stmtName=*/""); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, + "[libpq] Failed to execute query: could not infer schema: failed to " + "describe prepared statement: %s\nQuery was:%s", + PQerrorMessage(connection_->conn()), query_.c_str()); + PQclear(result); + return code; + } + + // Resolve the information from the PGresult into a PostgresType + PostgresType root_type; + AdbcStatusCode status = ResolvePostgresType(*type_resolver_, result, &root_type, error); + PQclear(result); + if (status != ADBC_STATUS_OK) return status; + + // Initialize the copy reader and infer the output schema (i.e., error for + // unsupported types before issuing the COPY query) + reader_.copy_reader_.reset(new PostgresCopyStreamReader()); + reader_.copy_reader_->Init(root_type); + struct ArrowError na_error; + int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); + if (na_res != NANOARROW_OK) { + SetError(error, "[libpq] Failed to infer output schema: (%d) %s: %s", na_res, + std::strerror(na_res), na_error.message); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; +} + void PostgresStatement::ClearResult() { // TODO: we may want to synchronize here for safety reader_.Release(); diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h index 0ff6cb8..20bb3b7 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h @@ -30,6 +30,9 @@ #include "postgres_copy_reader.h" #include "postgres_type.h" +#define ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES \ + "adbc.postgresql.batch_size_hint_bytes" + namespace adbcpq { class PostgresConnection; class PostgresStatement; @@ -38,35 +41,50 @@ class PostgresStatement; class TupleReader final { public: TupleReader(PGconn* conn) - : conn_(conn), result_(nullptr), pgbuf_(nullptr), copy_reader_(nullptr) { - StringBuilderInit(&error_builder_, 0); + : status_(ADBC_STATUS_OK), + error_(ADBC_ERROR_INIT), + conn_(conn), + result_(nullptr), + pgbuf_(nullptr), + copy_reader_(nullptr), + row_id_(-1), + batch_size_hint_bytes_(16777216), + is_finished_(false) { + data_.data.as_char = nullptr; + data_.size_bytes = 0; } int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); - const char* last_error() const { - if (error_builder_.size > 0) { - return error_builder_.buffer; - } else { - return nullptr; - } - } + const char* last_error() const { return error_.message; } void Release(); void ExportTo(struct ArrowArrayStream* stream); + static const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status); + private: friend class PostgresStatement; + int InitQueryAndFetchFirst(struct ArrowError* error); + int AppendRowAndFetchNext(struct ArrowError* error); + int BuildOutput(struct ArrowArray* out, struct ArrowError* error); + static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out); static int GetNextTrampoline(struct ArrowArrayStream* self, struct ArrowArray* out); static const char* GetLastErrorTrampoline(struct ArrowArrayStream* self); static void ReleaseTrampoline(struct ArrowArrayStream* self); + AdbcStatusCode status_; + struct AdbcError error_; PGconn* conn_; PGresult* result_; char* pgbuf_; - struct StringBuilder error_builder_; + struct ArrowBufferView data_; std::unique_ptr copy_reader_; + int64_t row_id_; + int64_t batch_size_hint_bytes_; + bool is_finished_; }; class PostgresStatement { @@ -82,13 +100,25 @@ class PostgresStatement { AdbcStatusCode Bind(struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error); + AdbcStatusCode Cancel(struct AdbcError* error); AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode ExecuteSchema(struct ArrowSchema* schema, struct AdbcError* error); + AdbcStatusCode GetOption(const char* key, char* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionBytes(const char* key, uint8_t* value, size_t* length, + struct AdbcError* error); + AdbcStatusCode GetOptionDouble(const char* key, double* value, struct AdbcError* error); + AdbcStatusCode GetOptionInt(const char* key, int64_t* value, struct AdbcError* error); AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error); AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error); AdbcStatusCode Prepare(struct AdbcError* error); AdbcStatusCode Release(struct AdbcError* error); AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error); + AdbcStatusCode SetOptionBytes(const char* key, const uint8_t* value, size_t length, + struct AdbcError* error); + AdbcStatusCode SetOptionDouble(const char* key, double value, struct AdbcError* error); + AdbcStatusCode SetOptionInt(const char* key, int64_t value, struct AdbcError* error); AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error); // --------------------------------------------------------------------- @@ -96,14 +126,15 @@ class PostgresStatement { void ClearResult(); AdbcStatusCode CreateBulkTable( - const struct ArrowSchema& source_schema, + const std::string& current_schema, const struct ArrowSchema& source_schema, const std::vector& source_schema_fields, - struct AdbcError* error); + std::string* escaped_table, struct AdbcError* error); AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error); AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error); AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); + AdbcStatusCode SetupReader(struct AdbcError* error); private: std::shared_ptr type_resolver_; @@ -115,9 +146,18 @@ class PostgresStatement { struct ArrowArrayStream bind_; // Bulk ingest state + enum class IngestMode { + kCreate, + kAppend, + kReplace, + kCreateAppend, + }; + struct { + std::string db_schema; std::string target; - bool append = false; + IngestMode mode = IngestMode::kCreate; + bool temporary = false; } ingest_; TupleReader reader_; diff --git a/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc b/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc index 3b2eb65..2a9f692 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc @@ -55,8 +55,9 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { adbc_validation::Handle statement; CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); - std::string drop = "DROP TABLE IF EXISTS "; + std::string drop = "DROP TABLE IF EXISTS \""; drop += name; + drop += "\""; CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, drop.c_str(), error)); CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); @@ -100,17 +101,22 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_FLOAT: case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + return NANOARROW_TYPE_STRING; default: return ingest_type; } } std::string BindParameter(int index) const override { return "?"; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return true; } bool supports_concurrent_statements() const override { return true; } bool supports_transactions() const override { return true; } bool supports_get_sql_info() const override { return false; } bool supports_get_objects() const override { return true; } - bool supports_bulk_ingest() const override { return true; } + bool supports_metadata_current_catalog() const override { return false; } + bool supports_metadata_current_db_schema() const override { return false; } bool supports_partitioned_data() const override { return false; } bool supports_dynamic_parameter_binding() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } @@ -156,6 +162,10 @@ class SnowflakeConnectionTest : public ::testing::Test, } } + // Supported, but we don't validate the values + void TestMetadataCurrentCatalog() { GTEST_SKIP(); } + void TestMetadataCurrentDbSchema() { GTEST_SKIP(); } + protected: SnowflakeQuirks quirks_; }; @@ -177,7 +187,41 @@ class SnowflakeStatementTest : public ::testing::Test, } } + void TestSqlIngestInterval() { GTEST_SKIP(); } + void TestSqlIngestDuration() { GTEST_SKIP(); } + + void TestSqlIngestColumnEscaping() { GTEST_SKIP(); } + protected: + void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, + enum ArrowTimeUnit unit, + const char* timezone) override { + switch (type) { + case NANOARROW_TYPE_TIMESTAMP: { + std::vector> expected; + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + expected = {std::nullopt, -42, 0, 42}; + break; + case NANOARROW_TIME_UNIT_MILLI: + expected = {std::nullopt, -42000, 0, 42000}; + break; + case NANOARROW_TIME_UNIT_MICRO: + expected = {std::nullopt, -42, 0, 42}; + break; + case NANOARROW_TIME_UNIT_NANO: + expected = {std::nullopt, -42, 0, 42}; + break; + } + ASSERT_NO_FATAL_FAILURE( + adbc_validation::CompareArray(values, expected)); + break; + } + default: + FAIL() << "ValidateIngestedTemporalData not implemented for type " << type; + } + } + SnowflakeQuirks quirks_; }; ADBCV_TEST_STATEMENT(SnowflakeStatementTest) diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/driver/sqlite/CMakeLists.txt index eb5a845..3914ee8 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/CMakeLists.txt @@ -56,6 +56,10 @@ foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING) endforeach() +include(CheckTypeSize) +check_type_size("time_t" SIZEOF_TIME_T) +add_definitions(-DSIZEOF_TIME_T=${SIZEOF_TIME_T}) + if(ADBC_TEST_LINKAGE STREQUAL "shared") set(TEST_LINK_LIBS adbc_driver_sqlite_shared) else() diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.c b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.c index 4124098..83cebec 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.c +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.c @@ -28,6 +28,7 @@ #include #include +#include "common/options.h" #include "common/utils.h" #include "statement_reader.h" #include "types.h" @@ -86,6 +87,26 @@ AdbcStatusCode SqliteDatabaseSetOption(struct AdbcDatabase* database, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteDatabaseSetOptionBytes(struct AdbcDatabase* database, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteDatabaseSetOptionDouble(struct AdbcDatabase* database, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + int OpenDatabase(const char* maybe_uri, sqlite3** db, struct AdbcError* error) { const char* uri = maybe_uri ? maybe_uri : kDefaultUri; int rc = sqlite3_open_v2(uri, db, @@ -120,6 +141,33 @@ AdbcStatusCode ExecuteQuery(struct SqliteConnection* conn, const char* query, return ADBC_STATUS_OK; } +AdbcStatusCode SqliteDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionBytes(struct AdbcDatabase* database, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionDouble(struct AdbcDatabase* database, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + CHECK_DB_INIT(database, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { CHECK_DB_INIT(database, error); @@ -204,6 +252,27 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { @@ -282,7 +351,8 @@ AdbcStatusCode SqliteConnectionGetInfoImpl(const uint32_t* info_codes, } // NOLINT(whitespace/indent) AdbcStatusCode SqliteConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, + size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { CHECK_CONN_INIT(connection, error); @@ -754,6 +824,34 @@ AdbcStatusCode SqliteConnectionGetObjects(struct AdbcConnection* connection, int return BatchToArrayStream(&array, &schema, out, error); } +AdbcStatusCode SqliteConnectionGetOption(struct AdbcConnection* connection, + const char* key, char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + CHECK_DB_INIT(connection, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -774,25 +872,26 @@ AdbcStatusCode SqliteConnectionGetTableSchema(struct AdbcConnection* connection, return ADBC_STATUS_INVALID_ARGUMENT; } - struct StringBuilder query = {0}; - if (StringBuilderInit(&query, /*initial_size=*/64) != 0) { - SetError(error, "[SQLite] Could not initiate StringBuilder"); + sqlite3_str* query = sqlite3_str_new(NULL); + if (sqlite3_str_errcode(query)) { + SetError(error, "[SQLite] %s", sqlite3_errmsg(conn->conn)); return ADBC_STATUS_INTERNAL; } - if (StringBuilderAppend(&query, "%s%s", "SELECT * FROM ", table_name) != 0) { - StringBuilderReset(&query); - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(query, "%s%Q", "SELECT * FROM ", table_name); + if (sqlite3_str_errcode(query)) { + SetError(error, "[SQLite] %s", sqlite3_errmsg(conn->conn)); + sqlite3_free(sqlite3_str_finish(query)); return ADBC_STATUS_INTERNAL; } sqlite3_stmt* stmt = NULL; - int rc = - sqlite3_prepare_v2(conn->conn, query.buffer, query.size, &stmt, /*pzTail=*/NULL); - StringBuilderReset(&query); + int rc = sqlite3_prepare_v2(conn->conn, sqlite3_str_value(query), + sqlite3_str_length(query), &stmt, /*pzTail=*/NULL); + sqlite3_free(sqlite3_str_finish(query)); if (rc != SQLITE_OK) { - SetError(error, "[SQLite] Failed to prepare query: %s", sqlite3_errmsg(conn->conn)); - return ADBC_STATUS_INTERNAL; + SetError(error, "[SQLite] GetTableSchema: %s", sqlite3_errmsg(conn->conn)); + return ADBC_STATUS_NOT_FOUND; } struct ArrowArrayStream stream = {0}; @@ -927,6 +1026,7 @@ AdbcStatusCode SqliteStatementRelease(struct AdbcStatement* statement, } if (stmt->query) free(stmt->query); AdbcSqliteBinderRelease(&stmt->binder); + if (stmt->target_catalog) free(stmt->target_catalog); if (stmt->target_table) free(stmt->target_table); if (rc != SQLITE_OK) { SetError(error, @@ -981,66 +1081,139 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, AdbcStatusCode code = ADBC_STATUS_OK; // Create statements for CREATE TABLE / INSERT - struct StringBuilder create_query = {0}; - struct StringBuilder insert_query = {0}; + sqlite3_str* create_query = NULL; + sqlite3_str* insert_query = NULL; + char* table = NULL; - if (StringBuilderInit(&create_query, /*initial_size=*/256) != 0) { - SetError(error, "[SQLite] Could not initiate StringBuilder"); - return ADBC_STATUS_INTERNAL; + create_query = sqlite3_str_new(NULL); + if (sqlite3_str_errcode(create_query)) { + SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; } - if (StringBuilderInit(&insert_query, /*initial_size=*/256) != 0) { - SetError(error, "[SQLite] Could not initiate StringBuilder"); - StringBuilderReset(&create_query); - return ADBC_STATUS_INTERNAL; + insert_query = sqlite3_str_new(NULL); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + + if (stmt->target_catalog != NULL && stmt->temporary != 0) { + SetError(error, "[SQLite] Cannot specify both %s and %s", + ADBC_INGEST_OPTION_TARGET_CATALOG, ADBC_INGEST_OPTION_TEMPORARY); + code = ADBC_STATUS_INVALID_STATE; + goto cleanup; + } + + if (stmt->target_catalog != NULL) { + table = sqlite3_mprintf("\"%w\" . \"%w\"", stmt->target_catalog, stmt->target_table); + } else if (stmt->temporary == 0) { + // If not temporary, explicitly target the main database + table = sqlite3_mprintf("main . \"%w\"", stmt->target_table); + } else { + // OK to be redundant (CREATE TEMP TABLE temp.foo) + table = sqlite3_mprintf("temp . \"%w\"", stmt->target_table); } - if (StringBuilderAppend(&create_query, "%s%s%s", "CREATE TABLE ", stmt->target_table, - " (") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + if (table == NULL) { + // Allocation failure code = ADBC_STATUS_INTERNAL; goto cleanup; } - if (StringBuilderAppend(&insert_query, "%s%s%s", "INSERT INTO ", stmt->target_table, - " VALUES (") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + if (stmt->temporary != 0) { + sqlite3_str_appendf(create_query, "CREATE TEMPORARY TABLE %s (", table); + } else { + sqlite3_str_appendf(create_query, "CREATE TABLE %s (", table); + } + if (sqlite3_str_errcode(create_query)) { + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } - for (int i = 0; i < stmt->binder.schema.n_children; i++) { - if (i > 0) StringBuilderAppend(&create_query, "%s", ", "); - // XXX: should escape the column name too - if (StringBuilderAppend(&create_query, "%s", stmt->binder.schema.children[i]->name) != - 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); - code = ADBC_STATUS_INTERNAL; - goto cleanup; - } + sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + struct ArrowError arrow_error = {0}; + struct ArrowSchemaView view = {0}; + for (int i = 0; i < stmt->binder.schema.n_children; i++) { if (i > 0) { - if (StringBuilderAppend(&insert_query, "%s", ", ") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(create_query, "%s", ", "); + if (sqlite3_str_errcode(create_query)) { + SetError(error, "[SQLite] Failed to build CREATE: %s", + sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } } - if (StringBuilderAppend(&insert_query, "%s", "?") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendf(create_query, "\"%w\"", stmt->binder.schema.children[i]->name); + if (sqlite3_str_errcode(create_query)) { + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + + int status = + ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error); + if (status != 0) { + SetError(error, "[SQLite] Failed to parse schema for column %d: %s (%d): %s", i, + strerror(status), status, arrow_error.message); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + + switch (view.type) { + case NANOARROW_TYPE_UINT8: + case NANOARROW_TYPE_UINT16: + case NANOARROW_TYPE_UINT32: + case NANOARROW_TYPE_UINT64: + case NANOARROW_TYPE_INT8: + case NANOARROW_TYPE_INT16: + case NANOARROW_TYPE_INT32: + case NANOARROW_TYPE_INT64: + sqlite3_str_appendf(create_query, " INTEGER"); + break; + case NANOARROW_TYPE_FLOAT: + case NANOARROW_TYPE_DOUBLE: + sqlite3_str_appendf(create_query, " REAL"); + break; + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_DATE32: + sqlite3_str_appendf(create_query, " TEXT"); + break; + case NANOARROW_TYPE_BINARY: + sqlite3_str_appendf(create_query, " BLOB"); + break; + default: + break; + } + + sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : "")); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } } - if (StringBuilderAppend(&create_query, "%s", ")") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + + sqlite3_str_appendchar(create_query, 1, ')'); + if (sqlite3_str_errcode(create_query)) { + SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } - if (StringBuilderAppend(&insert_query, "%s", ")") != 0) { - SetError(error, "[SQLite] Call to StringBuilderAppend failed"); + sqlite3_str_appendchar(insert_query, 1, ')'); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; goto cleanup; } @@ -1048,25 +1221,29 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_stmt* create = NULL; if (!stmt->append) { // Create table - int rc = sqlite3_prepare_v2(stmt->conn, create_query.buffer, (int)create_query.size, - &create, /*pzTail=*/NULL); + int rc = + sqlite3_prepare_v2(stmt->conn, sqlite3_str_value(create_query), + sqlite3_str_length(create_query), &create, /*pzTail=*/NULL); if (rc == SQLITE_OK) { rc = sqlite3_step(create); } if (rc != SQLITE_OK && rc != SQLITE_DONE) { - SetError(error, "[SQLite] Failed to create table: %s (executed '%s')", - sqlite3_errmsg(stmt->conn), create_query.buffer); + SetError(error, "[SQLite] Failed to create table: %s (executed '%.*s')", + sqlite3_errmsg(stmt->conn), sqlite3_str_length(create_query), + sqlite3_str_value(create_query)); code = ADBC_STATUS_INTERNAL; } } if (code == ADBC_STATUS_OK) { - int rc = sqlite3_prepare_v2(stmt->conn, insert_query.buffer, (int)insert_query.size, - insert_statement, /*pzTail=*/NULL); + int rc = sqlite3_prepare_v2(stmt->conn, sqlite3_str_value(insert_query), + sqlite3_str_length(insert_query), insert_statement, + /*pzTail=*/NULL); if (rc != SQLITE_OK) { - SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%s')", - sqlite3_errmsg(stmt->conn), insert_query.buffer); + SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%.*s')", + sqlite3_errmsg(stmt->conn), sqlite3_str_length(insert_query), + sqlite3_str_value(insert_query)); code = ADBC_STATUS_INTERNAL; } } @@ -1074,8 +1251,9 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, sqlite3_finalize(create); cleanup: - StringBuilderReset(&create_query); - StringBuilderReset(&insert_query); + sqlite3_free(sqlite3_str_finish(create_query)); + sqlite3_free(sqlite3_str_finish(insert_query)); + if (table != NULL) sqlite3_free(table); return code; } @@ -1091,7 +1269,10 @@ AdbcStatusCode SqliteStatementExecuteIngest(struct SqliteStatement* stmt, AdbcStatusCode status = SqliteStatementInitIngest(stmt, &insert, error); int64_t row_count = 0; + int is_autocommit = sqlite3_get_autocommit(stmt->conn); if (status == ADBC_STATUS_OK) { + if (is_autocommit) sqlite3_exec(stmt->conn, "BEGIN TRANSACTION", 0, 0, 0); + while (1) { char finished = 0; status = @@ -1110,6 +1291,8 @@ AdbcStatusCode SqliteStatementExecuteIngest(struct SqliteStatement* stmt, } row_count++; } + + if (is_autocommit) sqlite3_exec(stmt->conn, "COMMIT", 0, 0, 0); } if (rows_affected) *rows_affected = row_count; @@ -1197,6 +1380,10 @@ AdbcStatusCode SqliteStatementSetSqlQuery(struct AdbcStatement* statement, free(stmt->query); stmt->query = NULL; } + if (stmt->target_catalog) { + free(stmt->target_catalog); + stmt->target_catalog = NULL; + } if (stmt->target_table) { free(stmt->target_table); stmt->target_table = NULL; @@ -1233,6 +1420,34 @@ AdbcStatusCode SqliteStatementBindStream(struct AdbcStatement* statement, return AdbcSqliteBinderSetArrayStream(&stmt->binder, stream, error); } +AdbcStatusCode SqliteStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode SqliteStatementGetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t* value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode SqliteStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -1284,6 +1499,20 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c stmt->target_table = (char*)malloc(len); strncpy(stmt->target_table, value, len); return ADBC_STATUS_OK; + } else if (strcmp(key, ADBC_INGEST_OPTION_TARGET_CATALOG) == 0) { + if (stmt->query) { + free(stmt->query); + stmt->query = NULL; + } + if (stmt->target_catalog) { + free(stmt->target_catalog); + stmt->target_catalog = NULL; + } + + size_t len = strlen(value) + 1; + stmt->target_catalog = (char*)malloc(len); + strncpy(stmt->target_catalog, value, len); + return ADBC_STATUS_OK; } else if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { stmt->append = 1; @@ -1294,6 +1523,16 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c return ADBC_STATUS_INVALID_ARGUMENT; } return ADBC_STATUS_OK; + } else if (strcmp(key, ADBC_INGEST_OPTION_TEMPORARY) == 0) { + if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + stmt->temporary = 1; + } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + stmt->temporary = 0; + } else { + SetError(error, "[SQLite] Invalid statement option value %s=%s", key, value); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; } else if (strcmp(key, kStatementOptionBatchRows) == 0) { char* end = NULL; long batch_size = strtol(value, &end, /*base=*/10); // NOLINT(runtime/int) @@ -1322,6 +1561,27 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode SqliteStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode SqliteStatementSetOptionInt(struct AdbcStatement* statement, + const char* key, int64_t value, + struct AdbcError* error) { + CHECK_DB_INIT(statement, error); + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode SqliteStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -1340,7 +1600,7 @@ AdbcStatusCode SqliteDriverInit(int version, void* raw_driver, struct AdbcError* } struct AdbcDriver* driver = (struct AdbcDriver*)raw_driver; - memset(driver, 0, sizeof(*driver)); + memset(driver, 0, ADBC_DRIVER_1_0_0_SIZE); driver->DatabaseInit = SqliteDatabaseInit; driver->DatabaseNew = SqliteDatabaseNew; driver->DatabaseRelease = SqliteDatabaseRelease; @@ -1372,24 +1632,91 @@ AdbcStatusCode SqliteDriverInit(int version, void* raw_driver, struct AdbcError* // Public names -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { - return SqliteDatabaseNew(database, error); +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteDatabaseGetOption(database, key, value, length, error); } -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, - const char* value, struct AdbcError* error) { - return SqliteDatabaseSetOption(database, key, value, error); +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return SqliteDatabaseGetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return SqliteDatabaseGetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return SqliteDatabaseGetOptionDouble(database, key, value, error); } AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { return SqliteDatabaseInit(database, error); } +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return SqliteDatabaseNew(database, error); +} + AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, struct AdbcError* error) { return SqliteDatabaseRelease(database, error); } +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return SqliteDatabaseSetOption(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return SqliteDatabaseSetOptionBytes(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return SqliteDatabaseSetOptionInt(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return SqliteDatabaseSetOptionDouble(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteConnectionGetOption(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SqliteConnectionGetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return SqliteConnectionGetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return SqliteConnectionGetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, struct AdbcError* error) { return SqliteConnectionNew(connection, error); @@ -1400,6 +1727,24 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const return SqliteConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SqliteConnectionSetOptionBytes(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return SqliteConnectionSetOptionInt(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return SqliteConnectionSetOptionDouble(connection, key, value, error); +} + AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, struct AdbcDatabase* database, struct AdbcError* error) { @@ -1412,7 +1757,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { return SqliteConnectionGetInfo(connection, info_codes, info_codes_length, out, error); @@ -1428,6 +1773,20 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d table_type, column_name, out, error); } +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, const char* catalog, const char* db_schema, const char* table_name, @@ -1462,6 +1821,11 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, return SqliteConnectionRollback(connection, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, struct AdbcStatement* statement, struct AdbcError* error) { @@ -1480,6 +1844,12 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, return SqliteStatementExecuteQuery(statement, out, rows_affected, error); } +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, struct AdbcError* error) { return SqliteStatementPrepare(statement, error); @@ -1508,6 +1878,29 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, return SqliteStatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return SqliteStatementGetOption(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return SqliteStatementGetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return SqliteStatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return SqliteStatementGetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -1519,6 +1912,23 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha return SqliteStatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return SqliteStatementSetOptionBytes(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return SqliteStatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return SqliteStatementSetOptionDouble(statement, key, value, error); +} + AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc index 8a580cd..c95a3f1 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc @@ -43,6 +43,28 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { database, "uri", "file:Sqlite_Transactions?mode=memory&cache=shared", error); } + AdbcStatusCode DropTable(struct AdbcConnection* connection, const std::string& name, + struct AdbcError* error) const override { + adbc_validation::Handle statement; + RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error)); + + std::string query = "DROP TABLE IF EXISTS \"" + name + "\""; + RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error)); + RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + return AdbcStatementRelease(&statement.value, error); + } + + AdbcStatusCode DropTempTable(struct AdbcConnection* connection, const std::string& name, + struct AdbcError* error) const override { + adbc_validation::Handle statement; + RAISE_ADBC(AdbcStatementNew(connection, &statement.value, error)); + + std::string query = "DROP TABLE IF EXISTS temp . \"" + name + "\""; + RAISE_ADBC(AdbcStatementSetSqlQuery(&statement.value, query.c_str(), error)); + RAISE_ADBC(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); + return AdbcStatementRelease(&statement.value, error); + } + std::string BindParameter(int index) const override { return "?"; } ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override { @@ -59,6 +81,10 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_FLOAT: case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_DATE32: + case NANOARROW_TYPE_TIMESTAMP: + return NANOARROW_TYPE_STRING; default: return ingest_type; } @@ -71,7 +97,29 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { return ddl; } + bool supports_bulk_ingest(const char* mode) const override { + return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || + std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; + } + bool supports_bulk_ingest_catalog() const override { return true; } + bool supports_bulk_ingest_temporary() const override { return true; } bool supports_concurrent_statements() const override { return true; } + bool supports_get_option() const override { return false; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC SQLite Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "SQLite"; + case ADBC_INFO_VENDOR_VERSION: + return "3."; + default: + return std::nullopt; + } + } std::string catalog() const override { return "main"; } std::string db_schema() const override { return ""; } @@ -169,14 +217,93 @@ class SqliteStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } - void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; } + void TestSqlIngestUInt64() { + std::vector> values = {std::nullopt, 0, INT64_MAX}; + return TestSqlIngestType(NANOARROW_TYPE_UINT64, values); + } + void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; } + void TestSqlIngestDuration() { + GTEST_SKIP() << "Cannot ingest DURATION (not implemented)"; + } + void TestSqlIngestInterval() { + GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; + } protected: + void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, + enum ArrowTimeUnit unit, + const char* timezone) override { + switch (type) { + case NANOARROW_TYPE_TIMESTAMP: { + std::vector> expected; + switch (unit) { + case (NANOARROW_TIME_UNIT_SECOND): + expected.insert(expected.end(), + {std::nullopt, "1969-12-31T23:59:18", "1970-01-01T00:00:00", + "1970-01-01T00:00:42"}); + break; + case (NANOARROW_TIME_UNIT_MILLI): + expected.insert(expected.end(), + {std::nullopt, "1969-12-31T23:59:59.958", + "1970-01-01T00:00:00.000", "1970-01-01T00:00:00.042"}); + break; + case (NANOARROW_TIME_UNIT_MICRO): + expected.insert(expected.end(), + {std::nullopt, "1969-12-31T23:59:59.999958", + "1970-01-01T00:00:00.000000", "1970-01-01T00:00:00.000042"}); + break; + case (NANOARROW_TIME_UNIT_NANO): + expected.insert( + expected.end(), + {std::nullopt, "1969-12-31T23:59:59.999999958", + "1970-01-01T00:00:00.000000000", "1970-01-01T00:00:00.000000042"}); + break; + } + ASSERT_NO_FATAL_FAILURE( + adbc_validation::CompareArray(values, expected)); + break; + } + default: + FAIL() << "ValidateIngestedTemporalData not implemented for type " << type; + } + } + SqliteQuirks quirks_; }; ADBCV_TEST_STATEMENT(SqliteStatementTest) +TEST_F(SqliteStatementTest, SqlIngestNameEscaping) { + ASSERT_THAT(quirks()->DropTable(&connection, "test-table", &error), + adbc_validation::IsOkStatus(&error)); + + std::string table = "test-table"; + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + ASSERT_THAT( + adbc_validation::MakeSchema(&schema.value, {{"index", NANOARROW_TYPE_INT64}, + {"create", NANOARROW_TYPE_STRING}}), + adbc_validation::IsOkErrno()); + ASSERT_THAT((adbc_validation::MakeBatch( + &schema.value, &array.value, &na_error, {42, -42, std::nullopt}, + {"foo", std::nullopt, ""})), + adbc_validation::IsOkErrno(&na_error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + table.c_str(), &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + adbc_validation::IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_EQ(3, rows_affected); +} + // -- SQLite Specific Tests ------------------------------------------ constexpr size_t kInferRows = 16; diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c b/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c index 504a4d8..08bd27d 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c @@ -17,9 +17,12 @@ #include "statement_reader.h" +#include #include #include +#include #include +#include #include #include @@ -89,6 +92,160 @@ AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, memset(values, 0, sizeof(*values)); return AdbcSqliteBinderSet(binder, error); } + +#define SECONDS_PER_DAY 86400 + +/* + Allocates to buf on success. Caller is responsible for freeing. + On failure sets error and contents of buf are undefined. +*/ +static AdbcStatusCode ArrowDate32ToIsoString(int32_t value, char** buf, + struct AdbcError* error) { + int strlen = 10; + +#if SIZEOF_TIME_T < 8 + if ((value > INT32_MAX / SECONDS_PER_DAY) || (value < INT32_MIN / SECONDS_PER_DAY)) { + SetError(error, "Date %" PRId32 " exceeds platform time_t bounds", value); + + return ADBC_STATUS_INVALID_ARGUMENT; + } + time_t time = (time_t)(value * SECONDS_PER_DAY); +#else + time_t time = value * SECONDS_PER_DAY; +#endif + + struct tm broken_down_time; + +#if defined(_WIN32) + if (gmtime_s(&broken_down_time, &time) != 0) { + SetError(error, "Could not convert date %" PRId32 " to broken down time", value); + + return ADBC_STATUS_INVALID_ARGUMENT; + } +#else + if (gmtime_r(&time, &broken_down_time) != &broken_down_time) { + SetError(error, "Could not convert date %" PRId32 " to broken down time", value); + + return ADBC_STATUS_INVALID_ARGUMENT; + } +#endif + + char* tsstr = malloc(strlen + 1); + if (tsstr == NULL) { + return ADBC_STATUS_IO; + } + + if (strftime(tsstr, strlen + 1, "%Y-%m-%d", &broken_down_time) == 0) { + SetError(error, "Call to strftime for date %" PRId32 " with failed", value); + free(tsstr); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + *buf = tsstr; + return ADBC_STATUS_OK; +} + +/* + Allocates to buf on success. Caller is responsible for freeing. + On failure sets error and contents of buf are undefined. +*/ +static AdbcStatusCode ArrowTimestampToIsoString(int64_t value, enum ArrowTimeUnit unit, + char** buf, struct AdbcError* error) { + int scale = 1; + int strlen = 20; + int rem = 0; + + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + break; + case NANOARROW_TIME_UNIT_MILLI: + scale = 1000; + strlen = 24; + break; + case NANOARROW_TIME_UNIT_MICRO: + scale = 1000000; + strlen = 27; + break; + case NANOARROW_TIME_UNIT_NANO: + scale = 1000000000; + strlen = 30; + break; + } + + rem = value % scale; + if (rem < 0) { + value -= scale; + rem = scale + rem; + } + + const int64_t seconds = value / scale; + +#if SIZEOF_TIME_T < 8 + if ((seconds > INT32_MAX) || (seconds < INT32_MIN)) { + SetError(error, "Timestamp %" PRId64 " with unit %d exceeds platform time_t bounds", + value, unit); + + return ADBC_STATUS_INVALID_ARGUMENT; + } + const time_t time = (time_t)seconds; +#else + const time_t time = seconds; +#endif + + struct tm broken_down_time; + +#if defined(_WIN32) + if (gmtime_s(&broken_down_time, &time) != 0) { + SetError(error, + "Could not convert timestamp %" PRId64 " with unit %d to broken down time", + value, unit); + + return ADBC_STATUS_INVALID_ARGUMENT; + } +#else + if (gmtime_r(&time, &broken_down_time) != &broken_down_time) { + SetError(error, + "Could not convert timestamp %" PRId64 " with unit %d to broken down time", + value, unit); + + return ADBC_STATUS_INVALID_ARGUMENT; + } +#endif + + char* tsstr = malloc(strlen + 1); + if (tsstr == NULL) { + return ADBC_STATUS_IO; + } + + if (strftime(tsstr, strlen, "%Y-%m-%dT%H:%M:%S", &broken_down_time) == 0) { + SetError(error, "Call to strftime for timestamp %" PRId64 " with unit %d failed", + value, unit); + free(tsstr); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + assert(rem >= 0); + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + break; + case NANOARROW_TIME_UNIT_MILLI: + tsstr[19] = '.'; + snprintf(tsstr + 20, strlen - 20, "%03d", rem % 1000u); + break; + case NANOARROW_TIME_UNIT_MICRO: + tsstr[19] = '.'; + snprintf(tsstr + 20, strlen - 20, "%06d", rem % 1000000u); + break; + case NANOARROW_TIME_UNIT_NANO: + tsstr[19] = '.'; + snprintf(tsstr + 20, strlen - 20, "%09d", rem % 1000000000u); + break; + } + + *buf = tsstr; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3* conn, sqlite3_stmt* stmt, char* finished, struct AdbcError* error) { @@ -195,6 +352,45 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 SQLITE_STATIC); break; } + case NANOARROW_TYPE_DATE32: { + int64_t value = + ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); + char* tsstr; + + if ((value > INT32_MAX) || (value < INT32_MIN)) { + SetError(error, + "Column %d has value %" PRId64 + " which exceeds the expected range " + "for an Arrow DATE32 type", + col, value); + return ADBC_STATUS_INVALID_DATA; + } + + RAISE_ADBC(ArrowDate32ToIsoString((int32_t)value, &tsstr, error)); + // SQLITE_TRANSIENT ensures the value is copied during bind + status = + sqlite3_bind_text(stmt, col + 1, tsstr, strlen(tsstr), SQLITE_TRANSIENT); + + free(tsstr); + break; + } + case NANOARROW_TYPE_TIMESTAMP: { + struct ArrowSchemaView bind_schema_view; + RAISE_ADBC(ArrowSchemaViewInit(&bind_schema_view, binder->schema.children[col], + &arrow_error)); + enum ArrowTimeUnit unit = bind_schema_view.time_unit; + int64_t value = + ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); + + char* tsstr; + RAISE_ADBC(ArrowTimestampToIsoString(value, unit, &tsstr, error)); + + // SQLITE_TRANSIENT ensures the value is copied during bind + status = + sqlite3_bind_text(stmt, col + 1, tsstr, strlen(tsstr), SQLITE_TRANSIENT); + free((char*)tsstr); + break; + } default: SetError(error, "Column %d has unsupported type %s", col, ArrowTypeString(binder->types[col])); @@ -255,7 +451,7 @@ void StatementReaderSetError(struct StatementReader* reader) { const char* msg = sqlite3_errmsg(reader->db); // Reset here so that we don't get an error again in StatementRelease (void)sqlite3_reset(reader->stmt); - strncpy(reader->error.message, msg, sizeof(reader->error.message)); + strncpy(reader->error.message, msg, sizeof(reader->error.message) - 1); reader->error.message[sizeof(reader->error.message) - 1] = '\0'; } @@ -398,7 +594,8 @@ int StatementReaderGetNext(struct ArrowArrayStream* self, struct ArrowArray* out reader->done = 1; status = EIO; if (error.release) { - strncpy(reader->error.message, error.message, sizeof(reader->error.message)); + strncpy(reader->error.message, error.message, + sizeof(reader->error.message) - 1); reader->error.message[sizeof(reader->error.message) - 1] = '\0'; error.release(&error); } @@ -597,7 +794,10 @@ AdbcStatusCode StatementReaderAppendInt64ToBinary(struct ArrowBuffer* offsets, int written = 0; while (1) { written = snprintf(output, buffer_size, "%" PRId64, value); - if (written >= buffer_size) { + if (written < 0) { + SetError(error, "Encoding error when upcasting double to string"); + return ADBC_STATUS_INTERNAL; + } else if (((size_t)written) >= buffer_size) { // Truncated, resize and try again // Check for overflow - presumably this can never happen...? if (UINT_MAX - buffer_size < buffer_size) { @@ -627,7 +827,10 @@ AdbcStatusCode StatementReaderAppendDoubleToBinary(struct ArrowBuffer* offsets, int written = 0; while (1) { written = snprintf(output, buffer_size, "%e", value); - if (written >= buffer_size) { + if (written < 0) { + SetError(error, "Encoding error when upcasting double to string"); + return ADBC_STATUS_INTERNAL; + } else if (((size_t)written) >= buffer_size) { // Truncated, resize and try again // Check for overflow - presumably this can never happen...? if (UINT_MAX - buffer_size < buffer_size) { @@ -855,7 +1058,7 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt, if (status == ADBC_STATUS_OK && !reader->done) { int64_t num_rows = 0; - while (num_rows < batch_size) { + while (((size_t)num_rows) < batch_size) { int rc = sqlite3_step(stmt); if (rc == SQLITE_DONE) { if (!binder) { diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/types.h b/3rd_party/apache-arrow-adbc/c/driver/sqlite/types.h index cd46f4f..c9e57e3 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/types.h +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/types.h @@ -50,8 +50,10 @@ struct SqliteStatement { struct AdbcSqliteBinder binder; // -- Ingest state ---------------------------------------- + char* target_catalog; char* target_table; char append; + char temporary; // -- Query options --------------------------------------- int batch_size; diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt index dd28470..6fb51d9 100644 --- a/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt @@ -55,13 +55,28 @@ if(ADBC_BUILD_TESTS) driver-manager SOURCES adbc_driver_manager_test.cc - ../validation/adbc_validation.cc - ../validation/adbc_validation_util.cc EXTRA_LINK_LIBS adbc_driver_common + adbc_validation nanoarrow ${TEST_LINK_LIBS}) target_compile_features(adbc-driver-manager-test PRIVATE cxx_std_17) target_include_directories(adbc-driver-manager-test SYSTEM PRIVATE ${REPOSITORY_ROOT}/c/vendor/nanoarrow/) + + add_test_case(version_100_compatibility_test + PREFIX + adbc + EXTRA_LABELS + driver-manager + SOURCES + adbc_version_100.c + adbc_version_100_compatibility_test.cc + EXTRA_LINK_LIBS + adbc_validation_util + nanoarrow + ${TEST_LINK_LIBS}) + target_compile_features(adbc-version-100-compatibility-test PRIVATE cxx_std_17) + target_include_directories(adbc-version-100-compatibility-test SYSTEM + PRIVATE ${REPOSITORY_ROOT}/c/vendor/nanoarrow/) endif() diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc index c63560a..c28bea9 100644 --- a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -90,17 +92,141 @@ void SetError(struct AdbcError* error, const std::string& message) { // Driver state -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); +/// A driver DLL. +struct ManagedLibrary { + ManagedLibrary() : handle(nullptr) {} + ManagedLibrary(ManagedLibrary&& other) : handle(other.handle) { + other.handle = nullptr; + } + ManagedLibrary(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(const ManagedLibrary&) = delete; + ManagedLibrary& operator=(ManagedLibrary&& other) noexcept { + this->handle = other.handle; + other.handle = nullptr; + return *this; + } + + ~ManagedLibrary() { Release(); } + + void Release() { + // TODO(apache/arrow-adbc#204): causes tests to segfault + // Need to refcount the driver DLL; also, errors may retain a reference to + // release() from the DLL - how to handle this? + } + + AdbcStatusCode Load(const char* library, struct AdbcError* error) { + std::string error_message; +#if defined(_WIN32) + HMODULE handle = LoadLibraryExA(library, NULL, 0); + if (!handle) { + error_message += library; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + + std::string full_driver_name = library; + full_driver_name += ".dll"; + handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); + if (!handle) { + error_message += '\n'; + error_message += full_driver_name; + error_message += ": LoadLibraryExA() failed: "; + GetWinError(&error_message); + } + } + if (!handle) { + SetError(error, error_message); + return ADBC_STATUS_INTERNAL; + } else { + this->handle = handle; + } +#else + static const std::string kPlatformLibraryPrefix = "lib"; +#if defined(__APPLE__) + static const std::string kPlatformLibrarySuffix = ".dylib"; +#else + static const std::string kPlatformLibrarySuffix = ".so"; +#endif // defined(__APPLE__) + + void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message = "dlopen() failed: "; + error_message += dlerror(); + + // If applicable, append the shared library prefix/extension and + // try again (this way you don't have to hardcode driver names by + // platform in the application) + const std::string driver_str = library; + + std::string full_driver_name; + if (driver_str.size() < kPlatformLibraryPrefix.size() || + driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != + 0) { + full_driver_name += kPlatformLibraryPrefix; + } + full_driver_name += library; + if (driver_str.size() < kPlatformLibrarySuffix.size() || + driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix.size(), + kPlatformLibrarySuffix) != 0) { + full_driver_name += kPlatformLibrarySuffix; + } + handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!handle) { + error_message += "\ndlopen() failed: "; + error_message += dlerror(); + } + } + if (handle) { + this->handle = handle; + } else { + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + return ADBC_STATUS_OK; + } + + AdbcStatusCode Lookup(const char* name, void** func, struct AdbcError* error) { +#if defined(_WIN32) + void* load_handle = reinterpret_cast(GetProcAddress(handle, name)); + if (!load_handle) { + std::string message = "GetProcAddress("; + message += name; + message += ") failed: "; + GetWinError(&message); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#else + void* load_handle = dlsym(handle, name); + if (!load_handle) { + std::string message = "dlsym("; + message += name; + message += ") failed: "; + message += dlerror(); + SetError(error, message); + return ADBC_STATUS_INTERNAL; + } +#endif // defined(_WIN32) + *func = load_handle; + return ADBC_STATUS_OK; + } #if defined(_WIN32) // The loaded DLL HMODULE handle; +#else + void* handle; #endif // defined(_WIN32) }; +/// Hold the driver DLL and the driver release callback in the driver struct. +struct ManagerDriverState { + // The original release callback + AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error); + + ManagedLibrary handle; +}; + /// Unload the driver DLL. static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) { AdbcStatusCode status = ADBC_STATUS_OK; @@ -112,35 +238,132 @@ static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* if (state->driver_release) { status = state->driver_release(driver, error); } - -#if defined(_WIN32) - // TODO(apache/arrow-adbc#204): causes tests to segfault - // if (!FreeLibrary(state->handle)) { - // std::string message = "FreeLibrary() failed: "; - // GetWinError(&message); - // SetError(error, message); - // } -#endif // defined(_WIN32) + state->handle.Release(); driver->private_manager = nullptr; delete state; return status; } +// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream + +struct ErrorArrayStream { + struct ArrowArrayStream stream; + struct AdbcDriver* private_driver; +}; + +void ErrorArrayStreamRelease(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return; + + auto* private_data = reinterpret_cast(stream->private_data); + private_data->stream.release(&private_data->stream); + delete private_data; + std::memset(stream, 0, sizeof(*stream)); +} + +const char* ErrorArrayStreamGetLastError(struct ArrowArrayStream* stream) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return nullptr; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_last_error(&private_data->stream); +} + +int ErrorArrayStreamGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_next(&private_data->stream, array); +} + +int ErrorArrayStreamGetSchema(struct ArrowArrayStream* stream, + struct ArrowSchema* schema) { + if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL; + auto* private_data = reinterpret_cast(stream->private_data); + return private_data->stream.get_schema(&private_data->stream, schema); +} + // Default stubs +int ErrorGetDetailCount(const struct AdbcError* error) { return 0; } + +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError* error, int index) { + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + return nullptr; +} + +void ErrorArrayStreamInit(struct ArrowArrayStream* out, + struct AdbcDriver* private_driver) { + if (!out || !out->release || + // Don't bother wrapping if driver didn't claim support + private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { + return; + } + struct ErrorArrayStream* private_data = new ErrorArrayStream; + private_data->stream = *out; + private_data->private_driver = private_driver; + out->get_last_error = ErrorArrayStreamGetLastError; + out->get_next = ErrorArrayStreamGetNext; + out->get_schema = ErrorArrayStreamGetSchema; + out->release = ErrorArrayStreamRelease; + out->private_data = private_data; +} + +AdbcStatusCode DatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode DatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, uint32_t* info_codes, - size_t info_codes_length, struct ArrowArrayStream* out, - struct AdbcError* error) { +AdbcStatusCode ConnectionGetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } @@ -150,6 +373,39 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, co return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection*, const char*, const char*, + const char*, char, struct ArrowArrayStream*, + struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, const char*, const char*, struct ArrowSchema*, struct AdbcError* error) { @@ -178,11 +434,31 @@ AdbcStatusCode ConnectionSetOption(struct AdbcConnection*, const char*, const ch return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection*, const char*, + const uint8_t*, size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection* connection, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*, struct ArrowSchema*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementCancel(struct AdbcStatement* statement, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcPartitions* partitions, @@ -191,6 +467,33 @@ AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement, return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement* statement, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + +AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement* statement, const char* key, + double* value, struct AdbcError* error) { + return ADBC_STATUS_NOT_FOUND; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement, struct ArrowSchema* schema, struct AdbcError* error) { @@ -206,6 +509,21 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement*, const char*, const char return ADBC_STATUS_NOT_IMPLEMENTED; } +AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement*, const char*, const uint8_t*, + size_t, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement* statement, const char* key, + double value, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement*, const char*, struct AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -219,20 +537,134 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*, /// Temporary state while the database is being configured. struct TempDatabase { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; std::string driver; - // Default name (see adbc.h) - std::string entrypoint = "AdbcDriverInit"; + std::string entrypoint; AdbcDriverInitFunc init_func = nullptr; }; /// Temporary state while the database is being configured. struct TempConnection { std::unordered_map options; + std::unordered_map bytes_options; + std::unordered_map int_options; + std::unordered_map double_options; }; + +static const char kDefaultEntrypoint[] = "AdbcDriverInit"; } // namespace +// Other helpers (intentionally not in an anonymous namespace so they can be tested) + +ADBC_EXPORT +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { + /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit + /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit + /// - proprietary_driver.dll -> AdbcProprietaryDriverInit + + // Potential path -> filename + // Treat both \ and / as directory separators on all platforms for simplicity + std::string filename; + { + size_t pos = driver.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = driver.substr(pos + 1); + } else { + filename = driver; + } + } + + // Remove all extensions + { + size_t pos = filename.find('.'); + if (pos != std::string::npos) { + filename = filename.substr(0, pos); + } + } + + // Remove lib prefix + // https://stackoverflow.com/q/1878001/262727 + if (filename.rfind("lib", 0) == 0) { + filename = filename.substr(3); + } + + // Split on underscores, hyphens + // Capitalize and join + std::string entrypoint; + entrypoint.reserve(filename.size()); + size_t pos = 0; + while (pos < filename.size()) { + size_t prev = pos; + pos = filename.find_first_of("-_", pos); + // if pos == npos this is the entire filename + std::string token = filename.substr(prev, pos - prev); + // capitalize first letter + token[0] = std::toupper(static_cast(token[0])); + + entrypoint += token; + + if (pos != std::string::npos) { + pos++; + } + } + + if (entrypoint.rfind("Adbc", 0) != 0) { + entrypoint = "Adbc" + entrypoint; + } + entrypoint += "Init"; + + return entrypoint; +} + // Direct implementations of API methods +int AdbcErrorGetDetailCount(const struct AdbcError* error) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetailCount(error); + } + return 0; +} + +struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && + error->private_driver) { + return error->private_driver->ErrorGetDetail(error, index); + } + return {nullptr, nullptr, 0}; +} + +const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, + AdbcStatusCode* status) { + if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { + return nullptr; + } + auto* private_data = reinterpret_cast(stream->private_data); + auto* error = + private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); + if (error) { + const_cast(error)->private_driver = private_data->private_driver; + } + return error; +} + +#define INIT_ERROR(ERROR, SOURCE) \ + if ((ERROR) != nullptr && \ + (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ + (ERROR)->private_driver = (SOURCE)->private_driver; \ + } + +#define WRAP_STREAM(EXPR, OUT, SOURCE) \ + if (!(OUT)) { \ + /* Happens for ExecuteQuery where out is optional */ \ + return EXPR; \ + } \ + AdbcStatusCode status_code = EXPR; \ + ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ + return status_code; + AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { // Allocate a temporary structure to store options pre-Init database->private_data = new TempDatabase(); @@ -240,9 +672,93 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOption(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const std::string* result = nullptr; + if (std::strcmp(key, "driver") == 0) { + result = &args->driver; + } else if (std::strcmp(key, "entrypoint") == 0) { + result = &args->entrypoint; + } else { + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + result = &it->second; + } + + if (*length <= result->size() + 1) { + // Enough space + std::memcpy(value, result->c_str(), result->size() + 1); + } + *length = result->size() + 1; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, + error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + const std::string& result = it->second; + + if (*length <= result.size()) { + // Enough space + std::memcpy(value, result.c_str(), result.size()); + } + *length = result.size(); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionInt(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); + } + const auto* args = reinterpret_cast(database->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (database->private_driver) { + INIT_ERROR(error, database); return database->private_driver->DatabaseSetOption(database, key, value, error); } @@ -257,6 +773,44 @@ AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* return ADBC_STATUS_OK; } +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, + error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionInt(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + if (database->private_driver) { + INIT_ERROR(error, database); + return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); + } + + TempDatabase* args = reinterpret_cast(database->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; +} + AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase* database, AdbcDriverInitFunc init_func, struct AdbcError* error) { @@ -288,11 +842,14 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* // So we don't confuse a driver into thinking it's initialized already database->private_data = nullptr; if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, + status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else { + } else if (!args->entrypoint.empty()) { status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), - ADBC_VERSION_1_0_0, database->private_driver, error); + ADBC_VERSION_1_1_0, database->private_driver, error); + } else { + status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, + database->private_driver, error); } if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease @@ -313,25 +870,49 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - for (const auto& option : args->options) { + auto options = std::move(args->options); + auto bytes_options = std::move(args->bytes_options); + auto int_options = std::move(args->int_options); + auto double_options = std::move(args->double_options); + delete args; + + INIT_ERROR(error, database); + for (const auto& option : options) { status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) { - delete args; - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : bytes_options) { + status = database->private_driver->DatabaseSetOptionBytes( + database, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : int_options) { + status = database->private_driver->DatabaseSetOptionInt( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + for (const auto& option : double_options) { + status = database->private_driver->DatabaseSetOptionDouble( + database, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) break; + } + + if (status != ADBC_STATUS_OK) { + // Release the database + std::ignore = database->private_driver->DatabaseRelease(database, error); + if (database->private_driver->release) { + database->private_driver->release(database->private_driver, error); } + delete database->private_driver; + database->private_driver = nullptr; + // Should be redundant, but ensure that AdbcDatabaseRelease + // below doesn't think that it contains a TempDatabase + database->private_data = nullptr; + return status; } - delete args; return database->private_driver->DatabaseInit(database, error); } @@ -346,6 +927,7 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, database); auto status = database->private_driver->DatabaseRelease(database, error); if (database->private_driver->release) { database->private_driver->release(database->private_driver, error); @@ -356,23 +938,35 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, return status; } +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionCancel(connection, error); +} + AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, - uint32_t* info_codes, size_t info_codes_length, + const uint32_t* info_codes, size_t info_codes_length, struct ArrowArrayStream* out, struct AdbcError* error) { if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, - info_codes_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetInfo( + connection, info_codes, info_codes_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -384,9 +978,132 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetObjects( - connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, - error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_types, + column_name, stream, error), + stream, connection); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.c_str(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOption(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->bytes_options.find(key); + if (it == args->options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + if (*length >= it->second.size() + 1) { + std::memcpy(value, it->second.data(), it->second.size() + 1); + } + *length = it->second.size() + 1; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->int_options.find(key); + if (it == args->int_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, get the saved option + const auto* args = reinterpret_cast(connection->private_data); + const auto it = args->double_options.find(key); + if (it == args->double_options.end()) { + return ADBC_STATUS_NOT_FOUND; + } + *value = it->second; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatistics( + connection, catalog, db_schema, table_name, approximate == 1, out, error), + out, connection); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetStatisticNames(connection, out, error), + out, connection); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -397,6 +1114,7 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionGetTableSchema( connection, catalog, db_schema, table_name, schema, error); } @@ -407,7 +1125,10 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); + INIT_ERROR(error, connection); + WRAP_STREAM( + connection->private_driver->ConnectionGetTableTypes(connection, stream, error), + stream, connection); } AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, @@ -423,6 +1144,11 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, TempConnection* args = reinterpret_cast(connection->private_data); connection->private_data = nullptr; std::unordered_map options = std::move(args->options); + std::unordered_map bytes_options = + std::move(args->bytes_options); + std::unordered_map int_options = std::move(args->int_options); + std::unordered_map double_options = + std::move(args->double_options); delete args; auto status = database->private_driver->ConnectionNew(connection, error); @@ -434,6 +1160,24 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, connection, option.first.c_str(), option.second.c_str(), error); if (status != ADBC_STATUS_OK) return status; } + for (const auto& option : bytes_options) { + status = database->private_driver->ConnectionSetOptionBytes( + connection, option.first.c_str(), + reinterpret_cast(option.second.data()), option.second.size(), + error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : int_options) { + status = database->private_driver->ConnectionSetOptionInt( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + for (const auto& option : double_options) { + status = database->private_driver->ConnectionSetOptionDouble( + connection, option.first.c_str(), option.second, error); + if (status != ADBC_STATUS_OK) return status; + } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionInit(connection, database, error); } @@ -455,8 +1199,10 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionReadPartition( - connection, serialized_partition, serialized_length, out, error); + INIT_ERROR(error, connection); + WRAP_STREAM(connection->private_driver->ConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error), + out, connection); } AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, @@ -470,6 +1216,7 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->ConnectionRelease(connection, error); connection->private_driver = nullptr; return status; @@ -480,6 +1227,7 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionRollback(connection, error); } @@ -495,15 +1243,71 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const args->options[key] = value; return ADBC_STATUS_OK; } + INIT_ERROR(error, connection); return connection->private_driver->ConnectionSetOption(connection, key, value, error); } +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->bytes_options[key] = std::string(reinterpret_cast(value), length); + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, + length, error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->int_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionInt(connection, key, value, + error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + if (!connection->private_data) { + SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); + return ADBC_STATUS_INVALID_STATE; + } + if (!connection->private_driver) { + // Init not yet called, save the option + TempConnection* args = reinterpret_cast(connection->private_data); + args->double_options[key] = value; + return ADBC_STATUS_OK; + } + INIT_ERROR(error, connection); + return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, + error); +} + AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, struct ArrowArray* values, struct ArrowSchema* schema, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBind(statement, values, schema, error); } @@ -513,9 +1317,19 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementBindStream(statement, stream, error); } +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementCancel(statement, error); +} + // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, ArrowSchema* schema, @@ -525,6 +1339,7 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementExecutePartitions( statement, schema, partitions, rows_affected, error); } @@ -536,8 +1351,62 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, - error); + INIT_ERROR(error, statement); + WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, + rows_affected, error), + out, statement); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOption(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementGetOptionDouble(statement, key, value, + error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -546,6 +1415,7 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementGetParameterSchema(statement, schema, error); } @@ -555,6 +1425,7 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, connection); auto status = connection->private_driver->StatementNew(connection, statement, error); statement->private_driver = connection->private_driver; return status; @@ -565,6 +1436,7 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementPrepare(statement, error); } @@ -573,6 +1445,7 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); auto status = statement->private_driver->StatementRelease(statement, error); statement->private_driver = nullptr; return status; @@ -583,14 +1456,47 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetOption(statement, key, value, error); } +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionInt(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + if (!statement->private_driver) { + return ADBC_STATUS_INVALID_STATE; + } + INIT_ERROR(error, statement); + return statement->private_driver->StatementSetOptionDouble(statement, key, value, + error); +} + AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, const char* query, struct AdbcError* error) { if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSqlQuery(statement, query, error); } @@ -600,39 +1506,36 @@ AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } + INIT_ERROR(error, statement); return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); } const char* AdbcStatusCodeMessage(AdbcStatusCode code) { -#define STRINGIFY(s) #s -#define STRINGIFY_VALUE(s) STRINGIFY(s) -#define CASE(CONSTANT) \ - case CONSTANT: \ - return #CONSTANT " (" STRINGIFY_VALUE(CONSTANT) ")"; +#define CASE(CONSTANT) \ + case ADBC_STATUS_##CONSTANT: \ + return #CONSTANT; switch (code) { - CASE(ADBC_STATUS_OK); - CASE(ADBC_STATUS_UNKNOWN); - CASE(ADBC_STATUS_NOT_IMPLEMENTED); - CASE(ADBC_STATUS_NOT_FOUND); - CASE(ADBC_STATUS_ALREADY_EXISTS); - CASE(ADBC_STATUS_INVALID_ARGUMENT); - CASE(ADBC_STATUS_INVALID_STATE); - CASE(ADBC_STATUS_INVALID_DATA); - CASE(ADBC_STATUS_INTEGRITY); - CASE(ADBC_STATUS_INTERNAL); - CASE(ADBC_STATUS_IO); - CASE(ADBC_STATUS_CANCELLED); - CASE(ADBC_STATUS_TIMEOUT); - CASE(ADBC_STATUS_UNAUTHENTICATED); - CASE(ADBC_STATUS_UNAUTHORIZED); + CASE(OK); + CASE(UNKNOWN); + CASE(NOT_IMPLEMENTED); + CASE(NOT_FOUND); + CASE(ALREADY_EXISTS); + CASE(INVALID_ARGUMENT); + CASE(INVALID_STATE); + CASE(INVALID_DATA); + CASE(INTEGRITY); + CASE(INTERNAL); + CASE(IO); + CASE(CANCELLED); + CASE(TIMEOUT); + CASE(UNAUTHENTICATED); + CASE(UNAUTHORIZED); default: return "(invalid code)"; } #undef CASE -#undef STRINGIFY_VALUE -#undef STRINGIFY } AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, @@ -640,137 +1543,80 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, AdbcDriverInitFunc init_func; std::string error_message; - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - auto* driver = reinterpret_cast(raw_driver); - - if (!entrypoint) { - // Default entrypoint (see adbc.h) - entrypoint = "AdbcDriverInit"; + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; } -#if defined(_WIN32) - - HMODULE handle = LoadLibraryExA(driver_name, NULL, 0); - if (!handle) { - error_message += driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = driver_name; - full_driver_name += ".lib"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; } + auto* driver = reinterpret_cast(raw_driver); - void* load_handle = reinterpret_cast(GetProcAddress(handle, entrypoint)); - init_func = reinterpret_cast(load_handle); - if (!init_func) { - std::string message = "GetProcAddress("; - message += entrypoint; - message += ") failed: "; - GetWinError(&message); - if (!FreeLibrary(handle)) { - message += "\nFreeLibrary() failed: "; - GetWinError(&message); - } - SetError(error, message); - return ADBC_STATUS_INTERNAL; + ManagedLibrary library; + AdbcStatusCode status = library.Load(driver_name, error); + if (status != ADBC_STATUS_OK) { + // AdbcDatabaseInit tries to call this if set + driver->release = nullptr; + return status; } -#else - -#if defined(__APPLE__) - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibraryPrefix = "lib"; - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = driver_name; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != - 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += driver_name; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); + void* load_handle = nullptr; + if (entrypoint) { + status = library.Lookup(entrypoint, &load_handle, error); + } else { + auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); + status = library.Lookup(name.c_str(), &load_handle, error); + if (status != ADBC_STATUS_OK) { + status = library.Lookup(kDefaultEntrypoint, &load_handle, error); } } - if (!handle) { - SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return ADBC_STATUS_INTERNAL; - } - void* load_handle = dlsym(handle, entrypoint); - if (!load_handle) { - std::string message = "dlsym("; - message += entrypoint; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; + if (status != ADBC_STATUS_OK) { + library.Release(); + return status; } init_func = reinterpret_cast(load_handle); -#endif // defined(_WIN32) - - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); if (status == ADBC_STATUS_OK) { ManagerDriverState* state = new ManagerDriverState; state->driver_release = driver->release; -#if defined(_WIN32) - state->handle = handle; -#endif // defined(_WIN32) + state->handle = std::move(library); driver->release = &ReleaseDriver; driver->private_manager = state; } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); - } -#endif // defined(_WIN32) + library.Release(); } return status; } AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void* raw_driver, struct AdbcError* error) { + constexpr std::array kSupportedVersions = { + ADBC_VERSION_1_1_0, + ADBC_VERSION_1_0_0, + }; + + if (!raw_driver) { + SetError(error, "Must provide non-NULL raw_driver"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + switch (version) { + case ADBC_VERSION_1_0_0: + case ADBC_VERSION_1_1_0: + break; + default: + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + #define FILL_DEFAULT(DRIVER, STUB) \ if (!DRIVER->STUB) { \ DRIVER->STUB = &STUB; \ @@ -781,12 +1627,20 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers return ADBC_STATUS_INTERNAL; \ } - auto result = init_func(version, raw_driver, error); + // Starting from the passed version, try each (older) version in + // succession with the underlying driver until we find one that's + // accepted. + AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; + for (const int try_version : kSupportedVersions) { + if (try_version > version) continue; + result = init_func(try_version, raw_driver, error); + if (result != ADBC_STATUS_NOT_IMPLEMENTED) break; + } if (result != ADBC_STATUS_OK) { return result; } - if (version == ADBC_VERSION_1_0_0) { + if (version >= ADBC_VERSION_1_0_0) { auto* driver = reinterpret_cast(raw_driver); CHECK_REQUIRED(driver, DatabaseNew); CHECK_REQUIRED(driver, DatabaseInit); @@ -816,6 +1670,41 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers FILL_DEFAULT(driver, StatementSetSqlQuery); FILL_DEFAULT(driver, StatementSetSubstraitPlan); } + if (version >= ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + FILL_DEFAULT(driver, ErrorGetDetailCount); + FILL_DEFAULT(driver, ErrorGetDetail); + FILL_DEFAULT(driver, ErrorFromArrayStream); + + FILL_DEFAULT(driver, DatabaseGetOption); + FILL_DEFAULT(driver, DatabaseGetOptionBytes); + FILL_DEFAULT(driver, DatabaseGetOptionDouble); + FILL_DEFAULT(driver, DatabaseGetOptionInt); + FILL_DEFAULT(driver, DatabaseSetOptionBytes); + FILL_DEFAULT(driver, DatabaseSetOptionDouble); + FILL_DEFAULT(driver, DatabaseSetOptionInt); + + FILL_DEFAULT(driver, ConnectionCancel); + FILL_DEFAULT(driver, ConnectionGetOption); + FILL_DEFAULT(driver, ConnectionGetOptionBytes); + FILL_DEFAULT(driver, ConnectionGetOptionDouble); + FILL_DEFAULT(driver, ConnectionGetOptionInt); + FILL_DEFAULT(driver, ConnectionGetStatistics); + FILL_DEFAULT(driver, ConnectionGetStatisticNames); + FILL_DEFAULT(driver, ConnectionSetOptionBytes); + FILL_DEFAULT(driver, ConnectionSetOptionDouble); + FILL_DEFAULT(driver, ConnectionSetOptionInt); + + FILL_DEFAULT(driver, StatementCancel); + FILL_DEFAULT(driver, StatementExecuteSchema); + FILL_DEFAULT(driver, StatementGetOption); + FILL_DEFAULT(driver, StatementGetOptionBytes); + FILL_DEFAULT(driver, StatementGetOptionDouble); + FILL_DEFAULT(driver, StatementGetOptionInt); + FILL_DEFAULT(driver, StatementSetOptionBytes); + FILL_DEFAULT(driver, StatementSetOptionDouble); + FILL_DEFAULT(driver, StatementSetOptionInt); + } return ADBC_STATUS_OK; diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc index 99fa477..100feab 100644 --- a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc @@ -27,10 +27,13 @@ #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" +std::string AdbcDriverManagerDefaultEntrypoint(const std::string& filename); + // Tests of the SQLite example driver, except using the driver manager namespace adbc { +using adbc_validation::Handle; using adbc_validation::IsOkStatus; using adbc_validation::IsStatus; @@ -40,7 +43,7 @@ class DriverManager : public ::testing::Test { std::memset(&driver, 0, sizeof(driver)); std::memset(&error, 0, sizeof(error)); - ASSERT_THAT(AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_0_0, &driver, + ASSERT_THAT(AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_1_0, &driver, &error), IsOkStatus(&error)); } @@ -186,12 +189,34 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_FLOAT: case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; + case NANOARROW_TYPE_LARGE_STRING: + return NANOARROW_TYPE_STRING; default: return ingest_type; } } + bool supports_bulk_ingest(const char* mode) const override { + return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || + std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; + } bool supports_concurrent_statements() const override { return true; } + bool supports_get_option() const override { return false; } + std::optional supports_get_sql_info( + uint32_t info_code) const override { + switch (info_code) { + case ADBC_INFO_DRIVER_NAME: + return "ADBC SQLite Driver"; + case ADBC_INFO_DRIVER_VERSION: + return "(unknown)"; + case ADBC_INFO_VENDOR_NAME: + return "SQLite"; + case ADBC_INFO_VENDOR_VERSION: + return "3."; + default: + return std::nullopt; + } + } }; class SqliteDatabaseTest : public ::testing::Test, public adbc_validation::DatabaseTest { @@ -205,6 +230,20 @@ class SqliteDatabaseTest : public ::testing::Test, public adbc_validation::Datab }; ADBCV_TEST_DATABASE(SqliteDatabaseTest) +TEST_F(SqliteDatabaseTest, NullError) { + Handle conn; + + ASSERT_THAT(AdbcDatabaseNew(&database, nullptr), IsOkStatus()); + ASSERT_THAT(quirks()->SetupDatabase(&database, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcDatabaseInit(&database, nullptr), IsOkStatus()); + + ASSERT_THAT(AdbcConnectionNew(&conn.value, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcConnectionInit(&conn.value, &database, nullptr), IsOkStatus()); + ASSERT_THAT(AdbcConnectionRelease(&conn.value, nullptr), IsOkStatus()); + + ASSERT_THAT(AdbcDatabaseRelease(&database, nullptr), IsOkStatus()); +} + class SqliteConnectionTest : public ::testing::Test, public adbc_validation::ConnectionTest { public: @@ -226,10 +265,59 @@ class SqliteStatementTest : public ::testing::Test, void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; } void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; } + void TestSqlIngestDate32() { GTEST_SKIP() << "Cannot ingest DATE (not implemented)"; } + void TestSqlIngestTimestamp() { + GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)"; + } + void TestSqlIngestTimestampTz() { + GTEST_SKIP() << "Cannot ingest TIMESTAMP WITH TIMEZONE (not implemented)"; + } + void TestSqlIngestDuration() { + GTEST_SKIP() << "Cannot ingest DURATION (not implemented)"; + } + void TestSqlIngestInterval() { + GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; + } protected: SqliteQuirks quirks_; }; ADBCV_TEST_STATEMENT(SqliteStatementTest) +TEST(AdbcDriverManagerInternal, AdbcDriverManagerDefaultEntrypoint) { + for (const auto& driver : { + "adbc_driver_sqlite", + "adbc_driver_sqlite.dll", + "driver_sqlite", + "libadbc_driver_sqlite", + "libadbc_driver_sqlite.so", + "libadbc_driver_sqlite.so.6.0.0", + "/usr/lib/libadbc_driver_sqlite.so", + "/usr/lib/libadbc_driver_sqlite.so.6.0.0", + "C:\\System32\\adbc_driver_sqlite.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcDriverSqliteInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } + + for (const auto& driver : { + "adbc_sqlite", + "sqlite", + "/usr/lib/sqlite.so", + "C:\\System32\\sqlite.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcSqliteInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } + + for (const auto& driver : { + "proprietary_engine", + "libproprietary_engine.so.6.0.0", + "/usr/lib/proprietary_engine.so", + "C:\\System32\\proprietary_engine.dll", + }) { + SCOPED_TRACE(driver); + EXPECT_EQ("AdbcProprietaryEngineInit", ::AdbcDriverManagerDefaultEntrypoint(driver)); + } +} } // namespace adbc diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100.c b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100.c new file mode 100644 index 0000000..48114cd --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100.c @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "adbc_version_100.h" + +#include + +struct Version100Database { + int dummy; +}; + +static struct Version100Database kDatabase; + +struct Version100Connection { + int dummy; +}; + +static struct Version100Connection kConnection; + +struct Version100Statement { + int dummy; +}; + +static struct Version100Statement kStatement; + +AdbcStatusCode Version100DatabaseInit(struct AdbcDatabase* database, + struct AdbcError* error) { + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DatabaseNew(struct AdbcDatabase* database, + struct AdbcError* error) { + database->private_data = &kDatabase; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error) { + database->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + connection->private_data = &kConnection; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100StatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + int64_t* rows_affected, + struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; +} + +AdbcStatusCode Version100StatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + statement->private_data = &kStatement; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100StatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + statement->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100ConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + connection->private_data = NULL; + return ADBC_STATUS_OK; +} + +AdbcStatusCode Version100DriverInit(int version, void* raw_driver, + struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + struct AdbcDriverVersion100* driver = (struct AdbcDriverVersion100*)raw_driver; + memset(driver, 0, sizeof(struct AdbcDriverVersion100)); + + driver->DatabaseInit = &Version100DatabaseInit; + driver->DatabaseNew = &Version100DatabaseNew; + driver->DatabaseRelease = &Version100DatabaseRelease; + + driver->ConnectionInit = &Version100ConnectionInit; + driver->ConnectionNew = &Version100ConnectionNew; + driver->ConnectionRelease = &Version100ConnectionRelease; + + driver->StatementExecuteQuery = &Version100StatementExecuteQuery; + driver->StatementNew = &Version100StatementNew; + driver->StatementRelease = &Version100StatementRelease; + + return ADBC_STATUS_OK; +} diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100.h b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100.h new file mode 100644 index 0000000..b349f86 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100.h @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A dummy version 1.0.0 ADBC driver to test compatibility. + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct AdbcErrorVersion100 { + char* message; + int32_t vendor_code; + char sqlstate[5]; + void (*release)(struct AdbcError* error); +}; + +struct AdbcDriverVersion100 { + void* private_data; + void* private_manager; + AdbcStatusCode (*release)(struct AdbcDriver* driver, struct AdbcError* error); + + AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); + + AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, + const char*, const char*, const char**, + const char*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableSchema)(struct AdbcConnection*, const char*, + const char*, const char*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableTypes)(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcDatabase*, + struct AdbcError*); + AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*ConnectionReadPartition)(struct AdbcConnection*, const uint8_t*, + size_t, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionRollback)(struct AdbcConnection*, struct AdbcError*); + + AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*StatementExecuteQuery)(struct AdbcStatement*, struct ArrowArrayStream*, + int64_t*, struct AdbcError*); + AdbcStatusCode (*StatementExecutePartitions)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcPartitions*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetParameterSchema)(struct AdbcStatement*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementNew)(struct AdbcConnection*, struct AdbcStatement*, + struct AdbcError*); + AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, + size_t, struct AdbcError*); +}; + +AdbcStatusCode Version100DriverInit(int version, void* driver, struct AdbcError* error); + +#ifdef __cplusplus +} +#endif diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100_compatibility_test.cc b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100_compatibility_test.cc new file mode 100644 index 0000000..27e5f5d --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_version_100_compatibility_test.cc @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "adbc.h" +#include "adbc_driver_manager.h" +#include "adbc_version_100.h" +#include "validation/adbc_validation_util.h" + +namespace adbc { + +using adbc_validation::IsOkStatus; +using adbc_validation::IsStatus; + +class AdbcVersion : public ::testing::Test { + public: + void SetUp() override { + std::memset(&driver, 0, sizeof(driver)); + std::memset(&error, 0, sizeof(error)); + } + + void TearDown() override { + if (error.release) { + error.release(&error); + } + + if (driver.release) { + ASSERT_THAT(driver.release(&driver, &error), IsOkStatus(&error)); + ASSERT_EQ(driver.private_data, nullptr); + ASSERT_EQ(driver.private_manager, nullptr); + } + } + + protected: + struct AdbcDriver driver = {}; + struct AdbcError error = {}; +}; + +TEST_F(AdbcVersion, StructSize) { + ASSERT_EQ(sizeof(AdbcErrorVersion100), ADBC_ERROR_1_0_0_SIZE); + ASSERT_EQ(sizeof(AdbcError), ADBC_ERROR_1_1_0_SIZE); + + ASSERT_EQ(sizeof(AdbcDriverVersion100), ADBC_DRIVER_1_0_0_SIZE); + ASSERT_EQ(sizeof(AdbcDriver), ADBC_DRIVER_1_1_0_SIZE); +} + +// Initialize a version 1.0.0 driver with the version 1.1.0 driver struct. +TEST_F(AdbcVersion, OldDriverNewLayout) { + ASSERT_THAT(Version100DriverInit(ADBC_VERSION_1_1_0, &driver, &error), + IsStatus(ADBC_STATUS_NOT_IMPLEMENTED, &error)); + + ASSERT_THAT(Version100DriverInit(ADBC_VERSION_1_0_0, &driver, &error), + IsOkStatus(&error)); +} + +// Initialize a version 1.0.0 driver with the new driver manager/new version. +TEST_F(AdbcVersion, OldDriverNewManager) { + ASSERT_THAT(AdbcLoadDriverFromInitFunc(&Version100DriverInit, ADBC_VERSION_1_1_0, + &driver, &error), + IsOkStatus(&error)); + + EXPECT_NE(driver.ErrorGetDetailCount, nullptr); + EXPECT_NE(driver.ErrorGetDetail, nullptr); + + EXPECT_NE(driver.DatabaseGetOption, nullptr); + EXPECT_NE(driver.DatabaseGetOptionBytes, nullptr); + EXPECT_NE(driver.DatabaseGetOptionDouble, nullptr); + EXPECT_NE(driver.DatabaseGetOptionInt, nullptr); + EXPECT_NE(driver.DatabaseSetOptionInt, nullptr); + EXPECT_NE(driver.DatabaseSetOptionDouble, nullptr); + + EXPECT_NE(driver.ConnectionCancel, nullptr); + EXPECT_NE(driver.ConnectionGetOption, nullptr); + EXPECT_NE(driver.ConnectionGetOptionBytes, nullptr); + EXPECT_NE(driver.ConnectionGetOptionDouble, nullptr); + EXPECT_NE(driver.ConnectionGetOptionInt, nullptr); + EXPECT_NE(driver.ConnectionSetOptionInt, nullptr); + EXPECT_NE(driver.ConnectionSetOptionDouble, nullptr); + + EXPECT_NE(driver.StatementCancel, nullptr); + EXPECT_NE(driver.StatementExecuteSchema, nullptr); + EXPECT_NE(driver.StatementGetOption, nullptr); + EXPECT_NE(driver.StatementGetOptionBytes, nullptr); + EXPECT_NE(driver.StatementGetOptionDouble, nullptr); + EXPECT_NE(driver.StatementGetOptionInt, nullptr); + EXPECT_NE(driver.StatementSetOptionInt, nullptr); + EXPECT_NE(driver.StatementSetOptionDouble, nullptr); +} + +// N.B. see postgresql_test.cc for backwards compatibility test of AdbcError +// N.B. see postgresql_test.cc for backwards compatibility test of AdbcDriver + +} // namespace adbc diff --git a/3rd_party/apache-arrow-adbc/c/integration/duckdb/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/integration/duckdb/CMakeLists.txt index 52fb9d0..8053713 100644 --- a/3rd_party/apache-arrow-adbc/c/integration/duckdb/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/integration/duckdb/CMakeLists.txt @@ -49,6 +49,7 @@ if(ADBC_BUILD_TESTS) CACHE INTERNAL "Disable UBSAN") # Force cmake to honor our options here in the subproject cmake_policy(SET CMP0077 NEW) + message(STATUS "Fetching DuckDB") fetchcontent_makeavailable(duckdb) include_directories(SYSTEM ${REPOSITORY_ROOT}) diff --git a/3rd_party/apache-arrow-adbc/c/integration/duckdb/duckdb_test.cc b/3rd_party/apache-arrow-adbc/c/integration/duckdb/duckdb_test.cc index b2a5c52..a6bded0 100644 --- a/3rd_party/apache-arrow-adbc/c/integration/duckdb/duckdb_test.cc +++ b/3rd_party/apache-arrow-adbc/c/integration/duckdb/duckdb_test.cc @@ -46,7 +46,7 @@ class DuckDbQuirks : public adbc_validation::DriverQuirks { std::string BindParameter(int index) const override { return "?"; } - bool supports_bulk_ingest() const override { return false; } + bool supports_bulk_ingest(const char* /*mode*/) const override { return false; } bool supports_concurrent_statements() const override { return true; } bool supports_dynamic_parameter_binding() const override { return false; } bool supports_get_sql_info() const override { return false; } @@ -75,6 +75,7 @@ class DuckDbConnectionTest : public ::testing::Test, void TestAutocommitDefault() { GTEST_SKIP(); } void TestMetadataGetTableSchema() { GTEST_SKIP(); } + void TestMetadataGetTableSchemaNotFound() { GTEST_SKIP(); } void TestMetadataGetTableTypes() { GTEST_SKIP(); } protected: @@ -94,6 +95,17 @@ class DuckDbStatementTest : public ::testing::Test, // Accepts Prepare() without any query void TestSqlPrepareErrorNoQuery() { GTEST_SKIP(); } + void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; } + void TestSqlIngestColumnEscaping() { + GTEST_SKIP() << "Column escaping not implemented"; + } + + void TestSqlQueryErrors() { GTEST_SKIP() << "DuckDB does not set AdbcError.release"; } + + void TestErrorCompatibility() { + GTEST_SKIP() << "DuckDB does not set AdbcError.release"; + } + protected: DuckDbQuirks quirks_; }; diff --git a/3rd_party/apache-arrow-adbc/c/symbols.map b/3rd_party/apache-arrow-adbc/c/symbols.map index 5e965b3..c9464b2 100644 --- a/3rd_party/apache-arrow-adbc/c/symbols.map +++ b/3rd_party/apache-arrow-adbc/c/symbols.map @@ -20,6 +20,16 @@ # Only expose symbols from the ADBC API Adbc*; + # Expose driver-specific initialization routines + FlightSQLDriverInit; + PostgresqlDriverInit; + SnowflakeDriverInit; + SqliteDriverInit; + + extern "C++" { + Adbc*; + }; + local: *; }; diff --git a/3rd_party/apache-arrow-adbc/c/validation/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/validation/CMakeLists.txt index 3c83f95..bab7a63 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/validation/CMakeLists.txt @@ -15,10 +15,24 @@ # specific language governing permissions and limitations # under the License. -add_library(adbc_validation OBJECT adbc_validation.cc adbc_validation_util.cc) +add_library(adbc_validation_util STATIC adbc_validation_util.cc) +adbc_configure_target(adbc_validation_util) +target_compile_features(adbc_validation_util PRIVATE cxx_std_17) +target_include_directories(adbc_validation_util SYSTEM + PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" + "${REPOSITORY_ROOT}/c/vendor/") +target_link_libraries(adbc_validation_util PUBLIC adbc_driver_common nanoarrow + GTest::gtest GTest::gmock) + +add_library(adbc_validation OBJECT adbc_validation.cc) +adbc_configure_target(adbc_validation) target_compile_features(adbc_validation PRIVATE cxx_std_17) target_include_directories(adbc_validation SYSTEM PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/driver/" "${REPOSITORY_ROOT}/c/vendor/") -target_link_libraries(adbc_validation PUBLIC adbc_driver_common nanoarrow GTest::gtest - GTest::gmock) +target_link_libraries(adbc_validation + PUBLIC adbc_driver_common + adbc_validation_util + nanoarrow + GTest::gtest + GTest::gmock) diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.cc b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.cc index d73b556..d25f236 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.cc +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -33,8 +34,10 @@ #include #include #include +#include #include "adbc_validation_util.h" +#include "common/options.h" namespace adbc_validation { @@ -101,7 +104,7 @@ AdbcStatusCode DriverQuirks::EnsureSampleTable(struct AdbcConnection* connection AdbcStatusCode DriverQuirks::CreateSampleTable(struct AdbcConnection* connection, const std::string& name, struct AdbcError* error) const { - if (!supports_bulk_ingest()) { + if (!supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { return ADBC_STATUS_NOT_IMPLEMENTED; } return DoIngestSampleTable(connection, name, error); @@ -247,6 +250,56 @@ void ConnectionTest::TestAutocommitToggle() { //------------------------------------------------------------ // Tests of metadata +std::optional ConnectionGetOption(struct AdbcConnection* connection, + std::string_view option, + struct AdbcError* error) { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + AdbcStatusCode status = + AdbcConnectionGetOption(connection, option.data(), buffer, &buffer_size, error); + EXPECT_THAT(status, IsOkStatus(error)); + if (status != ADBC_STATUS_OK) return std::nullopt; + EXPECT_GT(buffer_size, 0); + if (buffer_size == 0) return std::nullopt; + return std::string(buffer, buffer_size - 1); +} + +void ConnectionTest::TestMetadataCurrentCatalog() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + if (quirks()->supports_metadata_current_catalog()) { + ASSERT_THAT( + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error), + ::testing::Optional(quirks()->catalog())); + } else { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + ASSERT_THAT( + AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, + buffer, &buffer_size, &error), + IsStatus(ADBC_STATUS_NOT_FOUND)); + } +} + +void ConnectionTest::TestMetadataCurrentDbSchema() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + if (quirks()->supports_metadata_current_db_schema()) { + ASSERT_THAT(ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + &error), + ::testing::Optional(quirks()->db_schema())); + } else { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + ASSERT_THAT( + AdbcConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + buffer, &buffer_size, &error), + IsStatus(ADBC_STATUS_NOT_FOUND)); + } +} + void ConnectionTest::TestMetadataGetInfo() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); @@ -255,83 +308,110 @@ void ConnectionTest::TestMetadataGetInfo() { GTEST_SKIP(); } - StreamReader reader; - std::vector info = { - ADBC_INFO_DRIVER_NAME, - ADBC_INFO_DRIVER_VERSION, - ADBC_INFO_VENDOR_NAME, - ADBC_INFO_VENDOR_VERSION, - }; + for (uint32_t info_code : { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ADBC_VERSION, + ADBC_INFO_VENDOR_NAME, + ADBC_INFO_VENDOR_VERSION, + }) { + SCOPED_TRACE("info_code = " + std::to_string(info_code)); + std::optional expected = quirks()->supports_get_sql_info(info_code); - ASSERT_THAT(AdbcConnectionGetInfo(&connection, info.data(), info.size(), - &reader.stream.value, &error), - IsOkStatus(&error)); - ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ASSERT_NO_FATAL_FAILURE(CompareSchema( - &reader.schema.value, { - {"info_name", NANOARROW_TYPE_UINT32, NOT_NULL}, - {"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1], - { - {"string_value", NANOARROW_TYPE_STRING, NULLABLE}, - {"bool_value", NANOARROW_TYPE_BOOL, NULLABLE}, - {"int64_value", NANOARROW_TYPE_INT64, NULLABLE}, - {"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE}, - {"string_list", NANOARROW_TYPE_LIST, NULLABLE}, - {"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4], - { - {"item", NANOARROW_TYPE_STRING, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[5], - { - {"entries", NANOARROW_TYPE_STRUCT, NOT_NULL}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1]->children[5]->children[0], - { - {"key", NANOARROW_TYPE_INT32, NOT_NULL}, - {"value", NANOARROW_TYPE_LIST, NULLABLE}, - })); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1], - { - {"item", NANOARROW_TYPE_INT32, NULLABLE}, - })); + if (!expected.has_value()) continue; - std::vector seen; - while (true) { - ASSERT_NO_FATAL_FAILURE(reader.Next()); - if (!reader.array->release) break; + uint32_t info[] = {info_code}; + + StreamReader reader; + ASSERT_THAT(AdbcConnectionGetInfo(&connection, info, 1, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, { + {"info_name", NANOARROW_TYPE_UINT32, NOT_NULL}, + {"info_value", NANOARROW_TYPE_DENSE_UNION, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1], + { + {"string_value", NANOARROW_TYPE_STRING, NULLABLE}, + {"bool_value", NANOARROW_TYPE_BOOL, NULLABLE}, + {"int64_value", NANOARROW_TYPE_INT64, NULLABLE}, + {"int32_bitmask", NANOARROW_TYPE_INT32, NULLABLE}, + {"string_list", NANOARROW_TYPE_LIST, NULLABLE}, + {"int32_to_int32_list_map", NANOARROW_TYPE_MAP, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE(CompareSchema(reader.schema->children[1]->children[4], + { + {"item", NANOARROW_TYPE_STRING, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5], + { + {"entries", NANOARROW_TYPE_STRUCT, NOT_NULL}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5]->children[0], + { + {"key", NANOARROW_TYPE_INT32, NOT_NULL}, + {"value", NANOARROW_TYPE_LIST, NULLABLE}, + })); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(reader.schema->children[1]->children[5]->children[0]->children[1], + { + {"item", NANOARROW_TYPE_INT32, NULLABLE}, + })); + + std::vector seen; + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; - for (int64_t row = 0; row < reader.array->length; row++) { - ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); - const uint32_t code = - reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; - seen.push_back(code); - - switch (code) { - case ADBC_INFO_DRIVER_NAME: - case ADBC_INFO_DRIVER_VERSION: - case ADBC_INFO_VENDOR_NAME: - case ADBC_INFO_VENDOR_VERSION: - // UTF8 - ASSERT_EQ(uint8_t(0), - reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]); - default: - // Ignored - break; + for (int64_t row = 0; row < reader.array->length; row++) { + ASSERT_FALSE(ArrowArrayViewIsNull(reader.array_view->children[0], row)); + const uint32_t code = + reader.array_view->children[0]->buffer_views[1].data.as_uint32[row]; + seen.push_back(code); + if (code != info_code) { + continue; + } + + ASSERT_TRUE(expected.has_value()) << "Got unexpected info code " << code; + + uint8_t type_code = + reader.array_view->children[1]->buffer_views[0].data.as_uint8[row]; + int32_t offset = + reader.array_view->children[1]->buffer_views[1].data.as_int32[row]; + ASSERT_NO_FATAL_FAILURE(std::visit( + [&](auto&& expected_value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + ASSERT_EQ(uint8_t(2), type_code); + EXPECT_EQ(expected_value, + ArrowArrayViewGetIntUnsafe( + reader.array_view->children[1]->children[2], offset)); + } else if constexpr (std::is_same_v) { + ASSERT_EQ(uint8_t(0), type_code); + struct ArrowStringView view = ArrowArrayViewGetStringUnsafe( + reader.array_view->children[1]->children[0], offset); + EXPECT_THAT(std::string_view(static_cast(view.data), + view.size_bytes), + ::testing::HasSubstr(expected_value)); + } else { + static_assert(!sizeof(T), "not yet implemented"); + } + }, + *expected)) + << "code: " << type_code; } } + EXPECT_THAT(seen, ::testing::IsSupersetOf(info)); } - ASSERT_THAT(seen, ::testing::UnorderedElementsAreArray(info)); } void ConnectionTest::TestMetadataGetTableSchema() { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); @@ -351,6 +431,33 @@ void ConnectionTest::TestMetadataGetTableSchema() { {"strings", NANOARROW_TYPE_STRING, NULLABLE}})); } +void ConnectionTest::TestMetadataGetTableSchemaEscaping() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + Handle schema; + ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, + /*db_schema=*/nullptr, "(SELECT CURRENT_TIME)", + &schema.value, &error), + IsStatus(ADBC_STATUS_NOT_FOUND, &error)); +}; + +void ConnectionTest::TestMetadataGetTableSchemaNotFound() { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTable(&connection, "thistabledoesnotexist", &error), + IsOkStatus(&error)); + + Handle schema; + ASSERT_THAT(AdbcConnectionGetTableSchema(&connection, /*catalog=*/nullptr, + /*db_schema=*/nullptr, "thistabledoesnotexist", + &schema.value, &error), + IsStatus(ADBC_STATUS_NOT_FOUND, &error)); +} + void ConnectionTest::TestMetadataGetTableTypes() { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); @@ -911,6 +1018,58 @@ void ConnectionTest::TestMetadataGetObjectsPrimaryKey() { ASSERT_EQ(constraint_column_name, "id"); } +void ConnectionTest::TestMetadataGetObjectsCancel() { + if (!quirks()->supports_cancel() || !quirks()->supports_get_objects()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + StreamReader reader; + ASSERT_THAT( + AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_CATALOGS, nullptr, nullptr, + nullptr, nullptr, nullptr, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_THAT(AdbcConnectionCancel(&connection, &error), IsOkStatus(&error)); + + while (true) { + int err = reader.MaybeNext(); + if (err != 0) { + ASSERT_THAT(err, ::testing::AnyOf(0, IsErrno(ECANCELED, &reader.stream.value, + /*ArrowError*/ nullptr))); + } + if (!reader.array->release) break; + } +} + +void ConnectionTest::TestMetadataGetStatisticNames() { + if (!quirks()->supports_statistics()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + StreamReader reader; + ASSERT_THAT(AdbcConnectionGetStatisticNames(&connection, &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, { + {"statistic_name", NANOARROW_TYPE_STRING, NOT_NULL}, + {"statistic_key", NANOARROW_TYPE_INT16, NOT_NULL}, + })); + + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) break; + } +} + //------------------------------------------------------------ // Tests of AdbcStatement @@ -965,7 +1124,7 @@ void StatementTest::TestRelease() { template void StatementTest::TestSqlIngestType(ArrowType type, const std::vector>& values) { - if (!quirks()->supports_bulk_ingest()) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1023,6 +1182,7 @@ void StatementTest::TestSqlIngestType(ArrowType type, ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(nullptr, reader.array->release); } + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } template @@ -1036,6 +1196,10 @@ void StatementTest::TestSqlIngestNumericType(ArrowType type) { // values. Likely a bug on our side, but for now, avoid them. values.push_back(static_cast(-1.5)); values.push_back(static_cast(1.5)); + } else if (type == ArrowType::NANOARROW_TYPE_DATE32) { + // Windows does not seem to support negative date values + values.push_back(static_cast(0)); + values.push_back(static_cast(42)); } else { values.push_back(std::numeric_limits::lowest()); values.push_back(std::numeric_limits::max()); @@ -1089,25 +1253,41 @@ void StatementTest::TestSqlIngestString() { NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"})); } +void StatementTest::TestSqlIngestLargeString() { + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( + NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"})); +} + void StatementTest::TestSqlIngestBinary() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "\xFE\xFF"})); } -void StatementTest::TestSqlIngestAppend() { - if (!quirks()->supports_bulk_ingest()) { +void StatementTest::TestSqlIngestDate32() { + ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_DATE32)); +} + +template +void StatementTest::TestSqlIngestTemporalType(const char* timezone) { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } - // Ingest ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); Handle schema; Handle array; struct ArrowError na_error; - ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); - ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + const std::vector> values = {std::nullopt, -42, 0, 42}; + + // much of this code is shared with TestSqlIngestType with minor + // changes to allow for various time units to be tested + ArrowSchemaInit(&schema.value); + ArrowSchemaSetTypeStruct(&schema.value, 1); + ArrowSchemaSetTypeDateTime(schema->children[0], type, TU, timezone); + ArrowSchemaSetName(schema->children[0], "col"); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, values), IsOkErrno()); ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -1120,31 +1300,12 @@ void StatementTest::TestSqlIngestAppend() { int64_t rows_affected = 0; ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), IsOkStatus(&error)); - ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); - - // Now append - - // Re-initialize since Bind() should take ownership of data - ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); - ASSERT_THAT( - MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), - IsOkErrno()); - - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, - "bulk_ingest", &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, - ADBC_INGEST_OPTION_MODE_APPEND, &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), - IsOkStatus(&error)); - - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), - IsOkStatus(&error)); - ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + ASSERT_THAT(rows_affected, + ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); - // Read data back - ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, + "SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error), IsOkStatus(&error)); { StreamReader reader; @@ -1152,19 +1313,20 @@ void StatementTest::TestSqlIngestAppend() { &reader.rows_affected, &error), IsOkStatus(&error)); ASSERT_THAT(reader.rows_affected, - ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, - {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}})); ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_NE(nullptr, reader.array->release); - ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(values.size(), reader.array->length); ASSERT_EQ(1, reader.array->n_children); - ASSERT_NO_FATAL_FAILURE( - CompareArray(reader.array_view->children[0], {42, -42, std::nullopt})); + ValidateIngestedTemporalData(reader.array_view->children[0], type, TU, timezone); ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(nullptr, reader.array->release); @@ -1173,92 +1335,546 @@ void StatementTest::TestSqlIngestAppend() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } -void StatementTest::TestSqlIngestErrors() { - if (!quirks()->supports_bulk_ingest()) { +void StatementTest::ValidateIngestedTemporalData(struct ArrowArrayView* values, + ArrowType type, enum ArrowTimeUnit unit, + const char* timezone) { + FAIL() << "ValidateIngestedTemporalData is not implemented in the base class"; +} + +void StatementTest::TestSqlIngestDuration() { + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); +} + +void StatementTest::TestSqlIngestTimestamp() { + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + nullptr))); +} + +void StatementTest::TestSqlIngestTimestampTz() { + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "UTC"))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "UTC"))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "UTC"))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "UTC"))); + + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "America/Los_Angeles"))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "America/Los_Angeles"))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "America/Los_Angeles"))); + ASSERT_NO_FATAL_FAILURE( + (TestSqlIngestTemporalType( + "America/Los_Angeles"))); +} + +void StatementTest::TestSqlIngestInterval() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } - // Ingest without bind - ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, - "bulk_ingest", &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_STATE, &error)); - if (error.release) error.release(&error); - ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); - // Append to nonexistent table Handle schema; Handle array; struct ArrowError na_error; + const enum ArrowType type = NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO; + // values are days, months, ns + struct ArrowInterval neg_interval; + struct ArrowInterval zero_interval; + struct ArrowInterval pos_interval; - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, - "bulk_ingest", &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, - ADBC_INGEST_OPTION_MODE_APPEND, &error), - IsOkStatus(&error)); - ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); - ASSERT_THAT( - MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), - IsOkErrno(&na_error)); - ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), - IsOkStatus(&error)); + ArrowIntervalInit(&neg_interval, type); + ArrowIntervalInit(&zero_interval, type); + ArrowIntervalInit(&pos_interval, type); - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - ::testing::Not(IsOkStatus(&error))); - if (error.release) error.release(&error); + neg_interval.months = -5; + neg_interval.days = -5; + neg_interval.ns = -42000; - // Ingest... - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, - ADBC_INGEST_OPTION_MODE_CREATE, &error), - IsOkStatus(&error)); - ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); - ASSERT_THAT( - MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), - IsOkErrno(&na_error)); - ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsOkStatus(&error)); + pos_interval.months = 5; + pos_interval.days = 5; + pos_interval.ns = 42000; - // ...then try to overwrite it - ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); - ASSERT_THAT( - MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), - IsOkErrno(&na_error)); - ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - ::testing::Not(IsOkStatus(&error))); - if (error.release) error.release(&error); + const std::vector> values = { + std::nullopt, &neg_interval, &zero_interval, &pos_interval}; - // ...then try to append an incompatible schema - ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}, - {"coltwo", NANOARROW_TYPE_INT64}}), + ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno()); + + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, values), IsOkErrno()); - ASSERT_THAT( - (MakeBatch(&schema.value, &array.value, &na_error, {}, {})), - IsOkErrno(&na_error)); - ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, - ADBC_INGEST_OPTION_MODE_APPEND, &error), + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - ::testing::Not(IsOkStatus(&error))); -} - -void StatementTest::TestSqlIngestMultipleConnections() { - if (!quirks()->supports_bulk_ingest()) { - GTEST_SKIP(); - } - ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, + ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, + "SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type); + ASSERT_NO_FATAL_FAILURE( + CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(values.size(), reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + if (round_trip_type == type) { + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], values)); + } + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestTableEscaping() { + std::string name = "create_table_escaping"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"index", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestColumnEscaping() { + std::string name = "create"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"index", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestAppend() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_APPEND)) { + GTEST_SKIP(); + } + + // Ingest + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Now append + + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT( + MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {42, -42, std::nullopt})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestReplace() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_REPLACE)) { + GTEST_SKIP(); + } + + // Ingest + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(1, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE(CompareArray(reader.array_view->children[0], {42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + // Replace + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {-42, -42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(2, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {-42, -42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } +} + +void StatementTest::TestSqlIngestCreateAppend() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE_APPEND)) { + GTEST_SKIP(); + } + + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + // Ingest + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_CREATE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(1), ::testing::Eq(-1))); + + // Append + // Re-initialize since Bind() should take ownership of data + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, {42, 42}), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(2), ::testing::Eq(-1))); + + // Read data back + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_ingest", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {42, 42, 42})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestErrors() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + // Ingest without bind + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + if (error.release) error.release(&error); + + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + // Append to nonexistent table + Handle schema; + Handle array; + struct ArrowError na_error; + + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT( + MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), + IsOkErrno(&na_error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + ::testing::Not(IsOkStatus(&error))); + if (error.release) error.release(&error); + + // Ingest... + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_CREATE, &error), + IsOkStatus(&error)); + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT( + MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), + IsOkErrno(&na_error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + // ...then try to overwrite it + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT( + MakeBatch(&schema.value, &array.value, &na_error, {-42, std::nullopt}), + IsOkErrno(&na_error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + ::testing::Not(IsOkStatus(&error))); + if (error.release) error.release(&error); + + // ...then try to append an incompatible schema + ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}, + {"coltwo", NANOARROW_TYPE_INT64}}), + IsOkErrno()); + ASSERT_THAT( + (MakeBatch(&schema.value, &array.value, &na_error, {}, {})), + IsOkErrno(&na_error)); + + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + ::testing::Not(IsOkStatus(&error))); +} + +void StatementTest::TestSqlIngestMultipleConnections() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); Handle schema; @@ -1269,97 +1885,630 @@ void StatementTest::TestSqlIngestMultipleConnections() { {42, -42, std::nullopt})), IsOkErrno()); - ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, - "bulk_ingest", &error), - IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), - IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + + { + struct AdbcConnection connection2 = {}; + ASSERT_THAT(AdbcConnectionNew(&connection2, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection2, &database, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementNew(&connection2, &statement, &error), IsOkStatus(&error)); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT * FROM bulk_ingest ORDER BY \"int64s\" DESC NULLS LAST", + &error), + IsOkStatus(&error)); + + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema( + &reader.schema.value, {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {42, -42, std::nullopt})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionRelease(&connection2, &error), IsOkStatus(&error)); + } +} + +void StatementTest::TestSqlIngestSample() { + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "bulk_ingest", &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "SELECT * FROM bulk_ingest ORDER BY int64s ASC NULLS FIRST", + &error), + IsOkStatus(&error)); + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, + {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}, + {"strings", NANOARROW_TYPE_STRING, NULLABLE}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(3, reader.array->length); + ASSERT_EQ(2, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE( + CompareArray(reader.array_view->children[0], {std::nullopt, -42, 42})); + ASSERT_NO_FATAL_FAILURE(CompareArray(reader.array_view->children[1], + {"", std::nullopt, "foo"})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); +} + +void StatementTest::TestSqlIngestTargetCatalog() { + if (!quirks()->supports_bulk_ingest_catalog() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + std::string catalog = quirks()->catalog(); + std::string name = "bulk_ingest"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_CATALOG, + catalog.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestTargetSchema() { + if (!quirks()->supports_bulk_ingest_db_schema() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + std::string db_schema = quirks()->db_schema(); + std::string name = "bulk_ingest"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT( + AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA, + db_schema.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestTargetCatalogSchema() { + if (!quirks()->supports_bulk_ingest_catalog() || + !quirks()->supports_bulk_ingest_db_schema() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + std::string catalog = quirks()->catalog(); + std::string db_schema = quirks()->db_schema(); + std::string name = "bulk_ingest"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_CATALOG, + catalog.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT( + AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA, + db_schema.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestTemporary() { + if (!quirks()->supports_bulk_ingest_temporary() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { + GTEST_SKIP(); + } + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + std::string name = "bulk_ingest"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTempTable(&connection, name, &error), IsOkStatus(&error)); + + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestTemporaryAppend() { + // Append to temp table shouldn't affect actual table and vice versa + if (!quirks()->supports_bulk_ingest_temporary() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_APPEND)) { + GTEST_SKIP(); + } + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + std::string name = "bulk_ingest"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTempTable(&connection, name, &error), IsOkStatus(&error)); + + // Create both tables with different schemas + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"strs", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {"foo", "bar", std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + // Append to the temporary table + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, {0, 1, 2})), + IsOkErrno()); - int64_t rows_affected = 0; - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + // Append to the normal table + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"strs", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + ASSERT_THAT( + (MakeBatch(&schema.value, &array.value, &na_error, {"", "a", "b"})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlIngestTemporaryReplace() { + // Replace temp table shouldn't affect actual table and vice versa + if (!quirks()->supports_bulk_ingest_temporary() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_APPEND) || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_REPLACE)) { + GTEST_SKIP(); + } + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), IsOkStatus(&error)); - ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); - ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + std::string name = "bulk_ingest"; + + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTempTable(&connection, name, &error), IsOkStatus(&error)); + + // Create both tables with different schemas { - struct AdbcConnection connection2 = {}; - ASSERT_THAT(AdbcConnectionNew(&connection2, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcConnectionInit(&connection2, &database, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementNew(&connection2, &statement, &error), IsOkStatus(&error)); + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } - ASSERT_THAT( - AdbcStatementSetSqlQuery( - &statement, "SELECT * FROM bulk_ingest ORDER BY \"int64s\" DESC NULLS LAST", - &error), - IsOkStatus(&error)); + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"strs", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {"foo", "bar", std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } - { - StreamReader reader; - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, - &reader.rows_affected, &error), - IsOkStatus(&error)); - ASSERT_THAT(reader.rows_affected, - ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + // Replace both tables with different schemas + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints2", NANOARROW_TYPE_INT64}, + {"strs2", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {0, 1, std::nullopt}, + {"foo", "bar", std::nullopt})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } - ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ASSERT_NO_FATAL_FAILURE(CompareSchema( - &reader.schema.value, {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}})); + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints3", NANOARROW_TYPE_INT64}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, {1, 2, 3})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_REPLACE, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } - ASSERT_NO_FATAL_FAILURE(reader.Next()); - ASSERT_NE(nullptr, reader.array->release); - ASSERT_EQ(3, reader.array->length); - ASSERT_EQ(1, reader.array->n_children); + // Now append to the replaced tables to check that the schemas are as expected + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints2", NANOARROW_TYPE_INT64}, + {"strs2", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {0, 1, std::nullopt}, + {"foo", "bar", std::nullopt})), + IsOkErrno()); - ASSERT_NO_FATAL_FAILURE( - CompareArray(reader.array_view->children[0], {42, -42, std::nullopt})); + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); - ASSERT_NO_FATAL_FAILURE(reader.Next()); - ASSERT_EQ(nullptr, reader.array->release); - } + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } - ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcConnectionRelease(&connection2, &error), IsOkStatus(&error)); + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints3", NANOARROW_TYPE_INT64}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, {4, 5, 6})), + IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); } + + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); } -void StatementTest::TestSqlIngestSample() { - if (!quirks()->supports_bulk_ingest()) { +void StatementTest::TestSqlIngestTemporaryExclusive() { + // Can't set target schema/catalog with temp table + if (!quirks()->supports_bulk_ingest_temporary() || + !quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } - ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "bulk_ingest", &error), - IsOkStatus(&error)); + std::string name = "bulk_ingest"; + ASSERT_THAT(quirks()->DropTempTable(&connection, name, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); - ASSERT_THAT(AdbcStatementSetSqlQuery( - &statement, "SELECT * FROM bulk_ingest ORDER BY int64s ASC NULLS FIRST", - &error), - IsOkStatus(&error)); - StreamReader reader; - ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, - &reader.rows_affected, &error), - IsOkStatus(&error)); - ASSERT_THAT(reader.rows_affected, - ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + if (quirks()->supports_bulk_ingest_catalog()) { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); - ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, - {{"int64s", NANOARROW_TYPE_INT64, NULLABLE}, - {"strings", NANOARROW_TYPE_STRING, NULLABLE}})); + std::string catalog = quirks()->catalog(); - ASSERT_NO_FATAL_FAILURE(reader.Next()); - ASSERT_NE(nullptr, reader.array->release); - ASSERT_EQ(3, reader.array->length); - ASSERT_EQ(2, reader.array->n_children); + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT( + AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_CATALOG, + catalog.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } - ASSERT_NO_FATAL_FAILURE( - CompareArray(reader.array_view->children[0], {std::nullopt, -42, 42})); - ASSERT_NO_FATAL_FAILURE(CompareArray(reader.array_view->children[1], - {"", std::nullopt, "foo"})); + if (quirks()->supports_bulk_ingest_db_schema()) { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"ints", NANOARROW_TYPE_INT64}}), IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); - ASSERT_NO_FATAL_FAILURE(reader.Next()); - ASSERT_EQ(nullptr, reader.array->release); + std::string db_schema = quirks()->db_schema(); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TEMPORARY, + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT( + AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA, + db_schema.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } } void StatementTest::TestSqlPartitionedInts() { @@ -1570,7 +2719,7 @@ void StatementTest::TestSqlPrepareSelectParams() { } void StatementTest::TestSqlPrepareUpdate() { - if (!quirks()->supports_bulk_ingest() || + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || !quirks()->supports_dynamic_parameter_binding()) { GTEST_SKIP(); } @@ -1649,7 +2798,7 @@ void StatementTest::TestSqlPrepareUpdateNoParams() { } void StatementTest::TestSqlPrepareUpdateStream() { - if (!quirks()->supports_bulk_ingest() || + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE) || !quirks()->supports_dynamic_parameter_binding()) { GTEST_SKIP(); } @@ -1924,6 +3073,36 @@ void StatementTest::TestSqlQueryStrings() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +void StatementTest::TestSqlQueryCancel() { + if (!quirks()->supports_cancel()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'SaShiSuSeSo'", &error), + IsOkStatus(&error)); + + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + + ASSERT_THAT(AdbcStatementCancel(&statement, &error), IsOkStatus(&error)); + while (true) { + int err = reader.MaybeNext(); + if (err != 0) { + ASSERT_THAT(err, ::testing::AnyOf(0, IsErrno(ECANCELED, &reader.stream.value, + /*ArrowError*/ nullptr))); + } + if (!reader.array->release) break; + } + } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestSqlQueryErrors() { // Invalid query ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -1943,6 +3122,13 @@ void StatementTest::TestTransactions() { ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error), IsOkStatus(&error)); + if (quirks()->supports_get_option()) { + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_ENABLED))); + } + Handle connection2; ASSERT_THAT(AdbcConnectionNew(&connection2.value, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection2.value, &database, &error), @@ -1952,6 +3138,13 @@ void StatementTest::TestTransactions() { ADBC_OPTION_VALUE_DISABLED, &error), IsOkStatus(&error)); + if (quirks()->supports_get_option()) { + auto autocommit = + ConnectionGetOption(&connection, ADBC_CONNECTION_OPTION_AUTOCOMMIT, &error); + ASSERT_THAT(autocommit, + ::testing::Optional(::testing::StrEq(ADBC_OPTION_VALUE_DISABLED))); + } + // Uncommitted change ASSERT_NO_FATAL_FAILURE(IngestSampleTable(&connection, &error)); @@ -2027,6 +3220,86 @@ void StatementTest::TestTransactions() { } } +void StatementTest::TestSqlSchemaInts() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("i"), // int32 + ::testing::StrEq("l"), // int64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaFloats() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT CAST(1.5 AS FLOAT)", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("f"), // float32 + ::testing::StrEq("g"), // float64 + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaStrings() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 'hi'", &error), + IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsOkStatus(&error)); + + ASSERT_EQ(1, schema->n_children); + ASSERT_THAT(schema->children[0]->format, ::testing::AnyOfArray({ + ::testing::StrEq("u"), // string + ::testing::StrEq("U"), // large_string + })); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + +void StatementTest::TestSqlSchemaErrors() { + if (!quirks()->supports_execute_schema()) { + GTEST_SKIP() << "Not supported"; + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + nanoarrow::UniqueSchema schema; + ASSERT_THAT(AdbcStatementExecuteSchema(&statement, schema.get(), &error), + IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestConcurrentStatements() { Handle statement1; Handle statement2; @@ -2062,6 +3335,24 @@ void StatementTest::TestConcurrentStatements() { ASSERT_NO_FATAL_FAILURE(reader1.GetSchema()); } +// Test that an ADBC 1.0.0-sized error still works +void StatementTest::TestErrorCompatibility() { + // XXX: sketchy cast + auto* error = static_cast(malloc(ADBC_ERROR_1_0_0_SIZE)); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, error), IsOkStatus(error)); + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "SELECT * FROM thistabledoesnotexist", error), + IsOkStatus(error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, error), + ::testing::Not(IsOkStatus(error))); + error->release(error); + free(error); +} + void StatementTest::TestResultInvalidation() { // Start reading from a statement, then overwrite it ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); @@ -2080,8 +3371,8 @@ void StatementTest::TestResultInvalidation() { IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader2.GetSchema()); - // First reader should not fail, but may give no data - ASSERT_NO_FATAL_FAILURE(reader1.Next()); + // First reader may fail, or may succeed but give no data + reader1.MaybeNext(); } #undef NOT_NULL diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h index 4e4251b..0d936de 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -31,6 +32,8 @@ namespace adbc_validation { #define ADBCV_STRINGIFY(s) #s #define ADBCV_STRINGIFY_VALUE(s) ADBCV_STRINGIFY(s) +using SqlInfoValue = std::variant; + /// \brief Configuration for driver-specific behavior. class DriverQuirks { public: @@ -47,6 +50,13 @@ class DriverQuirks { return ADBC_STATUS_OK; } + /// \brief Drop the given temporary table. Used by tests to reset state. + virtual AdbcStatusCode DropTempTable(struct AdbcConnection* connection, + const std::string& name, + struct AdbcError* error) const { + return ADBC_STATUS_OK; + } + /// \brief Drop the given view. Used by tests to reset state. virtual AdbcStatusCode DropView(struct AdbcConnection* connection, const std::string& name, @@ -85,10 +95,31 @@ class DriverQuirks { return ingest_type; } + /// \brief Whether bulk ingest is supported + virtual bool supports_bulk_ingest(const char* mode) const { return true; } + + /// \brief Whether bulk ingest to a specific catalog is supported + virtual bool supports_bulk_ingest_catalog() const { return false; } + + /// \brief Whether bulk ingest to a specific schema is supported + virtual bool supports_bulk_ingest_db_schema() const { return false; } + + /// \brief Whether bulk ingest to a temporary table is supported + virtual bool supports_bulk_ingest_temporary() const { return false; } + + /// \brief Whether we can cancel queries. + virtual bool supports_cancel() const { return false; } + /// \brief Whether two statements can be used at the same time on a /// single connection virtual bool supports_concurrent_statements() const { return false; } + /// \brief Whether AdbcStatementExecuteSchema should work + virtual bool supports_execute_schema() const { return false; } + + /// \brief Whether GetOption* should work + virtual bool supports_get_option() const { return true; } + /// \brief Whether AdbcStatementExecutePartitions should work virtual bool supports_partitioned_data() const { return false; } @@ -101,11 +132,19 @@ class DriverQuirks { /// \brief Whether GetSqlInfo is implemented virtual bool supports_get_sql_info() const { return true; } + /// \brief The expected value for a given info code + virtual std::optional supports_get_sql_info(uint32_t info_code) const { + return std::nullopt; + } + /// \brief Whether GetObjects is implemented virtual bool supports_get_objects() const { return true; } - /// \brief Whether bulk ingest is supported - virtual bool supports_bulk_ingest() const { return true; } + /// \brief Whether we can get ADBC_CONNECTION_OPTION_CURRENT_CATALOG + virtual bool supports_metadata_current_catalog() const { return false; } + + /// \brief Whether we can get ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA + virtual bool supports_metadata_current_db_schema() const { return false; } /// \brief Whether dynamic parameter bindings are supported for prepare virtual bool supports_dynamic_parameter_binding() const { return true; } @@ -113,6 +152,9 @@ class DriverQuirks { /// \brief Whether ExecuteQuery sets rows_affected appropriately virtual bool supports_rows_affected() const { return true; } + /// \brief Whether we can get statistics + virtual bool supports_statistics() const { return false; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } @@ -157,8 +199,13 @@ class ConnectionTest { void TestAutocommitToggle(); + void TestMetadataCurrentCatalog(); + void TestMetadataCurrentDbSchema(); + void TestMetadataGetInfo(); void TestMetadataGetTableSchema(); + void TestMetadataGetTableSchemaEscaping(); + void TestMetadataGetTableSchemaNotFound(); void TestMetadataGetTableTypes(); void TestMetadataGetObjectsCatalogs(); @@ -168,6 +215,9 @@ class ConnectionTest { void TestMetadataGetObjectsColumns(); void TestMetadataGetObjectsConstraints(); void TestMetadataGetObjectsPrimaryKey(); + void TestMetadataGetObjectsCancel(); + + void TestMetadataGetStatisticNames(); protected: struct AdbcError error; @@ -175,28 +225,38 @@ class ConnectionTest { struct AdbcConnection connection; }; -#define ADBCV_TEST_CONNECTION(FIXTURE) \ - static_assert(std::is_base_of::value, \ - ADBCV_STRINGIFY(FIXTURE) " must inherit from ConnectionTest"); \ - TEST_F(FIXTURE, NewInit) { TestNewInit(); } \ - TEST_F(FIXTURE, Release) { TestRelease(); } \ - TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ - TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ - TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ - TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ - TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ - TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ - TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ - TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ - TEST_F(FIXTURE, MetadataGetObjectsTables) { TestMetadataGetObjectsTables(); } \ - TEST_F(FIXTURE, MetadataGetObjectsTablesTypes) { \ - TestMetadataGetObjectsTablesTypes(); \ - } \ - TEST_F(FIXTURE, MetadataGetObjectsColumns) { TestMetadataGetObjectsColumns(); } \ - TEST_F(FIXTURE, MetadataGetObjectsConstraints) { \ - TestMetadataGetObjectsConstraints(); \ - } \ - TEST_F(FIXTURE, MetadataGetObjectsPrimaryKey) { TestMetadataGetObjectsPrimaryKey(); } +#define ADBCV_TEST_CONNECTION(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ADBCV_STRINGIFY(FIXTURE) " must inherit from ConnectionTest"); \ + TEST_F(FIXTURE, NewInit) { TestNewInit(); } \ + TEST_F(FIXTURE, Release) { TestRelease(); } \ + TEST_F(FIXTURE, Concurrent) { TestConcurrent(); } \ + TEST_F(FIXTURE, AutocommitDefault) { TestAutocommitDefault(); } \ + TEST_F(FIXTURE, AutocommitToggle) { TestAutocommitToggle(); } \ + TEST_F(FIXTURE, MetadataCurrentCatalog) { TestMetadataCurrentCatalog(); } \ + TEST_F(FIXTURE, MetadataCurrentDbSchema) { TestMetadataCurrentDbSchema(); } \ + TEST_F(FIXTURE, MetadataGetInfo) { TestMetadataGetInfo(); } \ + TEST_F(FIXTURE, MetadataGetTableSchema) { TestMetadataGetTableSchema(); } \ + TEST_F(FIXTURE, MetadataGetTableSchemaEscaping) { \ + TestMetadataGetTableSchemaEscaping(); \ + } \ + TEST_F(FIXTURE, MetadataGetTableSchemaNotFound) { \ + TestMetadataGetTableSchemaNotFound(); \ + } \ + TEST_F(FIXTURE, MetadataGetTableTypes) { TestMetadataGetTableTypes(); } \ + TEST_F(FIXTURE, MetadataGetObjectsCatalogs) { TestMetadataGetObjectsCatalogs(); } \ + TEST_F(FIXTURE, MetadataGetObjectsDbSchemas) { TestMetadataGetObjectsDbSchemas(); } \ + TEST_F(FIXTURE, MetadataGetObjectsTables) { TestMetadataGetObjectsTables(); } \ + TEST_F(FIXTURE, MetadataGetObjectsTablesTypes) { \ + TestMetadataGetObjectsTablesTypes(); \ + } \ + TEST_F(FIXTURE, MetadataGetObjectsColumns) { TestMetadataGetObjectsColumns(); } \ + TEST_F(FIXTURE, MetadataGetObjectsConstraints) { \ + TestMetadataGetObjectsConstraints(); \ + } \ + TEST_F(FIXTURE, MetadataGetObjectsPrimaryKey) { TestMetadataGetObjectsPrimaryKey(); } \ + TEST_F(FIXTURE, MetadataGetObjectsCancel) { TestMetadataGetObjectsCancel(); } \ + TEST_F(FIXTURE, MetadataGetStatisticNames) { TestMetadataGetStatisticNames(); } class StatementTest { public: @@ -227,14 +287,33 @@ class StatementTest { // Strings void TestSqlIngestString(); + void TestSqlIngestLargeString(); void TestSqlIngestBinary(); + // Temporal + void TestSqlIngestDuration(); + void TestSqlIngestDate32(); + void TestSqlIngestTimestamp(); + void TestSqlIngestTimestampTz(); + void TestSqlIngestInterval(); + // ---- End Type-specific tests ---------------- + void TestSqlIngestTableEscaping(); + void TestSqlIngestColumnEscaping(); void TestSqlIngestAppend(); + void TestSqlIngestReplace(); + void TestSqlIngestCreateAppend(); void TestSqlIngestErrors(); void TestSqlIngestMultipleConnections(); void TestSqlIngestSample(); + void TestSqlIngestTargetCatalog(); + void TestSqlIngestTargetSchema(); + void TestSqlIngestTargetCatalogSchema(); + void TestSqlIngestTemporary(); + void TestSqlIngestTemporaryAppend(); + void TestSqlIngestTemporaryReplace(); + void TestSqlIngestTemporaryExclusive(); void TestSqlPartitionedInts(); @@ -251,11 +330,19 @@ class StatementTest { void TestSqlQueryFloats(); void TestSqlQueryStrings(); + void TestSqlQueryCancel(); void TestSqlQueryErrors(); + void TestSqlSchemaInts(); + void TestSqlSchemaFloats(); + void TestSqlSchemaStrings(); + + void TestSqlSchemaErrors(); + void TestTransactions(); void TestConcurrentStatements(); + void TestErrorCompatibility(); void TestResultInvalidation(); protected: @@ -269,6 +356,13 @@ class StatementTest { template void TestSqlIngestNumericType(ArrowType type); + + template + void TestSqlIngestTemporalType(const char* timezone); + + virtual void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, + enum ArrowTimeUnit unit, + const char* timezone); }; #define ADBCV_TEST_STATEMENT(FIXTURE) \ @@ -287,11 +381,28 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestFloat32) { TestSqlIngestFloat32(); } \ TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \ TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \ + TEST_F(FIXTURE, SqlIngestLargeString) { TestSqlIngestLargeString(); } \ TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \ + TEST_F(FIXTURE, SqlIngestDuration) { TestSqlIngestDuration(); } \ + TEST_F(FIXTURE, SqlIngestDate32) { TestSqlIngestDate32(); } \ + TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \ + TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \ + TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \ + TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \ + TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); } \ TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \ + TEST_F(FIXTURE, SqlIngestReplace) { TestSqlIngestReplace(); } \ + TEST_F(FIXTURE, SqlIngestCreateAppend) { TestSqlIngestCreateAppend(); } \ TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); } \ TEST_F(FIXTURE, SqlIngestMultipleConnections) { TestSqlIngestMultipleConnections(); } \ TEST_F(FIXTURE, SqlIngestSample) { TestSqlIngestSample(); } \ + TEST_F(FIXTURE, SqlIngestTargetCatalog) { TestSqlIngestTargetCatalog(); } \ + TEST_F(FIXTURE, SqlIngestTargetSchema) { TestSqlIngestTargetSchema(); } \ + TEST_F(FIXTURE, SqlIngestTargetCatalogSchema) { TestSqlIngestTargetCatalogSchema(); } \ + TEST_F(FIXTURE, SqlIngestTemporary) { TestSqlIngestTemporary(); } \ + TEST_F(FIXTURE, SqlIngestTemporaryAppend) { TestSqlIngestTemporaryAppend(); } \ + TEST_F(FIXTURE, SqlIngestTemporaryReplace) { TestSqlIngestTemporaryReplace(); } \ + TEST_F(FIXTURE, SqlIngestTemporaryExclusive) { TestSqlIngestTemporaryExclusive(); } \ TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); } \ TEST_F(FIXTURE, SqlPrepareGetParameterSchema) { TestSqlPrepareGetParameterSchema(); } \ TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams(); } \ @@ -306,9 +417,15 @@ class StatementTest { TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \ + TEST_F(FIXTURE, SqlQueryCancel) { TestSqlQueryCancel(); } \ TEST_F(FIXTURE, SqlQueryErrors) { TestSqlQueryErrors(); } \ + TEST_F(FIXTURE, SqlSchemaInts) { TestSqlSchemaInts(); } \ + TEST_F(FIXTURE, SqlSchemaFloats) { TestSqlSchemaFloats(); } \ + TEST_F(FIXTURE, SqlSchemaStrings) { TestSqlSchemaStrings(); } \ + TEST_F(FIXTURE, SqlSchemaErrors) { TestSqlSchemaErrors(); } \ TEST_F(FIXTURE, Transactions) { TestTransactions(); } \ TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); } \ + TEST_F(FIXTURE, ErrorCompatibility) { TestErrorCompatibility(); } \ TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); } } // namespace adbc_validation diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc index 7978947..b3ca7d5 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc @@ -242,4 +242,25 @@ void CompareSchema( } } +std::string GetDriverVendorVersion(struct AdbcConnection* connection) { + const uint32_t info_code = ADBC_INFO_VENDOR_VERSION; + const uint32_t info[] = {info_code}; + + adbc_validation::StreamReader reader; + struct AdbcError error = ADBC_ERROR_INIT; + AdbcConnectionGetInfo(connection, info, 1, &reader.stream.value, &error); + reader.GetSchema(); + if (error.release) { + error.release(&error); + throw std::runtime_error("error occured calling AdbcConnectionGetInfo!"); + } + + reader.Next(); + const ArrowStringView raw_version = + ArrowArrayViewGetStringUnsafe(reader.array_view->children[1]->children[0], 0); + const std::string version(raw_version.data, raw_version.size_bytes); + + return version; +} + } // namespace adbc_validation diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h index a239e76..b637659 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h @@ -31,6 +31,7 @@ #include #include #include + #include "common/utils.h" namespace adbc_validation { @@ -200,7 +201,7 @@ struct StreamReader { /// \brief Read an AdbcGetInfoData struct with RAII safety struct GetObjectsReader { - explicit GetObjectsReader(struct ArrowArrayView* array_view) : array_view_(array_view) { + explicit GetObjectsReader(struct ArrowArrayView* array_view) { // TODO: this swallows any construction errors get_objects_data_ = AdbcGetObjectsDataInit(array_view); } @@ -214,7 +215,6 @@ struct GetObjectsReader { } private: - struct ArrowArrayView* array_view_; struct AdbcGetObjectsData* get_objects_data_; }; @@ -265,6 +265,10 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array, if (int errno_res = ArrowArrayAppendBytes(array, view); errno_res != 0) { return errno_res; } + } else if constexpr (std::is_same::value) { + if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) { + return errno_res; + } } else { static_assert(!sizeof(T), "Not yet implemented"); return ENOTSUP; @@ -376,6 +380,15 @@ void CompareArray(struct ArrowArrayView* array, struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i); std::string str(view.data, view.size_bytes); ASSERT_EQ(*v, str); + } else if constexpr (std::is_same::value) { + ASSERT_NE(array->buffer_views[1].data.data, nullptr); + struct ArrowInterval interval; + ArrowIntervalInit(&interval, ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); + ArrowArrayViewGetIntervalUnsafe(array, i, &interval); + + ASSERT_EQ(interval.months, (*v)->months); + ASSERT_EQ(interval.days, (*v)->days); + ASSERT_EQ(interval.ns, (*v)->ns); } else { static_assert(!sizeof(T), "Not yet implemented"); } @@ -392,4 +405,7 @@ void CompareSchema( struct ArrowSchema* schema, const std::vector, ArrowType, bool>>& fields); +/// \brief Helper method to get the vendor version of a driver +std::string GetDriverVendorVersion(struct AdbcConnection* connection); + } // namespace adbc_validation diff --git a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c index 0b8fc35..ab3e337 100644 --- a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c +++ b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c @@ -49,12 +49,21 @@ int ArrowErrorSet(struct ArrowError* error, const char* fmt, ...) { } } -const char* ArrowErrorMessage(struct ArrowError* error) { return error->message; } +const char* ArrowErrorMessage(struct ArrowError* error) { + if (error == NULL) { + return ""; + } else { + return error->message; + } +} void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { layout->buffer_type[0] = NANOARROW_BUFFER_TYPE_VALIDITY; - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_NONE; + layout->buffer_data_type[0] = NANOARROW_TYPE_BOOL; + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = storage_type; layout->buffer_type[2] = NANOARROW_BUFFER_TYPE_NONE; + layout->buffer_data_type[2] = NANOARROW_TYPE_UNINITIALIZED; layout->element_size_bits[0] = 1; layout->element_size_bits[1] = 0; @@ -66,43 +75,53 @@ void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { case NANOARROW_TYPE_UNINITIALIZED: case NANOARROW_TYPE_NA: layout->buffer_type[0] = NANOARROW_BUFFER_TYPE_NONE; + layout->buffer_data_type[0] = NANOARROW_TYPE_UNINITIALIZED; + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_NONE; + layout->buffer_data_type[1] = NANOARROW_TYPE_UNINITIALIZED; layout->element_size_bits[0] = 0; break; case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_MAP: layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA_OFFSET; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT32; layout->element_size_bits[1] = 32; break; case NANOARROW_TYPE_LARGE_LIST: layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA_OFFSET; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT64; layout->element_size_bits[1] = 64; break; + case NANOARROW_TYPE_STRUCT: + case NANOARROW_TYPE_FIXED_SIZE_LIST: + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_NONE; + layout->buffer_data_type[1] = NANOARROW_TYPE_UNINITIALIZED; + break; + case NANOARROW_TYPE_BOOL: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; layout->element_size_bits[1] = 1; break; case NANOARROW_TYPE_UINT8: case NANOARROW_TYPE_INT8: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; layout->element_size_bits[1] = 8; break; case NANOARROW_TYPE_UINT16: case NANOARROW_TYPE_INT16: case NANOARROW_TYPE_HALF_FLOAT: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; layout->element_size_bits[1] = 16; break; case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_INT32: case NANOARROW_TYPE_FLOAT: + layout->element_size_bits[1] = 32; + break; case NANOARROW_TYPE_INTERVAL_MONTHS: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT32; layout->element_size_bits[1] = 32; break; @@ -110,49 +129,61 @@ void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { case NANOARROW_TYPE_INT64: case NANOARROW_TYPE_DOUBLE: case NANOARROW_TYPE_INTERVAL_DAY_TIME: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; layout->element_size_bits[1] = 64; break; case NANOARROW_TYPE_DECIMAL128: case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; layout->element_size_bits[1] = 128; break; case NANOARROW_TYPE_DECIMAL256: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; layout->element_size_bits[1] = 256; break; case NANOARROW_TYPE_FIXED_SIZE_BINARY: - layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = NANOARROW_TYPE_BINARY; break; case NANOARROW_TYPE_DENSE_UNION: layout->buffer_type[0] = NANOARROW_BUFFER_TYPE_TYPE_ID; + layout->buffer_data_type[0] = NANOARROW_TYPE_INT8; layout->element_size_bits[0] = 8; layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_UNION_OFFSET; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT32; layout->element_size_bits[1] = 32; break; case NANOARROW_TYPE_SPARSE_UNION: layout->buffer_type[0] = NANOARROW_BUFFER_TYPE_TYPE_ID; + layout->buffer_data_type[0] = NANOARROW_TYPE_INT8; layout->element_size_bits[0] = 8; + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_NONE; + layout->buffer_data_type[1] = NANOARROW_TYPE_UNINITIALIZED; break; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_BINARY: layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA_OFFSET; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT32; layout->element_size_bits[1] = 32; layout->buffer_type[2] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[2] = storage_type; break; case NANOARROW_TYPE_LARGE_STRING: + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA_OFFSET; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT64; + layout->element_size_bits[1] = 64; + layout->buffer_type[2] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[2] = NANOARROW_TYPE_STRING; + break; case NANOARROW_TYPE_LARGE_BINARY: layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA_OFFSET; + layout->buffer_data_type[1] = NANOARROW_TYPE_INT64; layout->element_size_bits[1] = 64; layout->buffer_type[2] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[2] = NANOARROW_TYPE_BINARY; break; default: @@ -1892,24 +1923,43 @@ ArrowErrorCode ArrowArrayInitFromType(struct ArrowArray* array, return NANOARROW_OK; } -static ArrowErrorCode ArrowArrayInitFromArrayView(struct ArrowArray* array, - struct ArrowArrayView* array_view, - struct ArrowError* error) { - ArrowArrayInitFromType(array, array_view->storage_type); +ArrowErrorCode ArrowArrayInitFromArrayView(struct ArrowArray* array, + struct ArrowArrayView* array_view, + struct ArrowError* error) { + NANOARROW_RETURN_NOT_OK_WITH_ERROR( + ArrowArrayInitFromType(array, array_view->storage_type), error); + int result; + struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array->private_data; + private_data->layout = array_view->layout; - int result = ArrowArrayAllocateChildren(array, array_view->n_children); - if (result != NANOARROW_OK) { - array->release(array); - return result; + if (array_view->n_children > 0) { + result = ArrowArrayAllocateChildren(array, array_view->n_children); + if (result != NANOARROW_OK) { + array->release(array); + return result; + } + + for (int64_t i = 0; i < array_view->n_children; i++) { + result = + ArrowArrayInitFromArrayView(array->children[i], array_view->children[i], error); + if (result != NANOARROW_OK) { + array->release(array); + return result; + } + } } - private_data->layout = array_view->layout; + if (array_view->dictionary != NULL) { + result = ArrowArrayAllocateDictionary(array); + if (result != NANOARROW_OK) { + array->release(array); + return result; + } - for (int64_t i = 0; i < array_view->n_children; i++) { - int result = - ArrowArrayInitFromArrayView(array->children[i], array_view->children[i], error); + result = + ArrowArrayInitFromArrayView(array->dictionary, array_view->dictionary, error); if (result != NANOARROW_OK) { array->release(array); return result; @@ -1955,9 +2005,7 @@ ArrowErrorCode ArrowArrayAllocateChildren(struct ArrowArray* array, int64_t n_ch return ENOMEM; } - for (int64_t i = 0; i < n_children; i++) { - array->children[i] = NULL; - } + memset(array->children, 0, n_children * sizeof(struct ArrowArray*)); for (int64_t i = 0; i < n_children; i++) { array->children[i] = (struct ArrowArray*)ArrowMalloc(sizeof(struct ArrowArray)); @@ -2025,6 +2073,16 @@ static ArrowErrorCode ArrowArrayViewInitFromArray(struct ArrowArrayView* array_v ArrowArrayViewInitFromType(array_view, private_data->storage_type); array_view->layout = private_data->layout; array_view->array = array; + array_view->length = array->length; + array_view->offset = array->offset; + array_view->null_count = array->null_count; + + array_view->buffer_views[0].data.as_uint8 = private_data->bitmap.buffer.data; + array_view->buffer_views[0].size_bytes = private_data->bitmap.buffer.size_bytes; + array_view->buffer_views[1].data.as_uint8 = private_data->buffers[0].data; + array_view->buffer_views[1].size_bytes = private_data->buffers[0].size_bytes; + array_view->buffer_views[2].data.as_uint8 = private_data->buffers[1].data; + array_view->buffer_views[2].size_bytes = private_data->buffers[1].size_bytes; int result = ArrowArrayViewAllocateChildren(array_view, array->n_children); if (result != NANOARROW_OK) { @@ -2040,6 +2098,20 @@ static ArrowErrorCode ArrowArrayViewInitFromArray(struct ArrowArrayView* array_v } } + if (array->dictionary != NULL) { + result = ArrowArrayViewAllocateDictionary(array_view); + if (result != NANOARROW_OK) { + ArrowArrayViewReset(array_view); + return result; + } + + result = ArrowArrayViewInitFromArray(array_view->dictionary, array->dictionary); + if (result != NANOARROW_OK) { + ArrowArrayViewReset(array_view); + return result; + } + } + return NANOARROW_OK; } @@ -2112,6 +2184,10 @@ static ArrowErrorCode ArrowArrayFinalizeBuffers(struct ArrowArray* array) { NANOARROW_RETURN_NOT_OK(ArrowArrayFinalizeBuffers(array->children[i])); } + if (array->dictionary != NULL) { + NANOARROW_RETURN_NOT_OK(ArrowArrayFinalizeBuffers(array->dictionary)); + } + return NANOARROW_OK; } @@ -2126,39 +2202,10 @@ static void ArrowArrayFlushInternalPointers(struct ArrowArray* array) { for (int64_t i = 0; i < array->n_children; i++) { ArrowArrayFlushInternalPointers(array->children[i]); } -} - -static ArrowErrorCode ArrowArrayCheckInternalBufferSizes( - struct ArrowArray* array, struct ArrowArrayView* array_view, char set_length, - struct ArrowError* error) { - if (set_length) { - ArrowArrayViewSetLength(array_view, array->offset + array->length); - } - - for (int64_t i = 0; i < array->n_buffers; i++) { - if (array_view->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_VALIDITY && - array->null_count == 0 && array->buffers[i] == NULL) { - continue; - } - - int64_t expected_size = array_view->buffer_views[i].size_bytes; - int64_t actual_size = ArrowArrayBuffer(array, i)->size_bytes; - - if (actual_size < expected_size) { - ArrowErrorSet( - error, - "Expected buffer %d to size >= %ld bytes but found buffer with %ld bytes", - (int)i, (long)expected_size, (long)actual_size); - return EINVAL; - } - } - for (int64_t i = 0; i < array->n_children; i++) { - NANOARROW_RETURN_NOT_OK(ArrowArrayCheckInternalBufferSizes( - array->children[i], array_view->children[i], set_length, error)); + if (array->dictionary != NULL) { + ArrowArrayFlushInternalPointers(array->dictionary); } - - return NANOARROW_OK; } ArrowErrorCode ArrowArrayFinishBuilding(struct ArrowArray* array, @@ -2168,7 +2215,7 @@ ArrowErrorCode ArrowArrayFinishBuilding(struct ArrowArray* array, // in some implementations (at least one version of Arrow C++ at the time this // was added). Only do this fix if we can assume CPU data access. if (validation_level >= NANOARROW_VALIDATION_LEVEL_DEFAULT) { - NANOARROW_RETURN_NOT_OK(ArrowArrayFinalizeBuffers(array)); + NANOARROW_RETURN_NOT_OK_WITH_ERROR(ArrowArrayFinalizeBuffers(array), error); } // Make sure the value we get with array->buffers[i] is set to the actual @@ -2179,44 +2226,11 @@ ArrowErrorCode ArrowArrayFinishBuilding(struct ArrowArray* array, return NANOARROW_OK; } - // Check buffer sizes to make sure we are not sending an ArrowArray - // into the wild that is going to segfault + // For validation, initialize an ArrowArrayView with our known buffer sizes struct ArrowArrayView array_view; - - NANOARROW_RETURN_NOT_OK(ArrowArrayViewInitFromArray(&array_view, array)); - - // Check buffer sizes once without using internal buffer data since - // ArrowArrayViewSetArray() assumes that all the buffers are long enough - // and issues invalid reads on offset buffers if they are not - int result = ArrowArrayCheckInternalBufferSizes(array, &array_view, 1, error); - if (result != NANOARROW_OK) { - ArrowArrayViewReset(&array_view); - return result; - } - - if (validation_level == NANOARROW_VALIDATION_LEVEL_MINIMAL) { - ArrowArrayViewReset(&array_view); - return NANOARROW_OK; - } - - result = ArrowArrayViewSetArray(&array_view, array, error); - if (result != NANOARROW_OK) { - ArrowArrayViewReset(&array_view); - return result; - } - - result = ArrowArrayCheckInternalBufferSizes(array, &array_view, 0, error); - if (result != NANOARROW_OK) { - ArrowArrayViewReset(&array_view); - return result; - } - - if (validation_level == NANOARROW_VALIDATION_LEVEL_DEFAULT) { - ArrowArrayViewReset(&array_view); - return NANOARROW_OK; - } - - result = ArrowArrayViewValidateFull(&array_view, error); + NANOARROW_RETURN_NOT_OK_WITH_ERROR(ArrowArrayViewInitFromArray(&array_view, array), + error); + int result = ArrowArrayViewValidate(&array_view, validation_level, error); ArrowArrayViewReset(&array_view); return result; } @@ -2263,6 +2277,21 @@ ArrowErrorCode ArrowArrayViewAllocateChildren(struct ArrowArrayView* array_view, return NANOARROW_OK; } +ArrowErrorCode ArrowArrayViewAllocateDictionary(struct ArrowArrayView* array_view) { + if (array_view->dictionary != NULL) { + return EINVAL; + } + + array_view->dictionary = + (struct ArrowArrayView*)ArrowMalloc(sizeof(struct ArrowArrayView)); + if (array_view->dictionary == NULL) { + return ENOMEM; + } + + ArrowArrayViewInitFromType(array_view->dictionary, NANOARROW_TYPE_UNINITIALIZED); + return NANOARROW_OK; +} + ArrowErrorCode ArrowArrayViewInitFromSchema(struct ArrowArrayView* array_view, struct ArrowSchema* schema, struct ArrowError* error) { @@ -2277,6 +2306,7 @@ ArrowErrorCode ArrowArrayViewInitFromSchema(struct ArrowArrayView* array_view, result = ArrowArrayViewAllocateChildren(array_view, schema->n_children); if (result != NANOARROW_OK) { + ArrowErrorSet(error, "ArrowArrayViewAllocateChildren() failed"); ArrowArrayViewReset(array_view); return result; } @@ -2290,6 +2320,21 @@ ArrowErrorCode ArrowArrayViewInitFromSchema(struct ArrowArrayView* array_view, } } + if (schema->dictionary != NULL) { + result = ArrowArrayViewAllocateDictionary(array_view); + if (result != NANOARROW_OK) { + ArrowArrayViewReset(array_view); + return result; + } + + result = + ArrowArrayViewInitFromSchema(array_view->dictionary, schema->dictionary, error); + if (result != NANOARROW_OK) { + ArrowArrayViewReset(array_view); + return result; + } + } + if (array_view->storage_type == NANOARROW_TYPE_SPARSE_UNION || array_view->storage_type == NANOARROW_TYPE_DENSE_UNION) { array_view->union_type_id_map = (int8_t*)ArrowMalloc(256 * sizeof(int8_t)); @@ -2321,6 +2366,11 @@ void ArrowArrayViewReset(struct ArrowArrayView* array_view) { ArrowFree(array_view->children); } + if (array_view->dictionary != NULL) { + ArrowArrayViewReset(array_view->dictionary); + ArrowFree(array_view->dictionary); + } + if (array_view->union_type_id_map != NULL) { ArrowFree(array_view->union_type_id_map); } @@ -2331,7 +2381,6 @@ void ArrowArrayViewReset(struct ArrowArrayView* array_view) { void ArrowArrayViewSetLength(struct ArrowArrayView* array_view, int64_t length) { for (int i = 0; i < 3; i++) { int64_t element_size_bytes = array_view->layout.element_size_bits[i] / 8; - array_view->buffer_views[i].data.data = NULL; switch (array_view->layout.buffer_type[i]) { case NANOARROW_BUFFER_TYPE_VALIDITY: @@ -2375,11 +2424,11 @@ void ArrowArrayViewSetLength(struct ArrowArrayView* array_view, int64_t length) } } -ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, - struct ArrowArray* array, - struct ArrowError* error) { - array_view->array = array; - +// This version recursively extracts information from the array and stores it +// in the array view, performing any checks that require the original array. +static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, + struct ArrowArray* array, + struct ArrowError* error) { // Check length and offset if (array->offset < 0) { ArrowErrorSet(error, "Expected array offset >= 0 but found array offset of %ld", @@ -2393,8 +2442,10 @@ ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, return EINVAL; } - // First pass setting lengths that do not depend on the data buffer - ArrowArrayViewSetLength(array_view, array->offset + array->length); + array_view->array = array; + array_view->offset = array->offset; + array_view->length = array->length; + array_view->null_count = array->null_count; int64_t buffers_required = 0; for (int i = 0; i < 3; i++) { @@ -2404,28 +2455,187 @@ ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, buffers_required++; - // If the null_count is 0, the validity buffer can be NULL - if (array_view->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_VALIDITY && - array->null_count == 0 && array->buffers[i] == NULL) { + // Set buffer pointer + array_view->buffer_views[i].data.data = array->buffers[i]; + + // If non-null, set buffer size to unknown. + if (array->buffers[i] == NULL) { array_view->buffer_views[i].size_bytes = 0; + } else { + array_view->buffer_views[i].size_bytes = -1; } - - array_view->buffer_views[i].data.data = array->buffers[i]; } + // Check the number of buffers if (buffers_required != array->n_buffers) { ArrowErrorSet(error, "Expected array with %d buffer(s) but found %d buffer(s)", (int)buffers_required, (int)array->n_buffers); return EINVAL; } + // Check number of children if (array_view->n_children != array->n_children) { ArrowErrorSet(error, "Expected %ld children but found %ld children", (long)array_view->n_children, (long)array->n_children); return EINVAL; } - // Check child sizes and calculate sizes that depend on data in the array buffers + // Recurse for children + for (int64_t i = 0; i < array_view->n_children; i++) { + NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArrayInternal(array_view->children[i], + array->children[i], error)); + } + + // Check dictionary + if (array->dictionary == NULL && array_view->dictionary != NULL) { + ArrowErrorSet(error, "Expected dictionary but found NULL"); + return EINVAL; + } + + if (array->dictionary != NULL && array_view->dictionary == NULL) { + ArrowErrorSet(error, "Expected NULL dictionary but found dictionary member"); + return EINVAL; + } + + if (array->dictionary != NULL) { + NANOARROW_RETURN_NOT_OK( + ArrowArrayViewSetArrayInternal(array_view->dictionary, array->dictionary, error)); + } + + return NANOARROW_OK; +} + +static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, + struct ArrowError* error) { + // Calculate buffer sizes that do not require buffer access. If marked as + // unknown, assign the buffer size; otherwise, validate it. + int64_t offset_plus_length = array_view->offset + array_view->length; + + // Only loop over the first two buffers because the size of the third buffer + // is always data dependent for all current Arrow types. + for (int i = 0; i < 2; i++) { + int64_t element_size_bytes = array_view->layout.element_size_bits[i] / 8; + // Initialize with a value that will cause an error if accidentally used uninitialized + int64_t min_buffer_size_bytes = array_view->buffer_views[i].size_bytes + 1; + + switch (array_view->layout.buffer_type[i]) { + case NANOARROW_BUFFER_TYPE_VALIDITY: + if (array_view->null_count == 0 && array_view->buffer_views[i].size_bytes == 0) { + continue; + } + + min_buffer_size_bytes = _ArrowBytesForBits(offset_plus_length); + break; + case NANOARROW_BUFFER_TYPE_DATA_OFFSET: + // Probably don't want/need to rely on the producer to have allocated an + // offsets buffer of length 1 for a zero-size array + min_buffer_size_bytes = + (offset_plus_length != 0) * element_size_bytes * (offset_plus_length + 1); + break; + case NANOARROW_BUFFER_TYPE_DATA: + min_buffer_size_bytes = + _ArrowRoundUpToMultipleOf8(array_view->layout.element_size_bits[i] * + offset_plus_length) / + 8; + break; + case NANOARROW_BUFFER_TYPE_TYPE_ID: + case NANOARROW_BUFFER_TYPE_UNION_OFFSET: + min_buffer_size_bytes = element_size_bytes * offset_plus_length; + break; + case NANOARROW_BUFFER_TYPE_NONE: + continue; + } + + // Assign or validate buffer size + if (array_view->buffer_views[i].size_bytes == -1) { + array_view->buffer_views[i].size_bytes = min_buffer_size_bytes; + } else if (array_view->buffer_views[i].size_bytes < min_buffer_size_bytes) { + ArrowErrorSet(error, + "Expected %s array buffer %d to have size >= %ld bytes but found " + "buffer with %ld bytes", + ArrowTypeString(array_view->storage_type), (int)i, + (long)min_buffer_size_bytes, + (long)array_view->buffer_views[i].size_bytes); + return EINVAL; + } + } + + // For list, fixed-size list and map views, we can validate the number of children + switch (array_view->storage_type) { + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: + case NANOARROW_TYPE_FIXED_SIZE_LIST: + case NANOARROW_TYPE_MAP: + if (array_view->n_children != 1) { + ArrowErrorSet(error, "Expected 1 child of %s array but found %ld child arrays", + ArrowTypeString(array_view->storage_type), + (long)array_view->n_children); + return EINVAL; + } + default: + break; + } + + // For struct, the sparse union, and the fixed-size list views, we can validate child + // lengths. + int64_t child_min_length; + switch (array_view->storage_type) { + case NANOARROW_TYPE_SPARSE_UNION: + case NANOARROW_TYPE_STRUCT: + child_min_length = (array_view->offset + array_view->length); + for (int64_t i = 0; i < array_view->n_children; i++) { + if (array_view->children[i]->length < child_min_length) { + ArrowErrorSet( + error, + "Expected struct child %d to have length >= %ld but found child with " + "length %ld", + (int)(i + 1), (long)(child_min_length), + (long)array_view->children[i]->length); + return EINVAL; + } + } + break; + + case NANOARROW_TYPE_FIXED_SIZE_LIST: + child_min_length = (array_view->offset + array_view->length) * + array_view->layout.child_size_elements; + if (array_view->children[0]->length < child_min_length) { + ArrowErrorSet(error, + "Expected child of fixed_size_list array to have length >= %ld but " + "found array with length %ld", + (long)child_min_length, (long)array_view->children[0]->length); + return EINVAL; + } + break; + default: + break; + } + + // Recurse for children + for (int64_t i = 0; i < array_view->n_children; i++) { + NANOARROW_RETURN_NOT_OK( + ArrowArrayViewValidateMinimal(array_view->children[i], error)); + } + + // Recurse for dictionary + if (array_view->dictionary != NULL) { + NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateMinimal(array_view->dictionary, error)); + } + + return NANOARROW_OK; +} + +static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, + struct ArrowError* error) { + // Perform minimal validation. This will validate or assign + // buffer sizes as long as buffer access is not required. + NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateMinimal(array_view, error)); + + // Calculate buffer sizes or child lengths that require accessing the offsets + // buffer. Where appropriate, validate that the first offset is >= 0. + // If a buffer size is marked as unknown, assign it; otherwise, validate it. + int64_t offset_plus_length = array_view->offset + array_view->length; + int64_t first_offset; int64_t last_offset; switch (array_view->storage_type) { @@ -2439,11 +2649,22 @@ ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, return EINVAL; } - last_offset = - array_view->buffer_views[1].data.as_int32[array->offset + array->length]; - array_view->buffer_views[2].size_bytes = last_offset; + last_offset = array_view->buffer_views[1].data.as_int32[offset_plus_length]; + + // If the data buffer size is unknown, assign it; otherwise, check it + if (array_view->buffer_views[2].size_bytes == -1) { + array_view->buffer_views[2].size_bytes = last_offset; + } else if (array_view->buffer_views[2].size_bytes < last_offset) { + ArrowErrorSet(error, + "Expected %s array buffer 2 to have size >= %ld bytes but found " + "buffer with %ld bytes", + ArrowTypeString(array_view->storage_type), (long)last_offset, + (long)array_view->buffer_views[2].size_bytes); + return EINVAL; + } } break; + case NANOARROW_TYPE_LARGE_STRING: case NANOARROW_TYPE_LARGE_BINARY: if (array_view->buffer_views[1].size_bytes != 0) { @@ -2454,34 +2675,38 @@ ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, return EINVAL; } - last_offset = - array_view->buffer_views[1].data.as_int64[array->offset + array->length]; - array_view->buffer_views[2].size_bytes = last_offset; + last_offset = array_view->buffer_views[1].data.as_int64[offset_plus_length]; + + // If the data buffer size is unknown, assign it; otherwise, check it + if (array_view->buffer_views[2].size_bytes == -1) { + array_view->buffer_views[2].size_bytes = last_offset; + } else if (array_view->buffer_views[2].size_bytes < last_offset) { + ArrowErrorSet(error, + "Expected %s array buffer 2 to have size >= %ld bytes but found " + "buffer with %ld bytes", + ArrowTypeString(array_view->storage_type), (long)last_offset, + (long)array_view->buffer_views[2].size_bytes); + return EINVAL; + } } break; + case NANOARROW_TYPE_STRUCT: for (int64_t i = 0; i < array_view->n_children; i++) { - if (array->children[i]->length < (array->offset + array->length)) { + if (array_view->children[i]->length < offset_plus_length) { ArrowErrorSet( error, "Expected struct child %d to have length >= %ld but found child with " "length %ld", - (int)(i + 1), (long)(array->offset + array->length), - (long)array->children[i]->length); + (int)(i + 1), (long)offset_plus_length, + (long)array_view->children[i]->length); return EINVAL; } } break; - case NANOARROW_TYPE_LIST: - case NANOARROW_TYPE_MAP: { - const char* type_name = - array_view->storage_type == NANOARROW_TYPE_LIST ? "list" : "map"; - if (array->n_children != 1) { - ArrowErrorSet(error, "Expected 1 child of %s array but found %d child arrays", - type_name, (int)array->n_children); - return EINVAL; - } + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_MAP: if (array_view->buffer_views[1].size_bytes != 0) { first_offset = array_view->buffer_views[1].data.as_int32[0]; if (first_offset < 0) { @@ -2490,27 +2715,20 @@ ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, return EINVAL; } - last_offset = - array_view->buffer_views[1].data.as_int32[array->offset + array->length]; - if (array->children[0]->length < last_offset) { + last_offset = array_view->buffer_views[1].data.as_int32[offset_plus_length]; + if (array_view->children[0]->length < last_offset) { ArrowErrorSet( error, - "Expected child of %s array with length >= %ld but found array with " + "Expected child of %s array to have length >= %ld but found array with " "length %ld", - type_name, (long)last_offset, (long)array->children[0]->length); + ArrowTypeString(array_view->storage_type), (long)last_offset, + (long)array_view->children[0]->length); return EINVAL; } } break; - } - case NANOARROW_TYPE_LARGE_LIST: - if (array->n_children != 1) { - ArrowErrorSet(error, - "Expected 1 child of large list array but found %d child arrays", - (int)array->n_children); - return EINVAL; - } + case NANOARROW_TYPE_LARGE_LIST: if (array_view->buffer_views[1].size_bytes != 0) { first_offset = array_view->buffer_views[1].data.as_int64[0]; if (first_offset < 0) { @@ -2519,49 +2737,61 @@ ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, return EINVAL; } - last_offset = - array_view->buffer_views[1].data.as_int64[array->offset + array->length]; - if (array->children[0]->length < last_offset) { + last_offset = array_view->buffer_views[1].data.as_int64[offset_plus_length]; + if (array_view->children[0]->length < last_offset) { ArrowErrorSet( error, - "Expected child of large list array with length >= %ld but found array " + "Expected child of large list array to have length >= %ld but found array " "with length %ld", - (long)last_offset, (long)array->children[0]->length); + (long)last_offset, (long)array_view->children[0]->length); return EINVAL; } } break; - case NANOARROW_TYPE_FIXED_SIZE_LIST: - if (array->n_children != 1) { - ArrowErrorSet(error, - "Expected 1 child of fixed-size array but found %d child arrays", - (int)array->n_children); - return EINVAL; - } - - last_offset = - (array->offset + array->length) * array_view->layout.child_size_elements; - if (array->children[0]->length < last_offset) { - ArrowErrorSet( - error, - "Expected child of fixed-size list array with length >= %ld but found array " - "with length %ld", - (long)last_offset, (long)array->children[0]->length); - return EINVAL; - } - break; default: break; } + // Recurse for children for (int64_t i = 0; i < array_view->n_children; i++) { NANOARROW_RETURN_NOT_OK( - ArrowArrayViewSetArray(array_view->children[i], array->children[i], error)); + ArrowArrayViewValidateDefault(array_view->children[i], error)); + } + + // Recurse for dictionary + if (array_view->dictionary != NULL) { + NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateDefault(array_view->dictionary, error)); } return NANOARROW_OK; } +ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, + struct ArrowArray* array, + struct ArrowError* error) { + // Extract information from the array into the array view + NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArrayInternal(array_view, array, error)); + + // Run default validation. Because we've marked all non-NULL buffers as having unknown + // size, validation will also update the buffer sizes as it goes. + NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateDefault(array_view, error)); + + return NANOARROW_OK; +} + +ArrowErrorCode ArrowArrayViewSetArrayMinimal(struct ArrowArrayView* array_view, + struct ArrowArray* array, + struct ArrowError* error) { + // Extract information from the array into the array view + NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArrayInternal(array_view, array, error)); + + // Run default validation. Because we've marked all non-NULL buffers as having unknown + // size, validation will also update the buffer sizes as it goes. + NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateMinimal(array_view, error)); + + return NANOARROW_OK; +} + static int ArrowAssertIncreasingInt32(struct ArrowBufferView view, struct ArrowError* error) { if (view.size_bytes <= (int64_t)sizeof(int32_t)) { @@ -2633,8 +2863,8 @@ static int ArrowAssertInt8In(struct ArrowBufferView view, const int8_t* values, return NANOARROW_OK; } -ArrowErrorCode ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, - struct ArrowError* error) { +static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, + struct ArrowError* error) { for (int i = 0; i < 3; i++) { switch (array_view->layout.buffer_type[i]) { case NANOARROW_BUFFER_TYPE_DATA_OFFSET: @@ -2653,17 +2883,18 @@ ArrowErrorCode ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, if (array_view->storage_type == NANOARROW_TYPE_DENSE_UNION || array_view->storage_type == NANOARROW_TYPE_SPARSE_UNION) { - // Check that we have valid type ids. if (array_view->union_type_id_map == NULL) { - // If the union_type_id map is NULL - // (e.g., when using ArrowArrayInitFromType() + ArrowArrayAllocateChildren() - // + ArrowArrayFinishBuilding()), we don't have enough information to validate - // this buffer (GH-178). + // If the union_type_id map is NULL (e.g., when using ArrowArrayInitFromType() + + // ArrowArrayAllocateChildren() + ArrowArrayFinishBuilding()), we don't have enough + // information to validate this buffer. + ArrowErrorSet(error, + "Insufficient information provided for validation of union array"); + return EINVAL; } else if (_ArrowParsedUnionTypeIdsWillEqualChildIndices( array_view->union_type_id_map, array_view->n_children, array_view->n_children)) { - NANOARROW_RETURN_NOT_OK(ArrowAssertRangeInt8(array_view->buffer_views[0], 0, - array_view->n_children - 1, error)); + NANOARROW_RETURN_NOT_OK(ArrowAssertRangeInt8( + array_view->buffer_views[0], 0, (int8_t)(array_view->n_children - 1), error)); } else { NANOARROW_RETURN_NOT_OK(ArrowAssertInt8In(array_view->buffer_views[0], array_view->union_type_id_map + 128, @@ -2674,10 +2905,10 @@ ArrowErrorCode ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, if (array_view->storage_type == NANOARROW_TYPE_DENSE_UNION && array_view->union_type_id_map != NULL) { // Check that offsets refer to child elements that actually exist - for (int64_t i = 0; i < array_view->array->length; i++) { + for (int64_t i = 0; i < array_view->length; i++) { int8_t child_id = ArrowArrayViewUnionChildIndex(array_view, i); int64_t offset = ArrowArrayViewUnionChildOffset(array_view, i); - int64_t child_length = array_view->array->children[child_id]->length; + int64_t child_length = array_view->children[child_id]->length; if (offset < 0 || offset > child_length) { ArrowErrorSet( error, @@ -2689,12 +2920,38 @@ ArrowErrorCode ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, } } + // Recurse for children for (int64_t i = 0; i < array_view->n_children; i++) { NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateFull(array_view->children[i], error)); } + // Dictionary valiation not implemented + if (array_view->dictionary != NULL) { + ArrowErrorSet(error, "Validation for dictionary-encoded arrays is not implemented"); + return ENOTSUP; + } + return NANOARROW_OK; } + +ArrowErrorCode ArrowArrayViewValidate(struct ArrowArrayView* array_view, + enum ArrowValidationLevel validation_level, + struct ArrowError* error) { + switch (validation_level) { + case NANOARROW_VALIDATION_LEVEL_NONE: + return NANOARROW_OK; + case NANOARROW_VALIDATION_LEVEL_MINIMAL: + return ArrowArrayViewValidateMinimal(array_view, error); + case NANOARROW_VALIDATION_LEVEL_DEFAULT: + return ArrowArrayViewValidateDefault(array_view, error); + case NANOARROW_VALIDATION_LEVEL_FULL: + NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateDefault(array_view, error)); + return ArrowArrayViewValidateFull(array_view, error); + } + + ArrowErrorSet(error, "validation_level not recognized"); + return EINVAL; +} // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information diff --git a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h index 759c969..0131747 100644 --- a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h +++ b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h @@ -19,9 +19,9 @@ #define NANOARROW_BUILD_ID_H_INCLUDED #define NANOARROW_VERSION_MAJOR 0 -#define NANOARROW_VERSION_MINOR 2 +#define NANOARROW_VERSION_MINOR 3 #define NANOARROW_VERSION_PATCH 0 -#define NANOARROW_VERSION "0.2.0-SNAPSHOT" +#define NANOARROW_VERSION "0.3.0-SNAPSHOT" #define NANOARROW_VERSION_INT \ (NANOARROW_VERSION_MAJOR * 10000 + NANOARROW_VERSION_MINOR * 100 + \ @@ -55,6 +55,11 @@ +#if defined(NANOARROW_DEBUG) && !defined(NANOARROW_PRINT_AND_DIE) +#include +#include +#endif + #ifdef __cplusplus extern "C" { #endif @@ -194,6 +199,27 @@ static inline void ArrowArrayStreamMove(struct ArrowArrayStream* src, #define _NANOARROW_CHECK_UPPER_LIMIT(x_, max_) \ NANOARROW_RETURN_NOT_OK((x_ <= max_) ? NANOARROW_OK : EINVAL) +#if defined(NANOARROW_DEBUG) +#define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) { \ + ArrowErrorSet((ERROR_PTR_EXPR), "%s failed with errno %d\n* %s:%d", EXPR_STR, \ + NAME, __FILE__, __LINE__); \ + return NAME; \ + } \ + } while (0) +#else +#define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) { \ + ArrowErrorSet((ERROR_PTR_EXPR), "%s failed with errno %d", EXPR_STR, NAME); \ + return NAME; \ + } \ + } while (0) +#endif + /// \brief Return code for success. /// \ingroup nanoarrow-errors #define NANOARROW_OK 0 @@ -207,6 +233,47 @@ typedef int ArrowErrorCode; #define NANOARROW_RETURN_NOT_OK(EXPR) \ _NANOARROW_RETURN_NOT_OK_IMPL(_NANOARROW_MAKE_NAME(errno_status_, __COUNTER__), EXPR) +/// \brief Check the result of an expression and return it if not NANOARROW_OK, +/// adding an auto-generated message to an ArrowError. +/// \ingroup nanoarrow-errors +/// +/// This macro is used to ensure that functions that accept an ArrowError +/// as input always set its message when returning an error code (e.g., when calling +/// a nanoarrow function that does *not* accept ArrowError). +#define NANOARROW_RETURN_NOT_OK_WITH_ERROR(EXPR, ERROR_EXPR) \ + _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL( \ + _NANOARROW_MAKE_NAME(errno_status_, __COUNTER__), EXPR, ERROR_EXPR, #EXPR) + +#if defined(NANOARROW_DEBUG) && !defined(NANOARROW_PRINT_AND_DIE) +#define NANOARROW_PRINT_AND_DIE(VALUE, EXPR_STR) \ + do { \ + fprintf(stderr, "%s failed with errno %d\n* %s:%d\n", EXPR_STR, (int)(VALUE), \ + __FILE__, (int)__LINE__); \ + abort(); \ + } while (0) +#endif + +#if defined(NANOARROW_DEBUG) +#define _NANOARROW_ASSERT_OK_IMPL(NAME, EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) NANOARROW_PRINT_AND_DIE(NAME, EXPR_STR); \ + } while (0) + +/// \brief Assert that an expression's value is NANOARROW_OK +/// \ingroup nanoarrow-errors +/// +/// If nanoarrow was built in debug mode (i.e., defined(NANOARROW_DEBUG) is true), +/// print a message to stderr and abort. If nanoarrow was bulit in release mode, +/// this statement has no effect. You can customize fatal error behaviour +/// be defining the NANOARROW_PRINT_AND_DIE macro before including nanoarrow.h +/// This macro is provided as a convenience for users and is not used internally. +#define NANOARROW_ASSERT_OK(EXPR) \ + _NANOARROW_ASSERT_OK_IMPL(_NANOARROW_MAKE_NAME(errno_status_, __COUNTER__), EXPR, #EXPR) +#else +#define NANOARROW_ASSERT_OK(EXPR) EXPR +#endif + static char _ArrowIsLittleEndian(void) { uint32_t check = 1; char first_byte; @@ -266,6 +333,8 @@ enum ArrowType { /// \ingroup nanoarrow-utils /// /// Returns NULL for invalid values for type +static inline const char* ArrowTypeString(enum ArrowType type); + static inline const char* ArrowTypeString(enum ArrowType type) { switch (type) { case NANOARROW_TYPE_NA: @@ -384,6 +453,8 @@ enum ArrowValidationLevel { /// \ingroup nanoarrow-utils /// /// Returns NULL for invalid values for time_unit +static inline const char* ArrowTimeUnitString(enum ArrowTimeUnit time_unit); + static inline const char* ArrowTimeUnitString(enum ArrowTimeUnit time_unit) { switch (time_unit) { case NANOARROW_TIME_UNIT_SECOND: @@ -426,6 +497,8 @@ struct ArrowStringView { /// \brief Return a view of a const C string /// \ingroup nanoarrow-utils +static inline struct ArrowStringView ArrowCharView(const char* value); + static inline struct ArrowStringView ArrowCharView(const char* value) { struct ArrowStringView out; @@ -439,26 +512,28 @@ static inline struct ArrowStringView ArrowCharView(const char* value) { return out; } +union ArrowBufferViewData { + const void* data; + const int8_t* as_int8; + const uint8_t* as_uint8; + const int16_t* as_int16; + const uint16_t* as_uint16; + const int32_t* as_int32; + const uint32_t* as_uint32; + const int64_t* as_int64; + const uint64_t* as_uint64; + const double* as_double; + const float* as_float; + const char* as_char; +}; + /// \brief An non-owning view of a buffer /// \ingroup nanoarrow-utils struct ArrowBufferView { /// \brief A pointer to the start of the buffer /// /// If size_bytes is 0, this value may be NULL. - union { - const void* data; - const int8_t* as_int8; - const uint8_t* as_uint8; - const int16_t* as_int16; - const uint16_t* as_uint16; - const int32_t* as_int32; - const uint32_t* as_uint32; - const int64_t* as_int64; - const uint64_t* as_uint64; - const double* as_double; - const float* as_float; - const char* as_char; - } data; + union ArrowBufferViewData data; /// \brief The size of the buffer in bytes int64_t size_bytes; @@ -520,6 +595,9 @@ struct ArrowLayout { /// \brief The function of each buffer enum ArrowBufferType buffer_type[3]; + /// \brief The data type of each buffer + enum ArrowType buffer_data_type[3]; + /// \brief The size of an element each buffer or 0 if this size is variable or unknown int64_t element_size_bits[3]; @@ -534,12 +612,23 @@ struct ArrowLayout { /// This data structure provides access to the values contained within /// an ArrowArray with fields provided in a more readily-extractible /// form. You can re-use an ArrowArrayView for multiple ArrowArrays -/// with the same storage type, or use it to represent a hypothetical -/// ArrowArray that does not exist yet. +/// with the same storage type, use it to represent a hypothetical +/// ArrowArray that does not exist yet, or use it to validate the buffers +/// of a future ArrowArray. struct ArrowArrayView { - /// \brief The underlying ArrowArray or NULL if it has not been set + /// \brief The underlying ArrowArray or NULL if it has not been set or + /// if the buffers in this ArrowArrayView are not backed by an ArrowArray. struct ArrowArray* array; + /// \brief The number of elements from the physical start of the buffers. + int64_t offset; + + /// \brief The number of elements in this view. + int64_t length; + + /// \brief A cached null count or -1 to indicate that this value is unknown. + int64_t null_count; + /// \brief The type used to store values in this array /// /// This type represents only the minimum required information to @@ -560,6 +649,9 @@ struct ArrowArrayView { /// \brief Pointers to views of this array's children struct ArrowArrayView** children; + /// \brief Pointer to a view of this array's dictionary + struct ArrowArrayView* dictionary; + /// \brief Union type id to child index mapping /// /// If storage_type is a union type, a 256-byte ArrowMalloc()ed buffer @@ -596,6 +688,29 @@ struct ArrowArrayPrivateData { int8_t union_type_id_is_child_index; }; +/// \brief A representation of an interval. +/// \ingroup nanoarrow-utils +struct ArrowInterval { + /// \brief The type of interval being used + enum ArrowType type; + /// \brief The number of months represented by the interval + int32_t months; + /// \brief The number of days represented by the interval + int32_t days; + /// \brief The number of ms represented by the interval + int32_t ms; + /// \brief The number of ns represented by the interval + int64_t ns; +}; + +/// \brief Zero initialize an Interval with a given unit +/// \ingroup nanoarrow-utils +static inline void ArrowIntervalInit(struct ArrowInterval* interval, + enum ArrowType type) { + memset(interval, 0, sizeof(struct ArrowInterval)); + interval->type = type; +} + /// \brief A representation of a fixed-precision decimal number /// \ingroup nanoarrow-utils /// @@ -779,6 +894,10 @@ static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal, NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayInitFromType) #define ArrowArrayInitFromSchema \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayInitFromSchema) +#define ArrowArrayInitFromArrayView \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayInitFromArrayView) +#define ArrowArrayInitFromArrayView \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayInitFromArrayView) #define ArrowArrayAllocateChildren \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayAllocateChildren) #define ArrowArrayAllocateDictionary \ @@ -797,12 +916,16 @@ static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal, NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewInitFromSchema) #define ArrowArrayViewAllocateChildren \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewAllocateChildren) +#define ArrowArrayViewAllocateDictionary \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewAllocateDictionary) #define ArrowArrayViewSetLength \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewSetLength) #define ArrowArrayViewSetArray \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewSetArray) -#define ArrowArrayViewValidateFull \ - NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewValidateFull) +#define ArrowArrayViewSetArrayMinimal \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewSetArrayMinimal) +#define ArrowArrayViewValidate \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewValidate) #define ArrowArrayViewReset NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewReset) #define ArrowBasicArrayStreamInit \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowBasicArrayStreamInit) @@ -869,7 +992,16 @@ struct ArrowBufferAllocator ArrowBufferDeallocator( /// need to communicate more verbose error information accept a pointer /// to an ArrowError. This can be stack or statically allocated. The /// content of the message is undefined unless an error code has been -/// returned. +/// returned. If a nanoarrow function is passed a non-null ArrowError pointer, the +/// ArrowError pointed to by the argument will be propagated with a +/// null-terminated error message. It is safe to pass a NULL ArrowError anywhere +/// in the nanoarrow API. +/// +/// Except where documented, it is generally not safe to continue after a +/// function has returned a non-zero ArrowErrorCode. The NANOARROW_RETURN_NOT_OK and +/// NANOARROW_ASSERT_OK macros are provided to help propagate errors. C++ clients can use +/// the helpers provided in the nanoarrow.hpp header to facilitate using C++ idioms +/// for memory management and error propgagtion. /// /// @{ @@ -879,10 +1011,24 @@ struct ArrowError { char message[1024]; }; -/// \brief Set the contents of an error using printf syntax +/// \brief Ensure an ArrowError is null-terminated by zeroing the first character. +/// +/// If error is NULL, this function does nothing. +static inline void ArrowErrorInit(struct ArrowError* error) { + if (error) { + error->message[0] = '\0'; + } +} + +/// \brief Set the contents of an error using printf syntax. +/// +/// If error is NULL, this function does nothing and returns NANOARROW_OK. ArrowErrorCode ArrowErrorSet(struct ArrowError* error, const char* fmt, ...); /// \brief Get the contents of an error +/// +/// If error is NULL, returns "", or returns the contents of the error message +/// otherwise. const char* ArrowErrorMessage(struct ArrowError* error); /// @} @@ -1416,6 +1562,14 @@ ArrowErrorCode ArrowArrayInitFromSchema(struct ArrowArray* array, struct ArrowSchema* schema, struct ArrowError* error); +/// \brief Initialize the contents of an ArrowArray from an ArrowArrayView +/// +/// Caller is responsible for calling the array->release callback if +/// NANOARROW_OK is returned. +ArrowErrorCode ArrowArrayInitFromArrayView(struct ArrowArray* array, + struct ArrowArrayView* array_view, + struct ArrowError* error); + /// \brief Allocate the array->children array /// /// Includes the memory for each child struct ArrowArray, @@ -1518,6 +1672,13 @@ static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array, static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array, struct ArrowStringView value); +/// \brief Append a Interval to an array +/// +/// Returns NANOARROW_OK if value can be exactly represented by +/// the underlying storage type or EINVAL otherwise. +static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array, + struct ArrowInterval* value); + /// \brief Append a decimal value to an array /// /// Returns NANOARROW_OK if array is a decimal array with the appropriate @@ -1573,7 +1734,7 @@ ArrowErrorCode ArrowArrayFinishBuilding(struct ArrowArray* array, /// \defgroup nanoarrow-array-view Reading arrays /// -/// These functions read and validate the contents ArrowArray structures +/// These functions read and validate the contents ArrowArray structures. /// /// @{ @@ -1593,12 +1754,15 @@ ArrowErrorCode ArrowArrayViewInitFromSchema(struct ArrowArrayView* array_view, struct ArrowSchema* schema, struct ArrowError* error); -/// \brief Allocate the schema_view->children array +/// \brief Allocate the array_view->children array /// /// Includes the memory for each child struct ArrowArrayView ArrowErrorCode ArrowArrayViewAllocateChildren(struct ArrowArrayView* array_view, int64_t n_children); +/// \brief Allocate array_view->dictionary +ArrowErrorCode ArrowArrayViewAllocateDictionary(struct ArrowArrayView* array_view); + /// \brief Set data-independent buffer sizes from length void ArrowArrayViewSetLength(struct ArrowArrayView* array_view, int64_t length); @@ -1606,9 +1770,23 @@ void ArrowArrayViewSetLength(struct ArrowArrayView* array_view, int64_t length); ArrowErrorCode ArrowArrayViewSetArray(struct ArrowArrayView* array_view, struct ArrowArray* array, struct ArrowError* error); -/// \brief Performs extra checks on the array that was set via ArrowArrayViewSetArray() -ArrowErrorCode ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, - struct ArrowError* error); +/// \brief Set buffer sizes and data pointers from an ArrowArray except for those +/// that require dereferencing buffer content. +ArrowErrorCode ArrowArrayViewSetArrayMinimal(struct ArrowArrayView* array_view, + struct ArrowArray* array, + struct ArrowError* error); + +/// \brief Performs checks on the content of an ArrowArrayView +/// +/// If using ArrowArrayViewSetArray() to back array_view with an ArrowArray, +/// the buffer sizes and some content (fist and last offset) have already +/// been validated at the "default" level. If setting the buffer pointers +/// and sizes otherwise, you may wish to perform checks at a different level. See +/// documentation for ArrowValidationLevel for the details of checks performed +/// at each level. +ArrowErrorCode ArrowArrayViewValidate(struct ArrowArrayView* array_view, + enum ArrowValidationLevel validation_level, + struct ArrowError* error); /// \brief Reset the contents of an ArrowArrayView and frees resources void ArrowArrayViewReset(struct ArrowArrayView* array_view); @@ -1628,10 +1806,6 @@ static inline int8_t ArrowArrayViewUnionChildIndex(struct ArrowArrayView* array_ static inline int64_t ArrowArrayViewUnionChildOffset(struct ArrowArrayView* array_view, int64_t i); -/// \brief Get the index to use into the relevant list child array -static inline int64_t ArrowArrayViewListChildOffset(struct ArrowArrayView* array_view, - int64_t i); - /// \brief Get an element in an ArrowArrayView as an integer /// /// This function does not check for null values, that values are actually integers, or @@ -2019,36 +2193,37 @@ static inline int64_t ArrowBitCountSet(const uint8_t* bits, int64_t start_offset const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; + const int64_t i_last_valid = i_end - 1; const int64_t bytes_begin = i_begin / 8; - const int64_t bytes_end = i_end / 8 + 1; + const int64_t bytes_last_valid = i_last_valid / 8; - if (bytes_end == bytes_begin + 1) { + if (bytes_begin == bytes_last_valid) { // count bits within a single byte const uint8_t first_byte_mask = _ArrowkPrecedingBitmask[i_end % 8]; const uint8_t last_byte_mask = _ArrowkTrailingBitmask[i_begin % 8]; const uint8_t only_byte_mask = - i_end % 8 == 0 ? first_byte_mask : (uint8_t)(first_byte_mask & last_byte_mask); + i_end % 8 == 0 ? last_byte_mask : (uint8_t)(first_byte_mask & last_byte_mask); const uint8_t byte_masked = bits[bytes_begin] & only_byte_mask; return _ArrowkBytePopcount[byte_masked]; } const uint8_t first_byte_mask = _ArrowkPrecedingBitmask[i_begin % 8]; - const uint8_t last_byte_mask = _ArrowkTrailingBitmask[i_end % 8]; + const uint8_t last_byte_mask = i_end % 8 == 0 ? 0 : _ArrowkTrailingBitmask[i_end % 8]; int64_t count = 0; // first byte count += _ArrowkBytePopcount[bits[bytes_begin] & ~first_byte_mask]; // middle bytes - for (int64_t i = bytes_begin + 1; i < (bytes_end - 1); i++) { + for (int64_t i = bytes_begin + 1; i < bytes_last_valid; i++) { count += _ArrowkBytePopcount[bits[i]]; } // last byte - count += _ArrowkBytePopcount[bits[bytes_end - 1] & ~last_byte_mask]; + count += _ArrowkBytePopcount[bits[bytes_last_valid] & ~last_byte_mask]; return count; } @@ -2293,7 +2468,7 @@ static inline int8_t _ArrowParseUnionTypeIds(const char* type_ids, int8_t* out) } if (out != NULL) { - out[i] = type_id; + out[i] = (int8_t)type_id; } i++; @@ -2367,11 +2542,15 @@ static inline ArrowErrorCode ArrowArrayStartAppending(struct ArrowArray* array) } } - // Start building any child arrays + // Start building any child arrays or dictionaries for (int64_t i = 0; i < array->n_children; i++) { NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(array->children[i])); } + if (array->dictionary != NULL) { + NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(array->dictionary)); + } + return NANOARROW_OK; } @@ -2385,6 +2564,10 @@ static inline ArrowErrorCode ArrowArrayShrinkToFit(struct ArrowArray* array) { NANOARROW_RETURN_NOT_OK(ArrowArrayShrinkToFit(array->children[i])); } + if (array->dictionary != NULL) { + NANOARROW_RETURN_NOT_OK(ArrowArrayShrinkToFit(array->dictionary)); + } + return NANOARROW_OK; } @@ -2566,10 +2749,10 @@ static inline ArrowErrorCode ArrowArrayAppendInt(struct ArrowArray* array, _NANOARROW_CHECK_RANGE(value, 0, INT64_MAX); return ArrowArrayAppendUInt(array, value); case NANOARROW_TYPE_DOUBLE: - NANOARROW_RETURN_NOT_OK(ArrowBufferAppendDouble(data_buffer, value)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendDouble(data_buffer, (double)value)); break; case NANOARROW_TYPE_FLOAT: - NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, value)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; case NANOARROW_TYPE_BOOL: NANOARROW_RETURN_NOT_OK(_ArrowArrayAppendBits(array, 1, value != 0, 1)); @@ -2616,10 +2799,10 @@ static inline ArrowErrorCode ArrowArrayAppendUInt(struct ArrowArray* array, _NANOARROW_CHECK_UPPER_LIMIT(value, INT64_MAX); return ArrowArrayAppendInt(array, value); case NANOARROW_TYPE_DOUBLE: - NANOARROW_RETURN_NOT_OK(ArrowBufferAppendDouble(data_buffer, value)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendDouble(data_buffer, (double)value)); break; case NANOARROW_TYPE_FLOAT: - NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, value)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; case NANOARROW_TYPE_BOOL: NANOARROW_RETURN_NOT_OK(_ArrowArrayAppendBits(array, 1, value != 0, 1)); @@ -2682,7 +2865,7 @@ static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array, return EINVAL; } - offset += value.size_bytes; + offset += (int32_t)value.size_bytes; NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(offset_buffer, &offset, sizeof(int32_t))); NANOARROW_RETURN_NOT_OK( ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); @@ -2730,12 +2913,57 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array, switch (private_data->storage_type) { case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: return ArrowArrayAppendBytes(array, buffer_view); default: return EINVAL; } } +static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array, + struct ArrowInterval* value) { + struct ArrowArrayPrivateData* private_data = + (struct ArrowArrayPrivateData*)array->private_data; + + struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1); + + switch (private_data->storage_type) { + case NANOARROW_TYPE_INTERVAL_MONTHS: { + if (value->type != NANOARROW_TYPE_INTERVAL_MONTHS) { + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months)); + break; + } + case NANOARROW_TYPE_INTERVAL_DAY_TIME: { + if (value->type != NANOARROW_TYPE_INTERVAL_DAY_TIME) { + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->ms)); + break; + } + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + if (value->type != NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO) { + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days)); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt64(data_buffer, value->ns)); + break; + } + default: + return EINVAL; + } + + array->length++; + return NANOARROW_OK; +} + static inline ArrowErrorCode ArrowArrayAppendDecimal(struct ArrowArray* array, struct ArrowDecimal* value) { struct ArrowArrayPrivateData* private_data = @@ -2871,7 +3099,7 @@ static inline void ArrowArrayViewMove(struct ArrowArrayView* src, static inline int8_t ArrowArrayViewIsNull(struct ArrowArrayView* array_view, int64_t i) { const uint8_t* validity_buffer = array_view->buffer_views[0].data.as_uint8; - i += array_view->array->offset; + i += array_view->offset; switch (array_view->storage_type) { case NANOARROW_TYPE_NA: return 0x01; @@ -2917,7 +3145,6 @@ static inline int64_t ArrowArrayViewUnionChildOffset(struct ArrowArrayView* arra } } - static inline int64_t ArrowArrayViewListChildOffset(struct ArrowArrayView* array_view, int64_t i) { switch (array_view->storage_type) { @@ -2933,7 +3160,7 @@ static inline int64_t ArrowArrayViewListChildOffset(struct ArrowArrayView* array static inline int64_t ArrowArrayViewGetIntUnsafe(struct ArrowArrayView* array_view, int64_t i) { struct ArrowBufferView* data_view = &array_view->buffer_views[1]; - i += array_view->array->offset; + i += array_view->offset; switch (array_view->storage_type) { case NANOARROW_TYPE_INT64: return data_view->data.as_int64[i]; @@ -2952,9 +3179,9 @@ static inline int64_t ArrowArrayViewGetIntUnsafe(struct ArrowArrayView* array_vi case NANOARROW_TYPE_UINT8: return data_view->data.as_uint8[i]; case NANOARROW_TYPE_DOUBLE: - return data_view->data.as_double[i]; + return (int64_t)data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: - return data_view->data.as_float[i]; + return (int64_t)data_view->data.as_float[i]; case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -2964,7 +3191,7 @@ static inline int64_t ArrowArrayViewGetIntUnsafe(struct ArrowArrayView* array_vi static inline uint64_t ArrowArrayViewGetUIntUnsafe(struct ArrowArrayView* array_view, int64_t i) { - i += array_view->array->offset; + i += array_view->offset; struct ArrowBufferView* data_view = &array_view->buffer_views[1]; switch (array_view->storage_type) { case NANOARROW_TYPE_INT64: @@ -2984,9 +3211,9 @@ static inline uint64_t ArrowArrayViewGetUIntUnsafe(struct ArrowArrayView* array_ case NANOARROW_TYPE_UINT8: return data_view->data.as_uint8[i]; case NANOARROW_TYPE_DOUBLE: - return data_view->data.as_double[i]; + return (uint64_t)data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: - return data_view->data.as_float[i]; + return (uint64_t)data_view->data.as_float[i]; case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -2996,13 +3223,13 @@ static inline uint64_t ArrowArrayViewGetUIntUnsafe(struct ArrowArrayView* array_ static inline double ArrowArrayViewGetDoubleUnsafe(struct ArrowArrayView* array_view, int64_t i) { - i += array_view->array->offset; + i += array_view->offset; struct ArrowBufferView* data_view = &array_view->buffer_views[1]; switch (array_view->storage_type) { case NANOARROW_TYPE_INT64: - return data_view->data.as_int64[i]; + return (double)data_view->data.as_int64[i]; case NANOARROW_TYPE_UINT64: - return data_view->data.as_uint64[i]; + return (double)data_view->data.as_uint64[i]; case NANOARROW_TYPE_INT32: return data_view->data.as_int32[i]; case NANOARROW_TYPE_UINT32: @@ -3028,7 +3255,7 @@ static inline double ArrowArrayViewGetDoubleUnsafe(struct ArrowArrayView* array_ static inline struct ArrowStringView ArrowArrayViewGetStringUnsafe( struct ArrowArrayView* array_view, int64_t i) { - i += array_view->array->offset; + i += array_view->offset; struct ArrowBufferView* offsets_view = &array_view->buffer_views[1]; const char* data_view = array_view->buffer_views[2].data.as_char; @@ -3061,7 +3288,7 @@ static inline struct ArrowStringView ArrowArrayViewGetStringUnsafe( static inline struct ArrowBufferView ArrowArrayViewGetBytesUnsafe( struct ArrowArrayView* array_view, int64_t i) { - i += array_view->array->offset; + i += array_view->offset; struct ArrowBufferView* offsets_view = &array_view->buffer_views[1]; const uint8_t* data_view = array_view->buffer_views[2].data.as_uint8; @@ -3093,9 +3320,36 @@ static inline struct ArrowBufferView ArrowArrayViewGetBytesUnsafe( return view; } +static inline void ArrowArrayViewGetIntervalUnsafe(struct ArrowArrayView* array_view, + int64_t i, struct ArrowInterval* out) { + const uint8_t* data_view = array_view->buffer_views[1].data.as_uint8; + switch (array_view->storage_type) { + case NANOARROW_TYPE_INTERVAL_MONTHS: { + const size_t size = sizeof(int32_t); + memcpy(&out->months, data_view + i * size, sizeof(int32_t)); + break; + } + case NANOARROW_TYPE_INTERVAL_DAY_TIME: { + const size_t size = sizeof(int32_t) + sizeof(int32_t); + memcpy(&out->days, data_view + i * size, sizeof(int32_t)); + memcpy(&out->ms, data_view + i * size + 4, sizeof(int32_t)); + break; + } + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + const size_t size = sizeof(int32_t) + sizeof(int32_t) + sizeof(int64_t); + memcpy(&out->months, data_view + i * size, sizeof(int32_t)); + memcpy(&out->days, data_view + i * size + 4, sizeof(int32_t)); + memcpy(&out->ns, data_view + i * size + 8, sizeof(int64_t)); + break; + } + default: + break; + } +} + static inline void ArrowArrayViewGetDecimalUnsafe(struct ArrowArrayView* array_view, int64_t i, struct ArrowDecimal* out) { - i += array_view->array->offset; + i += array_view->offset; const uint8_t* data_view = array_view->buffer_views[1].data.as_uint8; switch (array_view->storage_type) { case NANOARROW_TYPE_DECIMAL128: diff --git a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp index 468e911..da54a57 100644 --- a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp +++ b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. +#include #include +#include #include "nanoarrow.h" @@ -31,6 +33,55 @@ namespace nanoarrow { +/// \defgroup nanoarrow_hpp-errors Error handling helpers +/// +/// Most functions in the C API return an ArrowErrorCode to communicate +/// possible failure. Except where documented, it is usually not safe to +/// continue after a non-zero value has been returned. While the +/// nanoarrow C++ helpers do not throw any exceptions of their own, +/// these helpers are provided to facilitate using the nanoarrow C++ helpers +/// in frameworks where this is a useful error handling idiom. +/// +/// @{ + +class Exception : public std::exception { + public: + Exception(const std::string& msg) : msg_(msg) {} + const char* what() const noexcept { return msg_.c_str(); } + + private: + std::string msg_; +}; + +#if defined(NANOARROW_DEBUG) +#define _NANOARROW_THROW_NOT_OK_IMPL(NAME, EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) { \ + throw nanoarrow::Exception( \ + std::string(EXPR_STR) + std::string(" failed with errno ") + \ + std::to_string(NAME) + std::string("\n * ") + std::string(__FILE__) + \ + std::string(":") + std::to_string(__LINE__) + std::string("\n")); \ + } \ + } while (0) +#else +#define _NANOARROW_THROW_NOT_OK_IMPL(NAME, EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) { \ + throw nanoarrow::Exception(std::string(EXPR_STR) + \ + std::string(" failed with errno ") + \ + std::to_string(NAME)); \ + } \ + } while (0) +#endif + +#define NANOARROW_THROW_NOT_OK(EXPR) \ + _NANOARROW_THROW_NOT_OK_IMPL(_NANOARROW_MAKE_NAME(errno_status_, __COUNTER__), EXPR, \ + #EXPR) + +/// @} + namespace internal { /// \defgroup nanoarrow_hpp-unique_base Base classes for Unique wrappers diff --git a/3rd_party/apache-arrow-adbc/c/vendor/portable-snippets/safe-math.h b/3rd_party/apache-arrow-adbc/c/vendor/portable-snippets/safe-math.h new file mode 100644 index 0000000..797404a --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/vendor/portable-snippets/safe-math.h @@ -0,0 +1,1076 @@ +/* Overflow-safe math functions + * Portable Snippets - https://github.com/nemequ/portable-snippets + * Created by Evan Nemerson + * + * To the extent possible under law, the authors have waived all + * copyright and related or neighboring rights to this code. For + * details, see the Creative Commons Zero 1.0 Universal license at + * https://creativecommons.org/publicdomain/zero/1.0/ + */ + +#if !defined(PSNIP_SAFE_H) +#define PSNIP_SAFE_H + +#if !defined(PSNIP_SAFE_FORCE_PORTABLE) +# if defined(__has_builtin) +# if __has_builtin(__builtin_add_overflow) && !defined(__ibmxl__) +# define PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW +# endif +# elif defined(__GNUC__) && (__GNUC__ >= 5) && !defined(__INTEL_COMPILER) +# define PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW +# endif +# if defined(__has_include) +# if __has_include() +# define PSNIP_SAFE_HAVE_INTSAFE_H +# endif +# elif defined(_WIN32) +# define PSNIP_SAFE_HAVE_INTSAFE_H +# endif +#endif /* !defined(PSNIP_SAFE_FORCE_PORTABLE) */ + +#if defined(__GNUC__) +# define PSNIP_SAFE_LIKELY(expr) __builtin_expect(!!(expr), 1) +# define PSNIP_SAFE_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#else +# define PSNIP_SAFE_LIKELY(expr) !!(expr) +# define PSNIP_SAFE_UNLIKELY(expr) !!(expr) +#endif /* defined(__GNUC__) */ + +#if !defined(PSNIP_SAFE_STATIC_INLINE) +# if defined(__GNUC__) +# define PSNIP_SAFE__COMPILER_ATTRIBUTES __attribute__((__unused__)) +# else +# define PSNIP_SAFE__COMPILER_ATTRIBUTES +# endif + +# if defined(HEDLEY_INLINE) +# define PSNIP_SAFE__INLINE HEDLEY_INLINE +# elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +# define PSNIP_SAFE__INLINE inline +# elif defined(__GNUC_STDC_INLINE__) +# define PSNIP_SAFE__INLINE __inline__ +# elif defined(_MSC_VER) && _MSC_VER >= 1200 +# define PSNIP_SAFE__INLINE __inline +# else +# define PSNIP_SAFE__INLINE +# endif + +# define PSNIP_SAFE__FUNCTION PSNIP_SAFE__COMPILER_ATTRIBUTES static PSNIP_SAFE__INLINE +#endif + +// !defined(__cplusplus) added for Solaris support +#if !defined(__cplusplus) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +# define psnip_safe_bool _Bool +#else +# define psnip_safe_bool int +#endif + +#if !defined(PSNIP_SAFE_NO_FIXED) +/* For maximum portability include the exact-int module from + portable snippets. */ +# if \ + !defined(psnip_int64_t) || !defined(psnip_uint64_t) || \ + !defined(psnip_int32_t) || !defined(psnip_uint32_t) || \ + !defined(psnip_int16_t) || !defined(psnip_uint16_t) || \ + !defined(psnip_int8_t) || !defined(psnip_uint8_t) +# include +# if !defined(psnip_int64_t) +# define psnip_int64_t int64_t +# endif +# if !defined(psnip_uint64_t) +# define psnip_uint64_t uint64_t +# endif +# if !defined(psnip_int32_t) +# define psnip_int32_t int32_t +# endif +# if !defined(psnip_uint32_t) +# define psnip_uint32_t uint32_t +# endif +# if !defined(psnip_int16_t) +# define psnip_int16_t int16_t +# endif +# if !defined(psnip_uint16_t) +# define psnip_uint16_t uint16_t +# endif +# if !defined(psnip_int8_t) +# define psnip_int8_t int8_t +# endif +# if !defined(psnip_uint8_t) +# define psnip_uint8_t uint8_t +# endif +# endif +#endif /* !defined(PSNIP_SAFE_NO_FIXED) */ +#include +#include + +#if !defined(PSNIP_SAFE_SIZE_MAX) +# if defined(__SIZE_MAX__) +# define PSNIP_SAFE_SIZE_MAX __SIZE_MAX__ +# elif defined(PSNIP_EXACT_INT_HAVE_STDINT) +# include +# endif +#endif + +#if defined(PSNIP_SAFE_SIZE_MAX) +# define PSNIP_SAFE__SIZE_MAX_RT PSNIP_SAFE_SIZE_MAX +#else +# define PSNIP_SAFE__SIZE_MAX_RT (~((size_t) 0)) +#endif + +#if defined(PSNIP_SAFE_HAVE_INTSAFE_H) +/* In VS 10, stdint.h and intsafe.h both define (U)INTN_MIN/MAX, which + triggers warning C4005 (level 1). */ +# if defined(_MSC_VER) && (_MSC_VER == 1600) +# pragma warning(push) +# pragma warning(disable:4005) +# endif +# include +# if defined(_MSC_VER) && (_MSC_VER == 1600) +# pragma warning(pop) +# endif +#endif /* defined(PSNIP_SAFE_HAVE_INTSAFE_H) */ + +/* If there is a type larger than the one we're concerned with it's + * likely much faster to simply promote the operands, perform the + * requested operation, verify that the result falls within the + * original type, then cast the result back to the original type. */ + +#if !defined(PSNIP_SAFE_NO_PROMOTIONS) + +#define PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, op_name, op) \ + PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \ + psnip_safe_larger_##name##_##op_name (T a, T b) { \ + return ((psnip_safe_##name##_larger) a) op ((psnip_safe_##name##_larger) b); \ + } + +#define PSNIP_SAFE_DEFINE_LARGER_UNARY_OP(T, name, op_name, op) \ + PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \ + psnip_safe_larger_##name##_##op_name (T value) { \ + return (op ((psnip_safe_##name##_larger) value)); \ + } + +#define PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(T, name) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, add, +) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, sub, -) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mul, *) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, div, /) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mod, %) \ + PSNIP_SAFE_DEFINE_LARGER_UNARY_OP (T, name, neg, -) + +#define PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(T, name) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, add, +) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, sub, -) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mul, *) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, div, /) \ + PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, mod, %) + +#define PSNIP_SAFE_IS_LARGER(ORIG_MAX, DEST_MAX) ((DEST_MAX / ORIG_MAX) >= ORIG_MAX) + +// Using __int128 intrinsics causes compilation to fail with -Wpedantic +// which is required to pass CRAN incoming checks for R packages that use this header +#if defined(PSNIP_USE_INTRINSIC_INT128) +#if defined(__GNUC__) && ((__GNUC__ >= 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__SIZEOF_INT128__) && !defined(__ibmxl__) +#define PSNIP_SAFE_HAVE_128 +typedef __int128 psnip_safe_int128_t; +typedef unsigned __int128 psnip_safe_uint128_t; +#endif /* defined(__GNUC__) */ +#endif + +#if !defined(PSNIP_SAFE_NO_FIXED) +#define PSNIP_SAFE_HAVE_INT8_LARGER +#define PSNIP_SAFE_HAVE_UINT8_LARGER +typedef psnip_int16_t psnip_safe_int8_larger; +typedef psnip_uint16_t psnip_safe_uint8_larger; + +#define PSNIP_SAFE_HAVE_INT16_LARGER +typedef psnip_int32_t psnip_safe_int16_larger; +typedef psnip_uint32_t psnip_safe_uint16_larger; + +#define PSNIP_SAFE_HAVE_INT32_LARGER +typedef psnip_int64_t psnip_safe_int32_larger; +typedef psnip_uint64_t psnip_safe_uint32_larger; + +#if defined(PSNIP_SAFE_HAVE_128) +#define PSNIP_SAFE_HAVE_INT64_LARGER +typedef psnip_safe_int128_t psnip_safe_int64_larger; +typedef psnip_safe_uint128_t psnip_safe_uint64_larger; +#endif /* defined(PSNIP_SAFE_HAVE_128) */ +#endif /* !defined(PSNIP_SAFE_NO_FIXED) */ + +#define PSNIP_SAFE_HAVE_LARGER_SCHAR +#if PSNIP_SAFE_IS_LARGER(SCHAR_MAX, SHRT_MAX) +typedef short psnip_safe_schar_larger; +#elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, INT_MAX) +typedef int psnip_safe_schar_larger; +#elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, LONG_MAX) +typedef long psnip_safe_schar_larger; +#elif PSNIP_SAFE_IS_LARGER(SCHAR_MAX, LLONG_MAX) +typedef long long psnip_safe_schar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fff) +typedef psnip_int16_t psnip_safe_schar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fffffffLL) +typedef psnip_int32_t psnip_safe_schar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SCHAR_MAX, 0x7fffffffffffffffLL) +typedef psnip_int64_t psnip_safe_schar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (SCHAR_MAX <= 0x7fffffffffffffffLL) +typedef psnip_safe_int128_t psnip_safe_schar_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_SCHAR +#endif + +#define PSNIP_SAFE_HAVE_LARGER_UCHAR +#if PSNIP_SAFE_IS_LARGER(UCHAR_MAX, USHRT_MAX) +typedef unsigned short psnip_safe_uchar_larger; +#elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, UINT_MAX) +typedef unsigned int psnip_safe_uchar_larger; +#elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, ULONG_MAX) +typedef unsigned long psnip_safe_uchar_larger; +#elif PSNIP_SAFE_IS_LARGER(UCHAR_MAX, ULLONG_MAX) +typedef unsigned long long psnip_safe_uchar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffU) +typedef psnip_uint16_t psnip_safe_uchar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffffffUL) +typedef psnip_uint32_t psnip_safe_uchar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UCHAR_MAX, 0xffffffffffffffffULL) +typedef psnip_uint64_t psnip_safe_uchar_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (UCHAR_MAX <= 0xffffffffffffffffULL) +typedef psnip_safe_uint128_t psnip_safe_uchar_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_UCHAR +#endif + +#if CHAR_MIN == 0 && defined(PSNIP_SAFE_HAVE_LARGER_UCHAR) +#define PSNIP_SAFE_HAVE_LARGER_CHAR +typedef psnip_safe_uchar_larger psnip_safe_char_larger; +#elif CHAR_MIN < 0 && defined(PSNIP_SAFE_HAVE_LARGER_SCHAR) +#define PSNIP_SAFE_HAVE_LARGER_CHAR +typedef psnip_safe_schar_larger psnip_safe_char_larger; +#endif + +#define PSNIP_SAFE_HAVE_LARGER_SHRT +#if PSNIP_SAFE_IS_LARGER(SHRT_MAX, INT_MAX) +typedef int psnip_safe_short_larger; +#elif PSNIP_SAFE_IS_LARGER(SHRT_MAX, LONG_MAX) +typedef long psnip_safe_short_larger; +#elif PSNIP_SAFE_IS_LARGER(SHRT_MAX, LLONG_MAX) +typedef long long psnip_safe_short_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fff) +typedef psnip_int16_t psnip_safe_short_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fffffffLL) +typedef psnip_int32_t psnip_safe_short_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(SHRT_MAX, 0x7fffffffffffffffLL) +typedef psnip_int64_t psnip_safe_short_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (SHRT_MAX <= 0x7fffffffffffffffLL) +typedef psnip_safe_int128_t psnip_safe_short_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_SHRT +#endif + +#define PSNIP_SAFE_HAVE_LARGER_USHRT +#if PSNIP_SAFE_IS_LARGER(USHRT_MAX, UINT_MAX) +typedef unsigned int psnip_safe_ushort_larger; +#elif PSNIP_SAFE_IS_LARGER(USHRT_MAX, ULONG_MAX) +typedef unsigned long psnip_safe_ushort_larger; +#elif PSNIP_SAFE_IS_LARGER(USHRT_MAX, ULLONG_MAX) +typedef unsigned long long psnip_safe_ushort_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffff) +typedef psnip_uint16_t psnip_safe_ushort_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffffffffUL) +typedef psnip_uint32_t psnip_safe_ushort_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(USHRT_MAX, 0xffffffffffffffffULL) +typedef psnip_uint64_t psnip_safe_ushort_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (USHRT_MAX <= 0xffffffffffffffffULL) +typedef psnip_safe_uint128_t psnip_safe_ushort_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_USHRT +#endif + +#define PSNIP_SAFE_HAVE_LARGER_INT +#if PSNIP_SAFE_IS_LARGER(INT_MAX, LONG_MAX) +typedef long psnip_safe_int_larger; +#elif PSNIP_SAFE_IS_LARGER(INT_MAX, LLONG_MAX) +typedef long long psnip_safe_int_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fff) +typedef psnip_int16_t psnip_safe_int_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fffffffLL) +typedef psnip_int32_t psnip_safe_int_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(INT_MAX, 0x7fffffffffffffffLL) +typedef psnip_int64_t psnip_safe_int_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (INT_MAX <= 0x7fffffffffffffffLL) +typedef psnip_safe_int128_t psnip_safe_int_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_INT +#endif + +#define PSNIP_SAFE_HAVE_LARGER_UINT +#if PSNIP_SAFE_IS_LARGER(UINT_MAX, ULONG_MAX) +typedef unsigned long psnip_safe_uint_larger; +#elif PSNIP_SAFE_IS_LARGER(UINT_MAX, ULLONG_MAX) +typedef unsigned long long psnip_safe_uint_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffff) +typedef psnip_uint16_t psnip_safe_uint_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffffffffUL) +typedef psnip_uint32_t psnip_safe_uint_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(UINT_MAX, 0xffffffffffffffffULL) +typedef psnip_uint64_t psnip_safe_uint_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (UINT_MAX <= 0xffffffffffffffffULL) +typedef psnip_safe_uint128_t psnip_safe_uint_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_UINT +#endif + +#define PSNIP_SAFE_HAVE_LARGER_LONG +#if PSNIP_SAFE_IS_LARGER(LONG_MAX, LLONG_MAX) +typedef long long psnip_safe_long_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fff) +typedef psnip_int16_t psnip_safe_long_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fffffffLL) +typedef psnip_int32_t psnip_safe_long_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LONG_MAX, 0x7fffffffffffffffLL) +typedef psnip_int64_t psnip_safe_long_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (LONG_MAX <= 0x7fffffffffffffffLL) +typedef psnip_safe_int128_t psnip_safe_long_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_LONG +#endif + +#define PSNIP_SAFE_HAVE_LARGER_ULONG +#if PSNIP_SAFE_IS_LARGER(ULONG_MAX, ULLONG_MAX) +typedef unsigned long long psnip_safe_ulong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffff) +typedef psnip_uint16_t psnip_safe_ulong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffffffffUL) +typedef psnip_uint32_t psnip_safe_ulong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULONG_MAX, 0xffffffffffffffffULL) +typedef psnip_uint64_t psnip_safe_ulong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (ULONG_MAX <= 0xffffffffffffffffULL) +typedef psnip_safe_uint128_t psnip_safe_ulong_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_ULONG +#endif + +#define PSNIP_SAFE_HAVE_LARGER_LLONG +#if !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fff) +typedef psnip_int16_t psnip_safe_llong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fffffffLL) +typedef psnip_int32_t psnip_safe_llong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(LLONG_MAX, 0x7fffffffffffffffLL) +typedef psnip_int64_t psnip_safe_llong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (LLONG_MAX <= 0x7fffffffffffffffLL) +typedef psnip_safe_int128_t psnip_safe_llong_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_LLONG +#endif + +#define PSNIP_SAFE_HAVE_LARGER_ULLONG +#if !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffff) +typedef psnip_uint16_t psnip_safe_ullong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffffffffUL) +typedef psnip_uint32_t psnip_safe_ullong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(ULLONG_MAX, 0xffffffffffffffffULL) +typedef psnip_uint64_t psnip_safe_ullong_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (ULLONG_MAX <= 0xffffffffffffffffULL) +typedef psnip_safe_uint128_t psnip_safe_ullong_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_ULLONG +#endif + +#if defined(PSNIP_SAFE_SIZE_MAX) +#define PSNIP_SAFE_HAVE_LARGER_SIZE +#if PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, USHRT_MAX) +typedef unsigned short psnip_safe_size_larger; +#elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, UINT_MAX) +typedef unsigned int psnip_safe_size_larger; +#elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, ULONG_MAX) +typedef unsigned long psnip_safe_size_larger; +#elif PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, ULLONG_MAX) +typedef unsigned long long psnip_safe_size_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffff) +typedef psnip_uint16_t psnip_safe_size_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffffffffUL) +typedef psnip_uint32_t psnip_safe_size_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && PSNIP_SAFE_IS_LARGER(PSNIP_SAFE_SIZE_MAX, 0xffffffffffffffffULL) +typedef psnip_uint64_t psnip_safe_size_larger; +#elif !defined(PSNIP_SAFE_NO_FIXED) && defined(PSNIP_SAFE_HAVE_128) && (PSNIP_SAFE_SIZE_MAX <= 0xffffffffffffffffULL) +typedef psnip_safe_uint128_t psnip_safe_size_larger; +#else +#undef PSNIP_SAFE_HAVE_LARGER_SIZE +#endif +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_SCHAR) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(signed char, schar) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_UCHAR) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned char, uchar) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_CHAR) +#if CHAR_MIN == 0 +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(char, char) +#else +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(char, char) +#endif +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_SHORT) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(short, short) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_USHORT) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned short, ushort) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_INT) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(int, int) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_UINT) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned int, uint) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_LONG) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(long, long) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_ULONG) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned long, ulong) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_LLONG) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(long long, llong) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_ULLONG) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(unsigned long long, ullong) +#endif + +#if defined(PSNIP_SAFE_HAVE_LARGER_SIZE) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(size_t, size) +#endif + +#if !defined(PSNIP_SAFE_NO_FIXED) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int8_t, int8) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint8_t, uint8) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int16_t, int16) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint16_t, uint16) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int32_t, int32) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint32_t, uint32) +#if defined(PSNIP_SAFE_HAVE_128) +PSNIP_SAFE_DEFINE_LARGER_SIGNED_OPS(psnip_int64_t, int64) +PSNIP_SAFE_DEFINE_LARGER_UNSIGNED_OPS(psnip_uint64_t, uint64) +#endif +#endif + +#endif /* !defined(PSNIP_SAFE_NO_PROMOTIONS) */ + +#define PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(T, name, op_name) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_##op_name(T* res, T a, T b) { \ + return !__builtin_##op_name##_overflow(a, b, res); \ + } + +#define PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(T, name, op_name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_##op_name(T* res, T a, T b) { \ + const psnip_safe_##name##_larger r = psnip_safe_larger_##name##_##op_name(a, b); \ + *res = (T) r; \ + return (r >= min) && (r <= max); \ + } + +#define PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(T, name, op_name, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_##op_name(T* res, T a, T b) { \ + const psnip_safe_##name##_larger r = psnip_safe_larger_##name##_##op_name(a, b); \ + *res = (T) r; \ + return (r <= max); \ + } + +#define PSNIP_SAFE_DEFINE_SIGNED_ADD(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_add (T* res, T a, T b) { \ + psnip_safe_bool r = !( ((b > 0) && (a > (max - b))) || \ + ((b < 0) && (a < (min - b))) ); \ + if(PSNIP_SAFE_LIKELY(r)) \ + *res = a + b; \ + return r; \ + } + +#define PSNIP_SAFE_DEFINE_UNSIGNED_ADD(T, name, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_add (T* res, T a, T b) { \ + *res = (T) (a + b); \ + return !PSNIP_SAFE_UNLIKELY((b > 0) && (a > (max - b))); \ + } + +#define PSNIP_SAFE_DEFINE_SIGNED_SUB(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_sub (T* res, T a, T b) { \ + psnip_safe_bool r = !((b > 0 && a < (min + b)) || \ + (b < 0 && a > (max + b))); \ + if(PSNIP_SAFE_LIKELY(r)) \ + *res = a - b; \ + return r; \ + } + +#define PSNIP_SAFE_DEFINE_UNSIGNED_SUB(T, name, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_sub (T* res, T a, T b) { \ + *res = a - b; \ + return !PSNIP_SAFE_UNLIKELY(b > a); \ + } + +#define PSNIP_SAFE_DEFINE_SIGNED_MUL(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_mul (T* res, T a, T b) { \ + psnip_safe_bool r = 1; \ + if (a > 0) { \ + if (b > 0) { \ + if (a > (max / b)) { \ + r = 0; \ + } \ + } else { \ + if (b < (min / a)) { \ + r = 0; \ + } \ + } \ + } else { \ + if (b > 0) { \ + if (a < (min / b)) { \ + r = 0; \ + } \ + } else { \ + if ( (a != 0) && (b < (max / a))) { \ + r = 0; \ + } \ + } \ + } \ + if(PSNIP_SAFE_LIKELY(r)) \ + *res = a * b; \ + return r; \ + } + +#define PSNIP_SAFE_DEFINE_UNSIGNED_MUL(T, name, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_mul (T* res, T a, T b) { \ + *res = (T) (a * b); \ + return !PSNIP_SAFE_UNLIKELY((a > 0) && (b > 0) && (a > (max / b))); \ + } + +#define PSNIP_SAFE_DEFINE_SIGNED_DIV(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_div (T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \ + *res = min; \ + return 0; \ + } else { \ + *res = (T) (a / b); \ + return 1; \ + } \ + } + +#define PSNIP_SAFE_DEFINE_UNSIGNED_DIV(T, name, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_div (T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else { \ + *res = a / b; \ + return 1; \ + } \ + } + +#define PSNIP_SAFE_DEFINE_SIGNED_MOD(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_mod (T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else if (PSNIP_SAFE_UNLIKELY(a == min && b == -1)) { \ + *res = min; \ + return 0; \ + } else { \ + *res = (T) (a % b); \ + return 1; \ + } \ + } + +#define PSNIP_SAFE_DEFINE_UNSIGNED_MOD(T, name, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_mod (T* res, T a, T b) { \ + if (PSNIP_SAFE_UNLIKELY(b == 0)) { \ + *res = 0; \ + return 0; \ + } else { \ + *res = a % b; \ + return 1; \ + } \ + } + +#define PSNIP_SAFE_DEFINE_SIGNED_NEG(T, name, min, max) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_neg (T* res, T value) { \ + psnip_safe_bool r = value != min; \ + *res = PSNIP_SAFE_LIKELY(r) ? -value : max; \ + return r; \ + } + +#define PSNIP_SAFE_DEFINE_INTSAFE(T, name, op, isf) \ + PSNIP_SAFE__FUNCTION psnip_safe_bool \ + psnip_safe_##name##_##op (T* res, T a, T b) { \ + return isf(a, b, res) == S_OK; \ + } + +#if CHAR_MIN == 0 +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_CHAR) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(char, char, add, CHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(char, char, sub, CHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(char, char, mul, CHAR_MAX) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(char, char, CHAR_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(char, char, CHAR_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(char, char, CHAR_MAX) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(char, char, CHAR_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(char, char, CHAR_MAX) +#else /* CHAR_MIN != 0 */ +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(char, char, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_CHAR) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(char, char, add, CHAR_MIN, CHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(char, char, sub, CHAR_MIN, CHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(char, char, mul, CHAR_MIN, CHAR_MAX) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(char, char, CHAR_MIN, CHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_SUB(char, char, CHAR_MIN, CHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MUL(char, char, CHAR_MIN, CHAR_MAX) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(char, char, CHAR_MIN, CHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MOD(char, char, CHAR_MIN, CHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_NEG(char, char, CHAR_MIN, CHAR_MAX) +#endif + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(signed char, schar, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(signed char, schar, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(signed char, schar, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_SCHAR) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(signed char, schar, add, SCHAR_MIN, SCHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(signed char, schar, sub, SCHAR_MIN, SCHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(signed char, schar, mul, SCHAR_MIN, SCHAR_MAX) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(signed char, schar, SCHAR_MIN, SCHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_SUB(signed char, schar, SCHAR_MIN, SCHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MUL(signed char, schar, SCHAR_MIN, SCHAR_MAX) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(signed char, schar, SCHAR_MIN, SCHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MOD(signed char, schar, SCHAR_MIN, SCHAR_MAX) +PSNIP_SAFE_DEFINE_SIGNED_NEG(signed char, schar, SCHAR_MIN, SCHAR_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned char, uchar, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned char, uchar, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned char, uchar, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_UCHAR) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned char, uchar, add, UCHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned char, uchar, sub, UCHAR_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned char, uchar, mul, UCHAR_MAX) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned char, uchar, UCHAR_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned char, uchar, UCHAR_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned char, uchar, UCHAR_MAX) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned char, uchar, UCHAR_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned char, uchar, UCHAR_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(short, short, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(short, short, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(short, short, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_SHORT) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(short, short, add, SHRT_MIN, SHRT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(short, short, sub, SHRT_MIN, SHRT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(short, short, mul, SHRT_MIN, SHRT_MAX) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(short, short, SHRT_MIN, SHRT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_SUB(short, short, SHRT_MIN, SHRT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MUL(short, short, SHRT_MIN, SHRT_MAX) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(short, short, SHRT_MIN, SHRT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MOD(short, short, SHRT_MIN, SHRT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_NEG(short, short, SHRT_MIN, SHRT_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned short, ushort, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned short, ushort, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned short, ushort, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned short, ushort, add, UShortAdd) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned short, ushort, sub, UShortSub) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned short, ushort, mul, UShortMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_USHORT) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned short, ushort, add, USHRT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned short, ushort, sub, USHRT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned short, ushort, mul, USHRT_MAX) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned short, ushort, USHRT_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned short, ushort, USHRT_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned short, ushort, USHRT_MAX) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned short, ushort, USHRT_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned short, ushort, USHRT_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(int, int, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(int, int, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(int, int, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_INT) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(int, int, add, INT_MIN, INT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(int, int, sub, INT_MIN, INT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(int, int, mul, INT_MIN, INT_MAX) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(int, int, INT_MIN, INT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_SUB(int, int, INT_MIN, INT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MUL(int, int, INT_MIN, INT_MAX) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(int, int, INT_MIN, INT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MOD(int, int, INT_MIN, INT_MAX) +PSNIP_SAFE_DEFINE_SIGNED_NEG(int, int, INT_MIN, INT_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned int, uint, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned int, uint, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned int, uint, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned int, uint, add, UIntAdd) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned int, uint, sub, UIntSub) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned int, uint, mul, UIntMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned int, uint, add, UINT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned int, uint, sub, UINT_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned int, uint, mul, UINT_MAX) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned int, uint, UINT_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned int, uint, UINT_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned int, uint, UINT_MAX) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned int, uint, UINT_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned int, uint, UINT_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long, long, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long, long, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long, long, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_LONG) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long, long, add, LONG_MIN, LONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long, long, sub, LONG_MIN, LONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long, long, mul, LONG_MIN, LONG_MAX) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(long, long, LONG_MIN, LONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_SUB(long, long, LONG_MIN, LONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MUL(long, long, LONG_MIN, LONG_MAX) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(long, long, LONG_MIN, LONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MOD(long, long, LONG_MIN, LONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_NEG(long, long, LONG_MIN, LONG_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long, ulong, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long, ulong, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long, ulong, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned long, ulong, add, ULongAdd) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned long, ulong, sub, ULongSub) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned long, ulong, mul, ULongMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_ULONG) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long, ulong, add, ULONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long, ulong, sub, ULONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long, ulong, mul, ULONG_MAX) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned long, ulong, ULONG_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned long, ulong, ULONG_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned long, ulong, ULONG_MAX) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned long, ulong, ULONG_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned long, ulong, ULONG_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long long, llong, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long long, llong, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(long long, llong, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_LLONG) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long long, llong, add, LLONG_MIN, LLONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long long, llong, sub, LLONG_MIN, LLONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(long long, llong, mul, LLONG_MIN, LLONG_MAX) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(long long, llong, LLONG_MIN, LLONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_SUB(long long, llong, LLONG_MIN, LLONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MUL(long long, llong, LLONG_MIN, LLONG_MAX) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(long long, llong, LLONG_MIN, LLONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_MOD(long long, llong, LLONG_MIN, LLONG_MAX) +PSNIP_SAFE_DEFINE_SIGNED_NEG(long long, llong, LLONG_MIN, LLONG_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long long, ullong, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long long, ullong, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(unsigned long long, ullong, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned long long, ullong, add, ULongLongAdd) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned long long, ullong, sub, ULongLongSub) +PSNIP_SAFE_DEFINE_INTSAFE(unsigned long long, ullong, mul, ULongLongMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_ULLONG) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long long, ullong, add, ULLONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long long, ullong, sub, ULLONG_MAX) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(unsigned long long, ullong, mul, ULLONG_MAX) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(unsigned long long, ullong, ULLONG_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(unsigned long long, ullong, ULLONG_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(unsigned long long, ullong, ULLONG_MAX) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(unsigned long long, ullong, ULLONG_MAX) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(unsigned long long, ullong, ULLONG_MAX) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(size_t, size, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(size_t, size, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(size_t, size, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) +PSNIP_SAFE_DEFINE_INTSAFE(size_t, size, add, SizeTAdd) +PSNIP_SAFE_DEFINE_INTSAFE(size_t, size, sub, SizeTSub) +PSNIP_SAFE_DEFINE_INTSAFE(size_t, size, mul, SizeTMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_SIZE) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(size_t, size, add, PSNIP_SAFE__SIZE_MAX_RT) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(size_t, size, sub, PSNIP_SAFE__SIZE_MAX_RT) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(size_t, size, mul, PSNIP_SAFE__SIZE_MAX_RT) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(size_t, size, PSNIP_SAFE__SIZE_MAX_RT) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(size_t, size, PSNIP_SAFE__SIZE_MAX_RT) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(size_t, size, PSNIP_SAFE__SIZE_MAX_RT) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(size_t, size, PSNIP_SAFE__SIZE_MAX_RT) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(size_t, size, PSNIP_SAFE__SIZE_MAX_RT) + +#if !defined(PSNIP_SAFE_NO_FIXED) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int8_t, int8, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int8_t, int8, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int8_t, int8, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_INT8) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int8_t, int8, add, (-0x7fLL-1), 0x7f) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int8_t, int8, sub, (-0x7fLL-1), 0x7f) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int8_t, int8, mul, (-0x7fLL-1), 0x7f) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int8_t, int8, (-0x7fLL-1), 0x7f) +PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int8_t, int8, (-0x7fLL-1), 0x7f) +PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int8_t, int8, (-0x7fLL-1), 0x7f) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int8_t, int8, (-0x7fLL-1), 0x7f) +PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int8_t, int8, (-0x7fLL-1), 0x7f) +PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int8_t, int8, (-0x7fLL-1), 0x7f) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint8_t, uint8, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint8_t, uint8, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint8_t, uint8, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT8) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint8_t, uint8, add, 0xff) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint8_t, uint8, sub, 0xff) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint8_t, uint8, mul, 0xff) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint8_t, uint8, 0xff) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint8_t, uint8, 0xff) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint8_t, uint8, 0xff) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint8_t, uint8, 0xff) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint8_t, uint8, 0xff) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int16_t, int16, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_INT16) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int16_t, int16, add, (-32767-1), 0x7fff) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int16_t, int16, sub, (-32767-1), 0x7fff) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int16_t, int16, mul, (-32767-1), 0x7fff) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int16_t, int16, (-32767-1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int16_t, int16, (-32767-1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int16_t, int16, (-32767-1), 0x7fff) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int16_t, int16, (-32767-1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int16_t, int16, (-32767-1), 0x7fff) +PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int16_t, int16, (-32767-1), 0x7fff) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint16_t, uint16, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, add, UShortAdd) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, sub, UShortSub) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint16_t, uint16, mul, UShortMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT16) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint16_t, uint16, add, 0xffff) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint16_t, uint16, sub, 0xffff) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint16_t, uint16, mul, 0xffff) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint16_t, uint16, 0xffff) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint16_t, uint16, 0xffff) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint16_t, uint16, 0xffff) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int32_t, int32, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_INT32) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int32_t, int32, add, (-0x7fffffffLL-1), 0x7fffffffLL) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int32_t, int32, sub, (-0x7fffffffLL-1), 0x7fffffffLL) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int32_t, int32, mul, (-0x7fffffffLL-1), 0x7fffffffLL) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int32_t, int32, (-0x7fffffffLL-1), 0x7fffffffLL) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint32_t, uint32, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, add, UIntAdd) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, sub, UIntSub) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint32_t, uint32, mul, UIntMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT32) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint32_t, uint32, add, 0xffffffffUL) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint32_t, uint32, sub, 0xffffffffUL) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint32_t, uint32, mul, 0xffffffffUL) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint32_t, uint32, 0xffffffffUL) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint32_t, uint32, 0xffffffffUL) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint32_t, uint32, 0xffffffffUL) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_int64_t, int64, mul) +#elif defined(PSNIP_SAFE_HAVE_LARGER_INT64) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int64_t, int64, add, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int64_t, int64, sub, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +PSNIP_SAFE_DEFINE_PROMOTED_SIGNED_BINARY_OP(psnip_int64_t, int64, mul, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +#else +PSNIP_SAFE_DEFINE_SIGNED_ADD(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_SUB(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_MUL(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +#endif +PSNIP_SAFE_DEFINE_SIGNED_DIV(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_MOD(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) +PSNIP_SAFE_DEFINE_SIGNED_NEG(psnip_int64_t, int64, (-0x7fffffffffffffffLL-1), 0x7fffffffffffffffLL) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, add) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, sub) +PSNIP_SAFE_DEFINE_BUILTIN_BINARY_OP(psnip_uint64_t, uint64, mul) +#elif defined(PSNIP_SAFE_HAVE_INTSAFE_H) && defined(_WIN32) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, add, ULongLongAdd) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, sub, ULongLongSub) +PSNIP_SAFE_DEFINE_INTSAFE(psnip_uint64_t, uint64, mul, ULongLongMult) +#elif defined(PSNIP_SAFE_HAVE_LARGER_UINT64) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint64_t, uint64, add, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint64_t, uint64, sub, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_PROMOTED_UNSIGNED_BINARY_OP(psnip_uint64_t, uint64, mul, 0xffffffffffffffffULL) +#else +PSNIP_SAFE_DEFINE_UNSIGNED_ADD(psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_SUB(psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_MUL(psnip_uint64_t, uint64, 0xffffffffffffffffULL) +#endif +PSNIP_SAFE_DEFINE_UNSIGNED_DIV(psnip_uint64_t, uint64, 0xffffffffffffffffULL) +PSNIP_SAFE_DEFINE_UNSIGNED_MOD(psnip_uint64_t, uint64, 0xffffffffffffffffULL) + +#endif /* !defined(PSNIP_SAFE_NO_FIXED) */ + +#define PSNIP_SAFE_C11_GENERIC_SELECTION(res, op) \ + _Generic((*res), \ + char: psnip_safe_char_##op, \ + unsigned char: psnip_safe_uchar_##op, \ + short: psnip_safe_short_##op, \ + unsigned short: psnip_safe_ushort_##op, \ + int: psnip_safe_int_##op, \ + unsigned int: psnip_safe_uint_##op, \ + long: psnip_safe_long_##op, \ + unsigned long: psnip_safe_ulong_##op, \ + long long: psnip_safe_llong_##op, \ + unsigned long long: psnip_safe_ullong_##op) + +#define PSNIP_SAFE_C11_GENERIC_BINARY_OP(op, res, a, b) \ + PSNIP_SAFE_C11_GENERIC_SELECTION(res, op)(res, a, b) +#define PSNIP_SAFE_C11_GENERIC_UNARY_OP(op, res, v) \ + PSNIP_SAFE_C11_GENERIC_SELECTION(res, op)(res, v) + +#if defined(PSNIP_SAFE_HAVE_BUILTIN_OVERFLOW) +#define psnip_safe_add(res, a, b) !__builtin_add_overflow(a, b, res) +#define psnip_safe_sub(res, a, b) !__builtin_sub_overflow(a, b, res) +#define psnip_safe_mul(res, a, b) !__builtin_mul_overflow(a, b, res) +#define psnip_safe_div(res, a, b) !__builtin_div_overflow(a, b, res) +#define psnip_safe_mod(res, a, b) !__builtin_mod_overflow(a, b, res) +#define psnip_safe_neg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP (neg, res, v) + +#elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) +/* The are no fixed-length or size selections because they cause an + * error about _Generic specifying two compatible types. Hopefully + * this doesn't cause problems on exotic platforms, but if it does + * please let me know and I'll try to figure something out. */ + +#define psnip_safe_add(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(add, res, a, b) +#define psnip_safe_sub(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(sub, res, a, b) +#define psnip_safe_mul(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(mul, res, a, b) +#define psnip_safe_div(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(div, res, a, b) +#define psnip_safe_mod(res, a, b) PSNIP_SAFE_C11_GENERIC_BINARY_OP(mod, res, a, b) +#define psnip_safe_neg(res, v) PSNIP_SAFE_C11_GENERIC_UNARY_OP (neg, res, v) +#endif + +#if !defined(PSNIP_SAFE_HAVE_BUILTINS) && (defined(PSNIP_SAFE_EMULATE_NATIVE) || defined(PSNIP_BUILTIN_EMULATE_NATIVE)) +# define __builtin_sadd_overflow(a, b, res) (!psnip_safe_int_add(res, a, b)) +# define __builtin_saddl_overflow(a, b, res) (!psnip_safe_long_add(res, a, b)) +# define __builtin_saddll_overflow(a, b, res) (!psnip_safe_llong_add(res, a, b)) +# define __builtin_uadd_overflow(a, b, res) (!psnip_safe_uint_add(res, a, b)) +# define __builtin_uaddl_overflow(a, b, res) (!psnip_safe_ulong_add(res, a, b)) +# define __builtin_uaddll_overflow(a, b, res) (!psnip_safe_ullong_add(res, a, b)) + +# define __builtin_ssub_overflow(a, b, res) (!psnip_safe_int_sub(res, a, b)) +# define __builtin_ssubl_overflow(a, b, res) (!psnip_safe_long_sub(res, a, b)) +# define __builtin_ssubll_overflow(a, b, res) (!psnip_safe_llong_sub(res, a, b)) +# define __builtin_usub_overflow(a, b, res) (!psnip_safe_uint_sub(res, a, b)) +# define __builtin_usubl_overflow(a, b, res) (!psnip_safe_ulong_sub(res, a, b)) +# define __builtin_usubll_overflow(a, b, res) (!psnip_safe_ullong_sub(res, a, b)) + +# define __builtin_smul_overflow(a, b, res) (!psnip_safe_int_mul(res, a, b)) +# define __builtin_smull_overflow(a, b, res) (!psnip_safe_long_mul(res, a, b)) +# define __builtin_smulll_overflow(a, b, res) (!psnip_safe_llong_mul(res, a, b)) +# define __builtin_umul_overflow(a, b, res) (!psnip_safe_uint_mul(res, a, b)) +# define __builtin_umull_overflow(a, b, res) (!psnip_safe_ulong_mul(res, a, b)) +# define __builtin_umulll_overflow(a, b, res) (!psnip_safe_ullong_mul(res, a, b)) +#endif + +#endif /* !defined(PSNIP_SAFE_H) */ diff --git a/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh b/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh index b7da540..45aa64f 100755 --- a/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh +++ b/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh @@ -28,6 +28,7 @@ main() { local -r tarball="$SCRATCH/nanoarrow.tar.gz" wget -O "$tarball" "$repo_url/archive/$commit_sha.tar.gz" + mv nanoarrow/CMakeLists.txt CMakeLists.nanoarrow.tmp rm -rf nanoarrow mkdir -p nanoarrow tar --strip-components 1 -C "$SCRATCH" -xf "$tarball" @@ -45,6 +46,7 @@ main() { cp "$SCRATCH/dist-adbc/nanoarrow.c" nanoarrow/ cp "$SCRATCH/dist-adbc/nanoarrow.h" nanoarrow/ cp "$SCRATCH/dist-adbc/nanoarrow.hpp" nanoarrow/ + mv CMakeLists.nanoarrow.tmp nanoarrow/CMakeLists.txt } main "$@" diff --git a/3rd_party/apache-arrow-adbc/docker-compose.yml b/3rd_party/apache-arrow-adbc/docker-compose.yml index e65d935..a25fb66 100644 --- a/3rd_party/apache-arrow-adbc/docker-compose.yml +++ b/3rd_party/apache-arrow-adbc/docker-compose.yml @@ -28,20 +28,7 @@ services: volumes: - .:/adbc:delegated command: | - /bin/bash -c 'git config --global --add safe.directory /adbc && source /opt/conda/etc/profile.d/conda.sh && mamba create -y -n adbc -c conda-forge go --file /adbc/ci/conda_env_cpp.txt --file /adbc/ci/conda_env_docs.txt --file /adbc/ci/conda_env_python.txt && conda activate adbc && env ADBC_USE_ASAN=0 ADBC_USE_UBSAN=0 /adbc/ci/scripts/cpp_build.sh /adbc /adbc/build && env CGO_ENABLED=1 /adbc/ci/scripts/go_build.sh /adbc /adbc/build && /adbc/ci/scripts/python_build.sh /adbc /adbc/build && /adbc/ci/scripts/docs_build.sh /adbc' - - golang-sqlite-flightsql: - image: ${REPO}:golang-${GO}-sqlite-flightsql - build: - context: . - cache_from: - - ${REPO}:golang-${GO}-sqlite-flightsql - dockerfile: ci/docker/golang-flightsql-sqlite.dockerfile - args: - GO: ${GO} - ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} - ports: - - 8080:8080 + /bin/bash -c 'git config --global --add safe.directory /adbc && source /opt/conda/etc/profile.d/conda.sh && mamba create -y -n adbc -c conda-forge go --file /adbc/ci/conda_env_cpp.txt --file /adbc/ci/conda_env_docs.txt --file /adbc/ci/conda_env_python.txt && conda activate adbc && env ADBC_USE_ASAN=0 ADBC_USE_UBSAN=0 /adbc/ci/scripts/cpp_build.sh /adbc /adbc/build && env CGO_ENABLED=1 /adbc/ci/scripts/go_build.sh /adbc /adbc/build && /adbc/ci/scripts/python_build.sh /adbc /adbc/build && /adbc/ci/scripts/r_build.sh /adbc && /adbc/ci/scripts/docs_build.sh /adbc' ############################ Java JARs ###################################### @@ -120,15 +107,6 @@ services: ###################### Test database environments ############################ - postgres_test: - container_name: adbc_postgres_test - image: postgres:latest - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: password - ports: - - "5432:5432" - dremio: container_name: adbc-dremio image: dremio/dremio-oss:latest @@ -162,3 +140,55 @@ services: entrypoint: "/init/bootstrap.sh" volumes: - "./ci/scripts/integration/dremio:/init" + + flightsql-test: + image: ${REPO}:adbc-flightsql-test + build: + context: . + cache_from: + - ${REPO}:adbc-flightsql-test + dockerfile: ci/docker/flightsql-test.dockerfile + args: + GO: ${GO} + ports: + - "41414:41414" + volumes: + - .:/adbc:delegated + command: >- + /bin/bash -c "cd /adbc/go/adbc && go run ./driver/flightsql/cmd/testserver -host 0.0.0.0 -port 41414" + + flightsql-sqlite-test: + image: ${REPO}:golang-${GO}-sqlite-flightsql + build: + context: . + cache_from: + - ${REPO}:golang-${GO}-sqlite-flightsql + dockerfile: ci/docker/golang-flightsql-sqlite.dockerfile + args: + GO: ${GO} + ARROW_MAJOR_VERSION: ${ARROW_MAJOR_VERSION} + ports: + - 8080:8080 + + mssql-test: + container_name: adbc_mssql_test + image: mcr.microsoft.com/mssql/server:2022-latest + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "Password1!" + ports: + - "1433:1433" + + postgres-test: + container_name: adbc_postgres_test + image: postgres:${POSTGRES_VERSION:-latest} + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + healthcheck: + test: ["CMD-SHELL", "pg_isready"] + interval: 10s + timeout: 5s + retries: 5 + ports: + - "5432:5432" diff --git a/3rd_party/apache-arrow-adbc/license.tpl b/3rd_party/apache-arrow-adbc/license.tpl index 00a484b..5f0a29e 100644 --- a/3rd_party/apache-arrow-adbc/license.tpl +++ b/3rd_party/apache-arrow-adbc/license.tpl @@ -213,6 +213,22 @@ License: http://www.apache.org/licenses/LICENSE-2.0 -------------------------------------------------------------------------------- +The files in c/vendor/portable-snippets/ contain code from + +https://github.com/nemequ/portable-snippets + +and have the following copyright notice: + +Each source file contains a preamble explaining the license situation +for that file, which takes priority over this file. With the +exception of some code pulled in from other repositories (such as +µnit, an MIT-licensed project which is used for testing), the code is +public domain, released using the CC0 1.0 Universal dedication (*). + +(*) https://creativecommons.org/publicdomain/zero/1.0/legalcode + +-------------------------------------------------------------------------------- + The files python/*/*/_version.py and python/*/*/_static_version.py contain code from diff --git a/3rd_party/apache-arrow-adbc/pyrightconfig.json b/3rd_party/apache-arrow-adbc/pyrightconfig.json new file mode 100644 index 0000000..e78d074 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/pyrightconfig.json @@ -0,0 +1,9 @@ +{ + "exclude": [ + "**/.asv/", + "**/_version.py", + "**/setup.py", + "**/build/", + "**/benchmarks/" + ] +} diff --git a/README.md b/README.md index 47fbf08..df6f557 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,15 @@ And now you can make queries with: {:ok, _} = Adbc.Connection.query(conn, "SELECT 123") ``` +## Updating ADBC + +This library vendors ADBC C implementation inside the 3rd_party folder. +In order to update it: + + 1. Download source for [latest ADBC release](https://github.com/apache/arrow-adbc/releases/) + 2. Copy root files and c/ directory from ADBC into 3rd_party/apache-arrow-adbc + 3. Update the driver version in `lib/adbc_driver.ex` + ## License Copyright 2023 Cocoa Xu, José Valim diff --git a/c_src/adbc_nif.cpp b/c_src/adbc_nif.cpp index 4c90387..63da108 100644 --- a/c_src/adbc_nif.cpp +++ b/c_src/adbc_nif.cpp @@ -1260,13 +1260,9 @@ int elixir_to_arrow_type_struct(ErlNifEnv *env, ERL_NIF_TERM values, struct Arro ArrowSchemaInit(schema_out); NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema_out, n_items)); - NANOARROW_RETURN_NOT_OK(ArrowSchemaSetName(schema_out, "")); - NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromType(array_out, NANOARROW_TYPE_STRUCT)); NANOARROW_RETURN_NOT_OK(ArrowArrayAllocateChildren(array_out, static_cast(n_items))); - array_out->length = 1; - array_out->null_count = -1; ERL_NIF_TERM head, tail; tail = values; @@ -1342,10 +1338,9 @@ int elixir_to_arrow_type_struct(ErlNifEnv *env, ERL_NIF_TERM values, struct Arro snprintf(error_out->message, sizeof(error_out->message), "type not supported yet."); return 1; } - NANOARROW_RETURN_NOT_OK(ArrowArrayFinishBuildingDefault(child_i, error_out)); processed++; } - + NANOARROW_RETURN_NOT_OK(ArrowArrayFinishBuildingDefault(array_out, error_out)); return !(processed == n_items); } diff --git a/lib/adbc_driver.ex b/lib/adbc_driver.ex index b29fc16..1cabf57 100644 --- a/lib/adbc_driver.ex +++ b/lib/adbc_driver.ex @@ -6,7 +6,7 @@ defmodule Adbc.Driver do @official_drivers ~w(sqlite postgresql flightsql snowflake)a @official_driver_base_url "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-" - @version "0.5.1" + @version "0.7.0" def download(driver_name, opts \\ [])