Skip to content

Commit

Permalink
Add support for decimal values (#79)
Browse files Browse the repository at this point in the history
* added decimal to deps list

* added support for decoding decimals from arrow array

* implemented `Adbc.Column.decimal{128,256}`
  • Loading branch information
cocoa-xu authored May 17, 2024
1 parent e8545ae commit 23db3dd
Show file tree
Hide file tree
Showing 10 changed files with 592 additions and 15 deletions.
92 changes: 89 additions & 3 deletions c_src/adbc_arrow_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,46 @@ template <typename M, typename OffsetT> static ERL_NIF_TERM strings_from_buffer(
return strings_from_buffer(env, 0, length, validity_bitmap, offsets_buffer, value_buffer, value_to_nif);
}

template <typename M>
static ERL_NIF_TERM fixed_size_binary_from_buffer(
ErlNifEnv *env,
int64_t element_offset,
int64_t element_count,
size_t element_bytes,
const uint8_t * validity_bitmap,
const uint8_t* value_buffer,
const M& value_to_nif) {
std::vector<ERL_NIF_TERM> values(element_count);
if (validity_bitmap == nullptr) {
for (int64_t i = element_offset; i < element_offset + element_count; i++) {
values[i - element_offset] = value_to_nif(env, &value_buffer[element_bytes * i]);
}
} else {
int64_t index = 0;
for (int64_t i = element_offset; i < element_offset + element_count; i++) {
uint8_t vbyte = validity_bitmap[i / 8];
if (vbyte & (1 << (i % 8))) {
values[i - element_offset] = value_to_nif(env, &value_buffer[element_bytes * index]);
index++;
} else {
values[i - element_offset] = kAtomNil;
}
}
}

return enif_make_list_from_array(env, values.data(), (unsigned)values.size());
}

template <typename M> static ERL_NIF_TERM fixed_size_binary_from_buffer(
ErlNifEnv *env,
int64_t length,
size_t element_bytes,
const uint8_t * validity_bitmap,
const uint8_t* value_buffer,
const M& value_to_nif) {
return fixed_size_binary_from_buffer(env, 0, length, element_bytes, validity_bitmap, value_buffer, value_to_nif);
}

int get_arrow_array_children_as_list(ErlNifEnv *env, struct ArrowSchema * schema, struct ArrowArray * values, int64_t offset, int64_t count, uint64_t level, std::vector<ERL_NIF_TERM> &children, ERL_NIF_TERM &error) {
if (schema->n_children > 0 && schema->children == nullptr) {
error = erlang::nif::error(env, "invalid ArrowSchema, schema->children == nullptr, however, schema->n_children > 0");
Expand Down Expand Up @@ -443,6 +483,7 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
return 1;
}

char err_msg_buf[256] = { '\0' };
const char* format = schema->format ? schema->format : "";
const char* name = schema->name ? schema->name : "";
term_type = kAtomNil;
Expand Down Expand Up @@ -777,6 +818,52 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
// NANOARROW_TYPE_SPARSE_UNION
term_type = kAdbcColumnTypeSparseUnion;
children_term = get_arrow_array_sparse_union_children(env, schema, values, offset, count, level);
} else if (strncmp("d:", format, 2) == 0) {
// NANOARROW_TYPE_DECIMAL128
// NANOARROW_TYPE_DECIMAL256
//
// format should match `d:P,S[,N]`
// where P is precision, S is scale, N is bits
// N is optional and defaults to 128
int precision = 0;
int scale = 0;
int bits = 128;
int * d[3] = {&precision, &scale, &bits};
int index = 0;
for (size_t i = 2; i < format_len; i++) {
if (format[i] == ',') {
if (index < 2) {
index++;
} else {
format_processed = false;
break;
}
continue;
}

*d[index] = *d[index] * 10 + (format[i] - '0');
}

if (format_processed) {
term_type = kAdbcColumnTypeDecimal(bits, precision, scale);
if (count == -1) count = values->length;
if (values->n_buffers != 2) {
snprintf(err_msg_buf, 255, "invalid n_buffers value for ArrowArray (format=%s), values->n_buffers != 2", schema->format);
error = erlang::nif::error(env, erlang::nif::make_binary(env, err_msg_buf));
return 1;
}
current_term = fixed_size_binary_from_buffer(
env,
offset,
count,
bits / 8,
(const uint8_t *)values->buffers[bitmap_buffer_index],
(const uint8_t *)values->buffers[data_buffer_index],
[&](ErlNifEnv *env, const uint8_t * val) -> ERL_NIF_TERM {
return erlang::nif::make_binary(env, (const char *)val, bits / 8);
}
);
}
} else if (strncmp("td", format, 2) == 0) {
char unit = format[2];

Expand Down Expand Up @@ -1065,9 +1152,8 @@ int arrow_array_to_nif_term(ErlNifEnv *env, struct ArrowSchema * schema, struct
}

if (!format_processed) {
char buf[256] = { '\0' };
snprintf(buf, 255, "not yet implemented for format: `%s`", schema->format);
error = erlang::nif::error(env, erlang::nif::make_binary(env, buf));
snprintf(err_msg_buf, 255, "not yet implemented for format: `%s`", schema->format);
error = erlang::nif::error(env, erlang::nif::make_binary(env, err_msg_buf));
return 1;
// printf("not implemented for format: `%s`\r\n", schema->format);
// printf("length: %lld\r\n", values->length);
Expand Down
66 changes: 65 additions & 1 deletion c_src/adbc_column.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,55 @@ int do_get_list_float(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, ArrowTyp
}
}

int get_list_decimal(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, ArrowType nanoarrow_type, int32_t bitwidth, int32_t precision, int32_t scale, const std::function<void(struct ArrowDecimal * val, bool is_nil)> &callback) {
ERL_NIF_TERM head, tail;
tail = list;
while (enif_get_list_cell(env, tail, &head, &tail)) {
struct ArrowDecimal val{};
ArrowDecimalInit(&val, bitwidth, precision, scale);
ErlNifBinary bytes;
if (enif_is_binary(env, head) && enif_inspect_binary(env, head, &bytes)) {
if (nanoarrow_type == NANOARROW_TYPE_DECIMAL128) {
if (bytes.size != 16) {
return 1;
}
ArrowDecimalSetBytes(&val, (const uint8_t *)bytes.data);
} else if (nanoarrow_type == NANOARROW_TYPE_DECIMAL256) {
if (bytes.size != 32) {
return 1;
}
ArrowDecimalSetBytes(&val, (const uint8_t *)bytes.data);
} else {
return 1;
}
callback(&val, false);
} else if (nullable && enif_is_identical(head, kAtomNil)) {
callback(&val, true);
} else {
return 1;
}
}
return 0;
}

int do_get_list_decimal(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, ArrowType nanoarrow_type, int32_t bitwidth, int32_t precision, int32_t scale, struct ArrowArray* array_out, struct ArrowSchema* schema_out, struct ArrowError* error_out) {
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeDecimal(schema_out, nanoarrow_type, precision, scale));
NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromSchema(array_out, schema_out, error_out));
NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(array_out));
if (nullable) {
return get_list_decimal(env, list, nullable, nanoarrow_type, bitwidth, precision, scale, [&array_out](struct ArrowDecimal * val, bool is_nil) -> void {
ArrowArrayAppendDecimal(array_out, val);
if (is_nil) {
ArrowArrayAppendNull(array_out, 1);
}
});
} else {
return get_list_decimal(env, list, nullable, nanoarrow_type, bitwidth, precision, scale, [&array_out](struct ArrowDecimal * val, bool) -> void {
ArrowArrayAppendDecimal(array_out, val);
});
}
}

