Skip to content

Commit

Permalink
Make DuckDB a singleton per connection process
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-- authored and JelteF committed Sep 12, 2024
1 parent 9dd113e commit 7a5087a
Show file tree
Hide file tree
Showing 15 changed files with 411 additions and 50 deletions.
25 changes: 24 additions & 1 deletion include/pgduckdb/pgduckdb_duckdb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,30 @@ extern "C" {

namespace pgduckdb {

duckdb::unique_ptr<duckdb::DuckDB> DuckdbOpenDatabase();
class DuckDBManager {
public:
static inline const DuckDBManager &
Get() {
static DuckDBManager instance;
return instance;
}

inline duckdb::DuckDB &
GetDatabase() const {
return *database;
}

private:
DuckDBManager();
void InitializeDatabase();

void LoadSecrets(duckdb::ClientContext &);
void LoadExtensions(duckdb::ClientContext &);
void LoadFunctions(duckdb::ClientContext &);

duckdb::unique_ptr<duckdb::DuckDB> database;
};

duckdb::unique_ptr<duckdb::Connection> DuckdbCreateConnection(List *rtables, PlannerInfo *planner_info,
List *needed_columns, const char *query);

Expand Down
6 changes: 6 additions & 0 deletions include/pgduckdb/pgduckdb_planner.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#pragma once

#include "duckdb.hpp"

extern "C" {
#include "postgres.h"
#include "optimizer/planner.h"
}

#include "pgduckdb/pgduckdb_duckdb.hpp"

extern bool duckdb_explain_analyze;

PlannedStmt *DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params);
std::tuple<duckdb::unique_ptr<duckdb::PreparedStatement>, duckdb::unique_ptr<duckdb::Connection>>
DuckdbPrepare(Query *query, ParamListInfo bound_params);
1 change: 1 addition & 0 deletions include/pgduckdb/pgduckdb_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ constexpr int64_t PGDUCKDB_DUCK_TIMESTAMP_OFFSET = INT64CONST(10957) * USECS_PER

duckdb::LogicalType ConvertPostgresToDuckColumnType(Form_pg_attribute &attribute);
Oid GetPostgresDuckDBType(duckdb::LogicalType type);
duckdb::Value ConvertPostgresToDuckValue(Datum value, Oid postgres_type);
void ConvertPostgresToDuckValue(Datum value, duckdb::Vector &result, idx_t offset);
void ConvertDuckToPostgresValue(TupleTableSlot *slot, duckdb::Value &value, idx_t col);
void InsertTupleIntoChunk(duckdb::DataChunk &output, duckdb::shared_ptr<PostgresScanGlobalState> scan_global_state,
Expand Down
6 changes: 3 additions & 3 deletions include/pgduckdb/scan/postgres_scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ class PostgresScanLocalState {
bool m_exhausted_scan;
};

struct PostgresReplacementScanData : public duckdb::ReplacementScanData {
struct PostgresReplacementScanDataClientContextState : public duckdb::ClientContextState {
public:
PostgresReplacementScanData(List *rtables, PlannerInfo *query_planner_info, List *needed_columns,
PostgresReplacementScanDataClientContextState(List *rtables, PlannerInfo *query_planner_info, List *needed_columns,
const char *query_string)
: m_rtables(rtables), m_query_planner_info(query_planner_info), m_needed_columns(needed_columns),
m_query_string(query_string) {
}
~PostgresReplacementScanData() override {};
~PostgresReplacementScanDataClientContextState() override {};

public:
List *m_rtables;
Expand Down
48 changes: 30 additions & 18 deletions src/pgduckdb_duckdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,23 @@ GetExtensionDirectory() {
return duckdb_extension_directory;
}

duckdb::unique_ptr<duckdb::DuckDB>
DuckdbOpenDatabase() {
DuckDBManager::DuckDBManager() {
elog(DEBUG2, "Creating DuckDB instance");

duckdb::DBConfig config;
config.SetOptionByName("extension_directory", GetExtensionDirectory());
return duckdb::make_uniq<duckdb::DuckDB>(nullptr, &config);
}

duckdb::unique_ptr<duckdb::Connection>
DuckdbCreateConnection(List *rtables, PlannerInfo *planner_info, List *needed_columns, const char *query) {
auto db = DuckdbOpenDatabase();

/* Add tables */
db->instance->config.replacement_scans.emplace_back(
pgduckdb::PostgresReplacementScan,
duckdb::make_uniq_base<duckdb::ReplacementScanData, PostgresReplacementScanData>(rtables, planner_info,
needed_columns, query));
database = duckdb::make_uniq<duckdb::DuckDB>(nullptr, &config);

auto connection = duckdb::make_uniq<duckdb::Connection>(*db);
auto connection = duckdb::make_uniq<duckdb::Connection>(*database);

// Add the postgres_scan inserted by the replacement scan
auto &context = *connection->context;
LoadFunctions(context);
LoadSecrets(context);
LoadExtensions(context);
}

void
DuckDBManager::LoadFunctions(duckdb::ClientContext &context) {
pgduckdb::PostgresSeqScanFunction seq_scan_fun;
duckdb::CreateTableFunctionInfo seq_scan_info(seq_scan_fun);

Expand All @@ -93,12 +88,15 @@ DuckdbCreateConnection(List *rtables, PlannerInfo *planner_info, List *needed_co

auto &catalog = duckdb::Catalog::GetSystemCatalog(context);
context.transaction.BeginTransaction();
auto &instance = *db->instance;
auto &instance = *database->instance;
duckdb::ExtensionUtil::RegisterType(instance, "UnsupportedPostgresType", duckdb::LogicalTypeId::VARCHAR);
catalog.CreateTableFunction(context, &seq_scan_info);
catalog.CreateTableFunction(context, &index_scan_info);
context.transaction.Commit();
}

void
DuckDBManager::LoadSecrets(duckdb::ClientContext &context) {
auto duckdb_secrets = ReadDuckdbSecrets();

int secret_id = 0;
Expand Down Expand Up @@ -128,7 +126,10 @@ DuckdbCreateConnection(List *rtables, PlannerInfo *planner_info, List *needed_co
pfree(secret_key->data);
secret_id++;
}
}

void
DuckDBManager::LoadExtensions(duckdb::ClientContext &context) {
auto duckdb_extensions = ReadDuckdbExtensions();

for (auto &extension : duckdb_extensions) {
Expand All @@ -142,8 +143,19 @@ DuckdbCreateConnection(List *rtables, PlannerInfo *planner_info, List *needed_co
}
pfree(duckdb_extension->data);
}
}

return connection;
duckdb::unique_ptr<duckdb::Connection>
DuckdbCreateConnection(List *rtables, PlannerInfo *planner_info, List *needed_columns, const char *query) {
auto &db = DuckDBManager::Get().GetDatabase();
/* Add DuckDB replacement scan to read PG data */
auto &scans = db.instance->config.replacement_scans;
scans.emplace_back(pgduckdb::PostgresReplacementScan);
auto con = duckdb::make_uniq<duckdb::Connection>(db);
con->context->registered_state->Insert("postgres_scan",
duckdb::make_shared_ptr<PostgresReplacementScanDataClientContextState>(
rtables, planner_info, needed_columns, query));
return con;
}

} // namespace pgduckdb
2 changes: 2 additions & 0 deletions src/pgduckdb_hooks.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "duckdb.hpp"

extern "C" {
#include "postgres.h"
#include "catalog/pg_namespace.h"
Expand Down
55 changes: 50 additions & 5 deletions src/pgduckdb_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
extern "C" {
#include "postgres.h"
#include "miscadmin.h"
#include "tcop/pquery.h"
#include "nodes/params.h"
#include "utils/ruleutils.h"
}

#include "pgduckdb/pgduckdb_node.hpp"
#include "pgduckdb/pgduckdb_types.hpp"
#include "pgduckdb/pgduckdb_duckdb.hpp"
#include "pgduckdb/pgduckdb_planner.hpp"

/* global variables */
CustomScanMethods duckdb_scan_scan_methods;
Expand All @@ -17,6 +22,8 @@ static CustomExecMethods duckdb_scan_exec_methods;

typedef struct DuckdbScanState {
CustomScanState css; /* must be first field */
Query *query;
ParamListInfo params;
duckdb::Connection *duckdb_connection;
duckdb::PreparedStatement *prepared_statement;
bool is_executed;
Expand Down Expand Up @@ -46,17 +53,29 @@ static Node *
Duckdb_CreateCustomScanState(CustomScan *cscan) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)newNode(sizeof(DuckdbScanState), T_CustomScanState);
CustomScanState *custom_scan_state = &duckdb_scan_state->css;
duckdb_scan_state->duckdb_connection = (duckdb::Connection *)linitial(cscan->custom_private);
duckdb_scan_state->prepared_statement = (duckdb::PreparedStatement *)lsecond(cscan->custom_private);
duckdb_scan_state->is_executed = false;
duckdb_scan_state->fetch_next = true;

duckdb_scan_state->query = (Query *)linitial(cscan->custom_private);
/* FIXME: We should pass a sensible bound_params, this breaks prepared statements */
custom_scan_state->methods = &duckdb_scan_exec_methods;
return (Node *)custom_scan_state;
}

void
Duckdb_BeginCustomScan(CustomScanState *cscanstate, EState *estate, int eflags) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)cscanstate;
auto prepare_result = DuckdbPrepare(duckdb_scan_state->query, estate->es_param_list_info);
auto prepared_query = std::move(std::get<0>(prepare_result));
auto duckdb_connection = std::move(std::get<1>(prepare_result));

if (prepared_query->HasError()) {
elog(ERROR, "DuckDB re-planning failed %s", prepared_query->GetError().c_str());
}

duckdb_scan_state->duckdb_connection = duckdb_connection.release();
duckdb_scan_state->prepared_statement = prepared_query.release();
duckdb_scan_state->params = estate->es_param_list_info;
duckdb_scan_state->is_executed = false;
duckdb_scan_state->fetch_next = true;
duckdb_scan_state->css.ss.ps.ps_ResultTupleDesc = duckdb_scan_state->css.ss.ss_ScanTupleSlot->tts_tupleDescriptor;
HOLD_CANCEL_INTERRUPTS();
}
Expand All @@ -66,8 +85,34 @@ ExecuteQuery(DuckdbScanState *state) {
auto &prepared = *state->prepared_statement;
auto &query_results = state->query_results;
auto &connection = state->duckdb_connection;
auto pg_params = state->params;
duckdb::vector<duckdb::Value> duckdb_params;
if (pg_params) {
for (int i = 0; i < pg_params->numParams; i++) {
ParamExternData *pg_param;
ParamExternData tmp_workspace;

/* give hook a chance in case parameter is dynamic */
if (pg_params->paramFetch != NULL)
pg_param = pg_params->paramFetch(pg_params, i + 1, false, &tmp_workspace);
else
pg_param = &pg_params->params[i];

if (pg_param->isnull) {
duckdb_params.push_back(duckdb::Value());
} else {
if (!OidIsValid(pg_param->ptype)) {
elog(ERROR, "parameter with invalid type during execution");
}
duckdb_params.push_back(pgduckdb::ConvertPostgresToDuckValue(pg_param->value, pg_param->ptype));
}
}
}

auto pending = prepared.PendingQuery();
auto pending = prepared.PendingQuery(duckdb_params, true);
if (pending->HasError()) {
elog(ERROR, "Duckdb execute returned an error: %s", pending->GetError().c_str());
}
duckdb::PendingExecutionResult execution_result;
do {
execution_result = pending->ExecuteTask();
Expand Down
4 changes: 2 additions & 2 deletions src/pgduckdb_options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ ReadDuckdbExtensions() {

static bool
DuckdbInstallExtension(Datum name) {
auto db = DuckdbOpenDatabase();
auto connection = duckdb::make_uniq<duckdb::Connection>(*db);
auto &db = DuckDBManager::Get().GetDatabase();
auto connection = duckdb::make_uniq<duckdb::Connection>(db);
auto &context = *connection->context;

auto extension_name = DatumToString(name);
Expand Down
52 changes: 35 additions & 17 deletions src/pgduckdb_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ extern "C" {
#include "postgres.h"
#include "catalog/pg_type.h"
#include "nodes/makefuncs.h"
#include "nodes/nodes.h"
#include "nodes/params.h"
#include "optimizer/optimizer.h"
#include "tcop/pquery.h"
#include "utils/syscache.h"
Expand Down Expand Up @@ -51,21 +53,46 @@ PlanQuery(Query *parse, ParamListInfo bound_params) {
);
}

static Plan *
CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) {
std::tuple<duckdb::unique_ptr<duckdb::PreparedStatement>, duckdb::unique_ptr<duckdb::Connection>>
DuckdbPrepare(Query *query, ParamListInfo bound_params) {
const char *query_string = pgduckdb_pg_get_querydef(query, false);

List *rtables = query->rtable;
/* TODO: Move this state into custom_private probably */
if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) {
if (duckdb_explain_analyze) {
query_string = psprintf("EXPLAIN ANALYZE %s", query_string);
} else {
query_string = psprintf("EXPLAIN %s", query_string);
}
}

List *rtables = query->rtable;
/* Extract required vars for table */
int flags = PVC_RECURSE_AGGREGATES | PVC_RECURSE_WINDOWFUNCS | PVC_RECURSE_PLACEHOLDERS;
List *vars = list_concat(pull_var_clause((Node *)query->targetList, flags),
pull_var_clause((Node *)query->jointree->quals, flags));

PlannerInfo *query_planner_info = PlanQuery(query, bound_params);
auto duckdb_connection = pgduckdb::DuckdbCreateConnection(rtables, query_planner_info, vars, query_string);
auto context = duckdb_connection->context;

auto prepared_query = context->Prepare(query_string);
return {std::move(prepared_query), std::move(duckdb_connection)};
}

static Plan *
CreatePlan(Query *query, ParamListInfo bound_params) {
/*
* Copy the arguments so we can attach unmodified versions to the
* custom_private field at the end of the function. DuckdbPrepare will
* slightly modify the query tree, because it calls subquery_planner, and
* that slightly modifies the query tree)
*/
auto query_copy = (Query *)copyObjectImpl(query);

/*
* Prepare the query, se we can get the returned types and column names.
*/
auto prepare_result = DuckdbPrepare(query, bound_params);
auto prepared_query = std::move(std::get<0>(prepare_result));

if (prepared_query->HasError()) {
elog(WARNING, "(DuckDB) %s", prepared_query->GetError().c_str());
Expand Down Expand Up @@ -93,30 +120,21 @@ CreatePlan(Query *query, const char *query_string, ParamListInfo bound_params) {

duckdb_node->custom_scan_tlist =
lappend(duckdb_node->custom_scan_tlist,
makeTargetEntry((Expr *)var, i + 1, (char *)prepared_query->GetNames()[i].c_str(), false));
makeTargetEntry((Expr *)var, i + 1, (char *)pstrdup(prepared_query->GetNames()[i].c_str()), false));

ReleaseSysCache(tp);
}

duckdb_node->custom_private = list_make2(duckdb_connection.release(), prepared_query.release());
duckdb_node->custom_private = list_make1(query_copy);
duckdb_node->methods = &duckdb_scan_scan_methods;

return (Plan *)duckdb_node;
}

PlannedStmt *
DuckdbPlanNode(Query *parse, int cursor_options, ParamListInfo bound_params) {
const char *query_string = pgduckdb_pg_get_querydef(parse, false);

if (ActivePortal && ActivePortal->commandTag == CMDTAG_EXPLAIN) {
if (duckdb_explain_analyze) {
query_string = psprintf("EXPLAIN ANALYZE %s", query_string);
} else {
query_string = psprintf("EXPLAIN %s", query_string);
}
}
/* We need to check can we DuckDB create plan */
Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, query_string, bound_params));
Plan *duckdb_plan = (Plan *)castNode(CustomScan, CreatePlan(parse, bound_params));

if (!duckdb_plan) {
return nullptr;
Expand Down
Loading

0 comments on commit 7a5087a

Please sign in to comment.