diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 0d0bc0e0f..36c157bb4 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -363,6 +363,7 @@ defmodule Explorer.PolarsBackend.Native do def s_from_list_str(_name, _val), do: err() def s_from_list_binary(_name, _val), do: err() def s_from_list_categories(_name, _val), do: err() + def s_from_list_decimal(_name, _val, _precision, _scale), do: err() def s_from_list_of_series(_name, _val, _dtype), do: err() def s_from_list_of_series_as_structs(_name, _val, _dtype), do: err() def s_from_binary_f32(_name, _val), do: err() diff --git a/lib/explorer/polars_backend/shared.ex b/lib/explorer/polars_backend/shared.ex index 1a4c63692..fe9ecdcd6 100644 --- a/lib/explorer/polars_backend/shared.ex +++ b/lib/explorer/polars_backend/shared.ex @@ -189,6 +189,7 @@ defmodule Explorer.PolarsBackend.Shared do {:duration, precision} -> apply(:s_from_list_duration, [name, list, precision]) :binary -> Native.s_from_list_binary(name, list) :null -> Native.s_from_list_null(name, length(list)) + {:decimal, precision, scale} -> Native.s_from_list_decimal(name, list, precision, scale) end end diff --git a/lib/explorer/shared.ex b/lib/explorer/shared.ex index 85f7c2516..d2d47e6e7 100644 --- a/lib/explorer/shared.ex +++ b/lib/explorer/shared.ex @@ -42,7 +42,7 @@ defmodule Explorer.Shared do within lists inside. """ def dtypes do - @scalar_types ++ [{:list, :any}, {:struct, :any}] + @scalar_types ++ [{:list, :any}, {:struct, :any}, {:decimal, :any, :any}] end @doc """ @@ -99,6 +99,9 @@ defmodule Explorer.Shared do {:naive_datetime, precision} end + def normalise_dtype({:decimal, _precision, _scale} = dtype), do: dtype + def normalise_dtype(:decimal), do: {:decimal, nil, 2} + def normalise_dtype(_dtype), do: nil @doc """ @@ -494,6 +497,7 @@ defmodule Explorer.Shared do def dtype_to_string({:f, size}), do: "f" <> Integer.to_string(size) def dtype_to_string({:s, size}), do: "s" <> Integer.to_string(size) def dtype_to_string({:u, size}), do: "u" <> Integer.to_string(size) + def dtype_to_string({:decimal, precision, scale}), do: "decimal[#{precision}, #{scale}]" def dtype_to_string(other) when is_atom(other), do: Atom.to_string(other) defp precision_string(:millisecond), do: "ms" diff --git a/mix.exs b/mix.exs index 8e118dfef..ab262e10d 100644 --- a/mix.exs +++ b/mix.exs @@ -46,6 +46,7 @@ defmodule Explorer.MixProject do {:rustler_precompiled, "~> 0.7"}, {:table, "~> 0.1.2"}, {:table_rex, "~> 3.1.1 or ~> 4.0.0"}, + {:decimal, "~> 2.1"}, ## Optional {:flame, "~> 0.3", optional: true}, diff --git a/native/explorer/src/datatypes.rs b/native/explorer/src/datatypes.rs index 4ddb84e36..53aebb09d 100644 --- a/native/explorer/src/datatypes.rs +++ b/native/explorer/src/datatypes.rs @@ -565,6 +565,33 @@ impl Literal for ExTime { } } +#[derive(NifStruct, Copy, Clone, Debug)] +#[module = "Decimal"] +pub struct ExDecimal { + pub sign: i8, + pub coef: u64, + pub exp: i64, +} + +impl ExDecimal { + pub fn signed_coef(self) -> i128 { + self.sign as i128 * self.coef as i128 + } +} + +impl Literal for ExDecimal { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::Decimal( + if self.sign.is_positive() { + self.coef.into() + } else { + -(self.coef as i128) + }, + usize::try_from(-(self.exp)).expect("exponent should fit an usize"), + )) + } +} + /// Represents valid Elixir types that can be used as literals in Polars. pub enum ExValidValue<'a> { I64(i64), @@ -575,6 +602,7 @@ pub enum ExValidValue<'a> { Time(ExTime), DateTime(ExNaiveDateTime), Duration(ExDuration), + Decimal(ExDecimal), } impl<'a> ExValidValue<'a> { @@ -598,6 +626,7 @@ impl<'a> Literal for &ExValidValue<'a> { ExValidValue::Time(v) => v.lit(), ExValidValue::DateTime(v) => v.lit(), ExValidValue::Duration(v) => v.lit(), + ExValidValue::Decimal(v) => v.lit(), } } } @@ -620,6 +649,8 @@ impl<'a> rustler::Decoder<'a> for ExValidValue<'a> { Ok(ExValidValue::DateTime(datetime)) } else if let Ok(duration) = term.decode::() { Ok(ExValidValue::Duration(duration)) + } else if let Ok(decimal) = term.decode::() { + Ok(ExValidValue::Decimal(decimal)) } else { Err(rustler::Error::BadArg) } diff --git a/native/explorer/src/datatypes/ex_dtypes.rs b/native/explorer/src/datatypes/ex_dtypes.rs index c93220127..4e7025720 100644 --- a/native/explorer/src/datatypes/ex_dtypes.rs +++ b/native/explorer/src/datatypes/ex_dtypes.rs @@ -67,6 +67,8 @@ pub enum ExSeriesDtype { Duration(ExTimeUnit), List(Box), Struct(Vec<(String, ExSeriesDtype)>), + // Precision and scale. + Decimal(Option, Option), } impl TryFrom<&DataType> for ExSeriesDtype { @@ -113,6 +115,8 @@ impl TryFrom<&DataType> for ExSeriesDtype { Ok(ExSeriesDtype::Struct(struct_fields)) } + DataType::Decimal(precision, scale) => Ok(ExSeriesDtype::Decimal(*precision, *scale)), + _ => Err(ExplorerError::Other(format!( "cannot cast to dtype: {value}" ))), @@ -171,6 +175,7 @@ impl TryFrom<&ExSeriesDtype> for DataType { .map(|(k, v)| Ok(Field::new(k.into(), v.try_into()?))) .collect::, Self::Error>>()?, )), + ExSeriesDtype::Decimal(precision, scale) => Ok(DataType::Decimal(*precision, *scale)), } } } diff --git a/native/explorer/src/encoding.rs b/native/explorer/src/encoding.rs index 6145cae30..755003bde 100644 --- a/native/explorer/src/encoding.rs +++ b/native/explorer/src/encoding.rs @@ -349,7 +349,7 @@ fn time_unit_to_atom(time_unit: TimeUnit) -> atom::Atom { TimeUnit::Nanoseconds => nanosecond(), } } - +// ######### Duration ########## macro_rules! unsafe_encode_duration { ($v: expr, $time_unit: expr, $duration_struct_keys: ident, $duration_module: ident, $env: ident) => {{ let value = $v; @@ -417,6 +417,75 @@ fn duration_series_to_list<'b>( )) } +// ######### End of Duration ########## + +// ######### Decimal ########## +macro_rules! unsafe_encode_decimal { + ($v: expr, $scale: expr, $decimal_struct_keys: ident, $decimal_module: ident, $env: ident) => {{ + let coef = $v.abs(); + let scale = -($scale as isize); + let sign = $v.signum(); + + unsafe { + Term::new( + $env, + map::make_map_from_arrays( + $env.as_c_arg(), + $decimal_struct_keys, + &[ + $decimal_module, + coef.encode($env).as_c_arg(), + scale.encode($env).as_c_arg(), + sign.encode($env).as_c_arg(), + ], + ) + .unwrap(), + ) + } + }}; +} + +// Here we build the Decimal struct manually, as it's much faster than using NifStruct +fn decimal_struct_keys(env: Env) -> [NIF_TERM; 4] { + return [ + atom::__struct__().encode(env).as_c_arg(), + atoms::coef().encode(env).as_c_arg(), + atoms::exp().encode(env).as_c_arg(), + atoms::sign().encode(env).as_c_arg(), + ]; +} + +#[inline] +pub fn encode_decimal(v: i128, scale: usize, env: Env) -> Result { + let struct_keys = &decimal_struct_keys(env); + let module_atom = atoms::decimal_module().encode(env).as_c_arg(); + + Ok(unsafe_encode_decimal!( + v, + scale, + struct_keys, + module_atom, + env + )) +} + +#[inline] +fn decimal_series_to_list<'b>(s: &Series, env: Env<'b>) -> Result, ExplorerError> { + let struct_keys = &decimal_struct_keys(env); + let module_atom = atoms::decimal_module().encode(env).as_c_arg(); + let decimal_chunked = s.decimal()?; + let scale = decimal_chunked.scale(); + + Ok(unsafe_iterator_series_to_list!( + env, + decimal_chunked.into_iter().map(|option| option + .map(|v| { unsafe_encode_decimal!(v, scale, struct_keys, module_atom, env) }) + .encode(env)) + )) +} + +// ######### End of Decimal ########## + macro_rules! unsafe_encode_time { ($v: expr, $naive_time_struct_keys: ident, $calendar_iso_module: ident, $time_module: ident, $env: ident) => {{ let t = time64ns_to_time($v); @@ -710,6 +779,7 @@ pub fn term_from_value<'b>(v: AnyValue, env: Env<'b>) -> Result, Explor .map(|(value, field)| Ok((field.name.as_str(), term_from_value(value, env)?))) .collect::, ExplorerError>>() .map(|map| map.encode(env)), + AnyValue::Decimal(number, scale) => encode_decimal(number, scale, env), dt => panic!("cannot encode value {dt:?} to term"), } } @@ -756,6 +826,7 @@ pub fn list_from_series(s: ExSeries, env: Env) -> Result { .map(|value| term_from_value(value, env)) .collect::, ExplorerError>>() .map(|values| values.encode(env)), + DataType::Decimal(_precision, _scale) => decimal_series_to_list(&s, env), dt => panic!("to_list/1 not implemented for {dt:?}"), } } diff --git a/native/explorer/src/lib.rs b/native/explorer/src/lib.rs index 0a0efb7c0..d16015f84 100644 --- a/native/explorer/src/lib.rs +++ b/native/explorer/src/lib.rs @@ -42,6 +42,7 @@ mod atoms { duration_module = "Elixir.Explorer.Duration", naive_datetime_module = "Elixir.NaiveDateTime", time_module = "Elixir.Time", + decimal_module = "Elixir.Decimal", hour, minute, second, @@ -61,6 +62,9 @@ mod atoms { time_zone, utc_offset, zone_abbr, + coef, + exp, + sign, } } diff --git a/native/explorer/src/series/from_list.rs b/native/explorer/src/series/from_list.rs index 6442233d4..cc5e92144 100644 --- a/native/explorer/src/series/from_list.rs +++ b/native/explorer/src/series/from_list.rs @@ -215,6 +215,53 @@ pub fn s_from_list_null(name: &str, length: usize) -> ExSeries { ExSeries::new(Series::new(name.into(), s)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn s_from_list_decimal( + _name: &str, + _val: Term, + _precision: Option, + _scale: Option, +) -> Result { + Err(ExplorerError::Other( + "from_list/2 not yet implemented for decimal lists".into(), + )) + // let iterator = val + // .decode::() + // .map_err(|err| ExplorerError::Other(format!("expecting list as term: {err:?}")))?; + // // let mut precision = precision; + // // let mut scale = scale; + + // let values: Vec> = iterator + // .map(|item| match item.get_type() { + // TermType::Integer => item.decode::>().map_err(|err| { + // ExplorerError::Other(format!("int number is too big for an i128: {err:?}")) + // }), + // TermType::Map => item + // .decode::() + // .map(|ex_decimal| Some(ex_decimal.signed_coef())) + // .map_err(|error| { + // ExplorerError::Other(format!( + // "cannot decode a valid decimal from term. error: {error:?}" + // )) + // }), + // // TODO: handle float special cases + // TermType::Atom => Ok(None), + // term_type => Err(ExplorerError::Other(format!( + // "from_list/2 for decimals not implemented for {term_type:?}" + // ))), + // }) + // .collect::>, ExplorerError>>()?; + + // Series::new(name.into(), values) + // .cast(&DataType::Decimal(precision, scale)) + // .map(ExSeries::new) + // .map_err(|error| { + // ExplorerError::Other(format!( + // "from_list/2 cannot cast integer series to a valid decimal series: {error:?}" + // )) + // }) +} + macro_rules! from_list { ($name:ident, $type:ty) => { #[rustler::nif(schedule = "DirtyCpu")] diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 4a5b10859..ef919a043 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -3899,6 +3899,42 @@ defmodule Explorer.SeriesTest do assert Series.dtype(s3) == {:naive_datetime, :microsecond} end + test "integer series to decimal" do + s = Series.from_list([1, 2, 3]) + s1 = Series.cast(s, {:decimal, nil, 0}) + assert Series.to_list(s1) == [Decimal.new("1"), Decimal.new("2"), Decimal.new("3")] + # 38 is Polars' default for precision. + assert Series.dtype(s1) == {:decimal, 38, 0} + + # increased scale + s2 = Series.cast(s, {:decimal, nil, 2}) + assert Series.to_list(s2) == [Decimal.new("1.00"), Decimal.new("2.00"), Decimal.new("3.00")] + assert Series.dtype(s2) == {:decimal, 38, 2} + end + + test "float series to decimal" do + s = Series.from_list([1.345, 2.561, 3.97212]) + s1 = Series.cast(s, {:decimal, nil, 3}) + + assert Series.to_list(s1) == [ + Decimal.new("1.345"), + Decimal.new("2.561"), + Decimal.new("3.972") + ] + + assert Series.dtype(s1) == {:decimal, 38, 3} + + s2 = Series.cast(s, {:decimal, nil, 4}) + + assert Series.to_list(s2) == [ + Decimal.new("1.3450"), + Decimal.new("2.5610"), + Decimal.new("3.9721") + ] + + assert Series.dtype(s2) == {:decimal, 38, 4} + end + test "string series to category" do s = Series.from_list(["apple", "banana", "apple", "lemon"]) s1 = Series.cast(s, :category)