Skip to content

Commit

Permalink
better support for union, list and map children
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoa-xu committed Jul 8, 2023
1 parent d9c746e commit 4ab482d
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 60 deletions.
166 changes: 124 additions & 42 deletions c_src/adbc_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ template <typename M> static ERL_NIF_TERM strings_from_buffer(
return enif_make_list_from_array(env, values.data(), (unsigned)values.size());
}

static ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level);
static int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector<ERL_NIF_TERM> &out_terms, ERL_NIF_TERM &error);

static int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector<ERL_NIF_TERM> &children, ERL_NIF_TERM &error) {
ERL_NIF_TERM children_term{};
Expand All @@ -118,10 +118,18 @@ static int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema *
for (int64_t child_i = 0; child_i < schema->n_children; child_i++) {
struct ArrowSchema * child_schema = schema->children[child_i];
struct ArrowArray * child_values = values->children[child_i];
children[child_i] = arrow_array_to_nif_term(env, child_schema, child_values, level + 1);
std::vector<ERL_NIF_TERM> childrens;
if (arrow_array_to_nif_term(env, child_schema, child_values, level + 1, childrens, error) == 1) {
return 1;
}

if (childrens.size() == 0) {
children[child_i] = childrens[0];
} else {
children[child_i] = enif_make_tuple2(env, childrens[0], childrens[1]);
}
}
}
// children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children);

return 0;
}
Expand All @@ -142,7 +150,7 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch
if (values->children == nullptr) {
return erlang::nif::error(env, "invalid ArrowArray (map), values->children == nullptr");
}
if (schema->n_children != 1) {
if (values->n_children != 1) {
return erlang::nif::error(env, "invalid ArrowArray (map), values->n_children != 1");
}

Expand All @@ -152,7 +160,6 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch
return erlang::nif::error(env, "invalid ArrowSchema (map), its single child is not named entries");
}

printf("entries_values->n_children: %d\r\n", entries_values->n_children);
std::vector<ERL_NIF_TERM> nif_keys, nif_values;
bool failed = false;
for (int64_t child_i = 0; child_i < entries_values->n_children; child_i++) {
Expand All @@ -177,7 +184,7 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch

if (!failed) {
if (nif_keys.size() != nif_values.size()) {
return erlang::nif::error(env, "map contains duplicated keys");
return erlang::nif::error(env, "number of keys and values doesn't match");
}

if (!enif_make_map_from_arrays(env, nif_keys.data(), nif_values.data(), (unsigned)nif_keys.size(), &map_out)) {
Expand All @@ -190,18 +197,100 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch
}
}

ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level) {
static ERL_NIF_TERM get_arrow_array_union_children(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level) {
ERL_NIF_TERM error{}, map_out{};
if (schema->n_children > 0 && schema->children == nullptr) {
return erlang::nif::error(env, "invalid ArrowSchema (union), schema->children == nullptr while schema->n_children > 0 ");
}
if (values->n_children > 0 && values->children == nullptr) {
return erlang::nif::error(env, "invalid ArrowArray (union), values->children == nullptr while values->n_children > 0");
}

std::vector<ERL_NIF_TERM> nif_keys(values->n_children), nif_values2(values->n_children);
std::vector<ERL_NIF_TERM> field_values;
bool failed = false;
for (int64_t child_i = 0; child_i < values->n_children; child_i++) {
struct ArrowSchema * entry_schema = schema->children[child_i];
struct ArrowArray * entry_values = values->children[child_i];
nif_keys[child_i] = erlang::nif::make_binary(env, entry_schema->name);
if (arrow_array_to_nif_term(env, entry_schema, entry_values, level + 1, field_values, error) == 1) {
return error;
}

if (field_values.size() == 0) {
nif_values2[child_i] = field_values[0];
} else {
nif_values2[child_i] = field_values[1];
}
}

if (!failed) {
if (!enif_make_map_from_arrays(env, nif_keys.data(), nif_values2.data(), (unsigned)nif_keys.size(), &map_out)) {
return erlang::nif::error(env, "union contains duplicated fields");
} else {
return map_out;
}
} else {
return erlang::nif::error(env, "invalid union");
}
}

static ERL_NIF_TERM get_arrow_array_list_children(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level) {
ERL_NIF_TERM error{};
if (schema->children == nullptr) {
return erlang::nif::error(env, "invalid ArrowSchema (list), schema->children == nullptr");
}
if (schema->n_children != 1) {
return erlang::nif::error(env, "invalid ArrowSchema (list), schema->n_children != 1");
}
if (values->children == nullptr) {
return erlang::nif::error(env, "invalid ArrowArray (list), values->children == nullptr");
}
if (values->n_children != 1) {
return erlang::nif::error(env, "invalid ArrowArray (list), values->n_children != 1");
}

struct ArrowSchema * items_schema = schema->children[0];
struct ArrowArray * items_values = values->children[0];
if (strncmp("item", items_schema->name, 4) != 0) {
return erlang::nif::error(env, "invalid ArrowSchema (list), its single child is not named item");
}

std::vector<ERL_NIF_TERM> children(items_values->n_children);
bool failed = false;
for (int64_t child_i = 0; child_i < items_values->n_children; child_i++) {
struct ArrowSchema * item_schema = items_schema->children[child_i];
struct ArrowArray * item_values = items_values->children[child_i];

std::vector<ERL_NIF_TERM> childrens;
if (arrow_array_to_nif_term(env, item_schema, item_values, level + 1, childrens, error) == 1) {
return error;
}

if (childrens.size() == 0) {
children[child_i] = childrens[0];
} else {
children[child_i] = enif_make_tuple2(env, childrens[0], childrens[1]);
}
}

return enif_make_list_from_array(env, children.data(), (unsigned)items_values->n_children);
}

int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector<ERL_NIF_TERM> &out_terms, ERL_NIF_TERM &error) {
if (schema == nullptr) {
return erlang::nif::error(env, "invalid ArrowSchema (nullptr) when invoking next");
error = erlang::nif::error(env, "invalid ArrowSchema (nullptr) when invoking next");
return 1;
}
if (values == nullptr) {
return erlang::nif::error(env, "invalid ArrowArray (nullptr) when invoking next");
error = erlang::nif::error(env, "invalid ArrowArray (nullptr) when invoking next");
return 1;
}

const char* format = schema->format ? schema->format : "";
const char* name = schema->name ? schema->name : "";

ERL_NIF_TERM current_term{}, children_term{}, error{};
ERL_NIF_TERM current_term{}, children_term{};
std::vector<ERL_NIF_TERM> children;

bool has_validity_bitmap = values->null_count != 0 && values->null_count != -1;
Expand Down Expand Up @@ -342,39 +431,21 @@ ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema
// only handle and return children if this is a struct
is_struct = true;
if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) {
return error;
return 1;
}
children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children);
//get_arrow_array_children_as_list(env, schema, values, level, children, error);
} else if (strncmp("+m", format, 2) == 0) {
// if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) {
// return error;
// }
// children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children);
children_term = get_arrow_array_map_children(env, schema, values, level);
} else if (strncmp("+l", format, 2) == 0 || strncmp("+L", format, 2) == 0) {
if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) {
return error;
}
children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children);
// children_term = get_arrow_array_children_as_list(env, schema, values, level);
children_term = get_arrow_array_list_children(env, schema, values, level);
} else {
format_processed = false;
}
} else if (format_len >= 4) {
if (strncmp("+w:", format, 3) == 0) {
if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) {
return error;
}
children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children);
// children_term = get_arrow_array_children_as_list(env, schema, values, level);
children_term = get_arrow_array_list_children(env, schema, values, level);
} else if (format_len > 4 && (strncmp("+ud:", format, 4) == 0 || strncmp("+us:", format, 4) == 0)) {
// todo: get as map
if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) {
return error;
}
children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children);
// children_term = get_arrow_array_children_as_list(env, schema, values, level);
children_term = get_arrow_array_union_children(env, schema, values, level);
} else {
format_processed = false;
}
Expand All @@ -395,21 +466,21 @@ ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema
// printf("buffers: %p\r\n", values->buffers);
}

