Skip to content

Commit

Permalink
Add initial support to Decimal
Browse files Browse the repository at this point in the history
It can read decimals and load from the Rust backend to Elixir using the
"Decimal" package.

This implementation was based on alexpearce@709aa67
  • Loading branch information
philss committed Sep 17, 2024
1 parent 56026d5 commit 00569ac
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 2 deletions.
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_from_list_str(_name, _val), do: err()
def s_from_list_binary(_name, _val), do: err()
def s_from_list_categories(_name, _val), do: err()
def s_from_list_decimal(_name, _val, _precision, _scale), do: err()
def s_from_list_of_series(_name, _val, _dtype), do: err()
def s_from_list_of_series_as_structs(_name, _val, _dtype), do: err()
def s_from_binary_f32(_name, _val), do: err()
Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ defmodule Explorer.PolarsBackend.Shared do
{:duration, precision} -> apply(:s_from_list_duration, [name, list, precision])
:binary -> Native.s_from_list_binary(name, list)
:null -> Native.s_from_list_null(name, length(list))
{:decimal, precision, scale} -> Native.s_from_list_decimal(name, list, precision, scale)
end
end

Expand Down
6 changes: 5 additions & 1 deletion lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ defmodule Explorer.Shared do
within lists inside.
"""
def dtypes do
@scalar_types ++ [{:list, :any}, {:struct, :any}]
@scalar_types ++ [{:list, :any}, {:struct, :any}, {:decimal, :any, :any}]
end

@doc """
Expand Down Expand Up @@ -99,6 +99,9 @@ defmodule Explorer.Shared do
{:naive_datetime, precision}
end

def normalise_dtype({:decimal, _precision, _scale} = dtype), do: dtype
def normalise_dtype(:decimal), do: {:decimal, nil, 2}

def normalise_dtype(_dtype), do: nil

