Skip to content

Commit

Permalink
Support extended protocol and prepared statements (#147)
Browse files Browse the repository at this point in the history
Previously prepared statements and the extended protocol were broken. We
need to support this feature because pretty much every client library
uses it. The reason that prepared statements did not work is that for
prepared statements to work, a Postgres plan needs to be fully copyable
using Postgres functionality. We were previously storing the DuckDB plan
in the Postgres plan, but since Postgres did not know how to copy the
DuckDB plan it would fail.

To work around that problem we now store the full Postgres `Query`
structure inside the DuckDB CustomScan node. Then during execution we
use this `Query` structure to re-plan the query during execution. This
works but it has the big downside that we're now planning the query
twice in DuckDB, once during planning to get the the column types of the
result and once during execution. For now that seems acceptable, so we
can at least support prepared statements and the extended protocol. In
the future we might want to do something smarter though, like
serializing and deserializing the DuckDB plan that is created at
planning time.

Note: This does not support all our supported types as arguments to a
prepared statement yet. After merging I'll open separate issue to track
adding the rest of these types.

Fixes #118
  • Loading branch information
JelteF committed Sep 23, 2024
1 parent b3d8315 commit 007d882
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 35 deletions.
6 changes: 6 additions & 0 deletions include/pgduckdb/pgduckdb_planner.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#pragma once

#include "duckdb.hpp"

extern "C" {
#include "postgres.h"
#include "optimizer/planner.h"
}

#include "pgduckdb/pgduckdb_duckdb.hpp"

extern bool duckdb_explain_analyze;

PlannedStmt *DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params);
std::tuple<duckdb::unique_ptr<duckdb::PreparedStatement>, duckdb::unique_ptr<duckdb::Connection>>
DuckdbPrepare(const Query *query, ParamListInfo bound_params);
1 change: 1 addition & 0 deletions include/pgduckdb/pgduckdb_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ constexpr int64_t PGDUCKDB_DUCK_TIMESTAMP_OFFSET = INT64CONST(10957) * USECS_PER

duckdb::LogicalType ConvertPostgresToDuckColumnType(Form_pg_attribute &attribute);
Oid GetPostgresDuckDBType(duckdb::LogicalType type);
duckdb::Value ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type);
void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset);
bool ConvertDuckToPostgresValue(TupleTableSlot *slot, duckdb::Value &value, idx_t col);
void InsertTupleIntoChunk(duckdb::DataChunk &output, duckdb::shared_ptr<PostgresScanGlobalState> scan_global_state,
Expand Down
11 changes: 11 additions & 0 deletions src/pgduckdb_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ extern "C" {
void
DuckdbExplainOneQueryHook(Query *query, int cursorOptions, IntoClause *into, ExplainState *es, const char *queryString,
ParamListInfo params, QueryEnvironment *queryEnv) {
/*
* It might seem sensible to store this data in the custom_private
* field of the CustomScan node, but that's not a trivial change to make.
* Storing this in a global variable works fine, as long as we only use
* this variable during planning when we're actually executing an explain
* QUERY (this can be checked by checking the commandTag of the
* ActivePortal). This even works when plans would normally be cached,
* because EXPLAIN always execute this hook whenever they are executed.
* EXPLAIN queries are also always re-planned (see
* standard_ExplainOneQuery).
*/
duckdb_explain_analyze = es->analyze;
prev_explain_one_query_hook(query, cursorOptions, into, es, queryString, params, queryEnv);
}
Expand Down
59 changes: 51 additions & 8 deletions src/pgduckdb_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
extern "C" {
#include "postgres.h"
#include "miscadmin.h"
#include "tcop/pquery.h"
#include "nodes/params.h"
#include "utils/ruleutils.h"
}

#include "pgduckdb/pgduckdb_node.hpp"
#include "pgduckdb/pgduckdb_types.hpp"
#include "pgduckdb/pgduckdb_duckdb.hpp"
#include "pgduckdb/pgduckdb_planner.hpp"

/* global variables */
CustomScanMethods duckdb_scan_scan_methods;
Expand All @@ -17,6 +22,8 @@ static CustomExecMethods duckdb_scan_exec_methods;

typedef struct DuckdbScanState {
CustomScanState css; /* must be first field */
const Query *query;
ParamListInfo params;
duckdb::Connection *duckdb_connection;
duckdb::PreparedStatement *prepared_statement;
bool is_executed;
Expand Down Expand Up @@ -46,17 +53,28 @@ static Node *
Duckdb_CreateCustomScanState(CustomScan *cscan) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)newNode(sizeof(DuckdbScanState), T_CustomScanState);
CustomScanState *custom_scan_state = &duckdb_scan_state->css;
duckdb_scan_state->duckdb_connection = (duckdb::Connection *)linitial(cscan->custom_private);
duckdb_scan_state->prepared_statement = (duckdb::PreparedStatement *)lsecond(cscan->custom_private);
duckdb_scan_state->is_executed = false;
duckdb_scan_state->fetch_next = true;

duckdb_scan_state->query = (const Query *)linitial(cscan->custom_private);
custom_scan_state->methods = &duckdb_scan_exec_methods;
return (Node *)custom_scan_state;
}

void
Duckdb_BeginCustomScan(CustomScanState *cscanstate, EState *estate, int eflags) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)cscanstate;
auto prepare_result = DuckdbPrepare(duckdb_scan_state->query, estate->es_param_list_info);
auto prepared_query = std::move(std::get<0>(prepare_result));
auto duckdb_connection = std::move(std::get<1>(prepare_result));

