diff --git a/c_src/adbc_nif.cpp b/c_src/adbc_nif.cpp index 3f4541b..0329b1f 100644 --- a/c_src/adbc_nif.cpp +++ b/c_src/adbc_nif.cpp @@ -96,7 +96,7 @@ template static ERL_NIF_TERM strings_from_buffer( return enif_make_list_from_array(env, values.data(), (unsigned)values.size()); } -static ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level); +static int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &out_terms, ERL_NIF_TERM &error); static int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &children, ERL_NIF_TERM &error) { ERL_NIF_TERM children_term{}; @@ -118,10 +118,18 @@ static int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema * for (int64_t child_i = 0; child_i < schema->n_children; child_i++) { struct ArrowSchema * child_schema = schema->children[child_i]; struct ArrowArray * child_values = values->children[child_i]; - children[child_i] = arrow_array_to_nif_term(env, child_schema, child_values, level + 1); + std::vector childrens; + if (arrow_array_to_nif_term(env, child_schema, child_values, level + 1, childrens, error) == 1) { + return 1; + } + + if (childrens.size() == 0) { + children[child_i] = childrens[0]; + } else { + children[child_i] = enif_make_tuple2(env, childrens[0], childrens[1]); + } } } - // children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); return 0; } @@ -142,7 +150,7 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch if (values->children == nullptr) { return erlang::nif::error(env, "invalid ArrowArray (map), values->children == nullptr"); } - if (schema->n_children != 1) { + if (values->n_children != 1) { return erlang::nif::error(env, "invalid ArrowArray (map), values->n_children != 1"); } @@ -152,7 +160,6 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch return erlang::nif::error(env, "invalid ArrowSchema (map), its single child is not named entries"); } - printf("entries_values->n_children: %d\r\n", entries_values->n_children); std::vector nif_keys, nif_values; bool failed = false; for (int64_t child_i = 0; child_i < entries_values->n_children; child_i++) { @@ -177,7 +184,7 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch if (!failed) { if (nif_keys.size() != nif_values.size()) { - return erlang::nif::error(env, "map contains duplicated keys"); + return erlang::nif::error(env, "number of keys and values doesn't match"); } if (!enif_make_map_from_arrays(env, nif_keys.data(), nif_values.data(), (unsigned)nif_keys.size(), &map_out)) { @@ -190,18 +197,100 @@ static ERL_NIF_TERM get_arrow_array_map_children(ErlNifEnv *env, struct ArrowSch } } -ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level) { +static ERL_NIF_TERM get_arrow_array_union_children(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level) { + ERL_NIF_TERM error{}, map_out{}; + if (schema->n_children > 0 && schema->children == nullptr) { + return erlang::nif::error(env, "invalid ArrowSchema (union), schema->children == nullptr while schema->n_children > 0 "); + } + if (values->n_children > 0 && values->children == nullptr) { + return erlang::nif::error(env, "invalid ArrowArray (union), values->children == nullptr while values->n_children > 0"); + } + + std::vector nif_keys(values->n_children), nif_values2(values->n_children); + std::vector field_values; + bool failed = false; + for (int64_t child_i = 0; child_i < values->n_children; child_i++) { + struct ArrowSchema * entry_schema = schema->children[child_i]; + struct ArrowArray * entry_values = values->children[child_i]; + nif_keys[child_i] = erlang::nif::make_binary(env, entry_schema->name); + if (arrow_array_to_nif_term(env, entry_schema, entry_values, level + 1, field_values, error) == 1) { + return error; + } + + if (field_values.size() == 0) { + nif_values2[child_i] = field_values[0]; + } else { + nif_values2[child_i] = field_values[1]; + } + } + + if (!failed) { + if (!enif_make_map_from_arrays(env, nif_keys.data(), nif_values2.data(), (unsigned)nif_keys.size(), &map_out)) { + return erlang::nif::error(env, "union contains duplicated fields"); + } else { + return map_out; + } + } else { + return erlang::nif::error(env, "invalid union"); + } +} + +static ERL_NIF_TERM get_arrow_array_list_children(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level) { + ERL_NIF_TERM error{}; + if (schema->children == nullptr) { + return erlang::nif::error(env, "invalid ArrowSchema (list), schema->children == nullptr"); + } + if (schema->n_children != 1) { + return erlang::nif::error(env, "invalid ArrowSchema (list), schema->n_children != 1"); + } + if (values->children == nullptr) { + return erlang::nif::error(env, "invalid ArrowArray (list), values->children == nullptr"); + } + if (values->n_children != 1) { + return erlang::nif::error(env, "invalid ArrowArray (list), values->n_children != 1"); + } + + struct ArrowSchema * items_schema = schema->children[0]; + struct ArrowArray * items_values = values->children[0]; + if (strncmp("item", items_schema->name, 4) != 0) { + return erlang::nif::error(env, "invalid ArrowSchema (list), its single child is not named item"); + } + + std::vector children(items_values->n_children); + bool failed = false; + for (int64_t child_i = 0; child_i < items_values->n_children; child_i++) { + struct ArrowSchema * item_schema = items_schema->children[child_i]; + struct ArrowArray * item_values = items_values->children[child_i]; + + std::vector childrens; + if (arrow_array_to_nif_term(env, item_schema, item_values, level + 1, childrens, error) == 1) { + return error; + } + + if (childrens.size() == 0) { + children[child_i] = childrens[0]; + } else { + children[child_i] = enif_make_tuple2(env, childrens[0], childrens[1]); + } + } + + return enif_make_list_from_array(env, children.data(), (unsigned)items_values->n_children); +} + +int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, uint64_t level, std::vector &out_terms, ERL_NIF_TERM &error) { if (schema == nullptr) { - return erlang::nif::error(env, "invalid ArrowSchema (nullptr) when invoking next"); + error = erlang::nif::error(env, "invalid ArrowSchema (nullptr) when invoking next"); + return 1; } if (values == nullptr) { - return erlang::nif::error(env, "invalid ArrowArray (nullptr) when invoking next"); + error = erlang::nif::error(env, "invalid ArrowArray (nullptr) when invoking next"); + return 1; } const char* format = schema->format ? schema->format : ""; const char* name = schema->name ? schema->name : ""; - ERL_NIF_TERM current_term{}, children_term{}, error{}; + ERL_NIF_TERM current_term{}, children_term{}; std::vector children; bool has_validity_bitmap = values->null_count != 0 && values->null_count != -1; @@ -342,39 +431,21 @@ ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema // only handle and return children if this is a struct is_struct = true; if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { - return error; + return 1; } children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); - //get_arrow_array_children_as_list(env, schema, values, level, children, error); } else if (strncmp("+m", format, 2) == 0) { - // if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { - // return error; - // } - // children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); children_term = get_arrow_array_map_children(env, schema, values, level); } else if (strncmp("+l", format, 2) == 0 || strncmp("+L", format, 2) == 0) { - if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { - return error; - } - children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); - // children_term = get_arrow_array_children_as_list(env, schema, values, level); + children_term = get_arrow_array_list_children(env, schema, values, level); } else { format_processed = false; } } else if (format_len >= 4) { if (strncmp("+w:", format, 3) == 0) { - if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { - return error; - } - children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); - // children_term = get_arrow_array_children_as_list(env, schema, values, level); + children_term = get_arrow_array_list_children(env, schema, values, level); } else if (format_len > 4 && (strncmp("+ud:", format, 4) == 0 || strncmp("+us:", format, 4) == 0)) { - // todo: get as map - if (get_arrow_array_children_as_list(env, schema, values, level, children, error) == 1) { - return error; - } - children_term = enif_make_list_from_array(env, children.data(), (unsigned)schema->n_children); - // children_term = get_arrow_array_children_as_list(env, schema, values, level); + children_term = get_arrow_array_union_children(env, schema, values, level); } else { format_processed = false; } @@ -395,21 +466,21 @@ ERL_NIF_TERM arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema // printf("buffers: %p\r\n", values->buffers); } + out_terms.clear(); + if (is_struct) { - return children_term; + out_terms.emplace_back(children_term); } else { if (schema->children) { - return enif_make_tuple2(env, - erlang::nif::make_binary(env, name), - children_term - ); + out_terms.emplace_back(erlang::nif::make_binary(env, name)); + out_terms.emplace_back(children_term); } else { - return enif_make_tuple2(env, - erlang::nif::make_binary(env, name), - current_term - ); + out_terms.emplace_back(erlang::nif::make_binary(env, name)); + out_terms.emplace_back(current_term); } } + + return 0; } static ERL_NIF_TERM adbc_database_new(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { @@ -804,8 +875,19 @@ static ERL_NIF_TERM adbc_arrow_array_stream_next(ErlNifEnv *env, int argc, const } } + std::vector out_terms; + auto schema = (struct ArrowSchema*)res->private_data; - ret = arrow_array_to_nif_term(env, schema, &out, 0); + if (arrow_array_to_nif_term(env, schema, &out, 0, out_terms, error) == 1) { + if (out.release) out.release(&out); + return error; + } + + if (out_terms.size() == 1) { + ret = out_terms[0]; + } else { + ret = enif_make_tuple2(env, out_terms[0], out_terms[1]); + } if (out.release) { out.release(&out); diff --git a/test/adbc_connection_test.exs b/test/adbc_connection_test.exs index 66155c1..95199e2 100644 --- a/test/adbc_connection_test.exs +++ b/test/adbc_connection_test.exs @@ -49,16 +49,15 @@ defmodule Adbc.Connection.Test do num_rows: nil, data: %{ "info_name" => [0, 1, 100, 101, 102], - "info_value" => [ - {"string_value", - # ["SQLite", "3.39.2", "ADBC SQLite Driver", "(unknown)", "0.2.0-SNAPSHOT"]}, - ["SQLite", _, "ADBC SQLite Driver", _, _]}, - {"bool_value", []}, - {"int64_value", []}, - {"int32_bitmask", []}, - {"string_list", [{"item", []}]}, - {"int32_to_int32_list_map", %{}} - ] + "info_value" => %{ + "bool_value" => [], + "int32_bitmask" => [], + "int32_to_int32_list_map" => %{}, + "int64_value" => [], + "string_list" => [], + # ["SQLite", "3.39.2", "ADBC SQLite Driver", "(unknown)", "0.2.0-SNAPSHOT"]}, + "string_value" => ["SQLite", _, "ADBC SQLite Driver", _, _] + } } }} = Connection.get_info(conn) end @@ -71,14 +70,14 @@ defmodule Adbc.Connection.Test do num_rows: nil, data: %{ "info_name" => [0], - "info_value" => [ - {"string_value", ["SQLite"]}, - {"bool_value", []}, - {"int64_value", []}, - {"int32_bitmask", []}, - {"string_list", [{"item", []}]}, - {"int32_to_int32_list_map", %{}} - ] + "info_value" => %{ + "bool_value" => [], + "int32_bitmask" => [], + "int32_to_int32_list_map" => %{}, + "int64_value" => [], + "string_list" => [], + "string_value" => ["SQLite"] + } } }} = Connection.get_info(conn, [0]) end