Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support prepared statements and parameters #21

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
# Embed ./src/assets/index.html as a C++ header
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/playground.hpp playgroundContent
COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/embed.cmake ${PROJECT_SOURCE_DIR}/src/assets/index.html ${CMAKE_CURRENT_BINARY_DIR}/httpserver_extension/http_handler/playground.hpp playgroundContent
DEPENDS ${PROJECT_SOURCE_DIR}/src/assets/index.html
)

set(EXTENSION_SOURCES
src/httpserver_extension.cpp
src/http_handler/authentication.cpp
src/http_handler/bindings.cpp
src/http_handler/handler.cpp
src/http_handler/response_serializer.cpp
${CMAKE_CURRENT_BINARY_DIR}/playground.hpp
)

Expand All @@ -37,7 +41,9 @@

include_directories(${OPENSSL_INCLUDE_DIR})
target_link_libraries(${LOADABLE_EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
set_property(TARGET ${LOADABLE_EXTENSION_NAME} PROPERTY CXX_STANDARD 17)
target_link_libraries(${EXTENSION_NAME} duckdb_mbedtls ${OPENSSL_LIBRARIES})
set_property(TARGET ${EXTENSION_NAME} PROPERTY CXX_STANDARD 17)

if(MINGW)
set(WIN_LIBS crypt32 ws2_32 wsock32)
Expand Down
55 changes: 55 additions & 0 deletions src/http_handler/authentication.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "httpserver_extension/http_handler/common.hpp"
#include "httpserver_extension/state.hpp"
#include <string>
#include <vector>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

namespace duckdb_httpserver {

// Base64 decoding function
static std::string base64_decode(const std::string &in) {
std::string out;
std::vector<int> T(256, -1);
for (int i = 0; i < 64; i++)
T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;

int val = 0, valb = -8;
for (unsigned char c : in) {
if (T[c] == -1) break;
val = (val << 6) + T[c];
valb += 6;
if (valb >= 0) {
out.push_back(char((val >> valb) & 0xFF));
valb -= 8;
}
}
return out;
}

// Check authentication
void CheckAuthentication(const duckdb_httplib_openssl::Request& req) {
if (global_state.auth_token.empty()) {
return; // No authentication required if no token is set
}

// Check for X-API-Key header
auto api_key = req.get_header_value("X-API-Key");
if (!api_key.empty() && api_key == global_state.auth_token) {
return;
}

// Check for Basic Auth
auto auth = req.get_header_value("Authorization");
if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) {
std::string decoded_auth = base64_decode(auth.substr(6));
if (decoded_auth == global_state.auth_token) {
return;
}
}

throw HttpHandlerException(401, "Unauthorized");
}

} // namespace duckdb_httpserver
69 changes: 69 additions & 0 deletions src/http_handler/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "httpserver_extension/http_handler/common.hpp"
#include "duckdb.hpp"
#include "yyjson.hpp"
#include <string>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

using namespace duckdb;
using namespace duckdb_yyjson;

namespace duckdb_httpserver {

static BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) {
if (!yyjson_is_obj(parameterVal)) {
throw HttpHandlerException(400, "The parameter `" + key + "` must be an object");
}

auto typeVal = yyjson_obj_get(parameterVal, "type");
if (!typeVal) {
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `type` field");
}
if (!yyjson_is_str(typeVal)) {
throw HttpHandlerException(400, "The field `type` for the parameter `" + key + "` must be a string");
}
auto type = std::string(yyjson_get_str(typeVal));

auto valueVal = yyjson_obj_get(parameterVal, "value");
if (!valueVal) {
throw HttpHandlerException(400, "The parameter `" + key + "` does not have a `value` field");
}

if (type == "TEXT") {
if (!yyjson_is_str(valueVal)) {
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a string");
}

return BoundParameterData(Value(yyjson_get_str(valueVal)));
}
else if (type == "BOOLEAN") {
if (!yyjson_is_bool(valueVal)) {
throw HttpHandlerException(400, "The field `value` for the parameter `" + key + "` must be a boolean");
}

return BoundParameterData(Value(bool(yyjson_get_bool(valueVal))));
}

throw HttpHandlerException(400, "Unsupported type " + type + " the parameter `" + key + "`");
}

case_insensitive_map_t<BoundParameterData> ExtractQueryParameters(yyjson_val* parametersVal) {
if (!parametersVal || !yyjson_is_obj(parametersVal)) {
throw HttpHandlerException(400, "The `parameters` field must be an object");
}

case_insensitive_map_t<BoundParameterData> named_values;

size_t idx, max;
yyjson_val *parameterKeyVal, *parameterVal;
yyjson_obj_foreach(parametersVal, idx, max, parameterKeyVal, parameterVal) {
auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal));

named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal);
}

