Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support extended protocol and prepared statements #147

Merged
merged 13 commits into from
Sep 23, 2024
6 changes: 6 additions & 0 deletions include/pgduckdb/pgduckdb_planner.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#pragma once

#include "duckdb.hpp"

extern "C" {
#include "postgres.h"
#include "optimizer/planner.h"
}

#include "pgduckdb/pgduckdb_duckdb.hpp"

extern bool duckdb_explain_analyze;

PlannedStmt *DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params);
std::tuple<duckdb::unique_ptr<duckdb::PreparedStatement>, duckdb::unique_ptr<duckdb::Connection>>
DuckdbPrepare(const Query *query, ParamListInfo bound_params);
1 change: 1 addition & 0 deletions include/pgduckdb/pgduckdb_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ constexpr int64_t PGDUCKDB_DUCK_TIMESTAMP_OFFSET = INT64CONST(10957) * USECS_PER

duckdb::LogicalType ConvertPostgresToDuckColumnType(Form_pg_attribute &attribute);
Oid GetPostgresDuckDBType(duckdb::LogicalType type);
duckdb::Value ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type);
void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset);
bool ConvertDuckToPostgresValue(TupleTableSlot *slot, duckdb::Value &value, idx_t col);
void InsertTupleIntoChunk(duckdb::DataChunk &output, duckdb::shared_ptr<PostgresScanGlobalState> scan_global_state,
Expand Down
11 changes: 11 additions & 0 deletions src/pgduckdb_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ extern "C" {
void
DuckdbExplainOneQueryHook(Query *query, int cursorOptions, IntoClause *into, ExplainState *es, const char *queryString,
ParamListInfo params, QueryEnvironment *queryEnv) {
/*
* It might seem sensible to store this data in the custom_private
* field of the CustomScan node, but that's not a trivial change to make.
* Storing this in a global variable works fine, as long as we only use
* this variable during planning when we're actually executing an explain
* QUERY (this can be checked by checking the commandTag of the
* ActivePortal). This even works when plans would normally be cached,
* because EXPLAIN always execute this hook whenever they are executed.
* EXPLAIN queries are also always re-planned (see
* standard_ExplainOneQuery).
*/
duckdb_explain_analyze = es->analyze;
prev_explain_one_query_hook(query, cursorOptions, into, es, queryString, params, queryEnv);
}
Expand Down
59 changes: 51 additions & 8 deletions src/pgduckdb_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
extern "C" {
#include "postgres.h"
#include "miscadmin.h"
#include "tcop/pquery.h"
#include "nodes/params.h"
#include "utils/ruleutils.h"
}

#include "pgduckdb/pgduckdb_node.hpp"
#include "pgduckdb/pgduckdb_types.hpp"
#include "pgduckdb/pgduckdb_duckdb.hpp"
#include "pgduckdb/pgduckdb_planner.hpp"

/* global variables */
CustomScanMethods duckdb_scan_scan_methods;
Expand All @@ -17,6 +22,8 @@ static CustomExecMethods duckdb_scan_exec_methods;

typedef struct DuckdbScanState {
CustomScanState css; /* must be first field */
const Query *query;
ParamListInfo params;
duckdb::Connection *duckdb_connection;
duckdb::PreparedStatement *prepared_statement;
bool is_executed;
Expand Down Expand Up @@ -46,17 +53,28 @@ static Node *
Duckdb_CreateCustomScanState(CustomScan *cscan) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)newNode(sizeof(DuckdbScanState), T_CustomScanState);
CustomScanState *custom_scan_state = &duckdb_scan_state->css;
duckdb_scan_state->duckdb_connection = (duckdb::Connection *)linitial(cscan->custom_private);
duckdb_scan_state->prepared_statement = (duckdb::PreparedStatement *)lsecond(cscan->custom_private);
duckdb_scan_state->is_executed = false;
duckdb_scan_state->fetch_next = true;

duckdb_scan_state->query = (const Query *)linitial(cscan->custom_private);
custom_scan_state->methods = &duckdb_scan_exec_methods;
return (Node *)custom_scan_state;
}

void
Duckdb_BeginCustomScan(CustomScanState *cscanstate, EState *estate, int eflags) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)cscanstate;
auto prepare_result = DuckdbPrepare(duckdb_scan_state->query, estate->es_param_list_info);
auto prepared_query = std::move(std::get<0>(prepare_result));
auto duckdb_connection = std::move(std::get<1>(prepare_result));

if (prepared_query->HasError()) {
elog(ERROR, "DuckDB re-planning failed %s", prepared_query->GetError().c_str());
}

