Skip to content

Commit

Permalink
Handle C++ exceptions in PG hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-- committed Oct 10, 2024
1 parent 303954a commit 692de3b
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 98 deletions.
13 changes: 12 additions & 1 deletion include/pgduckdb/pgduckdb_duckdb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,24 @@ class DuckDBManager {
static inline DuckDBManager &
Get() {
static DuckDBManager instance;
if (!instance.database) {
instance.Initialize();
}
return instance;
}

duckdb::unique_ptr<duckdb::Connection> GetConnection();
static inline duckdb::DuckDB &
GetDatabase() {
return *Get().database;
}

static duckdb::unique_ptr<duckdb::Connection>
CreateConnection();
private:
DuckDBManager();

void Initialize();

void InitializeDatabase();
bool CheckSecretsSeq();
void LoadSecrets(duckdb::ClientContext &);
Expand Down
10 changes: 10 additions & 0 deletions include/pgduckdb/pgduckdb_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ extern "C" {

#include "duckdb/common/exception.hpp"
#include "duckdb/common/error_data.hpp"
#include "pgduckdb/pgduckdb_duckdb.hpp"

#include <vector>
#include <string>
Expand Down Expand Up @@ -101,4 +102,13 @@ DuckDBFunctionGuard(FuncType duckdb_function, const char* function_name, FuncArg
std::abort(); // Cannot reach.
}

inline duckdb::unique_ptr<duckdb::QueryResult>
DuckDBQueryOrThrow(duckdb::ClientContext &context, const std::string &query) {
auto res = context.Query(query, false);
if (res->HasError()) {
res->ThrowError();
}
return res;
}

} // namespace pgduckdb
4 changes: 3 additions & 1 deletion src/catalog/pgduckdb_transaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ SchemaItems::GetTable(const string &entry_name) {
if (it != tables.end()) {
return it->second.get();
}

auto snapshot = schema->snapshot;
auto &catalog = schema->catalog;

Expand All @@ -59,7 +60,7 @@ SchemaItems::GetTable(const string &entry_name) {
// Check if the Relation is a VIEW
auto tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(rel_oid));
if (!HeapTupleIsValid(tuple)) {
elog(ERROR, "Cache lookup failed for relation %u", rel_oid);
throw std::runtime_error("Cache lookup failed for relation " + std::to_string(rel_oid));
}

auto relForm = (Form_pg_class)GETSTRUCT(tuple);
Expand All @@ -71,6 +72,7 @@ SchemaItems::GetTable(const string &entry_name) {
// will get bound again and hit a PostgresIndexTable / PostgresHeapTable.
return nullptr;
}

ReleaseSysCache(tuple);

::Relation rel = PostgresTable::OpenRelation(rel_oid);
Expand Down
6 changes: 3 additions & 3 deletions src/pgduckdb_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ extern "C" {
*/
void
DuckdbTruncateTable(Oid relation_oid) {
auto connection = pgduckdb::DuckDBManager::Get().GetConnection();
auto connection = pgduckdb::DuckDBManager::CreateConnection();
auto &context = *connection->context;
auto query = std::string("TRUNCATE ") + pgduckdb_relation_name(relation_oid);
auto result = context.Query(query, false);
Expand Down Expand Up @@ -162,7 +162,7 @@ duckdb_create_table_trigger(PG_FUNCTION_ARGS) {

std::string query_string(pgduckdb_get_tabledef(relid));

auto connection = pgduckdb::DuckDBManager::Get().GetConnection();
auto connection = pgduckdb::DuckDBManager::CreateConnection();
auto &context = *connection->context;
auto result = context.Query(query_string, false);
if (result->HasError()) {
Expand Down Expand Up @@ -230,7 +230,7 @@ duckdb_drop_table_trigger(PG_FUNCTION_ARGS) {
*/
PreventInTransactionBlock(true, "DuckDB queries");

auto connection = pgduckdb::DuckDBManager::Get().GetConnection();
auto connection = pgduckdb::DuckDBManager::CreateConnection();
auto &context = *connection->context;

auto result = context.Query("BEGIN TRANSACTION", false);
Expand Down
108 changes: 50 additions & 58 deletions src/pgduckdb_duckdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,58 +30,59 @@ extern "C" {
#include <sys/stat.h>
#include <unistd.h>
#include <errno.h>
#include <filesystem>

namespace pgduckdb {

static bool
CheckDataDirectory(const char *data_directory) {
CheckDataDirectory(const std::string &data_directory) {
struct stat info;

if (lstat(data_directory, &info) != 0) {
std::ostringstream oss;
if (lstat(data_directory.c_str(), &info) != 0) {
if (errno == ENOENT) {
elog(DEBUG2, "(PGDuckDB/CheckDataDirectory) Directory `%s` doesn't exists", data_directory);
elog(DEBUG2, "(PGDuckDB/CheckDataDirectory) Directory `%s` doesn't exists", data_directory.c_str());
return false;
} else if (errno == EACCES) {
elog(ERROR, "(PGDuckDB/CheckDataDirectory) Can't access `%s` directory", data_directory);
} else {
elog(ERROR, "(PGDuckDB/CheckDataDirectory) Other error when reading `%s`", data_directory);
}
}

if (!S_ISDIR(info.st_mode)) {
elog(WARNING, "(PGDuckDB/CheckDataDirectory) `%s` is not directory", data_directory);
}
oss << "(PGDuckDB/CheckDataDirectory) ";
if (errno == EACCES) {
oss << "Can't access `" << data_directory << "` directory";
} else {
oss << "Other error when reading `" << data_directory << "`: (" << errno << ") " << strerror(errno);
}

if (access(data_directory, R_OK | W_OK)) {
elog(ERROR, "(PGDuckDB/CheckDataDirectory) Directory `%s` permission problem", data_directory);
throw std::runtime_error(oss.str());
} else if (!S_ISDIR(info.st_mode)) {
oss << "(PGDuckDB/CheckDataDirectory) `" << data_directory << "` is not directory";
throw std::runtime_error(oss.str());
} else if (access(data_directory.c_str(), R_OK | W_OK)) {
oss << "(PGDuckDB/CheckDataDirectory) Directory `" << data_directory << "` permission problem";
throw std::runtime_error(oss.str());
}

return true;
}

static std::string
GetExtensionDirectory() {
StringInfo duckdb_extension_data_directory = makeStringInfo();
appendStringInfo(duckdb_extension_data_directory, "%s/duckdb_extensions", DataDir);

if (!CheckDataDirectory(duckdb_extension_data_directory->data)) {
if (mkdir(duckdb_extension_data_directory->data, S_IRWXU | S_IRWXG | S_IRWXO) == -1) {
int error = errno;
pfree(duckdb_extension_data_directory->data);
elog(ERROR,
"(PGDuckDB/GetExtensionDirectory) Creating duckdb extensions directory failed with reason `%s`\n",
strerror(error));
}
elog(DEBUG2, "(PGDuckDB/GetExtensionDirectory) Created %s as `duckdb.data_dir`",
duckdb_extension_data_directory->data);
};
std::ostringstream oss;
oss << DataDir << "/duckdb_extensions";
std::string data_directory = oss.str();

if (!CheckDataDirectory(data_directory)) {
std::filesystem::create_directories(data_directory);
elog(DEBUG2, "(PGDuckDB/GetExtensionDirectory) Created %s as `duckdb.data_dir`", data_directory.c_str());
}

std::string duckdb_extension_directory(duckdb_extension_data_directory->data);
pfree(duckdb_extension_data_directory->data);
return duckdb_extension_directory;
return data_directory;
}

DuckDBManager::DuckDBManager() : secret_table_num_rows(0), secret_table_current_seq(0) {
}

void
DuckDBManager::Initialize() {
elog(DEBUG2, "(PGDuckDB/DuckDBManager) Creating DuckDB instance");

duckdb::DBConfig config;
Expand Down Expand Up @@ -135,34 +136,30 @@ DuckDBManager::LoadSecrets(duckdb::ClientContext &context) {

int secret_id = 0;
for (auto &secret : duckdb_secrets) {
StringInfo secret_key = makeStringInfo();
std::ostringstream query;
bool is_r2_cloud_secret = (secret.type.rfind("R2", 0) == 0);
appendStringInfo(secret_key, "CREATE SECRET pgduckb_secret_%d ", secret_id);
appendStringInfo(secret_key, "(TYPE %s, KEY_ID '%s', SECRET '%s'", secret.type.c_str(), secret.id.c_str(),
secret.secret.c_str());
query << "CREATE SECRET pgduckb_secret_" << std::to_string(secret_id) << " ";
query << "(TYPE " << secret.type << ", KEY_ID '" << secret.id << "', SECRET '" << secret.secret << "'";
if (secret.region.length() && !is_r2_cloud_secret) {
appendStringInfo(secret_key, ", REGION '%s'", secret.region.c_str());
query << ", REGION '" << secret.region << "'";
}
if (secret.session_token.length() && !is_r2_cloud_secret) {
appendStringInfo(secret_key, ", SESSION_TOKEN '%s'", secret.session_token.c_str());
query << ", SESSION_TOKEN '" << secret.session_token << "'";
}
if (secret.endpoint.length() && !is_r2_cloud_secret) {
appendStringInfo(secret_key, ", ENDPOINT '%s'", secret.endpoint.c_str());
query << ", ENDPOINT '" << secret.endpoint << "'";
}
if (is_r2_cloud_secret) {
appendStringInfo(secret_key, ", ACCOUNT_ID '%s'", secret.endpoint.c_str());
query << ", ACCOUNT_ID '" << secret.endpoint << "'";
}
if (!secret.use_ssl) {
appendStringInfo(secret_key, ", USE_SSL 'FALSE'");
}
appendStringInfo(secret_key, ");");
auto res = context.Query(secret_key->data, false);
if (res->HasError()) {
elog(ERROR, "(PGDuckDB/LoadSecrets) secret `%s` could not be loaded with DuckDB", secret.id.c_str());
query << ", USE_SSL 'FALSE'";
}
query << ");";

DuckDBQueryOrThrow(context, query.str());

pfree(secret_key->data);
secret_id++;
++secret_id;
}

secret_table_num_rows = secret_id;
Expand All @@ -186,25 +183,20 @@ DuckDBManager::LoadExtensions(duckdb::ClientContext &context) {
auto duckdb_extensions = ReadDuckdbExtensions();

for (auto &extension : duckdb_extensions) {
StringInfo duckdb_extension = makeStringInfo();
if (extension.enabled) {
appendStringInfo(duckdb_extension, "LOAD %s;", extension.name.c_str());
auto res = context.Query(duckdb_extension->data, false);
if (res->HasError()) {
elog(ERROR, "(PGDuckDB/LoadExtensions) `%s` could not be loaded with DuckDB", extension.name.c_str());
}
DuckDBQueryOrThrow(context, "LOAD " + extension.name);
}
pfree(duckdb_extension->data);
}
}

duckdb::unique_ptr<duckdb::Connection>
DuckDBManager::GetConnection() {
auto connection = duckdb::make_uniq<duckdb::Connection>(*database);
if (CheckSecretsSeq()) {
DuckDBManager::CreateConnection() {
auto& instance = Get();
auto connection = duckdb::make_uniq<duckdb::Connection>(GetDatabase());
if (instance.CheckSecretsSeq()) {
auto &context = *connection->context;
DropSecrets(context);
LoadSecrets(context);
instance.DropSecrets(context);
instance.LoadSecrets(context);
}
return connection;
}
Expand Down
20 changes: 17 additions & 3 deletions src/pgduckdb_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ extern "C" {
#include "pgduckdb/utility/copy.hpp"
#include "pgduckdb/vendor/pg_explain.hpp"
#include "pgduckdb/vendor/pg_list.hpp"
#include "pgduckdb/pgduckdb_utils.hpp"

static planner_hook_type prev_planner_hook = NULL;
static ProcessUtility_hook_type prev_process_utility_hook = NULL;
Expand Down Expand Up @@ -164,7 +165,7 @@ IsAllowedStatement(Query *query, bool throw_error = false) {
}

static PlannedStmt *
DuckdbPlannerHook(Query *parse, const char *query_string, int cursor_options, ParamListInfo bound_params) {
DuckdbPlannerHook_Unsafe(Query *parse, const char *query_string, int cursor_options, ParamListInfo bound_params) {
if (pgduckdb::IsExtensionRegistered()) {
if (duckdb_execution && IsAllowedStatement(parse)) {
PlannedStmt *duckdbPlan = DuckdbPlanNode(parse, cursor_options);
Expand All @@ -190,9 +191,15 @@ DuckdbPlannerHook(Query *parse, const char *query_string, int cursor_options, Pa
}
}

static PlannedStmt *
DuckdbPlannerHook(Query *parse, const char *query_string, int cursor_options, ParamListInfo bound_params) {
return pgduckdb::DuckDBFunctionGuard<PlannedStmt *>(DuckdbPlannerHook_Unsafe, __FUNCTION__, parse, query_string, cursor_options, bound_params);
}

static void
DuckdbUtilityHook(PlannedStmt *pstmt, const char *query_string, bool read_only_tree, ProcessUtilityContext context,
ParamListInfo params, struct QueryEnvironment *query_env, DestReceiver *dest, QueryCompletion *qc) {
DuckdbUtilityHook_Unsafe(PlannedStmt *pstmt, const char *query_string, bool read_only_tree,
ProcessUtilityContext context, ParamListInfo params, struct QueryEnvironment *query_env,
DestReceiver *dest, QueryCompletion *qc) {
Node *parsetree = pstmt->utilityStmt;
if (duckdb_execution && pgduckdb::IsExtensionRegistered() && IsA(parsetree, CopyStmt)) {
uint64 processed;
Expand All @@ -215,6 +222,13 @@ DuckdbUtilityHook(PlannedStmt *pstmt, const char *query_string, bool read_only_t
}
}

static void
DuckdbUtilityHook(PlannedStmt *pstmt, const char *query_string, bool read_only_tree, ProcessUtilityContext context,
ParamListInfo params, struct QueryEnvironment *query_env, DestReceiver *dest, QueryCompletion *qc) {
pgduckdb::DuckDBFunctionGuard<void>(DuckdbUtilityHook_Unsafe, __FUNCTION__, pstmt, query_string, read_only_tree, context, params,
query_env, dest, qc);
}

extern "C" {
#include "nodes/print.h"
}
Expand Down
Loading

0 comments on commit 692de3b

Please sign in to comment.