diff --git a/.editorconfig b/.editorconfig index 562d6ce1ff4..3c044ad354a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,7 @@ indent_style = space indent_size = 2 indent_style = space -[ext/*.{c,cpp,h}] -indent_size = 4 +[edb_stat_statements/*.{c,h,l,y,pl,pm}] indent_style = tab +indent_size = tab +tab_width = 4 diff --git a/.github/Makefile b/.github/Makefile index 9090781f58e..e8475191706 100644 --- a/.github/Makefile +++ b/.github/Makefile @@ -13,7 +13,8 @@ all: workflows/nightly.yml \ workflows/tests-ha.yml \ workflows/tests-pg-versions.yml \ workflows/tests-patches.yml \ - workflows/tests-inplace.yml + workflows/tests-inplace.yml \ + workflows/tests-reflection.yml \ workflows/%.yml: workflows.src/%.tpl.yml workflows.src/%.targets.yml workflows.src/build.inc.yml workflows.src/ls-build.inc.yml $(ROOT)/workflows.src/render.py $* $*.targets.yml @@ -38,3 +39,6 @@ workflows.src/tests-patches.tpl.yml: workflows.src/tests.inc.yml workflows.src/tests-inplace.tpl.yml: workflows.src/tests.inc.yml touch $(ROOT)/workflows.src/tests-inplace.tpl.yml + +workflows.src/tests-reflection.tpl.yml: workflows.src/tests.inc.yml + touch $(ROOT)/workflows.src/tests-inplace.tpl.yml diff --git a/.github/workflows.src/tests-reflection.targets.yml b/.github/workflows.src/tests-reflection.targets.yml new file mode 100644 index 00000000000..99d4a714ac7 --- /dev/null +++ b/.github/workflows.src/tests-reflection.targets.yml @@ -0,0 +1 @@ +data: diff --git a/.github/workflows.src/tests-reflection.tpl.yml b/.github/workflows.src/tests-reflection.tpl.yml new file mode 100644 index 00000000000..3cf3a832d5d --- /dev/null +++ b/.github/workflows.src/tests-reflection.tpl.yml @@ -0,0 +1,59 @@ +<% from "tests.inc.yml" import build, calc_cache_key, restore_cache -%> +name: Tests with reflection validation + +on: + schedule: + - cron: "0 3 * * *" + workflow_dispatch: + inputs: {} + push: + branches: + - "REFL-*" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + <%- call build() -%> + + - name: Compute cache keys + env: + GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} + run: | + << calc_cache_key()|indent >> + <%- endcall %> + + test: + needs: build + runs-on: ubuntu-latest + + steps: + <<- restore_cache() >> + + # Run the test + + - name: Test + env: + EDGEDB_TEST_REPEATS: 1 + run: | + edb test -j2 -v + + workflow-notifications: + if: failure() && github.event_name != 'pull_request' + name: Notify in Slack on failures + needs: + - build + - test + runs-on: ubuntu-latest + permissions: + actions: 'read' + steps: + - name: Slack Workflow Notification + uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 + with: + repo_token: ${{secrets.GITHUB_TOKEN}} + slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} + name: 'Workflow notifications' + icon_emoji: ':hammer:' + include_jobs: 'on-failure' diff --git a/.github/workflows.src/tests.inc.yml b/.github/workflows.src/tests.inc.yml index aba64c5b64c..126c66044cc 100644 --- a/.github/workflows.src/tests.inc.yml +++ b/.github/workflows.src/tests.inc.yml @@ -89,7 +89,7 @@ id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -123,7 +123,7 @@ steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -195,7 +195,7 @@ if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -377,7 +377,7 @@ id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-ha.yml b/.github/workflows/tests-ha.yml index be61db8e813..5178c014516 100644 --- a/.github/workflows/tests-ha.yml +++ b/.github/workflows/tests-ha.yml @@ -135,7 +135,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -169,7 +169,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -241,7 +241,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -431,7 +431,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-inplace.yml b/.github/workflows/tests-inplace.yml index cabfd642236..6cc5a61f149 100644 --- a/.github/workflows/tests-inplace.yml +++ b/.github/workflows/tests-inplace.yml @@ -120,7 +120,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -154,7 +154,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -226,7 +226,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -428,7 +428,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-managed-pg.yml b/.github/workflows/tests-managed-pg.yml index 766bc1f86a6..5c930923be2 100644 --- a/.github/workflows/tests-managed-pg.yml +++ b/.github/workflows/tests-managed-pg.yml @@ -120,7 +120,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -154,7 +154,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -226,7 +226,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -462,7 +462,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -704,7 +704,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -994,7 +994,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -1250,7 +1250,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -1494,7 +1494,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-patches.yml b/.github/workflows/tests-patches.yml index bd5da952e44..9078bb212e9 100644 --- a/.github/workflows/tests-patches.yml +++ b/.github/workflows/tests-patches.yml @@ -122,7 +122,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -156,7 +156,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -228,7 +228,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -431,7 +431,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-pg-versions.yml b/.github/workflows/tests-pg-versions.yml index 21f1cef46fe..e83d3ae547c 100644 --- a/.github/workflows/tests-pg-versions.yml +++ b/.github/workflows/tests-pg-versions.yml @@ -120,7 +120,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -154,7 +154,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -226,7 +226,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -455,7 +455,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.github/workflows/tests-pool.yml b/.github/workflows/tests-pool.yml index 1047d76bc29..20b58139aef 100644 --- a/.github/workflows/tests-pool.yml +++ b/.github/workflows/tests-pool.yml @@ -130,7 +130,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -164,7 +164,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -236,7 +236,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: diff --git a/.github/workflows/tests-reflection.yml b/.github/workflows/tests-reflection.yml new file mode 100644 index 00000000000..f9ee8c816ac --- /dev/null +++ b/.github/workflows/tests-reflection.yml @@ -0,0 +1,519 @@ +name: Tests with reflection validation + +on: + schedule: + - cron: "0 3 * * *" + workflow_dispatch: + inputs: {} + push: + branches: + - "REFL-*" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: false + + - uses: actions/checkout@v4 + with: + fetch-depth: 50 + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + id: setup-python + with: + python-version: '3.12.2' + cache: 'pip' + cache-dependency-path: | + pyproject.toml + + # The below is technically a lie as we are technically not + # inside a virtual env, but there is really no reason to bother + # actually creating and activating one as below works just fine. + - name: Export $VIRTUAL_ENV + run: | + venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" + echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV + + - name: Set up uv cache + uses: actions/cache@v4 + with: + path: ~/.cache/uv + key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} + + - name: Cached requirements.txt + uses: actions/cache@v4 + id: requirements-cache + with: + path: requirements.txt + key: edb-requirements-${{ hashFiles('pyproject.toml') }} + + - name: Compute requirements.txt + if: steps.requirements-cache.outputs.cache-hit != 'true' + run: | + python -m pip install pip-tools + pip-compile --no-strip-extras --all-build-deps \ + --extra test,language-server \ + --output-file requirements.txt pyproject.toml + + - name: Install Python dependencies + run: | + python -c "import sys; print(sys.prefix)" + python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt + + - name: Compute cache keys + env: + GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} + run: | + mkdir -p shared-artifacts + if [ "$(uname)" = "Darwin" ]; then + find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt + else + find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt + fi + python setup.py -q ci_helper --type cli > shared-artifacts/edgedbcli_git_rev.txt + python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt + python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt + python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt + python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt + python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt + echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt + python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt + echo EDGEDBCLI_GIT_REV=$(cat shared-artifacts/edgedbcli_git_rev.txt) >> $GITHUB_ENV + echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV + echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV + echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV + echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV + echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV + + - name: Upload shared artifacts + uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + with: + name: shared-artifacts + path: shared-artifacts + retention-days: 1 + + # Restore binary cache + + - name: Handle cached EdgeDB CLI binaries + uses: actions/cache@v4 + id: cli-cache + with: + path: build/cli + key: edb-cli-v3-${{ env.EDGEDBCLI_GIT_REV }} + + - name: Handle cached Rust extensions + uses: actions/cache@v4 + id: rust-cache + with: + path: build/rust_extensions + key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} + restore-keys: | + edb-rust-v4- + + - name: Handle cached Cython extensions + uses: actions/cache@v4 + id: ext-cache + with: + path: build/extensions + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + + - name: Handle cached PostgreSQL build + uses: actions/cache@v4 + id: postgres-cache + with: + path: build/postgres/install + key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} + + - name: Handle cached Stolon build + uses: actions/cache@v4 + id: stolon-cache + with: + path: build/stolon/bin + key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} + + - name: Handle cached libpg_query build + uses: actions/cache@v4 + id: libpg-query-cache + with: + path: edb/pgsql/parser/libpg_query/libpg_query.a + key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} + + # Install system dependencies for building + + - name: Install system deps + if: | + steps.cli-cache.outputs.cache-hit != 'true' || + steps.rust-cache.outputs.cache-hit != 'true' || + steps.ext-cache.outputs.cache-hit != 'true' || + steps.stolon-cache.outputs.cache-hit != 'true' || + steps.postgres-cache.outputs.cache-hit != 'true' + run: | + sudo apt-get update + sudo apt-get install -y uuid-dev libreadline-dev bison flex + + - name: Install Rust toolchain + if: | + steps.cli-cache.outputs.cache-hit != 'true' || + steps.rust-cache.outputs.cache-hit != 'true' + uses: dsherret/rust-toolchain-file@v1 + + # Build EdgeDB CLI + + - name: Handle EdgeDB CLI build cache + uses: actions/cache@v4 + if: steps.cli-cache.outputs.cache-hit != 'true' + with: + path: ${{ env.BUILD_TEMP }}/rust/cli + key: edb-cli-build-v7-${{ env.EDGEDBCLI_GIT_REV }} + restore-keys: | + edb-cli-build-v7- + + - name: Build EdgeDB CLI + env: + CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/cli/cargo_home + CACHE_HIT: ${{ steps.cli-cache.outputs.cache-hit }} + run: | + if [[ "$CACHE_HIT" == "true" ]]; then + cp -v build/cli/bin/edgedb edb/cli/edgedb + else + python setup.py -v build_cli + fi + + # Build Rust extensions + + - name: Handle Rust extensions build cache + uses: actions/cache@v4 + if: steps.rust-cache.outputs.cache-hit != 'true' + with: + path: ${{ env.BUILD_TEMP }}/rust/extensions + key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} + restore-keys: | + edb-rust-build-v1- + + - name: Build Rust extensions + env: + CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home + CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} + run: | + if [[ "$CACHE_HIT" != "true" ]]; then + rm -rf ${BUILD_LIB} + mkdir -p build/rust_extensions + rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ + python setup.py -v build_rust + rsync -av ${BUILD_LIB}/ build/rust_extensions/ + rm -rf ${BUILD_LIB} + fi + rsync -av ./build/rust_extensions/edb/ ./edb/ + + # Build libpg_query + + - name: Build libpg_query + if: | + steps.libpg-query-cache.outputs.cache-hit != 'true' && + steps.ext-cache.outputs.cache-hit != 'true' + run: | + python setup.py build_libpg_query + + # Build extensions + + - name: Handle Cython extensions build cache + uses: actions/cache@v4 + if: steps.ext-cache.outputs.cache-hit != 'true' + with: + path: ${{ env.BUILD_TEMP }}/edb + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + + - name: Build Cython extensions + env: + CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} + BUILD_EXT_MODE: py-only + run: | + if [[ "$CACHE_HIT" != "true" ]]; then + rm -rf ${BUILD_LIB} + mkdir -p ./build/extensions + rsync -av ./build/extensions/ ${BUILD_LIB}/ + BUILD_EXT_MODE=py-only python setup.py -v build_ext + rsync -av ${BUILD_LIB}/ ./build/extensions/ + rm -rf ${BUILD_LIB} + fi + rsync -av ./build/extensions/edb/ ./edb/ + + # Build parsers + + - name: Handle compiled parsers cache + uses: actions/cache@v4 + id: parsers-cache + with: + path: build/lib + key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} + restore-keys: | + edb-parsers-v3- + + - name: Build parsers + env: + CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} + run: | + if [[ "$CACHE_HIT" != "true" ]]; then + rm -rf ${BUILD_LIB} + mkdir -p ./build/lib + rsync -av ./build/lib/ ${BUILD_LIB}/ + python setup.py -v build_parsers + rsync -av ${BUILD_LIB}/ ./build/lib/ + rm -rf ${BUILD_LIB} + fi + rsync -av ./build/lib/edb/ ./edb/ + + # Build PostgreSQL + + - name: Build PostgreSQL + env: + CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} + run: | + if [[ "$CACHE_HIT" == "true" ]]; then + cp build/postgres/install/stamp build/postgres/ + else + python setup.py build_postgres + cp build/postgres/stamp build/postgres/install/ + fi + + # Build Stolon + + - name: Set up Go + if: steps.stolon-cache.outputs.cache-hit != 'true' + uses: actions/setup-go@v2 + with: + go-version: 1.16 + + - uses: actions/checkout@v4 + if: steps.stolon-cache.outputs.cache-hit != 'true' + with: + repository: edgedb/stolon + path: build/stolon + ref: ${{ env.STOLON_GIT_REV }} + fetch-depth: 0 + submodules: false + + - name: Build Stolon + if: steps.stolon-cache.outputs.cache-hit != 'true' + run: | + mkdir -p build/stolon/bin/ + curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul + chmod +x build/stolon/bin/consul + cd build/stolon && make + + # Install edgedb-server and populate egg-info + + - name: Install edgedb-server + env: + BUILD_EXT_MODE: skip + run: | + # --no-build-isolation because we have explicitly installed all deps + # and don't want them to be reinstalled in an "isolated env". + pip install --no-build-isolation --no-deps -e .[test,docs] + + # Refresh the bootstrap cache + + - name: Handle bootstrap cache + uses: actions/cache@v4 + id: bootstrap-cache + with: + path: build/cache + key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} + restore-keys: | + edb-bootstrap-v2- + + - name: Bootstrap EdgeDB Server + if: steps.bootstrap-cache.outputs.cache-hit != 'true' + run: | + edb server --bootstrap-only + + test: + needs: build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: false + + - uses: actions/checkout@v4 + with: + fetch-depth: 50 + submodules: true + + - name: Set up Python + uses: actions/setup-python@v5 + id: setup-python + with: + python-version: '3.12.2' + cache: 'pip' + cache-dependency-path: | + pyproject.toml + + # The below is technically a lie as we are technically not + # inside a virtual env, but there is really no reason to bother + # actually creating and activating one as below works just fine. + - name: Export $VIRTUAL_ENV + run: | + venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" + echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV + + - name: Set up uv cache + uses: actions/cache@v4 + with: + path: ~/.cache/uv + key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} + + - name: Download requirements.txt + uses: actions/cache@v4 + with: + path: requirements.txt + key: edb-requirements-${{ hashFiles('pyproject.toml') }} + + - name: Install Python dependencies + run: python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt + + # Restore the artifacts and environment variables + + - name: Download shared artifacts + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + with: + name: shared-artifacts + path: shared-artifacts + + - name: Set environment variables + run: | + echo EDGEDBCLI_GIT_REV=$(cat shared-artifacts/edgedbcli_git_rev.txt) >> $GITHUB_ENV + echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV + echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV + echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV + echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV + + # Restore build cache + + - name: Restore cached EdgeDB CLI binaries + uses: actions/cache@v4 + id: cli-cache + with: + path: build/cli + key: edb-cli-v3-${{ env.EDGEDBCLI_GIT_REV }} + + - name: Restore cached Rust extensions + uses: actions/cache@v4 + id: rust-cache + with: + path: build/rust_extensions + key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} + + - name: Restore cached Cython extensions + uses: actions/cache@v4 + id: ext-cache + with: + path: build/extensions + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + + - name: Restore compiled parsers cache + uses: actions/cache@v4 + id: parsers-cache + with: + path: build/lib + key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} + + - name: Restore cached PostgreSQL build + uses: actions/cache@v4 + id: postgres-cache + with: + path: build/postgres/install + key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} + + - name: Restore cached Stolon build + uses: actions/cache@v4 + id: stolon-cache + with: + path: build/stolon/bin + key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} + + - name: Restore bootstrap cache + uses: actions/cache@v4 + id: bootstrap-cache + with: + path: build/cache + key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} + + - name: Stop if we cannot retrieve the cache + if: | + steps.cli-cache.outputs.cache-hit != 'true' || + steps.rust-cache.outputs.cache-hit != 'true' || + steps.ext-cache.outputs.cache-hit != 'true' || + steps.parsers-cache.outputs.cache-hit != 'true' || + steps.postgres-cache.outputs.cache-hit != 'true' || + steps.stolon-cache.outputs.cache-hit != 'true' || + steps.bootstrap-cache.outputs.cache-hit != 'true' + run: | + echo ::error::Cannot retrieve build cache. + exit 1 + + - name: Validate cached binaries + run: | + # Validate EdgeDB CLI + ./build/cli/bin/edgedb --version || exit 1 + + # Validate Stolon + ./build/stolon/bin/stolon-sentinel --version || exit 1 + ./build/stolon/bin/stolon-keeper --version || exit 1 + ./build/stolon/bin/stolon-proxy --version || exit 1 + + # Validate PostgreSQL + ./build/postgres/install/bin/postgres --version || exit 1 + ./build/postgres/install/bin/pg_config --version || exit 1 + + - name: Restore cache into the source tree + run: | + cp -v build/cli/bin/edgedb edb/cli/edgedb + rsync -av ./build/rust_extensions/edb/ ./edb/ + rsync -av ./build/extensions/edb/ ./edb/ + rsync -av ./build/lib/edb/ ./edb/ + cp build/postgres/install/stamp build/postgres/ + + - name: Install edgedb-server + env: + BUILD_EXT_MODE: skip + run: | + # --no-build-isolation because we have explicitly installed all deps + # and don't want them to be reinstalled in an "isolated env". + pip install --no-build-isolation --no-deps -e .[test,docs] + + # Run the test + + - name: Test + env: + EDGEDB_TEST_REPEATS: 1 + run: | + edb test -j2 -v + + workflow-notifications: + if: failure() && github.event_name != 'pull_request' + name: Notify in Slack on failures + needs: + - build + - test + runs-on: ubuntu-latest + permissions: + actions: 'read' + steps: + - name: Slack Workflow Notification + uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 + with: + repo_token: ${{secrets.GITHUB_TOKEN}} + slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} + name: 'Workflow notifications' + icon_emoji: ':hammer:' + include_jobs: 'on-failure' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 450a48b7d1b..00a208bb865 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -132,7 +132,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 @@ -166,7 +166,7 @@ jobs: steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update - sudo apt-get install -y uuid-dev libreadline-dev bison flex + sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: | @@ -238,7 +238,7 @@ jobs: if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb - key: edb-ext-build-v3-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: @@ -511,7 +511,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 @@ -691,7 +691,7 @@ jobs: id: ext-cache with: path: build/extensions - key: edb-ext-v5-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} + key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 diff --git a/.gitignore b/.gitignore index 6b75b711cce..314c1a4f16b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,9 @@ *.pyo *.o *.so -.vscode +*.dylib +.vscode/ +.zed/ *~ .#* .*.swp @@ -37,3 +39,4 @@ docs/_build /.vagga /.dmypy.json /compile_commands.json +/pyrightconfig.json diff --git a/.gitmodules b/.gitmodules index e286e4030fa..26b22531174 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ url = https://github.com/MagicStack/py-pgproto.git [submodule "edb/pgsql/parser/libpg_query"] path = edb/pgsql/parser/libpg_query - url = https://github.com/msullivan/libpg_query.git + url = https://github.com/edgedb/libpg_query.git diff --git a/Cargo.lock b/Cargo.lock index 0cde3ad83d6..70aab8897b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -540,7 +540,7 @@ dependencies = [ [[package]] name = "edgedb-errors" version = "0.4.2" -source = "git+https://github.com/edgedb/edgedb-rust#f9d784470af6e013051d1503882c1d88c51a5dcb" +source = "git+https://github.com/edgedb/edgedb-rust#b38fb4af07ae0017329eb3cce30ca37fe12acd29" dependencies = [ "bytes", ] @@ -548,7 +548,7 @@ dependencies = [ [[package]] name = "edgedb-protocol" version = "0.6.1" -source = "git+https://github.com/edgedb/edgedb-rust#f9d784470af6e013051d1503882c1d88c51a5dcb" +source = "git+https://github.com/edgedb/edgedb-rust#b38fb4af07ae0017329eb3cce30ca37fe12acd29" dependencies = [ "bigdecimal", "bitflags", @@ -1679,9 +1679,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +checksum = "7ebb0c0cc0de9678e53be9ccf8a2ab53045e6e3a8be03393ceccc5e7396ccb40" dependencies = [ "cfg-if", "indoc", @@ -1698,9 +1698,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +checksum = "80e3ce69c4ec34476534b490e412b871ba03a82e35604c3dfb95fcb6bfb60c09" dependencies = [ "once_cell", "target-lexicon", @@ -1708,9 +1708,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +checksum = "3b09f311c76b36dfd6dd6f7fa6f9f18e7e46a1c937110d283e80b12ba2468a75" dependencies = [ "libc", "pyo3-build-config", @@ -1718,9 +1718,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +checksum = "fd4f74086536d1e1deaff99ec0387481fb3325c82e4e48be0e75ab3d3fcb487a" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1730,9 +1730,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +checksum = "9e77dfeb76b32bbf069144a5ea0a36176ab59c8db9ce28732d0f06f096bbfbc8" dependencies = [ "heck", "proc-macro2", @@ -2234,6 +2234,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b835cb902660db3415a672d862905e791e54d306c6e8189168c7f3d9ae1c79d" dependencies = [ + "backtrace", "snafu-derive", ] @@ -2460,9 +2461,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.3" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 9ea22ec1dd9..d5bb17fc60d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ members = [ resolver = "2" [workspace.dependencies] -pyo3 = { version = "0.22.2", features = ["extension-module", "serde"] } +pyo3 = { version = "0.23.1", features = ["extension-module", "serde"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] } tracing = "0.1.40" tracing-subscriber = "0.3.18" diff --git a/Makefile b/Makefile index 44f57763c8e..b301fbb8221 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,10 @@ rust: build-reqs BUILD_EXT_MODE=rust-only python setup.py build_ext --inplace +cli: build-reqs + python setup.py build_cli + + docs: build-reqs find docs -name '*.rst' | xargs touch $(MAKE) -C docs html SPHINXOPTS=$(SPHINXOPTS) BUILDDIR="../build" @@ -39,6 +43,10 @@ parsers: python setup.py build_parsers --inplace +libpg-query: + python setup.py build_libpg_query + + ui: build-reqs python setup.py build_ui diff --git a/docs/changelog/1_0_rc2.rst b/docs/changelog/1_0_rc2.rst index 427ba33c6be..f4a2fff8520 100644 --- a/docs/changelog/1_0_rc2.rst +++ b/docs/changelog/1_0_rc2.rst @@ -298,7 +298,7 @@ Server configuration ``EDGEDB_SERVER_SECURITY`` - ``strict == default`` - - ``insecure_dev_mode`` — disable password-based authentication and allow + - ``insecure_dev_mode`` — disable password-based authentication and allow unencrypted HTTP traffic ``EDGEDB_DOCKER_APPLY_MIGRATIONS`` (Docker only) diff --git a/docs/guides/contributing/code.rst b/docs/guides/contributing/code.rst index 931b3f24b7d..b252b607e09 100644 --- a/docs/guides/contributing/code.rst +++ b/docs/guides/contributing/code.rst @@ -38,6 +38,7 @@ Linux or macOS. Windows is not currently supported. * Libuuid dev package; * Node.js 14 or later; * Yarn 1 +* Protobuf & C bindings for Protobuf .. zlib, readline and libuuid are required to build postgres. Should be removed when custom postgres build is no longer needed. @@ -93,6 +94,8 @@ be built with the following ``shell.nix`` file. openssl pkg-config icu + protobuf + protobufc ]; LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.stdenv.cc.cc ]; LIBCLANG_PATH = "${llvmPackages.libclang.lib}/lib"; diff --git a/docs/reference/ddl/functions.rst b/docs/reference/ddl/functions.rst index e937ab0f2a7..eef80095c16 100644 --- a/docs/reference/ddl/functions.rst +++ b/docs/reference/ddl/functions.rst @@ -49,7 +49,7 @@ Create function # and is one of - set volatility := {'Immutable' | 'Stable' | 'Volatile'} ; + set volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'} ; create annotation := ; using ( ) ; using ; @@ -75,7 +75,7 @@ Most sub-commands and options of this command are identical to the :ref:`SDL function declaration `, with some additional features listed below: -:eql:synopsis:`set volatility := {'Immutable' | 'Stable' | 'Volatile'}` +:eql:synopsis:`set volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'}` Function volatility determines how aggressively the compiler can optimize its invocations. Other than a slight syntactical difference this is the same as the corresponding SDL declaration. @@ -141,7 +141,7 @@ Change the definition of a function. # and is one of - set volatility := {'Immutable' | 'Stable' | 'Volatile'} ; + set volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'} ; reset volatility ; rename to ; create annotation := ; diff --git a/docs/reference/edgeql/index.rst b/docs/reference/edgeql/index.rst index 713784a4787..86ec2943871 100644 --- a/docs/reference/edgeql/index.rst +++ b/docs/reference/edgeql/index.rst @@ -88,6 +88,7 @@ Introspection command: casts functions cardinality + volatility select insert diff --git a/docs/reference/edgeql/volatility.rst b/docs/reference/edgeql/volatility.rst new file mode 100644 index 00000000000..7534329b082 --- /dev/null +++ b/docs/reference/edgeql/volatility.rst @@ -0,0 +1,134 @@ +.. _ref_reference_volatility: + + +Volatility +========== + +The **volatility** of an expression refers to how its value may change across +successive evaluations. + +Expressions may have one of the following volatilities, in order of increasing +volatility: + +* ``Immutable``: The expression cannot modify the database and is + guaranteed to have the same value *in all statements*. + +* ``Stable``: The expression cannot modify the database and is + guaranteed to have the same value *within a single statement*. + +* ``Volatile``: The expression cannot modify the database and can have + different values on successive evaluations. + +* ``Modifying``: The expression can modify the database and can have + different values on successive evaluations. + + +Expressions +----------- + +All :ref:`primitives `, +:ref:`ranges `, and +:ref:`multiranges ` are ``Immutable``. + +:ref:`Arrays `, :ref:`tuples `, and +:ref:`sets ` have the volatility of their most volatile +component. + +:ref:`Globals ` are always ``Stable``, even computed +globals with an immutable expression. + + +Objects and shapes +^^^^^^^^^^^^^^^^^^ + +:ref:`Objects ` are generally ``Stable`` except: + +* Objects with a :ref:`shape ` containing a more volatile + computed pointer will have the volatility of its most volatile component. + +* :ref:`Free objects ` have the volatility of + their most volatile component. They may be ``Immutable``. + +An object's non-computed pointers are ``Stable``. Its computed pointers have +the volatility of their expressions. + +Any DML (i.e., :ref:`insert `, :ref:`update `, +:ref:`delete `) is ``Modifying``. + + +Functions and operators +^^^^^^^^^^^^^^^^^^^^^^^ + +Unless explicitly specified, a :ref:`function's ` +volatility will be inferred from its body expression. + +A function call's volatility is highest of its body expression and its call +arguments. + +Given: + +.. code-block:: sdl + + # Immutable + function plus_primitive(x: float64) -> float64 + using (x + 1); + + # Stable + global one := 1; + function plus_global(x: float64) -> float64 + using (x + one); + + # Volatile + function plus_random(x: float64) -> float64 + using (x + random()); + + # Modifying + type One { + val := 1; + }; + function plus_insert(x: float64) -> float64 + using (x + (insert One).val); + +Some example operator and function calls: + +.. code-block:: + + 1 + 1: Immutable + 1 + global one: Stable + global one + random(): Volatile + (insert One).val: Modifying + plus_primitive(1): Immutable + plus_stable(1): Stable + plus_random(global one): Volatile + plus_insert(random()): Immutable + + +Restrictions +------------ + +Some features restrict the volatility of expressions. A lower volatility +can be used. + +:ref:`Indexes ` expressions must be ``Immutable``. +Within the index, pointers to the indexed object are treated as immutable + +:ref:`constraints ` expressions must be +``Immutable``. Within the constraint, the ``__subject__`` and its pointers are +treated as immutable. + +:ref:`Access policies ` must be ``Stable``. + +:ref:`Aliases `, :ref:`globals `, +and :ref:`computed pointers ` in the schema must be +``Stable``. + +The :ref:`cartesian product ` of a +``Volatile`` or ``Modifying`` expression is not allowed. + +.. code-block:: edgeql-repl + + db> SELECT {1, 2} + random() + QueryError: can not take cross product of volatile operation + +``Modifying`` expressions are not allowed in a non-scalar argument to a +function, except for :ref:`standard set functions `. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index c418325fd31..527747c13db 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -24,7 +24,7 @@ Reference backend_ha configuration http - sql_support + sql_adapter protocol/index bindings/index admin/index diff --git a/docs/reference/sdl/access_policies.rst b/docs/reference/sdl/access_policies.rst index 09662200561..c3a19d73c7f 100644 --- a/docs/reference/sdl/access_policies.rst +++ b/docs/reference/sdl/access_policies.rst @@ -190,6 +190,8 @@ The access policy declaration options are as follows: depends on whether this policy flavor is :eql:synopsis:`allow` or :eql:synopsis:`deny`. + The expression must be :ref:`Stable `. + When omitted, it is assumed that this policy applies to all eligible objects of a given type. diff --git a/docs/reference/sdl/aliases.rst b/docs/reference/sdl/aliases.rst index 9737a4c3103..e0520f98d9c 100644 --- a/docs/reference/sdl/aliases.rst +++ b/docs/reference/sdl/aliases.rst @@ -47,7 +47,8 @@ This declaration defines a new alias with the following options: The name (optionally module-qualified) of an alias to be created. :eql:synopsis:`` - The aliased expression. Can be any valid EdgeQL expression. + The aliased expression. Must be a :ref:`Stable ` + EdgeQL expression. The valid SDL sub-declarations are listed below: diff --git a/docs/reference/sdl/constraints.rst b/docs/reference/sdl/constraints.rst index b6983629f33..7b9d3c706e6 100644 --- a/docs/reference/sdl/constraints.rst +++ b/docs/reference/sdl/constraints.rst @@ -113,9 +113,14 @@ This declaration defines a new constraint with the following options: :eql:synopsis:`on ( )` An optional expression defining the *subject* of the constraint. If not specified, the subject is the value of the schema item on - which the concrete constraint is defined. The expression must - refer to the original subject of the constraint as - ``__subject__``. Note also that ```` itself has to + which the concrete constraint is defined. + + The expression must refer to the original subject of the constraint as + ``__subject__``. The expression must be + :ref:`Immutable `, but may refer to + ``__subject__`` and its properties and links. + + Note also that ```` itself has to be parenthesized. :eql:synopsis:`except ( )` diff --git a/docs/reference/sdl/functions.rst b/docs/reference/sdl/functions.rst index edf46a71486..e9c46bb949e 100644 --- a/docs/reference/sdl/functions.rst +++ b/docs/reference/sdl/functions.rst @@ -40,7 +40,7 @@ commands `. function ([ ] [, ... ]) -> "{" [ ] - [ volatility := {'Immutable' | 'Stable' | 'Volatile'} ] + [ volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'} ] [ using ( ) ; ] [ using ; ] [ ... ] @@ -127,23 +127,26 @@ This declaration defines a new constraint with the following options: The valid SDL sub-declarations are listed below: -:eql:synopsis:`volatility := {'Immutable' | 'Stable' | 'Volatile'}` +:eql:synopsis:`volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'}` Function volatility determines how aggressively the compiler can optimize its invocations. - If not explicitly specified the function volatility is set to - ``Volatile`` by default. + If not explicitly specified the function volatility is + :ref:`inferred ` from the function body. - * A ``Volatile`` function can modify the database and can return - different results on successive calls with the same arguments. + * An ``Immutable`` function cannot modify the database and is + guaranteed to return the same results given the same arguments + *in all statements*. * A ``Stable`` function cannot modify the database and is guaranteed to return the same results given the same arguments *within a single statement*. - * An ``Immutable`` function cannot modify the database and is - guaranteed to return the same results given the same arguments - *forever*. + * A ``Volatile`` function cannot modify the database and can return + different results on successive calls with the same arguments. + + * A ``Modifying`` function can modify the database and can return + different results on successive calls with the same arguments. :eql:synopsis:`using ( )` Specified the body of the function. :eql:synopsis:`` is an diff --git a/docs/reference/sdl/globals.rst b/docs/reference/sdl/globals.rst index 19355aaf55b..1587c6e4ac7 100644 --- a/docs/reference/sdl/globals.rst +++ b/docs/reference/sdl/globals.rst @@ -108,9 +108,12 @@ The following options are available: denoting a non-abstract scalar or a container type. :eql:synopsis:` := ` - Defines a *computed* global variable. The provided expression can be any - valid EdgeQL expression, including one referring to other global - variables. The type of a *computed* global variable is not limited to + Defines a *computed* global variable. + + The provided expression must be a :ref:`Stable ` + EdgeQL expression. It can refer to other global variables. + + The type of a *computed* global variable is not limited to scalar and container types, but also includes object types. So it is possible to use that to define a global object variable based on an another global scalar variable. diff --git a/docs/reference/sdl/indexes.rst b/docs/reference/sdl/indexes.rst index 1e0c124fa9b..00a267ad242 100644 --- a/docs/reference/sdl/indexes.rst +++ b/docs/reference/sdl/indexes.rst @@ -63,8 +63,12 @@ Description This declaration defines a new index with the following options: :sdl:synopsis:`on ( )` - The specific expression for which the index is made. Note also - that ```` itself has to be parenthesized. + The specific expression for which the index is made. + + The expression must be :ref:`Immutable ` but may + refer to the indexed object's properties and links. + + Note also that ```` itself has to be parenthesized. :eql:synopsis:`except ( )` An optional expression defining a condition to create exceptions diff --git a/docs/reference/sdl/links.rst b/docs/reference/sdl/links.rst index a0572474abf..f90a05ff702 100644 --- a/docs/reference/sdl/links.rst +++ b/docs/reference/sdl/links.rst @@ -300,6 +300,8 @@ The valid SDL sub-declarations are listed below: The default value is used in an ``insert`` statement if an explicit value for this link is not specified. + The expression must be :ref:`Stable `. + :eql:synopsis:`readonly := {true | false}` If ``true``, the link is considered *read-only*. Modifications of this link are prohibited once an object is created. All of the diff --git a/docs/reference/sdl/properties.rst b/docs/reference/sdl/properties.rst index a70788c9315..42817e68da4 100644 --- a/docs/reference/sdl/properties.rst +++ b/docs/reference/sdl/properties.rst @@ -270,6 +270,8 @@ The valid SDL sub-declarations are listed below: The default value is used in an ``insert`` statement if an explicit value for this property is not specified. + The expression must be :ref:`Stable `. + :eql:synopsis:`readonly := {true | false}` If ``true``, the property is considered *read-only*. Modifications of this property are prohibited once an object is diff --git a/docs/reference/sql_support.rst b/docs/reference/sql_adapter.rst similarity index 50% rename from docs/reference/sql_support.rst rename to docs/reference/sql_adapter.rst index 947252cf9a4..381c39e8cdf 100644 --- a/docs/reference/sql_support.rst +++ b/docs/reference/sql_adapter.rst @@ -1,9 +1,9 @@ .. versionadded:: 3.0 -.. _ref_sql_support: +.. _ref_sql_adapter: =========== -SQL support +SQL adapter =========== .. edb:youtube-embed:: 0KdY2MPb2oc @@ -11,12 +11,21 @@ SQL support Connecting ========== -EdgeDB supports running read-only SQL queries via the Postgres protocol to -enable connecting EdgeDB to existing BI and analytics solutions. Any -Postgres-compatible client can connect to your EdgeDB database by using the +EdgeDB server supports PostgreSQL connection interface. It implements PostgreSQL +wire protocol as well as SQL query language. + +As of EdgeDB 6.0, it also supports a subset of Data Modification Language, +namely INSERT, DELETE and UPDATE statements. + +It does not, however, support PostgreSQL Data Definition Language +(e.g. ``CREATE TABLE``). This means that it is not possible to use SQL +connections to EdgeDB to modify its schema. Instead, the schema should be +managed using ESDL (EdgeDB Schema Definition Language) and migration commands. + +Any Postgres-compatible client can connect to an EdgeDB database by using the same port that is used for the EdgeDB protocol and the -:versionreplace:`database;5.0:branch` name, username, and password you already -use for your database. +:versionreplace:`database;5.0:branch` name, username, and password already used +for the database. .. versionchanged:: _default @@ -52,7 +61,7 @@ use for your database. The insecure DSN returned by the CLI for EdgeDB Cloud instances will not contain the password. You will need to either :ref:`create a new role and - set the password `, using those values to connect + set the password `, using those values to connect to your SQL client, or change the password of the existing role, using that role name along with the newly created password. @@ -76,7 +85,7 @@ use for your database. ``libpq.dll``, click "Properties," and find the version on the "Details" tab. -.. _ref_sql_support_new_role: +.. _ref_sql_adapter_new_role: Creating a new role ------------------- @@ -177,24 +186,29 @@ Multi properties are in separate tables. ``source`` is the ``id`` of the Movie. SELECT source, target FROM "Movie.labels"; -When types are extended, parent object types' tables will by default contain -all objects of both the type and any types extended by it. The query below will +When using inheritance, parent object types' tables will by default contain +all objects of both the parent type and any child types. The query below will return all ``common::Content`` objects as well as all ``Movie`` objects. .. code-block:: sql SELECT id, title FROM common."Content"; -To omit objects of extended types, use ``ONLY``. This query will return +To omit objects of child types, use ``ONLY``. This query will return ``common::Content`` objects but not ``Movie`` objects. .. code-block:: sql SELECT id, title FROM ONLY common."Content"; -The SQL connector supports read-only statements and will throw errors if the -client attempts ``INSERT``, ``UPDATE``, ``DELETE``, or any DDL command. It -supports all SQL expressions supported by Postgres. +The SQL adapter supports a large majority of SQL language, including: + +- ``SELECT`` and all read-only constructs (``WITH``, sub-query, ``JOIN``, ...), +- ``INSERT`` / ``UPDATE`` / ``DELETE``, +- ``COPY ... FROM``, +- ``SET`` / ``RESET`` / ``SHOW``, +- transaction commands, +- ``PREPARE`` / ``EXECUTE`` / ``DEALLOCATE``. .. code-block:: sql @@ -207,8 +221,8 @@ supports all SQL expressions supported by Postgres. WHERE act.source = m.id ); -EdgeDB accomplishes this by emulating the ``information_schema`` and -``pg_catalog`` views to mimic the catalogs provided by Postgres 13. +The SQL adapter emulates the ``information_schema`` and ``pg_catalog`` views to +mimic the catalogs provided by Postgres 13. .. note:: @@ -244,7 +258,7 @@ Tested SQL tools include `XMIN Replication`_, incremental updates using "a user-defined monotonically increasing id," and full table updates. .. [2] dbt models are built and stored in the database as either tables or - views. Because the EdgeDB SQL connector does not allow writing or even + views. Because the EdgeDB SQL adapter does not allow writing or even creating schemas, view, or tables, any attempt to materialize dbt models will result in errors. If you want to build the models, we suggest first transferring your data to a true Postgres instance via pg_dump or Airbyte. @@ -254,3 +268,169 @@ Tested SQL tools https://www.postgresql.org/docs/current/runtime-config-replication.html .. _XMIN Replication: https://www.postgresql.org/docs/15/ddl-system-columns.html + + +ESDL to PostgreSQL +================== + +As mentioned, the SQL schema of the database is managed trough EdgeDB Schema +Definition Language (ESDL). Here is a breakdown of how each of the ESDL +construct is mapped to PostgreSQL schema: + +- Objects types are mapped into tables. + Each table has columns ``id UUID`` and ``__type__ UUID`` and one column for + each single property or link. + +- Single properties are mapped to tables columns. + +- Single links are mapped to table columns with suffix ``_id`` and are of type + ``UUID``. They contain the ids of the link's target type. + +- Multi properties are mapped to tables with two columns: + - ``source UUID``, which contains the id of the property's source object type, + - ``target``, which contains values of the property. + +- Multi links are mapped to tables with columns: + - ``source UUID``, which contains the id of the property's source object type, + - ``target UUID``, which contains the ids of the link's target object type, + - one column for each link property, using the same rules as properties on + object types. + +- Aliases are not mapped to PostgreSQL schema. + +- Globals are mapped to connection settings, prefixed with ``global ``. + For example, a ``global default::username: str`` can be set using + ``SET "global default::username" TO 'Tom'``. + +- Access policies are applied to object type tables when setting + ``apply_access_policies_sql`` is set to ``true``. + +- Mutation rewrites and triggers are applied to all DML commands. + + +DML commands +============ + +When using ``INSERT``, ``DELETE`` or ``UPDATE`` on any table, mutation rewrites +and triggers are applied. These commands do not have a straight-forward +translation to EdgeQL DML commands, but instead use the following mapping: + +- ``INSERT INTO "Foo"`` object table maps to ``insert Foo``, + +- ``INSERT INTO "Foo.keywords"`` link/property table maps to an + ``update Foo { keywords += ... }``, + +- ``DELETE FROM "Foo"`` object table maps to ``delete Foo``, + +- ``DELETE FROM "Foo.keywords"`` link property/table maps to + ``update Foo { keywords -= ... }``, + +- ``UPDATE "Foo"`` object table maps to ``update Foo set { ... }``, + +- ``UPDATE "Foo.keywords"`` is not supported. + + +Connection settings +=================== + +SQL adapter supports a limited subset of PostgreSQL connection settings. +There are the following additionally connection settings: + +- ``allow_user_specified_id`` (default ``false``), +- ``apply_access_policies_sql`` (default ``false``), +- settings prefixed with ``"global "`` can use used to set values of globals. + +Note that if ``allow_user_specified_id`` or ``apply_access_policies_sql`` are +unset, they default to configuration set by ``configure current database`` +EdgeQL command. + + +Example: gradual transition from ORMs to EdgeDB +=============================================== + +When a project is using Object-Relational Mappings (e.g. SQLAlchemy, Django, +Hibernate ORM, TypeORM) and is considering the migration to EdgeDB, it might +want to execute the transition gradually, as opposed to a total rewrite of the +project. + +In this case, the project can start the transition by migrating the ORM models +to EdgeDB Schema Definition Language. + +For example, such Hibernate ORM model in Java: + +.. code-block:: + + @Entity + class Movie { + @Id + @GeneratedValue(strategy = GenerationType.UUID) + UUID id; + + private String title; + + @NotNull + private Integer releaseYear; + + // ... getters and setters ... + } + +... would be translated to the following EdgeDB SDL: + +.. code-block:: sdl + + type Movie { + title: str; + + required releaseYear: int32; + } + +A new EdgeDB instance can now be created and migrated to the translated schema. +At this stage, EdgeDB will allow SQL connections to write into the ``"Movie"`` +table, just as it would have been created with the following DDL command: + +.. code-block:: sql + + CREATE TABLE "Movie" ( + id UUID PRIMARY KEY DEFAULT (...), + __type__ UUID NOT NULL DEFAULT (...), + title TEXT, + releaseYear INTEGER NOT NULL + ); + +When translating the old ORM model to EdgeDB SDL, one should aim to make the +SQL schema of EdgeDB match the SQL schema that the ORM expects. + +When this match is accomplished, any query that used to work with the old, plain +PostgreSQL, should now also work with the EdgeDB. For example, we can execute +the following query: + +.. code-block:: sql + + INSERT INTO "Movie" (title, releaseYear) + VALUES ("Madagascar", 2012) + RETURNING id, title, releaseYear; + +To complete the migration, the data can be exported from our old database into +an ``.sql`` file, which can be import it into EdgeDB: + +.. code-block:: bash + + $ pg_dump {your PostgreSQL connection params} \ + --data-only --inserts --no-owner --no-privileges \ + > dump.sql + + $ psql {your EdgeDB connection params} --file dump.sql + +Now, the ORM can be pointed to EdgeDB instead of the old PostgreSQL database, +which has been fully replaced. + +Arguably, the development of new features with the ORM is now more complex for +the duration of the transition, since the developer has to modify two model +definitions: the ORM and the EdgeDB schema. + +But it allows any new models to use EdgeDB schema, EdgeQL and code generators +for the client language of choice. The ORM-based code can now also be gradually +rewritten to use EdgeQL, one model at the time. + +For a detailed migration example, see repository +`edgedb/hibernate-example `_. diff --git a/docs/stdlib/cfg.rst b/docs/stdlib/cfg.rst index 18fbc11cef7..3e3a37d8430 100644 --- a/docs/stdlib/cfg.rst +++ b/docs/stdlib/cfg.rst @@ -439,7 +439,7 @@ Client connections - EdgeDB binary protocol * - ``cfg::ConnectionTransport.TCP_PG`` - Postgres protocol for the - :ref:`SQL query mode ` + :ref:`SQL query mode ` * - ``cfg::ConnectionTransport.HTTP`` - EdgeDB binary protocol :ref:`tunneled over HTTP ` diff --git a/edb/api/types.txt b/edb/api/types.txt index 1197b178bb4..3825a01dd6d 100644 --- a/edb/api/types.txt +++ b/edb/api/types.txt @@ -49,3 +49,9 @@ 00000000-0000-0000-0000-000000000112 std::cal::date_duration 00000000-0000-0000-0000-000000000130 cfg::memory + +00000000-0000-0000-0000-000001000001 std::pg::json +00000000-0000-0000-0000-000001000002 std::pg::timestamptz +00000000-0000-0000-0000-000001000003 std::pg::timestamp +00000000-0000-0000-0000-000001000004 std::pg::date +00000000-0000-0000-0000-000001000005 std::pg::interval diff --git a/edb/buildmeta.py b/edb/buildmeta.py index e40714fd61c..d381e336b4b 100644 --- a/edb/buildmeta.py +++ b/edb/buildmeta.py @@ -60,7 +60,7 @@ # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. -EDGEDB_CATALOG_VERSION = 2024_11_01_00_00 +EDGEDB_CATALOG_VERSION = 2024_11_15_00_00 EDGEDB_MAJOR_VERSION = 6 diff --git a/edb/common/assert_data_shape.py b/edb/common/assert_data_shape.py index 54c8a40cd82..22cde8812bb 100644 --- a/edb/common/assert_data_shape.py +++ b/edb/common/assert_data_shape.py @@ -20,14 +20,13 @@ from __future__ import annotations +import datetime import decimal import math import pprint import uuid import unittest -from datetime import timedelta - import edgedb @@ -280,14 +279,14 @@ def _assert_generic_shape(path, data, shape): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') - elif isinstance(shape, (str, int, bytes, timedelta, + elif isinstance(shape, (str, int, bytes, datetime.timedelta, decimal.Decimal)): if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.RelativeDuration): - if data != timedelta( + if data != datetime.timedelta( days=shape.months * 30 + shape.days, microseconds=shape.microseconds, ): @@ -295,7 +294,7 @@ def _assert_generic_shape(path, data, shape): f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.DateDuration): - if data != timedelta( + if data != datetime.timedelta( days=shape.months * 30 + shape.days, ): fail( @@ -352,7 +351,7 @@ def _assert_generic_shape(path, data, shape): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') - elif isinstance(shape, (str, int, bytes, timedelta, + elif isinstance(shape, (str, int, bytes, datetime.timedelta, decimal.Decimal)): if data != shape: fail( diff --git a/edb/common/debug.py b/edb/common/debug.py index 5b719181497..737d6f7a380 100644 --- a/edb/common/debug.py +++ b/edb/common/debug.py @@ -134,6 +134,9 @@ class flags(metaclass=FlagsMeta): delta_execute_ddl = Flag( doc="Output just the DDL commands as executed during migration.") + delta_validate_reflection = Flag( + doc="Whether to do expensive validation of reflection correctness.") + server = Flag( doc="Print server errors.") diff --git a/edb/edgeql-parser/edgeql-parser-python/src/errors.rs b/edb/edgeql-parser/edgeql-parser-python/src/errors.rs index 268fdf1619d..a62e6b48bd2 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/errors.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/errors.rs @@ -29,16 +29,17 @@ impl ParserResult { let mut buf = vec![0u8]; // type and version bincode::serialize_into(&mut buf, &rv) .map_err(|e| PyValueError::new_err(format!("Failed to pack: {e}")))?; - Ok(PyBytes::new_bound(py, buf.as_slice()).into()) + Ok(PyBytes::new(py, buf.as_slice()).into()) } } -pub fn parser_error_into_tuple(py: Python, error: Error) -> PyObject { +pub fn parser_error_into_tuple( + error: &Error, +) -> (&str, (u64, u64), Option<&String>, Option<&String>) { ( - error.message, + &error.message, (error.span.start, error.span.end), - error.hint, - error.details, + error.hint.as_ref(), + error.details.as_ref(), ) - .into_py(py) } diff --git a/edb/edgeql-parser/edgeql-parser-python/src/hash.rs b/edb/edgeql-parser/edgeql-parser-python/src/hash.rs index 9139c0d09f1..f52a2b69170 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/hash.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/hash.rs @@ -1,4 +1,4 @@ -use std::cell::RefCell; +use std::sync::RwLock; use edgeql_parser::hash; use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyString}; @@ -7,7 +7,7 @@ use crate::errors::SyntaxError; #[pyclass] pub struct Hasher { - _hasher: RefCell>, + _hasher: RwLock>, } #[pymethods] @@ -16,13 +16,13 @@ impl Hasher { fn start_migration(parent_id: &Bound) -> PyResult { let hasher = hash::Hasher::start_migration(parent_id.to_str()?); Ok(Hasher { - _hasher: RefCell::new(Some(hasher)), + _hasher: RwLock::new(Some(hasher)), }) } fn add_source(&self, py: Python, data: &Bound) -> PyResult { let text = data.to_str()?; - let mut cell = self._hasher.borrow_mut(); + let mut cell = self._hasher.write().unwrap(); let hasher = cell .as_mut() .ok_or_else(|| PyRuntimeError::new_err(("cannot add source after finish",)))?; @@ -36,7 +36,7 @@ impl Hasher { } fn make_migration_id(&self) -> PyResult { - let mut cell = self._hasher.borrow_mut(); + let mut cell = self._hasher.write().unwrap(); let hasher = cell .take() .ok_or_else(|| PyRuntimeError::new_err(("cannot do migration id twice",)))?; diff --git a/edb/edgeql-parser/edgeql-parser-python/src/keywords.rs b/edb/edgeql-parser/edgeql-parser-python/src/keywords.rs index 02e244b670c..7c7d5968ec0 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/keywords.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/keywords.rs @@ -1,46 +1,35 @@ -use pyo3::{ - prelude::*, - types::{PyList, PyString}, -}; +use pyo3::{prelude::*, types::PyFrozenSet}; use edgeql_parser::keywords; pub struct AllKeywords { - pub current: PyObject, - pub future: PyObject, - pub unreserved: PyObject, - pub partial: PyObject, + pub current: Py, + pub future: Py, + pub unreserved: Py, + pub partial: Py, } pub fn get_keywords(py: Python) -> PyResult { - let intern = py.import_bound("sys")?.getattr("intern")?; - let frozen = py.import_bound("builtins")?.getattr("frozenset")?; + let intern = py.import("sys")?.getattr("intern")?; - let current = prepare_keywords(py, keywords::CURRENT_RESERVED_KEYWORDS.iter(), &intern)?; - let unreserved = prepare_keywords(py, keywords::UNRESERVED_KEYWORDS.iter(), &intern)?; - let future = prepare_keywords(py, keywords::FUTURE_RESERVED_KEYWORDS.iter(), &intern)?; - let partial = prepare_keywords(py, keywords::PARTIAL_RESERVED_KEYWORDS.iter(), &intern)?; Ok(AllKeywords { - current: frozen - .call((PyList::new_bound(py, ¤t),), None)? - .into(), - unreserved: frozen - .call((PyList::new_bound(py, &unreserved),), None)? - .into(), - future: frozen.call((PyList::new_bound(py, &future),), None)?.into(), - partial: frozen - .call((PyList::new_bound(py, &partial),), None)? - .into(), + current: prepare_keywords(py, &keywords::CURRENT_RESERVED_KEYWORDS, &intern)?, + unreserved: prepare_keywords(py, &keywords::UNRESERVED_KEYWORDS, &intern)?, + future: prepare_keywords(py, &keywords::FUTURE_RESERVED_KEYWORDS, &intern)?, + partial: prepare_keywords(py, &keywords::PARTIAL_RESERVED_KEYWORDS, &intern)?, }) } -fn prepare_keywords<'py, I: Iterator>( +fn prepare_keywords<'a, 'py, I: IntoIterator>( py: Python<'py>, keyword_set: I, - intern: &'py Bound<'py, PyAny>, -) -> Result>, PyErr> { - keyword_set - .cloned() - .map(|s: &str| intern.call((PyString::new_bound(py, s),), None)) - .collect() + intern: &Bound<'py, PyAny>, +) -> PyResult> { + PyFrozenSet::new( + py, + keyword_set + .into_iter() + .map(|s| intern.call((&s,), None).unwrap()), + ) + .map(|o| o.unbind()) } diff --git a/edb/edgeql-parser/edgeql-parser-python/src/lib.rs b/edb/edgeql-parser/edgeql-parser-python/src/lib.rs index b36c0ddc9f6..d1fbcad8f30 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/lib.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/lib.rs @@ -14,8 +14,8 @@ use pyo3::prelude::*; /// Rust bindings to the edgeql-parser crate #[pymodule] fn _edgeql_parser(py: Python, m: &Bound) -> PyResult<()> { - m.add("SyntaxError", py.get_type_bound::())?; - m.add("ParserResult", py.get_type_bound::())?; + m.add("SyntaxError", py.get_type::())?; + m.add("ParserResult", py.get_type::())?; m.add_class::()?; @@ -36,7 +36,7 @@ fn _edgeql_parser(py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(position::offset_of_line, m)?)?; - m.add("SourcePoint", py.get_type_bound::())?; + m.add("SourcePoint", py.get_type::())?; m.add_class::()?; m.add_function(wrap_pyfunction!(tokenizer::tokenize, m)?)?; @@ -44,7 +44,7 @@ fn _edgeql_parser(py: Python, m: &Bound) -> PyResult<()> { m.add_function(wrap_pyfunction!(unpack::unpack, m)?)?; - tokenizer::fini_module(py, m); + tokenizer::fini_module(m); Ok(()) } diff --git a/edb/edgeql-parser/edgeql-parser-python/src/parser.rs b/edb/edgeql-parser/edgeql-parser-python/src/parser.rs index 678612c30df..0f572ddd59b 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/parser.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/parser.rs @@ -3,10 +3,10 @@ use once_cell::sync::OnceCell; use edgeql_parser::parser; use pyo3::exceptions::{PyAssertionError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyList, PyString, PyTuple}; +use pyo3::types::{PyList, PyString}; use crate::errors::{parser_error_into_tuple, ParserResult}; -use crate::pynormalize::value_to_py_object; +use crate::pynormalize::TokenizerValue; use crate::tokenizer::OpaqueToken; #[pyfunction] @@ -14,7 +14,7 @@ pub fn parse( py: Python, start_token_name: &Bound, tokens: PyObject, -) -> PyResult<(ParserResult, PyObject)> { +) -> PyResult<(ParserResult, &'static Py)> { let start_token_name = start_token_name.to_string(); let (spec, productions) = get_spec()?; @@ -24,28 +24,22 @@ pub fn parse( let context = parser::Context::new(spec); let (cst, errors) = parser::parse(&tokens, &context); - let cst = cst.map(|c| to_py_cst(&c, py)).transpose()?; - - let errors = errors - .into_iter() - .map(|e| parser_error_into_tuple(py, e)) - .collect::>(); - let errors = PyList::new_bound(py, &errors); + let errors = PyList::new(py, errors.iter().map(|e| parser_error_into_tuple(e)))?; let res = ParserResult { - out: cst.into_py(py), + out: cst.as_ref().map(ParserCSTNode).into_pyobject(py)?.unbind(), errors: errors.into(), }; - Ok((res, productions.to_object(py))) + Ok((res, productions)) } #[pyclass] pub struct CSTNode { #[pyo3(get)] - production: PyObject, + production: Option>, #[pyo3(get)] - terminal: PyObject, + terminal: Option>, } #[pyclass] @@ -136,56 +130,55 @@ pub fn save_spec(spec_json: &Bound, dst: &Bound) -> PyResult fn load_productions(py: Python<'_>, spec: &parser::Spec) -> PyResult { let grammar_name = "edb.edgeql.parser.grammar.start"; - let grammar_mod = py.import_bound(grammar_name)?; + let grammar_mod = py.import(grammar_name)?; let load_productions = py - .import_bound("edb.common.parsing")? + .import("edb.common.parsing")? .getattr("load_spec_productions")?; - let production_names: Vec<_> = spec - .production_names - .iter() - .map(|(a, b)| PyTuple::new_bound(py, [a, b])) - .collect(); - - let productions = load_productions.call((production_names, grammar_mod), None)?; + let productions = load_productions.call((&spec.production_names, grammar_mod), None)?; Ok(productions.into()) } -fn to_py_cst<'a>(cst: &'a parser::CSTNode<'a>, py: Python) -> PyResult { - Ok(match cst { - parser::CSTNode::Empty => CSTNode { - production: py.None(), - terminal: py.None(), - }, - parser::CSTNode::Terminal(token) => CSTNode { - production: py.None(), - terminal: Terminal { - text: token.text.clone(), - value: if let Some(val) = &token.value { - value_to_py_object(py, val)? - } else { - py.None() - }, - start: token.span.start, - end: token.span.end, - } - .into_py(py), - }, - parser::CSTNode::Production(prod) => CSTNode { - production: Production { - id: prod.id, - args: PyList::new_bound( +/// Newtype required to define a trait for a foreign type. +struct ParserCSTNode<'a>(&'a parser::CSTNode<'a>); + +impl<'a, 'py> IntoPyObject<'py> for ParserCSTNode<'a> { + type Target = CSTNode; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult { + let res = match self.0 { + parser::CSTNode::Empty => CSTNode { + production: None, + terminal: None, + }, + parser::CSTNode::Terminal(token) => CSTNode { + production: None, + terminal: Some(Py::new( py, - prod.args - .iter() - .map(|a| to_py_cst(a, py).map(|x| x.into_py(py))) - .collect::>>()? - .as_slice(), - ) - .into(), - } - .into_py(py), - terminal: py.None(), - }, - }) + Terminal { + text: token.text.clone(), + value: (token.value.as_ref()) + .map(TokenizerValue) + .into_pyobject(py)? + .unbind(), + start: token.span.start, + end: token.span.end, + }, + )?), + }, + parser::CSTNode::Production(prod) => CSTNode { + production: Some(Py::new( + py, + Production { + id: prod.id, + args: PyList::new(py, prod.args.iter().map(ParserCSTNode))?.into(), + }, + )?), + terminal: None, + }, + }; + Ok(Py::new(py, res)?.bind(py).clone()) + } } diff --git a/edb/edgeql-parser/edgeql-parser-python/src/position.rs b/edb/edgeql-parser/edgeql-parser-python/src/position.rs index 42b5bc54407..ddaada55bf1 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/position.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/position.rs @@ -1,7 +1,7 @@ use pyo3::{ exceptions::{PyIndexError, PyRuntimeError}, prelude::*, - types::PyBytes, + types::{PyBytes, PyList}, }; use edgeql_parser::position::InflatedPos; @@ -14,18 +14,20 @@ pub struct SourcePoint { #[pymethods] impl SourcePoint { #[staticmethod] - fn from_offsets(py: Python, data: &Bound, offsets: PyObject) -> PyResult { + fn from_offsets(py: Python, data: &Bound, offsets: PyObject) -> PyResult> { let mut list: Vec = offsets.extract(py)?; let data: &[u8] = data.as_bytes(); list.sort(); let result = InflatedPos::from_offsets(data, &list) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(result - .into_iter() - .map(|_position| SourcePoint { _position }) - .collect::>() - .into_py(py)) + PyList::new( + py, + result + .into_iter() + .map(|_position| SourcePoint { _position }), + ) + .map(|v| v.into()) } #[getter] diff --git a/edb/edgeql-parser/edgeql-parser-python/src/pynormalize.rs b/edb/edgeql-parser/edgeql-parser-python/src/pynormalize.rs index ce56df919a4..34b2b5d6fd9 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/pynormalize.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/pynormalize.rs @@ -8,7 +8,7 @@ use edgedb_protocol::model::{BigInt, Decimal}; use edgeql_parser::tokenizer::Value; use pyo3::exceptions::{PyAssertionError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict, PyFloat, PyList, PyLong, PyString}; +use pyo3::types::{PyBytes, PyDict, PyFloat, PyInt, PyList, PyString}; use crate::errors::SyntaxError; use crate::normalize::{normalize as _normalize, Error, PackedEntry, Variable}; @@ -55,20 +55,16 @@ pub struct Entry { impl Entry { pub fn new(py: Python, entry: crate::normalize::Entry) -> PyResult { - let blobs = serialize_all(py, &entry.variables).map_err(PyAssertionError::new_err)?; - let counts: Vec<_> = entry - .variables - .iter() - .map(|x| x.len().into_py(py)) - .collect(); + let blobs = serialize_all(py, &entry.variables)?; + let counts = entry.variables.iter().map(|x| x.len()); Ok(Entry { - key: PyBytes::new_bound(py, &entry.hash[..]).into(), - tokens: tokens_to_py(py, entry.tokens.clone())?, + key: PyBytes::new(py, &entry.hash[..]).into(), + tokens: tokens_to_py(py, entry.tokens.clone())?.into_any(), extra_blobs: blobs.into(), extra_named: entry.named_args, first_extra: entry.first_arg, - extra_counts: PyList::new_bound(py, &counts[..]).into(), + extra_counts: PyList::new(py, counts)?.into(), entry_pack: entry.into(), }) } @@ -77,10 +73,10 @@ impl Entry { #[pymethods] impl Entry { fn get_variables(&self, py: Python) -> PyResult { - let vars = PyDict::new_bound(py); + let vars = PyDict::new(py); let first = match self.first_extra { Some(first) => first, - None => return Ok(vars.to_object(py)), + None => return Ok(vars.into()), }; for (idx, var) in self.entry_pack.variables.iter().flatten().enumerate() { let s = if self.extra_named { @@ -88,17 +84,17 @@ impl Entry { } else { (first + idx).to_string() }; - vars.set_item(s.into_py(py), value_to_py_object(py, &var.value)?)?; + vars.set_item(s, TokenizerValue(&var.value))?; } - Ok(vars.to_object(py)) + Ok(vars.into()) } fn pack(&self, py: Python) -> PyResult { let mut buf = vec![1u8]; // type and version bincode::serialize_into(&mut buf, &self.entry_pack) .map_err(|e| PyValueError::new_err(format!("Failed to pack: {e}")))?; - Ok(PyBytes::new_bound(py, buf.as_slice()).into()) + Ok(PyBytes::new(py, buf.as_slice()).into()) } } @@ -167,28 +163,35 @@ pub fn serialize_extra(variables: &[Variable]) -> Result { pub fn serialize_all<'a>( py: Python<'a>, variables: &[Vec], -) -> Result, String> { +) -> PyResult> { let mut buf = Vec::with_capacity(variables.len()); for vars in variables { - let bytes = serialize_extra(vars)?; - buf.push(PyBytes::new_bound(py, &bytes)); + let bytes = serialize_extra(vars).map_err(PyAssertionError::new_err)?; + buf.push(PyBytes::new(py, &bytes)); } - Ok(PyList::new_bound(py, &buf)) + PyList::new(py, &buf) } -pub fn value_to_py_object(py: Python, val: &Value) -> PyResult { - Ok(match val { - Value::Int(v) => v.into_py(py), - Value::String(v) => v.into_py(py), - Value::Float(v) => v.into_py(py), - Value::BigInt(v) => py - .get_type_bound::() - .call((v, 16.into_py(py)), None)? - .into(), - Value::Decimal(v) => py - .get_type_bound::() - .call((v.to_string(),), None)? - .into(), - Value::Bytes(v) => PyBytes::new_bound(py, v).into(), - }) +/// Newtype required to define a trait for a foreign type. +pub struct TokenizerValue<'a>(pub &'a Value); + +impl<'py> IntoPyObject<'py> for TokenizerValue<'py> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult { + let res = match self.0 { + Value::Int(v) => v.into_pyobject(py)?.into_any(), + Value::String(v) => v.into_pyobject(py)?.into_any(), + Value::Float(v) => v.into_pyobject(py)?.into_any(), + Value::BigInt(v) => py.get_type::().call((v, 16), None)?, + Value::Decimal(v) => py + .get_type::() + .call((v.to_string(),), None)? + .into_any(), + Value::Bytes(v) => PyBytes::new(py, v).into_any(), + }; + Ok(res) + } } diff --git a/edb/edgeql-parser/edgeql-parser-python/src/tokenizer.rs b/edb/edgeql-parser/edgeql-parser-python/src/tokenizer.rs index 54a2b5e0f47..618088430bb 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/tokenizer.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/tokenizer.rs @@ -10,16 +10,16 @@ use crate::errors::{parser_error_into_tuple, ParserResult}; pub fn tokenize(py: Python, s: &Bound) -> PyResult { let data = s.to_string(); - let mut token_stream = Tokenizer::new(&data[..]).validated_values().with_eof(); + let token_stream = Tokenizer::new(&data[..]).validated_values().with_eof(); - let mut tokens: Vec<_> = Vec::new(); - let mut errors: Vec<_> = Vec::new(); + let mut tokens = vec![]; + let mut errors = vec![]; - for res in &mut token_stream { + for res in token_stream.into_iter() { match res { Ok(token) => tokens.push(token), Err(e) => { - errors.push(parser_error_into_tuple(py, e)); + errors.push(parser_error_into_tuple(&e).into_pyobject(py)?); // TODO: fix tokenizer to skip bad tokens and continue break; @@ -27,14 +27,10 @@ pub fn tokenize(py: Python, s: &Bound) -> PyResult { } } - let tokens = tokens_to_py(py, tokens)?; + let out = tokens_to_py(py, tokens)?.into_pyobject(py)?.into(); + let errors = PyList::new(py, errors)?.into(); - let errors = PyList::new_bound(py, errors.as_slice()).into_py(py); - - Ok(ParserResult { - out: tokens.into_py(py), - errors, - }) + Ok(ParserResult { out, errors }) } // An opaque wrapper around [edgeql_parser::tokenizer::Token]. @@ -54,20 +50,18 @@ impl OpaqueToken { .map_err(|e| PyValueError::new_err(format!("Failed to reduce: {e}")))?; let tok = get_unpickle_token_fn(py); - Ok((tok, (PyBytes::new_bound(py, &data).to_object(py),))) + Ok((tok, (PyBytes::new(py, &data).into(),))) } } -pub fn tokens_to_py(py: Python<'_>, rust_tokens: Vec) -> PyResult { - let mut buf = Vec::with_capacity(rust_tokens.len()); - for tok in rust_tokens { - let py_tok = OpaqueToken { +pub fn tokens_to_py(py: Python<'_>, rust_tokens: Vec) -> PyResult> { + Ok(PyList::new( + py, + rust_tokens.into_iter().map(|tok| OpaqueToken { inner: tok.cloned(), - }; - - buf.push(py_tok.into_py(py)); - } - Ok(PyList::new_bound(py, &buf[..]).into_py(py)) + }), + )? + .unbind()) } /// To support pickle serialization of OpaqueTokens, we need to provide a @@ -79,10 +73,10 @@ pub fn tokens_to_py(py: Python<'_>, rust_tokens: Vec) -> PyResult = OnceCell::new(); -pub fn fini_module(py: Python, m: &Bound) { +pub fn fini_module(m: &Bound) { let _unpickle_token = m.getattr("unpickle_token").unwrap(); FN_UNPICKLE_TOKEN - .set(_unpickle_token.to_object(py)) + .set(_unpickle_token.unbind()) .expect("module is already initialized"); } diff --git a/edb/edgeql-parser/edgeql-parser-python/src/unpack.rs b/edb/edgeql-parser/edgeql-parser-python/src/unpack.rs index 11c178d44a4..ad17edf80e1 100644 --- a/edb/edgeql-parser/edgeql-parser-python/src/unpack.rs +++ b/edb/edgeql-parser/edgeql-parser-python/src/unpack.rs @@ -14,13 +14,13 @@ pub fn unpack(py: Python<'_>, serialized: &Bound) -> PyResult 0u8 => { let tokens: Vec = bincode::deserialize(&buf[1..]) .map_err(|e| PyValueError::new_err(format!("{e}")))?; - tokens_to_py(py, tokens) + Ok(tokens_to_py(py, tokens)?.into_any()) } 1u8 => { let pack: PackedEntry = bincode::deserialize(&buf[1..]) .map_err(|e| PyValueError::new_err(format!("Failed to unpack: {e}")))?; let entry = Entry::new(py, pack.into())?; - Ok(entry.into_py(py)) + entry.into_pyobject(py).map(|e| e.unbind().into_any()) } _ => Err(PyValueError::new_err(format!( "Invalid type/version byte: {}", diff --git a/edb/edgeql-parser/src/helpers/bytes.rs b/edb/edgeql-parser/src/helpers/bytes.rs index cffd954bc4e..f2b8976c86c 100644 --- a/edb/edgeql-parser/src/helpers/bytes.rs +++ b/edb/edgeql-parser/src/helpers/bytes.rs @@ -1,6 +1,6 @@ pub fn unquote_bytes(value: &str) -> Result, String> { let idx = value - .find(|c| c == '\'' || c == '"') + .find(['\'', '"']) .ok_or_else(|| "invalid bytes literal: missing quotes".to_string())?; let prefix = &value[..idx]; match prefix { diff --git a/edb/edgeql-parser/src/into_python.rs b/edb/edgeql-parser/src/into_python.rs deleted file mode 100644 index e2f4f4c4335..00000000000 --- a/edb/edgeql-parser/src/into_python.rs +++ /dev/null @@ -1,117 +0,0 @@ -#![cfg(never)] // TODO: migrate cpython-rust to pyo3 -use indexmap::IndexMap; - -use cpython::{ - PyBytes, PyDict, PyList, PyObject, PyResult, PyTuple, Python, PythonObject, ToPyObject, -}; - -/// Convert into a Python object. -/// -/// Primitives (i64, String, Option, Vec) have this trait implemented with -/// calls to [cpython]. -/// -/// Structs have this trait derived to collect all their properties into a -/// [PyDict] and call constructor of the AST node. -/// -/// Enums represent either enums, union types or child classes and thus have -/// three different derive implementations. -/// -/// See [edgeql_parser_derive] crate. -pub trait IntoPython: Sized { - fn into_python(self, py: Python<'_>, parent_kw_args: Option) -> PyResult; -} - -impl IntoPython for String { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - Ok(self.to_py_object(py).into_object()) - } -} - -impl IntoPython for i64 { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - Ok(self.to_py_object(py).into_object()) - } -} - -impl IntoPython for f64 { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - Ok(self.to_py_object(py).into_object()) - } -} - -impl IntoPython for bool { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - Ok(if self { py.True() } else { py.False() }.into_object()) - } -} - -impl IntoPython for Vec { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - let mut elements = Vec::new(); - for x in self { - elements.push(x.into_python(py, None)?); - } - Ok(PyList::new_bound(py, elements.as_slice()).into_object()) - } -} - -impl IntoPython for Option { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - if let Some(value) = self { - value.into_python(py, None) - } else { - Ok(py.None()) - } - } -} - -impl IntoPython for Box { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - (*self).into_python(py, None) - } -} - -impl IntoPython for (T1, T2) { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - let mut elements = Vec::new(); - elements.push(self.0.into_python(py, None)?); - elements.push(self.1.into_python(py, None)?); - Ok(PyTuple::new_bound(py, elements.as_slice()).into_object()) - } -} - -impl IntoPython for IndexMap { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - let dict = PyDict::new_bound(py); - for (key, value) in self { - let key = key.into_python(py, None)?; - let value = value.into_python(py, None)?; - dict.set_item(py, key, value)?; - } - Ok(dict.into_object()) - } -} - -impl IntoPython for Vec { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - Ok(PyBytes::new_bound(py, self.as_slice()).into_object()) - } -} - -impl IntoPython for () { - fn into_python(self, py: Python<'_>, _: Option) -> PyResult { - Ok(py.None().into_object()) - } -} - -pub fn init_ast_class( - py: Python, - class_name: &'static str, - kw_args: PyDict, -) -> Result { - let locals = PyDict::new_bound(py); - locals.set_item(py, "kw_args", kw_args)?; - - let code = format!("qlast.{class_name}(**kw_args)"); - py.eval(&code, None, Some(&locals)) -} diff --git a/edb/edgeql-parser/src/lib.rs b/edb/edgeql-parser/src/lib.rs index 36841764a4a..01cd1304b72 100644 --- a/edb/edgeql-parser/src/lib.rs +++ b/edb/edgeql-parser/src/lib.rs @@ -2,8 +2,6 @@ pub mod ast; pub mod expr; pub mod hash; pub mod helpers; -#[cfg(feature = "python")] -pub mod into_python; pub mod keywords; pub mod parser; pub mod position; diff --git a/edb/edgeql-parser/src/position.rs b/edb/edgeql-parser/src/position.rs index a5d8bdd02b2..2f93b529a06 100644 --- a/edb/edgeql-parser/src/position.rs +++ b/edb/edgeql-parser/src/position.rs @@ -124,7 +124,7 @@ impl InflatedPos { let prefix_s = from_utf8(prefix).map_err(InflatingError::Utf8)?; let line_offset; let line; - if let Some(loff) = prefix_s.rfind(|c| c == '\r' || c == '\n') { + if let Some(loff) = prefix_s.rfind(['\r', '\n']) { line_offset = loff + 1; let mut lines = &prefix[..loff]; if data[loff] == b'\n' && loff > 0 && data[loff - 1] == b'\r' { diff --git a/edb/edgeql-parser/src/validation.rs b/edb/edgeql-parser/src/validation.rs index fd05d869975..b1bd57f1228 100644 --- a/edb/edgeql-parser/src/validation.rs +++ b/edb/edgeql-parser/src/validation.rs @@ -167,7 +167,7 @@ pub fn parse_value(token: &Token) -> Result, String> { return Err("number is out of range for std::float64".to_string()); } if num == 0.0 { - let mend = text.find(|c| c == 'e' || c == 'E').unwrap_or(text.len()); + let mend = text.find(['e', 'E']).unwrap_or(text.len()); let mantissa = &text[..mend]; if mantissa.chars().any(|c| c != '0' && c != '.') { return Err("number is out of range for std::float64".to_string()); diff --git a/edb/edgeql/__init__.py b/edb/edgeql/__init__.py index dcbdd259105..11d17479b1d 100644 --- a/edb/edgeql/__init__.py +++ b/edb/edgeql/__init__.py @@ -24,3 +24,4 @@ from .codegen import generate_source # NOQA from .parser import parse_fragment, parse_block, parse_query # NOQA from .parser.grammar import keywords # NOQA +from .quote import quote_literal, quote_ident # NOQA diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index eb74f85fe5d..ec919e1a6b1 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -674,6 +674,11 @@ class DDLCommand(Command, DDLOperation): __abstract_node__ = True +class NonTransactionalDDLCommand(DDLCommand): + __abstract_node__ = True + __rust_ignore__ = True + + class AlterAddInherit(DDLOperation): position: typing.Optional[Position] = None bases: typing.List[TypeName] @@ -868,7 +873,7 @@ class BranchType(s_enum.StrEnum): TEMPLATE = 'TEMPLATE' -class DatabaseCommand(ExternalObjectCommand): +class DatabaseCommand(ExternalObjectCommand, NonTransactionalDDLCommand): __abstract_node__ = True __rust_ignore__ = True diff --git a/edb/edgeql/compiler/expr.py b/edb/edgeql/compiler/expr.py index 719add2b9df..a4bd7c0eebb 100644 --- a/edb/edgeql/compiler/expr.py +++ b/edb/edgeql/compiler/expr.py @@ -1078,7 +1078,7 @@ def compile_type_check_op( ltype = setgen.get_set_type(left, ctx=ctx) typeref = typegen.ql_typeexpr_to_ir_typeref(expr.right, ctx=ctx) - if ltype.is_object_type(): + if ltype.is_object_type() and not ltype.is_free_object_type(ctx.env.schema): left = setgen.ptr_step_set( left, expr=None, source=ltype, ptr_name='__type__', span=expr.span, ctx=ctx diff --git a/edb/edgeql/compiler/inference/cardinality.py b/edb/edgeql/compiler/inference/cardinality.py index ee7504b83c9..e0691701674 100644 --- a/edb/edgeql/compiler/inference/cardinality.py +++ b/edb/edgeql/compiler/inference/cardinality.py @@ -1205,7 +1205,7 @@ def _infer_stmt_cardinality( scope_tree: irast.ScopeTreeNode, ctx: inference_context.InfCtx, ) -> qltypes.Cardinality: - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_cardinality(part, scope_tree=scope_tree, ctx=ctx) result = ir.subject if isinstance(ir, irast.MutatingStmt) else ir.result @@ -1339,7 +1339,7 @@ def __infer_insert_stmt( scope_tree: irast.ScopeTreeNode, ctx: inference_context.InfCtx, ) -> qltypes.Cardinality: - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_cardinality(part, scope_tree=scope_tree, ctx=ctx) infer_cardinality( diff --git a/edb/edgeql/compiler/inference/multiplicity.py b/edb/edgeql/compiler/inference/multiplicity.py index 12126886d55..0ec86dca3c4 100644 --- a/edb/edgeql/compiler/inference/multiplicity.py +++ b/edb/edgeql/compiler/inference/multiplicity.py @@ -588,7 +588,7 @@ def _infer_stmt_multiplicity( ) -> inf_ctx.MultiplicityInfo: # WITH block bindings need to be validated; they don't have to # have multiplicity UNIQUE, but their sub-expressions must be valid. - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx) subj = ir.subject if isinstance(ir, irast.MutatingStmt) else ir.result @@ -688,7 +688,7 @@ def __infer_insert_stmt( ) -> inf_ctx.MultiplicityInfo: # WITH block bindings need to be validated, they don't have to # have multiplicity UNIQUE, but their sub-expressions must be valid. - for part in (ir.bindings or []): + for part, _ in (ir.bindings or []): infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx) # INSERT will always return a proper set, but we still want to diff --git a/edb/edgeql/compiler/inference/volatility.py b/edb/edgeql/compiler/inference/volatility.py index 12593f63bf6..edd1d382f98 100644 --- a/edb/edgeql/compiler/inference/volatility.py +++ b/edb/edgeql/compiler/inference/volatility.py @@ -209,7 +209,7 @@ def __infer_set( vol, ]) - if ir.is_binding: + if ir.is_binding and ir.is_binding != irast.BindingKind.Schema: vol = IMMUTABLE return vol @@ -320,7 +320,7 @@ def __infer_select_stmt( components.append(ir.limit) if ir.bindings is not None: - components.extend(ir.bindings) + components.extend(part for part, _ in ir.bindings) return _common_volatility(components, env) diff --git a/edb/edgeql/compiler/setgen.py b/edb/edgeql/compiler/setgen.py index da88907c5cc..edd61477f54 100644 --- a/edb/edgeql/compiler/setgen.py +++ b/edb/edgeql/compiler/setgen.py @@ -1959,9 +1959,6 @@ def should_materialize( if not isinstance(ir, irast.Set): return reasons - if irtyputils.is_free_object(ir.typeref): - reasons.append(irast.MaterializeVolatile()) - typ = get_set_type(ir, ctx=ctx) assert ir.path_scope_id is not None diff --git a/edb/edgeql/compiler/stmt.py b/edb/edgeql/compiler/stmt.py index 4c92ef53dab..02bdf674d09 100644 --- a/edb/edgeql/compiler/stmt.py +++ b/edb/edgeql/compiler/stmt.py @@ -68,6 +68,7 @@ from . import context from . import config_desc from . import dispatch +from . import inference from . import pathctx from . import policies from . import setgen @@ -1309,7 +1310,7 @@ def process_with_block( *, ctx: context.ContextLevel, parent_ctx: context.ContextLevel, -) -> List[irast.Set]: +) -> list[tuple[irast.Set, qltypes.Volatility]]: if edgeql_tree.aliases is None: return [] @@ -1329,7 +1330,10 @@ def process_with_block( binding_kind=irast.BindingKind.With, ctx=scopectx, ) - results.append(binding) + volatility = inference.infer_volatility( + binding, ctx.env, exclude_dml=True + ) + results.append((binding, volatility)) if reason := setgen.should_materialize(binding, ctx=ctx): had_materialized = True diff --git a/edb/edgeql/compiler/stmtctx.py b/edb/edgeql/compiler/stmtctx.py index 2bc054f317f..4256ce065ff 100644 --- a/edb/edgeql/compiler/stmtctx.py +++ b/edb/edgeql/compiler/stmtctx.py @@ -816,9 +816,13 @@ def _declare_view_from_schema( view_ql = view_expr.parse() viewcls_name = viewcls.get_name(ctx.env.schema) assert isinstance(view_ql, qlast.Expr), 'expected qlast.Expr' - view_set = declare_view(view_ql, alias=viewcls_name, - binding_kind=irast.BindingKind.With, - fully_detached=True, ctx=subctx) + view_set = declare_view( + view_ql, + alias=viewcls_name, + binding_kind=irast.BindingKind.Schema, + fully_detached=True, + ctx=subctx, + ) # The view path id _itself_ should not be in the nested namespace. view_set.path_id = view_set.path_id.replace_namespace(frozenset()) view_set.is_schema_alias = True diff --git a/edb/edgeql/compiler/viewgen.py b/edb/edgeql/compiler/viewgen.py index 38f58713e48..fd4f86e2f6c 100644 --- a/edb/edgeql/compiler/viewgen.py +++ b/edb/edgeql/compiler/viewgen.py @@ -69,7 +69,6 @@ from . import context from . import dispatch from . import eta_expand -from . import inference from . import pathctx from . import schemactx from . import setgen @@ -1946,27 +1945,6 @@ def _normalize_view_ptr_expr( assert ptrcls is not None - if materialized and is_mutation and any( - x.is_binding == irast.BindingKind.With - and x.expr - # If it is a computed pointer, look just at the definition. - # TODO: It is a weird artifact of how our shapes are defined - # that if a shape element is defined to be some WITH-bound variable, - # that set can is both a is_binding and an irast.Pointer. It seems like - # the is_binding part should be nested inside it. - and (y := x.expr.expr if isinstance(x.expr, irast.Pointer) else x.expr) - and inference.infer_volatility( - y, ctx.env, exclude_dml=True).is_volatile() - - for reason in materialized - if isinstance(reason, irast.MaterializeVisible) - for _, x in reason.sets - ): - raise errors.QueryError( - f'cannot refer to volatile WITH bindings from DML', - span=compexpr.span if compexpr else None, - ) - if materialized and not is_mutation and ctx.qlstmt: assert ptrcls not in ctx.env.materialized_sets ctx.env.materialized_sets[ptrcls] = ctx.qlstmt, materialized @@ -2171,6 +2149,7 @@ def has_implicit_tid( return ( stype.is_object_type() + and not stype.is_free_object_type(ctx.env.schema) and not is_mutation and ctx.implicit_tid_in_shapes ) @@ -2182,6 +2161,7 @@ def has_implicit_tname( return ( stype.is_object_type() + and not stype.is_free_object_type(ctx.env.schema) and not is_mutation and ctx.implicit_tname_in_shapes ) diff --git a/edb/edgeql/declarative.py b/edb/edgeql/declarative.py index fe691bfd6a9..60e241f2ae0 100644 --- a/edb/edgeql/declarative.py +++ b/edb/edgeql/declarative.py @@ -615,6 +615,7 @@ def _trace_item_layout( obj = local_obj assert fq_name is not None + PointerType: type[qltracer.Pointer] if isinstance(node, qlast.BasedOnTuple): bases = [] diff --git a/edb/edgeql/tokenizer.py b/edb/edgeql/tokenizer.py index b721d0adce8..4a85fa23300 100644 --- a/edb/edgeql/tokenizer.py +++ b/edb/edgeql/tokenizer.py @@ -77,6 +77,12 @@ def extra_counts(self) -> Sequence[int]: def extra_blobs(self) -> Sequence[bytes]: return () + def extra_formatted_as_text(self) -> bool: + return False + + def extra_type_oids(self) -> Sequence[int]: + return () + def serialize(self) -> bytes: return self._serialized diff --git a/edb/graphql-rewrite/src/lib.rs b/edb/graphql-rewrite/src/lib.rs index 17cab94b874..d1de0ca4f4c 100644 --- a/edb/graphql-rewrite/src/lib.rs +++ b/edb/graphql-rewrite/src/lib.rs @@ -16,11 +16,11 @@ use pyo3::{prelude::*, types::PyString}; fn _graphql_rewrite(py: Python, m: &Bound) -> PyResult<()> { m.add_function(wrap_pyfunction!(py_rewrite, m)?)?; m.add_class::()?; - m.add("LexingError", py.get_type_bound::())?; - m.add("SyntaxError", py.get_type_bound::())?; - m.add("NotFoundError", py.get_type_bound::())?; - m.add("AssertionError", py.get_type_bound::())?; - m.add("QueryError", py.get_type_bound::())?; + m.add("LexingError", py.get_type::())?; + m.add("SyntaxError", py.get_type::())?; + m.add("NotFoundError", py.get_type::())?; + m.add("AssertionError", py.get_type::())?; + m.add("QueryError", py.get_type::())?; Ok(()) } diff --git a/edb/graphql-rewrite/src/py_entry.rs b/edb/graphql-rewrite/src/py_entry.rs index 05e082b8e20..0910c2fd596 100644 --- a/edb/graphql-rewrite/src/py_entry.rs +++ b/edb/graphql-rewrite/src/py_entry.rs @@ -1,5 +1,5 @@ use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyLong, PyString, PyTuple, PyType}; +use pyo3::types::{PyDict, PyInt, PyList, PyString, PyTuple, PyType}; use edb_graphql_parser::position::Pos; @@ -22,24 +22,23 @@ pub struct Entry { #[pymethods] impl Entry { - fn tokens(&self, py: Python, kinds: PyObject) -> PyResult { + fn tokens<'py>(&self, py: Python<'py>, kinds: PyObject) -> PyResult> { py_token::convert_tokens(py, &self._tokens, &self._end_pos, kinds) } } pub fn convert_entry(py: Python<'_>, entry: rewrite::Entry) -> PyResult { // import decimal - let decimal_cls = PyModule::import_bound(py, "decimal")?.getattr("Decimal")?; + let decimal_cls = PyModule::import(py, "decimal")?.getattr("Decimal")?; - let vars = PyDict::new_bound(py); - let substitutions = PyDict::new_bound(py); + let vars = PyDict::new(py); + let substitutions = PyDict::new(py); for (idx, var) in entry.variables.iter().enumerate() { - let s = format!("_edb_arg__{}", idx).to_object(py); - - vars.set_item(s.clone_ref(py), value_to_py(py, &var.value, &decimal_cls)?)?; + let s = format!("_edb_arg__{}", idx).into_pyobject(py)?; + vars.set_item(&s, value_to_py(py, &var.value, &decimal_cls)?)?; substitutions.set_item( - s.clone_ref(py), + s, ( &var.token.value, var.token.position.map(|x| x.line), @@ -48,20 +47,13 @@ pub fn convert_entry(py: Python<'_>, entry: rewrite::Entry) -> PyResult { )?; } for (name, var) in &entry.defaults { - vars.set_item(name.into_py(py), value_to_py(py, &var.value, &decimal_cls)?)? + vars.set_item(name, value_to_py(py, &var.value, &decimal_cls)?)? } - let key_vars = PyList::new_bound( - py, - entry - .key_vars - .iter() - .map(|v| v.into_py(py)) - .collect::>(), - ); + let key_vars = PyList::new(py, entry.key_vars)?; Ok(Entry { - key: PyString::new_bound(py, &entry.key).into(), + key: PyString::new(py, &entry.key).into(), key_vars: key_vars.into(), - variables: vars.into_py(py), + variables: vars.into_pyobject(py)?.into(), substitutions: substitutions.into(), _tokens: entry.tokens, _end_pos: entry.end_pos, @@ -70,16 +62,14 @@ pub fn convert_entry(py: Python<'_>, entry: rewrite::Entry) -> PyResult { fn value_to_py(py: Python, value: &Value, decimal_cls: &Bound) -> PyResult { let v = match value { - Value::Str(ref v) => PyString::new_bound(py, v).into(), - Value::Int32(v) => v.into_py(py), - Value::Int64(v) => v.into_py(py), - Value::Decimal(v) => decimal_cls - .call(PyTuple::new_bound(py, &[v.into_py(py)]), None)? - .into(), - Value::BigInt(ref v) => PyType::new_bound::(py) - .call(PyTuple::new_bound(py, &[v.into_py(py)]), None)? - .into(), - Value::Boolean(b) => b.into_py(py), + Value::Str(ref v) => PyString::new(py, v).into_any(), + Value::Int32(v) => v.into_pyobject(py)?.into_any(), + Value::Int64(v) => v.into_pyobject(py)?.into_any(), + Value::Decimal(v) => decimal_cls.call((v.as_str(),), None)?.into_any(), + Value::BigInt(ref v) => PyType::new::(py) + .call((v.as_str(),), None)? + .into_any(), + Value::Boolean(b) => b.into_pyobject(py)?.to_owned().into_any(), }; - Ok(v) + Ok(v.into()) } diff --git a/edb/graphql-rewrite/src/py_token.rs b/edb/graphql-rewrite/src/py_token.rs index 9143709fa64..edd738cfc4f 100644 --- a/edb/graphql-rewrite/src/py_token.rs +++ b/edb/graphql-rewrite/src/py_token.rs @@ -2,7 +2,7 @@ use edb_graphql_parser::common::{unquote_block_string, unquote_string}; use edb_graphql_parser::position::Pos; use edb_graphql_parser::tokenizer::Token; use pyo3::prelude::*; -use pyo3::types::{PyList, PyTuple}; +use pyo3::types::{PyList, PyString, PyTuple}; use std::borrow::Cow; use crate::py_exception::LexingError; @@ -73,42 +73,42 @@ impl PyToken { } } -pub fn convert_tokens( - py: Python, +pub fn convert_tokens<'py>( + py: Python<'py>, tokens: &[PyToken], end_pos: &Pos, kinds: PyObject, -) -> PyResult { +) -> PyResult> { use PyTokenKind as K; let sof = kinds.getattr(py, "SOF")?; let eof = kinds.getattr(py, "EOF")?; let bang = kinds.getattr(py, "BANG")?; - let bang_v: PyObject = "!".into_py(py); + let bang_v = "!".into_pyobject(py)?; let dollar = kinds.getattr(py, "DOLLAR")?; - let dollar_v: PyObject = "$".into_py(py); + let dollar_v = "$".into_pyobject(py)?; let paren_l = kinds.getattr(py, "PAREN_L")?; - let paren_l_v: PyObject = "(".into_py(py); + let paren_l_v = "(".into_pyobject(py)?; let paren_r = kinds.getattr(py, "PAREN_R")?; - let paren_r_v: PyObject = ")".into_py(py); + let paren_r_v = ")".into_pyobject(py)?; let spread = kinds.getattr(py, "SPREAD")?; - let spread_v: PyObject = "...".into_py(py); + let spread_v = "...".into_pyobject(py)?; let colon = kinds.getattr(py, "COLON")?; - let colon_v: PyObject = ":".into_py(py); + let colon_v = ":".into_pyobject(py)?; let equals = kinds.getattr(py, "EQUALS")?; - let equals_v: PyObject = "=".into_py(py); + let equals_v = "=".into_pyobject(py)?; let at = kinds.getattr(py, "AT")?; - let at_v: PyObject = "@".into_py(py); + let at_v = "@".into_pyobject(py)?; let bracket_l = kinds.getattr(py, "BRACKET_L")?; - let bracket_l_v: PyObject = "[".into_py(py); + let bracket_l_v = "[".into_pyobject(py)?; let bracket_r = kinds.getattr(py, "BRACKET_R")?; - let bracket_r_v: PyObject = "]".into_py(py); + let bracket_r_v = "]".into_pyobject(py)?; let brace_l = kinds.getattr(py, "BRACE_L")?; - let brace_l_v: PyObject = "{".into_py(py); + let brace_l_v = "{".into_pyobject(py)?; let pipe = kinds.getattr(py, "PIPE")?; - let pipe_v: PyObject = "|".into_py(py); + let pipe_v = "|".into_pyobject(py)?; let brace_r = kinds.getattr(py, "BRACE_R")?; - let brace_r_v: PyObject = "}".into_py(py); + let brace_r_v = "}".into_pyobject(py)?; let name = kinds.getattr(py, "NAME")?; let int = kinds.getattr(py, "INT")?; let float = kinds.getattr(py, "FLOAT")?; @@ -117,77 +117,76 @@ pub fn convert_tokens( let mut elems: Vec = Vec::with_capacity(tokens.len()); + let zero = 0u32.into_pyobject(py).unwrap(); let start_of_file = [ sof.clone_ref(py), - 0u32.into_py(py), - 0u32.into_py(py), - 0u32.into_py(py), - 0u32.into_py(py), + zero.clone().into(), + zero.clone().into(), + zero.clone().into(), + zero.clone().into(), py.None(), ]; - elems.push(PyTuple::new_bound(py, &start_of_file).into()); + elems.push(PyTuple::new(py, &start_of_file)?.into()); for token in tokens { let (kind, value) = match token.kind { K::Sof => (sof.clone_ref(py), py.None()), K::Eof => (eof.clone_ref(py), py.None()), - K::Bang => (bang.clone_ref(py), bang_v.clone_ref(py)), - K::Dollar => (dollar.clone_ref(py), dollar_v.clone_ref(py)), - K::ParenL => (paren_l.clone_ref(py), paren_l_v.clone_ref(py)), - K::ParenR => (paren_r.clone_ref(py), paren_r_v.clone_ref(py)), - K::Spread => (spread.clone_ref(py), spread_v.clone_ref(py)), - K::Colon => (colon.clone_ref(py), colon_v.clone_ref(py)), - K::Equals => (equals.clone_ref(py), equals_v.clone_ref(py)), - K::At => (at.clone_ref(py), at_v.clone_ref(py)), - K::BracketL => (bracket_l.clone_ref(py), bracket_l_v.clone_ref(py)), - K::BracketR => (bracket_r.clone_ref(py), bracket_r_v.clone_ref(py)), - K::BraceL => (brace_l.clone_ref(py), brace_l_v.clone_ref(py)), - K::Pipe => (pipe.clone_ref(py), pipe_v.clone_ref(py)), - K::BraceR => (brace_r.clone_ref(py), brace_r_v.clone_ref(py)), - K::Name => (name.clone_ref(py), token.value.clone().into_py(py)), - K::Int => (int.clone_ref(py), token.value.clone().into_py(py)), - K::Float => (float.clone_ref(py), token.value.clone().into_py(py)), + K::Bang => (bang.clone_ref(py), bang_v.to_owned().into()), + K::Dollar => (dollar.clone_ref(py), dollar_v.to_owned().into()), + K::ParenL => (paren_l.clone_ref(py), paren_l_v.to_owned().into()), + K::ParenR => (paren_r.clone_ref(py), paren_r_v.to_owned().into()), + K::Spread => (spread.clone_ref(py), spread_v.to_owned().into()), + K::Colon => (colon.clone_ref(py), colon_v.to_owned().into()), + K::Equals => (equals.clone_ref(py), equals_v.to_owned().into()), + K::At => (at.clone_ref(py), at_v.to_owned().into()), + K::BracketL => (bracket_l.clone_ref(py), bracket_l_v.to_owned().into()), + K::BracketR => (bracket_r.clone_ref(py), bracket_r_v.to_owned().into()), + K::BraceL => (brace_l.clone_ref(py), brace_l_v.to_owned().into()), + K::Pipe => (pipe.clone_ref(py), pipe_v.to_owned().into()), + K::BraceR => (brace_r.clone_ref(py), brace_r_v.to_owned().into()), + K::Name => (name.clone_ref(py), PyString::new(py, &token.value).into()), + K::Int => (int.clone_ref(py), PyString::new(py, &token.value).into()), + K::Float => (float.clone_ref(py), PyString::new(py, &token.value).into()), K::String => { // graphql-core 3 receives unescaped strings from the lexer let v = unquote_string(&token.value) .map_err(|e| LexingError::new_err(e.to_string()))? - .into_py(py); - (string.clone_ref(py), v) + .into_pyobject(py)?; + (string.clone_ref(py), v.to_owned().into()) } K::BlockString => { // graphql-core 3 receives unescaped strings from the lexer let v = unquote_block_string(&token.value) .map_err(|e| LexingError::new_err(e.to_string()))? - .into_py(py); - (block_string.clone_ref(py), v) + .into_pyobject(py)?; + (block_string.clone_ref(py), v.to_owned().into()) } }; - let token_tuple = [ + let token_tuple = ( kind, - token.position.map(|x| x.character).into_py(py), + token.position.map(|x| x.character), token .position - .map(|x| x.character + token.value.chars().count()) - .into_py(py), - token.position.map(|x| x.line).into_py(py), - token.position.map(|x| x.column).into_py(py), + .map(|x| x.character + token.value.chars().count()), + token.position.map(|x| x.line), + token.position.map(|x| x.column), value, - ]; - elems.push(PyTuple::new_bound(py, &token_tuple).into()); + ) + .into_pyobject(py)?; + elems.push(token_tuple.into()); } elems.push( - PyTuple::new_bound( - py, - &[ - eof.clone_ref(py), - end_pos.character.into_py(py), - end_pos.line.into_py(py), - end_pos.column.into_py(py), - end_pos.character.into_py(py), - py.None(), - ], + ( + eof, + end_pos.character, + end_pos.line, + end_pos.column, + end_pos.character, + py.None(), ) - .into(), + .into_pyobject(py)? + .into(), ); - Ok(PyList::new_bound(py, &elems[..]).into()) + PyList::new(py, elems) } diff --git a/edb/ir/ast.py b/edb/ir/ast.py index e284fa55c75..3602b592047 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -521,6 +521,7 @@ class BindingKind(s_enum.StrEnum): With = 'With' For = 'For' Select = 'Select' + Schema = 'Schema' class TypeRoot(Expr): @@ -1143,7 +1144,7 @@ class Stmt(Expr): result: Set = DUMMY_SET parent_stmt: typing.Optional[Stmt] = None iterator_stmt: typing.Optional[Set] = None - bindings: typing.Optional[typing.List[Set]] = None + bindings: typing.Optional[list[tuple[Set, qltypes.Volatility]]] = None @property def typeref(self) -> TypeRef: diff --git a/edb/ir/staeval.py b/edb/ir/staeval.py index 9c01fd552cd..4ff93d2f90f 100644 --- a/edb/ir/staeval.py +++ b/edb/ir/staeval.py @@ -292,7 +292,7 @@ def evaluate_SliceIndirection( base, start, stop = vals - value = base[start:stop] + value = base[start:stop] # type: ignore[index] return _process_op_result( value, slice.expr.typeref, schema, span=slice.span) diff --git a/edb/ir/typeutils.py b/edb/ir/typeutils.py index 861f2a3cfbd..2ae7a2d25ba 100644 --- a/edb/ir/typeutils.py +++ b/edb/ir/typeutils.py @@ -196,6 +196,10 @@ def is_persistent_tuple(typeref: irast.TypeRef) -> bool: return False +def is_empty_typeref(typeref: irast.TypeRef) -> bool: + return typeref.union is not None and len(typeref.union) == 0 + + def needs_custom_serialization(typeref: irast.TypeRef) -> bool: # True if any component needs custom serialization return contains_predicate( diff --git a/edb/lib/cfg.edgeql b/edb/lib/cfg.edgeql index 939913da510..56df07eacbb 100644 --- a/edb/lib/cfg.edgeql +++ b/edb/lib/cfg.edgeql @@ -102,6 +102,71 @@ CREATE TYPE cfg::Auth EXTENDING cfg::ConfigObject { }; }; +CREATE SCALAR TYPE cfg::SMTPSecurity EXTENDING enum< + PlainText, + TLS, + STARTTLS, + STARTTLSOrPlainText, +>; + +CREATE ABSTRACT TYPE cfg::EmailProviderConfig EXTENDING cfg::ConfigObject { + CREATE REQUIRED PROPERTY name -> std::str { + CREATE CONSTRAINT std::exclusive; + CREATE ANNOTATION std::description := + "The name of the email provider."; + }; +}; + +CREATE TYPE cfg::SMTPProviderConfig EXTENDING cfg::EmailProviderConfig { + CREATE PROPERTY sender -> std::str { + CREATE ANNOTATION std::description := + "\"From\" address of system emails sent for e.g. \ + password reset, etc."; + }; + CREATE PROPERTY host -> std::str { + CREATE ANNOTATION std::description := + "Host of SMTP server to use for sending emails. \ + If not set, \"localhost\" will be used."; + }; + CREATE PROPERTY port -> std::int32 { + CREATE ANNOTATION std::description := + "Port of SMTP server to use for sending emails. \ + If not set, common defaults will be used depending on security: \ + 465 for TLS, 587 for STARTTLS, 25 otherwise."; + }; + CREATE PROPERTY username -> std::str { + CREATE ANNOTATION std::description := + "Username to login as after connected to SMTP server."; + }; + CREATE PROPERTY password -> std::str { + SET secret := true; + CREATE ANNOTATION std::description := + "Password for login after connected to SMTP server."; + }; + CREATE REQUIRED PROPERTY security -> cfg::SMTPSecurity { + SET default := cfg::SMTPSecurity.STARTTLSOrPlainText; + CREATE ANNOTATION std::description := + "Security mode of the connection to SMTP server. \ + By default, initiate a STARTTLS upgrade if supported by the \ + server, or fallback to PlainText."; + }; + CREATE REQUIRED PROPERTY validate_certs -> std::bool { + SET default := true; + CREATE ANNOTATION std::description := + "Determines if SMTP server certificates are validated."; + }; + CREATE REQUIRED PROPERTY timeout_per_email -> std::duration { + SET default := '60 seconds'; + CREATE ANNOTATION std::description := + "Maximum time to send an email, including retry attempts."; + }; + CREATE REQUIRED PROPERTY timeout_per_attempt -> std::duration { + SET default := '15 seconds'; + CREATE ANNOTATION std::description := + "Maximum time for each SMTP request."; + }; +}; + CREATE ABSTRACT TYPE cfg::AbstractConfig extending cfg::ConfigObject; CREATE ABSTRACT TYPE cfg::ExtensionConfig EXTENDING cfg::ConfigObject { @@ -158,6 +223,16 @@ ALTER TYPE cfg::AbstractConfig { CREATE ANNOTATION cfg::system := 'true'; }; + CREATE MULTI LINK email_providers -> cfg::EmailProviderConfig { + CREATE ANNOTATION std::description := + 'The list of email providers that can be used to send emails.'; + }; + + CREATE PROPERTY current_email_provider_name -> std::str { + CREATE ANNOTATION std::description := + 'The name of the current email provider.'; + }; + CREATE PROPERTY allow_dml_in_functions -> std::bool { SET default := false; CREATE ANNOTATION cfg::affects_compilation := 'true'; diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index 4d056f49964..d1d31c08889 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -130,6 +130,10 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { create annotation std::description := "Identity provider's refresh token."; }; + create property id_token: std::str { + create annotation std::description := + "Identity provider's OpenID Connect id_token."; + }; create link identity: ext::auth::Identity { on target delete delete source; }; @@ -465,58 +469,6 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { }; }; - create scalar type ext::auth::SMTPSecurity extending enum; - - create type ext::auth::SMTPConfig extending cfg::ExtensionConfig { - create property sender: std::str { - create annotation std::description := - "\"From\" address of system emails sent for e.g. \ - password reset, etc."; - }; - create property host: std::str { - create annotation std::description := - "Host of SMTP server to use for sending emails. \ - If not set, \"localhost\" will be used."; - }; - create property port: std::int32 { - create annotation std::description := - "Port of SMTP server to use for sending emails. \ - If not set, common defaults will be used depending on security: \ - 465 for TLS, 587 for STARTTLS, 25 otherwise."; - }; - create property username: std::str { - create annotation std::description := - "Username to login as after connected to SMTP server."; - }; - create property password: std::str { - set secret := true; - create annotation std::description := - "Password for login after connected to SMTP server."; - }; - create required property security: ext::auth::SMTPSecurity { - set default := ext::auth::SMTPSecurity.STARTTLSOrPlainText; - create annotation std::description := - "Security mode of the connection to SMTP server. \ - By default, initiate a STARTTLS upgrade if supported by the \ - server, or fallback to PlainText."; - }; - create required property validate_certs: std::bool { - set default := true; - create annotation std::description := - "Determines if SMTP server certificates are validated."; - }; - create required property timeout_per_email: std::duration { - set default := '60 seconds'; - create annotation std::description := - "Maximum time to send an email, including retry attempts."; - }; - create required property timeout_per_attempt: std::duration { - set default := '15 seconds'; - create annotation std::description := - "Maximum time for each SMTP request."; - }; - }; - create function ext::auth::signing_key_exists() -> std::bool { using ( select exists cfg::Config.extensions[is ext::auth::AuthConfig] diff --git a/edb/lib/pg.edgeql b/edb/lib/pg.edgeql index d7ca8cdaf5c..a6359f62f93 100644 --- a/edb/lib/pg.edgeql +++ b/edb/lib/pg.edgeql @@ -84,3 +84,9 @@ create index match for std::cal::local_date using std::pg::brin; create index match for std::cal::local_time using std::pg::brin; create index match for std::cal::relative_duration using std::pg::brin; create index match for std::cal::date_duration using std::pg::brin; + +create scalar type std::pg::json extending std::anyscalar; +create scalar type std::pg::timestamptz extending std::anycontiguous; +create scalar type std::pg::timestamp extending std::anycontiguous; +create scalar type std::pg::date extending std::anydiscrete; +create scalar type std::pg::interval extending std::anycontiguous; diff --git a/edb/lib/schema.edgeql b/edb/lib/schema.edgeql index 7a1ad034e4a..ae5c70d5ba9 100644 --- a/edb/lib/schema.edgeql +++ b/edb/lib/schema.edgeql @@ -432,13 +432,6 @@ ALTER TYPE std::BaseObject { SET protected := True; }; }; -ALTER TYPE std::FreeObject { - # N.B: See above. - CREATE REQUIRED LINK __type__ -> schema::ObjectType { - SET readonly := True; - SET protected := True; - }; -}; ALTER TYPE schema::ObjectType { diff --git a/edb/lib/sys.edgeql b/edb/lib/sys.edgeql index 9b3f4b9f5c7..c8ee834233e 100644 --- a/edb/lib/sys.edgeql +++ b/edb/lib/sys.edgeql @@ -28,6 +28,14 @@ CREATE SCALAR TYPE sys::VersionStage EXTENDING enum; +CREATE SCALAR TYPE sys::QueryType + EXTENDING enum; + + +CREATE SCALAR TYPE sys::OutputFormat + EXTENDING enum; + + CREATE ABSTRACT TYPE sys::SystemObject EXTENDING schema::Object; CREATE ABSTRACT TYPE sys::ExternalObject EXTENDING sys::SystemObject; @@ -86,6 +94,141 @@ ALTER TYPE sys::Role { }; +CREATE TYPE sys::QueryStats EXTENDING sys::ExternalObject { + CREATE LINK branch -> sys::Branch { + CREATE ANNOTATION std::description := + "The branch this statistics entry was collected in."; + }; + CREATE PROPERTY query -> std::str { + CREATE ANNOTATION std::description := + "Text string of a representative query."; + }; + CREATE PROPERTY query_type -> sys::QueryType { + CREATE ANNOTATION std::description := + "Type of the query."; + }; + + CREATE PROPERTY compilation_config -> std::json; + CREATE PROPERTY protocol_version -> tuple; + CREATE PROPERTY default_namespace -> std::str; + CREATE OPTIONAL PROPERTY namespace_aliases -> std::json; + CREATE OPTIONAL PROPERTY output_format -> sys::OutputFormat; + CREATE OPTIONAL PROPERTY expect_one -> std::bool; + CREATE OPTIONAL PROPERTY implicit_limit -> std::int64; + CREATE OPTIONAL PROPERTY inline_typeids -> std::bool; + CREATE OPTIONAL PROPERTY inline_typenames -> std::bool; + CREATE OPTIONAL PROPERTY inline_objectids -> std::bool; + + CREATE PROPERTY plans -> std::int64 { + CREATE ANNOTATION std::description := + "Number of times the query was planned in the backend."; + }; + CREATE PROPERTY total_plan_time -> std::duration { + CREATE ANNOTATION std::description := + "Total time spent planning the query in the backend."; + }; + CREATE PROPERTY min_plan_time -> std::duration { + CREATE ANNOTATION std::description := + "Minimum time spent planning the query in the backend. " + ++ "This field will be zero if the counter has been reset " + ++ "using the `sys::reset_query_stats` function " + ++ "with the `minmax_only` parameter set to `true` " + ++ "and never been planned since."; + }; + CREATE PROPERTY max_plan_time -> std::duration { + CREATE ANNOTATION std::description := + "Maximum time spent planning the query in the backend. " + ++ "This field will be zero if the counter has been reset " + ++ "using the `sys::reset_query_stats` function " + ++ "with the `minmax_only` parameter set to `true` " + ++ "and never been planned since."; + }; + CREATE PROPERTY mean_plan_time -> std::duration { + CREATE ANNOTATION std::description := + "Mean time spent planning the query in the backend."; + }; + CREATE PROPERTY stddev_plan_time -> std::duration { + CREATE ANNOTATION std::description := + "Population standard deviation of time spent " + ++ "planning the query in the backend."; + }; + + CREATE PROPERTY calls -> std::int64 { + CREATE ANNOTATION std::description := + "Number of times the query was executed."; + }; + CREATE PROPERTY total_exec_time -> std::duration { + CREATE ANNOTATION std::description := + "Total time spent executing the query in the backend."; + }; + CREATE PROPERTY min_exec_time -> std::duration { + CREATE ANNOTATION std::description := + "Minimum time spent executing the query in the backend, " + ++ "this field will be zero until this query is executed " + ++ "first time after reset performed by the " + ++ "`sys::reset_query_stats` function with the " + ++ "`minmax_only` parameter set to `true`"; + }; + CREATE PROPERTY max_exec_time -> std::duration { + CREATE ANNOTATION std::description := + "Maximum time spent executing the query in the backend, " + ++ "this field will be zero until this query is executed " + ++ "first time after reset performed by the " + ++ "`sys::reset_query_stats` function with the " + ++ "`minmax_only` parameter set to `true`"; + }; + CREATE PROPERTY mean_exec_time -> std::duration { + CREATE ANNOTATION std::description := + "Mean time spent executing the query in the backend."; + }; + CREATE PROPERTY stddev_exec_time -> std::duration { + CREATE ANNOTATION std::description := + "Population standard deviation of time spent " + ++ "executing the query in the backend."; + }; + + CREATE PROPERTY rows -> std::int64 { + CREATE ANNOTATION std::description := + "Total number of rows retrieved or affected by the query."; + }; + CREATE PROPERTY stats_since -> std::datetime { + CREATE ANNOTATION std::description := + "Time at which statistics gathering started for this query."; + }; + CREATE PROPERTY minmax_stats_since -> std::datetime { + CREATE ANNOTATION std::description := + "Time at which min/max statistics gathering started " + ++ "for this query (fields `min_plan_time`, `max_plan_time`, " + ++ "`min_exec_time` and `max_exec_time`)."; + }; +}; + + +CREATE FUNCTION +sys::reset_query_stats( + named only branch_name: OPTIONAL std::str = {}, + named only id: OPTIONAL std::uuid = {}, + named only minmax_only: OPTIONAL std::bool = false, +) -> OPTIONAL std::datetime { + CREATE ANNOTATION std::description := + 'Discard query statistics gathered so far corresponding to the ' + ++ 'specified `branch_name` and `id`. If either of the ' + ++ 'parameters is not specified, the statistics that match with the ' + ++ 'other parameter will be reset. If no parameter is specified, ' + ++ 'it will discard all statistics. When `minmax_only` is `true`, ' + ++ 'only the values of minimum and maximum planning and execution ' + ++ 'time will be reset (i.e. `min_plan_time`, `max_plan_time`, ' + ++ '`min_exec_time` and `max_exec_time` fields). The default value ' + ++ 'for `minmax_only` parameter is `false`. This function returns ' + ++ 'the time of a reset. This time is saved to `stats_reset` or ' + ++ '`minmax_stats_since` field of `sys::QueryStats` if the ' + ++ 'corresponding reset was actually performed.'; + SET volatility := 'Volatile'; + USING SQL FUNCTION 'edgedb.reset_query_stats'; +}; + + # An intermediate function is needed because we can't # cast JSON to tuples yet. DO NOT use directly, it'll go away. CREATE FUNCTION diff --git a/edb/pgsql/ast.py b/edb/pgsql/ast.py index d97157b580d..cc87b61b442 100644 --- a/edb/pgsql/ast.py +++ b/edb/pgsql/ast.py @@ -949,7 +949,7 @@ class RangeFunction(BaseRangeVar): with_ordinality: bool = False # ROWS FROM form is_rowsfrom: bool = False - functions: typing.List[FuncCall] + functions: typing.List[BaseExpr] class JoinClause(BaseRangeVar): diff --git a/edb/pgsql/compiler/clauses.py b/edb/pgsql/compiler/clauses.py index f8528b525ea..31e531edad4 100644 --- a/edb/pgsql/compiler/clauses.py +++ b/edb/pgsql/compiler/clauses.py @@ -34,6 +34,7 @@ from . import astutils from . import context from . import dispatch +from . import dml from . import enums as pgce from . import output from . import pathctx @@ -129,47 +130,56 @@ def compile_materialized_exprs( path_id=mat_set.materialized.path_id, ctx=matctx): continue - mat_ids = set(mat_set.uses) - - # We pack optional things into arrays also, since it works. - # TODO: use NULL? - card = mat_set.cardinality - assert card != qltypes.Cardinality.UNKNOWN - is_singleton = card.is_single() and not card.can_be_zero() - - old_scope = matctx.path_scope - matctx.path_scope = old_scope.new_child() - for mat_id in mat_ids: - for k in old_scope: - if k.startswith(mat_id): - matctx.path_scope[k] = None - mat_qry = relgen.set_as_subquery( - mat_set.materialized, as_value=True, ctx=matctx - ) + _compile_materialized_expr(query, mat_set, ctx=matctx) + + +def _compile_materialized_expr( + query: pgast.SelectStmt, + mat_set: irast.MaterializedSet, + *, + ctx: context.CompilerContextLevel, +) -> None: + mat_ids = set(mat_set.uses) + + # We pack optional things into arrays also, since it works. + # TODO: use NULL? + card = mat_set.cardinality + assert card != qltypes.Cardinality.UNKNOWN + is_singleton = card.is_single() and not card.can_be_zero() + + old_scope = ctx.path_scope + ctx.path_scope = old_scope.new_child() + for mat_id in mat_ids: + for k in old_scope: + if k.startswith(mat_id): + ctx.path_scope[k] = None + mat_qry = relgen.set_as_subquery( + mat_set.materialized, as_value=True, ctx=ctx + ) - if not is_singleton: - mat_qry = relctx.set_to_array( - path_id=mat_set.materialized.path_id, - query=mat_qry, - ctx=matctx) + if not is_singleton: + mat_qry = relctx.set_to_array( + path_id=mat_set.materialized.path_id, + query=mat_qry, + ctx=ctx) - if not mat_qry.target_list[0].name: - mat_qry.target_list[0].name = ctx.env.aliases.get('v') + if not mat_qry.target_list[0].name: + mat_qry.target_list[0].name = ctx.env.aliases.get('v') - ref = pgast.ColumnRef( - name=[mat_qry.target_list[0].name], - is_packed_multi=not is_singleton, - ) - for mat_id in mat_ids: - pathctx.put_path_packed_output(mat_qry, mat_id, ref) - - mat_rvar = relctx.rvar_for_rel(mat_qry, lateral=True, ctx=matctx) - for mat_id in mat_ids: - relctx.include_rvar( - query, mat_rvar, path_id=mat_id, - flavor='packed', update_mask=False, pull_namespace=False, - ctx=matctx, - ) + ref = pgast.ColumnRef( + name=[mat_qry.target_list[0].name], + is_packed_multi=not is_singleton, + ) + for mat_id in mat_ids: + pathctx.put_path_packed_output(mat_qry, mat_id, ref) + + mat_rvar = relctx.rvar_for_rel(mat_qry, lateral=True, ctx=ctx) + for mat_id in mat_ids: + relctx.include_rvar( + query, mat_rvar, path_id=mat_id, + flavor='packed', update_mask=False, pull_namespace=False, + ctx=ctx, + ) def compile_iterator_expr( @@ -260,20 +270,103 @@ def compile_output( return val -def compile_dml_bindings( - stmt: irast.Stmt, *, - ctx: context.CompilerContextLevel) -> None: - for binding in (stmt.bindings or ()): +def compile_volatile_bindings( + stmt: irast.Stmt, + *, + ctx: context.CompilerContextLevel +) -> None: + for binding, volatility in (stmt.bindings or ()): # If something we are WITH binding contains DML, we want to # compile it *now*, in the context of its initial appearance - # and not where the variable is used. This will populate - # dml_stmts with the CTEs, which will be picked up when the - # variable is referenced. - if irutils.contains_dml(binding): + # and not where the variable is used. + # + # Similarly, if something we are WITH binding is volatile and the stmt + # contains dml, we similarly want to compile it *now*. + + # If the binding is a with binding for a DML stmt, manually construct + # the CTEs. + # + # Note: This condition is checked first, because if the binding + # *references* DML then contains_dml is true. If the binding is compiled + # normally, since the referenced DML was already compiled, the rvar will + # be retrieved, and no CTEs will be set up. + if volatility.is_volatile() and irutils.contains_dml(stmt): + _compile_volatile_binding_for_dml(stmt, binding, ctx=ctx) + + # For typical DML, just compile it. This will populate dml_stmts with + # the CTEs, which will be picked up when the variable is referenced. + elif irutils.contains_dml(binding): with ctx.substmt() as bctx: dispatch.compile(binding, ctx=bctx) +def _compile_volatile_binding_for_dml( + stmt: irast.Stmt, + binding: irast.Set, + *, + ctx: context.CompilerContextLevel +) -> None: + materialized_set = None + if ( + stmt.materialized_sets + and binding.typeref.id in stmt.materialized_sets + ): + materialized_set = stmt.materialized_sets[binding.typeref.id] + assert materialized_set is not None + + last_iterator = ctx.enclosing_cte_iterator + + with ( + context.output_format(ctx, context.OutputFormat.NATIVE), + ctx.newrel() as matctx + ): + matctx.materializing |= {stmt} + matctx.expr_exposed = True + + dml.merge_iterator(last_iterator, matctx.rel, ctx=matctx) + setup_iterator_volatility(last_iterator, ctx=matctx) + + _compile_materialized_expr( + matctx.rel, materialized_set, ctx=matctx + ) + + # Add iterator identity + bind_pathid = ( + irast.PathId.new_dummy(ctx.env.aliases.get('bind_path')) + ) + with matctx.subrel() as bind_pathid_ctx: + relctx.create_iterator_identity_for_path( + bind_pathid, bind_pathid_ctx.rel, ctx=bind_pathid_ctx + ) + bind_id_rvar = relctx.rvar_for_rel( + bind_pathid_ctx.rel, lateral=True, ctx=matctx + ) + relctx.include_rvar( + matctx.rel, bind_id_rvar, path_id=bind_pathid, ctx=matctx + ) + + bind_cte = pgast.CommonTableExpr( + name=ctx.env.aliases.get('bind'), + query=matctx.rel, + materialized=False, + ) + + bind_iterator = pgast.IteratorCTE( + path_id=bind_pathid, + cte=bind_cte, + parent=last_iterator, + iterator_bond=True, + ) + ctx.toplevel_stmt.append_cte(bind_cte) + + # Merge the new iterator + ctx.path_scope = ctx.path_scope.new_child() + dml.merge_iterator(bind_iterator, ctx.rel, ctx=ctx) + setup_iterator_volatility(bind_iterator, ctx=ctx) + + ctx.enclosing_cte_iterator = bind_iterator + + def compile_filter_clause( ir_set: irast.Set, cardinality: qltypes.Cardinality, *, diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index 57da2adbe95..c443cbe433a 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -102,7 +102,7 @@ def init_dml_stmt( range_cte: Optional[pgast.CommonTableExpr] range_rvar: Optional[pgast.RelRangeVar] - clauses.compile_dml_bindings(ir_stmt, ctx=ctx) + clauses.compile_volatile_bindings(ir_stmt, ctx=ctx) if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): # UPDATE and DELETE operate over a range, so generate diff --git a/edb/pgsql/compiler/group.py b/edb/pgsql/compiler/group.py index 3a225f55877..a3c8e5a0c9a 100644 --- a/edb/pgsql/compiler/group.py +++ b/edb/pgsql/compiler/group.py @@ -173,7 +173,7 @@ def _compile_group( ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel) -> pgast.BaseExpr: - clauses.compile_dml_bindings(stmt, ctx=ctx) + clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index c0e0390fe8b..361de222023 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -976,7 +976,7 @@ def process_set_as_path_type_intersection( assert not rptr.expr, 'type intersection pointer with expr??' - if ir_set.typeref.union is not None and len(ir_set.typeref.union) == 0: + if irtyputils.is_empty_typeref(ir_set.typeref): # If the typeref was a type expression which resolves to no actual # types, just return an empty set. empty_ir = irast.Set( @@ -1148,6 +1148,29 @@ def process_set_as_path( not relctx.find_rvar(stmt, path_id=ir_source.path_id, ctx=ctx) ) + if irtyputils.is_empty_typeref(ir_source.typeref): + # If the source is an empty type intersection, just produce an empty set + + if is_primitive_ref: + aspects = [pgce.PathAspect.VALUE] + else: + aspects = [pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE] + + empty_ir = irast.Set( + path_id=ir_set.path_id, + typeref=ir_set.typeref, + expr=irast.EmptySet(typeref=ir_set.typeref), + ) + empty_rvar = SetRVar( + relctx.new_empty_rvar( + cast('irast.SetE[irast.EmptySet]', empty_ir), + ctx=ctx + ), + path_id=ir_set.path_id, + aspects=aspects, + ) + return SetRVars(main=empty_rvar, new=[empty_rvar]) + main_rvar = None source_rptr = ( ir_source.expr if isinstance(ir_source.expr, irast.Pointer) else None) @@ -2823,9 +2846,12 @@ def process_set_as_enumerate( or arg_expr.limit or arg_expr.offset ) + ) and not any( + f_arg.param_typemod == qltypes.TypeModifier.SetOfType + for _, f_arg in arg_subj.args.items() ) ): - # Enumeration of a SET-returning function + # Enumeration of a non-aggregate function rvars = process_set_as_func_enumerate(ir_set, ctx=ctx) else: rvars = process_set_as_simple_enumerate(ir_set, ctx=ctx) diff --git a/edb/pgsql/compiler/stmt.py b/edb/pgsql/compiler/stmt.py index 9dd11aae42a..0b0a1b75f76 100644 --- a/edb/pgsql/compiler/stmt.py +++ b/edb/pgsql/compiler/stmt.py @@ -54,7 +54,7 @@ def compile_SelectStmt( parent_ctx = ctx with parent_ctx.substmt() as ctx: # Common setup. - clauses.compile_dml_bindings(stmt, ctx=ctx) + clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt diff --git a/edb/pgsql/dbops/base.py b/edb/pgsql/dbops/base.py index 9dde6621435..04b9051bce8 100644 --- a/edb/pgsql/dbops/base.py +++ b/edb/pgsql/dbops/base.py @@ -504,6 +504,10 @@ def __repr__(self) -> str: return f'' +class PLQuery(Query): + pass + + class DefaultMeta(type): def __bool__(cls): return False diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 817054691ab..d5f22d51d4e 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -2953,7 +2953,7 @@ def _alter_finalize( return schema -def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.Query: +def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.PLQuery: if len(pg_type) == 1: types_cte = f''' SELECT @@ -2980,7 +2980,6 @@ def drop_dependant_func_cache(pg_type: Tuple[str, ...]) -> dbops.Query: )\ ''' drop_func_cache_sql = textwrap.dedent(f''' - DO $$ DECLARE qc RECORD; BEGIN @@ -3014,9 +3013,9 @@ class AS ( LOOP PERFORM edgedb_VER."_evict_query_cache"(qc.key); END LOOP; - END $$; + END; ''') - return dbops.Query(drop_func_cache_sql) + return dbops.PLQuery(drop_func_cache_sql) class DeleteScalarType(ScalarTypeMetaCommand, diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 3d0d4f9bfe9..c22030fdfaa 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -32,6 +32,7 @@ cast, ) +import json import re import edb._edgeql_parser as ql_parser @@ -57,6 +58,7 @@ from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import properties as s_props +from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import sources as s_sources from edb.schema import types as s_types @@ -105,13 +107,13 @@ class PGConnection(Protocol): async def sql_execute( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, ) -> None: ... async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: @@ -1497,6 +1499,83 @@ def __init__(self) -> None: ) +class RaiseNoticeFunction(trampoline.VersionedFunction): + text = ''' + BEGIN + RAISE NOTICE USING + MESSAGE = "msg", + DETAIL = COALESCE("detail", ''), + HINT = COALESCE("hint", ''), + COLUMN = COALESCE("column", ''), + CONSTRAINT = COALESCE("constraint", ''), + DATATYPE = COALESCE("datatype", ''), + TABLE = COALESCE("table", ''), + SCHEMA = COALESCE("schema", ''); + RETURN "rtype"; + END; + ''' + + def __init__(self) -> None: + super().__init__( + name=('edgedb', 'notice'), + args=[ + ('rtype', ('anyelement',)), + ('msg', ('text',), "''"), + ('detail', ('text',), "''"), + ('hint', ('text',), "''"), + ('column', ('text',), "''"), + ('constraint', ('text',), "''"), + ('datatype', ('text',), "''"), + ('table', ('text',), "''"), + ('schema', ('text',), "''"), + ], + returns=('anyelement',), + # NOTE: The main reason why we don't want this function to be + # immutable is that immutable functions can be + # pre-evaluated by the query planner once if they have + # constant arguments. This means that using this function + # as the second argument in a COALESCE will raise a + # notice regardless of whether the first argument is + # NULL or not. + volatility='stable', + language='plpgsql', + text=self.text, + ) + + +# edgedb.indirect_return() to be used to return values from +# anonymous code blocks or other contexts that have no return +# data channel. +class IndirectReturnFunction(trampoline.VersionedFunction): + text = """ + SELECT + edgedb_VER.notice( + NULL::text, + msg => 'edb:notice:indirect_return', + detail => "value" + ) + """ + + def __init__(self) -> None: + super().__init__( + name=('edgedb', 'indirect_return'), + args=[ + ('value', ('text',)), + ], + returns=('text',), + # NOTE: The main reason why we don't want this function to be + # immutable is that immutable functions can be + # pre-evaluated by the query planner once if they have + # constant arguments. This means that using this function + # as the second argument in a COALESCE will raise a + # notice regardless of whether the first argument is + # NULL or not. + volatility='stable', + language='sql', + text=self.text, + ) + + class RaiseExceptionFunction(trampoline.VersionedFunction): text = ''' BEGIN @@ -4456,8 +4535,13 @@ class GetPgTypeForEdgeDBTypeFunction2(trampoline.VersionedFunction): 'invalid_parameter_value', msg => ( format( - 'cannot determine OID of Gel type %L', - "typeid"::text + 'cannot determine Postgres OID of Gel %s(%L)%s', + "kind", + "typeid"::text, + (case when "elemid" is not null + then ' with element type ' || "elemid"::text + else '' + end) ) ) ) @@ -4824,6 +4908,118 @@ def __init__(self) -> None: ) +class ResetQueryStatsFunction(trampoline.VersionedFunction): + text = r""" + DECLARE + tenant_id TEXT; + other_tenant_exists BOOLEAN; + db_oid OID; + queryid bigint; + BEGIN + tenant_id := edgedb_VER.get_backend_tenant_id(); + IF id IS NULL THEN + queryid := 0; + ELSE + queryid := edgedbext.edb_stat_queryid(id); + END IF; + + SELECT EXISTS ( + SELECT 1 + FROM + pg_database dat + CROSS JOIN LATERAL ( + SELECT + edgedb_VER.shobj_metadata(dat.oid, 'pg_database') + AS description + ) AS d + WHERE + (d.description)->>'id' IS NOT NULL + AND (d.description)->>'tenant_id' != tenant_id + ) INTO other_tenant_exists; + + IF branch_name IS NULL THEN + IF other_tenant_exists THEN + RETURN edgedbext.edb_stat_statements_reset( + 0, -- userid + ARRAY( + SELECT + dat.oid + FROM + pg_database dat + CROSS JOIN LATERAL ( + SELECT + edgedb_VER.shobj_metadata(dat.oid, + 'pg_database') + AS description + ) AS d + WHERE + (d.description)->>'id' IS NOT NULL + AND (d.description)->>'tenant_id' = tenant_id + ), + queryid, + COALESCE(minmax_only, false) + ); + ELSE + RETURN edgedbext.edb_stat_statements_reset( + 0, -- userid + '{}', -- database oid + queryid, + COALESCE(minmax_only, false) + ); + END IF; + ELSE + SELECT + dat.oid INTO db_oid + FROM + pg_database dat + CROSS JOIN LATERAL ( + SELECT + edgedb_VER.shobj_metadata(dat.oid, 'pg_database') + AS description + ) AS d + WHERE + (d.description)->>'id' IS NOT NULL + AND (d.description)->>'tenant_id' = tenant_id + AND edgedb_VER.get_database_frontend_name(dat.datname) = + branch_name; + + IF db_oid IS NULL THEN + RETURN NULL::edgedbt.timestamptz_t; + END IF; + + RETURN edgedbext.edb_stat_statements_reset( + 0, -- userid + ARRAY[db_oid], + queryid, + COALESCE(minmax_only, false) + ); + END IF; + + RETURN now()::edgedbt.timestamptz_t; + END; + """ + + noop_text = r""" + BEGIN + RETURN NULL::edgedbt.timestamptz_t; + END; + """ + + def __init__(self, enable_stats: bool) -> None: + super().__init__( + name=('edgedb', 'reset_query_stats'), + args=[ + ('branch_name', ('text',)), + ('id', ('uuid',)), + ('minmax_only', ('bool',)), + ], + returns=('edgedbt', 'timestamptz_t'), + volatility='volatile', + language='plpgsql', + text=self.text if enable_stats else self.noop_text, + ) + + def _maybe_trampoline( cmd: dbops.Command, out: list[trampoline.Trampoline] ) -> None: @@ -4975,6 +5171,8 @@ def get_bootstrap_commands( dbops.CreateFunction(GetSharedObjectMetadata()), dbops.CreateFunction(GetDatabaseMetadataFunction()), dbops.CreateFunction(GetCurrentDatabaseFunction()), + dbops.CreateFunction(RaiseNoticeFunction()), + dbops.CreateFunction(IndirectReturnFunction()), dbops.CreateFunction(RaiseExceptionFunction()), dbops.CreateFunction(RaiseExceptionOnNullFunction()), dbops.CreateFunction(RaiseExceptionOnNotNullFunction()), @@ -5047,6 +5245,7 @@ def get_bootstrap_commands( dbops.CreateFunction(FTSNormalizeDocFunction()), dbops.CreateFunction(FTSToRegconfig()), dbops.CreateFunction(PadBase64StringFunction()), + dbops.CreateFunction(ResetQueryStatsFunction(False)), ] commands = dbops.CommandGroup() @@ -5068,15 +5267,19 @@ async def create_pg_extensions( commands.add_command( dbops.CreateSchema(name=ext_schema, conditional=True), ) - if ( - inst_params.existing_exts is None - or inst_params.existing_exts.get("uuid-ossp") is None - ): - commands.add_commands([ - dbops.CreateExtension( - dbops.Extension(name='uuid-ossp', schema=ext_schema), - ), - ]) + extensions = ["uuid-ossp"] + if backend_params.has_stat_statements: + extensions.append("edb_stat_statements") + for ext in extensions: + if ( + inst_params.existing_exts is None + or inst_params.existing_exts.get(ext) is None + ): + commands.add_commands([ + dbops.CreateExtension( + dbops.Extension(name=ext, schema=ext_schema), + ), + ]) block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) @@ -5773,6 +5976,109 @@ def _generate_schema_ver_views(schema: s_schema.Schema) -> List[dbops.View]: return views +def _generate_stats_views(schema: s_schema.Schema) -> List[dbops.View]: + QueryStats = schema.get( + 'sys::QueryStats', + type=s_objtypes.ObjectType, + ) + pvd = common.get_backend_name( + schema, + QueryStats + .getptr(schema, s_name.UnqualName("protocol_version")) + .get_target(schema) # type: ignore + ) + QueryType = schema.get( + 'sys::QueryType', + type=s_scalars.ScalarType, + ) + query_type_domain = common.get_backend_name(schema, QueryType) + type_mapping = { + str(v): k for k, v in defines.QueryType.__members__.items() + } + output_format_domain = common.get_backend_name( + schema, schema.get('sys::OutputFormat', type=s_scalars.ScalarType) + ) + + def float64_to_duration_t(val: str) -> str: + return f"({val} * interval '1ms')::edgedbt.duration_t" + + query_stats_fields = { + 'id': "s.id", + 'name': "s.id::text", + 'name__internal': "s.queryid::text", + 'builtin': "false", + 'internal': "false", + 'computed_fields': 'ARRAY[]::text[]', + + 'compilation_config': "s.extras->'cc'", + 'protocol_version': f"ROW(s.extras->'pv'->0, s.extras->'pv'->1)::{pvd}", + 'default_namespace': "s.extras->>'dn'", + 'namespace_aliases': "s.extras->'na'", + 'output_format': f"(s.extras->>'of')::{output_format_domain}", + 'expect_one': "(s.extras->'e1')::boolean", + 'implicit_limit': "(s.extras->'il')::bigint", + 'inline_typeids': "(s.extras->'ii')::boolean", + 'inline_typenames': "(s.extras->'in')::boolean", + 'inline_objectids': "(s.extras->'io')::boolean", + + 'branch': "((d.description)->>'id')::uuid", + 'query': "s.query", + 'query_type': f"(t.mapping->>s.stmt_type::text)::{query_type_domain}", + + 'plans': 's.plans', + 'total_plan_time': float64_to_duration_t('s.total_plan_time'), + 'min_plan_time': float64_to_duration_t('s.min_plan_time'), + 'max_plan_time': float64_to_duration_t('s.max_plan_time'), + 'mean_plan_time': float64_to_duration_t('s.mean_plan_time'), + 'stddev_plan_time': float64_to_duration_t('s.stddev_plan_time'), + + 'calls': 's.calls', + 'total_exec_time': float64_to_duration_t('s.total_exec_time'), + 'min_exec_time': float64_to_duration_t('s.min_exec_time'), + 'max_exec_time': float64_to_duration_t('s.max_exec_time'), + 'mean_exec_time': float64_to_duration_t('s.mean_exec_time'), + 'stddev_exec_time': float64_to_duration_t('s.stddev_exec_time'), + + 'rows': 's.rows', + 'stats_since': 's.stats_since::edgedbt.timestamptz_t', + 'minmax_stats_since': 's.minmax_stats_since::edgedbt.timestamptz_t', + } + + query_stats_query = fr''' + SELECT + {format_fields(schema, QueryStats, query_stats_fields)} + FROM + edgedbext.edb_stat_statements AS s + INNER JOIN pg_database dat ON s.dbid = dat.oid + CROSS JOIN LATERAL ( + SELECT + edgedb_VER.shobj_metadata(dat.oid, 'pg_database') + AS description + ) AS d + CROSS JOIN LATERAL ( + SELECT {ql(json.dumps(type_mapping))}::jsonb AS mapping + ) AS t + WHERE + s.id IS NOT NULL + AND (d.description)->>'id' IS NOT NULL + AND (d.description)->>'tenant_id' + = edgedb_VER.get_backend_tenant_id() + AND t.mapping ? s.stmt_type::text + ''' + + objects = { + QueryStats: query_stats_query, + } + + views: list[dbops.View] = [] + for obj, query in objects.items(): + tabview = trampoline.VersionedView( + name=tabname(schema, obj), query=query) + views.append(tabview) + + return views + + def _make_json_caster( schema: s_schema.Schema, stype: s_types.Type, @@ -6278,20 +6584,6 @@ def _generate_sql_information_schema( ) ), ), - # TODO: Should we try to filter here, and fix up some stuff - # elsewhere, instead of overriding pg_get_constraintdef? - trampoline.VersionedView( - name=("edgedbsql", "pg_constraint"), - query=""" - SELECT - pc.*, - pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid - FROM pg_constraint pc - JOIN pg_namespace pn ON pc.connamespace = pn.oid - WHERE NOT (pn.nspname = 'edgedbpub' AND pc.conbin IS NOT NULL) - """ - ), - # pg_class that contains classes only for tables # This is needed so we can use it to filter pg_index to indexes only on # visible tables. @@ -6353,7 +6645,7 @@ def _generate_sql_information_schema( ), trampoline.VersionedView( name=("edgedbsql", "pg_index"), - query=""" + query=f""" SELECT pi.indexrelid, pi.indrelid, @@ -6363,7 +6655,7 @@ def _generate_sql_information_schema( WHEN COALESCE(is_id.t, FALSE) THEN TRUE ELSE pi.indisprimary END AS indisunique, - pi.indnullsnotdistinct, + {'pi.indnullsnotdistinct,' if backend_version.major >= 15 else ''} CASE WHEN COALESCE(is_id.t, FALSE) THEN TRUE ELSE pi.indisprimary @@ -6373,10 +6665,16 @@ def _generate_sql_information_schema( pi.indisclustered, pi.indisvalid, pi.indcheckxmin, - pi.indisready, + CASE + WHEN COALESCE(is_id.t, FALSE) THEN TRUE + ELSE FALSE -- override so pg_dump won't try to recreate them + END AS indisready, pi.indislive, pi.indisreplident, - pi.indkey, + CASE + WHEN COALESCE(is_id.t, FALSE) THEN ARRAY[1]::int2vector -- id: 1 + ELSE pi.indkey + END AS indkey, pi.indcollation, pi.indclass, pi.indoption, @@ -6606,13 +6904,10 @@ def _generate_sql_information_schema( pa.attrelid as pc_oid, pa.*, pa.tableoid, pa.xmin, pa.cmin, pa.xmax, pa.cmax, pa.ctid - FROM edgedb_VER."_SchemaPointer" sp + FROM edgedb_VER."_SchemaProperty" sp JOIN pg_class pc ON pc.relname = sp.id::TEXT JOIN pg_attribute pa ON pa.attrelid = pc.oid - -- needed for filtering out links - LEFT JOIN edgedb_VER."_SchemaLink" sl ON sl.id = sp.id - -- positions for special pointers JOIN ( VALUES ('source', 0), @@ -6620,8 +6915,7 @@ def _generate_sql_information_schema( ) spec(k, position) ON (spec.k = pa.attname) WHERE - sl.id IS NULL -- property (non-link) - AND sp.cardinality = 'Many' -- multi + sp.cardinality = 'Many' -- multi AND sp.expr IS NULL -- non-computed UNION ALL @@ -6733,6 +7027,164 @@ def _generate_sql_information_schema( WHERE FALSE """, ), + trampoline.VersionedView( + name=("edgedbsql", "pg_constraint"), + query=r""" + -- primary keys for: + -- - objects tables (that contains id) + -- - link tables (that contains source and target) + -- there exists a unique constraint for each of these + SELECT + pc.oid, + vt.table_name || '_pk' AS conname, + pc.connamespace, + 'p'::"char" AS contype, + pc.condeferrable, + pc.condeferred, + pc.convalidated, + pc.conrelid, + pc.contypid, + pc.conindid, + pc.conparentid, + NULL::oid AS confrelid, + NULL::"char" AS confupdtype, + NULL::"char" AS confdeltype, + NULL::"char" AS confmatchtype, + pc.conislocal, + pc.coninhcount, + pc.connoinherit, + CASE WHEN pa.attname = 'id' + THEN ARRAY[1]::int2[] -- id will always have attnum 1 + ELSE ARRAY[1, 2]::int2[] -- source and target + END AS conkey, + NULL::int2[] AS confkey, + NULL::oid[] AS conpfeqop, + NULL::oid[] AS conppeqop, + NULL::oid[] AS conffeqop, + NULL::int2[] AS confdelsetcols, + NULL::oid[] AS conexclop, + pc.conbin, + pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid + FROM pg_constraint pc + JOIN edgedbsql_VER.pg_class_tables pct ON pct.oid = pc.conrelid + JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pct.reltype + JOIN pg_attribute pa + ON (pa.attrelid = pct.oid + AND pa.attnum = ANY(conkey) + AND pa.attname IN ('id', 'source') + ) + WHERE contype = 'u' -- our ids and all links will have unique constraint + + UNION ALL + + -- foreign keys for object tables + SELECT + edgedbsql_VER.uuid_to_oid(sl.id) as oid, + vt.table_name || '_fk_' || sl.name AS conname, + edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, + 'f'::"char" AS contype, + FALSE AS condeferrable, + FALSE AS condeferred, + TRUE AS convalidated, + pc.oid AS conrelid, + 0::oid AS contypid, + 0::oid AS conindid, -- let's hope this is not needed + 0::oid AS conparentid, + pc_target.oid AS confrelid, + 'a'::"char" AS confupdtype, + 'a'::"char" AS confdeltype, + 's'::"char" AS confmatchtype, + TRUE AS conislocal, + 0::int2 AS coninhcount, + TRUE AS connoinherit, + ARRAY[pa.attnum]::int2[] AS conkey, + ARRAY[1]::int2[] AS confkey, -- id will always have attnum 1 + ARRAY['uuid_eq'::regproc]::oid[] AS conpfeqop, + ARRAY['uuid_eq'::regproc]::oid[] AS conppeqop, + ARRAY['uuid_eq'::regproc]::oid[] AS conffeqop, + NULL::int2[] AS confdelsetcols, + NULL::oid[] AS conexclop, + NULL::pg_node_tree AS conbin, + pa.tableoid, pa.xmin, pa.cmin, pa.xmax, pa.cmax, pa.ctid + FROM edgedbsql_VER.virtual_tables vt + JOIN pg_class pc ON pc.reltype = vt.pg_type_id + JOIN edgedb_VER."_SchemaLink" sl + ON sl.source = vt.id -- AND COALESCE(sl.cardinality = 'One', TRUE) + JOIN edgedbsql_VER.virtual_tables vt_target + ON sl.target = vt_target.id + JOIN pg_class pc_target ON pc_target.reltype = vt_target.pg_type_id + JOIN edgedbsql_VER.pg_attribute pa + ON pa.attrelid = pc.oid + AND pa.attname = sl.name || '_id' + + UNION ALL + + -- foreign keys for: + -- - multi link tables (source & target), + -- - multi property tables (source), + -- - single link with link properties (source & target), + -- these constraints do not actually exist, so we emulate it entierly + SELECT + edgedbsql_VER.uuid_to_oid(sp.id) AS oid, + vt.table_name || '_fk_' || spec.name AS conname, + edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, + 'f'::"char" AS contype, + FALSE AS condeferrable, + FALSE AS condeferred, + TRUE AS convalidated, + pc.oid AS conrelid, + pc.reltype AS contypid, + 0::oid AS conindid, -- TODO + 0::oid AS conparentid, + pcf.oid AS confrelid, + 'r'::"char" AS confupdtype, + 'r'::"char" AS confdeltype, + 's'::"char" AS confmatchtype, + TRUE AS conislocal, + 0::int2 AS coninhcount, + TRUE AS connoinherit, + ARRAY[spec.attnum]::int2[] AS conkey, + ARRAY[1]::int2[] AS confkey, -- id will have attnum 1 + ARRAY['uuid_eq'::regproc]::oid[] AS conpfeqop, + ARRAY['uuid_eq'::regproc]::oid[] AS conppeqop, + ARRAY['uuid_eq'::regproc]::oid[] AS conffeqop, + NULL::int2[] AS confdelsetcols, + NULL::oid[] AS conexclop, + pc.relpartbound AS conbin, + pc.tableoid, + pc.xmin, + pc.cmin, + pc.xmax, + pc.cmax, + pc.ctid + FROM edgedb_VER."_SchemaPointer" sp + + -- find links with link properties + LEFT JOIN LATERAL ( + SELECT sl.id + FROM edgedb_VER."_SchemaLink" sl + LEFT JOIN edgedb_VER."_SchemaProperty" AS slp ON slp.source = sl.id + GROUP BY sl.id + HAVING COUNT(*) > 2 + ) link_props ON link_props.id = sp.id + + JOIN pg_class pc ON pc.relname = sp.id::TEXT + JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pc.reltype + + -- duplicate each row for source and target + JOIN LATERAL (VALUES + ('source', 1::int2, sp.source), + ('target', 2::int2, sp.target) + ) spec(name, attnum, foreign_id) ON TRUE + JOIN edgedbsql_VER.virtual_tables vtf ON vtf.id = spec.foreign_id + JOIN pg_class pcf ON pcf.reltype = vtf.pg_type_id + + WHERE + sp.cardinality = 'Many' OR link_props.id IS NOT NULL + AND sp.computable IS NOT TRUE + AND sp.internal IS NOT TRUE + """ + ), trampoline.VersionedView( name=("edgedbsql", "pg_statistic"), query=""" @@ -6909,6 +7361,18 @@ def _generate_sql_information_schema( WHERE c.relkind = 'v'::"char" """, ), + # Omit all descriptions (comments), becase all non-system comments + # are our internal implementation details. + trampoline.VersionedView( + name=("edgedbsql", "pg_description"), + query=""" + SELECT + *, + tableoid, xmin, cmin, xmax, cmax, ctid + FROM pg_description + WHERE FALSE + """, + ), ] # We expose most of the views as empty tables, just to prevent errors when @@ -6957,6 +7421,7 @@ def _generate_sql_information_schema( 'pg_subscription', 'pg_tables', 'pg_views', + 'pg_description', } PG_TABLES_WITH_SYSTEM_COLS = { @@ -6977,7 +7442,6 @@ def _generate_sql_information_schema( 'pg_db_role_setting', 'pg_default_acl', 'pg_depend', - 'pg_description', 'pg_enum', 'pg_event_trigger', 'pg_extension', @@ -7268,7 +7732,51 @@ def construct_pg_view( WHERE t.oid = typeoid ''', - ) + ), + trampoline.VersionedFunction( + name=("edgedbsql", "pg_get_constraintdef"), + args=[ + ('conid', ('oid',)), + ], + returns=('text',), + volatility='stable', + text=r""" + SELECT CASE + WHEN contype = 'p' THEN + 'PRIMARY KEY(' || ( + SELECT string_agg('"' || attname || '"', ', ') + FROM edgedbsql_VER.pg_attribute + WHERE attrelid = conrelid AND attnum = ANY(conkey) + ) || ')' + WHEN contype = 'f' THEN + 'FOREIGN KEY ("' || ( + SELECT attname + FROM edgedbsql_VER.pg_attribute + WHERE attrelid = conrelid AND attnum = ANY(conkey) + LIMIT 1 + ) || '")' || ' REFERENCES "' + || pn.nspname || '"."' || pc.relname || '"(id)' + ELSE '' + END + FROM edgedbsql_VER.pg_constraint con + LEFT JOIN edgedbsql_VER.pg_class_tables pc ON pc.oid = confrelid + LEFT JOIN edgedbsql_VER.pg_namespace pn + ON pc.relnamespace = pn.oid + WHERE con.oid = conid + """ + ), + trampoline.VersionedFunction( + name=("edgedbsql", "pg_get_constraintdef"), + args=[ + ('conid', ('oid',)), + ('pretty', ('bool',)), + ], + returns=('text',), + volatility='stable', + text=r""" + SELECT pg_get_constraintdef(conid) + """ + ), ] return ( @@ -7400,6 +7908,15 @@ def get_synthetic_type_views( for verview in _generate_schema_ver_views(schema): commands.add_command(dbops.CreateView(verview, or_replace=True)) + if backend_params.has_stat_statements: + for stats_view in _generate_stats_views(schema): + commands.add_command(dbops.CreateView(stats_view, or_replace=True)) + commands.add_command( + dbops.CreateFunction( + ResetQueryStatsFunction(True), or_replace=True + ) + ) + return commands diff --git a/edb/pgsql/params.py b/edb/pgsql/params.py index a4df60054a4..7d9d64097fc 100644 --- a/edb/pgsql/params.py +++ b/edb/pgsql/params.py @@ -40,6 +40,8 @@ class BackendCapabilities(enum.IntFlag): CREATE_ROLE = 1 << 3 #: Whether CREATE DATABASE is allowed CREATE_DATABASE = 1 << 4 + #: Whether extension "edb_stat_statements" is available + STAT_STATEMENTS = 1 << 5 ALL_BACKEND_CAPABILITIES = ( @@ -48,6 +50,7 @@ class BackendCapabilities(enum.IntFlag): | BackendCapabilities.C_UTF8_LOCALE | BackendCapabilities.CREATE_ROLE | BackendCapabilities.CREATE_DATABASE + | BackendCapabilities.STAT_STATEMENTS ) @@ -111,6 +114,13 @@ def has_create_database(self) -> bool: & BackendCapabilities.CREATE_DATABASE ) + @property + def has_stat_statements(self) -> bool: + return self.has_superuser_access and bool( + self.instance_params.capabilities + & BackendCapabilities.STAT_STATEMENTS + ) + @functools.lru_cache def get_default_runtime_params( diff --git a/edb/pgsql/parser/.gitignore b/edb/pgsql/parser/.gitignore index 064a8d8ef55..b5a455aa984 100644 --- a/edb/pgsql/parser/.gitignore +++ b/edb/pgsql/parser/.gitignore @@ -1 +1 @@ -*.c +/*.c diff --git a/edb/pgsql/parser/__init__.py b/edb/pgsql/parser/__init__.py index db3565bbdd0..063e51f2c6c 100644 --- a/edb/pgsql/parser/__init__.py +++ b/edb/pgsql/parser/__init__.py @@ -16,19 +16,40 @@ # limitations under the License. # -from typing import List +from __future__ import annotations + +from typing import ( + List, +) import json from edb.pgsql import ast as pgast -from .parser import pg_parse -from .ast_builder import build_stmts +from . import ast_builder +from . import parser +from .parser import ( + Source, + NormalizedSource, + deserialize, +) + + +__all__ = ( + "parse", + "Source", + "NormalizedSource", + "deserialize" +) def parse( sql_query: str, propagate_spans: bool = False ) -> List[pgast.Query | pgast.Statement]: - ast_json = pg_parse(bytes(sql_query, encoding="UTF8")) + ast_json = parser.pg_parse(bytes(sql_query, encoding="UTF8")) - return build_stmts(json.loads(ast_json), sql_query, propagate_spans) + return ast_builder.build_stmts( + json.loads(ast_json), + sql_query, + propagate_spans, + ) diff --git a/edb/pgsql/parser/ast_builder.py b/edb/pgsql/parser/ast_builder.py index ea4a23d3826..1b62b424cc4 100644 --- a/edb/pgsql/parser/ast_builder.py +++ b/edb/pgsql/parser/ast_builder.py @@ -36,7 +36,7 @@ from edb.pgsql import ast as pgast from edb.edgeql import ast as qlast -from edb.pgsql.parser.exceptions import PSqlUnsupportedError +from edb.pgsql.parser.exceptions import PSqlUnsupportedError, get_node_name @dataclasses.dataclass(kw_only=True) @@ -134,7 +134,9 @@ def _enum( if outer_fallback: return None # type: ignore - raise PSqlUnsupportedError(node, ", ".join(node.keys())) + raise PSqlUnsupportedError( + node, ", ".join(get_node_name(k) for k in node.keys()) + ) finally: ctx.has_fallback = outer_fallback @@ -905,7 +907,7 @@ def _build_range_function(n: Node, c: Context) -> pgast.RangeFunction: with_ordinality=_bool_or_false(n, "ordinality"), is_rowsfrom=_bool_or_false(n, "is_rowsfrom"), functions=[ - _build_func_call(fn, c) + _build_base_expr(fn, c) for fn in n["functions"][0]["List"]["items"] if len(fn) > 0 ], diff --git a/edb/pgsql/parser/exceptions.py b/edb/pgsql/parser/exceptions.py index 207f485f577..82053a8f364 100644 --- a/edb/pgsql/parser/exceptions.py +++ b/edb/pgsql/parser/exceptions.py @@ -16,7 +16,7 @@ # limitations under the License. # - +import re from typing import Any, Optional @@ -34,9 +34,20 @@ class PSqlUnsupportedError(Exception): def __init__(self, node: Optional[Any] = None, feat: Optional[str] = None): self.node = node self.location = None - self.message = "unsupported SQL feature" + self.message = "not supported" if feat: - self.message += f" `{feat}`" + self.message += f": {feat}" def __str__(self): return self.message + + +def get_node_name(name: str) -> str: + """ + Given a node name (CreateTableStmt), this function tries to guess the SQL + command text (CREATE TABLE). + """ + + name = name.removesuffix('Stmt').removesuffix('Expr') + name = re.sub(r'(? str: @@ -49,3 +98,282 @@ def pg_parse(query) -> str: result_utf8 = result.parse_tree.decode('utf8') pg_query_free_parse_result(result) return result_utf8 + + +class LiteralTokenType(enum.StrEnum): + FCONST = "FCONST" + SCONST = "SCONST" + BCONST = "BCONST" + XCONST = "XCONST" + ICONST = "ICONST" + TRUE_P = "TRUE_P" + FALSE_P = "FALSE_P" + + +class PgLiteralTypeOID(enum.IntEnum): + BOOL = 16 + INT4 = 23 + TEXT = 25 + VARBIT = 1562 + NUMERIC = 1700 + + +class NormalizedQuery(NamedTuple): + text: str + highest_extern_param_id: int + extracted_constants: list[tuple[int, LiteralTokenType, bytes]] + + +def pg_normalize(query: str) -> NormalizedQuery: + cdef: + PgQueryNormalizeResult result + PgQueryNormalizeConstLocation loc + const ProtobufCEnumValue *token + int i + bytes queryb + bytes const + + queryb = query.encode("utf-8") + result = pg_query_normalize(queryb) + + try: + if result.error: + error = PSqlParseError( + result.error.message.decode('utf8'), + result.error.lineno, result.error.cursorpos + ) + raise error + + normalized_query = result.normalized_query.decode('utf8') + consts = [] + for i in range(result.clocations_count): + loc = result.clocations[i] + if loc.length != -1: + if loc.param_id < 0: + # Negative param_id means *relative* to highest explicit + # param id (after taking the absolute value). + param_id = ( + abs(loc.param_id) + + result.highest_extern_param_id + ) + else: + # Otherwise it's the absolute param id. + param_id = loc.param_id + if loc.val != NULL: + token = protobuf_c_enum_descriptor_get_value( + &pg_query__token__descriptor, loc.token) + if token == NULL: + raise RuntimeError( + f"could not lookup pg_query enum descriptor " + f"for token value {loc.token}" + ) + consts.append(( + param_id, + LiteralTokenType(bytes(token.name).decode("ascii")), + bytes(loc.val), + )) + + return NormalizedQuery( + text=normalized_query, + highest_extern_param_id=result.highest_extern_param_id, + extracted_constants=consts, + ) + finally: + pg_query_free_normalize_result(result) + + +cdef ReadBuffer _init_deserializer(serialized: bytes, tag: uint8_t, cls: str): + cdef ReadBuffer buf + + buf = ReadBuffer.new_message_parser(serialized) + + if buf.read_byte() != tag: + raise ValueError(f"malformed {cls} serialization") + + return buf + + +cdef class Source: + def __init__( + self, + text: str, + serialized: Optional[bytes] = None, + ) -> None: + self._text = text + if serialized is not None: + self._serialized = serialized + else: + self._serialized = b'' + self._cache_key = b'' + + @classmethod + def _tag(self) -> int: + return 0 + + cdef WriteBuffer _serialize(self): + cdef WriteBuffer buf = WriteBuffer.new() + buf.write_byte(self._tag()) + buf.write_len_prefixed_utf8(self._text) + return buf + + def serialize(self) -> bytes: + if not self._serialized: + self._serialized = bytes(self._serialize()) + return self._serialized + + @classmethod + def from_serialized(cls, serialized: bytes) -> NormalizedSource: + cdef ReadBuffer buf + + buf = _init_deserializer(serialized, cls._tag(), cls.__name__) + text = buf.read_len_prefixed_utf8() + + return Source(text, serialized) + + def text(self) -> str: + return self._text + + def cache_key(self) -> bytes: + if not self._cache_key: + self._cache_key = hashlib.blake2b(self.serialize()).digest() + return self._cache_key + + def variables(self) -> dict[str, Any]: + return {} + + def first_extra(self) -> Optional[int]: + return None + + def extra_counts(self) -> Sequence[int]: + return [] + + def extra_blobs(self) -> Sequence[bytes]: + return () + + def extra_formatted_as_text(self) -> bool: + return True + + def extra_type_oids(self) -> Sequence[int]: + return () + + @classmethod + def from_string(cls, text: str) -> Source: + return Source(text) + + +cdef class NormalizedSource(Source): + def __init__( + self, + normalized: NormalizedQuery, + orig_text: str, + serialized: Optional[bytes] = None, + ) -> None: + super().__init__(text=normalized.text, serialized=serialized) + self._extracted_constants = list( + sorted(normalized.extracted_constants, key=lambda i: i[0]), + ) + self._highest_extern_param_id = normalized.highest_extern_param_id + self._orig_text = orig_text + + @classmethod + def _tag(cls) -> int: + return 1 + + cdef WriteBuffer _serialize(self): + cdef WriteBuffer buf + + buf = Source._serialize(self) + buf.write_len_prefixed_utf8(self._orig_text) + buf.write_int32(self._highest_extern_param_id) + buf.write_int32(len(self._extracted_constants)) + for param_id, token, val in self._extracted_constants: + buf.write_int32(param_id) + buf.write_len_prefixed_utf8(token.value) + buf.write_len_prefixed_bytes(val) + + return buf + + def variables(self) -> dict[str, bytes]: + return {f"${n}": v[1] for n, _, v in self._extracted_constants} + + def first_extra(self) -> Optional[int]: + return ( + self._highest_extern_param_id + if self._extracted_constants + else None + ) + + def extra_counts(self) -> Sequence[int]: + return [len(self._extracted_constants)] + + def extra_blobs(self) -> list[bytes]: + cdef WriteBuffer buf + buf = WriteBuffer.new() + for _, _, v in self._extracted_constants: + buf.write_len_prefixed_bytes(v) + + return [bytes(buf)] + + def extra_type_oids(self) -> Sequence[int]: + oids = [] + for _, token, _ in self._extracted_constants: + if token is LiteralTokenType.FCONST: + oids.append(PgLiteralTypeOID.NUMERIC) + elif token is LiteralTokenType.ICONST: + oids.append(PgLiteralTypeOID.INT4) + elif ( + token is LiteralTokenType.FALSE_P + or token is LiteralTokenType.TRUE_P + ): + oids.append(PgLiteralTypeOID.BOOL) + elif token is LiteralTokenType.SCONST: + oids.append(PgLiteralTypeOID.TEXT) + elif ( + token is LiteralTokenType.XCONST + or token is LiteralTokenType.BCONST + ): + oids.append(PgLiteralTypeOID.VARBIT) + else: + raise AssertionError(f"unexpected literal token type: {token}") + + return oids + + @classmethod + def from_string(cls, text: str) -> NormalizedSource: + normalized = pg_normalize(text) + return NormalizedSource(normalized, text) + + @classmethod + def from_serialized(cls, serialized: bytes) -> NormalizedSource: + cdef ReadBuffer buf + + buf = _init_deserializer(serialized, cls._tag(), cls.__name__) + text = buf.read_len_prefixed_utf8() + orig_text = buf.read_len_prefixed_utf8() + highest_extern_param_id = buf.read_int32() + n_constants = buf.read_int32() + consts = [] + for _ in range(n_constants): + param_id = buf.read_int32() + token = buf.read_len_prefixed_utf8() + val = buf.read_len_prefixed_bytes() + consts.append((param_id, LiteralTokenType(token), val)) + + return NormalizedSource( + NormalizedQuery( + text=text, + highest_extern_param_id=highest_extern_param_id, + extracted_constants=consts, + ), + orig_text, + serialized, + ) + + +def deserialize(serialized: bytes) -> Source: + if serialized[0] == 0: + return Source.from_serialized(serialized) + elif serialized[0] == 1: + return NormalizedSource.from_serialized(serialized) + + raise ValueError(f"Invalid type/version byte: {serialized[0]}") diff --git a/edb/pgsql/resolver/__init__.py b/edb/pgsql/resolver/__init__.py index ddf37165c15..1dc3c2e8aa1 100644 --- a/edb/pgsql/resolver/__init__.py +++ b/edb/pgsql/resolver/__init__.py @@ -18,6 +18,8 @@ from __future__ import annotations from typing import Optional, List + +import copy import dataclasses from edb.common import debug @@ -41,6 +43,10 @@ class ResolvedSQL: # AST representing the query that can be sent to PostgreSQL ast: pgast.Base + # Optionally, AST representing the query returning data in EdgeQL + # format (i.e. single-column output). + edgeql_output_format_ast: Optional[pgast.Base] + # Special behavior for "tag" of "CommandComplete" message of this query. command_complete_tag: Optional[dbstate.CommandCompleteTag] @@ -56,6 +62,7 @@ def resolve( if debug.flags.sql_input: debug.header('SQL Input') + debug_sql_text = pgcodegen.generate_source( query, reordered=True, pretty=True ) @@ -108,8 +115,25 @@ def resolve( ) debug.dump_code(debug_sql_text, lexer='sql') + if options.include_edgeql_io_format_alternative: + edgeql_output_format_ast = copy.copy(resolved) + if isinstance(edgeql_output_format_ast, pgast.SelectStmt): + edgeql_output_format_ast.target_list = [ + pgast.ResTarget( + val=pgast.RowExpr( + args=[ + rt.val + for rt in edgeql_output_format_ast.target_list + ] + ) + ) + ] + else: + edgeql_output_format_ast = None + return ResolvedSQL( ast=resolved, + edgeql_output_format_ast=edgeql_output_format_ast, command_complete_tag=command_complete_tag, params=ctx.query_params, ) diff --git a/edb/pgsql/resolver/command.py b/edb/pgsql/resolver/command.py index f1e9a9d427a..e4c6fe1d1a9 100644 --- a/edb/pgsql/resolver/command.py +++ b/edb/pgsql/resolver/command.py @@ -220,11 +220,7 @@ def _uncompile_dml_stmt(stmt: pgast.DMLQuery, *, ctx: Context): - ptr-s are (usually) pointers on the subject. """ - raise errors.QueryError( - f'{stmt.__class__.__name__} are not supported', - span=stmt.span, - pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, - ) + raise dispatch._raise_unsupported(stmt) def _uncompile_dml_subject( @@ -1978,8 +1974,11 @@ def _resolve_returning_rows( ) returning_table = context.Table() + names: Set[str] = set() for t in returning_list: - targets, columns = pg_res_expr.resolve_ResTarget(t, ctx=sctx) + targets, columns = pg_res_expr.resolve_ResTarget( + t, existing_names=names, ctx=sctx + ) returning_query.target_list.extend(targets) returning_table.columns.extend(columns) return returning_query, returning_table diff --git a/edb/pgsql/resolver/context.py b/edb/pgsql/resolver/context.py index 592cee58cf9..c228a57132f 100644 --- a/edb/pgsql/resolver/context.py +++ b/edb/pgsql/resolver/context.py @@ -52,6 +52,12 @@ class Options: # apply access policies to select & dml statements apply_access_policies: bool + # whether to generate an EdgeQL-compatible single-column output variant. + include_edgeql_io_format_alternative: Optional[bool] + + # makes sure that output does not contain duplicated column names + disambiguate_column_names: bool + @dataclass(kw_only=True) class Scope: diff --git a/edb/pgsql/resolver/dispatch.py b/edb/pgsql/resolver/dispatch.py index 8ca24edbc18..d3ec583fcf5 100644 --- a/edb/pgsql/resolver/dispatch.py +++ b/edb/pgsql/resolver/dispatch.py @@ -21,8 +21,11 @@ import functools import typing +import re +from edb.server.pgcon import errors as pgerror from edb.pgsql import ast as pgast +from edb import errors from . import context @@ -34,7 +37,8 @@ def _resolve( expr: pgast.Base, *, ctx: context.ResolverContextLevel ) -> pgast.Base: - raise ValueError(f'no SQL resolve handler for {expr.__class__}') + expr.dump() + _raise_unsupported(expr) def resolve(expr: Base_T, *, ctx: context.ResolverContextLevel) -> Base_T: @@ -85,7 +89,7 @@ def _resolve_relation( include_inherited: bool, ctx: context.ResolverContextLevel, ) -> typing.Tuple[pgast.BaseRelation, context.Table]: - raise ValueError(f'no SQL resolve handler for {rel.__class__}') + _raise_unsupported(rel) @_resolve.register @@ -96,3 +100,16 @@ def _resolve_BaseRelation( rel, _ = resolve_relation(rel, ctx=ctx) return rel + + +def _raise_unsupported(expr: pgast.Base) -> typing.Never: + pretty_name = expr.__class__.__name__ + pretty_name = pretty_name.removesuffix('Stmt') + # title case to spaces + pretty_name = re.sub(r'(? Optional[str]: def resolve_ResTarget( res_target: pgast.ResTarget, *, - existing_names: Optional[Set[str]] = None, + existing_names: Set[str], + ctx: Context, +) -> Tuple[Sequence[pgast.ResTarget], Sequence[context.Column]]: + targets, columns = _resolve_ResTarget( + res_target, existing_names=existing_names, ctx=ctx + ) + + return (targets, columns) + + +def _resolve_ResTarget( + res_target: pgast.ResTarget, + *, + existing_names: Set[str], ctx: Context, ) -> Tuple[Sequence[pgast.ResTarget], Sequence[context.Column]]: alias = infer_alias(res_target) @@ -86,11 +100,38 @@ def resolve_ResTarget( res = [] columns = [] for table, column in col_res: - columns.append(column) + val = resolve_column_kind(table, column.kind, ctx=ctx) + + # make sure name is not duplicated + # this behavior is technically different then Postgres, but EdgeDB + # protocol does not support duplicate names. And we doubt that + # anyone is depending on original behavior. + nam: str = column.name + if nam in existing_names: + # prefix with table name + rel_var_name = table.alias or table.name + if rel_var_name: + nam = rel_var_name + '_' + nam + if nam in existing_names: + if ctx.options.disambiguate_column_names: + raise errors.QueryError( + f'duplicate column name: `{nam}`', + span=res_target.span, + pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE, + ) + existing_names.add(nam) + res.append( pgast.ResTarget( - val=resolve_column_kind(table, column.kind, ctx=ctx), - name=column.name, + name=nam, + val=val, + ) + ) + columns.append( + context.Column( + name=nam, + hidden=column.hidden, + kind=column.kind, ) ) return (res, columns) @@ -106,14 +147,29 @@ def resolve_ResTarget( ): alias = static.name_in_pg_catalog(res_target.val.name) - if not res_target.name and existing_names and alias in existing_names: - # when a name already exists, don't infer the same name - # this behavior is technically different than Postgres, but it is also - # not documented and users should not be relying on it. - # It does help us in some cases (passing `SELECT a.id, b.id` into DML). - alias = None + if alias in existing_names: + # duplicate name + + if res_target.name: + # explicit duplicate name: error out + if ctx.options.disambiguate_column_names: + raise errors.QueryError( + f'duplicate column name: `{alias}`', + span=res_target.span, + pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE, + ) + else: + # inferred duplicate name: use generated alias instead + + # this behavior is technically different than Postgres, but it is + # also not documented and users should not be relying on it. + # It does help us in some cases + # (passing `SELECT a.id, b.id` into DML). + alias = None name: str = alias or ctx.alias_generator.get('col') + existing_names.add(name) + col = context.Column( name=name, kind=context.ColumnByName(reference_as=name) ) @@ -260,7 +316,9 @@ def _lookup_column( if not matched_columns: raise errors.QueryError( - f'cannot find column `{col_name}`', span=column_ref.span + f'cannot find column `{col_name}`', + span=column_ref.span, + pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE, ) # apply precedence @@ -454,6 +512,14 @@ def resolve_SortBy( common.versioned_schema('edgedbsql'), '_format_type', ), + ('pg_catalog', 'pg_get_constraintdef'): ( + common.versioned_schema('edgedbsql'), + 'pg_get_constraintdef', + ), + ('pg_get_constraintdef',): ( + common.versioned_schema('edgedbsql'), + 'pg_get_constraintdef', + ), } diff --git a/edb/pgsql/resolver/range_var.py b/edb/pgsql/resolver/range_var.py index 3ac53c09bda..63b1eacd4a8 100644 --- a/edb/pgsql/resolver/range_var.py +++ b/edb/pgsql/resolver/range_var.py @@ -298,19 +298,21 @@ def _resolve_RangeFunction( ) -> Tuple[pgast.BaseRangeVar, context.Table]: with ctx.lateral() if range_var.lateral else ctx.child() as subctx: - functions = [] + functions: List[pgast.BaseExpr] = [] col_names = [] for function in range_var.functions: - - name = function.name[len(function.name) - 1] - if name in range_functions.COLUMNS: - col_names.extend(range_functions.COLUMNS[name]) - elif name == 'unnest': - col_names.extend('unnest' for _ in function.args) - else: - col_names.append(name) - - functions.append(dispatch.resolve(function, ctx=subctx)) + match function: + case pgast.FuncCall(): + name = function.name[len(function.name) - 1] + if name in range_functions.COLUMNS: + col_names.extend(range_functions.COLUMNS[name]) + elif name == 'unnest': + col_names.extend('unnest' for _ in function.args) + else: + col_names.append(name) + functions.append(dispatch.resolve(function, ctx=subctx)) + case _: + functions.append(dispatch.resolve(function, ctx=subctx)) inferred_columns = [ context.Column( diff --git a/edb/pgsql/resolver/relation.py b/edb/pgsql/resolver/relation.py index b0249191e81..d40b843b170 100644 --- a/edb/pgsql/resolver/relation.py +++ b/edb/pgsql/resolver/relation.py @@ -133,6 +133,7 @@ def resolve_SelectStmt( targets, columns = expr.resolve_ResTarget( t, existing_names=names, ctx=ctx ) + target_list.extend(targets) table.columns.extend(columns) names.update(c.name for c in columns) diff --git a/edb/pgsql/resolver/static.py b/edb/pgsql/resolver/static.py index e221bc52c0f..ebfe711951e 100644 --- a/edb/pgsql/resolver/static.py +++ b/edb/pgsql/resolver/static.py @@ -312,7 +312,9 @@ def eval_FuncCall( return value raise errors.QueryError( - "function set_config is not supported", span=expr.span + "function set_config is not supported", + span=expr.span, + pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) if fn_name == 'current_setting': @@ -329,6 +331,7 @@ def eval_FuncCall( raise errors.QueryError( f"function pg_catalog.{fn_name} is not supported", span=expr.span, + pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) if fn_name == "pg_get_serial_sequence": diff --git a/edb/pgsql/types.py b/edb/pgsql/types.py index 9a2319c2e1a..f49033cf67c 100644 --- a/edb/pgsql/types.py +++ b/edb/pgsql/types.py @@ -68,6 +68,12 @@ ('edgedbt', 'date_duration_t'), s_obj.get_known_type_id('cfg::memory'): ('edgedbt', 'memory_t'), + + s_obj.get_known_type_id('std::pg::json'): ('json',), + s_obj.get_known_type_id('std::pg::timestamptz'): ('timestamptz',), + s_obj.get_known_type_id('std::pg::timestamp'): ('timestamp',), + s_obj.get_known_type_id('std::pg::date'): ('date',), + s_obj.get_known_type_id('std::pg::interval'): ('interval',), } type_to_range_name_map = { @@ -85,6 +91,9 @@ # custom range is a big hassle, and daterange already has the # correct canonicalization function ('edgedbt', 'date_t'): ('daterange',), + ('timestamptz',): ('tstzrange',), + ('timestamp',): ('tsrange',), + ('date',): ('daterange',), } # Construct a multirange map based on type_to_range_name_map by replacing @@ -143,6 +152,8 @@ 'edgedbt.memory_t': sn.QualName('cfg', 'memory'), 'memory_t': sn.QualName('cfg', 'memory'), + + 'json': sn.QualName('std::pg', 'json'), } pg_tsvector_typeref = irast.TypeRef( diff --git a/edb/protocol/messages.py b/edb/protocol/messages.py index 784fc2e7165..5985e5f02c5 100644 --- a/edb/protocol/messages.py +++ b/edb/protocol/messages.py @@ -488,6 +488,12 @@ def dump(self) -> bytes: ############################################################################### +class InputLanguage(enum.Enum): + + EDGEQL = 0x45 # b'E' + SQL = 0x53 # b'S' + + class OutputFormat(enum.Enum): BINARY = 0x62 @@ -789,6 +795,7 @@ class Parse(ClientMessage): compilation_flags = EnumOf(UInt64, CompilationFlag, 'A bit mask of query options.') implicit_limit = UInt64('Implicit LIMIT clause on returned sets.') + input_language = EnumOf(UInt8, InputLanguage, 'Command source language.') output_format = EnumOf(UInt8, OutputFormat, 'Data output format.') expected_cardinality = EnumOf(UInt8, Cardinality, 'Expected result cardinality.') @@ -807,6 +814,7 @@ class Execute(ClientMessage): compilation_flags = EnumOf(UInt64, CompilationFlag, 'A bit mask of query options.') implicit_limit = UInt64('Implicit LIMIT clause on returned sets.') + input_language = EnumOf(UInt8, InputLanguage, 'Command source language.') output_format = EnumOf(UInt8, OutputFormat, 'Data output format.') expected_cardinality = EnumOf(UInt8, Cardinality, 'Expected result cardinality.') diff --git a/edb/protocol/protocol.pyx b/edb/protocol/protocol.pyx index b91da984b0d..866e240e558 100644 --- a/edb/protocol/protocol.pyx +++ b/edb/protocol/protocol.pyx @@ -63,6 +63,7 @@ cdef class Connection: messages.Execute( annotations=[], command_text=query, + input_language=messages.InputLanguage.EDGEQL, output_format=messages.OutputFormat.NONE, expected_cardinality=messages.Cardinality.MANY, allowed_capabilities=messages.Capability.ALL, @@ -173,6 +174,7 @@ async def new_connection( tls_ca=tls_ca, tls_ca_file=tls_ca_file, tls_security=tls_security, + tls_server_name=None, wait_until_available=timeout, credentials=credentials, credentials_file=credentials_file, diff --git a/edb/schema/_types.py b/edb/schema/_types.py index d96435e99bf..bf4af583475 100644 --- a/edb/schema/_types.py +++ b/edb/schema/_types.py @@ -66,4 +66,14 @@ UUID('00000000-0000-0000-0000-000000000112'), sn.name_from_string('cfg::memory'): UUID('00000000-0000-0000-0000-000000000130'), + sn.name_from_string('std::pg::json'): + UUID('00000000-0000-0000-0000-000001000001'), + sn.name_from_string('std::pg::timestamptz'): + UUID('00000000-0000-0000-0000-000001000002'), + sn.name_from_string('std::pg::timestamp'): + UUID('00000000-0000-0000-0000-000001000003'), + sn.name_from_string('std::pg::date'): + UUID('00000000-0000-0000-0000-000001000004'), + sn.name_from_string('std::pg::interval'): + UUID('00000000-0000-0000-0000-000001000005'), } diff --git a/edb/schema/delta.py b/edb/schema/delta.py index af0f87d2eeb..220515671f0 100644 --- a/edb/schema/delta.py +++ b/edb/schema/delta.py @@ -3758,7 +3758,7 @@ def _delete_finalize( orig_schema = ctx.original_schema if refs: for ref in refs: - if (not context.is_deleting(ref) + if (not self._is_deleting_ref(schema, context, ref) and ref.is_blocking_ref(orig_schema, self.scls)): ref_strs.append( ref.get_verbosename(orig_schema, with_parent=True)) @@ -3781,6 +3781,21 @@ def _delete_finalize( return schema + def _is_deleting_ref( + self, + schema: s_schema.Schema, + context: CommandContext, + ref: so.Object, + ) -> bool: + if context.is_deleting(ref): + return True + + for op in self.get_prerequisites(): + if isinstance(op, DeleteObject) and op.scls == ref: + return True + + return False + def _has_outside_references( self, schema: s_schema.Schema, diff --git a/edb/schema/links.py b/edb/schema/links.py index b455df2820c..2a825e5f9c9 100644 --- a/edb/schema/links.py +++ b/edb/schema/links.py @@ -415,7 +415,7 @@ def _apply_field_ast( op.new_value.resolve(schema) if isinstance(op.new_value, so.ObjectShell) else op.new_value) - + assert isinstance(new_type, s_types.Type) new_type_ast = utils.typeref_to_ast(schema, op.new_value) cast_expr = None # If the type isn't assignment castable, generate a diff --git a/edb/schema/objects.py b/edb/schema/objects.py index 8b5585ad32c..455af155156 100644 --- a/edb/schema/objects.py +++ b/edb/schema/objects.py @@ -2401,10 +2401,9 @@ def __reduce__(self) -> Tuple[ else: typeargs = types[0] if len(types) == 1 else types attrs = {k: getattr(self, k) for k in self.__slots__ if k != '_ids'} - # Mypy fails to resolve typeargs properly return ( cls.__restore__, - (typeargs, tuple(self._ids), attrs) # type: ignore + (typeargs, tuple(self._ids), attrs) ) @classmethod @@ -3222,7 +3221,7 @@ def get_explicit_local_field_value( def allow_ref_propagation( self, schema: s_schema.Schema, - constext: sd.CommandContext, + context: sd.CommandContext, refdict: RefDict, ) -> bool: return True diff --git a/edb/schema/objtypes.py b/edb/schema/objtypes.py index 61bb8f9cdc4..439f6c3ca18 100644 --- a/edb/schema/objtypes.py +++ b/edb/schema/objtypes.py @@ -324,7 +324,7 @@ def _issubclass( def allow_ref_propagation( self, schema: s_schema.Schema, - constext: sd.CommandContext, + context: sd.CommandContext, refdict: so.RefDict, ) -> bool: return not self.is_view(schema) or refdict.attr == 'pointers' @@ -333,9 +333,7 @@ def as_type_delete_if_unused( self, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[ObjectType]]: - if not schema.get_by_id(self.id, default=None): - # this type was already deleted by some other op - # (probably alias types cleanup) + if not self._is_deletable(schema): return None # References to aliases can only occur inside other aliases, diff --git a/edb/schema/pointers.py b/edb/schema/pointers.py index 5fb90f7111e..eb9056d0eac 100644 --- a/edb/schema/pointers.py +++ b/edb/schema/pointers.py @@ -857,7 +857,7 @@ def has_user_defined_properties(self, schema: s_schema.Schema) -> bool: def allow_ref_propagation( self, schema: s_schema.Schema, - constext: sd.CommandContext, + context: sd.CommandContext, refdict: so.RefDict, ) -> bool: object_type = self.get_source(schema) @@ -3257,4 +3257,37 @@ def get_or_create_intersection_pointer( transient=transient, ) + # We want to transform all the computables in the list of the + # components to their respective owned computables. This is to + # ensure that mixing multiple inherited copies of the same + # computable is actually allowed. + comp_set = set() + for c in components: + if c.is_pure_computable(schema): + comp_set.add(_get_nearest_owned(schema, c)) + else: + comp_set.add(c) + components = list(comp_set) + + if ( + any(p.is_pure_computable(schema) for p in components) + and len(components) > 1 + and ptrname.name not in ('__tname__', '__tid__') + ): + p = components[0] + raise errors.SchemaError( + f'it is illegal to create a type intersection that causes ' + f'a computed {p.get_verbosename(schema)} to mix ' + f'with other versions of the same {p.get_verbosename(schema)}', + ) + + if len({p.get_cardinality(schema) for p in components}) > 1: + p = components[0] + raise errors.SchemaError( + f'it is illegal to create a type intersection that causes ' + f'a {p.get_verbosename(schema)} to mix ' + f'with other versions of {p.get_verbosename(schema)} ' + f'which have a different cardinality', + ) + return schema, result diff --git a/edb/schema/properties.py b/edb/schema/properties.py index 5b1dd2a105c..f3cf6b90b0f 100644 --- a/edb/schema/properties.py +++ b/edb/schema/properties.py @@ -146,7 +146,7 @@ def is_link_property(self, schema: s_schema.Schema) -> bool: def allow_ref_propagation( self, schema: s_schema.Schema, - constext: sd.CommandContext, + context: sd.CommandContext, refdict: so.RefDict, ) -> bool: source = self.get_source(schema) diff --git a/edb/schema/referencing.py b/edb/schema/referencing.py index 1dfb6e87233..c65b251a15b 100644 --- a/edb/schema/referencing.py +++ b/edb/schema/referencing.py @@ -336,9 +336,9 @@ def derive_ref( cmdcls = sd.AlterObject if existing is not None else sd.CreateObject cmd: sd.ObjectCommand[ReferencedInheritingObjectT] = ( - sd.get_object_delta_command( + sd.get_object_delta_command( # type: ignore[type-var, assignment] objtype=type(self), - cmdtype=cmdcls, # type: ignore[arg-type] + cmdtype=cmdcls, schema=schema, name=derived_name, ) diff --git a/edb/schema/scalars.py b/edb/schema/scalars.py index 81d422c9923..ec6a622c8b5 100644 --- a/edb/schema/scalars.py +++ b/edb/schema/scalars.py @@ -595,6 +595,26 @@ def _cmd_tree_from_ast( return cmd + def _create_begin( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + schema = super()._create_begin(schema, context) + if ( + not context.canonical + and not self.scls.get_abstract(schema) + and not self.scls.get_transient(schema) + ): + # Create an array type for this scalar eagerly. + # We mostly do this so that we know the `backend_id` + # of the array type when running translation of SQL + # involving arrays of scalars. + schema2, arr_t = s_types.Array.from_subtypes(schema, [self.scls]) + self.add_caused(arr_t.as_shell(schema2).as_create_delta(schema2)) + + return schema + def validate_create( self, schema: s_schema.Schema, @@ -819,3 +839,19 @@ def _get_ast( return None else: return super()._get_ast(schema, context, parent_node=parent_node) + + def _delete_begin( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + if not context.canonical: + schema2, arr_typ = s_types.Array.from_subtypes(schema, [self.scls]) + arr_op = arr_typ.init_delta_command( + schema2, + sd.DeleteObject, + if_exists=True, + ) + self.add_prerequisite(arr_op) + + return super()._delete_begin(schema, context) diff --git a/edb/schema/types.py b/edb/schema/types.py index 1c1abfd9a31..0ac5926e6c4 100644 --- a/edb/schema/types.py +++ b/edb/schema/types.py @@ -533,6 +533,14 @@ def as_type_delete_if_unused( return None + def _is_deletable( + self, + schema: s_schema.Schema, + ) -> bool: + # this type was already deleted by some other op + # (probably alias types cleanup) + return schema.get_by_id(self.id, default=None) is not None + class QualifiedType(so.QualifiedObject, Type): pass @@ -988,7 +996,14 @@ def as_delete_delta( assert isinstance(delta, sd.DeleteObject) if not isinstance(self, CollectionExprAlias): delta.if_exists = True - delta.if_unused = True + if not ( + isinstance(self, Array) + and self.get_element_type(schema).is_scalar() + ): + # Arrays of scalars are special, because we create them + # implicitly and overload reference checks to never + # delete them unless the scalar is also deleted. + delta.if_unused = True return delta @classmethod @@ -1140,9 +1155,7 @@ def as_type_delete_if_unused( self: CollectionTypeT, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[CollectionTypeT]]: - if not schema.get_by_id(self.id, default=None): - # this type was already deleted by some other op - # (probably alias types cleanup) + if not self._is_deletable(schema): return None return self.init_delta_command( @@ -1202,9 +1215,7 @@ def as_type_delete_if_unused( self: CollectionExprAliasT, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[CollectionExprAliasT]]: - if not schema.get_by_id(self.id, default=None): - # this type was already deleted by some other op - # (probably alias types cleanup) + if not self._is_deletable(schema): return None cmd = self.init_delta_command(schema, sd.DeleteObject, if_exists=True) @@ -3365,7 +3376,21 @@ class DeleteTupleExprAlias(DeleteCollectionExprAlias[TupleExprAlias]): class DeleteArray(DeleteCollectionType[Array]): - pass + # Prevent array types from getting deleted unless the element + # type is being deleted too. + def _has_outside_references( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> bool: + if super()._has_outside_references(schema, context): + return True + + el_type = self.scls.get_element_type(schema) + if el_type.is_scalar() and not context.is_deleting(el_type): + return True + + return False class DeleteArrayExprAlias(DeleteCollectionExprAlias[ArrayExprAlias]): diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index a2e94db3a33..153a225251f 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -173,7 +173,7 @@ async def _retry_conn_errors( return result - async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: + async def sql_execute(self, sql: bytes) -> None: async def _task() -> None: assert self._conn is not None await self._conn.sql_execute(sql) @@ -181,7 +181,7 @@ async def _task() -> None: async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: @@ -634,8 +634,8 @@ def compile_single_query( ) -> str: ql_source = edgeql.Source.from_string(eql) units = edbcompiler.compile(ctx=compilerctx, source=ql_source).units - assert len(units) == 1 and len(units[0].sql) == 1 - return units[0].sql[0].decode() + assert len(units) == 1 + return units[0].sql.decode() def _get_all_subcommands( @@ -687,7 +687,7 @@ def prepare_repair_patch( schema_class_layout: s_refl.SchemaClassLayout, backend_params: params.BackendRuntimeParams, config: Any, -) -> tuple[bytes, ...]: +) -> bytes: compiler = edbcompiler.new_compiler( std_schema=stdschema, reflection_schema=reflschema, @@ -701,7 +701,7 @@ def prepare_repair_patch( ) res = edbcompiler.repair_schema(compilerctx) if not res: - return () + return b"" sql, _, _ = res return sql @@ -2026,45 +2026,65 @@ def compile_sys_queries( # The code below re-syncs backend_id properties of Gel builtin # types with the actual OIDs in the DB. backend_id_fixup_edgeql = ''' - WITH - _ := ( - UPDATE {schema::ScalarType, schema::Tuple} - FILTER - NOT (.abstract ?? False) - AND NOT (.transient ?? False) - SET { - backend_id := sys::_get_pg_type_for_edgedb_type( - .id, - .__type__.name, - {}, - [is schema::ScalarType].sql_type ?? ( - select [is schema::ScalarType] - .bases[is schema::ScalarType] limit 1 - ).sql_type, - ) - } - ), - _ := ( - UPDATE {schema::Array, schema::Range, schema::MultiRange} - FILTER - NOT (.abstract ?? False) - AND NOT (.transient ?? False) - SET { - backend_id := sys::_get_pg_type_for_edgedb_type( - .id, - .__type__.name, - .element_type.id, - {}, - ) - } - ), - SELECT 1; + UPDATE schema::ScalarType + FILTER + NOT (.abstract ?? False) + AND NOT (.transient ?? False) + SET { + backend_id := sys::_get_pg_type_for_edgedb_type( + .id, + .__type__.name, + {}, + [is schema::ScalarType].sql_type ?? ( + select [is schema::ScalarType] + .bases[is schema::ScalarType] limit 1 + ).sql_type, + ) + }; + UPDATE schema::Tuple + FILTER + NOT (.abstract ?? False) + AND NOT (.transient ?? False) + SET { + backend_id := sys::_get_pg_type_for_edgedb_type( + .id, + .__type__.name, + {}, + [is schema::ScalarType].sql_type ?? ( + select [is schema::ScalarType] + .bases[is schema::ScalarType] limit 1 + ).sql_type, + ) + }; + UPDATE {schema::Range, schema::MultiRange} + FILTER + NOT (.abstract ?? False) + AND NOT (.transient ?? False) + SET { + backend_id := sys::_get_pg_type_for_edgedb_type( + .id, + .__type__.name, + .element_type.id, + {}, + ) + }; + UPDATE schema::Array + FILTER + NOT (.abstract ?? False) + AND NOT (.transient ?? False) + SET { + backend_id := sys::_get_pg_type_for_edgedb_type( + .id, + .__type__.name, + .element_type.id, + {}, + ) + }; ''' _, sql = compile_bootstrap_script( compiler, schema, backend_id_fixup_edgeql, - expected_cardinality_one=True, ) queries['backend_id_fixup'] = sql @@ -2091,10 +2111,10 @@ def compile_sys_queries( ), source=edgeql.Source.from_string(report_configs_query), ).units - assert len(units) == 1 and len(units[0].sql) == 1 + assert len(units) == 1 report_configs_typedesc_2_0 = units[0].out_type_id + units[0].out_type_data - queries['report_configs'] = units[0].sql[0].decode() + queries['report_configs'] = units[0].sql.decode() units = edbcompiler.compile( ctx=edbcompiler.new_compiler_context( @@ -2108,7 +2128,7 @@ def compile_sys_queries( ), source=edgeql.Source.from_string(report_configs_query), ).units - assert len(units) == 1 and len(units[0].sql) == 1 + assert len(units) == 1 report_configs_typedesc_1_0 = units[0].out_type_id + units[0].out_type_data return ( diff --git a/edb/server/compiler/__init__.py b/edb/server/compiler/__init__.py index 77be86bae03..3931622c080 100644 --- a/edb/server/compiler/__init__.py +++ b/edb/server/compiler/__init__.py @@ -27,7 +27,7 @@ from .compiler import maybe_force_database_error from .dbstate import QueryUnit, QueryUnitGroup from .enums import Capability, Cardinality -from .enums import InputFormat, OutputFormat +from .enums import InputFormat, OutputFormat, InputLanguage from .explain import analyze_explain_output from .ddl import repair_schema from .rpc import CompilationRequest @@ -43,6 +43,7 @@ 'QueryUnitGroup', 'Capability', 'InputFormat', + 'InputLanguage', 'OutputFormat', 'analyze_explain_output', 'compile_edgeql_script', diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 4f922a5c5ce..b5bd9234341 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -60,6 +60,7 @@ from edb.common import uuidgen from edb.edgeql import ast as qlast +from edb.edgeql import codegen as qlcodegen from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes @@ -118,7 +119,7 @@ class CompilerDatabaseState: cached_reflection: immutables.Map[str, Tuple[str, ...]] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class CompileContext: compiler_state: CompilerState @@ -145,6 +146,8 @@ class CompileContext: log_ddl_as_migrations: bool = True dump_restore_mode: bool = False notebook: bool = False + branch_name: Optional[str] = None + role_name: Optional[str] = None cache_key: Optional[uuid.UUID] = None def get_cache_mode(self) -> config.QueryCacheMode: @@ -429,7 +432,7 @@ def _try_compile_rollback( sql = b'ROLLBACK;' unit = dbstate.QueryUnit( status=b'ROLLBACK', - sql=(sql,), + sql=sql, tx_rollback=True, cacheable=False) @@ -437,7 +440,7 @@ def _try_compile_rollback( sql = f'ROLLBACK TO {pg_common.quote_ident(stmt.name)};'.encode() unit = dbstate.QueryUnit( status=b'ROLLBACK TO SAVEPOINT', - sql=(sql,), + sql=sql, tx_savepoint_rollback=True, sp_name=stmt.name, cacheable=False) @@ -563,6 +566,9 @@ def compile_sql( current_user=current_user, allow_user_specified_id=allow_user_specified_id, apply_access_policies_sql=apply_access_policies_sql, + disambiguate_column_names=False, + backend_runtime_params=self.state.backend_runtime_params, + protocol_version=(-3, 0), # emulated PG binary protocol version ) def compile_serialized_request( @@ -583,7 +589,7 @@ def compile_serialized_request( self.state.compilation_config_serializer, ) - units, cstate = self.compile( + return self.compile( user_schema=user_schema, global_schema=global_schema, reflection_cache=reflection_cache, @@ -591,7 +597,6 @@ def compile_serialized_request( system_config=system_config, request=request, ) - return units, cstate def compile( self, @@ -641,10 +646,21 @@ def compile( json_parameters=request.input_format is enums.InputFormat.JSON, source=request.source, protocol_version=request.protocol_version, + role_name=request.role_name, + branch_name=request.branch_name, cache_key=request.get_cache_key(), ) - unit_group = compile(ctx=ctx, source=request.source) + match request.input_language: + case enums.InputLanguage.EDGEQL: + unit_group = compile(ctx=ctx, source=request.source) + case enums.InputLanguage.SQL: + unit_group = compile_sql_as_unit_group( + ctx=ctx, source=request.source) + case _: + raise NotImplementedError( + f"unnsupported input language: {request.input_language}") + tx_started = False for unit in unit_group: if unit.tx_id: @@ -727,8 +743,17 @@ def compile_in_tx( cache_key=request.get_cache_key(), ) - units = compile(ctx=ctx, source=request.source) - return units, ctx.state + match request.input_language: + case enums.InputLanguage.EDGEQL: + unit_group = compile(ctx=ctx, source=request.source) + case enums.InputLanguage.SQL: + unit_group = compile_sql_as_unit_group( + ctx=ctx, source=request.source) + case _: + raise NotImplementedError( + f"unnsupported input language: {request.input_language}") + + return unit_group, ctx.state def interpret_backend_error( self, @@ -905,6 +930,62 @@ def describe_database_dump( blocks=descriptors, ) + def _reprocess_restore_config( + self, + stmts: list[qlast.Base], + ) -> list[qlast.Base]: + '''Do any rewrites to the restore script needed. + + This is intended to patch over certain backwards incompatible + changes to config. We try not to do that too much, but when we + do, dumps still need to work. + ''' + + new_stmts = [] + smtp_config = {} + + for stmt in stmts: + # ext::auth::SMTPConfig got removed and moved into a cfg + # object, so intercept those and rewrite them. + if ( + isinstance(stmt, qlast.ConfigSet) + and stmt.name.module == 'ext::auth::SMTPConfig' + ): + smtp_config[stmt.name.name] = stmt.expr + else: + new_stmts.append(stmt) + + if smtp_config: + # Do the rewrite of SMTPConfig + smtp_config['name'] = qlast.Constant.string('_default') + + new_stmts.append( + qlast.ConfigInsert( + scope=qltypes.ConfigScope.DATABASE, + name=qlast.ObjectRef( + module='cfg', name='SMTPProviderConfig' + ), + shape=[ + qlast.ShapeElement( + expr=qlast.Path(steps=[qlast.Ptr(name=name)]), + compexpr=expr, + ) + for name, expr in smtp_config.items() + ], + ) + ) + new_stmts.append( + qlast.ConfigSet( + scope=qltypes.ConfigScope.DATABASE, + name=qlast.ObjectRef( + name='current_email_provider_name' + ), + expr=qlast.Constant.string('_default'), + ) + ) + + return new_stmts + def describe_database_restore( self, user_schema_pickle: bytes, @@ -984,7 +1065,11 @@ def describe_database_restore( # The state serializer generated below is somehow inappropriate, # so it's simply ignored here and the I/O process will do it on its own - units = compile(ctx=ctx, source=ddl_source).units + statements = edgeql.parse_block(ddl_source) + statements = self._reprocess_restore_config(statements) + units = _try_compile_ast( + ctx=ctx, source=ddl_source, statements=statements + ).units _check_force_database_error(ctx, scope='restore') @@ -1170,6 +1255,30 @@ def analyze_explain_output( return explain.analyze_explain_output( query_asts_pickled, data, self.state.std_schema) + def validate_schema_equivalence( + self, + schema_a: bytes, + schema_b: bytes, + global_schema: bytes, + conn_state_pickle: Any, + ) -> None: + if conn_state_pickle: + conn_state = pickle.loads(conn_state_pickle) + if ( + conn_state + and ( + conn_state.current_tx().get_migration_state() + or conn_state.current_tx().get_migration_rewrite_state() + ) + ): + return + ddl.validate_schema_equivalence( + self.state, + pickle.loads(schema_a), + pickle.loads(schema_b), + pickle.loads(global_schema), + ) + def compile_schema_storage_in_delta( ctx: CompileContext, @@ -1292,12 +1401,11 @@ def _compile_schema_storage_stmt( sql_stmts = [] for u in unit_group: - for stmt in u.sql: - stmt = stmt.strip() - if not stmt.endswith(b';'): - stmt += b';' + stmt = u.sql.strip() + if not stmt.endswith(b';'): + stmt += b';' - sql_stmts.append(stmt) + sql_stmts.append(stmt) if len(sql_stmts) > 1: raise errors.InternalServerError( @@ -1332,12 +1440,11 @@ def _compile_ql_script( sql_stmts = [] for u in unit_group: - for stmt in u.sql: - stmt = stmt.strip() - if not stmt.endswith(b';'): - stmt += b';' + stmt = u.sql.strip() + if not stmt.endswith(b';'): + stmt += b';' - sql_stmts.append(stmt) + sql_stmts.append(stmt) return b'\n'.join(sql_stmts).decode() @@ -1461,7 +1568,7 @@ def _compile_ql_explain( span=ql.span, ) - assert len(query.sql) == 1, query.sql + assert query.sql out_type_data, out_type_id = sertypes.describe( schema, @@ -1469,7 +1576,7 @@ def _compile_ql_explain( protocol_version=ctx.protocol_version, ) - sql_bytes = exp_command.encode('utf-8') + query.sql[0] + sql_bytes = exp_command.encode('utf-8') + query.sql sql_hash = _hash_sql( sql_bytes, mode=str(ctx.output_format).encode(), @@ -1479,9 +1586,9 @@ def _compile_ql_explain( return dataclasses.replace( query, is_explain=True, - append_rollback=args['execute'], + run_and_rollback=args['execute'], cacheable=False, - sql=(sql_bytes,), + sql=sql_bytes, sql_hash=sql_hash, cardinality=enums.Cardinality.ONE, out_type_data=out_type_data, @@ -1507,7 +1614,7 @@ def _compile_ql_administer( span=ql.expr.span, ) - return dbstate.MaintenanceQuery(sql=(b'ANALYZE',)) + return dbstate.MaintenanceQuery(sql=b'ANALYZE') elif ql.expr.func == 'schema_repair': return ddl.administer_repair_schema(ctx, ql) elif ql.expr.func == 'reindex': @@ -1537,6 +1644,52 @@ def _compile_ql_query( is_explain = explain_data is not None current_tx = ctx.state.current_tx() + sql_info: Dict[str, Any] = {} + if ( + not ctx.bootstrap_mode + and ctx.backend_runtime_params.has_stat_statements + and not ctx.schema_reflection_mode + ): + spec = ctx.compiler_state.config_spec + cconfig = config.to_json_obj( + spec, + { + **current_tx.get_system_config(), + **current_tx.get_database_config(), + **current_tx.get_session_config(), + }, + setting_filter=lambda v: v.name in spec + and spec[v.name].affects_compilation, + include_source=False, + ) + extras: Dict[str, Any] = { + 'cc': dict(sorted(cconfig.items())), # compilation_config + 'pv': ctx.protocol_version, # protocol_version + 'of': ctx.output_format, # output_format + 'e1': ctx.expected_cardinality_one, # expect_one + 'il': ctx.implicit_limit, # implicit_limit + 'ii': ctx.inline_typeids, # inline_typeids + 'in': ctx.inline_typenames, # inline_typenames + 'io': ctx.inline_objectids, # inline_objectids + } + modaliases = dict(current_tx.get_modaliases()) + # dn: default_namespace + extras['dn'] = modaliases.pop(None, defines.DEFAULT_MODULE_ALIAS) + if modaliases: + # na: namespace_aliases + extras['na'] = dict(sorted(modaliases.items())) + + sql_info.update({ + 'query': qlcodegen.generate_source(ql), + 'type': defines.QueryType.EdgeQL, + 'extras': json.dumps(extras), + }) + id_hash = hashlib.blake2b(digest_size=16) + id_hash.update( + json.dumps(sql_info).encode(defines.EDGEDB_ENCODING) + ) + sql_info['id'] = str(uuidgen.from_bytes(id_hash.digest())) + base_schema = ( ctx.compiler_state.std_schema if not _get_config_val(ctx, '__internal_query_reflschema') @@ -1602,12 +1755,11 @@ def _compile_ql_query( # If requested, embed the EdgeQL text in the SQL. if debug.flags.edgeql_text_in_sql and source: - sql_debug_obj = dict(edgeql=source.text()) - sql_debug_prefix = '-- ' + json.dumps(sql_debug_obj) + '\n' - sql_text = sql_debug_prefix + sql_text - if func_call_sql is not None: - func_call_sql = sql_debug_prefix + func_call_sql - sql_bytes = sql_text.encode(defines.EDGEDB_ENCODING) + sql_info['edgeql'] = source.text() + if sql_info: + sql_info_prefix = '-- ' + json.dumps(sql_info) + '\n' + else: + sql_info_prefix = '' globals = None if ir.globals: @@ -1639,21 +1791,23 @@ def _compile_ql_query( ) sql_hash = _hash_sql( - sql_bytes, + sql_text.encode(defines.EDGEDB_ENCODING), mode=str(ctx.output_format).encode(), intype=in_type_id.bytes, outtype=out_type_id.bytes) cache_func_call = None if func_call_sql is not None: - func_call_sql_bytes = func_call_sql.encode(defines.EDGEDB_ENCODING) func_call_sql_hash = _hash_sql( - func_call_sql_bytes, + func_call_sql.encode(defines.EDGEDB_ENCODING), mode=str(ctx.output_format).encode(), intype=in_type_id.bytes, outtype=out_type_id.bytes, ) - cache_func_call = (func_call_sql_bytes, func_call_sql_hash) + cache_func_call = ( + (sql_info_prefix + func_call_sql).encode(defines.EDGEDB_ENCODING), + func_call_sql_hash, + ) if is_explain: if isinstance(ir.schema, s_schema.ChainedSchema): @@ -1668,7 +1822,7 @@ def _compile_ql_query( query_asts = None return dbstate.Query( - sql=(sql_bytes,), + sql=(sql_info_prefix + sql_text).encode(defines.EDGEDB_ENCODING), sql_hash=sql_hash, cache_sql=cache_sql, cache_func_call=cache_func_call, @@ -1861,7 +2015,7 @@ def _compile_ql_transaction( if ql.deferrable is not None: sqls += f' {ql.deferrable.value}' sqls += ';' - sql = (sqls.encode(),) + sql = sqls.encode() action = dbstate.TxAction.START cacheable = False @@ -1877,7 +2031,7 @@ def _compile_ql_transaction( new_state = ctx.state.commit_tx() modaliases = new_state.modaliases - sql = (b'COMMIT',) + sql = b'COMMIT' cacheable = False action = dbstate.TxAction.COMMIT @@ -1885,7 +2039,7 @@ def _compile_ql_transaction( new_state = ctx.state.rollback_tx() modaliases = new_state.modaliases - sql = (b'ROLLBACK',) + sql = b'ROLLBACK' cacheable = False action = dbstate.TxAction.ROLLBACK @@ -1894,7 +2048,7 @@ def _compile_ql_transaction( sp_id = tx.declare_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) - sql = (f'SAVEPOINT {pgname}'.encode(),) + sql = f'SAVEPOINT {pgname}'.encode() cacheable = False action = dbstate.TxAction.DECLARE_SAVEPOINT @@ -1904,7 +2058,7 @@ def _compile_ql_transaction( elif isinstance(ql, qlast.ReleaseSavepoint): ctx.state.current_tx().release_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) - sql = (f'RELEASE SAVEPOINT {pgname}'.encode(),) + sql = f'RELEASE SAVEPOINT {pgname}'.encode() action = dbstate.TxAction.RELEASE_SAVEPOINT elif isinstance(ql, qlast.RollbackToSavepoint): @@ -1913,7 +2067,7 @@ def _compile_ql_transaction( modaliases = new_state.modaliases pgname = pg_common.quote_ident(ql.name) - sql = (f'ROLLBACK TO SAVEPOINT {pgname};'.encode(),) + sql = f'ROLLBACK TO SAVEPOINT {pgname};'.encode() cacheable = False action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT sp_name = ql.name @@ -1972,9 +2126,7 @@ def _compile_ql_sess_state( ctx.state.current_tx().update_modaliases(aliases) - return dbstate.SessionStateQuery( - sql=(), - ) + return dbstate.SessionStateQuery() def _get_config_spec( @@ -2146,9 +2298,7 @@ def _compile_ql_config_op( if pretty: debug.dump_code(sql_text, lexer='sql') - sql: tuple[bytes, ...] = ( - sql_text.encode(), - ) + sql = sql_text.encode() in_type_args, in_type_data, in_type_id = describe_params( ctx, ir, sql_res.argmap, None @@ -2325,6 +2475,132 @@ def compile( raise original_err +def compile_sql_as_unit_group( + *, + ctx: CompileContext, + source: edgeql.Source, +) -> dbstate.QueryUnitGroup: + + setting = _get_config_val(ctx, 'allow_user_specified_id') + allow_user_specified_id = None + if setting: + allow_user_specified_id = sql.is_setting_truthy(setting) + + apply_access_policies_sql = None + setting = _get_config_val(ctx, 'apply_access_policies_sql') + if setting: + apply_access_policies_sql = sql.is_setting_truthy(setting) + + tx_state = ctx.state.current_tx() + schema = tx_state.get_schema(ctx.compiler_state.std_schema) + + settings = dbstate.DEFAULT_SQL_FE_SETTINGS + sql_tx_state = dbstate.SQLTransactionState( + in_tx=not tx_state.is_implicit(), + settings=settings, + in_tx_settings=settings, + in_tx_local_settings=settings, + savepoints=[ + (not_none(tx.name), settings, settings) + for tx in tx_state._savepoints.values() + ], + ) + + sql_units = sql.compile_sql( + source.text(), + schema=schema, + tx_state=sql_tx_state, + prepared_stmt_map={}, + current_database=ctx.branch_name or "", + current_user=ctx.role_name or "", + allow_user_specified_id=allow_user_specified_id, + apply_access_policies_sql=apply_access_policies_sql, + include_edgeql_io_format_alternative=True, + allow_prepared_statements=False, + disambiguate_column_names=True, + backend_runtime_params=ctx.backend_runtime_params, + protocol_version=ctx.protocol_version, + ) + + qug = dbstate.QueryUnitGroup( + cardinality=sql_units[-1].cardinality, + cacheable=False, + ) + + for sql_unit in sql_units: + if sql_unit.eql_format_query is not None: + value_sql = sql_unit.eql_format_query.encode("utf-8") + intro_sql = sql_unit.query.encode("utf-8") + else: + value_sql = sql_unit.query.encode("utf-8") + intro_sql = None + if isinstance(sql_unit.command_complete_tag, dbstate.TagPlain): + status = sql_unit.command_complete_tag.tag + elif isinstance( + sql_unit.command_complete_tag, + (dbstate.TagCountMessages, dbstate.TagUnpackRow), + ): + status = sql_unit.command_complete_tag.prefix.encode("utf-8") + elif sql_unit.command_complete_tag is None: + status = b"SELECT" # XXX + else: + raise AssertionError( + f"unexpected SQLQueryUnit.command_complete_tag type: " + f"{sql_unit.command_complete_tag}" + ) + unit = dbstate.QueryUnit( + sql=value_sql, + introspection_sql=intro_sql, + status=status, + cardinality=sql_unit.cardinality, + capabilities=sql_unit.capabilities, + globals=[ + (str(sp.global_name), False) for sp in sql_unit.params + if isinstance(sp, dbstate.SQLParamGlobal) + ] if sql_unit.params else [], + output_format=( + enums.OutputFormat.NONE + if sql_unit.cardinality is enums.Cardinality.NO_RESULT + else enums.OutputFormat.BINARY + ), + ) + match sql_unit.tx_action: + case dbstate.TxAction.START: + ctx.state.start_tx() + tx_state = ctx.state.current_tx() + unit.tx_id = tx_state.id + case dbstate.TxAction.COMMIT: + ctx.state.commit_tx() + unit.tx_commit = True + case dbstate.TxAction.ROLLBACK: + ctx.state.rollback_tx() + unit.tx_rollback = True + case dbstate.TxAction.DECLARE_SAVEPOINT: + assert sql_unit.sp_name is not None + unit.tx_savepoint_declare = True + unit.sp_id = tx_state.declare_savepoint(sql_unit.sp_name) + unit.sp_name = sql_unit.sp_name + case dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: + assert sql_unit.sp_name is not None + tx_state.rollback_to_savepoint(sql_unit.sp_name) + unit.tx_savepoint_rollback = True + unit.sp_name = sql_unit.sp_name + case dbstate.TxAction.RELEASE_SAVEPOINT: + assert sql_unit.sp_name is not None + tx_state.release_savepoint(sql_unit.sp_name) + unit.sp_name = sql_unit.sp_name + case None: + pass + case _: + raise AssertionError( + f"unexpected SQLQueryUnit.tx_action: {sql_unit.tx_action}" + ) + + qug.append(unit) + + return qug + + def _try_compile( *, ctx: CompileContext, @@ -2339,8 +2615,25 @@ def _try_compile( if text.startswith(sentinel): time.sleep(float(text[len(sentinel):text.index("\n")])) - default_cardinality = enums.Cardinality.NO_RESULT statements = edgeql.parse_block(source) + return _try_compile_ast(statements=statements, source=source, ctx=ctx) + + +def _try_compile_ast( + *, + ctx: CompileContext, + statements: list[qlast.Base], + source: edgeql.Source, +) -> dbstate.QueryUnitGroup: + if _get_config_val(ctx, '__internal_testmode'): + # This is a bad but simple way to emulate a slow compilation for tests. + # Ideally, we should have a testmode function that is hooked to sleep + # as `simple_special_case`, or wait for a notification from the test. + sentinel = "# EDGEDB_TEST_COMPILER_SLEEP = " + text = source.text() + if text.startswith(sentinel): + time.sleep(float(text[len(sentinel):text.index("\n")])) + statements_len = len(statements) if not len(statements): # pragma: no cover @@ -2375,15 +2668,6 @@ def _try_compile( _check_force_database_error(stmt_ctx, stmt) - # Initialize user_schema_version with the version this query is - # going to be compiled upon. This can be overwritten later by DDLs. - try: - schema_version = _get_schema_version( - stmt_ctx.state.current_tx().get_user_schema() - ) - except errors.InvalidReferenceError: - schema_version = None - comp, capabilities = _compile_dispatch_ql( stmt_ctx, stmt, @@ -2392,234 +2676,21 @@ def _try_compile( in_script=is_script, ) - unit = dbstate.QueryUnit( - sql=(), - status=status.get_status(stmt), - cardinality=default_cardinality, + unit, user_schema = _make_query_unit( + ctx=ctx, + stmt_ctx=stmt_ctx, + stmt=stmt, + is_script=is_script, + is_trailing_stmt=is_trailing_stmt, + comp=comp, capabilities=capabilities, - output_format=stmt_ctx.output_format, - cache_key=ctx.cache_key, - user_schema_version=schema_version, - warnings=comp.warnings, ) - if not comp.is_transactional: - if is_script: - raise errors.QueryError( - f'cannot execute {status.get_status(stmt).decode()} ' - f'with other commands in one block', - span=stmt.span, - ) - - if not ctx.state.current_tx().is_implicit(): - raise errors.QueryError( - f'cannot execute {status.get_status(stmt).decode()} ' - f'in a transaction', - span=stmt.span, - ) - - unit.is_transactional = False - - if isinstance(comp, dbstate.Query): - unit.sql = comp.sql - unit.cache_sql = comp.cache_sql - unit.cache_func_call = comp.cache_func_call - unit.globals = comp.globals - unit.in_type_args = comp.in_type_args - - unit.sql_hash = comp.sql_hash - - unit.out_type_data = comp.out_type_data - unit.out_type_id = comp.out_type_id - unit.in_type_data = comp.in_type_data - unit.in_type_id = comp.in_type_id - - unit.cacheable = comp.cacheable - - if comp.is_explain: - unit.is_explain = True - unit.query_asts = comp.query_asts - - if comp.append_rollback: - unit.append_rollback = True - - if is_trailing_stmt: - unit.cardinality = comp.cardinality - - elif isinstance(comp, dbstate.SimpleQuery): - unit.sql = comp.sql - unit.in_type_args = comp.in_type_args - - elif isinstance(comp, dbstate.DDLQuery): - unit.sql = comp.sql - unit.create_db = comp.create_db - unit.drop_db = comp.drop_db - unit.drop_db_reset_connections = comp.drop_db_reset_connections - unit.create_db_template = comp.create_db_template - unit.create_db_mode = comp.create_db_mode - unit.ddl_stmt_id = comp.ddl_stmt_id - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - unit.feature_used_metrics = comp.feature_used_metrics - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - if comp.global_schema is not None: - unit.global_schema = pickle.dumps(comp.global_schema, -1) - unit.roles = _extract_roles(comp.global_schema) - - unit.config_ops.extend(comp.config_ops) - - elif isinstance(comp, dbstate.TxControlQuery): - if is_script: - raise errors.QueryError( - "Explicit transaction control commands cannot be executed " - "in an implicit transaction block" - ) - unit.sql = comp.sql - unit.cacheable = comp.cacheable - - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - unit.feature_used_metrics = comp.feature_used_metrics - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - if comp.global_schema is not None: - unit.global_schema = pickle.dumps(comp.global_schema, -1) - unit.roles = _extract_roles(comp.global_schema) - - if comp.modaliases is not None: - unit.modaliases = comp.modaliases - - if comp.action == dbstate.TxAction.START: - if unit.tx_id is not None: - raise errors.InternalServerError( - 'already in transaction') - unit.tx_id = ctx.state.current_tx().id - elif comp.action == dbstate.TxAction.COMMIT: - unit.tx_commit = True - elif comp.action == dbstate.TxAction.ROLLBACK: - unit.tx_rollback = True - elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: - unit.tx_savepoint_rollback = True - unit.sp_name = comp.sp_name - elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: - unit.tx_savepoint_declare = True - unit.sp_name = comp.sp_name - unit.sp_id = comp.sp_id - - elif isinstance(comp, dbstate.MigrationControlQuery): - unit.sql = comp.sql - unit.cacheable = comp.cacheable - - if not ctx.dump_restore_mode: - if comp.user_schema is not None: - final_user_schema = comp.user_schema - unit.user_schema = pickle.dumps(comp.user_schema, -1) - unit.user_schema_version = ( - _get_schema_version(comp.user_schema) - ) - unit.extensions, unit.ext_config_settings = ( - _extract_extensions(ctx, comp.user_schema) - ) - if comp.cached_reflection is not None: - unit.cached_reflection = \ - pickle.dumps(comp.cached_reflection, -1) - unit.ddl_stmt_id = comp.ddl_stmt_id - - if comp.modaliases is not None: - unit.modaliases = comp.modaliases - - if comp.tx_action == dbstate.TxAction.START: - if unit.tx_id is not None: - raise errors.InternalServerError( - 'already in transaction') - unit.tx_id = ctx.state.current_tx().id - elif comp.tx_action == dbstate.TxAction.COMMIT: - unit.tx_commit = True - elif comp.tx_action == dbstate.TxAction.ROLLBACK: - unit.tx_rollback = True - elif comp.action == dbstate.MigrationAction.ABORT: - unit.tx_abort_migration = True - - elif isinstance(comp, dbstate.SessionStateQuery): - unit.sql = comp.sql - unit.globals = comp.globals - - if comp.config_scope is qltypes.ConfigScope.INSTANCE: - if (not ctx.state.current_tx().is_implicit() or - statements_len > 1): - raise errors.QueryError( - 'CONFIGURE INSTANCE cannot be executed in a ' - 'transaction block') - - unit.system_config = True - elif comp.config_scope is qltypes.ConfigScope.GLOBAL: - unit.needs_readback = True - - elif comp.config_scope is qltypes.ConfigScope.DATABASE: - unit.database_config = True - unit.needs_readback = True - - if comp.is_backend_setting: - unit.backend_config = True - if comp.requires_restart: - unit.config_requires_restart = True - if comp.is_system_config: - unit.is_system_config = True - - unit.modaliases = ctx.state.current_tx().get_modaliases() - - if comp.config_op is not None: - unit.config_ops.append(comp.config_op) - - if comp.in_type_args: - unit.in_type_args = comp.in_type_args - if comp.in_type_data: - unit.in_type_data = comp.in_type_data - if comp.in_type_id: - unit.in_type_id = comp.in_type_id - - unit.has_set = True - - elif isinstance(comp, dbstate.MaintenanceQuery): - unit.sql = comp.sql - - elif isinstance(comp, dbstate.NullQuery): - pass - - else: # pragma: no cover - raise errors.InternalServerError('unknown compile state') - - if unit.in_type_args: - unit.in_type_args_real_count = sum( - len(p.sub_params[0]) if p.sub_params else 1 - for p in unit.in_type_args - ) - - if unit.warnings: - for warning in unit.warnings: - warning.__traceback__ = None - rv.append(unit) + if user_schema is not None: + final_user_schema = user_schema + if script_info: if ctx.state.current_tx().is_implicit(): if ctx.state.current_tx().get_migration_state(): @@ -2666,7 +2737,6 @@ def _try_compile( f'QueryUnit {unit!r} is cacheable but has config/aliases') if not na_cardinality and ( - len(unit.sql) > 1 or unit.tx_commit or unit.tx_rollback or unit.tx_savepoint_rollback or @@ -2691,6 +2761,260 @@ def _try_compile( return rv +def _make_query_unit( + *, + ctx: CompileContext, + stmt_ctx: CompileContext, + stmt: qlast.Base, + is_script: bool, + is_trailing_stmt: bool, + comp: dbstate.BaseQuery, + capabilities: enums.Capability, +) -> tuple[dbstate.QueryUnit, Optional[s_schema.Schema]]: + + # Initialize user_schema_version with the version this query is + # going to be compiled upon. This can be overwritten later by DDLs. + try: + schema_version = _get_schema_version( + stmt_ctx.state.current_tx().get_user_schema() + ) + except errors.InvalidReferenceError: + schema_version = None + + unit = dbstate.QueryUnit( + sql=b"", + status=status.get_status(stmt), + cardinality=enums.Cardinality.NO_RESULT, + capabilities=capabilities, + output_format=stmt_ctx.output_format, + cache_key=ctx.cache_key, + user_schema_version=schema_version, + warnings=comp.warnings, + ) + + if not comp.is_transactional: + if is_script: + raise errors.QueryError( + f'cannot execute {status.get_status(stmt).decode()} ' + f'with other commands in one block', + span=stmt.span, + ) + + if not ctx.state.current_tx().is_implicit(): + raise errors.QueryError( + f'cannot execute {status.get_status(stmt).decode()} ' + f'in a transaction', + span=stmt.span, + ) + + unit.is_transactional = False + + final_user_schema: Optional[s_schema.Schema] = None + + if isinstance(comp, dbstate.Query): + unit.sql = comp.sql + unit.cache_sql = comp.cache_sql + unit.cache_func_call = comp.cache_func_call + unit.globals = comp.globals + unit.in_type_args = comp.in_type_args + + unit.sql_hash = comp.sql_hash + + unit.out_type_data = comp.out_type_data + unit.out_type_id = comp.out_type_id + unit.in_type_data = comp.in_type_data + unit.in_type_id = comp.in_type_id + + unit.cacheable = comp.cacheable + + if comp.is_explain: + unit.is_explain = True + unit.query_asts = comp.query_asts + + if comp.run_and_rollback: + unit.run_and_rollback = True + + if is_trailing_stmt: + unit.cardinality = comp.cardinality + + elif isinstance(comp, dbstate.SimpleQuery): + unit.sql = comp.sql + unit.in_type_args = comp.in_type_args + + elif isinstance(comp, dbstate.DDLQuery): + unit.sql = comp.sql + unit.db_op_trailer = comp.db_op_trailer + unit.create_db = comp.create_db + unit.drop_db = comp.drop_db + unit.drop_db_reset_connections = comp.drop_db_reset_connections + unit.create_db_template = comp.create_db_template + unit.create_db_mode = comp.create_db_mode + unit.ddl_stmt_id = comp.ddl_stmt_id + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + unit.feature_used_metrics = comp.feature_used_metrics + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + if comp.global_schema is not None: + unit.global_schema = pickle.dumps(comp.global_schema, -1) + unit.roles = _extract_roles(comp.global_schema) + + unit.config_ops.extend(comp.config_ops) + + elif isinstance(comp, dbstate.TxControlQuery): + if is_script: + raise errors.QueryError( + "Explicit transaction control commands cannot be executed " + "in an implicit transaction block" + ) + unit.sql = comp.sql + unit.cacheable = comp.cacheable + + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + unit.feature_used_metrics = comp.feature_used_metrics + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + if comp.global_schema is not None: + unit.global_schema = pickle.dumps(comp.global_schema, -1) + unit.roles = _extract_roles(comp.global_schema) + + if comp.modaliases is not None: + unit.modaliases = comp.modaliases + + if comp.action == dbstate.TxAction.START: + if unit.tx_id is not None: + raise errors.InternalServerError( + 'already in transaction') + unit.tx_id = ctx.state.current_tx().id + elif comp.action == dbstate.TxAction.COMMIT: + unit.tx_commit = True + elif comp.action == dbstate.TxAction.ROLLBACK: + unit.tx_rollback = True + elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: + unit.tx_savepoint_rollback = True + unit.sp_name = comp.sp_name + elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: + unit.tx_savepoint_declare = True + unit.sp_name = comp.sp_name + unit.sp_id = comp.sp_id + + elif isinstance(comp, dbstate.MigrationControlQuery): + unit.sql = comp.sql + unit.cacheable = comp.cacheable + + if not ctx.dump_restore_mode: + if comp.user_schema is not None: + final_user_schema = comp.user_schema + unit.user_schema = pickle.dumps(comp.user_schema, -1) + unit.user_schema_version = ( + _get_schema_version(comp.user_schema) + ) + unit.extensions, unit.ext_config_settings = ( + _extract_extensions(ctx, comp.user_schema) + ) + if comp.cached_reflection is not None: + unit.cached_reflection = \ + pickle.dumps(comp.cached_reflection, -1) + unit.ddl_stmt_id = comp.ddl_stmt_id + + if comp.modaliases is not None: + unit.modaliases = comp.modaliases + + if comp.tx_action == dbstate.TxAction.START: + if unit.tx_id is not None: + raise errors.InternalServerError( + 'already in transaction') + unit.tx_id = ctx.state.current_tx().id + elif comp.tx_action == dbstate.TxAction.COMMIT: + unit.tx_commit = True + unit.append_tx_op = True + elif comp.tx_action == dbstate.TxAction.ROLLBACK: + unit.tx_rollback = True + unit.append_tx_op = True + elif comp.action == dbstate.MigrationAction.ABORT: + unit.tx_abort_migration = True + + elif isinstance(comp, dbstate.SessionStateQuery): + unit.sql = comp.sql + unit.globals = comp.globals + + if comp.config_scope is qltypes.ConfigScope.INSTANCE: + if not ctx.state.current_tx().is_implicit() or is_script: + raise errors.QueryError( + 'CONFIGURE INSTANCE cannot be executed in a ' + 'transaction block') + + unit.system_config = True + elif comp.config_scope is qltypes.ConfigScope.GLOBAL: + unit.needs_readback = True + + elif comp.config_scope is qltypes.ConfigScope.DATABASE: + unit.database_config = True + unit.needs_readback = True + + if comp.is_backend_setting: + unit.backend_config = True + if comp.requires_restart: + unit.config_requires_restart = True + if comp.is_system_config: + unit.is_system_config = True + + unit.modaliases = ctx.state.current_tx().get_modaliases() + + if comp.config_op is not None: + unit.config_ops.append(comp.config_op) + + if comp.in_type_args: + unit.in_type_args = comp.in_type_args + if comp.in_type_data: + unit.in_type_data = comp.in_type_data + if comp.in_type_id: + unit.in_type_id = comp.in_type_id + + unit.has_set = True + unit.output_format = enums.OutputFormat.NONE + + elif isinstance(comp, dbstate.MaintenanceQuery): + unit.sql = comp.sql + + elif isinstance(comp, dbstate.NullQuery): + pass + + else: # pragma: no cover + raise errors.InternalServerError('unknown compile state') + + if unit.in_type_args: + unit.in_type_args_real_count = sum( + len(p.sub_params[0]) if p.sub_params else 1 + for p in unit.in_type_args + ) + + if unit.warnings: + for warning in unit.warnings: + warning.__traceback__ = None + + return unit, final_user_schema + + def _extract_params( params: List[irast.Param], *, diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 73a1f753b70..89cc42e2506 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -60,7 +60,6 @@ class TxAction(enum.IntEnum): - START = 1 COMMIT = 2 ROLLBACK = 3 @@ -71,7 +70,6 @@ class TxAction(enum.IntEnum): class MigrationAction(enum.IntEnum): - START = 1 POPULATE = 2 DESCRIBE = 3 @@ -80,10 +78,11 @@ class MigrationAction(enum.IntEnum): REJECT_PROPOSED = 6 -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class BaseQuery: - - sql: Tuple[bytes, ...] + sql: bytes + is_transactional: bool = True + has_dml: bool = False cache_sql: Optional[Tuple[bytes, bytes]] = dataclasses.field( kw_only=True, default=None ) # (persist, evict) @@ -94,22 +93,14 @@ class BaseQuery: kw_only=True, default=() ) - @property - def is_transactional(self) -> bool: - return True - -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class NullQuery(BaseQuery): - - sql: Tuple[bytes, ...] = tuple() - is_transactional: bool = True - has_dml: bool = False + sql: bytes = b"" -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class Query(BaseQuery): - sql_hash: bytes cardinality: enums.Cardinality @@ -122,27 +113,21 @@ class Query(BaseQuery): globals: Optional[list[tuple[str, bool]]] = None - is_transactional: bool = True - has_dml: bool = False cacheable: bool = True is_explain: bool = False query_asts: Any = None - append_rollback: bool = False + run_and_rollback: bool = False -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class SimpleQuery(BaseQuery): - - sql: Tuple[bytes, ...] - is_transactional: bool = True - has_dml: bool = False # XXX: Temporary hack, since SimpleQuery will die in_type_args: Optional[List[Param]] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class SessionStateQuery(BaseQuery): - + sql: bytes = b"" config_scope: Optional[qltypes.ConfigScope] = None is_backend_setting: bool = False requires_restart: bool = False @@ -156,9 +141,8 @@ class SessionStateQuery(BaseQuery): in_type_args: Optional[List[Param]] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class DDLQuery(BaseQuery): - user_schema: Optional[s_schema.FlatSchema] feature_used_metrics: Optional[dict[str, float]] global_schema: Optional[s_schema.FlatSchema] = None @@ -169,18 +153,17 @@ class DDLQuery(BaseQuery): drop_db_reset_connections: bool = False create_db_template: Optional[str] = None create_db_mode: Optional[qlast.BranchType] = None + db_op_trailer: tuple[bytes, ...] = () ddl_stmt_id: Optional[str] = None config_ops: List[config.Operation] = dataclasses.field(default_factory=list) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class TxControlQuery(BaseQuery): - action: TxAction cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] - is_transactional: bool = True user_schema: Optional[s_schema.Schema] = None global_schema: Optional[s_schema.Schema] = None @@ -191,25 +174,22 @@ class TxControlQuery(BaseQuery): sp_id: Optional[int] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class MigrationControlQuery(BaseQuery): - action: MigrationAction tx_action: Optional[TxAction] cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] - is_transactional: bool = True user_schema: Optional[s_schema.FlatSchema] = None cached_reflection: Any = None ddl_stmt_id: Optional[str] = None -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class MaintenanceQuery(BaseQuery): - - is_transactional: bool = True + pass @dataclasses.dataclass(frozen=True) @@ -224,10 +204,11 @@ class Param: ############################# -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class QueryUnit: + sql: bytes - sql: Tuple[bytes, ...] + introspection_sql: Optional[bytes] = None # Status-line for the compiled command; returned to front-end # in a CommandComplete protocol message if the command is @@ -244,9 +225,9 @@ class QueryUnit: # Set only for units that contain queries that can be cached # as prepared statements in Postgres. - sql_hash: bytes = b'' + sql_hash: bytes = b"" - # True if all statments in *sql* can be executed inside a transaction. + # True if all statements in *sql* can be executed inside a transaction. # If False, they will be executed separately. is_transactional: bool = True @@ -298,6 +279,10 @@ class QueryUnit: create_db_template: Optional[str] = None create_db_mode: Optional[str] = None + # If a branch command needs extra SQL commands to be performed, + # those would end up here. + db_op_trailer: tuple[bytes, ...] = () + # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None @@ -357,7 +342,8 @@ class QueryUnit: is_explain: bool = False query_asts: Any = None - append_rollback: bool = False + run_and_rollback: bool = False + append_tx_op: bool = False @property def has_ddl(self) -> bool: @@ -390,13 +376,12 @@ def deserialize(cls, data: bytes) -> Self: def maybe_use_func_cache(self) -> None: if self.cache_func_call is not None: sql, sql_hash = self.cache_func_call - self.sql = (sql,) + self.sql = sql self.sql_hash = sql_hash @dataclasses.dataclass class QueryUnitGroup: - # All capabilities used by any query units in this group capabilities: enums.Capability = enums.Capability(0) @@ -531,11 +516,23 @@ class SQLQueryUnit: query: str = dataclasses.field(repr=False) """Translated query text.""" + prefix_len: int = 0 + translation_data: Optional[pgcodegen.TranslationData] = None + """Translation source map.""" + + eql_format_query: Optional[str] = dataclasses.field( + repr=False, default=None) + """Translated query text returning data in single-column format.""" + + eql_format_translation_data: Optional[pgcodegen.TranslationData] = None + """Translation source map for single-column format query.""" + orig_query: str = dataclasses.field(repr=False) """Original query text before translation.""" - translation_data: Optional[pgcodegen.TranslationData] = None - """Translation source map.""" + cardinality: enums.Cardinality = enums.Cardinality.NO_RESULT + + capabilities: enums.Capability = enums.Capability.NONE fe_settings: SQLSettings """Frontend-only settings effective during translation of this unit.""" @@ -568,29 +565,29 @@ class SQLQueryUnit: class CommandCompleteTag: - '''Dictates the tag of CommandComplete message that concludes this query.''' + """Dictates the tag of CommandComplete message that concludes this query.""" @dataclasses.dataclass(kw_only=True) class TagPlain(CommandCompleteTag): - '''Set the tag verbatim''' + """Set the tag verbatim""" tag: bytes @dataclasses.dataclass(kw_only=True) class TagCountMessages(CommandCompleteTag): - '''Count DataRow messages in the response and set the tag to - f'{prefix} {count_of_messages}'.''' + """Count DataRow messages in the response and set the tag to + f'{prefix} {count_of_messages}'.""" prefix: str @dataclasses.dataclass(kw_only=True) class TagUnpackRow(CommandCompleteTag): - '''Intercept a single DataRow message with a single column which represents + """Intercept a single DataRow message with a single column which represents the number of modified rows. - Sets the CommandComplete tag to f'{prefix} {modified_rows}'.''' + Sets the CommandComplete tag to f'{prefix} {modified_rows}'.""" prefix: str @@ -634,13 +631,15 @@ class ParsedDatabase: SQLSetting = tuple[str | int | float, ...] SQLSettings = immutables.Map[Optional[str], Optional[SQLSetting]] DEFAULT_SQL_SETTINGS: SQLSettings = immutables.Map() -DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map({ - "search_path": ("public",), - "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), - "server_version_num": cast( - SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) - ), -}) +DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map( + { + "search_path": ("public",), + "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), + "server_version_num": cast( + SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) + ), + } +) @dataclasses.dataclass @@ -680,11 +679,16 @@ def apply(self, query_unit: SQLQueryUnit) -> None: self.in_tx_local_settings = None self.savepoints.clear() elif query_unit.tx_action == TxAction.DECLARE_SAVEPOINT: - self.savepoints.append(( - query_unit.sp_name, - self.in_tx_settings, - self.in_tx_local_settings, - )) # type: ignore + assert query_unit.sp_name is not None + assert self.in_tx_settings is not None + assert self.in_tx_local_settings is not None + self.savepoints.append( + ( + query_unit.sp_name, + self.in_tx_settings, + self.in_tx_local_settings, + ) + ) elif query_unit.tx_action == TxAction.ROLLBACK_TO_SAVEPOINT: while self.savepoints: sp_name, settings, local_settings = self.savepoints[-1] @@ -735,7 +739,6 @@ def _set(attr_name: str) -> None: class ProposedMigrationStep(NamedTuple): - statements: Tuple[str, ...] confidence: float prompt: str @@ -748,17 +751,16 @@ class ProposedMigrationStep(NamedTuple): def to_json(self) -> Dict[str, Any]: return { - 'statements': [{'text': stmt} for stmt in self.statements], - 'confidence': self.confidence, - 'prompt': self.prompt, - 'prompt_id': self.prompt_id, - 'data_safe': self.data_safe, - 'required_user_input': list(self.required_user_input), + "statements": [{"text": stmt} for stmt in self.statements], + "confidence": self.confidence, + "prompt": self.prompt, + "prompt_id": self.prompt_id, + "data_safe": self.data_safe, + "required_user_input": list(self.required_user_input), } class MigrationState(NamedTuple): - parent_migration: Optional[s_migrations.Migration] initial_schema: s_schema.Schema initial_savepoint: Optional[str] @@ -769,14 +771,12 @@ class MigrationState(NamedTuple): class MigrationRewriteState(NamedTuple): - initial_savepoint: Optional[str] target_schema: s_schema.Schema accepted_migrations: Tuple[qlast.CreateMigration, ...] class TransactionState(NamedTuple): - id: int name: Optional[str] local_user_schema: s_schema.FlatSchema | None @@ -799,7 +799,6 @@ def user_schema(self) -> s_schema.FlatSchema: class Transaction: - _savepoints: Dict[int, TransactionState] _constate: CompilerConnectionState @@ -816,7 +815,6 @@ def __init__( cached_reflection: immutables.Map[str, Tuple[str, ...]], implicit: bool = True, ) -> None: - assert not isinstance(user_schema, s_schema.ChainedSchema) self._constate = constate @@ -857,12 +855,12 @@ def make_explicit(self) -> None: if self._implicit: self._implicit = False else: - raise errors.TransactionError('already in explicit transaction') + raise errors.TransactionError("already in explicit transaction") def declare_savepoint(self, name: str) -> int: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) return self._declare_savepoint(name) @@ -882,7 +880,7 @@ def _declare_savepoint(self, name: str) -> int: def rollback_to_savepoint(self, name: str) -> TransactionState: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) return self._rollback_to_savepoint(name) @@ -899,7 +897,7 @@ def _rollback_to_savepoint(self, name: str) -> TransactionState: sp_ids_to_erase.append(sp.id) else: - raise errors.TransactionError(f'there is no {name!r} savepoint') + raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) @@ -909,7 +907,7 @@ def _rollback_to_savepoint(self, name: str) -> TransactionState: def release_savepoint(self, name: str) -> None: if self.is_implicit(): raise errors.TransactionError( - 'savepoints can only be used in transaction blocks' + "savepoints can only be used in transaction blocks" ) self._release_savepoint(name) @@ -925,7 +923,7 @@ def _release_savepoint(self, name: str) -> None: if sp.name == name: break else: - raise errors.TransactionError(f'there is no {name!r} savepoint') + raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) @@ -1030,8 +1028,7 @@ def update_migration_rewrite_state( class CompilerConnectionState: - - __slots__ = ('_savepoints_log', '_current_tx', '_tx_count', '_user_schema') + __slots__ = ("_savepoints_log", "_current_tx", "_tx_count", "_user_schema") _savepoints_log: Dict[int, TransactionState] _user_schema: Optional[s_schema.FlatSchema] @@ -1111,7 +1108,7 @@ def sync_to_savepoint(self, spid: int) -> None: """Synchronize the compiler state with the current DB state.""" if not self.can_sync_to_savepoint(spid): - raise RuntimeError(f'failed to lookup savepoint with id={spid}') + raise RuntimeError(f"failed to lookup savepoint with id={spid}") sp = self._savepoints_log[spid] self._current_tx = sp.tx @@ -1137,7 +1134,7 @@ def start_tx(self) -> None: if self._current_tx.is_implicit(): self._current_tx.make_explicit() else: - raise errors.TransactionError('already in transaction') + raise errors.TransactionError("already in transaction") def rollback_tx(self) -> TransactionState: # Note that we might not be in a transaction as we allow @@ -1159,7 +1156,7 @@ def rollback_tx(self) -> TransactionState: def commit_tx(self) -> TransactionState: if self._current_tx.is_implicit(): - raise errors.TransactionError('cannot commit: not in transaction') + raise errors.TransactionError("cannot commit: not in transaction") latest_state = self._current_tx._current @@ -1184,5 +1181,5 @@ def sync_tx(self, txid: int) -> None: return raise errors.InternalServerError( - f'failed to lookup transaction or savepoint with id={txid}' + f"failed to lookup transaction or savepoint with id={txid}" ) # pragma: no cover diff --git a/edb/server/compiler/ddl.py b/edb/server/compiler/ddl.py index 6a1d7f20b80..76d443dbe26 100644 --- a/edb/server/compiler/ddl.py +++ b/edb/server/compiler/ddl.py @@ -64,17 +64,28 @@ from edb.pgsql import common as pg_common from edb.pgsql import delta as pg_delta from edb.pgsql import dbops as pg_dbops -from edb.pgsql import trampoline from . import dbstate from . import compiler +NIL_QUERY = b"SELECT LIMIT 0" + + def compile_and_apply_ddl_stmt( ctx: compiler.CompileContext, - stmt: qlast.DDLOperation, + stmt: qlast.DDLCommand, source: Optional[edgeql.Source] = None, ) -> dbstate.DDLQuery: + query, _ = _compile_and_apply_ddl_stmt(ctx, stmt, source) + return query + + +def _compile_and_apply_ddl_stmt( + ctx: compiler.CompileContext, + stmt: qlast.DDLCommand, + source: Optional[edgeql.Source] = None, +) -> tuple[dbstate.DDLQuery, Optional[pg_dbops.SQLBlock]]: if isinstance(stmt, qlast.GlobalObjectCommand): ctx._assert_not_in_migration_block(stmt) @@ -127,7 +138,7 @@ def compile_and_apply_ddl_stmt( ) ], ) - return compile_and_apply_ddl_stmt(ctx, cm) + return _compile_and_apply_ddl_stmt(ctx, cm) assert isinstance(stmt, qlast.DDLCommand) new_schema, delta = s_ddl.delta_and_schema_from_ddl( @@ -175,14 +186,16 @@ def compile_and_apply_ddl_stmt( current_tx.update_migration_state(mstate) current_tx.update_schema(new_schema) - return dbstate.DDLQuery( - sql=(b'SELECT LIMIT 0',), + query = dbstate.DDLQuery( + sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) + return query, None + store_migration_sdl = compiler._get_config_val(ctx, 'store_migration_sdl') if ( isinstance(stmt, qlast.CreateMigration) @@ -207,26 +220,30 @@ def compile_and_apply_ddl_stmt( current_tx.update_schema(new_schema) - return dbstate.DDLQuery( - sql=(b'SELECT LIMIT 0',), + query = dbstate.DDLQuery( + sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) + return query, None + # Apply and adapt delta, build native delta plan, which # will also update the schema. block, new_types, config_ops = _process_delta(ctx, delta) ddl_stmt_id: Optional[str] = None - is_transactional = block.is_transactional() if not is_transactional: - sql = tuple(stmt.encode('utf-8') for stmt in block.get_statements()) + if not isinstance(stmt, qlast.DatabaseCommand): + raise AssertionError( + f"unexpected non-transaction DDL command type: {stmt}") + sql_stmts = block.get_statements() + sql = sql_stmts[0].encode("utf-8") + db_op_trailer = tuple(stmt.encode("utf-8") for stmt in sql_stmts[1:]) else: - sql = (block.to_string().encode('utf-8'),) - if new_types: # Inject a query returning backend OIDs for the newly # created types. @@ -234,10 +251,10 @@ def compile_and_apply_ddl_stmt( new_type_ids = [ f'{pg_common.quote_literal(tid)}::uuid' for tid in new_types ] - sql = sql + ( - trampoline.fixup_query(textwrap.dedent( - f'''\ - SELECT + # Return newly-added type id mapping via the indirect + # return channel (see PGConnection.last_indirect_return) + new_types_sql = textwrap.dedent(f"""\ + PERFORM edgedb.indirect_return( json_build_object( 'ddl_stmt_id', {pg_common.quote_literal(ddl_stmt_id)}, @@ -245,7 +262,7 @@ def compile_and_apply_ddl_stmt( (SELECT json_object_agg( "id"::text, - "backend_id" + json_build_array("backend_id", "name") ) FROM edgedb_VER."_SchemaType" @@ -254,11 +271,15 @@ def compile_and_apply_ddl_stmt( {', '.join(new_type_ids)} ]) ) - )::text; - ''' - )).encode('utf-8'), + )::text + )""" ) + block.add_command(pg_dbops.Query(text=new_types_sql).code()) + + sql = block.to_string().encode('utf-8') + db_op_trailer = () + create_db = None drop_db = None drop_db_reset_connections = False @@ -286,10 +307,10 @@ def compile_and_apply_ddl_stmt( debug.dump_code(code, lexer='sql') if debug.flags.delta_execute: debug.header('Delta Script') - debug.dump_code(b'\n'.join(sql), lexer='sql') + debug.dump_code(sql + b"\n".join(db_op_trailer), lexer='sql') new_user_schema = current_tx.get_user_schema_if_updated() - return dbstate.DDLQuery( + query = dbstate.DDLQuery( sql=sql, is_transactional=is_transactional, create_db=create_db, @@ -297,6 +318,7 @@ def compile_and_apply_ddl_stmt( drop_db_reset_connections=drop_db_reset_connections, create_db_template=create_db_template, create_db_mode=create_db_mode, + db_op_trailer=db_op_trailer, ddl_stmt_id=ddl_stmt_id, user_schema=new_user_schema, cached_reflection=current_tx.get_cached_reflection_if_updated(), @@ -309,6 +331,8 @@ def compile_and_apply_ddl_stmt( ), ) + return query, block + def _new_delta_context( ctx: compiler.CompileContext, args: Any = None @@ -464,7 +488,7 @@ def _start_migration( else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, @@ -573,7 +597,7 @@ def _populate_migration( current_tx.update_schema(schema) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.POPULATE, cacheable=False, @@ -801,7 +825,7 @@ def _alter_current_migration_reject_proposed( current_tx.update_migration_state(mstate) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.REJECT_PROPOSED, cacheable=False, @@ -876,7 +900,7 @@ def _commit_migration( current_tx.update_migration_rewrite_state(mrstate) return dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.COMMIT, tx_action=None, cacheable=False, @@ -893,16 +917,12 @@ def _commit_migration( if mstate.initial_savepoint: current_tx.commit_migration(mstate.initial_savepoint) - sql = ddl_query.sql tx_action = None else: - tx_cmd = qlast.CommitTransaction() - tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) - sql = ddl_query.sql + tx_query.sql - tx_action = tx_query.action + tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( - sql=sql, + sql=ddl_query.sql, ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, @@ -923,7 +943,7 @@ def _abort_migration( if mstate.initial_savepoint: current_tx.abort_migration(mstate.initial_savepoint) - sql: Tuple[bytes, ...] = (b'SELECT LIMIT 0',) + sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() @@ -967,7 +987,7 @@ def _start_migration_rewrite( else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( - sql=(b'SELECT LIMIT 0',), + sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, @@ -1052,25 +1072,24 @@ def _commit_migration_rewrite( for cm in cmds: cm.dump_edgeql() - sqls: List[bytes] = [] + block = pg_dbops.PLTopBlock() for cmd in cmds: - ddl_query = compile_and_apply_ddl_stmt(ctx, cmd) + _, ddl_block = _compile_and_apply_ddl_stmt(ctx, cmd) + assert isinstance(ddl_block, pg_dbops.PLBlock) # We know nothing serious can be in that query # except for the SQL, so it's fine to just discard # it all. - sqls.extend(ddl_query.sql) + for stmt in ddl_block.get_statements(): + block.add_command(stmt) if mrstate.initial_savepoint: current_tx.commit_migration(mrstate.initial_savepoint) tx_action = None else: - tx_cmd = qlast.CommitTransaction() - tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) - sqls.extend(tx_query.sql) - tx_action = tx_query.action + tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( - sql=tuple(sqls), + sql=block.to_string().encode("utf-8"), action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, cacheable=False, @@ -1090,7 +1109,7 @@ def _abort_migration_rewrite( if mrstate.initial_savepoint: current_tx.abort_migration(mrstate.initial_savepoint) - sql: Tuple[bytes, ...] = (b'SELECT LIMIT 0',) + sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() @@ -1146,8 +1165,6 @@ def _reset_schema( current_schema=empty_schema, ) - sqls: List[bytes] = [] - # diff and create migration that drops all objects diff = s_ddl.delta_schemas(schema, empty_schema) new_ddl: Tuple[qlast.DDLCommand, ...] = tuple( @@ -1156,8 +1173,8 @@ def _reset_schema( create_mig = qlast.CreateMigration( # type: ignore body=qlast.NestedQLBlock(commands=tuple(new_ddl)), # type: ignore ) - ddl_query = compile_and_apply_ddl_stmt(ctx, create_mig) - sqls.extend(ddl_query.sql) + ddl_query, ddl_block = _compile_and_apply_ddl_stmt(ctx, create_mig) + assert ddl_block is not None # delete all migrations schema = current_tx.get_schema(ctx.compiler_state.std_schema) @@ -1170,11 +1187,13 @@ def _reset_schema( drop_mig = qlast.DropMigration( # type: ignore name=qlast.ObjectRef(name=mig.get_name(schema).name), ) - ddl_query = compile_and_apply_ddl_stmt(ctx, drop_mig) - sqls.extend(ddl_query.sql) + _, mig_block = _compile_and_apply_ddl_stmt(ctx, drop_mig) + assert isinstance(mig_block, pg_dbops.PLBlock) + for stmt in mig_block.get_statements(): + ddl_block.add_command(stmt) return dbstate.MigrationControlQuery( - sql=tuple(sqls), + sql=ddl_block.to_string().encode("utf-8"), ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=None, @@ -1278,7 +1297,7 @@ def _track(key: str) -> None: def repair_schema( ctx: compiler.CompileContext, -) -> Optional[tuple[tuple[bytes, ...], s_schema.Schema, Any]]: +) -> Optional[tuple[bytes, s_schema.Schema, Any]]: """Repair inconsistencies in the schema caused by bug fixes Works by comparing the actual current schema to the schema we get @@ -1340,11 +1359,11 @@ def repair_schema( is_transactional = block.is_transactional() assert not new_types assert is_transactional - sql = (block.to_string().encode('utf-8'),) + sql = block.to_string().encode('utf-8') if debug.flags.delta_execute: debug.header('Repair Delta Script') - debug.dump_code(b'\n'.join(sql), lexer='sql') + debug.dump_code(sql, lexer='sql') return sql, reloaded_schema, config_ops @@ -1363,7 +1382,7 @@ def administer_repair_schema( res = repair_schema(ctx) if not res: - return dbstate.MaintenanceQuery(sql=(b'',)) + return dbstate.MaintenanceQuery(sql=b"") sql, new_schema, config_ops = res current_tx.update_schema(new_schema) @@ -1511,9 +1530,11 @@ def administer_reindex( for pindex in pindexes ] - return dbstate.MaintenanceQuery( - sql=tuple(q.encode('utf-8') for q in commands) - ) + block = pg_dbops.PLTopBlock() + for command in commands: + block.add_command(command) + + return dbstate.MaintenanceQuery(sql=block.to_string().encode("utf-8")) def administer_vacuum( @@ -1663,7 +1684,7 @@ def administer_vacuum( command = f'VACUUM {options} ' + ', '.join(tables_and_columns) return dbstate.MaintenanceQuery( - sql=(command.encode('utf-8'),), + sql=command.encode('utf-8'), is_transactional=False, ) @@ -1700,3 +1721,31 @@ def administer_prepare_upgrade( cacheable=False, migration_block_query=True, ) + + +def validate_schema_equivalence( + state: compiler.CompilerState, + schema_a: s_schema.FlatSchema, + schema_b: s_schema.FlatSchema, + global_schema: s_schema.FlatSchema, +) -> None: + schema_a_full = s_schema.ChainedSchema( + state.std_schema, + schema_a, + global_schema, + ) + schema_b_full = s_schema.ChainedSchema( + state.std_schema, + schema_b, + global_schema, + ) + + diff = s_ddl.delta_schemas(schema_a_full, schema_b_full) + complete = not bool(diff.get_subcommands()) + if not complete: + if debug.flags.delta_plan: + debug.header('COMPARE SCHEMAS MISMATCH') + debug.dump(diff) + raise AssertionError( + f'schemas did not match after introspection:\n{debug.dumps(diff)}' + ) diff --git a/edb/server/compiler/enums.py b/edb/server/compiler/enums.py index 0e1fb3ff4fc..15231ba8808 100644 --- a/edb/server/compiler/enums.py +++ b/edb/server/compiler/enums.py @@ -85,6 +85,11 @@ class InputFormat(strenum.StrEnum): JSON = 'JSON' +class InputLanguage(strenum.StrEnum): + EDGEQL = 'EDGEQL' + SQL = 'SQL' + + def cardinality_from_ir_value(card: ir.Cardinality) -> Cardinality: if card is ir.Cardinality.AT_MOST_ONE: return Cardinality.AT_MOST_ONE diff --git a/edb/server/compiler/rpc.pxd b/edb/server/compiler/rpc.pxd index afb0c01ec77..f8aca79144d 100644 --- a/edb/server/compiler/rpc.pxd +++ b/edb/server/compiler/rpc.pxd @@ -20,14 +20,17 @@ cimport cython cdef char serialize_output_format(val) cdef deserialize_output_format(char mode) +cdef char serialize_input_language(val) +cdef deserialize_input_language(char mode) @cython.final cdef class CompilationRequest: cdef: - object _serializer + object serializer readonly object source readonly object protocol_version + readonly object input_language readonly object output_format readonly object input_format readonly bint expect_one @@ -35,6 +38,8 @@ cdef class CompilationRequest: readonly bint inline_typeids readonly bint inline_typenames readonly bint inline_objectids + readonly str role_name + readonly str branch_name readonly object modaliases readonly object session_config diff --git a/edb/server/compiler/rpc.pyi b/edb/server/compiler/rpc.pyi index 47b6d2fe464..baada6b7d71 100644 --- a/edb/server/compiler/rpc.pyi +++ b/edb/server/compiler/rpc.pyi @@ -28,6 +28,7 @@ from edb.server.compiler import sertypes, enums class CompilationRequest: source: edgeql.Source protocol_version: defines.ProtocolVersion + input_language: enums.InputLanguage output_format: enums.OutputFormat input_format: enums.InputFormat expect_one: bool @@ -35,6 +36,8 @@ class CompilationRequest: inline_typeids: bool inline_typenames: bool inline_objectids: bool + role_name: str + branch_name: str modaliases: immutables.Map[str | None, str] | None session_config: immutables.Map[str, config.SettingValue] | None @@ -46,6 +49,7 @@ class CompilationRequest: protocol_version: defines.ProtocolVersion, schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, + input_language: enums.InputLanguage = enums.InputLanguage.EDGEQL, output_format: enums.OutputFormat = enums.OutputFormat.BINARY, input_format: enums.InputFormat = enums.InputFormat.BINARY, expect_one: bool = False, @@ -57,6 +61,8 @@ class CompilationRequest: session_config: typing.Mapping[str, config.SettingValue] | None = None, database_config: typing.Mapping[str, config.SettingValue] | None = None, system_config: typing.Mapping[str, config.SettingValue] | None = None, + role_name: str = defines.EDGEDB_SUPERUSER, + branch_name: str = defines.EDGEDB_SUPERUSER_DB, ): ... diff --git a/edb/server/compiler/rpc.pyx b/edb/server/compiler/rpc.pyx index d510e131774..856fde06462 100644 --- a/edb/server/compiler/rpc.pyx +++ b/edb/server/compiler/rpc.pyx @@ -32,6 +32,7 @@ from edb.edgeql import qltypes from edb.edgeql import tokenizer from edb.server import config, defines from edb.server.pgproto.pgproto cimport WriteBuffer, ReadBuffer +from edb.pgsql import parser as pgparser from . import enums, sertypes @@ -43,6 +44,9 @@ cdef object OUT_FMT_NONE = enums.OutputFormat.NONE cdef object IN_FMT_BINARY = enums.InputFormat.BINARY cdef object IN_FMT_JSON = enums.InputFormat.JSON +cdef object IN_LANG_EDGEQL = enums.InputLanguage.EDGEQL +cdef object IN_LANG_SQL = enums.InputLanguage.SQL + cdef char MASK_JSON_PARAMETERS = 1 << 0 cdef char MASK_EXPECT_ONE = 1 << 1 cdef char MASK_INLINE_TYPEIDS = 1 << 2 @@ -74,7 +78,26 @@ cdef deserialize_output_format(char mode): return OUT_FMT_NONE else: raise errors.BinaryProtocolError( - f'unknown output mode "{repr(mode)[2:-1]}"') + f'unknown output format {mode.to_bytes(1, "big")!r}') + + +cdef char serialize_input_language(val): + if val is IN_LANG_EDGEQL: + return b'E' + elif val is IN_LANG_SQL: + return b'S' + else: + raise AssertionError("unreachable") + + +cdef deserialize_input_language(char lang): + if lang == b'E': + return IN_LANG_EDGEQL + elif lang == b'S': + return IN_LANG_SQL + else: + raise errors.BinaryProtocolError( + f'unknown input language {lang.to_bytes(1, "big")!r}') @cython.final @@ -86,6 +109,7 @@ cdef class CompilationRequest: protocol_version: defines.ProtocolVersion, schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, + input_language: enums.InputLanguage = enums.InputLanguage.EDGEQL, output_format: enums.OutputFormat = OUT_FMT_BINARY, input_format: enums.InputFormat = IN_FMT_BINARY, expect_one: bint = False, @@ -97,10 +121,13 @@ cdef class CompilationRequest: session_config: Mapping[str, config.SettingValue] | None = None, database_config: Mapping[str, config.SettingValue] | None = None, system_config: Mapping[str, config.SettingValue] | None = None, + role_name: str = defines.EDGEDB_SUPERUSER, + branch_name: str = defines.EDGEDB_SUPERUSER_DB, ): - self._serializer = compilation_config_serializer + self.serializer = compilation_config_serializer self.source = source self.protocol_version = protocol_version + self.input_language = input_language self.output_format = output_format self.input_format = input_format self.expect_one = expect_one @@ -113,6 +140,8 @@ cdef class CompilationRequest: self.session_config = session_config self.database_config = database_config self.system_config = system_config + self.role_name = role_name + self.branch_name = branch_name self.serialized_cache = None self.cache_key = None @@ -124,7 +153,8 @@ cdef class CompilationRequest: source=self.source, protocol_version=self.protocol_version, schema_version=self.schema_version, - compilation_config_serializer=self._serializer, + compilation_config_serializer=self.serializer, + input_language=self.input_language, output_format=self.output_format, input_format=self.input_format, expect_one=self.expect_one, @@ -136,6 +166,8 @@ cdef class CompilationRequest: session_config=self.session_config, database_config=self.database_config, system_config=self.system_config, + role_name=self.role_name, + branch_name=self.branch_name, ) rv.serialized_cache = self.serialized_cache rv.cache_key = self.cache_key @@ -178,15 +210,8 @@ cdef class CompilationRequest: query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, ) -> CompilationRequest: - buf = ReadBuffer.new_message_parser(data) - - if data[0] == 0: - return _deserialize_comp_req_v0( - buf, query_text, compilation_config_serializer) - else: - raise errors.UnsupportedProtocolVersionError( - f"unsupported compile cache: version {data[0]}" - ) + return _deserialize_comp_req( + data, query_text, compilation_config_serializer) def serialize(self) -> bytes: if self.serialized_cache is None: @@ -199,8 +224,12 @@ cdef class CompilationRequest: return self.cache_key cdef _serialize(self): - cache_key, buf = _serialize_comp_req_v0(self) - self.cache_key = cache_key + cdef WriteBuffer buf + + hash_obj, buf = _serialize_comp_req(self) + cache_key = hash_obj.digest() + buf.write_bytes(cache_key) + self.cache_key = uuidgen.from_bytes(cache_key) self.serialized_cache = bytes(buf) def __hash__(self): @@ -210,17 +239,44 @@ cdef class CompilationRequest: return ( self.source.cache_key() == other.source.cache_key() and self.protocol_version == other.protocol_version and + self.input_language == other.input_language and self.output_format == other.output_format and self.input_format == other.input_format and self.expect_one == other.expect_one and self.implicit_limit == other.implicit_limit and self.inline_typeids == other.inline_typeids and self.inline_typenames == other.inline_typenames and - self.inline_objectids == other.inline_objectids + self.inline_objectids == other.inline_objectids and + self.role_name == other.role_name and + self.branch_name == other.branch_name + ) + + +cdef CompilationRequest _deserialize_comp_req( + data: bytes, + query_text: str, + compilation_config_serializer: sertypes.CompilationConfigSerializer, +): + cdef: + ReadBuffer buf = ReadBuffer.new_message_parser(data) + CompilationRequest req + + if data[0] == 1: + req = _deserialize_comp_req_v1( + buf, query_text, compilation_config_serializer) + else: + raise errors.UnsupportedProtocolVersionError( + f"unsupported compile cache: version {data[0]}" ) + # Cache key is always trailing regardless of the version. + req.cache_key = uuidgen.from_bytes(buf.read_bytes(16)) + req.serialized_cache = data + + return req -cdef _deserialize_comp_req_v0( + +cdef _deserialize_comp_req_v1( buf: ReadBuffer, query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, @@ -249,21 +305,14 @@ cdef _deserialize_comp_req_v0( # * Session config: int32-length-prefixed serialized data # * Serialized Source or NormalizedSource without the original query # string - # * 16-byte cache key = BLAKE-2b hash of: - # * All above serialized, - # * Except that the source is replaced with Source.cache_key(), and - # * Except that the serialized session config is replaced by - # serialized combined config (session -> database -> system) - # that only affects compilation. - # * The schema version - # * OPTIONALLY, the schema version. We wanted to bump the protocol - # version to include this, but 5.x hard crashes when it reads a - # persistent cache with entries it doesn't understand, so instead - # we stick it on the end where it will be ignored by old versions. + # * The schema version ID. + # * 1 byte input language (the same as in the binary protocol) + # * role_name as a UTF-8 encoded string + # * branch_name as a UTF-8 encoded string cdef char flags - assert buf.read_byte() == 0 # version + assert buf.read_byte() == 1 # version flags = buf.read_byte() if flags & MASK_JSON_PARAMETERS > 0: @@ -318,18 +367,29 @@ cdef _deserialize_comp_req_v0( else: session_config = None - source = tokenizer.deserialize( - buf.read_len_prefixed_bytes(), query_text - ) - - cache_key = uuidgen.from_bytes(buf.read_bytes(16)) + serialized_source = buf.read_len_prefixed_bytes() schema_version = uuidgen.from_bytes(buf.read_bytes(16)) + input_language = deserialize_input_language(buf.read_byte()) + role_name = buf.read_len_prefixed_utf8() + branch_name = buf.read_len_prefixed_utf8() + + if input_language is enums.InputLanguage.EDGEQL: + source = tokenizer.deserialize(serialized_source, query_text) + elif input_language is enums.InputLanguage.SQL: + source = pgparser.deserialize(serialized_source) + else: + raise AssertionError( + f"unexpected source language in serialized " + f"CompilationRequest: {input_language}" + ) + req = CompilationRequest( source=source, protocol_version=protocol_version, schema_version=schema_version, compilation_config_serializer=serializer, + input_language=input_language, output_format=output_format, input_format=input_format, expect_one=expect_one, @@ -339,19 +399,18 @@ cdef _deserialize_comp_req_v0( inline_objectids=inline_objectids, modaliases=modaliases, session_config=session_config, + role_name=role_name, + branch_name=branch_name, ) - req.serialized_cache = data - req.cache_key = cache_key - return req -cdef _serialize_comp_req_v0(req: CompilationRequest): - # Please see _deserialize_v0 for the format doc +cdef _serialize_comp_req(req: CompilationRequest): + # Please see _deserialize_comp_req_v1 for the format doc cdef: - char version = 0, flags + char version = 1, flags WriteBuffer out = WriteBuffer.new() out.write_byte(version) @@ -385,7 +444,7 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): out.write_str(k, "utf-8") out.write_str(v, "utf-8") - type_id, desc = req._serializer.describe() + type_id, desc = req.serializer.describe() out.write_bytes(type_id.bytes) out.write_len_prefixed_bytes(desc) @@ -395,7 +454,7 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): if req.session_config is None: session_config = b"" else: - session_config = req._serializer.encode_configs( + session_config = req.serializer.encode_configs( req.session_config ) out.write_len_prefixed_bytes(session_config) @@ -403,7 +462,7 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): # Build config that affects compilation: session -> database -> system. # This is only used for calculating cache_key, while session # config itreq is separately stored above in the serialized format. - serialized_comp_config = req._serializer.encode_configs( + serialized_comp_config = req.serializer.encode_configs( req.system_config, req.database_config, req.session_config ) hash_obj.update(serialized_comp_config) @@ -412,11 +471,18 @@ cdef _serialize_comp_req_v0(req: CompilationRequest): assert req.schema_version is not None hash_obj.update(req.schema_version.bytes) - cache_key_bytes = hash_obj.digest() - cache_key = uuidgen.from_bytes(cache_key_bytes) - out.write_len_prefixed_bytes(req.source.serialize()) - out.write_bytes(cache_key_bytes) out.write_bytes(req.schema_version.bytes) - return cache_key, out + out.write_byte(serialize_input_language(req.input_language)) + hash_obj.update(req.input_language.value.encode("utf-8")) + + role_name = req.role_name.encode("utf-8") + out.write_len_prefixed_bytes(role_name) + hash_obj.update(role_name) + + branch_name = req.branch_name.encode("utf-8") + out.write_len_prefixed_bytes(branch_name) + hash_obj.update(branch_name) + + return hash_obj, out diff --git a/edb/server/compiler/sertypes.py b/edb/server/compiler/sertypes.py index 0e562c235ff..fd440218730 100644 --- a/edb/server/compiler/sertypes.py +++ b/edb/server/compiler/sertypes.py @@ -374,7 +374,7 @@ def _describe_tuple(t: s_types.Tuple, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) @@ -420,7 +420,7 @@ def _describe_array(t: s_types.Array, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) @@ -459,7 +459,7 @@ def _describe_range(t: s_types.Range, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) @@ -494,7 +494,7 @@ def _describe_multirange(t: s_types.MultiRange, *, ctx: Context) -> uuid.UUID: # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined - buf.append(_bool_packer(True)) + buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) diff --git a/edb/server/compiler/sql.py b/edb/server/compiler/sql.py index 445abd21356..1f0ce8a52e3 100644 --- a/edb/server/compiler/sql.py +++ b/edb/server/compiler/sql.py @@ -18,23 +18,28 @@ from __future__ import annotations -from typing import Tuple, Mapping, Sequence, List, TYPE_CHECKING, Optional +from typing import Mapping, Sequence, List, TYPE_CHECKING, Optional import dataclasses import functools import hashlib import immutables +import json from edb import errors +from edb.common import uuidgen +from edb.server import defines from edb.schema import schema as s_schema from edb.pgsql import ast as pgast from edb.pgsql import common as pg_common from edb.pgsql import codegen as pg_codegen +from edb.pgsql import params as pg_params from edb.pgsql import parser as pg_parser from . import dbstate +from . import enums if TYPE_CHECKING: from edb.pgsql import resolver as pg_resolver @@ -62,6 +67,11 @@ def compile_sql( current_user: str, allow_user_specified_id: Optional[bool], apply_access_policies_sql: Optional[bool], + include_edgeql_io_format_alternative: bool = False, + allow_prepared_statements: bool = True, + disambiguate_column_names: bool, + backend_runtime_params: pg_params.BackendRuntimeParams, + protocol_version: defines.ProtocolVersion, ) -> List[dbstate.SQLQueryUnit]: opts = ResolverOptionsPartial( query_str=query_str, @@ -69,6 +79,10 @@ def compile_sql( current_user=current_user, allow_user_specified_id=allow_user_specified_id, apply_access_policies_sql=apply_access_policies_sql, + include_edgeql_io_format_alternative=( + include_edgeql_io_format_alternative + ), + disambiguate_column_names=disambiguate_column_names, ) stmts = pg_parser.parse(query_str, propagate_spans=True) @@ -76,6 +90,7 @@ def compile_sql( for stmt in stmts: orig_text = pg_codegen.generate_source(stmt) fe_settings = tx_state.current_fe_settings() + track_stats = False unit = dbstate.SQLQueryUnit( orig_query=orig_text, @@ -118,6 +133,8 @@ def compile_sql( unit.set_vars = {stmt.name: value} unit.is_local = stmt.scope == pgast.OptionsScope.TRANSACTION + if not unit.is_local: + unit.capabilities |= enums.Capability.SESSION_CONFIG elif isinstance(stmt, pgast.VariableShowStmt): unit.get_var = stmt.name @@ -143,33 +160,46 @@ def compile_sql( elif isinstance(stmt, (pgast.BeginStmt, pgast.StartStmt)): unit.tx_action = dbstate.TxAction.START + unit.command_complete_tag = dbstate.TagPlain( + tag=b"START TRANSACTION" + ) elif isinstance(stmt, pgast.CommitStmt): unit.tx_action = dbstate.TxAction.COMMIT unit.tx_chain = stmt.chain or False + unit.command_complete_tag = dbstate.TagPlain(tag=b"COMMIT") elif isinstance(stmt, pgast.RollbackStmt): unit.tx_action = dbstate.TxAction.ROLLBACK unit.tx_chain = stmt.chain or False + unit.command_complete_tag = dbstate.TagPlain(tag=b"ROLLBACK") elif isinstance(stmt, pgast.SavepointStmt): unit.tx_action = dbstate.TxAction.DECLARE_SAVEPOINT unit.sp_name = stmt.savepoint_name + unit.command_complete_tag = dbstate.TagPlain(tag=b"SAVEPOINT") elif isinstance(stmt, pgast.ReleaseStmt): unit.tx_action = dbstate.TxAction.RELEASE_SAVEPOINT unit.sp_name = stmt.savepoint_name + unit.command_complete_tag = dbstate.TagPlain(tag=b"RELEASE") elif isinstance(stmt, pgast.RollbackToStmt): unit.tx_action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT unit.sp_name = stmt.savepoint_name + unit.command_complete_tag = dbstate.TagPlain(tag=b"ROLLBACK") elif isinstance(stmt, pgast.TwoPhaseTransactionStmt): raise NotImplementedError( "two-phase transactions are not supported" ) elif isinstance(stmt, pgast.PrepareStmt): + if not allow_prepared_statements: + raise errors.UnsupportedFeatureError( + "SQL prepared statements are not supported" + ) + # Translate the underlying query. - stmt_resolved, stmt_source = resolve_query( + stmt_resolved, stmt_source, _ = resolve_query( stmt.query, schema, tx_state, opts ) if stmt.argtypes: @@ -200,8 +230,14 @@ def compile_sql( translation_data=stmt_source.translation_data, ) unit.command_complete_tag = dbstate.TagPlain(tag=b"PREPARE") + track_stats = True elif isinstance(stmt, pgast.ExecuteStmt): + if not allow_prepared_statements: + raise errors.UnsupportedFeatureError( + "SQL prepared statements are not supported" + ) + orig_name = stmt.name mangled_name = prepared_stmt_map.get(orig_name) if not mangled_name: @@ -216,7 +252,14 @@ def compile_sql( stmt_name=orig_name, be_stmt_name=mangled_name.encode("utf-8"), ) + unit.cardinality = enums.Cardinality.MANY + track_stats = True + elif isinstance(stmt, pgast.DeallocateStmt): + if not allow_prepared_statements: + raise errors.UnsupportedFeatureError( + "SQL prepared statements are not supported" + ) orig_name = stmt.name mangled_name = prepared_stmt_map.get(orig_name) if not mangled_name: @@ -238,19 +281,80 @@ def compile_sql( raise NotImplementedError("exclusive lock is not supported") # just ignore unit.query = "DO $$ BEGIN END $$;" - else: - assert isinstance(stmt, (pgast.Query, pgast.CopyStmt)) - stmt_resolved, stmt_source = resolve_query( + elif isinstance(stmt, (pgast.Query, pgast.CopyStmt)): + stmt_resolved, stmt_source, edgeql_fmt_src = resolve_query( stmt, schema, tx_state, opts ) - unit.query = stmt_source.text unit.translation_data = stmt_source.translation_data + if edgeql_fmt_src is not None: + unit.eql_format_query = edgeql_fmt_src.text + unit.eql_format_translation_data = ( + edgeql_fmt_src.translation_data + ) unit.command_complete_tag = stmt_resolved.command_complete_tag unit.params = stmt_resolved.params + if isinstance(stmt, pgast.DMLQuery) and not stmt.returning_list: + unit.cardinality = enums.Cardinality.NO_RESULT + else: + unit.cardinality = enums.Cardinality.MANY + track_stats = True + else: + from edb.pgsql import resolver as pg_resolver + pg_resolver.dispatch._raise_unsupported(stmt) unit.stmt_name = compute_stmt_name(unit.query, tx_state).encode("utf-8") + if track_stats and backend_runtime_params.has_stat_statements: + cconfig: dict[str, dbstate.SQLSetting] = { + k: v for k, v in fe_settings.items() + if k is not None and v is not None and k in FE_SETTINGS_MUTABLE + } + cconfig.pop('server_version', None) + cconfig.pop('server_version_num', None) + if allow_user_specified_id is not None: + cconfig.setdefault( + 'allow_user_specified_id', + ('true' if allow_user_specified_id else 'false',), + ) + if apply_access_policies_sql is not None: + cconfig.setdefault( + 'apply_access_policies_sql', + ('true' if apply_access_policies_sql else 'false',), + ) + search_path = parse_search_path(cconfig.pop("search_path", ("",))) + cconfig = dict(sorted((k, v) for k, v in cconfig.items())) + extras = { + 'cc': cconfig, # compilation_config + 'pv': protocol_version, # protocol_version + 'dn': ', '.join(search_path), # default_namespace + } + sql_info = { + 'query': orig_text, + 'type': defines.QueryType.SQL, + 'extras': json.dumps(extras), + } + id_hash = hashlib.blake2b(digest_size=16) + id_hash.update( + json.dumps(sql_info).encode(defines.EDGEDB_ENCODING) + ) + sql_info['id'] = str(uuidgen.from_bytes(id_hash.digest())) + prefix = ''.join([ + '-- ', + json.dumps(sql_info), + '\n', + ]) + unit.prefix_len = len(prefix) + unit.query = prefix + unit.query + if unit.eql_format_query is not None: + unit.eql_format_query = prefix + unit.eql_format_query + + if isinstance(stmt, pgast.DMLQuery): + unit.capabilities |= enums.Capability.MODIFICATIONS + + if unit.tx_action is not None: + unit.capabilities |= enums.Capability.TRANSACTION + tx_state.apply(unit) sql_units.append(unit) @@ -274,6 +378,8 @@ class ResolverOptionsPartial: query_str: str allow_user_specified_id: Optional[bool] apply_access_policies_sql: Optional[bool] + include_edgeql_io_format_alternative: Optional[bool] + disambiguate_column_names: bool def resolve_query( @@ -281,7 +387,11 @@ def resolve_query( schema: s_schema.Schema, tx_state: dbstate.SQLTransactionState, opts: ResolverOptionsPartial, -) -> Tuple[pg_resolver.ResolvedSQL, pg_codegen.SQLSource]: +) -> tuple[ + pg_resolver.ResolvedSQL, + pg_codegen.SQLSource, + Optional[pg_codegen.SQLSource], +]: from edb.pgsql import resolver as pg_resolver search_path: Sequence[str] = ("public",) @@ -314,10 +424,21 @@ def resolve_query( search_path=search_path, allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies, + include_edgeql_io_format_alternative=( + opts.include_edgeql_io_format_alternative + ), + disambiguate_column_names=opts.disambiguate_column_names, ) resolved = pg_resolver.resolve(stmt, schema, options) source = pg_codegen.generate(resolved.ast, with_translation_data=True) - return resolved, source + if resolved.edgeql_output_format_ast is not None: + edgeql_format_source = pg_codegen.generate( + resolved.edgeql_output_format_ast, + with_translation_data=True, + ) + else: + edgeql_format_source = None + return resolved, source, edgeql_format_source def lookup_bool_setting( diff --git a/edb/server/compiler_pool/pool.py b/edb/server/compiler_pool/pool.py index 95c7f290c25..97820b3fe1e 100644 --- a/edb/server/compiler_pool/pool.py +++ b/edb/server/compiler_pool/pool.py @@ -603,6 +603,10 @@ async def analyze_explain_output(self, *args, **kwargs): return await self._simple_call( 'analyze_explain_output', *args, **kwargs) + async def validate_schema_equivalence(self, *args, **kwargs): + return await self._simple_call( + 'validate_schema_equivalence', *args, **kwargs) + def get_debug_info(self): return {} diff --git a/edb/server/config/__init__.py b/edb/server/config/__init__.py index af629499294..6ddd9fdd2a4 100644 --- a/edb/server/config/__init__.py +++ b/edb/server/config/__init__.py @@ -26,7 +26,9 @@ from edb.edgeql.qltypes import ConfigScope from .ops import OpCode, Operation, SettingValue -from .ops import spec_to_json, to_json, from_json, set_value, to_edgeql +from .ops import ( + spec_to_json, to_json_obj, to_json, from_json, set_value, to_edgeql +) from .ops import value_from_json, value_to_json_value, coerce_single_value from .spec import ( Spec, FlatSpec, ChainedSpec, Setting, @@ -40,8 +42,8 @@ __all__ = ( 'lookup', 'Spec', 'FlatSpec', 'ChainedSpec', 'Setting', 'SettingValue', - 'spec_to_json', 'to_json', 'to_edgeql', 'from_json', 'set_value', - 'value_from_json', 'value_to_json_value', + 'spec_to_json', 'to_json_obj', 'to_json', 'to_edgeql', 'from_json', + 'set_value', 'value_from_json', 'value_to_json_value', 'ConfigScope', 'OpCode', 'Operation', 'ConfigType', 'CompositeConfigType', 'load_spec_from_schema', 'load_ext_spec_from_schema', diff --git a/edb/server/config/ops.py b/edb/server/config/ops.py index 5244acba9c9..57194a5b52b 100644 --- a/edb/server/config/ops.py +++ b/edb/server/config/ops.py @@ -24,6 +24,7 @@ from typing import ( Any, Callable, + Dict, Optional, TypeVar, Union, @@ -429,17 +430,17 @@ def value_to_edgeql_const( return qlcodegen.generate_source(ql) -def to_json( +def to_json_obj( spec: spec.Spec, storage: Mapping[str, SettingValue], *, setting_filter: Optional[Callable[[SettingValue], bool]] = None, include_source: bool = True, -) -> str: +) -> Dict[str, Any]: dct = {} for name, value in storage.items(): - setting = spec[name] if setting_filter is None or setting_filter(value): + setting = spec[name] val = value_to_json_value(setting, value.value) if include_source: dct[name] = { @@ -450,6 +451,22 @@ def to_json( } else: dct[name] = val + return dct + + +def to_json( + spec: spec.Spec, + storage: Mapping[str, SettingValue], + *, + setting_filter: Optional[Callable[[SettingValue], bool]] = None, + include_source: bool = True, +) -> str: + dct = to_json_obj( + spec, + storage, + setting_filter=setting_filter, + include_source=include_source, + ) return json.dumps(dct) diff --git a/edb/server/conn_pool/src/python.rs b/edb/server/conn_pool/src/python.rs index 4cdd6f40e89..2d05b73e25f 100644 --- a/edb/server/conn_pool/src/python.rs +++ b/edb/server/conn_pool/src/python.rs @@ -14,6 +14,7 @@ use std::{ os::fd::IntoRawFd, pin::Pin, rc::Rc, + sync::Mutex, thread, time::{Duration, Instant}, }; @@ -37,21 +38,22 @@ enum RustToPythonMessage { Metrics(Vec), } -impl ToPyObject for RustToPythonMessage { - fn to_object(&self, py: Python<'_>) -> PyObject { +impl RustToPythonMessage { + fn to_object(&self, py: Python<'_>) -> PyResult { use RustToPythonMessage::*; match self { - Acquired(a, b) => (0, a, b.0).to_object(py), - PerformConnect(conn, s) => (1, conn.0, s).to_object(py), - PerformDisconnect(conn) => (2, conn.0).to_object(py), - PerformReconnect(conn, s) => (3, conn.0, s).to_object(py), - Pruned(conn) => (4, conn).to_object(py), - Failed(conn, error) => (5, conn, error.0).to_object(py), + Acquired(a, b) => (0, a, b.0).into_pyobject(py), + PerformConnect(conn, s) => (1, conn.0, s).into_pyobject(py), + PerformDisconnect(conn) => (2, conn.0).into_pyobject(py), + PerformReconnect(conn, s) => (3, conn.0, s).into_pyobject(py), + Pruned(conn) => (4, conn).into_pyobject(py), + Failed(conn, error) => (5, conn, error.0).into_pyobject(py), Metrics(metrics) => { // This is not really fast but it should not be happening very often - (6, PyByteArray::new_bound(py, &metrics)).to_object(py) + (6, PyByteArray::new(py, &metrics)).into_pyobject(py) } } + .map(|e| e.into()) } } @@ -175,7 +177,7 @@ impl Connector for Rc { #[pyclass] struct ConnPool { python_to_rust: tokio::sync::mpsc::UnboundedSender, - rust_to_python: std::sync::mpsc::Receiver, + rust_to_python: Mutex>, notify_fd: u64, } @@ -328,7 +330,7 @@ impl ConnPool { let notify_fd = rxfd.recv().unwrap(); ConnPool { python_to_rust: txpr, - rust_to_python: rxrp, + rust_to_python: Mutex::new(rxrp), notify_fd, } } @@ -374,19 +376,35 @@ impl ConnPool { .map_err(|_| internal_error("In shutdown")) } - fn _read(&self, py: Python<'_>) -> Py { - let Ok(msg) = self.rust_to_python.recv() else { - return py.None(); + fn _read(&self, py: Python<'_>) -> PyResult> { + let Ok(msg) = self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") + .try_recv() + else { + return Ok(py.None()); }; msg.to_object(py) } - fn _try_read(&self, py: Python<'_>) -> Py { - let Ok(msg) = self.rust_to_python.try_recv() else { - return py.None(); + fn _try_read(&self, py: Python<'_>) -> PyResult> { + let Ok(msg) = self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") + .try_recv() + else { + return Ok(py.None()); }; msg.to_object(py) } + + fn _close_pipe(&mut self) { + // Replace the channel with a dummy, closed one which will also + // signal the other side to exit. + self.rust_to_python = Mutex::new(std::sync::mpsc::channel().1); + } } /// Ensure that logging does not outlive the Python runtime. @@ -400,13 +418,12 @@ struct LoggingGuard { impl LoggingGuard { #[new] fn init_logging(py: Python) -> PyResult { - let logging = py.import_bound("logging")?; + let logging = py.import("logging")?; let logger = logging.getattr("getLogger")?.call(("edb.server",), None)?; let level = logger .getattr("getEffectiveLevel")? .call((), None)? .extract::()?; - let logger = logger.to_object(py); struct PythonSubscriber { logger: Py, @@ -469,7 +486,9 @@ impl LoggingGuard { tracing_subscriber::filter::LevelFilter::OFF }; - let subscriber = PythonSubscriber { logger }; + let subscriber = PythonSubscriber { + logger: logger.into(), + }; let guard = tracing_subscriber::registry() .with(level) .with(subscriber) @@ -484,7 +503,7 @@ impl LoggingGuard { fn _conn_pool(py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add("InternalError", py.get_type_bound::())?; + m.add("InternalError", py.get_type::())?; // Add each metric variant as a constant for variant in MetricVariant::iter() { diff --git a/edb/server/connpool/pool2.py b/edb/server/connpool/pool2.py index bf965b85e9a..f5b12bdfad5 100644 --- a/edb/server/connpool/pool2.py +++ b/edb/server/connpool/pool2.py @@ -20,11 +20,11 @@ import time import typing import dataclasses -import os import pickle from . import config from .config import logger +from edb.server import rust_async_channel guard = edb.server._conn_pool.LoggingGuard() @@ -101,27 +101,31 @@ class Pool(typing.Generic[C]): _errors: dict[int, BaseException] _conns_held: dict[C, int] _loop: asyncio.AbstractEventLoop - _skip_reads: int _counts: typing.Any _stats_collector: typing.Optional[StatsCollector] - def __init__(self, *, connect: Connector[C], - disconnect: Disconnector[C], - max_capacity: int, - stats_collector: typing.Optional[StatsCollector]=None, - min_idle_time_before_gc: float = config.MIN_IDLE_TIME_BEFORE_GC - ) -> None: + def __init__( + self, + *, + connect: Connector[C], + disconnect: Disconnector[C], + max_capacity: int, + stats_collector: typing.Optional[StatsCollector] = None, + min_idle_time_before_gc: float = config.MIN_IDLE_TIME_BEFORE_GC, + ) -> None: # Re-load the logger if it's been mocked for testing global logger logger = config.logger - logger.info(f'Creating a connection pool with \ - max_capacity={max_capacity}') + logger.info( + f'Creating a connection pool with \ + max_capacity={max_capacity}' + ) self._connect = connect self._disconnect = disconnect - self._pool = edb.server._conn_pool.ConnPool(max_capacity, - min_idle_time_before_gc, - config.STATS_COLLECT_INTERVAL) + self._pool = edb.server._conn_pool.ConnPool( + max_capacity, min_idle_time_before_gc, config.STATS_COLLECT_INTERVAL + ) self._max_capacity = max_capacity self._cur_capacity = 0 self._next_conn_id = 0 @@ -130,10 +134,14 @@ def __init__(self, *, connect: Connector[C], self._errors = {} self._conns_held = {} self._prunes = {} - self._skip_reads = 0 self._loop = asyncio.get_running_loop() - self._task = self._loop.create_task(self._boot(self._loop)) + self._channel = rust_async_channel.RustAsyncChannel( + self._pool, + self._process_message, + ) + + self._task = self._loop.create_task(self._boot(self._channel)) self._failed_connects = 0 self._failed_disconnects = 0 @@ -170,34 +178,19 @@ async def close(self) -> None: self._pool = None logger.info("Closed connection pool") - async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: + async def _boot( + self, + channel: rust_async_channel.RustAsyncChannel, + ) -> None: logger.info("Python-side connection pool booted") - reader = asyncio.StreamReader(loop=loop) - reader_protocol = asyncio.StreamReaderProtocol(reader) - fd = os.fdopen(self._pool._fd, 'rb') - transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd) try: - while len(await reader.read(1)) == 1: - if not self._pool or not self._task: - break - if self._skip_reads > 0: - self._skip_reads -= 1 - continue - msg = self._pool._read() - if not msg: - break - self._process_message(msg) - + await channel.run() finally: - transport.close() + channel.close() - # Allow readers to skip the self-pipe for performing reads which may reduce - # latency a small degree. We'll still need to eventually pick up a self-pipe - # read but we increment a counter to skip at that point. def _try_read(self) -> None: - while msg := self._pool._try_read(): - self._skip_reads += 1 - self._process_message(msg) + if self._channel: + self._channel.read_hint() def _process_message(self, msg: typing.Any) -> None: # If we're closing, don't dispatch any operations @@ -228,7 +221,9 @@ def _process_message(self, msg: typing.Any) -> None: # Pickled metrics self._counts = pickle.loads(msg[1]) if self._stats_collector: - self._stats_collector(self._build_snapshot(now=time.monotonic())) + self._stats_collector( + self._build_snapshot(now=time.monotonic()) + ) else: logger.critical(f'Unexpected message: {msg}') @@ -310,12 +305,16 @@ async def acquire(self, dbname: str) -> C: # Allow the final exception to escape if i == config.CONNECT_FAILURE_RETRIES: - logger.exception('Failed to acquire connection, will not ' - f'retry {dbname} ({self._cur_capacity}' - 'active)') + logger.exception( + 'Failed to acquire connection, will not ' + f'retry {dbname} ({self._cur_capacity}' + 'active)' + ) raise - logger.exception('Failed to acquire connection, will retry: ' - f'{dbname} ({self._cur_capacity} active)') + logger.exception( + 'Failed to acquire connection, will retry: ' + f'{dbname} ({self._cur_capacity} active)' + ) raise AssertionError("Unreachable end of loop") def release(self, dbname: str, conn: C, discard: bool = False) -> None: @@ -363,10 +362,10 @@ def _build_snapshot(self, *, now: float) -> Snapshot: dbname=dbname, nconns=v[edb.server._conn_pool.METRIC_ACTIVE], nwaiters_avg=v[edb.server._conn_pool.METRIC_WAITING], - npending=v[edb.server._conn_pool.METRIC_CONNECTING] + - v[edb.server._conn_pool.METRIC_RECONNECTING], + npending=v[edb.server._conn_pool.METRIC_CONNECTING] + + v[edb.server._conn_pool.METRIC_RECONNECTING], nwaiters=v[edb.server._conn_pool.METRIC_WAITING], - quota=stats['target'] + quota=stats['target'], ) blocks.append(block_snapshot) pass @@ -376,7 +375,6 @@ def _build_snapshot(self, *, now: float) -> Snapshot: blocks=blocks, capacity=self._cur_capacity, log=[], - failed_connects=self._failed_connects, failed_disconnects=self._failed_disconnects, successful_connects=self._successful_connects, diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 45598caefb7..4e3342912cb 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -40,6 +40,8 @@ cdef class CompiledQuery: cdef public object first_extra # Optional[int] cdef public object extra_counts cdef public object extra_blobs + cdef public bint extra_formatted_as_text + cdef public object extra_type_oids cdef public object request cdef public object recompiled_cache cdef public bint use_pending_func_cache @@ -90,6 +92,7 @@ cdef class Database: readonly bytes user_schema_pickle readonly object reflection_cache readonly object backend_ids + readonly object backend_id_to_name readonly object extensions readonly object _feature_used_metrics diff --git a/edb/server/dbview/dbview.pyi b/edb/server/dbview/dbview.pyi index 6a0d59fefd4..710eb09db3f 100644 --- a/edb/server/dbview/dbview.pyi +++ b/edb/server/dbview/dbview.pyi @@ -182,7 +182,7 @@ class DatabaseIndex: schema_version: Optional[uuid.UUID], db_config: Optional[Config], reflection_cache: Optional[Mapping[str, tuple[str, ...]]], - backend_ids: Optional[Mapping[str, int]], + backend_ids: Optional[Mapping[str, tuple[int, str]]], extensions: Optional[set[str]], ext_config_settings: Optional[list[config.Setting]], early: bool = False, diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index f91f34763b1..0a7d8ea620e 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -40,7 +40,8 @@ from edb.common import debug, lru, uuidgen, asyncutil from edb import edgeql from edb.edgeql import qltypes from edb.schema import schema as s_schema -from edb.server import compiler, defines, config, metrics +from edb.schema import name as s_name +from edb.server import compiler, defines, config, metrics, pgcon from edb.server.compiler import dbstate, enums, sertypes from edb.server.protocol import execute from edb.pgsql import dbops @@ -96,6 +97,8 @@ cdef class CompiledQuery: first_extra: Optional[int]=None, extra_counts=(), extra_blobs=(), + extra_formatted_as_text: bool = False, + extra_type_oids: Sequence[int] = (), request=None, recompiled_cache=None, use_pending_func_cache=False, @@ -104,6 +107,8 @@ cdef class CompiledQuery: self.first_extra = first_extra self.extra_counts = extra_counts self.extra_blobs = extra_blobs + self.extra_formatted_as_text = extra_formatted_as_text + self.extra_type_oids = tuple(extra_type_oids) self.request = request self.recompiled_cache = recompiled_cache self.use_pending_func_cache = use_pending_func_cache @@ -160,6 +165,12 @@ cdef class Database: self.user_config_spec = config.FlatSpec(*ext_config_settings) self.reflection_cache = reflection_cache self.backend_ids = backend_ids + if backend_ids is not None: + self.backend_id_to_name = { + v[0]: v[1] for k, v in backend_ids.items() + } + else: + self.backend_id_to_name = {} self.extensions = set() self._set_extensions(extensions) self._observe_auth_ext_config() @@ -367,6 +378,9 @@ cdef class Database: if backend_ids is not None: self.backend_ids = backend_ids + self.backend_id_to_name = { + v[0]: v[1] for k, v in backend_ids.items() + } if reflection_cache is not None: self.reflection_cache = reflection_cache if db_config is not None: @@ -402,6 +416,9 @@ cdef class Database: cdef _update_backend_ids(self, new_types): self.backend_ids.update(new_types) + self.backend_id_to_name.update({ + v[0]: v[1] for k, v in new_types.items() + }) cdef _invalidate_caches(self): self._sql_to_compiled.clear() @@ -727,15 +744,18 @@ cdef class DatabaseConnectionView: if self._in_tx: try: - return int(self._in_tx_new_types[type_id]) + tinfo = self._in_tx_new_types[type_id] except KeyError: pass + else: + return int(tinfo[0]) - tid = self._db.backend_ids.get(type_id) - if tid is None: + tinfo = self._db.backend_ids.get(type_id) + if tinfo is None: raise RuntimeError( f'cannot resolve backend OID for type {type_id}') - return tid + + return int(tinfo[0]) cdef bytes serialize_state(self): cdef list state @@ -1220,9 +1240,10 @@ cdef class DatabaseConnectionView: async def parse( self, query_req: rpc.CompilationRequest, - cached_globally=False, - bint use_metrics=True, - uint64_t allow_capabilities = compiler.Capability.ALL, + cached_globally: bint = False, + use_metrics: bint = True, + allow_capabilities: uint64_t = compiler.Capability.ALL, + pgcon: pgcon.PGConnection | None = None, ) -> CompiledQuery: query_unit_group = None if self._query_cache_enabled: @@ -1307,6 +1328,18 @@ cdef class DatabaseConnectionView: ) self._check_in_tx_error(query_unit_group) + if query_req.input_language is enums.InputLanguage.SQL: + if pgcon is None: + raise errors.InternalServerError( + "a valid backend connection is required to fully " + "compile a query in SQL mode", + ) + await self._amend_typedesc_in_sql( + query_req, + query_unit_group, + pgcon, + ) + if self._query_cache_enabled and query_unit_group.cacheable: if cached_globally: self.server.system_compile_cache[query_req] = ( @@ -1366,10 +1399,112 @@ cdef class DatabaseConnectionView: first_extra=source.first_extra(), extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), + extra_formatted_as_text=source.extra_formatted_as_text(), + extra_type_oids=source.extra_type_oids(), request=query_req, recompiled_cache=recompiled_cache, ) + async def _amend_typedesc_in_sql( + self, + query_req: rpc.CompilationRequest, + qug: dbstate.QueryUnitGroup, + pgcon: pgcon.PGConnection, + ) -> None: + # The SQL QueryUnitGroup as initially returned from the compiler + # is missing the input/output type descriptors because we currently + # don't run static SQL type inference. To mend that we ask Postgres + # to infer the the result types (as an OID tuple) and then use + # our OID -> scalar type mapping to construct an EdgeQL free shape with + # corresponding properties which we then send to the compiler to + # compute the type descriptors. + to_describe = [] + + desc_map = {} + source = query_req.source + first_extra = source.first_extra() + if first_extra is not None: + all_type_oids = [0] * first_extra + source.extra_type_oids() + else: + all_type_oids = [] + + for i, query_unit in enumerate(qug): + if query_unit.cardinality is enums.Cardinality.NO_RESULT: + continue + + intro_sql = query_unit.introspection_sql + if intro_sql is None: + intro_sql = query_unit.sql[0] + param_desc, result_desc = await pgcon.sql_describe( + intro_sql, all_type_oids) + result_types = [] + for col, toid in result_desc: + edb_type_expr = self._db.backend_id_to_name.get(toid) + if edb_type_expr is None: + raise errors.UnsupportedFeatureError( + f"unsupported SQL type in column \"{col}\" " + f"with type OID {toid}" + ) + + result_types.append( + f"{edgeql.quote_ident(col)} := <{edb_type_expr}>{{}}" + ) + if first_extra is not None: + param_desc = param_desc[:first_extra] + params = [] + for pi, toid in enumerate(param_desc): + edb_type_expr = self._db.backend_id_to_name.get(toid) + if edb_type_expr is None: + raise errors.UnsupportedFeatureError( + f"unsupported type in SQL parameter ${pi} " + f"with type OID {toid}" + ) + + params.append( + f"_p{pi} := <{edb_type_expr}>${pi}" + ) + + intro_qry = "" + if params: + intro_qry += "with _p := {" + ", ".join(params) + "} " + + if result_types: + intro_qry += "select {" + ", ".join(result_types) + "}" + else: + # No direct syntactic way of constructing an empty shape, + # so we have to do it this way. + intro_qry += "select {foo := 1}{}" + to_describe.append(intro_qry) + + desc_map[len(to_describe) - 1] = i + + if to_describe: + desc_req = rpc.CompilationRequest( + source=edgeql.Source.from_string(";\n".join(to_describe)), + protocol_version=query_req.protocol_version, + schema_version=query_req.schema_version, + compilation_config_serializer=query_req.serializer, + ) + + desc_qug = await self._compile(desc_req) + + for i, desc_qu in enumerate(desc_qug): + qu_i = desc_map[i] + qug[qu_i].out_type_data = desc_qu.out_type_data + qug[qu_i].out_type_id = desc_qu.out_type_id + qug[qu_i].in_type_data = desc_qu.in_type_data + qug[qu_i].in_type_id = desc_qu.in_type_id + qug[qu_i].in_type_args = desc_qu.in_type_args + qug[qu_i].in_type_args_real_count = ( + desc_qu.in_type_args_real_count) + + qug.out_type_data = desc_qug.out_type_data + qug.out_type_id = desc_qug.out_type_id + qug.in_type_data = desc_qug.in_type_data + qug.in_type_id = desc_qug.in_type_id + qug.in_type_args = desc_qug.in_type_args + qug.in_type_args_real_count = desc_qug.in_type_args_real_count + cdef inline _check_in_tx_error(self, query_unit_group): if self.in_tx_error(): # The current transaction is aborted, so we must fail @@ -1404,6 +1539,8 @@ cdef class DatabaseConnectionView: first_extra=source.first_extra(), extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), + extra_formatted_as_text=source.extra_formatted_as_text(), + extra_type_oids=source.extra_type_oids(), use_pending_func_cache=use_pending_func_cache, ) diff --git a/edb/server/defines.py b/edb/server/defines.py index bbfa973bc74..c612865cf1c 100644 --- a/edb/server/defines.py +++ b/edb/server/defines.py @@ -20,6 +20,8 @@ from __future__ import annotations from typing import TypeAlias +import enum + from edb import buildmeta from edb.common import enum as s_enum @@ -81,7 +83,7 @@ ProtocolVersion: TypeAlias = tuple[int, int] MIN_PROTOCOL: ProtocolVersion = (1, 0) -CURRENT_PROTOCOL: ProtocolVersion = (2, 0) +CURRENT_PROTOCOL: ProtocolVersion = (3, 0) MIN_SUGGESTED_CLIENT_POOL_SIZE = 10 MAX_SUGGESTED_CLIENT_POOL_SIZE = 100 @@ -100,3 +102,10 @@ class TxIsolationLevel(s_enum.StrEnum): RepeatableRead = 'REPEATABLE READ' Serializable = 'SERIALIZABLE' + + +# Mapping to the backend `edb_stat_statements.stmt_type` values, +# as well as `sys::QueryType` in edb/lib/sys.edgeql +class QueryType(enum.IntEnum): + EdgeQL = 1 + SQL = 2 diff --git a/edb/server/http.py b/edb/server/http.py index d5b962dd3ec..ae37b148af2 100644 --- a/edb/server/http.py +++ b/edb/server/http.py @@ -26,18 +26,19 @@ Union, Self, Callable, + List, ) import asyncio import dataclasses import logging -import os import json as json_lib import urllib.parse import time from http import HTTPStatus as HTTPStatus from edb.server._http import Http +from . import rust_async_channel logger = logging.getLogger("edb.server") HeaderType = Optional[Union[list[tuple[str, str]], dict[str, str]]] @@ -80,29 +81,26 @@ def __init__( self._stat_callback = stat_callback def __del__(self) -> None: - self.close() - - def __enter__(self) -> HttpClient: - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.close() + if not self.closed(): + logger.error(f"HttpClient {id(self)} was not closed") def close(self) -> None: - if self._task is not None: - self._task.cancel() - self._task = None + if not self.closed(): + if self._task is not None: + self._task.cancel() + self._task = None self._loop = None self._client = None + def closed(self) -> bool: + return self._task is None and self._loop is None + def _ensure_task(self): - if self._loop is None: + if self.closed(): raise Exception("HttpClient was closed") if self._task is None: self._client = Http(self._limit) - self._task = self._loop.create_task( - self._boot(self._loop, self._client._fd) - ) + self._task = self._loop.create_task(self._boot(self._client)) def _ensure_client(self): if self._client is None: @@ -129,7 +127,6 @@ def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]: return [(k, v) for k, v in headers.items()] if isinstance(headers, list): return headers - print(headers) raise ValueError(f"Invalid headers type: {type(headers)}") def _process_content( @@ -192,7 +189,7 @@ async def request( id = self._next_id self._next_id += 1 self._requests[id] = asyncio.Future() - start_time = time.time() + start_time = time.monotonic() try: self._ensure_client()._request(id, path, method, data, headers_list) resp = await self._requests[id] @@ -200,7 +197,9 @@ async def request( status_code, body, headers = resp self._stat_callback( HttpStat( - response_time_ms=int((time.time() - start_time) * 1000), + response_time_ms=int( + (time.monotonic() - start_time) * 1000 + ), error_code=status_code, response_body_size=len(body), response_content_type=dict(headers_list).get( @@ -255,7 +254,7 @@ async def stream_sse( id = self._next_id self._next_id += 1 self._requests[id] = asyncio.Future() - start_time = time.time() + start_time = time.monotonic() try: self._ensure_client()._request_sse( id, path, method, data, headers_list @@ -269,7 +268,9 @@ async def stream_sse( status_code, body, headers = resp self._stat_callback( HttpStat( - response_time_ms=int((time.time() - start_time) * 1000), + response_time_ms=int( + (time.monotonic() - start_time) * 1000 + ), error_code=status_code, response_body_size=len(body), response_content_type=dict(headers_list).get( @@ -295,28 +296,21 @@ async def stream_sse( finally: del self._requests[id] - async def _boot(self, loop: asyncio.AbstractEventLoop, fd: int) -> None: + async def _boot(self, client) -> None: logger.info(f"HTTP client initialized, user_agent={self._user_agent}") - reader = asyncio.StreamReader(loop=loop) - reader_protocol = asyncio.StreamReaderProtocol(reader) - transport, _ = await loop.connect_read_pipe( - lambda: reader_protocol, os.fdopen(fd, 'rb') - ) try: - while len(await reader.read(1)) == 1: - if not self._client or not self._task: - break - if self._skip_reads > 0: - self._skip_reads -= 1 - continue - msg = self._client._read() - if not msg: - break - self._process_message(msg) - finally: - transport.close() + channel = rust_async_channel.RustAsyncChannel( + client, self._process_message + ) + try: + await channel.run() + finally: + channel.close() + except Exception: + logger.error(f"Error in HTTP client", exc_info=True) + raise - def _process_message(self, msg): + def _process_message(self, msg: Tuple[Any, ...]) -> None: try: msg_type, id, data = msg if msg_type == 0: # Error @@ -347,7 +341,7 @@ async def __aenter__(self) -> Self: return self async def __aexit__(self, *args) -> None: # type: ignore - pass + self.close() class HttpClientContext(HttpClient): @@ -358,12 +352,27 @@ def __init__( headers: HeaderType = None, base_url: str | None = None, ): - self._task = None self.url_munger = url_munger self.http_client = http_client self.base_url = base_url self.headers = super()._process_headers(headers) + # HttpClientContext does not need to be closed + def __del__(self): + pass + + def closed(self): + return super().closed() + + def close(self): + pass + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args) -> None: # type: ignore + pass + def _process_headers(self, headers): headers = super()._process_headers(headers) headers += self.headers @@ -395,12 +404,6 @@ async def stream_sse( path, method=method, headers=headers, data=data, json=json ) - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, *args) -> None: # type: ignore - pass - class CaseInsensitiveDict(dict): def __init__(self, data: Optional[list[Tuple[str, str]]] = None): @@ -468,6 +471,7 @@ class ResponseSSE: _stream: asyncio.Queue = dataclasses.field(repr=False) _cancel: Callable[[], None] = dataclasses.field(repr=False) _ack: Callable[[], None] = dataclasses.field(repr=False) + _closed: List[bool] = dataclasses.field(default_factory=lambda: [False]) is_streaming: bool = True @classmethod @@ -492,18 +496,27 @@ def json(self): return json_lib.loads(self.data) def close(self): - self._cancel() + if not self.closed(): + self._closed[0] = True + self._cancel() + + def closed(self) -> bool: + return self._closed[0] def __del__(self): - self.close() + if not self.closed(): + logger.error(f"ResponseSSE {id(self)} was not closed") def __aiter__(self): return self async def __anext__(self): + if self.closed(): + raise StopAsyncIteration next = await self._stream.get() try: if next is None: + self.close() raise StopAsyncIteration id, data, event = next return self.SSEEvent(event, data, id) diff --git a/edb/server/http/src/python.rs b/edb/server/http/src/python.rs index 695454b9cf3..83e113afd76 100644 --- a/edb/server/http/src/python.rs +++ b/edb/server/http/src/python.rs @@ -1,7 +1,7 @@ use eventsource_stream::Eventsource; use futures::{future::poll_fn, TryStreamExt}; use pyo3::{exceptions::PyException, prelude::*, types::PyByteArray}; -use reqwest::{header::HeaderValue, Method}; +use reqwest::Method; use scopeguard::{defer, guard, ScopeGuard}; use std::{ cell::RefCell, @@ -35,24 +35,22 @@ enum RustToPythonMessage { SSEEnd(PythonConnId), Error(PythonConnId, String), } - -impl ToPyObject for RustToPythonMessage { - fn to_object(&self, py: Python<'_>) -> PyObject { +impl RustToPythonMessage { + fn to_object(&self, py: Python<'_>) -> PyResult { use RustToPythonMessage::*; + trace!("Read: {self:?}"); match self { - Error(conn, error) => (0, *conn, error).to_object(py), - Response(conn, (status, body, headers)) => ( - 1, - conn, - (*status, PyByteArray::new_bound(py, &body), headers), - ) - .to_object(py), - SSEStart(conn, (status, headers)) => (2, conn, (status, headers)).to_object(py), + Error(conn, error) => (0, conn, error).into_pyobject(py), + Response(conn, (status, body, headers)) => { + (1, conn, (status, PyByteArray::new(py, body), headers)).into_pyobject(py) + } + SSEStart(conn, (status, headers)) => (2, conn, (status, headers)).into_pyobject(py), SSEEvent(conn, message) => { - (3, conn, (&message.id, &message.data, &message.event)).to_object(py) + (3, conn, (&message.id, &message.data, &message.event)).into_pyobject(py) } - SSEEnd(conn) => (4, conn, ()).to_object(py), + SSEEnd(conn) => (4, conn, ()).into_pyobject(py), } + .map(|e| e.into()) } } @@ -86,6 +84,7 @@ impl std::fmt::Debug for RpcPipe { impl RpcPipe { async fn write(&self, msg: RustToPythonMessage) -> Result<(), String> { + trace!("Rust -> Python: {msg:?}"); self.rust_to_python.send(msg).map_err(|_| "Shutdown")?; // If we're shutting down, this may fail (but that's OK) poll_fn(|cx| { @@ -102,7 +101,7 @@ impl RpcPipe { #[pyclass] struct Http { python_to_rust: tokio::sync::mpsc::UnboundedSender, - rust_to_python: std::sync::mpsc::Receiver, + rust_to_python: Mutex>, notify_fd: u64, } @@ -200,6 +199,7 @@ async fn request_sse( )) .await; + trace!("Exiting SSE due to non-SSE response"); ScopeGuard::into_inner(guard); return Ok(()); } @@ -222,10 +222,15 @@ async fn request_sse( return Err(format!("Failed to read response body: {e:?}")); } }; + + // Note that we use semaphores here in a strange way, but basically we + // want to have per-stream backpressure to avoid buffering messages + // indefinitely. let Ok(permit) = backpressure.acquire().await else { break; }; permit.forget(); + if rpc_pipe .write(RustToPythonMessage::SSEEvent(id, chunk)) .await @@ -235,6 +240,7 @@ async fn request_sse( } } + trace!("Exiting SSE"); ScopeGuard::into_inner(guard); Ok(()) } @@ -381,14 +387,27 @@ struct HttpTask { async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { let rpc_pipe = Rc::new(rpc_pipe); + const CONNECT_TIMEOUT: Duration = Duration::from_secs(30); + const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30); + const STANDARD_READ_TIMEOUT: Duration = Duration::from_secs(10); + const STANDARD_TOTAL_TIMEOUT: Duration = Duration::from_secs(120); + const SSE_READ_TIMEOUT: Duration = Duration::from_secs(60 * 60); // 1 hour + // Set some reasonable defaults for timeouts let client = reqwest::Client::builder() - .connect_timeout(Duration::from_secs(30)) - .timeout(Duration::from_secs(120)) - .read_timeout(Duration::from_secs(10)) - .pool_idle_timeout(Duration::from_secs(30)); + .connect_timeout(CONNECT_TIMEOUT) + .timeout(STANDARD_TOTAL_TIMEOUT) + .read_timeout(STANDARD_READ_TIMEOUT) + .pool_idle_timeout(POOL_IDLE_TIMEOUT); let client = client.build().unwrap(); + // SSE requests should have a very long read timeout and no general timeout + let client_sse = reqwest::Client::builder() + .connect_timeout(CONNECT_TIMEOUT) + .read_timeout(SSE_READ_TIMEOUT) + .pool_idle_timeout(POOL_IDLE_TIMEOUT); + let client_sse = client_sse.build().unwrap(); + let permit_manager = Rc::new(PermitManager::new(capacity)); let tasks = Arc::new(Mutex::new(HashMap::::new())); @@ -399,6 +418,7 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { break; }; let client = client.clone(); + let client_sse = client_sse.clone(); trace!("Received RPC: {rpc:?}"); let rpc_pipe = rpc_pipe.clone(); // Allocate a task ID and backpressure object if we're initiating a @@ -417,6 +437,7 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { rpc, permit_manager.clone(), client, + client_sse, rpc_pipe, )); if let (Some(id), Some(backpressure)) = (id, backpressure) { @@ -435,6 +456,7 @@ async fn execute( rpc: PythonToRustMessage, permit_manager: Rc, client: reqwest::Client, + client_sse: reqwest::Client, rpc_pipe: Rc, ) { // If a request task was booted by this request, remove it from the list of @@ -468,13 +490,18 @@ async fn execute( drop(permit); } RequestSse(id, url, method, body, headers) => { - // Ensure we send the end message whenever this block exits - defer!(_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id))); + // Ensure we send the end message whenever this block exits (though + // we need to spawn a task to do so) + defer!({ + let rpc_pipe = rpc_pipe.clone(); + let future = async move { rpc_pipe.write(RustToPythonMessage::SSEEnd(id)).await }; + tokio::task::spawn_local(future); + }); let Ok(permit) = permit_manager.acquire().await else { return; }; match request_sse( - client, + client_sse, id, backpressure.unwrap(), url, @@ -509,6 +536,15 @@ async fn execute( } } +impl Http { + fn send(&self, msg: PythonToRustMessage) -> PyResult<()> { + trace!("Python -> Rust: {msg:?}"); + self.python_to_rust + .send(msg) + .map_err(|_| internal_error("In shutdown")) + } +} + #[pymethods] impl Http { /// Create the HTTP pool and automatically boot a tokio runtime on a @@ -549,7 +585,7 @@ impl Http { let notify_fd = rxfd.recv().unwrap(); Http { python_to_rust: txpr, - rust_to_python: rxrp, + rust_to_python: Mutex::new(rxrp), notify_fd, } } @@ -567,9 +603,7 @@ impl Http { body: Vec, headers: Vec<(String, String)>, ) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::Request(id, url, method, body, headers)) - .map_err(|_| internal_error("In shutdown")) + self.send(PythonToRustMessage::Request(id, url, method, body, headers)) } fn _request_sse( @@ -580,50 +614,59 @@ impl Http { body: Vec, headers: Vec<(String, String)>, ) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::RequestSse( - id, url, method, body, headers, - )) - .map_err(|_| internal_error("In shutdown")) + self.send(PythonToRustMessage::RequestSse( + id, url, method, body, headers, + )) } fn _close(&self, id: PythonConnId) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::Close(id)) - .map_err(|_| internal_error("In shutdown")) + self.send(PythonToRustMessage::Close(id)) } fn _ack_sse(&self, id: PythonConnId) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::Ack(id)) - .map_err(|_| internal_error("In shutdown")) + self.send(PythonToRustMessage::Ack(id)) } fn _update_limit(&self, limit: usize) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::UpdateLimit(limit)) - .map_err(|_| internal_error("In shutdown")) + self.send(PythonToRustMessage::UpdateLimit(limit)) } - fn _read(&self, py: Python<'_>) -> Py { - let Ok(msg) = self.rust_to_python.recv() else { - return py.None(); + fn _read(&self, py: Python<'_>) -> PyResult { + let Ok(msg) = self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") + .recv() + else { + return Ok(py.None()); }; msg.to_object(py) } - fn _try_read(&self, py: Python<'_>) -> Py { - let Ok(msg) = self.rust_to_python.try_recv() else { - return py.None(); + fn _try_read(&self, py: Python<'_>) -> PyResult { + let Ok(msg) = self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") + .try_recv() + else { + return Ok(py.None()); }; msg.to_object(py) } + + fn _close_pipe(&mut self) { + trace!("Closing pipe"); + // Replace the channel with a dummy, closed one which will also + // signal the other side to exit. + self.rust_to_python = Mutex::new(std::sync::mpsc::channel().1); + } } #[pymodule] fn _http(py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; - m.add("InternalError", py.get_type_bound::())?; + m.add("InternalError", py.get_type::())?; Ok(()) } diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index 784eea38724..d49673e9aae 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -51,6 +51,9 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: ) http_client._update_limit(http_max_connections) try: + # TODO: I think this TaskGroup approach might not be the right + # approach here. It is fragile to failures and means that slow + # queries can cause things to wait on them. async with (asyncio.TaskGroup() as g,): for db in list(tenant.iter_dbs()): if db.name == defines.EDGEDB_SYSTEM_DB: @@ -60,32 +63,47 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: # Don't run the net_worker if the database is not # connectable, e.g. being dropped continue - json_bytes = await execute.parse_execute_json( - db, - """ - with - PENDING_REQUESTS := ( - select std::net::http::ScheduledRequest - filter .state = std::net::RequestState.Pending - ), - UPDATED := ( - update PENDING_REQUESTS - set { - state := std::net::RequestState.InProgress, - updated_at := datetime_of_statement(), - } - ), - select UPDATED { - id, - method, - url, - body, - headers, - } - """, - cached_globally=True, - tx_isolation=defines.TxIsolationLevel.RepeatableRead, - ) + try: + json_bytes = await execute.parse_execute_json( + db, + """ + with + PENDING_REQUESTS := ( + select std::net::http::ScheduledRequest + filter .state = std::net::RequestState.Pending + ), + UPDATED := ( + update PENDING_REQUESTS + set { + state := std::net::RequestState.InProgress, + updated_at := datetime_of_statement(), + } + ), + select UPDATED { + id, + method, + url, + body, + headers, + } + """, + cached_globally=True, + tx_isolation=defines.TxIsolationLevel.RepeatableRead, + ) + except Exception as ex: + # If the query fails (because the database branch + # has been racily deleted, maybe), ignore an keep + # going. We don't want the failure to bubble up + # and cause tasks in the task group to die. + logger.debug( + "HTTP net_worker query failed " + "(instance: %s, branch: %s)", + tenant.get_instance_name(), + db, + exc_info=ex, + ) + continue + pending_requests: list[dict] = json.loads(json_bytes) for pending_request in pending_requests: request = ScheduledRequest(**pending_request) diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index 03272dfe382..393b126bc4a 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -79,6 +79,10 @@ EDGEDB_SERVER_SETTINGS = { 'client_encoding': 'utf-8', + # DO NOT raise client_min_messages above NOTICE level + # because server indirect block return machinery relies + # on NoticeResponse as the data channel. + 'client_min_messages': 'NOTICE', 'search_path': 'edgedb', 'timezone': 'UTC', 'intervalstyle': 'iso_8601', @@ -560,6 +564,12 @@ async def start( # `max_connections` scenarios. 'max_locks_per_transaction': 1024, 'max_pred_locks_per_transaction': 1024, + "shared_preload_libraries": ",".join( + [ + "edb_stat_statements", + ] + ), + "edb_stat_statements.track_planning": "true", } if os.getenv('EDGEDB_DEBUG_PGSERVER'): @@ -568,7 +578,7 @@ async def start( else: log_level_map = { 'd': 'INFO', - 'i': 'NOTICE', + 'i': 'WARNING', # NOTICE in Postgres is quite noisy 'w': 'WARNING', 'e': 'ERROR', 's': 'PANIC', @@ -1164,6 +1174,13 @@ async def _detect_capabilities( if roles['rolcreatedb']: caps |= pgparams.BackendCapabilities.CREATE_DATABASE + stats_ver = await conn.sql_fetch_val(b""" + SELECT default_version FROM pg_available_extensions + WHERE name = 'edb_stat_statements'; + """) + if stats_ver in (b"1.0",): + caps |= pgparams.BackendCapabilities.STAT_STATEMENTS + return caps async def _get_pg_settings( diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 1f2dc9dfd10..8f95af2e3be 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -139,6 +139,8 @@ cdef class PGConnection: object last_state + str last_indirect_return + cdef before_command(self) cdef write(self, buf) @@ -154,6 +156,7 @@ cdef class PGConnection: cdef bint before_prepare( self, bytes stmt_name, int dbver, WriteBuffer outbuf) cdef write_sync(self, WriteBuffer outbuf) + cdef send_sync(self) cdef make_clean_stmt_message(self, bytes stmt_name) cdef send_query_unit_group( @@ -186,4 +189,4 @@ cdef class PGConnection: cdef inline str get_tenant_label(self) cpdef set_stmt_cache_size(self, int maxsize) -cdef str setting_to_sql(self) +cdef setting_to_sql(setting) diff --git a/edb/server/pgcon/pgcon.pyi b/edb/server/pgcon/pgcon.pyi index ee35a9fad88..7c6821bf028 100644 --- a/edb/server/pgcon/pgcon.pyi +++ b/edb/server/pgcon/pgcon.pyi @@ -53,7 +53,7 @@ class PGConnection(asyncio.Protocol): async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: ... async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, @@ -75,6 +75,11 @@ class PGConnection(asyncio.Protocol): use_prep_stmt: bool = False, state: Optional[bytes] = None, ) -> list[bytes]: ... + async def sql_describe( + self, + sql: bytes, + param_type_oids: list[int] | None = None, + ) -> tuple[list[int], list[tuple[str, int]]]: ... def terminate(self) -> None: ... def add_log_listener(self, cb: Callable[[str, str], None]) -> None: ... def get_server_parameter_status(self, parameter: str) -> Optional[str]: ... diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 2654fa3dc60..7f1426d8300 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -35,7 +35,6 @@ import sys import struct import textwrap import time -from collections import deque cimport cython cimport cpython @@ -101,6 +100,7 @@ cdef dict POSTGRES_SHUTDOWN_ERR_CODES = { } cdef object EMPTY_SQL_STATE = b"{}" +cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args() cdef object logger = logging.getLogger('edb.server') @@ -240,6 +240,8 @@ cdef class PGConnection: self.last_parse_prep_stmts = [] self.debug = debug.flags.server_proto + self.last_indirect_return = None + self.log_listeners = [] self.server = None @@ -451,6 +453,10 @@ cdef class PGConnection: outbuf.write_bytes(_SYNC_MESSAGE) self.waiting_for_sync += 1 + cdef send_sync(self): + self.write(_SYNC_MESSAGE) + self.waiting_for_sync += 1 + def _build_apply_state_req(self, bytes serstate, WriteBuffer out): cdef: WriteBuffer buf @@ -591,6 +597,8 @@ cdef class PGConnection: WriteBuffer bind_data bytes stmt_name ssize_t idx = start + bytes sql + tuple sqls out = WriteBuffer.new() parsed = set() @@ -608,7 +616,6 @@ cdef class PGConnection: ) stmt_name = query_unit.sql_hash if stmt_name: - assert len(query_unit.sql) == 1 # The same EdgeQL query may show up twice in the same script. # We just need to know and skip if we've already parsed the # same query within current send batch, because self.prep_stmts @@ -625,15 +632,16 @@ cdef class PGConnection: for query_unit, bind_data in zip( query_unit_group.units[start:end], bind_datas): stmt_name = query_unit.sql_hash + sql = query_unit.sql if stmt_name: if parse_array[idx]: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) - buf.write_bytestring(query_unit.sql[0]) + buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( - len(query_unit.sql[0]), + len(sql), self.get_tenant_label(), 'compiled', ) @@ -650,26 +658,25 @@ cdef class PGConnection: out.write_buffer(buf.end_message()) else: - for sql in query_unit.sql: - buf = WriteBuffer.new_message(b'P') - buf.write_bytestring(b'') # statement name - buf.write_bytestring(sql) - buf.write_int16(0) - out.write_buffer(buf.end_message()) - metrics.query_size.observe( - len(sql), self.get_tenant_label(), 'compiled' - ) + buf = WriteBuffer.new_message(b'P') + buf.write_bytestring(b'') # statement name + buf.write_bytestring(sql) + buf.write_int16(0) + out.write_buffer(buf.end_message()) + metrics.query_size.observe( + len(sql), self.get_tenant_label(), 'compiled' + ) - buf = WriteBuffer.new_message(b'B') - buf.write_bytestring(b'') # portal name - buf.write_bytestring(b'') # statement name - buf.write_buffer(bind_data) - out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'B') + buf.write_bytestring(b'') # portal name + buf.write_bytestring(b'') # statement name + buf.write_buffer(bind_data) + out.write_buffer(buf.end_message()) - buf = WriteBuffer.new_message(b'E') - buf.write_bytestring(b'') # portal name - buf.write_int32(0) # limit: 0 - return all rows - out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') + buf.write_bytestring(b'') # portal name + buf.write_int32(0) # limit: 0 - return all rows + out.write_buffer(buf.end_message()) idx += 1 @@ -687,7 +694,7 @@ cdef class PGConnection: out = WriteBuffer.new() buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') - buf.write_bytestring(b'???') + buf.write_bytestring(b'') buf.write_int16(0) # Then do a sync to get everything executed and lined back up @@ -797,7 +804,6 @@ cdef class PGConnection: elif mtype == b'I': ## result # EmptyQueryResponse self.buffer.discard_message() - return result else: self.fallthrough() @@ -805,6 +811,152 @@ cdef class PGConnection: finally: self.buffer.finish_message() + async def _describe( + self, + query: bytes, + param_type_oids: Optional[list[int]], + ): + cdef: + WriteBuffer out + + out = WriteBuffer.new() + + buf = WriteBuffer.new_message(b"P") # Parse + buf.write_bytestring(b"") + buf.write_bytestring(query) + if param_type_oids: + buf.write_int16(len(param_type_oids)) + for oid in param_type_oids: + buf.write_int32(oid) + else: + buf.write_int16(0) + out.write_buffer(buf.end_message()) + + buf = WriteBuffer.new_message(b"D") # Describe + buf.write_byte(b"S") + buf.write_bytestring(b"") + out.write_buffer(buf.end_message()) + + out.write_bytes(FLUSH_MESSAGE) + + self.write(out) + + param_desc = None + result_desc = None + + try: + buf = None + while True: + if not self.buffer.take_message(): + await self.wait_for_message() + mtype = self.buffer.get_message_type() + + try: + if mtype == b'1': + # ParseComplete + self.buffer.discard_message() + + elif mtype == b't': + # ParameterDescription + param_desc = self._decode_param_desc(self.buffer) + + elif mtype == b'T': + # RowDescription + result_desc = self._decode_row_desc(self.buffer) + break + + elif mtype == b'n': + # NoData + self.buffer.discard_message() + param_desc = [] + result_desc = [] + break + + elif mtype == b'E': ## result + # ErrorResponse + er_cls, er_fields = self.parse_error_message() + raise er_cls(fields=er_fields) + + else: + self.fallthrough() + + finally: + self.buffer.finish_message() + except Exception: + self.send_sync() + await self.wait_for_sync() + raise + + if param_desc is None: + raise RuntimeError( + "did not receive ParameterDescription from backend " + "in response to Describe" + ) + + if result_desc is None: + raise RuntimeError( + "did not receive RowDescription from backend " + "in response to Describe" + ) + + return param_desc, result_desc + + def _decode_param_desc(self, buf: ReadBuffer): + cdef: + int16_t nparams + uint32_t p_oid + list result = [] + + nparams = buf.read_int16() + + for _ in range(nparams): + p_oid = buf.read_int32() + result.append(p_oid) + + return result + + def _decode_row_desc(self, buf: ReadBuffer): + cdef: + int16_t nfields + + bytes f_name + uint32_t f_table_oid + int16_t f_column_num + uint32_t f_dt_oid + int16_t f_dt_size + int32_t f_dt_mod + int16_t f_format + + list result + + nfields = buf.read_int16() + + result = [] + for _ in range(nfields): + f_name = buf.read_null_str() + f_table_oid = buf.read_int32() + f_column_num = buf.read_int16() + f_dt_oid = buf.read_int32() + f_dt_size = buf.read_int16() + f_dt_mod = buf.read_int32() + f_format = buf.read_int16() + + result.append((f_name.decode("utf-8"), f_dt_oid)) + + return result + + async def sql_describe( + self, + query: bytes, + param_type_oids: Optional[list[int]] = None, + ) -> tuple[list[int], list[tuple[str, int]]]: + self.before_command() + started_at = time.monotonic() + try: + return await self._describe(query, param_type_oids) + finally: + await self.after_command() + async def _parse_execute( self, query, @@ -815,11 +967,16 @@ cdef class PGConnection: int dbver, bint use_pending_func_cache, tx_isolation, + list param_data_types, ): cdef: WriteBuffer out WriteBuffer buf bytes stmt_name + bytes sql + tuple sqls + bytes prologue_sql + bytes epilogue_sql int32_t dat_len @@ -834,14 +991,6 @@ cdef class PGConnection: uint64_t msgs_executed = 0 uint64_t i - if use_pending_func_cache and query.cache_func_call: - sql, stmt_name = query.cache_func_call - sqls = (sql,) - else: - sqls = query.sql - stmt_name = query.sql_hash - msgs_num = (len(sqls)) - out = WriteBuffer.new() if state is not None: @@ -849,7 +998,7 @@ cdef class PGConnection: if ( query.tx_id or not query.is_transactional - or query.append_rollback + or query.run_and_rollback or tx_isolation is not None ): # This query has START TRANSACTION or non-transactional command @@ -861,22 +1010,22 @@ cdef class PGConnection: state_sync = 1 self.write_sync(out) - if query.append_rollback or tx_isolation is not None: + if query.run_and_rollback or tx_isolation is not None: if self.in_tx(): sp_name = f'_edb_{time.monotonic_ns()}' - sql = f'SAVEPOINT {sp_name}'.encode('utf-8') + prologue_sql = f'SAVEPOINT {sp_name}'.encode('utf-8') else: sp_name = None - sql = b'START TRANSACTION' + prologue_sql = b'START TRANSACTION' if tx_isolation is not None: - sql += ( + prologue_sql += ( f' ISOLATION LEVEL {tx_isolation._value_}' .encode('utf-8') ) buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') - buf.write_bytestring(sql) + buf.write_bytestring(prologue_sql) buf.write_int16(0) out.write_buffer(buf.end_message()) @@ -896,9 +1045,17 @@ cdef class PGConnection: # Insert a SYNC as a boundary of the parsing logic later self.write_sync(out) + if use_pending_func_cache and query.cache_func_call: + sql, stmt_name = query.cache_func_call + sqls = (sql,) + else: + sqls = (query.sql,) + query.db_op_trailer + stmt_name = query.sql_hash + + msgs_num = (len(sqls)) + if use_prep_stmt: - parse = self.before_prepare( - stmt_name, dbver, out) + parse = self.before_prepare(stmt_name, dbver, out) else: stmt_name = b'' @@ -932,7 +1089,12 @@ cdef class PGConnection: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) buf.write_bytestring(sqls[0]) - buf.write_int16(0) + if param_data_types: + buf.write_int16(len(param_data_types)) + for oid in param_data_types: + buf.write_int32(oid) + else: + buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( len(sqls[0]), self.get_tenant_label(), 'compiled' @@ -963,8 +1125,8 @@ cdef class PGConnection: buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) - if query.append_rollback or tx_isolation is not None: - if query.append_rollback: + if query.run_and_rollback or tx_isolation is not None: + if query.run_and_rollback: if sp_name: sql = f'ROLLBACK TO SAVEPOINT {sp_name}'.encode('utf-8') else: @@ -986,6 +1148,35 @@ cdef class PGConnection: buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') + buf.write_bytestring(b'') # portal name + buf.write_int32(0) # limit: 0 - return all rows + out.write_buffer(buf.end_message()) + elif query.append_tx_op: + if query.tx_commit: + sql = b'COMMIT' + elif query.tx_rollback: + sql = b'ROLLBACK' + else: + raise errors.InternalServerError( + "QueryUnit.append_tx_op is set but none of the " + "Query.tx_ properties are" + ) + + buf = WriteBuffer.new_message(b'P') + buf.write_bytestring(b'') + buf.write_bytestring(sql) + buf.write_int16(0) + out.write_buffer(buf.end_message()) + + buf = WriteBuffer.new_message(b'B') + buf.write_bytestring(b'') # portal name + buf.write_bytestring(b'') # statement name + buf.write_int16(0) # number of format codes + buf.write_int16(0) # number of parameters + buf.write_int16(0) # number of result columns + out.write_buffer(buf.end_message()) + buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows @@ -1000,7 +1191,7 @@ cdef class PGConnection: if state is not None: await self.wait_for_state_resp(state, state_sync) - if query.append_rollback or tx_isolation is not None: + if query.run_and_rollback or tx_isolation is not None: await self.wait_for_sync() buf = None @@ -1099,7 +1290,8 @@ cdef class PGConnection: self, *, query, - WriteBuffer bind_data, + WriteBuffer bind_data = NO_ARGS, + list param_data_types = None, frontend.AbstractFrontendConnection fe_conn = None, bint use_prep_stmt = False, bytes state = None, @@ -1119,6 +1311,7 @@ cdef class PGConnection: dbver, use_pending_func_cache, tx_isolation, + param_data_types, ) finally: metrics.backend_query_duration.observe( @@ -1128,29 +1321,21 @@ cdef class PGConnection: async def sql_fetch( self, - sql: bytes | tuple[bytes, ...], + sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, ) -> list[tuple[bytes, ...]]: - cdef tuple sql_tuple - - if not isinstance(sql, tuple): - sql_tuple = (sql,) - else: - sql_tuple = sql - if use_prep_stmt: sql_digest = hashlib.sha1() - for stmt in sql_tuple: - sql_digest.update(stmt) + sql_digest.update(sql) sql_hash = sql_digest.hexdigest().encode('latin1') else: sql_hash = None query = compiler.QueryUnit( - sql=sql_tuple, + sql=sql, sql_hash=sql_hash, status=b"", ) @@ -1306,7 +1491,7 @@ cdef class PGConnection: async def sql_extended_query( self, - actions: list[PGMessage], + actions, fe_conn: frontend.AbstractFrontendConnection, dbver: int, dbv: pg_ext.ConnectionView, @@ -1335,7 +1520,7 @@ cdef class PGConnection: def _write_sql_extended_query( self, - actions: list[PGMessage], + actions, dbver: int, dbv: pg_ext.ConnectionView, ) -> bytes: @@ -1545,7 +1730,7 @@ cdef class PGConnection: async def _parse_sql_extended_query( self, - actions: list[PGMessage], + actions, fe_conn: frontend.AbstractFrontendConnection, dbver: int, dbv: pg_ext.ConnectionView, @@ -1989,13 +2174,19 @@ cdef class PGConnection: while True: field_type = self.buffer.read_byte() if field_type == b'P': # Position - qu = (action.query_unit.translation_data - if action.query_unit else None) + if action.query_unit is None: + translation_data = None + offset = 0 + else: + qu = action.query_unit + translation_data = qu.translation_data + offset = -qu.prefix_len self._write_error_position( msg_buf, action.args[0], self.buffer.read_null_str(), - qu + translation_data, + offset, ) continue else: @@ -2041,6 +2232,7 @@ cdef class PGConnection: else: offset = 0 translation_data = qu.translation_data + offset -= qu.prefix_len else: query_text = b"" translation_data = None @@ -2068,42 +2260,24 @@ cdef class PGConnection: msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) - async def run_ddl( - self, - object query_unit, - bytes state=None - ): - data = await self.sql_fetch(query_unit.sql, state=state) - if query_unit.ddl_stmt_id is None: - return - else: - return self.load_ddl_return(query_unit, data) - - def load_ddl_return(self, object query_unit, data): + def load_last_ddl_return(self, object query_unit): if query_unit.ddl_stmt_id: + data = self.last_indirect_return if data: - ret = json.loads(data[0][0]) + ret = json.loads(data) if ret['ddl_stmt_id'] != query_unit.ddl_stmt_id: raise RuntimeError( - 'unrecognized data packet after a DDL command: ' - 'data_stmt_id do not match' + 'unrecognized data notice after a DDL command: ' + 'data_stmt_id do not match: expected ' + f'{query_unit.ddl_stmt_id!r}, got ' + f'{ret["ddl_stmt_id"]!r}' ) return ret else: raise RuntimeError( - 'missing the required data packet after a DDL command' + 'missing the required data notice after a DDL command' ) - async def handle_ddl_in_script( - self, object query_unit, bint parse, int dbver - ): - data = None - for sql in query_unit.sql: - data = await self.wait_for_command( - query_unit, parse, dbver, ignore_data=bool(data) - ) or data - return self.load_ddl_return(query_unit, data) - async def _dump(self, block, output_queue, fragment_suggested_size): cdef: WriteBuffer buf @@ -2528,6 +2702,7 @@ cdef class PGConnection: 'previous one') self.idle = False + self.last_indirect_return = None async def after_command(self): if self.idle: @@ -2676,14 +2851,18 @@ cdef class PGConnection: elif mtype == b'N': # NoticeResponse - if self.log_listeners: - _, fields = self.parse_error_message() - severity = fields.get('V') - message = fields.get('M') + _, fields = self.parse_error_message() + severity = fields.get('V') + message = fields.get('M') + detail = fields.get('D') + if ( + severity == "NOTICE" + and message.startswith("edb:notice:indirect_return") + ): + self.last_indirect_return = detail + elif self.log_listeners: for listener in self.log_listeners: self.loop.call_soon(listener, severity, message) - else: - self.buffer.discard_message() return True return False @@ -2830,7 +3009,7 @@ cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) cdef EdegDBCodecContext DEFAULT_CODEC_CONTEXT = EdegDBCodecContext() -cdef str setting_to_sql(setting: tuple[str | int | float, ...]): +cdef setting_to_sql(setting): return ', '.join(setting_val_to_sql(v) for v in setting) diff --git a/edb/server/pgproto b/edb/server/pgproto index 780c5228fc4..9f415b2c834 160000 --- a/edb/server/pgproto +++ b/edb/server/pgproto @@ -1 +1 @@ -Subproject commit 780c5228fc40632bbacceaf7b98264eded5f2fc3 +Subproject commit 9f415b2c834df119422c011e5163e21064bff6ad diff --git a/edb/server/pgrust/src/protocol/gen.rs b/edb/server/pgrust/src/protocol/gen.rs index dd6cb83875e..85b88b78226 100644 --- a/edb/server/pgrust/src/protocol/gen.rs +++ b/edb/server/pgrust/src/protocol/gen.rs @@ -313,7 +313,7 @@ macro_rules! protocol_builder { let Ok(val) = $crate::protocol::FieldAccess::<$type>::extract(buf.split_at(offset).1) else { return false; }; - if val != $value as usize as _ { return false; } + if val as usize != $value as usize { return false; } )? offset += std::mem::size_of::<$type>(); )* diff --git a/edb/server/pgrust/src/python.rs b/edb/server/pgrust/src/python.rs index 076f3d45959..2931a212b23 100644 --- a/edb/server/pgrust/src/python.rs +++ b/edb/server/pgrust/src/python.rs @@ -18,11 +18,11 @@ use pyo3::{ exceptions::{PyException, PyRuntimeError}, prelude::*, pymodule, - types::{PyAnyMethods, PyByteArray, PyBytes, PyMemoryView, PyModule, PyModuleMethods, PyNone}, + types::{PyAnyMethods, PyByteArray, PyBytes, PyMemoryView, PyModule, PyModuleMethods}, Bound, PyAny, PyResult, Python, }; use std::collections::HashMap; -use std::path::Path; +use std::{borrow::Cow, path::Path}; #[derive(Clone, Copy, PartialEq, Eq)] #[pyclass(eq, eq_int)] @@ -116,7 +116,7 @@ impl PyConnectionParams { ResolvedTarget::SocketAddr(addr) => { resolved_hosts.push(( if addr.ip().is_ipv4() { "v4" } else { "v6" }, - addr.ip().to_string().into_py(py), + addr.ip().to_string().into_pyobject(py)?.into(), hostname.clone(), addr.port(), )); @@ -126,7 +126,7 @@ impl PyConnectionParams { if let Some(path) = path.as_pathname() { resolved_hosts.push(( "unix", - path.to_string_lossy().into_py(py), + path.to_string_lossy().into_pyobject(py)?.into(), hostname.clone(), port, )); @@ -141,7 +141,7 @@ impl PyConnectionParams { name.insert(0, 0); resolved_hosts.push(( "unix", - PyBytes::new_bound(py, &name).as_any().clone().unbind(), + PyBytes::new(py, &name).as_any().clone().unbind(), hostname.clone(), port, )); @@ -200,7 +200,7 @@ impl PyConnectionParams { } pub fn resolve(&self, py: Python, username: String, home_dir: String) -> PyResult { - let os = py.import_bound("os")?; + let os = py.import("os")?; let environ = os.getattr("environ")?; let mut params = self.inner.clone(); @@ -215,7 +215,7 @@ impl PyConnectionParams { ¶ms.database, ¶ms.user, )? { - let warnings = py.import_bound("warnings")?; + let warnings = py.import("warnings")?; warnings.call_method1("warn", (warning.to_string(),))?; } @@ -249,8 +249,8 @@ impl PyConnectionParams { repr } - pub fn __getitem__(&self, py: Python, name: &str) -> Py { - self.inner.get_by_name(name).to_object(py) + pub fn __getitem__(&self, name: &str) -> Option> { + self.inner.get_by_name(name) } pub fn __setitem__(&mut self, name: &str, value: &str) -> PyResult<()> { @@ -286,7 +286,7 @@ impl PyConnectionState { username: String, home_dir: String, ) -> PyResult { - let os = py.import_bound("os")?; + let os = py.import("os")?; let environ = os.getattr("environ")?; let mut params = dsn.inner.clone(); @@ -301,7 +301,7 @@ impl PyConnectionState { ¶ms.database, ¶ms.user, )? { - let warnings = py.import_bound("warnings")?; + let warnings = py.import("warnings")?; warnings.call_method1("warn", (warning.to_string(),))?; } @@ -325,15 +325,15 @@ impl PyConnectionState { inner: ConnectionState::new(credentials, ssl_mode), parsed_dsn: Py::new(py, PyConnectionParams { inner: params })?, update: PyConnectionStateUpdate { - py_update: PyNone::get_bound(py).to_object(py), + py_update: py.None(), }, message_buffer: Default::default(), }) } #[setter] - fn update(&mut self, py: Python, update: &Bound) { - self.update.py_update = update.to_object(py); + fn update(&mut self, update: &Bound) { + self.update.py_update = update.clone().unbind(); } fn is_ready(&self) -> bool { @@ -351,7 +351,7 @@ impl PyConnectionState { } fn drive_message(&mut self, py: Python, data: &Bound) -> PyResult<()> { - let buffer = PyBuffer::::get_bound(data)?; + let buffer = PyBuffer::::get(data)?; if self.inner.read_ssl_response() { // SSL responses are always one character let response = [buffer.as_slice(py).unwrap().get(0).unwrap().get()]; @@ -408,7 +408,7 @@ impl ConnectionStateSend for PyConnectionStateUpdate { message: crate::protocol::definition::InitialBuilder, ) -> Result<(), std::io::Error> { Python::with_gil(|py| { - let bytes = PyByteArray::new_bound(py, &message.to_vec()); + let bytes = PyByteArray::new(py, &message.to_vec()); if let Err(e) = self.py_update.call_method1(py, "send", (bytes,)) { eprintln!("Error in send_initial: {:?}", e); e.print(py); @@ -422,7 +422,7 @@ impl ConnectionStateSend for PyConnectionStateUpdate { message: crate::protocol::definition::FrontendBuilder, ) -> Result<(), std::io::Error> { Python::with_gil(|py| { - let bytes = PyBytes::new_bound(py, &message.to_vec()); + let bytes = PyBytes::new(py, &message.to_vec()); if let Err(e) = self.py_update.call_method1(py, "send", (bytes,)) { eprintln!("Error in send: {:?}", e); e.print(py); diff --git a/edb/server/protocol/args_ser.pxd b/edb/server/protocol/args_ser.pxd index 75550076d62..04bb2c3ad47 100644 --- a/edb/server/protocol/args_ser.pxd +++ b/edb/server/protocol/args_ser.pxd @@ -26,6 +26,7 @@ cdef WriteBuffer recode_bind_args( dbview.CompiledQuery compiled, bytes bind_args, list positions = ?, + list data_types = ?, ) diff --git a/edb/server/protocol/args_ser.pyx b/edb/server/protocol/args_ser.pyx index c5efda663f7..d95b2ebac61 100644 --- a/edb/server/protocol/args_ser.pyx +++ b/edb/server/protocol/args_ser.pyx @@ -102,6 +102,7 @@ cdef WriteBuffer recode_bind_args( bytes bind_args, # XXX do something better?!? list positions = None, + list data_types = None, ): cdef: FRBuffer in_buf @@ -121,10 +122,6 @@ cdef WriteBuffer recode_bind_args( cpython.PyBytes_AS_STRING(bind_args), cpython.Py_SIZE(bind_args)) - # all parameters are in binary - if live: - out_buf.write_int32(0x00010001) - # number of elements in the tuple # for empty tuple it's okay to send zero-length arguments qug = compiled.query_unit_group @@ -155,11 +152,36 @@ cdef WriteBuffer recode_bind_args( f"argument count mismatch {recv_args} != {compiled.first_extra}" num_args += compiled.extra_counts[0] - num_args += _count_globals(qug) + num_globals = _count_globals(qug) + num_args += num_globals if live: + if not compiled.extra_formatted_as_text: + # all parameter values are in binary + out_buf.write_int32(0x00010001) + elif not recv_args and not num_globals: + # all parameter values are in text (i.e extracted SQL constants) + out_buf.write_int16(0x0000) + else: + # got a mix of binary and text, spell them out explicitly + out_buf.write_int16(num_args) + # explicit args are in binary + for _ in range(recv_args): + out_buf.write_int16(0x0001) + # and extracted SQL constants are in text + for _ in range(compiled.extra_counts[0]): + out_buf.write_int16(0x0000) + # and injected globals are binary again + for _ in range(num_globals): + out_buf.write_int16(0x0001) + out_buf.write_int16(num_args) + if data_types is not None and compiled.extra_type_oids: + data_types.extend([0] * recv_args) + data_types.extend(compiled.extra_type_oids) + data_types.extend([0] * num_globals) + if qug.in_type_args: for param in qug.in_type_args: if positions is not None: diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index 59e957cc5a4..d43a492e4a6 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -173,6 +173,7 @@ async def fetch_user_info( name=payload.get("name"), email=payload.get("email"), picture=payload.get("picture"), + source_id_token=id_token, ) async def _get_oidc_config(self) -> data.OpenIDConfig: diff --git a/edb/server/protocol/auth_ext/data.py b/edb/server/protocol/auth_ext/data.py index 6464a70673b..035a0c657af 100644 --- a/edb/server/protocol/auth_ext/data.py +++ b/edb/server/protocol/auth_ext/data.py @@ -51,6 +51,7 @@ class UserInfo: phone_number_verified: Optional[bool] = None address: Optional[dict[str, str]] = None updated_at: Optional[float] = None + source_id_token: Optional[str] = None def __str__(self) -> str: return self.sub diff --git a/edb/server/protocol/auth_ext/email.py b/edb/server/protocol/auth_ext/email.py index 826ae5bb111..c5879d45a8d 100644 --- a/edb/server/protocol/auth_ext/email.py +++ b/edb/server/protocol/auth_ext/email.py @@ -3,9 +3,9 @@ import random from typing import Any, Coroutine -from edb.server import tenant +from edb.server import tenant, smtp -from . import util, ui, smtp +from . import util, ui async def send_password_reset_email( @@ -15,7 +15,6 @@ async def send_password_reset_email( reset_url: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} @@ -27,16 +26,13 @@ async def send_password_reset_email( brand_color=app_details_config.brand_color, ) msg = ui.render_password_reset_email( - from_addr=from_addr, to_addr=to_addr, reset_url=reset_url, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) @@ -51,7 +47,6 @@ async def send_verification_email( provider: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) verification_token_params = urllib.parse.urlencode( { @@ -71,16 +66,13 @@ async def send_verification_email( brand_color=app_details_config.brand_color, ) msg = ui.render_verification_email( - from_addr=from_addr, to_addr=to_addr, verify_url=verify_url, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) @@ -93,7 +85,6 @@ async def send_magic_link_email( link: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} @@ -105,16 +96,13 @@ async def send_magic_link_email( brand_color=app_details_config.brand_color, ) msg = ui.render_magic_link_email( - from_addr=from_addr, to_addr=to_addr, link=link, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) diff --git a/edb/server/protocol/auth_ext/email_password.py b/edb/server/protocol/auth_ext/email_password.py index 3e6d93231ef..f71b32194ac 100644 --- a/edb/server/protocol/auth_ext/email_password.py +++ b/edb/server/protocol/auth_ext/email_password.py @@ -94,12 +94,9 @@ async def register(self, input: dict[str, Any]) -> data.EmailFactor: assert len(result_json) == 1 return data.EmailFactor(**result_json[0]) - async def authenticate(self, input: dict[str, Any]) -> data.LocalIdentity: - if 'email' not in input or 'password' not in input: - raise errors.InvalidData("Missing 'email' or 'password' in data") - - password = input["password"] - email = input["email"] + async def authenticate( + self, email: str, password: str + ) -> data.LocalIdentity: r = await execute.parse_execute_json( db=self.db, query="""\ @@ -151,12 +148,8 @@ async def authenticate(self, input: dict[str, Any]) -> data.LocalIdentity: async def get_email_factor_and_secret( self, - input: dict[str, Any], + email: str, ) -> tuple[data.EmailFactor, str]: - if 'email' not in input: - raise errors.InvalidData("Missing 'email' in data") - - email = input["email"] r = await execute.parse_execute_json( db=self.db, query=""" @@ -215,13 +208,8 @@ async def validate_reset_secret( return local_identity if secret == current_secret else None async def update_password( - self, identity_id: str, secret: str, input: dict[str, Any] + self, identity_id: str, secret: str, password: str ) -> data.LocalIdentity: - if 'password' not in input: - raise errors.InvalidData("Missing 'password' in data") - - password = input["password"] - local_identity = await self.validate_reset_secret(identity_id, secret) if local_identity is None: diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 9f9450c0ef9..3ee26c32bd1 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -30,6 +30,7 @@ import os import mimetypes import uuid +import dataclasses from typing import ( Any, @@ -128,6 +129,10 @@ async def handle_request( else None ) + logger.info( + f"Handling incoming HTTP request: /ext/auth/{'/'.join(args)}" + ) + try: match args: # API routes @@ -281,20 +286,12 @@ async def handle_authorize( request.url.query.decode("ascii") if request.url.query else "" ) provider_name = _get_search_param(query, "provider") - redirect_to = _get_search_param(query, "redirect_to") - redirect_to_on_signup = _maybe_get_search_param( - query, "redirect_to_on_signup" + allowed_redirect_to = self._make_allowed_url( + _get_search_param(query, "redirect_to") + ) + allowed_redirect_to_on_signup = self._maybe_make_allowed_url( + _maybe_get_search_param(query, "redirect_to_on_signup") ) - if not self._is_url_allowed(redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - if redirect_to_on_signup and not self._is_url_allowed( - redirect_to_on_signup - ): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) challenge = _get_search_param( query, "challenge", fallback_keys=["code_challenge"] ) @@ -308,11 +305,19 @@ async def handle_authorize( authorize_url = await oauth_client.get_authorize_url( redirect_uri=self._get_callback_url(), state=self._make_state_claims( - provider_name, redirect_to, redirect_to_on_signup, challenge + provider_name, + allowed_redirect_to.url, + ( + allowed_redirect_to_on_signup.url + if allowed_redirect_to_on_signup + else None + ), + challenge, ), ) - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = authorize_url + # n.b. Explicitly allow authorization URL to be outside of allowed + # URLs because it is a trusted URL from the identity provider. + self._do_redirect(response, AllowedUrl(authorize_url)) async def handle_callback( self, @@ -357,21 +362,19 @@ async def handle_callback( except Exception: raise errors.InvalidData("Invalid state token") - if not self._is_url_allowed(redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - params = { "error": error, } + error_str = error if error_description is not None: params["error_description"] = error_description - response.custom_headers["Location"] = util.join_url_params( - redirect_to, params + error_str += f": {error_description}" + + logger.debug(f"OAuth provider returned an error: {error_str}") + return self._try_redirect( + response, + util.join_url_params(redirect_to, params), ) - response.status = http.HTTPStatus.FOUND - return if code is None: raise errors.InvalidData( @@ -381,20 +384,12 @@ async def handle_callback( try: claims = self._verify_and_extract_claims(state) provider_name = cast(str, claims["provider"]) - redirect_to = cast(str, claims["redirect_to"]) - redirect_to_on_signup = cast( - Optional[str], claims.get("redirect_to_on_signup") + allowed_redirect_to = self._make_allowed_url( + cast(str, claims["redirect_to"]) + ) + allowed_redirect_to_on_signup = self._maybe_make_allowed_url( + cast(Optional[str], claims.get("redirect_to_on_signup")) ) - if not self._is_url_allowed(redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - if redirect_to_on_signup and not self._is_url_allowed( - redirect_to_on_signup - ): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) challenge = cast(str, claims["challenge"]) except Exception: raise errors.InvalidData("Invalid state token") @@ -409,6 +404,7 @@ async def handle_callback( new_identity, auth_token, refresh_token, + id_token, ) = await oauth_client.handle_callback(code, self._get_callback_url()) pkce_code = await pkce.link_identity_challenge( self.db, identity.id, challenge @@ -419,17 +415,22 @@ async def handle_callback( id=pkce_code, auth_token=auth_token, refresh_token=refresh_token, + id_token=id_token, + ) + new_url = ( + (allowed_redirect_to_on_signup or allowed_redirect_to) + if new_identity + else allowed_redirect_to + ).map( + lambda u: util.join_url_params( + u, {"code": pkce_code, "provider": provider_name} ) - new_url = util.join_url_params( - ( - (redirect_to_on_signup or redirect_to) - if new_identity - else redirect_to - ), - {"code": pkce_code, "provider": provider_name}, ) - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = new_url + logger.info( + "OAuth callback successful: " + f"identity_id={identity.id}, new_identity={new_identity}" + ) + self._do_redirect(response, new_url) async def handle_token( self, @@ -470,22 +471,25 @@ async def handle_token( if base64_url_encoded_verifier.decode() == pkce_object.challenge: await pkce.delete(self.db, code) + identity_id = pkce_object.identity_id await self._maybe_send_webhook( webhook.IdentityAuthenticated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), - identity_id=pkce_object.identity_id, + identity_id=identity_id, ) ) - session_token = self._make_session_token(pkce_object.identity_id) + session_token = self._make_session_token(identity_id) + logger.info(f"Token exchange successful: identity_id={identity_id}") response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps( { "auth_token": session_token, - "identity_id": pkce_object.identity_id, + "identity_id": identity_id, "provider_token": pkce_object.auth_token, "provider_refresh_token": pkce_object.refresh_token, + "provider_id_token": pkce_object.id_token, } ).encode() else: @@ -498,11 +502,10 @@ async def handle_register( ) -> None: data = self._get_data_from_request(request) - maybe_redirect_to = cast(Optional[str], data.get("redirect_to")) - if maybe_redirect_to and not self._is_url_allowed(maybe_redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) + allowed_redirect_to = self._maybe_make_allowed_url( + cast(Optional[str], data.get("redirect_to")) + ) + maybe_challenge = cast(Optional[str], data.get("challenge")) register_provider_name = cast(Optional[str], data.get("provider")) if register_provider_name is None: @@ -521,7 +524,9 @@ async def handle_register( identity_id=identity.id, verify_url=verify_url, maybe_challenge=maybe_challenge, - maybe_redirect_to=maybe_redirect_to, + maybe_redirect_to=( + allowed_redirect_to.url if allowed_redirect_to else None + ), ) await self._maybe_send_webhook( @@ -566,48 +571,55 @@ async def handle_register( verify_url=verify_url, ) - now_iso8601 = datetime.datetime.now( - datetime.timezone.utc - ).isoformat() - if maybe_redirect_to is not None: - response.status = http.HTTPStatus.FOUND - redirect_params = ( - {"verification_email_sent_at": now_iso8601} - if require_verification - else { - "code": cast(str, pkce_code), - "provider": register_provider_name, - } - ) - response.custom_headers["Location"] = util.join_url_params( - maybe_redirect_to, redirect_params + if require_verification: + response_dict = { + "verification_email_sent_at": datetime.datetime.now( + datetime.timezone.utc + ).isoformat() + } + else: + if pkce_code is None: + raise errors.PKCECreationFailed + response_dict = { + "code": pkce_code, + "provider": register_provider_name, + } + + logger.info( + f"Identity created: identity_id={identity.id}, " + f"pkce_id={pkce_code!r}" + ) + + if allowed_redirect_to is not None: + self._do_redirect( + response, + allowed_redirect_to.map( + lambda u: util.join_url_params(u, response_dict) + ), ) else: response.status = http.HTTPStatus.CREATED response.content_type = b"application/json" - if require_verification: - response.body = json.dumps( - {"verification_email_sent_at": (now_iso8601)} - ).encode() - else: - if pkce_code is None: - raise errors.PKCECreationFailed - response.body = json.dumps( - {"code": pkce_code, "provider": register_provider_name} - ).encode() + response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = data.get( - "redirect_on_failure", maybe_redirect_to + "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "email": data.get('email', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + email = data.get("email", "") + logger.error( + f"Error creating identity: error={error_message}, " + f"email={email}" + ) + error_redirect_url = util.join_url_params( + redirect_on_failure, + { + "error": error_message, + "email": email, + }, ) + return self._try_redirect(response, error_redirect_url) else: raise ex @@ -618,23 +630,22 @@ async def handle_authenticate( ) -> None: data = self._get_data_from_request(request) - authenticate_provider_name = data.get("provider") - if authenticate_provider_name is None: - raise errors.InvalidData('Missing "provider" in register request') - maybe_challenge = data.get("challenge") - if maybe_challenge is None: - raise errors.InvalidData('Missing "challenge" in register request') - await pkce.create(self.db, maybe_challenge) + _check_keyset(data, {"provider", "challenge", "email", "password"}) + challenge = data["challenge"] + email = data["email"] + password = data["password"] - maybe_redirect_to = data.get("redirect_to") - if maybe_redirect_to and not self._is_url_allowed(maybe_redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) + await pkce.create(self.db, challenge) + + allowed_redirect_to = self._maybe_make_allowed_url( + cast(Optional[str], data.get("redirect_to")) + ) email_password_client = email_password.Client(db=self.db) try: - local_identity = await email_password_client.authenticate(data) + local_identity = await email_password_client.authenticate( + email, password + ) verified_at = ( await email_password_client.get_verified_by_identity_id( identity_id=local_identity.id @@ -647,41 +658,43 @@ async def handle_authenticate( raise errors.VerificationRequired() pkce_code = await pkce.link_identity_challenge( - self.db, local_identity.id, maybe_challenge - ) - if maybe_redirect_to: - response.status = http.HTTPStatus.FOUND - redirect_params = { - "code": pkce_code, - } - response.custom_headers["Location"] = util.join_url_params( - maybe_redirect_to, redirect_params + self.db, local_identity.id, challenge + ) + response_dict = {"code": pkce_code} + logger.info( + f"Authentication successful: identity_id={local_identity.id}, " + f"pkce_id={pkce_code}" + ) + if allowed_redirect_to: + self._do_redirect( + response, + allowed_redirect_to.map( + lambda u: util.join_url_params(u, response_dict) + ), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" - response.body = json.dumps( - { - "code": pkce_code, - } - ).encode() + response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = data.get( - "redirect_on_failure", maybe_redirect_to + "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "email": data.get('email', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + email = data.get("email", "") + logger.error( + f"Error authenticating: error={error_message}, " + f"email={email}" ) + error_redirect_url = util.join_url_params( + redirect_on_failure, + { + "error": error_message, + "email": email, + }, + ) + return self._try_redirect(response, error_redirect_url) else: raise ex @@ -719,25 +732,27 @@ async def handle_verify( except errors.VerificationTokenExpired: response.status = http.HTTPStatus.FORBIDDEN response.content_type = b"application/json" - response.body = json.dumps( - { - "message": ( - "The 'iat' claim in verification token is older" - " than 24 hours" - ) - } - ).encode() + error_message = ( + "The 'iat' claim in verification token is older than 24 hours" + ) + logger.error(f"Verification token expired: {error_message}") + response.body = json.dumps({"message": error_message}).encode() return + logger.info( + f"Email verified: identity_id={identity_id}, " + f"email_factor_id={email_factor.id}, " + f"email={email_factor.email}" + ) match (maybe_challenge, maybe_redirect_to): case (str(challenge), str(redirect_to)): await pkce.create(self.db, challenge) code = await pkce.link_identity_challenge( self.db, identity_id, challenge ) - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = _with_appended_qs( - redirect_to, {"code": [code]} + return self._try_redirect( + response, + util.join_url_params(redirect_to, {"code": code}), ) case (str(challenge), _): await pkce.create(self.db, challenge) @@ -747,11 +762,12 @@ async def handle_verify( response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps({"code": code}).encode() + return case (_, str(redirect_to)): - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = redirect_to + return self._try_redirect(response, redirect_to) case (_, _): response.status = http.HTTPStatus.NO_CONTENT + return async def handle_resend_verification_email( self, @@ -819,8 +835,27 @@ async def handle_resend_verification_email( ) if email_factor is None: + match local_client: + case webauthn.Client(): + logger.debug( + f"Failed to find email factor for resend verification " + f"email: provider={provider_name}, " + f"webauthn_credential_id={request_data.get('credential_id')}" + ) + case email_password.Client(): + logger.debug( + f"Failed to find email factor for resend verification " + f"email: provider={provider_name}, " + f"email={request_data.get('email')}" + ) await auth_emails.send_fake_email(self.tenant) else: + logger.info( + f"Resending verification email: provider={provider_name}, " + f"identity_id={email_factor.identity.id}, " + f"email_factor_id={email_factor.id}, " + f"email={email_factor.email}" + ) verification_token = self._make_verification_token( identity_id=email_factor.identity.id, verify_url=verify_url, @@ -853,16 +888,15 @@ async def handle_send_reset_email( data = self._get_data_from_request(request) _check_keyset(data, {"provider", "email", "reset_url", "challenge"}) + email = data["email"] email_password_client = email_password.Client(db=self.db) if not self._is_url_allowed(data["reset_url"]): raise errors.InvalidData( "Redirect URL does not match any allowed URLs.", ) - maybe_redirect_to = data.get("redirect_to") - if maybe_redirect_to and not self._is_url_allowed(maybe_redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) + allowed_redirect_to = self._maybe_make_allowed_url( + data.get("redirect_to") + ) try: try: @@ -870,7 +904,7 @@ async def handle_send_reset_email( email_factor, secret, ) = await email_password_client.get_email_factor_and_secret( - data + email ) identity_id = email_factor.identity.id @@ -898,21 +932,27 @@ async def handle_send_reset_email( await auth_emails.send_password_reset_email( db=self.db, tenant=self.tenant, - to_addr=data["email"], + to_addr=email, reset_url=reset_url, test_mode=self.test_mode, ) except errors.NoIdentityFound: + logger.debug( + f"Failed to find identity for send reset email: " + f"email={email}" + ) await auth_emails.send_fake_email(self.tenant) return_data = { - "email_sent": data['email'], + "email_sent": email, } - if maybe_redirect_to: - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = util.join_url_params( - maybe_redirect_to, return_data + if allowed_redirect_to: + return self._do_redirect( + response, + allowed_redirect_to.map( + lambda u: util.join_url_params(u, return_data) + ), ) else: response.status = http.HTTPStatus.OK @@ -927,20 +967,24 @@ async def handle_send_reset_email( except Exception as ex: redirect_on_failure = data.get( - "redirect_on_failure", maybe_redirect_to + "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "email": data.get('email', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + logger.error( + f"Error sending reset email: error={error_message}, " + f"email={email}" + ) + redirect_url = util.join_url_params( + redirect_on_failure, + { + "error": error_message, + "email": email, + }, + ) + return self._try_redirect( + response, + redirect_url, ) else: raise ex @@ -952,56 +996,62 @@ async def handle_reset_password( ) -> None: data = self._get_data_from_request(request) - _check_keyset(data, {"provider", "reset_token"}) + _check_keyset(data, {"provider", "reset_token", "password"}) + reset_token = data['reset_token'] + password = data['password'] email_password_client = email_password.Client(db=self.db) - maybe_redirect_to = data.get("redirect_to") - if maybe_redirect_to and not self._is_url_allowed(maybe_redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) + allowed_redirect_to = self._maybe_make_allowed_url( + data.get("redirect_to") + ) try: - reset_token = data['reset_token'] identity_id, secret, challenge = self._get_data_from_reset_token( reset_token ) await email_password_client.update_password( - identity_id, secret, data + identity_id, secret, password ) await pkce.create(self.db, challenge) code = await pkce.link_identity_challenge( self.db, identity_id, challenge ) + response_dict = {"code": code} + logger.info( + f"Reset password: identity_id={identity_id}, pkce_id={code}" + ) - if maybe_redirect_to: - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = util.join_url_params( - maybe_redirect_to, {"code": code} + if allowed_redirect_to: + return self._do_redirect( + response, + allowed_redirect_to.map( + lambda u: util.join_url_params(u, response_dict) + ), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" - response.body = json.dumps({"code": code}).encode() + response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = data.get( - "redirect_on_failure", maybe_redirect_to + "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "reset_token": data.get('reset_token', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + logger.error( + f"Error resetting password: error={error_message}, " + f"reset_token={reset_token}" ) + redirect_url = util.join_url_params( + redirect_on_failure, + { + "error": error_message, + "reset_token": reset_token, + }, + ) + return self._try_redirect(response, redirect_url) else: raise ex @@ -1026,20 +1076,19 @@ async def handle_magic_link_register( email = data["email"] challenge = data["challenge"] callback_url = data["callback_url"] - redirect_on_failure = data["redirect_on_failure"] if not self._is_url_allowed(callback_url): raise errors.InvalidData( "Callback URL does not match any allowed URLs.", ) - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Error redirect URL does not match any allowed URLs.", - ) - maybe_redirect_to = data.get("redirect_to") - if maybe_redirect_to and not self._is_url_allowed(maybe_redirect_to): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) + + allowed_redirect_on_failure = self._make_allowed_url( + data["redirect_on_failure"] + ) + + allowed_redirect_to = self._maybe_make_allowed_url( + data.get("redirect_to") + ) + magic_link_client = magic_link.Client( db=self.db, issuer=self.base_path, @@ -1049,7 +1098,7 @@ async def handle_magic_link_register( request_accepts_json: bool = request.accept == b"application/json" - if not request_accepts_json and not maybe_redirect_to: + if not request_accepts_json and not allowed_redirect_to: raise errors.InvalidData( "Request must accept JSON or provide a redirect URL." ) @@ -1087,10 +1136,14 @@ async def handle_magic_link_register( magic_link_token=magic_link_token, ) ) + logger.info( + f"Sending magic link: identity_id={email_factor.identity.id}, " + f"email={email}" + ) await magic_link_client.send_magic_link( email=email, link_url=f"{self.base_path}/magic-link/authenticate", - redirect_on_failure=redirect_on_failure, + redirect_on_failure=allowed_redirect_on_failure.url, token=magic_link_token, ) @@ -1102,10 +1155,12 @@ async def handle_magic_link_register( response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(return_data).encode() - elif maybe_redirect_to: - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = util.join_url_params( - maybe_redirect_to, return_data + elif allowed_redirect_to: + return self._do_redirect( + response, + allowed_redirect_to.map( + lambda u: util.join_url_params(u, return_data) + ), ) else: # This should not happen since we check earlier for this case @@ -1117,14 +1172,21 @@ async def handle_magic_link_register( if request_accepts_json: raise ex - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "email": data.get('email', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + logger.error( + f"Error sending magic link: error={error_message}, " + f"email={email}" + ) + redirect_url = allowed_redirect_on_failure.map( + lambda u: util.join_url_params( + u, + { + "error": error_message, + "email": email, + }, + ) ) + return self._do_redirect(response, redirect_url) async def handle_magic_link_email( self, @@ -1133,8 +1195,6 @@ async def handle_magic_link_email( ) -> None: data = self._get_data_from_request(request) - maybe_redirect_to = data.get("redirect_to") - try: _check_keyset( data, @@ -1160,12 +1220,10 @@ async def handle_magic_link_email( "Error redirect URL does not match any allowed URLs.", ) - if maybe_redirect_to and not self._is_url_allowed( - maybe_redirect_to - ): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) + allowed_redirect_to = self._maybe_make_allowed_url( + data.get("redirect_to") + ) + magic_link_client = magic_link.Client( db=self.db, issuer=self.base_path, @@ -1176,10 +1234,15 @@ async def handle_magic_link_email( email ) if email_factor is None: + logger.error( + f"Cannot send magic link email: no email factor found for " + f"email={email}" + ) await auth_emails.send_fake_email(self.tenant) else: + identity_id = email_factor.identity.id magic_link_token = magic_link_client.make_magic_link_token( - identity_id=email_factor.identity.id, + identity_id=identity_id, callback_url=callback_url, challenge=challenge, ) @@ -1187,7 +1250,7 @@ async def handle_magic_link_email( webhook.MagicLinkRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), - identity_id=email_factor.identity.id, + identity_id=identity_id, email_factor_id=email_factor.id, magic_link_token=magic_link_token, ) @@ -1198,15 +1261,21 @@ async def handle_magic_link_email( link_url=f"{self.base_path}/magic-link/authenticate", redirect_on_failure=redirect_on_failure, ) + logger.info( + "Sent magic link email: " + f"identity_id={identity_id}, email={email}" + ) return_data = { "email_sent": email, } - if maybe_redirect_to: - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = util.join_url_params( - maybe_redirect_to, return_data + if allowed_redirect_to: + return self._do_redirect( + response, + allowed_redirect_to.map( + lambda u: util.join_url_params(u, return_data) + ), ) else: response.status = http.HTTPStatus.OK @@ -1214,23 +1283,24 @@ async def handle_magic_link_email( response.body = json.dumps(return_data).encode() except Exception as ex: redirect_on_failure = data.get( - "redirect_on_failure", maybe_redirect_to + "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is None: raise ex else: - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - "email": data.get('email', ''), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + logger.error( + f"Error sending magic link email: error={error_message}, " + f"email={email}" + ) + error_redirect_url = util.join_url_params( + redirect_on_failure, + { + "error": error_message, + "email": email, + }, ) + self._try_redirect(response, error_redirect_url) async def handle_magic_link_authenticate( self, @@ -1260,10 +1330,11 @@ async def handle_magic_link_authenticate( identity_id, datetime.datetime.now(datetime.timezone.utc) ) - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = util.join_url_params( - callback_url, {"code": code} + return self._try_redirect( + response, + util.join_url_params(callback_url, {"code": code}), ) + except Exception as ex: redirect_on_failure = _maybe_get_search_param( query, "redirect_on_failure" @@ -1271,17 +1342,18 @@ async def handle_magic_link_authenticate( if redirect_on_failure is None: raise ex else: - if not self._is_url_allowed(redirect_on_failure): - raise errors.InvalidData( - "Redirect URL does not match any allowed URLs.", - ) - response.status = http.HTTPStatus.FOUND - redirect_params = { - "error": str(ex), - } - response.custom_headers["Location"] = util.join_url_params( - redirect_on_failure, redirect_params + error_message = str(ex) + logger.error( + f"Error authenticating magic link: error={error_message}, " + f"token={token}" + ) + redirect_url = util.join_url_params( + redirect_on_failure, + { + "error": error_message, + }, ) + return self._try_redirect(response, redirect_url) async def handle_webauthn_register_options( self, @@ -1418,12 +1490,21 @@ async def handle_webauthn_register( response.body = json.dumps( {"verification_email_sent_at": (now_iso8601)} ).encode() + logger.info( + f"Sent verification email: identity_id={identity_id}, " + f"email={email}" + ) else: if pkce_code is None: raise errors.PKCECreationFailed response.body = json.dumps( {"code": pkce_code, "provider": provider_name} ).encode() + logger.info( + f"WebAuthn registration successful: identity_id={identity_id}, " + f"email={email}, " + f"pkce_id={pkce_code}" + ) async def handle_webauthn_authenticate_options( self, @@ -1487,6 +1568,12 @@ async def handle_webauthn_authenticate( self.db, identity.id, pkce_challenge ) + logger.info( + f"WebAuthn authentication successful: identity_id={identity.id}, " + f"email={email}, " + f"pkce_id={code}" + ) + response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps( @@ -1777,10 +1864,10 @@ async def handle_ui_verify( redirect_to = maybe_redirect_to or redirect_to redirect_to = ( - _with_appended_qs( + util.join_url_params( redirect_to, { - "code": [maybe_pkce_code], + "code": maybe_pkce_code, }, ) if maybe_pkce_code @@ -1806,9 +1893,7 @@ async def handle_ui_verify( # Only redirect back if verification succeeds if is_valid: - response.status = http.HTTPStatus.FOUND - response.custom_headers["Location"] = redirect_to - return + return self._try_redirect(response, redirect_to) app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK @@ -1938,7 +2023,8 @@ async def _maybe_send_webhook(self, event: webhook.Event) -> None: event=event, ) logger.info( - f"Sent webhook request {request_id} for event {event!r}" + f"Sent webhook request {request_id} " + f"to {webhook_config.url} for event {event!r}" ) def _get_callback_url(self) -> str: @@ -2338,6 +2424,36 @@ def _is_url_allowed(self, url: str) -> bool: return False + def _do_redirect( + self, response: protocol.HttpResponse, allowed_url: AllowedUrl + ) -> None: + response.status = http.HTTPStatus.FOUND + response.custom_headers["Location"] = allowed_url.url + + def _try_redirect(self, response: protocol.HttpResponse, url: str) -> None: + allowed_url = self._make_allowed_url(url) + self._do_redirect(response, allowed_url) + + def _make_allowed_url(self, url: str) -> AllowedUrl: + if not self._is_url_allowed(url): + raise errors.InvalidData( + "Redirect URL does not match any allowed URLs.", + ) + return AllowedUrl(url) + + def _maybe_make_allowed_url( + self, url: Optional[str] + ) -> Optional[AllowedUrl]: + return self._make_allowed_url(url) if url else None + + +@dataclasses.dataclass +class AllowedUrl: + url: str + + def map(self, f: Callable[[str], str]) -> "AllowedUrl": + return AllowedUrl(f(self.url)) + def _fail_with_error( *, @@ -2350,6 +2466,7 @@ def _fail_with_error( "type": str(ex.__class__.__name__), } + logger.error(f"Failed to handle HTTP request: {err_dct!r}") response.body = json.dumps({"error": err_dct}).encode() response.status = status @@ -2426,15 +2543,6 @@ def _set_cookie( response.custom_headers["Set-Cookie"] = val.OutputString() -def _with_appended_qs(url: str, query: dict[str, list[str]]) -> str: - url_parts = list(urllib.parse.urlparse(url)) - existing_query = urllib.parse.parse_qs(url_parts[4]) - existing_query.update(query) - - url_parts[4] = urllib.parse.urlencode(existing_query, doseq=True) - return urllib.parse.urlunparse(url_parts) - - def _check_keyset(candidate: dict[str, Any], keyset: set[str]) -> None: missing_fields = [field for field in keyset if field not in candidate] if missing_fields: diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index 1ed14ff1ccd..a1333ee2e81 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -104,16 +104,18 @@ async def get_authorize_url(self, state: str, redirect_uri: str) -> str: async def handle_callback( self, code: str, redirect_uri: str - ) -> tuple[data.Identity, bool, str | None, str | None]: + ) -> tuple[data.Identity, bool, str | None, str | None, str | None]: response = await self.provider.exchange_code(code, redirect_uri) user_info = await self.provider.fetch_user_info(response) auth_token = response.access_token refresh_token = response.refresh_token + source_id_token = user_info.source_id_token return ( *(await self._handle_identity(user_info)), auth_token, refresh_token, + source_id_token, ) async def _handle_identity( diff --git a/edb/server/protocol/auth_ext/pkce.py b/edb/server/protocol/auth_ext/pkce.py index b4d1c5ddde8..755310c9df0 100644 --- a/edb/server/protocol/auth_ext/pkce.py +++ b/edb/server/protocol/auth_ext/pkce.py @@ -46,6 +46,7 @@ class PKCEChallenge: challenge: str auth_token: str | None refresh_token: str | None + id_token: str | None identity_id: str | None @@ -95,6 +96,7 @@ async def add_provider_tokens( id: str, auth_token: str | None, refresh_token: str | None, + id_token: str | None, ) -> str: r = await execute.parse_execute_json( db, @@ -104,12 +106,14 @@ async def add_provider_tokens( set { auth_token := $auth_token, refresh_token := $refresh_token, + id_token := $id_token, } """, variables={ "id": id, "auth_token": auth_token, "refresh_token": refresh_token, + "id_token": id_token, }, cached_globally=True, ) @@ -129,6 +133,7 @@ async def get_by_id(db: edbtenant.dbview.Database, id: str) -> PKCEChallenge: challenge, auth_token, refresh_token, + id_token, identity_id := .identity.id } filter .id = $id diff --git a/edb/server/protocol/auth_ext/smtp.py b/edb/server/protocol/auth_ext/smtp.py deleted file mode 100644 index 90546fe354c..00000000000 --- a/edb/server/protocol/auth_ext/smtp.py +++ /dev/null @@ -1,171 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2023-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any, Optional, Union, Sequence - -import asyncio -import email -import email.message -import os -import pickle -import hashlib - -import aiosmtplib - -from edb.common import retryloop -from edb.ir import statypes - -from . import util - - -_semaphore: asyncio.BoundedSemaphore | None = None - - -async def send_email( - db: Any, - message: Union[ - email.message.EmailMessage, - email.message.Message, - str, - bytes, - ], - sender: Optional[str] = None, - recipients: Optional[Union[str, Sequence[str]]] = None, - test_mode: bool = False, -) -> None: - global _semaphore - if _semaphore is None: - _semaphore = asyncio.BoundedSemaphore( - int(os.environ.get("EDGEDB_SERVER_AUTH_SMTP_CONCURRENCY", 5)) - ) - - host = ( - util.maybe_get_config( - db, - "ext::auth::SMTPConfig::host", - ) - or "localhost" - ) - port = util.maybe_get_config( - db, - "ext::auth::SMTPConfig::port", - expected_type=int, - ) - username = util.maybe_get_config( - db, - "ext::auth::SMTPConfig::username", - ) - password = util.maybe_get_config( - db, - "ext::auth::SMTPConfig::password", - ) - timeout_per_attempt = util.get_config( - db, - "ext::auth::SMTPConfig::timeout_per_attempt", - expected_type=statypes.Duration, - ) - req_timeout = timeout_per_attempt.to_microseconds() / 1_000_000.0 - timeout_per_email = util.get_config( - db, - "ext::auth::SMTPConfig::timeout_per_email", - expected_type=statypes.Duration, - ) - validate_certs = util.get_config( - db, - "ext::auth::SMTPConfig::validate_certs", - expected_type=bool, - ) - security = util.get_config( - db, - "ext::auth::SMTPConfig::security", - ) - start_tls: bool | None - match security: - case "PlainText": - use_tls = False - start_tls = False - - case "TLS": - use_tls = True - start_tls = False - - case "STARTTLS": - use_tls = False - start_tls = True - - case "STARTTLSOrPlainText": - use_tls = False - start_tls = None - - case _: - raise NotImplementedError - - rloop = retryloop.RetryLoop( - timeout=timeout_per_email.to_microseconds() / 1_000_000.0, - backoff=retryloop.exp_backoff(), - ignore=( - aiosmtplib.SMTPConnectError, - aiosmtplib.SMTPHeloError, - aiosmtplib.SMTPServerDisconnected, - aiosmtplib.SMTPConnectTimeoutError, - aiosmtplib.SMTPConnectResponseError, - ), - ) - async for iteration in rloop: - async with iteration: - async with _semaphore: - # Currently we are not reusing SMTP connections, but ideally we - # should replace this with a pool of connections, and drop idle - # connections after configured time. - args = dict( - message=message, - sender=sender, - recipients=recipients, - hostname=host, - port=port, - username=username, - password=password, - timeout=req_timeout, - use_tls=use_tls, - start_tls=start_tls, - validate_certs=validate_certs, - ) - if test_mode: - recipients_list: list[str] - if isinstance(recipients, str): - recipients_list = [recipients] - elif recipients is None: - recipients_list = [] - else: - recipients_list = list(recipients) - - hash_input = f"{sender}{','.join(recipients_list)}" - file_name_hash = hashlib.sha256( - hash_input.encode() - ).hexdigest() - file_name = f"/tmp/edb-test-email-{file_name_hash}.pickle" - test_file = os.environ.get( - "EDGEDB_TEST_EMAIL_FILE", - file_name, - ) - if os.path.exists(test_file): - os.unlink(test_file) - with open(test_file, "wb") as f: - pickle.dump(args, f) - else: - await aiosmtplib.send(**args) # type: ignore diff --git a/edb/server/protocol/auth_ext/ui/__init__.py b/edb/server/protocol/auth_ext/ui/__init__.py index d066a0c0984..975f3393e45 100644 --- a/edb/server/protocol/auth_ext/ui/__init__.py +++ b/edb/server/protocol/auth_ext/ui/__init__.py @@ -20,9 +20,7 @@ from typing import cast, Optional import html - -from email.mime import multipart -from email.mime import text as mime_text +import email.message from edb.server.protocol.auth_ext import config as auth_config @@ -701,21 +699,17 @@ def render_magic_link_sent_page( def render_password_reset_email( *, - from_addr: str, to_addr: str, reset_url: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = "#007bff", -) -> multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Reset password" - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Somebody requested a new password for the {app_name or ''} account associated with {to_addr}. @@ -723,13 +717,8 @@ def render_password_reset_email( email address: {reset_url} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - - content = f""" + """ + html_content = f""" multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = ( f"Verify your email{f' for {app_name}' if app_name else ''}" ) - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Congratulations, you're registered{f' at {app_name}' if app_name else ''}! Please paste the following URL into your browser address bar to verify your email address: {verify_url} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - - content = f""" + """ + html_content = f""" - -""" - - html_msg = mime_text.MIMEText( + """ + msg.set_content(plain_text_content, subtype="plain") + msg.set_content( render.base_default_email( + content=html_content, app_name=app_name, logo_url=logo_url, - content=content, ), - "html", - "utf-8", + subtype="html", ) - alternative.attach(html_msg) - msg.attach(alternative) return msg def render_magic_link_email( *, - from_addr: str, to_addr: str, link: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = "#007bff", -) -> multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Sign in link" - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Please paste the following URL into your browser address bar to be signed into your account: {link} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - content = f""" + """ + html_content = f""" """ - html_msg = mime_text.MIMEText( + msg.set_content(plain_text_content, subtype="plain") + msg.set_content( render.base_default_email( + content=html_content, app_name=app_name, logo_url=logo_url, - content=content, ), - "html", - "utf-8", + subtype="html", ) - alternative.attach(html_msg) - msg.attach(alternative) return msg diff --git a/edb/server/protocol/auth_ext/util.py b/edb/server/protocol/auth_ext/util.py index 60440224d8a..57b6977386b 100644 --- a/edb/server/protocol/auth_ext/util.py +++ b/edb/server/protocol/auth_ext/util.py @@ -22,6 +22,8 @@ import base64 import urllib.parse import datetime +import html + from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand from cryptography.hazmat.backends import default_backend @@ -101,6 +103,17 @@ def get_config_typename(config_value: edb_config.SettingValue) -> str: return config_value._tspec.name # type: ignore +def escape_and_truncate(input_str: str | None, max_len: int) -> str | None: + if input_str is None: + return None + trunc = ( + f"{input_str[:max_len]}..." + if len(input_str) > max_len + else input_str + ) + return html.escape(trunc) + + def get_app_details_config(db: Any) -> config.AppDetailsConfig: ui_config = cast( Optional[config.UIConfig], @@ -108,21 +121,25 @@ def get_app_details_config(db: Any) -> config.AppDetailsConfig: ) return config.AppDetailsConfig( - app_name=( + app_name=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::app_name") - or (ui_config.app_name if ui_config else None) + or (ui_config.app_name if ui_config else None), + 100, ), - logo_url=( + logo_url=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::logo_url") - or (ui_config.logo_url if ui_config else None) + or (ui_config.logo_url if ui_config else None), + 2000, ), - dark_logo_url=( + dark_logo_url=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::dark_logo_url") - or (ui_config.dark_logo_url if ui_config else None) + or (ui_config.dark_logo_url if ui_config else None), + 2000, ), - brand_color=( + brand_color=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::brand_color") - or (ui_config.brand_color if ui_config else None) + or (ui_config.brand_color if ui_config else None), + 8, ), ) diff --git a/edb/server/protocol/auth_ext/webauthn.py b/edb/server/protocol/auth_ext/webauthn.py index 1bfeba03da0..e0d79972d4e 100644 --- a/edb/server/protocol/auth_ext/webauthn.py +++ b/edb/server/protocol/auth_ext/webauthn.py @@ -75,7 +75,8 @@ def _get_provider(self) -> config.WebAuthnProvider: ) def _get_app_name(self) -> Optional[str]: - return util.maybe_get_config(self.db, "ext::auth::AuthConfig::app_name") + app_config = util.get_app_details_config(self.db) + return app_config.app_name async def create_registration_options_for_email( self, email: str, diff --git a/edb/server/protocol/auth_ext/webhook.py b/edb/server/protocol/auth_ext/webhook.py index 02c2de9a8bc..1e0fc977e65 100644 --- a/edb/server/protocol/auth_ext/webhook.py +++ b/edb/server/protocol/auth_ext/webhook.py @@ -42,8 +42,9 @@ class Event(abc.ABC): def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" - f"event_id={self.event_id!r}, " - f"timestamp={self.timestamp!r})" + f"timestamp={self.timestamp!r}, " + f"event_id={self.event_id!r}" + ")" ) @@ -64,6 +65,15 @@ class IdentityCreated(Event, HasIdentity): init=False, ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}" + ")" + ) + @dataclasses.dataclass class IdentityAuthenticated(Event, HasIdentity): @@ -72,6 +82,15 @@ class IdentityAuthenticated(Event, HasIdentity): init=False, ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}" + ")" + ) + @dataclasses.dataclass class EmailFactorCreated(Event, HasIdentity, HasEmailFactor): @@ -80,6 +99,16 @@ class EmailFactorCreated(Event, HasIdentity, HasEmailFactor): init=False, ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}, " + f"email_factor_id={self.email_factor_id}" + ")" + ) + @dataclasses.dataclass class EmailVerificationRequested(Event, HasIdentity, HasEmailFactor): @@ -91,6 +120,16 @@ class EmailVerificationRequested(Event, HasIdentity, HasEmailFactor): ) verification_token: str + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}, " + f"email_factor_id={self.email_factor_id}" + ")" + ) + @dataclasses.dataclass class EmailVerified(Event, HasIdentity, HasEmailFactor): @@ -99,6 +138,16 @@ class EmailVerified(Event, HasIdentity, HasEmailFactor): init=False, ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}, " + f"email_factor_id={self.email_factor_id}" + ")" + ) + @dataclasses.dataclass class PasswordResetRequested(Event, HasIdentity, HasEmailFactor): @@ -108,6 +157,16 @@ class PasswordResetRequested(Event, HasIdentity, HasEmailFactor): ) reset_token: str + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}, " + f"email_factor_id={self.email_factor_id}" + ")" + ) + @dataclasses.dataclass class MagicLinkRequested(Event, HasIdentity, HasEmailFactor): @@ -117,6 +176,16 @@ class MagicLinkRequested(Event, HasIdentity, HasEmailFactor): ) magic_link_token: str + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"timestamp={self.timestamp}, " + f"event_id={self.event_id}, " + f"identity_id={self.identity_id}, " + f"email_factor_id={self.email_factor_id}" + ")" + ) + class DateTimeEncoder(json.JSONEncoder): def default(self, obj: typing.Any) -> typing.Any: diff --git a/edb/server/protocol/auth_helpers.pxd b/edb/server/protocol/auth_helpers.pxd index e43d6c4f504..9100d670051 100644 --- a/edb/server/protocol/auth_helpers.pxd +++ b/edb/server/protocol/auth_helpers.pxd @@ -18,7 +18,7 @@ cdef extract_token_from_auth_data(bytes auth_data) -cdef auth_jwt(tenant, str prefixed_token, str user, str dbname) +cdef auth_jwt(tenant, prefixed_token, str user, str dbname) cdef _check_jwt_authz(tenant, claims, token_version, str user, str dbname) cdef _get_jwt_edb_scope(claims, claim) cdef scram_get_verifier(tenant, str user) diff --git a/edb/server/protocol/auth_helpers.pyx b/edb/server/protocol/auth_helpers.pyx index ec8f1aeb52f..f0b22bc764f 100644 --- a/edb/server/protocol/auth_helpers.pyx +++ b/edb/server/protocol/auth_helpers.pyx @@ -39,7 +39,7 @@ cdef extract_token_from_auth_data(auth_data: bytes): return scheme.lower(), payload.strip() -cdef auth_jwt(tenant, prefixed_token: str, user: str, dbname: str): +cdef auth_jwt(tenant, prefixed_token: str | None, user: str, dbname: str): if not prefixed_token: raise errors.AuthenticationError( 'authentication failed: no authorization data provided') diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 72a8f265d9b..6f23400bdeb 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -43,6 +43,8 @@ from edb import buildmeta from edb import edgeql from edb.edgeql import qltypes +from edb.pgsql import parser as pgparser + from edb.server.pgproto cimport hton from edb.server.pgproto.pgproto cimport ( WriteBuffer, @@ -94,6 +96,10 @@ cdef object CARD_AT_MOST_ONE = compiler.Cardinality.AT_MOST_ONE cdef object CARD_MANY = compiler.Cardinality.MANY cdef object FMT_NONE = compiler.OutputFormat.NONE +cdef object FMT_BINARY = compiler.OutputFormat.BINARY + +cdef object LANG_EDGEQL = compiler.InputLanguage.EDGEQL +cdef object LANG_SQL = compiler.InputLanguage.SQL cdef tuple DUMP_VER_MIN = (0, 7) cdef tuple DUMP_VER_MAX = edbdef.CURRENT_PROTOCOL @@ -486,12 +492,25 @@ cdef class EdgeConnection(frontend.FrontendConnection): fe_conn=self, ) - def _tokenize(self, eql: bytes) -> edgeql.Source: + def _tokenize( + self, + eql: bytes, + lang: enums.InputLanguage, + ) -> edgeql.Source: text = eql.decode('utf-8') - if debug.flags.edgeql_disable_normalization: - return edgeql.Source.from_string(text) + if lang is LANG_EDGEQL: + if debug.flags.edgeql_disable_normalization: + return edgeql.Source.from_string(text) + else: + return edgeql.NormalizedSource.from_string(text) + elif lang is LANG_SQL: + if debug.flags.edgeql_disable_normalization: + return pgparser.Source.from_string(text) + else: + return pgparser.NormalizedSource.from_string(text) else: - return edgeql.NormalizedSource.from_string(text) + raise errors.UnsupportedFeatureError( + f"unsupported input language: {lang}") async def _suppress_tx_timeout(self): async with self.with_pgcon() as conn: @@ -529,6 +548,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): 'Cache key', source.cache_key(), f"protocol_version={query_req.protocol_version}", + f"input_language={query_req.input_language}", f"output_format={query_req.output_format}", f"expect_one={query_req.expect_one}", f"implicit_limit={query_req.implicit_limit}", @@ -550,9 +570,18 @@ cdef class EdgeConnection(frontend.FrontendConnection): if suppress_timeout: await self._suppress_tx_timeout() try: - return await dbv.parse( - query_req, allow_capabilities=allow_capabilities - ) + if query_req.input_language is LANG_SQL: + async with self.with_pgcon() as pg_conn: + return await dbv.parse( + query_req, + allow_capabilities=allow_capabilities, + pgcon=pg_conn, + ) + else: + return await dbv.parse( + query_req, + allow_capabilities=allow_capabilities, + ) finally: if suppress_timeout: await self._restore_tx_timeout(dbv) @@ -752,6 +781,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): bint inline_typenames = False bint inline_typeids = False bint inline_objectids = False + object cardinality object output_format bint expect_one = False bytes query @@ -779,10 +809,29 @@ cdef class EdgeConnection(frontend.FrontendConnection): & messages.CompilationFlag.INJECT_OUTPUT_OBJECT_IDS ) + if self.protocol_version >= (3, 0): + lang = rpc.deserialize_input_language(self.buffer.read_byte()) + else: + lang = LANG_EDGEQL + output_format = rpc.deserialize_output_format(self.buffer.read_byte()) - expect_one = ( - self.parse_cardinality(self.buffer.read_byte()) is CARD_AT_MOST_ONE - ) + if ( + lang is LANG_SQL + and output_format is not FMT_NONE + and output_format is not FMT_BINARY + ): + raise errors.UnsupportedFeatureError( + "non-binary output format is not supported with " + "SQL as the input language" + ) + + cardinality = self.parse_cardinality(self.buffer.read_byte()) + expect_one = cardinality is CARD_AT_MOST_ONE + if lang is LANG_SQL and cardinality is not CARD_MANY: + raise errors.UnsupportedFeatureError( + "output cardinality assertions are not supported with " + "SQL as the input language" + ) query = self.buffer.read_len_prefixed_bytes() if not query: @@ -803,10 +852,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): cfg_ser = self.server.compilation_config_serializer rv = rpc.CompilationRequest( - source=self._tokenize(query), + source=self._tokenize(query, lang), protocol_version=self.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, + input_language=lang, output_format=output_format, expect_one=expect_one, implicit_limit=implicit_limit, @@ -817,6 +867,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): session_config=_dbview.get_session_config(), database_config=_dbview.get_database_config(), system_config=_dbview.get_compilation_system_config(), + role_name=self.username, + branch_name=self.dbname, ) return rv, allow_capabilities @@ -893,6 +945,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): else: compiled = _dbview.as_compiled(query_req, query_unit_group) + if query_req.input_language is LANG_SQL and len(query_unit_group) > 1: + raise errors.UnsupportedFeatureError( + "multi-statement SQL scripts are not supported yet") + self._query_count += 1 # Clear the _last_anon_compiled so that the next Execute - if @@ -1438,6 +1494,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): protocol_version=self.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, + role_name=self.username, + branch_name=self.dbname, ) compiled = await _dbview.parse(query_req) @@ -1582,7 +1640,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): if query_unit.sql: if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) + await pgcon.parse_execute(query=query_unit) + ddl_ret = pgcon.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: @@ -1837,6 +1896,8 @@ async def run_script( schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, output_format=FMT_NONE, + role_name=user, + branch_name=database, ), ) if len(compiled.query_unit_group) > 1: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 4760e4bb341..43f5c896b29 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -94,19 +94,14 @@ cdef class ExecutionGroup: if state is not None: await be_conn.wait_for_state_resp(state, state_sync=0) for i, unit in enumerate(self.group): - if unit.output_format == FMT_NONE and unit.ddl_stmt_id is None: - for sql in unit.sql: - await be_conn.wait_for_command( - unit, parse_array[i], dbver, ignore_data=True - ) - rv = None - else: - for sql in unit.sql: - rv = await be_conn.wait_for_command( - unit, parse_array[i], dbver, - ignore_data=False, - fe_conn=fe_conn, - ) + ignore_data = unit.output_format == FMT_NONE + rv = await be_conn.wait_for_command( + unit, + parse_array[i], + dbver, + ignore_data=ignore_data, + fe_conn=None if ignore_data else fe_conn, + ) return rv @@ -135,13 +130,11 @@ cpdef ExecutionGroup build_cache_persistence_units( assert serialized_result is not None if evict: - group.append(compiler.QueryUnit(sql=(evict,), status=b'')) + group.append(compiler.QueryUnit(sql=evict, status=b'')) if persist: - group.append(compiler.QueryUnit(sql=(persist,), status=b'')) + group.append(compiler.QueryUnit(sql=persist, status=b'')) group.append( - compiler.QueryUnit( - sql=(insert_sql,), sql_hash=sql_hash, status=b'', - ), + compiler.QueryUnit(sql=insert_sql, sql_hash=sql_hash, status=b''), args_ser.combine_raw_args(( query_unit.cache_key.bytes, query_unit.user_schema_version.bytes, @@ -276,12 +269,15 @@ async def execute( if query_unit.sql: if query_unit.user_schema: - ddl_ret = await be_conn.run_ddl(query_unit, state) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] + await be_conn.parse_execute(query=query_unit, state=state) + if query_unit.ddl_stmt_id is not None: + ddl_ret = be_conn.load_last_ddl_return(query_unit) + if ddl_ret and ddl_ret['new_types']: + new_types = ddl_ret['new_types'] else: + data_types = [] bound_args_buf = args_ser.recode_bind_args( - dbv, compiled, bind_args) + dbv, compiled, bind_args, None, data_types) assert not (query_unit.database_config and query_unit.needs_readback), ( @@ -294,6 +290,7 @@ async def execute( query=query_unit, fe_conn=fe_conn if not read_data else None, bind_data=bound_args_buf, + param_data_types=data_types, use_prep_stmt=use_prep_stmt, state=state, dbver=dbv.dbver, @@ -358,6 +355,20 @@ async def execute( if config_ops: await dbv.apply_config_ops(be_conn, config_ops) + if query_unit.user_schema and debug.flags.delta_validate_reflection: + global_schema = ( + query_unit.global_schema or dbv.get_global_schema_pickle()) + new_user_schema = await dbv.tenant._debug_introspect( + be_conn, global_schema) + compiler_pool = dbv.server.get_compiler_pool() + await compiler_pool.validate_schema_equivalence( + query_unit.user_schema, + new_user_schema, + global_schema, + dbv._last_comp_state, + ) + query_unit.user_schema = new_user_schema + except Exception as ex: # If we made schema changes, include the new schema in the # exception so that it can be used when interpreting. @@ -505,35 +516,29 @@ async def execute_script( if query_unit.sql: parse = parse_array[idx] + fe_output = query_unit.output_format != FMT_NONE + ignore_data = ( + not fe_output + and not query_unit.needs_readback + ) + data = await conn.wait_for_command( + query_unit, + parse, + dbver, + ignore_data=ignore_data, + fe_conn=fe_conn if fe_output else None, + ) + if query_unit.ddl_stmt_id: - ddl_ret = await conn.handle_ddl_in_script( - query_unit, parse, dbver - ) + ddl_ret = conn.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] - elif query_unit.needs_readback: - config_data = [] - for sql in query_unit.sql: - config_data = await conn.wait_for_command( - query_unit, parse, dbver, ignore_data=False - ) - if config_data: - config_ops = [ - config.Operation.from_json(r[0][1:]) - for r in config_data - ] - elif query_unit.output_format == FMT_NONE: - for sql in query_unit.sql: - await conn.wait_for_command( - query_unit, parse, dbver, ignore_data=True - ) - else: - for sql in query_unit.sql: - data = await conn.wait_for_command( - query_unit, parse, dbver, - ignore_data=False, - fe_conn=fe_conn, - ) + + if query_unit.needs_readback and data: + config_ops = [ + config.Operation.from_json(r[0][1:]) + for r in data + ] if config_ops: await dbv.apply_config_ops(conn, config_ops) @@ -571,6 +576,22 @@ async def execute_script( raise else: + updated_user_schema = False + if user_schema and debug.flags.delta_validate_reflection: + cur_global_schema = ( + global_schema or dbv.get_global_schema_pickle()) + new_user_schema = await dbv.tenant._debug_introspect( + conn, cur_global_schema) + compiler_pool = dbv.server.get_compiler_pool() + await compiler_pool.validate_schema_equivalence( + user_schema, + new_user_schema, + cur_global_schema, + dbv._last_comp_state, + ) + user_schema = new_user_schema + updated_user_schema = True + if not in_tx: side_effects = dbv.commit_implicit_tx( user_schema, @@ -586,6 +607,9 @@ async def execute_script( state = dbv.serialize_state() if state is not orig_state: conn.last_state = state + elif updated_user_schema: + dbv._in_tx_user_schema_pickle = user_schema + if unit_group.state_serializer is not None: dbv.set_state_serializer(unit_group.state_serializer) @@ -600,7 +624,7 @@ async def execute_system_config( conn: pgcon.PGConnection, dbv: dbview.DatabaseConnectionView, query_unit: compiler.QueryUnit, - state: bytes, + state: bytes | None, ): if query_unit.is_system_config: dbv.server.before_alter_system_config() @@ -609,12 +633,7 @@ async def execute_system_config( await conn.sql_fetch(b'select 1', state=state) if query_unit.sql: - if len(query_unit.sql) > 1: - raise errors.InternalServerError( - "unexpected multiple SQL statements in CONFIGURE INSTANCE " - "compilation product" - ) - data = await conn.sql_fetch_col(query_unit.sql[0]) + data = await conn.sql_fetch_col(query_unit.sql) else: data = None diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index 7b4c6b1e50b..c6f098dbf46 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -1240,7 +1240,7 @@ cdef class PgConnection(frontend.FrontendConnection): dbv: ConnectionView, stmt_name: str, local_stmts: set[str], - actions: list[PGMessage], + actions ) -> PGMessage: """Make sure given *stmt_name* is known by Postgres diff --git a/edb/server/protocol/server_info.py b/edb/server/protocol/server_info.py index 65d65a7baff..921d3f9f344 100644 --- a/edb/server/protocol/server_info.py +++ b/edb/server/protocol/server_info.py @@ -44,7 +44,7 @@ def default(self, obj: Any) -> Any: return list(obj) if isinstance(obj, immutables.Map): return dict(obj.items()) - if dataclasses.is_dataclass(obj): + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): return dataclasses.asdict(obj) if isinstance(obj, statypes.Duration): return obj.to_iso8601() diff --git a/edb/server/rust_async_channel.py b/edb/server/rust_async_channel.py new file mode 100644 index 00000000000..9de2644a863 --- /dev/null +++ b/edb/server/rust_async_channel.py @@ -0,0 +1,104 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import io +import logging + + +from typing import Protocol, Optional, Tuple, Any, Callable + +logger = logging.getLogger("edb.server") + +MAX_BATCH_SIZE = 16 + + +class RustPipeProtocol(Protocol): + def _read(self) -> Tuple[Any, ...]: ... + + def _try_read(self) -> Optional[Tuple[Any, ...]]: ... + + def _close_pipe(self) -> None: ... + + _fd: int + + +class RustAsyncChannel: + _buffered_reader: io.BufferedReader + _skip_reads: int + _closed: asyncio.Event + + def __init__( + self, + pipe: RustPipeProtocol, + callback: Callable[[Tuple[Any, ...]], None], + ) -> None: + fd = pipe._fd + self._buffered_reader = io.BufferedReader( + io.FileIO(fd), buffer_size=MAX_BATCH_SIZE + ) + self._fd = fd + self._pipe = pipe + self._callback = callback + self._skip_reads = 0 + self._closed = asyncio.Event() + + def __del__(self): + if not self._closed.is_set(): + logger.error(f"RustAsyncChannel {id(self)} was not closed") + + async def run(self): + loop = asyncio.get_running_loop() + loop.add_reader(self._fd, self._channel_read) + try: + await self._closed.wait() + finally: + loop.remove_reader(self._fd) + + def close(self): + if not self._closed.is_set(): + self._pipe._close_pipe() + self._buffered_reader.close() + self._closed.set() + + def read_hint(self): + while msg := self._pipe._try_read(): + self._skip_reads += 1 + self._callback(msg) + + def _channel_read(self) -> None: + try: + n = len(self._buffered_reader.read1(MAX_BATCH_SIZE)) + if n == 0: + return + if self._skip_reads > n: + self._skip_reads -= n + return + n -= self._skip_reads + self._skip_reads = 0 + for _ in range(n): + msg = self._pipe._read() + if msg is None: + self.close() + return + self._callback(msg) + except Exception: + logger.error( + f"Error reading from Rust async channel", exc_info=True + ) + self.close() diff --git a/edb/server/server.py b/edb/server/server.py index 557fa39eaf0..73540b198ed 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -1274,6 +1274,7 @@ async def init(self) -> None: await self._load_instance_data() await self._maybe_patch() await self._tenant.init() + self._load_sidechannel_configs() await super().init() def get_default_tenant(self) -> edbtenant.Tenant: @@ -1282,6 +1283,20 @@ def get_default_tenant(self) -> edbtenant.Tenant: def iter_tenants(self) -> Iterator[edbtenant.Tenant]: yield self._tenant + def _load_sidechannel_configs(self) -> None: + # TODO(fantix): Do something like this for multitenant + magic_smtp = os.getenv('EDGEDB_MAGIC_SMTP_CONFIG') + if magic_smtp: + email_type = self._config_settings['email_providers'].type + assert not isinstance(email_type, type) + configs = [ + config.CompositeConfigType.from_json_value( + entry, tspec=email_type, spec=self._config_settings + ) + for entry in json.loads(magic_smtp) + ] + self._tenant.set_sidechannel_configs(configs) + async def _get_patch_log( self, conn: pgcon.PGConnection, idx: int ) -> Optional[bootstrap.PatchEntry]: @@ -1433,7 +1448,7 @@ async def _maybe_apply_patches( db_config = self._parse_db_config(config_json, user_schema) try: logger.info("repairing database '%s'", dbname) - sql += bootstrap.prepare_repair_patch( + rep_sql = bootstrap.prepare_repair_patch( self._std_schema, self._refl_schema, user_schema, @@ -1442,6 +1457,7 @@ async def _maybe_apply_patches( self._tenant.get_backend_runtime_params(), db_config, ) + sql += (rep_sql,) except errors.EdgeDBError as e: if isinstance(e, errors.InternalServerError): raise @@ -1454,7 +1470,7 @@ async def _maybe_apply_patches( ) from e if sql: - await conn.sql_fetch(sql) + await conn.sql_execute(sql) logger.info( "finished applying patch %d to database '%s'", num, dbname) diff --git a/edb/server/smtp.py b/edb/server/smtp.py new file mode 100644 index 00000000000..28bf59b86ca --- /dev/null +++ b/edb/server/smtp.py @@ -0,0 +1,231 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import dataclasses +import email.message +import asyncio +import logging +import os +import hashlib +import pickle +import aiosmtplib + +from typing import Optional + +from edb.common import retryloop +from edb.ir import statypes +from edb import errors +from . import dbview + + +_semaphore: asyncio.BoundedSemaphore | None = None + +logger = logging.getLogger('edb.server.smtp') + + +@dataclasses.dataclass +class SMTPProviderConfig: + name: str + sender: Optional[str] + host: Optional[str] + port: Optional[int] + username: Optional[str] + password: Optional[str] + security: str + validate_certs: bool + timeout_per_email: statypes.Duration + timeout_per_attempt: statypes.Duration + + +class SMTP: + def __init__(self, db: dbview.Database): + current_provider = _get_current_email_provider(db) + self.sender = current_provider.sender or "noreply@example.com" + default_port = ( + 465 + if current_provider.security == "TLS" + else 587 if current_provider.security == "STARTTLS" else 25 + ) + use_tls: bool + start_tls: bool | None + match current_provider.security: + case "PlainText": + use_tls = False + start_tls = False + + case "TLS": + use_tls = True + start_tls = False + + case "STARTTLS": + use_tls = False + start_tls = True + + case "STARTTLSOrPlainText": + use_tls = False + start_tls = None + + case _: + raise NotImplementedError + + host = current_provider.host or "localhost" + port = current_provider.port or default_port + username = current_provider.username + password = current_provider.password + validate_certs = current_provider.validate_certs + timeout_per_attempt = current_provider.timeout_per_attempt + + req_timeout = timeout_per_attempt.to_microseconds() / 1_000_000.0 + self.timeout_per_email = ( + current_provider.timeout_per_email.to_microseconds() / 1_000_000.0 + ) + self.client = aiosmtplib.SMTP( + hostname=host, + port=port, + username=username, + password=password, + timeout=req_timeout, + use_tls=use_tls, + start_tls=start_tls, + validate_certs=validate_certs, + ) + + async def send( + self, + message: email.message.Message, + *, + test_mode: bool = False, + ) -> None: + global _semaphore + if _semaphore is None: + _semaphore = asyncio.BoundedSemaphore( + int( + os.environ.get( + "EDGEDB_SERVER_AUTH_SMTP_CONCURRENCY", + os.environ.get("EDGEDB_SERVER_SMTP_CONCURRENCY", 5), + ) + ) + ) + + # n.b. When constructing EmailMessage objects, we don't set the "From" + # header since that is configured in the SmtpProviderConfig. However, + # the EmailMessage will have the correct "To" header. + message["From"] = self.sender + rloop = retryloop.RetryLoop( + timeout=self.timeout_per_email, + backoff=retryloop.exp_backoff(), + ignore=( + aiosmtplib.SMTPConnectError, + aiosmtplib.SMTPHeloError, + aiosmtplib.SMTPServerDisconnected, + aiosmtplib.SMTPConnectTimeoutError, + aiosmtplib.SMTPConnectResponseError, + ), + ) + async for iteration in rloop: + async with iteration: + async with _semaphore: + # Currently we are not reusing SMTP connections, but + # ideally we should replace this with a pool of + # connections, and drop idle connections after configured + # time. + if test_mode: + self._send_test_mode_email(message) + else: + logger.info( + "Sending SMTP message to " + f"{self.client.hostname}:{self.client.port}" + ) + + async with self.client: + errors, response = await self.client.send_message( + message + ) + if errors: + logger.error( + f"SMTP server returned errors: {errors}" + ) + else: + logger.info( + f"SMTP message sent successfully: {response}" + ) + + def _send_test_mode_email(self, message: email.message.Message): + sender = message["From"] + recipients = message["To"] + recipients_list: list[str] + if isinstance(recipients, str): + recipients_list = [recipients] + elif recipients is None: + recipients_list = [] + else: + recipients_list = list(recipients) + + hash_input = f"{sender}{','.join(recipients_list)}" + file_name_hash = hashlib.sha256(hash_input.encode()).hexdigest() + file_name = f"/tmp/edb-test-email-{file_name_hash}.pickle" + test_file = os.environ.get( + "EDGEDB_TEST_EMAIL_FILE", + file_name, + ) + if os.path.exists(test_file): + os.unlink(test_file) + with open(test_file, "wb") as f: + logger.info(f"Dumping SMTP message to {test_file}") + args = dict( + message=message, + sender=sender, + recipients=recipients, + hostname=self.client.hostname, + port=self.client.port, + username=self.client._login_username, + password=self.client._login_password, + timeout=self.client.timeout, + use_tls=self.client.use_tls, + start_tls=self.client._start_tls_on_connect, + validate_certs=self.client.validate_certs, + ) + pickle.dump(args, f) + + +def _get_current_email_provider( + db: dbview.Database, +) -> SMTPProviderConfig: + current_provider_name = db.lookup_config("current_email_provider_name") + if current_provider_name is None: + raise errors.ConfigurationError("No email provider configured") + + found = None + objs = ( + list(db.lookup_config("email_providers")) + + db.tenant._sidechannel_email_configs + ) + for obj in objs: + if obj.name == current_provider_name: + as_json = obj.to_json_value() + as_json.pop('_tname', None) + found = SMTPProviderConfig(**as_json) + break + + if found is None: + raise errors.ConfigurationError( + f"No email provider named {current_provider_name!r}" + ) + return found diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 7b7b0a9a9a1..64a346b8359 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -138,6 +138,8 @@ class Tenant(ha_base.ClusterProtocol): _http_client: HttpClient | None + _sidechannel_email_configs: list[Any] + def __init__( self, cluster: pgcluster.BaseCluster, @@ -161,6 +163,7 @@ def __init__( self._accept_new_tasks = False self._file_watch_finalizers = [] self._introspection_locks = weakref.WeakValueDictionary() + self._sidechannel_email_configs = [] self._extensions_dirs = extensions_dir @@ -246,6 +249,9 @@ def set_server(self, server: edbserver.BaseServer) -> None: self._server = server self.__loop = server.get_loop() + def set_sidechannel_configs(self, configs: list[Any]) -> None: + self._sidechannel_email_configs = configs + def get_http_client(self, *, originator: str) -> HttpClient: if self._http_client is None: http_max_connections = self._server.config_lookup( @@ -1101,6 +1107,21 @@ async def _introspect_extensions( return extensions + async def _debug_introspect( + self, + conn: pgcon.PGConnection, + global_schema_pickle, + ) -> Any: + user_schema_json = ( + await self._server.introspect_user_schema_json(conn) + ) + db_config_json = await self._server.introspect_db_config(conn) + + compiler_pool = self._server.get_compiler_pool() + return (await compiler_pool.parse_user_schema_db_config( + user_schema_json, db_config_json, global_schema_pickle, + )).user_schema_pickle + async def introspect_db( self, dbname: str, @@ -1194,7 +1215,7 @@ async def _introspect_db( SELECT json_object_agg( "id"::text, - "backend_id" + json_build_array("backend_id", "name") )::text FROM edgedb_VER."_SchemaType" diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 73fca8ec430..3da73069a89 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -57,6 +57,9 @@ def raise_first_warning(warnings, res): raise warnings[0] +InputLanguage = protocol.InputLanguage + + class BaseTransaction(abc.ABC): ID_COUNTER = 0 @@ -419,6 +422,7 @@ async def _fetchall( self, query: str, *args, + __language__: protocol.InputLanguage = protocol.InputLanguage.EDGEQL, __limit__: int = 0, __typeids__: bool = False, __typenames__: bool = False, @@ -436,6 +440,7 @@ async def _fetchall( implicit_limit=__limit__, inline_typeids=__typeids__, inline_typenames=__typenames__, + input_language=__language__, output_format=protocol.OutputFormat.BINARY, allow_capabilities=__allow_capabilities__, ) @@ -457,6 +462,7 @@ async def _fetchall_json( qc=self._query_cache.query_cache, implicit_limit=__limit__, inline_typenames=False, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, ) ) @@ -469,6 +475,7 @@ async def _fetchall_json_elements(self, query: str, *args, **kwargs): kwargs=kwargs, reg=self._query_cache.codecs_registry, qc=self._query_cache.query_cache, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON_ELEMENTS, allow_capabilities=edgedb_enums.Capability.EXECUTE, # type: ignore ) @@ -499,6 +506,7 @@ def is_closed(self): async def connect(self, single_attempt=False): self._params, client_config = con_utils.parse_connect_arguments( **self._connect_args, + tls_server_name=None, command_timeout=None, server_settings=None, ) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 2054a63d536..a234a6662fb 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -27,6 +27,7 @@ Type, Union, Iterable, + Literal, Sequence, Dict, List, @@ -572,6 +573,7 @@ def http_con_binary_request( compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=query, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, @@ -1111,40 +1113,51 @@ def assert_data_shape(self, data, shape, message=message, rel_tol=rel_tol, abs_tol=abs_tol, ) - async def assert_query_result(self, query, - exp_result_json, - exp_result_binary=..., - *, - always_typenames=False, - msg=None, sort=None, implicit_limit=0, - variables=None, json_only=False, - rel_tol=None, abs_tol=None): + async def assert_query_result( + self, + query, + exp_result_json, + exp_result_binary=..., + *, + always_typenames=False, + msg=None, + sort=None, + implicit_limit=0, + variables=None, + json_only=False, + binary_only=False, + rel_tol=None, + abs_tol=None, + language: Literal["sql", "edgeql"] = "edgeql", + ): fetch_args = variables if isinstance(variables, tuple) else () fetch_kw = variables if isinstance(variables, dict) else {} - try: - tx = self.con.transaction() - await tx.start() - try: - res = await self.con._fetchall_json( - query, - *fetch_args, - __limit__=implicit_limit, - **fetch_kw) - finally: - await tx.rollback() - res = json.loads(res) - if sort is not None: - assert_data_shape.sort_results(res, sort) - assert_data_shape.assert_data_shape( - res, exp_result_json, self.fail, - message=msg, rel_tol=rel_tol, abs_tol=abs_tol, - ) - except Exception: - self.add_fail_notes(serialization='json') - if msg: - self.add_fail_notes(msg=msg) - raise + if not binary_only and language != "sql": + try: + tx = self.con.transaction() + await tx.start() + try: + res = await self.con._fetchall_json( + query, + *fetch_args, + __limit__=implicit_limit, + **fetch_kw) + finally: + await tx.rollback() + + res = json.loads(res) + if sort is not None: + assert_data_shape.sort_results(res, sort) + assert_data_shape.assert_data_shape( + res, exp_result_json, self.fail, + message=msg, rel_tol=rel_tol, abs_tol=abs_tol, + ) + except Exception: + self.add_fail_notes(serialization='json') + if msg: + self.add_fail_notes(msg=msg) + raise if json_only: return @@ -1163,14 +1176,22 @@ async def assert_query_result(self, query, __typenames__=typenames, __typeids__=typeids, __limit__=implicit_limit, + __language__=( + tconn.InputLanguage.SQL if language == "sql" + else tconn.InputLanguage.EDGEQL + ), **fetch_kw ) res = serutils.serialize(res) if sort is not None: assert_data_shape.sort_results(res, sort) assert_data_shape.assert_data_shape( - res, exp_result_binary, self.fail, - message=msg, rel_tol=rel_tol, abs_tol=abs_tol, + res, + exp_result_binary, + self.fail, + message=msg, + rel_tol=rel_tol, + abs_tol=abs_tol, ) except Exception: self.add_fail_notes( @@ -1181,6 +1202,28 @@ async def assert_query_result(self, query, self.add_fail_notes(msg=msg) raise + async def assert_sql_query_result( + self, + query, + exp_result, + *, + msg=None, + sort=None, + variables=None, + rel_tol=None, + abs_tol=None, + ): + await self.assert_query_result( + query, + exp_result, + msg=msg, + sort=sort, + variables=variables, + rel_tol=rel_tol, + abs_tol=abs_tol, + language="sql", + ) + async def assert_index_use(self, query, *args, index_type): def look(obj): if ( diff --git a/edb/tools/edb.py b/edb/tools/edb.py index 0f399f258ba..711a24adc91 100644 --- a/edb/tools/edb.py +++ b/edb/tools/edb.py @@ -87,5 +87,6 @@ def load_ext(args: tuple[str, ...]): from . import ast_inheritance_graph # noqa from . import parser_demo # noqa from . import ls_forbidden_functions # noqa +from . import redo_metaschema # noqa from .profiling import cli as prof_cli # noqa from .experimental_interpreter import edb_entry # noqa diff --git a/edb/tools/gen_types.py b/edb/tools/gen_types.py index f9bb0ae0058..b16713e3966 100644 --- a/edb/tools/gen_types.py +++ b/edb/tools/gen_types.py @@ -63,7 +63,7 @@ def main(*, stdout: bool): f'\n\n\n' f'from __future__ import annotations' f'\n' - f'from typing import * # NoQA' + f'from typing import Type' f'\n\n\n' f'import uuid' f'\n\n' diff --git a/edb/tools/mypy/plugin.py b/edb/tools/mypy/plugin.py index bcbd30962c8..49338d1e4df 100644 --- a/edb/tools/mypy/plugin.py +++ b/edb/tools/mypy/plugin.py @@ -26,8 +26,10 @@ import mypy.plugin as mypy_plugin from mypy import mro from mypy import nodes +from mypy import options as mypy_options from mypy import types -from mypy import semanal +from mypy import typevars as mypy_typevars +from mypy import semanal_shared as mypy_semanal from mypy.plugins import common as mypy_helpers from mypy.server import trigger as mypy_trigger @@ -71,12 +73,14 @@ def handle_schema_class(self, ctx: mypy_plugin.ClassDefContext): transformers.append( SchemaClassTransformer( ctx, + self.options, field_makers={'edb.schema.objects.SchemaField'}, ) ) transformers.append( StructTransformer( ctx, + self.options, field_makers={'edb.schema.objects.Field'}, ) ) @@ -85,6 +89,7 @@ def handle_schema_class(self, ctx: mypy_plugin.ClassDefContext): transformers.append( StructTransformer( ctx, + self.options, field_makers={'edb.common.struct.Field'}, ) ) @@ -93,6 +98,7 @@ def handle_schema_class(self, ctx: mypy_plugin.ClassDefContext): transformers.append( ASTClassTransformer( ctx, + self.options, ) ) @@ -206,8 +212,10 @@ class BaseTransformer: def __init__( self, ctx: mypy_plugin.ClassDefContext, + options: mypy_options.Options, ) -> None: self._ctx = ctx + self._options = options def transform(self): ctx = self._ctx @@ -344,8 +352,8 @@ def _synthesize_init(self, fields: List[Field]) -> None: # var bounds), defer. If we skip deferring and stick something # in our symbol table anyway, we'll get in trouble. (Arguably # plugins.common ought to help us with this, but oh well.) - self_type = mypy_helpers.fill_typevars(cls_info) - if semanal.has_placeholder(self_type): + self_type = mypy_typevars.fill_typevars(cls_info) + if mypy_semanal.has_placeholder(self_type): raise DeferException if ( @@ -368,9 +376,10 @@ class BaseStructTransformer(BaseTransformer): def __init__( self, ctx: mypy_plugin.ClassDefContext, + options: mypy_options.Options, field_makers: AbstractSet[str], ) -> None: - super().__init__(ctx) + super().__init__(ctx, options) self._field_makers = field_makers def _field_from_field_def( @@ -417,7 +426,10 @@ def _field_from_field_def( if ftype is None: try: - un_type = exprtotype.expr_to_unanalyzed_type(type_arg) + un_type = exprtotype.expr_to_unanalyzed_type( + type_arg, + options=self._options, + ) except exprtotype.TypeTranslationError: ctx.api.fail('Cannot resolve schema field type', type_arg) else: diff --git a/edb/tools/redo_metaschema.py b/edb/tools/redo_metaschema.py new file mode 100644 index 00000000000..374e71b96b0 --- /dev/null +++ b/edb/tools/redo_metaschema.py @@ -0,0 +1,52 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2021-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from edb.tools.edb import edbcommands + + +@edbcommands.command("redo-metaschema-sql") +def run(): + """ + Generates DDL to recreate metaschema for sql introspection. + Can be used to apply changes to metaschema to an existing database. + + edb redo-metaschema-sql | ./build/postgres/install/bin/psql \ + "postgresql://postgres@/E_main?host=$(pwd)/tmp/devdatadir&port=5432" \ + -v ON_ERROR_STOP=ON + """ + + from edb.common import devmode + devmode.enable_dev_mode() + + from edb.pgsql import dbops, metaschema + from edb import buildmeta + + version = buildmeta.get_pg_version() + commands = metaschema._generate_sql_information_schema(version) + + for command in commands: + block = dbops.PLTopBlock() + + if isinstance(command, dbops.CreateFunction): + command.or_replace = True + if isinstance(command, dbops.CreateView): + command.or_replace = True + + command.generate(block) + + print(block.to_string()) diff --git a/edb/tools/test/runner.py b/edb/tools/test/runner.py index 9fb7127982e..2bfdfb39854 100644 --- a/edb/tools/test/runner.py +++ b/edb/tools/test/runner.py @@ -285,6 +285,17 @@ def monitor_thread(queue, result): method(*args, **kwargs) +def status_thread_func( + result: ParallelTextTestResult, + stop_event: threading.Event, +) -> None: + while True: + result.report_still_running() + time.sleep(1) + if stop_event.is_set(): + break + + class ParallelTestSuite(unittest.TestSuite): def __init__( self, tests, server_conn, num_workers, backend_dsn, init_worker @@ -310,10 +321,22 @@ def run(self, result): worker_param_queue.put((self.server_conn, self.backend_dsn)) result_thread = threading.Thread( - name='test-monitor', target=monitor_thread, - args=(result_queue, result), daemon=True) + name='test-monitor', + target=monitor_thread, + args=(result_queue, result), + daemon=True, + ) result_thread.start() + status_thread_stop_event = threading.Event() + status_thread = threading.Thread( + name='test-status', + target=status_thread_func, + args=(result, status_thread_stop_event), + daemon=True, + ) + status_thread.start() + initargs = ( status_queue, worker_param_queue, result_queue, self.init_worker ) @@ -357,12 +380,13 @@ def run(self, result): # Post the terminal message to the queue so that # test-monitor can stop. result_queue.put((None, None, None)) + status_thread_stop_event.set() - # Give the test-monitor thread some time to - # process the queue messages. If something - # goes wrong, the thread will be forcibly + # Give the test-monitor and test-status threads some time to process the + # queue messages. If something goes wrong, the thread will be forcibly # joined by a timeout. result_thread.join(timeout=3) + status_thread.join(timeout=3) return result @@ -450,6 +474,9 @@ def report(self, test, marker, description=None, *, currently_running): def report_start(self, test, *, currently_running): return + def report_still_running(self, still_running: dict[str, float]): + return + class SimpleRenderer(BaseRenderer): def report(self, test, marker, description=None, *, currently_running): @@ -480,6 +507,10 @@ def report(self, test, marker, description=None, *, currently_running): click.echo(style(self._render_test(test, marker, description)), file=self.stream) + def report_still_running(self, still_running: dict[str, float]) -> None: + items = [f"{t} for {d:.02f}s" for t, d in still_running.items()] + click.echo(f"still running:\n {'\n '.join(items)}") + class MultiLineRenderer(BaseRenderer): @@ -521,6 +552,10 @@ def report(self, test, marker, description=None, *, currently_running): def report_start(self, test, *, currently_running): self._render(currently_running) + def report_still_running(self, still_running: dict[str, float]): + # Still-running tests are already reported in normal repert + return + def _render_modname(self, name): return name.replace('.', '/') + '.py' @@ -590,10 +625,13 @@ def _render_test_list(label, max_lines, tests, style): # Prevent the rendered output from "jumping" up/down when we # render 2 lines worth of running tests just after we rendered # 3 lines. - for _ in range(self.max_label_lines_rendered[label] - tests_lines): + lkey = label.split(':')[0] + # ^- We can't just use `label`, as we append extra information + # to the "Running: (..)" label, so strip that + for _ in range(self.max_label_lines_rendered[lkey] - tests_lines): lines.append(' ' * cols) - self.max_label_lines_rendered[label] = max( - self.max_label_lines_rendered[label], + self.max_label_lines_rendered[lkey] = max( + self.max_label_lines_rendered[lkey], tests_lines ) @@ -724,6 +762,16 @@ def report_progress(self, test, marker, description=None): currently_running=list(self.currently_running), ) + def report_still_running(self): + now = time.monotonic() + still_running = {} + for test, start in self.currently_running.items(): + running_for = now - start + if running_for > 5.0: + still_running[test] = running_for + if still_running: + self.ren.report_still_running(still_running) + def record_test_stats(self, test, stats): self.test_stats.append((test, stats)) @@ -742,7 +790,7 @@ def getDescription(self, test): def startTest(self, test): super().startTest(test) - self.currently_running[test] = True + self.currently_running[test] = time.monotonic() self.ren.report_start( test, currently_running=list(self.currently_running)) diff --git a/edb_stat_statements/.gitignore b/edb_stat_statements/.gitignore new file mode 100644 index 00000000000..0df4d7f493e --- /dev/null +++ b/edb_stat_statements/.gitignore @@ -0,0 +1,9 @@ +# Generated subdirectories +/log/ +/results/ +/tmp_check/ +/expected/dml.out +/expected/level_tracking.out +/expected/parallel.out +/expected/utility.out +/expected/wal.out diff --git a/edb_stat_statements/Makefile b/edb_stat_statements/Makefile new file mode 100644 index 00000000000..f53c59c36f3 --- /dev/null +++ b/edb_stat_statements/Makefile @@ -0,0 +1,69 @@ +MODULE_big = edb_stat_statements +OBJS = \ + $(WIN32RES) \ + edb_stat_statements.o + +EXTENSION = edb_stat_statements +DATA = edb_stat_statements--1.0.sql +PGFILEDESC = "edb_stat_statements - execution statistics of EdgeDB queries" + +LDFLAGS_SL += $(filter -lm, $(LIBS)) + +REGRESS = select dml cursors utility level_tracking planning \ + user_activity wal entry_timestamp privileges \ + parallel cleanup oldextversions + +TAP_TESTS = 1 +PG_MAJOR = $(shell $(PG_CONFIG) --version | grep -oE '[0-9]+' | head -1) + +ifeq ($(shell test $(PG_MAJOR) -ge 18 && echo true), true) + REGRESS += extended +endif + +all: + +expected/dml.out: + if [ $(PG_MAJOR) -ge 18 ]; then \ + cp expected/dml.out.18 expected/dml.out; \ + else \ + cp expected/dml.out.17 expected/dml.out; \ + fi + +expected/level_tracking.out: + if [ $(PG_MAJOR) -ge 18 ]; then \ + cp expected/level_tracking.out.18 expected/level_tracking.out; \ + else \ + cp expected/level_tracking.out.17 expected/level_tracking.out; \ + fi + +expected/parallel.out: + if [ $(PG_MAJOR) -ge 18 ]; then \ + cp expected/parallel.out.18 expected/parallel.out; \ + else \ + cp expected/parallel.out.17 expected/parallel.out; \ + fi + +expected/utility.out: + if [ $(PG_MAJOR) -ge 17 ]; then \ + cp expected/utility.out.17 expected/utility.out; \ + else \ + cp expected/utility.out.16 expected/utility.out; \ + fi + +expected/wal.out: + if [ $(PG_MAJOR) -ge 18 ]; then \ + cp expected/wal.out.18 expected/wal.out; \ + else \ + cp expected/wal.out.17 expected/wal.out; \ + fi + +installcheck: \ + expected/dml.out \ + expected/level_tracking.out \ + expected/parallel.out \ + expected/utility.out \ + expected/wal.out + +PG_CONFIG = pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) diff --git a/edb_stat_statements/edb_stat_statements--1.0.sql b/edb_stat_statements/edb_stat_statements--1.0.sql new file mode 100644 index 00000000000..dd08d8045f1 --- /dev/null +++ b/edb_stat_statements/edb_stat_statements--1.0.sql @@ -0,0 +1,99 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "CREATE EXTENSION edb_stat_statements" to load this file. \quit + +-- Register functions. +CREATE FUNCTION edb_stat_statements_reset(IN userid Oid DEFAULT 0, + IN dbids Oid[] DEFAULT '{}', + IN queryid bigint DEFAULT 0, + IN minmax_only boolean DEFAULT false +) +RETURNS timestamp with time zone +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT PARALLEL SAFE; + +CREATE FUNCTION edb_stat_queryid(IN id uuid) +RETURNS bigint +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT PARALLEL SAFE; + +CREATE FUNCTION edb_stat_statements(IN showtext boolean, + OUT userid oid, + OUT dbid oid, + OUT toplevel bool, + OUT queryid bigint, + OUT query text, + OUT extras jsonb, + OUT id uuid, + OUT stmt_type int2, + OUT plans int8, + OUT total_plan_time float8, + OUT min_plan_time float8, + OUT max_plan_time float8, + OUT mean_plan_time float8, + OUT stddev_plan_time float8, + OUT calls int8, + OUT total_exec_time float8, + OUT min_exec_time float8, + OUT max_exec_time float8, + OUT mean_exec_time float8, + OUT stddev_exec_time float8, + OUT rows int8, + OUT shared_blks_hit int8, + OUT shared_blks_read int8, + OUT shared_blks_dirtied int8, + OUT shared_blks_written int8, + OUT local_blks_hit int8, + OUT local_blks_read int8, + OUT local_blks_dirtied int8, + OUT local_blks_written int8, + OUT temp_blks_read int8, + OUT temp_blks_written int8, + OUT shared_blk_read_time float8, + OUT shared_blk_write_time float8, + OUT local_blk_read_time float8, + OUT local_blk_write_time float8, + OUT temp_blk_read_time float8, + OUT temp_blk_write_time float8, + OUT wal_records int8, + OUT wal_fpi int8, + OUT wal_bytes numeric, + OUT jit_functions int8, + OUT jit_generation_time float8, + OUT jit_inlining_count int8, + OUT jit_inlining_time float8, + OUT jit_optimization_count int8, + OUT jit_optimization_time float8, + OUT jit_emission_count int8, + OUT jit_emission_time float8, + OUT jit_deform_count int8, + OUT jit_deform_time float8, + OUT parallel_workers_to_launch int8, + OUT parallel_workers_launched int8, + OUT stats_since timestamp with time zone, + OUT minmax_stats_since timestamp with time zone +) +RETURNS SETOF record +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT VOLATILE PARALLEL SAFE; + +CREATE FUNCTION edb_stat_statements_info( + OUT dealloc bigint, + OUT stats_reset timestamp with time zone +) +RETURNS record +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT VOLATILE PARALLEL SAFE; + +-- Register views on the functions for ease of use. +CREATE VIEW edb_stat_statements AS + SELECT * FROM edb_stat_statements(true); + +GRANT SELECT ON edb_stat_statements TO PUBLIC; + +CREATE VIEW edb_stat_statements_info AS + SELECT * FROM edb_stat_statements_info(); + +GRANT SELECT ON edb_stat_statements_info TO PUBLIC; + +-- Don't want this to be available to non-superusers. +REVOKE ALL ON FUNCTION edb_stat_statements_reset(Oid, Oid[], bigint, boolean) FROM PUBLIC; diff --git a/edb_stat_statements/edb_stat_statements.c b/edb_stat_statements/edb_stat_statements.c new file mode 100644 index 00000000000..6b8b322d33d --- /dev/null +++ b/edb_stat_statements/edb_stat_statements.c @@ -0,0 +1,3297 @@ +/*------------------------------------------------------------------------- + * + * edb_stat_statements.c + * Track statement planning and execution times as well as resource + * usage across a whole database cluster. + * + * Execution costs are totaled for each distinct source query, and kept in + * a shared hashtable. (We track only as many distinct queries as will fit + * in the designated amount of shared memory.) + * + * Starting in Postgres 9.2, this module normalized query entries. As of + * Postgres 14, the normalization is done by the core if compute_query_id is + * enabled, or optionally by third-party modules. + * + * To facilitate presenting entries to users, we create "representative" query + * strings in which constants are replaced with parameter symbols ($n), to + * make it clearer what a normalized entry can represent. To save on shared + * memory, and to avoid having to truncate oversized query strings, we store + * these strings in a temporary external query-texts file. Offsets into this + * file are kept in shared memory. + * + * Note about locking issues: to create or delete an entry in the shared + * hashtable, one must hold pgss->lock exclusively. Modifying any field + * in an entry except the counters requires the same. To look up an entry, + * one must hold the lock shared. To read or update the counters within + * an entry, one must hold the lock shared or exclusive (so the entry doesn't + * disappear!) and also take the entry's mutex spinlock. + * The shared state variable pgss->extent (the next free spot in the external + * query-text file) should be accessed only while holding either the + * pgss->mutex spinlock, or exclusive lock on pgss->lock. We use the mutex to + * allow reserving file space while holding only shared lock on pgss->lock. + * Rewriting the entire external query-text file, eg for garbage collection, + * requires holding pgss->lock exclusively; this allows individual entries + * in the file to be read or written while holding only shared lock. + * + * + * Copyright (c) 2008-2024, PostgreSQL Global Development Group + * Copyright 2024-present MagicStack Inc. and the EdgeDB authors. + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include +#include +#include + +#include "access/parallel.h" +#include "catalog/pg_authid.h" +#include "common/hashfn.h" +#include "common/int.h" +#include "common/jsonapi.h" +#include "executor/instrument.h" +#include "funcapi.h" +#include "jit/jit.h" +#include "mb/pg_wchar.h" +#include "miscadmin.h" +#include "nodes/queryjumble.h" +#include "optimizer/planner.h" +#include "parser/analyze.h" +#include "parser/parsetree.h" +#include "parser/scanner.h" +#include "parser/scansup.h" +#include "pgstat.h" +#include "storage/fd.h" +#include "storage/ipc.h" +#include "storage/lwlock.h" +#include "storage/shmem.h" +#include "storage/spin.h" +#include "tcop/utility.h" +#include "utils/acl.h" +#include "utils/builtins.h" +#include "utils/jsonb.h" +#include "utils/memutils.h" +#include "utils/timestamp.h" +#include "utils/uuid.h" + +PG_MODULE_MAGIC; + +#define EDB_STMT_MAGIC_PREFIX "-- {" + +/* Location of permanent stats file (valid when database is shut down) */ +#define PGSS_DUMP_FILE PGSTAT_STAT_PERMANENT_DIRECTORY "/edb_stat_statements.stat" + +/* + * Location of external query text file. + */ +#define PGSS_TEXT_FILE PG_STAT_TMP_DIR "/edbss_query_texts.stat" + +/* Magic number identifying the stats file format */ +static const uint32 PGSS_FILE_HEADER = 0x20241021; + +/* PostgreSQL major version number, changes in which invalidate all entries */ +static const uint32 PGSS_PG_MAJOR_VERSION = PG_VERSION_NUM / 100; + +/* XXX: Should USAGE_EXEC reflect execution time and/or buffer usage? */ +#define USAGE_EXEC(duration) (1.0) +#define USAGE_INIT (1.0) /* including initial planning */ +#define ASSUMED_MEDIAN_INIT (10.0) /* initial assumed median usage */ +#define ASSUMED_LENGTH_INIT 1024 /* initial assumed mean query length */ +#define USAGE_DECREASE_FACTOR (0.99) /* decreased every entry_dealloc */ +#define STICKY_DECREASE_FACTOR (0.50) /* factor for sticky entries */ +#define USAGE_DEALLOC_PERCENT 5 /* free this % of entries at once */ +#define IS_STICKY(c) ((c.calls[PGSS_PLAN] + c.calls[PGSS_EXEC]) == 0) + +/* + * Extension version number, for supporting older extension versions' objects + */ +typedef enum pgssVersion +{ + PGSS_V1_0 = 0, +} pgssVersion; + +typedef enum pgssStoreKind +{ + PGSS_INVALID = -1, + + /* + * PGSS_PLAN and PGSS_EXEC must be respectively 0 and 1 as they're used to + * reference the underlying values in the arrays in the Counters struct, + * and this order is required in edb_stat_statements_internal(). + */ + PGSS_PLAN = 0, + PGSS_EXEC, +} pgssStoreKind; + +#define PGSS_NUMKIND (PGSS_EXEC + 1) + +typedef enum EdbStmtType { + EDB_EDGEQL = 1, + EDB_SQL = 2, +} EdbStmtType; + +/* + * Internal states parsing the info JSON. + */ +typedef enum EdbStmtInfoParseState { + EDB_STMT_INFO_PARSE_NOOP = 0, + EDB_STMT_INFO_PARSE_QUERY = 1 << 0, + EDB_STMT_INFO_PARSE_ID = 1 << 1, + EDB_STMT_INFO_PARSE_TYPE = 1 << 2, + EDB_STMT_INFO_PARSE_EXTRAS = 1 << 3, +} EdbStmtInfoParseState; + +/* + * The info JSON parsing is only considered a success + * if all the fields listed below are found. + */ +#define EDB_STMT_INFO_PARSE_REQUIRED \ + (EDB_STMT_INFO_PARSE_QUERY \ + & EDB_STMT_INFO_PARSE_ID \ + & EDB_STMT_INFO_PARSE_TYPE \ + ) + +/* + * The result of parsing the info JSON by edbss_extract_stmt_info(). + */ +typedef struct EdbStmtInfo { + union { + pg_uuid_t uuid; + uint64 query_id; + } id; + const char *query; + int query_len; + EdbStmtType stmt_type; + Jsonb *extras; +} EdbStmtInfo; + +/* + * The custom "semantic state" structure for info JSON parsing. + * This is used internally as the `semstate` pointer of the parser, + * keeping track of: + * - level of nested JSON objects + * - known object keys we've found + * - current key/state we're parsing + * - pointer to the parse result struct + */ +typedef struct EdbStmtInfoSemState { + int nested_level; + uint found; + EdbStmtInfoParseState state; + EdbStmtInfo *info; +} EdbStmtInfoSemState; + +/* + * Hashtable key that defines the identity of a hashtable entry. We separate + * queries by user and by database even if they are otherwise identical. + * + * If you add a new key to this struct, make sure to teach pgss_store() to + * zero the padding bytes. Otherwise, things will break, because pgss_hash is + * created using HASH_BLOBS, and thus tag_hash is used to hash this. + + */ +typedef struct pgssHashKey +{ + Oid userid; /* user OID */ + Oid dbid; /* database OID */ + uint64 queryid; /* query identifier */ + bool toplevel; /* query executed at top level */ +} pgssHashKey; + +/* + * The actual stats counters kept within pgssEntry. + */ +typedef struct Counters +{ + int64 calls[PGSS_NUMKIND]; /* # of times planned/executed */ + double total_time[PGSS_NUMKIND]; /* total planning/execution time, + * in msec */ + double min_time[PGSS_NUMKIND]; /* minimum planning/execution time in + * msec since min/max reset */ + double max_time[PGSS_NUMKIND]; /* maximum planning/execution time in + * msec since min/max reset */ + double mean_time[PGSS_NUMKIND]; /* mean planning/execution time in + * msec */ + double sum_var_time[PGSS_NUMKIND]; /* sum of variances in + * planning/execution time in msec */ + int64 rows; /* total # of retrieved or affected rows */ + int64 shared_blks_hit; /* # of shared buffer hits */ + int64 shared_blks_read; /* # of shared disk blocks read */ + int64 shared_blks_dirtied; /* # of shared disk blocks dirtied */ + int64 shared_blks_written; /* # of shared disk blocks written */ + int64 local_blks_hit; /* # of local buffer hits */ + int64 local_blks_read; /* # of local disk blocks read */ + int64 local_blks_dirtied; /* # of local disk blocks dirtied */ + int64 local_blks_written; /* # of local disk blocks written */ + int64 temp_blks_read; /* # of temp blocks read */ + int64 temp_blks_written; /* # of temp blocks written */ + double shared_blk_read_time; /* time spent reading shared blocks, + * in msec */ + double shared_blk_write_time; /* time spent writing shared blocks, + * in msec */ + double local_blk_read_time; /* time spent reading local blocks, in + * msec */ + double local_blk_write_time; /* time spent writing local blocks, in + * msec */ + double temp_blk_read_time; /* time spent reading temp blocks, in msec */ + double temp_blk_write_time; /* time spent writing temp blocks, in + * msec */ + double usage; /* usage factor */ + int64 wal_records; /* # of WAL records generated */ + int64 wal_fpi; /* # of WAL full page images generated */ + uint64 wal_bytes; /* total amount of WAL generated in bytes */ + int64 jit_functions; /* total number of JIT functions emitted */ + double jit_generation_time; /* total time to generate jit code */ + int64 jit_inlining_count; /* number of times inlining time has been + * > 0 */ + double jit_deform_time; /* total time to deform tuples in jit code */ + int64 jit_deform_count; /* number of times deform time has been > + * 0 */ + + double jit_inlining_time; /* total time to inline jit code */ + int64 jit_optimization_count; /* number of times optimization time + * has been > 0 */ + double jit_optimization_time; /* total time to optimize jit code */ + int64 jit_emission_count; /* number of times emission time has been + * > 0 */ + double jit_emission_time; /* total time to emit jit code */ + int64 parallel_workers_to_launch; /* # of parallel workers planned + * to be launched */ + int64 parallel_workers_launched; /* # of parallel workers actually + * launched */ +} Counters; + +/* + * Global statistics for edb_stat_statements + */ +typedef struct pgssGlobalStats +{ + int64 dealloc; /* # of times entries were deallocated */ + TimestampTz stats_reset; /* timestamp with all stats reset */ +} pgssGlobalStats; + +/* + * Statistics per statement + * + * Note: in event of a failure in garbage collection of the query text file, + * we reset query_offset to zero and query_len to -1. This will be seen as + * an invalid state by qtext_fetch(). + */ +typedef struct pgssEntry +{ + pgssHashKey key; /* hash key of entry - MUST BE FIRST */ + Counters counters; /* the statistics for this query */ + Size query_offset; /* query text offset in external file */ + int query_len; /* # of valid bytes in query string, or -1 */ + int encoding; /* query text encoding */ + TimestampTz stats_since; /* timestamp of entry allocation */ + TimestampTz minmax_stats_since; /* timestamp of last min/max values reset */ + slock_t mutex; /* protects the counters only */ + + pg_uuid_t id; /* Full 16-bytes query ID as UUID */ + EdbStmtType stmt_type; /* Type of the EdgeDB query */ + int extras_len; /* # of valid bytes in extras jsonb, or -1 */ +} pgssEntry; + +/* + * Global shared state + */ +typedef struct pgssSharedState +{ + LWLock *lock; /* protects hashtable search/modification */ + double cur_median_usage; /* current median usage in hashtable */ + Size mean_query_len; /* current mean entry text length */ + slock_t mutex; /* protects following fields only: */ + Size extent; /* current extent of query file */ + int n_writers; /* number of active writers to query file */ + int gc_count; /* query file garbage collection cycle count */ + pgssGlobalStats stats; /* global statistics for pgss */ +} pgssSharedState; + +/*---- Local variables ----*/ + +static pg_uuid_t zero_uuid = { 0 }; + +/* Current nesting depth of planner/ExecutorRun/ProcessUtility calls */ +static int nesting_level = 0; + +/* Saved hook values in case of unload */ +static shmem_request_hook_type prev_shmem_request_hook = NULL; +static shmem_startup_hook_type prev_shmem_startup_hook = NULL; +static post_parse_analyze_hook_type prev_post_parse_analyze_hook = NULL; +static planner_hook_type prev_planner_hook = NULL; +static ExecutorStart_hook_type prev_ExecutorStart = NULL; +static ExecutorRun_hook_type prev_ExecutorRun = NULL; +static ExecutorFinish_hook_type prev_ExecutorFinish = NULL; +static ExecutorEnd_hook_type prev_ExecutorEnd = NULL; +static ProcessUtility_hook_type prev_ProcessUtility = NULL; + +/* Links to shared memory state */ +static pgssSharedState *pgss = NULL; +static HTAB *pgss_hash = NULL; + +/*---- GUC variables ----*/ + +typedef enum +{ + PGSS_TRACK_NONE, /* track no statements */ + PGSS_TRACK_TOP, /* only top level statements */ + PGSS_TRACK_ALL, /* all statements, including nested ones */ +} PGSSTrackLevel; + +static const struct config_enum_entry track_options[] = +{ + {"none", PGSS_TRACK_NONE, false}, + {"top", PGSS_TRACK_TOP, false}, + {"all", PGSS_TRACK_ALL, false}, + {NULL, 0, false} +}; + +static int pgss_max = 5000; /* max # statements to track */ +static int pgss_track = PGSS_TRACK_TOP; /* tracking level */ +static bool pgss_track_utility = true; /* whether to track utility commands */ +static bool pgss_track_planning = false; /* whether to track planning + * duration */ +static bool pgss_save = true; /* whether to save stats across shutdown */ +static bool edbss_track_unrecognized = false; /* whether to track unrecognized statements as-is */ + + +#define pgss_enabled(level) \ + (!IsParallelWorker() && \ + (pgss_track == PGSS_TRACK_ALL || \ + (pgss_track == PGSS_TRACK_TOP && (level) == 0))) + +#define record_gc_qtexts() \ + do { \ + SpinLockAcquire(&pgss->mutex); \ + pgss->gc_count++; \ + SpinLockRelease(&pgss->mutex); \ + } while(0) + +/*---- Function declarations ----*/ + +PG_FUNCTION_INFO_V1(edb_stat_statements_reset); +PG_FUNCTION_INFO_V1(edb_stat_statements); +PG_FUNCTION_INFO_V1(edb_stat_statements_info); +PG_FUNCTION_INFO_V1(edb_stat_queryid); + +const char * +edbss_extract_info_line(const char *s, int* len); +EdbStmtInfo * +edbss_extract_stmt_info(const char* query_str, int query_len); +static inline void +edbss_free_stmt_info(EdbStmtInfo *info); +static JsonParseErrorType +edbss_json_struct_start(void *semstate); +static JsonParseErrorType +edbss_json_struct_end(void *semstate); +static JsonParseErrorType +edbss_json_ofield_start(void *semstate, char *fname, bool isnull); +static JsonParseErrorType +edbss_json_scalar(void *semstate, char *token, JsonTokenType tokenType); + +static void pgss_shmem_request(void); +static void pgss_shmem_startup(void); +static void pgss_shmem_shutdown(int code, Datum arg); +static void pgss_post_parse_analyze(ParseState *pstate, Query *query, + JumbleState *jstate); +static PlannedStmt *pgss_planner(Query *parse, + const char *query_string, + int cursorOptions, + ParamListInfo boundParams); +static void pgss_ExecutorStart(QueryDesc *queryDesc, int eflags); +static void pgss_ExecutorRun(QueryDesc *queryDesc, + ScanDirection direction, + uint64 count, bool execute_once); +static void pgss_ExecutorFinish(QueryDesc *queryDesc); +static void pgss_ExecutorEnd(QueryDesc *queryDesc); +static void pgss_ProcessUtility(PlannedStmt *pstmt, const char *queryString, + bool readOnlyTree, + ProcessUtilityContext context, ParamListInfo params, + QueryEnvironment *queryEnv, + DestReceiver *dest, QueryCompletion *qc); +static void pgss_store(const char *query, uint64 queryId, + int query_location, int query_len, + pgssStoreKind kind, + double total_time, uint64 rows, + const BufferUsage *bufusage, + const WalUsage *walusage, + const struct JitInstrumentation *jitusage, + JumbleState *jstate, + bool edb_extracted, + pg_uuid_t *id, + EdbStmtType stmt_type, + const Jsonb *extras, + int parallel_workers_to_launch, + int parallel_workers_launched); +static void edb_stat_statements_internal(FunctionCallInfo fcinfo, + pgssVersion api_version, + bool showtext); +static Size pgss_memsize(void); +static pgssEntry *entry_alloc(pgssHashKey *key, Size query_offset, int query_len, + int encoding, bool sticky, pg_uuid_t *id, + EdbStmtType stmt_type, int extras_len); +static void entry_dealloc(void); +static bool qtext_store(const char *query, int query_len, + const Jsonb *extras, int extras_len, + Size *query_offset, int *gc_count); +static char *qtext_load_file(Size *buffer_size); +static char *qtext_fetch(Size query_offset, int query_len, + char *buffer, Size buffer_size); +static bool need_gc_qtexts(void); +static void gc_qtexts(void); +static TimestampTz entry_reset(Oid userid, const Datum *dbids, int dbids_len, uint64 queryid, bool minmax_only); +static char *generate_normalized_query(JumbleState *jstate, const char *query, + int query_loc, int *query_len_p); +static void fill_in_constant_lengths(JumbleState *jstate, const char *query, + int query_loc); +static int comp_location(const void *a, const void *b); + + +/* + * Module load callback + */ +void +_PG_init(void) +{ + /* + * In order to create our shared memory area, we have to be loaded via + * shared_preload_libraries. If not, fall out without hooking into any of + * the main system. (We don't throw error here because it seems useful to + * allow the edb_stat_statements functions to be created even when the + * module isn't active. The functions must protect themselves against + * being called then, however.) + */ + if (!process_shared_preload_libraries_in_progress) + return; + + /* + * Inform the postmaster that we want to enable query_id calculation if + * compute_query_id is set to auto. + */ + EnableQueryId(); + + /* + * Define (or redefine) custom GUC variables. + */ + DefineCustomIntVariable("edb_stat_statements.max", + "Sets the maximum number of statements tracked by edb_stat_statements.", + NULL, + &pgss_max, + 5000, + 100, + INT_MAX / 2, + PGC_POSTMASTER, + 0, + NULL, + NULL, + NULL); + + DefineCustomEnumVariable("edb_stat_statements.track", + "Selects which statements are tracked by edb_stat_statements.", + NULL, + &pgss_track, + PGSS_TRACK_TOP, + track_options, + PGC_SUSET, + 0, + NULL, + NULL, + NULL); + + DefineCustomBoolVariable("edb_stat_statements.track_utility", + "Selects whether utility commands are tracked by edb_stat_statements.", + NULL, + &pgss_track_utility, + true, + PGC_SUSET, + 0, + NULL, + NULL, + NULL); + + DefineCustomBoolVariable("edb_stat_statements.track_planning", + "Selects whether planning duration is tracked by edb_stat_statements.", + NULL, + &pgss_track_planning, + false, + PGC_SUSET, + 0, + NULL, + NULL, + NULL); + + DefineCustomBoolVariable("edb_stat_statements.save", + "Save edb_stat_statements statistics across server shutdowns.", + NULL, + &pgss_save, + true, + PGC_SIGHUP, + 0, + NULL, + NULL, + NULL); + + DefineCustomBoolVariable("edb_stat_statements.track_unrecognized", + "Selects whether unrecognized SQL statements are tracked as-is.", + NULL, + &edbss_track_unrecognized, + false, + PGC_SIGHUP, + 0, + NULL, + NULL, + NULL); + + MarkGUCPrefixReserved("edb_stat_statements"); + + /* + * Install hooks. + */ + prev_shmem_request_hook = shmem_request_hook; + shmem_request_hook = pgss_shmem_request; + prev_shmem_startup_hook = shmem_startup_hook; + shmem_startup_hook = pgss_shmem_startup; + prev_post_parse_analyze_hook = post_parse_analyze_hook; + post_parse_analyze_hook = pgss_post_parse_analyze; + prev_planner_hook = planner_hook; + planner_hook = pgss_planner; + prev_ExecutorStart = ExecutorStart_hook; + ExecutorStart_hook = pgss_ExecutorStart; + prev_ExecutorRun = ExecutorRun_hook; + ExecutorRun_hook = pgss_ExecutorRun; + prev_ExecutorFinish = ExecutorFinish_hook; + ExecutorFinish_hook = pgss_ExecutorFinish; + prev_ExecutorEnd = ExecutorEnd_hook; + ExecutorEnd_hook = pgss_ExecutorEnd; + prev_ProcessUtility = ProcessUtility_hook; + ProcessUtility_hook = pgss_ProcessUtility; +} + +const char * +edbss_extract_info_line(const char *s, int *len) { + int prefix_len = strlen(EDB_STMT_MAGIC_PREFIX); + if (*len > prefix_len && strncmp(s, EDB_STMT_MAGIC_PREFIX, prefix_len) == 0) { + const char *rv = s + 3; // skip "-- " + int remaining_len = *len - prefix_len; + int rv_len = 0; + while (rv_len < remaining_len && rv[rv_len] != '\n') + rv_len++; + if (rv_len > 0) { + *len = rv_len; + return rv; + } + } + return NULL; +} + +/* + * Extract EdgeDB query info from the JSON in the leading comments. + * If success, returns a palloc-ed EdbStmtInfo which must be freed + * after usage with edbss_free_stmt_info(). + * + * The query info JSON comments must be at the beginning of the + * query_str. Each line must start with `-- {` and end with `\n`, + * with a single valid JSON string. The JSON string itself must + * not contain any `\n`, or it'll be treated as a bad JSON. + * + * This function scans over all such lines and records known + * values progressively. Malformed JSONs may be partially read, + * this function won't bail just because of that; it'll continue + * with the next line. If the same key exists more than once, + * only the first occurrence is effective, later ones are ignored. + * This function returns successfully as soon as all required + * fields (EDB_STMT_INFO_PARSE_REQUIRED) are found AND the current + * JSON is in good form, ignoring remaining lines. For example: + * + * -- {"a": 1} + * -- {"a": 11, "d": 4, "nested": {"b": 22}} + * -- {"b": 2, "unknown": "skipped", + * -- {"c": 3} + * -- {"e": 5} + * SELECT .... + * + * If the required fields are {a, b, c}, while {d, e} are known + * but not required, the extracted info will be: + * + * {"a": 1, "b": 2, "c": 3, "d": 4} + * + */ +EdbStmtInfo * +edbss_extract_stmt_info(const char* query_str, int query_len) { + int info_len = query_len; + const char *info_str = edbss_extract_info_line(query_str, &info_len); + + if (info_str) { + EdbStmtInfo *info = (EdbStmtInfo *) palloc0(sizeof(EdbStmtInfo)); + EdbStmtInfoSemState state = { + .info = info, + .state = EDB_STMT_INFO_PARSE_NOOP, + }; + JsonSemAction sem = { + .semstate = (void *) &state, + .object_start = edbss_json_struct_start, + .object_end = edbss_json_struct_end, + .array_start = edbss_json_struct_start, + .array_end = edbss_json_struct_end, + .object_field_start = edbss_json_ofield_start, + .scalar = edbss_json_scalar, + }; + + while (info_str) { + JsonLexContext *lex = makeJsonLexContextCstringLen( +#if PG_VERSION_NUM >= 170000 + NULL, + info_str, +#else + (char *) info_str, // not actually mutating +#endif + info_len, + PG_UTF8, + true); + JsonParseErrorType parse_rv = pg_parse_json(lex, &sem); + freeJsonLexContext(lex); + + if (parse_rv == JSON_SUCCESS) + if ((state.found & EDB_STMT_INFO_PARSE_REQUIRED) == EDB_STMT_INFO_PARSE_REQUIRED) + return info->id.query_id != UINT64CONST(0) ? info : NULL; + + info_str += info_len + 1; + info_len = query_len - (int)(info_str - query_str); + info_str = edbss_extract_info_line(info_str, &info_len); + state.nested_level = 0; + state.state = EDB_STMT_INFO_PARSE_NOOP; + } + edbss_free_stmt_info(info); + } + + return NULL; +} + +/* + * Frees the given EdbStmtInfo struct as well as + * its owning sub-fields (query). + */ +static inline void +edbss_free_stmt_info(EdbStmtInfo *info) { + Assert(info != NULL); + pfree((void *) info->query); + pfree(info); +} + +static JsonParseErrorType +edbss_json_struct_start(void *semstate) { + EdbStmtInfoSemState *state = (EdbStmtInfoSemState *) semstate; + state->nested_level += 1; + return JSON_SUCCESS; +} + +static JsonParseErrorType +edbss_json_struct_end(void *semstate) { + EdbStmtInfoSemState *state = (EdbStmtInfoSemState *) semstate; + state->nested_level -= 1; + return JSON_SUCCESS; +} + +static JsonParseErrorType +edbss_json_ofield_start(void *semstate, char *fname, bool isnull) { + EdbStmtInfoSemState *state = (EdbStmtInfoSemState *) semstate; + Assert(fname != NULL); + if (state->nested_level == 1) { + if (strcmp(fname, "query") == 0) { + state->state = EDB_STMT_INFO_PARSE_QUERY; + } else if (strcmp(fname, "id") == 0) { + state->state = EDB_STMT_INFO_PARSE_ID; + } else if (strcmp(fname, "type") == 0) { + state->state = EDB_STMT_INFO_PARSE_TYPE; + } else if (strcmp(fname, "extras") == 0) { + state->state = EDB_STMT_INFO_PARSE_EXTRAS; + } + } + pfree(fname); /* must not use object_field_end */ + return JSON_SUCCESS; +} + +static JsonParseErrorType +edbss_json_scalar(void *semstate, char *token, JsonTokenType tokenType) { + EdbStmtInfoSemState *state = (EdbStmtInfoSemState *) semstate; + Assert(token != NULL); + + if (state->found & state->state) { + pfree(token); + state->state = EDB_STMT_INFO_PARSE_NOOP; + return JSON_SUCCESS; + } + + switch (state->state) { + case EDB_STMT_INFO_PARSE_QUERY: + if (tokenType == JSON_TOKEN_STRING) { + state->info->query = token; + state->info->query_len = (int) strlen(token); + break; + } + goto fail; + + case EDB_STMT_INFO_PARSE_ID: + if (tokenType == JSON_TOKEN_STRING) { + Datum id_datum = DirectFunctionCall1(uuid_in, CStringGetDatum(token)); + pg_uuid_t *id_ptr = DatumGetUUIDP(id_datum); + state->info->id.uuid = *id_ptr; + pfree(id_ptr); + pfree(token); + break; + } + goto fail; + + case EDB_STMT_INFO_PARSE_TYPE: + if (tokenType == JSON_TOKEN_NUMBER) { + char *endptr; + long type_val = strtol(token, &endptr, 10); + if (*endptr == '\0' && type_val != LONG_MAX) { + if (type_val == EDB_EDGEQL || type_val == EDB_SQL) { + state->info->stmt_type = type_val; + pfree(token); + break; + } + } + } + goto fail; + + case EDB_STMT_INFO_PARSE_EXTRAS: + if (tokenType == JSON_TOKEN_STRING) { + Datum extras_jsonb = DirectFunctionCall1(jsonb_in, CStringGetDatum(token)); + state->info->extras = DatumGetJsonbP(extras_jsonb); + pfree(token); + break; + } + goto fail; + + case EDB_STMT_INFO_PARSE_NOOP: + pfree(token); + return JSON_SUCCESS; + } + state->found |= state->state; + state->state = EDB_STMT_INFO_PARSE_NOOP; + return JSON_SUCCESS; + +fail: + pfree(token); + return JSON_SEM_ACTION_FAILED; +} + +/* + * shmem_request hook: request additional shared resources. We'll allocate or + * attach to the shared resources in pgss_shmem_startup(). + */ +static void +pgss_shmem_request(void) +{ + if (prev_shmem_request_hook) + prev_shmem_request_hook(); + + RequestAddinShmemSpace(pgss_memsize()); + RequestNamedLWLockTranche("edb_stat_statements", 1); +} + +/* + * shmem_startup hook: allocate or attach to shared memory, + * then load any pre-existing statistics from file. + * Also create and load the query-texts file, which is expected to exist + * (even if empty) while the module is enabled. + */ +static void +pgss_shmem_startup(void) +{ + bool found; + HASHCTL info; + FILE *file = NULL; + FILE *qfile = NULL; + uint32 header; + int32 num; + int32 pgver; + int32 i; + int buffer_size; + char *buffer = NULL; + + if (prev_shmem_startup_hook) + prev_shmem_startup_hook(); + + /* reset in case this is a restart within the postmaster */ + pgss = NULL; + pgss_hash = NULL; + + /* + * Create or attach to the shared memory state, including hash table + */ + LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); + + pgss = ShmemInitStruct("edb_stat_statements", + sizeof(pgssSharedState), + &found); + + if (!found) + { + /* First time through ... */ + pgss->lock = &(GetNamedLWLockTranche("edb_stat_statements"))->lock; + pgss->cur_median_usage = ASSUMED_MEDIAN_INIT; + pgss->mean_query_len = ASSUMED_LENGTH_INIT; + SpinLockInit(&pgss->mutex); + pgss->extent = 0; + pgss->n_writers = 0; + pgss->gc_count = 0; + pgss->stats.dealloc = 0; + pgss->stats.stats_reset = GetCurrentTimestamp(); + } + + info.keysize = sizeof(pgssHashKey); + info.entrysize = sizeof(pgssEntry); + pgss_hash = ShmemInitHash("edb_stat_statements hash", + pgss_max, pgss_max, + &info, + HASH_ELEM | HASH_BLOBS); + + LWLockRelease(AddinShmemInitLock); + + /* + * If we're in the postmaster (or a standalone backend...), set up a shmem + * exit hook to dump the statistics to disk. + */ + if (!IsUnderPostmaster) + on_shmem_exit(pgss_shmem_shutdown, (Datum) 0); + + /* + * Done if some other process already completed our initialization. + */ + if (found) + return; + + /* + * Note: we don't bother with locks here, because there should be no other + * processes running when this code is reached. + */ + + /* Unlink query text file possibly left over from crash */ + unlink(PGSS_TEXT_FILE); + + /* Allocate new query text temp file */ + qfile = AllocateFile(PGSS_TEXT_FILE, PG_BINARY_W); + if (qfile == NULL) + goto write_error; + + /* + * If we were told not to load old statistics, we're done. (Note we do + * not try to unlink any old dump file in this case. This seems a bit + * questionable but it's the historical behavior.) + */ + if (!pgss_save) + { + FreeFile(qfile); + return; + } + + /* + * Attempt to load old statistics from the dump file. + */ + file = AllocateFile(PGSS_DUMP_FILE, PG_BINARY_R); + if (file == NULL) + { + if (errno != ENOENT) + goto read_error; + /* No existing persisted stats file, so we're done */ + FreeFile(qfile); + return; + } + + buffer_size = 2048; + buffer = (char *) palloc(buffer_size); + + if (fread(&header, sizeof(uint32), 1, file) != 1 || + fread(&pgver, sizeof(uint32), 1, file) != 1 || + fread(&num, sizeof(int32), 1, file) != 1) + goto read_error; + + if (header != PGSS_FILE_HEADER || + pgver != PGSS_PG_MAJOR_VERSION) + goto data_error; + + for (i = 0; i < num; i++) + { + pgssEntry temp; + pgssEntry *entry; + Size query_offset; + int len; + + if (fread(&temp, sizeof(pgssEntry), 1, file) != 1) + goto read_error; + + /* Encoding is the only field we can easily sanity-check */ + if (!PG_VALID_BE_ENCODING(temp.encoding)) + goto data_error; + + len = temp.query_len + temp.extras_len; + + /* Resize buffer as needed */ + if (len >= buffer_size) + { + buffer_size = Max(buffer_size * 2, len + 1); + buffer = repalloc(buffer, buffer_size); + } + + if (fread(buffer, 1, len + 1, file) != len + 1) + goto read_error; + + /* Should have a trailing null, but let's make sure */ + buffer[len] = '\0'; + + /* Skip loading "sticky" entries */ + if (IS_STICKY(temp.counters)) + continue; + + /* Store the query text */ + query_offset = pgss->extent; + if (fwrite(buffer, 1, len + 1, qfile) != len + 1) + goto write_error; + pgss->extent += len + 1; + + /* make the hashtable entry (discards old entries if too many) */ + entry = entry_alloc(&temp.key, query_offset, temp.query_len, + temp.encoding, + false, NULL, 0, temp.extras_len); + + /* copy in the actual stats */ + entry->counters = temp.counters; + entry->stats_since = temp.stats_since; + entry->minmax_stats_since = temp.minmax_stats_since; + entry->id = temp.id; + entry->stmt_type = temp.stmt_type; + } + + /* Read global statistics for edb_stat_statements */ + if (fread(&pgss->stats, sizeof(pgssGlobalStats), 1, file) != 1) + goto read_error; + + pfree(buffer); + FreeFile(file); + FreeFile(qfile); + + /* + * Remove the persisted stats file so it's not included in + * backups/replication standbys, etc. A new file will be written on next + * shutdown. + * + * Note: it's okay if the PGSS_TEXT_FILE is included in a basebackup, + * because we remove that file on startup; it acts inversely to + * PGSS_DUMP_FILE, in that it is only supposed to be around when the + * server is running, whereas PGSS_DUMP_FILE is only supposed to be around + * when the server is not running. Leaving the file creates no danger of + * a newly restored database having a spurious record of execution costs, + * which is what we're really concerned about here. + */ + unlink(PGSS_DUMP_FILE); + + return; + +read_error: + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not read file \"%s\": %m", + PGSS_DUMP_FILE))); + goto fail; +data_error: + ereport(LOG, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("ignoring invalid data in file \"%s\"", + PGSS_DUMP_FILE))); + goto fail; +write_error: + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not write file \"%s\": %m", + PGSS_TEXT_FILE))); +fail: + if (buffer) + pfree(buffer); + if (file) + FreeFile(file); + if (qfile) + FreeFile(qfile); + /* If possible, throw away the bogus file; ignore any error */ + unlink(PGSS_DUMP_FILE); + + /* + * Don't unlink PGSS_TEXT_FILE here; it should always be around while the + * server is running with edb_stat_statements enabled + */ +} + +/* + * shmem_shutdown hook: Dump statistics into file. + * + * Note: we don't bother with acquiring lock, because there should be no + * other processes running when this is called. + */ +static void +pgss_shmem_shutdown(int code, Datum arg) +{ + FILE *file; + char *qbuffer = NULL; + Size qbuffer_size = 0; + HASH_SEQ_STATUS hash_seq; + int32 num_entries; + pgssEntry *entry; + + /* Don't try to dump during a crash. */ + if (code) + return; + + /* Safety check ... shouldn't get here unless shmem is set up. */ + if (!pgss || !pgss_hash) + return; + + /* Don't dump if told not to. */ + if (!pgss_save) + return; + + file = AllocateFile(PGSS_DUMP_FILE ".tmp", PG_BINARY_W); + if (file == NULL) + goto error; + + if (fwrite(&PGSS_FILE_HEADER, sizeof(uint32), 1, file) != 1) + goto error; + if (fwrite(&PGSS_PG_MAJOR_VERSION, sizeof(uint32), 1, file) != 1) + goto error; + num_entries = hash_get_num_entries(pgss_hash); + if (fwrite(&num_entries, sizeof(int32), 1, file) != 1) + goto error; + + qbuffer = qtext_load_file(&qbuffer_size); + if (qbuffer == NULL) + goto error; + + /* + * When serializing to disk, we store query texts immediately after their + * entry data. Any orphaned query texts are thereby excluded. + */ + hash_seq_init(&hash_seq, pgss_hash); + while ((entry = hash_seq_search(&hash_seq)) != NULL) + { + int len = entry->query_len + entry->extras_len; + char *qstr = qtext_fetch(entry->query_offset, len, + qbuffer, qbuffer_size); + + if (qstr == NULL) + continue; /* Ignore any entries with bogus texts */ + + if (fwrite(entry, sizeof(pgssEntry), 1, file) != 1 || + fwrite(qstr, 1, len + 1, file) != len + 1) + { + /* note: we assume hash_seq_term won't change errno */ + hash_seq_term(&hash_seq); + goto error; + } + } + + /* Dump global statistics for edb_stat_statements */ + if (fwrite(&pgss->stats, sizeof(pgssGlobalStats), 1, file) != 1) + goto error; + + free(qbuffer); + qbuffer = NULL; + + if (FreeFile(file)) + { + file = NULL; + goto error; + } + + /* + * Rename file into place, so we atomically replace any old one. + */ + (void) durable_rename(PGSS_DUMP_FILE ".tmp", PGSS_DUMP_FILE, LOG); + + /* Unlink query-texts file; it's not needed while shutdown */ + unlink(PGSS_TEXT_FILE); + + return; + +error: + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not write file \"%s\": %m", + PGSS_DUMP_FILE ".tmp"))); + free(qbuffer); + if (file) + FreeFile(file); + unlink(PGSS_DUMP_FILE ".tmp"); + unlink(PGSS_TEXT_FILE); +} + +/* + * Post-parse-analysis hook: mark query with a queryId + */ +static void +pgss_post_parse_analyze(ParseState *pstate, Query *query, JumbleState *jstate) +{ + EdbStmtInfo *info; + const char *query_str; + int query_location, query_len; + + if (prev_post_parse_analyze_hook) + prev_post_parse_analyze_hook(pstate, query, jstate); + + /* Safety check... */ + if (!pgss || !pgss_hash || !pgss_enabled(nesting_level)) + return; + + /* + * If it's EXECUTE, clear the queryId so that stats will accumulate for + * the underlying PREPARE. But don't do this if we're not tracking + * utility statements, to avoid messing up another extension that might be + * tracking them. + */ + if (query->utilityStmt) + { + if (pgss_track_utility && IsA(query->utilityStmt, ExecuteStmt)) + { + query->queryId = UINT64CONST(0); + return; + } + } + + /* Parse EdgeDB query info JSON and overwrite query->queryId */ + query_location = query->stmt_location; + query_len = query->stmt_len; + query_str = CleanQuerytext(pstate->p_sourcetext, &query_location, &query_len); + if ((info = edbss_extract_stmt_info(query_str, query_len)) != NULL) { + query->queryId = info->id.query_id; + + /* We immediately create a hash table entry for the query, + * so that we don't need to parse the query info JSON later + * again for the query with the same queryId. + */ + pgss_store(info->query, + info->id.query_id, + 0, + info->query_len, + PGSS_INVALID, + 0, + 0, + NULL, + NULL, + NULL, + NULL, + true, + &info->id.uuid, + info->stmt_type, + info->extras, + 0, + 0); + edbss_free_stmt_info(info); + } else if (!edbss_track_unrecognized) { + query->queryId = UINT64CONST(0); + } else if (jstate && jstate->clocations_count > 0) + /* + * If query jumbling were able to identify any ignorable constants, we + * immediately create a hash table entry for the query, so that we can + * record the normalized form of the query string. If there were no such + * constants, the normalized string would be the same as the query text + * anyway, so there's no need for an early entry. + */ + pgss_store(pstate->p_sourcetext, + query->queryId, + query->stmt_location, + query->stmt_len, + PGSS_INVALID, + 0, + 0, + NULL, + NULL, + NULL, + jstate, + true, + NULL, + 0, + NULL, + 0, + 0); +} + +/* + * Planner hook: forward to regular planner, but measure planning time + * if needed. + */ +static PlannedStmt * +pgss_planner(Query *parse, + const char *query_string, + int cursorOptions, + ParamListInfo boundParams) +{ + PlannedStmt *result; + + /* + * We can't process the query if no query_string is provided, as + * pgss_store needs it. We also ignore query without queryid, as it would + * be treated as a utility statement, which may not be the case. + */ + if (pgss_enabled(nesting_level) + && pgss_track_planning && query_string + && parse->queryId != UINT64CONST(0)) + { + instr_time start; + instr_time duration; + BufferUsage bufusage_start, + bufusage; + WalUsage walusage_start, + walusage; + + /* We need to track buffer usage as the planner can access them. */ + bufusage_start = pgBufferUsage; + + /* + * Similarly the planner could write some WAL records in some cases + * (e.g. setting a hint bit with those being WAL-logged) + */ + walusage_start = pgWalUsage; + INSTR_TIME_SET_CURRENT(start); + + nesting_level++; + PG_TRY(); + { + if (prev_planner_hook) + result = prev_planner_hook(parse, query_string, cursorOptions, + boundParams); + else + result = standard_planner(parse, query_string, cursorOptions, + boundParams); + } + PG_FINALLY(); + { + nesting_level--; + } + PG_END_TRY(); + + INSTR_TIME_SET_CURRENT(duration); + INSTR_TIME_SUBTRACT(duration, start); + + /* calc differences of buffer counters. */ + memset(&bufusage, 0, sizeof(BufferUsage)); + BufferUsageAccumDiff(&bufusage, &pgBufferUsage, &bufusage_start); + + /* calc differences of WAL counters. */ + memset(&walusage, 0, sizeof(WalUsage)); + WalUsageAccumDiff(&walusage, &pgWalUsage, &walusage_start); + + pgss_store(query_string, + parse->queryId, + parse->stmt_location, + parse->stmt_len, + PGSS_PLAN, + INSTR_TIME_GET_MILLISEC(duration), + 0, + &bufusage, + &walusage, + NULL, + NULL, + false, + NULL, + 0, + NULL, + 0, + 0); + } + else + { + /* + * Even though we're not tracking plan time for this statement, we + * must still increment the nesting level, to ensure that functions + * evaluated during planning are not seen as top-level calls. + */ + nesting_level++; + PG_TRY(); + { + if (prev_planner_hook) + result = prev_planner_hook(parse, query_string, cursorOptions, + boundParams); + else + result = standard_planner(parse, query_string, cursorOptions, + boundParams); + } + PG_FINALLY(); + { + nesting_level--; + } + PG_END_TRY(); + } + + return result; +} + +/* + * ExecutorStart hook: start up tracking if needed + */ +static void +pgss_ExecutorStart(QueryDesc *queryDesc, int eflags) +{ + if (prev_ExecutorStart) + prev_ExecutorStart(queryDesc, eflags); + else + standard_ExecutorStart(queryDesc, eflags); + + /* + * If query has queryId zero, don't track it. This prevents double + * counting of optimizable statements that are directly contained in + * utility statements. + */ + if (pgss_enabled(nesting_level) && queryDesc->plannedstmt->queryId != UINT64CONST(0)) + { + /* + * Set up to track total elapsed time in ExecutorRun. Make sure the + * space is allocated in the per-query context so it will go away at + * ExecutorEnd. + */ + if (queryDesc->totaltime == NULL) + { + MemoryContext oldcxt; + + oldcxt = MemoryContextSwitchTo(queryDesc->estate->es_query_cxt); + queryDesc->totaltime = InstrAlloc(1, INSTRUMENT_ALL, false); + MemoryContextSwitchTo(oldcxt); + } + } +} + +/* + * ExecutorRun hook: all we need do is track nesting depth + */ +static void +pgss_ExecutorRun(QueryDesc *queryDesc, ScanDirection direction, uint64 count, + bool execute_once) +{ + nesting_level++; + PG_TRY(); + { + if (prev_ExecutorRun) + prev_ExecutorRun(queryDesc, direction, count, execute_once); + else + standard_ExecutorRun(queryDesc, direction, count, execute_once); + } + PG_FINALLY(); + { + nesting_level--; + } + PG_END_TRY(); +} + +/* + * ExecutorFinish hook: all we need do is track nesting depth + */ +static void +pgss_ExecutorFinish(QueryDesc *queryDesc) +{ + nesting_level++; + PG_TRY(); + { + if (prev_ExecutorFinish) + prev_ExecutorFinish(queryDesc); + else + standard_ExecutorFinish(queryDesc); + } + PG_FINALLY(); + { + nesting_level--; + } + PG_END_TRY(); +} + +/* + * ExecutorEnd hook: store results if needed + */ +static void +pgss_ExecutorEnd(QueryDesc *queryDesc) +{ + uint64 queryId = queryDesc->plannedstmt->queryId; + + if (queryId != UINT64CONST(0) && queryDesc->totaltime && + pgss_enabled(nesting_level)) + { + /* + * Make sure stats accumulation is done. (Note: it's okay if several + * levels of hook all do this.) + */ + InstrEndLoop(queryDesc->totaltime); + + pgss_store(queryDesc->sourceText, + queryId, + queryDesc->plannedstmt->stmt_location, + queryDesc->plannedstmt->stmt_len, + PGSS_EXEC, + queryDesc->totaltime->total * 1000.0, /* convert to msec */ + queryDesc->estate->es_total_processed, + &queryDesc->totaltime->bufusage, + &queryDesc->totaltime->walusage, + queryDesc->estate->es_jit ? &queryDesc->estate->es_jit->instr : NULL, + NULL, + false, + NULL, + 0, + NULL, +#if PG_VERSION_NUM >= 180000 + queryDesc->estate->es_parallel_workers_to_launch, + queryDesc->estate->es_parallel_workers_launched +#else + 0, + 0 +#endif + ); + } + + if (prev_ExecutorEnd) + prev_ExecutorEnd(queryDesc); + else + standard_ExecutorEnd(queryDesc); +} + +/* + * ProcessUtility hook + */ +static void +pgss_ProcessUtility(PlannedStmt *pstmt, const char *queryString, + bool readOnlyTree, + ProcessUtilityContext context, + ParamListInfo params, QueryEnvironment *queryEnv, + DestReceiver *dest, QueryCompletion *qc) +{ + Node *parsetree = pstmt->utilityStmt; + uint64 saved_queryId = pstmt->queryId; + int saved_stmt_location = pstmt->stmt_location; + int saved_stmt_len = pstmt->stmt_len; + bool enabled = pgss_track_utility && pgss_enabled(nesting_level); + + /* + * Force utility statements to get queryId zero. We do this even in cases + * where the statement contains an optimizable statement for which a + * queryId could be derived (such as EXPLAIN or DECLARE CURSOR). For such + * cases, runtime control will first go through ProcessUtility and then + * the executor, and we don't want the executor hooks to do anything, + * since we are already measuring the statement's costs at the utility + * level. + * + * Note that this is only done if edb_stat_statements is enabled and + * configured to track utility statements, in the unlikely possibility + * that user configured another extension to handle utility statements + * only. + */ + if (enabled) + pstmt->queryId = UINT64CONST(0); + + /* + * If it's an EXECUTE statement, we don't track it and don't increment the + * nesting level. This allows the cycles to be charged to the underlying + * PREPARE instead (by the Executor hooks), which is much more useful. + * + * We also don't track execution of PREPARE. If we did, we would get one + * hash table entry for the PREPARE (with hash calculated from the query + * string), and then a different one with the same query string (but hash + * calculated from the query tree) would be used to accumulate costs of + * ensuing EXECUTEs. This would be confusing. Since PREPARE doesn't + * actually run the planner (only parse+rewrite), its costs are generally + * pretty negligible and it seems okay to just ignore it. + */ + if (enabled && + !IsA(parsetree, ExecuteStmt) && + !IsA(parsetree, PrepareStmt)) + { + instr_time start; + instr_time duration; + uint64 rows; + BufferUsage bufusage_start, + bufusage; + WalUsage walusage_start, + walusage; + + bufusage_start = pgBufferUsage; + walusage_start = pgWalUsage; + INSTR_TIME_SET_CURRENT(start); + + nesting_level++; + PG_TRY(); + { + if (prev_ProcessUtility) + prev_ProcessUtility(pstmt, queryString, readOnlyTree, + context, params, queryEnv, + dest, qc); + else + standard_ProcessUtility(pstmt, queryString, readOnlyTree, + context, params, queryEnv, + dest, qc); + } + PG_FINALLY(); + { + nesting_level--; + } + PG_END_TRY(); + + /* + * CAUTION: do not access the *pstmt data structure again below here. + * If it was a ROLLBACK or similar, that data structure may have been + * freed. We must copy everything we still need into local variables, + * which we did above. + * + * For the same reason, we can't risk restoring pstmt->queryId to its + * former value, which'd otherwise be a good idea. + */ + + INSTR_TIME_SET_CURRENT(duration); + INSTR_TIME_SUBTRACT(duration, start); + + /* + * Track the total number of rows retrieved or affected by the utility + * statements of COPY, FETCH, CREATE TABLE AS, CREATE MATERIALIZED + * VIEW, REFRESH MATERIALIZED VIEW and SELECT INTO. + */ + rows = (qc && (qc->commandTag == CMDTAG_COPY || + qc->commandTag == CMDTAG_FETCH || + qc->commandTag == CMDTAG_SELECT || + qc->commandTag == CMDTAG_REFRESH_MATERIALIZED_VIEW)) ? + qc->nprocessed : 0; + + /* calc differences of buffer counters. */ + memset(&bufusage, 0, sizeof(BufferUsage)); + BufferUsageAccumDiff(&bufusage, &pgBufferUsage, &bufusage_start); + + /* calc differences of WAL counters. */ + memset(&walusage, 0, sizeof(WalUsage)); + WalUsageAccumDiff(&walusage, &pgWalUsage, &walusage_start); + + pgss_store(queryString, + saved_queryId, + saved_stmt_location, + saved_stmt_len, + PGSS_EXEC, + INSTR_TIME_GET_MILLISEC(duration), + rows, + &bufusage, + &walusage, + NULL, + NULL, + false, + NULL, + 0, + NULL, + 0, + 0); + } + else + { + /* + * Even though we're not tracking execution time for this statement, + * we must still increment the nesting level, to ensure that functions + * evaluated within it are not seen as top-level calls. But don't do + * so for EXECUTE; that way, when control reaches pgss_planner or + * pgss_ExecutorStart, we will treat the costs as top-level if + * appropriate. Likewise, don't bump for PREPARE, so that parse + * analysis will treat the statement as top-level if appropriate. + * + * To be absolutely certain we don't mess up the nesting level, + * evaluate the bump_level condition just once. + */ + bool bump_level = + !IsA(parsetree, ExecuteStmt) && + !IsA(parsetree, PrepareStmt); + + if (bump_level) + nesting_level++; + PG_TRY(); + { + if (prev_ProcessUtility) + prev_ProcessUtility(pstmt, queryString, readOnlyTree, + context, params, queryEnv, + dest, qc); + else + standard_ProcessUtility(pstmt, queryString, readOnlyTree, + context, params, queryEnv, + dest, qc); + } + PG_FINALLY(); + { + if (bump_level) + nesting_level--; + } + PG_END_TRY(); + } +} + +/* + * Store some statistics for a statement. + * + * If jstate is not NULL then we're trying to create an entry for which + * we have no statistics as yet; we just want to record the normalized + * query string. total_time, rows, bufusage and walusage are ignored in this + * case. + * + * If kind is PGSS_PLAN or PGSS_EXEC, its value is used as the array position + * for the arrays in the Counters field. + */ +static void +pgss_store(const char *query, uint64 queryId, + int query_location, int query_len, + pgssStoreKind kind, + double total_time, uint64 rows, + const BufferUsage *bufusage, + const WalUsage *walusage, + const struct JitInstrumentation *jitusage, + JumbleState *jstate, + bool edb_extracted, + pg_uuid_t *id, + EdbStmtType stmt_type, + const Jsonb *extras, + int parallel_workers_to_launch, + int parallel_workers_launched) +{ + pgssHashKey key; + pgssEntry *entry; + char *norm_query = NULL; + int encoding = GetDatabaseEncoding(); + EdbStmtInfo *info = NULL; + + Assert(query != NULL); + + /* Safety check... */ + if (!pgss || !pgss_hash) + return; + + /* + * Nothing to do if compute_query_id isn't enabled and no other module + * computed a query identifier. + */ + if (queryId == UINT64CONST(0)) + return; + + /* + * Confine our attention to the relevant part of the string, if the query + * is a portion of a multi-statement source string, and update query + * location and length if needed. + */ + query = CleanQuerytext(query, &query_location, &query_len); + + /* Set up key for hashtable search */ + + /* clear padding */ + memset(&key, 0, sizeof(pgssHashKey)); + + key.userid = GetUserId(); + key.dbid = MyDatabaseId; + key.queryid = queryId; + key.toplevel = (nesting_level == 0); + + /* Lookup the hash table entry with shared lock. */ + LWLockAcquire(pgss->lock, LW_SHARED); + + entry = (pgssEntry *) hash_search(pgss_hash, &key, HASH_FIND, NULL); + + /* Create new entry, if not present */ + if (!entry) + { + Size query_offset; + int gc_count; + bool stored; + bool do_gc; + bool sticky = true; + int extras_len; + + if (!edb_extracted) { + /* Try extract from the context of plan/execute. + * This is usually happening after a stats reset. + */ + if ((info = edbss_extract_stmt_info(query, query_len)) != NULL) { + /* We should just get the same queryId again + * as we extracted before the reset in post_parse. + */ + if (info->id.query_id != queryId) + goto done; + query = info->query; + query_len = info->query_len; + id = &info->id.uuid; + stmt_type = info->stmt_type; + extras = info->extras; + } else if (!edbss_track_unrecognized) { + /* skip unrecognized statements unless we're told not to */ + goto done; + } else { + sticky = jstate != NULL; + } + } + + /* + * Create a new, normalized query string if caller asked. We don't + * need to hold the lock while doing this work. (Note: in any case, + * it's possible that someone else creates a duplicate hashtable entry + * in the interval where we don't hold the lock below. That case is + * handled by entry_alloc.) + */ + if (jstate) + { + LWLockRelease(pgss->lock); + norm_query = generate_normalized_query(jstate, query, + query_location, + &query_len); + LWLockAcquire(pgss->lock, LW_SHARED); + } + + extras_len = extras == NULL ? 0 : VARSIZE(JsonbPGetDatum(extras)); + + /* Append new query text to file with only shared lock held */ + stored = qtext_store(norm_query ? norm_query : query, query_len, extras, extras_len, + &query_offset, &gc_count); + + /* + * Determine whether we need to garbage collect external query texts + * while the shared lock is still held. This micro-optimization + * avoids taking the time to decide this while holding exclusive lock. + */ + do_gc = need_gc_qtexts(); + + /* Need exclusive lock to make a new hashtable entry - promote */ + LWLockRelease(pgss->lock); + LWLockAcquire(pgss->lock, LW_EXCLUSIVE); + + /* + * A garbage collection may have occurred while we weren't holding the + * lock. In the unlikely event that this happens, the query text we + * stored above will have been garbage collected, so write it again. + * This should be infrequent enough that doing it while holding + * exclusive lock isn't a performance problem. + */ + if (!stored || pgss->gc_count != gc_count) + stored = qtext_store(norm_query ? norm_query : query, query_len, + extras, extras_len, + &query_offset, NULL); + + /* If we failed to write to the text file, give up */ + if (!stored) + goto done; + + /* OK to create a new hashtable entry */ + entry = entry_alloc(&key, query_offset, query_len, encoding, + sticky, id, stmt_type, extras_len); + + /* If needed, perform garbage collection while exclusive lock held */ + if (do_gc) + gc_qtexts(); + } + + /* Increment the counts, except when jstate is not NULL */ + if (!edb_extracted) + { + Assert(kind == PGSS_PLAN || kind == PGSS_EXEC); + + /* + * Grab the spinlock while updating the counters (see comment about + * locking rules at the head of the file) + */ + SpinLockAcquire(&entry->mutex); + + /* "Unstick" entry if it was previously sticky */ + if (IS_STICKY(entry->counters)) + entry->counters.usage = USAGE_INIT; + + entry->counters.calls[kind] += 1; + entry->counters.total_time[kind] += total_time; + + if (entry->counters.calls[kind] == 1) + { + entry->counters.min_time[kind] = total_time; + entry->counters.max_time[kind] = total_time; + entry->counters.mean_time[kind] = total_time; + } + else + { + /* + * Welford's method for accurately computing variance. See + * + */ + double old_mean = entry->counters.mean_time[kind]; + + entry->counters.mean_time[kind] += + (total_time - old_mean) / entry->counters.calls[kind]; + entry->counters.sum_var_time[kind] += + (total_time - old_mean) * (total_time - entry->counters.mean_time[kind]); + + /* + * Calculate min and max time. min = 0 and max = 0 means that the + * min/max statistics were reset + */ + if (entry->counters.min_time[kind] == 0 + && entry->counters.max_time[kind] == 0) + { + entry->counters.min_time[kind] = total_time; + entry->counters.max_time[kind] = total_time; + } + else + { + if (entry->counters.min_time[kind] > total_time) + entry->counters.min_time[kind] = total_time; + if (entry->counters.max_time[kind] < total_time) + entry->counters.max_time[kind] = total_time; + } + } + entry->counters.rows += rows; + entry->counters.shared_blks_hit += bufusage->shared_blks_hit; + entry->counters.shared_blks_read += bufusage->shared_blks_read; + entry->counters.shared_blks_dirtied += bufusage->shared_blks_dirtied; + entry->counters.shared_blks_written += bufusage->shared_blks_written; + entry->counters.local_blks_hit += bufusage->local_blks_hit; + entry->counters.local_blks_read += bufusage->local_blks_read; + entry->counters.local_blks_dirtied += bufusage->local_blks_dirtied; + entry->counters.local_blks_written += bufusage->local_blks_written; + entry->counters.temp_blks_read += bufusage->temp_blks_read; + entry->counters.temp_blks_written += bufusage->temp_blks_written; +#if PG_VERSION_NUM >= 170000 + entry->counters.shared_blk_read_time += INSTR_TIME_GET_MILLISEC(bufusage->shared_blk_read_time); + entry->counters.shared_blk_write_time += INSTR_TIME_GET_MILLISEC(bufusage->shared_blk_write_time); + entry->counters.local_blk_read_time += INSTR_TIME_GET_MILLISEC(bufusage->local_blk_read_time); + entry->counters.local_blk_write_time += INSTR_TIME_GET_MILLISEC(bufusage->local_blk_write_time); +#else + entry->counters.shared_blk_read_time += INSTR_TIME_GET_MILLISEC(bufusage->blk_read_time); + entry->counters.shared_blk_write_time += INSTR_TIME_GET_MILLISEC(bufusage->blk_write_time); +#endif + entry->counters.temp_blk_read_time += INSTR_TIME_GET_MILLISEC(bufusage->temp_blk_read_time); + entry->counters.temp_blk_write_time += INSTR_TIME_GET_MILLISEC(bufusage->temp_blk_write_time); + entry->counters.usage += USAGE_EXEC(total_time); + entry->counters.wal_records += walusage->wal_records; + entry->counters.wal_fpi += walusage->wal_fpi; + entry->counters.wal_bytes += walusage->wal_bytes; + if (jitusage) + { + entry->counters.jit_functions += jitusage->created_functions; + entry->counters.jit_generation_time += INSTR_TIME_GET_MILLISEC(jitusage->generation_counter); + +#if PG_VERSION_NUM >= 170000 + if (INSTR_TIME_GET_MILLISEC(jitusage->deform_counter)) + entry->counters.jit_deform_count++; + entry->counters.jit_deform_time += INSTR_TIME_GET_MILLISEC(jitusage->deform_counter); +#endif + + if (INSTR_TIME_GET_MILLISEC(jitusage->inlining_counter)) + entry->counters.jit_inlining_count++; + entry->counters.jit_inlining_time += INSTR_TIME_GET_MILLISEC(jitusage->inlining_counter); + + if (INSTR_TIME_GET_MILLISEC(jitusage->optimization_counter)) + entry->counters.jit_optimization_count++; + entry->counters.jit_optimization_time += INSTR_TIME_GET_MILLISEC(jitusage->optimization_counter); + + if (INSTR_TIME_GET_MILLISEC(jitusage->emission_counter)) + entry->counters.jit_emission_count++; + entry->counters.jit_emission_time += INSTR_TIME_GET_MILLISEC(jitusage->emission_counter); + } + + /* parallel worker counters */ + entry->counters.parallel_workers_to_launch += parallel_workers_to_launch; + entry->counters.parallel_workers_launched += parallel_workers_launched; + + SpinLockRelease(&entry->mutex); + } + +done: + LWLockRelease(pgss->lock); + + /* We postpone this clean-up until we're out of the lock */ + if (norm_query) + pfree(norm_query); + + if (info) + edbss_free_stmt_info(info); +} + +/* + * Reset statement statistics corresponding to userid, dbid, and queryid. + */ + +Datum +edb_stat_statements_reset(PG_FUNCTION_ARGS) +{ + Oid userid; + ArrayType *dbids_array; + Datum *dbids; + int dbids_len; + uint64 queryid; + bool minmax_only; + + userid = PG_GETARG_OID(0); + dbids_array = PG_GETARG_ARRAYTYPE_P(1); + queryid = (uint64) PG_GETARG_INT64(2); + minmax_only = PG_GETARG_BOOL(3); + + deconstruct_array_builtin(dbids_array, OIDOID, &dbids, NULL, &dbids_len); + + PG_RETURN_TIMESTAMPTZ(entry_reset(userid, dbids, dbids_len, queryid, minmax_only)); +} + +/* Number of output arguments (columns) for various API versions */ +#define PG_STAT_STATEMENTS_COLS_V1_0 54 +#define PG_STAT_STATEMENTS_COLS 54 /* maximum of above */ + +/* + * Retrieve statement statistics. + * + * The SQL API of this function has changed multiple times, and will likely + * do so again in future. To support the case where a newer version of this + * loadable module is being used with an old SQL declaration of the function, + * we continue to support the older API versions. For 1.2 and later, the + * expected API version is identified by embedding it in the C name of the + * function. Unfortunately we weren't bright enough to do that for 1.1. + */ +Datum +edb_stat_statements(PG_FUNCTION_ARGS) +{ + bool showtext = PG_GETARG_BOOL(0); + + edb_stat_statements_internal(fcinfo, PGSS_V1_0, showtext); + + return (Datum) 0; +} + +/* Common code for all versions of edb_stat_statements() */ +static void +edb_stat_statements_internal(FunctionCallInfo fcinfo, + pgssVersion api_version, + bool showtext) +{ + ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; + Oid userid = GetUserId(); + bool is_allowed_role = false; + char *qbuffer = NULL; + Size qbuffer_size = 0; + Size extent = 0; + int gc_count = 0; + HASH_SEQ_STATUS hash_seq; + pgssEntry *entry; + + /* + * Superusers or roles with the privileges of pg_read_all_stats members + * are allowed + */ + is_allowed_role = has_privs_of_role(userid, ROLE_PG_READ_ALL_STATS); + + /* hash table must exist already */ + if (!pgss || !pgss_hash) + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("edb_stat_statements must be loaded via \"shared_preload_libraries\""))); + + InitMaterializedSRF(fcinfo, 0); + + /* + * Check we have the expected number of output arguments. Aside from + * being a good safety check, we need a kluge here to detect API version + * 1.1, which was wedged into the code in an ill-considered way. + */ + switch (rsinfo->setDesc->natts) + { + case PG_STAT_STATEMENTS_COLS_V1_0: + if (api_version != PGSS_V1_0) + elog(ERROR, "incorrect number of output arguments"); + break; + default: + elog(ERROR, "incorrect number of output arguments"); + } + + /* + * We'd like to load the query text file (if needed) while not holding any + * lock on pgss->lock. In the worst case we'll have to do this again + * after we have the lock, but it's unlikely enough to make this a win + * despite occasional duplicated work. We need to reload if anybody + * writes to the file (either a retail qtext_store(), or a garbage + * collection) between this point and where we've gotten shared lock. If + * a qtext_store is actually in progress when we look, we might as well + * skip the speculative load entirely. + */ + if (showtext) + { + int n_writers; + + /* Take the mutex so we can examine variables */ + SpinLockAcquire(&pgss->mutex); + extent = pgss->extent; + n_writers = pgss->n_writers; + gc_count = pgss->gc_count; + SpinLockRelease(&pgss->mutex); + + /* No point in loading file now if there are active writers */ + if (n_writers == 0) + qbuffer = qtext_load_file(&qbuffer_size); + } + + /* + * Get shared lock, load or reload the query text file if we must, and + * iterate over the hashtable entries. + * + * With a large hash table, we might be holding the lock rather longer + * than one could wish. However, this only blocks creation of new hash + * table entries, and the larger the hash table the less likely that is to + * be needed. So we can hope this is okay. Perhaps someday we'll decide + * we need to partition the hash table to limit the time spent holding any + * one lock. + */ + LWLockAcquire(pgss->lock, LW_SHARED); + + if (showtext) + { + /* + * Here it is safe to examine extent and gc_count without taking the + * mutex. Note that although other processes might change + * pgss->extent just after we look at it, the strings they then write + * into the file cannot yet be referenced in the hashtable, so we + * don't care whether we see them or not. + * + * If qtext_load_file fails, we just press on; we'll return NULL for + * every query text. + */ + if (qbuffer == NULL || + pgss->extent != extent || + pgss->gc_count != gc_count) + { + free(qbuffer); + qbuffer = qtext_load_file(&qbuffer_size); + } + } + + hash_seq_init(&hash_seq, pgss_hash); + while ((entry = hash_seq_search(&hash_seq)) != NULL) + { + Datum values[PG_STAT_STATEMENTS_COLS]; + bool nulls[PG_STAT_STATEMENTS_COLS]; + int i = 0; + Counters tmp; + double stddev; + int64 queryid = entry->key.queryid; + TimestampTz stats_since; + TimestampTz minmax_stats_since; + + memset(values, 0, sizeof(values)); + memset(nulls, 0, sizeof(nulls)); + + values[i++] = ObjectIdGetDatum(entry->key.userid); + values[i++] = ObjectIdGetDatum(entry->key.dbid); + values[i++] = BoolGetDatum(entry->key.toplevel); + + if (is_allowed_role || entry->key.userid == userid) + { + values[i++] = Int64GetDatumFast(queryid); + + if (showtext) + { + char *qstr = qtext_fetch(entry->query_offset, + entry->query_len + entry->extras_len, + qbuffer, + qbuffer_size); + + if (qstr) + { + char *enc; + + enc = pg_any_to_server(qstr + entry->extras_len, + entry->query_len, + entry->encoding); + + values[i++] = CStringGetTextDatum(enc); + + // The "extras" Jsonb varlena datum + if (entry->extras_len > 0) + values[i++] = PointerGetDatum(qstr); + else + nulls[i++] = true; + + if (enc != qstr + entry->extras_len) + pfree(enc); + } + else + { + /* Just return a null if we fail to find the text */ + nulls[i++] = true; + + /* null extras */ + nulls[i++] = true; + } + } + else + { + /* Query text not requested */ + nulls[i++] = true; + + /* null extras */ + nulls[i++] = true; + } + } + else + { + /* Don't show queryid */ + nulls[i++] = true; + + /* + * Don't show query text, but hint as to the reason for not doing + * so if it was requested + */ + if (showtext) + values[i++] = CStringGetTextDatum(""); + else + nulls[i++] = true; + + /* null extras */ + nulls[i++] = true; + } + + if (memcmp(&entry->id, &zero_uuid, sizeof(zero_uuid)) == 0) + nulls[i++] = true; + else + values[i++] = UUIDPGetDatum(&entry->id); + + if (entry->stmt_type == 0) + nulls[i++] = true; + else + values[i++] = Int16GetDatum(entry->stmt_type); + + /* copy counters to a local variable to keep locking time short */ + SpinLockAcquire(&entry->mutex); + tmp = entry->counters; + stats_since = entry->stats_since; + minmax_stats_since = entry->minmax_stats_since; + SpinLockRelease(&entry->mutex); + + /* Skip entry if unexecuted (ie, it's a pending "sticky" entry) */ + if (IS_STICKY(tmp)) + continue; + + /* Note that we rely on PGSS_PLAN being 0 and PGSS_EXEC being 1. */ + for (int kind = 0; kind < PGSS_NUMKIND; kind++) + { + values[i++] = Int64GetDatumFast(tmp.calls[kind]); + values[i++] = Float8GetDatumFast(tmp.total_time[kind]); + values[i++] = Float8GetDatumFast(tmp.min_time[kind]); + values[i++] = Float8GetDatumFast(tmp.max_time[kind]); + values[i++] = Float8GetDatumFast(tmp.mean_time[kind]); + + /* + * Note we are calculating the population variance here, not + * the sample variance, as we have data for the whole + * population, so Bessel's correction is not used, and we + * don't divide by tmp.calls - 1. + */ + if (tmp.calls[kind] > 1) + stddev = sqrt(tmp.sum_var_time[kind] / tmp.calls[kind]); + else + stddev = 0.0; + values[i++] = Float8GetDatumFast(stddev); + } + values[i++] = Int64GetDatumFast(tmp.rows); + values[i++] = Int64GetDatumFast(tmp.shared_blks_hit); + values[i++] = Int64GetDatumFast(tmp.shared_blks_read); + values[i++] = Int64GetDatumFast(tmp.shared_blks_dirtied); + values[i++] = Int64GetDatumFast(tmp.shared_blks_written); + values[i++] = Int64GetDatumFast(tmp.local_blks_hit); + values[i++] = Int64GetDatumFast(tmp.local_blks_read); + values[i++] = Int64GetDatumFast(tmp.local_blks_dirtied); + values[i++] = Int64GetDatumFast(tmp.local_blks_written); + values[i++] = Int64GetDatumFast(tmp.temp_blks_read); + values[i++] = Int64GetDatumFast(tmp.temp_blks_written); + values[i++] = Float8GetDatumFast(tmp.shared_blk_read_time); + values[i++] = Float8GetDatumFast(tmp.shared_blk_write_time); + values[i++] = Float8GetDatumFast(tmp.local_blk_read_time); + values[i++] = Float8GetDatumFast(tmp.local_blk_write_time); + values[i++] = Float8GetDatumFast(tmp.temp_blk_read_time); + values[i++] = Float8GetDatumFast(tmp.temp_blk_write_time); + { + char buf[256]; + Datum wal_bytes; + + values[i++] = Int64GetDatumFast(tmp.wal_records); + values[i++] = Int64GetDatumFast(tmp.wal_fpi); + + snprintf(buf, sizeof buf, UINT64_FORMAT, tmp.wal_bytes); + + /* Convert to numeric. */ + wal_bytes = DirectFunctionCall3(numeric_in, + CStringGetDatum(buf), + ObjectIdGetDatum(0), + Int32GetDatum(-1)); + values[i++] = wal_bytes; + } + values[i++] = Int64GetDatumFast(tmp.jit_functions); + values[i++] = Float8GetDatumFast(tmp.jit_generation_time); + values[i++] = Int64GetDatumFast(tmp.jit_inlining_count); + values[i++] = Float8GetDatumFast(tmp.jit_inlining_time); + values[i++] = Int64GetDatumFast(tmp.jit_optimization_count); + values[i++] = Float8GetDatumFast(tmp.jit_optimization_time); + values[i++] = Int64GetDatumFast(tmp.jit_emission_count); + values[i++] = Float8GetDatumFast(tmp.jit_emission_time); + values[i++] = Int64GetDatumFast(tmp.jit_deform_count); + values[i++] = Float8GetDatumFast(tmp.jit_deform_time); + values[i++] = Int64GetDatumFast(tmp.parallel_workers_to_launch); + values[i++] = Int64GetDatumFast(tmp.parallel_workers_launched); + values[i++] = TimestampTzGetDatum(stats_since); + values[i++] = TimestampTzGetDatum(minmax_stats_since); + + Assert(i == (api_version == PGSS_V1_0 ? PG_STAT_STATEMENTS_COLS_V1_0 : + -1 /* fail if you forget to update this assert */ )); + + tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); + } + + LWLockRelease(pgss->lock); + + free(qbuffer); +} + +/* Number of output arguments (columns) for edb_stat_statements_info */ +#define PG_STAT_STATEMENTS_INFO_COLS 2 + +/* + * Return statistics of edb_stat_statements. + */ +Datum +edb_stat_statements_info(PG_FUNCTION_ARGS) +{ + pgssGlobalStats stats; + TupleDesc tupdesc; + Datum values[PG_STAT_STATEMENTS_INFO_COLS] = {0}; + bool nulls[PG_STAT_STATEMENTS_INFO_COLS] = {0}; + + if (!pgss || !pgss_hash) + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("edb_stat_statements must be loaded via \"shared_preload_libraries\""))); + + /* Build a tuple descriptor for our result type */ + if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE) + elog(ERROR, "return type must be a row type"); + + /* Read global statistics for edb_stat_statements */ + SpinLockAcquire(&pgss->mutex); + stats = pgss->stats; + SpinLockRelease(&pgss->mutex); + + values[0] = Int64GetDatum(stats.dealloc); + values[1] = TimestampTzGetDatum(stats.stats_reset); + + PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls))); +} + +/* + * Convert uuid to bigint as queryid. + */ +Datum +edb_stat_queryid(PG_FUNCTION_ARGS) { + union { + pg_uuid_t uuid; + uint64 id; + } id; + id.uuid = *PG_GETARG_UUID_P(0); + return UInt64GetDatum(id.id); +} + +/* + * Estimate shared memory space needed. + */ +static Size +pgss_memsize(void) +{ + Size size; + + size = MAXALIGN(sizeof(pgssSharedState)); + size = add_size(size, hash_estimate_size(pgss_max, sizeof(pgssEntry))); + + return size; +} + +/* + * Allocate a new hashtable entry. + * caller must hold an exclusive lock on pgss->lock + * + * "query" need not be null-terminated; we rely on query_len instead + * + * If "sticky" is true, make the new entry artificially sticky so that it will + * probably still be there when the query finishes execution. We do this by + * giving it a median usage value rather than the normal value. (Strictly + * speaking, query strings are normalized on a best effort basis, though it + * would be difficult to demonstrate this even under artificial conditions.) + * + * Note: despite needing exclusive lock, it's not an error for the target + * entry to already exist. This is because pgss_store releases and + * reacquires lock after failing to find a match; so someone else could + * have made the entry while we waited to get exclusive lock. + */ +static pgssEntry * +entry_alloc(pgssHashKey *key, Size query_offset, int query_len, int encoding, + bool sticky, pg_uuid_t *id, EdbStmtType stmt_type, int extras_len) +{ + pgssEntry *entry; + bool found; + + /* Make space if needed */ + while (hash_get_num_entries(pgss_hash) >= pgss_max) + entry_dealloc(); + + /* Find or create an entry with desired hash code */ + entry = (pgssEntry *) hash_search(pgss_hash, key, HASH_ENTER, &found); + + if (!found) + { + /* New entry, initialize it */ + + /* reset the statistics */ + memset(&entry->counters, 0, sizeof(Counters)); + /* set the appropriate initial usage count */ + entry->counters.usage = sticky ? pgss->cur_median_usage : USAGE_INIT; + /* re-initialize the mutex each time ... we assume no one using it */ + SpinLockInit(&entry->mutex); + /* ... and don't forget the query text metadata */ + Assert(query_len >= 0); + entry->query_offset = query_offset; + entry->query_len = query_len; + entry->encoding = encoding; + entry->stats_since = GetCurrentTimestamp(); + entry->minmax_stats_since = entry->stats_since; + if (id != NULL) + entry->id = *id; + entry->stmt_type = stmt_type; + entry->extras_len = extras_len; + } + + return entry; +} + +/* + * qsort comparator for sorting into increasing usage order + */ +static int +entry_cmp(const void *lhs, const void *rhs) +{ + double l_usage = (*(pgssEntry *const *) lhs)->counters.usage; + double r_usage = (*(pgssEntry *const *) rhs)->counters.usage; + + if (l_usage < r_usage) + return -1; + else if (l_usage > r_usage) + return +1; + else + return 0; +} + +/* + * Deallocate least-used entries. + * + * Caller must hold an exclusive lock on pgss->lock. + */ +static void +entry_dealloc(void) +{ + HASH_SEQ_STATUS hash_seq; + pgssEntry **entries; + pgssEntry *entry; + int nvictims; + int i; + Size tottextlen; + int nvalidtexts; + + /* + * Sort entries by usage and deallocate USAGE_DEALLOC_PERCENT of them. + * While we're scanning the table, apply the decay factor to the usage + * values, and update the mean query length. + * + * Note that the mean query length is almost immediately obsolete, since + * we compute it before not after discarding the least-used entries. + * Hopefully, that doesn't affect the mean too much; it doesn't seem worth + * making two passes to get a more current result. Likewise, the new + * cur_median_usage includes the entries we're about to zap. + */ + + entries = palloc(hash_get_num_entries(pgss_hash) * sizeof(pgssEntry *)); + + i = 0; + tottextlen = 0; + nvalidtexts = 0; + + hash_seq_init(&hash_seq, pgss_hash); + while ((entry = hash_seq_search(&hash_seq)) != NULL) + { + entries[i++] = entry; + /* "Sticky" entries get a different usage decay rate. */ + if (IS_STICKY(entry->counters)) + entry->counters.usage *= STICKY_DECREASE_FACTOR; + else + entry->counters.usage *= USAGE_DECREASE_FACTOR; + /* In the mean length computation, ignore dropped texts. */ + if (entry->query_len >= 0) + { + tottextlen += entry->query_len + 1; + nvalidtexts++; + } + } + + /* Sort into increasing order by usage */ + qsort(entries, i, sizeof(pgssEntry *), entry_cmp); + + /* Record the (approximate) median usage */ + if (i > 0) + pgss->cur_median_usage = entries[i / 2]->counters.usage; + /* Record the mean query length */ + if (nvalidtexts > 0) + pgss->mean_query_len = tottextlen / nvalidtexts; + else + pgss->mean_query_len = ASSUMED_LENGTH_INIT; + + /* Now zap an appropriate fraction of lowest-usage entries */ + nvictims = Max(10, i * USAGE_DEALLOC_PERCENT / 100); + nvictims = Min(nvictims, i); + + for (i = 0; i < nvictims; i++) + { + hash_search(pgss_hash, &entries[i]->key, HASH_REMOVE, NULL); + } + + pfree(entries); + + /* Increment the number of times entries are deallocated */ + SpinLockAcquire(&pgss->mutex); + pgss->stats.dealloc += 1; + SpinLockRelease(&pgss->mutex); +} + +/* + * Given a query string (not necessarily null-terminated), allocate a new + * entry in the external query text file and store the string there. + * + * If successful, returns true, and stores the new entry's offset in the file + * into *query_offset. Also, if gc_count isn't NULL, *gc_count is set to the + * number of garbage collections that have occurred so far. + * + * On failure, returns false. + * + * At least a shared lock on pgss->lock must be held by the caller, so as + * to prevent a concurrent garbage collection. Share-lock-holding callers + * should pass a gc_count pointer to obtain the number of garbage collections, + * so that they can recheck the count after obtaining exclusive lock to + * detect whether a garbage collection occurred (and removed this entry). + */ +static bool +qtext_store(const char *query, int query_len, + const Jsonb *extras, int extras_len, + Size *query_offset, int *gc_count) +{ + Size off; + int fd; + + /* + * We use a spinlock to protect extent/n_writers/gc_count, so that + * multiple processes may execute this function concurrently. + */ + SpinLockAcquire(&pgss->mutex); + off = pgss->extent; + pgss->extent += query_len + extras_len + 1; + pgss->n_writers++; + if (gc_count) + *gc_count = pgss->gc_count; + SpinLockRelease(&pgss->mutex); + + *query_offset = off; + + /* + * Don't allow the file to grow larger than what qtext_load_file can + * (theoretically) handle. This has been seen to be reachable on 32-bit + * platforms. + */ + if (unlikely(query_len + extras_len >= MaxAllocHugeSize - off)) + { + errno = EFBIG; /* not quite right, but it'll do */ + fd = -1; + goto error; + } + + /* Now write the data into the successfully-reserved part of the file */ + fd = OpenTransientFile(PGSS_TEXT_FILE, O_RDWR | O_CREAT | PG_BINARY); + if (fd < 0) + goto error; + + if (extras_len > 0 && pg_pwrite(fd, extras, extras_len, off) != extras_len) + goto error; + if (pg_pwrite(fd, query, query_len, off + extras_len) != query_len) + goto error; + if (pg_pwrite(fd, "\0", 1, off + extras_len + query_len) != 1) + goto error; + + CloseTransientFile(fd); + + /* Mark our write complete */ + SpinLockAcquire(&pgss->mutex); + pgss->n_writers--; + SpinLockRelease(&pgss->mutex); + + return true; + +error: + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not write file \"%s\": %m", + PGSS_TEXT_FILE))); + + if (fd >= 0) + CloseTransientFile(fd); + + /* Mark our write complete */ + SpinLockAcquire(&pgss->mutex); + pgss->n_writers--; + SpinLockRelease(&pgss->mutex); + + return false; +} + +/* + * Read the external query text file into a malloc'd buffer. + * + * Returns NULL (without throwing an error) if unable to read, eg + * file not there or insufficient memory. + * + * On success, the buffer size is also returned into *buffer_size. + * + * This can be called without any lock on pgss->lock, but in that case + * the caller is responsible for verifying that the result is sane. + */ +static char * +qtext_load_file(Size *buffer_size) +{ + char *buf; + int fd; + struct stat stat; + Size nread; + + fd = OpenTransientFile(PGSS_TEXT_FILE, O_RDONLY | PG_BINARY); + if (fd < 0) + { + if (errno != ENOENT) + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not read file \"%s\": %m", + PGSS_TEXT_FILE))); + return NULL; + } + + /* Get file length */ + if (fstat(fd, &stat)) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not stat file \"%s\": %m", + PGSS_TEXT_FILE))); + CloseTransientFile(fd); + return NULL; + } + + /* Allocate buffer; beware that off_t might be wider than size_t */ + if (stat.st_size <= MaxAllocHugeSize) + buf = (char *) malloc(stat.st_size); + else + buf = NULL; + if (buf == NULL) + { + ereport(LOG, + (errcode(ERRCODE_OUT_OF_MEMORY), + errmsg("out of memory"), + errdetail("Could not allocate enough memory to read file \"%s\".", + PGSS_TEXT_FILE))); + CloseTransientFile(fd); + return NULL; + } + + /* + * OK, slurp in the file. Windows fails if we try to read more than + * INT_MAX bytes at once, and other platforms might not like that either, + * so read a very large file in 1GB segments. + */ + nread = 0; + while (nread < stat.st_size) + { + int toread = Min(1024 * 1024 * 1024, stat.st_size - nread); + + /* + * If we get a short read and errno doesn't get set, the reason is + * probably that garbage collection truncated the file since we did + * the fstat(), so we don't log a complaint --- but we don't return + * the data, either, since it's most likely corrupt due to concurrent + * writes from garbage collection. + */ + errno = 0; + if (read(fd, buf + nread, toread) != toread) + { + if (errno) + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not read file \"%s\": %m", + PGSS_TEXT_FILE))); + free(buf); + CloseTransientFile(fd); + return NULL; + } + nread += toread; + } + + if (CloseTransientFile(fd) != 0) + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not close file \"%s\": %m", PGSS_TEXT_FILE))); + + *buffer_size = nread; + return buf; +} + +/* + * Locate a query text in the file image previously read by qtext_load_file(). + * + * We validate the given offset/length, and return NULL if bogus. Otherwise, + * the result points to a null-terminated string within the buffer. + */ +static char * +qtext_fetch(Size query_offset, int query_len, + char *buffer, Size buffer_size) +{ + /* File read failed? */ + if (buffer == NULL) + return NULL; + /* Bogus offset/length? */ + if (query_len < 0 || + query_offset + query_len >= buffer_size) + return NULL; + /* As a further sanity check, make sure there's a trailing null */ + if (buffer[query_offset + query_len] != '\0') + return NULL; + /* Looks OK */ + return buffer + query_offset; +} + +/* + * Do we need to garbage-collect the external query text file? + * + * Caller should hold at least a shared lock on pgss->lock. + */ +static bool +need_gc_qtexts(void) +{ + Size extent; + + /* Read shared extent pointer */ + SpinLockAcquire(&pgss->mutex); + extent = pgss->extent; + SpinLockRelease(&pgss->mutex); + + /* + * Don't proceed if file does not exceed 512 bytes per possible entry. + * + * Here and in the next test, 32-bit machines have overflow hazards if + * pgss_max and/or mean_query_len are large. Force the multiplications + * and comparisons to be done in uint64 arithmetic to forestall trouble. + */ + if ((uint64) extent < (uint64) 512 * pgss_max) + return false; + + /* + * Don't proceed if file is less than about 50% bloat. Nothing can or + * should be done in the event of unusually large query texts accounting + * for file's large size. We go to the trouble of maintaining the mean + * query length in order to prevent garbage collection from thrashing + * uselessly. + */ + if ((uint64) extent < (uint64) pgss->mean_query_len * pgss_max * 2) + return false; + + return true; +} + +/* + * Garbage-collect orphaned query texts in external file. + * + * This won't be called often in the typical case, since it's likely that + * there won't be too much churn, and besides, a similar compaction process + * occurs when serializing to disk at shutdown or as part of resetting. + * Despite this, it seems prudent to plan for the edge case where the file + * becomes unreasonably large, with no other method of compaction likely to + * occur in the foreseeable future. + * + * The caller must hold an exclusive lock on pgss->lock. + * + * At the first sign of trouble we unlink the query text file to get a clean + * slate (although existing statistics are retained), rather than risk + * thrashing by allowing the same problem case to recur indefinitely. + */ +static void +gc_qtexts(void) +{ + char *qbuffer; + Size qbuffer_size; + FILE *qfile = NULL; + HASH_SEQ_STATUS hash_seq; + pgssEntry *entry; + Size extent; + int nentries; + + /* + * When called from pgss_store, some other session might have proceeded + * with garbage collection in the no-lock-held interim of lock strength + * escalation. Check once more that this is actually necessary. + */ + if (!need_gc_qtexts()) + return; + + /* + * Load the old texts file. If we fail (out of memory, for instance), + * invalidate query texts. Hopefully this is rare. It might seem better + * to leave things alone on an OOM failure, but the problem is that the + * file is only going to get bigger; hoping for a future non-OOM result is + * risky and can easily lead to complete denial of service. + */ + qbuffer = qtext_load_file(&qbuffer_size); + if (qbuffer == NULL) + goto gc_fail; + + /* + * We overwrite the query texts file in place, so as to reduce the risk of + * an out-of-disk-space failure. Since the file is guaranteed not to get + * larger, this should always work on traditional filesystems; though we + * could still lose on copy-on-write filesystems. + */ + qfile = AllocateFile(PGSS_TEXT_FILE, PG_BINARY_W); + if (qfile == NULL) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not write file \"%s\": %m", + PGSS_TEXT_FILE))); + goto gc_fail; + } + + extent = 0; + nentries = 0; + + hash_seq_init(&hash_seq, pgss_hash); + while ((entry = hash_seq_search(&hash_seq)) != NULL) + { + int query_len = entry->query_len + entry->extras_len; + char *qry = qtext_fetch(entry->query_offset, + query_len, + qbuffer, + qbuffer_size); + + if (qry == NULL) + { + /* Trouble ... drop the text */ + entry->query_offset = 0; + entry->query_len = -1; + entry->extras_len = 0; + /* entry will not be counted in mean query length computation */ + continue; + } + + if (fwrite(qry, 1, query_len + 1, qfile) != query_len + 1) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not write file \"%s\": %m", + PGSS_TEXT_FILE))); + hash_seq_term(&hash_seq); + goto gc_fail; + } + + entry->query_offset = extent; + extent += query_len + 1; + nentries++; + } + + /* + * Truncate away any now-unused space. If this fails for some odd reason, + * we log it, but there's no need to fail. + */ + if (ftruncate(fileno(qfile), extent) != 0) + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not truncate file \"%s\": %m", + PGSS_TEXT_FILE))); + + if (FreeFile(qfile)) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not write file \"%s\": %m", + PGSS_TEXT_FILE))); + qfile = NULL; + goto gc_fail; + } + + elog(DEBUG1, "pgss gc of queries file shrunk size from %zu to %zu", + pgss->extent, extent); + + /* Reset the shared extent pointer */ + pgss->extent = extent; + + /* + * Also update the mean query length, to be sure that need_gc_qtexts() + * won't still think we have a problem. + */ + if (nentries > 0) + pgss->mean_query_len = extent / nentries; + else + pgss->mean_query_len = ASSUMED_LENGTH_INIT; + + free(qbuffer); + + /* + * OK, count a garbage collection cycle. (Note: even though we have + * exclusive lock on pgss->lock, we must take pgss->mutex for this, since + * other processes may examine gc_count while holding only the mutex. + * Also, we have to advance the count *after* we've rewritten the file, + * else other processes might not realize they read a stale file.) + */ + record_gc_qtexts(); + + return; + +gc_fail: + /* clean up resources */ + if (qfile) + FreeFile(qfile); + free(qbuffer); + + /* + * Since the contents of the external file are now uncertain, mark all + * hashtable entries as having invalid texts. + */ + hash_seq_init(&hash_seq, pgss_hash); + while ((entry = hash_seq_search(&hash_seq)) != NULL) + { + entry->query_offset = 0; + entry->query_len = -1; + entry->extras_len = 0; + } + + /* + * Destroy the query text file and create a new, empty one + */ + (void) unlink(PGSS_TEXT_FILE); + qfile = AllocateFile(PGSS_TEXT_FILE, PG_BINARY_W); + if (qfile == NULL) + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not recreate file \"%s\": %m", + PGSS_TEXT_FILE))); + else + FreeFile(qfile); + + /* Reset the shared extent pointer */ + pgss->extent = 0; + + /* Reset mean_query_len to match the new state */ + pgss->mean_query_len = ASSUMED_LENGTH_INIT; + + /* + * Bump the GC count even though we failed. + * + * This is needed to make concurrent readers of file without any lock on + * pgss->lock notice existence of new version of file. Once readers + * subsequently observe a change in GC count with pgss->lock held, that + * forces a safe reopen of file. Writers also require that we bump here, + * of course. (As required by locking protocol, readers and writers don't + * trust earlier file contents until gc_count is found unchanged after + * pgss->lock acquired in shared or exclusive mode respectively.) + */ + record_gc_qtexts(); +} + +#define SINGLE_ENTRY_RESET(e) \ +if (e) { \ + if (minmax_only) { \ + /* When requested reset only min/max statistics of an entry */ \ + for (int kind = 0; kind < PGSS_NUMKIND; kind++) \ + { \ + e->counters.max_time[kind] = 0; \ + e->counters.min_time[kind] = 0; \ + } \ + e->minmax_stats_since = stats_reset; \ + } \ + else \ + { \ + /* Remove the key otherwise */ \ + hash_search(pgss_hash, &e->key, HASH_REMOVE, NULL); \ + num_remove++; \ + } \ +} + +/* + * Reset entries corresponding to parameters passed. + */ +static TimestampTz +entry_reset(Oid userid, const Datum *dbids, int dbids_len, uint64 queryid, bool minmax_only) +{ + HASH_SEQ_STATUS hash_seq; + pgssEntry *entry; + FILE *qfile; + long num_entries; + long num_remove = 0; + pgssHashKey key; + TimestampTz stats_reset; + + if (!pgss || !pgss_hash) + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("edb_stat_statements must be loaded via \"shared_preload_libraries\""))); + + LWLockAcquire(pgss->lock, LW_EXCLUSIVE); + num_entries = hash_get_num_entries(pgss_hash); + + stats_reset = GetCurrentTimestamp(); + + if (userid != 0 && dbids_len == 1 && queryid != UINT64CONST(0)) + { + /* If all the parameters are available, use the fast path. */ + memset(&key, 0, sizeof(pgssHashKey)); + key.userid = userid; + key.dbid = DatumGetObjectId(dbids[0]); + key.queryid = queryid; + + /* + * Reset the entry if it exists, starting with the non-top-level + * entry. + */ + key.toplevel = false; + entry = (pgssEntry *) hash_search(pgss_hash, &key, HASH_FIND, NULL); + + SINGLE_ENTRY_RESET(entry); + + /* Also reset the top-level entry if it exists. */ + key.toplevel = true; + entry = (pgssEntry *) hash_search(pgss_hash, &key, HASH_FIND, NULL); + + SINGLE_ENTRY_RESET(entry); + } + else if (userid != 0 || dbids_len > 0 || queryid != UINT64CONST(0)) + { + /* Reset entries corresponding to valid parameters. */ + hash_seq_init(&hash_seq, pgss_hash); + if (dbids_len > 0) { + while ((entry = hash_seq_search(&hash_seq)) != NULL) { + for (int i = 0; i < dbids_len; i++) { + Oid dbid = DatumGetObjectId(dbids[i]); + if ((!userid || entry->key.userid == userid) && + (entry->key.dbid == dbid) && + (!queryid || entry->key.queryid == queryid)) { + SINGLE_ENTRY_RESET(entry); + break; + } + } + } + } else { + while ((entry = hash_seq_search(&hash_seq)) != NULL) { + if ((!userid || entry->key.userid == userid) && + (!queryid || entry->key.queryid == queryid)) { + SINGLE_ENTRY_RESET(entry); + } + } + } + } + else + { + /* Reset all entries. */ + hash_seq_init(&hash_seq, pgss_hash); + while ((entry = hash_seq_search(&hash_seq)) != NULL) + { + SINGLE_ENTRY_RESET(entry); + } + } + + /* All entries are removed? */ + if (num_entries != num_remove) + goto release_lock; + + /* + * Reset global statistics for edb_stat_statements since all entries are + * removed. + */ + SpinLockAcquire(&pgss->mutex); + pgss->stats.dealloc = 0; + pgss->stats.stats_reset = stats_reset; + SpinLockRelease(&pgss->mutex); + + /* + * Write new empty query file, perhaps even creating a new one to recover + * if the file was missing. + */ + qfile = AllocateFile(PGSS_TEXT_FILE, PG_BINARY_W); + if (qfile == NULL) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not create file \"%s\": %m", + PGSS_TEXT_FILE))); + goto done; + } + + /* If ftruncate fails, log it, but it's not a fatal problem */ + if (ftruncate(fileno(qfile), 0) != 0) + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not truncate file \"%s\": %m", + PGSS_TEXT_FILE))); + + FreeFile(qfile); + +done: + pgss->extent = 0; + /* This counts as a query text garbage collection for our purposes */ + record_gc_qtexts(); + +release_lock: + LWLockRelease(pgss->lock); + + return stats_reset; +} + +/* + * Generate a normalized version of the query string that will be used to + * represent all similar queries. + * + * Note that the normalized representation may well vary depending on + * just which "equivalent" query is used to create the hashtable entry. + * We assume this is OK. + * + * If query_loc > 0, then "query" has been advanced by that much compared to + * the original string start, so we need to translate the provided locations + * to compensate. (This lets us avoid re-scanning statements before the one + * of interest, so it's worth doing.) + * + * *query_len_p contains the input string length, and is updated with + * the result string length on exit. The resulting string might be longer + * or shorter depending on what happens with replacement of constants. + * + * Returns a palloc'd string. + */ +static char * +generate_normalized_query(JumbleState *jstate, const char *query, + int query_loc, int *query_len_p) +{ + char *norm_query; + int query_len = *query_len_p; + int i, + norm_query_buflen, /* Space allowed for norm_query */ + len_to_wrt, /* Length (in bytes) to write */ + quer_loc = 0, /* Source query byte location */ + n_quer_loc = 0, /* Normalized query byte location */ + last_off = 0, /* Offset from start for previous tok */ + last_tok_len = 0; /* Length (in bytes) of that tok */ + + /* + * Get constants' lengths (core system only gives us locations). Note + * this also ensures the items are sorted by location. + */ + fill_in_constant_lengths(jstate, query, query_loc); + + /* + * Allow for $n symbols to be longer than the constants they replace. + * Constants must take at least one byte in text form, while a $n symbol + * certainly isn't more than 11 bytes, even if n reaches INT_MAX. We + * could refine that limit based on the max value of n for the current + * query, but it hardly seems worth any extra effort to do so. + */ + norm_query_buflen = query_len + jstate->clocations_count * 10; + + /* Allocate result buffer */ + norm_query = palloc(norm_query_buflen + 1); + + for (i = 0; i < jstate->clocations_count; i++) + { + int off, /* Offset from start for cur tok */ + tok_len; /* Length (in bytes) of that tok */ + + off = jstate->clocations[i].location; + /* Adjust recorded location if we're dealing with partial string */ + off -= query_loc; + + tok_len = jstate->clocations[i].length; + + if (tok_len < 0) + continue; /* ignore any duplicates */ + + /* Copy next chunk (what precedes the next constant) */ + len_to_wrt = off - last_off; + len_to_wrt -= last_tok_len; + + Assert(len_to_wrt >= 0); + memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt); + n_quer_loc += len_to_wrt; + + /* And insert a param symbol in place of the constant token */ + n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d", + i + 1 + jstate->highest_extern_param_id); + + quer_loc = off + tok_len; + last_off = off; + last_tok_len = tok_len; + } + + /* + * We've copied up until the last ignorable constant. Copy over the + * remaining bytes of the original query string. + */ + len_to_wrt = query_len - quer_loc; + + Assert(len_to_wrt >= 0); + memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt); + n_quer_loc += len_to_wrt; + + Assert(n_quer_loc <= norm_query_buflen); + norm_query[n_quer_loc] = '\0'; + + *query_len_p = n_quer_loc; + return norm_query; +} + +/* + * Given a valid SQL string and an array of constant-location records, + * fill in the textual lengths of those constants. + * + * The constants may use any allowed constant syntax, such as float literals, + * bit-strings, single-quoted strings and dollar-quoted strings. This is + * accomplished by using the public API for the core scanner. + * + * It is the caller's job to ensure that the string is a valid SQL statement + * with constants at the indicated locations. Since in practice the string + * has already been parsed, and the locations that the caller provides will + * have originated from within the authoritative parser, this should not be + * a problem. + * + * Duplicate constant pointers are possible, and will have their lengths + * marked as '-1', so that they are later ignored. (Actually, we assume the + * lengths were initialized as -1 to start with, and don't change them here.) + * + * If query_loc > 0, then "query" has been advanced by that much compared to + * the original string start, so we need to translate the provided locations + * to compensate. (This lets us avoid re-scanning statements before the one + * of interest, so it's worth doing.) + * + * N.B. There is an assumption that a '-' character at a Const location begins + * a negative numeric constant. This precludes there ever being another + * reason for a constant to start with a '-'. + */ +static void +fill_in_constant_lengths(JumbleState *jstate, const char *query, + int query_loc) +{ + LocationLen *locs; + core_yyscan_t yyscanner; + core_yy_extra_type yyextra; + core_YYSTYPE yylval; + YYLTYPE yylloc; + int last_loc = -1; + int i; + + /* + * Sort the records by location so that we can process them in order while + * scanning the query text. + */ + if (jstate->clocations_count > 1) + qsort(jstate->clocations, jstate->clocations_count, + sizeof(LocationLen), comp_location); + locs = jstate->clocations; + + /* initialize the flex scanner --- should match raw_parser() */ + yyscanner = scanner_init(query, + &yyextra, + &ScanKeywords, + ScanKeywordTokens); + + /* we don't want to re-emit any escape string warnings */ + yyextra.escape_string_warning = false; + + /* Search for each constant, in sequence */ + for (i = 0; i < jstate->clocations_count; i++) + { + int loc = locs[i].location; + int tok; + + /* Adjust recorded location if we're dealing with partial string */ + loc -= query_loc; + + Assert(loc >= 0); + + if (loc <= last_loc) + continue; /* Duplicate constant, ignore */ + + /* Lex tokens until we find the desired constant */ + for (;;) + { + tok = core_yylex(&yylval, &yylloc, yyscanner); + + /* We should not hit end-of-string, but if we do, behave sanely */ + if (tok == 0) + break; /* out of inner for-loop */ + + /* + * We should find the token position exactly, but if we somehow + * run past it, work with that. + */ + if (yylloc >= loc) + { + if (query[loc] == '-') + { + /* + * It's a negative value - this is the one and only case + * where we replace more than a single token. + * + * Do not compensate for the core system's special-case + * adjustment of location to that of the leading '-' + * operator in the event of a negative constant. It is + * also useful for our purposes to start from the minus + * symbol. In this way, queries like "select * from foo + * where bar = 1" and "select * from foo where bar = -2" + * will have identical normalized query strings. + */ + tok = core_yylex(&yylval, &yylloc, yyscanner); + if (tok == 0) + break; /* out of inner for-loop */ + } + + /* + * We now rely on the assumption that flex has placed a zero + * byte after the text of the current token in scanbuf. + */ + locs[i].length = strlen(yyextra.scanbuf + loc); + break; /* out of inner for-loop */ + } + } + + /* If we hit end-of-string, give up, leaving remaining lengths -1 */ + if (tok == 0) + break; + + last_loc = loc; + } + + scanner_finish(yyscanner); +} + +/* + * comp_location: comparator for qsorting LocationLen structs by location + */ +static int +comp_location(const void *a, const void *b) +{ + int l = ((const LocationLen *) a)->location; + int r = ((const LocationLen *) b)->location; + +#if PG_VERSION_NUM >= 170000 + return pg_cmp_s32(l, r); +#else + if (l < r) + return -1; + else if (l > r) + return +1; + else + return 0; +#endif +} diff --git a/edb_stat_statements/edb_stat_statements.control b/edb_stat_statements/edb_stat_statements.control new file mode 100644 index 00000000000..4884103987f --- /dev/null +++ b/edb_stat_statements/edb_stat_statements.control @@ -0,0 +1,5 @@ +# edb_stat_statements extension +comment = 'track planning and execution statistics of all EdgeDB queries executed' +default_version = '1.0' +module_pathname = '$libdir/edb_stat_statements' +relocatable = true diff --git a/edb_stat_statements/expected/cleanup.out b/edb_stat_statements/expected/cleanup.out new file mode 100644 index 00000000000..03e40380b87 --- /dev/null +++ b/edb_stat_statements/expected/cleanup.out @@ -0,0 +1 @@ +DROP EXTENSION edb_stat_statements; diff --git a/edb_stat_statements/expected/cursors.out b/edb_stat_statements/expected/cursors.out new file mode 100644 index 00000000000..677ab63a506 --- /dev/null +++ b/edb_stat_statements/expected/cursors.out @@ -0,0 +1,70 @@ +-- +-- Cursors +-- +-- These tests require track_utility to be enabled. +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- DECLARE +-- SELECT is normalized. +DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT 1; +CLOSE cursor_stats_1; +DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT 2; +CLOSE cursor_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------- + 2 | 0 | CLOSE cursor_stats_1 + 2 | 0 | DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT $1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- FETCH +BEGIN; +DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT 2; +DECLARE cursor_stats_2 CURSOR WITH HOLD FOR SELECT 3; +FETCH 1 IN cursor_stats_1; + ?column? +---------- + 2 +(1 row) + +FETCH 1 IN cursor_stats_2; + ?column? +---------- + 3 +(1 row) + +CLOSE cursor_stats_1; +CLOSE cursor_stats_2; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------- + 1 | 0 | BEGIN + 1 | 0 | CLOSE cursor_stats_1 + 1 | 0 | CLOSE cursor_stats_2 + 1 | 0 | COMMIT + 1 | 0 | DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT $1 + 1 | 0 | DECLARE cursor_stats_2 CURSOR WITH HOLD FOR SELECT $1 + 1 | 1 | FETCH 1 IN cursor_stats_1 + 1 | 1 | FETCH 1 IN cursor_stats_2 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(9 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/dml.out.17 b/edb_stat_statements/expected/dml.out.17 new file mode 100644 index 00000000000..c79f836350e --- /dev/null +++ b/edb_stat_statements/expected/dml.out.17 @@ -0,0 +1,174 @@ +-- +-- DMLs on test table +-- +SET edb_stat_statements.track_utility = FALSE; +CREATE TEMP TABLE pgss_dml_tab (a int, b char(20)); +INSERT INTO pgss_dml_tab VALUES(generate_series(1, 10), 'aaa'); +UPDATE pgss_dml_tab SET b = 'bbb' WHERE a > 7; +DELETE FROM pgss_dml_tab WHERE a > 9; +-- explicit transaction +BEGIN; +UPDATE pgss_dml_tab SET b = '111' WHERE a = 1 ; +COMMIT; +BEGIN \; +UPDATE pgss_dml_tab SET b = '222' WHERE a = 2 \; +COMMIT ; +UPDATE pgss_dml_tab SET b = '333' WHERE a = 3 \; +UPDATE pgss_dml_tab SET b = '444' WHERE a = 4 ; +BEGIN \; +UPDATE pgss_dml_tab SET b = '555' WHERE a = 5 \; +UPDATE pgss_dml_tab SET b = '666' WHERE a = 6 \; +COMMIT ; +-- many INSERT values +INSERT INTO pgss_dml_tab (a, b) VALUES (1, 'a'), (2, 'b'), (3, 'c'); +-- SELECT with constants +SELECT * FROM pgss_dml_tab WHERE a > 5 ORDER BY a ; + a | b +---+---------------------- + 6 | 666 + 7 | aaa + 8 | bbb + 9 | bbb +(4 rows) + +SELECT * + FROM pgss_dml_tab + WHERE a > 9 + ORDER BY a ; + a | b +---+--- +(0 rows) + +-- these two need to be done on a different table +-- SELECT without constants +SELECT * FROM pgss_dml_tab ORDER BY a; + a | b +---+---------------------- + 1 | a + 1 | 111 + 2 | b + 2 | 222 + 3 | c + 3 | 333 + 4 | 444 + 5 | 555 + 6 | 666 + 7 | aaa + 8 | bbb + 9 | bbb +(12 rows) + +-- SELECT with IN clause +SELECT * FROM pgss_dml_tab WHERE a IN (1, 2, 3, 4, 5); + a | b +---+---------------------- + 1 | 111 + 2 | 222 + 3 | 333 + 4 | 444 + 5 | 555 + 1 | a + 2 | b + 3 | c +(8 rows) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+--------------------------------------------------------------------- + 1 | 1 | DELETE FROM pgss_dml_tab WHERE a > $1 + 1 | 3 | INSERT INTO pgss_dml_tab (a, b) VALUES ($1, $2), ($3, $4), ($5, $6) + 1 | 10 | INSERT INTO pgss_dml_tab VALUES(generate_series($1, $2), $3) + 1 | 12 | SELECT * FROM pgss_dml_tab ORDER BY a + 2 | 4 | SELECT * FROM pgss_dml_tab WHERE a > $1 ORDER BY a + 1 | 8 | SELECT * FROM pgss_dml_tab WHERE a IN ($1, $2, $3, $4, $5) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | SET edb_stat_statements.track_utility = FALSE + 6 | 6 | UPDATE pgss_dml_tab SET b = $1 WHERE a = $2 + 1 | 3 | UPDATE pgss_dml_tab SET b = $1 WHERE a > $2 +(10 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- MERGE +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN UPDATE SET b = st.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN UPDATE SET b = pgss_dml_tab.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED AND length(st.b) > 1 THEN UPDATE SET b = pgss_dml_tab.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (a, b) VALUES (0, NULL); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT VALUES (0, NULL); -- same as above +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (b, a) VALUES (NULL, 0); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (a) VALUES (0); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN DELETE; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN DO NOTHING; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN NOT MATCHED THEN DO NOTHING; +DROP TABLE pgss_dml_tab; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------------------------------------------- + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED AND length(st.b) > $2 THEN UPDATE SET b = pgss_dml_tab.b || st.a::text + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN DELETE + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN DO NOTHING + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN UPDATE SET b = pgss_dml_tab.b || st.a::text + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN UPDATE SET b = st.b || st.a::text + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN NOT MATCHED THEN DO NOTHING + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + + | | WHEN NOT MATCHED THEN INSERT (a) VALUES ($1) + 2 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + + | | WHEN NOT MATCHED THEN INSERT (a, b) VALUES ($1, $2) + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + + | | WHEN NOT MATCHED THEN INSERT (b, a) VALUES ($1, $2) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(10 rows) + +-- check that [temp] table relation extensions are tracked as writes +CREATE TABLE pgss_extend_tab (a int, b text); +CREATE TEMP TABLE pgss_extend_temp_tab (a int, b text); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +INSERT INTO pgss_extend_tab (a, b) SELECT generate_series(1, 1000), 'something'; +INSERT INTO pgss_extend_temp_tab (a, b) SELECT generate_series(1, 1000), 'something'; +WITH sizes AS ( + SELECT + pg_relation_size('pgss_extend_tab') / current_setting('block_size')::int8 AS rel_size, + pg_relation_size('pgss_extend_temp_tab') / current_setting('block_size')::int8 AS temp_rel_size +) +SELECT + SUM(local_blks_written) >= (SELECT temp_rel_size FROM sizes) AS temp_written_ok, + SUM(local_blks_dirtied) >= (SELECT temp_rel_size FROM sizes) AS temp_dirtied_ok, + SUM(shared_blks_written) >= (SELECT rel_size FROM sizes) AS written_ok, + SUM(shared_blks_dirtied) >= (SELECT rel_size FROM sizes) AS dirtied_ok +FROM edb_stat_statements; + temp_written_ok | temp_dirtied_ok | written_ok | dirtied_ok +-----------------+-----------------+------------+------------ + t | t | t | t +(1 row) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/dml.out.18 b/edb_stat_statements/expected/dml.out.18 new file mode 100644 index 00000000000..68ffd4bd459 --- /dev/null +++ b/edb_stat_statements/expected/dml.out.18 @@ -0,0 +1,174 @@ +-- +-- DMLs on test table +-- +SET edb_stat_statements.track_utility = FALSE; +CREATE TEMP TABLE pgss_dml_tab (a int, b char(20)); +INSERT INTO pgss_dml_tab VALUES(generate_series(1, 10), 'aaa'); +UPDATE pgss_dml_tab SET b = 'bbb' WHERE a > 7; +DELETE FROM pgss_dml_tab WHERE a > 9; +-- explicit transaction +BEGIN; +UPDATE pgss_dml_tab SET b = '111' WHERE a = 1 ; +COMMIT; +BEGIN \; +UPDATE pgss_dml_tab SET b = '222' WHERE a = 2 \; +COMMIT ; +UPDATE pgss_dml_tab SET b = '333' WHERE a = 3 \; +UPDATE pgss_dml_tab SET b = '444' WHERE a = 4 ; +BEGIN \; +UPDATE pgss_dml_tab SET b = '555' WHERE a = 5 \; +UPDATE pgss_dml_tab SET b = '666' WHERE a = 6 \; +COMMIT ; +-- many INSERT values +INSERT INTO pgss_dml_tab (a, b) VALUES (1, 'a'), (2, 'b'), (3, 'c'); +-- SELECT with constants +SELECT * FROM pgss_dml_tab WHERE a > 5 ORDER BY a ; + a | b +---+---------------------- + 6 | 666 + 7 | aaa + 8 | bbb + 9 | bbb +(4 rows) + +SELECT * + FROM pgss_dml_tab + WHERE a > 9 + ORDER BY a ; + a | b +---+--- +(0 rows) + +-- these two need to be done on a different table +-- SELECT without constants +SELECT * FROM pgss_dml_tab ORDER BY a; + a | b +---+---------------------- + 1 | a + 1 | 111 + 2 | b + 2 | 222 + 3 | c + 3 | 333 + 4 | 444 + 5 | 555 + 6 | 666 + 7 | aaa + 8 | bbb + 9 | bbb +(12 rows) + +-- SELECT with IN clause +SELECT * FROM pgss_dml_tab WHERE a IN (1, 2, 3, 4, 5); + a | b +---+---------------------- + 1 | 111 + 2 | 222 + 3 | 333 + 4 | 444 + 5 | 555 + 1 | a + 2 | b + 3 | c +(8 rows) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+--------------------------------------------------------------------- + 1 | 1 | DELETE FROM pgss_dml_tab WHERE a > $1 + 1 | 3 | INSERT INTO pgss_dml_tab (a, b) VALUES ($1, $2), ($3, $4), ($5, $6) + 1 | 10 | INSERT INTO pgss_dml_tab VALUES(generate_series($1, $2), $3) + 1 | 12 | SELECT * FROM pgss_dml_tab ORDER BY a + 2 | 4 | SELECT * FROM pgss_dml_tab WHERE a > $1 ORDER BY a + 1 | 8 | SELECT * FROM pgss_dml_tab WHERE a IN ($1, $2, $3, $4, $5) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | SET edb_stat_statements.track_utility = $1 + 6 | 6 | UPDATE pgss_dml_tab SET b = $1 WHERE a = $2 + 1 | 3 | UPDATE pgss_dml_tab SET b = $1 WHERE a > $2 +(10 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- MERGE +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN UPDATE SET b = st.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN UPDATE SET b = pgss_dml_tab.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED AND length(st.b) > 1 THEN UPDATE SET b = pgss_dml_tab.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (a, b) VALUES (0, NULL); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT VALUES (0, NULL); -- same as above +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (b, a) VALUES (NULL, 0); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (a) VALUES (0); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN DELETE; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN DO NOTHING; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN NOT MATCHED THEN DO NOTHING; +DROP TABLE pgss_dml_tab; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------------------------------------------- + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED AND length(st.b) > $2 THEN UPDATE SET b = pgss_dml_tab.b || st.a::text + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN DELETE + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN DO NOTHING + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN UPDATE SET b = pgss_dml_tab.b || st.a::text + 1 | 6 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN MATCHED THEN UPDATE SET b = st.b || st.a::text + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= $1)+ + | | WHEN NOT MATCHED THEN DO NOTHING + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + + | | WHEN NOT MATCHED THEN INSERT (a) VALUES ($1) + 2 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + + | | WHEN NOT MATCHED THEN INSERT (a, b) VALUES ($1, $2) + 1 | 0 | MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + + | | WHEN NOT MATCHED THEN INSERT (b, a) VALUES ($1, $2) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(10 rows) + +-- check that [temp] table relation extensions are tracked as writes +CREATE TABLE pgss_extend_tab (a int, b text); +CREATE TEMP TABLE pgss_extend_temp_tab (a int, b text); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +INSERT INTO pgss_extend_tab (a, b) SELECT generate_series(1, 1000), 'something'; +INSERT INTO pgss_extend_temp_tab (a, b) SELECT generate_series(1, 1000), 'something'; +WITH sizes AS ( + SELECT + pg_relation_size('pgss_extend_tab') / current_setting('block_size')::int8 AS rel_size, + pg_relation_size('pgss_extend_temp_tab') / current_setting('block_size')::int8 AS temp_rel_size +) +SELECT + SUM(local_blks_written) >= (SELECT temp_rel_size FROM sizes) AS temp_written_ok, + SUM(local_blks_dirtied) >= (SELECT temp_rel_size FROM sizes) AS temp_dirtied_ok, + SUM(shared_blks_written) >= (SELECT rel_size FROM sizes) AS written_ok, + SUM(shared_blks_dirtied) >= (SELECT rel_size FROM sizes) AS dirtied_ok +FROM edb_stat_statements; + temp_written_ok | temp_dirtied_ok | written_ok | dirtied_ok +-----------------+-----------------+------------+------------ + t | t | t | t +(1 row) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/entry_timestamp.out b/edb_stat_statements/expected/entry_timestamp.out new file mode 100644 index 00000000000..278dede997d --- /dev/null +++ b/edb_stat_statements/expected/entry_timestamp.out @@ -0,0 +1,159 @@ +-- +-- statement timestamps +-- +-- planning time is needed during tests +SET edb_stat_statements.track_planning = TRUE; +SELECT 1 AS "STMTTS1"; + STMTTS1 +--------- + 1 +(1 row) + +SELECT now() AS ref_ts \gset +SELECT 1,2 AS "STMTTS2"; + ?column? | STMTTS2 +----------+--------- + 1 | 2 +(1 row) + +SELECT stats_since >= :'ref_ts', count(*) FROM edb_stat_statements +WHERE query LIKE '%STMTTS%' +GROUP BY stats_since >= :'ref_ts' +ORDER BY stats_since >= :'ref_ts'; + ?column? | count +----------+------- + f | 1 + t | 1 +(2 rows) + +SELECT now() AS ref_ts \gset +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_plan_time + max_plan_time = 0 + ) as minmax_plan_zero, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_stats_since_after_ref, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + total | minmax_plan_zero | minmax_exec_zero | minmax_stats_since_after_ref | stats_since_after_ref +-------+------------------+------------------+------------------------------+----------------------- + 2 | 0 | 0 | 0 | 0 +(1 row) + +-- Perform single min/max reset +SELECT edb_stat_statements_reset(0, '{}', queryid, true) AS minmax_reset_ts +FROM edb_stat_statements +WHERE query LIKE '%STMTTS1%' \gset +-- check +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_plan_time + max_plan_time = 0 + ) as minmax_plan_zero, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_stats_since_after_ref, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + total | minmax_plan_zero | minmax_exec_zero | minmax_stats_since_after_ref | stats_since_after_ref +-------+------------------+------------------+------------------------------+----------------------- + 2 | 1 | 1 | 1 | 0 +(1 row) + +-- check minmax reset timestamps +SELECT +query, minmax_stats_since = :'minmax_reset_ts' AS reset_ts_match +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%' +ORDER BY query COLLATE "C"; + query | reset_ts_match +---------------------------+---------------- + SELECT $1 AS "STMTTS1" | t + SELECT $1,$2 AS "STMTTS2" | f +(2 rows) + +-- check that minmax reset does not set stats_reset +SELECT +stats_reset = :'minmax_reset_ts' AS stats_reset_ts_match +FROM edb_stat_statements_info; + stats_reset_ts_match +---------------------- + f +(1 row) + +-- Perform common min/max reset +SELECT edb_stat_statements_reset(0, '{}', 0, true) AS minmax_reset_ts \gset +-- check again +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_plan_time + max_plan_time = 0 + ) as minmax_plan_zero, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_ts_after_ref, + count(*) FILTER ( + WHERE minmax_stats_since = :'minmax_reset_ts' + ) as minmax_ts_match, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + total | minmax_plan_zero | minmax_exec_zero | minmax_ts_after_ref | minmax_ts_match | stats_since_after_ref +-------+------------------+------------------+---------------------+-----------------+----------------------- + 2 | 2 | 2 | 2 | 2 | 0 +(1 row) + +-- Execute first query once more to check stats update +SELECT 1 AS "STMTTS1"; + STMTTS1 +--------- + 1 +(1 row) + +-- check +-- we don't check planing times here to be independent of +-- plan caching approach +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_ts_after_ref, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + total | minmax_exec_zero | minmax_ts_after_ref | stats_since_after_ref +-------+------------------+---------------------+----------------------- + 2 | 1 | 2 | 0 +(1 row) + +-- Cleanup +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/extended.out b/edb_stat_statements/expected/extended.out new file mode 100644 index 00000000000..febe7c6f5aa --- /dev/null +++ b/edb_stat_statements/expected/extended.out @@ -0,0 +1,70 @@ +-- Tests with extended query protocol +SET edb_stat_statements.track_utility = FALSE; +-- This test checks that an execute message sets a query ID. +SELECT query_id IS NOT NULL AS query_id_set + FROM pg_stat_activity WHERE pid = pg_backend_pid() \bind \g + query_id_set +-------------- + t +(1 row) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT $1 \parse stmt1 +SELECT $1, $2 \parse stmt2 +SELECT $1, $2, $3 \parse stmt3 +SELECT $1 \bind 'unnamed_val1' \g + ?column? +-------------- + unnamed_val1 +(1 row) + +\bind_named stmt1 'stmt1_val1' \g + ?column? +------------ + stmt1_val1 +(1 row) + +\bind_named stmt2 'stmt2_val1' 'stmt2_val2' \g + ?column? | ?column? +------------+------------ + stmt2_val1 | stmt2_val2 +(1 row) + +\bind_named stmt3 'stmt3_val1' 'stmt3_val2' 'stmt3_val3' \g + ?column? | ?column? | ?column? +------------+------------+------------ + stmt3_val1 | stmt3_val2 | stmt3_val3 +(1 row) + +\bind_named stmt3 'stmt3_val4' 'stmt3_val5' 'stmt3_val6' \g + ?column? | ?column? | ?column? +------------+------------+------------ + stmt3_val4 | stmt3_val5 | stmt3_val6 +(1 row) + +\bind_named stmt2 'stmt2_val3' 'stmt2_val4' \g + ?column? | ?column? +------------+------------ + stmt2_val3 | stmt2_val4 +(1 row) + +\bind_named stmt1 'stmt1_val1' \g + ?column? +------------ + stmt1_val1 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 3 | 3 | SELECT $1 + 2 | 2 | SELECT $1, $2 + 2 | 2 | SELECT $1, $2, $3 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(4 rows) + diff --git a/edb_stat_statements/expected/level_tracking.out.17 b/edb_stat_statements/expected/level_tracking.out.17 new file mode 100644 index 00000000000..c5de894cb6d --- /dev/null +++ b/edb_stat_statements/expected/level_tracking.out.17 @@ -0,0 +1,363 @@ +-- +-- Statement level tracking +-- +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- DO block - top-level tracking. +CREATE TABLE stats_track_tab (x int); +SET edb_stat_statements.track = 'top'; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; +$$ LANGUAGE plpgsql; +SELECT toplevel, calls, query FROM edb_stat_statements + WHERE query LIKE '%DELETE%' ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+-------------------------------- + t | 1 | DELETE FROM stats_track_tab + t | 1 | DO $$ + + | | BEGIN + + | | DELETE FROM stats_track_tab;+ + | | END; + + | | $$ LANGUAGE plpgsql +(2 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- DO block - all-level tracking. +SET edb_stat_statements.track = 'all'; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + f | 1 | DELETE FROM stats_track_tab + t | 1 | DELETE FROM stats_track_tab + t | 1 | DO $$ + + | | BEGIN + + | | DELETE FROM stats_track_tab; + + | | END; $$ + t | 1 | DO LANGUAGE plpgsql $$ + + | | BEGIN + + | | -- this is a SELECT + + | | PERFORM 'hello world'::TEXT; + + | | END; $$ + f | 1 | SELECT $1::TEXT + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + t | 1 | SET edb_stat_statements.track = 'all' +(7 rows) + +-- Procedure with multiple utility statements. +CREATE OR REPLACE PROCEDURE proc_with_utility_stmt() +LANGUAGE SQL +AS $$ + SHOW edb_stat_statements.track; + show edb_stat_statements.track; + SHOW edb_stat_statements.track_utility; +$$; +SET edb_stat_statements.track_utility = TRUE; +-- all-level tracking. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CALL proc_with_utility_stmt(); +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + t | 1 | CALL proc_with_utility_stmt() + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + f | 2 | SHOW edb_stat_statements.track + f | 1 | SHOW edb_stat_statements.track_utility +(4 rows) + +-- top-level tracking. +SET edb_stat_statements.track = 'top'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CALL proc_with_utility_stmt(); +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + t | 1 | CALL proc_with_utility_stmt() + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(2 rows) + +-- DO block - top-level tracking without utility. +SET edb_stat_statements.track = 'top'; +SET edb_stat_statements.track_utility = FALSE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + t | 1 | DELETE FROM stats_track_tab + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(2 rows) + +-- DO block - all-level tracking without utility. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + f | 1 | DELETE FROM stats_track_tab + t | 1 | DELETE FROM stats_track_tab + f | 1 | SELECT $1::TEXT + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(4 rows) + +-- PL/pgSQL function - top-level tracking. +SET edb_stat_statements.track = 'top'; +SET edb_stat_statements.track_utility = FALSE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CREATE FUNCTION PLUS_TWO(i INTEGER) RETURNS INTEGER AS $$ +DECLARE + r INTEGER; +BEGIN + SELECT (i + 1 + 1.0)::INTEGER INTO r; + RETURN r; +END; $$ LANGUAGE plpgsql; +SELECT PLUS_TWO(3); + plus_two +---------- + 5 +(1 row) + +SELECT PLUS_TWO(7); + plus_two +---------- + 9 +(1 row) + +-- SQL function --- use LIMIT to keep it from being inlined +CREATE FUNCTION PLUS_ONE(i INTEGER) RETURNS INTEGER AS +$$ SELECT (i + 1.0)::INTEGER LIMIT 1 $$ LANGUAGE SQL; +SELECT PLUS_ONE(8); + plus_one +---------- + 9 +(1 row) + +SELECT PLUS_ONE(10); + plus_one +---------- + 11 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 2 | SELECT PLUS_ONE($1) + 2 | 2 | SELECT PLUS_TWO($1) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +-- immutable SQL function --- can be executed at plan time +CREATE FUNCTION PLUS_THREE(i INTEGER) RETURNS INTEGER AS +$$ SELECT i + 3 LIMIT 1 $$ IMMUTABLE LANGUAGE SQL; +SELECT PLUS_THREE(8); + plus_three +------------ + 11 +(1 row) + +SELECT PLUS_THREE(10); + plus_three +------------ + 13 +(1 row) + +SELECT toplevel, calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + toplevel | calls | rows | query +----------+-------+------+------------------------------------------------------------------------------- + t | 2 | 2 | SELECT PLUS_ONE($1) + t | 2 | 2 | SELECT PLUS_THREE($1) + t | 2 | 2 | SELECT PLUS_TWO($1) + t | 1 | 3 | SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C" + t | 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +-- PL/pgSQL function - all-level tracking. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- we drop and recreate the functions to avoid any caching funnies +DROP FUNCTION PLUS_ONE(INTEGER); +DROP FUNCTION PLUS_TWO(INTEGER); +DROP FUNCTION PLUS_THREE(INTEGER); +-- PL/pgSQL function +CREATE FUNCTION PLUS_TWO(i INTEGER) RETURNS INTEGER AS $$ +DECLARE + r INTEGER; +BEGIN + SELECT (i + 1 + 1.0)::INTEGER INTO r; + RETURN r; +END; $$ LANGUAGE plpgsql; +SELECT PLUS_TWO(-1); + plus_two +---------- + 1 +(1 row) + +SELECT PLUS_TWO(2); + plus_two +---------- + 4 +(1 row) + +-- SQL function --- use LIMIT to keep it from being inlined +CREATE FUNCTION PLUS_ONE(i INTEGER) RETURNS INTEGER AS +$$ SELECT (i + 1.0)::INTEGER LIMIT 1 $$ LANGUAGE SQL; +SELECT PLUS_ONE(3); + plus_one +---------- + 4 +(1 row) + +SELECT PLUS_ONE(1); + plus_one +---------- + 2 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 2 | SELECT (i + $2 + $3)::INTEGER + 2 | 2 | SELECT (i + $2)::INTEGER LIMIT $3 + 2 | 2 | SELECT PLUS_ONE($1) + 2 | 2 | SELECT PLUS_TWO($1) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +-- immutable SQL function --- can be executed at plan time +CREATE FUNCTION PLUS_THREE(i INTEGER) RETURNS INTEGER AS +$$ SELECT i + 3 LIMIT 1 $$ IMMUTABLE LANGUAGE SQL; +SELECT PLUS_THREE(8); + plus_three +------------ + 11 +(1 row) + +SELECT PLUS_THREE(10); + plus_three +------------ + 13 +(1 row) + +SELECT toplevel, calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + toplevel | calls | rows | query +----------+-------+------+------------------------------------------------------------------------------- + f | 2 | 2 | SELECT (i + $2 + $3)::INTEGER + f | 2 | 2 | SELECT (i + $2)::INTEGER LIMIT $3 + t | 2 | 2 | SELECT PLUS_ONE($1) + t | 2 | 2 | SELECT PLUS_THREE($1) + t | 2 | 2 | SELECT PLUS_TWO($1) + t | 1 | 5 | SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C" + t | 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + f | 2 | 2 | SELECT i + $2 LIMIT $3 +(8 rows) + +-- +-- edb_stat_statements.track = none +-- +SET edb_stat_statements.track = 'none'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT 1 AS "one"; + one +----- + 1 +(1 row) + +SELECT 1 + 1 AS "two"; + two +----- + 2 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------- +(0 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/level_tracking.out.18 b/edb_stat_statements/expected/level_tracking.out.18 new file mode 100644 index 00000000000..876f0c1aa1e --- /dev/null +++ b/edb_stat_statements/expected/level_tracking.out.18 @@ -0,0 +1,363 @@ +-- +-- Statement level tracking +-- +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- DO block - top-level tracking. +CREATE TABLE stats_track_tab (x int); +SET edb_stat_statements.track = 'top'; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; +$$ LANGUAGE plpgsql; +SELECT toplevel, calls, query FROM edb_stat_statements + WHERE query LIKE '%DELETE%' ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+-------------------------------- + t | 1 | DELETE FROM stats_track_tab + t | 1 | DO $$ + + | | BEGIN + + | | DELETE FROM stats_track_tab;+ + | | END; + + | | $$ LANGUAGE plpgsql +(2 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- DO block - all-level tracking. +SET edb_stat_statements.track = 'all'; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + f | 1 | DELETE FROM stats_track_tab + t | 1 | DELETE FROM stats_track_tab + t | 1 | DO $$ + + | | BEGIN + + | | DELETE FROM stats_track_tab; + + | | END; $$ + t | 1 | DO LANGUAGE plpgsql $$ + + | | BEGIN + + | | -- this is a SELECT + + | | PERFORM 'hello world'::TEXT; + + | | END; $$ + f | 1 | SELECT $1::TEXT + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + t | 1 | SET edb_stat_statements.track = $1 +(7 rows) + +-- Procedure with multiple utility statements. +CREATE OR REPLACE PROCEDURE proc_with_utility_stmt() +LANGUAGE SQL +AS $$ + SHOW edb_stat_statements.track; + show edb_stat_statements.track; + SHOW edb_stat_statements.track_utility; +$$; +SET edb_stat_statements.track_utility = TRUE; +-- all-level tracking. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CALL proc_with_utility_stmt(); +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + t | 1 | CALL proc_with_utility_stmt() + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + f | 2 | SHOW edb_stat_statements.track + f | 1 | SHOW edb_stat_statements.track_utility +(4 rows) + +-- top-level tracking. +SET edb_stat_statements.track = 'top'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CALL proc_with_utility_stmt(); +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + t | 1 | CALL proc_with_utility_stmt() + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(2 rows) + +-- DO block - top-level tracking without utility. +SET edb_stat_statements.track = 'top'; +SET edb_stat_statements.track_utility = FALSE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + t | 1 | DELETE FROM stats_track_tab + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(2 rows) + +-- DO block - all-level tracking without utility. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + toplevel | calls | query +----------+-------+----------------------------------------------------- + f | 1 | DELETE FROM stats_track_tab + t | 1 | DELETE FROM stats_track_tab + f | 1 | SELECT $1::TEXT + t | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(4 rows) + +-- PL/pgSQL function - top-level tracking. +SET edb_stat_statements.track = 'top'; +SET edb_stat_statements.track_utility = FALSE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CREATE FUNCTION PLUS_TWO(i INTEGER) RETURNS INTEGER AS $$ +DECLARE + r INTEGER; +BEGIN + SELECT (i + 1 + 1.0)::INTEGER INTO r; + RETURN r; +END; $$ LANGUAGE plpgsql; +SELECT PLUS_TWO(3); + plus_two +---------- + 5 +(1 row) + +SELECT PLUS_TWO(7); + plus_two +---------- + 9 +(1 row) + +-- SQL function --- use LIMIT to keep it from being inlined +CREATE FUNCTION PLUS_ONE(i INTEGER) RETURNS INTEGER AS +$$ SELECT (i + 1.0)::INTEGER LIMIT 1 $$ LANGUAGE SQL; +SELECT PLUS_ONE(8); + plus_one +---------- + 9 +(1 row) + +SELECT PLUS_ONE(10); + plus_one +---------- + 11 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 2 | SELECT PLUS_ONE($1) + 2 | 2 | SELECT PLUS_TWO($1) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +-- immutable SQL function --- can be executed at plan time +CREATE FUNCTION PLUS_THREE(i INTEGER) RETURNS INTEGER AS +$$ SELECT i + 3 LIMIT 1 $$ IMMUTABLE LANGUAGE SQL; +SELECT PLUS_THREE(8); + plus_three +------------ + 11 +(1 row) + +SELECT PLUS_THREE(10); + plus_three +------------ + 13 +(1 row) + +SELECT toplevel, calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + toplevel | calls | rows | query +----------+-------+------+------------------------------------------------------------------------------- + t | 2 | 2 | SELECT PLUS_ONE($1) + t | 2 | 2 | SELECT PLUS_THREE($1) + t | 2 | 2 | SELECT PLUS_TWO($1) + t | 1 | 3 | SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C" + t | 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +-- PL/pgSQL function - all-level tracking. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- we drop and recreate the functions to avoid any caching funnies +DROP FUNCTION PLUS_ONE(INTEGER); +DROP FUNCTION PLUS_TWO(INTEGER); +DROP FUNCTION PLUS_THREE(INTEGER); +-- PL/pgSQL function +CREATE FUNCTION PLUS_TWO(i INTEGER) RETURNS INTEGER AS $$ +DECLARE + r INTEGER; +BEGIN + SELECT (i + 1 + 1.0)::INTEGER INTO r; + RETURN r; +END; $$ LANGUAGE plpgsql; +SELECT PLUS_TWO(-1); + plus_two +---------- + 1 +(1 row) + +SELECT PLUS_TWO(2); + plus_two +---------- + 4 +(1 row) + +-- SQL function --- use LIMIT to keep it from being inlined +CREATE FUNCTION PLUS_ONE(i INTEGER) RETURNS INTEGER AS +$$ SELECT (i + 1.0)::INTEGER LIMIT 1 $$ LANGUAGE SQL; +SELECT PLUS_ONE(3); + plus_one +---------- + 4 +(1 row) + +SELECT PLUS_ONE(1); + plus_one +---------- + 2 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 2 | SELECT (i + $2 + $3)::INTEGER + 2 | 2 | SELECT (i + $2)::INTEGER LIMIT $3 + 2 | 2 | SELECT PLUS_ONE($1) + 2 | 2 | SELECT PLUS_TWO($1) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +-- immutable SQL function --- can be executed at plan time +CREATE FUNCTION PLUS_THREE(i INTEGER) RETURNS INTEGER AS +$$ SELECT i + 3 LIMIT 1 $$ IMMUTABLE LANGUAGE SQL; +SELECT PLUS_THREE(8); + plus_three +------------ + 11 +(1 row) + +SELECT PLUS_THREE(10); + plus_three +------------ + 13 +(1 row) + +SELECT toplevel, calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + toplevel | calls | rows | query +----------+-------+------+------------------------------------------------------------------------------- + f | 2 | 2 | SELECT (i + $2 + $3)::INTEGER + f | 2 | 2 | SELECT (i + $2)::INTEGER LIMIT $3 + t | 2 | 2 | SELECT PLUS_ONE($1) + t | 2 | 2 | SELECT PLUS_THREE($1) + t | 2 | 2 | SELECT PLUS_TWO($1) + t | 1 | 5 | SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C" + t | 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + f | 2 | 2 | SELECT i + $2 LIMIT $3 +(8 rows) + +-- +-- edb_stat_statements.track = none +-- +SET edb_stat_statements.track = 'none'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT 1 AS "one"; + one +----- + 1 +(1 row) + +SELECT 1 + 1 AS "two"; + two +----- + 2 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------- +(0 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/oldextversions.out b/edb_stat_statements/expected/oldextversions.out new file mode 100644 index 00000000000..6c17cff894a --- /dev/null +++ b/edb_stat_statements/expected/oldextversions.out @@ -0,0 +1,96 @@ +-- test old extension version entry points +CREATE EXTENSION edb_stat_statements WITH VERSION '1.0'; +SELECT pg_get_functiondef('edb_stat_statements_info'::regproc); + pg_get_functiondef +-------------------------------------------------------------------------------------------------------------------------- + CREATE OR REPLACE FUNCTION public.edb_stat_statements_info(OUT dealloc bigint, OUT stats_reset timestamp with time zone)+ + RETURNS record + + LANGUAGE c + + PARALLEL SAFE STRICT + + AS '$libdir/edb_stat_statements', $function$edb_stat_statements_info$function$ + + +(1 row) + +SELECT pg_get_functiondef('edb_stat_statements_reset'::regproc); + pg_get_functiondef +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + CREATE OR REPLACE FUNCTION public.edb_stat_statements_reset(userid oid DEFAULT 0, dbids oid[] DEFAULT '{}'::oid[], queryid bigint DEFAULT 0, minmax_only boolean DEFAULT false)+ + RETURNS timestamp with time zone + + LANGUAGE c + + PARALLEL SAFE STRICT + + AS '$libdir/edb_stat_statements', $function$edb_stat_statements_reset$function$ + + +(1 row) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +\d edb_stat_statements + View "public.edb_stat_statements" + Column | Type | Collation | Nullable | Default +----------------------------+--------------------------+-----------+----------+--------- + userid | oid | | | + dbid | oid | | | + toplevel | boolean | | | + queryid | bigint | | | + query | text | | | + extras | jsonb | | | + id | uuid | | | + stmt_type | smallint | | | + plans | bigint | | | + total_plan_time | double precision | | | + min_plan_time | double precision | | | + max_plan_time | double precision | | | + mean_plan_time | double precision | | | + stddev_plan_time | double precision | | | + calls | bigint | | | + total_exec_time | double precision | | | + min_exec_time | double precision | | | + max_exec_time | double precision | | | + mean_exec_time | double precision | | | + stddev_exec_time | double precision | | | + rows | bigint | | | + shared_blks_hit | bigint | | | + shared_blks_read | bigint | | | + shared_blks_dirtied | bigint | | | + shared_blks_written | bigint | | | + local_blks_hit | bigint | | | + local_blks_read | bigint | | | + local_blks_dirtied | bigint | | | + local_blks_written | bigint | | | + temp_blks_read | bigint | | | + temp_blks_written | bigint | | | + shared_blk_read_time | double precision | | | + shared_blk_write_time | double precision | | | + local_blk_read_time | double precision | | | + local_blk_write_time | double precision | | | + temp_blk_read_time | double precision | | | + temp_blk_write_time | double precision | | | + wal_records | bigint | | | + wal_fpi | bigint | | | + wal_bytes | numeric | | | + jit_functions | bigint | | | + jit_generation_time | double precision | | | + jit_inlining_count | bigint | | | + jit_inlining_time | double precision | | | + jit_optimization_count | bigint | | | + jit_optimization_time | double precision | | | + jit_emission_count | bigint | | | + jit_emission_time | double precision | | | + jit_deform_count | bigint | | | + jit_deform_time | double precision | | | + parallel_workers_to_launch | bigint | | | + parallel_workers_launched | bigint | | | + stats_since | timestamp with time zone | | | + minmax_stats_since | timestamp with time zone | | | + +SELECT count(*) > 0 AS has_data FROM edb_stat_statements; + has_data +---------- + t +(1 row) + +DROP EXTENSION edb_stat_statements; diff --git a/edb_stat_statements/expected/parallel.out.17 b/edb_stat_statements/expected/parallel.out.17 new file mode 100644 index 00000000000..1d4643a3982 --- /dev/null +++ b/edb_stat_statements/expected/parallel.out.17 @@ -0,0 +1,34 @@ +-- +-- Tests for parallel statistics +-- +SET edb_stat_statements.track_utility = FALSE; +-- encourage use of parallel plans +SET parallel_setup_cost = 0; +SET parallel_tuple_cost = 0; +SET min_parallel_table_scan_size = 0; +SET max_parallel_workers_per_gather = 2; +CREATE TABLE pgss_parallel_tab (a int); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT count(*) FROM pgss_parallel_tab; + count +------- + 0 +(1 row) + +SELECT query, + parallel_workers_to_launch > 0 AS has_workers_to_launch, + parallel_workers_launched > 0 AS has_workers_launched + FROM edb_stat_statements + WHERE query ~ 'SELECT count' + ORDER BY query COLLATE "C"; + query | has_workers_to_launch | has_workers_launched +----------------------------------------+-----------------------+---------------------- + SELECT count(*) FROM pgss_parallel_tab | f | f +(1 row) + +DROP TABLE pgss_parallel_tab; diff --git a/edb_stat_statements/expected/parallel.out.18 b/edb_stat_statements/expected/parallel.out.18 new file mode 100644 index 00000000000..aebe94728f1 --- /dev/null +++ b/edb_stat_statements/expected/parallel.out.18 @@ -0,0 +1,34 @@ +-- +-- Tests for parallel statistics +-- +SET edb_stat_statements.track_utility = FALSE; +-- encourage use of parallel plans +SET parallel_setup_cost = 0; +SET parallel_tuple_cost = 0; +SET min_parallel_table_scan_size = 0; +SET max_parallel_workers_per_gather = 2; +CREATE TABLE pgss_parallel_tab (a int); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT count(*) FROM pgss_parallel_tab; + count +------- + 0 +(1 row) + +SELECT query, + parallel_workers_to_launch > 0 AS has_workers_to_launch, + parallel_workers_launched > 0 AS has_workers_launched + FROM edb_stat_statements + WHERE query ~ 'SELECT count' + ORDER BY query COLLATE "C"; + query | has_workers_to_launch | has_workers_launched +----------------------------------------+-----------------------+---------------------- + SELECT count(*) FROM pgss_parallel_tab | t | t +(1 row) + +DROP TABLE pgss_parallel_tab; diff --git a/edb_stat_statements/expected/planning.out b/edb_stat_statements/expected/planning.out new file mode 100644 index 00000000000..758d6daa387 --- /dev/null +++ b/edb_stat_statements/expected/planning.out @@ -0,0 +1,88 @@ +-- +-- Information related to planning +-- +-- These tests require track_planning to be enabled. +SET edb_stat_statements.track_planning = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- +-- [re]plan counting +-- +CREATE TABLE stats_plan_test (); +PREPARE prep1 AS SELECT COUNT(*) FROM stats_plan_test; +EXECUTE prep1; + count +------- + 0 +(1 row) + +EXECUTE prep1; + count +------- + 0 +(1 row) + +EXECUTE prep1; + count +------- + 0 +(1 row) + +ALTER TABLE stats_plan_test ADD COLUMN x int; +EXECUTE prep1; + count +------- + 0 +(1 row) + +SELECT 42; + ?column? +---------- + 42 +(1 row) + +SELECT 42; + ?column? +---------- + 42 +(1 row) + +SELECT 42; + ?column? +---------- + 42 +(1 row) + +SELECT plans, calls, rows, query FROM edb_stat_statements + WHERE query NOT LIKE 'PREPARE%' ORDER BY query COLLATE "C"; + plans | calls | rows | query +-------+-------+------+----------------------------------------------------------- + 0 | 1 | 0 | ALTER TABLE stats_plan_test ADD COLUMN x int + 0 | 1 | 0 | CREATE TABLE stats_plan_test () + 3 | 3 | 3 | SELECT $1 + 0 | 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | 0 | SELECT plans, calls, rows, query FROM edb_stat_statements+ + | | | WHERE query NOT LIKE $1 ORDER BY query COLLATE "C" +(5 rows) + +-- for the prepared statement we expect at least one replan, but cache +-- invalidations could force more +SELECT plans >= 2 AND plans <= calls AS plans_ok, calls, rows, query FROM edb_stat_statements + WHERE query LIKE 'PREPARE%' ORDER BY query COLLATE "C"; + plans_ok | calls | rows | query +----------+-------+------+------------------------------------------------------- + t | 4 | 4 | PREPARE prep1 AS SELECT COUNT(*) FROM stats_plan_test +(1 row) + +-- Cleanup +DROP TABLE stats_plan_test; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/privileges.out b/edb_stat_statements/expected/privileges.out new file mode 100644 index 00000000000..deb9b9b5cdf --- /dev/null +++ b/edb_stat_statements/expected/privileges.out @@ -0,0 +1,97 @@ +-- +-- Only superusers and roles with privileges of the pg_read_all_stats role +-- are allowed to see the SQL text and queryid of queries executed by +-- other users. Other users can see the statistics. +-- +SET edb_stat_statements.track_utility = FALSE; +CREATE ROLE regress_stats_superuser SUPERUSER; +CREATE ROLE regress_stats_user1; +CREATE ROLE regress_stats_user2; +GRANT pg_read_all_stats TO regress_stats_user2; +SET ROLE regress_stats_superuser; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT 1 AS "ONE"; + ONE +----- + 1 +(1 row) + +SET ROLE regress_stats_user1; +SELECT 1+1 AS "TWO"; + TWO +----- + 2 +(1 row) + +-- +-- A superuser can read all columns of queries executed by others, +-- including query text and queryid. +-- +SET ROLE regress_stats_superuser; +SELECT r.rolname, ss.queryid <> 0 AS queryid_bool, ss.query, ss.calls, ss.rows + FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid + ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows; + rolname | queryid_bool | query | calls | rows +-------------------------+--------------+-----------------------------------------------------+-------+------ + regress_stats_superuser | t | SELECT $1 AS "ONE" | 1 | 1 + regress_stats_superuser | t | SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + regress_stats_user1 | t | SELECT $1+$2 AS "TWO" | 1 | 1 +(3 rows) + +-- +-- regress_stats_user1 has no privileges to read the query text or +-- queryid of queries executed by others but can see statistics +-- like calls and rows. +-- +SET ROLE regress_stats_user1; +SELECT r.rolname, ss.queryid <> 0 AS queryid_bool, ss.query, ss.calls, ss.rows + FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid + ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows; + rolname | queryid_bool | query | calls | rows +-------------------------+--------------+--------------------------+-------+------ + regress_stats_superuser | | | 1 | 1 + regress_stats_superuser | | | 1 | 1 + regress_stats_superuser | | | 1 | 3 + regress_stats_user1 | t | SELECT $1+$2 AS "TWO" | 1 | 1 +(4 rows) + +-- +-- regress_stats_user2, with pg_read_all_stats role privileges, can +-- read all columns, including query text and queryid, of queries +-- executed by others. +-- +SET ROLE regress_stats_user2; +SELECT r.rolname, ss.queryid <> 0 AS queryid_bool, ss.query, ss.calls, ss.rows + FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid + ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows; + rolname | queryid_bool | query | calls | rows +-------------------------+--------------+---------------------------------------------------------------------------------+-------+------ + regress_stats_superuser | t | SELECT $1 AS "ONE" | 1 | 1 + regress_stats_superuser | t | SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + regress_stats_superuser | t | SELECT r.rolname, ss.queryid <> $1 AS queryid_bool, ss.query, ss.calls, ss.rows+| 1 | 3 + | | FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid +| | + | | ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows | | + regress_stats_user1 | t | SELECT $1+$2 AS "TWO" | 1 | 1 + regress_stats_user1 | t | SELECT r.rolname, ss.queryid <> $1 AS queryid_bool, ss.query, ss.calls, ss.rows+| 1 | 4 + | | FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid +| | + | | ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows | | +(5 rows) + +-- +-- cleanup +-- +RESET ROLE; +DROP ROLE regress_stats_superuser; +DROP ROLE regress_stats_user1; +DROP ROLE regress_stats_user2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/select.out b/edb_stat_statements/expected/select.out new file mode 100644 index 00000000000..fd92de07082 --- /dev/null +++ b/edb_stat_statements/expected/select.out @@ -0,0 +1,414 @@ +-- +-- SELECT statements +-- +CREATE EXTENSION edb_stat_statements; +SET edb_stat_statements.track_utility = FALSE; +SET edb_stat_statements.track_planning = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- +-- simple and compound statements +-- +SELECT 1 AS "int"; + int +----- + 1 +(1 row) + +SELECT 'hello' + -- multiline + AS "text"; + text +------- + hello +(1 row) + +SELECT 'world' AS "text"; + text +------- + world +(1 row) + +-- transaction +BEGIN; +SELECT 1 AS "int"; + int +----- + 1 +(1 row) + +SELECT 'hello' AS "text"; + text +------- + hello +(1 row) + +COMMIT; +-- compound transaction +BEGIN \; +SELECT 2.0 AS "float" \; +SELECT 'world' AS "text" \; +COMMIT; + float +------- + 2.0 +(1 row) + + text +------- + world +(1 row) + +-- compound with empty statements and spurious leading spacing +\;\; SELECT 3 + 3 \;\;\; SELECT ' ' || ' !' \;\; SELECT 1 + 4 \;; + ?column? +---------- + 6 +(1 row) + + ?column? +---------- + ! +(1 row) + + ?column? +---------- + 5 +(1 row) + +-- non ;-terminated statements +SELECT 1 + 1 + 1 AS "add" \gset +SELECT :add + 1 + 1 AS "add" \; +SELECT :add + 1 + 1 AS "add" \gset + add +----- + 5 +(1 row) + +-- set operator +SELECT 1 AS i UNION SELECT 2 ORDER BY i; + i +--- + 1 + 2 +(2 rows) + +-- ? operator +select '{"a":1, "b":2}'::jsonb ? 'b'; + ?column? +---------- + t +(1 row) + +-- cte +WITH t(f) AS ( + VALUES (1.0), (2.0) +) + SELECT f FROM t ORDER BY f; + f +----- + 1.0 + 2.0 +(2 rows) + +-- prepared statement with parameter +PREPARE pgss_test (int) AS SELECT $1, 'test' LIMIT 1; +EXECUTE pgss_test(1); + ?column? | ?column? +----------+---------- + 1 | test +(1 row) + +DEALLOCATE pgss_test; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------------------- + 1 | 1 | PREPARE pgss_test (int) AS SELECT $1, $2 LIMIT $3 + 4 | 4 | SELECT $1 + + | | -- multiline + + | | AS "text" + 2 | 2 | SELECT $1 + $2 + 3 | 3 | SELECT $1 + $2 + $3 AS "add" + 1 | 1 | SELECT $1 AS "float" + 2 | 2 | SELECT $1 AS "int" + 1 | 2 | SELECT $1 AS i UNION SELECT $2 ORDER BY i + 1 | 1 | SELECT $1 || $2 + 0 | 0 | SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C" + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 2 | WITH t(f) AS ( + + | | VALUES ($1), ($2) + + | | ) + + | | SELECT f FROM t ORDER BY f + 1 | 1 | select $1::jsonb ? $2 +(12 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- +-- queries with locking clauses +-- +CREATE TABLE pgss_a (id integer PRIMARY KEY); +CREATE TABLE pgss_b (id integer PRIMARY KEY, a_id integer REFERENCES pgss_a); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- control query +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id; + id | id | a_id +----+----+------ +(0 rows) + +-- test range tables +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE; + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_a; + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_b; + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_a, pgss_b; -- matches plain "FOR UPDATE" + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_b, pgss_a; + id | id | a_id +----+----+------ +(0 rows) + +-- test strengths +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR NO KEY UPDATE; + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR SHARE; + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR KEY SHARE; + id | id | a_id +----+----+------ +(0 rows) + +-- test wait policies +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE NOWAIT; + id | id | a_id +----+----+------ +(0 rows) + +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE SKIP LOCKED; + id | id | a_id +----+----+------ +(0 rows) + +SELECT calls, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | query +-------+------------------------------------------------------------------------------------------ + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR KEY SHARE + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR NO KEY UPDATE + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR SHARE + 2 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE NOWAIT + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_a + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_b + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_b, pgss_a + 1 | SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE SKIP LOCKED + 0 | SELECT calls, query FROM edb_stat_statements ORDER BY query COLLATE "C" + 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(12 rows) + +DROP TABLE pgss_a, pgss_b CASCADE; +-- +-- access to edb_stat_statements_info view +-- +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT dealloc FROM edb_stat_statements_info; + dealloc +--------- + 0 +(1 row) + +-- FROM [ONLY] +CREATE TABLE tbl_inh(id integer); +CREATE TABLE tbl_inh_1() INHERITS (tbl_inh); +INSERT INTO tbl_inh_1 SELECT 1; +SELECT * FROM tbl_inh; + id +---- + 1 +(1 row) + +SELECT * FROM ONLY tbl_inh; + id +---- +(0 rows) + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%FROM%tbl_inh%'; + count +------- + 2 +(1 row) + +-- WITH TIES +CREATE TABLE limitoption AS SELECT 0 AS val FROM generate_series(1, 10); +SELECT * +FROM limitoption +WHERE val < 2 +ORDER BY val +FETCH FIRST 2 ROWS WITH TIES; + val +----- + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 +(10 rows) + +SELECT * +FROM limitoption +WHERE val < 2 +ORDER BY val +FETCH FIRST 2 ROW ONLY; + val +----- + 0 + 0 +(2 rows) + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%FETCH FIRST%'; + count +------- + 2 +(1 row) + +-- GROUP BY [DISTINCT] +SELECT a, b, c +FROM (VALUES (1, 2, 3), (4, NULL, 6), (7, 8, 9)) AS t (a, b, c) +GROUP BY ROLLUP(a, b), rollup(a, c) +ORDER BY a, b, c; + a | b | c +---+---+--- + 1 | 2 | 3 + 1 | 2 | + 1 | 2 | + 1 | | 3 + 1 | | 3 + 1 | | + 1 | | + 1 | | + 4 | | 6 + 4 | | 6 + 4 | | 6 + 4 | | + 4 | | + 4 | | + 4 | | + 4 | | + 7 | 8 | 9 + 7 | 8 | + 7 | 8 | + 7 | | 9 + 7 | | 9 + 7 | | + 7 | | + 7 | | + | | +(25 rows) + +SELECT a, b, c +FROM (VALUES (1, 2, 3), (4, NULL, 6), (7, 8, 9)) AS t (a, b, c) +GROUP BY DISTINCT ROLLUP(a, b), rollup(a, c) +ORDER BY a, b, c; + a | b | c +---+---+--- + 1 | 2 | 3 + 1 | 2 | + 1 | | 3 + 1 | | + 4 | | 6 + 4 | | 6 + 4 | | + 4 | | + 7 | 8 | 9 + 7 | 8 | + 7 | | 9 + 7 | | + | | +(13 rows) + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%GROUP BY%ROLLUP%'; + count +------- + 2 +(1 row) + +-- GROUPING SET agglevelsup +SELECT ( + SELECT ( + SELECT GROUPING(a,b) FROM (VALUES (1)) v2(c) + ) FROM (VALUES (1,2)) v1(a,b) GROUP BY (a,b) +) FROM (VALUES(6,7)) v3(e,f) GROUP BY ROLLUP(e,f); + grouping +---------- + 0 + 0 + 0 +(3 rows) + +SELECT ( + SELECT ( + SELECT GROUPING(e,f) FROM (VALUES (1)) v2(c) + ) FROM (VALUES (1,2)) v1(a,b) GROUP BY (a,b) +) FROM (VALUES(6,7)) v3(e,f) GROUP BY ROLLUP(e,f); + grouping +---------- + 3 + 0 + 1 +(3 rows) + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%SELECT GROUPING%'; + count +------- + 2 +(1 row) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/user_activity.out b/edb_stat_statements/expected/user_activity.out new file mode 100644 index 00000000000..8004961034e --- /dev/null +++ b/edb_stat_statements/expected/user_activity.out @@ -0,0 +1,209 @@ +-- +-- Track user activity and reset them +-- +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CREATE ROLE regress_stats_user1; +CREATE ROLE regress_stats_user2; +SET ROLE regress_stats_user1; +SELECT 1 AS "ONE"; + ONE +----- + 1 +(1 row) + +SELECT 1+1 AS "TWO"; + TWO +----- + 2 +(1 row) + +RESET ROLE; +SET ROLE regress_stats_user2; +SELECT 1 AS "ONE"; + ONE +----- + 1 +(1 row) + +SELECT 1+1 AS "TWO"; + TWO +----- + 2 +(1 row) + +RESET ROLE; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows +-----------------------------------------------------+-------+------ + CREATE ROLE regress_stats_user1 | 1 | 0 + CREATE ROLE regress_stats_user2 | 1 | 0 + RESET ROLE | 2 | 0 + SELECT $1 AS "ONE" | 1 | 1 + SELECT $1 AS "ONE" | 1 | 1 + SELECT $1+$2 AS "TWO" | 1 | 1 + SELECT $1+$2 AS "TWO" | 1 | 1 + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + SET ROLE regress_stats_user1 | 1 | 0 + SET ROLE regress_stats_user2 | 1 | 0 +(10 rows) + +-- +-- Don't reset anything if any of the parameter is NULL +-- +SELECT edb_stat_statements_reset(NULL) IS NOT NULL AS t; + t +--- + f +(1 row) + +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows +-------------------------------------------------------------------------------+-------+------ + CREATE ROLE regress_stats_user1 | 1 | 0 + CREATE ROLE regress_stats_user2 | 1 | 0 + RESET ROLE | 2 | 0 + SELECT $1 AS "ONE" | 1 | 1 + SELECT $1 AS "ONE" | 1 | 1 + SELECT $1+$2 AS "TWO" | 1 | 1 + SELECT $1+$2 AS "TWO" | 1 | 1 + SELECT edb_stat_statements_reset($1) IS NOT NULL AS t | 1 | 1 + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C" | 1 | 10 + SET ROLE regress_stats_user1 | 1 | 0 + SET ROLE regress_stats_user2 | 1 | 0 +(12 rows) + +-- +-- remove query ('SELECT $1+$2 AS "TWO"') executed by regress_stats_user2 +-- in the current_database +-- +SELECT edb_stat_statements_reset( + (SELECT r.oid FROM pg_roles AS r WHERE r.rolname = 'regress_stats_user2'), + ARRAY(SELECT d.oid FROM pg_database As d where datname = current_database()), + (SELECT s.queryid FROM edb_stat_statements AS s + WHERE s.query = 'SELECT $1+$2 AS "TWO"' LIMIT 1)) + IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows +---------------------------------------------------------------------------------------+-------+------ + CREATE ROLE regress_stats_user1 | 1 | 0 + CREATE ROLE regress_stats_user2 | 1 | 0 + RESET ROLE | 2 | 0 + SELECT $1 AS "ONE" | 1 | 1 + SELECT $1 AS "ONE" | 1 | 1 + SELECT $1+$2 AS "TWO" | 1 | 1 + SELECT edb_stat_statements_reset( +| 1 | 1 + (SELECT r.oid FROM pg_roles AS r WHERE r.rolname = $1), +| | + ARRAY(SELECT d.oid FROM pg_database As d where datname = current_database()),+| | + (SELECT s.queryid FROM edb_stat_statements AS s +| | + WHERE s.query = $2 LIMIT $3)) +| | + IS NOT NULL AS t | | + SELECT edb_stat_statements_reset($1) IS NOT NULL AS t | 1 | 1 + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C" | 2 | 22 + SET ROLE regress_stats_user1 | 1 | 0 + SET ROLE regress_stats_user2 | 1 | 0 +(12 rows) + +-- +-- remove query ('SELECT $1 AS "ONE"') executed by two users +-- +SELECT edb_stat_statements_reset(0,'{}',s.queryid) IS NOT NULL AS t + FROM edb_stat_statements AS s WHERE s.query = 'SELECT $1 AS "ONE"'; + t +--- + t + t +(2 rows) + +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows +---------------------------------------------------------------------------------------+-------+------ + CREATE ROLE regress_stats_user1 | 1 | 0 + CREATE ROLE regress_stats_user2 | 1 | 0 + RESET ROLE | 2 | 0 + SELECT $1+$2 AS "TWO" | 1 | 1 + SELECT edb_stat_statements_reset( +| 1 | 1 + (SELECT r.oid FROM pg_roles AS r WHERE r.rolname = $1), +| | + ARRAY(SELECT d.oid FROM pg_database As d where datname = current_database()),+| | + (SELECT s.queryid FROM edb_stat_statements AS s +| | + WHERE s.query = $2 LIMIT $3)) +| | + IS NOT NULL AS t | | + SELECT edb_stat_statements_reset($1) IS NOT NULL AS t | 1 | 1 + SELECT edb_stat_statements_reset($1,$2,s.queryid) IS NOT NULL AS t +| 1 | 2 + FROM edb_stat_statements AS s WHERE s.query = $3 | | + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C" | 3 | 34 + SET ROLE regress_stats_user1 | 1 | 0 + SET ROLE regress_stats_user2 | 1 | 0 +(11 rows) + +-- +-- remove query of a user (regress_stats_user1) +-- +SELECT edb_stat_statements_reset(r.oid) IS NOT NULL AS t + FROM pg_roles AS r WHERE r.rolname = 'regress_stats_user1'; + t +--- + t +(1 row) + +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows +---------------------------------------------------------------------------------------+-------+------ + CREATE ROLE regress_stats_user1 | 1 | 0 + CREATE ROLE regress_stats_user2 | 1 | 0 + RESET ROLE | 2 | 0 + SELECT edb_stat_statements_reset( +| 1 | 1 + (SELECT r.oid FROM pg_roles AS r WHERE r.rolname = $1), +| | + ARRAY(SELECT d.oid FROM pg_database As d where datname = current_database()),+| | + (SELECT s.queryid FROM edb_stat_statements AS s +| | + WHERE s.query = $2 LIMIT $3)) +| | + IS NOT NULL AS t | | + SELECT edb_stat_statements_reset($1) IS NOT NULL AS t | 1 | 1 + SELECT edb_stat_statements_reset($1,$2,s.queryid) IS NOT NULL AS t +| 1 | 2 + FROM edb_stat_statements AS s WHERE s.query = $3 | | + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 + SELECT edb_stat_statements_reset(r.oid) IS NOT NULL AS t +| 1 | 1 + FROM pg_roles AS r WHERE r.rolname = $1 | | + SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C" | 4 | 45 + SET ROLE regress_stats_user2 | 1 | 0 +(10 rows) + +-- +-- reset all +-- +SELECT edb_stat_statements_reset(0,'{}',0) IS NOT NULL AS t; + t +--- + t +(1 row) + +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows +-------------------------------------------------------------+-------+------ + SELECT edb_stat_statements_reset(0,'{}',0) IS NOT NULL AS t | 1 | 1 +(1 row) + +-- +-- cleanup +-- +DROP ROLE regress_stats_user1; +DROP ROLE regress_stats_user2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/utility.out.16 b/edb_stat_statements/expected/utility.out.16 new file mode 100644 index 00000000000..bd297986023 --- /dev/null +++ b/edb_stat_statements/expected/utility.out.16 @@ -0,0 +1,756 @@ +-- +-- Utility commands +-- +-- These tests require track_utility to be enabled. +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Tables, indexes, triggers +CREATE TEMP TABLE tab_stats (a int, b char(20)); +CREATE INDEX index_stats ON tab_stats(b, (b || 'data1'), (b || 'data2')) WHERE a > 0; +ALTER TABLE tab_stats ALTER COLUMN b set default 'a'; +ALTER TABLE tab_stats ALTER COLUMN b TYPE text USING 'data' || b; +ALTER TABLE tab_stats ADD CONSTRAINT a_nonzero CHECK (a <> 0); +DROP TABLE tab_stats \; +DROP TABLE IF EXISTS tab_stats \; +-- This DROP query uses two different strings, still they count as one entry. +DROP TABLE IF EXISTS tab_stats \; +Drop Table If Exists tab_stats \; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +NOTICE: table "tab_stats" does not exist, skipping +NOTICE: table "tab_stats" does not exist, skipping +NOTICE: table "tab_stats" does not exist, skipping + calls | rows | query +-------+------+-------------------------------------------------------------------------------------- + 1 | 0 | ALTER TABLE tab_stats ADD CONSTRAINT a_nonzero CHECK (a <> 0) + 1 | 0 | ALTER TABLE tab_stats ALTER COLUMN b TYPE text USING 'data' || b + 1 | 0 | ALTER TABLE tab_stats ALTER COLUMN b set default 'a' + 1 | 0 | CREATE INDEX index_stats ON tab_stats(b, (b || 'data1'), (b || 'data2')) WHERE a > 0 + 1 | 0 | CREATE TEMP TABLE tab_stats (a int, b char(20)) + 3 | 0 | DROP TABLE IF EXISTS tab_stats + 1 | 0 | DROP TABLE tab_stats + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(8 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Partitions +CREATE TABLE pt_stats (a int, b int) PARTITION BY range (a); +CREATE TABLE pt_stats1 (a int, b int); +ALTER TABLE pt_stats ATTACH PARTITION pt_stats1 FOR VALUES FROM (0) TO (100); +CREATE TABLE pt_stats2 PARTITION OF pt_stats FOR VALUES FROM (100) TO (200); +CREATE INDEX pt_stats_index ON ONLY pt_stats (a); +CREATE INDEX pt_stats2_index ON ONLY pt_stats2 (a); +ALTER INDEX pt_stats_index ATTACH PARTITION pt_stats2_index; +DROP TABLE pt_stats; +-- Views +CREATE VIEW view_stats AS SELECT 1::int AS a, 2::int AS b; +ALTER VIEW view_stats ALTER COLUMN a SET DEFAULT 2; +DROP VIEW view_stats; +-- Foreign tables +CREATE FOREIGN DATA WRAPPER wrapper_stats; +CREATE SERVER server_stats FOREIGN DATA WRAPPER wrapper_stats; +CREATE FOREIGN TABLE foreign_stats (a int) SERVER server_stats; +ALTER FOREIGN TABLE foreign_stats ADD COLUMN b integer DEFAULT 1; +ALTER FOREIGN TABLE foreign_stats ADD CONSTRAINT b_nonzero CHECK (b <> 0); +DROP FOREIGN TABLE foreign_stats; +DROP SERVER server_stats; +DROP FOREIGN DATA WRAPPER wrapper_stats; +-- Functions +CREATE FUNCTION func_stats(a text DEFAULT 'a_data', b text DEFAULT lower('b_data')) + RETURNS text AS $$ SELECT $1::text || '_' || $2::text; $$ LANGUAGE SQL + SET work_mem = '256kB'; +DROP FUNCTION func_stats; +-- Rules +CREATE TABLE tab_rule_stats (a int, b int); +CREATE TABLE tab_rule_stats_2 (a int, b int, c int, d int); +CREATE RULE rules_stats AS ON INSERT TO tab_rule_stats DO INSTEAD + INSERT INTO tab_rule_stats_2 VALUES(new.*, 1, 2); +DROP RULE rules_stats ON tab_rule_stats; +DROP TABLE tab_rule_stats, tab_rule_stats_2; +-- Types +CREATE TYPE stats_type as (f1 numeric(35, 6), f2 numeric(35, 2)); +DROP TYPE stats_type; +-- Triggers +CREATE TABLE trigger_tab_stats (a int, b int); +CREATE FUNCTION trigger_func_stats () RETURNS trigger LANGUAGE plpgsql + AS $$ BEGIN return OLD; end; $$; +CREATE TRIGGER trigger_tab_stats + AFTER UPDATE ON trigger_tab_stats + FOR EACH ROW WHEN (OLD.a < 0 AND OLD.b < 1 AND true) + EXECUTE FUNCTION trigger_func_stats(); +DROP TABLE trigger_tab_stats; +-- Policies +CREATE TABLE tab_policy_stats (a int, b int); +CREATE POLICY policy_stats ON tab_policy_stats USING (a = 5) WITH CHECK (b < 5); +DROP TABLE tab_policy_stats; +-- Statistics +CREATE TABLE tab_expr_stats (a int, b int); +CREATE STATISTICS tab_expr_stats_1 (mcv) ON a, (2*a), (3*b) FROM tab_expr_stats; +DROP TABLE tab_expr_stats; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------------------------- + 1 | 0 | ALTER FOREIGN TABLE foreign_stats ADD COLUMN b integer DEFAULT 1 + 1 | 0 | ALTER FOREIGN TABLE foreign_stats ADD CONSTRAINT b_nonzero CHECK (b <> 0) + 1 | 0 | ALTER INDEX pt_stats_index ATTACH PARTITION pt_stats2_index + 1 | 0 | ALTER TABLE pt_stats ATTACH PARTITION pt_stats1 FOR VALUES FROM (0) TO (100) + 1 | 0 | ALTER VIEW view_stats ALTER COLUMN a SET DEFAULT 2 + 1 | 0 | CREATE FOREIGN DATA WRAPPER wrapper_stats + 1 | 0 | CREATE FOREIGN TABLE foreign_stats (a int) SERVER server_stats + 1 | 0 | CREATE FUNCTION func_stats(a text DEFAULT 'a_data', b text DEFAULT lower('b_data'))+ + | | RETURNS text AS $$ SELECT $1::text || '_' || $2::text; $$ LANGUAGE SQL + + | | SET work_mem = '256kB' + 1 | 0 | CREATE FUNCTION trigger_func_stats () RETURNS trigger LANGUAGE plpgsql + + | | AS $$ BEGIN return OLD; end; $$ + 1 | 0 | CREATE INDEX pt_stats2_index ON ONLY pt_stats2 (a) + 1 | 0 | CREATE INDEX pt_stats_index ON ONLY pt_stats (a) + 1 | 0 | CREATE POLICY policy_stats ON tab_policy_stats USING (a = 5) WITH CHECK (b < 5) + 1 | 0 | CREATE RULE rules_stats AS ON INSERT TO tab_rule_stats DO INSTEAD + + | | INSERT INTO tab_rule_stats_2 VALUES(new.*, 1, 2) + 1 | 0 | CREATE SERVER server_stats FOREIGN DATA WRAPPER wrapper_stats + 1 | 0 | CREATE STATISTICS tab_expr_stats_1 (mcv) ON a, (2*a), (3*b) FROM tab_expr_stats + 1 | 0 | CREATE TABLE pt_stats (a int, b int) PARTITION BY range (a) + 1 | 0 | CREATE TABLE pt_stats1 (a int, b int) + 1 | 0 | CREATE TABLE pt_stats2 PARTITION OF pt_stats FOR VALUES FROM (100) TO (200) + 1 | 0 | CREATE TABLE tab_expr_stats (a int, b int) + 1 | 0 | CREATE TABLE tab_policy_stats (a int, b int) + 1 | 0 | CREATE TABLE tab_rule_stats (a int, b int) + 1 | 0 | CREATE TABLE tab_rule_stats_2 (a int, b int, c int, d int) + 1 | 0 | CREATE TABLE trigger_tab_stats (a int, b int) + 1 | 0 | CREATE TRIGGER trigger_tab_stats + + | | AFTER UPDATE ON trigger_tab_stats + + | | FOR EACH ROW WHEN (OLD.a < 0 AND OLD.b < 1 AND true) + + | | EXECUTE FUNCTION trigger_func_stats() + 1 | 0 | CREATE TYPE stats_type as (f1 numeric(35, 6), f2 numeric(35, 2)) + 1 | 0 | CREATE VIEW view_stats AS SELECT 1::int AS a, 2::int AS b + 1 | 0 | DROP FOREIGN DATA WRAPPER wrapper_stats + 1 | 0 | DROP FOREIGN TABLE foreign_stats + 1 | 0 | DROP FUNCTION func_stats + 1 | 0 | DROP RULE rules_stats ON tab_rule_stats + 1 | 0 | DROP SERVER server_stats + 1 | 0 | DROP TABLE pt_stats + 1 | 0 | DROP TABLE tab_expr_stats + 1 | 0 | DROP TABLE tab_policy_stats + 1 | 0 | DROP TABLE tab_rule_stats, tab_rule_stats_2 + 1 | 0 | DROP TABLE trigger_tab_stats + 1 | 0 | DROP TYPE stats_type + 1 | 0 | DROP VIEW view_stats + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(39 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Transaction statements +BEGIN; +ABORT; +BEGIN; +ROLLBACK; +-- WORK +BEGIN WORK; +COMMIT WORK; +BEGIN WORK; +ABORT WORK; +-- TRANSACTION +BEGIN TRANSACTION; +COMMIT TRANSACTION; +BEGIN TRANSACTION; +ABORT TRANSACTION; +-- More isolation levels +BEGIN TRANSACTION DEFERRABLE; +COMMIT TRANSACTION AND NO CHAIN; +BEGIN ISOLATION LEVEL SERIALIZABLE; +COMMIT; +BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE; +COMMIT; +-- List of A_Const nodes, same lists. +BEGIN TRANSACTION READ ONLY, READ WRITE, DEFERRABLE, NOT DEFERRABLE; +COMMIT; +BEGIN TRANSACTION NOT DEFERRABLE, READ ONLY, READ WRITE, DEFERRABLE; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+--------------------------------------------------------------------- + 4 | 0 | ABORT + 6 | 0 | BEGIN + 2 | 0 | BEGIN ISOLATION LEVEL SERIALIZABLE + 1 | 0 | BEGIN TRANSACTION DEFERRABLE + 1 | 0 | BEGIN TRANSACTION NOT DEFERRABLE, READ ONLY, READ WRITE, DEFERRABLE + 1 | 0 | BEGIN TRANSACTION READ ONLY, READ WRITE, DEFERRABLE, NOT DEFERRABLE + 7 | 0 | COMMIT WORK + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(8 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Two-phase transactions +BEGIN; +PREPARE TRANSACTION 'stat_trans1'; +COMMIT PREPARED 'stat_trans1'; +BEGIN; +PREPARE TRANSACTION 'stat_trans2'; +ROLLBACK PREPARED 'stat_trans2'; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 0 | BEGIN + 1 | 0 | COMMIT PREPARED 'stat_trans1' + 1 | 0 | PREPARE TRANSACTION 'stat_trans1' + 1 | 0 | PREPARE TRANSACTION 'stat_trans2' + 1 | 0 | ROLLBACK PREPARED 'stat_trans2' + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(6 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Savepoints +BEGIN; +SAVEPOINT sp1; +SAVEPOINT sp2; +SAVEPOINT sp3; +SAVEPOINT sp4; +ROLLBACK TO sp4; +ROLLBACK TO SAVEPOINT sp4; +ROLLBACK TRANSACTION TO SAVEPOINT sp3; +RELEASE sp3; +RELEASE SAVEPOINT sp2; +ROLLBACK TO sp1; +RELEASE SAVEPOINT sp1; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 1 | 0 | BEGIN + 1 | 0 | COMMIT + 1 | 0 | RELEASE SAVEPOINT sp1 + 1 | 0 | RELEASE SAVEPOINT sp2 + 1 | 0 | RELEASE sp3 + 1 | 0 | ROLLBACK TO sp1 + 2 | 0 | ROLLBACK TO sp4 + 1 | 0 | ROLLBACK TRANSACTION TO SAVEPOINT sp3 + 1 | 0 | SAVEPOINT sp1 + 1 | 0 | SAVEPOINT sp2 + 1 | 0 | SAVEPOINT sp3 + 1 | 0 | SAVEPOINT sp4 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(13 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- EXPLAIN statements +-- A Query is used, normalized by the query jumbling. +EXPLAIN (costs off) SELECT 1; + QUERY PLAN +------------ + Result +(1 row) + +EXPLAIN (costs off) SELECT 2; + QUERY PLAN +------------ + Result +(1 row) + +EXPLAIN (costs off) SELECT a FROM generate_series(1,10) AS tab(a) WHERE a = 3; + QUERY PLAN +-------------------------------------- + Function Scan on generate_series tab + Filter: (a = 3) +(2 rows) + +EXPLAIN (costs off) SELECT a FROM generate_series(1,10) AS tab(a) WHERE a = 7; + QUERY PLAN +-------------------------------------- + Function Scan on generate_series tab + Filter: (a = 7) +(2 rows) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+--------------------------------------------------------------------------------- + 2 | 0 | EXPLAIN (costs off) SELECT $1 + 2 | 0 | EXPLAIN (costs off) SELECT a FROM generate_series($1,$2) AS tab(a) WHERE a = $3 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +-- CALL +CREATE OR REPLACE PROCEDURE sum_one(i int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + i)::int INTO r; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE PROCEDURE sum_two(i int, j int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + j)::int INTO r; +END; $$ LANGUAGE plpgsql; +-- Overloaded functions. +CREATE OR REPLACE PROCEDURE overload(i int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + i)::int INTO r; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE PROCEDURE overload(i text) AS $$ +DECLARE + r text; +BEGIN + SELECT i::text INTO r; +END; $$ LANGUAGE plpgsql; +-- Mix of IN/OUT parameters. +CREATE OR REPLACE PROCEDURE in_out(i int, i2 OUT int, i3 INOUT int) AS $$ +DECLARE + r int; +BEGIN + i2 := i; + i3 := i3 + i; +END; $$ LANGUAGE plpgsql; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CALL sum_one(3); +CALL sum_one(199); +CALL sum_two(1,1); +CALL sum_two(1,2); +CALL overload(1); +CALL overload('A'); +CALL in_out(1, NULL, 1); + i2 | i3 +----+---- + 1 | 2 +(1 row) + +CALL in_out(2, 1, 2); + i2 | i3 +----+---- + 2 | 4 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 1 | 0 | CALL in_out(1, NULL, 1) + 1 | 0 | CALL in_out(2, 1, 2) + 1 | 0 | CALL overload('A') + 1 | 0 | CALL overload(1) + 1 | 0 | CALL sum_one(199) + 1 | 0 | CALL sum_one(3) + 1 | 0 | CALL sum_two(1,1) + 1 | 0 | CALL sum_two(1,2) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(9 rows) + +-- COPY +CREATE TABLE copy_stats (a int, b int); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Some queries with A_Const nodes. +COPY (SELECT 1) TO STDOUT; +1 +COPY (SELECT 2) TO STDOUT; +2 +COPY (INSERT INTO copy_stats VALUES (1, 1) RETURNING *) TO STDOUT; +1 1 +COPY (INSERT INTO copy_stats VALUES (2, 2) RETURNING *) TO STDOUT; +2 2 +COPY (UPDATE copy_stats SET b = b + 1 RETURNING *) TO STDOUT; +1 2 +2 3 +COPY (UPDATE copy_stats SET b = b + 2 RETURNING *) TO STDOUT; +1 4 +2 5 +COPY (DELETE FROM copy_stats WHERE a = 1 RETURNING *) TO STDOUT; +1 4 +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------- + 1 | 1 | COPY (DELETE FROM copy_stats WHERE a = 1 RETURNING *) TO STDOUT + 1 | 1 | COPY (INSERT INTO copy_stats VALUES (1, 1) RETURNING *) TO STDOUT + 1 | 1 | COPY (INSERT INTO copy_stats VALUES (2, 2) RETURNING *) TO STDOUT + 1 | 1 | COPY (SELECT 1) TO STDOUT + 1 | 1 | COPY (SELECT 2) TO STDOUT + 1 | 2 | COPY (UPDATE copy_stats SET b = b + 1 RETURNING *) TO STDOUT + 1 | 2 | COPY (UPDATE copy_stats SET b = b + 2 RETURNING *) TO STDOUT + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(8 rows) + +DROP TABLE copy_stats; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- CREATE TABLE AS +-- SELECT queries are normalized, creating matching query IDs. +CREATE TABLE ctas_stats_1 AS SELECT 1 AS a; +DROP TABLE ctas_stats_1; +CREATE TABLE ctas_stats_1 AS SELECT 2 AS a; +DROP TABLE ctas_stats_1; +CREATE TABLE ctas_stats_2 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP TABLE ctas_stats_2; +CREATE TABLE ctas_stats_2 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 1; +DROP TABLE ctas_stats_2; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+-------------------------------------------------------------------- + 2 | 2 | CREATE TABLE ctas_stats_1 AS SELECT $1 AS a + 2 | 4 | CREATE TABLE ctas_stats_2 AS + + | | SELECT a AS col1, $1::int AS col2 + + | | FROM generate_series($2, $3) AS tab(a) WHERE a < $4 AND a > $5 + 2 | 0 | DROP TABLE ctas_stats_1 + 2 | 0 | DROP TABLE ctas_stats_2 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- CREATE MATERIALIZED VIEW +-- SELECT queries are normalized, creating matching query IDs. +CREATE MATERIALIZED VIEW matview_stats_1 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP MATERIALIZED VIEW matview_stats_1; +CREATE MATERIALIZED VIEW matview_stats_1 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3; +DROP MATERIALIZED VIEW matview_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+-------------------------------------------------------------------- + 2 | 2 | CREATE MATERIALIZED VIEW matview_stats_1 AS + + | | SELECT a AS col1, $1::int AS col2 + + | | FROM generate_series($2, $3) AS tab(a) WHERE a < $4 AND a > $5 + 2 | 0 | DROP MATERIALIZED VIEW matview_stats_1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- CREATE VIEW +CREATE VIEW view_stats_1 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP VIEW view_stats_1; +CREATE VIEW view_stats_1 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3; +DROP VIEW view_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------------------- + 1 | 0 | CREATE VIEW view_stats_1 AS + + | | SELECT a AS col1, 2::int AS col2 + + | | FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2 + 1 | 0 | CREATE VIEW view_stats_1 AS + + | | SELECT a AS col1, 4::int AS col2 + + | | FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3 + 2 | 0 | DROP VIEW view_stats_1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(4 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Domains +CREATE DOMAIN domain_stats AS int CHECK (VALUE > 0); +ALTER DOMAIN domain_stats SET DEFAULT '3'; +ALTER DOMAIN domain_stats ADD CONSTRAINT higher_than_one CHECK (VALUE > 1); +DROP DOMAIN domain_stats; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+---------------------------------------------------------------------------- + 1 | 0 | ALTER DOMAIN domain_stats ADD CONSTRAINT higher_than_one CHECK (VALUE > 1) + 1 | 0 | ALTER DOMAIN domain_stats SET DEFAULT '3' + 1 | 0 | CREATE DOMAIN domain_stats AS int CHECK (VALUE > 0) + 1 | 0 | DROP DOMAIN domain_stats + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Execution statements +SELECT 1 as a; + a +--- + 1 +(1 row) + +PREPARE stat_select AS SELECT $1 AS a; +EXECUTE stat_select (1); + a +--- + 1 +(1 row) + +DEALLOCATE stat_select; +PREPARE stat_select AS SELECT $1 AS a; +EXECUTE stat_select (2); + a +--- + 2 +(1 row) + +DEALLOCATE PREPARE stat_select; +DEALLOCATE ALL; +DEALLOCATE PREPARE ALL; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 0 | DEALLOCATE ALL + 2 | 0 | DEALLOCATE stat_select + 2 | 2 | PREPARE stat_select AS SELECT $1 AS a + 1 | 1 | SELECT $1 as a + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- SET statements. +-- These use two different strings, still they count as one entry. +CREATE ROLE regress_stat_set_1; +CREATE ROLE regress_stat_set_2; +SET work_mem = '1MB'; +Set work_mem = '1MB'; +SET work_mem = '2MB'; +SET work_mem = DEFAULT; +SET work_mem TO DEFAULT; +SET work_mem FROM CURRENT; +BEGIN; +SET LOCAL work_mem = '128kB'; +SET LOCAL work_mem = '256kB'; +SET LOCAL work_mem = DEFAULT; +SET LOCAL work_mem TO DEFAULT; +SET LOCAL work_mem FROM CURRENT; +COMMIT; +RESET work_mem; +SET enable_seqscan = off; +SET enable_seqscan = on; +SET SESSION work_mem = '300kB'; +SET SESSION work_mem = '400kB'; +RESET enable_seqscan; +-- SET TRANSACTION ISOLATION +BEGIN; +SET TRANSACTION ISOLATION LEVEL READ COMMITTED; +SET TRANSACTION ISOLATION LEVEL REPEATABLE READ; +SET TRANSACTION ISOLATION LEVEL SERIALIZABLE; +COMMIT; +-- SET SESSION AUTHORIZATION +SET SESSION SESSION AUTHORIZATION DEFAULT; +SET SESSION AUTHORIZATION 'regress_stat_set_1'; +SET SESSION AUTHORIZATION 'regress_stat_set_2'; +RESET SESSION AUTHORIZATION; +BEGIN; +SET LOCAL SESSION AUTHORIZATION DEFAULT; +SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_1'; +SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_2'; +RESET SESSION AUTHORIZATION; +COMMIT; +-- SET SESSION CHARACTERISTICS +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY; +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ ONLY; +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ WRITE; +-- SET XML OPTION +SET XML OPTION DOCUMENT; +SET XML OPTION CONTENT; +-- SET TIME ZONE +SET TIME ZONE 'America/New_York'; +SET TIME ZONE 'Asia/Tokyo'; +SET TIME ZONE DEFAULT; +SET TIME ZONE LOCAL; +SET TIME ZONE 'CST7CDT,M4.1.0,M10.5.0'; +RESET TIME ZONE; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------ + 3 | 0 | BEGIN + 3 | 0 | COMMIT + 1 | 0 | CREATE ROLE regress_stat_set_1 + 1 | 0 | CREATE ROLE regress_stat_set_2 + 2 | 0 | RESET SESSION AUTHORIZATION + 1 | 0 | RESET TIME ZONE + 1 | 0 | RESET enable_seqscan + 1 | 0 | RESET work_mem + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_1' + 1 | 0 | SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_2' + 1 | 0 | SET LOCAL SESSION AUTHORIZATION DEFAULT + 1 | 0 | SET LOCAL work_mem = '128kB' + 1 | 0 | SET LOCAL work_mem = '256kB' + 2 | 0 | SET LOCAL work_mem = DEFAULT + 1 | 0 | SET LOCAL work_mem FROM CURRENT + 1 | 0 | SET SESSION AUTHORIZATION 'regress_stat_set_1' + 1 | 0 | SET SESSION AUTHORIZATION 'regress_stat_set_2' + 1 | 0 | SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY + 1 | 0 | SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ ONLY + 1 | 0 | SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ WRITE + 1 | 0 | SET SESSION SESSION AUTHORIZATION DEFAULT + 1 | 0 | SET SESSION work_mem = '300kB' + 1 | 0 | SET SESSION work_mem = '400kB' + 1 | 0 | SET TIME ZONE 'America/New_York' + 1 | 0 | SET TIME ZONE 'Asia/Tokyo' + 1 | 0 | SET TIME ZONE 'CST7CDT,M4.1.0,M10.5.0' + 2 | 0 | SET TIME ZONE DEFAULT + 1 | 0 | SET TRANSACTION ISOLATION LEVEL READ COMMITTED + 1 | 0 | SET TRANSACTION ISOLATION LEVEL REPEATABLE READ + 1 | 0 | SET TRANSACTION ISOLATION LEVEL SERIALIZABLE + 1 | 0 | SET XML OPTION CONTENT + 1 | 0 | SET XML OPTION DOCUMENT + 1 | 0 | SET enable_seqscan = off + 1 | 0 | SET enable_seqscan = on + 2 | 0 | SET work_mem = '1MB' + 1 | 0 | SET work_mem = '2MB' + 2 | 0 | SET work_mem = DEFAULT + 1 | 0 | SET work_mem FROM CURRENT +(39 rows) + +DROP ROLE regress_stat_set_1; +DROP ROLE regress_stat_set_2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- +-- Track the total number of rows retrieved or affected by the utility +-- commands of COPY, FETCH, CREATE TABLE AS, CREATE MATERIALIZED VIEW, +-- REFRESH MATERIALIZED VIEW and SELECT INTO +-- +CREATE TABLE pgss_ctas AS SELECT a, 'ctas' b FROM generate_series(1, 10) a; +SELECT generate_series(1, 10) c INTO pgss_select_into; +COPY pgss_ctas (a, b) FROM STDIN; +CREATE MATERIALIZED VIEW pgss_matv AS SELECT * FROM pgss_ctas; +REFRESH MATERIALIZED VIEW pgss_matv; +BEGIN; +DECLARE pgss_cursor CURSOR FOR SELECT * FROM pgss_matv; +FETCH NEXT pgss_cursor; + a | b +---+------ + 1 | ctas +(1 row) + +FETCH FORWARD 5 pgss_cursor; + a | b +---+------ + 2 | ctas + 3 | ctas + 4 | ctas + 5 | ctas + 6 | ctas +(5 rows) + +FETCH FORWARD ALL pgss_cursor; + a | b +----+------ + 7 | ctas + 8 | ctas + 9 | ctas + 10 | ctas + 11 | copy + 12 | copy + 13 | copy +(7 rows) + +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------------- + 1 | 0 | BEGIN + 1 | 0 | COMMIT + 1 | 3 | COPY pgss_ctas (a, b) FROM STDIN + 1 | 13 | CREATE MATERIALIZED VIEW pgss_matv AS SELECT * FROM pgss_ctas + 1 | 10 | CREATE TABLE pgss_ctas AS SELECT a, $1 b FROM generate_series($2, $3) a + 1 | 0 | DECLARE pgss_cursor CURSOR FOR SELECT * FROM pgss_matv + 1 | 5 | FETCH FORWARD 5 pgss_cursor + 1 | 7 | FETCH FORWARD ALL pgss_cursor + 1 | 1 | FETCH NEXT pgss_cursor + 1 | 13 | REFRESH MATERIALIZED VIEW pgss_matv + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 10 | SELECT generate_series($1, $2) c INTO pgss_select_into +(12 rows) + +DROP MATERIALIZED VIEW pgss_matv; +DROP TABLE pgss_ctas; +DROP TABLE pgss_select_into; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Special cases. Keep these ones at the end to avoid conflicts. +SET SCHEMA 'foo'; +SET SCHEMA 'public'; +RESET ALL; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 1 | 0 | RESET ALL + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | SET SCHEMA 'foo' + 1 | 0 | SET SCHEMA 'public' +(4 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/utility.out.17 b/edb_stat_statements/expected/utility.out.17 new file mode 100644 index 00000000000..fc581c6d82b --- /dev/null +++ b/edb_stat_statements/expected/utility.out.17 @@ -0,0 +1,745 @@ +-- +-- Utility commands +-- +-- These tests require track_utility to be enabled. +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Tables, indexes, triggers +CREATE TEMP TABLE tab_stats (a int, b char(20)); +CREATE INDEX index_stats ON tab_stats(b, (b || 'data1'), (b || 'data2')) WHERE a > 0; +ALTER TABLE tab_stats ALTER COLUMN b set default 'a'; +ALTER TABLE tab_stats ALTER COLUMN b TYPE text USING 'data' || b; +ALTER TABLE tab_stats ADD CONSTRAINT a_nonzero CHECK (a <> 0); +DROP TABLE tab_stats \; +DROP TABLE IF EXISTS tab_stats \; +-- This DROP query uses two different strings, still they count as one entry. +DROP TABLE IF EXISTS tab_stats \; +Drop Table If Exists tab_stats \; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +NOTICE: table "tab_stats" does not exist, skipping +NOTICE: table "tab_stats" does not exist, skipping +NOTICE: table "tab_stats" does not exist, skipping + calls | rows | query +-------+------+-------------------------------------------------------------------------------------- + 1 | 0 | ALTER TABLE tab_stats ADD CONSTRAINT a_nonzero CHECK (a <> 0) + 1 | 0 | ALTER TABLE tab_stats ALTER COLUMN b TYPE text USING 'data' || b + 1 | 0 | ALTER TABLE tab_stats ALTER COLUMN b set default 'a' + 1 | 0 | CREATE INDEX index_stats ON tab_stats(b, (b || 'data1'), (b || 'data2')) WHERE a > 0 + 1 | 0 | CREATE TEMP TABLE tab_stats (a int, b char(20)) + 3 | 0 | DROP TABLE IF EXISTS tab_stats + 1 | 0 | DROP TABLE tab_stats + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(8 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Partitions +CREATE TABLE pt_stats (a int, b int) PARTITION BY range (a); +CREATE TABLE pt_stats1 (a int, b int); +ALTER TABLE pt_stats ATTACH PARTITION pt_stats1 FOR VALUES FROM (0) TO (100); +CREATE TABLE pt_stats2 PARTITION OF pt_stats FOR VALUES FROM (100) TO (200); +CREATE INDEX pt_stats_index ON ONLY pt_stats (a); +CREATE INDEX pt_stats2_index ON ONLY pt_stats2 (a); +ALTER INDEX pt_stats_index ATTACH PARTITION pt_stats2_index; +DROP TABLE pt_stats; +-- Views +CREATE VIEW view_stats AS SELECT 1::int AS a, 2::int AS b; +ALTER VIEW view_stats ALTER COLUMN a SET DEFAULT 2; +DROP VIEW view_stats; +-- Foreign tables +CREATE FOREIGN DATA WRAPPER wrapper_stats; +CREATE SERVER server_stats FOREIGN DATA WRAPPER wrapper_stats; +CREATE FOREIGN TABLE foreign_stats (a int) SERVER server_stats; +ALTER FOREIGN TABLE foreign_stats ADD COLUMN b integer DEFAULT 1; +ALTER FOREIGN TABLE foreign_stats ADD CONSTRAINT b_nonzero CHECK (b <> 0); +DROP FOREIGN TABLE foreign_stats; +DROP SERVER server_stats; +DROP FOREIGN DATA WRAPPER wrapper_stats; +-- Functions +CREATE FUNCTION func_stats(a text DEFAULT 'a_data', b text DEFAULT lower('b_data')) + RETURNS text AS $$ SELECT $1::text || '_' || $2::text; $$ LANGUAGE SQL + SET work_mem = '256kB'; +DROP FUNCTION func_stats; +-- Rules +CREATE TABLE tab_rule_stats (a int, b int); +CREATE TABLE tab_rule_stats_2 (a int, b int, c int, d int); +CREATE RULE rules_stats AS ON INSERT TO tab_rule_stats DO INSTEAD + INSERT INTO tab_rule_stats_2 VALUES(new.*, 1, 2); +DROP RULE rules_stats ON tab_rule_stats; +DROP TABLE tab_rule_stats, tab_rule_stats_2; +-- Types +CREATE TYPE stats_type as (f1 numeric(35, 6), f2 numeric(35, 2)); +DROP TYPE stats_type; +-- Triggers +CREATE TABLE trigger_tab_stats (a int, b int); +CREATE FUNCTION trigger_func_stats () RETURNS trigger LANGUAGE plpgsql + AS $$ BEGIN return OLD; end; $$; +CREATE TRIGGER trigger_tab_stats + AFTER UPDATE ON trigger_tab_stats + FOR EACH ROW WHEN (OLD.a < 0 AND OLD.b < 1 AND true) + EXECUTE FUNCTION trigger_func_stats(); +DROP TABLE trigger_tab_stats; +-- Policies +CREATE TABLE tab_policy_stats (a int, b int); +CREATE POLICY policy_stats ON tab_policy_stats USING (a = 5) WITH CHECK (b < 5); +DROP TABLE tab_policy_stats; +-- Statistics +CREATE TABLE tab_expr_stats (a int, b int); +CREATE STATISTICS tab_expr_stats_1 (mcv) ON a, (2*a), (3*b) FROM tab_expr_stats; +DROP TABLE tab_expr_stats; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------------------------- + 1 | 0 | ALTER FOREIGN TABLE foreign_stats ADD COLUMN b integer DEFAULT 1 + 1 | 0 | ALTER FOREIGN TABLE foreign_stats ADD CONSTRAINT b_nonzero CHECK (b <> 0) + 1 | 0 | ALTER INDEX pt_stats_index ATTACH PARTITION pt_stats2_index + 1 | 0 | ALTER TABLE pt_stats ATTACH PARTITION pt_stats1 FOR VALUES FROM (0) TO (100) + 1 | 0 | ALTER VIEW view_stats ALTER COLUMN a SET DEFAULT 2 + 1 | 0 | CREATE FOREIGN DATA WRAPPER wrapper_stats + 1 | 0 | CREATE FOREIGN TABLE foreign_stats (a int) SERVER server_stats + 1 | 0 | CREATE FUNCTION func_stats(a text DEFAULT 'a_data', b text DEFAULT lower('b_data'))+ + | | RETURNS text AS $$ SELECT $1::text || '_' || $2::text; $$ LANGUAGE SQL + + | | SET work_mem = '256kB' + 1 | 0 | CREATE FUNCTION trigger_func_stats () RETURNS trigger LANGUAGE plpgsql + + | | AS $$ BEGIN return OLD; end; $$ + 1 | 0 | CREATE INDEX pt_stats2_index ON ONLY pt_stats2 (a) + 1 | 0 | CREATE INDEX pt_stats_index ON ONLY pt_stats (a) + 1 | 0 | CREATE POLICY policy_stats ON tab_policy_stats USING (a = 5) WITH CHECK (b < 5) + 1 | 0 | CREATE RULE rules_stats AS ON INSERT TO tab_rule_stats DO INSTEAD + + | | INSERT INTO tab_rule_stats_2 VALUES(new.*, 1, 2) + 1 | 0 | CREATE SERVER server_stats FOREIGN DATA WRAPPER wrapper_stats + 1 | 0 | CREATE STATISTICS tab_expr_stats_1 (mcv) ON a, (2*a), (3*b) FROM tab_expr_stats + 1 | 0 | CREATE TABLE pt_stats (a int, b int) PARTITION BY range (a) + 1 | 0 | CREATE TABLE pt_stats1 (a int, b int) + 1 | 0 | CREATE TABLE pt_stats2 PARTITION OF pt_stats FOR VALUES FROM (100) TO (200) + 1 | 0 | CREATE TABLE tab_expr_stats (a int, b int) + 1 | 0 | CREATE TABLE tab_policy_stats (a int, b int) + 1 | 0 | CREATE TABLE tab_rule_stats (a int, b int) + 1 | 0 | CREATE TABLE tab_rule_stats_2 (a int, b int, c int, d int) + 1 | 0 | CREATE TABLE trigger_tab_stats (a int, b int) + 1 | 0 | CREATE TRIGGER trigger_tab_stats + + | | AFTER UPDATE ON trigger_tab_stats + + | | FOR EACH ROW WHEN (OLD.a < 0 AND OLD.b < 1 AND true) + + | | EXECUTE FUNCTION trigger_func_stats() + 1 | 0 | CREATE TYPE stats_type as (f1 numeric(35, 6), f2 numeric(35, 2)) + 1 | 0 | CREATE VIEW view_stats AS SELECT 1::int AS a, 2::int AS b + 1 | 0 | DROP FOREIGN DATA WRAPPER wrapper_stats + 1 | 0 | DROP FOREIGN TABLE foreign_stats + 1 | 0 | DROP FUNCTION func_stats + 1 | 0 | DROP RULE rules_stats ON tab_rule_stats + 1 | 0 | DROP SERVER server_stats + 1 | 0 | DROP TABLE pt_stats + 1 | 0 | DROP TABLE tab_expr_stats + 1 | 0 | DROP TABLE tab_policy_stats + 1 | 0 | DROP TABLE tab_rule_stats, tab_rule_stats_2 + 1 | 0 | DROP TABLE trigger_tab_stats + 1 | 0 | DROP TYPE stats_type + 1 | 0 | DROP VIEW view_stats + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(39 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Transaction statements +BEGIN; +ABORT; +BEGIN; +ROLLBACK; +-- WORK +BEGIN WORK; +COMMIT WORK; +BEGIN WORK; +ABORT WORK; +-- TRANSACTION +BEGIN TRANSACTION; +COMMIT TRANSACTION; +BEGIN TRANSACTION; +ABORT TRANSACTION; +-- More isolation levels +BEGIN TRANSACTION DEFERRABLE; +COMMIT TRANSACTION AND NO CHAIN; +BEGIN ISOLATION LEVEL SERIALIZABLE; +COMMIT; +BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE; +COMMIT; +-- List of A_Const nodes, same lists. +BEGIN TRANSACTION READ ONLY, READ WRITE, DEFERRABLE, NOT DEFERRABLE; +COMMIT; +BEGIN TRANSACTION NOT DEFERRABLE, READ ONLY, READ WRITE, DEFERRABLE; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+--------------------------------------------------------------------- + 4 | 0 | ABORT + 6 | 0 | BEGIN + 2 | 0 | BEGIN ISOLATION LEVEL SERIALIZABLE + 1 | 0 | BEGIN TRANSACTION DEFERRABLE + 1 | 0 | BEGIN TRANSACTION NOT DEFERRABLE, READ ONLY, READ WRITE, DEFERRABLE + 1 | 0 | BEGIN TRANSACTION READ ONLY, READ WRITE, DEFERRABLE, NOT DEFERRABLE + 7 | 0 | COMMIT WORK + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(8 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Two-phase transactions +BEGIN; +PREPARE TRANSACTION 'stat_trans1'; +COMMIT PREPARED 'stat_trans1'; +BEGIN; +PREPARE TRANSACTION 'stat_trans2'; +ROLLBACK PREPARED 'stat_trans2'; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 0 | BEGIN + 1 | 0 | COMMIT PREPARED $1 + 2 | 0 | PREPARE TRANSACTION $1 + 1 | 0 | ROLLBACK PREPARED $1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Savepoints +BEGIN; +SAVEPOINT sp1; +SAVEPOINT sp2; +SAVEPOINT sp3; +SAVEPOINT sp4; +ROLLBACK TO sp4; +ROLLBACK TO SAVEPOINT sp4; +ROLLBACK TRANSACTION TO SAVEPOINT sp3; +RELEASE sp3; +RELEASE SAVEPOINT sp2; +ROLLBACK TO sp1; +RELEASE SAVEPOINT sp1; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 1 | 0 | BEGIN + 1 | 0 | COMMIT + 3 | 0 | RELEASE $1 + 4 | 0 | ROLLBACK TO $1 + 4 | 0 | SAVEPOINT $1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(6 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- EXPLAIN statements +-- A Query is used, normalized by the query jumbling. +EXPLAIN (costs off) SELECT 1; + QUERY PLAN +------------ + Result +(1 row) + +EXPLAIN (costs off) SELECT 2; + QUERY PLAN +------------ + Result +(1 row) + +EXPLAIN (costs off) SELECT a FROM generate_series(1,10) AS tab(a) WHERE a = 3; + QUERY PLAN +-------------------------------------- + Function Scan on generate_series tab + Filter: (a = 3) +(2 rows) + +EXPLAIN (costs off) SELECT a FROM generate_series(1,10) AS tab(a) WHERE a = 7; + QUERY PLAN +-------------------------------------- + Function Scan on generate_series tab + Filter: (a = 7) +(2 rows) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+--------------------------------------------------------------------------------- + 2 | 0 | EXPLAIN (costs off) SELECT $1 + 2 | 0 | EXPLAIN (costs off) SELECT a FROM generate_series($1,$2) AS tab(a) WHERE a = $3 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +-- CALL +CREATE OR REPLACE PROCEDURE sum_one(i int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + i)::int INTO r; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE PROCEDURE sum_two(i int, j int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + j)::int INTO r; +END; $$ LANGUAGE plpgsql; +-- Overloaded functions. +CREATE OR REPLACE PROCEDURE overload(i int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + i)::int INTO r; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE PROCEDURE overload(i text) AS $$ +DECLARE + r text; +BEGIN + SELECT i::text INTO r; +END; $$ LANGUAGE plpgsql; +-- Mix of IN/OUT parameters. +CREATE OR REPLACE PROCEDURE in_out(i int, i2 OUT int, i3 INOUT int) AS $$ +DECLARE + r int; +BEGIN + i2 := i; + i3 := i3 + i; +END; $$ LANGUAGE plpgsql; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +CALL sum_one(3); +CALL sum_one(199); +CALL sum_two(1,1); +CALL sum_two(1,2); +CALL overload(1); +CALL overload('A'); +CALL in_out(1, NULL, 1); + i2 | i3 +----+---- + 1 | 2 +(1 row) + +CALL in_out(2, 1, 2); + i2 | i3 +----+---- + 2 | 4 +(1 row) + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 0 | CALL in_out($1, $2, $3) + 1 | 0 | CALL overload($1) + 1 | 0 | CALL overload($1) + 2 | 0 | CALL sum_one($1) + 2 | 0 | CALL sum_two($1,$2) + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(6 rows) + +-- COPY +CREATE TABLE copy_stats (a int, b int); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Some queries with A_Const nodes. +COPY (SELECT 1) TO STDOUT; +1 +COPY (SELECT 2) TO STDOUT; +2 +COPY (INSERT INTO copy_stats VALUES (1, 1) RETURNING *) TO STDOUT; +1 1 +COPY (INSERT INTO copy_stats VALUES (2, 2) RETURNING *) TO STDOUT; +2 2 +COPY (UPDATE copy_stats SET b = b + 1 RETURNING *) TO STDOUT; +1 2 +2 3 +COPY (UPDATE copy_stats SET b = b + 2 RETURNING *) TO STDOUT; +1 4 +2 5 +COPY (DELETE FROM copy_stats WHERE a = 1 RETURNING *) TO STDOUT; +1 4 +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------- + 1 | 1 | COPY (DELETE FROM copy_stats WHERE a = 1 RETURNING *) TO STDOUT + 1 | 1 | COPY (INSERT INTO copy_stats VALUES (1, 1) RETURNING *) TO STDOUT + 1 | 1 | COPY (INSERT INTO copy_stats VALUES (2, 2) RETURNING *) TO STDOUT + 1 | 1 | COPY (SELECT 1) TO STDOUT + 1 | 1 | COPY (SELECT 2) TO STDOUT + 1 | 2 | COPY (UPDATE copy_stats SET b = b + 1 RETURNING *) TO STDOUT + 1 | 2 | COPY (UPDATE copy_stats SET b = b + 2 RETURNING *) TO STDOUT + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(8 rows) + +DROP TABLE copy_stats; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- CREATE TABLE AS +-- SELECT queries are normalized, creating matching query IDs. +CREATE TABLE ctas_stats_1 AS SELECT 1 AS a; +DROP TABLE ctas_stats_1; +CREATE TABLE ctas_stats_1 AS SELECT 2 AS a; +DROP TABLE ctas_stats_1; +CREATE TABLE ctas_stats_2 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP TABLE ctas_stats_2; +CREATE TABLE ctas_stats_2 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 1; +DROP TABLE ctas_stats_2; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+-------------------------------------------------------------------- + 2 | 2 | CREATE TABLE ctas_stats_1 AS SELECT $1 AS a + 2 | 4 | CREATE TABLE ctas_stats_2 AS + + | | SELECT a AS col1, $1::int AS col2 + + | | FROM generate_series($2, $3) AS tab(a) WHERE a < $4 AND a > $5 + 2 | 0 | DROP TABLE ctas_stats_1 + 2 | 0 | DROP TABLE ctas_stats_2 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- CREATE MATERIALIZED VIEW +-- SELECT queries are normalized, creating matching query IDs. +CREATE MATERIALIZED VIEW matview_stats_1 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP MATERIALIZED VIEW matview_stats_1; +CREATE MATERIALIZED VIEW matview_stats_1 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3; +DROP MATERIALIZED VIEW matview_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+-------------------------------------------------------------------- + 2 | 2 | CREATE MATERIALIZED VIEW matview_stats_1 AS + + | | SELECT a AS col1, $1::int AS col2 + + | | FROM generate_series($2, $3) AS tab(a) WHERE a < $4 AND a > $5 + 2 | 0 | DROP MATERIALIZED VIEW matview_stats_1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(3 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- CREATE VIEW +CREATE VIEW view_stats_1 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP VIEW view_stats_1; +CREATE VIEW view_stats_1 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3; +DROP VIEW view_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------------------- + 1 | 0 | CREATE VIEW view_stats_1 AS + + | | SELECT a AS col1, 2::int AS col2 + + | | FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2 + 1 | 0 | CREATE VIEW view_stats_1 AS + + | | SELECT a AS col1, 4::int AS col2 + + | | FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3 + 2 | 0 | DROP VIEW view_stats_1 + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(4 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Domains +CREATE DOMAIN domain_stats AS int CHECK (VALUE > 0); +ALTER DOMAIN domain_stats SET DEFAULT '3'; +ALTER DOMAIN domain_stats ADD CONSTRAINT higher_than_one CHECK (VALUE > 1); +DROP DOMAIN domain_stats; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+---------------------------------------------------------------------------- + 1 | 0 | ALTER DOMAIN domain_stats ADD CONSTRAINT higher_than_one CHECK (VALUE > 1) + 1 | 0 | ALTER DOMAIN domain_stats SET DEFAULT '3' + 1 | 0 | CREATE DOMAIN domain_stats AS int CHECK (VALUE > 0) + 1 | 0 | DROP DOMAIN domain_stats + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Execution statements +SELECT 1 as a; + a +--- + 1 +(1 row) + +PREPARE stat_select AS SELECT $1 AS a; +EXECUTE stat_select (1); + a +--- + 1 +(1 row) + +DEALLOCATE stat_select; +PREPARE stat_select AS SELECT $1 AS a; +EXECUTE stat_select (2); + a +--- + 2 +(1 row) + +DEALLOCATE PREPARE stat_select; +DEALLOCATE ALL; +DEALLOCATE PREPARE ALL; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 2 | 0 | DEALLOCATE $1 + 2 | 0 | DEALLOCATE ALL + 2 | 2 | PREPARE stat_select AS SELECT $1 AS a + 1 | 1 | SELECT $1 as a + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- SET statements. +-- These use two different strings, still they count as one entry. +CREATE ROLE regress_stat_set_1; +CREATE ROLE regress_stat_set_2; +SET work_mem = '1MB'; +Set work_mem = '1MB'; +SET work_mem = '2MB'; +SET work_mem = DEFAULT; +SET work_mem TO DEFAULT; +SET work_mem FROM CURRENT; +BEGIN; +SET LOCAL work_mem = '128kB'; +SET LOCAL work_mem = '256kB'; +SET LOCAL work_mem = DEFAULT; +SET LOCAL work_mem TO DEFAULT; +SET LOCAL work_mem FROM CURRENT; +COMMIT; +RESET work_mem; +SET enable_seqscan = off; +SET enable_seqscan = on; +SET SESSION work_mem = '300kB'; +SET SESSION work_mem = '400kB'; +RESET enable_seqscan; +-- SET TRANSACTION ISOLATION +BEGIN; +SET TRANSACTION ISOLATION LEVEL READ COMMITTED; +SET TRANSACTION ISOLATION LEVEL REPEATABLE READ; +SET TRANSACTION ISOLATION LEVEL SERIALIZABLE; +COMMIT; +-- SET SESSION AUTHORIZATION +SET SESSION SESSION AUTHORIZATION DEFAULT; +SET SESSION AUTHORIZATION 'regress_stat_set_1'; +SET SESSION AUTHORIZATION 'regress_stat_set_2'; +RESET SESSION AUTHORIZATION; +BEGIN; +SET LOCAL SESSION AUTHORIZATION DEFAULT; +SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_1'; +SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_2'; +RESET SESSION AUTHORIZATION; +COMMIT; +-- SET SESSION CHARACTERISTICS +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY; +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ ONLY; +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ WRITE; +-- SET XML OPTION +SET XML OPTION DOCUMENT; +SET XML OPTION CONTENT; +-- SET TIME ZONE +SET TIME ZONE 'America/New_York'; +SET TIME ZONE 'Asia/Tokyo'; +SET TIME ZONE DEFAULT; +SET TIME ZONE LOCAL; +SET TIME ZONE 'CST7CDT,M4.1.0,M10.5.0'; +RESET TIME ZONE; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------ + 3 | 0 | BEGIN + 3 | 0 | COMMIT + 1 | 0 | CREATE ROLE regress_stat_set_1 + 1 | 0 | CREATE ROLE regress_stat_set_2 + 2 | 0 | RESET SESSION AUTHORIZATION + 1 | 0 | RESET TIME ZONE + 1 | 0 | RESET enable_seqscan + 1 | 0 | RESET work_mem + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_1' + 1 | 0 | SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_2' + 1 | 0 | SET LOCAL SESSION AUTHORIZATION DEFAULT + 1 | 0 | SET LOCAL work_mem = '128kB' + 1 | 0 | SET LOCAL work_mem = '256kB' + 2 | 0 | SET LOCAL work_mem = DEFAULT + 1 | 0 | SET LOCAL work_mem FROM CURRENT + 1 | 0 | SET SESSION AUTHORIZATION 'regress_stat_set_1' + 1 | 0 | SET SESSION AUTHORIZATION 'regress_stat_set_2' + 1 | 0 | SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY + 1 | 0 | SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ ONLY + 1 | 0 | SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ WRITE + 1 | 0 | SET SESSION SESSION AUTHORIZATION DEFAULT + 1 | 0 | SET SESSION work_mem = '300kB' + 1 | 0 | SET SESSION work_mem = '400kB' + 1 | 0 | SET TIME ZONE 'America/New_York' + 1 | 0 | SET TIME ZONE 'Asia/Tokyo' + 1 | 0 | SET TIME ZONE 'CST7CDT,M4.1.0,M10.5.0' + 2 | 0 | SET TIME ZONE DEFAULT + 1 | 0 | SET TRANSACTION ISOLATION LEVEL READ COMMITTED + 1 | 0 | SET TRANSACTION ISOLATION LEVEL REPEATABLE READ + 1 | 0 | SET TRANSACTION ISOLATION LEVEL SERIALIZABLE + 1 | 0 | SET XML OPTION CONTENT + 1 | 0 | SET XML OPTION DOCUMENT + 1 | 0 | SET enable_seqscan = off + 1 | 0 | SET enable_seqscan = on + 2 | 0 | SET work_mem = '1MB' + 1 | 0 | SET work_mem = '2MB' + 2 | 0 | SET work_mem = DEFAULT + 1 | 0 | SET work_mem FROM CURRENT +(39 rows) + +DROP ROLE regress_stat_set_1; +DROP ROLE regress_stat_set_2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- +-- Track the total number of rows retrieved or affected by the utility +-- commands of COPY, FETCH, CREATE TABLE AS, CREATE MATERIALIZED VIEW, +-- REFRESH MATERIALIZED VIEW and SELECT INTO +-- +CREATE TABLE pgss_ctas AS SELECT a, 'ctas' b FROM generate_series(1, 10) a; +SELECT generate_series(1, 10) c INTO pgss_select_into; +COPY pgss_ctas (a, b) FROM STDIN; +CREATE MATERIALIZED VIEW pgss_matv AS SELECT * FROM pgss_ctas; +REFRESH MATERIALIZED VIEW pgss_matv; +BEGIN; +DECLARE pgss_cursor CURSOR FOR SELECT * FROM pgss_matv; +FETCH NEXT pgss_cursor; + a | b +---+------ + 1 | ctas +(1 row) + +FETCH FORWARD 5 pgss_cursor; + a | b +---+------ + 2 | ctas + 3 | ctas + 4 | ctas + 5 | ctas + 6 | ctas +(5 rows) + +FETCH FORWARD ALL pgss_cursor; + a | b +----+------ + 7 | ctas + 8 | ctas + 9 | ctas + 10 | ctas + 11 | copy + 12 | copy + 13 | copy +(7 rows) + +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+------------------------------------------------------------------------- + 1 | 0 | BEGIN + 1 | 0 | COMMIT + 1 | 3 | COPY pgss_ctas (a, b) FROM STDIN + 1 | 13 | CREATE MATERIALIZED VIEW pgss_matv AS SELECT * FROM pgss_ctas + 1 | 10 | CREATE TABLE pgss_ctas AS SELECT a, $1 b FROM generate_series($2, $3) a + 1 | 0 | DECLARE pgss_cursor CURSOR FOR SELECT * FROM pgss_matv + 1 | 5 | FETCH FORWARD 5 pgss_cursor + 1 | 7 | FETCH FORWARD ALL pgss_cursor + 1 | 1 | FETCH NEXT pgss_cursor + 1 | 13 | REFRESH MATERIALIZED VIEW pgss_matv + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 10 | SELECT generate_series($1, $2) c INTO pgss_select_into +(12 rows) + +DROP MATERIALIZED VIEW pgss_matv; +DROP TABLE pgss_ctas; +DROP TABLE pgss_select_into; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + +-- Special cases. Keep these ones at the end to avoid conflicts. +SET SCHEMA 'foo'; +SET SCHEMA 'public'; +RESET ALL; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + calls | rows | query +-------+------+----------------------------------------------------- + 1 | 0 | RESET ALL + 1 | 1 | SELECT edb_stat_statements_reset() IS NOT NULL AS t + 1 | 0 | SET SCHEMA 'foo' + 1 | 0 | SET SCHEMA 'public' +(4 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/wal.out.17 b/edb_stat_statements/expected/wal.out.17 new file mode 100644 index 00000000000..c12598ef707 --- /dev/null +++ b/edb_stat_statements/expected/wal.out.17 @@ -0,0 +1,30 @@ +-- +-- Validate WAL generation metrics +-- +SET edb_stat_statements.track_utility = FALSE; +CREATE TABLE pgss_wal_tab (a int, b char(20)); +INSERT INTO pgss_wal_tab VALUES(generate_series(1, 10), 'aaa'); +UPDATE pgss_wal_tab SET b = 'bbb' WHERE a > 7; +DELETE FROM pgss_wal_tab WHERE a > 9; +DROP TABLE pgss_wal_tab; +-- Check WAL is generated for the above statements +SELECT query, calls, rows, +wal_bytes > 0 as wal_bytes_generated, +wal_records > 0 as wal_records_generated, +wal_records >= rows as wal_records_ge_rows +FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows | wal_bytes_generated | wal_records_generated | wal_records_ge_rows +--------------------------------------------------------------+-------+------+---------------------+-----------------------+--------------------- + DELETE FROM pgss_wal_tab WHERE a > $1 | 1 | 1 | t | t | t + INSERT INTO pgss_wal_tab VALUES(generate_series($1, $2), $3) | 1 | 10 | t | t | t + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 | f | f | f + SET edb_stat_statements.track_utility = FALSE | 1 | 0 | f | f | t + UPDATE pgss_wal_tab SET b = $1 WHERE a > $2 | 1 | 3 | t | t | t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/expected/wal.out.18 b/edb_stat_statements/expected/wal.out.18 new file mode 100644 index 00000000000..b4c255cdf78 --- /dev/null +++ b/edb_stat_statements/expected/wal.out.18 @@ -0,0 +1,30 @@ +-- +-- Validate WAL generation metrics +-- +SET edb_stat_statements.track_utility = FALSE; +CREATE TABLE pgss_wal_tab (a int, b char(20)); +INSERT INTO pgss_wal_tab VALUES(generate_series(1, 10), 'aaa'); +UPDATE pgss_wal_tab SET b = 'bbb' WHERE a > 7; +DELETE FROM pgss_wal_tab WHERE a > 9; +DROP TABLE pgss_wal_tab; +-- Check WAL is generated for the above statements +SELECT query, calls, rows, +wal_bytes > 0 as wal_bytes_generated, +wal_records > 0 as wal_records_generated, +wal_records >= rows as wal_records_ge_rows +FROM edb_stat_statements ORDER BY query COLLATE "C"; + query | calls | rows | wal_bytes_generated | wal_records_generated | wal_records_ge_rows +--------------------------------------------------------------+-------+------+---------------------+-----------------------+--------------------- + DELETE FROM pgss_wal_tab WHERE a > $1 | 1 | 1 | t | t | t + INSERT INTO pgss_wal_tab VALUES(generate_series($1, $2), $3) | 1 | 10 | t | t | t + SELECT edb_stat_statements_reset() IS NOT NULL AS t | 1 | 1 | f | f | f + SET edb_stat_statements.track_utility = $1 | 1 | 0 | f | f | t + UPDATE pgss_wal_tab SET b = $1 WHERE a > $2 | 1 | 3 | t | t | t +(5 rows) + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + t +--- + t +(1 row) + diff --git a/edb_stat_statements/sql/cleanup.sql b/edb_stat_statements/sql/cleanup.sql new file mode 100644 index 00000000000..03e40380b87 --- /dev/null +++ b/edb_stat_statements/sql/cleanup.sql @@ -0,0 +1 @@ +DROP EXTENSION edb_stat_statements; diff --git a/edb_stat_statements/sql/cursors.sql b/edb_stat_statements/sql/cursors.sql new file mode 100644 index 00000000000..2c0b637d488 --- /dev/null +++ b/edb_stat_statements/sql/cursors.sql @@ -0,0 +1,30 @@ +-- +-- Cursors +-- + +-- These tests require track_utility to be enabled. +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- DECLARE +-- SELECT is normalized. +DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT 1; +CLOSE cursor_stats_1; +DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT 2; +CLOSE cursor_stats_1; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- FETCH +BEGIN; +DECLARE cursor_stats_1 CURSOR WITH HOLD FOR SELECT 2; +DECLARE cursor_stats_2 CURSOR WITH HOLD FOR SELECT 3; +FETCH 1 IN cursor_stats_1; +FETCH 1 IN cursor_stats_2; +CLOSE cursor_stats_1; +CLOSE cursor_stats_2; +COMMIT; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/dml.sql b/edb_stat_statements/sql/dml.sql new file mode 100644 index 00000000000..bf413e0e00b --- /dev/null +++ b/edb_stat_statements/sql/dml.sql @@ -0,0 +1,95 @@ +-- +-- DMLs on test table +-- + +SET edb_stat_statements.track_utility = FALSE; + +CREATE TEMP TABLE pgss_dml_tab (a int, b char(20)); + +INSERT INTO pgss_dml_tab VALUES(generate_series(1, 10), 'aaa'); +UPDATE pgss_dml_tab SET b = 'bbb' WHERE a > 7; +DELETE FROM pgss_dml_tab WHERE a > 9; + +-- explicit transaction +BEGIN; +UPDATE pgss_dml_tab SET b = '111' WHERE a = 1 ; +COMMIT; + +BEGIN \; +UPDATE pgss_dml_tab SET b = '222' WHERE a = 2 \; +COMMIT ; + +UPDATE pgss_dml_tab SET b = '333' WHERE a = 3 \; +UPDATE pgss_dml_tab SET b = '444' WHERE a = 4 ; + +BEGIN \; +UPDATE pgss_dml_tab SET b = '555' WHERE a = 5 \; +UPDATE pgss_dml_tab SET b = '666' WHERE a = 6 \; +COMMIT ; + +-- many INSERT values +INSERT INTO pgss_dml_tab (a, b) VALUES (1, 'a'), (2, 'b'), (3, 'c'); + +-- SELECT with constants +SELECT * FROM pgss_dml_tab WHERE a > 5 ORDER BY a ; + +SELECT * + FROM pgss_dml_tab + WHERE a > 9 + ORDER BY a ; + +-- these two need to be done on a different table +-- SELECT without constants +SELECT * FROM pgss_dml_tab ORDER BY a; + +-- SELECT with IN clause +SELECT * FROM pgss_dml_tab WHERE a IN (1, 2, 3, 4, 5); + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- MERGE +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN UPDATE SET b = st.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN UPDATE SET b = pgss_dml_tab.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED AND length(st.b) > 1 THEN UPDATE SET b = pgss_dml_tab.b || st.a::text; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (a, b) VALUES (0, NULL); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT VALUES (0, NULL); -- same as above +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (b, a) VALUES (NULL, 0); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a) + WHEN NOT MATCHED THEN INSERT (a) VALUES (0); +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN DELETE; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN MATCHED THEN DO NOTHING; +MERGE INTO pgss_dml_tab USING pgss_dml_tab st ON (st.a = pgss_dml_tab.a AND st.a >= 4) + WHEN NOT MATCHED THEN DO NOTHING; + +DROP TABLE pgss_dml_tab; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- check that [temp] table relation extensions are tracked as writes +CREATE TABLE pgss_extend_tab (a int, b text); +CREATE TEMP TABLE pgss_extend_temp_tab (a int, b text); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +INSERT INTO pgss_extend_tab (a, b) SELECT generate_series(1, 1000), 'something'; +INSERT INTO pgss_extend_temp_tab (a, b) SELECT generate_series(1, 1000), 'something'; +WITH sizes AS ( + SELECT + pg_relation_size('pgss_extend_tab') / current_setting('block_size')::int8 AS rel_size, + pg_relation_size('pgss_extend_temp_tab') / current_setting('block_size')::int8 AS temp_rel_size +) +SELECT + SUM(local_blks_written) >= (SELECT temp_rel_size FROM sizes) AS temp_written_ok, + SUM(local_blks_dirtied) >= (SELECT temp_rel_size FROM sizes) AS temp_dirtied_ok, + SUM(shared_blks_written) >= (SELECT rel_size FROM sizes) AS written_ok, + SUM(shared_blks_dirtied) >= (SELECT rel_size FROM sizes) AS dirtied_ok +FROM edb_stat_statements; + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/entry_timestamp.sql b/edb_stat_statements/sql/entry_timestamp.sql new file mode 100644 index 00000000000..2e2b096f413 --- /dev/null +++ b/edb_stat_statements/sql/entry_timestamp.sql @@ -0,0 +1,114 @@ +-- +-- statement timestamps +-- + +-- planning time is needed during tests +SET edb_stat_statements.track_planning = TRUE; + +SELECT 1 AS "STMTTS1"; +SELECT now() AS ref_ts \gset +SELECT 1,2 AS "STMTTS2"; +SELECT stats_since >= :'ref_ts', count(*) FROM edb_stat_statements +WHERE query LIKE '%STMTTS%' +GROUP BY stats_since >= :'ref_ts' +ORDER BY stats_since >= :'ref_ts'; + +SELECT now() AS ref_ts \gset +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_plan_time + max_plan_time = 0 + ) as minmax_plan_zero, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_stats_since_after_ref, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + +-- Perform single min/max reset +SELECT edb_stat_statements_reset(0, '{}', queryid, true) AS minmax_reset_ts +FROM edb_stat_statements +WHERE query LIKE '%STMTTS1%' \gset + +-- check +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_plan_time + max_plan_time = 0 + ) as minmax_plan_zero, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_stats_since_after_ref, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + +-- check minmax reset timestamps +SELECT +query, minmax_stats_since = :'minmax_reset_ts' AS reset_ts_match +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%' +ORDER BY query COLLATE "C"; + +-- check that minmax reset does not set stats_reset +SELECT +stats_reset = :'minmax_reset_ts' AS stats_reset_ts_match +FROM edb_stat_statements_info; + +-- Perform common min/max reset +SELECT edb_stat_statements_reset(0, '{}', 0, true) AS minmax_reset_ts \gset + +-- check again +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_plan_time + max_plan_time = 0 + ) as minmax_plan_zero, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_ts_after_ref, + count(*) FILTER ( + WHERE minmax_stats_since = :'minmax_reset_ts' + ) as minmax_ts_match, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + +-- Execute first query once more to check stats update +SELECT 1 AS "STMTTS1"; + +-- check +-- we don't check planing times here to be independent of +-- plan caching approach +SELECT + count(*) as total, + count(*) FILTER ( + WHERE min_exec_time + max_exec_time = 0 + ) as minmax_exec_zero, + count(*) FILTER ( + WHERE minmax_stats_since >= :'ref_ts' + ) as minmax_ts_after_ref, + count(*) FILTER ( + WHERE stats_since >= :'ref_ts' + ) as stats_since_after_ref +FROM edb_stat_statements +WHERE query LIKE '%STMTTS%'; + +-- Cleanup +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/extended.sql b/edb_stat_statements/sql/extended.sql new file mode 100644 index 00000000000..83ca79fa7d5 --- /dev/null +++ b/edb_stat_statements/sql/extended.sql @@ -0,0 +1,21 @@ +-- Tests with extended query protocol + +SET edb_stat_statements.track_utility = FALSE; + +-- This test checks that an execute message sets a query ID. +SELECT query_id IS NOT NULL AS query_id_set + FROM pg_stat_activity WHERE pid = pg_backend_pid() \bind \g + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +SELECT $1 \parse stmt1 +SELECT $1, $2 \parse stmt2 +SELECT $1, $2, $3 \parse stmt3 +SELECT $1 \bind 'unnamed_val1' \g +\bind_named stmt1 'stmt1_val1' \g +\bind_named stmt2 'stmt2_val1' 'stmt2_val2' \g +\bind_named stmt3 'stmt3_val1' 'stmt3_val2' 'stmt3_val3' \g +\bind_named stmt3 'stmt3_val4' 'stmt3_val5' 'stmt3_val6' \g +\bind_named stmt2 'stmt2_val3' 'stmt2_val4' \g +\bind_named stmt1 'stmt1_val1' \g + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; diff --git a/edb_stat_statements/sql/level_tracking.sql b/edb_stat_statements/sql/level_tracking.sql new file mode 100644 index 00000000000..9cb852e8e20 --- /dev/null +++ b/edb_stat_statements/sql/level_tracking.sql @@ -0,0 +1,173 @@ +-- +-- Statement level tracking +-- + +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- DO block - top-level tracking. +CREATE TABLE stats_track_tab (x int); +SET edb_stat_statements.track = 'top'; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; +$$ LANGUAGE plpgsql; +SELECT toplevel, calls, query FROM edb_stat_statements + WHERE query LIKE '%DELETE%' ORDER BY query COLLATE "C", toplevel; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- DO block - all-level tracking. +SET edb_stat_statements.track = 'all'; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + +-- Procedure with multiple utility statements. +CREATE OR REPLACE PROCEDURE proc_with_utility_stmt() +LANGUAGE SQL +AS $$ + SHOW edb_stat_statements.track; + show edb_stat_statements.track; + SHOW edb_stat_statements.track_utility; +$$; +SET edb_stat_statements.track_utility = TRUE; +-- all-level tracking. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +CALL proc_with_utility_stmt(); +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; +-- top-level tracking. +SET edb_stat_statements.track = 'top'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +CALL proc_with_utility_stmt(); +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + +-- DO block - top-level tracking without utility. +SET edb_stat_statements.track = 'top'; +SET edb_stat_statements.track_utility = FALSE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + +-- DO block - all-level tracking without utility. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +DELETE FROM stats_track_tab; +DO $$ +BEGIN + DELETE FROM stats_track_tab; +END; $$; +DO LANGUAGE plpgsql $$ +BEGIN + -- this is a SELECT + PERFORM 'hello world'::TEXT; +END; $$; +SELECT toplevel, calls, query FROM edb_stat_statements + ORDER BY query COLLATE "C", toplevel; + +-- PL/pgSQL function - top-level tracking. +SET edb_stat_statements.track = 'top'; +SET edb_stat_statements.track_utility = FALSE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +CREATE FUNCTION PLUS_TWO(i INTEGER) RETURNS INTEGER AS $$ +DECLARE + r INTEGER; +BEGIN + SELECT (i + 1 + 1.0)::INTEGER INTO r; + RETURN r; +END; $$ LANGUAGE plpgsql; + +SELECT PLUS_TWO(3); +SELECT PLUS_TWO(7); + +-- SQL function --- use LIMIT to keep it from being inlined +CREATE FUNCTION PLUS_ONE(i INTEGER) RETURNS INTEGER AS +$$ SELECT (i + 1.0)::INTEGER LIMIT 1 $$ LANGUAGE SQL; + +SELECT PLUS_ONE(8); +SELECT PLUS_ONE(10); + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- immutable SQL function --- can be executed at plan time +CREATE FUNCTION PLUS_THREE(i INTEGER) RETURNS INTEGER AS +$$ SELECT i + 3 LIMIT 1 $$ IMMUTABLE LANGUAGE SQL; + +SELECT PLUS_THREE(8); +SELECT PLUS_THREE(10); + +SELECT toplevel, calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- PL/pgSQL function - all-level tracking. +SET edb_stat_statements.track = 'all'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- we drop and recreate the functions to avoid any caching funnies +DROP FUNCTION PLUS_ONE(INTEGER); +DROP FUNCTION PLUS_TWO(INTEGER); +DROP FUNCTION PLUS_THREE(INTEGER); + +-- PL/pgSQL function +CREATE FUNCTION PLUS_TWO(i INTEGER) RETURNS INTEGER AS $$ +DECLARE + r INTEGER; +BEGIN + SELECT (i + 1 + 1.0)::INTEGER INTO r; + RETURN r; +END; $$ LANGUAGE plpgsql; + +SELECT PLUS_TWO(-1); +SELECT PLUS_TWO(2); + +-- SQL function --- use LIMIT to keep it from being inlined +CREATE FUNCTION PLUS_ONE(i INTEGER) RETURNS INTEGER AS +$$ SELECT (i + 1.0)::INTEGER LIMIT 1 $$ LANGUAGE SQL; + +SELECT PLUS_ONE(3); +SELECT PLUS_ONE(1); + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- immutable SQL function --- can be executed at plan time +CREATE FUNCTION PLUS_THREE(i INTEGER) RETURNS INTEGER AS +$$ SELECT i + 3 LIMIT 1 $$ IMMUTABLE LANGUAGE SQL; + +SELECT PLUS_THREE(8); +SELECT PLUS_THREE(10); + +SELECT toplevel, calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- edb_stat_statements.track = none +-- +SET edb_stat_statements.track = 'none'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +SELECT 1 AS "one"; +SELECT 1 + 1 AS "two"; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/oldextversions.sql b/edb_stat_statements/sql/oldextversions.sql new file mode 100644 index 00000000000..078101ab8ee --- /dev/null +++ b/edb_stat_statements/sql/oldextversions.sql @@ -0,0 +1,13 @@ +-- test old extension version entry points + +CREATE EXTENSION edb_stat_statements WITH VERSION '1.0'; + +SELECT pg_get_functiondef('edb_stat_statements_info'::regproc); + +SELECT pg_get_functiondef('edb_stat_statements_reset'::regproc); + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +\d edb_stat_statements +SELECT count(*) > 0 AS has_data FROM edb_stat_statements; + +DROP EXTENSION edb_stat_statements; diff --git a/edb_stat_statements/sql/parallel.sql b/edb_stat_statements/sql/parallel.sql new file mode 100644 index 00000000000..f4592e147ba --- /dev/null +++ b/edb_stat_statements/sql/parallel.sql @@ -0,0 +1,26 @@ +-- +-- Tests for parallel statistics +-- + +SET edb_stat_statements.track_utility = FALSE; + +-- encourage use of parallel plans +SET parallel_setup_cost = 0; +SET parallel_tuple_cost = 0; +SET min_parallel_table_scan_size = 0; +SET max_parallel_workers_per_gather = 2; + +CREATE TABLE pgss_parallel_tab (a int); + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +SELECT count(*) FROM pgss_parallel_tab; + +SELECT query, + parallel_workers_to_launch > 0 AS has_workers_to_launch, + parallel_workers_launched > 0 AS has_workers_launched + FROM edb_stat_statements + WHERE query ~ 'SELECT count' + ORDER BY query COLLATE "C"; + +DROP TABLE pgss_parallel_tab; diff --git a/edb_stat_statements/sql/planning.sql b/edb_stat_statements/sql/planning.sql new file mode 100644 index 00000000000..618baadb18d --- /dev/null +++ b/edb_stat_statements/sql/planning.sql @@ -0,0 +1,31 @@ +-- +-- Information related to planning +-- + +-- These tests require track_planning to be enabled. +SET edb_stat_statements.track_planning = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- +-- [re]plan counting +-- +CREATE TABLE stats_plan_test (); +PREPARE prep1 AS SELECT COUNT(*) FROM stats_plan_test; +EXECUTE prep1; +EXECUTE prep1; +EXECUTE prep1; +ALTER TABLE stats_plan_test ADD COLUMN x int; +EXECUTE prep1; +SELECT 42; +SELECT 42; +SELECT 42; +SELECT plans, calls, rows, query FROM edb_stat_statements + WHERE query NOT LIKE 'PREPARE%' ORDER BY query COLLATE "C"; +-- for the prepared statement we expect at least one replan, but cache +-- invalidations could force more +SELECT plans >= 2 AND plans <= calls AS plans_ok, calls, rows, query FROM edb_stat_statements + WHERE query LIKE 'PREPARE%' ORDER BY query COLLATE "C"; + +-- Cleanup +DROP TABLE stats_plan_test; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/privileges.sql b/edb_stat_statements/sql/privileges.sql new file mode 100644 index 00000000000..1b5459e3067 --- /dev/null +++ b/edb_stat_statements/sql/privileges.sql @@ -0,0 +1,60 @@ +-- +-- Only superusers and roles with privileges of the pg_read_all_stats role +-- are allowed to see the SQL text and queryid of queries executed by +-- other users. Other users can see the statistics. +-- + +SET edb_stat_statements.track_utility = FALSE; +CREATE ROLE regress_stats_superuser SUPERUSER; +CREATE ROLE regress_stats_user1; +CREATE ROLE regress_stats_user2; +GRANT pg_read_all_stats TO regress_stats_user2; + +SET ROLE regress_stats_superuser; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +SELECT 1 AS "ONE"; + +SET ROLE regress_stats_user1; +SELECT 1+1 AS "TWO"; + +-- +-- A superuser can read all columns of queries executed by others, +-- including query text and queryid. +-- + +SET ROLE regress_stats_superuser; +SELECT r.rolname, ss.queryid <> 0 AS queryid_bool, ss.query, ss.calls, ss.rows + FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid + ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows; + +-- +-- regress_stats_user1 has no privileges to read the query text or +-- queryid of queries executed by others but can see statistics +-- like calls and rows. +-- + +SET ROLE regress_stats_user1; +SELECT r.rolname, ss.queryid <> 0 AS queryid_bool, ss.query, ss.calls, ss.rows + FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid + ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows; + +-- +-- regress_stats_user2, with pg_read_all_stats role privileges, can +-- read all columns, including query text and queryid, of queries +-- executed by others. +-- + +SET ROLE regress_stats_user2; +SELECT r.rolname, ss.queryid <> 0 AS queryid_bool, ss.query, ss.calls, ss.rows + FROM edb_stat_statements ss JOIN pg_roles r ON ss.userid = r.oid + ORDER BY r.rolname, ss.query COLLATE "C", ss.calls, ss.rows; + +-- +-- cleanup +-- + +RESET ROLE; +DROP ROLE regress_stats_superuser; +DROP ROLE regress_stats_user1; +DROP ROLE regress_stats_user2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/select.sql b/edb_stat_statements/sql/select.sql new file mode 100644 index 00000000000..6847a198161 --- /dev/null +++ b/edb_stat_statements/sql/select.sql @@ -0,0 +1,149 @@ +-- +-- SELECT statements +-- + +CREATE EXTENSION edb_stat_statements; +SET edb_stat_statements.track_utility = FALSE; +SET edb_stat_statements.track_planning = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- +-- simple and compound statements +-- +SELECT 1 AS "int"; + +SELECT 'hello' + -- multiline + AS "text"; + +SELECT 'world' AS "text"; + +-- transaction +BEGIN; +SELECT 1 AS "int"; +SELECT 'hello' AS "text"; +COMMIT; + +-- compound transaction +BEGIN \; +SELECT 2.0 AS "float" \; +SELECT 'world' AS "text" \; +COMMIT; + +-- compound with empty statements and spurious leading spacing +\;\; SELECT 3 + 3 \;\;\; SELECT ' ' || ' !' \;\; SELECT 1 + 4 \;; + +-- non ;-terminated statements +SELECT 1 + 1 + 1 AS "add" \gset +SELECT :add + 1 + 1 AS "add" \; +SELECT :add + 1 + 1 AS "add" \gset + +-- set operator +SELECT 1 AS i UNION SELECT 2 ORDER BY i; + +-- ? operator +select '{"a":1, "b":2}'::jsonb ? 'b'; + +-- cte +WITH t(f) AS ( + VALUES (1.0), (2.0) +) + SELECT f FROM t ORDER BY f; + +-- prepared statement with parameter +PREPARE pgss_test (int) AS SELECT $1, 'test' LIMIT 1; +EXECUTE pgss_test(1); +DEALLOCATE pgss_test; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- +-- queries with locking clauses +-- +CREATE TABLE pgss_a (id integer PRIMARY KEY); +CREATE TABLE pgss_b (id integer PRIMARY KEY, a_id integer REFERENCES pgss_a); + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- control query +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id; + +-- test range tables +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE; +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_a; +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_b; +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_a, pgss_b; -- matches plain "FOR UPDATE" +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE OF pgss_b, pgss_a; + +-- test strengths +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR NO KEY UPDATE; +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR SHARE; +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR KEY SHARE; + +-- test wait policies +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE NOWAIT; +SELECT * FROM pgss_a JOIN pgss_b ON pgss_b.a_id = pgss_a.id FOR UPDATE SKIP LOCKED; + +SELECT calls, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +DROP TABLE pgss_a, pgss_b CASCADE; + +-- +-- access to edb_stat_statements_info view +-- +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +SELECT dealloc FROM edb_stat_statements_info; + +-- FROM [ONLY] +CREATE TABLE tbl_inh(id integer); +CREATE TABLE tbl_inh_1() INHERITS (tbl_inh); +INSERT INTO tbl_inh_1 SELECT 1; + +SELECT * FROM tbl_inh; +SELECT * FROM ONLY tbl_inh; + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%FROM%tbl_inh%'; + +-- WITH TIES +CREATE TABLE limitoption AS SELECT 0 AS val FROM generate_series(1, 10); +SELECT * +FROM limitoption +WHERE val < 2 +ORDER BY val +FETCH FIRST 2 ROWS WITH TIES; + +SELECT * +FROM limitoption +WHERE val < 2 +ORDER BY val +FETCH FIRST 2 ROW ONLY; + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%FETCH FIRST%'; + +-- GROUP BY [DISTINCT] +SELECT a, b, c +FROM (VALUES (1, 2, 3), (4, NULL, 6), (7, 8, 9)) AS t (a, b, c) +GROUP BY ROLLUP(a, b), rollup(a, c) +ORDER BY a, b, c; +SELECT a, b, c +FROM (VALUES (1, 2, 3), (4, NULL, 6), (7, 8, 9)) AS t (a, b, c) +GROUP BY DISTINCT ROLLUP(a, b), rollup(a, c) +ORDER BY a, b, c; + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%GROUP BY%ROLLUP%'; + +-- GROUPING SET agglevelsup +SELECT ( + SELECT ( + SELECT GROUPING(a,b) FROM (VALUES (1)) v2(c) + ) FROM (VALUES (1,2)) v1(a,b) GROUP BY (a,b) +) FROM (VALUES(6,7)) v3(e,f) GROUP BY ROLLUP(e,f); +SELECT ( + SELECT ( + SELECT GROUPING(e,f) FROM (VALUES (1)) v2(c) + ) FROM (VALUES (1,2)) v1(a,b) GROUP BY (a,b) +) FROM (VALUES(6,7)) v3(e,f) GROUP BY ROLLUP(e,f); + +SELECT COUNT(*) FROM edb_stat_statements WHERE query LIKE '%SELECT GROUPING%'; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/user_activity.sql b/edb_stat_statements/sql/user_activity.sql new file mode 100644 index 00000000000..47c0e0639fa --- /dev/null +++ b/edb_stat_statements/sql/user_activity.sql @@ -0,0 +1,67 @@ +-- +-- Track user activity and reset them +-- + +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +CREATE ROLE regress_stats_user1; +CREATE ROLE regress_stats_user2; + +SET ROLE regress_stats_user1; + +SELECT 1 AS "ONE"; +SELECT 1+1 AS "TWO"; + +RESET ROLE; +SET ROLE regress_stats_user2; + +SELECT 1 AS "ONE"; +SELECT 1+1 AS "TWO"; + +RESET ROLE; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- Don't reset anything if any of the parameter is NULL +-- +SELECT edb_stat_statements_reset(NULL) IS NOT NULL AS t; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- remove query ('SELECT $1+$2 AS "TWO"') executed by regress_stats_user2 +-- in the current_database +-- +SELECT edb_stat_statements_reset( + (SELECT r.oid FROM pg_roles AS r WHERE r.rolname = 'regress_stats_user2'), + ARRAY(SELECT d.oid FROM pg_database As d where datname = current_database()), + (SELECT s.queryid FROM edb_stat_statements AS s + WHERE s.query = 'SELECT $1+$2 AS "TWO"' LIMIT 1)) + IS NOT NULL AS t; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- remove query ('SELECT $1 AS "ONE"') executed by two users +-- +SELECT edb_stat_statements_reset(0,'{}',s.queryid) IS NOT NULL AS t + FROM edb_stat_statements AS s WHERE s.query = 'SELECT $1 AS "ONE"'; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- remove query of a user (regress_stats_user1) +-- +SELECT edb_stat_statements_reset(r.oid) IS NOT NULL AS t + FROM pg_roles AS r WHERE r.rolname = 'regress_stats_user1'; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- reset all +-- +SELECT edb_stat_statements_reset(0,'{}',0) IS NOT NULL AS t; +SELECT query, calls, rows FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- +-- cleanup +-- +DROP ROLE regress_stats_user1; +DROP ROLE regress_stats_user2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/utility.sql b/edb_stat_statements/sql/utility.sql new file mode 100644 index 00000000000..1d1c3961491 --- /dev/null +++ b/edb_stat_statements/sql/utility.sql @@ -0,0 +1,374 @@ +-- +-- Utility commands +-- + +-- These tests require track_utility to be enabled. +SET edb_stat_statements.track_utility = TRUE; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Tables, indexes, triggers +CREATE TEMP TABLE tab_stats (a int, b char(20)); +CREATE INDEX index_stats ON tab_stats(b, (b || 'data1'), (b || 'data2')) WHERE a > 0; +ALTER TABLE tab_stats ALTER COLUMN b set default 'a'; +ALTER TABLE tab_stats ALTER COLUMN b TYPE text USING 'data' || b; +ALTER TABLE tab_stats ADD CONSTRAINT a_nonzero CHECK (a <> 0); +DROP TABLE tab_stats \; +DROP TABLE IF EXISTS tab_stats \; +-- This DROP query uses two different strings, still they count as one entry. +DROP TABLE IF EXISTS tab_stats \; +Drop Table If Exists tab_stats \; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Partitions +CREATE TABLE pt_stats (a int, b int) PARTITION BY range (a); +CREATE TABLE pt_stats1 (a int, b int); +ALTER TABLE pt_stats ATTACH PARTITION pt_stats1 FOR VALUES FROM (0) TO (100); +CREATE TABLE pt_stats2 PARTITION OF pt_stats FOR VALUES FROM (100) TO (200); +CREATE INDEX pt_stats_index ON ONLY pt_stats (a); +CREATE INDEX pt_stats2_index ON ONLY pt_stats2 (a); +ALTER INDEX pt_stats_index ATTACH PARTITION pt_stats2_index; +DROP TABLE pt_stats; + +-- Views +CREATE VIEW view_stats AS SELECT 1::int AS a, 2::int AS b; +ALTER VIEW view_stats ALTER COLUMN a SET DEFAULT 2; +DROP VIEW view_stats; + +-- Foreign tables +CREATE FOREIGN DATA WRAPPER wrapper_stats; +CREATE SERVER server_stats FOREIGN DATA WRAPPER wrapper_stats; +CREATE FOREIGN TABLE foreign_stats (a int) SERVER server_stats; +ALTER FOREIGN TABLE foreign_stats ADD COLUMN b integer DEFAULT 1; +ALTER FOREIGN TABLE foreign_stats ADD CONSTRAINT b_nonzero CHECK (b <> 0); +DROP FOREIGN TABLE foreign_stats; +DROP SERVER server_stats; +DROP FOREIGN DATA WRAPPER wrapper_stats; + +-- Functions +CREATE FUNCTION func_stats(a text DEFAULT 'a_data', b text DEFAULT lower('b_data')) + RETURNS text AS $$ SELECT $1::text || '_' || $2::text; $$ LANGUAGE SQL + SET work_mem = '256kB'; +DROP FUNCTION func_stats; + +-- Rules +CREATE TABLE tab_rule_stats (a int, b int); +CREATE TABLE tab_rule_stats_2 (a int, b int, c int, d int); +CREATE RULE rules_stats AS ON INSERT TO tab_rule_stats DO INSTEAD + INSERT INTO tab_rule_stats_2 VALUES(new.*, 1, 2); +DROP RULE rules_stats ON tab_rule_stats; +DROP TABLE tab_rule_stats, tab_rule_stats_2; + +-- Types +CREATE TYPE stats_type as (f1 numeric(35, 6), f2 numeric(35, 2)); +DROP TYPE stats_type; + +-- Triggers +CREATE TABLE trigger_tab_stats (a int, b int); +CREATE FUNCTION trigger_func_stats () RETURNS trigger LANGUAGE plpgsql + AS $$ BEGIN return OLD; end; $$; +CREATE TRIGGER trigger_tab_stats + AFTER UPDATE ON trigger_tab_stats + FOR EACH ROW WHEN (OLD.a < 0 AND OLD.b < 1 AND true) + EXECUTE FUNCTION trigger_func_stats(); +DROP TABLE trigger_tab_stats; + +-- Policies +CREATE TABLE tab_policy_stats (a int, b int); +CREATE POLICY policy_stats ON tab_policy_stats USING (a = 5) WITH CHECK (b < 5); +DROP TABLE tab_policy_stats; + +-- Statistics +CREATE TABLE tab_expr_stats (a int, b int); +CREATE STATISTICS tab_expr_stats_1 (mcv) ON a, (2*a), (3*b) FROM tab_expr_stats; +DROP TABLE tab_expr_stats; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Transaction statements +BEGIN; +ABORT; +BEGIN; +ROLLBACK; +-- WORK +BEGIN WORK; +COMMIT WORK; +BEGIN WORK; +ABORT WORK; +-- TRANSACTION +BEGIN TRANSACTION; +COMMIT TRANSACTION; +BEGIN TRANSACTION; +ABORT TRANSACTION; +-- More isolation levels +BEGIN TRANSACTION DEFERRABLE; +COMMIT TRANSACTION AND NO CHAIN; +BEGIN ISOLATION LEVEL SERIALIZABLE; +COMMIT; +BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE; +COMMIT; +-- List of A_Const nodes, same lists. +BEGIN TRANSACTION READ ONLY, READ WRITE, DEFERRABLE, NOT DEFERRABLE; +COMMIT; +BEGIN TRANSACTION NOT DEFERRABLE, READ ONLY, READ WRITE, DEFERRABLE; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Two-phase transactions +BEGIN; +PREPARE TRANSACTION 'stat_trans1'; +COMMIT PREPARED 'stat_trans1'; +BEGIN; +PREPARE TRANSACTION 'stat_trans2'; +ROLLBACK PREPARED 'stat_trans2'; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Savepoints +BEGIN; +SAVEPOINT sp1; +SAVEPOINT sp2; +SAVEPOINT sp3; +SAVEPOINT sp4; +ROLLBACK TO sp4; +ROLLBACK TO SAVEPOINT sp4; +ROLLBACK TRANSACTION TO SAVEPOINT sp3; +RELEASE sp3; +RELEASE SAVEPOINT sp2; +ROLLBACK TO sp1; +RELEASE SAVEPOINT sp1; +COMMIT; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- EXPLAIN statements +-- A Query is used, normalized by the query jumbling. +EXPLAIN (costs off) SELECT 1; +EXPLAIN (costs off) SELECT 2; +EXPLAIN (costs off) SELECT a FROM generate_series(1,10) AS tab(a) WHERE a = 3; +EXPLAIN (costs off) SELECT a FROM generate_series(1,10) AS tab(a) WHERE a = 7; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- CALL +CREATE OR REPLACE PROCEDURE sum_one(i int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + i)::int INTO r; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE PROCEDURE sum_two(i int, j int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + j)::int INTO r; +END; $$ LANGUAGE plpgsql; +-- Overloaded functions. +CREATE OR REPLACE PROCEDURE overload(i int) AS $$ +DECLARE + r int; +BEGIN + SELECT (i + i)::int INTO r; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE PROCEDURE overload(i text) AS $$ +DECLARE + r text; +BEGIN + SELECT i::text INTO r; +END; $$ LANGUAGE plpgsql; +-- Mix of IN/OUT parameters. +CREATE OR REPLACE PROCEDURE in_out(i int, i2 OUT int, i3 INOUT int) AS $$ +DECLARE + r int; +BEGIN + i2 := i; + i3 := i3 + i; +END; $$ LANGUAGE plpgsql; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +CALL sum_one(3); +CALL sum_one(199); +CALL sum_two(1,1); +CALL sum_two(1,2); +CALL overload(1); +CALL overload('A'); +CALL in_out(1, NULL, 1); +CALL in_out(2, 1, 2); +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +-- COPY +CREATE TABLE copy_stats (a int, b int); +SELECT edb_stat_statements_reset() IS NOT NULL AS t; +-- Some queries with A_Const nodes. +COPY (SELECT 1) TO STDOUT; +COPY (SELECT 2) TO STDOUT; +COPY (INSERT INTO copy_stats VALUES (1, 1) RETURNING *) TO STDOUT; +COPY (INSERT INTO copy_stats VALUES (2, 2) RETURNING *) TO STDOUT; +COPY (UPDATE copy_stats SET b = b + 1 RETURNING *) TO STDOUT; +COPY (UPDATE copy_stats SET b = b + 2 RETURNING *) TO STDOUT; +COPY (DELETE FROM copy_stats WHERE a = 1 RETURNING *) TO STDOUT; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +DROP TABLE copy_stats; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- CREATE TABLE AS +-- SELECT queries are normalized, creating matching query IDs. +CREATE TABLE ctas_stats_1 AS SELECT 1 AS a; +DROP TABLE ctas_stats_1; +CREATE TABLE ctas_stats_1 AS SELECT 2 AS a; +DROP TABLE ctas_stats_1; +CREATE TABLE ctas_stats_2 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP TABLE ctas_stats_2; +CREATE TABLE ctas_stats_2 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 1; +DROP TABLE ctas_stats_2; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- CREATE MATERIALIZED VIEW +-- SELECT queries are normalized, creating matching query IDs. +CREATE MATERIALIZED VIEW matview_stats_1 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP MATERIALIZED VIEW matview_stats_1; +CREATE MATERIALIZED VIEW matview_stats_1 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3; +DROP MATERIALIZED VIEW matview_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- CREATE VIEW +CREATE VIEW view_stats_1 AS + SELECT a AS col1, 2::int AS col2 + FROM generate_series(1, 10) AS tab(a) WHERE a < 5 AND a > 2; +DROP VIEW view_stats_1; +CREATE VIEW view_stats_1 AS + SELECT a AS col1, 4::int AS col2 + FROM generate_series(1, 5) AS tab(a) WHERE a < 4 AND a > 3; +DROP VIEW view_stats_1; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Domains +CREATE DOMAIN domain_stats AS int CHECK (VALUE > 0); +ALTER DOMAIN domain_stats SET DEFAULT '3'; +ALTER DOMAIN domain_stats ADD CONSTRAINT higher_than_one CHECK (VALUE > 1); +DROP DOMAIN domain_stats; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Execution statements +SELECT 1 as a; +PREPARE stat_select AS SELECT $1 AS a; +EXECUTE stat_select (1); +DEALLOCATE stat_select; +PREPARE stat_select AS SELECT $1 AS a; +EXECUTE stat_select (2); +DEALLOCATE PREPARE stat_select; +DEALLOCATE ALL; +DEALLOCATE PREPARE ALL; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- SET statements. +-- These use two different strings, still they count as one entry. +CREATE ROLE regress_stat_set_1; +CREATE ROLE regress_stat_set_2; +SET work_mem = '1MB'; +Set work_mem = '1MB'; +SET work_mem = '2MB'; +SET work_mem = DEFAULT; +SET work_mem TO DEFAULT; +SET work_mem FROM CURRENT; +BEGIN; +SET LOCAL work_mem = '128kB'; +SET LOCAL work_mem = '256kB'; +SET LOCAL work_mem = DEFAULT; +SET LOCAL work_mem TO DEFAULT; +SET LOCAL work_mem FROM CURRENT; +COMMIT; +RESET work_mem; +SET enable_seqscan = off; +SET enable_seqscan = on; +SET SESSION work_mem = '300kB'; +SET SESSION work_mem = '400kB'; +RESET enable_seqscan; +-- SET TRANSACTION ISOLATION +BEGIN; +SET TRANSACTION ISOLATION LEVEL READ COMMITTED; +SET TRANSACTION ISOLATION LEVEL REPEATABLE READ; +SET TRANSACTION ISOLATION LEVEL SERIALIZABLE; +COMMIT; +-- SET SESSION AUTHORIZATION +SET SESSION SESSION AUTHORIZATION DEFAULT; +SET SESSION AUTHORIZATION 'regress_stat_set_1'; +SET SESSION AUTHORIZATION 'regress_stat_set_2'; +RESET SESSION AUTHORIZATION; +BEGIN; +SET LOCAL SESSION AUTHORIZATION DEFAULT; +SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_1'; +SET LOCAL SESSION AUTHORIZATION 'regress_stat_set_2'; +RESET SESSION AUTHORIZATION; +COMMIT; +-- SET SESSION CHARACTERISTICS +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY; +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ ONLY; +SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY, READ WRITE; +-- SET XML OPTION +SET XML OPTION DOCUMENT; +SET XML OPTION CONTENT; +-- SET TIME ZONE +SET TIME ZONE 'America/New_York'; +SET TIME ZONE 'Asia/Tokyo'; +SET TIME ZONE DEFAULT; +SET TIME ZONE LOCAL; +SET TIME ZONE 'CST7CDT,M4.1.0,M10.5.0'; +RESET TIME ZONE; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; +DROP ROLE regress_stat_set_1; +DROP ROLE regress_stat_set_2; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- +-- Track the total number of rows retrieved or affected by the utility +-- commands of COPY, FETCH, CREATE TABLE AS, CREATE MATERIALIZED VIEW, +-- REFRESH MATERIALIZED VIEW and SELECT INTO +-- +CREATE TABLE pgss_ctas AS SELECT a, 'ctas' b FROM generate_series(1, 10) a; +SELECT generate_series(1, 10) c INTO pgss_select_into; +COPY pgss_ctas (a, b) FROM STDIN; +11 copy +12 copy +13 copy +\. +CREATE MATERIALIZED VIEW pgss_matv AS SELECT * FROM pgss_ctas; +REFRESH MATERIALIZED VIEW pgss_matv; +BEGIN; +DECLARE pgss_cursor CURSOR FOR SELECT * FROM pgss_matv; +FETCH NEXT pgss_cursor; +FETCH FORWARD 5 pgss_cursor; +FETCH FORWARD ALL pgss_cursor; +COMMIT; + +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +DROP MATERIALIZED VIEW pgss_matv; +DROP TABLE pgss_ctas; +DROP TABLE pgss_select_into; + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; + +-- Special cases. Keep these ones at the end to avoid conflicts. +SET SCHEMA 'foo'; +SET SCHEMA 'public'; +RESET ALL; +SELECT calls, rows, query FROM edb_stat_statements ORDER BY query COLLATE "C"; + +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/sql/wal.sql b/edb_stat_statements/sql/wal.sql new file mode 100644 index 00000000000..2555460514a --- /dev/null +++ b/edb_stat_statements/sql/wal.sql @@ -0,0 +1,20 @@ +-- +-- Validate WAL generation metrics +-- + +SET edb_stat_statements.track_utility = FALSE; + +CREATE TABLE pgss_wal_tab (a int, b char(20)); + +INSERT INTO pgss_wal_tab VALUES(generate_series(1, 10), 'aaa'); +UPDATE pgss_wal_tab SET b = 'bbb' WHERE a > 7; +DELETE FROM pgss_wal_tab WHERE a > 9; +DROP TABLE pgss_wal_tab; + +-- Check WAL is generated for the above statements +SELECT query, calls, rows, +wal_bytes > 0 as wal_bytes_generated, +wal_records > 0 as wal_records_generated, +wal_records >= rows as wal_records_ge_rows +FROM edb_stat_statements ORDER BY query COLLATE "C"; +SELECT edb_stat_statements_reset() IS NOT NULL AS t; diff --git a/edb_stat_statements/t/010_restart.pl b/edb_stat_statements/t/010_restart.pl new file mode 100644 index 00000000000..3a5dab06a6a --- /dev/null +++ b/edb_stat_statements/t/010_restart.pl @@ -0,0 +1,55 @@ +# Copyright (c) 2023-2024, PostgreSQL Global Development Group + +# Tests for checking that edb_stat_statements contents are preserved +# across restarts. + +use strict; +use warnings FATAL => 'all'; +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; + +my $node = PostgreSQL::Test::Cluster->new('main'); +$node->init; +$node->append_conf('postgresql.conf', + "shared_preload_libraries = 'edb_stat_statements'"); +$node->append_conf('postgresql.conf', + "edb_stat_statements.track_unrecognized = true"); +$node->start; + +$node->safe_psql('postgres', 'CREATE EXTENSION edb_stat_statements'); + +$node->safe_psql('postgres', 'CREATE TABLE t1 (a int)'); +$node->safe_psql('postgres', 'SELECT a FROM t1'); + +is( $node->safe_psql( + 'postgres', + "SELECT query FROM edb_stat_statements WHERE query NOT LIKE '%edb_stat_statements%' ORDER BY query" + ), + "CREATE TABLE t1 (a int)\nSELECT a FROM t1", + 'edb_stat_statements populated'); + +$node->restart; + +is( $node->safe_psql( + 'postgres', + "SELECT query FROM edb_stat_statements WHERE query NOT LIKE '%edb_stat_statements%' ORDER BY query" + ), + "CREATE TABLE t1 (a int)\nSELECT a FROM t1", + 'edb_stat_statements data kept across restart'); + +$node->append_conf('postgresql.conf', "edb_stat_statements.save = false"); +$node->reload; + +$node->restart; + +is( $node->safe_psql( + 'postgres', + "SELECT count(*) FROM edb_stat_statements WHERE query NOT LIKE '%edb_stat_statements%'" + ), + '0', + 'edb_stat_statements data not kept across restart with .save=false'); + +$node->stop; + +done_testing(); diff --git a/pyproject.toml b/pyproject.toml index 7903fbb4ba5..875884969f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,12 +4,12 @@ description = "Gel Server" requires-python = '>=3.12.0' dynamic = ["version"] dependencies = [ - 'edgedb~=2.1.0', + 'edgedb==3.0.0b2', 'httptools>=0.6.0', 'immutables>=0.18', 'parsing~=2.0', - 'uvloop~=0.19.0', + 'uvloop~=0.21.0', 'click~=8.0', 'cryptography~=42.0', @@ -21,7 +21,7 @@ dependencies = [ 'hishel==0.0.24', 'webauthn~=2.0.0', 'argon2-cffi~=23.1.0', - 'aiosmtplib~=2.0', + 'aiosmtplib~=3.0', 'tiktoken~=0.7.0', 'mistral_common~=1.2.1', ] @@ -44,7 +44,7 @@ test = [ 'black~=24.2.0', 'coverage~=7.4', 'ruff==0.3.7', - 'asyncpg~=0.29.0', + 'asyncpg~=0.30.0', # Needed for testing asyncutil 'async_solipsism==0.5.0', @@ -57,7 +57,7 @@ test = [ 'MarkupSafe~=1.1', 'PyYAML~=6.0', - 'mypy~=1.10.0', + 'mypy[faster-cache] ~= 1.13.0', # mypy stub packages; when updating, you can use mypy --install-types # to install stub packages and then pip freeze to read out the specifier 'types-docutils~=0.17.0,<0.17.6', # incomplete nodes.document.__init__ @@ -96,14 +96,14 @@ language-server = ['pygls~=1.3.1'] [build-system] requires = [ - "Cython (>=0.29.32, <0.30.0)", + "Cython(>=3.0.11,<3.1.0)", "packaging >= 21.0", "setuptools >= 67", "setuptools-rust ~= 1.8", "wheel", "parsing ~= 2.0", - 'edgedb~=2.1.0', + "edgedb==3.0.0b2", ] # Custom backend needed to set up build-time sys.path because # setup.py needs to import `edb.buildmeta`. @@ -268,3 +268,4 @@ lint.flake8-bugbear.extend-immutable-calls = [ [tool.pyright] # Pyright has no idea about metaclass-generated getters for schema fields. reportAttributeAccessIssue = false +typeCheckingMode = "off" diff --git a/setup.py b/setup.py index 22f78142163..a3353112526 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ EXT_LIB_DIRS = [ (ROOT_PATH / 'edb' / 'pgsql' / 'parser' / 'libpg_query').as_posix() ] +EDBSS_DIR = ROOT_PATH / 'edb_stat_statements' if platform.uname().system != 'Windows': @@ -195,7 +196,7 @@ def _get_env_with_openssl_flags(): return env -def _compile_postgres(build_base, *, +def _compile_postgres(build_base, build_temp, *, force_build=False, fresh_build=True, run_configure=True, build_contrib=True, produce_compile_commands_json=False): @@ -246,12 +247,21 @@ def _compile_postgres(build_base, *, if run_configure or fresh_build or is_outdated: env = _get_env_with_openssl_flags() - subprocess.run([ + cmd = [ str(postgres_src / 'configure'), '--prefix=' + str(postgres_build / 'install'), '--with-openssl', '--with-uuid=' + uuidlib, - ], check=True, cwd=str(build_dir), env=env) + ] + if os.environ.get('EDGEDB_DEBUG'): + cmd += [ + '--enable-tap-tests', + '--enable-debug', + ] + cflags = os.environ.get("CFLAGS", "") + cflags = f"{cflags} -O0" + env['CFLAGS'] = cflags + subprocess.run(cmd, check=True, cwd=str(build_dir), env=env) if produce_compile_commands_json: make = ['bear', '--', 'make'] @@ -279,6 +289,12 @@ def _compile_postgres(build_base, *, ['make', '-C', 'contrib', 'MAKELEVEL=0', 'install'], cwd=str(build_dir), check=True) + pg_config = ( + build_base / 'postgres' / 'install' / 'bin' / 'pg_config' + ).resolve() + _compile_pgvector(pg_config, build_temp) + _compile_edb_stat_statements(pg_config, build_temp) + with open(postgres_build_stamp, 'w') as f: f.write(source_stamp) @@ -289,7 +305,7 @@ def _compile_postgres(build_base, *, ) -def _compile_pgvector(build_base, build_temp): +def _compile_pgvector(pg_config, build_temp): git_rev = _get_git_rev(PGVECTOR_REPO, PGVECTOR_COMMIT) pgv_root = (build_temp / 'pgvector').resolve() @@ -317,10 +333,6 @@ def _compile_pgvector(build_base, build_temp): cwd=pgv_root, ) - pg_config = ( - build_base / 'postgres' / 'install' / 'bin' / 'pg_config' - ).resolve() - cflags = os.environ.get("CFLAGS", "") cflags = f"{cflags} {' '.join(SAFE_EXT_CFLAGS)} -std=gnu99" @@ -344,6 +356,27 @@ def _compile_pgvector(build_base, build_temp): ) +def _compile_edb_stat_statements(pg_config, build_temp): + subprocess.run( + [ + 'make', + f'PG_CONFIG={pg_config}', + ], + cwd=EDBSS_DIR, + check=True, + ) + + subprocess.run( + [ + 'make', + 'install', + f'PG_CONFIG={pg_config}', + ], + cwd=EDBSS_DIR, + check=True, + ) + + def _compile_libpg_query(): dir = (ROOT_PATH / 'edb' / 'pgsql' / 'parser' / 'libpg_query').resolve() @@ -387,13 +420,27 @@ def _get_git_rev(repo, ref): def _get_pg_source_stamp(): + from edb.buildmeta import hash_dirs + output = subprocess.check_output( ['git', 'submodule', 'status', '--cached', 'postgres'], universal_newlines=True, cwd=ROOT_PATH, ) revision, _, _ = output[1:].partition(' ') - source_stamp = revision + '+' + PGVECTOR_COMMIT + edbss_dir = EDBSS_DIR.as_posix() + edbss_hash = hash_dirs( + [(edbss_dir, '.c'), (edbss_dir, '.sql')], + extra_files=[ + EDBSS_DIR / 'Makefile', + EDBSS_DIR / 'edb_stat_statements.control', + ], + ) + edbss = binascii.hexlify(edbss_hash).decode() + stamp_list = [revision, PGVECTOR_COMMIT, edbss] + if os.environ.get('EDGEDB_DEBUG'): + stamp_list += ['debug'] + source_stamp = '+'.join(stamp_list) return source_stamp.strip() @@ -413,21 +460,30 @@ def _compile_cli(build_base, build_temp): env = dict(os.environ) env['CARGO_TARGET_DIR'] = str(build_temp / 'rust' / 'cli') env['PSQL_DEFAULT_PATH'] = build_base / 'postgres' / 'install' / 'bin' - git_ref = env.get("EDGEDBCLI_GIT_REV") or EDGEDBCLI_COMMIT - git_rev = _get_git_rev(EDGEDBCLI_REPO, git_ref) - - subprocess.run( - [ - 'cargo', 'install', - '--verbose', '--verbose', + path = env.get("EDGEDBCLI_PATH") + args = [ + 'cargo', 'install', + '--verbose', '--verbose', + '--bin', 'edgedb', + '--root', rust_root, + '--features=dev_mode', + '--locked', + '--debug', + ] + if path: + args.extend([ + '--path', path, + ]) + else: + git_ref = env.get("EDGEDBCLI_GIT_REV") or EDGEDBCLI_COMMIT + git_rev = _get_git_rev(EDGEDBCLI_REPO, git_ref) + args.extend([ '--git', EDGEDBCLI_REPO, '--rev', git_rev, - '--bin', 'edgedb', - '--root', rust_root, - '--features=dev_mode', - '--locked', - '--debug', - ], + ]) + + subprocess.run( + args, env=env, check=True, ) @@ -647,16 +703,13 @@ def run(self, *args, **kwargs): build = self.get_finalized_command('build') _compile_postgres( pathlib.Path(build.build_base).resolve(), + pathlib.Path(build.build_temp).resolve(), force_build=True, fresh_build=self.fresh_build, run_configure=self.configure, build_contrib=self.build_contrib, produce_compile_commands_json=self.compile_commands, ) - _compile_pgvector( - pathlib.Path(build.build_base).resolve(), - pathlib.Path(build.build_temp).resolve(), - ) class build_libpg_query(setuptools.Command): @@ -682,8 +735,8 @@ class build_ext(setuptools_build_ext.build_ext): user_options = setuptools_build_ext.build_ext.user_options + [ ('cython-annotate', None, 'Produce a colorized HTML version of the Cython source.'), - ('cython-directives=', None, - 'Cython compiler directives'), + ('cython-extra-directives=', None, + 'Extra Cython compiler directives'), ] def initialize_options(self): @@ -698,17 +751,17 @@ def initialize_options(self): if os.environ.get('EDGEDB_DEBUG'): self.cython_always = True self.cython_annotate = True - self.cython_directives = "linetrace=True" + self.cython_extra_directives = "linetrace=True" self.define = 'PG_DEBUG,CYTHON_TRACE,CYTHON_TRACE_NOGIL' self.debug = True else: self.cython_always = False self.cython_annotate = None - self.cython_directives = None + self.cython_extra_directives = None self.debug = False self.build_mode = os.environ.get('BUILD_EXT_MODE', 'both') - def finalize_options(self): + def finalize_options(self) -> None: # finalize_options() may be called multiple times on the # same command object, so make sure not to override previously # set options. @@ -722,12 +775,12 @@ def finalize_options(self): super(build_ext, self).finalize_options() return - directives = { + directives: dict[str, str | bool] = { 'language_level': '3' } - if self.cython_directives: - for directive in self.cython_directives.split(','): + if self.cython_extra_directives: + for directive in self.cython_extra_directives.split(','): k, _, v = directive.partition('=') if v.lower() == 'false': v = False diff --git a/tests/schemas/advtypes.esdl b/tests/schemas/advtypes.esdl index f8e367f3f93..1feeb3155df 100644 --- a/tests/schemas/advtypes.esdl +++ b/tests/schemas/advtypes.esdl @@ -98,3 +98,146 @@ type XBb { type XBc { required property bc -> float64; } + +# Objects which all have a `numbers` property and `siblings` link + +# non-computed single + +type SoloNonCompSinglePropA { + single property numbers -> int64; +} +type SoloNonCompSinglePropB { + single property numbers -> int64; +} +type SoloNonCompSingleLinkA { + single link siblings -> SoloNonCompSingleLinkA; +} +type SoloNonCompSingleLinkB { + single link siblings -> SoloNonCompSingleLinkB; +} + +# non-computed multi + +type SoloNonCompMultiPropA { + multi property numbers -> int64; +} +type SoloNonCompMultiPropB { + multi property numbers -> int64; +} +type SoloNonCompMultiLinkA { + multi link siblings -> SoloNonCompMultiLinkA; +} +type SoloNonCompMultiLinkB { + multi link siblings -> SoloNonCompMultiLinkB; +} + +# computed single + +type SoloCompSinglePropA { + single property numbers := 1; +} +type SoloCompSinglePropB { + single property numbers := 1; +} +type SoloCompSingleLinkA { + single link siblings := (select detached SoloCompSingleLinkA limit 1); +} +type SoloCompSingleLinkB { + single link siblings := (select detached SoloCompSingleLinkB limit 1); +} + +# computed multi + +type SoloCompMultiPropA { + multi property numbers := {1, 2, 3}; +} +type SoloCompMultiPropB { + multi property numbers := {1, 2, 3}; +} +type SoloCompMultiLinkA { + multi link siblings := (select detached SoloCompMultiLinkA); +} +type SoloCompMultiLinkB { + multi link siblings := (select detached SoloCompMultiLinkB); +} + +# non-computed single from base class + +abstract type BaseNonCompSingleProp { + single property numbers -> int64; +} +type DerivedNonCompSinglePropA extending BaseNonCompSingleProp; +type DerivedNonCompSinglePropB extending BaseNonCompSingleProp; + +abstract type BaseNonCompSingleLink { + single link siblings -> BaseNonCompSingleLink; +} +type DerivedNonCompSingleLinkA extending BaseNonCompSingleLink; +type DerivedNonCompSingleLinkB extending BaseNonCompSingleLink; + +# non-computed multi from base class + +abstract type BaseNonCompMultiProp { + multi property numbers -> int64; +} +type DerivedNonCompMultiPropA extending BaseNonCompMultiProp; +type DerivedNonCompMultiPropB extending BaseNonCompMultiProp; + +abstract type BaseNonCompMultiLink { + multi link siblings -> BaseNonCompMultiLink; +} +type DerivedNonCompMultiLinkA extending BaseNonCompMultiLink; +type DerivedNonCompMultiLinkB extending BaseNonCompMultiLink; + +# computed single from base class + +abstract type BaseCompSingleProp { + single property numbers := 1; +} +type DerivedCompSinglePropA extending BaseCompSingleProp; +type DerivedCompSinglePropB extending BaseCompSingleProp; + +abstract type BaseCompSingleLink { + single link siblings := (select detached BaseCompSingleLink limit 1); +} +type DerivedCompSingleLinkA extending BaseCompSingleLink; +type DerivedCompSingleLinkB extending BaseCompSingleLink; + +# computed multi from base class + +abstract type BaseCompMultiProp { + multi property numbers := {1, 2, 3}; +} +type DerivedCompMultiPropA extending BaseCompMultiProp; +type DerivedCompMultiPropB extending BaseCompMultiProp; + +abstract type BaseCompMultiLink { + multi link siblings := (select detached BaseCompMultiLink); +} +type DerivedCompMultiLinkA extending BaseCompMultiLink; +type DerivedCompMultiLinkB extending BaseCompMultiLink; + +# Objects with links to a target type + +type Destination { + required property name -> str; +} + +# independent types with compatible pointers + +type SoloOriginA { + single link dest -> Destination; +} +type SoloOriginB { + single link dest -> Destination; +} + +# independent types with compatible pointers and common derived type + +type BaseOriginA { + single link dest -> Destination; +} +type BaseOriginB { + single link dest -> Destination; +} +type DerivedOriginC extending BaseOriginA, BaseOriginB; diff --git a/tests/schemas/cards.esdl b/tests/schemas/cards.esdl index 8fcfbba4b22..da140791e9e 100644 --- a/tests/schemas/cards.esdl +++ b/tests/schemas/cards.esdl @@ -159,3 +159,14 @@ alias UserAlias := ( alias SpecialCardAlias := SpecialCard { el_cost := (.element, .cost) }; + +alias AliasOne := 1; +global GlobalOne := 1; + +global HighestCost := ( + SELECT max(Card.cost) +); + +global CardsWithText := ( + SELECT Card FILTER exists(.text) +); diff --git a/tests/schemas/dump02_setup.edgeql b/tests/schemas/dump02_setup.edgeql index f7700acfc3e..cce4641f57d 100644 --- a/tests/schemas/dump02_setup.edgeql +++ b/tests/schemas/dump02_setup.edgeql @@ -19,8 +19,8 @@ SET MODULE default; -CREATE MIGRATION m12hdldnmvzj5weaevxsmizppnl2poo6nconx2hcfkklbwcghqsmaq -ONTO m1vyvlra26tef6oe6yu37m7lfw7i3ef3n62m6om353dvnbm3mynqqa { +CREATE MIGRATION m1t2phsw6j2rgl4ieihm6mnvoln3ssayxncjzl2kwkxmunn2f6aqha +ONTO m1iej6dr3hk33wykqwqgg4xxo3tivpiznpb2mto7qsw2zgipsbfihq { CREATE TYPE default::Migrated; create type default::Migrated2 {}; }; diff --git a/tests/schemas/dump_v4_setup.edgeql b/tests/schemas/dump_v4_setup.edgeql index 496d71dd0e3..7334687a4e1 100644 --- a/tests/schemas/dump_v4_setup.edgeql +++ b/tests/schemas/dump_v4_setup.edgeql @@ -55,8 +55,22 @@ ext::auth::AuthConfig::auth_signing_key := 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'; CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::token_time_to_live := '24 hours'; +# N.B: This CONFIGURE command was the original one, but then we +# removed that flag. We kept it working in dumps, though, so old +# dumps still work and behave as if they had the next two statements +# instead. +# +# CONFIGURE CURRENT DATABASE SET +# ext::auth::SMTPConfig::sender := 'noreply@example.com'; + +CONFIGURE CURRENT DATABASE INSERT cfg::SMTPProviderConfig { + name := "_default", + sender := 'noreply@example.com', +}; + CONFIGURE CURRENT DATABASE SET -ext::auth::SMTPConfig::sender := 'noreply@example.com'; +cfg::current_email_provider_name := "_default"; + CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::allowed_redirect_urls := { diff --git a/tests/test_dump_v4.py b/tests/test_dump_v4.py index 638925596b7..6722e37fcf8 100644 --- a/tests/test_dump_v4.py +++ b/tests/test_dump_v4.py @@ -131,6 +131,39 @@ async def _ensure_schema_data_integrity(self, include_secrets): }] ) + # We didn't specify include_secrets in the dumps we made for + # 4.0, but the way that smtp config was done then, it got + # dumped anyway. (The secret wasn't specified.) + has_smtp = ( + include_secrets + or self._testMethodName == 'test_dumpv4_restore_compatibility_4_0' + ) + + # N.B: This is not what it looked like in the original + # dumps. We patched it up during restore starting with 6.0. + if has_smtp: + await self.assert_query_result( + ''' + select cfg::Config { + email_providers[is cfg::SMTPProviderConfig]: { + name, sender + }, + current_email_provider_name, + }; + ''', + [ + { + "email_providers": [ + { + "name": "_default", + "sender": "noreply@example.com", + } + ], + "current_email_provider_name": "_default" + } + ], + ) + class TestDumpV4(tb.StableDumpTestCase, DumpTestCaseMixin): EXTENSIONS = ["pgvector", "_conf", "pgcrypto", "auth"] diff --git a/tests/test_edgeql_advtypes.py b/tests/test_edgeql_advtypes.py index 137429b7282..c43a2835b5e 100644 --- a/tests/test_edgeql_advtypes.py +++ b/tests/test_edgeql_advtypes.py @@ -19,6 +19,8 @@ import os.path +import edgedb + from edb.testbase import server as tb @@ -1650,3 +1652,120 @@ async def test_edgeql_advtypes_delete_complex_type_08(self): {'tn': 'default::CBc', 'ba': None, 'bb': None, 'bc': 1.5}, ], ) + + async def test_edgeql_advtypes_intersection_pointers_01(self): + # Type intersections with incompatible pointers should produce errors. + + type_roots = [ + "SoloNonCompSingle", + "SoloNonCompMulti", + "SoloCompSingle", + "SoloCompMulti", + "DerivedNonCompSingle", + "DerivedNonCompMulti", + "DerivedCompSingle", + "DerivedCompMulti", + ] + + for type_root_a in type_roots: + for type_root_b in type_roots: + for type_suffix, ptr_name in ( + ("Prop", "numbers"), + ("Link", "siblings"), + ): + if ( + # Either type has computed pointer + ( + "NonComp" not in type_root_a + or "NonComp" not in type_root_b + ) + # but the pointer doesn't come from a common base + and not ( + "Derived" in type_root_a + and type_root_a == type_root_b + ) + ): + async with self.assertRaisesRegexTx( + edgedb.SchemaError, + r"it is illegal to create a type intersection " + r"that causes a computed .* to mix " + r"with other versions of the same .*" + ): + await self.con.execute(f""" + select {type_root_a}{type_suffix}A {{ + x := ( + [is {type_root_b}{type_suffix}B] + .{ptr_name} + ) + }}; + """) + + elif ( + # differing pointer cardinalities + ("Single" in type_root_a) != ("Single" in type_root_b) + ): + async with self.assertRaisesRegexTx( + edgedb.SchemaError, + r"it is illegal to create a type intersection " + r"that causes a .* to mix " + r"with other versions of .* " + r"which have a different cardinality" + ): + await self.con.execute(f""" + select {type_root_a}{type_suffix}A {{ + x := ( + [is {type_root_b}{type_suffix}B] + .{ptr_name} + ) + }}; + """) + + else: + await self.con.execute(f""" + select {type_root_a}{type_suffix}A {{ + x := ( + [is {type_root_b}{type_suffix}B] + .{ptr_name} + ) + }}; + """) + + async def test_edgeql_advtypes_intersection_pointers_02(self): + # Intersection pointer should return nothing if they types are + # unrelated. + + await self.con.execute(""" + INSERT SoloOriginA { dest := (INSERT Destination{ name := "A" }) }; + INSERT SoloOriginB { dest := (INSERT Destination{ name := "B" }) }; + """) + + await self.assert_query_result( + r""" + SELECT SoloOriginA { + x := [is SoloOriginB].dest.name + } + """, + [{'x': None}], + ) + + async def test_edgeql_advtypes_intersection_pointers_03(self): + # Intersection pointer should return the correct values if the type + # intersection is not empty. + + await self.con.execute(""" + INSERT BaseOriginA { dest := (INSERT Destination{ name := "A" }) }; + INSERT BaseOriginB { dest := (INSERT Destination{ name := "B" }) }; + INSERT DerivedOriginC { + dest := (INSERT Destination{ name := "C" }) + }; + """) + + await self.assert_query_result( + r""" + SELECT BaseOriginA { + x := [is BaseOriginB].dest.name + } + ORDER BY .x + """, + [{'x': None}, {'x': 'C'}], + ) diff --git a/tests/test_edgeql_ddl.py b/tests/test_edgeql_ddl.py index 5b0061fd029..b6ea703f32f 100644 --- a/tests/test_edgeql_ddl.py +++ b/tests/test_edgeql_ddl.py @@ -9412,6 +9412,13 @@ async def test_edgeql_ddl_extension_02(self): algo = ext::auth::JWTAlgo.RS256 ); }; + + create type ext::auth::Config extending std::BaseObject { + create property supported_algos: + array; + create multi property algo_config: + tuple; + }; } """) @@ -10452,6 +10459,43 @@ async def test_edgeql_ddl_alias_13(self): [True] ) + async def test_edgeql_ddl_alias_14(self): + # Issue #8003 + await self.con.execute(r""" + create global One := 1; + create alias MyAlias := global One; + """) + + async with self.assertRaisesRegexTx( + edgedb.SchemaDefinitionError, "index expressions must be immutable" + ): + await self.con.execute( + r""" + create type Foo { create index on (MyAlias) }; + """ + ) + + async def test_edgeql_ddl_alias_15(self): + # Issue #8003 + await self.con.execute( + r""" + create global One := 1; + create alias MyAlias := 1; + create type Foo { create index on (MyAlias) }; + """ + ) + + async with self.assertRaisesRegexTx( + edgedb.SchemaDefinitionError, + "cannot alter alias 'default::MyAlias' because this affects " + "expression of index of object type 'default::Foo'" + ): + await self.con.execute( + r""" + alter alias MyAlias {using (global One)}; + """ + ) + async def test_edgeql_ddl_inheritance_alter_01(self): await self.con.execute(r""" CREATE TYPE InhTest01 { @@ -14823,7 +14867,7 @@ async def test_edgeql_ddl_drop_multi_prop_01(self): """) async def test_edgeql_ddl_collection_cleanup_01(self): - count_query = "SELECT count(schema::Array);" + count_query = "SELECT count(schema::Tuple);" orig_count = await self.con.query_single(count_query) await self.con.execute(r""" @@ -14832,9 +14876,9 @@ async def test_edgeql_ddl_collection_cleanup_01(self): CREATE SCALAR TYPE b extending str; CREATE SCALAR TYPE c extending str; - CREATE TYPE TestArrays { - CREATE PROPERTY x -> array; - CREATE PROPERTY y -> array; + CREATE TYPE TestTuples { + CREATE PROPERTY x -> tuple; + CREATE PROPERTY y -> tuple; }; """) @@ -14844,7 +14888,7 @@ async def test_edgeql_ddl_collection_cleanup_01(self): ) await self.con.execute(r""" - ALTER TYPE TestArrays { + ALTER TYPE TestTuples { DROP PROPERTY x; }; """) @@ -14855,10 +14899,10 @@ async def test_edgeql_ddl_collection_cleanup_01(self): ) await self.con.execute(r""" - ALTER TYPE TestArrays { + ALTER TYPE TestTuples { ALTER PROPERTY y { - SET TYPE array USING ( - >>.y); + SET TYPE tuple USING ( + >>.y); } }; """) @@ -14869,13 +14913,13 @@ async def test_edgeql_ddl_collection_cleanup_01(self): ) await self.con.execute(r""" - DROP TYPE TestArrays; + DROP TYPE TestTuples; """) self.assertEqual(await self.con.query_single(count_query), orig_count) async def test_edgeql_ddl_collection_cleanup_01b(self): - count_query = "SELECT count(schema::Array);" + count_query = "SELECT count(schema::Tuple);" orig_count = await self.con.query_single(count_query) await self.con.execute(r""" @@ -14884,10 +14928,10 @@ async def test_edgeql_ddl_collection_cleanup_01b(self): CREATE SCALAR TYPE b extending str; CREATE SCALAR TYPE c extending str; - CREATE TYPE TestArrays { - CREATE PROPERTY x -> array; - CREATE PROPERTY y -> array; - CREATE PROPERTY z -> array; + CREATE TYPE TestTuples { + CREATE PROPERTY x -> tuple; + CREATE PROPERTY y -> tuple; + CREATE PROPERTY z -> tuple; }; """) @@ -14897,7 +14941,7 @@ async def test_edgeql_ddl_collection_cleanup_01b(self): ) await self.con.execute(r""" - ALTER TYPE TestArrays { + ALTER TYPE TestTuples { DROP PROPERTY x; }; """) @@ -14908,10 +14952,10 @@ async def test_edgeql_ddl_collection_cleanup_01b(self): ) await self.con.execute(r""" - ALTER TYPE TestArrays { + ALTER TYPE TestTuples { ALTER PROPERTY y { - SET TYPE array USING ( - >>.y); + SET TYPE tuple USING ( + >>.y); } }; """) @@ -14922,7 +14966,7 @@ async def test_edgeql_ddl_collection_cleanup_01b(self): ) await self.con.execute(r""" - DROP TYPE TestArrays; + DROP TYPE TestTuples; """) self.assertEqual(await self.con.query_single(count_query), orig_count) @@ -14944,14 +14988,17 @@ async def test_edgeql_ddl_collection_cleanup_02(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, + orig_count + 3 + 2, ) await self.con.execute(r""" DROP TYPE TestArrays; """) - self.assertEqual(await self.con.query_single(count_query), orig_count) + self.assertEqual( + await self.con.query_single(count_query), + orig_count + 3, + ) async def test_edgeql_ddl_collection_cleanup_03(self): count_query = "SELECT count(schema::CollectionType);" @@ -14972,7 +15019,7 @@ async def test_edgeql_ddl_collection_cleanup_03(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 4, + orig_count + 3 + 2, ) await self.con.execute(r""" @@ -14980,9 +15027,14 @@ async def test_edgeql_ddl_collection_cleanup_03(self): x: array, z: tuple, y: array>); """) - self.assertEqual(await self.con.query_single(count_query), orig_count) self.assertEqual( - await self.con.query_single(elem_count_query), orig_elem_count) + await self.con.query_single(count_query), + orig_count + 3, + ) + self.assertEqual( + await self.con.query_single(elem_count_query), + orig_elem_count, + ) async def test_edgeql_ddl_collection_cleanup_04(self): count_query = "SELECT count(schema::CollectionType);" @@ -15005,7 +15057,7 @@ async def test_edgeql_ddl_collection_cleanup_04(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 1, + orig_count + 3 + 1, ) await self.con.execute(r""" @@ -15014,7 +15066,7 @@ async def test_edgeql_ddl_collection_cleanup_04(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 1, + orig_count + 3 + 1, ) await self.con.execute(r""" @@ -15023,7 +15075,7 @@ async def test_edgeql_ddl_collection_cleanup_04(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, + orig_count + 3 + 2, ) await self.con.execute(r""" @@ -15032,7 +15084,7 @@ async def test_edgeql_ddl_collection_cleanup_04(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, + orig_count + 3 + 2, ) await self.con.execute(r""" @@ -15041,7 +15093,7 @@ async def test_edgeql_ddl_collection_cleanup_04(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, + orig_count + 3 + 2, ) # Make a change that doesn't change the types @@ -15051,14 +15103,17 @@ async def test_edgeql_ddl_collection_cleanup_04(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, + orig_count + 3 + 2, ) await self.con.execute(r""" DROP ALIAS Bar; """) - self.assertEqual(await self.con.query_single(count_query), orig_count) + self.assertEqual( + await self.con.query_single(count_query), + orig_count + 3, + ) async def test_edgeql_ddl_collection_cleanup_05(self): count_query = "SELECT count(schema::CollectionType);" @@ -15074,7 +15129,9 @@ async def test_edgeql_ddl_collection_cleanup_05(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, # one for tuple, one for TupleExprAlias + orig_count + 2 + 2, # one for tuple + # one for TupleExprAlias, + # two for implicit array and array ) await self.con.execute(r""" @@ -15083,14 +15140,17 @@ async def test_edgeql_ddl_collection_cleanup_05(self): self.assertEqual( await self.con.query_single(count_query), - orig_count + 2, + orig_count + 2 + 2, ) await self.con.execute(r""" DROP ALIAS Bar; """) - self.assertEqual(await self.con.query_single(count_query), orig_count) + self.assertEqual( + await self.con.query_single(count_query), + orig_count + 2, + ) async def test_edgeql_ddl_drop_field_01(self): await self.con.execute(r""" diff --git a/tests/test_edgeql_functions.py b/tests/test_edgeql_functions.py index bd6eb250742..b53797fef9e 100644 --- a/tests/test_edgeql_functions.py +++ b/tests/test_edgeql_functions.py @@ -22,7 +22,6 @@ import json import os.path import random -import unittest import uuid import edgedb @@ -851,6 +850,23 @@ async def test_edgeql_functions_enumerate_08(self): ]) ) + async def test_edgeql_functions_enumerate_09(self): + await self.assert_query_result( + 'SELECT enumerate(sum({1,2,3}))', + [[0, 6]] + ) + await self.assert_query_result( + 'SELECT enumerate(count(Issue))', + [[0, 4]] + ) + await self.assert_query_result( + ''' + WITH x := (SELECT enumerate(array_agg((select User)))), + SELECT (x.0, array_unpack(x.1).name) + ''', + [[0, 'Elvis'], [0, 'Yury']] + ) + async def test_edgeql_functions_array_get_01(self): await self.assert_query_result( r'''SELECT array_get([1, 2, 3], 2);''', @@ -8315,10962 +8331,3 @@ async def test_edgeql_functions_complex_types_04(self): ['https://edgedb.com', '~/screenshot.png'], sort=True, ) - - async def test_edgeql_functions_inline_basic_01(self): - await self.con.execute(''' - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_02(self): - await self.con.execute(''' - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (x * x + 2 * x + 1); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [4], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [4, 9, 16], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [4, 9, 16], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_03(self): - await self.con.execute(''' - create function foo(x: int64, y: int64) -> int64 { - set is_inlined := true; - using (x + y); - }; - ''') - await self.assert_query_result( - 'select foo({}, {})', - [], - ) - await self.assert_query_result( - 'select foo(1, {})', - [], - ) - await self.assert_query_result( - 'select foo({}, 1)', - [], - ) - await self.assert_query_result( - 'select foo(1, 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_04(self): - await self.con.execute(''' - create function foo(x: int64 = 9) -> int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo()', - [9], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_05(self): - await self.con.execute(''' - create function foo(x: int64) -> optional int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_06(self): - await self.con.execute(''' - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_07(self): - await self.con.execute(''' - create function foo(x: int64, y: int64 = 90) -> int64 { - set is_inlined := true; - using (x + y); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [91], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'select foo({}, {})', - [], - ) - await self.assert_query_result( - 'select foo(1, {})', - [], - ) - await self.assert_query_result( - 'select foo({}, 1)', - [], - ) - await self.assert_query_result( - 'select foo(1, 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(1, y))', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_08(self): - await self.con.execute(''' - create function foo(x: int64 = 9, y: int64 = 90) -> int64 { - set is_inlined := true; - using (x + y); - }; - ''') - await self.assert_query_result( - 'select foo()', - [99], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [91], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'select foo({}, {})', - [], - ) - await self.assert_query_result( - 'select foo(1, {})', - [], - ) - await self.assert_query_result( - 'select foo({}, 1)', - [], - ) - await self.assert_query_result( - 'select foo(1, 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(1, y))', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_09(self): - await self.con.execute(''' - create function foo(variadic x: int64) -> int64 { - set is_inlined := true; - using (sum(array_unpack(x))); - }; - ''') - await self.assert_query_result( - 'select foo()', - [0], - ) - await self.assert_query_result( - 'select foo(1,{})', - [], - ) - await self.assert_query_result( - 'select foo({},1)', - [], - ) - await self.assert_query_result( - 'select foo(1, 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, {10, 20, 30}, 100)', - [111, 112, 113, 121, 122, 123, 131, 132, 133], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, y, 100)' - ' )' - ')', - [111, 112, 113, 121, 122, 123, 131, 132, 133], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_10(self): - await self.con.execute(''' - create function foo(named only a: int64) -> int64 { - set is_inlined := true; - using (a); - }; - ''') - await self.assert_query_result( - 'select foo(a := {})', - [], - ) - await self.assert_query_result( - 'select foo(a := 1)', - [1], - ) - await self.assert_query_result( - 'select foo(a := {1,2,3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(a := x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_11(self): - await self.con.execute(''' - create function foo(x: int64, named only a: int64) -> int64 { - set is_inlined := true; - using (x + a); - }; - ''') - await self.assert_query_result( - 'select foo({}, a := {})', - [], - ) - await self.assert_query_result( - 'select foo(1, a := {})', - [], - ) - await self.assert_query_result( - 'select foo({}, a := 10)', - [], - ) - await self.assert_query_result( - 'select foo(1, a := 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, a := {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x, a := 10))', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(1, a := y))', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, a := y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_12(self): - await self.con.execute(''' - create function foo( - x: int64 = 9, - named only a: int64 - ) -> int64 { - set is_inlined := true; - using (x + a); - }; - ''') - await self.assert_query_result( - 'select foo(a := {})', - [], - ) - await self.assert_query_result( - 'select foo(a := 10)', - [19], - ) - await self.assert_query_result( - 'select foo(a := {10, 20, 30})', - [19, 29, 39], - sort=True, - ) - await self.assert_query_result( - 'select foo({}, a := {})', - [], - ) - await self.assert_query_result( - 'select foo(1, a := {})', - [], - ) - await self.assert_query_result( - 'select foo({}, a := 10)', - [], - ) - await self.assert_query_result( - 'select foo(1, a := 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, a := {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x, a := 10))', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(a := y))', - [19, 29, 39], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(1, a := y))', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, a := y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_13(self): - await self.con.execute(''' - create function foo( - x: int64, - named only a: int64 = 90 - ) -> int64 { - set is_inlined := true; - using (x + a); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [91], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'select foo({}, a := {})', - [], - ) - await self.assert_query_result( - 'select foo(1, a := {})', - [], - ) - await self.assert_query_result( - 'select foo({}, a := 10)', - [], - ) - await self.assert_query_result( - 'select foo(1, a := 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, a := {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x, a := 10))', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(1, a := y))', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, a := y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_14(self): - await self.con.execute(''' - create function foo( - x: int64 = 9, - named only a: int64 = 90 - ) -> int64 { - set is_inlined := true; - using (x + a); - }; - ''') - await self.assert_query_result( - 'select foo()', - [99], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [91], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'select foo(a := {})', - [], - ) - await self.assert_query_result( - 'select foo(a := 10)', - [19], - ) - await self.assert_query_result( - 'select foo(a := {10, 20, 30})', - [19, 29, 39], - sort=True, - ) - await self.assert_query_result( - 'select foo({}, a := {})', - [], - ) - await self.assert_query_result( - 'select foo(1, a := {})', - [], - ) - await self.assert_query_result( - 'select foo({}, a := 10)', - [], - ) - await self.assert_query_result( - 'select foo(1, a := 10)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := 10)', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, a := {10, 20, 30})', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, a := {10, 20, 30})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [91, 92, 93], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x, a := 10))', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(a := y))', - [19, 29, 39], - sort=True, - ) - await self.assert_query_result( - 'for y in {10, 20, 30} union (select foo(1, a := y))', - [11, 21, 31], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, a := y)' - ' )' - ')', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_15(self): - await self.con.execute(''' - create function foo( - x: int64, - y: int64 = 90, - variadic z: int64, - named only a: int64, - named only b: int64 = 90000 - ) -> int64 { - set is_inlined := true; - using (x + y + sum(array_unpack(z)) + a + b); - }; - ''') - await self.assert_query_result( - 'select foo(1, a := 1000)', - [91091], - ) - await self.assert_query_result( - 'select foo(1, 10, a := 1000)', - [91011], - ) - await self.assert_query_result( - 'select foo(1, a := 1000, b := 10000)', - [11091], - ) - await self.assert_query_result( - 'select foo(1, 10, a := 1000, b := 10000)', - [11011], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, a := 1000)', - [91111], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, a := 1000, b := 10000)', - [11111], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, 200, a := 1000)', - [91311], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, 200, a := 1000, b := 10000)', - [11311], - ) - - async def test_edgeql_functions_inline_basic_16(self): - await self.con.execute(''' - create function foo(x: optional int64) -> optional int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_17(self): - await self.con.execute(''' - create function foo( - x: optional int64 - ) -> int64 { - set is_inlined := true; - using (x ?? 5); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [5], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_18(self): - await self.con.execute(''' - create function foo( - x: optional int64 = 9 - ) -> int64 { - set is_inlined := true; - using (x ?? 5); - }; - ''') - await self.assert_query_result( - 'select foo()', - [9], - ) - await self.assert_query_result( - 'select foo({})', - [5], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_basic_19(self): - await self.con.execute(''' - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using (for y in {x, x + 1, x + 2} union (y)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1, 2, 3], - ) - await self.assert_query_result( - 'select foo({11, 21, 31})', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'for x in {11, 21, 31} union (select foo(x))', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - async def test_edgeql_functions_inline_array_01(self): - await self.con.execute(''' - create function foo(x: int64) -> array { - set is_inlined := true; - using ([x]); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [[1]], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [[1], [2], [3]], - sort=True, - ) - - async def test_edgeql_functions_inline_array_02(self): - await self.con.execute(''' - create function foo(x: array) -> array { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_array_03(self): - await self.con.execute(''' - create function foo( - x: array = [9] - ) -> array { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo()', - [[9]], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_array_04(self): - await self.con.execute(''' - create function foo(x: array) -> int64 { - set is_inlined := true; - using (sum(array_unpack(x))); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [1], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [1, 5], - sort=True, - ) - - async def test_edgeql_functions_inline_array_05(self): - await self.con.execute(''' - create function foo(x: array) -> set of int64 { - set is_inlined := true; - using (array_unpack(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [1], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_tuple_01(self): - await self.con.execute(''' - create function foo(x: int64) -> tuple { - set is_inlined := true; - using ((x,)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1,)], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(1,), (2,), (3,)], - sort=True, - ) - - async def test_edgeql_functions_inline_tuple_02(self): - await self.con.execute(''' - create function foo( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - sort=True, - ) - - async def test_edgeql_functions_inline_tuple_03(self): - await self.con.execute(''' - create function foo( - x: tuple = (9,) - ) -> tuple { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo()', - [(9,)], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - ) - - async def test_edgeql_functions_inline_tuple_04(self): - await self.con.execute(''' - create function foo( - x: tuple - ) -> int64 { - set is_inlined := true; - using (x.0); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [1], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_tuple_05(self): - await self.con.execute(''' - create function foo(x: int64) -> tuple { - set is_inlined := true; - using ((a:=x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [{'a': 1}], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [{'a': 1}, {'a': 2}, {'a': 3}], - ) - - async def test_edgeql_functions_inline_tuple_06(self): - await self.con.execute(''' - create function foo( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [{'a': 1}], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [{'a': 1}, {'a': 2}, {'a': 3}], - ) - - async def test_edgeql_functions_inline_tuple_07(self): - await self.con.execute(''' - create function foo( - x: tuple = (a:=9) - ) -> tuple { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo()', - [{'a': 9}], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [{'a': 1}], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [{'a': 1}, {'a': 2}, {'a': 3}], - ) - - async def test_edgeql_functions_inline_tuple_08(self): - await self.con.execute(''' - create function foo( - x: tuple - ) -> int64 { - set is_inlined := true; - using (x.a); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [1], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [1, 2, 3], - ) - - async def test_edgeql_functions_inline_object_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: int64) -> optional Bar { - set is_inlined := true; - using ((select Bar{a} filter .a = x limit 1)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo(-1).a', - [], - ) - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: Bar) -> Bar { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: optional Bar) -> optional Bar { - set is_inlined := true; - using (x ?? (select Bar filter .a = 1 limit 1)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: Bar) -> int64 { - set is_inlined := true; - using (x.a); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_05(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: Bar) -> set of Bar { - set is_inlined := true; - using ((select Bar{a} filter .a <= x.a)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 1, 1, 2, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_06(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ((select Bar{a} filter .a <= x).a); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1,2,3})', - [1, 1, 1, 2, 2, 3], - sort=True, - ) - - @tb.needs_factoring - async def test_edgeql_functions_inline_object_07(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo() -> int64 { - set is_inlined := true; - using (count(Bar)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [3], - ) - await self.assert_query_result( - 'select (foo(), foo())', - [[3, 3]], - sort=True, - ) - await self.assert_query_result( - 'select (Bar.a, foo())', - [[1, 3], [2, 3], [3, 3]], - sort=True, - ) - await self.assert_query_result( - 'select (foo(), Bar.a)', - [[3, 1], [3, 2], [3, 3]], - sort=True, - ) - await self.assert_query_result( - 'select (Bar.a, foo(), Bar.a, foo())', - [[1, 3, 1, 3], [2, 3, 2, 3], [3, 3, 3, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_object_08(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo() -> set of tuple { - set is_inlined := true; - using ((Bar.a, count(Bar))); - }; - ''') - await self.assert_query_result( - 'select foo()', - [[1, 1], [2, 1], [3, 1]], - ) - await self.assert_query_result( - 'select (foo(), foo())', - [ - [[1, 1], [1, 1]], [[1, 1], [2, 1]], [[1, 1], [3, 1]], - [[2, 1], [1, 1]], [[2, 1], [2, 1]], [[2, 1], [3, 1]], - [[3, 1], [1, 1]], [[3, 1], [2, 1]], [[3, 1], [3, 1]], - ], - sort=True, - ) - await self.assert_query_result( - 'select (Bar.a, foo())', - [ - [1, [1, 1]], [1, [2, 1]], [1, [3, 1]], - [2, [1, 1]], [2, [2, 1]], [2, [3, 1]], - [3, [1, 1]], [3, [2, 1]], [3, [3, 1]], - ], - sort=True, - ) - await self.assert_query_result( - 'select (foo(), Bar.a)', - [ - [[1, 1], 1], [[1, 1], 2], [[1, 1], 3], - [[2, 1], 1], [[2, 1], 2], [[2, 1], 3], - [[3, 1], 1], [[3, 1], 2], [[3, 1], 3], - ], - sort=True, - ) - - @tb.needs_factoring - async def test_edgeql_functions_inline_object_09(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: Bar) -> tuple { - set is_inlined := true; - using ((x.a, count(Bar))); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select (Bar.a, foo((select Bar filter .a = 1)))', - [[1, [1, 3]]], - ) - await self.assert_query_result( - 'select (Bar.a, foo((select detached Bar filter .a = 1)))', - [[1, [1, 3]], [2, [1, 3]], [3, [1, 3]]], - sort=True, - ) - await self.assert_query_result( - 'select (Bar.a, foo(Bar))', - [[1, [1, 3]], [2, [2, 3]], [3, [3, 3]]], - sort=True, - ) - await self.assert_query_result( - 'select (foo(Bar), foo(Bar))', - [[[1, 3], [1, 3]], [[2, 3], [2, 3]], [[3, 3], [3, 3]]], - sort=True, - ) - await self.assert_query_result( - 'select (foo(Bar), foo(detached Bar))', - [ - [[1, 3], [1, 3]], [[1, 3], [2, 3]], [[1, 3], [3, 3]], - [[2, 3], [1, 3]], [[2, 3], [2, 3]], [[2, 3], [3, 3]], - [[3, 3], [1, 3]], [[3, 3], [2, 3]], [[3, 3], [3, 3]], - ], - sort=True, - ) - - async def test_edgeql_functions_inline_object_10(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: Bar) -> set of Baz { - set is_inlined := true; - using ((select Baz filter .b <= x.a)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [4, 4, 4, 5, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_object_11(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: Bar | Baz) -> Bar | Baz { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Baz filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Baz)).a', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select {Bar, Baz})).a', - [1, 2, 3, 4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_object_12(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: int64) -> optional Bar | Baz { - set is_inlined := true; - using ((select {Bar, Baz} filter .a = x limit 1)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select foo({1, 4}).a', - [1, 4], - sort=True, - ) - await self.assert_query_result( - 'select foo({0, 1, 2, 3, 4, 5, 6, 7, 8}).a', - [1, 2, 3, 4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_object_13(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: Bar | Baz) -> optional Bar { - set is_inlined := true; - using (x[is Bar]); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Baz filter .a = 4)).a', - [], - ) - await self.assert_query_result( - 'select foo((select Baz)).a', - [], - ) - await self.assert_query_result( - 'select foo((select {Bar, Baz})).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_14(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: Bar | Baz) -> optional int64 { - set is_inlined := true; - using ( - x[is Baz].b - ) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [], - ) - await self.assert_query_result( - 'select foo((select Bar))', - [], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Baz filter .a = 4))', - [1], - ) - await self.assert_query_result( - 'select foo((select Baz))', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo((select {Bar, Baz}))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_object_15(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: Bar | Baz) -> optional int64 { - set is_inlined := true; - using ( - if x is Bar - then x.a*2 - else 10 + assert_exists(x[is Baz]).b - ) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [2], - ) - await self.assert_query_result( - 'select foo((select Bar))', - [2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Baz filter .a = 4))', - [11], - ) - await self.assert_query_result( - 'select foo((select Baz))', - [11, 12, 13], - sort=True, - ) - await self.assert_query_result( - 'select foo((select {Bar, Baz}))', - [2, 4, 6, 11, 12, 13], - sort=True, - ) - - async def test_edgeql_functions_inline_object_16(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Bar2 extending Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar2{a := 4}; - insert Bar2{a := 5}; - insert Bar2{a := 6}; - create function foo(x: Bar) -> optional Bar2 { - set is_inlined := true; - using (x[is Bar2]); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar2 filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Bar2)).a', - [4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_object_17(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - b := 4, - bar := assert_exists((select Bar filter .a = 1 limit 1)), - }; - insert Baz{ - b := 5, - bar := assert_exists((select Bar filter .a = 2 limit 1)), - }; - insert Baz{ - b := 6, - bar := assert_exists((select Bar filter .a = 3 limit 1)), - }; - create function foo(x: Baz) -> Bar { - set is_inlined := true; - using (x.bar); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Baz filter .b = 4)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Baz)).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_shape_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select Bar{' - ' a,' - ' b := foo(.a)' - '} order by .a', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 3, 'b': 3}, - ], - ) - - async def test_edgeql_functions_inline_shape_02(self): - await self.con.execute(''' - create type Bar { - create property a -> int64; - }; - insert Bar{}; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: optional int64) -> optional int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select Bar{' - ' a,' - ' b := foo(.a)' - '} order by .a', - [ - {'a': None, 'b': None}, - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 3, 'b': 3}, - ], - ) - - async def test_edgeql_functions_inline_shape_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: optional int64) -> set of int64 { - set is_inlined := true; - using ({10 + x, 20 + x, 30 + x}); - }; - ''') - await self.assert_query_result( - 'select Bar{' - ' a,' - ' b := foo(.a)' - '} order by .a', - [ - {'a': 1, 'b': [11, 21, 31]}, - {'a': 2, 'b': [12, 22, 32]}, - {'a': 3, 'b': [13, 23, 33]}, - ], - ) - - async def test_edgeql_functions_inline_shape_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo() -> int64 { - set is_inlined := true; - using (count(Bar)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [3], - ) - await self.assert_query_result( - 'select Bar {' - ' a,' - ' n := foo(),' - '} order by .a', - [{'a': 1, 'n': 3}, {'a': 2, 'n': 3}, {'a': 3, 'n': 3}], - ) - - async def test_edgeql_functions_inline_shape_05(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo() -> set of tuple { - set is_inlined := true; - using ((Bar.a, count(Bar))); - }; - ''') - await self.assert_query_result( - 'select foo()', - [[1, 1], [2, 1], [3, 1]], - ) - await self.assert_query_result( - 'select Bar {' - ' a,' - ' n := foo(),' - '} order by .a', - [ - {'a': 1, 'n': [[1, 1], [2, 1], [3, 1]]}, - {'a': 2, 'n': [[1, 1], [2, 1], [3, 1]]}, - {'a': 3, 'n': [[1, 1], [2, 1], [3, 1]]}, - ], - ) - - async def test_edgeql_functions_inline_shape_06(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: Bar) -> tuple { - set is_inlined := true; - using ((x.a, count(Bar))); - }; - ''') - await self.assert_query_result( - 'select Bar {' - ' a,' - ' n := foo(Bar),' - '} order by .a', - [ - {'a': 1, 'n': [1, 3]}, - {'a': 2, 'n': [2, 3]}, - {'a': 3, 'n': [3, 3]}, - ], - ) - - async def test_edgeql_functions_inline_shape_07(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using (assert_exists((select Bar filter .a = x limit 1))); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a,' - ' c := foo(.b).a,' - '} order by .a', - [ - {'a': 4, 'c': 1}, - {'a': 5, 'c': 2}, - {'a': 6, 'c': 3}, - ], - ) - - async def test_edgeql_functions_inline_shape_08(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - insert Baz{a := 7, b := 4}; - create function foo(x: int64) -> optional Bar { - set is_inlined := true; - using ((select Bar filter .a = x limit 1)); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a,' - ' c := foo(.b).a,' - '} order by .a', - [ - {'a': 4, 'c': 1}, - {'a': 5, 'c': 2}, - {'a': 6, 'c': 3}, - {'a': 7, 'c': None}, - ], - ) - - async def test_edgeql_functions_inline_shape_09(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using ((select Bar filter .a <= x)); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a,' - ' c := foo(.b).a,' - '} order by .a', - [ - {'a': 4, 'c': [1]}, - {'a': 5, 'c': [1, 2]}, - {'a': 6, 'c': [1, 2, 3]}, - ], - ) - - async def test_edgeql_functions_inline_shape_10(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - b := 4, - bar := assert_exists((select Bar filter .a = 1 limit 1)), - }; - insert Baz{ - b := 5, - bar := assert_exists((select Bar filter .a = 2 limit 1)), - }; - insert Baz{ - b := 6, - bar := assert_exists((select Bar filter .a = 3 limit 1)), - }; - create function foo(x: Bar) -> Bar { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a := foo(.bar).a,' - ' b,' - '} order by .a', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_shape_11(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - b := 4, - bar := assert_exists((select Bar filter .a = 1 limit 1)), - }; - insert Baz{ - b := 5, - bar := assert_exists((select Bar filter .a = 2 limit 1)), - }; - insert Baz{ - b := 6, - bar := assert_exists((select Bar filter .a = 3 limit 1)), - }; - create function foo(x: Bar) -> int64 { - set is_inlined := true; - using (x.a); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a := foo(.bar),' - ' b,' - '} order by .a', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_shape_12(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create multi link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - b := 4, - bar := assert_exists((select Bar filter .a <= 1)), - }; - insert Baz{ - b := 5, - bar := assert_exists((select Bar filter .a <= 2)), - }; - insert Baz{ - b := 6, - bar := assert_exists((select Bar filter .a <= 3)), - }; - create function foo(x: Bar) -> Bar { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a := foo(.bar).a,' - ' b,' - '} order by .b', - [ - {'a': [1], 'b': 4}, - {'a': [1, 2], 'b': 5}, - {'a': [1, 2, 3], 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_shape_13(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - }; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - bar := assert_exists((select Bar filter .a = 1 limit 1)) { - @b := 4 - }, - }; - insert Baz{ - bar := assert_exists((select Bar filter .a = 2 limit 1)) { - @b := 5 - } - }; - insert Baz{ - bar := assert_exists((select Bar filter .a = 3 limit 1)) { - @b := 6 - } - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - ''') - await self.assert_query_result( - 'select Baz{' - ' a := .bar.a,' - ' b := foo(.bar@b),' - '} order by .a', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_global_01(self): - await self.con.execute(''' - create global a := 1; - create function foo() -> int64 { - set is_inlined := true; - using (global a); - }; - ''') - await self.assert_query_result( - 'select foo()', - [1], - ) - - async def test_edgeql_functions_inline_global_02(self): - await self.con.execute(''' - create global a -> int64; - create function foo() -> optional int64 { - set is_inlined := true; - using (global a); - }; - ''') - await self.assert_query_result( - 'select foo()', - [], - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo()', - [1], - ) - - async def test_edgeql_functions_inline_global_03(self): - await self.con.execute(''' - create global a := 1; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (global a + x); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_global_04(self): - await self.con.execute(''' - create global a -> int64; - create function foo(x: int64) -> optional int64 { - set is_inlined := true; - using (global a + x) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [], - sort=True, - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_01(self): - # Directly passing parameter - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x) - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner(x)) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_02(self): - # Indirectly passing parameter - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner(x + 1)) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [4], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [4, 9, 16], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [4, 9, 16], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_03(self): - # Calling same inner function with different parameters - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(x: int64, y: int64) -> int64 { - set is_inlined := true; - using (inner(x) + inner(y)); - }; - ''') - await self.assert_query_result( - 'select foo({}, {})', - [], - ) - await self.assert_query_result( - 'select foo(1, {})', - [], - ) - await self.assert_query_result( - 'select foo({}, 1)', - [], - ) - await self.assert_query_result( - 'select foo(1, 10)', - [101], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, 10)', - [101, 104, 109], - sort=True, - ) - await self.assert_query_result( - 'select foo(1, {10, 20, 30})', - [101, 401, 901], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}, {10, 20, 30})', - [101, 104, 109, 401, 404, 409, 901, 904, 909], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' for y in {10, 20, 30} union (' - ' select foo(x, y)' - ' )' - ')', - [101, 104, 109, 401, 404, 409, 901, 904, 909], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_04(self): - # Directly passing parameter with default - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(x: int64 = 9) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [81], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_05(self): - # Indirectly passing parameter with default - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(x: int64 = 9) -> int64 { - set is_inlined := true; - using (inner(x+1)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [100], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [4], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [4, 9, 16], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [4, 9, 16], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_06(self): - # Inner function with default parameter - await self.con.execute(''' - create function inner(x: int64 = 9) -> int64 { - set is_inlined := true; - using (x * x) - }; - create function foo1() -> int64 { - set is_inlined := true; - using (inner()); - }; - create function foo2(x: int64) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo1()', - [81], - ) - await self.assert_query_result( - 'select foo2({})', - [], - ) - await self.assert_query_result( - 'select foo2(1)', - [1], - ) - await self.assert_query_result( - 'select foo2({1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo2(x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_07(self): - # Directly passing optional parameter - await self.con.execute(''' - create function inner(x: optional int64) -> optional int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(x: optional int64) -> int64 { - set is_inlined := true; - using (inner(x) ?? 99); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [99], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_08(self): - # Indirectly passing optional parameter - await self.con.execute(''' - create function inner(x: optional int64) -> optional int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(x: optional int64) -> int64 { - set is_inlined := true; - using (inner(x+1) ?? 99); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [99], - ) - await self.assert_query_result( - 'select foo(1)', - [4], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [4, 9, 16], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [4, 9, 16], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_09(self): - # Inner function with optional parameter - await self.con.execute(''' - create function inner(x: optional int64) -> int64 { - set is_inlined := true; - using ((x * x) ?? 99) - }; - create function foo1() -> int64 { - set is_inlined := true; - using (inner({})); - }; - create function foo2(x: int64) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo1()', - [99], - ) - await self.assert_query_result( - 'select foo2({})', - [], - ) - await self.assert_query_result( - 'select foo2(1)', - [1], - ) - await self.assert_query_result( - 'select foo2({1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo2(x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_10(self): - # Directly passing variadic parameter - await self.con.execute(''' - create function inner(x: array) -> int64 { - set is_inlined := true; - using (sum(array_unpack(x))) - }; - create function foo(variadic x: int64) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [0], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo(1, 2, 3)', - [6], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2}, {10, 20})', - [11, 12, 21, 22], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_11(self): - # Indirectly passing variadic parameter - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x) - }; - create function foo(variadic x: int64) -> int64 { - set is_inlined := true; - using (inner(sum(array_unpack(x)))); - }; - ''') - await self.assert_query_result( - 'select foo()', - [0], - ) - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo(1, 2, 3)', - [6], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2}, {10, 20})', - [11, 12, 21, 22], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_12(self): - # Inner function with variadic parameter - await self.con.execute(''' - create function inner(variadic x: int64) -> int64 { - set is_inlined := true; - using (sum(array_unpack(x))) - }; - create function foo1() -> int64 { - set is_inlined := true; - using (inner()); - }; - create function foo2(x: int64, y: int64, z: int64) -> int64 { - set is_inlined := true; - using (inner(x, y, z)); - }; - ''') - await self.assert_query_result( - 'select foo1()', - [0], - ) - await self.assert_query_result( - 'select foo2({}, {}, {})', - [], - ) - await self.assert_query_result( - 'select foo2(1, 2, 3)', - [6], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo2(x, x * 10, x * 100))', - [111, 222, 333], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_13(self): - # Directly passing named parameter - await self.con.execute(''' - create function inner(named only a: int64) -> int64 { - set is_inlined := true; - using (a * a) - }; - create function foo(named only a: int64) -> int64 { - set is_inlined := true; - using (inner(a := a)); - }; - ''') - await self.assert_query_result( - 'select foo(a := {})', - [], - ) - await self.assert_query_result( - 'select foo(a := 1)', - [1], - ) - await self.assert_query_result( - 'select foo(a := {1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(a := x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_14(self): - # Indirectly passing named parameter - await self.con.execute(''' - create function inner(named only a: int64) -> int64 { - set is_inlined := true; - using (a * a) - }; - create function foo(named only a: int64) -> int64 { - set is_inlined := true; - using (inner(a := a + 1)); - }; - ''') - await self.assert_query_result( - 'select foo(a := {})', - [], - ) - await self.assert_query_result( - 'select foo(a := 1)', - [4], - ) - await self.assert_query_result( - 'select foo(a := {1, 2, 3})', - [4, 9, 16], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(a := x))', - [4, 9, 16], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_15(self): - # Passing named parameter as positional - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x * x) - }; - create function foo(named only a: int64) -> int64 { - set is_inlined := true; - using (inner(a)); - }; - ''') - await self.assert_query_result( - 'select foo(a := {})', - [], - ) - await self.assert_query_result( - 'select foo(a := 1)', - [1], - ) - await self.assert_query_result( - 'select foo(a := {1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(a := x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_16(self): - # Passing positional parameter as named - await self.con.execute(''' - create function inner(named only a: int64) -> int64 { - set is_inlined := true; - using (a * a) - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner(a := x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [1, 4, 9], - sort=True, - ) - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 4, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_17(self): - # Variety of paremeter types - await self.con.execute(''' - create function inner1(x: int64, y: int64) -> int64 { - set is_inlined := true; - using (x + y) - }; - create function inner2(x: array) -> int64 { - set is_inlined := true; - using (sum(array_unpack(x))) - }; - create function foo( - x: int64, - y: int64 = 90, - variadic z: int64, - named only a: int64, - named only b: int64 = 90000 - ) -> int64 { - set is_inlined := true; - using (inner1(x, a) + inner1(y, b) + inner2(z)); - }; - ''') - await self.assert_query_result( - 'select foo(1, a := 1000)', - [91091], - ) - await self.assert_query_result( - 'select foo(1, 10, a := 1000)', - [91011], - ) - await self.assert_query_result( - 'select foo(1, a := 1000, b := 10000)', - [11091], - ) - await self.assert_query_result( - 'select foo(1, 10, a := 1000, b := 10000)', - [11011], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, a := 1000)', - [91111], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, a := 1000, b := 10000)', - [11111], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, 200, a := 1000)', - [91311], - ) - await self.assert_query_result( - 'select foo(1, 10, 100, 200, a := 1000, b := 10000)', - [11311], - ) - - async def test_edgeql_functions_inline_nested_basic_18(self): - # For in inner function - await self.con.execute(''' - create function inner(x: int64) -> set of int64 { - set is_inlined := true; - using (for y in {x, x + 1, x + 2} union (y)) - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(10)', - [10, 11, 12], - ) - await self.assert_query_result( - 'select foo({10, 20, 30})', - [10, 11, 12, 20, 21, 22, 30, 31, 32], - sort=True, - ) - await self.assert_query_result( - 'for x in {10, 20, 30} union (select foo(x))', - [10, 11, 12, 20, 21, 22, 30, 31, 32], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_19(self): - # For in outer function - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x) - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using (for y in {x, x + 1, x + 2} union (inner(y))); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(10)', - [10, 11, 12], - ) - await self.assert_query_result( - 'select foo({10, 20, 30})', - [10, 11, 12, 20, 21, 22, 30, 31, 32], - sort=True, - ) - await self.assert_query_result( - 'for x in {10, 20, 30} union (select foo(x))', - [10, 11, 12, 20, 21, 22, 30, 31, 32], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_basic_20(self): - # Deeply nested - await self.con.execute(''' - create function inner1(x: int64) -> int64 { - set is_inlined := true; - using (x+1) - }; - create function inner2(x: int64) -> int64 { - set is_inlined := true; - using (inner1(x+2)) - }; - create function inner3(x: int64) -> int64 { - set is_inlined := true; - using (inner2(x+3)) - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner3(x+4)) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [11], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [11, 12, 13], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_01(self): - # Return array from inner function - await self.con.execute(''' - create function inner(x: int64) -> array { - set is_inlined := true; - using ([x]); - }; - create function foo(x: int64) -> array { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [[1]], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [[1], [2], [3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_02(self): - # Access array element in inner function - await self.con.execute(''' - create function inner(x: array) -> int64 { - set is_inlined := true; - using (x[0]); - }; - create function foo(x: array) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [1], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [1, 2], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_03(self): - # Access array element in outer function - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo(x: array) -> int64 { - set is_inlined := true; - using (inner(x[0])); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [1], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [1, 2], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_04(self): - # Directly passing array parameter - await self.con.execute(''' - create function inner(x: array) -> array { - set is_inlined := true; - using (x); - }; - create function foo(x: array) -> array { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_05(self): - # Indirectly passing array parameter - await self.con.execute(''' - create function inner(x: array) -> array { - set is_inlined := true; - using (x); - }; - create function foo(x: array) -> array { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_06(self): - # Inner function with array parameter - await self.con.execute(''' - create function inner(x: array) -> array { - set is_inlined := true; - using (x); - }; - create function foo(x: int64) -> array { - set is_inlined := true; - using (inner([x])); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [[1]], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [[1], [2], [3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_07(self): - # Directly passing array parameter with default - await self.con.execute(''' - create function inner(x: array) -> array { - set is_inlined := true; - using (x); - }; - create function foo( - x: array = [9] - ) -> array { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [[9]], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_08(self): - # Directly passing array parameter with default - await self.con.execute(''' - create function inner(x: array) -> array { - set is_inlined := true; - using (x); - }; - create function foo( - x: array = [9] - ) -> array { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo()', - [[9]], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_09(self): - # Inner function with array parameter with default - await self.con.execute(''' - create function inner(x: array = [9]) -> array { - set is_inlined := true; - using (x); - }; - create function foo1() -> array { - set is_inlined := true; - using (inner()); - }; - create function foo2( - x: array - ) -> array { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo1()', - [[9]], - ) - await self.assert_query_result( - 'select foo2(>{})', - [], - ) - await self.assert_query_result( - 'select foo2([1])', - [[1]], - ) - await self.assert_query_result( - 'select foo2({[1], [2, 3]})', - [[1], [2, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_10(self): - # Unpack array in inner function - await self.con.execute(''' - create function inner(x: array) -> set of int64 { - set is_inlined := true; - using (array_unpack(x)); - }; - create function foo(x: array) -> set of int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [1], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_array_11(self): - # Unpack array in outer function - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo(x: array) -> set of int64 { - set is_inlined := true; - using (inner(array_unpack(x))); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo([1])', - [1], - ) - await self.assert_query_result( - 'select foo({[1], [2, 3]})', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_01(self): - # Return tuple from inner function - await self.con.execute(''' - create function inner(x: int64) -> tuple { - set is_inlined := true; - using ((x,)); - }; - create function foo(x: int64) -> tuple { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1,)], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(1,), (2,), (3,)], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3}).0', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_02(self): - # Return named tuple from inner function - await self.con.execute(''' - create function inner(x: int64) -> tuple { - set is_inlined := true; - using ((a := x)); - }; - create function foo(x: int64) -> tuple { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [{'a': 1}], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}) order by .a', - [{'a': 1}, {'a': 2}, {'a': 3}], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_03(self): - # Accessing tuple element in inner function - await self.con.execute(''' - create function inner( - x: tuple - ) -> int64 { - set is_inlined := true; - using (x.0); - }; - create function foo( - x: tuple - ) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [1], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_04(self): - # Accessing tuple element in outer function - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo( - x: tuple - ) -> int64 { - set is_inlined := true; - using (inner(x.0)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [1], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [1, 2, 3], - ) - - async def test_edgeql_functions_inline_nested_tuple_05(self): - # Accessing named tuple element in inner function - await self.con.execute(''' - create function inner( - x: tuple - ) -> int64 { - set is_inlined := true; - using (x.a); - }; - create function foo( - x: tuple - ) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((a := 1))', - [1], - ) - await self.assert_query_result( - 'select foo({(a := 1), (a := 2), (a := 3)})', - [1, 2, 3], - ) - - async def test_edgeql_functions_inline_nested_tuple_06(self): - # Accessing named tuple element in outer function - await self.con.execute(''' - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo( - x: tuple - ) -> int64 { - set is_inlined := true; - using (inner(x.a)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((a := 1))', - [1], - ) - await self.assert_query_result( - 'select foo({(a := 1), (a := 2), (a := 3)})', - [1, 2, 3], - ) - - async def test_edgeql_functions_inline_nested_tuple_07(self): - # Directly passing tuple parameter - await self.con.execute(''' - create function inner( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - create function foo( - x: tuple - ) -> tuple { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_08(self): - # Indirectly passing tuple parameter - await self.con.execute(''' - create function inner( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - create function foo( - x: tuple - ) -> tuple { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_09(self): - # Inner function with tuple parameter - await self.con.execute(''' - create function inner( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - create function foo( - x: int64 - ) -> tuple { - set is_inlined := true; - using (inner((x,))); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1,)], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(1,), (2,), (3,)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_tuple_10(self): - # Directly passing a tuple parameter with default - await self.con.execute(''' - create function inner( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - create function foo( - x: tuple = (9,) - ) -> tuple { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo()', - [(9,)], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - ) - - async def test_edgeql_functions_inline_nested_tuple_11(self): - # Indirectly passing tuple parameter with default - await self.con.execute(''' - create function inner( - x: tuple - ) -> tuple { - set is_inlined := true; - using (x); - }; - create function foo( - x: tuple = (9,) - ) -> tuple { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo()', - [(9,)], - ) - await self.assert_query_result( - 'select foo(>{})', - [], - ) - await self.assert_query_result( - 'select foo((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - ) - - async def test_edgeql_functions_inline_nested_tuple_12(self): - # Inner function with tuple parameter with default - await self.con.execute(''' - create function inner( - x: tuple = (9,) - ) -> tuple { - set is_inlined := true; - using (x); - }; - create function foo1() -> tuple { - set is_inlined := true; - using (inner()); - }; - create function foo2( - x: tuple - ) -> tuple { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo1()', - [(9,)], - ) - await self.assert_query_result( - 'select foo2(>{})', - [], - ) - await self.assert_query_result( - 'select foo2((1,))', - [(1,)], - ) - await self.assert_query_result( - 'select foo2({(1,), (2,), (3,)})', - [(1,), (2,), (3,)], - ) - - async def test_edgeql_functions_inline_nested_object_01(self): - # Directly passing object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: Bar) -> Bar { - set is_inlined := true; - using (x); - }; - create function foo(x: Bar) -> Bar { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_02(self): - # Indirectly passing object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: Bar) -> Bar { - set is_inlined := true; - using (x); - }; - create function foo(x: Bar) -> Bar { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_03(self): - # Inner function with object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: Bar) -> Bar { - set is_inlined := true; - using (x); - }; - create function foo(x: int64) -> optional Bar { - set is_inlined := true; - using (inner((select Bar filter .a = x limit 1))); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3, 4}).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_04(self): - # Inner function returning object - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: int64) -> optional Bar { - set is_inlined := true; - using ((select Bar filter .a = x limit 1)); - }; - create function foo(x: int64) -> optional Bar { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3, 4}).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_05(self): - # Outer function returning object - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo(x: int64) -> optional Bar { - set is_inlined := true; - using ((select Bar filter .a = inner(x) limit 1)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3, 4}).a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_06(self): - # Inner function returning set of object - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: int64) -> set of Bar { - set is_inlined := true; - using ((select Bar filter .a <= x)); - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}).a', - [1, 1, 1, 2, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_07(self): - # Outer function returning set of object - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using ((select Bar filter .a <= inner(x))); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo(2).a', - [1, 2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3}).a', - [1, 1, 1, 2, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_08(self): - # Directly passing optional object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: optional Bar) -> optional int64 { - set is_inlined := true; - using (x.a ?? 99); - }; - create function foo(x: optional Bar) -> optional int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [99], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_09(self): - # Indirectly passing optional object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: optional Bar) -> optional int64 { - set is_inlined := true; - using (x.a ?? 99); - }; - create function foo(x: optional Bar) -> optional int64 { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [99], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_10(self): - # Inner function with optional object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: optional Bar) -> int64 { - set is_inlined := true; - using (x.a ?? 99); - }; - create function foo1() -> int64 { - set is_inlined := true; - using (inner({})); - }; - create function foo2(x: Bar) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo1()', - [99], - ) - await self.assert_query_result( - 'select foo2({})', - [], - ) - await self.assert_query_result( - 'select foo2((select Bar filter .a = 1))', - [1], - ) - await self.assert_query_result( - 'select foo2((select Bar))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_11(self): - # Check path factoring - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner() -> set of tuple { - set is_inlined := true; - using ((Bar.a, count(Bar))); - }; - create function foo() -> set of tuple { - set is_inlined := true; - using (inner()); - }; - ''') - await self.assert_query_result( - 'select foo()', - [[1, 1], [2, 1], [3, 1]], - ) - await self.assert_query_result( - 'select (foo(), foo())', - [ - [[1, 1], [1, 1]], [[1, 1], [2, 1]], [[1, 1], [3, 1]], - [[2, 1], [1, 1]], [[2, 1], [2, 1]], [[2, 1], [3, 1]], - [[3, 1], [1, 1]], [[3, 1], [2, 1]], [[3, 1], [3, 1]], - ], - sort=True, - ) - await self.assert_query_result( - 'select (Bar.a, foo())', - [ - [1, [1, 1]], [1, [2, 1]], [1, [3, 1]], - [2, [1, 1]], [2, [2, 1]], [2, [3, 1]], - [3, [1, 1]], [3, [2, 1]], [3, [3, 1]], - ], - sort=True, - ) - await self.assert_query_result( - 'select (foo(), Bar.a)', - [ - [[1, 1], 1], [[1, 1], 2], [[1, 1], 3], - [[2, 1], 1], [[2, 1], 2], [[2, 1], 3], - [[3, 1], 1], [[3, 1], 2], [[3, 1], 3], - ], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_12(self): - # Check path factoring - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner1(x: Bar) -> int64 { - set is_inlined := true; - using (x.a); - }; - create function inner2(x: Bar) -> int64 { - set is_inlined := true; - using (count(Bar)); - }; - create function foo(x: Bar) -> tuple { - set is_inlined := true; - using ((inner1(x), inner2(x))); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [[1, 3]], - ) - await self.assert_query_result( - 'select (' - ' foo((select Bar filter .a = 1)),' - ' foo((select Bar filter .a = 2)),' - ')', - [[[1, 3], [2, 3]]], - ) - await self.assert_query_result( - 'select foo((select Bar))', - [[1, 3], [2, 3], [3, 3]], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_13(self): - # Directly passing complex type object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function inner(x: Bar | Baz) -> Bar | Baz { - set is_inlined := true; - using (x); - }; - create function foo(x: Bar | Baz) -> Bar | Baz { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Baz filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Baz)).a', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select {Bar, Baz})).a', - [1, 2, 3, 4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_14(self): - # Indirectly passing complex type object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function inner(x: Bar | Baz) -> Bar | Baz { - set is_inlined := true; - using (x); - }; - create function foo(x: Bar | Baz) -> Bar | Baz { - set is_inlined := true; - using (inner((select x))); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Baz filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Baz)).a', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select {Bar, Baz})).a', - [1, 2, 3, 4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_15(self): - # Inner function with complex type object parameter - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function inner(x: Bar | Baz) -> Bar | Baz { - set is_inlined := true; - using (x); - }; - create function foo1(x: Bar) -> Bar | Baz { - set is_inlined := true; - using (inner(x)); - }; - create function foo2(x: Baz) -> Bar | Baz { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo1({}).a', - [], - ) - await self.assert_query_result( - 'select foo2({}).a', - [], - ) - await self.assert_query_result( - 'select foo1((select Bar filter .a = 1)).a', - [1], - ) - await self.assert_query_result( - 'select foo1((select Bar)).a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select foo2((select Baz filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo2((select Baz)).a', - [4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_16(self): - # Type intersection in inner function - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Bar2 extending Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar2{a := 4}; - insert Bar2{a := 5}; - insert Bar2{a := 6}; - create function inner(x: Bar) -> optional Bar2 { - set is_inlined := true; - using (x[is Bar2]); - }; - create function foo(x: Bar) -> optional Bar2 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar2 filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Bar2)).a', - [4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_17(self): - # Type intersection in outer function - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Bar2 extending Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar2{a := 4}; - insert Bar2{a := 5}; - insert Bar2{a := 6}; - create function inner(x: Bar2) -> optional Bar2 { - set is_inlined := true; - using (x); - }; - create function foo(x: Bar) -> optional Bar2 { - set is_inlined := true; - using (inner(x[is Bar2])); - }; - ''') - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo({}).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1)).a', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar2 filter .a = 4)).a', - [4], - ) - await self.assert_query_result( - 'select foo((select Bar)).a', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select foo((select Bar2)).a', - [4, 5, 6], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_object_18(self): - # Access linked object in inner function - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar; - }; - create type Bazz { - create required link baz -> Baz; - }; - insert Bazz{baz := (insert Baz{bar := (insert Bar{a := 1})})}; - insert Bazz{baz := (insert Baz{bar := (insert Bar{a := 2})})}; - insert Bazz{baz := (insert Baz{bar := (insert Bar{a := 3})})}; - create function inner1(x: Bar) -> int64 { - set is_inlined := true; - using (x.a); - }; - create function inner2(x: Baz) -> int64 { - set is_inlined := true; - using (inner1(x.bar)); - }; - create function foo(x: Bazz) -> int64 { - set is_inlined := true; - using (inner2(x.baz)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bazz filter .baz.bar.a = 1))', - [1], - ) - await self.assert_query_result( - 'select foo((select Bazz))', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_01(self): - # Put result of inner function taking Bar.a into Bar - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo(x: Bar) -> int64 { - set is_inlined := true; - using ((select x{a, b := inner(x.a)}).b); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [1], - ) - await self.assert_query_result( - 'select foo(Bar)', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_02(self): - # Put result of inner function taking Bar into Bar - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: Bar) -> int64 { - set is_inlined := true; - using (x.a + 90); - }; - create function foo(x: Bar) -> tuple { - set is_inlined := true; - using ( - with y := (select x{a, b := inner(x)}) - select (y.a, y.b) - ); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo((select Bar filter .a = 1))', - [(1, 91)], - ) - await self.assert_query_result( - 'select foo(Bar)', - [(1, 91), (2, 92), (3, 93)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_03(self): - # Put result of inner function taking number into Bar - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x + 90); - }; - create function foo(x: int64) -> set of tuple { - set is_inlined := true; - using ( - with y := (select Bar{a, b := inner(x)}) - select (y.a, y.b) - ); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1, 91), (2, 91), (3, 91)], - sort=True, - ) - await self.assert_query_result( - 'select foo(Bar.a)', - [ - (1, 91), (1, 92), (1, 93), - (2, 91), (2, 92), (2, 93), - (3, 91), (3, 92), (3, 93), - ], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_04(self): - # Put result of inner function using Bar into Bar - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function inner() -> int64 { - set is_inlined := true; - using (count(Bar)); - }; - create function foo(x: int64) -> set of tuple { - set is_inlined := true; - using ( - with y := (select Bar{a, b := inner()} filter .a = x) - select (y.a, y.b) - ); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1, 3)], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(1, 3), (2, 3), (3, 3)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_05(self): - # Put result of inner function taking Baz.b and returning Bar into Baz - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property a -> int64; - create required property b -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{a := 4, b := 1}; - insert Baz{a := 5, b := 2}; - insert Baz{a := 6, b := 3}; - create function inner(x: int64) -> Bar { - set is_inlined := true; - using (assert_exists((select Bar filter .a = x limit 1))); - }; - create function foo(x: int64) -> set of tuple { - set is_inlined := true; - using ( - with y := (select Baz{a, c := inner(.b).a} filter .b = x) - select (y.a, y.b) - ); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(4, 1)], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(4, 1), (5, 2), (6, 3)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_06(self): - # Put result of inner function taking Baz.bar into Baz - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - b := 4, - bar := assert_exists((select Bar filter .a = 1 limit 1)), - }; - insert Baz{ - b := 5, - bar := assert_exists((select Bar filter .a = 2 limit 1)), - }; - insert Baz{ - b := 6, - bar := assert_exists((select Bar filter .a = 3 limit 1)), - }; - create function inner(x: Bar) -> int64 { - set is_inlined := true; - using (x.a); - }; - create function foo(x: int64) -> set of tuple { - set is_inlined := true; - using ( - with y := (select Baz{a := inner(.bar), b} filter .a = x) - select (y.a, y.b) - ); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1, 4)], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(1, 4), (2, 5), (3, 6)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_shape_07(self): - # Put result of inner function taking Baz.bar@b into Baz - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - }; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{ - bar := assert_exists((select Bar filter .a = 1 limit 1)) { - @b := 4 - } - }; - insert Baz{ - bar := assert_exists((select Bar filter .a = 2 limit 1)) { - @b := 5 - } - }; - insert Baz{ - bar := assert_exists((select Bar filter .a = 3 limit 1)) { - @b := 6 - } - }; - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (x); - }; - create function foo(x: int64) -> set of tuple { - set is_inlined := true; - using ( - with y := ( - select Baz{a := .bar.a, b := inner(.bar@b)} - filter .a = x - ) - select (y.a, y.b) - ); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [(1, 4)], - sort=True, - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [(1, 4), (2, 5), (3, 6)], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_01(self): - # Use computed global in inner function - await self.con.execute(''' - create global a := 1; - create function inner(x: int64) -> int64 { - set is_inlined := true; - using (global a + x); - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_02(self): - # Use non-computed global in inner function - await self.con.execute(''' - create global a -> int64; - create function inner(x: int64) -> optional int64 { - set is_inlined := true; - using (global a + x); - }; - create function foo(x: int64) -> optional int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [], - sort=True, - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_03(self): - # Pass computed global to inner function - await self.con.execute(''' - create global a := 1; - create function inner(x: int64, y: int64) -> int64 { - set is_inlined := true; - using (x + y); - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner(global a, x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_04(self): - # Pass non-computed global to inner function - await self.con.execute(''' - create global a -> int64; - create function inner(x: int64, y: int64) -> optional int64 { - set is_inlined := true; - using (x + y); - }; - create function foo(x: int64) -> optional int64 { - set is_inlined := true; - using (inner(global a, x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [], - sort=True, - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_05(self): - # Use computed global in inner non-inlined function - # - inlined > non-inlined - await self.con.execute(''' - create global a := 1; - create function inner(x: int64) -> int64 { - using (global a + x); - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_06(self): - # Use non-computed global in inner non-inlined function - # - inlined > non-inlined - await self.con.execute(''' - create global a -> int64; - create function inner(x: int64) -> optional int64 { - using (global a + x); - }; - create function foo(x: int64) -> optional int64 { - set is_inlined := true; - using (inner(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [], - sort=True, - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_07(self): - # Use computed global nested in non-inlined function - # - non-inlined > inlined > non-inlined - await self.con.execute(''' - create global a := 1; - create function inner1(x: int64) -> int64 { - using (global a + x); - }; - create function inner2(x: int64) -> int64 { - set is_inlined := true; - using (inner1(x)); - }; - create function foo(x: int64) -> int64 { - using (inner2(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_08(self): - # Use non-computed global nested in non-inlined function - # - non-inlined > inlined > non-inlined - await self.con.execute(''' - create global a -> int64; - create function inner1(x: int64) -> optional int64 { - using (global a + x); - }; - create function inner2(x: int64) -> optional int64 { - set is_inlined := true; - using (inner1(x)); - }; - create function foo(x: int64) -> optional int64 { - using (inner2(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [], - sort=True, - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_09(self): - # Use computed global in deeply nested inner non-inlined function - # - inlined > inlined > inlined > non-inlined - await self.con.execute(''' - create global a := 1; - create function inner1(x: int64) -> int64 { - using (global a + x); - }; - create function inner2(x: int64) -> int64 { - set is_inlined := true; - using (inner1(x)); - }; - create function inner3(x: int64) -> int64 { - set is_inlined := true; - using (inner2(x)); - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using (inner3(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_nested_global_10(self): - # Use computed global in deeply nested inner non-inlined function - # - inlined > inlined > inlined > non-inlined - await self.con.execute(''' - create global a -> int64; - create function inner1(x: int64) -> optional int64 { - using (global a + x); - }; - create function inner2(x: int64) -> optional int64 { - set is_inlined := true; - using (inner1(x)); - }; - create function inner3(x: int64) -> optional int64 { - set is_inlined := true; - using (inner2(x)); - }; - create function foo(x: int64) -> optional int64 { - set is_inlined := true; - using (inner3(x)); - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [], - sort=True, - ) - - await self.con.execute(''' - set global a := 1; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - await self.assert_query_result( - 'select foo(1)', - [2], - ) - await self.assert_query_result( - 'select foo({1, 2, 3})', - [2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_modifying_cardinality_01(self): - await self.con.execute(''' - create function foo(x: int64) -> int64 { - set volatility := schema::Volatility.Modifying; - using (x) - }; - ''') - await self.assert_query_result( - 'select foo(1)', - [1], - ) - - async def test_edgeql_functions_inline_modifying_cardinality_02(self): - await self.con.execute(''' - create function foo(x: int64) -> int64 { - set volatility := schema::Volatility.Modifying; - using (x) - }; - ''') - with self.assertRaisesRegex( - edgedb.QueryError, - 'possibly an empty set passed as non-optional argument ' - 'into modifying function' - ): - await self.con.execute(''' - select foo({}) - ''') - - async def test_edgeql_functions_inline_modifying_cardinality_03(self): - await self.con.execute(''' - create function foo(x: int64) -> int64 { - set volatility := schema::Volatility.Modifying; - using (x) - }; - ''') - with self.assertRaisesRegex( - edgedb.QueryError, - 'possibly more than one element passed into modifying function' - ): - await self.con.execute(''' - select foo({1, 2, 3}) - ''') - - async def test_edgeql_functions_inline_modifying_cardinality_04(self): - await self.con.execute(''' - create function foo(x: optional int64) -> optional int64 { - set volatility := schema::Volatility.Modifying; - using (x) - }; - ''') - await self.assert_query_result( - 'select foo(1)', - [1], - ) - - async def test_edgeql_functions_inline_modifying_cardinality_05(self): - await self.con.execute(''' - create function foo(x: optional int64) -> optional int64 { - set volatility := schema::Volatility.Modifying; - using (x) - }; - ''') - await self.assert_query_result( - 'select foo({})', - [], - ) - - async def test_edgeql_functions_inline_modifying_cardinality_06(self): - await self.con.execute(''' - create function foo(x: optional int64) -> optional int64 { - set volatility := schema::Volatility.Modifying; - using (x) - }; - ''') - with self.assertRaisesRegex( - edgedb.QueryError, - 'possibly more than one element passed into modifying function' - ): - await self.con.execute(''' - select foo({1, 2, 3}) - ''') - - async def test_edgeql_functions_inline_insert_basic_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo() -> Bar { - set is_inlined := true; - using ((insert Bar{ a := 1 })); - }; - ''') - - await self.assert_query_result( - 'select foo().a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - async def test_edgeql_functions_inline_insert_basic_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })) - }; - ''') - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - async def test_edgeql_functions_inline_insert_basic_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using ((insert Bar{ a := x }).a) - }; - ''') - - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - async def test_edgeql_functions_inline_insert_basic_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x + 1 })) - }; - ''') - - await self.assert_query_result( - 'select foo(1).a', - [2], - ) - await self.assert_query_result( - 'select Bar.a', - [2], - ) - - async def test_edgeql_functions_inline_insert_basic_05(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using ((insert Bar{ a := 2 * x + 1 }).a + 10) - }; - ''') - - await self.assert_query_result( - 'select foo(1)', - [13], - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - - async def test_edgeql_functions_inline_insert_basic_06(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64 = 0) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })) - }; - ''') - - await self.assert_query_result( - 'select foo().a', - [0], - ) - await self.assert_query_result( - 'select Bar.a', - [0], - ) - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1], - ) - - async def test_edgeql_functions_inline_insert_basic_07(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: optional int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x ?? 0 })) - }; - ''') - - await self.assert_query_result( - 'select foo({}).a', - [0], - ) - await self.assert_query_result( - 'select Bar.a', - [0], - ) - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_basic_08(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(named only x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })) - }; - ''') - - await self.assert_query_result( - 'select foo(x := 1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - async def test_edgeql_functions_inline_insert_basic_09(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(variadic x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := sum(array_unpack(x)) })) - }; - ''') - - await self.assert_query_result( - 'select foo().a', - [0], - ) - await self.assert_query_result( - 'select Bar.a', - [0], - ) - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1], - sort=True, - ) - - await self.assert_query_result( - 'select foo(2, 3).a', - [5], - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 5], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_basic_10(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - create required property b -> int64; - }; - create function foo(x: int64, y: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x, b := y })) - }; - ''') - - await self.assert_query_result( - 'select foo(1, 10){a, b}' - 'order by .a then .b', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [{'a': 1, 'b': 10}], - ) - - async def test_edgeql_functions_inline_insert_iterator_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })) - }; - ''') - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (select foo(x).a)', - [2, 3, 4], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4], - sort=True, - ) - - await self.assert_query_result( - 'select if true then foo(5).a else 99', - [5], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4, 5], - sort=True, - ) - await self.assert_query_result( - 'select if false then foo(6).a else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4, 5], - sort=True, - ) - await self.assert_query_result( - 'select if true then 99 else foo(7).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4, 5], - sort=True, - ) - await self.assert_query_result( - 'select if false then 99 else foo(8).a', - [8], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4, 5, 8], - sort=True, - ) - - await self.assert_query_result( - 'select foo(9).a ?? 99', - [9], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4, 5, 8, 9], - sort=True, - ) - await self.assert_query_result( - 'select 99 ?? foo(10).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4, 5, 8, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_iterator_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - create required property b -> int64; - }; - create function foo(x: int64, y: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x, b := y })) - }; - ''') - - await self.assert_query_result( - 'select foo(1, 10){a, b}' - 'order by .a then .b', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [{'a': 1, 'b': 10}], - ) - - await self.assert_query_result( - 'select (' - ' for x in {2, 3} union(' - ' for y in {20, 30} union(' - ' select foo(x, y)' - ' )' - ' )' - '){a, b}' - 'order by .a then .b', - [ - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - ], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - ], - ) - - await self.assert_query_result( - 'select (' - ' if true' - ' then foo(5, 50)' - ' else (select Bar filter .a = 1)' - '){a, b}' - 'order by .a then .b', - [{'a': 5, 'b': 50}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - {'a': 5, 'b': 50}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then foo(6, 60)' - ' else (select Bar filter .a = 1)' - '){a, b}' - 'order by .a then .b', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - {'a': 5, 'b': 50}, - ], - ) - await self.assert_query_result( - 'select (' - ' if true' - ' then (select Bar filter .a = 1)' - ' else foo(7, 70)' - '){a, b}' - 'order by .a then .b', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - {'a': 5, 'b': 50}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then (select Bar filter .a = 1)' - ' else foo(8, 80)' - '){a, b}' - 'order by .a then .b', - [{'a': 8, 'b': 80}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - {'a': 5, 'b': 50}, - {'a': 8, 'b': 80}, - ], - ) - - await self.assert_query_result( - 'select (foo(9, 90) ?? (select Bar filter .a = 1)){a, b}', - [{'a': 9, 'b': 90}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - {'a': 5, 'b': 50}, - {'a': 8, 'b': 80}, - {'a': 9, 'b': 90}, - ], - ) - await self.assert_query_result( - 'select ((select Bar filter .a = 1) ?? foo(10, 100)){a, b}', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar{a, b}' - 'order by .a then .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - {'a': 5, 'b': 50}, - {'a': 8, 'b': 80}, - {'a': 9, 'b': 90}, - ], - ) - - async def test_edgeql_functions_inline_insert_iterator_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using ( - for y in {x, x + 1, x + 2} union ( - (insert Bar{ a := y }) - ) - ) - }; - ''') - - await self.assert_query_result( - 'select foo(1).a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - await self.assert_query_result( - 'for x in {11, 21, 31} union (select foo(x).a)', - [11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 11, 12, 13, 21, 22, 23, 31, 32, 33], - sort=True, - ) - - await self.assert_query_result( - 'select if true then foo(51).a else 99', - [51, 52, 53], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [ - 1, 2, 3, - 11, 12, 13, - 21, 22, 23, - 31, 32, 33, - 51, 52, 53, - ], - sort=True, - ) - await self.assert_query_result( - 'select if false then foo(61).a else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 1, 2, 3, - 11, 12, 13, - 21, 22, 23, - 31, 32, 33, - 51, 52, 53, - ], - sort=True, - ) - await self.assert_query_result( - 'select if true then 99 else foo(71).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 1, 2, 3, - 11, 12, 13, - 21, 22, 23, - 31, 32, 33, - 51, 52, 53, - ], - sort=True, - ) - await self.assert_query_result( - 'select if false then 99 else foo(81).a', - [81, 82, 83], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [ - 1, 2, 3, - 11, 12, 13, - 21, 22, 23, - 31, 32, 33, - 51, 52, 53, - 81, 82, 83, - ], - sort=True, - ) - - await self.assert_query_result( - 'select foo(91).a ?? 99', - [91, 92, 93], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 1, 2, 3, - 11, 12, 13, - 21, 22, 23, - 31, 32, 33, - 51, 52, 53, - 81, 82, 83, - 91, 92, 93, - ], - sort=True, - ) - await self.assert_query_result( - 'select 99 ?? foo(101).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 1, 2, 3, - 11, 12, 13, - 21, 22, 23, - 31, 32, 33, - 51, 52, 53, - 81, 82, 83, - 91, 92, 93, - ], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_iterator_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: bool, y: int64) -> optional Bar { - set is_inlined := true; - using ( - if x then (insert Bar{ a := y }) else {} - ) - }; - ''') - - await self.assert_query_result( - 'select foo(false, 0).a', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - await self.assert_query_result( - 'select foo(true, 1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4, 5} union (select foo(x % 2 = 0, x).a)', - [2, 4], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4], - sort=True, - ) - - await self.assert_query_result( - 'select if true then foo(false, 6).a else 99', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4], - sort=True, - ) - await self.assert_query_result( - 'select if true then foo(true, 6).a else 99', - [6], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select if false then foo(false, 7).a else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select if false then foo(true, 7).a else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select if true then 99 else foo(false, 8).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select if true then 99 else foo(true, 8).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select if false then 99 else foo(false, 9).a', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6], - sort=True, - ) - await self.assert_query_result( - 'select if false then 99 else foo(true, 9).a', - [9], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6, 9], - sort=True, - ) - - await self.assert_query_result( - 'select foo(false, 10).a ?? 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6, 9], - sort=True, - ) - await self.assert_query_result( - 'select foo(true, 10).a ?? 99', - [10], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6, 9, 10], - sort=True, - ) - await self.assert_query_result( - 'select 99 ?? foo(false, 11).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6, 9, 10], - sort=True, - ) - await self.assert_query_result( - 'select 99 ?? foo(true, 11).a', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 4, 6, 9, 10], - sort=True, - ) - - @unittest.skip('Cannot correlate same set inside and outside DML') - async def test_edgeql_functions_inline_insert_correlate_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> tuple { - set is_inlined := true; - using (((insert Bar{ a := x }), x)) - }; - ''') - - await self.assert_query_result( - 'select foo(1)', - [[[], 1]], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (select foo(x).a)', - [2, 3, 4], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4], - sort=True, - ) - - @unittest.skip('Cannot correlate same set inside and outside DML') - async def test_edgeql_functions_inline_insert_correlate_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> int64 { - set is_inlined := true; - using ((insert Bar{ a := 2 * x + 1 }).a + x * x) - }; - ''') - - await self.assert_query_result( - 'select foo(1)', - [4], - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (select foo(x))', - [9, 16, 25], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3, 5, 7, 9], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_correlate_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> tuple { - set is_inlined := true; - using (( - (insert Bar{ a := x }).a, - (insert Bar{ a := x + 1 }).a, - )) - }; - ''') - - await self.assert_query_result( - 'select foo(1)', - [[1, 2]], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2], - sort=True, - ) - - await self.assert_query_result( - 'for x in {11, 21, 31} union (select foo(x))', - [[11, 12], [21, 22], [31, 32]], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 11, 12, 21, 22, 31, 32], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_correlate_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64, y: int64) -> tuple { - set is_inlined := true; - using (( - (insert Bar{ a := x }).a, - (insert Bar{ a := y }).a, - )) - }; - ''') - - await self.assert_query_result( - 'select foo(1, 2)', - [[1, 2]], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2], - sort=True, - ) - - await self.assert_query_result( - 'for x in {1, 5} union (' - ' for y in {10, 20} union (' - ' select foo(x + y, x + y + 1)' - ' )' - ')', - [[11, 12], [15, 16], [21, 22], [25, 26]], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 11, 12, 15, 16, 21, 22, 25, 26], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_correlate_05(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64, y: int64) -> int64 { - set is_inlined := true; - using ((insert Bar{ a := 2 * x + 1 }).a + y) - }; - ''') - - await self.assert_query_result( - 'select foo(1, 10)', - [13], - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - - await self.assert_query_result( - 'for x in {2, 3} union(' - ' for y in {20, 30} union(' - ' select foo(x, y)' - ' )' - ')', - [25, 27, 35, 37], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3, 5, 5, 7, 7], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_conflict_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - create constraint exclusive on (.a) - }; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using (( - insert Bar{a := x} - unless conflict on .a - else ((update Bar set {a := x + 10})) - )) - }; - ''') - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x).a)', - [2, 3, 11], - sort=True - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3, 11], - ) - - async def test_edgeql_functions_inline_insert_conflict_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create type Baz { - create link bar -> Bar; - create constraint exclusive on (.bar) - }; - create function foo(x: Bar) -> Baz { - set is_inlined := true; - using (( - insert Baz{bar := x} - unless conflict on .bar - else (( - update Baz set {bar := (insert Bar{a := x.a + 10})} - )) - )) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' assert_exists((select Bar filter .a = 1 limit 1))' - ').bar.a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - ) - await self.assert_query_result( - 'select Baz.bar.a', - [1], - ) - - await self.assert_query_result( - 'for x in {1, 2, 3} union (' - ' select foo(' - ' assert_exists((select Bar filter .a = x limit 1))' - ' ).bar.a' - ')', - [2, 3, 11], - sort=True - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 11], - ) - await self.assert_query_result( - 'select Baz.bar.a', - [2, 3, 11], - ) - - async def test_edgeql_functions_inline_insert_link_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(n: int64, x: Bar) -> Baz { - set is_inlined := true; - using ((insert Baz{ b := n, bar := x })) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' 4,' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b}', - [{'a': 1, 'b': 4}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a', - [{'a': 1, 'b': 4}], - ) - - await self.assert_query_result( - 'select foo(' - ' 5,' - ' assert_exists((select Bar filter .a = 2 limit 1))' - '){a := .bar.a, b}', - [{'a': 2, 'b': 5}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - ], - ) - - async def test_edgeql_functions_inline_insert_link_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create multi link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: int64, y: int64) -> Baz { - set is_inlined := true; - using ( - (insert Baz{ - b := x, - bar := (select Bar filter .a <= y), - }) - ); - }; - ''') - - await self.assert_query_result( - 'select foo(4, 1){a := .bar.a, b}', - [{'a': [1], 'b': 4}], - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [{'a': [1], 'b': 4}], - ) - - await self.assert_query_result( - 'select foo(5, 2){a := .bar.a, b}', - [{'a': [1, 2], 'b': 5}], - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': [1], 'b': 4}, - {'a': [1, 2], 'b': 5}, - ], - ) - - async def test_edgeql_functions_inline_insert_link_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - create function foo(x: int64, y: int64) -> Baz { - set is_inlined := true; - using ( - (insert Baz { - b := y, - bar := (insert Bar{ a := x }) - }) - ); - }; - ''') - - await self.assert_query_result( - 'select foo(1, 4).b', - [4], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b', - [{'a': 1, 'b': 4}], - ) - - await self.assert_query_result( - 'select foo(2, 5).b', - [5], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2], - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - ], - ) - - async def test_edgeql_functions_inline_insert_link_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar {a := x})) - }; - ''') - - await self.assert_query_result( - 'select (insert Baz{b := 4, bar := foo(1)})' - '{a := .bar.a, b} order by .b', - [{'a': 1, 'b': 4}], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b', - [{'a': 1, 'b': 4}], - ) - - await self.assert_query_result( - 'select (insert Baz{b := 5, bar := foo(2)})' - '{a := .bar.a, b} order by .b', - [{'a': 2, 'b': 5}], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2], - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - ], - ) - - async def test_edgeql_functions_inline_insert_link_iterator_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar{a := 4}; - create function foo(n: int64, x: Bar) -> Baz { - set is_inlined := true; - using ((insert Baz{ b := n, bar := x })) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' 1, assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b} order by .a then .b', - [{'a': 1, 'b': 1}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [{'a': 1, 'b': 1}], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (' - ' select foo(' - ' x, assert_exists((select Bar filter .a = 2 limit 1))' - ' ).b' - ')', - [2, 3, 4], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - - await self.assert_query_result( - 'select (' - ' if true' - ' then foo(' - ' 5, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).b' - ' else 99' - ')', - [5], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then foo(' - ' 6, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).b' - ' else 99' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select (' - ' if true' - ' then 99' - ' else foo(' - ' 7, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).b' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then 99' - ' else foo(' - ' 8, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).b' - ')', - [8], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - {'a': 3, 'b': 8}, - ], - ) - - await self.assert_query_result( - 'select foo(' - ' 9, assert_exists((select Bar filter .a = 4 limit 1))' - ').b ?? 99', - [9], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - {'a': 3, 'b': 8}, - {'a': 4, 'b': 9}, - ], - ) - await self.assert_query_result( - 'select 99 ?? foo(' - ' 9, assert_exists((select Bar filter .a = 4 limit 1))' - ').b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - {'a': 3, 'b': 8}, - {'a': 4, 'b': 9}, - ], - ) - - async def test_edgeql_functions_inline_insert_link_iterator_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create multi link bar -> Bar; - }; - create function foo(x: int64, y: int64) -> Baz { - set is_inlined := true; - using ( - (insert Baz { - b := y, - bar := (for z in {x, x + 1, x + 2} union( - (insert Bar{ a := z }) - )) - }) - ); - }; - ''') - - await self.assert_query_result( - 'select foo(10, 1).b', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [10, 11, 12], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [{'a': [10, 11, 12], 'b': 1}], - ) - - await self.assert_query_result( - 'for x in {20, 30} union (' - ' for y in {2, 3} union (' - ' select foo(x, y).b' - ' )' - ')', - [2, 2, 3, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - ], - ) - - await self.assert_query_result( - 'select if true then foo(40, 4).b else 999', - [4], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - 40, 41, 42, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - {'a': [40, 41, 42], 'b': 4}, - ], - ) - await self.assert_query_result( - 'select if false then foo(50, 5).b else 999', - [999], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - 40, 41, 42, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - {'a': [40, 41, 42], 'b': 4}, - ], - ) - await self.assert_query_result( - 'select if true then 999 else foo(60, 6).b', - [999], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - 40, 41, 42, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - {'a': [40, 41, 42], 'b': 4}, - ], - ) - await self.assert_query_result( - 'select if false then 999 else foo(70, 7).b', - [7], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - 40, 41, 42, - 70, 71, 72, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - {'a': [40, 41, 42], 'b': 4}, - {'a': [70, 71, 72], 'b': 7}, - ], - ) - - await self.assert_query_result( - 'select foo(80, 8).b ?? 999', - [8], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - 40, 41, 42, - 70, 71, 72, - 80, 81, 82, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - {'a': [40, 41, 42], 'b': 4}, - {'a': [70, 71, 72], 'b': 7}, - {'a': [80, 81, 82], 'b': 8}, - ], - ) - await self.assert_query_result( - 'select 999 ?? foo(90, 9).b', - [999], - ) - await self.assert_query_result( - 'select Bar.a', - [ - 10, 11, 12, - 20, 20, 21, 21, 22, 22, - 30, 30, 31, 31, 32, 32, - 40, 41, 42, - 70, 71, 72, - 80, 81, 82, - ], - sort=True, - ) - await self.assert_query_result( - 'select Baz {a := .bar.a, b} order by .b then sum(.a)', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [30, 31, 32], 'b': 2}, - {'a': [20, 21, 22], 'b': 3}, - {'a': [30, 31, 32], 'b': 3}, - {'a': [40, 41, 42], 'b': 4}, - {'a': [70, 71, 72], 'b': 7}, - {'a': [80, 81, 82], 'b': 8}, - ], - ) - - async def test_edgeql_functions_inline_insert_link_iterator_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar{a := 4}; - create function foo(n: int64, x: Bar, flag: bool) -> optional Baz { - set is_inlined := true; - using ( - if flag then (insert Baz{ b := n, bar := x }) else {} - ) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' 0, assert_exists((select Bar filter .a = 1 limit 1)), false' - '){a := .bar.a, b} order by .a then .b', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [], - ) - await self.assert_query_result( - 'select foo(' - ' 1, assert_exists((select Bar filter .a = 1 limit 1)), true' - '){a := .bar.a, b} order by .a then .b', - [{'a': 1, 'b': 1}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [{'a': 1, 'b': 1}], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (' - ' select foo(' - ' x,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' false,' - ' ).b' - ')', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [{'a': 1, 'b': 1}], - ) - await self.assert_query_result( - 'for x in {2, 3, 4} union (' - ' select foo(' - ' x,' - ' assert_exists((select Bar filter .a = 2 limit 1)),' - ' true,' - ' ).b' - ')', - [2, 3, 4], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - - await self.assert_query_result( - 'select (' - ' if true' - ' then foo(' - ' 5,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' false,' - ' ).b' - ' else 99' - ')', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then foo(' - ' 6,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' false,' - ' ).b' - ' else 99' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select (' - ' if true' - ' then 99' - ' else foo(' - ' 7,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' false,' - ' ).b' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then 99' - ' else foo(' - ' 8,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' false,' - ' ).b' - ')', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select (' - ' if true' - ' then foo(' - ' 9,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' true,' - ' ).b' - ' else 99' - ')', - [9], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' true,' - ' ).b' - ' else 99' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - ], - ) - await self.assert_query_result( - 'select (' - ' if true' - ' then 99' - ' else foo(' - ' 11,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' true,' - ' ).b' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then 99' - ' else foo(' - ' 12,' - ' assert_exists((select Bar filter .a = 3 limit 1)),' - ' true,' - ' ).b' - ')', - [12], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - {'a': 3, 'b': 12}, - ], - ) - - await self.assert_query_result( - 'select foo(' - ' 13, assert_exists((select Bar filter .a = 4 limit 1)), false' - ').b ?? 99', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - {'a': 3, 'b': 12}, - ], - ) - await self.assert_query_result( - 'select 99 ?? foo(' - ' 14, assert_exists((select Bar filter .a = 4 limit 1)), false' - ').b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - {'a': 3, 'b': 12}, - ], - ) - await self.assert_query_result( - 'select foo(' - ' 15, assert_exists((select Bar filter .a = 4 limit 1)), true' - ').b ?? 99', - [15], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - {'a': 3, 'b': 12}, - {'a': 4, 'b': 15}, - ], - ) - await self.assert_query_result( - 'select 99 ?? foo(' - ' 16, assert_exists((select Bar filter .a = 4 limit 1)), true' - ').b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 9}, - {'a': 3, 'b': 12}, - {'a': 4, 'b': 15}, - ], - ) - - async def test_edgeql_functions_inline_insert_linkprop_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - } - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(x: Bar) -> Baz { - set is_inlined := true; - using ((insert Baz{ bar := x { @b := 10 } })) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b := .bar@b}', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a', - [{'a': 1, 'b': 10}], - ) - - async def test_edgeql_functions_inline_insert_linkprop_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - } - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - create function foo(n: int64, x: Bar) -> Baz { - set is_inlined := true; - using ((insert Baz{ bar := x { @b := n } })) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' 4,' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b := .bar@b}', - [{'a': 1, 'b': 4}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a', - [{'a': 1, 'b': 4}], - ) - - async def test_edgeql_functions_inline_insert_linkprop_iterator_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - } - }; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar{a := 4}; - create function foo(n: int64, x: Bar) -> Baz { - set is_inlined := true; - using ((insert Baz{ bar := x { @b := n } })) - }; - ''') - - await self.assert_query_result( - 'select foo(' - ' 1,' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b := .bar@b}', - [{'a': 1, 'b': 1}], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a', - [{'a': 1, 'b': 1}], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (' - ' select foo(' - ' x, assert_exists((select Bar filter .a = 2 limit 1))' - ' ).bar@b' - ')', - [2, 3, 4], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - ], - ) - - await self.assert_query_result( - 'select (' - ' if true' - ' then foo(' - ' 5, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).bar@b' - ' else 99' - ')', - [5], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then foo(' - ' 6, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).bar@b' - ' else 99' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select (' - ' if true' - ' then 99' - ' else foo(' - ' 7, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).bar@b' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select (' - ' if false' - ' then 99' - ' else foo(' - ' 8, assert_exists((select Bar filter .a = 3 limit 1))' - ' ).bar@b' - ')', - [8], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - {'a': 3, 'b': 8}, - ], - ) - - await self.assert_query_result( - 'select foo(' - ' 9, assert_exists((select Bar filter .a = 4 limit 1))' - ').bar@b ?? 99', - [9], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - {'a': 3, 'b': 8}, - {'a': 4, 'b': 9}, - ], - ) - await self.assert_query_result( - 'select 99 ?? foo(' - ' 9, assert_exists((select Bar filter .a = 4 limit 1))' - ').bar@b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', - [ - {'a': 1, 'b': 1}, - {'a': 2, 'b': 2}, - {'a': 2, 'b': 3}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': 5}, - {'a': 3, 'b': 8}, - {'a': 4, 'b': 9}, - ], - ) - - async def test_edgeql_functions_inline_insert_nested_01(self): - # Simple inner modifying function - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function inner(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })); - }; - create function foo(x: int64) -> Bar { - set is_inlined := true; - using (inner(x)); - }; - ''') - - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - - await self.assert_query_result( - 'for x in {2, 3, 4} union (foo(x).a)', - [2, 3, 4], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3, 4], - sort=True, - ) - - async def test_edgeql_functions_inline_insert_nested_02(self): - # Putting the result of an inner modifying function into shape - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create required link bar -> Bar; - }; - create function inner1(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })) - }; - create function inner2(x: int64, y: int64) -> Baz { - set is_inlined := true; - using ((insert Baz{ b := y, bar := inner1(x) })) - }; - create function foo(x: int64, y: int64) -> Baz { - set is_inlined := true; - using (inner2(x, y)) - }; - ''') - - await self.assert_query_result( - 'select foo(1, 10){a := .bar.a, b := .b}', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .b} order by .a', - [{'a': 1, 'b': 10}], - ) - - await self.assert_query_result( - 'select (' - ' for x in {2, 3} union (' - ' for y in {20, 30} union (' - ' foo(x, y){a := .bar.a, b := .b}' - ' )' - ' )' - ') order by .a then .b', - [ - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 2, 3, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .b} order by .a', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - ], - ) - - async def test_edgeql_functions_inline_insert_nested_03(self): - # Putting the result of an inner modifying function into shape with - # link property - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - }; - }; - create function inner1(x: int64) -> Bar { - set is_inlined := true; - using ((insert Bar{ a := x })) - }; - create function inner2(x: int64, y: int64) -> Baz { - set is_inlined := true; - using ((insert Baz{ bar := inner1(x){ @b := y } })) - }; - create function foo(x: int64, y: int64) -> Baz { - set is_inlined := true; - using (inner2(x, y)) - }; - ''') - - await self.assert_query_result( - 'select foo(1, 10){a := .bar.a, b := .bar@b}', - [{'a': 1, 'b': 10}], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a', - [{'a': 1, 'b': 10}], - ) - - await self.assert_query_result( - 'select (' - ' for x in {2, 3} union (' - ' for y in {20, 30} union (' - ' foo(x, y){a := .bar.a, b := .bar@b}' - ' )' - ' )' - ') order by .a then .b', - [ - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 2, 3, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': 2, 'b': 30}, - {'a': 3, 'b': 20}, - {'a': 3, 'b': 30}, - ], - ) - - async def test_edgeql_functions_inline_update_basic_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using ((update Bar set { a := x })); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(1).a', - [1, 1, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 1, 1], - sort=True, - ) - - async def test_edgeql_functions_inline_update_basic_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64, y: int64) -> set of int64 { - set is_inlined := true; - using ((update Bar filter .a <= y set { a := x }).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0, 0)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1)', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 3)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - async def test_edgeql_functions_inline_update_basic_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo( - named only m: int64, - named only n: int64, - ) -> set of int64 { - set is_inlined := true; - using ((update Bar filter .a <= n set { a := m }).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(m := 0, n := 0)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(m := 0, n := 1)', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(m := 0, n := 2)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(m := 0, n := 3)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - async def test_edgeql_functions_inline_update_basic_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo( - x: optional int64, - y: optional int64, - ) -> set of int64 { - set is_inlined := true; - using ((update Bar filter .a <= y ?? 9 set { a := x ?? 9 }).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo({}, {})', - [9, 9, 9], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [9, 9, 9], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo({}, 2)', - [9, 9], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3, 9, 9], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2, {})', - [2, 2, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [2, 2, 2], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0, 0)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1)', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 3)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - async def test_edgeql_functions_inline_update_basic_05(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo( - x: int64, - variadic y: int64, - ) -> set of int64 { - set is_inlined := true; - using ( - ( - update Bar - filter .a <= sum(array_unpack(y)) - set { a := x } - ).a - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1)', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1, 2)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - async def test_edgeql_functions_inline_update_iterator_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64, y: int64) -> set of int64 { - set is_inlined := true; - using ((update Bar filter .a <= y set { a := x }).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0, 0)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1)', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 3)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(0, x))', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(0, x))', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x - 1, 0))', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x - 1, 3))', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {1} union (select foo(x - 1, x))', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {2, 3} union (select foo(x - 1, x))', - [1, 1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 1, 2], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x - 1, x))', - [0, 1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 2], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(0, 2) else 99', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(0, 2) else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(0, 2)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(0, 2)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0, 0) ?? 99', - [99], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2) ?? 99', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(0, 2)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_update_iterator_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64, y: int64) -> set of int64 { - set is_inlined := true; - using ( - for z in {0, 1} union ( - (update Bar filter .a <= y + z set { a := x + z }).a - ) - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0, 0)', - [1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1)', - [0, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2)', - [0, 0, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 1], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 3)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(0, x))', - [1, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 1, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(0, x))', - [0, 1, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 1], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x - 1, 0))', - [1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x - 1, 3))', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {1} union (select foo(x - 1, x))', - [0, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {2, 3} union (select foo(x - 1, x))', - [1, 1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 1, 2], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x - 1, x))', - [0, 1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 2], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(0, 1) else 99', - [0, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(0, 1) else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(0, 1)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(0, 1)', - [0, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 3], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0, -1) ?? 99', - [99], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1) ?? 99', - [0, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(0, 1)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_update_iterator_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo( - x: int64, y: int64, z: bool - ) -> set of int64 { - set is_inlined := true; - using ( - if z - then (update Bar filter .a <= y set { a := x }).a - else {} - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0, 2, false)', - [], - ) - await self.assert_query_result( - 'select foo(0, 3, false)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2, true)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 3, true)', - [0, 0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 0], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(0, x, false))', - [], - sort=True, - ) - await self.assert_query_result( - 'for x in {2, 3} union (select foo(x - 1, x, false))', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(0, x, true))', - [0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {2, 3} union (select foo(x - 1, x, true))', - [1, 1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 1, 2], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(0, 2, false) else 99', - [], - sort=True, - ) - await self.assert_query_result( - 'select if false then foo(0, 2, false) else 99', - [99], - ) - await self.assert_query_result( - 'select if true then 99 else foo(0, 2, false)', - [99], - ) - await self.assert_query_result( - 'select if false then 99 else foo(0, 2, false)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then foo(0, 2, true) else 99', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(0, 2, true) else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(0, 2, true)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(0, 2, true)', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0, 0, false) ?? 99', - [99], - sort=True, - ) - await self.assert_query_result( - 'select foo(0, 2, false) ?? 99', - [99], - sort=True, - ) - await self.assert_query_result( - 'select 99 ?? foo(0, 2, false)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 0, true) ?? 99', - [99], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 2, true) ?? 99', - [0, 0], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [0, 0, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(0, 2, true)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_update_link_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar; - }; - create function foo(n: int64, x: Bar) -> set of Baz { - set is_inlined := true; - using ((update Baz filter .b <= n set { bar := x })) - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{b := 4}; - insert Baz{b := 5}; - insert Baz{b := 6}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(' - ' 4,' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b}', - [ - {'a': 1, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': None, 'b': 5}, - {'a': None, 'b': 6}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(' - ' 5,' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b}', - [ - {'a': 1, 'b': 4}, - {'a': 1, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 1, 'b': 5}, - {'a': None, 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_update_link_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create multi link bar -> Bar; - }; - create function foo(x: int64, y: int64) -> set of Baz { - set is_inlined := true; - using ( - (update Baz filter .b <= x set { - bar := (select Bar filter .a <= y), - }) - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{b := 4}; - insert Baz{b := 5}; - insert Baz{b := 6}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(4, 1){a := .bar.a, b}', - [ - {'a': [1], 'b': 4}, - ], - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': [1], 'b': 4}, - {'a': [], 'b': 5}, - {'a': [], 'b': 6}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(5, 2){a := .bar.a, b}', - [ - {'a': [1, 2], 'b': 4}, - {'a': [1, 2], 'b': 5}, - ], - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': [1, 2], 'b': 4}, - {'a': [1, 2], 'b': 5}, - {'a': [], 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_update_link_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create optional link bar -> Bar; - }; - create function foo(x: int64, y: int64) -> set of Baz { - set is_inlined := true; - using ( - (update Baz filter .b <= x set { - bar := (insert Bar{a := y}), - }) - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 4}; - insert Baz{b := 5}; - insert Baz{b := 6}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(4, 1){a := .bar.a, b}', - [ - {'a': 1, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': None, 'b': 5}, - {'a': None, 'b': 6}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(5, 2){a := .bar.a, b}', - [ - {'a': 2, 'b': 4}, - {'a': 2, 'b': 5}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 2], - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': 2, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': None, 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_update_link_iterator_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar; - }; - create function foo(n: int64, x: Bar) -> set of Baz { - set is_inlined := true; - using ((update Baz filter .b = n set { bar := x })) - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Bar{a := 4}; - insert Baz{b := 10}; - insert Baz{b := 20}; - insert Baz{b := 30}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1))' - '){a := .bar.a, b}', - [ - {'a': 1, 'b': 10}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select (' - ' for x in {1, 2} union(' - ' select foo(' - ' x * 10,' - ' assert_exists((select Bar filter .a = x limit 1))' - ' ).b' - ' )' - ')', - [10, 20], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 10}, - {'a': 2, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select (' - ' if true' - ' then foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1)),' - ' ).b' - ' else 99' - ')', - [10], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select (' - ' if false' - ' then foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1)),' - ' ).b' - ' else 99' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select (' - ' if true' - ' then 99' - ' else foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1)),' - ' ).b' - ')', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select (' - ' if false' - ' then 99' - ' else foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1)),' - ' ).b' - ')', - [10], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1)),' - ').b ?? 99', - [10], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(' - ' 10,' - ' assert_exists((select Bar filter .a = 1 limit 1)),' - ').b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 10}, - {'a': None, 'b': 20}, - {'a': None, 'b': 30}, - ], - ) - - async def test_edgeql_functions_inline_update_link_iterator_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create multi link bar -> Bar; - }; - create function foo(x: int64, y: int64) -> set of Baz { - set is_inlined := true; - using (( - update Baz filter .b = x set { - bar := (for z in {y, y + 1, y + 2} union ( - insert Bar{a := z} - ) - ) - } - )) - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 1}; - insert Baz{b := 2}; - insert Baz{b := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(1, 10){a := .bar.a, b}', - [ - {'a': [10, 11, 12], 'b': 1}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'for x in {1, 2} union (select foo(x, x * 10){a := .bar.a, b})', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [20, 21, 22], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(1, 10).b else 99', - [1], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(1, 10).b else 99', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(1, 10).b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(1, 10).b', - [1], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(1, 10).b ?? 99', - [1], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [10, 11, 12], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(1, 10).b', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': [], 'b': 1}, - {'a': [], 'b': 2}, - {'a': [], 'b': 3}, - ], - ) - - async def test_edgeql_functions_inline_update_link_iterator_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar; - }; - create function foo(x: int64, y: int64, flag: bool) -> set of Baz { - set is_inlined := true; - using (( - update Baz filter .b = x set { - bar := ( - if flag - then (insert Bar{a := y}) - else {} - ) - } - )) - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 1}; - insert Baz{b := 2}; - insert Baz{b := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(1, 10, false){a := .bar.a, b}', - [ - {'a': None, 'b': 1}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(1, 10, true){a := .bar.a, b}', - [ - {'a': 10, 'b': 1}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [10], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 10, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'for x in {1, 2} union (' - ' select foo(x, x * 10, false){a := .bar.a, b}' - ')', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2} union (' - ' select foo(x, x * 10, true){a := .bar.a, b}' - ')', - [ - {'a': 10, 'b': 1}, - {'a': 20, 'b': 2}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 10, 'b': 1}, - {'a': 20, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(1, 10, false).bar.a else 99', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(1, 10, false).bar.a else 99', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(1, 10, false).bar.a', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(1, 10, false).bar.a', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if true then foo(1, 10, true).bar.a else 99', - [10], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 10, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(1, 10, true).bar.a else 99', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(1, 10, true).bar.a', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(1, 10, true).bar.a', - [10], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 10, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(1, 10, false).bar.a ?? 99', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(1, 10, false).bar.a', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(1, 10, true).bar.a ?? 99', - [10], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 10, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(1, 10, true).bar.a', - [99], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 1}, - {'a': None, 'b': 2}, - {'a': None, 'b': 3}, - ], - ) - - async def test_edgeql_functions_inline_update_linkprop_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required link bar -> Bar { - create property b -> int64; - } - }; - create function foo(x: int64, y: int64) -> set of Baz { - set is_inlined := true; - using (( - update Baz filter .bar.a <= x set { - bar := .bar { @b := y } - } - )) - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{bar := (insert Bar{a := 1})}; - insert Baz{bar := (insert Bar{a := 2})}; - insert Baz{bar := (insert Bar{a := 3})}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(2, 4){a := .bar.a, b := .bar@b}', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 4}, - ], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b := .bar@b} order by .a', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 4}, - {'a': 3, 'b': None}, - ], - ) - - async def test_edgeql_functions_inline_update_nested_01(self): - # Simple inner modifying function - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function inner(x: int64) -> set of Bar { - set is_inlined := true; - using ((update Bar set { a := x })); - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using (inner(x)); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(1).a', - [1, 1, 1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 1, 1], - sort=True, - ) - - async def test_edgeql_functions_inline_update_nested_02(self): - # Putting the result of an inner modifying function into shape - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create multi link bar -> Bar; - }; - create function inner1(y: int64) -> set of Bar { - set is_inlined := true; - using ((update Bar filter .a <= y set { a := .a - 1 })); - }; - create function inner2(x: int64, y: int64) -> set of Baz { - set is_inlined := true; - using ( - (update Baz filter .b <= x set { - bar := assert_distinct(inner1(y)), - }) - ); - }; - create function foo(x: int64, y: int64) -> set of Baz { - set is_inlined := true; - using (inner2(x, y)); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - insert Baz{b := 4}; - insert Baz{b := 5}; - insert Baz{b := 6}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(4, 1){a := .bar.a, b}', - [ - {'a': [0], 'b': 4}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [0, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': [0], 'b': 4}, - {'a': [], 'b': 5}, - {'a': [], 'b': 6}, - ], - ) - - # Inner update will return an empty set for all subsequent calls. - await reset_data() - await self.assert_query_result( - 'select foo(5, 2){a := .bar.a, b}', - [ - {'a': [0, 1], 'b': 4}, - {'a': [], 'b': 5}, - ], - ) - await self.assert_query_result( - 'select Bar.a', - [0, 1, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz {' - ' a := (select .bar order by .a).a,' - ' b,' - '} order by .b', - [ - {'a': [0, 1], 'b': 4}, - {'a': [], 'b': 5}, - {'a': [], 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_delete_basic_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using ((delete Bar filter .a <= x)); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2).a', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - - async def test_edgeql_functions_inline_delete_basic_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ((delete Bar filter .a <= x).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - async def test_edgeql_functions_inline_delete_basic_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(named only m: int64) -> set of int64 { - set is_inlined := true; - using ((delete Bar filter .a <= m).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(m := 0)', - [], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(m := 1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(m := 2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select foo(m := 3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - async def test_edgeql_functions_inline_delete_basic_04(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: optional int64) -> set of int64 { - set is_inlined := true; - using ((delete Bar filter .a <= x ?? 9).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo({})', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - async def test_edgeql_functions_inline_delete_basic_05(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo( - variadic x: int64, - ) -> set of int64 { - set is_inlined := true; - using ( - ( - delete Bar - filter .a <= sum(array_unpack(x)) - ).a - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, 1, 2)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - async def test_edgeql_functions_inline_delete_iterator_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ((delete Bar filter .a <= x).a); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(x))', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(2) else 99', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(2) else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(2)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0) ?? 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2) ?? 99', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(2)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_delete_iterator_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ( - for z in {0, 1} union ( - (delete Bar filter .a <= x).a - ) - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(x))', - [1], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - ) - await reset_data() - await self.assert_query_result( - 'for x in {1, 2, 3} union (select foo(x))', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(2) else 99', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(2) else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(2)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0) ?? 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2) ?? 99', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(2)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_delete_iterator_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function foo( - x: int64, y: bool - ) -> set of int64 { - set is_inlined := true; - using ( - if y - then (delete Bar filter .a <= x).a - else {} - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(2, false)', - [], - ) - await self.assert_query_result( - 'select foo(3, false)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2, true)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3, true)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(x, false))', - [], - ) - await self.assert_query_result( - 'for x in {2, 3} union (select foo(x, false))', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {0, 1} union (select foo(x, true))', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'for x in {2, 3} union (select foo(x, true))', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - - await reset_data() - await self.assert_query_result( - 'select if true then foo(2, false) else 99', - [], - ) - await self.assert_query_result( - 'select if false then foo(2, false) else 99', - [99], - ) - await self.assert_query_result( - 'select if true then 99 else foo(2, false)', - [99], - ) - await self.assert_query_result( - 'select if false then 99 else foo(2, false)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then foo(2, true) else 99', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await reset_data() - await self.assert_query_result( - 'select if false then foo(2, true) else 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if true then 99 else foo(2, true)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select if false then 99 else foo(2, true)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - sort=True, - ) - - await reset_data() - await self.assert_query_result( - 'select foo(0, false) ?? 99', - [99], - ) - await self.assert_query_result( - 'select foo(2, false) ?? 99', - [99], - ) - await self.assert_query_result( - 'select 99 ?? foo(2, false)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(0, true) ?? 99', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2, true) ?? 99', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select 99 ?? foo(2, true)', - [99], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - - async def test_edgeql_functions_inline_delete_policy_target_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar { - on target delete allow; - }; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ( - (delete Bar filter .a <= x).a - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 4, bar := (insert Bar{a := 1})}; - insert Baz{b := 5, bar := (insert Bar{a := 2})}; - insert Baz{b := 6, bar := (insert Bar{a := 3})}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 4}, - {'a': None, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': None, 'b': 4}, - {'a': None, 'b': 5}, - {'a': None, 'b': 6}, - ], - ) - - async def test_edgeql_functions_inline_delete_policy_target_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar { - on target delete delete source; - }; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ( - (delete Bar filter .a <= x).a - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 4, bar := (insert Bar{a := 1})}; - insert Baz{b := 5, bar := (insert Bar{a := 2})}; - insert Baz{b := 6, bar := (insert Bar{a := 3})}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b}', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(1)', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b}', - [ - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(2)', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b}', - [ - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(3)', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b}', - [], - ) - - async def test_edgeql_functions_inline_delete_policy_source_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar { - on source delete allow; - }; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ( - (delete Baz filter .b <= x).b - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 4, bar := (insert Bar{a := 1})}; - insert Baz{b := 5, bar := (insert Bar{a := 2})}; - insert Baz{b := 6, bar := (insert Bar{a := 3})}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(4)', - [4], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(5)', - [4, 5], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(6)', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [], - ) - - async def test_edgeql_functions_inline_delete_policy_source_02(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar { - on source delete delete target; - }; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ( - (delete Baz filter .b <= x).b - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 4, bar := (insert Bar{a := 1})}; - insert Baz{b := 5, bar := (insert Bar{a := 2})}; - insert Baz{b := 6, bar := (insert Bar{a := 3})}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(4)', - [4], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(5)', - [4, 5], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 3, 'b': 6}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(6)', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [], - ) - - async def test_edgeql_functions_inline_delete_policy_source_03(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create type Baz { - create required property b -> int64; - create link bar -> Bar { - on source delete delete target if orphan; - }; - }; - create function foo(x: int64) -> set of int64 { - set is_inlined := true; - using ( - (delete Baz filter .b <= x).b - ); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Baz; - delete Bar; - insert Baz{b := 4, bar := (insert Bar{a := 1})}; - insert Baz{b := 5, bar := (insert Bar{a := 2})}; - insert Baz{b := 6, bar := (insert Bar{a := 3})}; - insert Baz{ - b := 7, - bar := assert_exists((select Bar filter .a = 1 limit 1)), - }; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(0)', - [], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 4}, - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - {'a': 1, 'b': 7}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(4)', - [4], - ) - await self.assert_query_result( - 'select Bar.a', - [1, 2, 3], - sort=True, - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 2, 'b': 5}, - {'a': 3, 'b': 6}, - {'a': 1, 'b': 7}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(5)', - [4, 5], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1, 3], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 3, 'b': 6}, - {'a': 1, 'b': 7}, - ], - ) - await reset_data() - await self.assert_query_result( - 'select foo(6)', - [4, 5, 6], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [1], - ) - await self.assert_query_result( - 'select Baz{a := .bar.a, b} order by .b', - [ - {'a': 1, 'b': 7}, - ], - ) - - async def test_edgeql_functions_inline_delete_nested_01(self): - await self.con.execute(''' - create type Bar { - create required property a -> int64; - }; - create function inner(x: int64) -> set of Bar { - set is_inlined := true; - using ((delete Bar filter .a <= x)); - }; - create function foo(x: int64) -> set of Bar { - set is_inlined := true; - using (inner(x)); - }; - ''') - - async def reset_data(): - await self.con.execute(''' - delete Bar; - insert Bar{a := 1}; - insert Bar{a := 2}; - insert Bar{a := 3}; - ''') - - await reset_data() - await self.assert_query_result( - 'select foo(1).a', - [1], - ) - await self.assert_query_result( - 'select Bar.a', - [2, 3], - sort=True, - ) - await reset_data() - await self.assert_query_result( - 'select foo(2).a', - [1, 2], - sort=True, - ) - await self.assert_query_result( - 'select Bar.a', - [3], - ) diff --git a/tests/test_edgeql_functions_inline.py b/tests/test_edgeql_functions_inline.py new file mode 100644 index 00000000000..14f0bf13241 --- /dev/null +++ b/tests/test_edgeql_functions_inline.py @@ -0,0 +1,10986 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2017-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import edgedb + +from edb.testbase import server as tb + + +class TestEdgeQLFunctionsInline(tb.QueryTestCase): + NO_FACTOR = True + + async def test_edgeql_functions_inline_basic_01(self): + await self.con.execute(''' + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_02(self): + await self.con.execute(''' + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (x * x + 2 * x + 1); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [4], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [4, 9, 16], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [4, 9, 16], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_03(self): + await self.con.execute(''' + create function foo(x: int64, y: int64) -> int64 { + set is_inlined := true; + using (x + y); + }; + ''') + await self.assert_query_result( + 'select foo({}, {})', + [], + ) + await self.assert_query_result( + 'select foo(1, {})', + [], + ) + await self.assert_query_result( + 'select foo({}, 1)', + [], + ) + await self.assert_query_result( + 'select foo(1, 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_04(self): + await self.con.execute(''' + create function foo(x: int64 = 9) -> int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo()', + [9], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_05(self): + await self.con.execute(''' + create function foo(x: int64) -> optional int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_06(self): + await self.con.execute(''' + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_07(self): + await self.con.execute(''' + create function foo(x: int64, y: int64 = 90) -> int64 { + set is_inlined := true; + using (x + y); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [91], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'select foo({}, {})', + [], + ) + await self.assert_query_result( + 'select foo(1, {})', + [], + ) + await self.assert_query_result( + 'select foo({}, 1)', + [], + ) + await self.assert_query_result( + 'select foo(1, 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(1, y))', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_08(self): + await self.con.execute(''' + create function foo(x: int64 = 9, y: int64 = 90) -> int64 { + set is_inlined := true; + using (x + y); + }; + ''') + await self.assert_query_result( + 'select foo()', + [99], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [91], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'select foo({}, {})', + [], + ) + await self.assert_query_result( + 'select foo(1, {})', + [], + ) + await self.assert_query_result( + 'select foo({}, 1)', + [], + ) + await self.assert_query_result( + 'select foo(1, 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(1, y))', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_09(self): + await self.con.execute(''' + create function foo(variadic x: int64) -> int64 { + set is_inlined := true; + using (sum(array_unpack(x))); + }; + ''') + await self.assert_query_result( + 'select foo()', + [0], + ) + await self.assert_query_result( + 'select foo(1,{})', + [], + ) + await self.assert_query_result( + 'select foo({},1)', + [], + ) + await self.assert_query_result( + 'select foo(1, 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, {10, 20, 30}, 100)', + [111, 112, 113, 121, 122, 123, 131, 132, 133], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, y, 100)' + ' )' + ')', + [111, 112, 113, 121, 122, 123, 131, 132, 133], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_10(self): + await self.con.execute(''' + create function foo(named only a: int64) -> int64 { + set is_inlined := true; + using (a); + }; + ''') + await self.assert_query_result( + 'select foo(a := {})', + [], + ) + await self.assert_query_result( + 'select foo(a := 1)', + [1], + ) + await self.assert_query_result( + 'select foo(a := {1,2,3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(a := x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_11(self): + await self.con.execute(''' + create function foo(x: int64, named only a: int64) -> int64 { + set is_inlined := true; + using (x + a); + }; + ''') + await self.assert_query_result( + 'select foo({}, a := {})', + [], + ) + await self.assert_query_result( + 'select foo(1, a := {})', + [], + ) + await self.assert_query_result( + 'select foo({}, a := 10)', + [], + ) + await self.assert_query_result( + 'select foo(1, a := 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, a := {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x, a := 10))', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(1, a := y))', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, a := y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_12(self): + await self.con.execute(''' + create function foo( + x: int64 = 9, + named only a: int64 + ) -> int64 { + set is_inlined := true; + using (x + a); + }; + ''') + await self.assert_query_result( + 'select foo(a := {})', + [], + ) + await self.assert_query_result( + 'select foo(a := 10)', + [19], + ) + await self.assert_query_result( + 'select foo(a := {10, 20, 30})', + [19, 29, 39], + sort=True, + ) + await self.assert_query_result( + 'select foo({}, a := {})', + [], + ) + await self.assert_query_result( + 'select foo(1, a := {})', + [], + ) + await self.assert_query_result( + 'select foo({}, a := 10)', + [], + ) + await self.assert_query_result( + 'select foo(1, a := 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, a := {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x, a := 10))', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(a := y))', + [19, 29, 39], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(1, a := y))', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, a := y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_13(self): + await self.con.execute(''' + create function foo( + x: int64, + named only a: int64 = 90 + ) -> int64 { + set is_inlined := true; + using (x + a); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [91], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'select foo({}, a := {})', + [], + ) + await self.assert_query_result( + 'select foo(1, a := {})', + [], + ) + await self.assert_query_result( + 'select foo({}, a := 10)', + [], + ) + await self.assert_query_result( + 'select foo(1, a := 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, a := {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x, a := 10))', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(1, a := y))', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, a := y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_14(self): + await self.con.execute(''' + create function foo( + x: int64 = 9, + named only a: int64 = 90 + ) -> int64 { + set is_inlined := true; + using (x + a); + }; + ''') + await self.assert_query_result( + 'select foo()', + [99], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [91], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'select foo(a := {})', + [], + ) + await self.assert_query_result( + 'select foo(a := 10)', + [19], + ) + await self.assert_query_result( + 'select foo(a := {10, 20, 30})', + [19, 29, 39], + sort=True, + ) + await self.assert_query_result( + 'select foo({}, a := {})', + [], + ) + await self.assert_query_result( + 'select foo(1, a := {})', + [], + ) + await self.assert_query_result( + 'select foo({}, a := 10)', + [], + ) + await self.assert_query_result( + 'select foo(1, a := 10)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := 10)', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, a := {10, 20, 30})', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, a := {10, 20, 30})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [91, 92, 93], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x, a := 10))', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(a := y))', + [19, 29, 39], + sort=True, + ) + await self.assert_query_result( + 'for y in {10, 20, 30} union (select foo(1, a := y))', + [11, 21, 31], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, a := y)' + ' )' + ')', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_15(self): + await self.con.execute(''' + create function foo( + x: int64, + y: int64 = 90, + variadic z: int64, + named only a: int64, + named only b: int64 = 90000 + ) -> int64 { + set is_inlined := true; + using (x + y + sum(array_unpack(z)) + a + b); + }; + ''') + await self.assert_query_result( + 'select foo(1, a := 1000)', + [91091], + ) + await self.assert_query_result( + 'select foo(1, 10, a := 1000)', + [91011], + ) + await self.assert_query_result( + 'select foo(1, a := 1000, b := 10000)', + [11091], + ) + await self.assert_query_result( + 'select foo(1, 10, a := 1000, b := 10000)', + [11011], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, a := 1000)', + [91111], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, a := 1000, b := 10000)', + [11111], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, 200, a := 1000)', + [91311], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, 200, a := 1000, b := 10000)', + [11311], + ) + + async def test_edgeql_functions_inline_basic_16(self): + await self.con.execute(''' + create function foo(x: optional int64) -> optional int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_17(self): + await self.con.execute(''' + create function foo( + x: optional int64 + ) -> int64 { + set is_inlined := true; + using (x ?? 5); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [5], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_18(self): + await self.con.execute(''' + create function foo( + x: optional int64 = 9 + ) -> int64 { + set is_inlined := true; + using (x ?? 5); + }; + ''') + await self.assert_query_result( + 'select foo()', + [9], + ) + await self.assert_query_result( + 'select foo({})', + [5], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_basic_19(self): + await self.con.execute(''' + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using (for y in {x, x + 1, x + 2} union (y)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1, 2, 3], + ) + await self.assert_query_result( + 'select foo({11, 21, 31})', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'for x in {11, 21, 31} union (select foo(x))', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + async def test_edgeql_functions_inline_array_01(self): + await self.con.execute(''' + create function foo(x: int64) -> array { + set is_inlined := true; + using ([x]); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [[1]], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [[1], [2], [3]], + sort=True, + ) + + async def test_edgeql_functions_inline_array_02(self): + await self.con.execute(''' + create function foo(x: array) -> array { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_array_03(self): + await self.con.execute(''' + create function foo( + x: array = [9] + ) -> array { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo()', + [[9]], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_array_04(self): + await self.con.execute(''' + create function foo(x: array) -> int64 { + set is_inlined := true; + using (sum(array_unpack(x))); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [1], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [1, 5], + sort=True, + ) + + async def test_edgeql_functions_inline_array_05(self): + await self.con.execute(''' + create function foo(x: array) -> set of int64 { + set is_inlined := true; + using (array_unpack(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [1], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_tuple_01(self): + await self.con.execute(''' + create function foo(x: int64) -> tuple { + set is_inlined := true; + using ((x,)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1,)], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(1,), (2,), (3,)], + sort=True, + ) + + async def test_edgeql_functions_inline_tuple_02(self): + await self.con.execute(''' + create function foo( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + sort=True, + ) + + async def test_edgeql_functions_inline_tuple_03(self): + await self.con.execute(''' + create function foo( + x: tuple = (9,) + ) -> tuple { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo()', + [(9,)], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + ) + + async def test_edgeql_functions_inline_tuple_04(self): + await self.con.execute(''' + create function foo( + x: tuple + ) -> int64 { + set is_inlined := true; + using (x.0); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [1], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_tuple_05(self): + await self.con.execute(''' + create function foo(x: int64) -> tuple { + set is_inlined := true; + using ((a:=x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [{'a': 1}], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [{'a': 1}, {'a': 2}, {'a': 3}], + ) + + async def test_edgeql_functions_inline_tuple_06(self): + await self.con.execute(''' + create function foo( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [{'a': 1}], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [{'a': 1}, {'a': 2}, {'a': 3}], + ) + + async def test_edgeql_functions_inline_tuple_07(self): + await self.con.execute(''' + create function foo( + x: tuple = (a:=9) + ) -> tuple { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo()', + [{'a': 9}], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [{'a': 1}], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [{'a': 1}, {'a': 2}, {'a': 3}], + ) + + async def test_edgeql_functions_inline_tuple_08(self): + await self.con.execute(''' + create function foo( + x: tuple + ) -> int64 { + set is_inlined := true; + using (x.a); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [1], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [1, 2, 3], + ) + + async def test_edgeql_functions_inline_object_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: int64) -> optional Bar { + set is_inlined := true; + using ((select Bar{a} filter .a = x limit 1)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo(-1).a', + [], + ) + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: Bar) -> Bar { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: optional Bar) -> optional Bar { + set is_inlined := true; + using (x ?? (select Bar filter .a = 1 limit 1)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: Bar) -> int64 { + set is_inlined := true; + using (x.a); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_05(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: Bar) -> set of Bar { + set is_inlined := true; + using ((select Bar{a} filter .a <= x.a)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 1, 1, 2, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_06(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ((select Bar{a} filter .a <= x).a); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1,2,3})', + [1, 1, 1, 2, 2, 3], + sort=True, + ) + + @tb.needs_factoring + async def test_edgeql_functions_inline_object_07(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo() -> int64 { + set is_inlined := true; + using (count(Bar)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [3], + ) + await self.assert_query_result( + 'select (foo(), foo())', + [[3, 3]], + sort=True, + ) + await self.assert_query_result( + 'select (Bar.a, foo())', + [[1, 3], [2, 3], [3, 3]], + sort=True, + ) + await self.assert_query_result( + 'select (foo(), Bar.a)', + [[3, 1], [3, 2], [3, 3]], + sort=True, + ) + await self.assert_query_result( + 'select (Bar.a, foo(), Bar.a, foo())', + [[1, 3, 1, 3], [2, 3, 2, 3], [3, 3, 3, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_object_08(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo() -> set of tuple { + set is_inlined := true; + using ((Bar.a, count(Bar))); + }; + ''') + await self.assert_query_result( + 'select foo()', + [[1, 1], [2, 1], [3, 1]], + ) + await self.assert_query_result( + 'select (foo(), foo())', + [ + [[1, 1], [1, 1]], [[1, 1], [2, 1]], [[1, 1], [3, 1]], + [[2, 1], [1, 1]], [[2, 1], [2, 1]], [[2, 1], [3, 1]], + [[3, 1], [1, 1]], [[3, 1], [2, 1]], [[3, 1], [3, 1]], + ], + sort=True, + ) + await self.assert_query_result( + 'select (Bar.a, foo())', + [ + [1, [1, 1]], [1, [2, 1]], [1, [3, 1]], + [2, [1, 1]], [2, [2, 1]], [2, [3, 1]], + [3, [1, 1]], [3, [2, 1]], [3, [3, 1]], + ], + sort=True, + ) + await self.assert_query_result( + 'select (foo(), Bar.a)', + [ + [[1, 1], 1], [[1, 1], 2], [[1, 1], 3], + [[2, 1], 1], [[2, 1], 2], [[2, 1], 3], + [[3, 1], 1], [[3, 1], 2], [[3, 1], 3], + ], + sort=True, + ) + + @tb.needs_factoring + async def test_edgeql_functions_inline_object_09(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: Bar) -> tuple { + set is_inlined := true; + using ((x.a, count(Bar))); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select (Bar.a, foo((select Bar filter .a = 1)))', + [[1, [1, 3]]], + ) + await self.assert_query_result( + 'select (Bar.a, foo((select detached Bar filter .a = 1)))', + [[1, [1, 3]], [2, [1, 3]], [3, [1, 3]]], + sort=True, + ) + await self.assert_query_result( + 'select (Bar.a, foo(Bar))', + [[1, [1, 3]], [2, [2, 3]], [3, [3, 3]]], + sort=True, + ) + await self.assert_query_result( + 'select (foo(Bar), foo(Bar))', + [[[1, 3], [1, 3]], [[2, 3], [2, 3]], [[3, 3], [3, 3]]], + sort=True, + ) + await self.assert_query_result( + 'select (foo(Bar), foo(detached Bar))', + [ + [[1, 3], [1, 3]], [[1, 3], [2, 3]], [[1, 3], [3, 3]], + [[2, 3], [1, 3]], [[2, 3], [2, 3]], [[2, 3], [3, 3]], + [[3, 3], [1, 3]], [[3, 3], [2, 3]], [[3, 3], [3, 3]], + ], + sort=True, + ) + + async def test_edgeql_functions_inline_object_10(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: Bar) -> set of Baz { + set is_inlined := true; + using ((select Baz filter .b <= x.a)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [4, 4, 4, 5, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_object_11(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: Bar | Baz) -> Bar | Baz { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Baz filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Baz)).a', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select {Bar, Baz})).a', + [1, 2, 3, 4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_object_12(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: int64) -> optional Bar | Baz { + set is_inlined := true; + using ((select {Bar, Baz} filter .a = x limit 1)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select foo({1, 4}).a', + [1, 4], + sort=True, + ) + await self.assert_query_result( + 'select foo({0, 1, 2, 3, 4, 5, 6, 7, 8}).a', + [1, 2, 3, 4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_object_13(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: Bar | Baz) -> optional Bar { + set is_inlined := true; + using (x[is Bar]); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Baz filter .a = 4)).a', + [], + ) + await self.assert_query_result( + 'select foo((select Baz)).a', + [], + ) + await self.assert_query_result( + 'select foo((select {Bar, Baz})).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_14(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: Bar | Baz) -> optional int64 { + set is_inlined := true; + using ( + x[is Baz].b + ) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [], + ) + await self.assert_query_result( + 'select foo((select Bar))', + [], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Baz filter .a = 4))', + [1], + ) + await self.assert_query_result( + 'select foo((select Baz))', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo((select {Bar, Baz}))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_object_15(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: Bar | Baz) -> optional int64 { + set is_inlined := true; + using ( + if x is Bar + then x.a*2 + else 10 + assert_exists(x[is Baz]).b + ) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [2], + ) + await self.assert_query_result( + 'select foo((select Bar))', + [2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Baz filter .a = 4))', + [11], + ) + await self.assert_query_result( + 'select foo((select Baz))', + [11, 12, 13], + sort=True, + ) + await self.assert_query_result( + 'select foo((select {Bar, Baz}))', + [2, 4, 6, 11, 12, 13], + sort=True, + ) + + async def test_edgeql_functions_inline_object_16(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Bar2 extending Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar2{a := 4}; + insert Bar2{a := 5}; + insert Bar2{a := 6}; + create function foo(x: Bar) -> optional Bar2 { + set is_inlined := true; + using (x[is Bar2]); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar2 filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Bar2)).a', + [4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_object_17(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + b := 4, + bar := assert_exists((select Bar filter .a = 1 limit 1)), + }; + insert Baz{ + b := 5, + bar := assert_exists((select Bar filter .a = 2 limit 1)), + }; + insert Baz{ + b := 6, + bar := assert_exists((select Bar filter .a = 3 limit 1)), + }; + create function foo(x: Baz) -> Bar { + set is_inlined := true; + using (x.bar); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Baz filter .b = 4)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Baz)).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_shape_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select Bar{' + ' a,' + ' b := foo(.a)' + '} order by .a', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 3, 'b': 3}, + ], + ) + + async def test_edgeql_functions_inline_shape_02(self): + await self.con.execute(''' + create type Bar { + create property a -> int64; + }; + insert Bar{}; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: optional int64) -> optional int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select Bar{' + ' a,' + ' b := foo(.a)' + '} order by .a', + [ + {'a': None, 'b': None}, + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 3, 'b': 3}, + ], + ) + + async def test_edgeql_functions_inline_shape_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: optional int64) -> set of int64 { + set is_inlined := true; + using ({10 + x, 20 + x, 30 + x}); + }; + ''') + await self.assert_query_result( + 'select Bar{' + ' a,' + ' b := foo(.a)' + '} order by .a', + [ + {'a': 1, 'b': [11, 21, 31]}, + {'a': 2, 'b': [12, 22, 32]}, + {'a': 3, 'b': [13, 23, 33]}, + ], + ) + + async def test_edgeql_functions_inline_shape_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo() -> int64 { + set is_inlined := true; + using (count(Bar)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [3], + ) + await self.assert_query_result( + 'select Bar {' + ' a,' + ' n := foo(),' + '} order by .a', + [{'a': 1, 'n': 3}, {'a': 2, 'n': 3}, {'a': 3, 'n': 3}], + ) + + async def test_edgeql_functions_inline_shape_05(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo() -> set of tuple { + set is_inlined := true; + using ((Bar.a, count(Bar))); + }; + ''') + await self.assert_query_result( + 'select foo()', + [[1, 1], [2, 1], [3, 1]], + ) + await self.assert_query_result( + 'select Bar {' + ' a,' + ' n := foo(),' + '} order by .a', + [ + {'a': 1, 'n': [[1, 1], [2, 1], [3, 1]]}, + {'a': 2, 'n': [[1, 1], [2, 1], [3, 1]]}, + {'a': 3, 'n': [[1, 1], [2, 1], [3, 1]]}, + ], + ) + + async def test_edgeql_functions_inline_shape_06(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: Bar) -> tuple { + set is_inlined := true; + using ((x.a, count(Bar))); + }; + ''') + await self.assert_query_result( + 'select Bar {' + ' a,' + ' n := foo(Bar),' + '} order by .a', + [ + {'a': 1, 'n': [1, 3]}, + {'a': 2, 'n': [2, 3]}, + {'a': 3, 'n': [3, 3]}, + ], + ) + + async def test_edgeql_functions_inline_shape_07(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using (assert_exists((select Bar filter .a = x limit 1))); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a,' + ' c := foo(.b).a,' + '} order by .a', + [ + {'a': 4, 'c': 1}, + {'a': 5, 'c': 2}, + {'a': 6, 'c': 3}, + ], + ) + + async def test_edgeql_functions_inline_shape_08(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + insert Baz{a := 7, b := 4}; + create function foo(x: int64) -> optional Bar { + set is_inlined := true; + using ((select Bar filter .a = x limit 1)); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a,' + ' c := foo(.b).a,' + '} order by .a', + [ + {'a': 4, 'c': 1}, + {'a': 5, 'c': 2}, + {'a': 6, 'c': 3}, + {'a': 7, 'c': None}, + ], + ) + + async def test_edgeql_functions_inline_shape_09(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((select Bar filter .a <= x)); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a,' + ' c := foo(.b).a,' + '} order by .a', + [ + {'a': 4, 'c': [1]}, + {'a': 5, 'c': [1, 2]}, + {'a': 6, 'c': [1, 2, 3]}, + ], + ) + + async def test_edgeql_functions_inline_shape_10(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + b := 4, + bar := assert_exists((select Bar filter .a = 1 limit 1)), + }; + insert Baz{ + b := 5, + bar := assert_exists((select Bar filter .a = 2 limit 1)), + }; + insert Baz{ + b := 6, + bar := assert_exists((select Bar filter .a = 3 limit 1)), + }; + create function foo(x: Bar) -> Bar { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a := foo(.bar).a,' + ' b,' + '} order by .a', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_shape_11(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + b := 4, + bar := assert_exists((select Bar filter .a = 1 limit 1)), + }; + insert Baz{ + b := 5, + bar := assert_exists((select Bar filter .a = 2 limit 1)), + }; + insert Baz{ + b := 6, + bar := assert_exists((select Bar filter .a = 3 limit 1)), + }; + create function foo(x: Bar) -> int64 { + set is_inlined := true; + using (x.a); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a := foo(.bar),' + ' b,' + '} order by .a', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_shape_12(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create multi link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + b := 4, + bar := assert_exists((select Bar filter .a <= 1)), + }; + insert Baz{ + b := 5, + bar := assert_exists((select Bar filter .a <= 2)), + }; + insert Baz{ + b := 6, + bar := assert_exists((select Bar filter .a <= 3)), + }; + create function foo(x: Bar) -> Bar { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a := foo(.bar).a,' + ' b,' + '} order by .b', + [ + {'a': [1], 'b': 4}, + {'a': [1, 2], 'b': 5}, + {'a': [1, 2, 3], 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_shape_13(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + }; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + bar := assert_exists((select Bar filter .a = 1 limit 1)) { + @b := 4 + }, + }; + insert Baz{ + bar := assert_exists((select Bar filter .a = 2 limit 1)) { + @b := 5 + } + }; + insert Baz{ + bar := assert_exists((select Bar filter .a = 3 limit 1)) { + @b := 6 + } + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + ''') + await self.assert_query_result( + 'select Baz{' + ' a := .bar.a,' + ' b := foo(.bar@b),' + '} order by .a', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_global_01(self): + await self.con.execute(''' + create global a := 1; + create function foo() -> int64 { + set is_inlined := true; + using (global a); + }; + ''') + await self.assert_query_result( + 'select foo()', + [1], + ) + + async def test_edgeql_functions_inline_global_02(self): + await self.con.execute(''' + create global a -> int64; + create function foo() -> optional int64 { + set is_inlined := true; + using (global a); + }; + ''') + await self.assert_query_result( + 'select foo()', + [], + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo()', + [1], + ) + + async def test_edgeql_functions_inline_global_03(self): + await self.con.execute(''' + create global a := 1; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (global a + x); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_global_04(self): + await self.con.execute(''' + create global a -> int64; + create function foo(x: int64) -> optional int64 { + set is_inlined := true; + using (global a + x) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [], + sort=True, + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_01(self): + # Directly passing parameter + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x) + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner(x)) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_02(self): + # Indirectly passing parameter + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner(x + 1)) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [4], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [4, 9, 16], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [4, 9, 16], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_03(self): + # Calling same inner function with different parameters + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(x: int64, y: int64) -> int64 { + set is_inlined := true; + using (inner(x) + inner(y)); + }; + ''') + await self.assert_query_result( + 'select foo({}, {})', + [], + ) + await self.assert_query_result( + 'select foo(1, {})', + [], + ) + await self.assert_query_result( + 'select foo({}, 1)', + [], + ) + await self.assert_query_result( + 'select foo(1, 10)', + [101], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, 10)', + [101, 104, 109], + sort=True, + ) + await self.assert_query_result( + 'select foo(1, {10, 20, 30})', + [101, 401, 901], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}, {10, 20, 30})', + [101, 104, 109, 401, 404, 409, 901, 904, 909], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' for y in {10, 20, 30} union (' + ' select foo(x, y)' + ' )' + ')', + [101, 104, 109, 401, 404, 409, 901, 904, 909], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_04(self): + # Directly passing parameter with default + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(x: int64 = 9) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [81], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_05(self): + # Indirectly passing parameter with default + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(x: int64 = 9) -> int64 { + set is_inlined := true; + using (inner(x+1)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [100], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [4], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [4, 9, 16], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [4, 9, 16], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_06(self): + # Inner function with default parameter + await self.con.execute(''' + create function inner(x: int64 = 9) -> int64 { + set is_inlined := true; + using (x * x) + }; + create function foo1() -> int64 { + set is_inlined := true; + using (inner()); + }; + create function foo2(x: int64) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo1()', + [81], + ) + await self.assert_query_result( + 'select foo2({})', + [], + ) + await self.assert_query_result( + 'select foo2(1)', + [1], + ) + await self.assert_query_result( + 'select foo2({1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo2(x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_07(self): + # Directly passing optional parameter + await self.con.execute(''' + create function inner(x: optional int64) -> optional int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(x: optional int64) -> int64 { + set is_inlined := true; + using (inner(x) ?? 99); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [99], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_08(self): + # Indirectly passing optional parameter + await self.con.execute(''' + create function inner(x: optional int64) -> optional int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(x: optional int64) -> int64 { + set is_inlined := true; + using (inner(x+1) ?? 99); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [99], + ) + await self.assert_query_result( + 'select foo(1)', + [4], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [4, 9, 16], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [4, 9, 16], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_09(self): + # Inner function with optional parameter + await self.con.execute(''' + create function inner(x: optional int64) -> int64 { + set is_inlined := true; + using ((x * x) ?? 99) + }; + create function foo1() -> int64 { + set is_inlined := true; + using (inner({})); + }; + create function foo2(x: int64) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo1()', + [99], + ) + await self.assert_query_result( + 'select foo2({})', + [], + ) + await self.assert_query_result( + 'select foo2(1)', + [1], + ) + await self.assert_query_result( + 'select foo2({1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo2(x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_10(self): + # Directly passing variadic parameter + await self.con.execute(''' + create function inner(x: array) -> int64 { + set is_inlined := true; + using (sum(array_unpack(x))) + }; + create function foo(variadic x: int64) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [0], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo(1, 2, 3)', + [6], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2}, {10, 20})', + [11, 12, 21, 22], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_11(self): + # Indirectly passing variadic parameter + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x) + }; + create function foo(variadic x: int64) -> int64 { + set is_inlined := true; + using (inner(sum(array_unpack(x)))); + }; + ''') + await self.assert_query_result( + 'select foo()', + [0], + ) + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo(1, 2, 3)', + [6], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2}, {10, 20})', + [11, 12, 21, 22], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_12(self): + # Inner function with variadic parameter + await self.con.execute(''' + create function inner(variadic x: int64) -> int64 { + set is_inlined := true; + using (sum(array_unpack(x))) + }; + create function foo1() -> int64 { + set is_inlined := true; + using (inner()); + }; + create function foo2(x: int64, y: int64, z: int64) -> int64 { + set is_inlined := true; + using (inner(x, y, z)); + }; + ''') + await self.assert_query_result( + 'select foo1()', + [0], + ) + await self.assert_query_result( + 'select foo2({}, {}, {})', + [], + ) + await self.assert_query_result( + 'select foo2(1, 2, 3)', + [6], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo2(x, x * 10, x * 100))', + [111, 222, 333], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_13(self): + # Directly passing named parameter + await self.con.execute(''' + create function inner(named only a: int64) -> int64 { + set is_inlined := true; + using (a * a) + }; + create function foo(named only a: int64) -> int64 { + set is_inlined := true; + using (inner(a := a)); + }; + ''') + await self.assert_query_result( + 'select foo(a := {})', + [], + ) + await self.assert_query_result( + 'select foo(a := 1)', + [1], + ) + await self.assert_query_result( + 'select foo(a := {1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(a := x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_14(self): + # Indirectly passing named parameter + await self.con.execute(''' + create function inner(named only a: int64) -> int64 { + set is_inlined := true; + using (a * a) + }; + create function foo(named only a: int64) -> int64 { + set is_inlined := true; + using (inner(a := a + 1)); + }; + ''') + await self.assert_query_result( + 'select foo(a := {})', + [], + ) + await self.assert_query_result( + 'select foo(a := 1)', + [4], + ) + await self.assert_query_result( + 'select foo(a := {1, 2, 3})', + [4, 9, 16], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(a := x))', + [4, 9, 16], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_15(self): + # Passing named parameter as positional + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x * x) + }; + create function foo(named only a: int64) -> int64 { + set is_inlined := true; + using (inner(a)); + }; + ''') + await self.assert_query_result( + 'select foo(a := {})', + [], + ) + await self.assert_query_result( + 'select foo(a := 1)', + [1], + ) + await self.assert_query_result( + 'select foo(a := {1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(a := x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_16(self): + # Passing positional parameter as named + await self.con.execute(''' + create function inner(named only a: int64) -> int64 { + set is_inlined := true; + using (a * a) + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner(a := x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [1, 4, 9], + sort=True, + ) + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 4, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_17(self): + # Variety of paremeter types + await self.con.execute(''' + create function inner1(x: int64, y: int64) -> int64 { + set is_inlined := true; + using (x + y) + }; + create function inner2(x: array) -> int64 { + set is_inlined := true; + using (sum(array_unpack(x))) + }; + create function foo( + x: int64, + y: int64 = 90, + variadic z: int64, + named only a: int64, + named only b: int64 = 90000 + ) -> int64 { + set is_inlined := true; + using (inner1(x, a) + inner1(y, b) + inner2(z)); + }; + ''') + await self.assert_query_result( + 'select foo(1, a := 1000)', + [91091], + ) + await self.assert_query_result( + 'select foo(1, 10, a := 1000)', + [91011], + ) + await self.assert_query_result( + 'select foo(1, a := 1000, b := 10000)', + [11091], + ) + await self.assert_query_result( + 'select foo(1, 10, a := 1000, b := 10000)', + [11011], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, a := 1000)', + [91111], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, a := 1000, b := 10000)', + [11111], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, 200, a := 1000)', + [91311], + ) + await self.assert_query_result( + 'select foo(1, 10, 100, 200, a := 1000, b := 10000)', + [11311], + ) + + async def test_edgeql_functions_inline_nested_basic_18(self): + # For in inner function + await self.con.execute(''' + create function inner(x: int64) -> set of int64 { + set is_inlined := true; + using (for y in {x, x + 1, x + 2} union (y)) + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(10)', + [10, 11, 12], + ) + await self.assert_query_result( + 'select foo({10, 20, 30})', + [10, 11, 12, 20, 21, 22, 30, 31, 32], + sort=True, + ) + await self.assert_query_result( + 'for x in {10, 20, 30} union (select foo(x))', + [10, 11, 12, 20, 21, 22, 30, 31, 32], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_19(self): + # For in outer function + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x) + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using (for y in {x, x + 1, x + 2} union (inner(y))); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(10)', + [10, 11, 12], + ) + await self.assert_query_result( + 'select foo({10, 20, 30})', + [10, 11, 12, 20, 21, 22, 30, 31, 32], + sort=True, + ) + await self.assert_query_result( + 'for x in {10, 20, 30} union (select foo(x))', + [10, 11, 12, 20, 21, 22, 30, 31, 32], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_basic_20(self): + # Deeply nested + await self.con.execute(''' + create function inner1(x: int64) -> int64 { + set is_inlined := true; + using (x+1) + }; + create function inner2(x: int64) -> int64 { + set is_inlined := true; + using (inner1(x+2)) + }; + create function inner3(x: int64) -> int64 { + set is_inlined := true; + using (inner2(x+3)) + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner3(x+4)) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [11], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [11, 12, 13], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_01(self): + # Return array from inner function + await self.con.execute(''' + create function inner(x: int64) -> array { + set is_inlined := true; + using ([x]); + }; + create function foo(x: int64) -> array { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [[1]], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [[1], [2], [3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_02(self): + # Access array element in inner function + await self.con.execute(''' + create function inner(x: array) -> int64 { + set is_inlined := true; + using (x[0]); + }; + create function foo(x: array) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [1], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [1, 2], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_03(self): + # Access array element in outer function + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo(x: array) -> int64 { + set is_inlined := true; + using (inner(x[0])); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [1], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [1, 2], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_04(self): + # Directly passing array parameter + await self.con.execute(''' + create function inner(x: array) -> array { + set is_inlined := true; + using (x); + }; + create function foo(x: array) -> array { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_05(self): + # Indirectly passing array parameter + await self.con.execute(''' + create function inner(x: array) -> array { + set is_inlined := true; + using (x); + }; + create function foo(x: array) -> array { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_06(self): + # Inner function with array parameter + await self.con.execute(''' + create function inner(x: array) -> array { + set is_inlined := true; + using (x); + }; + create function foo(x: int64) -> array { + set is_inlined := true; + using (inner([x])); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [[1]], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [[1], [2], [3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_07(self): + # Directly passing array parameter with default + await self.con.execute(''' + create function inner(x: array) -> array { + set is_inlined := true; + using (x); + }; + create function foo( + x: array = [9] + ) -> array { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [[9]], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_08(self): + # Directly passing array parameter with default + await self.con.execute(''' + create function inner(x: array) -> array { + set is_inlined := true; + using (x); + }; + create function foo( + x: array = [9] + ) -> array { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo()', + [[9]], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_09(self): + # Inner function with array parameter with default + await self.con.execute(''' + create function inner(x: array = [9]) -> array { + set is_inlined := true; + using (x); + }; + create function foo1() -> array { + set is_inlined := true; + using (inner()); + }; + create function foo2( + x: array + ) -> array { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo1()', + [[9]], + ) + await self.assert_query_result( + 'select foo2(>{})', + [], + ) + await self.assert_query_result( + 'select foo2([1])', + [[1]], + ) + await self.assert_query_result( + 'select foo2({[1], [2, 3]})', + [[1], [2, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_10(self): + # Unpack array in inner function + await self.con.execute(''' + create function inner(x: array) -> set of int64 { + set is_inlined := true; + using (array_unpack(x)); + }; + create function foo(x: array) -> set of int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [1], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_array_11(self): + # Unpack array in outer function + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo(x: array) -> set of int64 { + set is_inlined := true; + using (inner(array_unpack(x))); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo([1])', + [1], + ) + await self.assert_query_result( + 'select foo({[1], [2, 3]})', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_01(self): + # Return tuple from inner function + await self.con.execute(''' + create function inner(x: int64) -> tuple { + set is_inlined := true; + using ((x,)); + }; + create function foo(x: int64) -> tuple { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1,)], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(1,), (2,), (3,)], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3}).0', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_02(self): + # Return named tuple from inner function + await self.con.execute(''' + create function inner(x: int64) -> tuple { + set is_inlined := true; + using ((a := x)); + }; + create function foo(x: int64) -> tuple { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [{'a': 1}], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}) order by .a', + [{'a': 1}, {'a': 2}, {'a': 3}], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_03(self): + # Accessing tuple element in inner function + await self.con.execute(''' + create function inner( + x: tuple + ) -> int64 { + set is_inlined := true; + using (x.0); + }; + create function foo( + x: tuple + ) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [1], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_04(self): + # Accessing tuple element in outer function + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo( + x: tuple + ) -> int64 { + set is_inlined := true; + using (inner(x.0)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [1], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [1, 2, 3], + ) + + async def test_edgeql_functions_inline_nested_tuple_05(self): + # Accessing named tuple element in inner function + await self.con.execute(''' + create function inner( + x: tuple + ) -> int64 { + set is_inlined := true; + using (x.a); + }; + create function foo( + x: tuple + ) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((a := 1))', + [1], + ) + await self.assert_query_result( + 'select foo({(a := 1), (a := 2), (a := 3)})', + [1, 2, 3], + ) + + async def test_edgeql_functions_inline_nested_tuple_06(self): + # Accessing named tuple element in outer function + await self.con.execute(''' + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo( + x: tuple + ) -> int64 { + set is_inlined := true; + using (inner(x.a)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((a := 1))', + [1], + ) + await self.assert_query_result( + 'select foo({(a := 1), (a := 2), (a := 3)})', + [1, 2, 3], + ) + + async def test_edgeql_functions_inline_nested_tuple_07(self): + # Directly passing tuple parameter + await self.con.execute(''' + create function inner( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + create function foo( + x: tuple + ) -> tuple { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_08(self): + # Indirectly passing tuple parameter + await self.con.execute(''' + create function inner( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + create function foo( + x: tuple + ) -> tuple { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_09(self): + # Inner function with tuple parameter + await self.con.execute(''' + create function inner( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + create function foo( + x: int64 + ) -> tuple { + set is_inlined := true; + using (inner((x,))); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1,)], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(1,), (2,), (3,)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_tuple_10(self): + # Directly passing a tuple parameter with default + await self.con.execute(''' + create function inner( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + create function foo( + x: tuple = (9,) + ) -> tuple { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo()', + [(9,)], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + ) + + async def test_edgeql_functions_inline_nested_tuple_11(self): + # Indirectly passing tuple parameter with default + await self.con.execute(''' + create function inner( + x: tuple + ) -> tuple { + set is_inlined := true; + using (x); + }; + create function foo( + x: tuple = (9,) + ) -> tuple { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo()', + [(9,)], + ) + await self.assert_query_result( + 'select foo(>{})', + [], + ) + await self.assert_query_result( + 'select foo((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + ) + + async def test_edgeql_functions_inline_nested_tuple_12(self): + # Inner function with tuple parameter with default + await self.con.execute(''' + create function inner( + x: tuple = (9,) + ) -> tuple { + set is_inlined := true; + using (x); + }; + create function foo1() -> tuple { + set is_inlined := true; + using (inner()); + }; + create function foo2( + x: tuple + ) -> tuple { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo1()', + [(9,)], + ) + await self.assert_query_result( + 'select foo2(>{})', + [], + ) + await self.assert_query_result( + 'select foo2((1,))', + [(1,)], + ) + await self.assert_query_result( + 'select foo2({(1,), (2,), (3,)})', + [(1,), (2,), (3,)], + ) + + async def test_edgeql_functions_inline_nested_object_01(self): + # Directly passing object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: Bar) -> Bar { + set is_inlined := true; + using (x); + }; + create function foo(x: Bar) -> Bar { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_02(self): + # Indirectly passing object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: Bar) -> Bar { + set is_inlined := true; + using (x); + }; + create function foo(x: Bar) -> Bar { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_03(self): + # Inner function with object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: Bar) -> Bar { + set is_inlined := true; + using (x); + }; + create function foo(x: int64) -> optional Bar { + set is_inlined := true; + using (inner((select Bar filter .a = x limit 1))); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3, 4}).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_04(self): + # Inner function returning object + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: int64) -> optional Bar { + set is_inlined := true; + using ((select Bar filter .a = x limit 1)); + }; + create function foo(x: int64) -> optional Bar { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3, 4}).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_05(self): + # Outer function returning object + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo(x: int64) -> optional Bar { + set is_inlined := true; + using ((select Bar filter .a = inner(x) limit 1)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3, 4}).a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_06(self): + # Inner function returning set of object + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: int64) -> set of Bar { + set is_inlined := true; + using ((select Bar filter .a <= x)); + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}).a', + [1, 1, 1, 2, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_07(self): + # Outer function returning set of object + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((select Bar filter .a <= inner(x))); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo(2).a', + [1, 2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3}).a', + [1, 1, 1, 2, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_08(self): + # Directly passing optional object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: optional Bar) -> optional int64 { + set is_inlined := true; + using (x.a ?? 99); + }; + create function foo(x: optional Bar) -> optional int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [99], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_09(self): + # Indirectly passing optional object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: optional Bar) -> optional int64 { + set is_inlined := true; + using (x.a ?? 99); + }; + create function foo(x: optional Bar) -> optional int64 { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [99], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_10(self): + # Inner function with optional object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: optional Bar) -> int64 { + set is_inlined := true; + using (x.a ?? 99); + }; + create function foo1() -> int64 { + set is_inlined := true; + using (inner({})); + }; + create function foo2(x: Bar) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo1()', + [99], + ) + await self.assert_query_result( + 'select foo2({})', + [], + ) + await self.assert_query_result( + 'select foo2((select Bar filter .a = 1))', + [1], + ) + await self.assert_query_result( + 'select foo2((select Bar))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_11(self): + # Check path factoring + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner() -> set of tuple { + set is_inlined := true; + using ((Bar.a, count(Bar))); + }; + create function foo() -> set of tuple { + set is_inlined := true; + using (inner()); + }; + ''') + await self.assert_query_result( + 'select foo()', + [[1, 1], [2, 1], [3, 1]], + ) + await self.assert_query_result( + 'select (foo(), foo())', + [ + [[1, 1], [1, 1]], [[1, 1], [2, 1]], [[1, 1], [3, 1]], + [[2, 1], [1, 1]], [[2, 1], [2, 1]], [[2, 1], [3, 1]], + [[3, 1], [1, 1]], [[3, 1], [2, 1]], [[3, 1], [3, 1]], + ], + sort=True, + ) + await self.assert_query_result( + 'select (Bar.a, foo())', + [ + [1, [1, 1]], [1, [2, 1]], [1, [3, 1]], + [2, [1, 1]], [2, [2, 1]], [2, [3, 1]], + [3, [1, 1]], [3, [2, 1]], [3, [3, 1]], + ], + sort=True, + ) + await self.assert_query_result( + 'select (foo(), Bar.a)', + [ + [[1, 1], 1], [[1, 1], 2], [[1, 1], 3], + [[2, 1], 1], [[2, 1], 2], [[2, 1], 3], + [[3, 1], 1], [[3, 1], 2], [[3, 1], 3], + ], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_12(self): + # Check path factoring + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner1(x: Bar) -> int64 { + set is_inlined := true; + using (x.a); + }; + create function inner2(x: Bar) -> int64 { + set is_inlined := true; + using (count(Bar)); + }; + create function foo(x: Bar) -> tuple { + set is_inlined := true; + using ((inner1(x), inner2(x))); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [[1, 3]], + ) + await self.assert_query_result( + 'select (' + ' foo((select Bar filter .a = 1)),' + ' foo((select Bar filter .a = 2)),' + ')', + [[[1, 3], [2, 3]]], + ) + await self.assert_query_result( + 'select foo((select Bar))', + [[1, 3], [2, 3], [3, 3]], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_13(self): + # Directly passing complex type object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function inner(x: Bar | Baz) -> Bar | Baz { + set is_inlined := true; + using (x); + }; + create function foo(x: Bar | Baz) -> Bar | Baz { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Baz filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Baz)).a', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select {Bar, Baz})).a', + [1, 2, 3, 4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_14(self): + # Indirectly passing complex type object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function inner(x: Bar | Baz) -> Bar | Baz { + set is_inlined := true; + using (x); + }; + create function foo(x: Bar | Baz) -> Bar | Baz { + set is_inlined := true; + using (inner((select x))); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Baz filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Baz)).a', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select {Bar, Baz})).a', + [1, 2, 3, 4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_15(self): + # Inner function with complex type object parameter + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function inner(x: Bar | Baz) -> Bar | Baz { + set is_inlined := true; + using (x); + }; + create function foo1(x: Bar) -> Bar | Baz { + set is_inlined := true; + using (inner(x)); + }; + create function foo2(x: Baz) -> Bar | Baz { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo1({}).a', + [], + ) + await self.assert_query_result( + 'select foo2({}).a', + [], + ) + await self.assert_query_result( + 'select foo1((select Bar filter .a = 1)).a', + [1], + ) + await self.assert_query_result( + 'select foo1((select Bar)).a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select foo2((select Baz filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo2((select Baz)).a', + [4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_16(self): + # Type intersection in inner function + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Bar2 extending Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar2{a := 4}; + insert Bar2{a := 5}; + insert Bar2{a := 6}; + create function inner(x: Bar) -> optional Bar2 { + set is_inlined := true; + using (x[is Bar2]); + }; + create function foo(x: Bar) -> optional Bar2 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar2 filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Bar2)).a', + [4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_17(self): + # Type intersection in outer function + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Bar2 extending Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar2{a := 4}; + insert Bar2{a := 5}; + insert Bar2{a := 6}; + create function inner(x: Bar2) -> optional Bar2 { + set is_inlined := true; + using (x); + }; + create function foo(x: Bar) -> optional Bar2 { + set is_inlined := true; + using (inner(x[is Bar2])); + }; + ''') + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo({}).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1)).a', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar2 filter .a = 4)).a', + [4], + ) + await self.assert_query_result( + 'select foo((select Bar)).a', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select foo((select Bar2)).a', + [4, 5, 6], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_object_18(self): + # Access linked object in inner function + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar; + }; + create type Bazz { + create required link baz -> Baz; + }; + insert Bazz{baz := (insert Baz{bar := (insert Bar{a := 1})})}; + insert Bazz{baz := (insert Baz{bar := (insert Bar{a := 2})})}; + insert Bazz{baz := (insert Baz{bar := (insert Bar{a := 3})})}; + create function inner1(x: Bar) -> int64 { + set is_inlined := true; + using (x.a); + }; + create function inner2(x: Baz) -> int64 { + set is_inlined := true; + using (inner1(x.bar)); + }; + create function foo(x: Bazz) -> int64 { + set is_inlined := true; + using (inner2(x.baz)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bazz filter .baz.bar.a = 1))', + [1], + ) + await self.assert_query_result( + 'select foo((select Bazz))', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_01(self): + # Put result of inner function taking Bar.a into Bar + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo(x: Bar) -> int64 { + set is_inlined := true; + using ((select x{a, b := inner(x.a)}).b); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [1], + ) + await self.assert_query_result( + 'select foo(Bar)', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_02(self): + # Put result of inner function taking Bar into Bar + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: Bar) -> int64 { + set is_inlined := true; + using (x.a + 90); + }; + create function foo(x: Bar) -> tuple { + set is_inlined := true; + using ( + with y := (select x{a, b := inner(x)}) + select (y.a, y.b) + ); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo((select Bar filter .a = 1))', + [(1, 91)], + ) + await self.assert_query_result( + 'select foo(Bar)', + [(1, 91), (2, 92), (3, 93)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_03(self): + # Put result of inner function taking number into Bar + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x + 90); + }; + create function foo(x: int64) -> set of tuple { + set is_inlined := true; + using ( + with y := (select Bar{a, b := inner(x)}) + select (y.a, y.b) + ); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1, 91), (2, 91), (3, 91)], + sort=True, + ) + await self.assert_query_result( + 'select foo(Bar.a)', + [ + (1, 91), (1, 92), (1, 93), + (2, 91), (2, 92), (2, 93), + (3, 91), (3, 92), (3, 93), + ], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_04(self): + # Put result of inner function using Bar into Bar + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function inner() -> int64 { + set is_inlined := true; + using (count(Bar)); + }; + create function foo(x: int64) -> set of tuple { + set is_inlined := true; + using ( + with y := (select Bar{a, b := inner()} filter .a = x) + select (y.a, y.b) + ); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1, 3)], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(1, 3), (2, 3), (3, 3)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_05(self): + # Put result of inner function taking Baz.b and returning Bar into Baz + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property a -> int64; + create required property b -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{a := 4, b := 1}; + insert Baz{a := 5, b := 2}; + insert Baz{a := 6, b := 3}; + create function inner(x: int64) -> Bar { + set is_inlined := true; + using (assert_exists((select Bar filter .a = x limit 1))); + }; + create function foo(x: int64) -> set of tuple { + set is_inlined := true; + using ( + with y := (select Baz{a, c := inner(.b).a} filter .b = x) + select (y.a, y.b) + ); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(4, 1)], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(4, 1), (5, 2), (6, 3)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_06(self): + # Put result of inner function taking Baz.bar into Baz + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + b := 4, + bar := assert_exists((select Bar filter .a = 1 limit 1)), + }; + insert Baz{ + b := 5, + bar := assert_exists((select Bar filter .a = 2 limit 1)), + }; + insert Baz{ + b := 6, + bar := assert_exists((select Bar filter .a = 3 limit 1)), + }; + create function inner(x: Bar) -> int64 { + set is_inlined := true; + using (x.a); + }; + create function foo(x: int64) -> set of tuple { + set is_inlined := true; + using ( + with y := (select Baz{a := inner(.bar), b} filter .a = x) + select (y.a, y.b) + ); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1, 4)], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(1, 4), (2, 5), (3, 6)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_shape_07(self): + # Put result of inner function taking Baz.bar@b into Baz + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + }; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{ + bar := assert_exists((select Bar filter .a = 1 limit 1)) { + @b := 4 + } + }; + insert Baz{ + bar := assert_exists((select Bar filter .a = 2 limit 1)) { + @b := 5 + } + }; + insert Baz{ + bar := assert_exists((select Bar filter .a = 3 limit 1)) { + @b := 6 + } + }; + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (x); + }; + create function foo(x: int64) -> set of tuple { + set is_inlined := true; + using ( + with y := ( + select Baz{a := .bar.a, b := inner(.bar@b)} + filter .a = x + ) + select (y.a, y.b) + ); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [(1, 4)], + sort=True, + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [(1, 4), (2, 5), (3, 6)], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_01(self): + # Use computed global in inner function + await self.con.execute(''' + create global a := 1; + create function inner(x: int64) -> int64 { + set is_inlined := true; + using (global a + x); + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_02(self): + # Use non-computed global in inner function + await self.con.execute(''' + create global a -> int64; + create function inner(x: int64) -> optional int64 { + set is_inlined := true; + using (global a + x); + }; + create function foo(x: int64) -> optional int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [], + sort=True, + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_03(self): + # Pass computed global to inner function + await self.con.execute(''' + create global a := 1; + create function inner(x: int64, y: int64) -> int64 { + set is_inlined := true; + using (x + y); + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner(global a, x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_04(self): + # Pass non-computed global to inner function + await self.con.execute(''' + create global a -> int64; + create function inner(x: int64, y: int64) -> optional int64 { + set is_inlined := true; + using (x + y); + }; + create function foo(x: int64) -> optional int64 { + set is_inlined := true; + using (inner(global a, x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [], + sort=True, + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_05(self): + # Use computed global in inner non-inlined function + # - inlined > non-inlined + await self.con.execute(''' + create global a := 1; + create function inner(x: int64) -> int64 { + using (global a + x); + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_06(self): + # Use non-computed global in inner non-inlined function + # - inlined > non-inlined + await self.con.execute(''' + create global a -> int64; + create function inner(x: int64) -> optional int64 { + using (global a + x); + }; + create function foo(x: int64) -> optional int64 { + set is_inlined := true; + using (inner(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [], + sort=True, + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_07(self): + # Use computed global nested in non-inlined function + # - non-inlined > inlined > non-inlined + await self.con.execute(''' + create global a := 1; + create function inner1(x: int64) -> int64 { + using (global a + x); + }; + create function inner2(x: int64) -> int64 { + set is_inlined := true; + using (inner1(x)); + }; + create function foo(x: int64) -> int64 { + using (inner2(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_08(self): + # Use non-computed global nested in non-inlined function + # - non-inlined > inlined > non-inlined + await self.con.execute(''' + create global a -> int64; + create function inner1(x: int64) -> optional int64 { + using (global a + x); + }; + create function inner2(x: int64) -> optional int64 { + set is_inlined := true; + using (inner1(x)); + }; + create function foo(x: int64) -> optional int64 { + using (inner2(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [], + sort=True, + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_09(self): + # Use computed global in deeply nested inner non-inlined function + # - inlined > inlined > inlined > non-inlined + await self.con.execute(''' + create global a := 1; + create function inner1(x: int64) -> int64 { + using (global a + x); + }; + create function inner2(x: int64) -> int64 { + set is_inlined := true; + using (inner1(x)); + }; + create function inner3(x: int64) -> int64 { + set is_inlined := true; + using (inner2(x)); + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using (inner3(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_nested_global_10(self): + # Use computed global in deeply nested inner non-inlined function + # - inlined > inlined > inlined > non-inlined + await self.con.execute(''' + create global a -> int64; + create function inner1(x: int64) -> optional int64 { + using (global a + x); + }; + create function inner2(x: int64) -> optional int64 { + set is_inlined := true; + using (inner1(x)); + }; + create function inner3(x: int64) -> optional int64 { + set is_inlined := true; + using (inner2(x)); + }; + create function foo(x: int64) -> optional int64 { + set is_inlined := true; + using (inner3(x)); + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [], + sort=True, + ) + + await self.con.execute(''' + set global a := 1; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + await self.assert_query_result( + 'select foo(1)', + [2], + ) + await self.assert_query_result( + 'select foo({1, 2, 3})', + [2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_modifying_cardinality_01(self): + await self.con.execute(''' + create function foo(x: int64) -> int64 { + set volatility := schema::Volatility.Modifying; + using (x) + }; + ''') + await self.assert_query_result( + 'select foo(1)', + [1], + ) + + async def test_edgeql_functions_inline_modifying_cardinality_02(self): + await self.con.execute(''' + create function foo(x: int64) -> int64 { + set volatility := schema::Volatility.Modifying; + using (x) + }; + ''') + with self.assertRaisesRegex( + edgedb.QueryError, + 'possibly an empty set passed as non-optional argument ' + 'into modifying function' + ): + await self.con.execute(''' + select foo({}) + ''') + + async def test_edgeql_functions_inline_modifying_cardinality_03(self): + await self.con.execute(''' + create function foo(x: int64) -> int64 { + set volatility := schema::Volatility.Modifying; + using (x) + }; + ''') + with self.assertRaisesRegex( + edgedb.QueryError, + 'possibly more than one element passed into modifying function' + ): + await self.con.execute(''' + select foo({1, 2, 3}) + ''') + + async def test_edgeql_functions_inline_modifying_cardinality_04(self): + await self.con.execute(''' + create function foo(x: optional int64) -> optional int64 { + set volatility := schema::Volatility.Modifying; + using (x) + }; + ''') + await self.assert_query_result( + 'select foo(1)', + [1], + ) + + async def test_edgeql_functions_inline_modifying_cardinality_05(self): + await self.con.execute(''' + create function foo(x: optional int64) -> optional int64 { + set volatility := schema::Volatility.Modifying; + using (x) + }; + ''') + await self.assert_query_result( + 'select foo({})', + [], + ) + + async def test_edgeql_functions_inline_modifying_cardinality_06(self): + await self.con.execute(''' + create function foo(x: optional int64) -> optional int64 { + set volatility := schema::Volatility.Modifying; + using (x) + }; + ''') + with self.assertRaisesRegex( + edgedb.QueryError, + 'possibly more than one element passed into modifying function' + ): + await self.con.execute(''' + select foo({1, 2, 3}) + ''') + + async def test_edgeql_functions_inline_insert_basic_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo() -> Bar { + set is_inlined := true; + using ((insert Bar{ a := 1 })); + }; + ''') + + await self.assert_query_result( + 'select foo().a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + async def test_edgeql_functions_inline_insert_basic_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + async def test_edgeql_functions_inline_insert_basic_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using ((insert Bar{ a := x }).a) + }; + ''') + + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + async def test_edgeql_functions_inline_insert_basic_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x + 1 })) + }; + ''') + + await self.assert_query_result( + 'select foo(1).a', + [2], + ) + await self.assert_query_result( + 'select Bar.a', + [2], + ) + + async def test_edgeql_functions_inline_insert_basic_05(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using ((insert Bar{ a := 2 * x + 1 }).a + 10) + }; + ''') + + await self.assert_query_result( + 'select foo(1)', + [13], + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + + async def test_edgeql_functions_inline_insert_basic_06(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64 = 0) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'select foo().a', + [0], + ) + await self.assert_query_result( + 'select Bar.a', + [0], + ) + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1], + ) + + async def test_edgeql_functions_inline_insert_basic_07(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: optional int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x ?? 0 })) + }; + ''') + + await self.assert_query_result( + 'select foo({}).a', + [0], + ) + await self.assert_query_result( + 'select Bar.a', + [0], + ) + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_basic_08(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(named only x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'select foo(x := 1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + async def test_edgeql_functions_inline_insert_basic_09(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(variadic x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := sum(array_unpack(x)) })) + }; + ''') + + await self.assert_query_result( + 'select foo().a', + [0], + ) + await self.assert_query_result( + 'select Bar.a', + [0], + ) + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1], + sort=True, + ) + + await self.assert_query_result( + 'select foo(2, 3).a', + [5], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 5], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_basic_10(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + create required property b -> int64; + }; + create function foo(x: int64, y: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x, b := y })) + }; + ''') + + await self.assert_query_result( + 'select foo(1, 10){a, b}' + 'order by .a then .b', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [{'a': 1, 'b': 10}], + ) + + async def test_edgeql_functions_inline_insert_iterator_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + ''') + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (select foo(x).a)', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + await self.assert_query_result( + 'select if true then foo(5).a else 99', + [5], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'select if false then foo(6).a else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'select if true then 99 else foo(7).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5], + sort=True, + ) + await self.assert_query_result( + 'select if false then 99 else foo(8).a', + [8], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8], + sort=True, + ) + + await self.assert_query_result( + 'select foo(9).a ?? 99', + [9], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8, 9], + sort=True, + ) + await self.assert_query_result( + 'select 99 ?? foo(10).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4, 5, 8, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_iterator_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + create required property b -> int64; + }; + create function foo(x: int64, y: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x, b := y })) + }; + ''') + + await self.assert_query_result( + 'select foo(1, 10){a, b}' + 'order by .a then .b', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [{'a': 1, 'b': 10}], + ) + + await self.assert_query_result( + 'select (' + ' for x in {2, 3} union(' + ' for y in {20, 30} union(' + ' select foo(x, y)' + ' )' + ' )' + '){a, b}' + 'order by .a then .b', + [ + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + ], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + ], + ) + + await self.assert_query_result( + 'select (' + ' if true' + ' then foo(5, 50)' + ' else (select Bar filter .a = 1)' + '){a, b}' + 'order by .a then .b', + [{'a': 5, 'b': 50}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + {'a': 5, 'b': 50}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then foo(6, 60)' + ' else (select Bar filter .a = 1)' + '){a, b}' + 'order by .a then .b', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + {'a': 5, 'b': 50}, + ], + ) + await self.assert_query_result( + 'select (' + ' if true' + ' then (select Bar filter .a = 1)' + ' else foo(7, 70)' + '){a, b}' + 'order by .a then .b', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + {'a': 5, 'b': 50}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then (select Bar filter .a = 1)' + ' else foo(8, 80)' + '){a, b}' + 'order by .a then .b', + [{'a': 8, 'b': 80}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + {'a': 5, 'b': 50}, + {'a': 8, 'b': 80}, + ], + ) + + await self.assert_query_result( + 'select (foo(9, 90) ?? (select Bar filter .a = 1)){a, b}', + [{'a': 9, 'b': 90}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + {'a': 5, 'b': 50}, + {'a': 8, 'b': 80}, + {'a': 9, 'b': 90}, + ], + ) + await self.assert_query_result( + 'select ((select Bar filter .a = 1) ?? foo(10, 100)){a, b}', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar{a, b}' + 'order by .a then .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + {'a': 5, 'b': 50}, + {'a': 8, 'b': 80}, + {'a': 9, 'b': 90}, + ], + ) + + async def test_edgeql_functions_inline_insert_iterator_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ( + for y in {x, x + 1, x + 2} union ( + (insert Bar{ a := y }) + ) + ) + }; + ''') + + await self.assert_query_result( + 'select foo(1).a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + await self.assert_query_result( + 'for x in {11, 21, 31} union (select foo(x).a)', + [11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 11, 12, 13, 21, 22, 23, 31, 32, 33], + sort=True, + ) + + await self.assert_query_result( + 'select if true then foo(51).a else 99', + [51, 52, 53], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [ + 1, 2, 3, + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 51, 52, 53, + ], + sort=True, + ) + await self.assert_query_result( + 'select if false then foo(61).a else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 1, 2, 3, + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 51, 52, 53, + ], + sort=True, + ) + await self.assert_query_result( + 'select if true then 99 else foo(71).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 1, 2, 3, + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 51, 52, 53, + ], + sort=True, + ) + await self.assert_query_result( + 'select if false then 99 else foo(81).a', + [81, 82, 83], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [ + 1, 2, 3, + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 51, 52, 53, + 81, 82, 83, + ], + sort=True, + ) + + await self.assert_query_result( + 'select foo(91).a ?? 99', + [91, 92, 93], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 1, 2, 3, + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 51, 52, 53, + 81, 82, 83, + 91, 92, 93, + ], + sort=True, + ) + await self.assert_query_result( + 'select 99 ?? foo(101).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 1, 2, 3, + 11, 12, 13, + 21, 22, 23, + 31, 32, 33, + 51, 52, 53, + 81, 82, 83, + 91, 92, 93, + ], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_iterator_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: bool, y: int64) -> optional Bar { + set is_inlined := true; + using ( + if x then (insert Bar{ a := y }) else {} + ) + }; + ''') + + await self.assert_query_result( + 'select foo(false, 0).a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + await self.assert_query_result( + 'select foo(true, 1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4, 5} union (select foo(x % 2 = 0, x).a)', + [2, 4], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4], + sort=True, + ) + + await self.assert_query_result( + 'select if true then foo(false, 6).a else 99', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4], + sort=True, + ) + await self.assert_query_result( + 'select if true then foo(true, 6).a else 99', + [6], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select if false then foo(false, 7).a else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select if false then foo(true, 7).a else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select if true then 99 else foo(false, 8).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select if true then 99 else foo(true, 8).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select if false then 99 else foo(false, 9).a', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6], + sort=True, + ) + await self.assert_query_result( + 'select if false then 99 else foo(true, 9).a', + [9], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6, 9], + sort=True, + ) + + await self.assert_query_result( + 'select foo(false, 10).a ?? 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6, 9], + sort=True, + ) + await self.assert_query_result( + 'select foo(true, 10).a ?? 99', + [10], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6, 9, 10], + sort=True, + ) + await self.assert_query_result( + 'select 99 ?? foo(false, 11).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6, 9, 10], + sort=True, + ) + await self.assert_query_result( + 'select 99 ?? foo(true, 11).a', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 4, 6, 9, 10], + sort=True, + ) + + @unittest.skip('Cannot correlate same set inside and outside DML') + async def test_edgeql_functions_inline_insert_correlate_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> tuple { + set is_inlined := true; + using (((insert Bar{ a := x }), x)) + }; + ''') + + await self.assert_query_result( + 'select foo(1)', + [[[], 1]], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (select foo(x).a)', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + @unittest.skip('Cannot correlate same set inside and outside DML') + async def test_edgeql_functions_inline_insert_correlate_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> int64 { + set is_inlined := true; + using ((insert Bar{ a := 2 * x + 1 }).a + x * x) + }; + ''') + + await self.assert_query_result( + 'select foo(1)', + [4], + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (select foo(x))', + [9, 16, 25], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 5, 7, 9], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_correlate_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> tuple { + set is_inlined := true; + using (( + (insert Bar{ a := x }).a, + (insert Bar{ a := x + 1 }).a, + )) + }; + ''') + + await self.assert_query_result( + 'select foo(1)', + [[1, 2]], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2], + sort=True, + ) + + await self.assert_query_result( + 'for x in {11, 21, 31} union (select foo(x))', + [[11, 12], [21, 22], [31, 32]], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 11, 12, 21, 22, 31, 32], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_correlate_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> tuple { + set is_inlined := true; + using (( + (insert Bar{ a := x }).a, + (insert Bar{ a := y }).a, + )) + }; + ''') + + await self.assert_query_result( + 'select foo(1, 2)', + [[1, 2]], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2], + sort=True, + ) + + await self.assert_query_result( + 'for x in {1, 5} union (' + ' for y in {10, 20} union (' + ' select foo(x + y, x + y + 1)' + ' )' + ')', + [[11, 12], [15, 16], [21, 22], [25, 26]], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 11, 12, 15, 16, 21, 22, 25, 26], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_correlate_05(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> int64 { + set is_inlined := true; + using ((insert Bar{ a := 2 * x + 1 }).a + y) + }; + ''') + + await self.assert_query_result( + 'select foo(1, 10)', + [13], + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + + await self.assert_query_result( + 'for x in {2, 3} union(' + ' for y in {20, 30} union(' + ' select foo(x, y)' + ' )' + ')', + [25, 27, 35, 37], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 5, 5, 7, 7], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_conflict_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + create constraint exclusive on (.a) + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using (( + insert Bar{a := x} + unless conflict on .a + else ((update Bar set {a := x + 10})) + )) + }; + ''') + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x).a)', + [2, 3, 11], + sort=True + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3, 11], + ) + + async def test_edgeql_functions_inline_insert_conflict_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create type Baz { + create link bar -> Bar; + create constraint exclusive on (.bar) + }; + create function foo(x: Bar) -> Baz { + set is_inlined := true; + using (( + insert Baz{bar := x} + unless conflict on .bar + else (( + update Baz set {bar := (insert Bar{a := x.a + 10})} + )) + )) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' assert_exists((select Bar filter .a = 1 limit 1))' + ').bar.a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + ) + await self.assert_query_result( + 'select Baz.bar.a', + [1], + ) + + await self.assert_query_result( + 'for x in {1, 2, 3} union (' + ' select foo(' + ' assert_exists((select Bar filter .a = x limit 1))' + ' ).bar.a' + ')', + [2, 3, 11], + sort=True + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 11], + ) + await self.assert_query_result( + 'select Baz.bar.a', + [2, 3, 11], + ) + + async def test_edgeql_functions_inline_insert_link_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(n: int64, x: Bar) -> Baz { + set is_inlined := true; + using ((insert Baz{ b := n, bar := x })) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' 4,' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b}', + [{'a': 1, 'b': 4}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a', + [{'a': 1, 'b': 4}], + ) + + await self.assert_query_result( + 'select foo(' + ' 5,' + ' assert_exists((select Bar filter .a = 2 limit 1))' + '){a := .bar.a, b}', + [{'a': 2, 'b': 5}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + ], + ) + + async def test_edgeql_functions_inline_insert_link_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create multi link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: int64, y: int64) -> Baz { + set is_inlined := true; + using ( + (insert Baz{ + b := x, + bar := (select Bar filter .a <= y), + }) + ); + }; + ''') + + await self.assert_query_result( + 'select foo(4, 1){a := .bar.a, b}', + [{'a': [1], 'b': 4}], + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [{'a': [1], 'b': 4}], + ) + + await self.assert_query_result( + 'select foo(5, 2){a := .bar.a, b}', + [{'a': [1, 2], 'b': 5}], + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': [1], 'b': 4}, + {'a': [1, 2], 'b': 5}, + ], + ) + + async def test_edgeql_functions_inline_insert_link_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + create function foo(x: int64, y: int64) -> Baz { + set is_inlined := true; + using ( + (insert Baz { + b := y, + bar := (insert Bar{ a := x }) + }) + ); + }; + ''') + + await self.assert_query_result( + 'select foo(1, 4).b', + [4], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b', + [{'a': 1, 'b': 4}], + ) + + await self.assert_query_result( + 'select foo(2, 5).b', + [5], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2], + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + ], + ) + + async def test_edgeql_functions_inline_insert_link_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar {a := x})) + }; + ''') + + await self.assert_query_result( + 'select (insert Baz{b := 4, bar := foo(1)})' + '{a := .bar.a, b} order by .b', + [{'a': 1, 'b': 4}], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b', + [{'a': 1, 'b': 4}], + ) + + await self.assert_query_result( + 'select (insert Baz{b := 5, bar := foo(2)})' + '{a := .bar.a, b} order by .b', + [{'a': 2, 'b': 5}], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2], + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + ], + ) + + async def test_edgeql_functions_inline_insert_link_iterator_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + create function foo(n: int64, x: Bar) -> Baz { + set is_inlined := true; + using ((insert Baz{ b := n, bar := x })) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' 1, assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b} order by .a then .b', + [{'a': 1, 'b': 1}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [{'a': 1, 'b': 1}], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (' + ' select foo(' + ' x, assert_exists((select Bar filter .a = 2 limit 1))' + ' ).b' + ')', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + + await self.assert_query_result( + 'select (' + ' if true' + ' then foo(' + ' 5, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).b' + ' else 99' + ')', + [5], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then foo(' + ' 6, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).b' + ' else 99' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select (' + ' if true' + ' then 99' + ' else foo(' + ' 7, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).b' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then 99' + ' else foo(' + ' 8, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).b' + ')', + [8], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + {'a': 3, 'b': 8}, + ], + ) + + await self.assert_query_result( + 'select foo(' + ' 9, assert_exists((select Bar filter .a = 4 limit 1))' + ').b ?? 99', + [9], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + {'a': 3, 'b': 8}, + {'a': 4, 'b': 9}, + ], + ) + await self.assert_query_result( + 'select 99 ?? foo(' + ' 9, assert_exists((select Bar filter .a = 4 limit 1))' + ').b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + {'a': 3, 'b': 8}, + {'a': 4, 'b': 9}, + ], + ) + + async def test_edgeql_functions_inline_insert_link_iterator_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create multi link bar -> Bar; + }; + create function foo(x: int64, y: int64) -> Baz { + set is_inlined := true; + using ( + (insert Baz { + b := y, + bar := (for z in {x, x + 1, x + 2} union( + (insert Bar{ a := z }) + )) + }) + ); + }; + ''') + + await self.assert_query_result( + 'select foo(10, 1).b', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [10, 11, 12], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [{'a': [10, 11, 12], 'b': 1}], + ) + + await self.assert_query_result( + 'for x in {20, 30} union (' + ' for y in {2, 3} union (' + ' select foo(x, y).b' + ' )' + ')', + [2, 2, 3, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + ], + ) + + await self.assert_query_result( + 'select if true then foo(40, 4).b else 999', + [4], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + 40, 41, 42, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + {'a': [40, 41, 42], 'b': 4}, + ], + ) + await self.assert_query_result( + 'select if false then foo(50, 5).b else 999', + [999], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + 40, 41, 42, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + {'a': [40, 41, 42], 'b': 4}, + ], + ) + await self.assert_query_result( + 'select if true then 999 else foo(60, 6).b', + [999], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + 40, 41, 42, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + {'a': [40, 41, 42], 'b': 4}, + ], + ) + await self.assert_query_result( + 'select if false then 999 else foo(70, 7).b', + [7], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + 40, 41, 42, + 70, 71, 72, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + {'a': [40, 41, 42], 'b': 4}, + {'a': [70, 71, 72], 'b': 7}, + ], + ) + + await self.assert_query_result( + 'select foo(80, 8).b ?? 999', + [8], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + 40, 41, 42, + 70, 71, 72, + 80, 81, 82, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + {'a': [40, 41, 42], 'b': 4}, + {'a': [70, 71, 72], 'b': 7}, + {'a': [80, 81, 82], 'b': 8}, + ], + ) + await self.assert_query_result( + 'select 999 ?? foo(90, 9).b', + [999], + ) + await self.assert_query_result( + 'select Bar.a', + [ + 10, 11, 12, + 20, 20, 21, 21, 22, 22, + 30, 30, 31, 31, 32, 32, + 40, 41, 42, + 70, 71, 72, + 80, 81, 82, + ], + sort=True, + ) + await self.assert_query_result( + 'select Baz {a := .bar.a, b} order by .b then sum(.a)', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [30, 31, 32], 'b': 2}, + {'a': [20, 21, 22], 'b': 3}, + {'a': [30, 31, 32], 'b': 3}, + {'a': [40, 41, 42], 'b': 4}, + {'a': [70, 71, 72], 'b': 7}, + {'a': [80, 81, 82], 'b': 8}, + ], + ) + + async def test_edgeql_functions_inline_insert_link_iterator_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + create function foo(n: int64, x: Bar, flag: bool) -> optional Baz { + set is_inlined := true; + using ( + if flag then (insert Baz{ b := n, bar := x }) else {} + ) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' 0, assert_exists((select Bar filter .a = 1 limit 1)), false' + '){a := .bar.a, b} order by .a then .b', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [], + ) + await self.assert_query_result( + 'select foo(' + ' 1, assert_exists((select Bar filter .a = 1 limit 1)), true' + '){a := .bar.a, b} order by .a then .b', + [{'a': 1, 'b': 1}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [{'a': 1, 'b': 1}], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (' + ' select foo(' + ' x,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' false,' + ' ).b' + ')', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [{'a': 1, 'b': 1}], + ) + await self.assert_query_result( + 'for x in {2, 3, 4} union (' + ' select foo(' + ' x,' + ' assert_exists((select Bar filter .a = 2 limit 1)),' + ' true,' + ' ).b' + ')', + [2, 3, 4], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + + await self.assert_query_result( + 'select (' + ' if true' + ' then foo(' + ' 5,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' false,' + ' ).b' + ' else 99' + ')', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then foo(' + ' 6,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' false,' + ' ).b' + ' else 99' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select (' + ' if true' + ' then 99' + ' else foo(' + ' 7,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' false,' + ' ).b' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then 99' + ' else foo(' + ' 8,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' false,' + ' ).b' + ')', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select (' + ' if true' + ' then foo(' + ' 9,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' true,' + ' ).b' + ' else 99' + ')', + [9], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' true,' + ' ).b' + ' else 99' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + ], + ) + await self.assert_query_result( + 'select (' + ' if true' + ' then 99' + ' else foo(' + ' 11,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' true,' + ' ).b' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then 99' + ' else foo(' + ' 12,' + ' assert_exists((select Bar filter .a = 3 limit 1)),' + ' true,' + ' ).b' + ')', + [12], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + {'a': 3, 'b': 12}, + ], + ) + + await self.assert_query_result( + 'select foo(' + ' 13, assert_exists((select Bar filter .a = 4 limit 1)), false' + ').b ?? 99', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + {'a': 3, 'b': 12}, + ], + ) + await self.assert_query_result( + 'select 99 ?? foo(' + ' 14, assert_exists((select Bar filter .a = 4 limit 1)), false' + ').b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + {'a': 3, 'b': 12}, + ], + ) + await self.assert_query_result( + 'select foo(' + ' 15, assert_exists((select Bar filter .a = 4 limit 1)), true' + ').b ?? 99', + [15], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + {'a': 3, 'b': 12}, + {'a': 4, 'b': 15}, + ], + ) + await self.assert_query_result( + 'select 99 ?? foo(' + ' 16, assert_exists((select Bar filter .a = 4 limit 1)), true' + ').b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 9}, + {'a': 3, 'b': 12}, + {'a': 4, 'b': 15}, + ], + ) + + async def test_edgeql_functions_inline_insert_linkprop_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + } + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(x: Bar) -> Baz { + set is_inlined := true; + using ((insert Baz{ bar := x { @b := 10 } })) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b := .bar@b}', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a', + [{'a': 1, 'b': 10}], + ) + + async def test_edgeql_functions_inline_insert_linkprop_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + } + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + create function foo(n: int64, x: Bar) -> Baz { + set is_inlined := true; + using ((insert Baz{ bar := x { @b := n } })) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' 4,' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b := .bar@b}', + [{'a': 1, 'b': 4}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a', + [{'a': 1, 'b': 4}], + ) + + async def test_edgeql_functions_inline_insert_linkprop_iterator_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + } + }; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + create function foo(n: int64, x: Bar) -> Baz { + set is_inlined := true; + using ((insert Baz{ bar := x { @b := n } })) + }; + ''') + + await self.assert_query_result( + 'select foo(' + ' 1,' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b := .bar@b}', + [{'a': 1, 'b': 1}], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a', + [{'a': 1, 'b': 1}], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (' + ' select foo(' + ' x, assert_exists((select Bar filter .a = 2 limit 1))' + ' ).bar@b' + ')', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + ], + ) + + await self.assert_query_result( + 'select (' + ' if true' + ' then foo(' + ' 5, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).bar@b' + ' else 99' + ')', + [5], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then foo(' + ' 6, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).bar@b' + ' else 99' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select (' + ' if true' + ' then 99' + ' else foo(' + ' 7, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).bar@b' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select (' + ' if false' + ' then 99' + ' else foo(' + ' 8, assert_exists((select Bar filter .a = 3 limit 1))' + ' ).bar@b' + ')', + [8], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + {'a': 3, 'b': 8}, + ], + ) + + await self.assert_query_result( + 'select foo(' + ' 9, assert_exists((select Bar filter .a = 4 limit 1))' + ').bar@b ?? 99', + [9], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + {'a': 3, 'b': 8}, + {'a': 4, 'b': 9}, + ], + ) + await self.assert_query_result( + 'select 99 ?? foo(' + ' 9, assert_exists((select Bar filter .a = 4 limit 1))' + ').bar@b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a then .b', + [ + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 2, 'b': 3}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': 5}, + {'a': 3, 'b': 8}, + {'a': 4, 'b': 9}, + ], + ) + + async def test_edgeql_functions_inline_insert_nested_01(self): + # Simple inner modifying function + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function inner(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })); + }; + create function foo(x: int64) -> Bar { + set is_inlined := true; + using (inner(x)); + }; + ''') + + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + + await self.assert_query_result( + 'for x in {2, 3, 4} union (foo(x).a)', + [2, 3, 4], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3, 4], + sort=True, + ) + + async def test_edgeql_functions_inline_insert_nested_02(self): + # Putting the result of an inner modifying function into shape + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create required link bar -> Bar; + }; + create function inner1(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + create function inner2(x: int64, y: int64) -> Baz { + set is_inlined := true; + using ((insert Baz{ b := y, bar := inner1(x) })) + }; + create function foo(x: int64, y: int64) -> Baz { + set is_inlined := true; + using (inner2(x, y)) + }; + ''') + + await self.assert_query_result( + 'select foo(1, 10){a := .bar.a, b := .b}', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .b} order by .a', + [{'a': 1, 'b': 10}], + ) + + await self.assert_query_result( + 'select (' + ' for x in {2, 3} union (' + ' for y in {20, 30} union (' + ' foo(x, y){a := .bar.a, b := .b}' + ' )' + ' )' + ') order by .a then .b', + [ + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 2, 3, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .b} order by .a', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + ], + ) + + async def test_edgeql_functions_inline_insert_nested_03(self): + # Putting the result of an inner modifying function into shape with + # link property + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + }; + }; + create function inner1(x: int64) -> Bar { + set is_inlined := true; + using ((insert Bar{ a := x })) + }; + create function inner2(x: int64, y: int64) -> Baz { + set is_inlined := true; + using ((insert Baz{ bar := inner1(x){ @b := y } })) + }; + create function foo(x: int64, y: int64) -> Baz { + set is_inlined := true; + using (inner2(x, y)) + }; + ''') + + await self.assert_query_result( + 'select foo(1, 10){a := .bar.a, b := .bar@b}', + [{'a': 1, 'b': 10}], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a', + [{'a': 1, 'b': 10}], + ) + + await self.assert_query_result( + 'select (' + ' for x in {2, 3} union (' + ' for y in {20, 30} union (' + ' foo(x, y){a := .bar.a, b := .bar@b}' + ' )' + ' )' + ') order by .a then .b', + [ + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 2, 3, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': 2, 'b': 30}, + {'a': 3, 'b': 20}, + {'a': 3, 'b': 30}, + ], + ) + + async def test_edgeql_functions_inline_update_basic_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar set { a := x })); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(1).a', + [1, 1, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 1, 1], + sort=True, + ) + + async def test_edgeql_functions_inline_update_basic_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of int64 { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := x }).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0, 0)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1)', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 3)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + async def test_edgeql_functions_inline_update_basic_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo( + named only m: int64, + named only n: int64, + ) -> set of int64 { + set is_inlined := true; + using ((update Bar filter .a <= n set { a := m }).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(m := 0, n := 0)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(m := 0, n := 1)', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(m := 0, n := 2)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(m := 0, n := 3)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + async def test_edgeql_functions_inline_update_basic_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo( + x: optional int64, + y: optional int64, + ) -> set of int64 { + set is_inlined := true; + using ((update Bar filter .a <= y ?? 9 set { a := x ?? 9 }).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo({}, {})', + [9, 9, 9], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [9, 9, 9], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo({}, 2)', + [9, 9], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3, 9, 9], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2, {})', + [2, 2, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [2, 2, 2], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0, 0)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1)', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 3)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + async def test_edgeql_functions_inline_update_basic_05(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo( + x: int64, + variadic y: int64, + ) -> set of int64 { + set is_inlined := true; + using ( + ( + update Bar + filter .a <= sum(array_unpack(y)) + set { a := x } + ).a + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1)', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1, 2)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + async def test_edgeql_functions_inline_update_iterator_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of int64 { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := x }).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0, 0)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1)', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 3)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(0, x))', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(0, x))', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x - 1, 0))', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x - 1, 3))', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {1} union (select foo(x - 1, x))', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {2, 3} union (select foo(x - 1, x))', + [1, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 1, 2], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x - 1, x))', + [0, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 2], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(0, 2) else 99', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(0, 2) else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(0, 2)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(0, 2)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0, 0) ?? 99', + [99], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2) ?? 99', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(0, 2)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_update_iterator_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64, y: int64) -> set of int64 { + set is_inlined := true; + using ( + for z in {0, 1} union ( + (update Bar filter .a <= y + z set { a := x + z }).a + ) + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0, 0)', + [1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1)', + [0, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2)', + [0, 0, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 1], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 3)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(0, x))', + [1, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 1, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(0, x))', + [0, 1, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 1], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x - 1, 0))', + [1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x - 1, 3))', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {1} union (select foo(x - 1, x))', + [0, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {2, 3} union (select foo(x - 1, x))', + [1, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 1, 2], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x - 1, x))', + [0, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 2], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(0, 1) else 99', + [0, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(0, 1) else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(0, 1)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(0, 1)', + [0, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 3], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0, -1) ?? 99', + [99], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1) ?? 99', + [0, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(0, 1)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_update_iterator_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo( + x: int64, y: int64, z: bool + ) -> set of int64 { + set is_inlined := true; + using ( + if z + then (update Bar filter .a <= y set { a := x }).a + else {} + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0, 2, false)', + [], + ) + await self.assert_query_result( + 'select foo(0, 3, false)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2, true)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 3, true)', + [0, 0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 0], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(0, x, false))', + [], + sort=True, + ) + await self.assert_query_result( + 'for x in {2, 3} union (select foo(x - 1, x, false))', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(0, x, true))', + [0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {2, 3} union (select foo(x - 1, x, true))', + [1, 1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 1, 2], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(0, 2, false) else 99', + [], + sort=True, + ) + await self.assert_query_result( + 'select if false then foo(0, 2, false) else 99', + [99], + ) + await self.assert_query_result( + 'select if true then 99 else foo(0, 2, false)', + [99], + ) + await self.assert_query_result( + 'select if false then 99 else foo(0, 2, false)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then foo(0, 2, true) else 99', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(0, 2, true) else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(0, 2, true)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(0, 2, true)', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0, 0, false) ?? 99', + [99], + sort=True, + ) + await self.assert_query_result( + 'select foo(0, 2, false) ?? 99', + [99], + sort=True, + ) + await self.assert_query_result( + 'select 99 ?? foo(0, 2, false)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 0, true) ?? 99', + [99], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 2, true) ?? 99', + [0, 0], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [0, 0, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(0, 2, true)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_update_link_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar; + }; + create function foo(n: int64, x: Bar) -> set of Baz { + set is_inlined := true; + using ((update Baz filter .b <= n set { bar := x })) + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{b := 4}; + insert Baz{b := 5}; + insert Baz{b := 6}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(' + ' 4,' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b}', + [ + {'a': 1, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': None, 'b': 5}, + {'a': None, 'b': 6}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(' + ' 5,' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b}', + [ + {'a': 1, 'b': 4}, + {'a': 1, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 1, 'b': 5}, + {'a': None, 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_update_link_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create multi link bar -> Bar; + }; + create function foo(x: int64, y: int64) -> set of Baz { + set is_inlined := true; + using ( + (update Baz filter .b <= x set { + bar := (select Bar filter .a <= y), + }) + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{b := 4}; + insert Baz{b := 5}; + insert Baz{b := 6}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(4, 1){a := .bar.a, b}', + [ + {'a': [1], 'b': 4}, + ], + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': [1], 'b': 4}, + {'a': [], 'b': 5}, + {'a': [], 'b': 6}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(5, 2){a := .bar.a, b}', + [ + {'a': [1, 2], 'b': 4}, + {'a': [1, 2], 'b': 5}, + ], + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': [1, 2], 'b': 4}, + {'a': [1, 2], 'b': 5}, + {'a': [], 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_update_link_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create optional link bar -> Bar; + }; + create function foo(x: int64, y: int64) -> set of Baz { + set is_inlined := true; + using ( + (update Baz filter .b <= x set { + bar := (insert Bar{a := y}), + }) + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 4}; + insert Baz{b := 5}; + insert Baz{b := 6}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(4, 1){a := .bar.a, b}', + [ + {'a': 1, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': None, 'b': 5}, + {'a': None, 'b': 6}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(5, 2){a := .bar.a, b}', + [ + {'a': 2, 'b': 4}, + {'a': 2, 'b': 5}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 2], + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': 2, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': None, 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_update_link_iterator_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar; + }; + create function foo(n: int64, x: Bar) -> set of Baz { + set is_inlined := true; + using ((update Baz filter .b = n set { bar := x })) + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Bar{a := 4}; + insert Baz{b := 10}; + insert Baz{b := 20}; + insert Baz{b := 30}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1))' + '){a := .bar.a, b}', + [ + {'a': 1, 'b': 10}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select (' + ' for x in {1, 2} union(' + ' select foo(' + ' x * 10,' + ' assert_exists((select Bar filter .a = x limit 1))' + ' ).b' + ' )' + ')', + [10, 20], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 10}, + {'a': 2, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select (' + ' if true' + ' then foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1)),' + ' ).b' + ' else 99' + ')', + [10], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select (' + ' if false' + ' then foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1)),' + ' ).b' + ' else 99' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select (' + ' if true' + ' then 99' + ' else foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1)),' + ' ).b' + ')', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select (' + ' if false' + ' then 99' + ' else foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1)),' + ' ).b' + ')', + [10], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1)),' + ').b ?? 99', + [10], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(' + ' 10,' + ' assert_exists((select Bar filter .a = 1 limit 1)),' + ').b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 10}, + {'a': None, 'b': 20}, + {'a': None, 'b': 30}, + ], + ) + + async def test_edgeql_functions_inline_update_link_iterator_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create multi link bar -> Bar; + }; + create function foo(x: int64, y: int64) -> set of Baz { + set is_inlined := true; + using (( + update Baz filter .b = x set { + bar := (for z in {y, y + 1, y + 2} union ( + insert Bar{a := z} + ) + ) + } + )) + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 1}; + insert Baz{b := 2}; + insert Baz{b := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(1, 10){a := .bar.a, b}', + [ + {'a': [10, 11, 12], 'b': 1}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'for x in {1, 2} union (select foo(x, x * 10){a := .bar.a, b})', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [20, 21, 22], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(1, 10).b else 99', + [1], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(1, 10).b else 99', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(1, 10).b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(1, 10).b', + [1], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(1, 10).b ?? 99', + [1], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [10, 11, 12], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(1, 10).b', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': [], 'b': 1}, + {'a': [], 'b': 2}, + {'a': [], 'b': 3}, + ], + ) + + async def test_edgeql_functions_inline_update_link_iterator_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar; + }; + create function foo(x: int64, y: int64, flag: bool) -> set of Baz { + set is_inlined := true; + using (( + update Baz filter .b = x set { + bar := ( + if flag + then (insert Bar{a := y}) + else {} + ) + } + )) + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 1}; + insert Baz{b := 2}; + insert Baz{b := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(1, 10, false){a := .bar.a, b}', + [ + {'a': None, 'b': 1}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(1, 10, true){a := .bar.a, b}', + [ + {'a': 10, 'b': 1}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [10], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 10, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'for x in {1, 2} union (' + ' select foo(x, x * 10, false){a := .bar.a, b}' + ')', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2} union (' + ' select foo(x, x * 10, true){a := .bar.a, b}' + ')', + [ + {'a': 10, 'b': 1}, + {'a': 20, 'b': 2}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 10, 'b': 1}, + {'a': 20, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(1, 10, false).bar.a else 99', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(1, 10, false).bar.a else 99', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(1, 10, false).bar.a', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(1, 10, false).bar.a', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if true then foo(1, 10, true).bar.a else 99', + [10], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 10, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(1, 10, true).bar.a else 99', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(1, 10, true).bar.a', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(1, 10, true).bar.a', + [10], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 10, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(1, 10, false).bar.a ?? 99', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(1, 10, false).bar.a', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(1, 10, true).bar.a ?? 99', + [10], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 10, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(1, 10, true).bar.a', + [99], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 1}, + {'a': None, 'b': 2}, + {'a': None, 'b': 3}, + ], + ) + + async def test_edgeql_functions_inline_update_linkprop_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required link bar -> Bar { + create property b -> int64; + } + }; + create function foo(x: int64, y: int64) -> set of Baz { + set is_inlined := true; + using (( + update Baz filter .bar.a <= x set { + bar := .bar { @b := y } + } + )) + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{bar := (insert Bar{a := 1})}; + insert Baz{bar := (insert Bar{a := 2})}; + insert Baz{bar := (insert Bar{a := 3})}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(2, 4){a := .bar.a, b := .bar@b}', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 4}, + ], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b := .bar@b} order by .a', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 4}, + {'a': 3, 'b': None}, + ], + ) + + async def test_edgeql_functions_inline_update_nested_01(self): + # Simple inner modifying function + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function inner(x: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar set { a := x })); + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using (inner(x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(1).a', + [1, 1, 1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 1, 1], + sort=True, + ) + + async def test_edgeql_functions_inline_update_nested_02(self): + # Putting the result of an inner modifying function into shape + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create multi link bar -> Bar; + }; + create function inner1(y: int64) -> set of Bar { + set is_inlined := true; + using ((update Bar filter .a <= y set { a := .a - 1 })); + }; + create function inner2(x: int64, y: int64) -> set of Baz { + set is_inlined := true; + using ( + (update Baz filter .b <= x set { + bar := assert_distinct(inner1(y)), + }) + ); + }; + create function foo(x: int64, y: int64) -> set of Baz { + set is_inlined := true; + using (inner2(x, y)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + insert Baz{b := 4}; + insert Baz{b := 5}; + insert Baz{b := 6}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(4, 1){a := .bar.a, b}', + [ + {'a': [0], 'b': 4}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': [0], 'b': 4}, + {'a': [], 'b': 5}, + {'a': [], 'b': 6}, + ], + ) + + # Inner update will return an empty set for all subsequent calls. + await reset_data() + await self.assert_query_result( + 'select foo(5, 2){a := .bar.a, b}', + [ + {'a': [0, 1], 'b': 4}, + {'a': [], 'b': 5}, + ], + ) + await self.assert_query_result( + 'select Bar.a', + [0, 1, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz {' + ' a := (select .bar order by .a).a,' + ' b,' + '} order by .b', + [ + {'a': [0, 1], 'b': 4}, + {'a': [], 'b': 5}, + {'a': [], 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_delete_basic_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using ((delete Bar filter .a <= x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2).a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + + async def test_edgeql_functions_inline_delete_basic_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ((delete Bar filter .a <= x).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + async def test_edgeql_functions_inline_delete_basic_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(named only m: int64) -> set of int64 { + set is_inlined := true; + using ((delete Bar filter .a <= m).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(m := 0)', + [], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(m := 1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(m := 2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select foo(m := 3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + async def test_edgeql_functions_inline_delete_basic_04(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: optional int64) -> set of int64 { + set is_inlined := true; + using ((delete Bar filter .a <= x ?? 9).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo({})', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + async def test_edgeql_functions_inline_delete_basic_05(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo( + variadic x: int64, + ) -> set of int64 { + set is_inlined := true; + using ( + ( + delete Bar + filter .a <= sum(array_unpack(x)) + ).a + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, 1, 2)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + async def test_edgeql_functions_inline_delete_iterator_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ((delete Bar filter .a <= x).a); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(x))', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(2) else 99', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(2) else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(2)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0) ?? 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2) ?? 99', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(2)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_delete_iterator_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ( + for z in {0, 1} union ( + (delete Bar filter .a <= x).a + ) + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(x))', + [1], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + ) + await reset_data() + await self.assert_query_result( + 'for x in {1, 2, 3} union (select foo(x))', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(2) else 99', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(2) else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(2)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0) ?? 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2) ?? 99', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(2)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_delete_iterator_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function foo( + x: int64, y: bool + ) -> set of int64 { + set is_inlined := true; + using ( + if y + then (delete Bar filter .a <= x).a + else {} + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(2, false)', + [], + ) + await self.assert_query_result( + 'select foo(3, false)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2, true)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3, true)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(x, false))', + [], + ) + await self.assert_query_result( + 'for x in {2, 3} union (select foo(x, false))', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {0, 1} union (select foo(x, true))', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'for x in {2, 3} union (select foo(x, true))', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + + await reset_data() + await self.assert_query_result( + 'select if true then foo(2, false) else 99', + [], + ) + await self.assert_query_result( + 'select if false then foo(2, false) else 99', + [99], + ) + await self.assert_query_result( + 'select if true then 99 else foo(2, false)', + [99], + ) + await self.assert_query_result( + 'select if false then 99 else foo(2, false)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then foo(2, true) else 99', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await reset_data() + await self.assert_query_result( + 'select if false then foo(2, true) else 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if true then 99 else foo(2, true)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select if false then 99 else foo(2, true)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + sort=True, + ) + + await reset_data() + await self.assert_query_result( + 'select foo(0, false) ?? 99', + [99], + ) + await self.assert_query_result( + 'select foo(2, false) ?? 99', + [99], + ) + await self.assert_query_result( + 'select 99 ?? foo(2, false)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(0, true) ?? 99', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2, true) ?? 99', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select 99 ?? foo(2, true)', + [99], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + + async def test_edgeql_functions_inline_delete_policy_target_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar { + on target delete allow; + }; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ( + (delete Bar filter .a <= x).a + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 4, bar := (insert Bar{a := 1})}; + insert Baz{b := 5, bar := (insert Bar{a := 2})}; + insert Baz{b := 6, bar := (insert Bar{a := 3})}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 4}, + {'a': None, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': None, 'b': 4}, + {'a': None, 'b': 5}, + {'a': None, 'b': 6}, + ], + ) + + async def test_edgeql_functions_inline_delete_policy_target_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar { + on target delete delete source; + }; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ( + (delete Bar filter .a <= x).a + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 4, bar := (insert Bar{a := 1})}; + insert Baz{b := 5, bar := (insert Bar{a := 2})}; + insert Baz{b := 6, bar := (insert Bar{a := 3})}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b}', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(1)', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b}', + [ + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(2)', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b}', + [ + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(3)', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b}', + [], + ) + + async def test_edgeql_functions_inline_delete_policy_source_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar { + on source delete allow; + }; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ( + (delete Baz filter .b <= x).b + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 4, bar := (insert Bar{a := 1})}; + insert Baz{b := 5, bar := (insert Bar{a := 2})}; + insert Baz{b := 6, bar := (insert Bar{a := 3})}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(4)', + [4], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(5)', + [4, 5], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(6)', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [], + ) + + async def test_edgeql_functions_inline_delete_policy_source_02(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar { + on source delete delete target; + }; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ( + (delete Baz filter .b <= x).b + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 4, bar := (insert Bar{a := 1})}; + insert Baz{b := 5, bar := (insert Bar{a := 2})}; + insert Baz{b := 6, bar := (insert Bar{a := 3})}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(4)', + [4], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(5)', + [4, 5], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 3, 'b': 6}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(6)', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [], + ) + + async def test_edgeql_functions_inline_delete_policy_source_03(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create type Baz { + create required property b -> int64; + create link bar -> Bar { + on source delete delete target if orphan; + }; + }; + create function foo(x: int64) -> set of int64 { + set is_inlined := true; + using ( + (delete Baz filter .b <= x).b + ); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Baz; + delete Bar; + insert Baz{b := 4, bar := (insert Bar{a := 1})}; + insert Baz{b := 5, bar := (insert Bar{a := 2})}; + insert Baz{b := 6, bar := (insert Bar{a := 3})}; + insert Baz{ + b := 7, + bar := assert_exists((select Bar filter .a = 1 limit 1)), + }; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(0)', + [], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 4}, + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + {'a': 1, 'b': 7}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(4)', + [4], + ) + await self.assert_query_result( + 'select Bar.a', + [1, 2, 3], + sort=True, + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 2, 'b': 5}, + {'a': 3, 'b': 6}, + {'a': 1, 'b': 7}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(5)', + [4, 5], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1, 3], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 3, 'b': 6}, + {'a': 1, 'b': 7}, + ], + ) + await reset_data() + await self.assert_query_result( + 'select foo(6)', + [4, 5, 6], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [1], + ) + await self.assert_query_result( + 'select Baz{a := .bar.a, b} order by .b', + [ + {'a': 1, 'b': 7}, + ], + ) + + async def test_edgeql_functions_inline_delete_nested_01(self): + await self.con.execute(''' + create type Bar { + create required property a -> int64; + }; + create function inner(x: int64) -> set of Bar { + set is_inlined := true; + using ((delete Bar filter .a <= x)); + }; + create function foo(x: int64) -> set of Bar { + set is_inlined := true; + using (inner(x)); + }; + ''') + + async def reset_data(): + await self.con.execute(''' + delete Bar; + insert Bar{a := 1}; + insert Bar{a := 2}; + insert Bar{a := 3}; + ''') + + await reset_data() + await self.assert_query_result( + 'select foo(1).a', + [1], + ) + await self.assert_query_result( + 'select Bar.a', + [2, 3], + sort=True, + ) + await reset_data() + await self.assert_query_result( + 'select foo(2).a', + [1, 2], + sort=True, + ) + await self.assert_query_result( + 'select Bar.a', + [3], + ) diff --git a/tests/test_edgeql_group.py b/tests/test_edgeql_group.py index 8e8a88168a8..90bbd714536 100644 --- a/tests/test_edgeql_group.py +++ b/tests/test_edgeql_group.py @@ -1499,14 +1499,7 @@ async def test_edgeql_group_destruct_immediately_02(self): ["element", "element", "element", "element"], ) - @test.xerror(""" - Issue #5796 - - Materialized set not finalized. - """) async def test_edgeql_group_issue_5796(self): - # Fails on assert mat_set.materialized . - # Depends on double select and deck in shape await self.assert_query_result( r''' with @@ -1528,7 +1521,6 @@ async def test_edgeql_group_issue_5796(self): ]), ) - @test.xerror("Issue #6059") async def test_edgeql_group_issue_6059(self): await self.assert_query_result( r''' @@ -1547,7 +1539,6 @@ async def test_edgeql_group_issue_6059(self): [{"keyCard": {}}] * 4, ) - @test.xerror("Issue #6060") async def test_edgeql_group_issue_6060(self): await self.assert_query_result( r''' @@ -1609,13 +1600,10 @@ async def test_edgeql_group_issue_6019_a(self): ) by .key ''') - @test.xerror(""" - Issue #6019 - - Grouping on key should probably be rejected. - (And if not, it should not ISE!) - """) async def test_edgeql_group_issue_6019_b(self): + # This didn't work because group created free objects which were then + # materialized as volatile. `group (group X by .x) by .key` has a + # different cause. await self.assert_query_result( ''' with diff --git a/tests/test_edgeql_insert.py b/tests/test_edgeql_insert.py index 2a6ea609752..24083714e4c 100644 --- a/tests/test_edgeql_insert.py +++ b/tests/test_edgeql_insert.py @@ -5820,35 +5820,1024 @@ async def test_edgeql_insert_cardinality_assertion(self): } ''') - @tb.needs_factoring_weakly async def test_edgeql_insert_volatile_01(self): - # Ideally we'll support these versions eventually - async with self.assertRaisesRegexTx( - edgedb.QueryError, - "cannot refer to volatile WITH bindings from DML"): - await self.con.execute(''' - WITH name := random(), - INSERT Person { name := name, tag := name }; - ''') + await self.con.execute(''' + WITH name := random(), + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_02(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ "!", + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_03(self): + await self.con.execute(''' + WITH + x := "!", + name := x ++ random(), + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_04(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ random(), + INSERT Person { name := name, tag := name }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_05(self): + await self.con.execute(''' + WITH name := random(), + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_06(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ "!", + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_07(self): + await self.con.execute(''' + WITH + x := "!", + name := x ++ random(), + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_08(self): + await self.con.execute(''' + WITH + x := random(), + name := x ++ random(), + SELECT (INSERT Person { name := name, tag := name }); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_09(self): + await self.con.execute(''' + WITH x := random() + SELECT ( + WITH name := x ++ "!" + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_10(self): + await self.con.execute(''' + WITH x := "!" + SELECT ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_11(self): + await self.con.execute(''' + WITH x := random() + SELECT ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_12(self): + await self.con.execute(''' + WITH + x := random(), + y := x ++ random(), + SELECT ( + WITH name := y ++ random() + INSERT Person { name := name, tag := name } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], + ) + + async def test_edgeql_insert_volatile_13(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ) + SELECT ( + INSERT Person { + name := x.name ++ "!", + tag := x.tag ++ "!", + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_14(self): + await self.con.execute(''' + WITH + x := "!", + y := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + SELECT ( + INSERT Person { + name := x ++ y.name, + tag := x ++ y.tag, + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_15(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := "!", + INSERT Person { name := name, tag := name, tag2 := name } + ), + SELECT ( + INSERT Person { + name := x ++ y.name, + tag := x ++ y.tag, + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_16(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + SELECT ( + INSERT Person { + name := x ++ y.name, + tag := x ++ y.tag, + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_17(self): + await self.con.execute(''' + WITH + x := "!", + y := ( + WITH name := x ++ random(), + INSERT Person { name := name, tag := name, tag2 := x } + ), + SELECT ( + INSERT Person { + name := y.name ++ "!", + tag := y.tag ++ "!", + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_18(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := x ++ "!", + INSERT Person { name := name, tag := name, tag2 := x } + ), + SELECT ( + INSERT Person { + name := y.name ++ "!", + tag := y.tag ++ "!", + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_19(self): + await self.con.execute(''' + WITH + x := random(), + y := ( + WITH name := x ++ random(), + INSERT Person { name := name, tag := name, tag2 := x } + ), + SELECT ( + INSERT Person { + name := y.name ++ "!", + tag := y.tag ++ "!", + tag2 := y.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_20(self): + await self.con.execute(''' + WITH + x := ( + WITH name := "!", + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := random(), + SELECT ( + INSERT Person { + name := x.name ++ y, + tag := x.tag ++ y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_21(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := "!", + SELECT ( + INSERT Person { + name := x.name ++ y, + tag := x.tag ++ y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_22(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := random(), + SELECT ( + INSERT Person { + name := x.name ++ y, + tag := x.tag ++ y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_23(self): + await self.con.execute(''' + WITH + x := ( + WITH name := "!", + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := x.name ++ random(), + SELECT ( + INSERT Person { + name := y, + tag := y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_24(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := x.name ++ "!", + SELECT ( + INSERT Person { + name := y, + tag := y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_25(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { name := name, tag := name, tag2 := name } + ), + y := x.name ++ random(), + SELECT ( + INSERT Person { + name := y, + tag := y, + tag2 := x.tag2, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) - async with self.assertRaisesRegexTx( - edgedb.QueryError, - "cannot refer to volatile WITH bindings from DML"): - await self.con.execute(''' - WITH name := random(), - SELECT (INSERT Person { name := name, tag := name }); - ''') + async def test_edgeql_insert_volatile_26(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { + name := name, + tag := name, + tag2 := name, + } + ), + y := ( + WITH r := random(), + INSERT Person { + name := x.name ++ r, + tag := x.tag ++ r, + tag2 := x.tag, + } + ), + SELECT ( + WITH r := random(), + INSERT Person { + name := y.name ++ r, + tag := y.name ++ r, + tag2 := y.tag ++ r, + } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [3], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + async def test_edgeql_insert_volatile_27(self): + await self.con.execute(''' + WITH x := "!" + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'WITH N := (Note {ok := .name = .note}) SELECT all(N.ok)', + [True], + ) + + async def test_edgeql_insert_volatile_28(self): + await self.con.execute(''' + WITH x := random(), + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'WITH N := (Note {ok := .name = .note}) SELECT all(N.ok)', + [True], + ) + + async def test_edgeql_insert_volatile_29(self): + await self.con.execute(''' + WITH x := "!", + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := x ++ random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'WITH N := (Note {ok := .name = .note}) SELECT all(N.ok)', + [True], + ) + + async def test_edgeql_insert_volatile_30(self): + await self.con.execute(''' + WITH x := random(), + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := x ++ "!" + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'WITH N := (Note {ok := .name = .note}) SELECT all(N.ok)', + [True], + ) + + async def test_edgeql_insert_volatile_31(self): + await self.con.execute(''' + WITH x := random(), + INSERT Person { + name := x, + tag := x, + note := ( + WITH y := x ++ random() + insert Note { name := y, note := y } + ) + }; + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'WITH N := (Note {ok := .name = .note}) SELECT all(N.ok)', + [True], + ) + async def test_edgeql_insert_volatile_32(self): await self.con.execute(''' - FOR name in {random()} + FOR name in {random(), random()} UNION (INSERT Person { name := name, tag := name }); ''') await self.assert_query_result( - r''' - SELECT all(Person.name = Person.tag) - ''', - [True] + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + + async def test_edgeql_insert_volatile_33(self): + await self.con.execute(''' + WITH x := "!" + FOR y in {random(), random()} + UNION ( + WITH name := x ++ y + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_34(self): + await self.con.execute(''' + WITH x := random() + FOR y in {"A", "B"} + UNION ( + WITH name := x ++ y + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_35(self): + await self.con.execute(''' + WITH x := random() + FOR y in {random(), random()} + UNION ( + WITH name := x ++ y + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_36(self): + await self.con.execute(''' + WITH x := "!" + FOR name in {x ++ random(), x ++ random()} + UNION ( + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_37(self): + await self.con.execute(''' + WITH x := random() + FOR name in {x ++ "A", x ++ "B"} + UNION ( + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_38(self): + await self.con.execute(''' + WITH x := random() + FOR name in {x ++ random(), x ++ random()} + UNION ( + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_volatile_39(self): + await self.con.execute(''' + FOR x in {"A", "B"} + UNION ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + async def test_edgeql_insert_volatile_40(self): + await self.con.execute(''' + FOR x in {random(), random()} + UNION ( + WITH name := x ++ "!" + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + async def test_edgeql_insert_volatile_41(self): + await self.con.execute(''' + FOR x in {random(), random()} + UNION ( + WITH name := x ++ random() + INSERT Person { name := name, tag := name, tag2 := x } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [2], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [2], + ) + + async def test_edgeql_insert_volatile_42(self): + await self.con.execute(''' + WITH + x := ( + WITH name := random(), + INSERT Person { + name := name, + tag := name, + tag2 := name, + } + ) + FOR y in {random(), random()} + UNION ( + WITH name := x.name ++ y + INSERT Person { name := name, tag := name, tag2 := x.tag2 } + ); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [3], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.tag2))', + [1], + ) + + async def test_edgeql_insert_with_freeobject_01(self): + await self.con.execute(''' + WITH free := { name := "asdf" }, + SELECT (INSERT Person { name := free.name }); + ''') + + await self.assert_query_result( + 'SELECT Person.name = "asdf"', + [True], + ) + + async def test_edgeql_insert_with_freeobject_02(self): + await self.con.execute(''' + WITH free := { name := random() }, + SELECT (INSERT Person { name := free.name, tag := free.name }); + ''') + + await self.assert_query_result( + 'WITH P := (Person {ok := .name = .tag}) SELECT all(P.ok)', + [True], + ) + await self.assert_query_result( + 'SELECT count(distinct(Person.name))', + [1], ) async def test_edgeql_insert_multi_exclusive_01(self): diff --git a/tests/test_edgeql_ir_volatility_inference.py b/tests/test_edgeql_ir_volatility_inference.py index 4d56138a2a0..2cc86e8a44d 100644 --- a/tests/test_edgeql_ir_volatility_inference.py +++ b/tests/test_edgeql_ir_volatility_inference.py @@ -156,3 +156,38 @@ def test_edgeql_ir_volatility_inference_11(self): % OK % Volatile """ + + def test_edgeql_ir_volatility_inference_12(self): + """ + select AliasOne +% OK % + Immutable + """ + + def test_edgeql_ir_volatility_inference_13(self): + """ + select global GlobalOne +% OK % + Stable + """ + + def test_edgeql_ir_volatility_inference_14(self): + """ + select AirCard +% OK % + Stable + """ + + def test_edgeql_ir_volatility_inference_15(self): + """ + select global HighestCost +% OK % + Stable + """ + + def test_edgeql_ir_volatility_inference_16(self): + """ + select global CardsWithText +% OK % + Stable + """ diff --git a/tests/test_edgeql_sys.py b/tests/test_edgeql_sys.py index 08921fb52db..db3fbb5ddfa 100644 --- a/tests/test_edgeql_sys.py +++ b/tests/test_edgeql_sys.py @@ -17,12 +17,77 @@ # +import asyncpg import edgedb +from edb.pgsql import common + from edb.testbase import server as tb -class TestEdgeQLSys(tb.QueryTestCase): +class TestQueryStatsMixin: + stats_magic_word: str = NotImplemented + stats_type: str = NotImplemented + + async def _query_for_stats(self): + raise NotImplementedError + + async def _bad_query_for_stats(self): + raise NotImplementedError + + async def _test_sys_query_stats(self): + stats_query = f''' + with stats := ( + select + sys::QueryStats + filter + .query like '%{self.stats_magic_word}%' + and .query not like '%sys::%' + and .query_type = $0 + ) + select sum(stats.calls) + ''' + calls = await self.con.query_single(stats_query, self.stats_type) + + await self._query_for_stats() + self.assertEqual( + await self.con.query_single(stats_query, self.stats_type), + calls + 1, + ) + + await self._bad_query_for_stats() + self.assertEqual( + await self.con.query_single(stats_query, self.stats_type), + calls + 1, + ) + + self.assertIsNone( + await self.con.query_single( + "select sys::reset_query_stats(branch_name := 'non_exdb')" + ) + ) + self.assertEqual( + await self.con.query_single(stats_query, self.stats_type), + calls + 1, + ) + + self.assertIsNotNone( + await self.con.query('select sys::reset_query_stats()') + ) + self.assertEqual( + await self.con.query_single(stats_query, self.stats_type), + 0, + ) + + +class TestEdgeQLSys(tb.QueryTestCase, TestQueryStatsMixin): + stats_magic_word = 'TestEdgeQLSys' + stats_type = 'EdgeQL' + SETUP = f''' + create type {stats_magic_word} {{ + create property bar -> str; + }}; + ''' async def test_edgeql_sys_locks(self): lock_key = tb.gen_lock_key() @@ -59,3 +124,42 @@ async def test_edgeql_sys_locks(self): 'select sys::_advisory_unlock($0)', lock_key), [False]) + + async def _query_for_stats(self): + self.assertEqual( + await self.con.query(f'select {self.stats_magic_word}'), + [], + ) + + async def _bad_query_for_stats(self): + async with self.assertRaisesRegexTx( + edgedb.InvalidReferenceError, 'does not exist' + ): + await self.con.query(f'select {self.stats_magic_word}_NoSuchType') + + async def test_edgeql_sys_query_stats(self): + await self._test_sys_query_stats() + + +class TestSQLSys(tb.SQLQueryTestCase, TestQueryStatsMixin): + stats_magic_word = 'TestSQLSys' + stats_type = 'SQL' + + async def _query_for_stats(self): + self.assertEqual( + await self.squery_values( + f"select {common.quote_literal(self.stats_magic_word)}" + ), + [[self.stats_magic_word]], + ) + + async def _bad_query_for_stats(self): + with self.assertRaisesRegex( + asyncpg.InvalidColumnReferenceError, "cannot find column" + ): + await self.squery_values( + f'select {self.stats_magic_word}_NoSuchType' + ) + + async def test_sql_sys_query_stats(self): + await self._test_sys_query_stats() diff --git a/tests/test_http.py b/tests/test_http.py index f1438d48dd6..efbfdca275a 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -22,6 +22,7 @@ from edb.server import http from edb.testbase import http as tb +from edb.tools.test import async_timeout class HttpTest(tb.BaseHttpTest): @@ -37,8 +38,9 @@ def tearDown(self): self.mock_server = None super().tearDown() + @async_timeout(timeout=5) async def test_get(self): - with http.HttpClient(100) as client: + async with http.HttpClient(100) as client: example_request = ( 'GET', self.base_url, @@ -61,8 +63,9 @@ async def test_get(self): self.assertEqual(result.status_code, 200) self.assertEqual(result.json(), {"message": "Hello, world!"}) + @async_timeout(timeout=5) async def test_post(self): - with http.HttpClient(100) as client: + async with http.HttpClient(100) as client: example_request = ( 'POST', self.base_url, @@ -85,8 +88,9 @@ async def test_post(self): result.json(), {"message": f"Hello, world! {random_data}"} ) + @async_timeout(timeout=5) async def test_post_with_headers(self): - with http.HttpClient(100) as client: + async with http.HttpClient(100) as client: example_request = ( 'POST', self.base_url, @@ -112,19 +116,23 @@ async def test_post_with_headers(self): ) self.assertEqual(result.headers["X-Test"], "test!") + @async_timeout(timeout=5) async def test_bad_url(self): - with http.HttpClient(100) as client: + async with http.HttpClient(100) as client: with self.assertRaisesRegex(Exception, "Scheme"): await client.get("httpx://uh-oh") + @async_timeout(timeout=5) async def test_immediate_connection_drop(self): """Test handling of a connection that is dropped immediately by the server""" async def mock_drop_server( - _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): - # Close connection immediately without sending any response + # Close connection immediately after reading a byte without sending + # any response + await reader.read(1) writer.close() await writer.wait_closed() @@ -133,23 +141,47 @@ async def mock_drop_server( url = f'http://{addr[0]}:{addr[1]}/drop' try: - with http.HttpClient(100) as client: + async with http.HttpClient(100) as client: with self.assertRaisesRegex( - Exception, "Connection reset by peer" + Exception, "Connection reset by peer|IncompleteMessage" ): await client.get(url) finally: server.close() await server.wait_closed() + @async_timeout(timeout=5) + async def test_streaming_get_with_no_sse(self): + async with http.HttpClient(100) as client: + example_request = ( + 'GET', + self.base_url, + '/test-get-with-sse', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + lambda _handler, request: ( + "\"ok\"", + 200, + ) + ) + result = await client.stream_sse(url, method="GET") + self.assertEqual(result.status_code, 200) + self.assertEqual(result.json(), "ok") + + +class HttpSSETest(tb.BaseHttpTest): + @async_timeout(timeout=5) async def test_immediate_connection_drop_streaming(self): """Test handling of a connection that is dropped immediately by the server""" async def mock_drop_server( - _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): - # Close connection immediately without sending any response + # Close connection immediately after reading a byte without sending + # any response + await reader.read(1) writer.close() await writer.wait_closed() @@ -158,34 +190,17 @@ async def mock_drop_server( url = f'http://{addr[0]}:{addr[1]}/drop' try: - with http.HttpClient(100) as client: + async with http.HttpClient(100) as client: with self.assertRaisesRegex( - Exception, "Connection reset by peer" + Exception, "Connection reset by peer|IncompleteMessage" ): await client.stream_sse(url) finally: server.close() await server.wait_closed() - async def test_streaming_get_with_no_sse(self): - with http.HttpClient(100) as client: - example_request = ( - 'GET', - self.base_url, - '/test-get-with-sse', - ) - url = f"{example_request[1]}{example_request[2]}" - self.mock_server.register_route_handler(*example_request)( - lambda _handler, request: ( - "\"ok\"", - 200, - ) - ) - result = await client.stream_sse(url, method="GET") - self.assertEqual(result.status_code, 200) - self.assertEqual(result.json(), "ok") - - async def test_sse_with_mock_server(self): + @async_timeout(timeout=5) + async def test_sse_with_mock_server_client_close(self): """Since the regular mock server doesn't support SSE, we need to test with a real socket. We handle just enough HTTP to get the job done.""" @@ -233,26 +248,103 @@ async def mock_sse_server( url = f'http://{addr[0]}:{addr[1]}/sse' async def client_task(): - with http.HttpClient(100) as client: - response = await client.stream_sse(url, method="GET") - assert response.status_code == 200 - assert response.headers['Content-Type'] == 'text/event-stream' - assert isinstance(response, http.ResponseSSE) - - events = [] - async for event in response: - self.assertEqual(event.event, 'message') - events.append(event) - if len(events) == 3: - break - - assert len(events) == 3 - assert events[0].data == 'Event 1' - assert events[1].data == 'Event 2' - assert events[2].data == 'Event 3' + async with http.HttpClient(100) as client: + async with await client.stream_sse( + url, method="GET" + ) as response: + assert response.status_code == 200 + assert ( + response.headers['Content-Type'] == 'text/event-stream' + ) + assert isinstance(response, http.ResponseSSE) + + events = [] + async for event in response: + self.assertEqual(event.event, 'message') + events.append(event) + if len(events) == 3: + break + + assert len(events) == 3 + assert events[0].data == 'Event 1' + assert events[1].data == 'Event 2' + assert events[2].data == 'Event 3' async with server: client_future = asyncio.create_task(client_task()) await asyncio.wait_for(client_future, timeout=5.0) assert is_closed + + @async_timeout(timeout=5) + async def test_sse_with_mock_server_close(self): + """Try to close the server-side stream and see if the client detects + an end for the iterator. Note that this is technically not correct SSE: + the client should actually try to reconnect after the specified retry + interval, _but_ we don't handle retries yet.""" + + is_closed = False + + async def mock_sse_server( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + nonlocal is_closed + + # Read until empty line + while True: + line = await reader.readline() + if line == b'\r\n': + break + + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"Cache-Control: no-cache\r\n\r\n" + ) + writer.write(headers) + await writer.drain() + + for i in range(3): + writer.write(b": test comment that should be ignored\n\n") + await writer.drain() + + writer.write( + f"event: message\ndata: Event {i + 1}\n\n".encode() + ) + await writer.drain() + await asyncio.sleep(0.1) + + await writer.drain() + writer.close() + is_closed = True + + server = await asyncio.start_server(mock_sse_server, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + url = f'http://{addr[0]}:{addr[1]}/sse' + + async def client_task(): + async with http.HttpClient(100) as client: + async with await client.stream_sse( + url, method="GET", headers={"Connection": "close"} + ) as response: + assert response.status_code == 200 + assert ( + response.headers['Content-Type'] == 'text/event-stream' + ) + assert isinstance(response, http.ResponseSSE) + + events = [] + async for event in response: + self.assertEqual(event.event, 'message') + events.append(event) + + assert len(events) == 3 + assert events[0].data == 'Event 1' + assert events[1].data == 'Event 2' + assert events[2].data == 'Event 3' + + client_future = asyncio.create_task(client_task()) + async with server: + client_future = asyncio.create_task(client_task()) + await asyncio.wait_for(client_future, timeout=5.0) + assert is_closed diff --git a/tests/test_http_auth.py b/tests/test_http_auth.py index e66fc81ba0a..4b173979ba1 100644 --- a/tests/test_http_auth.py +++ b/tests/test_http_auth.py @@ -265,6 +265,7 @@ def test_http_binary_proto_too_old(self): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text="SELECT 42", + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, @@ -308,6 +309,7 @@ def test_http_binary_proto_old_supported(self): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text="SELECT 42", + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index eaefccde7e6..4827fdffdc0 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -31,6 +31,7 @@ from typing import Any, Optional, cast from jwcrypto import jwt, jwk +from email.message import EmailMessage from edgedb import QueryAssertionError from edb.testbase import http as tb @@ -235,7 +236,7 @@ def utcnow(): DISCORD_SECRET = 'd' * 32 SLACK_SECRET = 'd' * 32 GENERIC_OIDC_SECRET = 'e' * 32 -APP_NAME = "Test App" +APP_NAME = "Test App" * 13 LOGO_URL = "http://example.com/logo.png" DARK_LOGO_URL = "http://example.com/darklogo.png" BRAND_COLOR = "f0f8ff" @@ -248,6 +249,14 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): SETUP = [ f""" + CONFIGURE CURRENT DATABASE INSERT cfg::SMTPProviderConfig {{ + name := "email_hosting_is_easy", + sender := "{SENDER}", + }}; + + CONFIGURE CURRENT DATABASE SET + cfg::current_email_provider_name := "email_hosting_is_easy"; + CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::auth_signing_key := '{SIGNING_KEY}'; @@ -272,9 +281,6 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): redirect_to_on_signup := 'https://example.com/signup/app', }}; - CONFIGURE CURRENT DATABASE SET - ext::auth::SMTPConfig::sender := '{SENDER}'; - CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::allowed_redirect_urls := {{ 'https://example.com/app' @@ -3203,10 +3209,18 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], form_data["email"]) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( - r'

([^<]+)', html_email + r'

([^<]+)', + html_email, ) assert match is not None verify_url = urllib.parse.urlparse(match.group(1)) @@ -3382,8 +3396,11 @@ async def test_http_auth_ext_local_webauthn_resend_verification(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -3474,6 +3491,7 @@ async def test_http_auth_ext_token_01(self): challenge := $challenge, auth_token := $auth_token, refresh_token := $refresh_token, + id_token := $id_token, identity := ( insert ext::auth::Identity { issuer := "https://example.com", @@ -3486,12 +3504,14 @@ async def test_http_auth_ext_token_01(self): challenge, auth_token, refresh_token, + id_token, identity_id := .identity.id } """, challenge=challenge.decode(), auth_token="a_provider_token", refresh_token="a_refresh_token", + id_token="an_id_token", ) # Correct code, random verifier @@ -3530,6 +3550,7 @@ async def test_http_auth_ext_token_01(self): "identity_id": str(pkce.identity_id), "provider_token": "a_provider_token", "provider_refresh_token": "a_refresh_token", + "provider_id_token": "an_id_token", }, ) async for tr in self.try_until_succeeds( @@ -3675,8 +3696,11 @@ async def test_http_auth_ext_local_password_forgot_form_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -3862,8 +3886,11 @@ async def test_http_auth_ext_local_password_reset_form_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -4031,7 +4058,7 @@ async def test_http_auth_ext_ui_signin(self): body_str = body.decode() - self.assertIn(APP_NAME, body_str) + self.assertIn(f"{APP_NAME[:100]}...", body_str) self.assertIn(LOGO_URL, body_str) self.assertIn(BRAND_COLOR, body_str) @@ -4065,7 +4092,7 @@ async def test_http_auth_ext_webauthn_register_options(self): self.assertIsInstance(body_json["rp"], dict) self.assertIn("name", body_json["rp"]) - self.assertEqual(body_json["rp"]["name"], APP_NAME) + self.assertEqual(body_json["rp"]["name"], f"{APP_NAME[:100]}...") self.assertIn("id", body_json["rp"]) self.assertEqual(body_json["rp"]["id"], "example.com") @@ -4349,8 +4376,11 @@ async def test_http_auth_ext_magic_link_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 13211593534..2de6c650c4d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -46,6 +46,7 @@ async def _execute( compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=command_text, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, @@ -150,6 +151,7 @@ async def test_proto_flush_01(self): allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.BINARY, expected_cardinality=compiler.Cardinality.AT_MOST_ONE, command_text='SEL ECT 1', @@ -174,6 +176,7 @@ async def test_proto_flush_01(self): allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.BINARY, expected_cardinality=compiler.Cardinality.AT_MOST_ONE, command_text='SELECT 1', @@ -425,6 +428,7 @@ async def _parse(self, query, output_format=protocol.OutputFormat.BINARY): allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, + input_language=protocol.InputLanguage.EDGEQL, output_format=output_format, expected_cardinality=compiler.Cardinality.MANY, command_text=query, @@ -540,6 +544,7 @@ async def _parse_execute(self, query, args): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=query, + input_language=protocol.InputLanguage.EDGEQL, output_format=output_format, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=res.input_typedesc_id, @@ -846,6 +851,7 @@ async def test_proto_connection_lost_cancel_query(self): UPDATE tclcq SET { p := 'inner' }; SELECT sys::_sleep(10); """, + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, @@ -914,6 +920,7 @@ async def test_proto_gh3170_connection_lost_error(self): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text='START TRANSACTION', + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py index 7e1cf5643fe..0c48bf40c33 100644 --- a/tests/test_server_auth.py +++ b/tests/test_server_auth.py @@ -406,6 +406,12 @@ async def test_server_auth_jwt_1(self): ''') await conn.aclose() + with self.assertRaisesRegex( + edgedb.AuthenticationError, + 'authentication failed: no authorization data provided', + ): + await sd.connect() + # bad secret keys with self.assertRaisesRegex( edgedb.AuthenticationError, diff --git a/tests/test_server_config.py b/tests/test_server_config.py index fa6c3f7af65..3f8f8f25591 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -1972,6 +1972,7 @@ async def test_server_config_idle_transaction(self): messages.Execute( annotations=[], command_text=query, + input_language=messages.InputLanguage.EDGEQL, output_format=messages.OutputFormat.NONE, expected_cardinality=messages.Cardinality.MANY, allowed_capabilities=messages.Capability.ALL, diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 0a904be368a..d8676304dde 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -42,6 +42,8 @@ import edgedb from edgedb import errors +import edb +from edb import buildmeta from edb import protocol from edb.common import devmode from edb.protocol import protocol as edb_protocol # type: ignore @@ -684,6 +686,7 @@ async def _test_connection(self, con): compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text='SELECT 1', + input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.NONE, expected_cardinality=protocol.Cardinality.MANY, input_typedesc_id=b'\0' * 16, @@ -1645,9 +1648,6 @@ async def _test_server_ops_global_compile_cache( insert ext::auth::EmailPasswordProviderConfig {{ require_verification := false, }}; - - configure current database set - ext::auth::SMTPConfig::sender := 'noreply@example.com'; ''') finally: await conn.aclose() @@ -1700,3 +1700,39 @@ def reload_server(self): json.dump(self.conf, self.conf_file.file) self.conf_file.file.flush() self.srv.proc.send_signal(signal.SIGHUP) + + +class TestPGExtensions(tb.TestCase): + async def test_edb_stat_statements(self): + ext_home = ( + pathlib.Path(edb.__file__).parent.parent / 'edb_stat_statements' + ).resolve() + if not ext_home.exists(): + raise unittest.SkipTest("no source of edb_stat_statements") + with tempfile.TemporaryDirectory() as td: + cluster = await pgcluster.get_local_pg_cluster(td, log_level='s') + cluster.update_connection_params( + user='postgres', + database='template1', + ) + self.assertTrue(await cluster.ensure_initialized()) + await cluster.start(server_settings={ + 'edb_stat_statements.track_planning': 'false', + 'edb_stat_statements.track_unrecognized': 'true', + 'max_prepared_transactions': '5', + }) + try: + pg_config = buildmeta.get_pg_config_path() + env = os.environ.copy() + params = cluster.get_pgaddr() + env['PGHOST'] = params.host + env['PGPORT'] = params.port + env['PGUSER'] = params.user + env['PGDATABASE'] = params.database + subprocess.check_output([ + 'make', + f'PG_CONFIG={pg_config}', + 'installcheck', + ], cwd=str(ext_home), env=env) + finally: + await cluster.stop() diff --git a/tests/test_sql_dml.py b/tests/test_sql_dml.py index 66d2e9656ad..f93c731f450 100644 --- a/tests/test_sql_dml.py +++ b/tests/test_sql_dml.py @@ -1069,7 +1069,7 @@ async def test_sql_dml_delete_05(self): # delete where current of with self.assertRaisesRegex( asyncpg.FeatureNotSupportedError, - 'unsupported SQL feature `CurrentOfExpr`', + 'not supported: CURRENT OF', ): await self.scon.execute( ''' @@ -1480,7 +1480,7 @@ async def test_sql_dml_update_05(self): # update where current of with self.assertRaisesRegex( asyncpg.FeatureNotSupportedError, - 'unsupported SQL feature `CurrentOfExpr`', + 'not supported: CURRENT OF', ): await self.scon.execute( ''' diff --git a/tests/test_sql_query.py b/tests/test_sql_query.py index a556a90e31a..6c31ae4f8a5 100644 --- a/tests/test_sql_query.py +++ b/tests/test_sql_query.py @@ -26,6 +26,8 @@ from edb.tools import test from edb.testbase import server as tb +import edgedb + try: import asyncpg from asyncpg import serverversion @@ -225,8 +227,8 @@ async def test_sql_query_11(self): 'genre_id', 'release_year', 'title', - 'id', - '__type__', + 'g_id', + 'g___type__', 'name', ], ) @@ -775,7 +777,10 @@ async def test_sql_query_42(self): # params out of order res = await self.squery_values( - 'SELECT $2::int, $3::bool, $1::text', 'hello', 42, True, + 'SELECT $2::int, $3::bool, $1::text', + 'hello', + 42, + True, ) self.assertEqual(res, [[42, True, 'hello']]) @@ -860,6 +865,84 @@ async def test_sql_query_43(self): ) self.assertEqual(res, [[1, 1, 1], [2, None, 2], [None, 3, 3]]) + async def test_sql_query_44(self): + # range function that is an "sql value function", whatever that is + + # to be exact: User is *parsed* as function call CURRENT_USER + # we'd ideally want a message that hints that it should use quotes + + with self.assertRaisesRegex( + asyncpg.InvalidColumnReferenceError, 'cannot find column `name`' + ): + await self.squery_values('SELECT name FROM User') + + async def test_sql_query_45(self): + res = await self.scon.fetch('SELECT 1 AS a, 2 AS a') + self.assert_shape(res, 1, ['a', 'a']) + + async def test_sql_query_46(self): + res = await self.scon.fetch( + ''' + WITH + x(a) AS (VALUES (1)), + y(a) AS (VALUES (2)), + z(a) AS (VALUES (3)) + SELECT * FROM x, y JOIN z u on TRUE + ''' + ) + + # `a` would be duplicated, + # so second and third instance are prefixed with rel var name + self.assert_shape(res, 1, ['a', 'y_a', 'u_a']) + + async def test_sql_query_47(self): + res = await self.scon.fetch( + ''' + WITH + x(a) AS (VALUES (1)), + y(a) AS (VALUES (2), (3)) + SELECT x.*, u.* FROM x, y as u + ''' + ) + self.assert_shape(res, 2, ['a', 'u_a']) + + async def test_sql_query_48(self): + res = await self.scon.fetch( + ''' + WITH + x(a) AS (VALUES (1)), + y(a) AS (VALUES (2), (3)) + SELECT * FROM x, y, y + ''' + ) + + # duplicate rel var names can yield duplicate column names + self.assert_shape(res, 4, ['a', 'y_a', 'y_a']) + + async def test_sql_query_49(self): + res = await self.scon.fetch( + ''' + WITH + x(a) AS (VALUES (2)) + SELECT 1 as x_a, * FROM x, x + ''' + ) + + # duplicate rel var names can yield duplicate column names + self.assert_shape(res, 1, ['x_a', 'a', 'x_a']) + + async def test_sql_query_50(self): + res = await self.scon.fetch( + ''' + WITH + x(a) AS (VALUES (2)) + SELECT 1 as a, * FROM x + ''' + ) + + # duplicate rel var names can yield duplicate column names + self.assert_shape(res, 1, ['a', 'x_a']) + async def test_sql_query_introspection_00(self): dbname = self.con.dbname res = await self.squery_values( @@ -1027,6 +1110,50 @@ async def test_sql_query_introspection_04(self): ], ) + async def test_sql_query_introspection_05(self): + # test pg_constraint + + res = await self.squery_values( + ''' + SELECT pc.relname, pcon.contype, pa.key, pcf.relname, paf.key + FROM pg_constraint pcon + JOIN pg_class pc ON pc.oid = pcon.conrelid + LEFT JOIN pg_class pcf ON pcf.oid = pcon.confrelid + LEFT JOIN LATERAL ( + SELECT string_agg(attname, ',') as key + FROM pg_attribute + WHERE attrelid = pcon.conrelid + AND attnum = ANY(pcon.conkey) + ) pa ON TRUE + LEFT JOIN LATERAL ( + SELECT string_agg(attname, ',') as key + FROM pg_attribute + WHERE attrelid = pcon.confrelid + AND attnum = ANY(pcon.confkey) + ) paf ON TRUE + WHERE pc.relname IN ( + 'Book.chapters', 'Movie', 'Movie.director', 'Movie.actors' + ) + ORDER BY pc.relname ASC, pcon.contype DESC, pa.key + ''' + ) + + self.assertEqual( + res, + [ + ['Book.chapters', b'f', 'source', 'Book', 'id'], + ['Movie', b'p', 'id', None, None], + ['Movie', b'f', 'director_id', 'Person', 'id'], + ['Movie', b'f', 'genre_id', 'Genre', 'id'], + ['Movie.actors', b'p', 'source,target', None, None], + ['Movie.actors', b'f', 'source', 'Movie', 'id'], + ['Movie.actors', b'f', 'target', 'Person', 'id'], + ['Movie.director', b'p', 'source,target', None, None], + ['Movie.director', b'f', 'source', 'Movie', 'id'], + ['Movie.director', b'f', 'target', 'Person', 'id'], + ], + ) + async def test_sql_query_schemas_01(self): await self.scon.fetch('SELECT id FROM "inventory"."Item";') await self.scon.fetch('SELECT id FROM "public"."Person";') @@ -1250,6 +1377,7 @@ async def test_sql_query_static_eval_03(self): SELECT information_schema._pg_truetypid(a.*, t.*) FROM pg_attribute a JOIN pg_type t ON t.oid = a.atttypid + LIMIT 500 ''' ) @@ -2006,3 +2134,244 @@ async def test_sql_query_access_policy_04(self): self.assertEqual(len(res), 0) await tran.rollback() + + async def test_sql_query_unsupported_01(self): + # test error messages of unsupported queries + + # we build AST for this not, but throw in resolver + with self.assertRaisesRegex( + asyncpg.FeatureNotSupportedError, + "not supported: CREATE", + # position="14", # TODO: this is confusing + ): + await self.squery_values('CREATE TABLE a();') + + # we don't even have AST node for this + with self.assertRaisesRegex( + asyncpg.FeatureNotSupportedError, + "not supported: ALTER TABLE", + ): + await self.squery_values('ALTER TABLE a ADD COLUMN b INT;') + + with self.assertRaisesRegex( + asyncpg.FeatureNotSupportedError, + "not supported: REINDEX", + ): + await self.squery_values('REINDEX TABLE a;') + + async def test_native_sql_query_00(self): + await self.assert_sql_query_result( + """ + SELECT + 1 AS a, + 'two' AS b, + to_json('three') AS c, + timestamp '2000-12-16 12:21:13' AS d, + timestamp with time zone '2000-12-16 12:21:13' AS e, + date '0001-01-01 AD' AS f, + interval '2000 years' AS g, + ARRAY[1, 2, 3] AS h, + FALSE AS i + """, + [ + { + "a": 1, + "b": "two", + "c": '"three"', + "d": "2000-12-16T12:21:13", + "e": "2000-12-16T12:21:13+00:00", + "f": "0001-01-01", + "g": edgedb.RelativeDuration(months=2000 * 12), + "h": [1, 2, 3], + "i": False, + } + ], + ) + + async def test_native_sql_query_01(self): + await self.assert_sql_query_result( + """ + SELECT + "Movie".title, + "Genre".name AS genre + FROM + "Movie", + "Genre" + WHERE + "Movie".genre_id = "Genre".id + AND "Genre".name = 'Drama' + ORDER BY + title + """, + [ + { + "title": "Forrest Gump", + "genre": "Drama", + }, + { + "title": "Saving Private Ryan", + "genre": "Drama", + }, + ], + ) + + async def test_native_sql_query_02(self): + await self.assert_sql_query_result( + """ + SELECT + "Movie".title, + "Genre".name AS genre + FROM + "Movie", + "Genre" + WHERE + "Movie".genre_id = "Genre".id + AND "Genre".name = $1::text + AND length("Movie".title) > $2::int + ORDER BY + title + """, + [ + { + "title": "Saving Private Ryan", + "genre": "Drama", + } + ], + variables={ + "0": "Drama", + "1": 14, + }, + ) + + async def test_native_sql_query_03(self): + # No output at all + await self.assert_sql_query_result( + """ + SELECT + WHERE NULL + """, + [], + ) + + # Empty tuples + await self.assert_sql_query_result( + """ + SELECT + FROM "Movie" + LIMIT 1 + """, + [{}], + ) + + async def test_native_sql_query_04(self): + with self.assertRaisesRegex( + edgedb.errors.QueryError, + 'duplicate column name: `a`', + _position=16, + ): + await self.assert_sql_query_result('SELECT 1 AS a, 2 AS a', []) + + async def test_native_sql_query_05(self): + # `a` would be duplicated, + # so second and third instance are prefixed with rel var name + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (1::int)), + y(a) AS (VALUES (1::int + 1::int)), + z(a) AS (VALUES (1::int + 1::int + 1::int)) + SELECT * FROM x, y JOIN z u ON TRUE::bool + ''', + [{'a': 1, 'y_a': 2, 'u_a': 3}], + ) + + async def test_native_sql_query_06(self): + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (1)), + y(a) AS (VALUES (2), (3)) + SELECT x.*, u.* FROM x, y as u + ''', + [{'a': 1, 'u_a': 2}, {'a': 1, 'u_a': 3}], + ) + + async def test_native_sql_query_07(self): + with self.assertRaisesRegex( + edgedb.errors.QueryError, + 'duplicate column name: `y_a`', + # _position=114, TODO: spans are messed up somewhere + ): + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (1)), + y(a) AS (VALUES (1 + 1), (1 + 1 + 1)) + SELECT * FROM x, y, y + ''', + [], + ) + + async def test_native_sql_query_08(self): + with self.assertRaisesRegex( + edgedb.errors.QueryError, + 'duplicate column name: `x_a`', + # _position=83, TODO: spans are messed up somewhere + ): + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (2)) + SELECT 1 as x_a, * FROM x, x + ''', + [], + ) + + async def test_native_sql_query_09(self): + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (1 + 1)) + SELECT 1 as a, * FROM x + ''', + [{'a': 1, 'x_a': 2}], + ) + + async def test_native_sql_query_10(self): + await self.assert_sql_query_result( + ''' + WITH + x(b, c) AS (VALUES (2, 3)) + SELECT 1 as a, * FROM x + ''', + [{'a': 1, 'b': 2, 'c': 3}], # values are swapped around + ) + + async def test_native_sql_query_11(self): + # JOIN ... ON TRUE fails, saying it expects bool, but it got an int + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (1)), + y(b) AS (VALUES (2)), + z(c) AS (VALUES (3)) + SELECT * FROM x, y JOIN z ON TRUE + ''', + [{'a': 1, 'b': 2, 'c': 3}], + ) + + async def test_native_sql_query_12(self): + await self.assert_sql_query_result( + ''' + WITH + x(a) AS (VALUES (1), (5)), + y(b) AS (VALUES (2), (3)) + SELECT * FROM x, y + ''', + [ + {'a': 1, 'b': 2}, + {'a': 1, 'b': 3}, + {'a': 5, 'b': 2}, + {'a': 5, 'b': 3}, + ], + )