Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create RecordBatch in BradStatement from query result and schema exposed from underlying connections #502

Merged
merged 6 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 104 additions & 20 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility>
#include <stdexcept>

#include <arrow/api.h>
#include <arrow/array/builder_binary.h>
#include "brad_sql_info.h"
#include "brad_statement.h"
Expand Down Expand Up @@ -50,23 +51,108 @@ arrow::Result<std::pair<std::string, std::string>> DecodeTransactionQuery(
return std::make_pair(std::move(autoincrement_id), std::move(transaction_id));
}

std::vector<std::vector<std::any>> TransformQueryResult(
std::vector<py::tuple> query_result) {
std::vector<std::vector<std::any>> transformed_query_result;
for (const auto &row : query_result) {
std::vector<std::any> transformed_row{};
for (const auto &field : row) {
if (py::isinstance<py::int_>(field)) {
transformed_row.push_back(std::make_any<int>(py::cast<int>(field)));
} else if (py::isinstance<py::float_>(field)) {
transformed_row.push_back(std::make_any<float>(py::cast<float>(field)));
} else {
transformed_row.push_back(std::make_any<std::string>(py::cast<std::string>(field)));
arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
const std::vector<py::tuple> &query_result,
const std::shared_ptr<arrow::Schema> &schema) {
const size_t num_rows = query_result.size();

const size_t num_columns = schema->num_fields();
std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(num_columns);

for (int field_ix = 0; field_ix < num_columns; ++field_ix) {
const auto &field_type = schema->field(field_ix)->type();
if (field_type->Equals(arrow::int64())) {
arrow::Int64Builder int64builder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<int64_t> val =
py::cast<std::optional<int64_t>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(int64builder.Append(*val));
} else {
ARROW_RETURN_NOT_OK(int64builder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::float32())) {
arrow::FloatBuilder floatbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<float> val =
py::cast<std::optional<float>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(floatbuilder.Append(*val));
} else {
ARROW_RETURN_NOT_OK(floatbuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) {
arrow::Decimal128Builder decimalbuilder(arrow::decimal(/*precision=*/10, /*scale=*/2));
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<std::string> val =
py::cast<std::optional<std::string>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(
decimalbuilder.Append(arrow::Decimal128::FromString(*val).ValueOrDie()));
} else {
ARROW_RETURN_NOT_OK(decimalbuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, decimalbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::utf8())) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<std::string> str =
py::cast<std::optional<std::string>>(query_result[row_ix][field_ix]);
if (str) {
ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size()));
} else {
ARROW_RETURN_NOT_OK(stringbuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::date64())) {
arrow::Date64Builder datebuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<int64_t> val =
py::cast<std::optional<int64_t>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(datebuilder.Append(*val));
} else {
ARROW_RETURN_NOT_OK(datebuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, datebuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::null())) {
arrow::NullBuilder nullbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
ARROW_RETURN_NOT_OK(nullbuilder.AppendNull());
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish());
columns.push_back(values);
}
transformed_query_result.push_back(transformed_row);
}
return transformed_query_result;

std::shared_ptr<arrow::RecordBatch> result_record_batch =
arrow::RecordBatch::Make(schema, num_rows, columns);

return result_record_batch;
}

BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {}
Expand Down Expand Up @@ -125,25 +211,23 @@ arrow::Result<std::unique_ptr<FlightInfo>>
EncodeTransactionQuery(query_ticket));

std::shared_ptr<arrow::Schema> result_schema;
std::vector<std::vector<std::any>> transformed_query_result;
std::shared_ptr<arrow::RecordBatch> result_record_batch;

{
py::gil_scoped_acquire guard;
auto result = handle_query_(query);
result_schema = ArrowSchemaFromBradSchema(result.second);
transformed_query_result = TransformQueryResult(result.first);
result_record_batch = ResultToRecordBatch(std::move(result.first), result_schema).ValueOrDie();
sopzha marked this conversation as resolved.
Show resolved Hide resolved
}

ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result));
ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(result_record_batch, result_schema));
sopzha marked this conversation as resolved.
Show resolved Hide resolved
query_data_.insert(query_ticket, statement);

ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema());

std::vector<FlightEndpoint> endpoints{
FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}};

const bool ordered = false;
ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema,
ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*result_schema,
descriptor,
endpoints,
-1,
Expand Down
1 change: 1 addition & 0 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "libcuckoo/cuckoohash_map.hh"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace brad {

Expand Down
89 changes: 10 additions & 79 deletions cpp/server/brad_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,96 +25,27 @@ arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
}

arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
std::vector<std::vector<std::any>> query_result) {
std::shared_ptr<BradStatement> result(
std::make_shared<BradStatement>(query_result));
return result;
std::shared_ptr<arrow::RecordBatch> result_record_batch,
std::shared_ptr<arrow::Schema> schema) {
std::shared_ptr<BradStatement> result(
std::make_shared<BradStatement>(result_record_batch, schema));
return result;
}

BradStatement::BradStatement(std::vector<std::vector<std::any>> query_result) :
query_result_(std::move(query_result)) {}
BradStatement::BradStatement(std::shared_ptr<arrow::RecordBatch> result_record_batch,
std::shared_ptr<arrow::Schema> schema) :
result_record_batch_(std::move(result_record_batch)),
schema_(std::move(schema)) {}

BradStatement::~BradStatement() {
}

arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() const {
if (schema_) {
return schema_;
}

std::vector<std::shared_ptr<arrow::Field>> fields;

if (query_result_.size() > 0) {
const std::vector<std::any> &row = query_result_[0];

int counter = 0;
for (const auto &field : row) {
std::string field_type = field.type().name();
if (field_type == "i") {
fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8()));
} else if (field_type == "f") {
fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32()));
} else {
fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8()));
}
}
}

schema_ = arrow::schema(fields);
return schema_;
}

arrow::Result<std::shared_ptr<arrow::RecordBatch>> BradStatement::FetchResult() {
std::shared_ptr<arrow::Schema> schema = GetSchema().ValueOrDie();

const int num_rows = query_result_.size();

std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(schema->num_fields());

for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) {
const auto &field = schema->fields()[field_ix];
if (field->type() == arrow::int8()) {
arrow::Int8Builder int8builder;
int8_t values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<int>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::float32()) {
arrow::FloatBuilder floatbuilder;
float values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<float>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::utf8()) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::string* str = std::any_cast<const std::string>(&(query_result_[row_ix][field_ix]));
ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size()));
}

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish());
}
}

std::shared_ptr<arrow::RecordBatch> record_batch =
arrow::RecordBatch::Make(schema,
num_rows,
columns);
return record_batch;
return result_record_batch_;
}

std::string* BradStatement::GetBradStmt() const { return stmt_; }
Expand Down
8 changes: 5 additions & 3 deletions cpp/server/brad_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ class BradStatement {
const std::string& sql);

static arrow::Result<std::shared_ptr<BradStatement>> Create(
const std::vector<std::vector<std::any>>);
std::shared_ptr<arrow::RecordBatch> result_record_batch,
std::shared_ptr<arrow::Schema> schema);

BradStatement(std::vector<std::vector<std::any>>);
BradStatement(std::shared_ptr<arrow::RecordBatch>,
std::shared_ptr<arrow::Schema>);

~BradStatement();

Expand All @@ -41,7 +43,7 @@ class BradStatement {
std::string* GetBradStmt() const;

private:
std::vector<std::vector<std::any>> query_result_;
std::shared_ptr<arrow::RecordBatch> result_record_batch_;

mutable std::shared_ptr<arrow::Schema> schema_;

Expand Down