if (prepared_query->HasError()) {
elog(ERROR, "DuckDB re-planning failed %s", prepared_query->GetError().c_str());
}

duckdb_scan_state->duckdb_connection = duckdb_connection.release();
duckdb_scan_state->prepared_statement = prepared_query.release();
duckdb_scan_state->params = estate->es_param_list_info;
duckdb_scan_state->is_executed = false;
duckdb_scan_state->fetch_next = true;
duckdb_scan_state->css.ss.ps.ps_ResultTupleDesc = duckdb_scan_state->css.ss.ss_ScanTupleSlot->tts_tupleDescriptor;
HOLD_CANCEL_INTERRUPTS();
}
Expand All @@ -66,8 +84,33 @@ ExecuteQuery(DuckdbScanState *state) {
auto &prepared = *state->prepared_statement;
auto &query_results = state->query_results;
auto &connection = state->duckdb_connection;
auto pg_params = state->params;
const auto num_params = pg_params ? pg_params->numParams : 0;
duckdb::vector<duckdb::Value> duckdb_params;
for (int i = 0; i < num_params; i++) {
ParamExternData *pg_param;
ParamExternData tmp_workspace;

/* give hook a chance in case parameter is dynamic */
if (pg_params->paramFetch != NULL)
pg_param = pg_params->paramFetch(pg_params, i + 1, false, &tmp_workspace);
else
pg_param = &pg_params->params[i];

if (pg_param->isnull) {
duckdb_params.push_back(duckdb::Value());
} else {
if (!OidIsValid(pg_param->ptype)) {
elog(ERROR, "parameter with invalid type during execution");
}
duckdb_params.push_back(pgduckdb::ConvertPostgresParameterToDuckValue(pg_param->value, pg_param->ptype));
}
}

auto pending = prepared.PendingQuery();
auto pending = prepared.PendingQuery(duckdb_params, true);
if (pending->HasError()) {
elog(ERROR, "DuckDB execute returned an error: %s", pending->GetError().c_str());
}
duckdb::PendingExecutionResult execution_result;
do {
execution_result = pending->ExecuteTask();
Expand Down Expand Up @@ -160,14 +203,14 @@ Duckdb_ReScanCustomScan(CustomScanState *node) {
void
Duckdb_ExplainCustomScan(CustomScanState *node, List *ancestors, ExplainState *es) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)node;
auto res = duckdb_scan_state->prepared_statement->Execute();
std::string explain_output = "\n\n";
auto chunk = res->Fetch();
ExecuteQuery(duckdb_scan_state);
auto chunk = duckdb_scan_state->query_results->Fetch();
if (!chunk || chunk->size() == 0) {
return;
}
/* Is it safe to hardcode this as result of DuckDB explain? */
auto value = chunk->GetValue(1, 0);
std::string explain_output = "\n\n";
explain_output += value.GetValue<duckdb::string>();
explain_output += "\n";
ExplainPropertyText("DuckDB Execution Plan", explain_output.c_str(), es);
Expand Down
56 changes: 36 additions & 20 deletions src/pgduckdb_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ extern "C" {
#include "postgres.h"
#include "catalog/pg_type.h"
#include "nodes/makefuncs.h"
#include "nodes/nodes.h"
#include "nodes/params.h"
#include "optimizer/optimizer.h"
#include "tcop/pquery.h"
#include "utils/syscache.h"
Expand Down Expand Up @@ -51,21 +53,44 @@ PlanQuery(Query *parse, ParamListInfo bound_params) {
);
}

static Plan *
CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) {
std::tuple<duckdb::unique_ptr<duckdb::PreparedStatement>, duckdb::unique_ptr<duckdb::Connection>>
DuckdbPrepare(const Query *query, ParamListInfo bound_params) {
/*
* Copy the query, so the original one is not modified by the
* subquery_planner call that PlanQuery does.
*/
Query *copied_query = (Query *)copyObjectImpl(query);
const char *query_string = pgduckdb_pg_get_querydef(copied_query, false);

List *rtables = query->rtable;
if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) {
if (duckdb_explain_analyze) {
query_string = psprintf("EXPLAIN ANALYZE %s", query_string);
} else {
query_string = psprintf("EXPLAIN %s", query_string);
}
}

elog(DEBUG2, "(PGDuckDB/DuckdbPrepare) Preparing: %s", query_string);

List *rtables = copied_query->rtable;
/* Extract required vars for table */
int flags = PVC_RECURSE_AGGREGATES | PVC_RECURSE_WINDOWFUNCS | PVC_RECURSE_PLACEHOLDERS;
List *vars = list_concat(pull_var_clause((Node *)query->targetList, flags),
pull_var_clause((Node *)query->jointree->quals, flags));

PlannerInfo *query_planner_info = PlanQuery(query, bound_params);
List *vars = list_concat(pull_var_clause((Node *)copied_query->targetList, flags),
pull_var_clause((Node *)copied_query->jointree->quals, flags));
PlannerInfo *query_planner_info = PlanQuery(copied_query, bound_params);
auto duckdb_connection = pgduckdb::DuckdbCreateConnection(rtables, query_planner_info, vars, query_string);
auto context = duckdb_connection->context;

auto prepared_query = context->Prepare(query_string);
return {std::move(prepared_query), std::move(duckdb_connection)};
}

static Plan *
CreatePlan(Query *query, ParamListInfo bound_params) {
/*
* Prepare the query, se we can get the returned types and column names.
*/
auto prepare_result = DuckdbPrepare(query, bound_params);
auto prepared_query = std::move(std::get<0>(prepare_result));

if (prepared_query->HasError()) {
elog(WARNING, "(PGDuckDB/CreatePlan) Prepared query returned an error: '%s",
Expand Down Expand Up @@ -101,30 +126,21 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) {

duckdb_node->custom_scan_tlist =
lappend(duckdb_node->custom_scan_tlist,
makeTargetEntry((Expr *)var, i + 1, (char *)prepared_query->GetNames()[i].c_str(), false));
makeTargetEntry((Expr *)var, i + 1, (char *)pstrdup(prepared_query->GetNames()[i].c_str()), false));

ReleaseSysCache(tp);
}

duckdb_node->custom_private = list_make2(duckdb_connection.release(), prepared_query.release());
duckdb_node->custom_private = list_make1(query);
duckdb_node->methods = &duckdb_scan_scan_methods;

return (Plan *)duckdb_node;
}

PlannedStmt *
DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params) {
const char *query_string = pgduckdb_pg_get_querydef(parse, false);

if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) {
if (duckdb_explain_analyze) {
query_string = psprintf("EXPLAIN ANALYZE %s", query_string);
} else {
query_string = psprintf("EXPLAIN %s", query_string);
}
}
/* We need to check can we DuckDB create plan */
Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, query_string, bound_params));
Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, bound_params));

