Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create RecordBatch in BradStatement from query result and schema exposed from underlying connections #502

Merged
merged 6 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 66 additions & 20 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility>
#include <stdexcept>

#include <arrow/api.h>
#include <arrow/array/builder_binary.h>
#include "brad_sql_info.h"
#include "brad_statement.h"
Expand Down Expand Up @@ -50,23 +51,70 @@ arrow::Result<std::pair<std::string, std::string>> DecodeTransactionQuery(
return std::make_pair(std::move(autoincrement_id), std::move(transaction_id));
}

std::vector<std::vector<std::any>> TransformQueryResult(
std::vector<py::tuple> query_result) {
std::vector<std::vector<std::any>> transformed_query_result;
for (const auto &row : query_result) {
std::vector<std::any> transformed_row{};
for (const auto &field : row) {
if (py::isinstance<py::int_>(field)) {
transformed_row.push_back(std::make_any<int>(py::cast<int>(field)));
} else if (py::isinstance<py::float_>(field)) {
transformed_row.push_back(std::make_any<float>(py::cast<float>(field)));
} else {
transformed_row.push_back(std::make_any<std::string>(py::cast<std::string>(field)));
arrow::Result<std::shared_ptr<arrow::RecordBatch>>
ResultToRecordBatch(std::vector<py::tuple> query_result, std::shared_ptr<arrow::Schema> schema) {
sopzha marked this conversation as resolved.
Show resolved Hide resolved
const int num_rows = query_result.size();

const int num_columns = schema->num_fields();
sopzha marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(num_columns);

for (int field_ix = 0; field_ix < num_columns; ++field_ix) {
const auto &field_type = schema->field(field_ix)->type();
if (field_type->Equals(arrow::int64())) {
arrow::Int64Builder int64builder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const int64_t val = py::cast<int64_t>(query_result[row_ix][field_ix]);
// TODO: How do we check for null values in ints or floats?
sopzha marked this conversation as resolved.
Show resolved Hide resolved
ARROW_RETURN_NOT_OK(int64builder.Append(val));
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::float32()) ||
// TODO: Should not hardcode precision and scale values
field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) {
sopzha marked this conversation as resolved.
Show resolved Hide resolved
arrow::FloatBuilder floatbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const float val = py::cast<float>(query_result[row_ix][field_ix]);
ARROW_RETURN_NOT_OK(floatbuilder.Append(val));
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::utf8())) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::string str = py::cast<std::string>(query_result[row_ix][field_ix]);
sopzha marked this conversation as resolved.
Show resolved Hide resolved
if (str.empty()) {
ARROW_RETURN_NOT_OK(stringbuilder.Append(str.data(), str.size()));
} else {
ARROW_RETURN_NOT_OK(stringbuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::date64())) {
arrow::Date64Builder datebuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const int64_t val = py::cast<int64_t>(query_result[row_ix][field_ix]);
ARROW_RETURN_NOT_OK(datebuilder.Append(val));
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, datebuilder.Finish());
columns.push_back(values);

}
transformed_query_result.push_back(transformed_row);
}
return transformed_query_result;

std::shared_ptr<arrow::RecordBatch> result_record_batch =
arrow::RecordBatch::Make(schema, num_rows, columns);

return result_record_batch;
}

BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {}
Expand Down Expand Up @@ -125,25 +173,23 @@ arrow::Result<std::unique_ptr<FlightInfo>>
EncodeTransactionQuery(query_ticket));

std::shared_ptr<arrow::Schema> result_schema;
std::vector<std::vector<std::any>> transformed_query_result;
std::shared_ptr<arrow::RecordBatch> result_record_batch;

{
py::gil_scoped_acquire guard;
auto result = handle_query_(query);
result_schema = ArrowSchemaFromBradSchema(result.second);
transformed_query_result = TransformQueryResult(result.first);
result_record_batch = ResultToRecordBatch(result.first, result_schema).ValueOrDie();
sopzha marked this conversation as resolved.
Show resolved Hide resolved
}

ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result));
ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(result_record_batch, result_schema));
sopzha marked this conversation as resolved.
Show resolved Hide resolved
query_data_.insert(query_ticket, statement);

ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema());

std::vector<FlightEndpoint> endpoints{
FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}};

const bool ordered = false;
ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema,
ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*result_schema,
descriptor,
endpoints,
-1,
Expand Down
89 changes: 10 additions & 79 deletions cpp/server/brad_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,96 +25,27 @@ arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
}

arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
std::vector<std::vector<std::any>> query_result) {
std::shared_ptr<BradStatement> result(
std::make_shared<BradStatement>(query_result));
return result;
std::shared_ptr<arrow::RecordBatch> result_record_batch,
std::shared_ptr<arrow::Schema> schema) {
std::shared_ptr<BradStatement> result(
std::make_shared<BradStatement>(result_record_batch, schema));
return result;
}

BradStatement::BradStatement(std::vector<std::vector<std::any>> query_result) :
query_result_(std::move(query_result)) {}
BradStatement::BradStatement(std::shared_ptr<arrow::RecordBatch> result_record_batch,
std::shared_ptr<arrow::Schema> schema) :
result_record_batch_(std::move(result_record_batch)),
schema_(std::move(schema)) {}

BradStatement::~BradStatement() {
}

arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() const {
if (schema_) {
return schema_;
}

std::vector<std::shared_ptr<arrow::Field>> fields;

if (query_result_.size() > 0) {
const std::vector<std::any> &row = query_result_[0];

int counter = 0;
for (const auto &field : row) {
std::string field_type = field.type().name();
if (field_type == "i") {
fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8()));
} else if (field_type == "f") {
fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32()));
} else {
fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8()));
}
}
}

schema_ = arrow::schema(fields);
return schema_;
}

arrow::Result<std::shared_ptr<arrow::RecordBatch>> BradStatement::FetchResult() {
std::shared_ptr<arrow::Schema> schema = GetSchema().ValueOrDie();

const int num_rows = query_result_.size();

std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(schema->num_fields());

for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) {
const auto &field = schema->fields()[field_ix];
if (field->type() == arrow::int8()) {
arrow::Int8Builder int8builder;
int8_t values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<int>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::float32()) {
arrow::FloatBuilder floatbuilder;
float values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<float>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::utf8()) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::string* str = std::any_cast<const std::string>(&(query_result_[row_ix][field_ix]));
ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size()));
}

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish());
}
}

std::shared_ptr<arrow::RecordBatch> record_batch =
arrow::RecordBatch::Make(schema,
num_rows,
columns);
return record_batch;
return result_record_batch_;
}

std::string* BradStatement::GetBradStmt() const { return stmt_; }
Expand Down
8 changes: 5 additions & 3 deletions cpp/server/brad_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ class BradStatement {
const std::string& sql);

static arrow::Result<std::shared_ptr<BradStatement>> Create(
const std::vector<std::vector<std::any>>);
std::shared_ptr<arrow::RecordBatch> result_record_batch,
std::shared_ptr<arrow::Schema> schema);

BradStatement(std::vector<std::vector<std::any>>);
BradStatement(std::shared_ptr<arrow::RecordBatch>,
std::shared_ptr<arrow::Schema>);

~BradStatement();

Expand All @@ -41,7 +43,7 @@ class BradStatement {
std::string* GetBradStmt() const;

private:
std::vector<std::vector<std::any>> query_result_;
std::shared_ptr<arrow::RecordBatch> result_record_batch_;

mutable std::shared_ptr<arrow::Schema> schema_;

Expand Down