duckdb_scan_state->duckdb_connection = duckdb_connection.release();
duckdb_scan_state->prepared_statement = prepared_query.release();
duckdb_scan_state->params = estate->es_param_list_info;
duckdb_scan_state->is_executed = false;
duckdb_scan_state->fetch_next = true;
duckdb_scan_state->css.ss.ps.ps_ResultTupleDesc = duckdb_scan_state->css.ss.ss_ScanTupleSlot->tts_tupleDescriptor;
HOLD_CANCEL_INTERRUPTS();
}
Expand All @@ -66,8 +84,33 @@ ExecuteQuery(DuckdbScanState *state) {
auto &prepared = *state->prepared_statement;
auto &query_results = state->query_results;
auto &connection = state->duckdb_connection;
auto pg_params = state->params;
const auto num_params = pg_params ? pg_params->numParams : 0;
duckdb::vector<duckdb::Value> duckdb_params;
for (int i = 0; i < num_params; i++) {
ParamExternData *pg_param;
ParamExternData tmp_workspace;

/* give hook a chance in case parameter is dynamic */
if (pg_params->paramFetch != NULL)
pg_param = pg_params->paramFetch(pg_params, i + 1, false, &tmp_workspace);
else
pg_param = &pg_params->params[i];

if (pg_param->isnull) {
duckdb_params.push_back(duckdb::Value());
} else {
if (!OidIsValid(pg_param->ptype)) {
elog(ERROR, "parameter with invalid type during execution");
}
duckdb_params.push_back(pgduckdb::ConvertPostgresParameterToDuckValue(pg_param->value, pg_param->ptype));
}
}

auto pending = prepared.PendingQuery();
auto pending = prepared.PendingQuery(duckdb_params, true);
if (pending->HasError()) {
elog(ERROR, "DuckDB execute returned an error: %s", pending->GetError().c_str());
}
duckdb::PendingExecutionResult execution_result;
do {
execution_result = pending->ExecuteTask();
Expand Down Expand Up @@ -160,14 +203,14 @@ Duckdb_ReScanCustomScan(CustomScanState *node) {
void
Duckdb_ExplainCustomScan(CustomScanState *node, List *ancestors, ExplainState *es) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)node;
auto res = duckdb_scan_state->prepared_statement->Execute();
std::string explain_output = "\n\n";
auto chunk = res->Fetch();
ExecuteQuery(duckdb_scan_state);
auto chunk = duckdb_scan_state->query_results->Fetch();
if (!chunk || chunk->size() == 0) {
return;
}
/* Is it safe to hardcode this as result of DuckDB explain? */
auto value = chunk->GetValue(1, 0);
std::string explain_output = "\n\n";
explain_output += value.GetValue<duckdb::string>();
explain_output += "\n";
ExplainPropertyText("DuckDB Execution Plan", explain_output.c_str(), es);
Expand Down
56 changes: 36 additions & 20 deletions src/pgduckdb_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ extern "C" {
#include "postgres.h"
#include "catalog/pg_type.h"
#include "nodes/makefuncs.h"
#include "nodes/nodes.h"
#include "nodes/params.h"
#include "optimizer/optimizer.h"
#include "tcop/pquery.h"
#include "utils/syscache.h"
Expand Down Expand Up @@ -51,21 +53,44 @@ PlanQuery(Query *parse, ParamListInfo bound_params) {
);
}

static Plan *
CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) {
std::tuple<duckdb::unique_ptr<duckdb::PreparedStatement>, duckdb::unique_ptr<duckdb::Connection>>
DuckdbPrepare(const Query *query, ParamListInfo bound_params) {
/*
* Copy the query, so the original one is not modified by the
* subquery_planner call that PlanQuery does.
*/
Query *copied_query = (Query *)copyObjectImpl(query);
const char *query_string = pgduckdb_pg_get_querydef(copied_query, false);

List *rtables = query->rtable;
if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) {
if (duckdb_explain_analyze) {
query_string = psprintf("EXPLAIN ANALYZE %s", query_string);
} else {
query_string = psprintf("EXPLAIN %s", query_string);
}
}

elog(DEBUG2, "(PGDuckDB/DuckdbPrepare) Preparing: %s", query_string);

List *rtables = copied_query->rtable;
/* Extract required vars for table */
int flags = PVC_RECURSE_AGGREGATES | PVC_RECURSE_WINDOWFUNCS | PVC_RECURSE_PLACEHOLDERS;
List *vars = list_concat(pull_var_clause((Node *)query->targetList, flags),
pull_var_clause((Node *)query->jointree->quals, flags));

PlannerInfo *query_planner_info = PlanQuery(query, bound_params);
List *vars = list_concat(pull_var_clause((Node *)copied_query->targetList, flags),
pull_var_clause((Node *)copied_query->jointree->quals, flags));
PlannerInfo *query_planner_info = PlanQuery(copied_query, bound_params);
auto duckdb_connection = pgduckdb::DuckdbCreateConnection(rtables, query_planner_info, vars, query_string);
auto context = duckdb_connection->context;

auto prepared_query = context->Prepare(query_string);
return {std::move(prepared_query), std::move(duckdb_connection)};
}

static Plan *
CreatePlan(Query *query, ParamListInfo bound_params) {
/*
* Prepare the query, se we can get the returned types and column names.
*/
auto prepare_result = DuckdbPrepare(query, bound_params);
auto prepared_query = std::move(std::get<0>(prepare_result));

if (prepared_query->HasError()) {
elog(WARNING, "(PGDuckDB/CreatePlan) Prepared query returned an error: '%s",
Expand Down Expand Up @@ -101,30 +126,21 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) {

duckdb_node->custom_scan_tlist =
lappend(duckdb_node->custom_scan_tlist,
makeTargetEntry((Expr *)var, i + 1, (char *)prepared_query->GetNames()[i].c_str(), false));
makeTargetEntry((Expr *)var, i + 1, (char *)pstrdup(prepared_query->GetNames()[i].c_str()), false));

ReleaseSysCache(tp);
}

duckdb_node->custom_private = list_make2(duckdb_connection.release(), prepared_query.release());
duckdb_node->custom_private = list_make1(query);
duckdb_node->methods = &duckdb_scan_scan_methods;

return (Plan *)duckdb_node;
}

PlannedStmt *
DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params) {
const char *query_string = pgduckdb_pg_get_querydef(parse, false);

if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) {
if (duckdb_explain_analyze) {
query_string = psprintf("EXPLAIN ANALYZE %s", query_string);
} else {
query_string = psprintf("EXPLAIN %s", query_string);
}
}
/* We need to check can we DuckDB create plan */
Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, query_string, bound_params));
Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, bound_params));

