Skip to content

Commit

Permalink
Support prepared statements and parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
vhiairrassary committed Dec 17, 2024
1 parent bd381c1 commit c3d2472
Show file tree
Hide file tree
Showing 18 changed files with 894 additions and 326 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ include_directories(
# Embed ./src/assets/index.html as a C++ header
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent
DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html
)

set(EXTENSION_SOURCES
src/httpserver_extension.cpp
src/http_handler/authentication.cpp
src/http_handler/bindings.cpp
src/http_handler/handler.cpp
src/http_handler/response_serializer.cpp
${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
)

Expand All @@ -37,7 +41,9 @@ build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES})

include_directories(${OPENSSL_INCLUDE_DIR})
target_link_libraries(${LOADABLE_EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
set_property(TARGET ${LOADABLE_EXTENSION_NAME} PROPERTY CXX_STANDARD 17)
target_link_libraries(${EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
set_property(TARGET ${EXTENSION_NAME} PROPERTY CXX_STANDARD 17)

if(MINGW)
set(WIN_LIBS crypt32 ws2_32 wsock32)
Expand Down
55 changes: 55 additions & 0 deletions src/http_handler/authentication.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "httpserver_extension/http_handler/common.hpp"
#include "httpserver_extension/state.hpp"
#include <string>
#include <vector>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

namespace duckdb_httpserver {

// Base64 decoding function
static std::string base64_decode(const std::string &in) {
std::string out;
std::vector<int> T(256, -1);
for (int i = 0; i < 64; i++)
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;

int val = 0, valb = -8;
for (unsigned char c : in) {
if (T[c] == -1) break;
val = (val << 6) + T[c];
valb += 6;
if (valb >= 0) {
out.push_back(char((val >> valb) & 0xFF));
valb -= 8;
}
}
return out;
}

// Check authentication
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
if (global_state.auth_token.empty()) {
return; // No authentication required if no token is set
}

// Check for X-API-Key header
auto api_key = req.get_header_value("X-API-Key");
if (!api_key.empty() && api_key == global_state.auth_token) {
return;
}

// Check for Basic Auth
auto auth = req.get_header_value("Authorization");
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
std::string decoded_auth = base64_decode(auth.substr(6));
if (decoded_auth == global_state.auth_token) {
return;
}
}

throw HttpHandlerException(401, "Unauthorized");
}

} // namespace duckdb_httpserver
69 changes: 69 additions & 0 deletions src/http_handler/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "httpserver_extension/http_handler/common.hpp"
#include "duckdb.hpp"
#include "yyjson.hpp"
#include <string>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

using namespace duckdb;
using namespace duckdb_yyjson;

namespace duckdb_httpserver {

static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
if (!yyjson_is_obj(parameterVal)) {
throw HttpHandlerException(400, "The parameter `" + key + "` must be an object");
}

auto typeVal = yyjson_obj_get(parameterVal, "type");
if (!typeVal) {
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field");
}
if (!yyjson_is_str(typeVal)) {
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string");
}
auto type = std::string(yyjson_get_str(typeVal));

auto valueVal = yyjson_obj_get(parameterVal, "value");
if (!valueVal) {
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field");
}

if (type == "TEXT") {
if (!yyjson_is_str(valueVal)) {
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string");
}

return BoundParameterData(Value(yyjson_get_str(valueVal)));
}
else if (type == "BOOLEAN") {
if (!yyjson_is_bool(valueVal)) {
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
}

return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
}

throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
}

case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_val* parametersVal) {
if (!parametersVal || !yyjson_is_obj(parametersVal)) {
throw HttpHandlerException(400, "The `parameters` field must be an object");
}

case_insensitive_map_t<BoundParameterData> named_values;

size_t idx, max;
yyjson_val *parameterKeyVal, *parameterVal;
yyjson_obj_foreach(parametersVal, idx, max, parameterKeyVal, parameterVal) {
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));

named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
}

return named_values;
}

} // namespace duckdb_httpserver
219 changes: 219 additions & 0 deletions src/http_handler/handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#include "httpserver_extension/http_handler/authentication.hpp"
#include "httpserver_extension/http_handler/bindings.hpp"
#include "httpserver_extension/http_handler/common.hpp"
#include "httpserver_extension/http_handler/handler.hpp"
#include "httpserver_extension/http_handler/playground.hpp"
#include "httpserver_extension/http_handler/response_serializer.hpp"
#include "httpserver_extension/state.hpp"
#include "duckdb.hpp"
#include "yyjson.hpp"

#include <string>
#include <vector>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

using namespace duckdb;
using namespace duckdb_yyjson;