if (!duckdb_plan) {
return nullptr;
Expand Down
44 changes: 44 additions & 0 deletions src/pgduckdb_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ extern "C" {
#include "miscadmin.h"
#include "catalog/pg_type.h"
#include "executor/tuptable.h"
#include "utils/builtins.h"
#include "utils/numeric.h"
#include "utils/uuid.h"
#include "utils/array.h"
#include "fmgr.h"
#include "utils/lsyscache.h"
#include "utils/syscache.h"
#include "utils/date.h"
#include "utils/timestamp.h"
}

#include "pgduckdb/pgduckdb.h"
Expand Down Expand Up @@ -707,6 +710,47 @@ ConvertDecimal(const NumericVar &numeric) {
return (NumericIsNegative(numeric) ? -base_res : base_res);
}

/*
* Convert a Postgres Datum to a DuckDB Value. This is meant to be used to
* covert query parameters in a prepared statement to its DuckDB equivalent.
* Passing it a Datum that is stored on disk results in undefined behavior,
* because this fuction makes no effert to detoast the Datum.
*/
duckdb::Value
ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type) {
switch (postgres_type) {
case BOOLOID:
return duckdb::Value::BOOLEAN(DatumGetBool(value));
case INT2OID:
return duckdb::Value::SMALLINT(DatumGetInt16(value));
case INT4OID:
return duckdb::Value::INTEGER(DatumGetInt32(value));
case INT8OID:
return duckdb::Value::BIGINT(DatumGetInt64(value));
case BPCHAROID:
case TEXTOID:
case JSONOID:
case VARCHAROID: {
// FIXME: TextDatumGetCstring allocates so it needs a
// guard, but it's a macro not a function, so our current gaurd
// template does not handle it.
return duckdb::Value(TextDatumGetCString(value));
}
case DATEOID:
return duckdb::Value::DATE(duckdb::date_t(DatumGetDateADT(value) + PGDUCKDB_DUCK_DATE_OFFSET));
case TIMESTAMPOID:
return duckdb::Value::TIMESTAMP(duckdb::timestamp_t(DatumGetTimestamp(value) + PGDUCKDB_DUCK_TIMESTAMP_OFFSET));
case FLOAT4OID: {
return duckdb::Value::FLOAT(DatumGetFloat4(value));
}
case FLOAT8OID: {
return duckdb::Value::DOUBLE(DatumGetFloat8(value));
}
default:
elog(ERROR, "Could not convert Postgres parameter of type: %d to DuckDB type", postgres_type);
}
}