return named_values;
}

} // namespace duckdb_httpserver
220 changes: 220 additions & 0 deletions src/http_handler/handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#include "httpserver_extension/http_handler/authentication.hpp"
#include "httpserver_extension/http_handler/bindings.hpp"
#include "httpserver_extension/http_handler/common.hpp"
#include "httpserver_extension/http_handler/handler.hpp"
#include "httpserver_extension/http_handler/playground.hpp"
#include "httpserver_extension/http_handler/response_serializer.hpp"
#include "httpserver_extension/state.hpp"
#include "duckdb.hpp"
#include "yyjson.hpp"

#include <string>
#include <vector>

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

using namespace duckdb;
using namespace duckdb_yyjson;

namespace duckdb_httpserver {

// Handle both GET and POST requests
void HttpHandler(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
try {
// CORS allow
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT");
res.set_header("Access-Control-Allow-Headers", "*");
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Max-Age", "86400");

// Handle preflight OPTIONS request
if (req.method == "OPTIONS") {
res.status = 204; // No content
return;
}

CheckAuthentication(req);

auto queryApiParameters = ExtractQueryApiParameters(req);

if (!queryApiParameters.sqlQueryOpt.has_value()) {
res.status = 200;
res.set_content(reinterpret_cast<char const*>(playgroundContent), sizeof(playgroundContent), "text/html");
return;
}

if (!global_state.db_instance) {
throw IOException("Database instance not initialized");
}

auto start = std::chrono::system_clock::now();
auto result = ExecuteQuery(req, queryApiParameters);
auto end = std::chrono::system_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);

QueryExecStats stats {
static_cast<float>(elapsed.count()) / 1000,
0,
0
};

// Format output
if (queryApiParameters.outputFormat == OutputFormat::Ndjson) {
std::string json_output = ConvertResultToNDJSON(*result);
res.set_content(json_output, "application/x-ndjson");
}
else {
auto json_output = ConvertResultToJSON(*result, stats);
res.set_content(json_output, "application/json");
}
}
catch (const HttpHandlerException& ex) {
res.status = ex.status;
res.set_content(ex.message, "text/plain");
}
catch (const std::exception& ex) {
res.status = 500;
std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what());
res.set_content(error_message, "text/plain");
}
}

// Execute query (optionally using a prepared statement)
std::unique_ptr<MaterializedQueryResult> ExecuteQuery(
const duckdb_httplib_openssl::Request& req,
const QueryApiParameters& queryApiParameters
) {
Connection con(*global_state.db_instance);
std::unique_ptr<MaterializedQueryResult> result;
auto query = queryApiParameters.sqlQueryOpt.value();

auto use_prepared_stmt =
queryApiParameters.sqlParametersOpt.has_value() &&
queryApiParameters.sqlParametersOpt.value().empty() == false;

if (use_prepared_stmt) {
auto prepared_stmt = con.Prepare(query);
if (prepared_stmt->HasError()) {
throw HttpHandlerException(500, prepared_stmt->GetError());
}

auto named_values = queryApiParameters.sqlParametersOpt.value();

auto prepared_stmt_result = prepared_stmt->Execute(named_values);
D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT);
result = unique_ptr_cast<QueryResult, StreamQueryResult>(std::move(prepared_stmt_result))->Materialize();
} else {
result = con.Query(query);
}

if (result->HasError()) {
throw HttpHandlerException(500, result->GetError());
}

return result;
}

QueryApiParameters ExtractQueryApiParameters(const duckdb_httplib_openssl::Request& req) {
if (req.method == "POST" && req.has_header("Content-Type") && req.get_header_value("Content-Type") == "application/json") {
return ExtractQueryApiParametersComplex(req);
}
else {
return QueryApiParameters {
ExtractSqlQuerySimple(req),
std::nullopt,
ExtractOutputFormatSimple(req),
};
}
}