@doc """
Expand Down Expand Up @@ -494,6 +497,7 @@ defmodule Explorer.Shared do
def dtype_to_string({:f, size}), do: "f" <> Integer.to_string(size)
def dtype_to_string({:s, size}), do: "s" <> Integer.to_string(size)
def dtype_to_string({:u, size}), do: "u" <> Integer.to_string(size)
def dtype_to_string({:decimal, precision, scale}), do: "decimal[#{precision}, #{scale}]"
def dtype_to_string(other) when is_atom(other), do: Atom.to_string(other)

defp precision_string(:millisecond), do: "ms"
Expand Down
1 change: 1 addition & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ defmodule Explorer.MixProject do
{:rustler_precompiled, "~> 0.7"},
{:table, "~> 0.1.2"},
{:table_rex, "~> 3.1.1 or ~> 4.0.0"},
{:decimal, "~> 2.1"},

## Optional
{:flame, "~> 0.3", optional: true},
Expand Down
31 changes: 31 additions & 0 deletions native/explorer/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,33 @@ impl Literal for ExTime {
}
}

#[derive(NifStruct, Copy, Clone, Debug)]
#[module = "Decimal"]
pub struct ExDecimal {
pub sign: i8,
pub coef: u64,
pub exp: i64,
}

impl ExDecimal {
pub fn signed_coef(self) -> i128 {
self.sign as i128 * self.coef as i128
}
}

impl Literal for ExDecimal {
fn lit(self) -> Expr {
Expr::Literal(LiteralValue::Decimal(
if self.sign.is_positive() {
self.coef.into()
} else {
-(self.coef as i128)
},
usize::try_from(-(self.exp)).expect("exponent should fit an usize"),
))
}
}

/// Represents valid Elixir types that can be used as literals in Polars.
pub enum ExValidValue<'a> {
I64(i64),
Expand All @@ -575,6 +602,7 @@ pub enum ExValidValue<'a> {
Time(ExTime),
DateTime(ExNaiveDateTime),
Duration(ExDuration),
Decimal(ExDecimal),
}

impl<'a> ExValidValue<'a> {
Expand All @@ -598,6 +626,7 @@ impl<'a> Literal for &ExValidValue<'a> {
ExValidValue::Time(v) => v.lit(),
ExValidValue::DateTime(v) => v.lit(),
ExValidValue::Duration(v) => v.lit(),
ExValidValue::Decimal(v) => v.lit(),
}
}
}
Expand All @@ -620,6 +649,8 @@ impl<'a> rustler::Decoder<'a> for ExValidValue<'a> {
Ok(ExValidValue::DateTime(datetime))
} else if let Ok(duration) = term.decode::<ExDuration>() {
Ok(ExValidValue::Duration(duration))
} else if let Ok(decimal) = term.decode::<ExDecimal>() {
Ok(ExValidValue::Decimal(decimal))
} else {
Err(rustler::Error::BadArg)
}
Expand Down
5 changes: 5 additions & 0 deletions native/explorer/src/datatypes/ex_dtypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ pub enum ExSeriesDtype {
Duration(ExTimeUnit),
List(Box<ExSeriesDtype>),
Struct(Vec<(String, ExSeriesDtype)>),
// Precision and scale.
Decimal(Option<usize>, Option<usize>),
}

impl TryFrom<&DataType> for ExSeriesDtype {
Expand Down Expand Up @@ -113,6 +115,8 @@ impl TryFrom<&DataType> for ExSeriesDtype {
Ok(ExSeriesDtype::Struct(struct_fields))
}

DataType::Decimal(precision, scale) => Ok(ExSeriesDtype::Decimal(*precision, *scale)),

_ => Err(ExplorerError::Other(format!(
"cannot cast to dtype: {value}"
))),
Expand Down Expand Up @@ -171,6 +175,7 @@ impl TryFrom<&ExSeriesDtype> for DataType {
.map(|(k, v)| Ok(Field::new(k.into(), v.try_into()?)))
.collect::<Result<Vec<Field>, Self::Error>>()?,
)),
ExSeriesDtype::Decimal(precision, scale) => Ok(DataType::Decimal(*precision, *scale)),
}
}
}
73 changes: 72 additions & 1 deletion native/explorer/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ fn time_unit_to_atom(time_unit: TimeUnit) -> atom::Atom {
TimeUnit::Nanoseconds => nanosecond(),
}
}

// ######### Duration ##########
macro_rules! unsafe_encode_duration {
($v: expr, $time_unit: expr, $duration_struct_keys: ident, $duration_module: ident, $env: ident) => {{
let value = $v;
Expand Down Expand Up @@ -417,6 +417,75 @@ fn duration_series_to_list<'b>(
))
}

// ######### End of Duration ##########

// ######### Decimal ##########
macro_rules! unsafe_encode_decimal {
($v: expr, $scale: expr, $decimal_struct_keys: ident, $decimal_module: ident, $env: ident) => {{
let coef = $v.abs();
let scale = -($scale as isize);
let sign = $v.signum();

unsafe {
Term::new(
$env,
map::make_map_from_arrays(
$env.as_c_arg(),
$decimal_struct_keys,
&[
$decimal_module,
coef.encode($env).as_c_arg(),
scale.encode($env).as_c_arg(),
sign.encode($env).as_c_arg(),
],
)
.unwrap(),
)
}
}};
}

// Here we build the Decimal struct manually, as it's much faster than using NifStruct
fn decimal_struct_keys(env: Env) -> [NIF_TERM; 4] {
return [
atom::__struct__().encode(env).as_c_arg(),
atoms::coef().encode(env).as_c_arg(),
atoms::exp().encode(env).as_c_arg(),
atoms::sign().encode(env).as_c_arg(),
];
}

#[inline]
pub fn encode_decimal(v: i128, scale: usize, env: Env) -> Result<Term, ExplorerError> {
let struct_keys = &decimal_struct_keys(env);
let module_atom = atoms::decimal_module().encode(env).as_c_arg();

Ok(unsafe_encode_decimal!(
v,
scale,
struct_keys,
module_atom,
env
))
}

#[inline]
fn decimal_series_to_list<'b>(s: &Series, env: Env<'b>) -> Result<Term<'b>, ExplorerError> {
let struct_keys = &decimal_struct_keys(env);
let module_atom = atoms::decimal_module().encode(env).as_c_arg();
let decimal_chunked = s.decimal()?;
let scale = decimal_chunked.scale();

Ok(unsafe_iterator_series_to_list!(
env,
decimal_chunked.into_iter().map(|option| option
.map(|v| { unsafe_encode_decimal!(v, scale, struct_keys, module_atom, env) })
.encode(env))
))
}

// ######### End of Decimal ##########

macro_rules! unsafe_encode_time {
($v: expr, $naive_time_struct_keys: ident, $calendar_iso_module: ident, $time_module: ident, $env: ident) => {{
let t = time64ns_to_time($v);
Expand Down Expand Up @@ -710,6 +779,7 @@ pub fn term_from_value<'b>(v: AnyValue, env: Env<'b>) -> Result<Term<'b>, Explor
.map(|(value, field)| Ok((field.name.as_str(), term_from_value(value, env)?)))
.collect::<Result<HashMap<_, _>, ExplorerError>>()
.map(|map| map.encode(env)),
AnyValue::Decimal(number, scale) => encode_decimal(number, scale, env),
dt => panic!("cannot encode value {dt:?} to term"),
}
}
Expand Down Expand Up @@ -756,6 +826,7 @@ pub fn list_from_series(s: ExSeries, env: Env) -> Result<Term, ExplorerError> {
.map(|value| term_from_value(value, env))
.collect::<Result<Vec<_>, ExplorerError>>()
.map(|values| values.encode(env)),
DataType::Decimal(_precision, _scale) => decimal_series_to_list(&s, env),
dt => panic!("to_list/1 not implemented for {dt:?}"),
}
}
Expand Down
4 changes: 4 additions & 0 deletions native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ mod atoms {
duration_module = "Elixir.Explorer.Duration",
naive_datetime_module = "Elixir.NaiveDateTime",
time_module = "Elixir.Time",
decimal_module = "Elixir.Decimal",
hour,
minute,
second,
Expand All @@ -61,6 +62,9 @@ mod atoms {
time_zone,
utc_offset,
zone_abbr,
coef,
exp,
sign,
}
}

Expand Down
47 changes: 47 additions & 0 deletions native/explorer/src/series/from_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,53 @@ pub fn s_from_list_null(name: &str, length: usize) -> ExSeries {
ExSeries::new(Series::new(name.into(), s))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_from_list_decimal(
_name: &str,
_val: Term,
_precision: Option<usize>,
_scale: Option<usize>,
) -> Result<ExSeries, ExplorerError> {
Err(ExplorerError::Other(
"from_list/2 not yet implemented for decimal lists".into(),
))
// let iterator = val
// .decode::<ListIterator>()
// .map_err(|err| ExplorerError::Other(format!("expecting list as term: {err:?}")))?;
// // let mut precision = precision;
// // let mut scale = scale;

// let values: Vec<Option<i128>> = iterator
// .map(|item| match item.get_type() {
// TermType::Integer => item.decode::<Option<i128>>().map_err(|err| {
// ExplorerError::Other(format!("int number is too big for an i128: {err:?}"))
// }),
// TermType::Map => item
// .decode::<ExDecimal>()
// .map(|ex_decimal| Some(ex_decimal.signed_coef()))
// .map_err(|error| {
// ExplorerError::Other(format!(
// "cannot decode a valid decimal from term. error: {error:?}"
// ))
// }),
// // TODO: handle float special cases
// TermType::Atom => Ok(None),
// term_type => Err(ExplorerError::Other(format!(
// "from_list/2 for decimals not implemented for {term_type:?}"
// ))),
// })
// .collect::<Result<Vec<Option<i128>>, ExplorerError>>()?;

// Series::new(name.into(), values)
// .cast(&DataType::Decimal(precision, scale))
// .map(ExSeries::new)
// .map_err(|error| {
// ExplorerError::Other(format!(
// "from_list/2 cannot cast integer series to a valid decimal series: {error:?}"
// ))
// })
}

macro_rules! from_list {
($name:ident, $type:ty) => {
#[rustler::nif(schedule = "DirtyCpu")]
Expand Down
36 changes: 36 additions & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3899,6 +3899,42 @@ defmodule Explorer.SeriesTest do
assert Series.dtype(s3) == {:naive_datetime, :microsecond}
end

test "integer series to decimal" do
s = Series.from_list([1, 2, 3])
s1 = Series.cast(s, {:decimal, nil, 0})
assert Series.to_list(s1) == [Decimal.new("1"), Decimal.new("2"), Decimal.new("3")]
# 38 is Polars' default for precision.
assert Series.dtype(s1) == {:decimal, 38, 0}

# increased scale
s2 = Series.cast(s, {:decimal, nil, 2})
assert Series.to_list(s2) == [Decimal.new("1.00"), Decimal.new("2.00"), Decimal.new("3.00")]
assert Series.dtype(s2) == {:decimal, 38, 2}
end

test "float series to decimal" do
s = Series.from_list([1.345, 2.561, 3.97212])
s1 = Series.cast(s, {:decimal, nil, 3})

assert Series.to_list(s1) == [
Decimal.new("1.345"),
Decimal.new("2.561"),
Decimal.new("3.972")
]

assert Series.dtype(s1) == {:decimal, 38, 3}

s2 = Series.cast(s, {:decimal, nil, 4})

assert Series.to_list(s2) == [
Decimal.new("1.3450"),
Decimal.new("2.5610"),
Decimal.new("3.9721")
]

assert Series.dtype(s2) == {:decimal, 38, 4}
end

test "string series to category" do
s = Series.from_list(["apple", "banana", "apple", "lemon"])
s1 = Series.cast(s, :category)
Expand Down

0 comments on commit 00569ac

Please sign in to comment.