namespace duckdb_httpserver {

// Handle both GET and POST requests
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
try {
// CORS allow
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Max-Age", "86400");

// Handle preflight OPTIONS request
if (req.method == "OPTIONS") {
res.status = 204; // No content
return;
}

CheckAuthentication(req);

auto queryApiParameters = ExtractQueryApiParameters(req);

if (!queryApiParameters.sqlQueryOpt.has_value()) {
res.status = 200;
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
return;
}

if (!global_state.db_instance) {
throw IOException("Database instance not initialized");
}

auto start = std::chrono::system_clock::now();
auto result = ExecuteQuery(req, queryApiParameters);
auto end = std::chrono::system_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);

QueryExecStats stats {
.elapsed_sec = static_cast<float>(elapsed.count()) / 1000,
.read_bytes = 0,
.read_rows = 0
};

// Format output
if (queryApiParameters.outputFormat == OutputFormat::Ndjson) {
std::string json_output = ConvertResultToNDJSON(*result);
res.set_content(json_output, "application/x-ndjson");
}
else {
auto json_output = ConvertResultToJSON(*result, stats);
res.set_content(json_output, "application/json");
}
}
catch (const HttpHandlerException& ex) {
res.status = ex.status;
res.set_content(ex.message, "text/plain");
}
catch (const std::exception& ex) {
res.status = 500;
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
res.set_content(error_message, "text/plain");
}
}

// Execute query (optionally using a prepared statement)
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
const duckdb_httplib_openssl::Request& req,
const QueryApiParameters& queryApiParameters
) {
Connection con(*global_state.db_instance);
std::unique_ptr<MaterializedQueryResult> result;
auto query = queryApiParameters.sqlQueryOpt.value();

auto use_prepared_stmt =
queryApiParameters.sqlParametersOpt.has_value() &&
queryApiParameters.sqlParametersOpt.value().empty() == false;

if (use_prepared_stmt) {
auto prepared_stmt = con.Prepare(query);
if (prepared_stmt->HasError()) {
throw HttpHandlerException(500, prepared_stmt->GetError());
}

auto named_values = queryApiParameters.sqlParametersOpt.value();

auto prepared_stmt_result = prepared_stmt->Execute(named_values);
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
} else {
result = con.Query(query);
}

if (result->HasError()) {
throw HttpHandlerException(500, result->GetError());
}

return result;
}

QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req) {
if (req.method == "POST" && req.has_header("Content-Type") && req.get_header_value("Content-Type") == "application/json") {
return ExtractQueryApiParametersComplex(req);
}
else {
return QueryApiParameters {
.sqlQueryOpt = ExtractSqlQuerySimple(req),
.outputFormat = ExtractOutputFormatSimple(req),
};
}
}

std::optional<std::string> ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req) {
// Check if the query is in the URL parameters
if (req.has_param("query")) {
return req.get_param_value("query");
}
else if (req.has_param("q")) {
return req.get_param_value("q");
}

// If not in URL, and it's a POST request, check the body
else if (req.method == "POST" && !req.body.empty()) {
return req.body;
}

return std::nullopt;
}

OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req) {
// Check for format in URL parameter or header
if (req.has_param("default_format")) {
return ParseOutputFormat(req.get_param_value("default_format"));
}
else if (req.has_header("X-ClickHouse-Format")) {
return ParseOutputFormat(req.get_header_value("X-ClickHouse-Format"));
}
else if (req.has_header("format")) {
return ParseOutputFormat(req.get_header_value("format"));
}
else {
return OutputFormat::Ndjson;
}
}

OutputFormat ParseOutputFormat(const std::string& formatStr) {
if (formatStr == "JSONEachRow" || formatStr == "ndjson" || formatStr == "jsonl") {
return OutputFormat::Ndjson;
}
else if (formatStr == "JSONCompact") {
return OutputFormat::Json;
}
else {
throw HttpHandlerException(400, "Unknown format");
}
}

QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req) {
yyjson_doc *bodyDoc = nullptr;

try {
auto bodyJson = req.body;
auto bodyJsonCStr = bodyJson.c_str();
bodyDoc = yyjson_read(bodyJsonCStr, strlen(bodyJsonCStr), 0);

return ExtractQueryApiParametersComplexImpl(bodyDoc);
}
catch (const std::exception& exception) {
yyjson_doc_free(bodyDoc);
throw;
}
}

QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) {
if (!bodyDoc) {
throw HttpHandlerException(400, "Unable to parse the request body");
}

auto bodyRoot = yyjson_doc_get_root(bodyDoc);
if (!yyjson_is_obj(bodyRoot)) {
throw HttpHandlerException(400, "The request body must be an object");
}

auto queryVal = yyjson_obj_get(bodyRoot, "query");
if (!queryVal || !yyjson_is_str(queryVal)) {
throw HttpHandlerException(400, "The `query` field must be a string");
}

auto formatVal = yyjson_obj_get(bodyRoot, "format");
if (!formatVal || !yyjson_is_str(formatVal)) {
throw HttpHandlerException(400, "The `format` field must be a string");
}

return QueryApiParameters {
.sqlQueryOpt = std::string(yyjson_get_str(queryVal)),
.sqlParametersOpt = ExtractQueryParameters(yyjson_obj_get(bodyRoot, "parameters")),
.outputFormat = ParseOutputFormat(std::string(yyjson_get_str(formatVal))),
};
}

} // namespace duckdb_httpserver
Loading

0 comments on commit c3d2472

Please sign in to comment.