if (!duckdb_plan) {
return nullptr;
Expand Down
44 changes: 44 additions & 0 deletions src/pgduckdb_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ extern "C" {
#include "miscadmin.h"
#include "catalog/pg_type.h"
#include "executor/tuptable.h"
#include "utils/builtins.h"
#include "utils/numeric.h"
#include "utils/uuid.h"
#include "utils/array.h"
#include "fmgr.h"
#include "utils/lsyscache.h"
#include "utils/syscache.h"
#include "utils/date.h"
#include "utils/timestamp.h"
}

#include "pgduckdb/pgduckdb.h"
Expand Down Expand Up @@ -707,6 +710,47 @@ ConvertDecimal(const NumericVar &numeric) {
return (NumericIsNegative(numeric) ? -base_res : base_res);
}

/*
* Convert a Postgres Datum to a DuckDB Value. This is meant to be used to
* covert query parameters in a prepared statement to its DuckDB equivalent.
* Passing it a Datum that is stored on disk results in undefined behavior,
* because this fuction makes no effert to detoast the Datum.
*/
duckdb::Value
ConvertPostgresParameterToDuckValue(Datum value, Oid postgres_type) {
switch (postgres_type) {
case BOOLOID:
return duckdb::Value::BOOLEAN(DatumGetBool(value));
case INT2OID:
return duckdb::Value::SMALLINT(DatumGetInt16(value));
case INT4OID:
return duckdb::Value::INTEGER(DatumGetInt32(value));
case INT8OID:
return duckdb::Value::BIGINT(DatumGetInt64(value));
case BPCHAROID:
case TEXTOID:
case JSONOID:
case VARCHAROID: {
// FIXME: TextDatumGetCstring allocates so it needs a
// guard, but it's a macro not a function, so our current gaurd
// template does not handle it.
return duckdb::Value(TextDatumGetCString(value));
}
case DATEOID:
return duckdb::Value::DATE(duckdb::date_t(DatumGetDateADT(value) + PGDUCKDB_DUCK_DATE_OFFSET));
case TIMESTAMPOID:
return duckdb::Value::TIMESTAMP(duckdb::timestamp_t(DatumGetTimestamp(value) + PGDUCKDB_DUCK_TIMESTAMP_OFFSET));
case FLOAT4OID: {
return duckdb::Value::FLOAT(DatumGetFloat4(value));
}
case FLOAT8OID: {
return duckdb::Value::DOUBLE(DatumGetFloat8(value));
}
default:
elog(ERROR, "Could not convert Postgres parameter of type: %d to DuckDB type", postgres_type);
}
}

void
ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset) {
auto &type = result.GetType();
Expand Down
16 changes: 15 additions & 1 deletion test/pycheck/explain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def test_explain(cur: Cursor):
cur.sql("CREATE TABLE test_table (id int primary key, name text)")
cur.sql("CREATE TABLE test_table (id int, name text)")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to change this definition to not create an index because index scans are currently broken, see #183 for details.

result = cur.sql("EXPLAIN SELECT count(*) FROM test_table")
plan = "\n".join(result)
assert "UNGROUPED_AGGREGATE" in plan
Expand All @@ -13,3 +13,17 @@ def test_explain(cur: Cursor):
assert "Query Profiling Information" in plan
assert "UNGROUPED_AGGREGATE" in plan
assert "Total Time:" in plan

result = cur.sql("EXPLAIN SELECT count(*) FROM test_table where id = %s", (1,))
plan = "\n".join(result)
assert "UNGROUPED_AGGREGATE" in plan
assert "id=1 AND id IS NOT NULL" in plan
assert "Total Time:" not in plan

result = cur.sql(
"EXPLAIN ANALYZE SELECT count(*) FROM test_table where id = %s", (1,)
)
plan = "\n".join(result)
assert "UNGROUPED_AGGREGATE" in plan
assert "id=1 AND id IS NOT NULL" in plan
assert "Total Time:" in plan
Loading