From 007d8822413c356f888e55c0fcf6bb5c34a1286f Mon Sep 17 00:00:00 2001 From: Jelte Fennema-Nio Date: Mon, 23 Sep 2024 09:06:06 +0200 Subject: [PATCH] Support extended protocol and prepared statements (#147) Previously prepared statements and the extended protocol were broken. We need to support this feature because pretty much every client library uses it. The reason that prepared statements did not work is that for prepared statements to work, a Postgres plan needs to be fully copyable using Postgres functionality. We were previously storing the DuckDB plan in the Postgres plan, but since Postgres did not know how to copy the DuckDB plan it would fail. To work around that problem we now store the full Postgres `Query` structure inside the DuckDB CustomScan node. Then during execution we use this `Query` structure to re-plan the query during execution. This works but it has the big downside that we're now planning the query twice in DuckDB, once during planning to get the the column types of the result and once during execution. For now that seems acceptable, so we can at least support prepared statements and the extended protocol. In the future we might want to do something smarter though, like serializing and deserializing the DuckDB plan that is created at planning time. Note: This does not support all our supported types as arguments to a prepared statement yet. After merging I'll open separate issue to track adding the rest of these types. Fixes #118 --- include/pgduckdb/pgduckdb_planner.hpp | 6 ++ include/pgduckdb/pgduckdb_types.hpp | 1 + src/pgduckdb_hooks.cpp | 11 ++++ src/pgduckdb_node.cpp | 59 ++++++++++++++--- src/pgduckdb_planner.cpp | 56 ++++++++++------ src/pgduckdb_types.cpp | 44 +++++++++++++ test/pycheck/explain_test.py | 16 ++++- test/pycheck/prepared_test.py | 95 +++++++++++++++++++++++++++ test/pycheck/utils.py | 12 ++-- test/regression/expected/basic.out | 22 +++++++ test/regression/sql/basic.sql | 3 + 11 files changed, 290 insertions(+), 35 deletions(-) create mode 100644 test/pycheck/prepared_test.py diff --git a/include/pgduckdb/pgduckdb_planner.hpp b/include/pgduckdb/pgduckdb_planner.hpp index b95fb65b..1784a46f 100644 --- a/include/pgduckdb/pgduckdb_planner.hpp +++ b/include/pgduckdb/pgduckdb_planner.hpp @@ -1,10 +1,16 @@ #pragma once +#include "duckdb.hpp" + extern "C" { #include "postgres.h" #include "optimizer/planner.h" } +#include "pgduckdb/pgduckdb_duckdb.hpp" + extern bool duckdb_explain_analyze; PlannedStmt *DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params); +std::tuple, duckdb::unique_ptr> +DuckdbPrepare(const Query *query, ParamListInfo bound_params); diff --git a/include/pgduckdb/pgduckdb_types.hpp b/include/pgduckdb/pgduckdb_types.hpp index 2f0d00a3..d444bef3 100644 --- a/include/pgduckdb/pgduckdb_types.hpp +++ b/include/pgduckdb/pgduckdb_types.hpp @@ -19,6 +19,7 @@ constexpr int64_t PGDUCKDB_DUCK_TIMESTAMP_OFFSET = INT64CONST(10957) * USECS_PER duckdb::LogicalType ConvertPostgresToDuckColumnType(Form_pg_attribute &attribute); Oid GetPostgresDuckDBType(duckdb::LogicalType type); +duckdb::Value ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type); void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset); bool ConvertDuckToPostgresValue(TupleTableSlot *slot, duckdb::Value &value, idx_t col); void InsertTupleIntoChunk(duckdb::DataChunk &output, duckdb::shared_ptr scan_global_state, diff --git a/src/pgduckdb_hooks.cpp b/src/pgduckdb_hooks.cpp index 4d74e8ed..aa92a764 100644 --- a/src/pgduckdb_hooks.cpp +++ b/src/pgduckdb_hooks.cpp @@ -158,6 +158,17 @@ extern "C" { void DuckdbExplainOneQueryHook(Query *query, int cursorOptions, IntoClause *into, ExplainState *es, const char *queryString, ParamListInfo params, QueryEnvironment *queryEnv) { + /* + * It might seem sensible to store this data in the custom_private + * field of the CustomScan node, but that's not a trivial change to make. + * Storing this in a global variable works fine, as long as we only use + * this variable during planning when we're actually executing an explain + * QUERY (this can be checked by checking the commandTag of the + * ActivePortal). This even works when plans would normally be cached, + * because EXPLAIN always execute this hook whenever they are executed. + * EXPLAIN queries are also always re-planned (see + * standard_ExplainOneQuery). + */ duckdb_explain_analyze = es->analyze; prev_explain_one_query_hook(query, cursorOptions, into, es, queryString, params, queryEnv); } diff --git a/src/pgduckdb_node.cpp b/src/pgduckdb_node.cpp index daf267ed..998e9a13 100644 --- a/src/pgduckdb_node.cpp +++ b/src/pgduckdb_node.cpp @@ -4,10 +4,15 @@ extern "C" { #include "postgres.h" #include "miscadmin.h" +#include "tcop/pquery.h" +#include "nodes/params.h" +#include "utils/ruleutils.h" } #include "pgduckdb/pgduckdb_node.hpp" #include "pgduckdb/pgduckdb_types.hpp" +#include "pgduckdb/pgduckdb_duckdb.hpp" +#include "pgduckdb/pgduckdb_planner.hpp" /* global variables */ CustomScanMethods duckdb_scan_scan_methods; @@ -17,6 +22,8 @@ static CustomExecMethods duckdb_scan_exec_methods; typedef struct DuckdbScanState { CustomScanState css; /* must be first field */ + const Query *query; + ParamListInfo params; duckdb::Connection *duckdb_connection; duckdb::PreparedStatement *prepared_statement; bool is_executed; @@ -46,10 +53,8 @@ static Node * Duckdb_CreateCustomScanState(CustomScan *cscan) { DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)newNode(sizeof(DuckdbScanState), T_CustomScanState); CustomScanState *custom_scan_state = &duckdb_scan_state->css; - duckdb_scan_state->duckdb_connection = (duckdb::Connection *)linitial(cscan->custom_private); - duckdb_scan_state->prepared_statement = (duckdb::PreparedStatement *)lsecond(cscan->custom_private); - duckdb_scan_state->is_executed = false; - duckdb_scan_state->fetch_next = true; + + duckdb_scan_state->query = (const Query *)linitial(cscan->custom_private); custom_scan_state->methods = &duckdb_scan_exec_methods; return (Node *)custom_scan_state; } @@ -57,6 +62,19 @@ Duckdb_CreateCustomScanState(CustomScan *cscan) { void Duckdb_BeginCustomScan(CustomScanState *cscanstate, EState *estate, int eflags) { DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)cscanstate; + auto prepare_result = DuckdbPrepare(duckdb_scan_state->query, estate->es_param_list_info); + auto prepared_query = std::move(std::get<0>(prepare_result)); + auto duckdb_connection = std::move(std::get<1>(prepare_result)); + + if (prepared_query->HasError()) { + elog(ERROR, "DuckDB re-planning failed %s", prepared_query->GetError().c_str()); + } + + duckdb_scan_state->duckdb_connection = duckdb_connection.release(); + duckdb_scan_state->prepared_statement = prepared_query.release(); + duckdb_scan_state->params = estate->es_param_list_info; + duckdb_scan_state->is_executed = false; + duckdb_scan_state->fetch_next = true; duckdb_scan_state->css.ss.ps.ps_ResultTupleDesc = duckdb_scan_state->css.ss.ss_ScanTupleSlot->tts_tupleDescriptor; HOLD_CANCEL_INTERRUPTS(); } @@ -66,8 +84,33 @@ ExecuteQuery(DuckdbScanState *state) { auto &prepared = *state->prepared_statement; auto &query_results = state->query_results; auto &connection = state->duckdb_connection; + auto pg_params = state->params; + const auto num_params = pg_params ? pg_params->numParams : 0; + duckdb::vector duckdb_params; + for (int i = 0; i < num_params; i++) { + ParamExternData *pg_param; + ParamExternData tmp_workspace; + + /* give hook a chance in case parameter is dynamic */ + if (pg_params->paramFetch != NULL) + pg_param = pg_params->paramFetch(pg_params, i + 1, false, &tmp_workspace); + else + pg_param = &pg_params->params[i]; + + if (pg_param->isnull) { + duckdb_params.push_back(duckdb::Value()); + } else { + if (!OidIsValid(pg_param->ptype)) { + elog(ERROR, "parameter with invalid type during execution"); + } + duckdb_params.push_back(pgduckdb::ConvertPostgresParameterToDuckValue(pg_param->value, pg_param->ptype)); + } + } - auto pending = prepared.PendingQuery(); + auto pending = prepared.PendingQuery(duckdb_params, true); + if (pending->HasError()) { + elog(ERROR, "DuckDB execute returned an error: %s", pending->GetError().c_str()); + } duckdb::PendingExecutionResult execution_result; do { execution_result = pending->ExecuteTask(); @@ -160,14 +203,14 @@ Duckdb_ReScanCustomScan(CustomScanState *node) { void Duckdb_ExplainCustomScan(CustomScanState *node, List *ancestors, ExplainState *es) { DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)node; - auto res = duckdb_scan_state->prepared_statement->Execute(); - std::string explain_output = "\n\n"; - auto chunk = res->Fetch(); + ExecuteQuery(duckdb_scan_state); + auto chunk = duckdb_scan_state->query_results->Fetch(); if (!chunk || chunk->size() == 0) { return; } /* Is it safe to hardcode this as result of DuckDB explain? */ auto value = chunk->GetValue(1, 0); + std::string explain_output = "\n\n"; explain_output += value.GetValue(); explain_output += "\n"; ExplainPropertyText("DuckDB Execution Plan", explain_output.c_str(), es); diff --git a/src/pgduckdb_planner.cpp b/src/pgduckdb_planner.cpp index 68cbfdd5..d6c7fe87 100644 --- a/src/pgduckdb_planner.cpp +++ b/src/pgduckdb_planner.cpp @@ -4,6 +4,8 @@ extern "C" { #include "postgres.h" #include "catalog/pg_type.h" #include "nodes/makefuncs.h" +#include "nodes/nodes.h" +#include "nodes/params.h" #include "optimizer/optimizer.h" #include "tcop/pquery.h" #include "utils/syscache.h" @@ -51,21 +53,44 @@ PlanQuery(Query *parse, ParamListInfo bound_params) { ); } -static Plan * -CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) { +std::tuple, duckdb::unique_ptr> +DuckdbPrepare(const Query *query, ParamListInfo bound_params) { + /* + * Copy the query, so the original one is not modified by the + * subquery_planner call that PlanQuery does. + */ + Query *copied_query = (Query *)copyObjectImpl(query); + const char *query_string = pgduckdb_pg_get_querydef(copied_query, false); - List *rtables = query->rtable; + if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) { + if (duckdb_explain_analyze) { + query_string = psprintf("EXPLAIN ANALYZE %s", query_string); + } else { + query_string = psprintf("EXPLAIN %s", query_string); + } + } + + elog(DEBUG2, "(PGDuckDB/DuckdbPrepare) Preparing: %s", query_string); + List *rtables = copied_query->rtable; /* Extract required vars for table */ int flags = PVC_RECURSE_AGGREGATES | PVC_RECURSE_WINDOWFUNCS | PVC_RECURSE_PLACEHOLDERS; - List *vars = list_concat(pull_var_clause((Node *)query->targetList, flags), - pull_var_clause((Node *)query->jointree->quals, flags)); - - PlannerInfo *query_planner_info = PlanQuery(query, bound_params); + List *vars = list_concat(pull_var_clause((Node *)copied_query->targetList, flags), + pull_var_clause((Node *)copied_query->jointree->quals, flags)); + PlannerInfo *query_planner_info = PlanQuery(copied_query, bound_params); auto duckdb_connection = pgduckdb::DuckdbCreateConnection(rtables, query_planner_info, vars, query_string); auto context = duckdb_connection->context; - auto prepared_query = context->Prepare(query_string); + return {std::move(prepared_query), std::move(duckdb_connection)}; +} + +static Plan * +CreatePlan(Query *query, ParamListInfo bound_params) { + /* + * Prepare the query, se we can get the returned types and column names. + */ + auto prepare_result = DuckdbPrepare(query, bound_params); + auto prepared_query = std::move(std::get<0>(prepare_result)); if (prepared_query->HasError()) { elog(WARNING, "(PGDuckDB/CreatePlan) Prepared query returned an error: '%s", @@ -101,12 +126,12 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) { duckdb_node->custom_scan_tlist = lappend(duckdb_node->custom_scan_tlist, - makeTargetEntry((Expr *)var, i + 1, (char *)prepared_query->GetNames()[i].c_str(), false)); + makeTargetEntry((Expr *)var, i + 1, (char *)pstrdup(prepared_query->GetNames()[i].c_str()), false)); ReleaseSysCache(tp); } - duckdb_node->custom_private = list_make2(duckdb_connection.release(), prepared_query.release()); + duckdb_node->custom_private = list_make1(query); duckdb_node->methods = &duckdb_scan_scan_methods; return (Plan *)duckdb_node; @@ -114,17 +139,8 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) { PlannedStmt * DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params) { - const char *query_string = pgduckdb_pg_get_querydef(parse, false); - - if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) { - if (duckdb_explain_analyze) { - query_string = psprintf("EXPLAIN ANALYZE %s", query_string); - } else { - query_string = psprintf("EXPLAIN %s", query_string); - } - } /* We need to check can we DuckDB create plan */ - Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, query_string, bound_params)); + Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, bound_params)); if (!duckdb_plan) { return nullptr; diff --git a/src/pgduckdb_types.cpp b/src/pgduckdb_types.cpp index b0f1e6cd..edd433ea 100644 --- a/src/pgduckdb_types.cpp +++ b/src/pgduckdb_types.cpp @@ -9,12 +9,15 @@ extern "C" { #include "miscadmin.h" #include "catalog/pg_type.h" #include "executor/tuptable.h" +#include "utils/builtins.h" #include "utils/numeric.h" #include "utils/uuid.h" #include "utils/array.h" #include "fmgr.h" #include "utils/lsyscache.h" #include "utils/syscache.h" +#include "utils/date.h" +#include "utils/timestamp.h" } #include "pgduckdb/pgduckdb.h" @@ -707,6 +710,47 @@ ConvertDecimal(const NumericVar &numeric) { return (NumericIsNegative(numeric) ? -base_res : base_res); } +/* + * Convert a Postgres Datum to a DuckDB Value. This is meant to be used to + * covert query parameters in a prepared statement to its DuckDB equivalent. + * Passing it a Datum that is stored on disk results in undefined behavior, + * because this fuction makes no effert to detoast the Datum. + */ +duckdb::Value +ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type) { + switch (postgres_type) { + case BOOLOID: + return duckdb::Value::BOOLEAN(DatumGetBool(value)); + case INT2OID: + return duckdb::Value::SMALLINT(DatumGetInt16(value)); + case INT4OID: + return duckdb::Value::INTEGER(DatumGetInt32(value)); + case INT8OID: + return duckdb::Value::BIGINT(DatumGetInt64(value)); + case BPCHAROID: + case TEXTOID: + case JSONOID: + case VARCHAROID: { + // FIXME: TextDatumGetCstring allocates so it needs a + // guard, but it's a macro not a function, so our current gaurd + // template does not handle it. + return duckdb::Value(TextDatumGetCString(value)); + } + case DATEOID: + return duckdb::Value::DATE(duckdb::date_t(DatumGetDateADT(value) + PGDUCKDB_DUCK_DATE_OFFSET)); + case TIMESTAMPOID: + return duckdb::Value::TIMESTAMP(duckdb::timestamp_t(DatumGetTimestamp(value) + PGDUCKDB_DUCK_TIMESTAMP_OFFSET)); + case FLOAT4OID: { + return duckdb::Value::FLOAT(DatumGetFloat4(value)); + } + case FLOAT8OID: { + return duckdb::Value::DOUBLE(DatumGetFloat8(value)); + } + default: + elog(ERROR, "Could not convert Postgres parameter of type: %d to DuckDB type", postgres_type); + } +} + void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset) { auto &type = result.GetType(); diff --git a/test/pycheck/explain_test.py b/test/pycheck/explain_test.py index a1411610..3569ba11 100644 --- a/test/pycheck/explain_test.py +++ b/test/pycheck/explain_test.py @@ -2,7 +2,7 @@ def test_explain(cur: Cursor): - cur.sql("CREATE TABLE test_table (id int primary key, name text)") + cur.sql("CREATE TABLE test_table (id int, name text)") result = cur.sql("EXPLAIN SELECT count(*) FROM test_table") plan = "\n".join(result) assert "UNGROUPED_AGGREGATE" in plan @@ -13,3 +13,17 @@ def test_explain(cur: Cursor): assert "Query Profiling Information" in plan assert "UNGROUPED_AGGREGATE" in plan assert "Total Time:" in plan + + result = cur.sql("EXPLAIN SELECT count(*) FROM test_table where id = %s", (1,)) + plan = "\n".join(result) + assert "UNGROUPED_AGGREGATE" in plan + assert "id=1 AND id IS NOT NULL" in plan + assert "Total Time:" not in plan + + result = cur.sql( + "EXPLAIN ANALYZE SELECT count(*) FROM test_table where id = %s", (1,) + ) + plan = "\n".join(result) + assert "UNGROUPED_AGGREGATE" in plan + assert "id=1 AND id IS NOT NULL" in plan + assert "Total Time:" in plan diff --git a/test/pycheck/prepared_test.py b/test/pycheck/prepared_test.py new file mode 100644 index 00000000..ecf31582 --- /dev/null +++ b/test/pycheck/prepared_test.py @@ -0,0 +1,95 @@ +from .utils import Cursor + +import datetime +import psycopg.types.json + + +def test_prepared(cur: Cursor): + cur.sql("CREATE TABLE test_table (id int)") + + # Try prepared query without parameters + q1 = "SELECT count(*) FROM test_table" + assert cur.sql(q1, prepare=True) == 0 + assert cur.sql(q1) == 0 + assert cur.sql(q1) == 0 + + cur.sql("INSERT INTO test_table VALUES (1), (2), (3)") + assert cur.sql(q1) == 3 + + # The following tests a prepared query that has parameters. + # There are two ways in which prepared queries that have parameters can be + # executed: + # 1. With a custom plan, where the query is prepared with the exact values + # 2. With a generic plan, where the query is planned without the values and + # the values get only substituted at execution time + # + # The below tests both of these cases, by setting the plan_cache_mode. + q2 = "SELECT count(*) FROM test_table where id = %s" + cur.sql("SET plan_cache_mode = 'force_custom_plan'") + assert cur.sql(q2, (1,), prepare=True) == 1 + assert cur.sql(q2, (1,)) == 1 + assert cur.sql(q2, (1,)) == 1 + assert cur.sql(q2, (3,)) == 1 + assert cur.sql(q2, (4,)) == 0 + + cur.sql("SET plan_cache_mode = 'force_generic_plan'") + assert cur.sql(q2, (1,)) == 1 # creates generic plan + assert cur.sql(q2, (1,)) == 1 + assert cur.sql(q2, (3,)) == 1 + assert cur.sql(q2, (4,)) == 0 + + +def test_extended(cur: Cursor): + cur.sql(""" + CREATE TABLE t( + bool BOOLEAN, + i2 SMALLINT, + i4 INT, + i8 BIGINT, + fl4 REAL, + fl8 DOUBLE PRECISION, + t1 TEXT, + t2 VARCHAR, + t3 BPCHAR, + d DATE, + ts TIMESTAMP, + json_obj JSON); + """) + + row = ( + True, + 2, + 4, + 8, + 4.0, + 8.0, + "t1", + "t2", + "t3", + datetime.date(2024, 5, 4), + datetime.datetime(2020, 1, 1, 1, 2, 3), + psycopg.types.json.Json({"a": 1}), + ) + cur.sql( + "INSERT INTO t VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", row + ) + + assert (True,) * len(row) == cur.sql( + """ + SELECT + bool = %s, + i2 = %s, + i4 = %s, + i8 = %s, + fl4 = %s, + fl8 = %s, + t1 = %s, + t2 = %s, + t3 = %s, + d = %s, + ts = %s, + json_obj::text = %s::text + FROM t; + """, + row, + ) diff --git a/test/pycheck/utils.py b/test/pycheck/utils.py index 89b864a1..33fbb7b0 100644 --- a/test/pycheck/utils.py +++ b/test/pycheck/utils.py @@ -237,8 +237,8 @@ def __init__(self, cursor: psycopg.Cursor): def __getattr__(self, name): return getattr(self.cursor, name) - def sql(self, query, params=None) -> Any: - self.execute(query, params) + def sql(self, query, params=None, **kwargs) -> Any: + self.execute(query, params, **kwargs) try: return simplify_query_results(self.fetchall()) except psycopg.ProgrammingError as e: @@ -256,11 +256,11 @@ def __init__(self, cursor: psycopg.AsyncCursor): def __getattr__(self, name): return getattr(self.cursor, name) - def sql(self, query, params=None): - return asyncio.ensure_future(self.sql_coroutine(query, params)) + def sql(self, query, params=None, **kwargs): + return asyncio.ensure_future(self.sql_coroutine(query, params, **kwargs)) - async def sql_coroutine(self, query, params=None) -> Any: - await self.execute(query, params) + async def sql_coroutine(self, query, params=None, **kwargs) -> Any: + await self.execute(query, params, **kwargs) try: return simplify_query_results(await self.fetchall()) except psycopg.ProgrammingError as e: diff --git a/test/regression/expected/basic.out b/test/regression/expected/basic.out index fc407602..ed181bea 100644 --- a/test/regression/expected/basic.out +++ b/test/regression/expected/basic.out @@ -16,6 +16,28 @@ SELECT a, COUNT(*) FROM t WHERE a > 5 GROUP BY a ORDER BY a; 9 | 100 (4 rows) +select COUNT(*) from t \bind \g + count +------- + 1000 +(1 row) + +select a, COUNT(*) from t WHERE a > $1 GROUP BY a ORDER BY a \bind 5 \g + a | count +---+------- + 6 | 100 + 7 | 100 + 8 | 100 + 9 | 100 +(4 rows) + +\bind 7 \g + a | count +---+------- + 8 | 100 + 9 | 100 +(2 rows) + SET duckdb.max_threads_per_query to 4; SELECT COUNT(*) FROM t; count diff --git a/test/regression/sql/basic.sql b/test/regression/sql/basic.sql index d32d3f0a..b9b3f5e9 100644 --- a/test/regression/sql/basic.sql +++ b/test/regression/sql/basic.sql @@ -6,6 +6,9 @@ SET client_min_messages to 'DEBUG1'; SELECT COUNT(*) FROM t; SELECT a, COUNT(*) FROM t WHERE a > 5 GROUP BY a ORDER BY a; +select COUNT(*) from t \bind \g +select a, COUNT(*) from t WHERE a > $1 GROUP BY a ORDER BY a \bind 5 \g +\bind 7 \g SET duckdb.max_threads_per_query to 4;