void
ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset) {
auto &type = result.GetType();
Expand Down
16 changes: 15 additions & 1 deletion test/pycheck/explain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def test_explain(cur: Cursor):
cur.sql("CREATE TABLE test_table (id int primary key, name text)")
cur.sql("CREATE TABLE test_table (id int, name text)")
result = cur.sql("EXPLAIN SELECT count(*) FROM test_table")
plan = "\n".join(result)
assert "UNGROUPED_AGGREGATE" in plan
Expand All @@ -13,3 +13,17 @@ def test_explain(cur: Cursor):
assert "Query Profiling Information" in plan
assert "UNGROUPED_AGGREGATE" in plan
assert "Total Time:" in plan

result = cur.sql("EXPLAIN SELECT count(*) FROM test_table where id = %s", (1,))
plan = "\n".join(result)
assert "UNGROUPED_AGGREGATE" in plan
assert "id=1 AND id IS NOT NULL" in plan
assert "Total Time:" not in plan

result = cur.sql(
"EXPLAIN ANALYZE SELECT count(*) FROM test_table where id = %s", (1,)
)
plan = "\n".join(result)
assert "UNGROUPED_AGGREGATE" in plan
assert "id=1 AND id IS NOT NULL" in plan
assert "Total Time:" in plan
Loading

0 comments on commit 007d882

Please sign in to comment.