Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose more APIs #64

Merged
merged 13 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 186 additions & 35 deletions c_src/adbc_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
// only handle and return children if this is a struct
is_struct = true;

if (schema->n_children == values->n_children) {
if (values->length > 0 || values->release != nullptr) {
if (count == -1) count = values->n_children;
if (get_arrow_array_children_as_list(env, schema, values, offset, count, level, children, error) == 1) {
return 1;
Expand Down Expand Up @@ -1026,32 +1026,161 @@ static ERL_NIF_TERM adbc_database_new(ErlNifEnv *env, int argc, const ERL_NIF_TE
return erlang::nif::ok(env, ret);
}

static ERL_NIF_TERM adbc_database_set_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
using res_type = NifRes<struct AdbcDatabase>;
template <typename T, typename GetString, typename GetBytes, typename GetInt, typename GetDouble>
static ERL_NIF_TERM adbc_get_option(ErlNifEnv *env, const ERL_NIF_TERM argv[], GetString& get_string, GetBytes& get_bytes, GetInt& get_int, GetDouble& get_double) {
using res_type = NifRes<T>;

ERL_NIF_TERM error{};
res_type * database = nullptr;
if ((database = res_type::get_resource(env, argv[0], error)) == nullptr) {
res_type * resource = nullptr;
if ((resource = res_type::get_resource(env, argv[0], error)) == nullptr) {
return error;
}

std::string key, value;
if (!erlang::nif::get(env, argv[1], key)) {
std::string type, key;
if (!erlang::nif::get(env, argv[1], type)) {
return enif_make_badarg(env);
}
if (!erlang::nif::get(env, argv[2], value)) {
if (!erlang::nif::get(env, argv[2], key)) {
return enif_make_badarg(env);
}

struct AdbcError adbc_error{};
AdbcStatusCode code = AdbcDatabaseSetOption(&database->val, key.c_str(), value.c_str(), &adbc_error);
if (type == "string" || type == "binary") {
int is_string = type == "string";
uint8_t value[64] = {'\0'};
constexpr size_t value_buffer_size = sizeof(value) / sizeof(value[0]);
size_t value_len = value_buffer_size;
AdbcStatusCode code;
size_t elem_size = 0;
if (is_string) {
elem_size = sizeof(char);
code = get_string(&resource->val, key.c_str(), (char *)value, &value_len, &adbc_error);
} else {
elem_size = sizeof(uint8_t);
code = get_bytes(&resource->val, key.c_str(), value, &value_len, &adbc_error);
}
if (code != ADBC_STATUS_OK) {
return nif_error_from_adbc_error(env, &adbc_error);
}

if (value_len > value_buffer_size) {
uint8_t * out_value = (uint8_t *)enif_alloc(elem_size * (value_len + 1));
memset(out_value, 0, elem_size * (value_len + 1));
value_len += 1;
if (is_string) {
code = get_string(&resource->val, key.c_str(), (char *)out_value, &value_len, &adbc_error);
} else {
code = get_bytes(&resource->val, key.c_str(), out_value, &value_len, &adbc_error);
}

if (code != ADBC_STATUS_OK) {
return nif_error_from_adbc_error(env, &adbc_error);
}

// minus 1 to remove the null terminator for strings
ERL_NIF_TERM ret;
ret = erlang::nif::make_binary(env, (const char *)out_value, value_len - (is_string ? 1 : 0));
enif_free(out_value);
return erlang::nif::ok(env, ret);
} else {
// minus 1 to remove the null terminator for strings
return erlang::nif::ok(env, erlang::nif::make_binary(env, (const char *)value, value_len - (is_string ? 1 : 0)));
}
} else if (type == "integer") {
int64_t value = 0;
AdbcStatusCode code = get_int(&resource->val, key.c_str(), &value, &adbc_error);
if (code != ADBC_STATUS_OK) {
return nif_error_from_adbc_error(env, &adbc_error);
}

return erlang::nif::ok(env, erlang::nif::make(env, value));
} else if (type == "float") {
double value = 0;
AdbcStatusCode code = get_double(&resource->val, key.c_str(), &value, &adbc_error);
if (code != ADBC_STATUS_OK) {
return nif_error_from_adbc_error(env, &adbc_error);
}

return erlang::nif::ok(env, erlang::nif::make(env, value));
} else {
return enif_make_badarg(env);
}
}

template <typename T, typename SetString, typename SetBytes, typename SetInt, typename SetDouble>
static ERL_NIF_TERM adbc_set_option(ErlNifEnv *env, const ERL_NIF_TERM argv[], SetString& set_string, SetBytes &set_bytes, SetInt &set_int, SetDouble &set_double) {
using res_type = NifRes<T>;

ERL_NIF_TERM error{};
res_type * resource = nullptr;
if ((resource = res_type::get_resource(env, argv[0], error)) == nullptr) {
return error;
}

std::string type, key;
if (!erlang::nif::get_atom(env, argv[1], type)) {
return enif_make_badarg(env);
}
if (!erlang::nif::get(env, argv[2], key)) {
return enif_make_badarg(env);
}

struct AdbcError adbc_error{};
AdbcStatusCode code;
if (type == "string" || type == "binary") {
std::string value;
if (!erlang::nif::get(env, argv[3], value)) {
return enif_make_badarg(env);
}
if (type == "string") {
code = set_string(&resource->val, key.c_str(), value.c_str(), &adbc_error);
} else {
code = set_bytes(&resource->val, key.c_str(), (const uint8_t *)value.data(), value.length(), &adbc_error);
}
} else if (type == "integer") {
int64_t value;
if (!erlang::nif::get(env, argv[3], &value)) {
return enif_make_badarg(env);
}
code = set_int(&resource->val, key.c_str(), value, &adbc_error);
} else if (type == "float") {
double value;
if (!erlang::nif::get(env, argv[3], &value)) {
return enif_make_badarg(env);
}
code = set_double(&resource->val, key.c_str(), value, &adbc_error);
} else {
return enif_make_badarg(env);
}

if (code != ADBC_STATUS_OK) {
return nif_error_from_adbc_error(env, &adbc_error);
}

return erlang::nif::ok(env);
}

static ERL_NIF_TERM adbc_database_get_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
return adbc_get_option<struct AdbcDatabase>(
env,
argv,
AdbcDatabaseGetOption,
AdbcDatabaseGetOptionBytes,
AdbcDatabaseGetOptionInt,
AdbcDatabaseGetOptionDouble
);
}

static ERL_NIF_TERM adbc_database_set_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
return adbc_set_option<struct AdbcDatabase>(
env,
argv,
AdbcDatabaseSetOption,
AdbcDatabaseSetOptionBytes,
AdbcDatabaseSetOptionInt,
AdbcDatabaseSetOptionDouble
);
}

static ERL_NIF_TERM adbc_database_init(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
using res_type = NifRes<struct AdbcDatabase>;

Expand Down Expand Up @@ -1089,30 +1218,26 @@ static ERL_NIF_TERM adbc_connection_new(ErlNifEnv *env, int argc, const ERL_NIF_
return erlang::nif::ok(env, ret);
}

static ERL_NIF_TERM adbc_connection_set_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
using res_type = NifRes<struct AdbcConnection>;

ERL_NIF_TERM error{};
res_type * connection = res_type::get_resource(env, argv[0], error);
if (connection == nullptr) {
return error;
}

std::string key, value;
if (!erlang::nif::get(env, argv[1], key)) {
return enif_make_badarg(env);
}
if (!erlang::nif::get(env, argv[2], value)) {
return enif_make_badarg(env);
}

struct AdbcError adbc_error{};
AdbcStatusCode code = AdbcConnectionSetOption(&connection->val, key.c_str(), value.c_str(), &adbc_error);
if (code != ADBC_STATUS_OK) {
return nif_error_from_adbc_error(env, &adbc_error);
}
static ERL_NIF_TERM adbc_connection_get_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
return adbc_get_option<struct AdbcConnection>(
env,
argv,
AdbcConnectionGetOption,
AdbcConnectionGetOptionBytes,
AdbcConnectionGetOptionInt,
AdbcConnectionGetOptionDouble
);
}

return erlang::nif::ok(env);
static ERL_NIF_TERM adbc_connection_set_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
return adbc_set_option<struct AdbcConnection>(
env,
argv,
AdbcConnectionSetOption,
AdbcConnectionSetOptionBytes,
AdbcConnectionSetOptionInt,
AdbcConnectionSetOptionDouble
);
}

static ERL_NIF_TERM adbc_connection_init(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
Expand Down Expand Up @@ -1428,6 +1553,28 @@ static ERL_NIF_TERM adbc_statement_new(ErlNifEnv *env, int argc, const ERL_NIF_T
return erlang::nif::ok(env, ret);
}

static ERL_NIF_TERM adbc_statement_get_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
return adbc_get_option<struct AdbcStatement>(
env,
argv,
AdbcStatementGetOption,
AdbcStatementGetOptionBytes,
AdbcStatementGetOptionInt,
AdbcStatementGetOptionDouble
);
}

static ERL_NIF_TERM adbc_statement_set_option(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
return adbc_set_option<struct AdbcStatement>(
env,
argv,
AdbcStatementSetOption,
AdbcStatementSetOptionBytes,
AdbcStatementSetOptionInt,
AdbcStatementSetOptionDouble
);
}

static ERL_NIF_TERM adbc_statement_execute_query(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
using res_type = NifRes<struct AdbcStatement>;
using array_stream_type = NifRes<struct ArrowArrayStream>;
Expand Down Expand Up @@ -1712,17 +1859,21 @@ static int on_upgrade(ErlNifEnv *, void **, void **, ERL_NIF_TERM) {

static ErlNifFunc nif_functions[] = {
{"adbc_database_new", 0, adbc_database_new, 0},
{"adbc_database_set_option", 3, adbc_database_set_option, 0},
{"adbc_database_get_option", 3, adbc_database_get_option, 0},
{"adbc_database_set_option", 4, adbc_database_set_option, 0},
{"adbc_database_init", 1, adbc_database_init, 0},

{"adbc_connection_new", 0, adbc_connection_new, 0},
{"adbc_connection_set_option", 3, adbc_connection_set_option, 0},
{"adbc_connection_get_option", 3, adbc_connection_get_option, 0},
{"adbc_connection_set_option", 4, adbc_connection_set_option, 0},
{"adbc_connection_init", 2, adbc_connection_init, 0},
{"adbc_connection_get_info", 2, adbc_connection_get_info, 0},
{"adbc_connection_get_objects", 7, adbc_connection_get_objects, 0},
{"adbc_connection_get_table_types", 1, adbc_connection_get_table_types, 0},

{"adbc_statement_new", 1, adbc_statement_new, 0},
{"adbc_statement_get_option", 3, adbc_statement_get_option, 0},
{"adbc_statement_set_option", 4, adbc_statement_set_option, 0},
{"adbc_statement_execute_query", 1, adbc_statement_execute_query, 0},
{"adbc_statement_prepare", 1, adbc_statement_prepare, 0},
{"adbc_statement_set_sql_query", 2, adbc_statement_set_sql_query, 0},
Expand Down
Loading