int get_list_string(ErlNifEnv *env, ERL_NIF_TERM list, bool nullable, const std::function<void(struct ArrowStringView val, bool is_nil)> &callback) {
ERL_NIF_TERM head, tail;
tail = list;
Expand Down Expand Up @@ -745,7 +794,22 @@ int adbc_column_to_adbc_field(ErlNifEnv *env, ERL_NIF_TERM adbc_buffer, struct A
ret = do_get_list_timestamp(env, data_term, nullable, NANOARROW_TYPE_TIMESTAMP, NANOARROW_TIME_UNIT_NANO, 1, timezone.c_str(), array_out, schema_out, error_out);
}
}
}
}
} else if (arity == 4) {
// NANOARROW_TYPE_DECIMAL128
// NANOARROW_TYPE_DECIMAL256
if (enif_is_identical(tuple[0], kAtomDecimal)) {
int bits = 0;
int precision = 0;
int scale = 0;
if (erlang::nif::get(env, tuple[1], &bits) && erlang::nif::get(env, tuple[2], &precision) && erlang::nif::get(env, tuple[3], &scale)) {
if (bits == 128) {
ret = do_get_list_decimal(env, data_term, nullable, NANOARROW_TYPE_DECIMAL128, bits, precision, scale, array_out, schema_out, error_out);
} else if (bits == 256) {
ret = do_get_list_decimal(env, data_term, nullable, NANOARROW_TYPE_DECIMAL256, bits, precision, scale, array_out, schema_out, error_out);
}
}
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion c_src/adbc_consts.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ static ERL_NIF_TERM kAtomMilliseconds;
static ERL_NIF_TERM kAtomMicroseconds;
static ERL_NIF_TERM kAtomNanoseconds;
static ERL_NIF_TERM kAtomTimestamp;
static ERL_NIF_TERM kAtomDecimal;

static ERL_NIF_TERM kAtomCalendarKey;
static ERL_NIF_TERM kAtomCalendarISO;
Expand All @@ -40,7 +41,6 @@ static ERL_NIF_TERM kAtomTypeKey;
static ERL_NIF_TERM kAtomNullableKey;
static ERL_NIF_TERM kAtomMetadataKey;
static ERL_NIF_TERM kAtomDataKey;
// static ERL_NIF_TERM kAtomPrivateKey;

static ERL_NIF_TERM kAdbcColumnTypeU8;
static ERL_NIF_TERM kAdbcColumnTypeU16;
Expand Down Expand Up @@ -77,6 +77,7 @@ static ERL_NIF_TERM kAdbcColumnTypeBool;
#define kAdbcColumnTypeDurationMilliseconds enif_make_tuple2(env, kAtomDuration, kAtomMilliseconds)
#define kAdbcColumnTypeDurationMicroseconds enif_make_tuple2(env, kAtomDuration, kAtomMicroseconds)
#define kAdbcColumnTypeDurationNanoseconds enif_make_tuple2(env, kAtomDuration, kAtomNanoseconds)
#define kAdbcColumnTypeDecimal(bitwidth, precision, scale) enif_make_tuple4(env, kAtomDecimal, enif_make_int(env, bitwidth), enif_make_int(env, precision), enif_make_int(env, scale))

// error codes
constexpr int kErrorBufferIsNotAMap = 1;
Expand Down
2 changes: 1 addition & 1 deletion c_src/adbc_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ static int on_load(ErlNifEnv *env, void **, ERL_NIF_TERM) {
kAtomMicroseconds = erlang::nif::atom(env, "microseconds");
kAtomNanoseconds = erlang::nif::atom(env, "nanoseconds");
kAtomTimestamp = erlang::nif::atom(env, "timestamp");
kAtomDecimal = erlang::nif::atom(env, "decimal");

kAtomCalendarKey = erlang::nif::atom(env, "calendar");
kAtomCalendarISO = erlang::nif::atom(env, "Elixir.Calendar.ISO");
Expand All @@ -818,7 +819,6 @@ static int on_load(ErlNifEnv *env, void **, ERL_NIF_TERM) {
kAtomNullableKey = erlang::nif::atom(env, "nullable");
kAtomMetadataKey = erlang::nif::atom(env, "metadata");
kAtomDataKey = erlang::nif::atom(env, "data");
// kAdbcBufferPrivateKey = enif_make_atom(env, "__private__");

kAdbcColumnTypeU8 = erlang::nif::atom(env, "u8");
kAdbcColumnTypeU16 = erlang::nif::atom(env, "u16");
Expand Down
95 changes: 95 additions & 0 deletions lib/adbc_column.ex
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ defmodule Adbc.Column do
@type floating ::
:f32
| :f64
@type decimal128 :: {:decimal, 128, integer(), integer()}
@type decimal256 :: {:decimal, 256, integer(), integer()}
@type decimal_t ::
decimal128
| decimal256
@type time_unit ::
:seconds
| :milliseconds
Expand All @@ -50,6 +55,7 @@ defmodule Adbc.Column do
| signed_integer
| unsigned_integer
| floating
| decimal_t
| :string
| :large_string
| :binary
Expand Down Expand Up @@ -443,6 +449,95 @@ defmodule Adbc.Column do
column(:f64, data, opts)
end

@doc """
A column that contains 128-bit decimals.
## Arguments
* `data`: a list of `Decimal.t()`
* `precision`: The precision of the decimal values
* `scale`: The scale of the decimal values
* `opts`: A keyword list of options
## Options
* `:name` - The name of the column
* `:nullable` - A boolean value indicating whether the column is nullable
* `:metadata` - A map of metadata
"""
@spec decimal128([Decimal.t()], integer(), integer(), Keyword.t()) :: %Adbc.Column{}
def decimal128(data, precision, scale, opts \\ []) do
bitwidth = 128

column(
{:decimal, bitwidth, precision, scale},
preprocess_decimal(bitwidth, precision, scale, data, []),
opts
)
end

@doc """
A column that contains 256-bit decimals.
## Arguments
* `data`: a list of `Decimal.t()`
* `precision`: The precision of the decimal values
* `scale`: The scale of the decimal values
* `opts`: A keyword list of options
## Options
* `:name` - The name of the column
* `:nullable` - A boolean value indicating whether the column is nullable
* `:metadata` - A map of metadata
"""
@spec decimal256([Decimal.t()], integer(), integer(), Keyword.t()) :: %Adbc.Column{}
def decimal256(data, precision, scale, opts \\ []) do
bitwidth = 256

column(
{:decimal, bitwidth, precision, scale},
preprocess_decimal(bitwidth, precision, scale, data, []),
opts
)
end

defp preprocess_decimal(_bitwidth, _precision, _scale, [], acc), do: Enum.reverse(acc)

defp preprocess_decimal(bitwidth, precision, scale, [nil | rest], acc) do
preprocess_decimal(bitwidth, precision, scale, rest, [nil | acc])
end

defp preprocess_decimal(
bitwidth,
_precision,
scale,
[%Decimal{exp: exp} = decimal | _rest],
_acc
)
when -exp > scale do
raise Adbc.Error,
"`#{Decimal.to_string(decimal)}` with exponent `#{exp}` cannot be represented as a valid decimal#{Integer.to_string(bitwidth)} number with scale value `#{scale}`"
end

defp preprocess_decimal(bitwidth, precision, scale, [%Decimal{exp: exp} = decimal | rest], acc)
when -exp <= scale do
if Decimal.inf?(decimal) or Decimal.nan?(decimal) do
raise Adbc.Error,
"`#{Decimal.to_string(decimal)}` cannot be represented as a valid decimal#{Integer.to_string(bitwidth)} number"
else
if Decimal.coef_length(decimal.coef) > precision do
raise Adbc.Error,
"`#{Decimal.to_string(decimal)}` cannot be fitted into a decimal#{Integer.to_string(bitwidth)} with the specified precision #{Integer.to_string(precision)}"
else
coef = trunc(decimal.coef * decimal.sign * :math.pow(10, exp + scale))
acc = [<<coef::signed-integer-little-size(bitwidth)>> | acc]
preprocess_decimal(bitwidth, precision, scale, rest, acc)
end
end
end

@doc """
A column that contains UTF-8 encoded strings.
Expand Down
27 changes: 26 additions & 1 deletion lib/adbc_connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -407,16 +407,41 @@ defmodule Adbc.Connection do
end
end

defp merge_columns([result]), do: result
defp merge_columns([result]), do: handle_decimal(result)

defp merge_columns(chucked_results) do
Enum.zip_with(chucked_results, fn columns ->
Enum.reduce(columns, fn column, merged_column ->
column = handle_decimal(column)
%{merged_column | data: merged_column.data ++ column.data}
end)
end)
end

defp handle_decimal([column | rest]) do
[handle_decimal(column) | handle_decimal(rest)]
end

defp handle_decimal(%Adbc.Column{type: {:decimal, bits, _, scale}, data: decimal_data} = column) do
%{column | data: handle_decimal(decimal_data, bits, scale)}
end

defp handle_decimal(column) do
column
end

defp handle_decimal(decimal_data, bits, scale) do
Enum.map(decimal_data, fn data ->
<<decimal::signed-integer-size(bits)-little>> = data

if decimal < 0 do
Decimal.new(-1, -decimal, -scale)
else
Decimal.new(1, decimal, -scale)
end
end)
end

## Callbacks

@impl true
Expand Down
Loading

0 comments on commit 23db3dd

Please sign in to comment.