generated from duckdb/extension-template
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support prepared statements and parameters
- Loading branch information
1 parent
bd381c1
commit 88d25a3
Showing
18 changed files
with
895 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#include "httpserver_extension/http_handler/common.hpp" | ||
#include "httpserver_extension/state.hpp" | ||
#include <string> | ||
#include <vector> | ||
|
||
#define CPPHTTPLIB_OPENSSL_SUPPORT | ||
#include "httplib.hpp" | ||
|
||
namespace duckdb_httpserver { | ||
|
||
// Base64 decoding function | ||
static std::string base64_decode(const std::string &in) { | ||
std::string out; | ||
std::vector<int> T(256, -1); | ||
for (int i = 0; i < 64; i++) | ||
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i; | ||
|
||
int val = 0, valb = -8; | ||
for (unsigned char c : in) { | ||
if (T[c] == -1) break; | ||
val = (val << 6) + T[c]; | ||
valb += 6; | ||
if (valb >= 0) { | ||
out.push_back(char((val >> valb) & 0xFF)); | ||
valb -= 8; | ||
} | ||
} | ||
return out; | ||
} | ||
|
||
// Check authentication | ||
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) { | ||
if (global_state.auth_token.empty()) { | ||
return; // No authentication required if no token is set | ||
} | ||
|
||
// Check for X-API-Key header | ||
auto api_key = req.get_header_value("X-API-Key"); | ||
if (!api_key.empty() && api_key == global_state.auth_token) { | ||
return; | ||
} | ||
|
||
// Check for Basic Auth | ||
auto auth = req.get_header_value("Authorization"); | ||
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) { | ||
std::string decoded_auth = base64_decode(auth.substr(6)); | ||
if (decoded_auth == global_state.auth_token) { | ||
return; | ||
} | ||
} | ||
|
||
throw HttpHandlerException(401, "Unauthorized"); | ||
} | ||
|
||
} // namespace duckdb_httpserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include "httpserver_extension/http_handler/common.hpp" | ||
#include "duckdb.hpp" | ||
#include "yyjson.hpp" | ||
#include <string> | ||
|
||
#define CPPHTTPLIB_OPENSSL_SUPPORT | ||
#include "httplib.hpp" | ||
|
||
using namespace duckdb; | ||
using namespace duckdb_yyjson; | ||
|
||
namespace duckdb_httpserver { | ||
|
||
static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) { | ||
if (!yyjson_is_obj(parameterVal)) { | ||
throw HttpHandlerException(400, "The parameter `" + key + "` must be an object"); | ||
} | ||
|
||
auto typeVal = yyjson_obj_get(parameterVal, "type"); | ||
if (!typeVal) { | ||
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field"); | ||
} | ||
if (!yyjson_is_str(typeVal)) { | ||
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string"); | ||
} | ||
auto type = std::string(yyjson_get_str(typeVal)); | ||
|
||
auto valueVal = yyjson_obj_get(parameterVal, "value"); | ||
if (!valueVal) { | ||
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field"); | ||
} | ||
|
||
if (type == "TEXT") { | ||
if (!yyjson_is_str(valueVal)) { | ||
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string"); | ||
} | ||
|
||
return BoundParameterData(Value(yyjson_get_str(valueVal))); | ||
} | ||
else if (type == "BOOLEAN") { | ||
if (!yyjson_is_bool(valueVal)) { | ||
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean"); | ||
} | ||
|
||
return BoundParameterData(Value(bool(yyjson_get_bool(valueVal)))); | ||
} | ||
|
||
throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`"); | ||
} | ||
|
||
case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_val* parametersVal) { | ||
if (!parametersVal || !yyjson_is_obj(parametersVal)) { | ||
throw HttpHandlerException(400, "The `parameters` field must be an object"); | ||
} | ||
|
||
case_insensitive_map_t<BoundParameterData> named_values; | ||
|
||
size_t idx, max; | ||
yyjson_val *parameterKeyVal, *parameterVal; | ||
yyjson_obj_foreach(parametersVal, idx, max, parameterKeyVal, parameterVal) { | ||
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal)); | ||
|
||
named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal); | ||
} | ||
|
||
return named_values; | ||
} | ||
|
||
} // namespace duckdb_httpserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
#include "httpserver_extension/http_handler/authentication.hpp" | ||
#include "httpserver_extension/http_handler/bindings.hpp" | ||
#include "httpserver_extension/http_handler/common.hpp" | ||
#include "httpserver_extension/http_handler/handler.hpp" | ||
#include "httpserver_extension/http_handler/playground.hpp" | ||
#include "httpserver_extension/http_handler/response_serializer.hpp" | ||
#include "httpserver_extension/state.hpp" | ||
#include "duckdb.hpp" | ||
#include "yyjson.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#define CPPHTTPLIB_OPENSSL_SUPPORT | ||
#include "httplib.hpp" | ||
|
||
using namespace duckdb; | ||
using namespace duckdb_yyjson; | ||
|
||
namespace duckdb_httpserver { | ||
|
||
// Handle both GET and POST requests | ||
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { | ||
try { | ||
// CORS allow | ||
res.set_header("Access-Control-Allow-Origin", "*"); | ||
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT"); | ||
res.set_header("Access-Control-Allow-Headers", "*"); | ||
res.set_header("Access-Control-Allow-Credentials", "true"); | ||
res.set_header("Access-Control-Max-Age", "86400"); | ||
|
||
// Handle preflight OPTIONS request | ||
if (req.method == "OPTIONS") { | ||
res.status = 204; // No content | ||
return; | ||
} | ||
|
||
CheckAuthentication(req); | ||
|
||
auto queryApiParameters = ExtractQueryApiParameters(req); | ||
|
||
if (!queryApiParameters.sqlQueryOpt.has_value()) { | ||
res.status = 200; | ||
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html"); | ||
return; | ||
} | ||
|
||
if (!global_state.db_instance) { | ||
throw IOException("Database instance not initialized"); | ||
} | ||
|
||
auto start = std::chrono::system_clock::now(); | ||
auto result = ExecuteQuery(req, queryApiParameters); | ||
auto end = std::chrono::system_clock::now(); | ||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); | ||
|
||
QueryExecStats stats { | ||
static_cast<float>(elapsed.count()) / 1000, | ||
0, | ||
0 | ||
}; | ||
|
||
// Format output | ||
if (queryApiParameters.outputFormat == OutputFormat::Ndjson) { | ||
std::string json_output = ConvertResultToNDJSON(*result); | ||
res.set_content(json_output, "application/x-ndjson"); | ||
} | ||
else { | ||
auto json_output = ConvertResultToJSON(*result, stats); | ||
res.set_content(json_output, "application/json"); | ||
} | ||
} | ||
catch (const HttpHandlerException& ex) { | ||
res.status = ex.status; | ||
res.set_content(ex.message, "text/plain"); | ||
} | ||
catch (const std::exception& ex) { | ||
res.status = 500; | ||
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what()); | ||
res.set_content(error_message, "text/plain"); | ||
} | ||
} | ||
|
||
// Execute query (optionally using a prepared statement) | ||
std::unique_ptr<MaterializedQueryResult> ExecuteQuery( | ||
const duckdb_httplib_openssl::Request& req, | ||
const QueryApiParameters& queryApiParameters | ||
) { | ||
Connection con(*global_state.db_instance); | ||
std::unique_ptr<MaterializedQueryResult> result; | ||
auto query = queryApiParameters.sqlQueryOpt.value(); | ||
|
||
auto use_prepared_stmt = | ||
queryApiParameters.sqlParametersOpt.has_value() && | ||
queryApiParameters.sqlParametersOpt.value().empty() == false; | ||
|
||
if (use_prepared_stmt) { | ||
auto prepared_stmt = con.Prepare(query); | ||
if (prepared_stmt->HasError()) { | ||
throw HttpHandlerException(500, prepared_stmt->GetError()); | ||
} | ||
|
||
auto named_values = queryApiParameters.sqlParametersOpt.value(); | ||
|
||
auto prepared_stmt_result = prepared_stmt->Execute(named_values); | ||
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT); | ||
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize(); | ||
} else { | ||
result = con.Query(query); | ||
} | ||
|
||
if (result->HasError()) { | ||
throw HttpHandlerException(500, result->GetError()); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req) { | ||
if (req.method == "POST" && req.has_header("Content-Type") && req.get_header_value("Content-Type") == "application/json") { | ||
return ExtractQueryApiParametersComplex(req); | ||
} | ||
else { | ||
return QueryApiParameters { | ||
ExtractSqlQuerySimple(req), | ||
std::nullopt, | ||
ExtractOutputFormatSimple(req), | ||
}; | ||
} | ||
} | ||
|
||
std::optional<std::string> ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req) { | ||
// Check if the query is in the URL parameters | ||
if (req.has_param("query")) { | ||
return req.get_param_value("query"); | ||
} | ||
else if (req.has_param("q")) { | ||
return req.get_param_value("q"); | ||
} | ||
|
||
// If not in URL, and it's a POST request, check the body | ||
else if (req.method == "POST" && !req.body.empty()) { | ||
return req.body; | ||
} | ||
|
||
return std::nullopt; | ||
} | ||
|
||
OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req) { | ||
// Check for format in URL parameter or header | ||
if (req.has_param("default_format")) { | ||
return ParseOutputFormat(req.get_param_value("default_format")); | ||
} | ||
else if (req.has_header("X-ClickHouse-Format")) { | ||
return ParseOutputFormat(req.get_header_value("X-ClickHouse-Format")); | ||
} | ||
else if (req.has_header("format")) { | ||
return ParseOutputFormat(req.get_header_value("format")); | ||
} | ||
else { | ||
return OutputFormat::Ndjson; | ||
} | ||
} | ||
|
||
OutputFormat ParseOutputFormat(const std::string& formatStr) { | ||
if (formatStr == "JSONEachRow" || formatStr == "ndjson" || formatStr == "jsonl") { | ||
return OutputFormat::Ndjson; | ||
} | ||
else if (formatStr == "JSONCompact") { | ||
return OutputFormat::Json; | ||
} | ||
else { | ||
throw HttpHandlerException(400, "Unknown format"); | ||
} | ||
} | ||
|
||
QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req) { | ||
yyjson_doc *bodyDoc = nullptr; | ||
|
||
try { | ||
auto bodyJson = req.body; | ||
auto bodyJsonCStr = bodyJson.c_str(); | ||
bodyDoc = yyjson_read(bodyJsonCStr, strlen(bodyJsonCStr), 0); | ||
|
||
return ExtractQueryApiParametersComplexImpl(bodyDoc); | ||
} | ||
catch (const std::exception& exception) { | ||
yyjson_doc_free(bodyDoc); | ||
throw; | ||
} | ||
} | ||
|
||
QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) { | ||
if (!bodyDoc) { | ||
throw HttpHandlerException(400, "Unable to parse the request body"); | ||
} | ||
|
||
auto bodyRoot = yyjson_doc_get_root(bodyDoc); | ||
if (!yyjson_is_obj(bodyRoot)) { | ||
throw HttpHandlerException(400, "The request body must be an object"); | ||
} | ||
|
||
auto queryVal = yyjson_obj_get(bodyRoot, "query"); | ||
if (!queryVal || !yyjson_is_str(queryVal)) { | ||
throw HttpHandlerException(400, "The `query` field must be a string"); | ||
} | ||
|
||
auto formatVal = yyjson_obj_get(bodyRoot, "format"); | ||
if (!formatVal || !yyjson_is_str(formatVal)) { | ||
throw HttpHandlerException(400, "The `format` field must be a string"); | ||
} | ||
|
||
return QueryApiParameters { | ||
std::string(yyjson_get_str(queryVal)), | ||
ExtractQueryParameters(yyjson_obj_get(bodyRoot, "parameters")), | ||
ParseOutputFormat(std::string(yyjson_get_str(formatVal))), | ||
}; | ||
} | ||
|
||
} // namespace duckdb_httpserver |
Oops, something went wrong.