diff --git a/include/pgduckdb/pgduckdb_planner.hpp b/include/pgduckdb/pgduckdb_planner.hpp index b95fb65b..1784a46f 100644 --- a/include/pgduckdb/pgduckdb_planner.hpp +++ b/include/pgduckdb/pgduckdb_planner.hpp @@ -1,10 +1,16 @@ #pragma once +#include "duckdb.hpp" + extern "C" { #include "postgres.h" #include "optimizer/planner.h" } +#include "pgduckdb/pgduckdb_duckdb.hpp" + extern bool duckdb_explain_analyze; PlannedStmt *DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params); +std::tuple, duckdb::unique_ptr> +DuckdbPrepare(const Query *query, ParamListInfo bound_params); diff --git a/include/pgduckdb/pgduckdb_types.hpp b/include/pgduckdb/pgduckdb_types.hpp index 2f0d00a3..d444bef3 100644 --- a/include/pgduckdb/pgduckdb_types.hpp +++ b/include/pgduckdb/pgduckdb_types.hpp @@ -19,6 +19,7 @@ constexpr int64_t PGDUCKDB_DUCK_TIMESTAMP_OFFSET = INT64CONST(10957) * USECS_PER duckdb::LogicalType ConvertPostgresToDuckColumnType(Form_pg_attribute &attribute); Oid GetPostgresDuckDBType(duckdb::LogicalType type); +duckdb::Value ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type); void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset); bool ConvertDuckToPostgresValue(TupleTableSlot *slot, duckdb::Value &value, idx_t col); void InsertTupleIntoChunk(duckdb::DataChunk &output, duckdb::shared_ptr scan_global_state, diff --git a/src/pgduckdb_hooks.cpp b/src/pgduckdb_hooks.cpp index 4d74e8ed..aa92a764 100644 --- a/src/pgduckdb_hooks.cpp +++ b/src/pgduckdb_hooks.cpp @@ -158,6 +158,17 @@ extern "C" { void DuckdbExplainOneQueryHook(Query *query, int cursorOptions, IntoClause *into, ExplainState *es, const char *queryString, ParamListInfo params, QueryEnvironment *queryEnv) { + /* + * It might seem sensible to store this data in the custom_private + * field of the CustomScan node, but that's not a trivial change to make. + * Storing this in a global variable works fine, as long as we only use + * this variable during planning when we're actually executing an explain + * QUERY (this can be checked by checking the commandTag of the + * ActivePortal). This even works when plans would normally be cached, + * because EXPLAIN always execute this hook whenever they are executed. + * EXPLAIN queries are also always re-planned (see + * standard_ExplainOneQuery). + */ duckdb_explain_analyze = es->analyze; prev_explain_one_query_hook(query, cursorOptions, into, es, queryString, params, queryEnv); } diff --git a/src/pgduckdb_node.cpp b/src/pgduckdb_node.cpp index daf267ed..998e9a13 100644 --- a/src/pgduckdb_node.cpp +++ b/src/pgduckdb_node.cpp @@ -4,10 +4,15 @@ extern "C" { #include "postgres.h" #include "miscadmin.h" +#include "tcop/pquery.h" +#include "nodes/params.h" +#include "utils/ruleutils.h" } #include "pgduckdb/pgduckdb_node.hpp" #include "pgduckdb/pgduckdb_types.hpp" +#include "pgduckdb/pgduckdb_duckdb.hpp" +#include "pgduckdb/pgduckdb_planner.hpp" /* global variables */ CustomScanMethods duckdb_scan_scan_methods; @@ -17,6 +22,8 @@ static CustomExecMethods duckdb_scan_exec_methods; typedef struct DuckdbScanState { CustomScanState css; /* must be first field */ + const Query *query; + ParamListInfo params; duckdb::Connection *duckdb_connection; duckdb::PreparedStatement *prepared_statement; bool is_executed; @@ -46,10 +53,8 @@ static Node * Duckdb_CreateCustomScanState(CustomScan *cscan) { DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)newNode(sizeof(DuckdbScanState), T_CustomScanState); CustomScanState *custom_scan_state = &duckdb_scan_state->css; - duckdb_scan_state->duckdb_connection = (duckdb::Connection *)linitial(cscan->custom_private); - duckdb_scan_state->prepared_statement = (duckdb::PreparedStatement *)lsecond(cscan->custom_private); - duckdb_scan_state->is_executed = false; - duckdb_scan_state->fetch_next = true; + + duckdb_scan_state->query = (const Query *)linitial(cscan->custom_private); custom_scan_state->methods = &duckdb_scan_exec_methods; return (Node *)custom_scan_state; } @@ -57,6 +62,19 @@ Duckdb_CreateCustomScanState(CustomScan *cscan) { void Duckdb_BeginCustomScan(CustomScanState *cscanstate, EState *estate, int eflags) { DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)cscanstate; + auto prepare_result = DuckdbPrepare(duckdb_scan_state->query, estate->es_param_list_info); + auto prepared_query = std::move(std::get<0>(prepare_result)); + auto duckdb_connection = std::move(std::get<1>(prepare_result)); + + if (prepared_query->HasError()) { + elog(ERROR, "DuckDB re-planning failed %s", prepared_query->GetError().c_str()); + } + + duckdb_scan_state->duckdb_connection = duckdb_connection.release(); + duckdb_scan_state->prepared_statement = prepared_query.release(); + duckdb_scan_state->params = estate->es_param_list_info; + duckdb_scan_state->is_executed = false; + duckdb_scan_state->fetch_next = true; duckdb_scan_state->css.ss.ps.ps_ResultTupleDesc = duckdb_scan_state->css.ss.ss_ScanTupleSlot->tts_tupleDescriptor; HOLD_CANCEL_INTERRUPTS(); } @@ -66,8 +84,33 @@ ExecuteQuery(DuckdbScanState *state) { auto &prepared = *state->prepared_statement; auto &query_results = state->query_results; auto &connection = state->duckdb_connection; + auto pg_params = state->params; + const auto num_params = pg_params ? pg_params->numParams : 0; + duckdb::vector duckdb_params; + for (int i = 0; i < num_params; i++) { + ParamExternData *pg_param; + ParamExternData tmp_workspace; + + /* give hook a chance in case parameter is dynamic */ + if (pg_params->paramFetch != NULL) + pg_param = pg_params->paramFetch(pg_params, i + 1, false, &tmp_workspace); + else + pg_param = &pg_params->params[i]; + + if (pg_param->isnull) { + duckdb_params.push_back(duckdb::Value()); + } else { + if (!OidIsValid(pg_param->ptype)) { + elog(ERROR, "parameter with invalid type during execution"); + } + duckdb_params.push_back(pgduckdb::ConvertPostgresParameterToDuckValue(pg_param->value, pg_param->ptype)); + } + } - auto pending = prepared.PendingQuery(); + auto pending = prepared.PendingQuery(duckdb_params, true); + if (pending->HasError()) { + elog(ERROR, "DuckDB execute returned an error: %s", pending->GetError().c_str()); + } duckdb::PendingExecutionResult execution_result; do { execution_result = pending->ExecuteTask(); @@ -160,14 +203,14 @@ Duckdb_ReScanCustomScan(CustomScanState *node) { void Duckdb_ExplainCustomScan(CustomScanState *node, List *ancestors, ExplainState *es) { DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)node; - auto res = duckdb_scan_state->prepared_statement->Execute(); - std::string explain_output = "\n\n"; - auto chunk = res->Fetch(); + ExecuteQuery(duckdb_scan_state); + auto chunk = duckdb_scan_state->query_results->Fetch(); if (!chunk || chunk->size() == 0) { return; } /* Is it safe to hardcode this as result of DuckDB explain? */ auto value = chunk->GetValue(1, 0); + std::string explain_output = "\n\n"; explain_output += value.GetValue(); explain_output += "\n"; ExplainPropertyText("DuckDB Execution Plan", explain_output.c_str(), es); diff --git a/src/pgduckdb_planner.cpp b/src/pgduckdb_planner.cpp index 68cbfdd5..d6c7fe87 100644 --- a/src/pgduckdb_planner.cpp +++ b/src/pgduckdb_planner.cpp @@ -4,6 +4,8 @@ extern "C" { #include "postgres.h" #include "catalog/pg_type.h" #include "nodes/makefuncs.h" +#include "nodes/nodes.h" +#include "nodes/params.h" #include "optimizer/optimizer.h" #include "tcop/pquery.h" #include "utils/syscache.h" @@ -51,21 +53,44 @@ PlanQuery(Query *parse, ParamListInfo bound_params) { ); } -static Plan * -CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) { +std::tuple, duckdb::unique_ptr> +DuckdbPrepare(const Query *query, ParamListInfo bound_params) { + /* + * Copy the query, so the original one is not modified by the + * subquery_planner call that PlanQuery does. + */ + Query *copied_query = (Query *)copyObjectImpl(query); + const char *query_string = pgduckdb_pg_get_querydef(copied_query, false); - List *rtables = query->rtable; + if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) { + if (duckdb_explain_analyze) { + query_string = psprintf("EXPLAIN ANALYZE %s", query_string); + } else { + query_string = psprintf("EXPLAIN %s", query_string); + } + } + + elog(DEBUG2, "(PGDuckDB/DuckdbPrepare) Preparing: %s", query_string); + List *rtables = copied_query->rtable; /* Extract required vars for table */ int flags = PVC_RECURSE_AGGREGATES | PVC_RECURSE_WINDOWFUNCS | PVC_RECURSE_PLACEHOLDERS; - List *vars = list_concat(pull_var_clause((Node *)query->targetList, flags), - pull_var_clause((Node *)query->jointree->quals, flags)); - - PlannerInfo *query_planner_info = PlanQuery(query, bound_params); + List *vars = list_concat(pull_var_clause((Node *)copied_query->targetList, flags), + pull_var_clause((Node *)copied_query->jointree->quals, flags)); + PlannerInfo *query_planner_info = PlanQuery(copied_query, bound_params); auto duckdb_connection = pgduckdb::DuckdbCreateConnection(rtables, query_planner_info, vars, query_string); auto context = duckdb_connection->context; - auto prepared_query = context->Prepare(query_string); + return {std::move(prepared_query), std::move(duckdb_connection)}; +} + +static Plan * +CreatePlan(Query *query, ParamListInfo bound_params) { + /* + * Prepare the query, se we can get the returned types and column names. + */ + auto prepare_result = DuckdbPrepare(query, bound_params); + auto prepared_query = std::move(std::get<0>(prepare_result)); if (prepared_query->HasError()) { elog(WARNING, "(PGDuckDB/CreatePlan) Prepared query returned an error: '%s", @@ -101,12 +126,12 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) { duckdb_node->custom_scan_tlist = lappend(duckdb_node->custom_scan_tlist, - makeTargetEntry((Expr *)var, i + 1, (char *)prepared_query->GetNames()[i].c_str(), false)); + makeTargetEntry((Expr *)var, i + 1, (char *)pstrdup(prepared_query->GetNames()[i].c_str()), false)); ReleaseSysCache(tp); } - duckdb_node->custom_private = list_make2(duckdb_connection.release(), prepared_query.release()); + duckdb_node->custom_private = list_make1(query); duckdb_node->methods = &duckdb_scan_scan_methods; return (Plan *)duckdb_node; @@ -114,17 +139,8 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) { PlannedStmt * DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params) { - const char *query_string = pgduckdb_pg_get_querydef(parse, false); - - if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) { - if (duckdb_explain_analyze) { - query_string = psprintf("EXPLAIN ANALYZE %s", query_string); - } else { - query_string = psprintf("EXPLAIN %s", query_string); - } - } /* We need to check can we DuckDB create plan */ - Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, query_string, bound_params)); + Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, bound_params)); if (!duckdb_plan) { return nullptr; diff --git a/src/pgduckdb_types.cpp b/src/pgduckdb_types.cpp index b0f1e6cd..edd433ea 100644 --- a/src/pgduckdb_types.cpp +++ b/src/pgduckdb_types.cpp @@ -9,12 +9,15 @@ extern "C" { #include "miscadmin.h" #include "catalog/pg_type.h" #include "executor/tuptable.h" +#include "utils/builtins.h" #include "utils/numeric.h" #include "utils/uuid.h" #include "utils/array.h" #include "fmgr.h" #include "utils/lsyscache.h" #include "utils/syscache.h" +#include "utils/date.h" +#include "utils/timestamp.h" } #include "pgduckdb/pgduckdb.h" @@ -707,6 +710,47 @@ ConvertDecimal(const NumericVar &numeric) { return (NumericIsNegative(numeric) ? -base_res : base_res); } +/* + * Convert a Postgres Datum to a DuckDB Value. This is meant to be used to + * covert query parameters in a prepared statement to its DuckDB equivalent. + * Passing it a Datum that is stored on disk results in undefined behavior, + * because this fuction makes no effert to detoast the Datum. + */ +duckdb::Value +ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type) { + switch (postgres_type) { + case BOOLOID: + return duckdb::Value::BOOLEAN(DatumGetBool(value)); + case INT2OID: + return duckdb::Value::SMALLINT(DatumGetInt16(value)); + case INT4OID: + return duckdb::Value::INTEGER(DatumGetInt32(value)); + case INT8OID: + return duckdb::Value::BIGINT(DatumGetInt64(value)); + case BPCHAROID: + case TEXTOID: + case JSONOID: + case VARCHAROID: { + // FIXME: TextDatumGetCstring allocates so it needs a + // guard, but it's a macro not a function, so our current gaurd + // template does not handle it. + return duckdb::Value(TextDatumGetCString(value)); + } + case DATEOID: + return duckdb::Value::DATE(duckdb::date_t(DatumGetDateADT(value) + PGDUCKDB_DUCK_DATE_OFFSET)); + case TIMESTAMPOID: + return duckdb::Value::TIMESTAMP(duckdb::timestamp_t(DatumGetTimestamp(value) + PGDUCKDB_DUCK_TIMESTAMP_OFFSET)); + case FLOAT4OID: { + return duckdb::Value::FLOAT(DatumGetFloat4(value)); + } + case FLOAT8OID: { + return duckdb::Value::DOUBLE(DatumGetFloat8(value)); + } + default: + elog(ERROR, "Could not convert Postgres parameter of type: %d to DuckDB type", postgres_type); + } +} + void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset) { auto &type = result.GetType(); diff --git a/test/pycheck/explain_test.py b/test/pycheck/explain_test.py index a1411610..3569ba11 100644 --- a/test/pycheck/explain_test.py +++ b/test/pycheck/explain_test.py @@ -2,7 +2,7 @@ def test_explain(cur: Cursor): - cur.sql("CREATE TABLE test_table (id int primary key, name text)") + cur.sql("CREATE TABLE test_table (id int, name text)") result = cur.sql("EXPLAIN SELECT count(*) FROM test_table") plan = "\n".join(result) assert "UNGROUPED_AGGREGATE" in plan @@ -13,3 +13,17 @@ def test_explain(cur: Cursor): assert "Query Profiling Information" in plan assert "UNGROUPED_AGGREGATE" in plan assert "Total Time:" in plan + + result = cur.sql("EXPLAIN SELECT count(*) FROM test_table where id = %s", (1,)) + plan = "\n".join(result) + assert "UNGROUPED_AGGREGATE" in plan + assert "id=1 AND id IS NOT NULL" in plan + assert "Total Time:" not in plan + + result = cur.sql( + "EXPLAIN ANALYZE SELECT count(*) FROM test_table where id = %s", (1,) + ) + plan = "\n".join(result) + assert "UNGROUPED_AGGREGATE" in plan + assert "id=1 AND id IS NOT NULL" in plan + assert "Total Time:" in plan diff --git a/test/pycheck/prepared_test.py b/test/pycheck/prepared_test.py new file mode 100644 index 00000000..ecf31582 --- /dev/null +++ b/test/pycheck/prepared_test.py @@ -0,0 +1,95 @@ +from .utils import Cursor + +import datetime +import psycopg.types.json + + +def test_prepared(cur: Cursor): + cur.sql("CREATE TABLE test_table (id int)") + + # Try prepared query without parameters + q1 = "SELECT count(*) FROM test_table" + assert cur.sql(q1, prepare=True) == 0 + assert cur.sql(q1) == 0 + assert cur.sql(q1) == 0 + + cur.sql("INSERT INTO test_table VALUES (1), (2), (3)") + assert cur.sql(q1) == 3 + + # The following tests a prepared query that has parameters. + # There are two ways in which prepared queries that have parameters can be + # executed: + # 1. With a custom plan, where the query is prepared with the exact values + # 2. With a generic plan, where the query is planned without the values and + # the values get only substituted at execution time + # + # The below tests both of these cases, by setting the plan_cache_mode. + q2 = "SELECT count(*) FROM test_table where id = %s" + cur.sql("SET plan_cache_mode = 'force_custom_plan'") + assert cur.sql(q2, (1,), prepare=True) == 1 + assert cur.sql(q2, (1,)) == 1 + assert cur.sql(q2, (1,)) == 1 + assert cur.sql(q2, (3,)) == 1 + assert cur.sql(q2, (4,)) == 0 + + cur.sql("SET plan_cache_mode = 'force_generic_plan'") + assert cur.sql(q2, (1,)) == 1 # creates generic plan + assert cur.sql(q2, (1,)) == 1 + assert cur.sql(q2, (3,)) == 1 + assert cur.sql(q2, (4,)) == 0 + + +def test_extended(cur: Cursor): + cur.sql(""" + CREATE TABLE t( + bool BOOLEAN, + i2 SMALLINT, + i4 INT, + i8 BIGINT, + fl4 REAL, + fl8 DOUBLE PRECISION, + t1 TEXT, + t2 VARCHAR, + t3 BPCHAR, + d DATE, + ts TIMESTAMP, + json_obj JSON); + """) + + row = ( + True, + 2, + 4, + 8, + 4.0, + 8.0, + "t1", + "t2", + "t3", + datetime.date(2024, 5, 4), + datetime.datetime(2020, 1, 1, 1, 2, 3), + psycopg.types.json.Json({"a": 1}), + ) + cur.sql( + "INSERT INTO t VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", row + ) + + assert (True,) * len(row) == cur.sql( + """ + SELECT + bool = %s, + i2 = %s, + i4 = %s, + i8 = %s, + fl4 = %s, + fl8 = %s, + t1 = %s, + t2 = %s, + t3 = %s, + d = %s, + ts = %s, + json_obj::text = %s::text + FROM t; + """, + row, + ) diff --git a/test/pycheck/utils.py b/test/pycheck/utils.py index 89b864a1..33fbb7b0 100644 --- a/test/pycheck/utils.py +++ b/test/pycheck/utils.py @@ -237,8 +237,8 @@ def __init__(self, cursor: psycopg.Cursor): def __getattr__(self, name): return getattr(self.cursor, name) - def sql(self, query, params=None) -> Any: - self.execute(query, params) + def sql(self, query, params=None, **kwargs) -> Any: + self.execute(query, params, **kwargs) try: return simplify_query_results(self.fetchall()) except psycopg.ProgrammingError as e: @@ -256,11 +256,11 @@ def __init__(self, cursor: psycopg.AsyncCursor): def __getattr__(self, name): return getattr(self.cursor, name) - def sql(self, query, params=None): - return asyncio.ensure_future(self.sql_coroutine(query, params)) + def sql(self, query, params=None, **kwargs): + return asyncio.ensure_future(self.sql_coroutine(query, params, **kwargs)) - async def sql_coroutine(self, query, params=None) -> Any: - await self.execute(query, params) + async def sql_coroutine(self, query, params=None, **kwargs) -> Any: + await self.execute(query, params, **kwargs) try: return simplify_query_results(await self.fetchall()) except psycopg.ProgrammingError as e: diff --git a/test/regression/expected/basic.out b/test/regression/expected/basic.out index fc407602..ed181bea 100644 --- a/test/regression/expected/basic.out +++ b/test/regression/expected/basic.out @@ -16,6 +16,28 @@ SELECT a, COUNT(*) FROM t WHERE a > 5 GROUP BY a ORDER BY a; 9 | 100 (4 rows) +select COUNT(*) from t \bind \g + count +------- + 1000 +(1 row) + +select a, COUNT(*) from t WHERE a > $1 GROUP BY a ORDER BY a \bind 5 \g + a | count +---+------- + 6 | 100 + 7 | 100 + 8 | 100 + 9 | 100 +(4 rows) + +\bind 7 \g + a | count +---+------- + 8 | 100 + 9 | 100 +(2 rows) + SET duckdb.max_threads_per_query to 4; SELECT COUNT(*) FROM t; count diff --git a/test/regression/sql/basic.sql b/test/regression/sql/basic.sql index d32d3f0a..b9b3f5e9 100644 --- a/test/regression/sql/basic.sql +++ b/test/regression/sql/basic.sql @@ -6,6 +6,9 @@ SET client_min_messages to 'DEBUG1'; SELECT COUNT(*) FROM t; SELECT a, COUNT(*) FROM t WHERE a > 5 GROUP BY a ORDER BY a; +select COUNT(*) from t \bind \g +select a, COUNT(*) from t WHERE a > $1 GROUP BY a ORDER BY a \bind 5 \g +\bind 7 \g SET duckdb.max_threads_per_query to 4;