std::optional<std::string> ExtractSqlQuerySimple(const duckdb_httplib_openssl::Request& req) {
// Check if the query is in the URL parameters
if (req.has_param("query")) {
return req.get_param_value("query");
}
else if (req.has_param("q")) {
return req.get_param_value("q");
}

// If not in URL, and it's a POST request, check the body
else if (req.method == "POST" && !req.body.empty()) {
return req.body;
}

return std::nullopt;
}

OutputFormat ExtractOutputFormatSimple(const duckdb_httplib_openssl::Request& req) {
// Check for format in URL parameter or header
if (req.has_param("default_format")) {
return ParseOutputFormat(req.get_param_value("default_format"));
}
else if (req.has_header("X-ClickHouse-Format")) {
return ParseOutputFormat(req.get_header_value("X-ClickHouse-Format"));
}
else if (req.has_header("format")) {
return ParseOutputFormat(req.get_header_value("format"));
}
else {
return OutputFormat::Ndjson;
}
}

OutputFormat ParseOutputFormat(const std::string& formatStr) {
if (formatStr == "JSONEachRow" || formatStr == "ndjson" || formatStr == "jsonl") {
return OutputFormat::Ndjson;
}
else if (formatStr == "JSONCompact") {
return OutputFormat::Json;
}
else {
throw HttpHandlerException(400, "Unknown format");
}
}

QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl::Request& req) {
yyjson_doc *bodyDoc = nullptr;

try {
auto bodyJson = req.body;
auto bodyJsonCStr = bodyJson.c_str();
bodyDoc = yyjson_read(bodyJsonCStr, strlen(bodyJsonCStr), 0);

return ExtractQueryApiParametersComplexImpl(bodyDoc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bodyDoc is never freed

Copy link
Collaborator

@lmangani lmangani Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong but possibly json_line as well?

diff --git a/src/http_handler/handler.cpp b/src/http_handler/handler.cpp
index 88d25a3..e5f7a8b 100644
--- a/src/http_handler/handler.cpp
+++ b/src/http_handler/handler.cpp
@@ -179,6 +179,7 @@ QueryApiParameters ExtractQueryApiParametersComplex(const duckdb_httplib_openssl
         return ExtractQueryApiParametersComplexImpl(bodyDoc);
     }
     catch (const std::exception& exception) {
+        yyjson_doc_free(bodyDoc);
         throw;
     }
 }
@@ -217,6 +218,7 @@ QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) {
 
 static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
     std::string ndjson_output;
+    yyjson_mut_doc *doc = nullptr;

     for (idx_t row = 0; row < result.RowCount(); ++row) {
         // Create a new JSON document for each row
@@ -226,6 +228,8 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
         yyjson_mut_doc_set_root(doc, root);
 
         for (idx_t col = 0; col < result.ColumnCount(); ++col) {
+            // Ensure doc is freed on each iteration
+            yyjson_mut_doc_free(doc);
             Value value = result.GetValue(col, row);
             const char* column_name = result.ColumnName(col).c_str();
 
@@ -246,6 +250,7 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) {
 
         ndjson_output += json_line;
         ndjson_output += "\n";
+        free(json_line);
     }
     return ndjson_output;
 }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vhiairrassary shall I apply the patches or you prefer doing so?

}
catch (const std::exception& exception) {
yyjson_doc_free(bodyDoc);
throw;
}
}

QueryApiParameters ExtractQueryApiParametersComplexImpl(yyjson_doc* bodyDoc) {
if (!bodyDoc) {
throw HttpHandlerException(400, "Unable to parse the request body");
}

auto bodyRoot = yyjson_doc_get_root(bodyDoc);
if (!yyjson_is_obj(bodyRoot)) {
throw HttpHandlerException(400, "The request body must be an object");
}

auto queryVal = yyjson_obj_get(bodyRoot, "query");
if (!queryVal || !yyjson_is_str(queryVal)) {
throw HttpHandlerException(400, "The `query` field must be a string");
}

auto formatVal = yyjson_obj_get(bodyRoot, "format");
if (!formatVal || !yyjson_is_str(formatVal)) {
throw HttpHandlerException(400, "The `format` field must be a string");
}

return QueryApiParameters {
std::string(yyjson_get_str(queryVal)),
ExtractQueryParameters(yyjson_obj_get(bodyRoot, "parameters")),
ParseOutputFormat(std::string(yyjson_get_str(formatVal))),
};
}

} // namespace duckdb_httpserver
Loading
Loading