out_terms.clear();

if (is_struct) {
return children_term;
out_terms.emplace_back(children_term);
} else {
if (schema->children) {
return enif_make_tuple2(env,
erlang::nif::make_binary(env, name),
children_term
);
out_terms.emplace_back(erlang::nif::make_binary(env, name));
out_terms.emplace_back(children_term);
} else {
return enif_make_tuple2(env,
erlang::nif::make_binary(env, name),
current_term
);
out_terms.emplace_back(erlang::nif::make_binary(env, name));
out_terms.emplace_back(current_term);
}
}

return 0;
}

static ERL_NIF_TERM adbc_database_new(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
Expand Down Expand Up @@ -804,8 +875,19 @@ static ERL_NIF_TERM adbc_arrow_array_stream_next(ErlNifEnv *env, int argc, const
}
}

std::vector<ERL_NIF_TERM> out_terms;

auto schema = (struct ArrowSchema*)res->private_data;
ret = arrow_array_to_nif_term(env, schema, &out, 0);
if (arrow_array_to_nif_term(env, schema, &out, 0, out_terms, error) == 1) {
if (out.release) out.release(&out);
return error;
}

if (out_terms.size() == 1) {
ret = out_terms[0];
} else {
ret = enif_make_tuple2(env, out_terms[0], out_terms[1]);
}

if (out.release) {
out.release(&out);
Expand Down
35 changes: 17 additions & 18 deletions test/adbc_connection_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,15 @@ defmodule Adbc.Connection.Test do
num_rows: nil,
data: %{
"info_name" => [0, 1, 100, 101, 102],
"info_value" => [
{"string_value",
# ["SQLite", "3.39.2", "ADBC SQLite Driver", "(unknown)", "0.2.0-SNAPSHOT"]},
["SQLite", _, "ADBC SQLite Driver", _, _]},
{"bool_value", []},
{"int64_value", []},
{"int32_bitmask", []},
{"string_list", [{"item", []}]},
{"int32_to_int32_list_map", %{}}
]
"info_value" => %{
"bool_value" => [],
"int32_bitmask" => [],
"int32_to_int32_list_map" => %{},
"int64_value" => [],
"string_list" => [],
# ["SQLite", "3.39.2", "ADBC SQLite Driver", "(unknown)", "0.2.0-SNAPSHOT"]},
"string_value" => ["SQLite", _, "ADBC SQLite Driver", _, _]
}
}
}} = Connection.get_info(conn)
end
Expand All @@ -71,14 +70,14 @@ defmodule Adbc.Connection.Test do
num_rows: nil,
data: %{
"info_name" => [0],
"info_value" => [
{"string_value", ["SQLite"]},
{"bool_value", []},
{"int64_value", []},
{"int32_bitmask", []},
{"string_list", [{"item", []}]},
{"int32_to_int32_list_map", %{}}
]
"info_value" => %{
"bool_value" => [],
"int32_bitmask" => [],
"int32_to_int32_list_map" => %{},
"int64_value" => [],
"string_list" => [],
"string_value" => ["SQLite"]
}
}
}} = Connection.get_info(conn, [0])
end
Expand Down

0 comments on commit 4ab482d

Please sign in to comment.